Merge pull request #195 from flashbots/consistent-middleware

Use middleware for all RPC calls
This commit is contained in:
Luke Van Seters 2021-12-30 10:11:33 -05:00 committed by GitHub
commit 9235020999
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 56 additions and 34 deletions

View File

@ -24,7 +24,6 @@ async def get_latest_block_number(base_provider) -> int:
async def create_from_block_number(
base_provider,
w3: Web3,
block_number: int,
trace_db_session: Optional[orm.Session],
@ -35,34 +34,22 @@ async def create_from_block_number(
block = _find_block(trace_db_session, block_number)
if block is None:
block = await _fetch_block(w3, base_provider, block_number)
block = await _fetch_block(w3, block_number)
return block
else:
return block
async def _fetch_block(w3, base_provider, block_number: int, retries: int = 0) -> Block:
async def _fetch_block(w3, block_number: int) -> Block:
block_json, receipts_json, traces_json, base_fee_per_gas = await asyncio.gather(
w3.eth.get_block(block_number),
base_provider.make_request("eth_getBlockReceipts", [block_number]),
base_provider.make_request("trace_block", [block_number]),
w3.eth.get_block_receipts(block_number),
w3.eth.trace_block(block_number),
fetch_base_fee_per_gas(w3, block_number),
)
try:
receipts: List[Receipt] = [
Receipt(**receipt) for receipt in receipts_json["result"]
]
traces = [Trace(**trace_json) for trace_json in traces_json["result"]]
except KeyError as e:
logger.warning(
f"Failed to create objects from block: {block_number}: {e}, retrying: {retries + 1} / 3"
)
if retries < 3:
await asyncio.sleep(5)
return await _fetch_block(w3, base_provider, block_number, retries)
else:
raise
receipts: List[Receipt] = [Receipt(**receipt) for receipt in receipts_json]
traces = [Trace(**trace_json) for trace_json in traces_json]
return Block(
block_number=block_number,

View File

@ -58,7 +58,6 @@ logger = logging.getLogger(__name__)
async def inspect_block(
inspect_db_session: orm.Session,
base_provider,
w3: Web3,
trace_classifier: TraceClassifier,
block_number: int,
@ -67,7 +66,6 @@ async def inspect_block(
):
await inspect_many_blocks(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
@ -79,7 +77,6 @@ async def inspect_block(
async def inspect_many_blocks(
inspect_db_session: orm.Session,
base_provider,
w3: Web3,
trace_classifier: TraceClassifier,
after_block_number: int,
@ -105,7 +102,6 @@ async def inspect_many_blocks(
for block_number in range(after_block_number, before_block_number):
block = await create_from_block_number(
base_provider,
w3,
block_number,
trace_db_session,

View File

@ -11,11 +11,18 @@ from web3.eth import AsyncEth
from mev_inspect.block import create_from_block_number
from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.inspect_block import inspect_block, inspect_many_blocks
from mev_inspect.methods import get_block_receipts, trace_block
from mev_inspect.provider import get_base_provider
logger = logging.getLogger(__name__)
# add missing parity methods
# this is a bit gross
AsyncEth.trace_block = trace_block
AsyncEth.get_block_receipts = get_block_receipts
class MEVInspector:
def __init__(
self,
@ -27,14 +34,15 @@ class MEVInspector:
):
self.inspect_db_session = inspect_db_session
self.trace_db_session = trace_db_session
self.base_provider = get_base_provider(rpc, request_timeout=request_timeout)
self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
base_provider = get_base_provider(rpc, request_timeout=request_timeout)
self.w3 = Web3(base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
self.trace_classifier = TraceClassifier()
self.max_concurrency = asyncio.Semaphore(max_concurrency)
async def create_from_block(self, block_number: int):
return await create_from_block_number(
base_provider=self.base_provider,
w3=self.w3,
block_number=block_number,
trace_db_session=self.trace_db_session,
@ -43,7 +51,6 @@ class MEVInspector:
async def inspect_single_block(self, block: int):
return await inspect_block(
self.inspect_db_session,
self.base_provider,
self.w3,
self.trace_classifier,
block,
@ -87,7 +94,6 @@ class MEVInspector:
async with self.max_concurrency:
return await inspect_many_blocks(
self.inspect_db_session,
self.base_provider,
self.w3,
self.trace_classifier,
after_block_number,

16
mev_inspect/methods.py Normal file
View File

@ -0,0 +1,16 @@
from typing import Callable, List
from web3._utils.rpc_abi import RPC
from web3.method import Method, default_root_munger
from web3.types import BlockIdentifier, ParityBlockTrace, RPCEndpoint
trace_block: Method[Callable[[BlockIdentifier], List[ParityBlockTrace]]] = Method(
RPC.trace_block,
mungers=[default_root_munger],
)
get_block_receipts: Method[Callable[[BlockIdentifier], List[dict]]] = Method(
RPCEndpoint("eth_getBlockReceipts"),
mungers=[default_root_munger],
)

View File

@ -5,6 +5,7 @@ from asyncio.exceptions import TimeoutError
from typing import Any, Callable, Collection, Coroutine, Type
from aiohttp.client_exceptions import (
ClientConnectorError,
ClientOSError,
ClientResponseError,
ServerDisconnectedError,
@ -12,20 +13,33 @@ from aiohttp.client_exceptions import (
)
from requests.exceptions import ConnectionError, HTTPError, Timeout, TooManyRedirects
from web3 import Web3
from web3.middleware.exception_retry_request import check_if_retry_on_failure
from web3.middleware.exception_retry_request import whitelist
from web3.types import RPCEndpoint, RPCResponse
request_exceptions = (ConnectionError, HTTPError, Timeout, TooManyRedirects)
aiohttp_exceptions = (
ClientOSError,
ClientResponseError,
ClientConnectorError,
ServerDisconnectedError,
ServerTimeoutError,
ClientResponseError,
)
whitelist_additions = ["eth_getBlockReceipts", "trace_block", "eth_feeHistory"]
logger = logging.getLogger(__name__)
def check_if_retry_on_failure(method: RPCEndpoint) -> bool:
root = method.split("_")[0]
if root in (whitelist + whitelist_additions):
return True
elif method in (whitelist + whitelist_additions):
return True
else:
return False
async def exception_retry_with_backoff_middleware(
make_request: Callable[[RPCEndpoint, Any], Any],
web3: Web3, # pylint: disable=unused-argument
@ -47,15 +61,14 @@ async def exception_retry_with_backoff_middleware(
# https://github.com/python/mypy/issues/5349
except errors: # type: ignore
logger.error(
f"Request for method {method}, block: {int(params[0], 16)}, retrying: {i}/{retries}"
f"Request for method {method}, params: {params}, retrying: {i}/{retries}"
)
if i < retries - 1:
if i < (retries - 1):
backoff_time = backoff_time_seconds * (
random.uniform(5, 10) ** i
)
await asyncio.sleep(backoff_time)
continue
else:
raise
return None
@ -71,5 +84,9 @@ async def http_retry_with_backoff_request_middleware(
return await exception_retry_with_backoff_middleware(
make_request,
web3,
(request_exceptions + aiohttp_exceptions + (TimeoutError,)),
(
request_exceptions
+ aiohttp_exceptions
+ (TimeoutError, ConnectionRefusedError)
),
)