updates for latest master

This commit is contained in:
Supragya Raj 2021-11-25 10:46:00 +01:00
parent 3fa8655e43
commit 8504ac5cca
4 changed files with 45 additions and 17 deletions

11
cli.py
View File

@ -3,8 +3,6 @@ import os
import sys import sys
import click import click
from web3 import Web3
from web3.middleware import geth_poa_middleware
from mev_inspect.concurrency import coro from mev_inspect.concurrency import coro
from mev_inspect.db import get_inspect_session, get_trace_session from mev_inspect.db import get_inspect_session, get_trace_session
@ -26,6 +24,7 @@ def cli():
@click.option("--geth/--no-geth", default=False) @click.option("--geth/--no-geth", default=False)
@coro @coro
async def inspect_block_command(block_number: int, rpc: str, geth: bool): async def inspect_block_command(block_number: int, rpc: str, geth: bool):
print("geth", geth)
inspect_db_session = get_inspect_session() inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session() trace_db_session = get_trace_session()
@ -41,7 +40,7 @@ async def fetch_block_command(block_number: int, rpc: str):
inspect_db_session = get_inspect_session() inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session() trace_db_session = get_trace_session()
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session, false) inspector = MEVInspector(rpc, inspect_db_session, trace_db_session, False)
block = await inspector.create_from_block(block_number=block_number) block = await inspector.create_from_block(block_number=block_number)
print(block.json()) print(block.json())
@ -51,7 +50,6 @@ async def fetch_block_command(block_number: int, rpc: str):
@click.argument("before_block", type=int) @click.argument("before_block", type=int)
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, "")) @click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@click.option("--geth/--no-geth", default=False) @click.option("--geth/--no-geth", default=False)
@click.option( @click.option(
"--max-concurrency", "--max-concurrency",
type=int, type=int,
@ -68,7 +66,7 @@ async def inspect_many_blocks_command(
rpc: str, rpc: str,
max_concurrency: int, max_concurrency: int,
request_timeout: int, request_timeout: int,
geth: bool geth: bool,
): ):
inspect_db_session = get_inspect_session() inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session() trace_db_session = get_trace_session()
@ -76,14 +74,15 @@ async def inspect_many_blocks_command(
rpc, rpc,
inspect_db_session, inspect_db_session,
trace_db_session, trace_db_session,
geth,
max_concurrency=max_concurrency, max_concurrency=max_concurrency,
request_timeout=request_timeout, request_timeout=request_timeout,
geth
) )
await inspector.inspect_many_blocks( await inspector.inspect_many_blocks(
after_block=after_block, before_block=before_block after_block=after_block, before_block=before_block
) )
def get_rpc_url() -> str: def get_rpc_url() -> str:
return os.environ["RPC_URL"] return os.environ["RPC_URL"]

View File

