asyncio-based concurrent backfilling

This commit is contained in:
carlomazzaferro 2021-10-20 17:12:21 +01:00
parent a5e4a2d1d4
commit 4f20c540e6
No known key found for this signature in database
GPG Key ID: 0CED3103EF7B2187
7 changed files with 117 additions and 55 deletions

View File

@ -36,3 +36,9 @@ docker_build_with_restart("mev-inspect-py", ".",
) )
k8s_yaml(helm('./k8s/mev-inspect', name='mev-inspect')) k8s_yaml(helm('./k8s/mev-inspect', name='mev-inspect'))
k8s_resource(workload="mev-inspect", resource_deps=["postgresql-postgresql"]) k8s_resource(workload="mev-inspect", resource_deps=["postgresql-postgresql"])
local_resource(
'pg-port-forward',
serve_cmd='kubectl port-forward --namespace default svc/postgresql 5432:5432',
resource_deps=["postgresql-postgresql"]
)

97
cli.py
View File

@ -1,43 +1,62 @@
import os import asyncio
import logging import logging
import os
import sys import sys
from functools import wraps
import click import click
from web3 import Web3 from web3 import Web3
from web3.eth import AsyncEth
from mev_inspect.classifiers.trace import TraceClassifier from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.db import get_inspect_session, get_trace_session from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspect_block import inspect_block from mev_inspect.inspect_block import inspect_block
from mev_inspect.provider import get_base_provider from mev_inspect.provider import get_base_provider
from mev_inspect.retry import http_retry_with_backoff_request_middleware
RPC_URL_ENV = "RPC_URL" RPC_URL_ENV = "RPC_URL"
logging.basicConfig(stream=sys.stdout, level=logging.INFO) logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
semaphore: asyncio.Semaphore
@click.group() @click.group()
def cli(): def cli():
pass pass
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
return wrapper
@cli.command() @cli.command()
@click.argument("block_number", type=int) @click.argument("block_number", type=int)
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, "")) @click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@click.option("--cache/--no-cache", default=True) @click.option("--cache/--no-cache", default=True)
def inspect_block_command(block_number: int, rpc: str, cache: bool): @coro
async def inspect_block_command(block_number: int, rpc: str, cache: bool):
inspect_db_session = get_inspect_session() inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session() trace_db_session = get_trace_session()
base_provider = get_base_provider(rpc) base_provider = get_base_provider(rpc)
w3 = Web3(base_provider) w3 = Web3(base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
trace_classifier = TraceClassifier() trace_classifier = TraceClassifier()
if not cache: if not cache:
logger.info("Skipping cache") logger.info("Skipping cache")
inspect_block( await inspect_block(
inspect_db_session, inspect_db_session,
base_provider, base_provider,
w3, w3,
@ -52,30 +71,70 @@ def inspect_block_command(block_number: int, rpc: str, cache: bool):
@click.argument("before_block", type=int) @click.argument("before_block", type=int)
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, "")) @click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@click.option("--cache/--no-cache", default=True) @click.option("--cache/--no-cache", default=True)
def inspect_many_blocks_command( @click.option(
after_block: int, before_block: int, rpc: str, cache: bool "--max-concurrency",
type=int,
help="maximum number of concurrent connections",
default=5,
)
@click.option(
"--request-timeout", type=int, help="timeout for requests to nodes", default=500
)
@coro
async def inspect_many_blocks_command(
after_block: int,
before_block: int,
rpc: str,
cache: bool,
max_concurrency: int,
request_timeout: int,
): ):
global semaphore # pylint: disable=global-statement
semaphore = asyncio.Semaphore(max_concurrency)
inspect_db_session = get_inspect_session() inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session() trace_db_session = get_trace_session()
base_provider = get_base_provider(rpc) base_provider = get_base_provider(rpc, request_timeout=request_timeout)
w3 = Web3(base_provider) w3 = Web3(
base_provider,
modules={"eth": (AsyncEth,)},
middlewares=[http_retry_with_backoff_request_middleware],
)
trace_classifier = TraceClassifier() trace_classifier = TraceClassifier()
if not cache: if not cache:
logger.info("Skipping cache") logger.info("Skipping cache")
for i, block_number in enumerate(range(after_block, before_block)): tasks = []
block_message = (
f"Running for {block_number} ({i+1}/{before_block - after_block})"
)
dashes = "-" * len(block_message)
logger.info(dashes)
logger.info(block_message)
logger.info(dashes)
inspect_block( for block_number in range(after_block, before_block):
tasks.append(
asyncio.ensure_future(
safe_inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session,
)
)
)
logger.info(f"Gathered {len(tasks)} blocks to inspect")
await asyncio.gather(*tasks)
async def safe_inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session,
):
async with semaphore:
return await inspect_block(
inspect_db_session, inspect_db_session,
base_provider, base_provider,
w3, w3,

View File

@ -16,7 +16,7 @@ def get_latest_block_number(w3: Web3) -> int:
return int(w3.eth.get_block("latest")["number"]) return int(w3.eth.get_block("latest")["number"])
def create_from_block_number( async def create_from_block_number(
base_provider, base_provider,
w3: Web3, w3: Web3,
block_number: int, block_number: int,
@ -28,25 +28,27 @@ def create_from_block_number(
block = _find_block(trace_db_session, block_number) block = _find_block(trace_db_session, block_number)
if block is None: if block is None:
return _fetch_block(w3, base_provider, block_number) block = await _fetch_block(w3, base_provider, block_number)
return block
else: else:
return block return block
def _fetch_block( async def _fetch_block(
w3, w3,
base_provider, base_provider,
block_number: int, block_number: int,
) -> Block: ) -> Block:
block_json = w3.eth.get_block(block_number) block_json = await w3.eth.get_block(block_number)
receipts_json = base_provider.make_request("eth_getBlockReceipts", [block_number]) receipts_json = await base_provider.make_request(
traces_json = w3.parity.trace_block(block_number) "eth_getBlockReceipts", [block_number]
)
traces_json = await base_provider.make_request("trace_block", [block_number])
receipts: List[Receipt] = [ receipts: List[Receipt] = [
Receipt(**receipt) for receipt in receipts_json["result"] Receipt(**receipt) for receipt in receipts_json["result"]
] ]
traces = [Trace(**trace_json) for trace_json in traces_json] traces = [Trace(**trace_json) for trace_json in traces_json["result"]]
base_fee_per_gas = fetch_base_fee_per_gas(w3, block_number) base_fee_per_gas = await fetch_base_fee_per_gas(w3, block_number)
return Block( return Block(
block_number=block_number, block_number=block_number,

View File

@ -1,9 +1,10 @@
from web3 import Web3 from web3 import Web3
def fetch_base_fee_per_gas(w3: Web3, block_number: int) -> int: async def fetch_base_fee_per_gas(w3: Web3, block_number: int) -> int:
base_fees = w3.eth.fee_history(1, block_number)["baseFeePerGas"] base_fees = await w3.eth.fee_history(1, block_number)
if len(base_fees) == 0: base_fees_per_gas = base_fees["baseFeePerGas"]
if len(base_fees_per_gas) == 0:
raise RuntimeError("Unexpected error - no fees returned") raise RuntimeError("Unexpected error - no fees returned")
return base_fees[0] return base_fees_per_gas[0]

View File

@ -35,7 +35,7 @@ from mev_inspect.liquidations import get_liquidations
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def inspect_block( async def inspect_block(
inspect_db_session: orm.Session, inspect_db_session: orm.Session,
base_provider, base_provider,
w3: Web3, w3: Web3,
@ -44,47 +44,49 @@ def inspect_block(
trace_db_session: Optional[orm.Session], trace_db_session: Optional[orm.Session],
should_write_classified_traces: bool = True, should_write_classified_traces: bool = True,
): ):
block = create_from_block_number( block = await create_from_block_number(
base_provider, base_provider,
w3, w3,
block_number, block_number,
trace_db_session, trace_db_session,
) )
logger.info(f"Total traces: {len(block.traces)}") logger.info(f"Block: {block_number} -- Total traces: {len(block.traces)}")
total_transactions = len( total_transactions = len(
set(t.transaction_hash for t in block.traces if t.transaction_hash is not None) set(t.transaction_hash for t in block.traces if t.transaction_hash is not None)
) )
logger.info(f"Total transactions: {total_transactions}") logger.info(f"Block: {block_number} -- Total transactions: {total_transactions}")
classified_traces = trace_clasifier.classify(block.traces) classified_traces = trace_clasifier.classify(block.traces)
logger.info(f"Returned {len(classified_traces)} classified traces") logger.info(
f"Block: {block_number} -- Returned {len(classified_traces)} classified traces"
)
if should_write_classified_traces: if should_write_classified_traces:
delete_classified_traces_for_block(inspect_db_session, block_number) delete_classified_traces_for_block(inspect_db_session, block_number)
write_classified_traces(inspect_db_session, classified_traces) write_classified_traces(inspect_db_session, classified_traces)
transfers = get_transfers(classified_traces) transfers = get_transfers(classified_traces)
logger.info(f"Found {len(transfers)} transfers") logger.info(f"Block: {block_number} -- Found {len(transfers)} transfers")
delete_transfers_for_block(inspect_db_session, block_number) delete_transfers_for_block(inspect_db_session, block_number)
write_transfers(inspect_db_session, transfers) write_transfers(inspect_db_session, transfers)
swaps = get_swaps(classified_traces) swaps = get_swaps(classified_traces)
logger.info(f"Found {len(swaps)} swaps") logger.info(f"Block: {block_number} -- Found {len(swaps)} swaps")
delete_swaps_for_block(inspect_db_session, block_number) delete_swaps_for_block(inspect_db_session, block_number)
write_swaps(inspect_db_session, swaps) write_swaps(inspect_db_session, swaps)
arbitrages = get_arbitrages(swaps) arbitrages = get_arbitrages(swaps)
logger.info(f"Found {len(arbitrages)} arbitrages") logger.info(f"Block: {block_number} -- Found {len(arbitrages)} arbitrages")
delete_arbitrages_for_block(inspect_db_session, block_number) delete_arbitrages_for_block(inspect_db_session, block_number)
write_arbitrages(inspect_db_session, arbitrages) write_arbitrages(inspect_db_session, arbitrages)
liquidations = get_liquidations(classified_traces) liquidations = get_liquidations(classified_traces)
logger.info(f"Found {len(liquidations)} liquidations") logger.info(f"Block: {block_number} -- Found {len(liquidations)} liquidations")
delete_liquidations_for_block(inspect_db_session, block_number) delete_liquidations_for_block(inspect_db_session, block_number)
write_liquidations(inspect_db_session, liquidations) write_liquidations(inspect_db_session, liquidations)

View File

@ -1,14 +1,6 @@
from web3 import Web3 from web3 import Web3, AsyncHTTPProvider
from mev_inspect.retry import http_retry_with_backoff_request_middleware
def get_base_provider(rpc: str) -> Web3.HTTPProvider: def get_base_provider(rpc: str, request_timeout: int = 500) -> Web3.AsyncHTTPProvider:
base_provider = Web3.HTTPProvider(rpc) base_provider = AsyncHTTPProvider(rpc, request_kwargs={"timeout": request_timeout})
base_provider.middlewares.remove("http_retry_request")
base_provider.middlewares.add(
http_retry_with_backoff_request_middleware,
"http_retry_with_backoff",
)
return base_provider return base_provider

View File

@ -20,7 +20,7 @@ from web3.types import (
) )
def exception_retry_with_backoff_middleware( async def exception_retry_with_backoff_middleware(
make_request: Callable[[RPCEndpoint, Any], RPCResponse], make_request: Callable[[RPCEndpoint, Any], RPCResponse],
web3: Web3, # pylint: disable=unused-argument web3: Web3, # pylint: disable=unused-argument
errors: Collection[Type[BaseException]], errors: Collection[Type[BaseException]],
@ -51,9 +51,9 @@ def exception_retry_with_backoff_middleware(
return middleware return middleware
def http_retry_with_backoff_request_middleware( async def http_retry_with_backoff_request_middleware(
make_request: Callable[[RPCEndpoint, Any], Any], web3: Web3 make_request: Callable[[RPCEndpoint, Any], Any], web3: Web3
) -> Callable[[RPCEndpoint, Any], Any]: ) -> Callable[[RPCEndpoint, Any], Any]:
return exception_retry_with_backoff_middleware( return await exception_retry_with_backoff_middleware(
make_request, web3, (ConnectionError, HTTPError, Timeout, TooManyRedirects) make_request, web3, (ConnectionError, HTTPError, Timeout, TooManyRedirects)
) )