diff --git a/mev_inspect/classifiers/trace.py b/mev_inspect/classifiers/trace.py index 92b985c..9968eb6 100644 --- a/mev_inspect/classifiers/trace.py +++ b/mev_inspect/classifiers/trace.py @@ -6,6 +6,8 @@ from mev_inspect.schemas.blocks import CallAction, CallResult, Trace, TraceType from mev_inspect.schemas.classified_traces import ( Classification, ClassifiedTrace, + CallTrace, + DecodedCallTrace, ) from .specs import ALL_CLASSIFIER_SPECS @@ -69,7 +71,7 @@ class TraceClassifier: signature, Classification.unknown ) - return ClassifiedTrace( + return DecodedCallTrace( **trace.dict(), trace_type=trace.type, classification=classification, @@ -85,7 +87,7 @@ class TraceClassifier: gas_used=result.gas_used if result is not None else None, ) - return ClassifiedTrace( + return CallTrace( **trace.dict(), trace_type=trace.type, classification=Classification.unknown, diff --git a/mev_inspect/crud/classified_traces.py b/mev_inspect/crud/classified_traces.py index f0300f5..ad641fb 100644 --- a/mev_inspect/crud/classified_traces.py +++ b/mev_inspect/crud/classified_traces.py @@ -22,9 +22,29 @@ def write_classified_traces( db_session, classified_traces: List[ClassifiedTrace], ) -> None: - models = [ - ClassifiedTraceModel(**json.loads(trace.json())) for trace in classified_traces - ] + models = [] + for trace in classified_traces: + inputs_json = (json.loads(trace.json(include={"inputs"}))["inputs"],) + models.append( + ClassifiedTraceModel( + transaction_hash=trace.transaction_hash, + block_number=trace.block_number, + classification=trace.classification.value, + trace_type=trace.type.value, + trace_address=trace.trace_address, + protocol=str(trace.protocol), + abi_name=trace.abi_name, + function_name=trace.function_name, + function_signature=trace.function_signature, + inputs=inputs_json, + from_address=trace.from_address, + to_address=trace.to_address, + gas=trace.gas, + value=trace.value, + gas_used=trace.gas_used, + error=trace.error, + ) + ) db_session.bulk_save_objects(models) db_session.commit() diff --git a/mev_inspect/schemas/classified_traces.py b/mev_inspect/schemas/classified_traces.py index d8d7f41..c7da09e 100644 --- a/mev_inspect/schemas/classified_traces.py +++ b/mev_inspect/schemas/classified_traces.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel -from .blocks import TraceType +from .blocks import Trace class Classification(Enum): @@ -24,25 +24,25 @@ class Protocol(Enum): zero_ex = "0x" -class ClassifiedTrace(BaseModel): +class ClassifiedTrace(Trace): transaction_hash: str block_number: int - trace_type: TraceType trace_address: List[int] classification: Classification - protocol: Optional[Protocol] - abi_name: Optional[str] - function_name: Optional[str] - function_signature: Optional[str] - inputs: Optional[Dict[str, Any]] + error: Optional[str] to_address: Optional[str] from_address: Optional[str] gas: Optional[int] value: Optional[int] gas_used: Optional[int] - error: Optional[str] + protocol: Optional[Protocol] + function_name: Optional[str] + function_signature: Optional[str] + inputs: Optional[Dict[str, Any]] + abi_name: Optional[str] class Config: + validate_assignment = True json_encoders = { # a little lazy but fine for now # this is used for bytes value inputs @@ -50,6 +50,21 @@ class ClassifiedTrace(BaseModel): } +class CallTrace(ClassifiedTrace): + to_address: str + from_address: str + + +class DecodedCallTrace(CallTrace): + inputs: Dict[str, Any] + abi_name: str + protocol: Optional[Protocol] + gas: Optional[int] + gas_used: Optional[int] + function_name: Optional[str] + function_signature: Optional[str] + + class ClassifierSpec(BaseModel): abi_name: str protocol: Optional[Protocol] = None diff --git a/mev_inspect/schemas/swaps.py b/mev_inspect/schemas/swaps.py index 2350142..f17068a 100644 --- a/mev_inspect/schemas/swaps.py +++ b/mev_inspect/schemas/swaps.py @@ -10,7 +10,6 @@ class Swap(BaseModel): transaction_hash: str block_number: int trace_address: List[int] - protocol: Optional[Protocol] pool_address: str from_address: str to_address: str @@ -18,4 +17,5 @@ class Swap(BaseModel): token_in_amount: int token_out_address: str token_out_amount: int + protocol: Optional[Protocol] error: Optional[str] diff --git a/tests/helpers.py b/tests/helpers.py index 0ebcf0f..540a327 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,7 +1,12 @@ from typing import List from mev_inspect.schemas.blocks import TraceType -from mev_inspect.schemas.classified_traces import Classification, ClassifiedTrace +from mev_inspect.schemas.classified_traces import ( + Classification, + ClassifiedTrace, + CallTrace, + DecodedCallTrace, +) def make_transfer_trace( @@ -13,10 +18,10 @@ def make_transfer_trace( token_address: str, amount: int, ): - return ClassifiedTrace( + return CallTrace( transaction_hash=transaction_hash, block_number=block_number, - trace_type=TraceType.call, + type=TraceType.call, trace_address=trace_address, classification=Classification.transfer, from_address=from_address, @@ -25,6 +30,9 @@ def make_transfer_trace( "recipient": to_address, "amount": amount, }, + block_hash=str(block_number), + action={}, + subtraces=0.0, ) @@ -38,37 +46,43 @@ def make_swap_trace( recipient_address: str, recipient_input_key: str, ): - return ClassifiedTrace( + return DecodedCallTrace( transaction_hash=transaction_hash, block_number=block_number, - trace_type=TraceType.call, + type=TraceType.call, trace_address=trace_address, + action={}, + subtraces=0, classification=Classification.swap, from_address=from_address, to_address=pool_address, inputs={recipient_input_key: recipient_address}, abi_name=abi_name, + block_hash=str(block_number), ) def make_unknown_trace( - block_number, - transaction_hash, - trace_address, + block_number: int, + transaction_hash: str, + trace_address: List[int], ): return ClassifiedTrace( - transaction_hash=transaction_hash, block_number=block_number, - trace_type=TraceType.call, + transaction_hash=transaction_hash, trace_address=trace_address, + action={}, + subtraces=0, + block_hash=str(block_number), + type=TraceType.call, classification=Classification.unknown, ) def make_many_unknown_traces( - block_number, - transaction_hash, - trace_addresses, + block_number: int, + transaction_hash: str, + trace_addresses: List[List[int]], ) -> List[ClassifiedTrace]: return [