diff --git a/mev_inspect/block.py b/mev_inspect/block.py index cf8e2d4..0ff0341 100644 --- a/mev_inspect/block.py +++ b/mev_inspect/block.py @@ -1,6 +1,7 @@ import asyncio import logging -from typing import List, Optional +from collections.abc import Awaitable +from typing import Callable, List, Optional, TypeVar from sqlalchemy import orm from web3 import Web3 @@ -12,6 +13,7 @@ from mev_inspect.schemas.traces import Trace, TraceType from mev_inspect.utils import hex_to_int logger = logging.getLogger(__name__) +T = TypeVar("T") async def get_latest_block_number(base_provider) -> int: @@ -28,62 +30,33 @@ async def create_from_block_number( block_number: int, trace_db_session: Optional[orm.Session], ) -> Block: - block: Optional[Block] = None + existing_block_timestamp = None + existing_base_fee_per_gas = None + existing_traces = None + existing_receipts = None if trace_db_session is not None: - block = _find_block(trace_db_session, block_number) + existing_block_timestamp = _find_block_timestamp(trace_db_session, block_number) + existing_base_fee_per_gas = _find_base_fee(trace_db_session, block_number) + existing_traces = _find_traces(trace_db_session, block_number) + existing_receipts = _find_receipts(trace_db_session, block_number) - if block is None: - block = await _fetch_block(w3, block_number) - return block - else: - return block - - -async def _fetch_block(w3, block_number: int) -> Block: - block_json, receipts_json, traces_json, base_fee_per_gas = await asyncio.gather( - w3.eth.get_block(block_number), - w3.eth.get_block_receipts(block_number), - w3.eth.trace_block(block_number), - fetch_base_fee_per_gas(w3, block_number), + block_timestamp, receipts, traces, base_fee_per_gas = await asyncio.gather( + _await_if_absent( + lambda: _fetch_block_timestamp(w3, block_number), existing_block_timestamp + ), + _await_if_absent( + lambda: _fetch_block_receipts(w3, block_number), existing_receipts + ), + _await_if_absent( + lambda: _fetch_block_traces(w3, block_number), existing_traces + ), + _await_if_absent( + lambda: fetch_base_fee_per_gas(w3, block_number), existing_base_fee_per_gas + ), ) - receipts: List[Receipt] = [Receipt(**receipt) for receipt in receipts_json] - traces = [Trace(**trace_json) for trace_json in traces_json] - - return Block( - block_number=block_number, - block_timestamp=block_json["timestamp"], - miner=block_json["miner"], - base_fee_per_gas=base_fee_per_gas, - traces=traces, - receipts=receipts, - ) - - -def _find_block( - trace_db_session: orm.Session, - block_number: int, -) -> Optional[Block]: - block_timestamp = _find_block_timestamp(trace_db_session, block_number) - if block_timestamp is None: - return None - - base_fee_per_gas = _find_base_fee(trace_db_session, block_number) - if base_fee_per_gas is None: - return None - - traces = _find_traces(trace_db_session, block_number) - if traces is None: - return None - - receipts = _find_receipts(trace_db_session, block_number) - if receipts is None: - return None - miner_address = _get_miner_address_from_traces(traces) - if miner_address is None: - return None return Block( block_number=block_number, @@ -95,6 +68,30 @@ def _find_block( ) +async def _fetch_block_timestamp(w3, block_number: int) -> int: + block_json = await w3.eth.get_block(block_number) + return block_json["timestamp"] + + +async def _fetch_block_receipts(w3, block_number: int) -> List[Receipt]: + receipts_json = await w3.eth.get_block_receipts(block_number) + return [Receipt(**receipt) for receipt in receipts_json] + + +async def _fetch_block_traces(w3, block_number: int) -> List[Trace]: + traces_json = await w3.eth.trace_block(block_number) + return [Trace(**trace_json) for trace_json in traces_json] + + +async def _await_if_absent( + awaitable: Callable[[], Awaitable[T]], existing: Optional[T] +) -> T: + if existing is not None: + return existing + else: + return await awaitable() + + def _find_block_timestamp( trace_db_session: orm.Session, block_number: int, @@ -116,7 +113,7 @@ def _find_traces( block_number: int, ) -> Optional[List[Trace]]: result = trace_db_session.execute( - "SELECT raw_traces FROM block_traces WHERE block_number = :block_number", + "SELECT raw_traces FROM block_traces WHERE block_number = :block_number LIMIT 1", params={"block_number": block_number}, ).one_or_none()