share session object
This commit is contained in:
parent
40e4e2e111
commit
c5ef5c2f45
@ -2,10 +2,9 @@ import asyncio
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import orm
|
||||
from sqlalchemy.ext.asyncio import async_scoped_session
|
||||
from web3 import Web3
|
||||
|
||||
from mev_inspect.db import get_trace_session
|
||||
from mev_inspect.fees import fetch_base_fee_per_gas
|
||||
from mev_inspect.schemas.blocks import Block
|
||||
from mev_inspect.schemas.receipts import Receipt
|
||||
@ -29,12 +28,13 @@ async def create_from_block_number(
|
||||
base_provider,
|
||||
w3: Web3,
|
||||
block_number: int,
|
||||
trace_session: Optional[async_scoped_session],
|
||||
) -> Block:
|
||||
block: Optional[Block] = None
|
||||
|
||||
if get_trace_session() is not None:
|
||||
async with get_trace_session() as session: # type: ignore
|
||||
block = await _find_block(session, block_number)
|
||||
if trace_session is not None:
|
||||
block = await _find_block(trace_session, block_number)
|
||||
await trace_session.close()
|
||||
if block is None:
|
||||
block = await _fetch_block(w3, base_provider, block_number)
|
||||
return block
|
||||
@ -76,7 +76,7 @@ async def _fetch_block(w3, base_provider, block_number: int, retries: int = 0) -
|
||||
|
||||
|
||||
async def _find_block(
|
||||
trace_db_session: orm.Session,
|
||||
trace_db_session: async_scoped_session,
|
||||
block_number: int,
|
||||
) -> Optional[Block]:
|
||||
traces = await _find_traces(trace_db_session, block_number)
|
||||
@ -108,7 +108,7 @@ async def _find_block(
|
||||
|
||||
|
||||
def _find_block_timestamp(
|
||||
trace_db_session: orm.Session,
|
||||
trace_db_session: async_scoped_session,
|
||||
block_number: int,
|
||||
) -> Optional[int]:
|
||||
result = trace_db_session.execute(
|
||||
@ -124,7 +124,7 @@ def _find_block_timestamp(
|
||||
|
||||
|
||||
async def _find_traces(
|
||||
trace_db_session: orm.Session,
|
||||
trace_db_session: async_scoped_session,
|
||||
block_number: int,
|
||||
) -> Optional[List[Trace]]:
|
||||
result = await trace_db_session.execute(
|
||||
@ -140,7 +140,7 @@ async def _find_traces(
|
||||
|
||||
|
||||
async def _find_receipts(
|
||||
trace_db_session: orm.Session,
|
||||
trace_db_session: async_scoped_session,
|
||||
block_number: int,
|
||||
) -> Optional[List[Receipt]]:
|
||||
result = await trace_db_session.execute(
|
||||
@ -156,7 +156,7 @@ async def _find_receipts(
|
||||
|
||||
|
||||
async def _find_base_fee(
|
||||
trace_db_session: orm.Session,
|
||||
trace_db_session: async_scoped_session,
|
||||
block_number: int,
|
||||
) -> Optional[int]:
|
||||
result = await trace_db_session.execute(
|
||||
|
@ -1,8 +1,13 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from asyncio import current_task
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
create_async_engine,
|
||||
AsyncSession,
|
||||
async_scoped_session,
|
||||
)
|
||||
|
||||
|
||||
def get_trace_database_uri() -> Optional[str]:
|
||||
@ -31,15 +36,15 @@ def _get_engine(uri: str):
|
||||
|
||||
def _get_session(uri: str):
|
||||
session = sessionmaker(bind=_get_engine(uri), class_=AsyncSession)
|
||||
return session()
|
||||
return async_scoped_session(session, scopefunc=current_task)
|
||||
|
||||
|
||||
def get_inspect_session() -> sessionmaker:
|
||||
def get_inspect_session() -> async_scoped_session:
|
||||
uri = get_inspect_database_uri()
|
||||
return _get_session(uri)
|
||||
|
||||
|
||||
def get_trace_session() -> Optional[sessionmaker]:
|
||||
def get_trace_session() -> Optional[async_scoped_session]:
|
||||
uri = get_trace_database_uri()
|
||||
|
||||
if uri is not None:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.ext.asyncio import async_scoped_session
|
||||
from web3 import Web3
|
||||
|
||||
from mev_inspect.arbitrages import get_arbitrages
|
||||
@ -37,14 +38,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def inspect_block(
|
||||
inspect_db_session: AsyncSession,
|
||||
inspect_db_session: async_scoped_session,
|
||||
trace_db_session: Optional[async_scoped_session],
|
||||
base_provider,
|
||||
w3: Web3,
|
||||
trace_classifier: TraceClassifier,
|
||||
block_number: int,
|
||||
should_write_classified_traces: bool = True,
|
||||
):
|
||||
block = await create_from_block_number(base_provider, w3, block_number)
|
||||
block = await create_from_block_number(
|
||||
base_provider, w3, block_number, trace_db_session
|
||||
)
|
||||
|
||||
logger.info(f"Block: {block_number} -- Total traces: {len(block.traces)}")
|
||||
|
||||
@ -91,4 +95,4 @@ async def inspect_block(
|
||||
)
|
||||
await delete_miner_payments_for_block(inspect_db_session, block_number)
|
||||
await write_miner_payments(inspect_db_session, miner_payments)
|
||||
await inspect_db_session.commit()
|
||||
await inspect_db_session.close()
|
||||
|
@ -2,13 +2,15 @@ import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
from asyncio import CancelledError
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_scoped_session
|
||||
from web3 import Web3
|
||||
from web3.eth import AsyncEth
|
||||
|
||||
from mev_inspect.block import create_from_block_number
|
||||
from mev_inspect.classifiers.trace import TraceClassifier
|
||||
from mev_inspect.db import get_inspect_session
|
||||
from mev_inspect.db import get_inspect_session, get_trace_session
|
||||
from mev_inspect.inspect_block import inspect_block
|
||||
from mev_inspect.provider import get_base_provider
|
||||
|
||||
@ -28,15 +30,24 @@ class MEVInspector:
|
||||
self.max_concurrency = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def create_from_block(self, block_number: int):
|
||||
_, trace_session = _get_sessions()
|
||||
return await create_from_block_number(
|
||||
base_provider=self.base_provider, w3=self.w3, block_number=block_number
|
||||
base_provider=self.base_provider,
|
||||
w3=self.w3,
|
||||
block_number=block_number,
|
||||
trace_session=trace_session,
|
||||
)
|
||||
|
||||
async def inspect_single_block(self, block: int):
|
||||
async with get_inspect_session() as session:
|
||||
return await inspect_block(
|
||||
session, self.base_provider, self.w3, self.trace_classifier, block
|
||||
)
|
||||
inspect_session, trace_session = _get_sessions()
|
||||
return await inspect_block(
|
||||
inspect_session,
|
||||
trace_session,
|
||||
self.base_provider,
|
||||
self.w3,
|
||||
self.trace_classifier,
|
||||
block,
|
||||
)
|
||||
|
||||
async def inspect_many_blocks(self, after_block: int, before_block: int):
|
||||
tasks = []
|
||||
@ -56,12 +67,20 @@ class MEVInspector:
|
||||
traceback.print_exc()
|
||||
|
||||
async def safe_inspect_block(self, block_number: int):
|
||||
async with get_inspect_session() as session:
|
||||
async with self.max_concurrency:
|
||||
return await inspect_block(
|
||||
session,
|
||||
self.base_provider,
|
||||
self.w3,
|
||||
self.trace_classifier,
|
||||
block_number,
|
||||
)
|
||||
inspect_session, trace_session = _get_sessions()
|
||||
async with self.max_concurrency:
|
||||
return await inspect_block(
|
||||
inspect_session,
|
||||
trace_session,
|
||||
self.base_provider,
|
||||
self.w3,
|
||||
self.trace_classifier,
|
||||
block_number,
|
||||
)
|
||||
|
||||
|
||||
def _get_sessions() -> Tuple[async_scoped_session, Optional[async_scoped_session]]:
|
||||
inspect_db_session = get_inspect_session()
|
||||
trace_db_session = get_trace_session()
|
||||
trace_db_session = trace_db_session() if trace_db_session is not None else None
|
||||
return inspect_db_session, trace_db_session
|
||||
|
Loading…
x
Reference in New Issue
Block a user