diff --git a/cli.py b/cli.py index e535c4f..3c65fc4 100644 --- a/cli.py +++ b/cli.py @@ -5,6 +5,7 @@ from functools import wraps import click +from mev_inspect.db import get_inspect_session, get_trace_session from mev_inspect.inspector import MEVInspector RPC_URL_ENV = "RPC_URL" @@ -39,7 +40,10 @@ def coro(f): @click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, "")) @coro async def inspect_block_command(block_number: int, rpc: str): - inspector = MEVInspector(rpc=rpc) + inspect_db_session = get_inspect_session() + trace_db_session = get_trace_session() + + inspector = MEVInspector(rpc, inspect_db_session, trace_db_session) await inspector.inspect_single_block(block=block_number) @@ -48,7 +52,10 @@ async def inspect_block_command(block_number: int, rpc: str): @click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, "")) @coro async def fetch_block_command(block_number: int, rpc: str): - inspector = MEVInspector(rpc=rpc) + inspect_db_session = get_inspect_session() + trace_db_session = get_trace_session() + + inspector = MEVInspector(rpc, inspect_db_session, trace_db_session) block = await inspector.create_from_block(block_number=block_number) print(block.json()) @@ -74,8 +81,13 @@ async def inspect_many_blocks_command( max_concurrency: int, request_timeout: int, ): + inspect_db_session = get_inspect_session() + trace_db_session = get_trace_session() + inspector = MEVInspector( - rpc=rpc, + rpc, + inspect_db_session, + trace_db_session, max_concurrency=max_concurrency, request_timeout=request_timeout, ) diff --git a/mev_inspect/inspector.py b/mev_inspect/inspector.py index ca4e996..fd729f8 100644 --- a/mev_inspect/inspector.py +++ b/mev_inspect/inspector.py @@ -3,13 +3,14 @@ import logging import sys import traceback from asyncio import CancelledError +from typing import Optional +from sqlalchemy import orm from web3 import Web3 from web3.eth import AsyncEth from mev_inspect.block import create_from_block_number from mev_inspect.classifiers.trace import TraceClassifier -from mev_inspect.db import get_inspect_session, get_trace_session from mev_inspect.inspect_block import inspect_block from mev_inspect.provider import get_base_provider @@ -21,11 +22,13 @@ class MEVInspector: def __init__( self, rpc: str, + inspect_db_session: orm.Session, + trace_db_session: Optional[orm.Session], max_concurrency: int = 1, request_timeout: int = 300, ): - self.inspect_db_session = get_inspect_session() - self.trace_db_session = get_trace_session() + self.inspect_db_session = inspect_db_session + self.trace_db_session = trace_db_session self.base_provider = get_base_provider(rpc, request_timeout=request_timeout) self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[]) self.trace_classifier = TraceClassifier()