163 lines
4.1 KiB
Python
163 lines
4.1 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
import traceback
|
|
from asyncio import CancelledError
|
|
from functools import wraps
|
|
|
|
import click
|
|
from web3 import Web3
|
|
from web3.eth import AsyncEth
|
|
|
|
from mev_inspect.classifiers.trace import TraceClassifier
|
|
from mev_inspect.db import get_inspect_session, get_trace_session
|
|
from mev_inspect.inspect_block import inspect_block
|
|
from mev_inspect.provider import get_base_provider
|
|
|
|
RPC_URL_ENV = "RPC_URL"
|
|
|
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
semaphore: asyncio.Semaphore
|
|
|
|
|
|
@click.group()
|
|
def cli():
|
|
pass
|
|
|
|
|
|
def coro(f):
|
|
@wraps(f)
|
|
def wrapper(*args, **kwargs):
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def cancel_task_callback():
|
|
for task in asyncio.all_tasks():
|
|
task.cancel()
|
|
|
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
loop.add_signal_handler(sig, cancel_task_callback)
|
|
try:
|
|
loop.run_until_complete(f(*args, **kwargs))
|
|
finally:
|
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
|
|
|
return wrapper
|
|
|
|
|
|
@cli.command()
|
|
@click.argument("block_number", type=int)
|
|
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
|
|
@click.option("--cache/--no-cache", default=True)
|
|
@coro
|
|
async def inspect_block_command(block_number: int, rpc: str, cache: bool):
|
|
inspect_db_session = get_inspect_session()
|
|
trace_db_session = get_trace_session()
|
|
|
|
base_provider = get_base_provider(rpc)
|
|
w3 = Web3(base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
|
|
trace_classifier = TraceClassifier()
|
|
|
|
if not cache:
|
|
logger.info("Skipping cache")
|
|
|
|
await inspect_block(
|
|
inspect_db_session,
|
|
base_provider,
|
|
w3,
|
|
trace_classifier,
|
|
block_number,
|
|
trace_db_session=trace_db_session,
|
|
)
|
|
|
|
|
|
@cli.command()
|
|
@click.argument("after_block", type=int)
|
|
@click.argument("before_block", type=int)
|
|
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
|
|
@click.option("--cache/--no-cache", default=True)
|
|
@click.option(
|
|
"--max-concurrency",
|
|
type=int,
|
|
help="maximum number of concurrent connections",
|
|
default=5,
|
|
)
|
|
@click.option(
|
|
"--request-timeout", type=int, help="timeout for requests to nodes", default=500
|
|
)
|
|
@coro
|
|
async def inspect_many_blocks_command(
|
|
after_block: int,
|
|
before_block: int,
|
|
rpc: str,
|
|
cache: bool,
|
|
max_concurrency: int,
|
|
request_timeout: int,
|
|
):
|
|
global semaphore # pylint: disable=global-statement
|
|
semaphore = asyncio.Semaphore(max_concurrency)
|
|
inspect_db_session = get_inspect_session()
|
|
trace_db_session = get_trace_session()
|
|
|
|
base_provider = get_base_provider(rpc, request_timeout=request_timeout)
|
|
w3 = Web3(base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
|
|
|
|
trace_classifier = TraceClassifier()
|
|
|
|
if not cache:
|
|
logger.info("Skipping cache")
|
|
|
|
tasks = []
|
|
|
|
for block_number in range(after_block, before_block):
|
|
tasks.append(
|
|
asyncio.ensure_future(
|
|
safe_inspect_block(
|
|
inspect_db_session,
|
|
base_provider,
|
|
w3,
|
|
trace_classifier,
|
|
block_number,
|
|
trace_db_session,
|
|
)
|
|
)
|
|
)
|
|
logger.info(f"Gathered {len(tasks)} blocks to inspect")
|
|
try:
|
|
await asyncio.gather(*tasks)
|
|
except CancelledError:
|
|
logger.info("Requested to exit, cleaning up...")
|
|
except Exception as e:
|
|
logger.error(f"Existed due to {type(e)}")
|
|
traceback.print_exc()
|
|
|
|
|
|
async def safe_inspect_block(
|
|
inspect_db_session,
|
|
base_provider,
|
|
w3,
|
|
trace_classifier,
|
|
block_number,
|
|
trace_db_session,
|
|
):
|
|
async with semaphore:
|
|
return await inspect_block(
|
|
inspect_db_session,
|
|
base_provider,
|
|
w3,
|
|
trace_classifier,
|
|
block_number,
|
|
trace_db_session=trace_db_session,
|
|
)
|
|
|
|
|
|
def get_rpc_url() -> str:
|
|
return os.environ["RPC_URL"]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cli()
|