2021-12-13 18:46:39 -05:00

105 lines
3.4 KiB
Python

from typing import Dict, List, Optional
from mev_inspect.abi import get_abi
from mev_inspect.decode import ABIDecoder
from mev_inspect.schemas.blocks import CallAction, CallResult
from mev_inspect.schemas.traces import (
CallTrace,
Classification,
ClassifiedTrace,
DecodedCallTrace,
Trace,
TraceType,
)
from .specs import ALL_CLASSIFIER_SPECS
class TraceClassifier:
def __init__(self) -> None:
self._classifier_specs = ALL_CLASSIFIER_SPECS
self._decoders_by_abi_name: Dict[str, ABIDecoder] = {}
for spec in self._classifier_specs:
abi = get_abi(spec.abi_name, spec.protocol)
if abi is None:
raise ValueError(f"No ABI found for {spec.abi_name}")
decoder = ABIDecoder(abi)
self._decoders_by_abi_name[spec.abi_name] = decoder
def classify(
self,
traces: List[Trace],
) -> List[ClassifiedTrace]:
return [
self._classify_trace(trace)
for trace in traces
if trace.type != TraceType.reward
]
def _classify_trace(self, trace: Trace) -> ClassifiedTrace:
if trace.type == TraceType.call:
classified_trace = self._classify_call(trace)
if classified_trace is not None:
return classified_trace
return ClassifiedTrace(
**trace.dict(),
trace_type=trace.type,
classification=Classification.unknown,
)
def _classify_call(self, trace) -> Optional[ClassifiedTrace]:
action = CallAction(**trace.action)
result = CallResult(**trace.result) if trace.result is not None else None
for spec in self._classifier_specs:
if spec.valid_contract_addresses is not None:
lower_valid_addresses = {
address.lower() for address in spec.valid_contract_addresses
}
if action.to not in lower_valid_addresses:
continue
decoder = self._decoders_by_abi_name[spec.abi_name]
call_data = decoder.decode(action.input)
if call_data is not None:
signature = call_data.function_signature
classifier = spec.classifiers.get(signature)
classification = (
Classification.unknown
if classifier is None
else classifier.get_classification()
)
return DecodedCallTrace(
**trace.dict(),
trace_type=trace.type,
classification=classification,
protocol=spec.protocol,
abi_name=spec.abi_name,
function_name=call_data.function_name,
function_signature=signature,
inputs=call_data.inputs,
to_address=action.to,
from_address=action.from_,
value=action.value,
gas=action.gas,
gas_used=result.gas_used if result is not None else None,
)
return CallTrace(
**trace.dict(),
trace_type=trace.type,
classification=Classification.unknown,
to_address=action.to,
from_address=action.from_,
value=action.value,
gas=action.gas,
gas_used=result.gas_used if result is not None else None,
)