diff --git a/mev_inspect/block.py b/mev_inspect/block.py index 37b1080..abdab2c 100644 --- a/mev_inspect/block.py +++ b/mev_inspect/block.py @@ -2,10 +2,9 @@ import asyncio import logging from typing import List, Optional -from sqlalchemy import orm +from sqlalchemy.ext.asyncio import async_scoped_session from web3 import Web3 -from mev_inspect.db import get_trace_session from mev_inspect.fees import fetch_base_fee_per_gas from mev_inspect.schemas.blocks import Block from mev_inspect.schemas.receipts import Receipt @@ -29,12 +28,13 @@ async def create_from_block_number( base_provider, w3: Web3, block_number: int, + trace_session: Optional[async_scoped_session], ) -> Block: block: Optional[Block] = None - if get_trace_session() is not None: - async with get_trace_session() as session: # type: ignore - block = await _find_block(session, block_number) + if trace_session is not None: + block = await _find_block(trace_session, block_number) + await trace_session.close() if block is None: block = await _fetch_block(w3, base_provider, block_number) return block @@ -76,7 +76,7 @@ async def _fetch_block(w3, base_provider, block_number: int, retries: int = 0) - async def _find_block( - trace_db_session: orm.Session, + trace_db_session: async_scoped_session, block_number: int, ) -> Optional[Block]: traces = await _find_traces(trace_db_session, block_number) @@ -108,7 +108,7 @@ async def _find_block( def _find_block_timestamp( - trace_db_session: orm.Session, + trace_db_session: async_scoped_session, block_number: int, ) -> Optional[int]: result = trace_db_session.execute( @@ -124,7 +124,7 @@ def _find_block_timestamp( async def _find_traces( - trace_db_session: orm.Session, + trace_db_session: async_scoped_session, block_number: int, ) -> Optional[List[Trace]]: result = await trace_db_session.execute( @@ -140,7 +140,7 @@ async def _find_traces( async def _find_receipts( - trace_db_session: orm.Session, + trace_db_session: async_scoped_session, block_number: int, ) -> Optional[List[Receipt]]: result = await trace_db_session.execute( @@ -156,7 +156,7 @@ async def _find_receipts( async def _find_base_fee( - trace_db_session: orm.Session, + trace_db_session: async_scoped_session, block_number: int, ) -> Optional[int]: result = await trace_db_session.execute( diff --git a/mev_inspect/db.py b/mev_inspect/db.py index ac04dd6..a01cb52 100644 --- a/mev_inspect/db.py +++ b/mev_inspect/db.py @@ -1,8 +1,13 @@ import os from typing import Optional +from asyncio import current_task from sqlalchemy.orm import sessionmaker -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.ext.asyncio import ( + create_async_engine, + AsyncSession, + async_scoped_session, +) def get_trace_database_uri() -> Optional[str]: @@ -31,15 +36,15 @@ def _get_engine(uri: str): def _get_session(uri: str): session = sessionmaker(bind=_get_engine(uri), class_=AsyncSession) - return session() + return async_scoped_session(session, scopefunc=current_task) -def get_inspect_session() -> sessionmaker: +def get_inspect_session() -> async_scoped_session: uri = get_inspect_database_uri() return _get_session(uri) -def get_trace_session() -> Optional[sessionmaker]: +def get_trace_session() -> Optional[async_scoped_session]: uri = get_trace_database_uri() if uri is not None: diff --git a/mev_inspect/inspect_block.py b/mev_inspect/inspect_block.py index 356698d..389701f 100644 --- a/mev_inspect/inspect_block.py +++ b/mev_inspect/inspect_block.py @@ -1,6 +1,7 @@ import logging +from typing import Optional -from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import async_scoped_session from web3 import Web3 from mev_inspect.arbitrages import get_arbitrages @@ -37,14 +38,17 @@ logger = logging.getLogger(__name__) async def inspect_block( - inspect_db_session: AsyncSession, + inspect_db_session: async_scoped_session, + trace_db_session: Optional[async_scoped_session], base_provider, w3: Web3, trace_classifier: TraceClassifier, block_number: int, should_write_classified_traces: bool = True, ): - block = await create_from_block_number(base_provider, w3, block_number) + block = await create_from_block_number( + base_provider, w3, block_number, trace_db_session + ) logger.info(f"Block: {block_number} -- Total traces: {len(block.traces)}") @@ -91,4 +95,4 @@ async def inspect_block( ) await delete_miner_payments_for_block(inspect_db_session, block_number) await write_miner_payments(inspect_db_session, miner_payments) - await inspect_db_session.commit() + await inspect_db_session.close() diff --git a/mev_inspect/inspector.py b/mev_inspect/inspector.py index 7550252..dea26b1 100644 --- a/mev_inspect/inspector.py +++ b/mev_inspect/inspector.py @@ -2,13 +2,15 @@ import asyncio import logging import traceback from asyncio import CancelledError +from typing import Tuple, Optional +from sqlalchemy.ext.asyncio import async_scoped_session from web3 import Web3 from web3.eth import AsyncEth from mev_inspect.block import create_from_block_number from mev_inspect.classifiers.trace import TraceClassifier -from mev_inspect.db import get_inspect_session +from mev_inspect.db import get_inspect_session, get_trace_session from mev_inspect.inspect_block import inspect_block from mev_inspect.provider import get_base_provider @@ -28,15 +30,24 @@ class MEVInspector: self.max_concurrency = asyncio.Semaphore(max_concurrency) async def create_from_block(self, block_number: int): + _, trace_session = _get_sessions() return await create_from_block_number( - base_provider=self.base_provider, w3=self.w3, block_number=block_number + base_provider=self.base_provider, + w3=self.w3, + block_number=block_number, + trace_session=trace_session, ) async def inspect_single_block(self, block: int): - async with get_inspect_session() as session: - return await inspect_block( - session, self.base_provider, self.w3, self.trace_classifier, block - ) + inspect_session, trace_session = _get_sessions() + return await inspect_block( + inspect_session, + trace_session, + self.base_provider, + self.w3, + self.trace_classifier, + block, + ) async def inspect_many_blocks(self, after_block: int, before_block: int): tasks = [] @@ -56,12 +67,20 @@ class MEVInspector: traceback.print_exc() async def safe_inspect_block(self, block_number: int): - async with get_inspect_session() as session: - async with self.max_concurrency: - return await inspect_block( - session, - self.base_provider, - self.w3, - self.trace_classifier, - block_number, - ) + inspect_session, trace_session = _get_sessions() + async with self.max_concurrency: + return await inspect_block( + inspect_session, + trace_session, + self.base_provider, + self.w3, + self.trace_classifier, + block_number, + ) + + +def _get_sessions() -> Tuple[async_scoped_session, Optional[async_scoped_session]]: + inspect_db_session = get_inspect_session() + trace_db_session = get_trace_session() + trace_db_session = trace_db_session() if trace_db_session is not None else None + return inspect_db_session, trace_db_session