diff --git a/mev_inspect/block.py b/mev_inspect/block.py index df156df..dcd3554 100644 --- a/mev_inspect/block.py +++ b/mev_inspect/block.py @@ -80,19 +80,19 @@ async def _fetch_block( else: raise else: - block_json = await w3.eth.get_block(block_number) - print(block_json) - traces = geth_get_tx_traces_parity_format(base_provider, block_json) - geth_tx_receipts = geth_get_tx_receipts( - base_provider, block_json["transactions"] + # print(block_number) + block_json = await asyncio.gather(w3.eth.get_block(block_number)) + traces = await geth_get_tx_traces_parity_format(base_provider, block_json[0]) + geth_tx_receipts = await geth_get_tx_receipts_async( + base_provider.endpoint_uri, block_json[0]["transactions"] ) - receipts = geth_receipts_translator(block_json, geth_tx_receipts) + receipts = geth_receipts_translator(block_json[0], geth_tx_receipts) base_fee_per_gas = 0 return Block( block_number=block_number, - block_timestamp=block_json["timestamp"], - miner=block_json["miner"], + block_timestamp=block_json[0]["timestamp"], + miner=block_json[0]["miner"], base_fee_per_gas=base_fee_per_gas, traces=traces, receipts=receipts, @@ -220,9 +220,11 @@ def get_transaction_hashes(calls: List[Trace]) -> List[str]: # Geth specific additions -def geth_get_tx_traces_parity_format(base_provider, block_json: dict): +async def geth_get_tx_traces_parity_format(base_provider, block_json: dict): + # print(block_json['hash'].hex()) block_hash = block_json["hash"] - block_trace = geth_get_tx_traces(base_provider, block_hash) + block_trace = await geth_get_tx_traces(base_provider, block_hash) + # print(block_trace) parity_traces = [] for idx, trace in enumerate(block_trace["result"]): if "result" in trace: @@ -232,8 +234,8 @@ def geth_get_tx_traces_parity_format(base_provider, block_json: dict): return parity_traces -def geth_get_tx_traces(base_provider, block_hash): - block_trace = base_provider.make_request( +async def geth_get_tx_traces(base_provider, block_hash): + block_trace = await base_provider.make_request( "debug_traceBlockByHash", [block_hash.hex(), {"tracer": "callTracer"}] ) return block_trace @@ -314,12 +316,6 @@ async def geth_get_tx_receipts_async(endpoint_uri, transactions): return [json.loads(tx_receipts) for tx_receipts in geth_tx_receipts] -def geth_get_tx_receipts(base_provider, transactions): - return asyncio.run( - geth_get_tx_receipts_async(base_provider.endpoint_uri, transactions) - ) - - def geth_receipts_translator(block_json, geth_tx_receipts) -> List[Receipt]: json_decoded_receipts = [ tx_receipt["result"] diff --git a/mev_inspect/geth_poa_middleware.py b/mev_inspect/geth_poa_middleware.py new file mode 100644 index 0000000..712ac95 --- /dev/null +++ b/mev_inspect/geth_poa_middleware.py @@ -0,0 +1,99 @@ +from typing import ( + Any, + Callable, +) + +from hexbytes import ( + HexBytes, +) + +from eth_utils.curried import ( + apply_formatter_if, + apply_formatters_to_dict, + apply_key_map, + is_null, +) +from eth_utils.toolz import ( + complement, + compose, + assoc, +) + +from web3._utils.rpc_abi import ( + RPC, +) + +from web3.types import ( + Formatters, + RPCEndpoint, + RPCResponse, +) + +from web3 import Web3 # noqa: F401 + + +async def get_geth_poa_middleware( + make_request: Callable[[RPCEndpoint, Any], RPCResponse], + request_formatters: Formatters = {}, + result_formatters: Formatters = {}, + error_formatters: Formatters = {}, +) -> RPCResponse: + async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + if method in request_formatters: + formatter = request_formatters[method] + formatted_params = formatter(params) + response = await make_request(method, formatted_params) + else: + response = await make_request(method, params) + + if "result" in response and method in result_formatters: + formatter = result_formatters[method] + formatted_response = assoc( + response, + "result", + formatter(response["result"]), + ) + return formatted_response + elif "error" in response and method in error_formatters: + formatter = error_formatters[method] + formatted_response = assoc( + response, + "error", + formatter(response["error"]), + ) + return formatted_response + else: + return response + + return middleware + + +is_not_null = complement(is_null) + +remap_geth_poa_fields = apply_key_map( + { + "extraData": "proofOfAuthorityData", + } +) + +pythonic_geth_poa = apply_formatters_to_dict( + { + "proofOfAuthorityData": HexBytes, + } +) + +geth_poa_cleanup = compose(pythonic_geth_poa, remap_geth_poa_fields) + + +async def geth_poa_middleware( + make_request: Callable[[RPCEndpoint, Any], Any], *_: Web3 +): + return await get_geth_poa_middleware( + make_request=make_request, + request_formatters={}, + result_formatters={ + RPC.eth_getBlockByHash: apply_formatter_if(is_not_null, geth_poa_cleanup), + RPC.eth_getBlockByNumber: apply_formatter_if(is_not_null, geth_poa_cleanup), + }, + error_formatters={}, + ) diff --git a/mev_inspect/inspector.py b/mev_inspect/inspector.py index 0fe97f6..51ea585 100644 --- a/mev_inspect/inspector.py +++ b/mev_inspect/inspector.py @@ -7,7 +7,6 @@ from typing import Optional from sqlalchemy import orm from web3 import Web3 from web3.eth import AsyncEth -from web3.middleware import geth_poa_middleware from mev_inspect.block import create_from_block_number from mev_inspect.classifiers.trace import TraceClassifier @@ -29,18 +28,19 @@ class MEVInspector: ): self.inspect_db_session = inspect_db_session self.trace_db_session = trace_db_session - self.base_provider = get_base_provider(rpc, request_timeout=request_timeout) + self.base_provider = get_base_provider(rpc, request_timeout, geth) self.geth = geth - if geth: - self.w3 = Web3( - self.base_provider, - modules={"eth": (AsyncEth,)}, - middlewares=[geth_poa_middleware], - ) - else: - self.w3 = Web3( - self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[] - ) + self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[]) + # if geth: + # self.w3 = Web3( + # self.base_provider, + # modules={"eth": (AsyncEth,)}, + # middlewares=[], + # ) + # else: + # self.w3 = Web3( + # self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[] + # ) self.trace_classifier = TraceClassifier() self.max_concurrency = asyncio.Semaphore(max_concurrency) diff --git a/mev_inspect/provider.py b/mev_inspect/provider.py index 3b930ea..9fb20eb 100644 --- a/mev_inspect/provider.py +++ b/mev_inspect/provider.py @@ -1,9 +1,18 @@ from web3 import Web3, AsyncHTTPProvider from mev_inspect.retry import http_retry_with_backoff_request_middleware +from mev_inspect.geth_poa_middleware import geth_poa_middleware -def get_base_provider(rpc: str, request_timeout: int = 500) -> Web3.AsyncHTTPProvider: +def get_base_provider( + rpc: str, request_timeout: int = 500, geth: bool = False +) -> Web3.AsyncHTTPProvider: base_provider = AsyncHTTPProvider(rpc, request_kwargs={"timeout": request_timeout}) - base_provider.middlewares += (http_retry_with_backoff_request_middleware,) + if geth: + base_provider.middlewares += ( + geth_poa_middleware, + http_retry_with_backoff_request_middleware, + ) + else: + base_provider.middlewares += (http_retry_with_backoff_request_middleware,) return base_provider