diff --git a/mev_inspect/crud/blocks.py b/mev_inspect/crud/blocks.py index fce9d2e..41199a5 100644 --- a/mev_inspect/crud/blocks.py +++ b/mev_inspect/crud/blocks.py @@ -1,6 +1,7 @@ from datetime import datetime from typing import List +from mev_inspect.db import write_as_csv from mev_inspect.schemas.blocks import Block @@ -28,16 +29,11 @@ def write_blocks( db_session, blocks: List[Block], ) -> None: - block_params = [ - { - "block_number": block.block_number, - "block_timestamp": datetime.fromtimestamp(block.block_timestamp), - } + items_generator = ( + ( + block.block_number, + datetime.fromtimestamp(block.block_timestamp), + ) for block in blocks - ] - - db_session.execute( - "INSERT INTO blocks (block_number, block_timestamp) VALUES (:block_number, :block_timestamp)", - params=block_params, ) - db_session.commit() + write_as_csv(db_session, "blocks", items_generator) diff --git a/mev_inspect/crud/traces.py b/mev_inspect/crud/traces.py index 0f099f6..903026e 100644 --- a/mev_inspect/crud/traces.py +++ b/mev_inspect/crud/traces.py @@ -1,6 +1,8 @@ import json +from datetime import datetime, timezone from typing import List +from mev_inspect.db import to_postgres_list, write_as_csv from mev_inspect.models.traces import ClassifiedTraceModel from mev_inspect.schemas.traces import ClassifiedTrace @@ -26,30 +28,35 @@ def write_classified_traces( db_session, classified_traces: List[ClassifiedTrace], ) -> None: - models = [] - for trace in classified_traces: - inputs_json = (json.loads(trace.json(include={"inputs"}))["inputs"],) - models.append( - ClassifiedTraceModel( - transaction_hash=trace.transaction_hash, - transaction_position=trace.transaction_position, - block_number=trace.block_number, - classification=trace.classification.value, - trace_type=trace.type.value, - trace_address=trace.trace_address, - protocol=str(trace.protocol), - abi_name=trace.abi_name, - function_name=trace.function_name, - function_signature=trace.function_signature, - inputs=inputs_json, - from_address=trace.from_address, - to_address=trace.to_address, - gas=trace.gas, - value=trace.value, - gas_used=trace.gas_used, - error=trace.error, - ) + classified_at = datetime.now(timezone.utc) + items = ( + ( + classified_at, + trace.transaction_hash, + trace.block_number, + trace.classification.value, + trace.type.value, + str(trace.protocol), + trace.abi_name, + trace.function_name, + trace.function_signature, + _inputs_as_json(trace), + trace.from_address, + trace.to_address, + trace.gas, + trace.value, + trace.gas_used, + trace.error, + to_postgres_list(trace.trace_address), + trace.transaction_position, ) + for trace in classified_traces + ) - db_session.bulk_save_objects(models) - db_session.commit() + write_as_csv(db_session, "classified_traces", items) + + +def _inputs_as_json(trace) -> str: + inputs = json.dumps(json.loads(trace.json(include={"inputs"}))["inputs"]) + inputs_with_array = f"[{inputs}]" + return inputs_with_array diff --git a/mev_inspect/db.py b/mev_inspect/db.py index 15ccdc3..dd7c66a 100644 --- a/mev_inspect/db.py +++ b/mev_inspect/db.py @@ -1,9 +1,11 @@ import os -from typing import Optional +from typing import Any, Iterable, List, Optional from sqlalchemy import create_engine, orm from sqlalchemy.orm import sessionmaker +from mev_inspect.string_io import StringIteratorIO + def get_trace_database_uri() -> Optional[str]: username = os.getenv("TRACE_DB_USER") @@ -63,3 +65,29 @@ def get_trace_session() -> Optional[orm.Session]: return Session() return None + + +def write_as_csv( + db_session, + table_name: str, + items: Iterable[Iterable[Any]], +) -> None: + csv_iterator = StringIteratorIO( + ("|".join(map(_clean_csv_value, item)) + "\n" for item in items) + ) + + with db_session.connection().connection.cursor() as cursor: + cursor.copy_from(csv_iterator, table_name, sep="|") + + +def _clean_csv_value(value: Optional[Any]) -> str: + if value is None: + return r"\N" + return str(value).replace("\n", "\\n") + + +def to_postgres_list(values: List[Any]) -> str: + if len(values) == 0: + return "{}" + + return "{" + ",".join(map(str, values)) + "}" diff --git a/mev_inspect/string_io.py b/mev_inspect/string_io.py new file mode 100644 index 0000000..37efb5f --- /dev/null +++ b/mev_inspect/string_io.py @@ -0,0 +1,40 @@ +"""This is taken from https://hakibenita.com/fast-load-data-python-postgresql""" + +import io +from typing import Iterator, Optional + + +class StringIteratorIO(io.TextIOBase): + def __init__(self, iter: Iterator[str]): + self._iter = iter + self._buff = "" + + def readable(self) -> bool: + return True + + def _read1(self, n: Optional[int] = None) -> str: + while not self._buff: + try: + self._buff = next(self._iter) + except StopIteration: + break + ret = self._buff[:n] + self._buff = self._buff[len(ret) :] + return ret + + def read(self, n: Optional[int] = None) -> str: + line = [] + if n is None or n < 0: + while True: + m = self._read1() + if not m: + break + line.append(m) + else: + while n > 0: + m = self._read1(n) + if not m: + break + n -= len(m) + line.append(m) + return "".join(line)