Merge pull request #255 from flashbots/split-out-workers-from-task

Separate importing tasks from importing the worker
This commit is contained in:
Luke Van Seters 2022-02-08 13:06:37 -05:00 committed by GitHub
commit 3965c5f7ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 131 additions and 79 deletions

10
cli.py
View File

@ -4,12 +4,15 @@ import sys
from datetime import datetime from datetime import datetime
import click import click
import dramatiq
from mev_inspect.concurrency import coro from mev_inspect.concurrency import coro
from mev_inspect.crud.prices import write_prices from mev_inspect.crud.prices import write_prices
from mev_inspect.db import get_inspect_session, get_trace_session from mev_inspect.db import get_inspect_session, get_trace_session
from mev_inspect.inspector import MEVInspector from mev_inspect.inspector import MEVInspector
from mev_inspect.prices import fetch_prices, fetch_prices_range from mev_inspect.prices import fetch_prices, fetch_prices_range
from mev_inspect.queue.broker import connect_broker
from mev_inspect.queue.tasks import inspect_many_blocks_task
RPC_URL_ENV = "RPC_URL" RPC_URL_ENV = "RPC_URL"
@ -97,14 +100,13 @@ async def inspect_many_blocks_command(
@click.argument("before_block", type=int) @click.argument("before_block", type=int)
@click.argument("batch_size", type=int, default=10) @click.argument("batch_size", type=int, default=10)
def enqueue_many_blocks_command(after_block: int, before_block: int, batch_size: int): def enqueue_many_blocks_command(after_block: int, before_block: int, batch_size: int):
from worker import ( # pylint: disable=import-outside-toplevel broker = connect_broker()
inspect_many_blocks_task, inspect_many_blocks_actor = dramatiq.actor(inspect_many_blocks_task, broker=broker)
)
for batch_after_block in range(after_block, before_block, batch_size): for batch_after_block in range(after_block, before_block, batch_size):
batch_before_block = min(batch_after_block + batch_size, before_block) batch_before_block = min(batch_after_block + batch_size, before_block)
logger.info(f"Sending {batch_after_block} to {batch_before_block}") logger.info(f"Sending {batch_after_block} to {batch_before_block}")
inspect_many_blocks_task.send(batch_after_block, batch_before_block) inspect_many_blocks_actor.send(batch_after_block, batch_before_block)
@cli.command() @cli.command()

View File

View File

@ -0,0 +1,7 @@
import os
from dramatiq.brokers.redis import RedisBroker
def connect_broker():
return RedisBroker(host="redis-master", password=os.environ["REDIS_PASSWORD"])

View File

@ -0,0 +1,75 @@
import asyncio
import logging
from threading import local
from dramatiq.middleware import Middleware
from mev_inspect.db import get_inspect_sessionmaker, get_trace_sessionmaker
from mev_inspect.inspector import MEVInspector
logger = logging.getLogger(__name__)
class DbMiddleware(Middleware):
STATE = local()
INSPECT_SESSION_STATE_KEY = "InspectSession"
TRACE_SESSION_STATE_KEY = "TraceSession"
@classmethod
def get_inspect_sessionmaker(cls):
return getattr(cls.STATE, cls.INSPECT_SESSION_STATE_KEY, None)
@classmethod
def get_trace_sessionmaker(cls):
return getattr(cls.STATE, cls.TRACE_SESSION_STATE_KEY, None)
def before_process_message(self, _broker, message):
if not hasattr(self.STATE, self.INSPECT_SESSION_STATE_KEY):
logger.info("Building sessionmakers")
setattr(
self.STATE, self.INSPECT_SESSION_STATE_KEY, get_inspect_sessionmaker()
)
setattr(self.STATE, self.TRACE_SESSION_STATE_KEY, get_trace_sessionmaker())
else:
logger.info("Sessionmakers already set")
class InspectorMiddleware(Middleware):
STATE = local()
INSPECT_STATE_KEY = "inspector"
def __init__(self, rpc_url):
self._rpc_url = rpc_url
@classmethod
def get_inspector(cls):
return getattr(cls.STATE, cls.INSPECT_STATE_KEY, None)
def before_process_message(
self, _broker, worker
): # pylint: disable=unused-argument
if not hasattr(self.STATE, self.INSPECT_STATE_KEY):
logger.info("Building inspector")
inspector = MEVInspector(
self._rpc_url,
max_concurrency=5,
request_timeout=300,
)
setattr(self.STATE, self.INSPECT_STATE_KEY, inspector)
else:
logger.info("Inspector already exists")
class AsyncMiddleware(Middleware):
def before_process_message(
self, _broker, message
): # pylint: disable=unused-argument
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def after_process_message(
self, _broker, message, *, result=None, exception=None
): # pylint: disable=unused-argument
if hasattr(self, "loop"):
self.loop.close()

View File

@ -0,0 +1,32 @@
import asyncio
import logging
from contextlib import contextmanager
from .middleware import DbMiddleware, InspectorMiddleware
logger = logging.getLogger(__name__)
def inspect_many_blocks_task(
after_block: int,
before_block: int,
):
with _session_scope(DbMiddleware.get_inspect_sessionmaker()) as inspect_db_session:
with _session_scope(DbMiddleware.get_trace_sessionmaker()) as trace_db_session:
asyncio.run(
InspectorMiddleware.get_inspector().inspect_many_blocks(
inspect_db_session=inspect_db_session,
trace_db_session=trace_db_session,
after_block=after_block,
before_block=before_block,
)
)
@contextmanager
def _session_scope(Session=None):
if Session is None:
yield None
else:
with Session() as session:
yield session

View File

@ -1,87 +1,23 @@
import asyncio
import logging import logging
import os import os
import sys import sys
import threading
from contextlib import contextmanager
import dramatiq import dramatiq
from dramatiq.brokers.redis import RedisBroker
from dramatiq.cli import main as dramatiq_worker
from dramatiq.middleware import Middleware
from mev_inspect.db import get_inspect_sessionmaker, get_trace_sessionmaker from mev_inspect.queue.broker import connect_broker
from mev_inspect.inspector import MEVInspector from mev_inspect.queue.middleware import (
AsyncMiddleware,
DbMiddleware,
InspectorMiddleware,
)
from mev_inspect.queue.tasks import inspect_many_blocks_task
InspectSession = get_inspect_sessionmaker()
TraceSession = get_trace_sessionmaker()
thread_local = threading.local()
logging.basicConfig(stream=sys.stdout, level=logging.INFO) logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
broker = connect_broker()
class AsyncMiddleware(Middleware): broker.add_middleware(DbMiddleware())
def before_process_message(
self, _broker, message
): # pylint: disable=unused-argument
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def after_process_message(
self, _broker, message, *, result=None, exception=None
): # pylint: disable=unused-argument
self.loop.close()
class InspectorMiddleware(Middleware):
def before_process_message(
self, _broker, worker
): # pylint: disable=unused-argument
rpc = os.environ["RPC_URL"]
if not hasattr(thread_local, "inspector"):
logger.info("Building inspector")
thread_local.inspector = MEVInspector(
rpc,
max_concurrency=5,
request_timeout=300,
)
else:
logger.info("Inspector already exists")
broker = RedisBroker(host="redis-master", password=os.environ["REDIS_PASSWORD"])
broker.add_middleware(AsyncMiddleware()) broker.add_middleware(AsyncMiddleware())
broker.add_middleware(InspectorMiddleware()) broker.add_middleware(InspectorMiddleware(os.environ["RPC_URL"]))
dramatiq.set_broker(broker) dramatiq.set_broker(broker)
dramatiq.actor(inspect_many_blocks_task)
@contextmanager
def session_scope(Session=None):
if Session is None:
yield None
else:
with Session() as session:
yield session
@dramatiq.actor
def inspect_many_blocks_task(
after_block: int,
before_block: int,
):
with session_scope(InspectSession) as inspect_db_session:
with session_scope(TraceSession) as trace_db_session:
asyncio.run(
thread_local.inspector.inspect_many_blocks(
inspect_db_session=inspect_db_session,
trace_db_session=trace_db_session,
after_block=after_block,
before_block=before_block,
)
)
if __name__ == "__main__":
dramatiq_worker(processes=1, threads=1)