Add trace DB session
This commit is contained in:
parent
2935df284d
commit
1993f0a14d
@ -1,9 +1,11 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/ambv/black
|
- repo: local
|
||||||
rev: 20.8b1
|
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: system
|
||||||
language_version: python3.9
|
name: Black
|
||||||
|
entry: poetry run black .
|
||||||
|
pass_filenames: false
|
||||||
|
language: system
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
|
@ -5,12 +5,12 @@ from sqlalchemy import pool
|
|||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
|
|
||||||
from mev_inspect.db import get_sqlalchemy_database_uri
|
from mev_inspect.db import get_inspect_database_uri
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
# access to the values within the .ini file in use.
|
# access to the values within the .ini file in use.
|
||||||
config = context.config
|
config = context.config
|
||||||
config.set_main_option("sqlalchemy.url", get_sqlalchemy_database_uri())
|
config.set_main_option("sqlalchemy.url", get_inspect_database_uri())
|
||||||
|
|
||||||
# Interpret the config file for Python logging.
|
# Interpret the config file for Python logging.
|
||||||
# This line sets up loggers basically.
|
# This line sets up loggers basically.
|
||||||
|
6
cli.py
6
cli.py
@ -6,7 +6,7 @@ import click
|
|||||||
from web3 import Web3
|
from web3 import Web3
|
||||||
|
|
||||||
from mev_inspect.classifiers.trace import TraceClassifier
|
from mev_inspect.classifiers.trace import TraceClassifier
|
||||||
from mev_inspect.db import get_session
|
from mev_inspect.db import get_inspect_session
|
||||||
from mev_inspect.inspect_block import inspect_block
|
from mev_inspect.inspect_block import inspect_block
|
||||||
from mev_inspect.provider import get_base_provider
|
from mev_inspect.provider import get_base_provider
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ def cli():
|
|||||||
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
|
@click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, ""))
|
||||||
@click.option("--cache/--no-cache", default=True)
|
@click.option("--cache/--no-cache", default=True)
|
||||||
def inspect_block_command(block_number: int, rpc: str, cache: bool):
|
def inspect_block_command(block_number: int, rpc: str, cache: bool):
|
||||||
db_session = get_session()
|
db_session = get_inspect_session()
|
||||||
base_provider = get_base_provider(rpc)
|
base_provider = get_base_provider(rpc)
|
||||||
w3 = Web3(base_provider)
|
w3 = Web3(base_provider)
|
||||||
trace_classifier = TraceClassifier()
|
trace_classifier = TraceClassifier()
|
||||||
@ -54,7 +54,7 @@ def inspect_many_blocks_command(
|
|||||||
after_block: int, before_block: int, rpc: str, cache: bool
|
after_block: int, before_block: int, rpc: str, cache: bool
|
||||||
):
|
):
|
||||||
|
|
||||||
db_session = get_session()
|
db_session = get_inspect_session()
|
||||||
base_provider = get_base_provider(rpc)
|
base_provider = get_base_provider(rpc)
|
||||||
w3 = Web3(base_provider)
|
w3 = Web3(base_provider)
|
||||||
trace_classifier = TraceClassifier()
|
trace_classifier = TraceClassifier()
|
||||||
|
@ -9,7 +9,7 @@ from mev_inspect.crud.latest_block_update import (
|
|||||||
find_latest_block_update,
|
find_latest_block_update,
|
||||||
update_latest_block,
|
update_latest_block,
|
||||||
)
|
)
|
||||||
from mev_inspect.db import get_session
|
from mev_inspect.db import get_inspect_session
|
||||||
from mev_inspect.inspect_block import inspect_block
|
from mev_inspect.inspect_block import inspect_block
|
||||||
from mev_inspect.provider import get_base_provider
|
from mev_inspect.provider import get_base_provider
|
||||||
from mev_inspect.signal_handler import GracefulKiller
|
from mev_inspect.signal_handler import GracefulKiller
|
||||||
@ -32,7 +32,7 @@ def run():
|
|||||||
|
|
||||||
killer = GracefulKiller()
|
killer = GracefulKiller()
|
||||||
|
|
||||||
db_session = get_session()
|
db_session = get_inspect_session()
|
||||||
base_provider = get_base_provider(rpc)
|
base_provider = get_base_provider(rpc)
|
||||||
w3 = Web3(base_provider)
|
w3 = Web3(base_provider)
|
||||||
|
|
||||||
|
@ -1,10 +1,23 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine, orm
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
|
||||||
def get_sqlalchemy_database_uri():
|
def get_trace_database_uri() -> Optional[str]:
|
||||||
|
username = os.getenv("TRACE_DB_USER")
|
||||||
|
password = os.getenv("TRACE_DB_PASSWORD")
|
||||||
|
host = os.getenv("TRACE_DB_HOST")
|
||||||
|
db_name = "trace_db"
|
||||||
|
|
||||||
|
if all(field is not None for field in [username, password, host]):
|
||||||
|
return f"postgresql://{username}:{password}@{host}/{db_name}"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_inspect_database_uri():
|
||||||
username = os.getenv("POSTGRES_USER")
|
username = os.getenv("POSTGRES_USER")
|
||||||
password = os.getenv("POSTGRES_PASSWORD")
|
password = os.getenv("POSTGRES_PASSWORD")
|
||||||
host = os.getenv("POSTGRES_HOST")
|
host = os.getenv("POSTGRES_HOST")
|
||||||
@ -12,10 +25,24 @@ def get_sqlalchemy_database_uri():
|
|||||||
return f"postgresql://{username}:{password}@{host}/{db_name}"
|
return f"postgresql://{username}:{password}@{host}/{db_name}"
|
||||||
|
|
||||||
|
|
||||||
def get_engine():
|
def _get_engine(uri: str):
|
||||||
return create_engine(get_sqlalchemy_database_uri())
|
return create_engine(uri)
|
||||||
|
|
||||||
|
|
||||||
def get_session():
|
def _get_session(uri: str):
|
||||||
Session = sessionmaker(bind=get_engine())
|
Session = sessionmaker(bind=_get_engine(uri))
|
||||||
return Session()
|
return Session()
|
||||||
|
|
||||||
|
|
||||||
|
def get_inspect_session() -> orm.Session:
|
||||||
|
uri = get_inspect_database_uri()
|
||||||
|
return _get_session(uri)
|
||||||
|
|
||||||
|
|
||||||
|
def get_trace_session() -> Optional[orm.Session]:
|
||||||
|
uri = get_trace_database_uri()
|
||||||
|
|
||||||
|
if uri is not None:
|
||||||
|
return _get_session(uri)
|
||||||
|
|
||||||
|
return None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user