Pass DB sessions into inspector

This commit is contained in:
Luke Van Seters 2021-12-31 15:50:07 -05:00
parent 1ff9e9aa1c
commit 4662a1ecbc
3 changed files with 47 additions and 22 deletions

25
cli.py
View File

@ -29,8 +29,13 @@ async def inspect_block_command(block_number: int, rpc: str):
inspect_db_session = get_inspect_session() inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session() trace_db_session = get_trace_session()
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session) inspector = MEVInspector(rpc)
await inspector.inspect_single_block(block=block_number)
await inspector.inspect_single_block(
inspect_db_session=inspect_db_session,
trace_db_session=trace_db_session,
block=block_number,
)
@cli.command() @cli.command()
@ -38,11 +43,14 @@ async def inspect_block_command(block_number: int, rpc: str):
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, "")) @click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@coro @coro
async def fetch_block_command(block_number: int, rpc: str): async def fetch_block_command(block_number: int, rpc: str):
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session() trace_db_session = get_trace_session()
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session) inspector = MEVInspector(rpc)
block = await inspector.create_from_block(block_number=block_number) block = await inspector.create_from_block(
block_number=block_number,
trace_db_session=trace_db_session,
)
print(block.json()) print(block.json())
@ -72,13 +80,14 @@ async def inspect_many_blocks_command(
inspector = MEVInspector( inspector = MEVInspector(
rpc, rpc,
inspect_db_session,
trace_db_session,
max_concurrency=max_concurrency, max_concurrency=max_concurrency,
request_timeout=request_timeout, request_timeout=request_timeout,
) )
await inspector.inspect_many_blocks( await inspector.inspect_many_blocks(
after_block=after_block, before_block=before_block inspect_db_session=inspect_db_session,
trace_db_session=trace_db_session,
after_block=after_block,
before_block=before_block,
) )

View File

@ -37,13 +37,14 @@ async def run():
inspect_db_session = get_inspect_session() inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session() trace_db_session = get_trace_session()
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session) inspector = MEVInspector(rpc)
base_provider = get_base_provider(rpc) base_provider = get_base_provider(rpc)
while not killer.kill_now: while not killer.kill_now:
await inspect_next_block( await inspect_next_block(
inspector, inspector,
inspect_db_session, inspect_db_session,
trace_db_session,
base_provider, base_provider,
healthcheck_url, healthcheck_url,
) )
@ -54,6 +55,7 @@ async def run():
async def inspect_next_block( async def inspect_next_block(
inspector: MEVInspector, inspector: MEVInspector,
inspect_db_session, inspect_db_session,
trace_db_session,
base_provider, base_provider,
healthcheck_url, healthcheck_url,
): ):
@ -72,7 +74,11 @@ async def inspect_next_block(
logger.info(f"Writing block: {block_number}") logger.info(f"Writing block: {block_number}")
await inspector.inspect_single_block(block=block_number) await inspector.inspect_single_block(
inspect_db_session=inspect_db_session,
trace_db_session=trace_db_session,
block=block_number,
)
update_latest_block(inspect_db_session, block_number) update_latest_block(inspect_db_session, block_number)
if healthcheck_url: if healthcheck_url:

View File

@ -27,38 +27,44 @@ class MEVInspector:
def __init__( def __init__(
self, self,
rpc: str, rpc: str,
inspect_db_session: orm.Session,
trace_db_session: Optional[orm.Session],
max_concurrency: int = 1, max_concurrency: int = 1,
request_timeout: int = 300, request_timeout: int = 300,
): ):
self.inspect_db_session = inspect_db_session
self.trace_db_session = trace_db_session
base_provider = get_base_provider(rpc, request_timeout=request_timeout) base_provider = get_base_provider(rpc, request_timeout=request_timeout)
self.w3 = Web3(base_provider, modules={"eth": (AsyncEth,)}, middlewares=[]) self.w3 = Web3(base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
self.trace_classifier = TraceClassifier() self.trace_classifier = TraceClassifier()
self.max_concurrency = asyncio.Semaphore(max_concurrency) self.max_concurrency = asyncio.Semaphore(max_concurrency)
async def create_from_block(self, block_number: int): async def create_from_block(
self,
trace_db_session: Optional[orm.Session],
block_number: int,
):
return await create_from_block_number( return await create_from_block_number(
w3=self.w3, w3=self.w3,
block_number=block_number, block_number=block_number,
trace_db_session=self.trace_db_session, trace_db_session=trace_db_session,
) )
async def inspect_single_block(self, block: int): async def inspect_single_block(
self,
inspect_db_session: orm.Session,
block: int,
trace_db_session: Optional[orm.Session],
):
return await inspect_block( return await inspect_block(
self.inspect_db_session, inspect_db_session,
self.w3, self.w3,
self.trace_classifier, self.trace_classifier,
block, block,
trace_db_session=self.trace_db_session, trace_db_session=trace_db_session,
) )
async def inspect_many_blocks( async def inspect_many_blocks(
self, self,
inspect_db_session: orm.Session,
trace_db_session: Optional[orm.Session],
after_block: int, after_block: int,
before_block: int, before_block: int,
block_batch_size: int = 10, block_batch_size: int = 10,
@ -71,6 +77,8 @@ class MEVInspector:
tasks.append( tasks.append(
asyncio.ensure_future( asyncio.ensure_future(
self.safe_inspect_many_blocks( self.safe_inspect_many_blocks(
inspect_db_session,
trace_db_session,
after_block_number=batch_after_block, after_block_number=batch_after_block,
before_block_number=batch_before_block, before_block_number=batch_before_block,
) )
@ -88,15 +96,17 @@ class MEVInspector:
async def safe_inspect_many_blocks( async def safe_inspect_many_blocks(
self, self,
inspect_db_session: orm.Session,
trace_db_session: Optional[orm.Session],
after_block_number: int, after_block_number: int,
before_block_number: int, before_block_number: int,
): ):
async with self.max_concurrency: async with self.max_concurrency:
return await inspect_many_blocks( return await inspect_many_blocks(
self.inspect_db_session, inspect_db_session,
self.w3, self.w3,
self.trace_classifier, self.trace_classifier,
after_block_number, after_block_number,
before_block_number, before_block_number,
trace_db_session=self.trace_db_session, trace_db_session=trace_db_session,
) )