diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 397ce75..8ee03eb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,11 @@ repos: -- repo: https://github.com/ambv/black - rev: 20.8b1 +- repo: local hooks: - - id: black - language_version: python3.9 + - id: system + name: Black + entry: poetry run black . + pass_filenames: false + language: system - repo: local hooks: - id: pylint diff --git a/alembic/env.py b/alembic/env.py index cbb52c1..7e5e6cd 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -5,12 +5,12 @@ from sqlalchemy import pool from alembic import context -from mev_inspect.db import get_sqlalchemy_database_uri +from mev_inspect.db import get_inspect_database_uri # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config -config.set_main_option("sqlalchemy.url", get_sqlalchemy_database_uri()) +config.set_main_option("sqlalchemy.url", get_inspect_database_uri()) # Interpret the config file for Python logging. # This line sets up loggers basically. diff --git a/cli.py b/cli.py index e7f2386..898fd42 100644 --- a/cli.py +++ b/cli.py @@ -6,7 +6,7 @@ import click from web3 import Web3 from mev_inspect.classifiers.trace import TraceClassifier -from mev_inspect.db import get_session +from mev_inspect.db import get_inspect_session from mev_inspect.inspect_block import inspect_block from mev_inspect.provider import get_base_provider @@ -27,7 +27,7 @@ def cli(): @click.option("--rpc", default=lambda: os.environ.get(RPC_URL_ENV, "")) @click.option("--cache/--no-cache", default=True) def inspect_block_command(block_number: int, rpc: str, cache: bool): - db_session = get_session() + db_session = get_inspect_session() base_provider = get_base_provider(rpc) w3 = Web3(base_provider) trace_classifier = TraceClassifier() @@ -54,7 +54,7 @@ def inspect_many_blocks_command( after_block: int, before_block: int, rpc: str, cache: bool ): - db_session = get_session() + db_session = get_inspect_session() base_provider = get_base_provider(rpc) w3 = Web3(base_provider) trace_classifier = TraceClassifier() diff --git a/listener.py b/listener.py index 88e364c..05cfb38 100644 --- a/listener.py +++ b/listener.py @@ -9,7 +9,7 @@ from mev_inspect.crud.latest_block_update import ( find_latest_block_update, update_latest_block, ) -from mev_inspect.db import get_session +from mev_inspect.db import get_inspect_session from mev_inspect.inspect_block import inspect_block from mev_inspect.provider import get_base_provider from mev_inspect.signal_handler import GracefulKiller @@ -32,7 +32,7 @@ def run(): killer = GracefulKiller() - db_session = get_session() + db_session = get_inspect_session() base_provider = get_base_provider(rpc) w3 = Web3(base_provider) diff --git a/mev_inspect/db.py b/mev_inspect/db.py index df2b7d9..9cdaa48 100644 --- a/mev_inspect/db.py +++ b/mev_inspect/db.py @@ -1,10 +1,23 @@ import os +from typing import Optional -from sqlalchemy import create_engine +from sqlalchemy import create_engine, orm from sqlalchemy.orm import sessionmaker -def get_sqlalchemy_database_uri(): +def get_trace_database_uri() -> Optional[str]: + username = os.getenv("TRACE_DB_USER") + password = os.getenv("TRACE_DB_PASSWORD") + host = os.getenv("TRACE_DB_HOST") + db_name = "trace_db" + + if all(field is not None for field in [username, password, host]): + return f"postgresql://{username}:{password}@{host}/{db_name}" + + return None + + +def get_inspect_database_uri(): username = os.getenv("POSTGRES_USER") password = os.getenv("POSTGRES_PASSWORD") host = os.getenv("POSTGRES_HOST") @@ -12,10 +25,24 @@ def get_sqlalchemy_database_uri(): return f"postgresql://{username}:{password}@{host}/{db_name}" -def get_engine(): - return create_engine(get_sqlalchemy_database_uri()) +def _get_engine(uri: str): + return create_engine(uri) -def get_session(): - Session = sessionmaker(bind=get_engine()) +def _get_session(uri: str): + Session = sessionmaker(bind=_get_engine(uri)) return Session() + + +def get_inspect_session() -> orm.Session: + uri = get_inspect_database_uri() + return _get_session(uri) + + +def get_trace_session() -> Optional[orm.Session]: + uri = get_trace_database_uri() + + if uri is not None: + return _get_session(uri) + + return None