share session object

This commit is contained in:
carlomazzaferro 2021-11-25 15:39:02 +01:00
parent 40e4e2e111
commit c5ef5c2f45
No known key found for this signature in database
GPG Key ID: 0CED3103EF7B2187
4 changed files with 61 additions and 33 deletions

View File

@ -2,10 +2,9 @@ import asyncio
import logging import logging
from typing import List, Optional from typing import List, Optional
from sqlalchemy import orm from sqlalchemy.ext.asyncio import async_scoped_session
from web3 import Web3 from web3 import Web3
from mev_inspect.db import get_trace_session
from mev_inspect.fees import fetch_base_fee_per_gas from mev_inspect.fees import fetch_base_fee_per_gas
from mev_inspect.schemas.blocks import Block from mev_inspect.schemas.blocks import Block
from mev_inspect.schemas.receipts import Receipt from mev_inspect.schemas.receipts import Receipt
@ -29,12 +28,13 @@ async def create_from_block_number(
base_provider, base_provider,
w3: Web3, w3: Web3,
block_number: int, block_number: int,
trace_session: Optional[async_scoped_session],
) -> Block: ) -> Block:
block: Optional[Block] = None block: Optional[Block] = None
if get_trace_session() is not None: if trace_session is not None:
async with get_trace_session() as session: # type: ignore block = await _find_block(trace_session, block_number)
block = await _find_block(session, block_number) await trace_session.close()
if block is None: if block is None:
block = await _fetch_block(w3, base_provider, block_number) block = await _fetch_block(w3, base_provider, block_number)
return block return block
@ -76,7 +76,7 @@ async def _fetch_block(w3, base_provider, block_number: int, retries: int = 0) -
async def _find_block( async def _find_block(
trace_db_session: orm.Session, trace_db_session: async_scoped_session,
block_number: int, block_number: int,
) -> Optional[Block]: ) -> Optional[Block]:
traces = await _find_traces(trace_db_session, block_number) traces = await _find_traces(trace_db_session, block_number)
@ -108,7 +108,7 @@ async def _find_block(
def _find_block_timestamp( def _find_block_timestamp(
trace_db_session: orm.Session, trace_db_session: async_scoped_session,
block_number: int, block_number: int,
) -> Optional[int]: ) -> Optional[int]:
result = trace_db_session.execute( result = trace_db_session.execute(
@ -124,7 +124,7 @@ def _find_block_timestamp(
async def _find_traces( async def _find_traces(
trace_db_session: orm.Session, trace_db_session: async_scoped_session,
block_number: int, block_number: int,
) -> Optional[List[Trace]]: ) -> Optional[List[Trace]]:
result = await trace_db_session.execute( result = await trace_db_session.execute(
@ -140,7 +140,7 @@ async def _find_traces(
async def _find_receipts( async def _find_receipts(
trace_db_session: orm.Session, trace_db_session: async_scoped_session,
block_number: int, block_number: int,
) -> Optional[List[Receipt]]: ) -> Optional[List[Receipt]]:
result = await trace_db_session.execute( result = await trace_db_session.execute(
@ -156,7 +156,7 @@ async def _find_receipts(
async def _find_base_fee( async def _find_base_fee(
trace_db_session: orm.Session, trace_db_session: async_scoped_session,
block_number: int, block_number: int,
) -> Optional[int]: ) -> Optional[int]:
result = await trace_db_session.execute( result = await trace_db_session.execute(

View File

@ -1,8 +1,13 @@
import os import os
from typing import Optional from typing import Optional
from asyncio import current_task
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.ext.asyncio import (
create_async_engine,
AsyncSession,
async_scoped_session,
)
def get_trace_database_uri() -> Optional[str]: def get_trace_database_uri() -> Optional[str]:
@ -31,15 +36,15 @@ def _get_engine(uri: str):
def _get_session(uri: str): def _get_session(uri: str):
session = sessionmaker(bind=_get_engine(uri), class_=AsyncSession) session = sessionmaker(bind=_get_engine(uri), class_=AsyncSession)
return session() return async_scoped_session(session, scopefunc=current_task)
def get_inspect_session() -> sessionmaker: def get_inspect_session() -> async_scoped_session:
uri = get_inspect_database_uri() uri = get_inspect_database_uri()
return _get_session(uri) return _get_session(uri)
def get_trace_session() -> Optional[sessionmaker]: def get_trace_session() -> Optional[async_scoped_session]:
uri = get_trace_database_uri() uri = get_trace_database_uri()
if uri is not None: if uri is not None:

View File

@ -1,6 +1,7 @@
import logging import logging
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import async_scoped_session
from web3 import Web3 from web3 import Web3
from mev_inspect.arbitrages import get_arbitrages from mev_inspect.arbitrages import get_arbitrages
@ -37,14 +38,17 @@ logger = logging.getLogger(__name__)
async def inspect_block( async def inspect_block(
inspect_db_session: AsyncSession, inspect_db_session: async_scoped_session,
trace_db_session: Optional[async_scoped_session],
base_provider, base_provider,
w3: Web3, w3: Web3,
trace_classifier: TraceClassifier, trace_classifier: TraceClassifier,
block_number: int, block_number: int,
should_write_classified_traces: bool = True, should_write_classified_traces: bool = True,
): ):
block = await create_from_block_number(base_provider, w3, block_number) block = await create_from_block_number(
base_provider, w3, block_number, trace_db_session
)
logger.info(f"Block: {block_number} -- Total traces: {len(block.traces)}") logger.info(f"Block: {block_number} -- Total traces: {len(block.traces)}")
@ -91,4 +95,4 @@ async def inspect_block(
) )
await delete_miner_payments_for_block(inspect_db_session, block_number) await delete_miner_payments_for_block(inspect_db_session, block_number)
await write_miner_payments(inspect_db_session, miner_payments) await write_miner_payments(inspect_db_session, miner_payments)
await inspect_db_session.commit() await inspect_db_session.close()

View File

@ -2,13 +2,15 @@ import asyncio
import logging import logging
import traceback import traceback
from asyncio import CancelledError from asyncio import CancelledError
from typing import Tuple, Optional
from sqlalchemy.ext.asyncio import async_scoped_session
from web3 import Web3 from web3 import Web3
from web3.eth import AsyncEth from web3.eth import AsyncEth
from mev_inspect.block import create_from_block_number from mev_inspect.block import create_from_block_number
from mev_inspect.classifiers.trace import TraceClassifier from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.db import get_inspect_session from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspect_block import inspect_block from mev_inspect.inspect_block import inspect_block
from mev_inspect.provider import get_base_provider from mev_inspect.provider import get_base_provider
@ -28,15 +30,24 @@ class MEVInspector:
self.max_concurrency = asyncio.Semaphore(max_concurrency) self.max_concurrency = asyncio.Semaphore(max_concurrency)
async def create_from_block(self, block_number: int): async def create_from_block(self, block_number: int):
_, trace_session = _get_sessions()
return await create_from_block_number( return await create_from_block_number(
base_provider=self.base_provider, w3=self.w3, block_number=block_number base_provider=self.base_provider,
w3=self.w3,
block_number=block_number,
trace_session=trace_session,
) )
async def inspect_single_block(self, block: int): async def inspect_single_block(self, block: int):
async with get_inspect_session() as session: inspect_session, trace_session = _get_sessions()
return await inspect_block( return await inspect_block(
session, self.base_provider, self.w3, self.trace_classifier, block inspect_session,
) trace_session,
self.base_provider,
self.w3,
self.trace_classifier,
block,
)
async def inspect_many_blocks(self, after_block: int, before_block: int): async def inspect_many_blocks(self, after_block: int, before_block: int):
tasks = [] tasks = []
@ -56,12 +67,20 @@ class MEVInspector:
traceback.print_exc() traceback.print_exc()
async def safe_inspect_block(self, block_number: int): async def safe_inspect_block(self, block_number: int):
async with get_inspect_session() as session: inspect_session, trace_session = _get_sessions()
async with self.max_concurrency: async with self.max_concurrency:
return await inspect_block( return await inspect_block(
session, inspect_session,
self.base_provider, trace_session,
self.w3, self.base_provider,
self.trace_classifier, self.w3,
block_number, self.trace_classifier,
) block_number,
)
def _get_sessions() -> Tuple[async_scoped_session, Optional[async_scoped_session]]:
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
trace_db_session = trace_db_session() if trace_db_session is not None else None
return inspect_db_session, trace_db_session