Use inspector class -- remove global Semaphore and improve error handling

This commit is contained in:
carlomazzaferro 2021-10-28 11:04:24 +01:00
parent e15eef49c1
commit c3475bbd8f
No known key found for this signature in database
GPG Key ID: 0CED3103EF7B2187
4 changed files with 115 additions and 88 deletions

95
cli.py
View File

@ -3,26 +3,17 @@ import logging
import os
import signal
import sys
import traceback
from asyncio import CancelledError
from functools import wraps
import click
from web3 import Web3
from web3.eth import AsyncEth
from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspect_block import inspect_block
from mev_inspect.provider import get_base_provider
from mev_inspect.inspector import MEVInspector
RPC_URL_ENV = "RPC_URL"
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
semaphore: asyncio.Semaphore
@click.group()
def cli():
@ -54,24 +45,8 @@ def coro(f):
@click.option("--cache/--no-cache", default=True)
@coro
async def inspect_block_command(block_number: int, rpc: str, cache: bool):
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
base_provider = get_base_provider(rpc)
w3 = Web3(base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
trace_classifier = TraceClassifier()
if not cache:
logger.info("Skipping cache")
await inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session=trace_db_session,
)
inspector = MEVInspector(rpc=rpc, cache=cache)
await inspector.inspect_single_block(block=block_number)
@cli.command()
@ -97,61 +72,15 @@ async def inspect_many_blocks_command(
max_concurrency: int,
request_timeout: int,
):
global semaphore # pylint: disable=global-statement
semaphore = asyncio.Semaphore(max_concurrency)
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
base_provider = get_base_provider(rpc, request_timeout=request_timeout)
w3 = Web3(base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
trace_classifier = TraceClassifier()
if not cache:
logger.info("Skipping cache")
tasks = []
for block_number in range(after_block, before_block):
tasks.append(
asyncio.ensure_future(
safe_inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session,
)
)
)
logger.info(f"Gathered {len(tasks)} blocks to inspect")
try:
await asyncio.gather(*tasks)
except CancelledError:
logger.info("Requested to exit, cleaning up...")
except Exception as e:
logger.error(f"Existed due to {type(e)}")
traceback.print_exc()
async def safe_inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session,
):
async with semaphore:
return await inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session=trace_db_session,
)
inspector = MEVInspector(
rpc=rpc,
cache=cache,
max_concurrency=max_concurrency,
request_timeout=request_timeout,
)
await inspector.inspect_many_blocks(
after_block=after_block, before_block=before_block
)
def get_rpc_url() -> str:

View File

@ -1,4 +1,6 @@
import asyncio
import logging
import sys
from pathlib import Path
from typing import List, Optional
@ -11,6 +13,8 @@ from mev_inspect.schemas.receipts import Receipt
cache_directory = "./cache"
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
def get_latest_block_number(w3: Web3) -> int:
@ -47,10 +51,17 @@ async def _fetch_block(
fetch_base_fee_per_gas(w3, block_number),
)
receipts: List[Receipt] = [
Receipt(**receipt) for receipt in receipts_json["result"]
]
traces = [Trace(**trace_json) for trace_json in traces_json["result"]]
try:
receipts: List[Receipt] = [
Receipt(**receipt) for receipt in receipts_json["result"]
]
traces = [Trace(**trace_json) for trace_json in traces_json["result"]]
except KeyError as e:
logger.warning(
f"Failed to create objects from block: {block_number}: {e}, retrying in 5..."
)
await asyncio.sleep(5)
return await _fetch_block(w3, base_provider, block_number)
return Block(
block_number=block_number,

73
mev_inspect/inspector.py Normal file
View File

@ -0,0 +1,73 @@
import asyncio
import logging
import sys
import traceback
from asyncio import CancelledError
from web3 import Web3
from web3.eth import AsyncEth
from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspect_block import inspect_block
from mev_inspect.provider import get_base_provider
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
class MEVInspector:
def __init__(
self,
rpc: str,
cache: bool,
max_concurrency: int = 1,
request_timeout: int = 300,
):
if not cache:
logger.info("Skipping cache")
self.inspect_db_session = get_inspect_session()
self.trace_db_session = get_trace_session()
self.base_provider = get_base_provider(rpc, request_timeout=request_timeout)
self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
self.trace_classifier = TraceClassifier()
self.max_concurrency = asyncio.Semaphore(max_concurrency)
async def inspect_single_block(self, block: int):
return await inspect_block(
self.inspect_db_session,
self.base_provider,
self.w3,
self.trace_classifier,
block,
trace_db_session=self.trace_db_session,
)
async def inspect_many_blocks(self, after_block: int, before_block: int):
tasks = []
for block_number in range(after_block, before_block):
tasks.append(
asyncio.ensure_future(
self.safe_inspect_block(block_number=block_number)
)
)
logger.info(f"Gathered {len(tasks)} blocks to inspect")
try:
await asyncio.gather(*tasks)
except CancelledError:
logger.info("Requested to exit, cleaning up...")
except Exception as e:
logger.error(f"Existed due to {type(e)}")
traceback.print_exc()
async def safe_inspect_block(self, block_number: int):
async with self.max_concurrency:
return await inspect_block(
self.inspect_db_session,
self.base_provider,
self.w3,
self.trace_classifier,
block_number,
trace_db_session=self.trace_db_session,
)

View File

@ -11,6 +11,12 @@ from typing import (
)
from asyncio.exceptions import TimeoutError
from aiohttp.client_exceptions import (
ClientOSError,
ServerDisconnectedError,
ServerTimeoutError,
ClientResponseError,
)
from requests.exceptions import (
ConnectionError,
HTTPError,
@ -25,6 +31,14 @@ from web3.types import (
)
request_exceptions = (ConnectionError, HTTPError, Timeout, TooManyRedirects)
aiohttp_exceptions = (
ClientOSError,
ServerDisconnectedError,
ServerTimeoutError,
ClientResponseError,
)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
@ -74,5 +88,5 @@ async def http_retry_with_backoff_request_middleware(
return await exception_retry_with_backoff_middleware(
make_request,
web3,
(ConnectionError, HTTPError, Timeout, TooManyRedirects, TimeoutError),
(request_exceptions + aiohttp_exceptions + (TimeoutError,)),
)