Merge pull request #123 from flashbots/listener-async

Support asyncio in listener
This commit is contained in:
Luke Van Seters 2021-11-12 19:02:33 -05:00 committed by GitHub
commit ec49c03484
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 68 additions and 52 deletions

41
cli.py
View File

@ -1,10 +1,9 @@
import asyncio
import os
import signal
from functools import wraps
import click
from mev_inspect.concurrency import coro
from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspector import MEVInspector
RPC_URL_ENV = "RPC_URL"
@ -15,31 +14,15 @@ def cli():
pass
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
def cancel_task_callback():
for task in asyncio.all_tasks():
task.cancel()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, cancel_task_callback)
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
return wrapper
@cli.command()
@click.argument("block_number", type=int)
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@coro
async def inspect_block_command(block_number: int, rpc: str):
inspector = MEVInspector(rpc=rpc)
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session)
await inspector.inspect_single_block(block=block_number)
@ -48,7 +31,10 @@ async def inspect_block_command(block_number: int, rpc: str):
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
@coro
async def fetch_block_command(block_number: int, rpc: str):
inspector = MEVInspector(rpc=rpc)
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session)
block = await inspector.create_from_block(block_number=block_number)
print(block.json())
@ -74,8 +60,13 @@ async def inspect_many_blocks_command(
max_concurrency: int,
request_timeout: int,
):
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
inspector = MEVInspector(
rpc=rpc,
rpc,
inspect_db_session,
trace_db_session,
max_concurrency=max_concurrency,
request_timeout=request_timeout,
)

View File

@ -1,17 +1,15 @@
import asyncio
import logging
import os
import time
from web3 import Web3
from mev_inspect.block import get_latest_block_number
from mev_inspect.concurrency import coro
from mev_inspect.crud.latest_block_update import (
find_latest_block_update,
update_latest_block,
)
from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspect_block import inspect_block
from mev_inspect.inspector import MEVInspector
from mev_inspect.provider import get_base_provider
from mev_inspect.signal_handler import GracefulKiller
@ -23,7 +21,8 @@ logger = logging.getLogger(__name__)
BLOCK_NUMBER_LAG = 5
def run():
@coro
async def run():
rpc = os.getenv("RPC_URL")
if rpc is None:
raise RuntimeError("Missing environment variable RPC_URL")
@ -34,21 +33,23 @@ def run():
inspect_db_session = get_inspect_session()
trace_db_session = get_trace_session()
trace_classifier = TraceClassifier()
inspector = MEVInspector(rpc, inspect_db_session, trace_db_session)
base_provider = get_base_provider(rpc)
w3 = Web3(base_provider)
latest_block_number = get_latest_block_number(w3)
latest_block_number = await get_latest_block_number(base_provider)
while not killer.kill_now:
last_written_block = find_latest_block_update(inspect_db_session)
logger.info(f"Latest block: {latest_block_number}")
logger.info(f"Last written block: {last_written_block}")
if (last_written_block is None) or (
last_written_block < (latest_block_number - BLOCK_NUMBER_LAG)
):
if last_written_block is None:
# maintain lag if no blocks written yet
last_written_block = latest_block_number - 1
if last_written_block < (latest_block_number - BLOCK_NUMBER_LAG):
block_number = (
latest_block_number
if last_written_block is None
@ -57,18 +58,11 @@ def run():
logger.info(f"Writing block: {block_number}")
inspect_block(
inspect_db_session,
base_provider,
w3,
trace_classifier,
block_number,
trace_db_session=trace_db_session,
)
await inspector.inspect_single_block(block=block_number)
update_latest_block(inspect_db_session, block_number)
else:
time.sleep(5)
latest_block_number = get_latest_block_number(w3)
await asyncio.sleep(5)
latest_block_number = await get_latest_block_number(base_provider)
logger.info("Stopping...")

View File

@ -11,6 +11,7 @@ from mev_inspect.fees import fetch_base_fee_per_gas
from mev_inspect.schemas.blocks import Block
from mev_inspect.schemas.receipts import Receipt
from mev_inspect.schemas.traces import Trace, TraceType
from mev_inspect.utils import hex_to_int
cache_directory = "./cache"
@ -18,8 +19,13 @@ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
def get_latest_block_number(w3: Web3) -> int:
return int(w3.eth.get_block("latest")["number"])
async def get_latest_block_number(base_provider) -> int:
latest_block = await base_provider.make_request(
"eth_getBlockByNumber",
["latest", False],
)
return hex_to_int(latest_block["result"]["number"])
async def create_from_block_number(

View File

@ -0,0 +1,22 @@
import asyncio
import signal
from functools import wraps
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
def cancel_task_callback():
for task in asyncio.all_tasks():
task.cancel()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, cancel_task_callback)
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
return wrapper

View File

@ -3,13 +3,14 @@ import logging
import sys
import traceback
from asyncio import CancelledError
from typing import Optional
from sqlalchemy import orm
from web3 import Web3
from web3.eth import AsyncEth
from mev_inspect.block import create_from_block_number
from mev_inspect.classifiers.trace import TraceClassifier
from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspect_block import inspect_block
from mev_inspect.provider import get_base_provider
@ -21,11 +22,13 @@ class MEVInspector:
def __init__(
self,
rpc: str,
inspect_db_session: orm.Session,
trace_db_session: Optional[orm.Session],
max_concurrency: int = 1,
request_timeout: int = 300,
):
self.inspect_db_session = get_inspect_session()
self.trace_db_session = get_trace_session()
self.inspect_db_session = inspect_db_session
self.trace_db_session = trace_db_session
self.base_provider = get_base_provider(rpc, request_timeout=request_timeout)
self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
self.trace_classifier = TraceClassifier()