Merge pull request #123 from flashbots/listener-async

Support asyncio in listener
This commit is contained in:
Luke Van Seters 2021-11-12 19:02:33 -05:00 committed by GitHub
commit ec49c03484
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 68 additions and 52 deletions

41
cli.py
View File

@ -1,10 +1,9 @@
import asyncio
import os import os
import signal
from functools import wraps
import click import click
from mev_inspect.concurrency import coro
from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspector import MEVInspector from mev_inspect.inspector import MEVInspector
RPC_URL_ENV = "RPC_URL" RPC_URL_ENV = "RPC_URL"
@ -15,31 +14,15 @@ def cli():
pass pass
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
def cancel_task_callback():
for task in asyncio.all_tasks():
task.cancel()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, cancel_task_callback)
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
return wrapper
@cli.command() @cli.command()
@click.argument("block_number", type=int) @click.argument("block_number", type=int)
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, "")) @click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@coro @coro
async def inspect_block_command(block_number: int, rpc: str): async def inspect_block_command(block_number: int, rpc: str):
inspector = MEVInspector(rpc=rpc) inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session)
await inspector.inspect_single_block(block=block_number) await inspector.inspect_single_block(block=block_number)
@ -48,7 +31,10 @@ async def inspect_block_command(block_number: int, rpc: str):
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, "")) @click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@coro @coro
async def fetch_block_command(block_number: int, rpc: str): async def fetch_block_command(block_number: int, rpc: str):
inspector = MEVInspector(rpc=rpc) inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session)
block = await inspector.create_from_block(block_number=block_number) block = await inspector.create_from_block(block_number=block_number)
print(block.json()) print(block.json())
@ -74,8 +60,13 @@ async def inspect_many_blocks_command(
max_concurrency: int, max_concurrency: int,
request_timeout: int, request_timeout: int,
): ):
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
inspector = MEVInspector( inspector = MEVInspector(
rpc=rpc, rpc,
inspect_db_session,
trace_db_session,
max_concurrency=max_concurrency, max_concurrency=max_concurrency,
request_timeout=request_timeout, request_timeout=request_timeout,
) )

View File

@ -1,17 +1,15 @@
import asyncio
import logging import logging
import os import os
import time
from web3 import Web3
from mev_inspect.block import get_latest_block_number from mev_inspect.block import get_latest_block_number
from mev_inspect.concurrency import coro
from mev_inspect.crud.latest_block_update import ( from mev_inspect.crud.latest_block_update import (
find_latest_block_update, find_latest_block_update,
update_latest_block, update_latest_block,
) )
from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.db import get_inspect_session, get_trace_session from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspect_block import inspect_block from mev_inspect.inspector import MEVInspector
from mev_inspect.provider import get_base_provider from mev_inspect.provider import get_base_provider
from mev_inspect.signal_handler import GracefulKiller from mev_inspect.signal_handler import GracefulKiller
@ -23,7 +21,8 @@ logger = logging.getLogger(__name__)
BLOCK_NUMBER_LAG = 5 BLOCK_NUMBER_LAG = 5
def run(): @coro
async def run():
rpc = os.getenv("RPC_URL") rpc = os.getenv("RPC_URL")
if rpc is None: if rpc is None:
raise RuntimeError("Missing environment variable RPC_URL") raise RuntimeError("Missing environment variable RPC_URL")
@ -34,21 +33,23 @@ def run():
inspect_db_session = get_inspect_session() inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session() trace_db_session = get_trace_session()
trace_classifier = TraceClassifier()
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session)
base_provider = get_base_provider(rpc) base_provider = get_base_provider(rpc)
w3 = Web3(base_provider)
latest_block_number = get_latest_block_number(w3) latest_block_number = await get_latest_block_number(base_provider)
while not killer.kill_now: while not killer.kill_now:
last_written_block = find_latest_block_update(inspect_db_session) last_written_block = find_latest_block_update(inspect_db_session)
logger.info(f"Latest block: {latest_block_number}") logger.info(f"Latest block: {latest_block_number}")
logger.info(f"Last written block: {last_written_block}") logger.info(f"Last written block: {last_written_block}")
if (last_written_block is None) or ( if last_written_block is None:
last_written_block < (latest_block_number - BLOCK_NUMBER_LAG) # maintain lag if no blocks written yet
): last_written_block = latest_block_number - 1
if last_written_block < (latest_block_number - BLOCK_NUMBER_LAG):
block_number = ( block_number = (
latest_block_number latest_block_number
if last_written_block is None if last_written_block is None
@ -57,18 +58,11 @@ def run():
logger.info(f"Writing block: {block_number}") logger.info(f"Writing block: {block_number}")
inspect_block( await inspector.inspect_single_block(block=block_number)
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session=trace_db_session,
)
update_latest_block(inspect_db_session, block_number) update_latest_block(inspect_db_session, block_number)
else: else:
time.sleep(5) await asyncio.sleep(5)
latest_block_number = get_latest_block_number(w3) latest_block_number = await get_latest_block_number(base_provider)
logger.info("Stopping...") logger.info("Stopping...")

View File

@ -11,6 +11,7 @@ from mev_inspect.fees import fetch_base_fee_per_gas
from mev_inspect.schemas.blocks import Block from mev_inspect.schemas.blocks import Block
from mev_inspect.schemas.receipts import Receipt from mev_inspect.schemas.receipts import Receipt
from mev_inspect.schemas.traces import Trace, TraceType from mev_inspect.schemas.traces import Trace, TraceType
from mev_inspect.utils import hex_to_int
cache_directory = "./cache" cache_directory = "./cache"
@ -18,8 +19,13 @@ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_latest_block_number(w3: Web3) -> int: async def get_latest_block_number(base_provider) -> int:
return int(w3.eth.get_block("latest")["number"]) latest_block = await base_provider.make_request(
"eth_getBlockByNumber",
["latest", False],
)
return hex_to_int(latest_block["result"]["number"])
async def create_from_block_number( async def create_from_block_number(

View File

@ -0,0 +1,22 @@
import asyncio
import signal
from functools import wraps
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
def cancel_task_callback():
for task in asyncio.all_tasks():
task.cancel()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, cancel_task_callback)
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
return wrapper

View File

@ -3,13 +3,14 @@ import logging
import sys import sys
import traceback import traceback
from asyncio import CancelledError from asyncio import CancelledError
from typing import Optional
from sqlalchemy import orm
from web3 import Web3 from web3 import Web3
from web3.eth import AsyncEth from web3.eth import AsyncEth
from mev_inspect.block import create_from_block_number from mev_inspect.block import create_from_block_number
from mev_inspect.classifiers.trace import TraceClassifier from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspect_block import inspect_block from mev_inspect.inspect_block import inspect_block
from mev_inspect.provider import get_base_provider from mev_inspect.provider import get_base_provider
@ -21,11 +22,13 @@ class MEVInspector:
def __init__( def __init__(
self, self,
rpc: str, rpc: str,
inspect_db_session: orm.Session,
trace_db_session: Optional[orm.Session],
max_concurrency: int = 1, max_concurrency: int = 1,
request_timeout: int = 300, request_timeout: int = 300,
): ):
self.inspect_db_session = get_inspect_session() self.inspect_db_session = inspect_db_session
self.trace_db_session = get_trace_session() self.trace_db_session = trace_db_session
self.base_provider = get_base_provider(rpc, request_timeout=request_timeout) self.base_provider = get_base_provider(rpc, request_timeout=request_timeout)
self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[]) self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
self.trace_classifier = TraceClassifier() self.trace_classifier = TraceClassifier()