diff --git a/mev_inspect/block.py b/mev_inspect/block.py index e58d97e..4c9c876 100644 --- a/mev_inspect/block.py +++ b/mev_inspect/block.py @@ -1,7 +1,6 @@ import asyncio import logging -from collections.abc import Awaitable -from typing import Callable, List, Optional, TypeVar +from typing import List, Optional, TypeVar from sqlalchemy import orm from web3 import Web3 @@ -30,30 +29,11 @@ async def create_from_block_number( block_number: int, trace_db_session: Optional[orm.Session], ) -> Block: - existing_block_timestamp = None - existing_base_fee_per_gas = None - existing_traces = None - existing_receipts = None - - if trace_db_session is not None: - existing_block_timestamp = _find_block_timestamp(trace_db_session, block_number) - existing_base_fee_per_gas = _find_base_fee(trace_db_session, block_number) - existing_traces = _find_traces(trace_db_session, block_number) - existing_receipts = _find_receipts(trace_db_session, block_number) - block_timestamp, receipts, traces, base_fee_per_gas = await asyncio.gather( - _await_if_absent( - lambda: _fetch_block_timestamp(w3, block_number), existing_block_timestamp - ), - _await_if_absent( - lambda: _fetch_block_receipts(w3, block_number), existing_receipts - ), - _await_if_absent( - lambda: _fetch_block_traces(w3, block_number), existing_traces - ), - _await_if_absent( - lambda: fetch_base_fee_per_gas(w3, block_number), existing_base_fee_per_gas - ), + _find_or_fetch_block_timestamp(w3, block_number, trace_db_session), + _find_or_fetch_block_receipts(w3, block_number, trace_db_session), + _find_or_fetch_block_traces(w3, block_number, trace_db_session), + _find_or_fetch_base_fee_per_gas(w3, block_number, trace_db_session), ) miner_address = _get_miner_address_from_traces(traces) @@ -68,6 +48,60 @@ async def create_from_block_number( ) +async def _find_or_fetch_block_timestamp( + w3, + block_number: int, + trace_db_session: Optional[orm.Session], +) -> int: + if trace_db_session is not None: + existing_block_timestamp = _find_block_timestamp(trace_db_session, block_number) + if existing_block_timestamp is not None: + return existing_block_timestamp + + return await _fetch_block_timestamp(w3, block_number) + + +async def _find_or_fetch_block_receipts( + w3, + block_number: int, + trace_db_session: Optional[orm.Session], +) -> List[Receipt]: + if trace_db_session is not None: + existing_block_receipts = _find_block_receipts(trace_db_session, block_number) + if existing_block_receipts is not None: + return existing_block_receipts + + return await _fetch_block_receipts(w3, block_number) + + +async def _find_or_fetch_block_traces( + w3, + block_number: int, + trace_db_session: Optional[orm.Session], +) -> List[Trace]: + if trace_db_session is not None: + existing_block_traces = _find_block_traces(trace_db_session, block_number) + if existing_block_traces is not None: + return existing_block_traces + + return await _fetch_block_traces(w3, block_number) + + +async def _find_or_fetch_base_fee_per_gas( + w3, + block_number: int, + trace_db_session: Optional[orm.Session], +) -> int: + if trace_db_session is not None: + existing_base_fee_per_gas = _find_base_fee_per_gas( + trace_db_session, block_number + ) + if existing_base_fee_per_gas is not None: + return existing_base_fee_per_gas + + return await fetch_base_fee_per_gas(w3, block_number) + + async def _fetch_block_timestamp(w3, block_number: int) -> int: block_json = await w3.eth.get_block(block_number) return block_json["timestamp"] @@ -83,15 +117,6 @@ async def _fetch_block_traces(w3, block_number: int) -> List[Trace]: return [Trace(**trace_json) for trace_json in traces_json] -async def _await_if_absent( - awaitable: Callable[[], Awaitable[T]], existing: Optional[T] -) -> T: - if existing is not None: - return existing - else: - return await awaitable() - - def _find_block_timestamp( trace_db_session: orm.Session, block_number: int, @@ -108,7 +133,7 @@ def _find_block_timestamp( return block_timestamp -def _find_traces( +def _find_block_traces( trace_db_session: orm.Session, block_number: int, ) -> Optional[List[Trace]]: @@ -124,7 +149,7 @@ def _find_traces( return [Trace(**trace_json) for trace_json in traces_json] -def _find_receipts( +def _find_block_receipts( trace_db_session: orm.Session, block_number: int, ) -> Optional[List[Receipt]]: @@ -140,7 +165,7 @@ def _find_receipts( return [Receipt(**receipt) for receipt in receipts_json] -def _find_base_fee( +def _find_base_fee_per_gas( trace_db_session: orm.Session, block_number: int, ) -> Optional[int]: