added RPCType, fixes

This commit is contained in:
Supragya Raj 2021-12-09 13:17:34 +01:00
parent f705bb9f2b
commit b31f5d750e
7 changed files with 142 additions and 119 deletions

34
cli.py
View File

@ -7,6 +7,7 @@ import click
from mev_inspect.concurrency import coro
from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspector import MEVInspector
from mev_inspect.utils import RPCType
RPC_URL_ENV = "RPC_URL"
@ -21,17 +22,29 @@ def cli():
@cli.command()
@click.argument("block_number", type=int)
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@click.option("--geth/--no-geth", default=False)
@click.option(
"--type",
type=click.Choice(list(map(lambda x: x.name, RPCType)), case_sensitive=False),
default=RPCType.parity.name,
)
@coro
async def inspect_block_command(block_number: int, rpc: str, geth: bool):
print("geth", geth)
async def inspect_block_command(block_number: int, rpc: str, type: str):
type_e = convert_str_to_enum(type)
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session, geth)
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session, type_e)
await inspector.inspect_single_block(block=block_number)
def convert_str_to_enum(type: str) -> RPCType:
if type == "parity":
return RPCType.parity
elif type == "geth":
return RPCType.geth
raise ValueError
@cli.command()
@click.argument("block_number", type=int)
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@ -40,7 +53,7 @@ async def fetch_block_command(block_number: int, rpc: str):
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session, False)
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session, RPCType.parity)
block = await inspector.create_from_block(block_number=block_number)
print(block.json())
@ -49,7 +62,11 @@ async def fetch_block_command(block_number: int, rpc: str):
@click.argument("after_block", type=int)
@click.argument("before_block", type=int)
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@click.option("--geth/--no-geth", default=False)
@click.option(
"--type",
type=click.Choice(list(map(lambda x: x.name, RPCType)), case_sensitive=False),
default=RPCType.parity.name,
)
@click.option(
"--max-concurrency",
type=int,
@ -66,15 +83,16 @@ async def inspect_many_blocks_command(
rpc: str,
max_concurrency: int,
request_timeout: int,
geth: bool,
type: str,
):
type_e = convert_str_to_enum(type)
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
inspector = MEVInspector(
rpc,
inspect_db_session,
trace_db_session,
geth,
type_e,
max_concurrency=max_concurrency,
request_timeout=request_timeout,
)

View File

@ -1,8 +1,6 @@
import asyncio
import logging
from typing import List, Optional
import json
import aiohttp
from sqlalchemy import orm
from web3 import Web3
@ -11,10 +9,17 @@ from mev_inspect.fees import fetch_base_fee_per_gas
from mev_inspect.schemas.blocks import Block
from mev_inspect.schemas.receipts import Receipt
from mev_inspect.schemas.traces import Trace, TraceType
from mev_inspect.utils import hex_to_int
from mev_inspect.utils import RPCType, hex_to_int
logger = logging.getLogger(__name__)
_calltype_mapping = {
"CALL": "call",
"DELEGATECALL": "delegateCall",
"CREATE": "create",
"SUICIDE": "suicide",
"REWARD": "reward",
}
async def get_latest_block_number(base_provider) -> int:
@ -29,7 +34,7 @@ async def get_latest_block_number(base_provider) -> int:
async def create_from_block_number(
base_provider,
w3: Web3,
geth: bool,
type: RPCType,
block_number: int,
trace_db_session: Optional[orm.Session],
) -> Block:
@ -39,15 +44,19 @@ async def create_from_block_number(
block = _find_block(trace_db_session, block_number)
if block is None:
block = await _fetch_block(w3, base_provider, geth, block_number)
return block
if type is RPCType.parity:
block = await _fetch_block_parity(w3, base_provider, block_number)
elif type is RPCType.geth:
block = await _fetch_block_geth(w3, base_provider, block_number)
else:
logger.error(f"RPCType not known - {type}")
raise ValueError
return block
async def _fetch_block(
w3, base_provider, geth: bool, block_number: int, retries: int = 0
async def _fetch_block_parity(
w3, base_provider, block_number: int, retries: int = 0
) -> Block:
if not geth:
block_json, receipts_json, traces_json, base_fee_per_gas = await asyncio.gather(
w3.eth.get_block(block_number),
base_provider.make_request("eth_getBlockReceipts", [block_number]),
@ -74,20 +83,24 @@ async def _fetch_block(
)
if retries < 3:
await asyncio.sleep(5)
return await _fetch_block(
w3, base_provider, geth, block_number, retries
)
return await _fetch_block_parity(w3, base_provider, block_number, retries)
else:
raise
else:
# print(block_number)
async def _fetch_block_geth(
w3, base_provider, block_number: int, retries: int = 0
) -> Block:
block_json = await asyncio.gather(w3.eth.get_block(block_number))
try:
# Separate calls to help with load during block tracing
traces = await geth_get_tx_traces_parity_format(base_provider, block_json[0])
geth_tx_receipts = await geth_get_tx_receipts_async(
base_provider.endpoint_uri, block_json[0]["transactions"]
base_provider, block_json[0]["transactions"]
)
receipts = geth_receipts_translator(block_json[0], geth_tx_receipts)
base_fee_per_gas = 0
base_fee_per_gas = 0 # Polygon specific, TODO for other chains
return Block(
block_number=block_number,
@ -97,6 +110,15 @@ async def _fetch_block(
traces=traces,
receipts=receipts,
)
except KeyError as e:
logger.warning(
f"Failed to create objects from block: {block_number}: {e}, retrying: {retries + 1} / 3"
)
if retries < 3:
await asyncio.sleep(5)
return await _fetch_block_geth(w3, base_provider, block_number, retries)
else:
raise
def _find_block(
@ -245,13 +267,6 @@ def unwrap_tx_trace_for_parity(
block_json, tx_pos_in_block, tx_trace, position=[]
) -> List[Trace]:
response_list = []
_calltype_mapping = {
"CALL": "call",
"DELEGATECALL": "delegateCall",
"CREATE": "create",
"SUICIDE": "suicide",
"REWARD": "reward",
}
try:
if tx_trace["type"] == "STATICCALL":
return []
@ -279,7 +294,8 @@ def unwrap_tx_trace_for_parity(
type=TraceType(_calltype_mapping[tx_trace["type"]]),
)
)
except Exception:
except Exception as e:
logger.warn(f"error while unwraping tx trace for parity {e}")
return []
if "calls" in tx_trace.keys():
@ -292,28 +308,20 @@ def unwrap_tx_trace_for_parity(
return response_list
async def geth_get_tx_receipts_task(session, endpoint_uri, tx):
data = {
"jsonrpc": "2.0",
"id": "0",
"method": "eth_getTransactionReceipt",
"params": [tx.hex()],
}
async with session.post(endpoint_uri, json=data) as response:
if response.status != 200:
response.raise_for_status()
return await response.text()
async def geth_get_tx_receipts_task(base_provider, tx):
receipt = await base_provider.make_request("eth_getTransactionReceipt", [tx.hex()])
return receipt
async def geth_get_tx_receipts_async(endpoint_uri, transactions):
async def geth_get_tx_receipts_async(base_provider, transactions):
geth_tx_receipts = []
async with aiohttp.ClientSession() as session:
tasks = [
asyncio.create_task(geth_get_tx_receipts_task(session, endpoint_uri, tx))
asyncio.create_task(geth_get_tx_receipts_task(base_provider, tx))
for tx in transactions
]
geth_tx_receipts = await asyncio.gather(*tasks)
return [json.loads(tx_receipts) for tx_receipts in geth_tx_receipts]
# return [json.loads(tx_receipts) for tx_receipts in geth_tx_receipts]
return geth_tx_receipts
def geth_receipts_translator(block_json, geth_tx_receipts) -> List[Receipt]:
@ -331,9 +339,8 @@ def geth_receipts_translator(block_json, geth_tx_receipts) -> List[Receipt]:
def unwrap_tx_receipt_for_parity(block_json, tx_pos_in_block, tx_receipt) -> Receipt:
try:
if tx_pos_in_block != int(tx_receipt["transactionIndex"], 16):
print(
logger.info(
"Alert the position of transaction in block is mismatched ",
tx_pos_in_block,
tx_receipt["transactionIndex"],
@ -347,8 +354,3 @@ def unwrap_tx_receipt_for_parity(block_json, tx_pos_in_block, tx_receipt) -> Rec
cumulative_gas_used=tx_receipt["cumulativeGasUsed"],
to=tx_receipt["to"],
)
except Exception as e:
print("error while decoding receipt", tx_receipt, e)
return Receipt()

View File

@ -1,3 +1,7 @@
"""
Modified asynchronous geth_poa_middleware which mirrors functionality of
https://github.com/ethereum/web3.py/blob/master/web3/middleware/geth_poa.py
"""
from typing import (
Any,
Callable,

View File

@ -34,6 +34,7 @@ from mev_inspect.miner_payments import get_miner_payments
from mev_inspect.swaps import get_swaps
from mev_inspect.transfers import get_transfers
from mev_inspect.liquidations import get_liquidations
from mev_inspect.utils import RPCType
logger = logging.getLogger(__name__)
@ -43,7 +44,7 @@ async def inspect_block(
inspect_db_session: orm.Session,
base_provider,
w3: Web3,
geth: bool,
type: RPCType,
trace_classifier: TraceClassifier,
block_number: int,
trace_db_session: Optional[orm.Session],
@ -52,7 +53,7 @@ async def inspect_block(
block = await create_from_block_number(
base_provider,
w3,
geth,
type,
block_number,
trace_db_session,
)

View File

@ -12,6 +12,7 @@ from mev_inspect.block import create_from_block_number
from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.inspect_block import inspect_block
from mev_inspect.provider import get_base_provider
from mev_inspect.utils import RPCType
logger = logging.getLogger(__name__)
@ -22,25 +23,15 @@ class MEVInspector:
rpc: str,
inspect_db_session: orm.Session,
trace_db_session: Optional[orm.Session],
geth: bool = False,
type: RPCType = RPCType.parity,
max_concurrency: int = 1,
request_timeout: int = 300,
):
self.inspect_db_session = inspect_db_session
self.trace_db_session = trace_db_session
self.base_provider = get_base_provider(rpc, request_timeout, geth)
self.geth = geth
self.base_provider = get_base_provider(rpc, request_timeout, type)
self.type = type
self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
# if geth:
# self.w3 = Web3(
# self.base_provider,
# modules={"eth": (AsyncEth,)},
# middlewares=[],
# )
# else:
# self.w3 = Web3(
# self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[]
# )
self.trace_classifier = TraceClassifier()
self.max_concurrency = asyncio.Semaphore(max_concurrency)
@ -48,7 +39,7 @@ class MEVInspector:
return await create_from_block_number(
base_provider=self.base_provider,
w3=self.w3,
geth=self.geth,
type=self.type,
block_number=block_number,
trace_db_session=self.trace_db_session,
)
@ -58,7 +49,7 @@ class MEVInspector:
self.inspect_db_session,
self.base_provider,
self.w3,
self.geth,
self.type,
self.trace_classifier,
block,
trace_db_session=self.trace_db_session,
@ -87,7 +78,7 @@ class MEVInspector:
self.inspect_db_session,
self.base_provider,
self.w3,
self.geth,
self.type,
self.trace_classifier,
block_number,
trace_db_session=self.trace_db_session,

View File

@ -2,13 +2,14 @@ from web3 import Web3, AsyncHTTPProvider
from mev_inspect.retry import http_retry_with_backoff_request_middleware
from mev_inspect.geth_poa_middleware import geth_poa_middleware
from mev_inspect.utils import RPCType
def get_base_provider(
rpc: str, request_timeout: int = 500, geth: bool = False
rpc: str, request_timeout: int = 500, type: RPCType = RPCType.parity
) -> Web3.AsyncHTTPProvider:
base_provider = AsyncHTTPProvider(rpc, request_kwargs={"timeout": request_timeout})
if geth:
if type is RPCType.geth:
base_provider.middlewares += (
geth_poa_middleware,
http_retry_with_backoff_request_middleware,

View File

@ -1,5 +1,11 @@
from enum import Enum
from hexbytes._utils import hexstr_to_bytes
class RPCType(Enum):
parity = 0
geth = 1
def hex_to_int(value: str) -> int:
return int.from_bytes(hexstr_to_bytes(value), byteorder="big")