diff --git a/mev_inspect/processor.py b/mev_inspect/processor.py index 6bc00f3..7e6bba8 100644 --- a/mev_inspect/processor.py +++ b/mev_inspect/processor.py @@ -1,28 +1,87 @@ -from typing import List +from typing import Dict, List, Optional -from mev_inspect.schemas.blocks import Block, TraceType +from mev_inspect.abi import get_abi +from mev_inspect.decode import ABIDecoder +from mev_inspect.schemas.blocks import Block, Trace, TraceType from mev_inspect.schemas.classifications import ( Classification, ClassificationType, + DecodeSpec, ) class Processor: - def __init__(self) -> None: - pass + def __init__(self, decode_specs: List[DecodeSpec]) -> None: + # TODO - index by contract_addresses for speed + self._decode_specs = decode_specs + self._decoders_by_abi_name: Dict[str, ABIDecoder] = {} + + for spec in self._decode_specs: + abi = get_abi(spec.abi_name) + + if abi is None: + raise ValueError(f"No ABI found for {spec.abi_name}") + + decoder = ABIDecoder(abi) + self._decoders_by_abi_name[spec.abi_name] = decoder def process( self, block: Block, ) -> List[Classification]: return [ - Classification( - transaction_hash=trace.transaction_hash, - block_number=trace.block_number, - trace_type=trace.type, - trace_address=trace.trace_address, - classification_type=ClassificationType.unknown, - ) + self._classify(trace) for trace in block.traces if trace.type != TraceType.reward ] + + def _classify(self, trace: Trace) -> Classification: + if trace.type == TraceType.call: + classification = self._classify_call(trace) + if classification is not None: + return classification + + return self._build_unknown_classification(trace) + + def _classify_call(self, trace) -> Optional[Classification]: + to_address = trace.action["to"] + + for spec in self._decode_specs: + if spec.valid_contract_addresses is not None: + if to_address is None: + continue + + if to_address not in spec.valid_contract_addresses: + continue + + decoder = self._decoders_by_abi_name[spec.abi_name] + call_data = decoder.decode(trace.action["input"]) + + if call_data is not None: + return Classification( + transaction_hash=trace.transaction_hash, + block_number=trace.block_number, + trace_type=trace.type, + trace_address=trace.trace_address, + classification_type=ClassificationType.unknown, + protocol=spec.protocol, + function_name=call_data.function_name, + function_signature=call_data.function_signature, + intputs=call_data.inputs, + ) + + return None + + @staticmethod + def _build_unknown_classification(trace): + return Classification( + transaction_hash=trace.transaction_hash, + block_number=trace.block_number, + trace_type=trace.type, + trace_address=trace.trace_address, + classification_type=ClassificationType.unknown, + protocol=None, + function_name=None, + function_signature=None, + intputs=None, + ) diff --git a/mev_inspect/schemas/classifications.py b/mev_inspect/schemas/classifications.py index 719e5dc..8e92881 100644 --- a/mev_inspect/schemas/classifications.py +++ b/mev_inspect/schemas/classifications.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List +from typing import Any, Dict, List, Optional from pydantic import BaseModel @@ -8,6 +8,12 @@ from .blocks import TraceType class ClassificationType(Enum): unknown = "unknown" + swap = "swap" + + +class Protocol(Enum): + uniswap_v2 = "uniswap_v2" + sushiswap = "sushiswap" class Classification(BaseModel): @@ -16,3 +22,13 @@ class Classification(BaseModel): trace_type: TraceType trace_address: List[int] classification_type: ClassificationType + protocol: Optional[Protocol] + function_name: Optional[str] + function_signature: Optional[str] + inputs: Optional[Dict[str, Any]] + + +class DecodeSpec(BaseModel): + abi_name: str + protocol: Optional[Protocol] = None + valid_contract_addresses: Optional[List[str]] = None diff --git a/testing_file.py b/testing_file.py index 350009f..c3f773f 100644 --- a/testing_file.py +++ b/testing_file.py @@ -1,9 +1,30 @@ import argparse +import json from web3 import Web3 from mev_inspect import block from mev_inspect.processor import Processor +from mev_inspect.schemas.classifications import DecodeSpec, Protocol + + +SUSHISWAP_ROUTER_ADDRESS = "0xd9e1cE17f2641f24aE83637ab66a2cca9C378B9F" +UNISWAP_V2_ROUTER_ADDRESS = "0x7a250d5630B4cF539739dF2C5dAcb4c659F2488D" + + +DECODE_SPECS = [ + DecodeSpec( + abi_name="UniswapV2Router", + protocol=Protocol.uniswap_v2, + valid_contract_addresses=[UNISWAP_V2_ROUTER_ADDRESS], + ), + DecodeSpec( + abi_name="UniswapV2Router", + protocol=Protocol.sushiswap, + valid_contract_addresses=[SUSHISWAP_ROUTER_ADDRESS], + ), + DecodeSpec(abi_name="UniswapV2Pair"), +] def inspect_block(base_provider, block_number): @@ -19,11 +40,24 @@ def inspect_block(base_provider, block_number): ) print(f"Total transactions: {total_transactions}") - processor = Processor() + processor = Processor(DECODE_SPECS) classifications = processor.process(block_data) print(f"Returned {len(classifications)} classifications") + stats = {} + + for classification in classifications: + protocol = classification.protocol + signature = classification.function_signature + + protocol_stats = stats.get(protocol, {}) + signature_count = protocol_stats.get(signature, 0) + protocol_stats[signature] = signature_count + 1 + stats[protocol] = protocol_stats + + print(json.dumps(dict(stats.items()), indent=4)) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Inspect some blocks.")