Merge pull request #211 from flashbots/faster-writes

Use COPY to speed up database writes for blocks and traces
This commit is contained in:
Luke Van Seters 2022-01-04 13:17:24 -05:00 committed by GitHub
commit 379bd82f0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 108 additions and 37 deletions

View File

@ -1,6 +1,7 @@
from datetime import datetime from datetime import datetime
from typing import List from typing import List
from mev_inspect.db import write_as_csv
from mev_inspect.schemas.blocks import Block from mev_inspect.schemas.blocks import Block
@ -28,16 +29,11 @@ def write_blocks(
db_session, db_session,
blocks: List[Block], blocks: List[Block],
) -> None: ) -> None:
block_params = [ items_generator = (
{ (
"block_number": block.block_number, block.block_number,
"block_timestamp": datetime.fromtimestamp(block.block_timestamp), datetime.fromtimestamp(block.block_timestamp),
} )
for block in blocks for block in blocks
]
db_session.execute(
"INSERT INTO blocks (block_number, block_timestamp) VALUES (:block_number, :block_timestamp)",
params=block_params,
) )
db_session.commit() write_as_csv(db_session, "blocks", items_generator)

View File

@ -1,6 +1,8 @@
import json import json
from datetime import datetime, timezone
from typing import List from typing import List
from mev_inspect.db import to_postgres_list, write_as_csv
from mev_inspect.models.traces import ClassifiedTraceModel from mev_inspect.models.traces import ClassifiedTraceModel
from mev_inspect.schemas.traces import ClassifiedTrace from mev_inspect.schemas.traces import ClassifiedTrace
@ -26,30 +28,35 @@ def write_classified_traces(
db_session, db_session,
classified_traces: List[ClassifiedTrace], classified_traces: List[ClassifiedTrace],
) -> None: ) -> None:
models = [] classified_at = datetime.now(timezone.utc)
for trace in classified_traces: items = (
inputs_json = (json.loads(trace.json(include={"inputs"}))["inputs"],) (
models.append( classified_at,
ClassifiedTraceModel( trace.transaction_hash,
transaction_hash=trace.transaction_hash, trace.block_number,
transaction_position=trace.transaction_position, trace.classification.value,
block_number=trace.block_number, trace.type.value,
classification=trace.classification.value, str(trace.protocol),
trace_type=trace.type.value, trace.abi_name,
trace_address=trace.trace_address, trace.function_name,
protocol=str(trace.protocol), trace.function_signature,
abi_name=trace.abi_name, _inputs_as_json(trace),
function_name=trace.function_name, trace.from_address,
function_signature=trace.function_signature, trace.to_address,
inputs=inputs_json, trace.gas,
from_address=trace.from_address, trace.value,
to_address=trace.to_address, trace.gas_used,
gas=trace.gas, trace.error,
value=trace.value, to_postgres_list(trace.trace_address),
gas_used=trace.gas_used, trace.transaction_position,
error=trace.error,
)
) )
for trace in classified_traces
)
db_session.bulk_save_objects(models) write_as_csv(db_session, "classified_traces", items)
db_session.commit()
def _inputs_as_json(trace) -> str:
inputs = json.dumps(json.loads(trace.json(include={"inputs"}))["inputs"])
inputs_with_array = f"[{inputs}]"
return inputs_with_array

View File

@ -1,9 +1,11 @@
import os import os
from typing import Optional from typing import Any, Iterable, List, Optional
from sqlalchemy import create_engine, orm from sqlalchemy import create_engine, orm
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from mev_inspect.string_io import StringIteratorIO
def get_trace_database_uri() -> Optional[str]: def get_trace_database_uri() -> Optional[str]:
username = os.getenv("TRACE_DB_USER") username = os.getenv("TRACE_DB_USER")
@ -63,3 +65,29 @@ def get_trace_session() -> Optional[orm.Session]:
return Session() return Session()
return None return None
def write_as_csv(
db_session,
table_name: str,
items: Iterable[Iterable[Any]],
) -> None:
csv_iterator = StringIteratorIO(
("|".join(map(_clean_csv_value, item)) + "\n" for item in items)
)
with db_session.connection().connection.cursor() as cursor:
cursor.copy_from(csv_iterator, table_name, sep="|")
def _clean_csv_value(value: Optional[Any]) -> str:
if value is None:
return r"\N"
return str(value).replace("\n", "\\n")
def to_postgres_list(values: List[Any]) -> str:
if len(values) == 0:
return "{}"
return "{" + ",".join(map(str, values)) + "}"

40
mev_inspect/string_io.py Normal file
View File

@ -0,0 +1,40 @@
"""This is taken from https://hakibenita.com/fast-load-data-python-postgresql"""
import io
from typing import Iterator, Optional
class StringIteratorIO(io.TextIOBase):
def __init__(self, iter: Iterator[str]):
self._iter = iter
self._buff = ""
def readable(self) -> bool:
return True
def _read1(self, n: Optional[int] = None) -> str:
while not self._buff:
try:
self._buff = next(self._iter)
except StopIteration:
break
ret = self._buff[:n]
self._buff = self._buff[len(ret) :]
return ret
def read(self, n: Optional[int] = None) -> str:
line = []
if n is None or n < 0:
while True:
m = self._read1()
if not m:
break
line.append(m)
else:
while n > 0:
m = self._read1(n)
if not m:
break
n -= len(m)
line.append(m)
return "".join(line)