2022-02-02 13:16:36 -05:00

76 lines
2.3 KiB
Python

import asyncio
import logging
from threading import local
from dramatiq.middleware import Middleware
from mev_inspect.db import get_inspect_sessionmaker, get_trace_sessionmaker
from mev_inspect.inspector import MEVInspector
logger = logging.getLogger(__name__)
class DbMiddleware(Middleware):
STATE = local()
INSPECT_SESSION_STATE_KEY = "InspectSession"
TRACE_SESSION_STATE_KEY = "TraceSession"
@classmethod
def get_inspect_sessionmaker(cls):
return getattr(cls.STATE, cls.INSPECT_SESSION_STATE_KEY, None)
@classmethod
def get_trace_sessionmaker(cls):
return getattr(cls.STATE, cls.TRACE_SESSION_STATE_KEY, None)
def before_process_message(self, _broker, message):
if not hasattr(self.STATE, self.INSPECT_SESSION_STATE_KEY):
logger.info("Building sessionmakers")
setattr(
self.STATE, self.INSPECT_SESSION_STATE_KEY, get_inspect_sessionmaker()
)
setattr(self.STATE, self.TRACE_SESSION_STATE_KEY, get_trace_sessionmaker())
else:
logger.info("Sessionmakers already set")
class InspectorMiddleware(Middleware):
STATE = local()
INSPECT_STATE_KEY = "inspector"
def __init__(self, rpc_url):
self._rpc_url = rpc_url
@classmethod
def get_inspector(cls):
return getattr(cls.STATE, cls.INSPECT_STATE_KEY, None)
def before_process_message(
self, _broker, worker
): # pylint: disable=unused-argument
if not hasattr(self.STATE, self.INSPECT_STATE_KEY):
logger.info("Building inspector")
inspector = MEVInspector(
self._rpc_url,
max_concurrency=5,
request_timeout=300,
)
setattr(self.STATE, self.INSPECT_STATE_KEY, inspector)
else:
logger.info("Inspector already exists")
class AsyncMiddleware(Middleware):
def before_process_message(
self, _broker, message
): # pylint: disable=unused-argument
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def after_process_message(
self, _broker, message, *, result=None, exception=None
): # pylint: disable=unused-argument
if hasattr(self, "loop"):
self.loop.close()