Use inspector class -- remove global Semaphore and improve error handling

This commit is contained in:
carlomazzaferro 2021-10-28 11:04:24 +01:00
parent e15eef49c1
commit c3475bbd8f
No known key found for this signature in database
GPG Key ID: 0CED3103EF7B2187
4 changed files with 115 additions and 88 deletions

95
cli.py
View File

@ -3,26 +3,17 @@ import logging
import os import os
import signal import signal
import sys import sys
import traceback
from asyncio import CancelledError
from functools import wraps from functools import wraps
import click import click
from web3 import Web3
from web3.eth import AsyncEth
from mev_inspect.classifiers.trace import TraceClassifier from mev_inspect.inspector import MEVInspector
from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspect_block import inspect_block
from mev_inspect.provider import get_base_provider
RPC_URL_ENV = "RPC_URL" RPC_URL_ENV = "RPC_URL"
logging.basicConfig(stream=sys.stdout, level=logging.INFO) logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
semaphore: asyncio.Semaphore
@click.group() @click.group()
def cli(): def cli():
@ -54,24 +45,8 @@ def coro(f):
@click.option("--cache/--no-cache", default=True) @click.option("--cache/--no-cache", default=True)
@coro @coro
async def inspect_block_command(block_number: int, rpc: str, cache: bool): async def inspect_block_command(block_number: int, rpc: str, cache: bool):
inspect_db_session = get_inspect_session() inspector = MEVInspector(rpc=rpc, cache=cache)
trace_db_session = get_trace_session() await inspector.inspect_single_block(block=block_number)
base_provider = get_base_provider(rpc)
w3 = Web3(base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
trace_classifier = TraceClassifier()
if not cache:
logger.info("Skipping cache")
await inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session=trace_db_session,
)
@cli.command() @cli.command()
@ -97,61 +72,15 @@ async def inspect_many_blocks_command(
max_concurrency: int, max_concurrency: int,
request_timeout: int, request_timeout: int,
): ):
global semaphore # pylint: disable=global-statement inspector = MEVInspector(
semaphore = asyncio.Semaphore(max_concurrency) rpc=rpc,
inspect_db_session = get_inspect_session() cache=cache,
trace_db_session = get_trace_session() max_concurrency=max_concurrency,
request_timeout=request_timeout,
base_provider = get_base_provider(rpc, request_timeout=request_timeout) )
w3 = Web3(base_provider, modules={"eth": (AsyncEth,)}, middlewares=[]) await inspector.inspect_many_blocks(
after_block=after_block, before_block=before_block
trace_classifier = TraceClassifier() )
if not cache:
logger.info("Skipping cache")
tasks = []
for block_number in range(after_block, before_block):
tasks.append(
asyncio.ensure_future(
safe_inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session,
)
)
)
logger.info(f"Gathered {len(tasks)} blocks to inspect")
try:
await asyncio.gather(*tasks)
except CancelledError:
logger.info("Requested to exit, cleaning up...")
except Exception as e:
logger.error(f"Existed due to {type(e)}")
traceback.print_exc()
async def safe_inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session,
):
async with semaphore:
return await inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session=trace_db_session,
)
def get_rpc_url() -> str: def get_rpc_url() -> str:

View File

@ -1,4 +1,6 @@
import asyncio import asyncio
import logging
import sys
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
@ -11,6 +13,8 @@ from mev_inspect.schemas.receipts import Receipt
cache_directory = "./cache" cache_directory = "./cache"
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
def get_latest_block_number(w3: Web3) -> int: def get_latest_block_number(w3: Web3) -> int:
@ -47,10 +51,17 @@ async def _fetch_block(
fetch_base_fee_per_gas(w3, block_number), fetch_base_fee_per_gas(w3, block_number),
) )
receipts: List[Receipt] = [ try:
Receipt(**receipt) for receipt in receipts_json["result"] receipts: List[Receipt] = [
] Receipt(**receipt) for receipt in receipts_json["result"]
traces = [Trace(**trace_json) for trace_json in traces_json["result"]] ]
traces = [Trace(**trace_json) for trace_json in traces_json["result"]]
except KeyError as e:
logger.warning(
f"Failed to create objects from block: {block_number}: {e}, retrying in 5..."
)
await asyncio.sleep(5)
return await _fetch_block(w3, base_provider, block_number)
return Block( return Block(
block_number=block_number, block_number=block_number,

73
mev_inspect/inspector.py Normal file
View File

@ -0,0 +1,73 @@
import asyncio
import logging
import sys
import traceback
from asyncio import CancelledError
from web3 import Web3
from web3.eth import AsyncEth
from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspect_block import inspect_block
from mev_inspect.provider import get_base_provider
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
class MEVInspector:
def __init__(
self,
rpc: str,
cache: bool,
max_concurrency: int = 1,
request_timeout: int = 300,
):
if not cache:
logger.info("Skipping cache")
self.inspect_db_session = get_inspect_session()
self.trace_db_session = get_trace_session()
self.base_provider = get_base_provider(rpc, request_timeout=request_timeout)
self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
self.trace_classifier = TraceClassifier()
self.max_concurrency = asyncio.Semaphore(max_concurrency)
async def inspect_single_block(self, block: int):
return await inspect_block(
self.inspect_db_session,
self.base_provider,
self.w3,
self.trace_classifier,
block,
trace_db_session=self.trace_db_session,
)
async def inspect_many_blocks(self, after_block: int, before_block: int):
tasks = []
for block_number in range(after_block, before_block):
tasks.append(
asyncio.ensure_future(
self.safe_inspect_block(block_number=block_number)
)
)
logger.info(f"Gathered {len(tasks)} blocks to inspect")
try:
await asyncio.gather(*tasks)
except CancelledError:
logger.info("Requested to exit, cleaning up...")
except Exception as e:
logger.error(f"Existed due to {type(e)}")
traceback.print_exc()
async def safe_inspect_block(self, block_number: int):
async with self.max_concurrency:
return await inspect_block(
self.inspect_db_session,
self.base_provider,
self.w3,
self.trace_classifier,
block_number,
trace_db_session=self.trace_db_session,
)

View File

@ -11,6 +11,12 @@ from typing import (
) )
from asyncio.exceptions import TimeoutError from asyncio.exceptions import TimeoutError
from aiohttp.client_exceptions import (
ClientOSError,
ServerDisconnectedError,
ServerTimeoutError,
ClientResponseError,
)
from requests.exceptions import ( from requests.exceptions import (
ConnectionError, ConnectionError,
HTTPError, HTTPError,
@ -25,6 +31,14 @@ from web3.types import (
) )
request_exceptions = (ConnectionError, HTTPError, Timeout, TooManyRedirects)
aiohttp_exceptions = (
ClientOSError,
ServerDisconnectedError,
ServerTimeoutError,
ClientResponseError,
)
logging.basicConfig(stream=sys.stdout, level=logging.INFO) logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -74,5 +88,5 @@ async def http_retry_with_backoff_request_middleware(
return await exception_retry_with_backoff_middleware( return await exception_retry_with_backoff_middleware(
make_request, make_request,
web3, web3,
(ConnectionError, HTTPError, Timeout, TooManyRedirects, TimeoutError), (request_exceptions + aiohttp_exceptions + (TimeoutError,)),
) )