asyncio-based concurrent backfilling

This commit is contained in:
carlomazzaferro 2021-10-20 17:12:21 +01:00
parent a5e4a2d1d4
commit 4f20c540e6
No known key found for this signature in database
GPG Key ID: 0CED3103EF7B2187
7 changed files with 117 additions and 55 deletions

View File

@ -36,3 +36,9 @@ docker_build_with_restart("mev-inspect-py", ".",
)
k8s_yaml(helm('./k8s/mev-inspect', name='mev-inspect'))
k8s_resource(workload="mev-inspect", resource_deps=["postgresql-postgresql"])
local_resource(
'pg-port-forward',
serve_cmd='kubectl port-forward --namespace default svc/postgresql 5432:5432',
resource_deps=["postgresql-postgresql"]
)

97
cli.py
View File

@ -1,43 +1,62 @@
import os
import asyncio
import logging
import os
import sys
from functools import wraps
import click
from web3 import Web3
from web3.eth import AsyncEth
from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspect_block import inspect_block
from mev_inspect.provider import get_base_provider
from mev_inspect.retry import http_retry_with_backoff_request_middleware
RPC_URL_ENV = "RPC_URL"
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
semaphore: asyncio.Semaphore
@click.group()
def cli():
pass
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
return wrapper
@cli.command()
@click.argument("block_number", type=int)
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@click.option("--cache/--no-cache", default=True)
def inspect_block_command(block_number: int, rpc: str, cache: bool):
@coro
async def inspect_block_command(block_number: int, rpc: str, cache: bool):
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
base_provider = get_base_provider(rpc)
w3 = Web3(base_provider)
w3 = Web3(base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
trace_classifier = TraceClassifier()
if not cache:
logger.info("Skipping cache")
inspect_block(
await inspect_block(
inspect_db_session,
base_provider,
w3,
@ -52,30 +71,70 @@ def inspect_block_command(block_number: int, rpc: str, cache: bool):
@click.argument("before_block", type=int)
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@click.option("--cache/--no-cache", default=True)
def inspect_many_blocks_command(
after_block: int, before_block: int, rpc: str, cache: bool
@click.option(
"--max-concurrency",
type=int,
help="maximum number of concurrent connections",
default=5,
)
@click.option(
"--request-timeout", type=int, help="timeout for requests to nodes", default=500
)
@coro
async def inspect_many_blocks_command(
after_block: int,
before_block: int,
rpc: str,
cache: bool,
max_concurrency: int,
request_timeout: int,
):
global semaphore # pylint: disable=global-statement
semaphore = asyncio.Semaphore(max_concurrency)
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
base_provider = get_base_provider(rpc)
w3 = Web3(base_provider)
base_provider = get_base_provider(rpc, request_timeout=request_timeout)
w3 = Web3(
base_provider,
modules={"eth": (AsyncEth,)},
middlewares=[http_retry_with_backoff_request_middleware],
)
trace_classifier = TraceClassifier()
if not cache:
logger.info("Skipping cache")
for i, block_number in enumerate(range(after_block, before_block)):
block_message = (
f"Running for {block_number} ({i+1}/{before_block - after_block})"
)
dashes = "-" * len(block_message)
logger.info(dashes)
logger.info(block_message)
logger.info(dashes)
tasks = []
inspect_block(
for block_number in range(after_block, before_block):
tasks.append(
asyncio.ensure_future(
safe_inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session,
)
)
)
logger.info(f"Gathered {len(tasks)} blocks to inspect")
await asyncio.gather(*tasks)
async def safe_inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session,
):
async with semaphore:
return await inspect_block(
inspect_db_session,
base_provider,
w3,

View File

@ -16,7 +16,7 @@ def get_latest_block_number(w3: Web3) -> int:
return int(w3.eth.get_block("latest")["number"])
def create_from_block_number(
async def create_from_block_number(
base_provider,
w3: Web3,
block_number: int,
@ -28,25 +28,27 @@ def create_from_block_number(
block = _find_block(trace_db_session, block_number)
if block is None:
return _fetch_block(w3, base_provider, block_number)
block = await _fetch_block(w3, base_provider, block_number)
return block
else:
return block
def _fetch_block(
async def _fetch_block(
w3,
base_provider,
block_number: int,
) -> Block:
block_json = w3.eth.get_block(block_number)
receipts_json = base_provider.make_request("eth_getBlockReceipts", [block_number])
traces_json = w3.parity.trace_block(block_number)
block_json = await w3.eth.get_block(block_number)
receipts_json = await base_provider.make_request(
"eth_getBlockReceipts", [block_number]
)
traces_json = await base_provider.make_request("trace_block", [block_number])
receipts: List[Receipt] = [
Receipt(**receipt) for receipt in receipts_json["result"]
]
traces = [Trace(**trace_json) for trace_json in traces_json]
base_fee_per_gas = fetch_base_fee_per_gas(w3, block_number)
traces = [Trace(**trace_json) for trace_json in traces_json["result"]]
base_fee_per_gas = await fetch_base_fee_per_gas(w3, block_number)
return Block(
block_number=block_number,

View File

@ -1,9 +1,10 @@
from web3 import Web3
def fetch_base_fee_per_gas(w3: Web3, block_number: int) -> int:
base_fees = w3.eth.fee_history(1, block_number)["baseFeePerGas"]
if len(base_fees) == 0:
async def fetch_base_fee_per_gas(w3: Web3, block_number: int) -> int:
base_fees = await w3.eth.fee_history(1, block_number)
base_fees_per_gas = base_fees["baseFeePerGas"]
if len(base_fees_per_gas) == 0:
raise RuntimeError("Unexpected error - no fees returned")
return base_fees[0]
return base_fees_per_gas[0]

View File

@ -35,7 +35,7 @@ from mev_inspect.liquidations import get_liquidations
logger = logging.getLogger(__name__)
def inspect_block(
async def inspect_block(
inspect_db_session: orm.Session,
base_provider,
w3: Web3,
@ -44,47 +44,49 @@ def inspect_block(
trace_db_session: Optional[orm.Session],
should_write_classified_traces: bool = True,
):
block = create_from_block_number(
block = await create_from_block_number(
base_provider,
w3,
block_number,
trace_db_session,
)
logger.info(f"Total traces: {len(block.traces)}")
logger.info(f"Block: {block_number} -- Total traces: {len(block.traces)}")
total_transactions = len(
set(t.transaction_hash for t in block.traces if t.transaction_hash is not None)
)
logger.info(f"Total transactions: {total_transactions}")
logger.info(f"Block: {block_number} -- Total transactions: {total_transactions}")
classified_traces = trace_clasifier.classify(block.traces)
logger.info(f"Returned {len(classified_traces)} classified traces")
logger.info(
f"Block: {block_number} -- Returned {len(classified_traces)} classified traces"
)
if should_write_classified_traces:
delete_classified_traces_for_block(inspect_db_session, block_number)
write_classified_traces(inspect_db_session, classified_traces)
transfers = get_transfers(classified_traces)
logger.info(f"Found {len(transfers)} transfers")
logger.info(f"Block: {block_number} -- Found {len(transfers)} transfers")
delete_transfers_for_block(inspect_db_session, block_number)
write_transfers(inspect_db_session, transfers)
swaps = get_swaps(classified_traces)
logger.info(f"Found {len(swaps)} swaps")
logger.info(f"Block: {block_number} -- Found {len(swaps)} swaps")
delete_swaps_for_block(inspect_db_session, block_number)
write_swaps(inspect_db_session, swaps)
arbitrages = get_arbitrages(swaps)
logger.info(f"Found {len(arbitrages)} arbitrages")
logger.info(f"Block: {block_number} -- Found {len(arbitrages)} arbitrages")
delete_arbitrages_for_block(inspect_db_session, block_number)
write_arbitrages(inspect_db_session, arbitrages)
liquidations = get_liquidations(classified_traces)
logger.info(f"Found {len(liquidations)} liquidations")
logger.info(f"Block: {block_number} -- Found {len(liquidations)} liquidations")
delete_liquidations_for_block(inspect_db_session, block_number)
write_liquidations(inspect_db_session, liquidations)

View File

@ -1,14 +1,6 @@
from web3 import Web3
from mev_inspect.retry import http_retry_with_backoff_request_middleware
from web3 import Web3, AsyncHTTPProvider
def get_base_provider(rpc: str) -> Web3.HTTPProvider:
base_provider = Web3.HTTPProvider(rpc)
base_provider.middlewares.remove("http_retry_request")
base_provider.middlewares.add(
http_retry_with_backoff_request_middleware,
"http_retry_with_backoff",
)
def get_base_provider(rpc: str, request_timeout: int = 500) -> Web3.AsyncHTTPProvider:
base_provider = AsyncHTTPProvider(rpc, request_kwargs={"timeout": request_timeout})
return base_provider

View File

@ -20,7 +20,7 @@ from web3.types import (
)
def exception_retry_with_backoff_middleware(
async def exception_retry_with_backoff_middleware(
make_request: Callable[[RPCEndpoint, Any], RPCResponse],
web3: Web3, # pylint: disable=unused-argument
errors: Collection[Type[BaseException]],
@ -51,9 +51,9 @@ def exception_retry_with_backoff_middleware(
return middleware
def http_retry_with_backoff_request_middleware(
async def http_retry_with_backoff_request_middleware(
make_request: Callable[[RPCEndpoint, Any], Any], web3: Web3
) -> Callable[[RPCEndpoint, Any], Any]:
return exception_retry_with_backoff_middleware(
return await exception_retry_with_backoff_middleware(
make_request, web3, (ConnectionError, HTTPError, Timeout, TooManyRedirects)
)