@ -2,7 +2,6 @@ import asyncio
import logging import logging
from typing import List, Optional from typing import List, Optional
import json import json
import asyncio
import aiohttp import aiohttp
from sqlalchemy import orm from sqlalchemy import orm
@ -40,13 +39,14 @@ async def create_from_block_number(
block = _find_block(trace_db_session, block_number) block = _find_block(trace_db_session, block_number)
if block is None: if block is None:
block = await _fetch_block(w3, base_provider, block_number) block = await _fetch_block(w3, base_provider, geth, block_number)
return block return block
else:
return block return block
async def _fetch_block(w3, base_provider, geth, block_number: int, retries: int = 0) -> Block: async def _fetch_block(
w3, base_provider, geth: bool, block_number: int, retries: int = 0
) -> Block:
if not geth: if not geth:
block_json, receipts_json, traces_json, base_fee_per_gas = await asyncio.gather( block_json, receipts_json, traces_json, base_fee_per_gas = await asyncio.gather(
w3.eth.get_block(block_number), w3.eth.get_block(block_number),
@ -60,16 +60,28 @@ async def _fetch_block(w3, base_provider, geth, block_number: int, retries: int
Receipt(**receipt) for receipt in receipts_json["result"] Receipt(**receipt) for receipt in receipts_json["result"]
] ]
traces = [Trace(**trace_json) for trace_json in traces_json["result"]] traces = [Trace(**trace_json) for trace_json in traces_json["result"]]
return Block(
block_number=block_number,
block_timestamp=block_json["timestamp"],
miner=block_json["miner"],
base_fee_per_gas=base_fee_per_gas,
traces=traces,
receipts=receipts,
)
except KeyError as e: except KeyError as e:
logger.warning( logger.warning(
f"Failed to create objects from block: {block_number}: {e}, retrying: {retries + 1} / 3" f"Failed to create objects from block: {block_number}: {e}, retrying: {retries + 1} / 3"
) )
if retries < 3: if retries < 3:
await asyncio.sleep(5) await asyncio.sleep(5)
return await _fetch_block(w3, base_provider, block_number, retries) return await _fetch_block(
w3, base_provider, geth, block_number, retries
)
else: else:
raise raise
else: else:
block_json = await asyncio.gather(w3.eth.get_block(block_number))
print(block_json)
traces = geth_get_tx_traces_parity_format(base_provider, block_json) traces = geth_get_tx_traces_parity_format(base_provider, block_json)
geth_tx_receipts = geth_get_tx_receipts( geth_tx_receipts = geth_get_tx_receipts(
base_provider, block_json["transactions"] base_provider, block_json["transactions"]
@ -118,6 +130,7 @@ def _find_block(
receipts=receipts, receipts=receipts,
) )
def _find_block_timestamp( def _find_block_timestamp(
trace_db_session: orm.Session, trace_db_session: orm.Session,
block_number: int, block_number: int,
@ -203,10 +216,11 @@ def get_transaction_hashes(calls: List[Trace]) -> List[str]:
return result return result
# Geth specific additions # Geth specific additions
def geth_get_tx_traces_parity_format(base_provider, block_json): def geth_get_tx_traces_parity_format(base_provider, block_json: dict):
block_hash = block_json["hash"] block_hash = block_json["hash"]
block_trace = geth_get_tx_traces(base_provider, block_hash) block_trace = geth_get_tx_traces(base_provider, block_hash)
parity_traces = [] parity_traces = []

View File

@ -44,7 +44,7 @@ async def inspect_block(
base_provider, base_provider,
w3: Web3, w3: Web3,
geth: bool, geth: bool,
trace_clasifier: TraceClassifier, trace_classifier: TraceClassifier,
block_number: int, block_number: int,
trace_db_session: Optional[orm.Session], trace_db_session: Optional[orm.Session],
should_write_classified_traces: bool = True, should_write_classified_traces: bool = True,

View File

@ -7,6 +7,7 @@ from typing import Optional
from sqlalchemy import orm from sqlalchemy import orm
from web3 import Web3 from web3 import Web3
from web3.eth import AsyncEth from web3.eth import AsyncEth
from web3.middleware import geth_poa_middleware
from mev_inspect.block import create_from_block_number from mev_inspect.block import create_from_block_number
from mev_inspect.classifiers.trace import TraceClassifier from mev_inspect.classifiers.trace import TraceClassifier
@ -22,13 +23,24 @@ class MEVInspector:
rpc: str, rpc: str,
inspect_db_session: orm.Session, inspect_db_session: orm.Session,
trace_db_session: Optional[orm.Session], trace_db_session: Optional[orm.Session],
geth: bool = False,
max_concurrency: int = 1, max_concurrency: int = 1,
request_timeout: int = 300, request_timeout: int = 300,
): ):
self.inspect_db_session = inspect_db_session self.inspect_db_session = inspect_db_session
self.trace_db_session = trace_db_session self.trace_db_session = trace_db_session
self.base_provider = get_base_provider(rpc, request_timeout=request_timeout) self.base_provider = get_base_provider(rpc, request_timeout=request_timeout)
self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[]) self.geth = geth
if geth:
self.w3 = Web3(
self.base_provider,
modules={"eth": (AsyncEth,)},
middlewares=[geth_poa_middleware],
)
else:
self.w3 = Web3(
self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[]
)
self.trace_classifier = TraceClassifier() self.trace_classifier = TraceClassifier()
self.max_concurrency = asyncio.Semaphore(max_concurrency) self.max_concurrency = asyncio.Semaphore(max_concurrency)
@ -36,6 +48,7 @@ class MEVInspector:
return await create_from_block_number( return await create_from_block_number(
base_provider=self.base_provider, base_provider=self.base_provider,
w3=self.w3, w3=self.w3,
geth=self.geth,
block_number=block_number, block_number=block_number,
trace_db_session=self.trace_db_session, trace_db_session=self.trace_db_session,
) )
@ -45,6 +58,7 @@ class MEVInspector:
self.inspect_db_session, self.inspect_db_session,
self.base_provider, self.base_provider,
self.w3, self.w3,
self.geth,
self.trace_classifier, self.trace_classifier,
block, block,
trace_db_session=self.trace_db_session, trace_db_session=self.trace_db_session,
@ -73,6 +87,7 @@ class MEVInspector:
self.inspect_db_session, self.inspect_db_session,
self.base_provider, self.base_provider,
self.w3, self.w3,
self.geth,
self.trace_classifier, self.trace_classifier,
block_number, block_number,
trace_db_session=self.trace_db_session, trace_db_session=self.trace_db_session,