Merge pull request #211 from flashbots/faster-writes

Use COPY to speed up database writes for blocks and traces
This commit is contained in:
Luke Van Seters 2022-01-04 13:17:24 -05:00 committed by GitHub
commit 379bd82f0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 108 additions and 37 deletions

View File

@ -1,6 +1,7 @@
from datetime import datetime
from typing import List
from mev_inspect.db import write_as_csv
from mev_inspect.schemas.blocks import Block
@ -28,16 +29,11 @@ def write_blocks(
db_session,
blocks: List[Block],
) -> None:
block_params = [
{
"block_number": block.block_number,
"block_timestamp": datetime.fromtimestamp(block.block_timestamp),
}
items_generator = (
(
block.block_number,
datetime.fromtimestamp(block.block_timestamp),
)
for block in blocks
]
db_session.execute(
"INSERT INTO blocks (block_number, block_timestamp) VALUES (:block_number, :block_timestamp)",
params=block_params,
)
db_session.commit()
write_as_csv(db_session, "blocks", items_generator)

View File

@ -1,6 +1,8 @@
import json
from datetime import datetime, timezone
from typing import List
from mev_inspect.db import to_postgres_list, write_as_csv
from mev_inspect.models.traces import ClassifiedTraceModel
from mev_inspect.schemas.traces import ClassifiedTrace
@ -26,30 +28,35 @@ def write_classified_traces(
db_session,
classified_traces: List[ClassifiedTrace],
) -> None:
models = []
for trace in classified_traces:
inputs_json = (json.loads(trace.json(include={"inputs"}))["inputs"],)
models.append(
ClassifiedTraceModel(
transaction_hash=trace.transaction_hash,
transaction_position=trace.transaction_position,
block_number=trace.block_number,
classification=trace.classification.value,
trace_type=trace.type.value,
trace_address=trace.trace_address,
protocol=str(trace.protocol),
abi_name=trace.abi_name,
function_name=trace.function_name,
function_signature=trace.function_signature,
inputs=inputs_json,
from_address=trace.from_address,
to_address=trace.to_address,
gas=trace.gas,
value=trace.value,
gas_used=trace.gas_used,
error=trace.error,
)
classified_at = datetime.now(timezone.utc)
items = (
(
classified_at,
trace.transaction_hash,
trace.block_number,
trace.classification.value,
trace.type.value,
str(trace.protocol),
trace.abi_name,
trace.function_name,
trace.function_signature,
_inputs_as_json(trace),
trace.from_address,
trace.to_address,
trace.gas,
trace.value,
trace.gas_used,
trace.error,
to_postgres_list(trace.trace_address),
trace.transaction_position,
)
for trace in classified_traces
)
db_session.bulk_save_objects(models)
db_session.commit()
write_as_csv(db_session, "classified_traces", items)
def _inputs_as_json(trace) -> str:
inputs = json.dumps(json.loads(trace.json(include={"inputs"}))["inputs"])
inputs_with_array = f"[{inputs}]"
return inputs_with_array

View File

@ -1,9 +1,11 @@
import os
from typing import Optional
from typing import Any, Iterable, List, Optional
from sqlalchemy import create_engine, orm
from sqlalchemy.orm import sessionmaker
from mev_inspect.string_io import StringIteratorIO
def get_trace_database_uri() -> Optional[str]:
username = os.getenv("TRACE_DB_USER")
@ -63,3 +65,29 @@ def get_trace_session() -> Optional[orm.Session]:
return Session()
return None
def write_as_csv(
db_session,
table_name: str,
items: Iterable[Iterable[Any]],
) -> None:
csv_iterator = StringIteratorIO(
("|".join(map(_clean_csv_value, item)) + "\n" for item in items)
)
with db_session.connection().connection.cursor() as cursor:
cursor.copy_from(csv_iterator, table_name, sep="|")
def _clean_csv_value(value: Optional[Any]) -> str:
if value is None:
return r"\N"
return str(value).replace("\n", "\\n")
def to_postgres_list(values: List[Any]) -> str:
if len(values) == 0:
return "{}"
return "{" + ",".join(map(str, values)) + "}"

40
mev_inspect/string_io.py Normal file
View File

@ -0,0 +1,40 @@
"""This is taken from https://hakibenita.com/fast-load-data-python-postgresql"""
import io
from typing import Iterator, Optional
class StringIteratorIO(io.TextIOBase):
def __init__(self, iter: Iterator[str]):
self._iter = iter
self._buff = ""
def readable(self) -> bool:
return True
def _read1(self, n: Optional[int] = None) -> str:
while not self._buff:
try:
self._buff = next(self._iter)
except StopIteration:
break
ret = self._buff[:n]
self._buff = self._buff[len(ret) :]
return ret
def read(self, n: Optional[int] = None) -> str:
line = []
if n is None or n < 0:
while True:
m = self._read1()
if not m:
break
line.append(m)
else:
while n > 0:
m = self._read1(n)
if not m:
break
n -= len(m)
line.append(m)
return "".join(line)