Clean up async defaults

This commit is contained in:
Luke Van Seters 2022-01-18 11:29:45 -05:00
parent 89d2a718b2
commit 091ddbd9c1

View File

@ -1,7 +1,6 @@
import asyncio import asyncio
import logging import logging
from collections.abc import Awaitable from typing import List, Optional, TypeVar
from typing import Callable, List, Optional, TypeVar
from sqlalchemy import orm from sqlalchemy import orm
from web3 import Web3 from web3 import Web3
@ -30,30 +29,11 @@ async def create_from_block_number(
block_number: int, block_number: int,
trace_db_session: Optional[orm.Session], trace_db_session: Optional[orm.Session],
) -> Block: ) -> Block:
existing_block_timestamp = None
existing_base_fee_per_gas = None
existing_traces = None
existing_receipts = None
if trace_db_session is not None:
existing_block_timestamp = _find_block_timestamp(trace_db_session, block_number)
existing_base_fee_per_gas = _find_base_fee(trace_db_session, block_number)
existing_traces = _find_traces(trace_db_session, block_number)
existing_receipts = _find_receipts(trace_db_session, block_number)
block_timestamp, receipts, traces, base_fee_per_gas = await asyncio.gather( block_timestamp, receipts, traces, base_fee_per_gas = await asyncio.gather(
_await_if_absent( _find_or_fetch_block_timestamp(w3, block_number, trace_db_session),
lambda: _fetch_block_timestamp(w3, block_number), existing_block_timestamp _find_or_fetch_block_receipts(w3, block_number, trace_db_session),
), _find_or_fetch_block_traces(w3, block_number, trace_db_session),
_await_if_absent( _find_or_fetch_base_fee_per_gas(w3, block_number, trace_db_session),
lambda: _fetch_block_receipts(w3, block_number), existing_receipts
),
_await_if_absent(
lambda: _fetch_block_traces(w3, block_number), existing_traces
),
_await_if_absent(
lambda: fetch_base_fee_per_gas(w3, block_number), existing_base_fee_per_gas
),
) )
miner_address = _get_miner_address_from_traces(traces) miner_address = _get_miner_address_from_traces(traces)
@ -68,6 +48,60 @@ async def create_from_block_number(
) )
async def _find_or_fetch_block_timestamp(
w3,
block_number: int,
trace_db_session: Optional[orm.Session],
) -> int:
if trace_db_session is not None:
existing_block_timestamp = _find_block_timestamp(trace_db_session, block_number)
if existing_block_timestamp is not None:
return existing_block_timestamp
return await _fetch_block_timestamp(w3, block_number)
async def _find_or_fetch_block_receipts(
w3,
block_number: int,
trace_db_session: Optional[orm.Session],
) -> List[Receipt]:
if trace_db_session is not None:
existing_block_receipts = _find_block_receipts(trace_db_session, block_number)
if existing_block_receipts is not None:
return existing_block_receipts
return await _fetch_block_receipts(w3, block_number)
async def _find_or_fetch_block_traces(
w3,
block_number: int,
trace_db_session: Optional[orm.Session],
) -> List[Trace]:
if trace_db_session is not None:
existing_block_traces = _find_block_traces(trace_db_session, block_number)
if existing_block_traces is not None:
return existing_block_traces
return await _fetch_block_traces(w3, block_number)
async def _find_or_fetch_base_fee_per_gas(
w3,
block_number: int,
trace_db_session: Optional[orm.Session],
) -> int:
if trace_db_session is not None:
existing_base_fee_per_gas = _find_base_fee_per_gas(
trace_db_session, block_number
)
if existing_base_fee_per_gas is not None:
return existing_base_fee_per_gas
return await fetch_base_fee_per_gas(w3, block_number)
async def _fetch_block_timestamp(w3, block_number: int) -> int: async def _fetch_block_timestamp(w3, block_number: int) -> int:
block_json = await w3.eth.get_block(block_number) block_json = await w3.eth.get_block(block_number)
return block_json["timestamp"] return block_json["timestamp"]
@ -83,15 +117,6 @@ async def _fetch_block_traces(w3, block_number: int) -> List[Trace]:
return [Trace(**trace_json) for trace_json in traces_json] return [Trace(**trace_json) for trace_json in traces_json]
async def _await_if_absent(
awaitable: Callable[[], Awaitable[T]], existing: Optional[T]
) -> T:
if existing is not None:
return existing
else:
return await awaitable()
def _find_block_timestamp( def _find_block_timestamp(
trace_db_session: orm.Session, trace_db_session: orm.Session,
block_number: int, block_number: int,
@ -108,7 +133,7 @@ def _find_block_timestamp(
return block_timestamp return block_timestamp
def _find_traces( def _find_block_traces(
trace_db_session: orm.Session, trace_db_session: orm.Session,
block_number: int, block_number: int,
) -> Optional[List[Trace]]: ) -> Optional[List[Trace]]:
@ -124,7 +149,7 @@ def _find_traces(
return [Trace(**trace_json) for trace_json in traces_json] return [Trace(**trace_json) for trace_json in traces_json]
def _find_receipts( def _find_block_receipts(
trace_db_session: orm.Session, trace_db_session: orm.Session,
block_number: int, block_number: int,
) -> Optional[List[Receipt]]: ) -> Optional[List[Receipt]]:
@ -140,7 +165,7 @@ def _find_receipts(
return [Receipt(**receipt) for receipt in receipts_json] return [Receipt(**receipt) for receipt in receipts_json]
def _find_base_fee( def _find_base_fee_per_gas(
trace_db_session: orm.Session, trace_db_session: orm.Session,
block_number: int, block_number: int,
) -> Optional[int]: ) -> Optional[int]: