From c912e8aa74616398b35a9c18d5061aed3e65d7ee Mon Sep 17 00:00:00 2001 From: carlomazzaferro Date: Sun, 28 Nov 2021 13:16:25 +0100 Subject: [PATCH] move session creation up the stack --- cli.py | 17 ++++++++++++++--- mev_inspect/inspector.py | 39 ++++++++++++++++++++++++++++++--------- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/cli.py b/cli.py index 0f0376b..9030180 100644 --- a/cli.py +++ b/cli.py @@ -5,6 +5,7 @@ import sys import click from mev_inspect.concurrency import coro +from mev_inspect.db import get_sessions from mev_inspect.inspector import MEVInspector RPC_URL_ENV = "RPC_URL" @@ -23,7 +24,10 @@ def cli(): @coro async def inspect_block_command(block_number: int, rpc: str): inspector = MEVInspector(rpc) - await inspector.inspect_single_block(block=block_number) + inspect_session, trace_session = get_sessions() + await inspector.inspect_single_block( + block=block_number, inspect_session=inspect_session, trace_session=trace_session + ) @cli.command() @@ -32,7 +36,10 @@ async def inspect_block_command(block_number: int, rpc: str): @coro async def fetch_block_command(block_number: int, rpc: str): inspector = MEVInspector(rpc) - block = await inspector.create_from_block(block_number=block_number) + _, trace_session = get_sessions() + block = await inspector.create_from_block( + block_number=block_number, trace_session=trace_session + ) print(block.json()) @@ -62,8 +69,12 @@ async def inspect_many_blocks_command( max_concurrency=max_concurrency, request_timeout=request_timeout, ) + inspect_session, trace_session = get_sessions() await inspector.inspect_many_blocks( - after_block=after_block, before_block=before_block + after_block=after_block, + before_block=before_block, + inspect_session=inspect_session, + trace_session=trace_session, ) diff --git a/mev_inspect/inspector.py b/mev_inspect/inspector.py index 8e3cb26..17cfafb 100644 --- a/mev_inspect/inspector.py +++ b/mev_inspect/inspector.py @@ -2,13 +2,14 @@ import asyncio import logging import traceback from asyncio import CancelledError +from typing import Optional +from sqlalchemy.ext.asyncio import async_scoped_session from web3 import Web3 from web3.eth import AsyncEth from mev_inspect.block import create_from_block_number from mev_inspect.classifiers.trace import TraceClassifier -from mev_inspect.db import get_sessions from mev_inspect.inspect_block import inspect_block from mev_inspect.provider import get_base_provider @@ -27,8 +28,9 @@ class MEVInspector: self.trace_classifier = TraceClassifier() self.max_concurrency = asyncio.Semaphore(max_concurrency) - async def create_from_block(self, block_number: int): - _, trace_session = get_sessions() + async def create_from_block( + self, block_number: int, trace_session: Optional[async_scoped_session] + ): return await create_from_block_number( base_provider=self.base_provider, w3=self.w3, @@ -36,8 +38,12 @@ class MEVInspector: trace_session=trace_session, ) - async def inspect_single_block(self, block: int): - inspect_session, trace_session = get_sessions() + async def inspect_single_block( + self, + block: int, + inspect_session: async_scoped_session, + trace_session: Optional[async_scoped_session], + ): return await inspect_block( inspect_session, trace_session, @@ -47,12 +53,22 @@ class MEVInspector: block, ) - async def inspect_many_blocks(self, after_block: int, before_block: int): + async def inspect_many_blocks( + self, + after_block: int, + before_block: int, + inspect_session: async_scoped_session, + trace_session: Optional[async_scoped_session], + ): tasks = [] for block_number in range(after_block, before_block): tasks.append( asyncio.ensure_future( - self.safe_inspect_block(block_number=block_number) + self.safe_inspect_block( + block_number=block_number, + inspect_session=inspect_session, + trace_session=trace_session, + ) ) ) logger.info(f"Gathered {len(tasks)} blocks to inspect") @@ -64,8 +80,13 @@ class MEVInspector: logger.error(f"Exited due to {type(e)}") traceback.print_exc() - async def safe_inspect_block(self, block_number: int): - inspect_session, trace_session = get_sessions() + async def safe_inspect_block( + self, + block_number: int, + inspect_session: async_scoped_session, + trace_session: Optional[async_scoped_session], + ): + async with self.max_concurrency: return await inspect_block( inspect_session,