163 lines
4.1 KiB
Python

import asyncio
import logging
import os
import signal
import sys
import traceback
from asyncio import CancelledError
from functools import wraps
import click
from web3 import Web3
from web3.eth import AsyncEth
from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspect_block import inspect_block
from mev_inspect.provider import get_base_provider
RPC_URL_ENV = "RPC_URL"
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
semaphore: asyncio.Semaphore
@click.group()
def cli():
pass
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
def cancel_task_callback():
for task in asyncio.all_tasks():
task.cancel()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, cancel_task_callback)
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
return wrapper
@cli.command()
@click.argument("block_number", type=int)
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@click.option("--cache/--no-cache", default=True)
@coro
async def inspect_block_command(block_number: int, rpc: str, cache: bool):
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
base_provider = get_base_provider(rpc)
w3 = Web3(base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
trace_classifier = TraceClassifier()
if not cache:
logger.info("Skipping cache")
await inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session=trace_db_session,
)
@cli.command()
@click.argument("after_block", type=int)
@click.argument("before_block", type=int)
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@click.option("--cache/--no-cache", default=True)
@click.option(
"--max-concurrency",
type=int,
help="maximum number of concurrent connections",
default=5,
)
@click.option(
"--request-timeout", type=int, help="timeout for requests to nodes", default=500
)
@coro
async def inspect_many_blocks_command(
after_block: int,
before_block: int,
rpc: str,
cache: bool,
max_concurrency: int,
request_timeout: int,
):
global semaphore # pylint: disable=global-statement
semaphore = asyncio.Semaphore(max_concurrency)
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
base_provider = get_base_provider(rpc, request_timeout=request_timeout)
w3 = Web3(base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
trace_classifier = TraceClassifier()
if not cache:
logger.info("Skipping cache")
tasks = []
for block_number in range(after_block, before_block):
tasks.append(
asyncio.ensure_future(
safe_inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session,
)
)
)
logger.info(f"Gathered {len(tasks)} blocks to inspect")
try:
await asyncio.gather(*tasks)
except CancelledError:
logger.info("Requested to exit, cleaning up...")
except Exception as e:
logger.error(f"Existed due to {type(e)}")
traceback.print_exc()
async def safe_inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session,
):
async with semaphore:
return await inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session=trace_db_session,
)
def get_rpc_url() -> str:
return os.environ["RPC_URL"]
if __name__ == "__main__":
cli()