added RPCType, fixes

This commit is contained in:
Supragya Raj 2021-12-09 13:17:34 +01:00
parent f705bb9f2b
commit b31f5d750e
7 changed files with 142 additions and 119 deletions

34
cli.py
View File

@ -7,6 +7,7 @@ import click
from mev_inspect.concurrency import coro from mev_inspect.concurrency import coro
from mev_inspect.db import get_inspect_session, get_trace_session from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspector import MEVInspector from mev_inspect.inspector import MEVInspector
from mev_inspect.utils import RPCType
RPC_URL_ENV = "RPC_URL" RPC_URL_ENV = "RPC_URL"
@ -21,17 +22,29 @@ def cli():
@cli.command() @cli.command()
@click.argument("block_number", type=int) @click.argument("block_number", type=int)
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, "")) @click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@click.option("--geth/--no-geth", default=False) @click.option(
"--type",
type=click.Choice(list(map(lambda x: x.name, RPCType)), case_sensitive=False),
default=RPCType.parity.name,
)
@coro @coro
async def inspect_block_command(block_number: int, rpc: str, geth: bool): async def inspect_block_command(block_number: int, rpc: str, type: str):
print("geth", geth) type_e = convert_str_to_enum(type)
inspect_db_session = get_inspect_session() inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session() trace_db_session = get_trace_session()
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session, geth) inspector = MEVInspector(rpc, inspect_db_session, trace_db_session, type_e)
await inspector.inspect_single_block(block=block_number) await inspector.inspect_single_block(block=block_number)
def convert_str_to_enum(type: str) -> RPCType:
if type == "parity":
return RPCType.parity
elif type == "geth":
return RPCType.geth
raise ValueError
@cli.command() @cli.command()
@click.argument("block_number", type=int) @click.argument("block_number", type=int)
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, "")) @click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@ -40,7 +53,7 @@ async def fetch_block_command(block_number: int, rpc: str):
inspect_db_session = get_inspect_session() inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session() trace_db_session = get_trace_session()
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session, False) inspector = MEVInspector(rpc, inspect_db_session, trace_db_session, RPCType.parity)
block = await inspector.create_from_block(block_number=block_number) block = await inspector.create_from_block(block_number=block_number)
print(block.json()) print(block.json())
@ -49,7 +62,11 @@ async def fetch_block_command(block_number: int, rpc: str):
@click.argument("after_block", type=int) @click.argument("after_block", type=int)
@click.argument("before_block", type=int) @click.argument("before_block", type=int)
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, "")) @click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@click.option("--geth/--no-geth", default=False) @click.option(
"--type",
type=click.Choice(list(map(lambda x: x.name, RPCType)), case_sensitive=False),
default=RPCType.parity.name,
)
@click.option( @click.option(
"--max-concurrency", "--max-concurrency",
type=int, type=int,
@ -66,15 +83,16 @@ async def inspect_many_blocks_command(
rpc: str, rpc: str,
max_concurrency: int, max_concurrency: int,
request_timeout: int, request_timeout: int,
geth: bool, type: str,
): ):
type_e = convert_str_to_enum(type)
inspect_db_session = get_inspect_session() inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session() trace_db_session = get_trace_session()
inspector = MEVInspector( inspector = MEVInspector(
rpc, rpc,
inspect_db_session, inspect_db_session,
trace_db_session, trace_db_session,
geth, type_e,
max_concurrency=max_concurrency, max_concurrency=max_concurrency,
request_timeout=request_timeout, request_timeout=request_timeout,
) )

View File

@ -1,8 +1,6 @@
import asyncio import asyncio
import logging import logging
from typing import List, Optional from typing import List, Optional
import json
import aiohttp
from sqlalchemy import orm from sqlalchemy import orm
from web3 import Web3 from web3 import Web3
@ -11,10 +9,17 @@ from mev_inspect.fees import fetch_base_fee_per_gas
from mev_inspect.schemas.blocks import Block from mev_inspect.schemas.blocks import Block
from mev_inspect.schemas.receipts import Receipt from mev_inspect.schemas.receipts import Receipt
from mev_inspect.schemas.traces import Trace, TraceType from mev_inspect.schemas.traces import Trace, TraceType
from mev_inspect.utils import hex_to_int from mev_inspect.utils import RPCType, hex_to_int
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_calltype_mapping = {
"CALL": "call",
"DELEGATECALL": "delegateCall",
"CREATE": "create",
"SUICIDE": "suicide",
"REWARD": "reward",
}
async def get_latest_block_number(base_provider) -> int: async def get_latest_block_number(base_provider) -> int:
@ -29,7 +34,7 @@ async def get_latest_block_number(base_provider) -> int:
async def create_from_block_number( async def create_from_block_number(
base_provider, base_provider,
w3: Web3, w3: Web3,
geth: bool, type: RPCType,
block_number: int, block_number: int,
trace_db_session: Optional[orm.Session], trace_db_session: Optional[orm.Session],
) -> Block: ) -> Block:
@ -39,55 +44,63 @@ async def create_from_block_number(
block = _find_block(trace_db_session, block_number) block = _find_block(trace_db_session, block_number)
if block is None: if block is None:
block = await _fetch_block(w3, base_provider, geth, block_number) if type is RPCType.parity:
return block block = await _fetch_block_parity(w3, base_provider, block_number)
elif type is RPCType.geth:
block = await _fetch_block_geth(w3, base_provider, block_number)
else:
logger.error(f"RPCType not known - {type}")
raise ValueError
return block return block
async def _fetch_block( async def _fetch_block_parity(
w3, base_provider, geth: bool, block_number: int, retries: int = 0 w3, base_provider, block_number: int, retries: int = 0
) -> Block: ) -> Block:
if not geth: block_json, receipts_json, traces_json, base_fee_per_gas = await asyncio.gather(
block_json, receipts_json, traces_json, base_fee_per_gas = await asyncio.gather( w3.eth.get_block(block_number),
w3.eth.get_block(block_number), base_provider.make_request("eth_getBlockReceipts", [block_number]),
base_provider.make_request("eth_getBlockReceipts", [block_number]), base_provider.make_request("trace_block", [block_number]),
base_provider.make_request("trace_block", [block_number]), fetch_base_fee_per_gas(w3, block_number),
fetch_base_fee_per_gas(w3, block_number), )
)
try: try:
receipts: List[Receipt] = [ receipts: List[Receipt] = [
Receipt(**receipt) for receipt in receipts_json["result"] Receipt(**receipt) for receipt in receipts_json["result"]
] ]
traces = [Trace(**trace_json) for trace_json in traces_json["result"]] traces = [Trace(**trace_json) for trace_json in traces_json["result"]]
return Block( return Block(
block_number=block_number, block_number=block_number,
block_timestamp=block_json["timestamp"], block_timestamp=block_json["timestamp"],
miner=block_json["miner"], miner=block_json["miner"],
base_fee_per_gas=base_fee_per_gas, base_fee_per_gas=base_fee_per_gas,
traces=traces, traces=traces,
receipts=receipts, receipts=receipts,
) )
except KeyError as e: except KeyError as e:
logger.warning( logger.warning(
f"Failed to create objects from block: {block_number}: {e}, retrying: {retries + 1} / 3" f"Failed to create objects from block: {block_number}: {e}, retrying: {retries + 1} / 3"
) )
if retries < 3: if retries < 3:
await asyncio.sleep(5) await asyncio.sleep(5)
return await _fetch_block( return await _fetch_block_parity(w3, base_provider, block_number, retries)
w3, base_provider, geth, block_number, retries else:
) raise
else:
raise
else: async def _fetch_block_geth(
# print(block_number) w3, base_provider, block_number: int, retries: int = 0
block_json = await asyncio.gather(w3.eth.get_block(block_number)) ) -> Block:
block_json = await asyncio.gather(w3.eth.get_block(block_number))
try:
# Separate calls to help with load during block tracing
traces = await geth_get_tx_traces_parity_format(base_provider, block_json[0]) traces = await geth_get_tx_traces_parity_format(base_provider, block_json[0])
geth_tx_receipts = await geth_get_tx_receipts_async( geth_tx_receipts = await geth_get_tx_receipts_async(
base_provider.endpoint_uri, block_json[0]["transactions"] base_provider, block_json[0]["transactions"]
) )
receipts = geth_receipts_translator(block_json[0], geth_tx_receipts) receipts = geth_receipts_translator(block_json[0], geth_tx_receipts)
base_fee_per_gas = 0 base_fee_per_gas = 0 # Polygon specific, TODO for other chains
return Block( return Block(
block_number=block_number, block_number=block_number,
@ -97,6 +110,15 @@ async def _fetch_block(
traces=traces, traces=traces,
receipts=receipts, receipts=receipts,
) )
except KeyError as e:
logger.warning(
f"Failed to create objects from block: {block_number}: {e}, retrying: {retries + 1} / 3"
)
if retries < 3:
await asyncio.sleep(5)
return await _fetch_block_geth(w3, base_provider, block_number, retries)
else:
raise
def _find_block( def _find_block(
@ -245,13 +267,6 @@ def unwrap_tx_trace_for_parity(
block_json, tx_pos_in_block, tx_trace, position=[] block_json, tx_pos_in_block, tx_trace, position=[]
) -> List[Trace]: ) -> List[Trace]:
response_list = [] response_list = []
_calltype_mapping = {
"CALL": "call",
"DELEGATECALL": "delegateCall",
"CREATE": "create",
"SUICIDE": "suicide",
"REWARD": "reward",
}
try: try:
if tx_trace["type"] == "STATICCALL": if tx_trace["type"] == "STATICCALL":
return [] return []
@ -279,7 +294,8 @@ def unwrap_tx_trace_for_parity(
type=TraceType(_calltype_mapping[tx_trace["type"]]), type=TraceType(_calltype_mapping[tx_trace["type"]]),
) )
) )
except Exception: except Exception as e:
logger.warn(f"error while unwraping tx trace for parity {e}")
return [] return []
if "calls" in tx_trace.keys(): if "calls" in tx_trace.keys():
@ -292,28 +308,20 @@ def unwrap_tx_trace_for_parity(
return response_list return response_list
async def geth_get_tx_receipts_task(session, endpoint_uri, tx): async def geth_get_tx_receipts_task(base_provider, tx):
data = { receipt = await base_provider.make_request("eth_getTransactionReceipt", [tx.hex()])
"jsonrpc": "2.0", return receipt
"id": "0",
"method": "eth_getTransactionReceipt",
"params": [tx.hex()],
}
async with session.post(endpoint_uri, json=data) as response:
if response.status != 200:
response.raise_for_status()
return await response.text()
async def geth_get_tx_receipts_async(endpoint_uri, transactions): async def geth_get_tx_receipts_async(base_provider, transactions):
geth_tx_receipts = [] geth_tx_receipts = []
async with aiohttp.ClientSession() as session: tasks = [
tasks = [ asyncio.create_task(geth_get_tx_receipts_task(base_provider, tx))
asyncio.create_task(geth_get_tx_receipts_task(session, endpoint_uri, tx)) for tx in transactions
for tx in transactions ]
] geth_tx_receipts = await asyncio.gather(*tasks)
geth_tx_receipts = await asyncio.gather(*tasks) # return [json.loads(tx_receipts) for tx_receipts in geth_tx_receipts]
return [json.loads(tx_receipts) for tx_receipts in geth_tx_receipts] return geth_tx_receipts
def geth_receipts_translator(block_json, geth_tx_receipts) -> List[Receipt]: def geth_receipts_translator(block_json, geth_tx_receipts) -> List[Receipt]:
@ -331,24 +339,18 @@ def geth_receipts_translator(block_json, geth_tx_receipts) -> List[Receipt]:
def unwrap_tx_receipt_for_parity(block_json, tx_pos_in_block, tx_receipt) -> Receipt: def unwrap_tx_receipt_for_parity(block_json, tx_pos_in_block, tx_receipt) -> Receipt:
try: if tx_pos_in_block != int(tx_receipt["transactionIndex"], 16):
if tx_pos_in_block != int(tx_receipt["transactionIndex"], 16): logger.info(
print( "Alert the position of transaction in block is mismatched ",
"Alert the position of transaction in block is mismatched ", tx_pos_in_block,
tx_pos_in_block, tx_receipt["transactionIndex"],
tx_receipt["transactionIndex"],
)
return Receipt(
block_number=block_json["number"],
transaction_hash=tx_receipt["transactionHash"],
transaction_index=tx_pos_in_block,
gas_used=tx_receipt["gasUsed"],
effective_gas_price=tx_receipt["effectiveGasPrice"],
cumulative_gas_used=tx_receipt["cumulativeGasUsed"],
to=tx_receipt["to"],
) )
return Receipt(
except Exception as e: block_number=block_json["number"],
print("error while decoding receipt", tx_receipt, e) transaction_hash=tx_receipt["transactionHash"],
transaction_index=tx_pos_in_block,
return Receipt() gas_used=tx_receipt["gasUsed"],
effective_gas_price=tx_receipt["effectiveGasPrice"],
cumulative_gas_used=tx_receipt["cumulativeGasUsed"],
to=tx_receipt["to"],
)

View File

@ -1,3 +1,7 @@
"""
Modified asynchronous geth_poa_middleware which mirrors functionality of
https://github.com/ethereum/web3.py/blob/master/web3/middleware/geth_poa.py
"""
from typing import ( from typing import (
Any, Any,
Callable, Callable,

View File

@ -34,6 +34,7 @@ from mev_inspect.miner_payments import get_miner_payments
from mev_inspect.swaps import get_swaps from mev_inspect.swaps import get_swaps
from mev_inspect.transfers import get_transfers from mev_inspect.transfers import get_transfers
from mev_inspect.liquidations import get_liquidations from mev_inspect.liquidations import get_liquidations
from mev_inspect.utils import RPCType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -43,7 +44,7 @@ async def inspect_block(
inspect_db_session: orm.Session, inspect_db_session: orm.Session,
base_provider, base_provider,
w3: Web3, w3: Web3,
geth: bool, type: RPCType,
trace_classifier: TraceClassifier, trace_classifier: TraceClassifier,
block_number: int, block_number: int,
trace_db_session: Optional[orm.Session], trace_db_session: Optional[orm.Session],
@ -52,7 +53,7 @@ async def inspect_block(
block = await create_from_block_number( block = await create_from_block_number(
base_provider, base_provider,
w3, w3,
geth, type,
block_number, block_number,
trace_db_session, trace_db_session,
) )

View File

@ -12,6 +12,7 @@ from mev_inspect.block import create_from_block_number
from mev_inspect.classifiers.trace import TraceClassifier from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.inspect_block import inspect_block from mev_inspect.inspect_block import inspect_block
from mev_inspect.provider import get_base_provider from mev_inspect.provider import get_base_provider
from mev_inspect.utils import RPCType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -22,25 +23,15 @@ class MEVInspector:
rpc: str, rpc: str,
inspect_db_session: orm.Session, inspect_db_session: orm.Session,
trace_db_session: Optional[orm.Session], trace_db_session: Optional[orm.Session],
geth: bool = False, type: RPCType = RPCType.parity,
max_concurrency: int = 1, max_concurrency: int = 1,
request_timeout: int = 300, request_timeout: int = 300,
): ):
self.inspect_db_session = inspect_db_session self.inspect_db_session = inspect_db_session
self.trace_db_session = trace_db_session self.trace_db_session = trace_db_session
self.base_provider = get_base_provider(rpc, request_timeout, geth) self.base_provider = get_base_provider(rpc, request_timeout, type)
self.geth = geth self.type = type
self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[]) self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
# if geth:
# self.w3 = Web3(
# self.base_provider,
# modules={"eth": (AsyncEth,)},
# middlewares=[],
# )
# else:
# self.w3 = Web3(
# self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[]
# )
self.trace_classifier = TraceClassifier() self.trace_classifier = TraceClassifier()
self.max_concurrency = asyncio.Semaphore(max_concurrency) self.max_concurrency = asyncio.Semaphore(max_concurrency)
@ -48,7 +39,7 @@ class MEVInspector:
return await create_from_block_number( return await create_from_block_number(
base_provider=self.base_provider, base_provider=self.base_provider,
w3=self.w3, w3=self.w3,
geth=self.geth, type=self.type,
block_number=block_number, block_number=block_number,
trace_db_session=self.trace_db_session, trace_db_session=self.trace_db_session,
) )
@ -58,7 +49,7 @@ class MEVInspector:
self.inspect_db_session, self.inspect_db_session,
self.base_provider, self.base_provider,
self.w3, self.w3,
self.geth, self.type,
self.trace_classifier, self.trace_classifier,
block, block,
trace_db_session=self.trace_db_session, trace_db_session=self.trace_db_session,
@ -87,7 +78,7 @@ class MEVInspector:
self.inspect_db_session, self.inspect_db_session,
self.base_provider, self.base_provider,
self.w3, self.w3,
self.geth, self.type,
self.trace_classifier, self.trace_classifier,
block_number, block_number,
trace_db_session=self.trace_db_session, trace_db_session=self.trace_db_session,

View File

@ -2,13 +2,14 @@ from web3 import Web3, AsyncHTTPProvider
from mev_inspect.retry import http_retry_with_backoff_request_middleware from mev_inspect.retry import http_retry_with_backoff_request_middleware
from mev_inspect.geth_poa_middleware import geth_poa_middleware from mev_inspect.geth_poa_middleware import geth_poa_middleware
from mev_inspect.utils import RPCType
def get_base_provider( def get_base_provider(
rpc: str, request_timeout: int = 500, geth: bool = False rpc: str, request_timeout: int = 500, type: RPCType = RPCType.parity
) -> Web3.AsyncHTTPProvider: ) -> Web3.AsyncHTTPProvider:
base_provider = AsyncHTTPProvider(rpc, request_kwargs={"timeout": request_timeout}) base_provider = AsyncHTTPProvider(rpc, request_kwargs={"timeout": request_timeout})
if geth: if type is RPCType.geth:
base_provider.middlewares += ( base_provider.middlewares += (
geth_poa_middleware, geth_poa_middleware,
http_retry_with_backoff_request_middleware, http_retry_with_backoff_request_middleware,

View File

@ -1,5 +1,11 @@
from enum import Enum
from hexbytes._utils import hexstr_to_bytes from hexbytes._utils import hexstr_to_bytes
class RPCType(Enum):
parity = 0
geth = 1
def hex_to_int(value: str) -> int: def hex_to_int(value: str) -> int:
return int.from_bytes(hexstr_to_bytes(value), byteorder="big") return int.from_bytes(hexstr_to_bytes(value), byteorder="big")