Merge pull request #70 from flashbots/call-trace-db

Add stricter types to trace classifier
This commit is contained in:
Luke Van Seters 2021-09-16 14:55:28 -06:00 committed by GitHub
commit bff71b01c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 79 additions and 28 deletions

View File

@ -6,6 +6,8 @@ from mev_inspect.schemas.blocks import CallAction, CallResult, Trace, TraceType
from mev_inspect.schemas.classified_traces import ( from mev_inspect.schemas.classified_traces import (
Classification, Classification,
ClassifiedTrace, ClassifiedTrace,
CallTrace,
DecodedCallTrace,
) )
from .specs import ALL_CLASSIFIER_SPECS from .specs import ALL_CLASSIFIER_SPECS
@ -69,7 +71,7 @@ class TraceClassifier:
signature, Classification.unknown signature, Classification.unknown
) )
return ClassifiedTrace( return DecodedCallTrace(
**trace.dict(), **trace.dict(),
trace_type=trace.type, trace_type=trace.type,
classification=classification, classification=classification,
@ -85,7 +87,7 @@ class TraceClassifier:
gas_used=result.gas_used if result is not None else None, gas_used=result.gas_used if result is not None else None,
) )
return ClassifiedTrace( return CallTrace(
**trace.dict(), **trace.dict(),
trace_type=trace.type, trace_type=trace.type,
classification=Classification.unknown, classification=Classification.unknown,

View File

@ -22,9 +22,29 @@ def write_classified_traces(
db_session, db_session,
classified_traces: List[ClassifiedTrace], classified_traces: List[ClassifiedTrace],
) -> None: ) -> None:
models = [ models = []
ClassifiedTraceModel(**json.loads(trace.json())) for trace in classified_traces for trace in classified_traces:
] inputs_json = (json.loads(trace.json(include={"inputs"}))["inputs"],)
models.append(
ClassifiedTraceModel(
transaction_hash=trace.transaction_hash,
block_number=trace.block_number,
classification=trace.classification.value,
trace_type=trace.type.value,
trace_address=trace.trace_address,
protocol=str(trace.protocol),
abi_name=trace.abi_name,
function_name=trace.function_name,
function_signature=trace.function_signature,
inputs=inputs_json,
from_address=trace.from_address,
to_address=trace.to_address,
gas=trace.gas,
value=trace.value,
gas_used=trace.gas_used,
error=trace.error,
)
)
db_session.bulk_save_objects(models) db_session.bulk_save_objects(models)
db_session.commit() db_session.commit()

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from .blocks import TraceType from .blocks import Trace
class Classification(Enum): class Classification(Enum):
@ -24,25 +24,25 @@ class Protocol(Enum):
zero_ex = "0x" zero_ex = "0x"
class ClassifiedTrace(BaseModel): class ClassifiedTrace(Trace):
transaction_hash: str transaction_hash: str
block_number: int block_number: int
trace_type: TraceType
trace_address: List[int] trace_address: List[int]
classification: Classification classification: Classification
protocol: Optional[Protocol] error: Optional[str]
abi_name: Optional[str]
function_name: Optional[str]
function_signature: Optional[str]
inputs: Optional[Dict[str, Any]]
to_address: Optional[str] to_address: Optional[str]
from_address: Optional[str] from_address: Optional[str]
gas: Optional[int] gas: Optional[int]
value: Optional[int] value: Optional[int]
gas_used: Optional[int] gas_used: Optional[int]
error: Optional[str] protocol: Optional[Protocol]
function_name: Optional[str]
function_signature: Optional[str]
inputs: Optional[Dict[str, Any]]
abi_name: Optional[str]
class Config: class Config:
validate_assignment = True
json_encoders = { json_encoders = {
# a little lazy but fine for now # a little lazy but fine for now
# this is used for bytes value inputs # this is used for bytes value inputs
@ -50,6 +50,21 @@ class ClassifiedTrace(BaseModel):
} }
class CallTrace(ClassifiedTrace):
to_address: str
from_address: str
class DecodedCallTrace(CallTrace):
inputs: Dict[str, Any]
abi_name: str
protocol: Optional[Protocol]
gas: Optional[int]
gas_used: Optional[int]
function_name: Optional[str]
function_signature: Optional[str]
class ClassifierSpec(BaseModel): class ClassifierSpec(BaseModel):
abi_name: str abi_name: str
protocol: Optional[Protocol] = None protocol: Optional[Protocol] = None

View File

@ -10,7 +10,6 @@ class Swap(BaseModel):
transaction_hash: str transaction_hash: str
block_number: int block_number: int
trace_address: List[int] trace_address: List[int]
protocol: Optional[Protocol]
pool_address: str pool_address: str
from_address: str from_address: str
to_address: str to_address: str
@ -18,4 +17,5 @@ class Swap(BaseModel):
token_in_amount: int token_in_amount: int
token_out_address: str token_out_address: str
token_out_amount: int token_out_amount: int
protocol: Optional[Protocol]
error: Optional[str] error: Optional[str]

View File

@ -1,7 +1,12 @@
from typing import List from typing import List
from mev_inspect.schemas.blocks import TraceType from mev_inspect.schemas.blocks import TraceType
from mev_inspect.schemas.classified_traces import Classification, ClassifiedTrace from mev_inspect.schemas.classified_traces import (
Classification,
ClassifiedTrace,
CallTrace,
DecodedCallTrace,
)
def make_transfer_trace( def make_transfer_trace(
@ -13,10 +18,10 @@ def make_transfer_trace(
token_address: str, token_address: str,
amount: int, amount: int,
): ):
return ClassifiedTrace( return CallTrace(
transaction_hash=transaction_hash, transaction_hash=transaction_hash,
block_number=block_number, block_number=block_number,
trace_type=TraceType.call, type=TraceType.call,
trace_address=trace_address, trace_address=trace_address,
classification=Classification.transfer, classification=Classification.transfer,
from_address=from_address, from_address=from_address,
@ -25,6 +30,9 @@ def make_transfer_trace(
"recipient": to_address, "recipient": to_address,
"amount": amount, "amount": amount,
}, },
block_hash=str(block_number),
action={},
subtraces=0.0,
) )
@ -38,37 +46,43 @@ def make_swap_trace(
recipient_address: str, recipient_address: str,
recipient_input_key: str, recipient_input_key: str,
): ):
return ClassifiedTrace( return DecodedCallTrace(
transaction_hash=transaction_hash, transaction_hash=transaction_hash,
block_number=block_number, block_number=block_number,
trace_type=TraceType.call, type=TraceType.call,
trace_address=trace_address, trace_address=trace_address,
action={},
subtraces=0,
classification=Classification.swap, classification=Classification.swap,
from_address=from_address, from_address=from_address,
to_address=pool_address, to_address=pool_address,
inputs={recipient_input_key: recipient_address}, inputs={recipient_input_key: recipient_address},
abi_name=abi_name, abi_name=abi_name,
block_hash=str(block_number),
) )
def make_unknown_trace( def make_unknown_trace(
block_number, block_number: int,
transaction_hash, transaction_hash: str,
trace_address, trace_address: List[int],
): ):
return ClassifiedTrace( return ClassifiedTrace(
transaction_hash=transaction_hash,
block_number=block_number, block_number=block_number,
trace_type=TraceType.call, transaction_hash=transaction_hash,
trace_address=trace_address, trace_address=trace_address,
action={},
subtraces=0,
block_hash=str(block_number),
type=TraceType.call,
classification=Classification.unknown, classification=Classification.unknown,
) )
def make_many_unknown_traces( def make_many_unknown_traces(
block_number, block_number: int,
transaction_hash, transaction_hash: str,
trace_addresses, trace_addresses: List[List[int]],
) -> List[ClassifiedTrace]: ) -> List[ClassifiedTrace]:
return [ return [