move session creation up the stack

This commit is contained in:
carlomazzaferro 2021-11-28 13:16:25 +01:00
parent c34485c493
commit c912e8aa74
No known key found for this signature in database
GPG Key ID: 0CED3103EF7B2187
2 changed files with 44 additions and 12 deletions

17
cli.py
View File

@ -5,6 +5,7 @@ import sys
import click
from mev_inspect.concurrency import coro
from mev_inspect.db import get_sessions
from mev_inspect.inspector import MEVInspector
RPC_URL_ENV = "RPC_URL"
@ -23,7 +24,10 @@ def cli():
@coro
async def inspect_block_command(block_number: int, rpc: str):
inspector = MEVInspector(rpc)
await inspector.inspect_single_block(block=block_number)
inspect_session, trace_session = get_sessions()
await inspector.inspect_single_block(
block=block_number, inspect_session=inspect_session, trace_session=trace_session
)
@cli.command()
@ -32,7 +36,10 @@ async def inspect_block_command(block_number: int, rpc: str):
@coro
async def fetch_block_command(block_number: int, rpc: str):
inspector = MEVInspector(rpc)
block = await inspector.create_from_block(block_number=block_number)
_, trace_session = get_sessions()
block = await inspector.create_from_block(
block_number=block_number, trace_session=trace_session
)
print(block.json())
@ -62,8 +69,12 @@ async def inspect_many_blocks_command(
max_concurrency=max_concurrency,
request_timeout=request_timeout,
)
inspect_session, trace_session = get_sessions()
await inspector.inspect_many_blocks(
after_block=after_block, before_block=before_block
after_block=after_block,
before_block=before_block,
inspect_session=inspect_session,
trace_session=trace_session,
)

View File

@ -2,13 +2,14 @@ import asyncio
import logging
import traceback
from asyncio import CancelledError
from typing import Optional
from sqlalchemy.ext.asyncio import async_scoped_session
from web3 import Web3
from web3.eth import AsyncEth
from mev_inspect.block import create_from_block_number
from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.db import get_sessions
from mev_inspect.inspect_block import inspect_block
from mev_inspect.provider import get_base_provider
@ -27,8 +28,9 @@ class MEVInspector:
self.trace_classifier = TraceClassifier()
self.max_concurrency = asyncio.Semaphore(max_concurrency)
async def create_from_block(self, block_number: int):
_, trace_session = get_sessions()
async def create_from_block(
self, block_number: int, trace_session: Optional[async_scoped_session]
):
return await create_from_block_number(
base_provider=self.base_provider,
w3=self.w3,
@ -36,8 +38,12 @@ class MEVInspector:
trace_session=trace_session,
)
async def inspect_single_block(self, block: int):
inspect_session, trace_session = get_sessions()
async def inspect_single_block(
self,
block: int,
inspect_session: async_scoped_session,
trace_session: Optional[async_scoped_session],
):
return await inspect_block(
inspect_session,
trace_session,
@ -47,12 +53,22 @@ class MEVInspector:
block,
)
async def inspect_many_blocks(self, after_block: int, before_block: int):
async def inspect_many_blocks(
self,
after_block: int,
before_block: int,
inspect_session: async_scoped_session,
trace_session: Optional[async_scoped_session],
):
tasks = []
for block_number in range(after_block, before_block):
tasks.append(
asyncio.ensure_future(
self.safe_inspect_block(block_number=block_number)
self.safe_inspect_block(
block_number=block_number,
inspect_session=inspect_session,
trace_session=trace_session,
)
)
)
logger.info(f"Gathered {len(tasks)} blocks to inspect")
@ -64,8 +80,13 @@ class MEVInspector:
logger.error(f"Exited due to {type(e)}")
traceback.print_exc()
async def safe_inspect_block(self, block_number: int):
inspect_session, trace_session = get_sessions()
async def safe_inspect_block(
self,
block_number: int,
inspect_session: async_scoped_session,
trace_session: Optional[async_scoped_session],
):
async with self.max_concurrency:
return await inspect_block(
inspect_session,