Merge pull request #70 from flashbots/call-trace-db

Add stricter types to trace classifier
This commit is contained in:
Luke Van Seters 2021-09-16 14:55:28 -06:00 committed by GitHub
commit bff71b01c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 79 additions and 28 deletions

View File

@ -6,6 +6,8 @@ from mev_inspect.schemas.blocks import CallAction, CallResult, Trace, TraceType
from mev_inspect.schemas.classified_traces import (
Classification,
ClassifiedTrace,
CallTrace,
DecodedCallTrace,
)
from .specs import ALL_CLASSIFIER_SPECS
@ -69,7 +71,7 @@ class TraceClassifier:
signature, Classification.unknown
)
return ClassifiedTrace(
return DecodedCallTrace(
**trace.dict(),
trace_type=trace.type,
classification=classification,
@ -85,7 +87,7 @@ class TraceClassifier:
gas_used=result.gas_used if result is not None else None,
)
return ClassifiedTrace(
return CallTrace(
**trace.dict(),
trace_type=trace.type,
classification=Classification.unknown,

View File

@ -22,9 +22,29 @@ def write_classified_traces(
db_session,
classified_traces: List[ClassifiedTrace],
) -> None:
models = [
ClassifiedTraceModel(**json.loads(trace.json())) for trace in classified_traces
]
models = []
for trace in classified_traces:
inputs_json = (json.loads(trace.json(include={"inputs"}))["inputs"],)
models.append(
ClassifiedTraceModel(
transaction_hash=trace.transaction_hash,
block_number=trace.block_number,
classification=trace.classification.value,
trace_type=trace.type.value,
trace_address=trace.trace_address,
protocol=str(trace.protocol),
abi_name=trace.abi_name,
function_name=trace.function_name,
function_signature=trace.function_signature,
inputs=inputs_json,
from_address=trace.from_address,
to_address=trace.to_address,
gas=trace.gas,
value=trace.value,
gas_used=trace.gas_used,
error=trace.error,
)
)
db_session.bulk_save_objects(models)
db_session.commit()

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from .blocks import TraceType
from .blocks import Trace
class Classification(Enum):
@ -24,25 +24,25 @@ class Protocol(Enum):
zero_ex = "0x"
class ClassifiedTrace(BaseModel):
class ClassifiedTrace(Trace):
transaction_hash: str
block_number: int
trace_type: TraceType
trace_address: List[int]
classification: Classification
protocol: Optional[Protocol]
abi_name: Optional[str]
function_name: Optional[str]
function_signature: Optional[str]
inputs: Optional[Dict[str, Any]]
error: Optional[str]
to_address: Optional[str]
from_address: Optional[str]
gas: Optional[int]
value: Optional[int]
gas_used: Optional[int]
error: Optional[str]
protocol: Optional[Protocol]
function_name: Optional[str]
function_signature: Optional[str]
inputs: Optional[Dict[str, Any]]
abi_name: Optional[str]
class Config:
validate_assignment = True
json_encoders = {
# a little lazy but fine for now
# this is used for bytes value inputs
@ -50,6 +50,21 @@ class ClassifiedTrace(BaseModel):
}
class CallTrace(ClassifiedTrace):
to_address: str
from_address: str
class DecodedCallTrace(CallTrace):
inputs: Dict[str, Any]
abi_name: str
protocol: Optional[Protocol]
gas: Optional[int]
gas_used: Optional[int]
function_name: Optional[str]
function_signature: Optional[str]
class ClassifierSpec(BaseModel):
abi_name: str
protocol: Optional[Protocol] = None

View File

@ -10,7 +10,6 @@ class Swap(BaseModel):
transaction_hash: str
block_number: int
trace_address: List[int]
protocol: Optional[Protocol]
pool_address: str
from_address: str
to_address: str
@ -18,4 +17,5 @@ class Swap(BaseModel):
token_in_amount: int
token_out_address: str
token_out_amount: int
protocol: Optional[Protocol]
error: Optional[str]

View File

@ -1,7 +1,12 @@
from typing import List
from mev_inspect.schemas.blocks import TraceType
from mev_inspect.schemas.classified_traces import Classification, ClassifiedTrace
from mev_inspect.schemas.classified_traces import (
Classification,
ClassifiedTrace,
CallTrace,
DecodedCallTrace,
)
def make_transfer_trace(
@ -13,10 +18,10 @@ def make_transfer_trace(
token_address: str,
amount: int,
):
return ClassifiedTrace(
return CallTrace(
transaction_hash=transaction_hash,
block_number=block_number,
trace_type=TraceType.call,
type=TraceType.call,
trace_address=trace_address,
classification=Classification.transfer,
from_address=from_address,
@ -25,6 +30,9 @@ def make_transfer_trace(
"recipient": to_address,
"amount": amount,
},
block_hash=str(block_number),
action={},
subtraces=0.0,
)
@ -38,37 +46,43 @@ def make_swap_trace(
recipient_address: str,
recipient_input_key: str,
):
return ClassifiedTrace(
return DecodedCallTrace(
transaction_hash=transaction_hash,
block_number=block_number,
trace_type=TraceType.call,
type=TraceType.call,
trace_address=trace_address,
action={},
subtraces=0,
classification=Classification.swap,
from_address=from_address,
to_address=pool_address,
inputs={recipient_input_key: recipient_address},
abi_name=abi_name,
block_hash=str(block_number),
)
def make_unknown_trace(
block_number,
transaction_hash,
trace_address,
block_number: int,
transaction_hash: str,
trace_address: List[int],
):
return ClassifiedTrace(
transaction_hash=transaction_hash,
block_number=block_number,
trace_type=TraceType.call,
transaction_hash=transaction_hash,
trace_address=trace_address,
action={},
subtraces=0,
block_hash=str(block_number),
type=TraceType.call,
classification=Classification.unknown,
)
def make_many_unknown_traces(
block_number,
transaction_hash,
trace_addresses,
block_number: int,
transaction_hash: str,
trace_addresses: List[List[int]],
) -> List[ClassifiedTrace]:
return [