Use middleware for trace and receipt methods

This commit is contained in:
Luke Van Seters 2021-12-23 22:21:18 -05:00
parent 2982ff700f
commit fcc453391f
5 changed files with 38 additions and 32 deletions

View File

@ -24,7 +24,6 @@ async def get_latest_block_number(base_provider) -> int:
async def create_from_block_number( async def create_from_block_number(
base_provider,
w3: Web3, w3: Web3,
block_number: int, block_number: int,
trace_db_session: Optional[orm.Session], trace_db_session: Optional[orm.Session],
@ -35,34 +34,22 @@ async def create_from_block_number(
block = _find_block(trace_db_session, block_number) block = _find_block(trace_db_session, block_number)
if block is None: if block is None:
block = await _fetch_block(w3, base_provider, block_number) block = await _fetch_block(w3, block_number)
return block return block
else: else:
return block return block
async def _fetch_block(w3, base_provider, block_number: int, retries: int = 0) -> Block: async def _fetch_block(w3, block_number: int) -> Block:
block_json, receipts_json, traces_json, base_fee_per_gas = await asyncio.gather( block_json, receipts_json, traces_json, base_fee_per_gas = await asyncio.gather(
w3.eth.get_block(block_number), w3.eth.get_block(block_number),
base_provider.make_request("eth_getBlockReceipts", [block_number]), w3.eth.get_block_receipts(block_number),
base_provider.make_request("trace_block", [block_number]), w3.eth.trace_block(block_number),
fetch_base_fee_per_gas(w3, block_number), fetch_base_fee_per_gas(w3, block_number),
) )
try: receipts: List[Receipt] = [Receipt(**receipt) for receipt in receipts_json]
receipts: List[Receipt] = [ traces = [Trace(**trace_json) for trace_json in traces_json]
Receipt(**receipt) for receipt in receipts_json["result"]
]
traces = [Trace(**trace_json) for trace_json in traces_json["result"]]
except KeyError as e:
logger.warning(
f"Failed to create objects from block: {block_number}: {e}, retrying: {retries + 1} / 3"
)
if retries < 3:
await asyncio.sleep(5)
return await _fetch_block(w3, base_provider, block_number, retries)
else:
raise
return Block( return Block(
block_number=block_number, block_number=block_number,

View File

@ -55,7 +55,6 @@ logger = logging.getLogger(__name__)
async def inspect_block( async def inspect_block(
inspect_db_session: orm.Session, inspect_db_session: orm.Session,
base_provider,
w3: Web3, w3: Web3,
trace_classifier: TraceClassifier, trace_classifier: TraceClassifier,
block_number: int, block_number: int,
@ -64,7 +63,6 @@ async def inspect_block(
): ):
await inspect_many_blocks( await inspect_many_blocks(
inspect_db_session, inspect_db_session,
base_provider,
w3, w3,
trace_classifier, trace_classifier,
block_number, block_number,
@ -76,7 +74,6 @@ async def inspect_block(
async def inspect_many_blocks( async def inspect_many_blocks(
inspect_db_session: orm.Session, inspect_db_session: orm.Session,
base_provider,
w3: Web3, w3: Web3,
trace_classifier: TraceClassifier, trace_classifier: TraceClassifier,
after_block_number: int, after_block_number: int,
@ -100,7 +97,6 @@ async def inspect_many_blocks(
for block_number in range(after_block_number, before_block_number): for block_number in range(after_block_number, before_block_number):
block = await create_from_block_number( block = await create_from_block_number(
base_provider,
w3, w3,
block_number, block_number,
trace_db_session, trace_db_session,

View File

@ -11,11 +11,18 @@ from web3.eth import AsyncEth
from mev_inspect.block import create_from_block_number from mev_inspect.block import create_from_block_number
from mev_inspect.classifiers.trace import TraceClassifier from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.inspect_block import inspect_block, inspect_many_blocks from mev_inspect.inspect_block import inspect_block, inspect_many_blocks
from mev_inspect.methods import get_block_receipts, trace_block
from mev_inspect.provider import get_base_provider from mev_inspect.provider import get_base_provider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# add missing parity methods
# this is a bit gross
AsyncEth.trace_block = trace_block
AsyncEth.get_block_receipts = get_block_receipts
class MEVInspector: class MEVInspector:
def __init__( def __init__(
self, self,
@ -27,14 +34,15 @@ class MEVInspector:
): ):
self.inspect_db_session = inspect_db_session self.inspect_db_session = inspect_db_session
self.trace_db_session = trace_db_session self.trace_db_session = trace_db_session
self.base_provider = get_base_provider(rpc, request_timeout=request_timeout)
self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[]) base_provider = get_base_provider(rpc, request_timeout=request_timeout)
self.w3 = Web3(base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
self.trace_classifier = TraceClassifier() self.trace_classifier = TraceClassifier()
self.max_concurrency = asyncio.Semaphore(max_concurrency) self.max_concurrency = asyncio.Semaphore(max_concurrency)
async def create_from_block(self, block_number: int): async def create_from_block(self, block_number: int):
return await create_from_block_number( return await create_from_block_number(
base_provider=self.base_provider,
w3=self.w3, w3=self.w3,
block_number=block_number, block_number=block_number,
trace_db_session=self.trace_db_session, trace_db_session=self.trace_db_session,
@ -43,7 +51,6 @@ class MEVInspector:
async def inspect_single_block(self, block: int): async def inspect_single_block(self, block: int):
return await inspect_block( return await inspect_block(
self.inspect_db_session, self.inspect_db_session,
self.base_provider,
self.w3, self.w3,
self.trace_classifier, self.trace_classifier,
block, block,
@ -87,7 +94,6 @@ class MEVInspector:
async with self.max_concurrency: async with self.max_concurrency:
return await inspect_many_blocks( return await inspect_many_blocks(
self.inspect_db_session, self.inspect_db_session,
self.base_provider,
self.w3, self.w3,
self.trace_classifier, self.trace_classifier,
after_block_number, after_block_number,

16
mev_inspect/methods.py Normal file
View File

@ -0,0 +1,16 @@
from typing import Callable, List
from web3._utils.rpc_abi import RPC
from web3.method import Method, default_root_munger
from web3.types import BlockIdentifier, ParityBlockTrace, RPCEndpoint
trace_block: Method[Callable[[BlockIdentifier], List[ParityBlockTrace]]] = Method(
RPC.trace_block,
mungers=[default_root_munger],
)
get_block_receipts: Method[Callable[[BlockIdentifier], List[dict]]] = Method(
RPCEndpoint("eth_getBlockReceipts"),
mungers=[default_root_munger],
)

View File

@ -1,7 +1,7 @@
import asyncio
import logging import logging
import random import random
from asyncio.exceptions import TimeoutError from asyncio.exceptions import TimeoutError
from time import sleep
from typing import Any, Callable, Collection, Coroutine, Type from typing import Any, Callable, Collection, Coroutine, Type
from aiohttp.client_exceptions import ( from aiohttp.client_exceptions import (
@ -49,13 +49,14 @@ async def exception_retry_with_backoff_middleware(
logger.error( logger.error(
f"Request for method {method}, block: {int(params[0], 16)}, retrying: {i}/{retries}" f"Request for method {method}, block: {int(params[0], 16)}, retrying: {i}/{retries}"
) )
if i < retries - 1: if i < (retries - 1):
backoff_time = backoff_time_seconds * ( backoff_time = backoff_time_seconds * (
random.uniform(5, 10) ** i random.uniform(5, 10) ** i
) )
await asyncio.sleep(backoff_time) # use blocking sleep to prevent new tries on
# concurrent requests during sleep
sleep(backoff_time)
continue continue
else: else:
raise raise
return None return None