diff --git a/cli.py b/cli.py index 86f9594..7bea3f5 100644 --- a/cli.py +++ b/cli.py @@ -7,6 +7,7 @@ import click from mev_inspect.concurrency import coro from mev_inspect.db import get_inspect_session, get_trace_session from mev_inspect.inspector import MEVInspector +from mev_inspect.utils import RPCType RPC_URL_ENV = "RPC_URL" @@ -21,17 +22,29 @@ def cli(): @cli.command() @click.argument("block_number", type=int) @click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, "")) -@click.option("--geth/--no-geth", default=False) +@click.option( + "--type", + type=click.Choice(list(map(lambda x: x.name, RPCType)), case_sensitive=False), + default=RPCType.parity.name, +) @coro -async def inspect_block_command(block_number: int, rpc: str, geth: bool): - print("geth", geth) +async def inspect_block_command(block_number: int, rpc: str, type: str): + type_e = convert_str_to_enum(type) inspect_db_session = get_inspect_session() trace_db_session = get_trace_session() - inspector = MEVInspector(rpc, inspect_db_session, trace_db_session, geth) + inspector = MEVInspector(rpc, inspect_db_session, trace_db_session, type_e) await inspector.inspect_single_block(block=block_number) +def convert_str_to_enum(type: str) -> RPCType: + if type == "parity": + return RPCType.parity + elif type == "geth": + return RPCType.geth + raise ValueError + + @cli.command() @click.argument("block_number", type=int) @click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, "")) @@ -40,7 +53,7 @@ async def fetch_block_command(block_number: int, rpc: str): inspect_db_session = get_inspect_session() trace_db_session = get_trace_session() - inspector = MEVInspector(rpc, inspect_db_session, trace_db_session, False) + inspector = MEVInspector(rpc, inspect_db_session, trace_db_session, RPCType.parity) block = await inspector.create_from_block(block_number=block_number) print(block.json()) @@ -49,7 +62,11 @@ async def fetch_block_command(block_number: int, rpc: str): @click.argument("after_block", type=int) @click.argument("before_block", type=int) @click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, "")) -@click.option("--geth/--no-geth", default=False) +@click.option( + "--type", + type=click.Choice(list(map(lambda x: x.name, RPCType)), case_sensitive=False), + default=RPCType.parity.name, +) @click.option( "--max-concurrency", type=int, @@ -66,15 +83,16 @@ async def inspect_many_blocks_command( rpc: str, max_concurrency: int, request_timeout: int, - geth: bool, + type: str, ): + type_e = convert_str_to_enum(type) inspect_db_session = get_inspect_session() trace_db_session = get_trace_session() inspector = MEVInspector( rpc, inspect_db_session, trace_db_session, - geth, + type_e, max_concurrency=max_concurrency, request_timeout=request_timeout, ) diff --git a/mev_inspect/block.py b/mev_inspect/block.py index dcd3554..5f6a3fa 100644 --- a/mev_inspect/block.py +++ b/mev_inspect/block.py @@ -1,8 +1,6 @@ import asyncio import logging from typing import List, Optional -import json -import aiohttp from sqlalchemy import orm from web3 import Web3 @@ -11,10 +9,17 @@ from mev_inspect.fees import fetch_base_fee_per_gas from mev_inspect.schemas.blocks import Block from mev_inspect.schemas.receipts import Receipt from mev_inspect.schemas.traces import Trace, TraceType -from mev_inspect.utils import hex_to_int +from mev_inspect.utils import RPCType, hex_to_int logger = logging.getLogger(__name__) +_calltype_mapping = { + "CALL": "call", + "DELEGATECALL": "delegateCall", + "CREATE": "create", + "SUICIDE": "suicide", + "REWARD": "reward", +} async def get_latest_block_number(base_provider) -> int: @@ -29,7 +34,7 @@ async def get_latest_block_number(base_provider) -> int: async def create_from_block_number( base_provider, w3: Web3, - geth: bool, + type: RPCType, block_number: int, trace_db_session: Optional[orm.Session], ) -> Block: @@ -39,55 +44,63 @@ async def create_from_block_number( block = _find_block(trace_db_session, block_number) if block is None: - block = await _fetch_block(w3, base_provider, geth, block_number) - return block + if type is RPCType.parity: + block = await _fetch_block_parity(w3, base_provider, block_number) + elif type is RPCType.geth: + block = await _fetch_block_geth(w3, base_provider, block_number) + else: + logger.error(f"RPCType not known - {type}") + raise ValueError return block -async def _fetch_block( - w3, base_provider, geth: bool, block_number: int, retries: int = 0 +async def _fetch_block_parity( + w3, base_provider, block_number: int, retries: int = 0 ) -> Block: - if not geth: - block_json, receipts_json, traces_json, base_fee_per_gas = await asyncio.gather( - w3.eth.get_block(block_number), - base_provider.make_request("eth_getBlockReceipts", [block_number]), - base_provider.make_request("trace_block", [block_number]), - fetch_base_fee_per_gas(w3, block_number), - ) + block_json, receipts_json, traces_json, base_fee_per_gas = await asyncio.gather( + w3.eth.get_block(block_number), + base_provider.make_request("eth_getBlockReceipts", [block_number]), + base_provider.make_request("trace_block", [block_number]), + fetch_base_fee_per_gas(w3, block_number), + ) - try: - receipts: List[Receipt] = [ - Receipt(**receipt) for receipt in receipts_json["result"] - ] - traces = [Trace(**trace_json) for trace_json in traces_json["result"]] - return Block( - block_number=block_number, - block_timestamp=block_json["timestamp"], - miner=block_json["miner"], - base_fee_per_gas=base_fee_per_gas, - traces=traces, - receipts=receipts, - ) - except KeyError as e: - logger.warning( - f"Failed to create objects from block: {block_number}: {e}, retrying: {retries + 1} / 3" - ) - if retries < 3: - await asyncio.sleep(5) - return await _fetch_block( - w3, base_provider, geth, block_number, retries - ) - else: - raise - else: - # print(block_number) - block_json = await asyncio.gather(w3.eth.get_block(block_number)) + try: + receipts: List[Receipt] = [ + Receipt(**receipt) for receipt in receipts_json["result"] + ] + traces = [Trace(**trace_json) for trace_json in traces_json["result"]] + return Block( + block_number=block_number, + block_timestamp=block_json["timestamp"], + miner=block_json["miner"], + base_fee_per_gas=base_fee_per_gas, + traces=traces, + receipts=receipts, + ) + except KeyError as e: + logger.warning( + f"Failed to create objects from block: {block_number}: {e}, retrying: {retries + 1} / 3" + ) + if retries < 3: + await asyncio.sleep(5) + return await _fetch_block_parity(w3, base_provider, block_number, retries) + else: + raise + + +async def _fetch_block_geth( + w3, base_provider, block_number: int, retries: int = 0 +) -> Block: + block_json = await asyncio.gather(w3.eth.get_block(block_number)) + + try: + # Separate calls to help with load during block tracing traces = await geth_get_tx_traces_parity_format(base_provider, block_json[0]) geth_tx_receipts = await geth_get_tx_receipts_async( - base_provider.endpoint_uri, block_json[0]["transactions"] + base_provider, block_json[0]["transactions"] ) receipts = geth_receipts_translator(block_json[0], geth_tx_receipts) - base_fee_per_gas = 0 + base_fee_per_gas = 0 # Polygon specific, TODO for other chains return Block( block_number=block_number, @@ -97,6 +110,15 @@ async def _fetch_block( traces=traces, receipts=receipts, ) + except KeyError as e: + logger.warning( + f"Failed to create objects from block: {block_number}: {e}, retrying: {retries + 1} / 3" + ) + if retries < 3: + await asyncio.sleep(5) + return await _fetch_block_geth(w3, base_provider, block_number, retries) + else: + raise def _find_block( @@ -245,13 +267,6 @@ def unwrap_tx_trace_for_parity( block_json, tx_pos_in_block, tx_trace, position=[] ) -> List[Trace]: response_list = [] - _calltype_mapping = { - "CALL": "call", - "DELEGATECALL": "delegateCall", - "CREATE": "create", - "SUICIDE": "suicide", - "REWARD": "reward", - } try: if tx_trace["type"] == "STATICCALL": return [] @@ -279,7 +294,8 @@ def unwrap_tx_trace_for_parity( type=TraceType(_calltype_mapping[tx_trace["type"]]), ) ) - except Exception: + except Exception as e: + logger.warn(f"error while unwraping tx trace for parity {e}") return [] if "calls" in tx_trace.keys(): @@ -292,28 +308,20 @@ def unwrap_tx_trace_for_parity( return response_list -async def geth_get_tx_receipts_task(session, endpoint_uri, tx): - data = { - "jsonrpc": "2.0", - "id": "0", - "method": "eth_getTransactionReceipt", - "params": [tx.hex()], - } - async with session.post(endpoint_uri, json=data) as response: - if response.status != 200: - response.raise_for_status() - return await response.text() +async def geth_get_tx_receipts_task(base_provider, tx): + receipt = await base_provider.make_request("eth_getTransactionReceipt", [tx.hex()]) + return receipt -async def geth_get_tx_receipts_async(endpoint_uri, transactions): +async def geth_get_tx_receipts_async(base_provider, transactions): geth_tx_receipts = [] - async with aiohttp.ClientSession() as session: - tasks = [ - asyncio.create_task(geth_get_tx_receipts_task(session, endpoint_uri, tx)) - for tx in transactions - ] - geth_tx_receipts = await asyncio.gather(*tasks) - return [json.loads(tx_receipts) for tx_receipts in geth_tx_receipts] + tasks = [ + asyncio.create_task(geth_get_tx_receipts_task(base_provider, tx)) + for tx in transactions + ] + geth_tx_receipts = await asyncio.gather(*tasks) + # return [json.loads(tx_receipts) for tx_receipts in geth_tx_receipts] + return geth_tx_receipts def geth_receipts_translator(block_json, geth_tx_receipts) -> List[Receipt]: @@ -331,24 +339,18 @@ def geth_receipts_translator(block_json, geth_tx_receipts) -> List[Receipt]: def unwrap_tx_receipt_for_parity(block_json, tx_pos_in_block, tx_receipt) -> Receipt: - try: - if tx_pos_in_block != int(tx_receipt["transactionIndex"], 16): - print( - "Alert the position of transaction in block is mismatched ", - tx_pos_in_block, - tx_receipt["transactionIndex"], - ) - return Receipt( - block_number=block_json["number"], - transaction_hash=tx_receipt["transactionHash"], - transaction_index=tx_pos_in_block, - gas_used=tx_receipt["gasUsed"], - effective_gas_price=tx_receipt["effectiveGasPrice"], - cumulative_gas_used=tx_receipt["cumulativeGasUsed"], - to=tx_receipt["to"], + if tx_pos_in_block != int(tx_receipt["transactionIndex"], 16): + logger.info( + "Alert the position of transaction in block is mismatched ", + tx_pos_in_block, + tx_receipt["transactionIndex"], ) - - except Exception as e: - print("error while decoding receipt", tx_receipt, e) - - return Receipt() + return Receipt( + block_number=block_json["number"], + transaction_hash=tx_receipt["transactionHash"], + transaction_index=tx_pos_in_block, + gas_used=tx_receipt["gasUsed"], + effective_gas_price=tx_receipt["effectiveGasPrice"], + cumulative_gas_used=tx_receipt["cumulativeGasUsed"], + to=tx_receipt["to"], + ) diff --git a/mev_inspect/geth_poa_middleware.py b/mev_inspect/geth_poa_middleware.py index d5ae8c2..6eb103d 100644 --- a/mev_inspect/geth_poa_middleware.py +++ b/mev_inspect/geth_poa_middleware.py @@ -1,3 +1,7 @@ +""" +Modified asynchronous geth_poa_middleware which mirrors functionality of +https://github.com/ethereum/web3.py/blob/master/web3/middleware/geth_poa.py +""" from typing import ( Any, Callable, diff --git a/mev_inspect/inspect_block.py b/mev_inspect/inspect_block.py index 53ad8a4..11c5625 100644 --- a/mev_inspect/inspect_block.py +++ b/mev_inspect/inspect_block.py @@ -34,6 +34,7 @@ from mev_inspect.miner_payments import get_miner_payments from mev_inspect.swaps import get_swaps from mev_inspect.transfers import get_transfers from mev_inspect.liquidations import get_liquidations +from mev_inspect.utils import RPCType logger = logging.getLogger(__name__) @@ -43,7 +44,7 @@ async def inspect_block( inspect_db_session: orm.Session, base_provider, w3: Web3, - geth: bool, + type: RPCType, trace_classifier: TraceClassifier, block_number: int, trace_db_session: Optional[orm.Session], @@ -52,7 +53,7 @@ async def inspect_block( block = await create_from_block_number( base_provider, w3, - geth, + type, block_number, trace_db_session, ) diff --git a/mev_inspect/inspector.py b/mev_inspect/inspector.py index 51ea585..cfc2df1 100644 --- a/mev_inspect/inspector.py +++ b/mev_inspect/inspector.py @@ -12,6 +12,7 @@ from mev_inspect.block import create_from_block_number from mev_inspect.classifiers.trace import TraceClassifier from mev_inspect.inspect_block import inspect_block from mev_inspect.provider import get_base_provider +from mev_inspect.utils import RPCType logger = logging.getLogger(__name__) @@ -22,25 +23,15 @@ class MEVInspector: rpc: str, inspect_db_session: orm.Session, trace_db_session: Optional[orm.Session], - geth: bool = False, + type: RPCType = RPCType.parity, max_concurrency: int = 1, request_timeout: int = 300, ): self.inspect_db_session = inspect_db_session self.trace_db_session = trace_db_session - self.base_provider = get_base_provider(rpc, request_timeout, geth) - self.geth = geth + self.base_provider = get_base_provider(rpc, request_timeout, type) + self.type = type self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[]) - # if geth: - # self.w3 = Web3( - # self.base_provider, - # modules={"eth": (AsyncEth,)}, - # middlewares=[], - # ) - # else: - # self.w3 = Web3( - # self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[] - # ) self.trace_classifier = TraceClassifier() self.max_concurrency = asyncio.Semaphore(max_concurrency) @@ -48,7 +39,7 @@ class MEVInspector: return await create_from_block_number( base_provider=self.base_provider, w3=self.w3, - geth=self.geth, + type=self.type, block_number=block_number, trace_db_session=self.trace_db_session, ) @@ -58,7 +49,7 @@ class MEVInspector: self.inspect_db_session, self.base_provider, self.w3, - self.geth, + self.type, self.trace_classifier, block, trace_db_session=self.trace_db_session, @@ -87,7 +78,7 @@ class MEVInspector: self.inspect_db_session, self.base_provider, self.w3, - self.geth, + self.type, self.trace_classifier, block_number, trace_db_session=self.trace_db_session, diff --git a/mev_inspect/provider.py b/mev_inspect/provider.py index 9fb20eb..836d51e 100644 --- a/mev_inspect/provider.py +++ b/mev_inspect/provider.py @@ -2,13 +2,14 @@ from web3 import Web3, AsyncHTTPProvider from mev_inspect.retry import http_retry_with_backoff_request_middleware from mev_inspect.geth_poa_middleware import geth_poa_middleware +from mev_inspect.utils import RPCType def get_base_provider( - rpc: str, request_timeout: int = 500, geth: bool = False + rpc: str, request_timeout: int = 500, type: RPCType = RPCType.parity ) -> Web3.AsyncHTTPProvider: base_provider = AsyncHTTPProvider(rpc, request_kwargs={"timeout": request_timeout}) - if geth: + if type is RPCType.geth: base_provider.middlewares += ( geth_poa_middleware, http_retry_with_backoff_request_middleware, diff --git a/mev_inspect/utils.py b/mev_inspect/utils.py index 922fada..eb78413 100644 --- a/mev_inspect/utils.py +++ b/mev_inspect/utils.py @@ -1,5 +1,11 @@ +from enum import Enum from hexbytes._utils import hexstr_to_bytes +class RPCType(Enum): + parity = 0 + geth = 1 + + def hex_to_int(value: str) -> int: return int.from_bytes(hexstr_to_bytes(value), byteorder="big")