Use decode specs to parse inputs

This commit is contained in:
Luke Van Seters 2021-07-25 20:50:37 -04:00
parent de733ecfb3
commit c2dc095c7d
3 changed files with 122 additions and 13 deletions

View File

@ -1,28 +1,87 @@
from typing import List from typing import Dict, List, Optional
from mev_inspect.schemas.blocks import Block, TraceType from mev_inspect.abi import get_abi
from mev_inspect.decode import ABIDecoder
from mev_inspect.schemas.blocks import Block, Trace, TraceType
from mev_inspect.schemas.classifications import ( from mev_inspect.schemas.classifications import (
Classification, Classification,
ClassificationType, ClassificationType,
DecodeSpec,
) )
class Processor: class Processor:
def __init__(self) -> None: def __init__(self, decode_specs: List[DecodeSpec]) -> None:
pass # TODO - index by contract_addresses for speed
self._decode_specs = decode_specs
self._decoders_by_abi_name: Dict[str, ABIDecoder] = {}
for spec in self._decode_specs:
abi = get_abi(spec.abi_name)
if abi is None:
raise ValueError(f"No ABI found for {spec.abi_name}")
decoder = ABIDecoder(abi)
self._decoders_by_abi_name[spec.abi_name] = decoder
def process( def process(
self, self,
block: Block, block: Block,
) -> List[Classification]: ) -> List[Classification]:
return [ return [
Classification( self._classify(trace)
transaction_hash=trace.transaction_hash,
block_number=trace.block_number,
trace_type=trace.type,
trace_address=trace.trace_address,
classification_type=ClassificationType.unknown,
)
for trace in block.traces for trace in block.traces
if trace.type != TraceType.reward if trace.type != TraceType.reward
] ]
def _classify(self, trace: Trace) -> Classification:
if trace.type == TraceType.call:
classification = self._classify_call(trace)
if classification is not None:
return classification
return self._build_unknown_classification(trace)
def _classify_call(self, trace) -> Optional[Classification]:
to_address = trace.action["to"]
for spec in self._decode_specs:
if spec.valid_contract_addresses is not None:
if to_address is None:
continue
if to_address not in spec.valid_contract_addresses:
continue
decoder = self._decoders_by_abi_name[spec.abi_name]
call_data = decoder.decode(trace.action["input"])
if call_data is not None:
return Classification(
transaction_hash=trace.transaction_hash,
block_number=trace.block_number,
trace_type=trace.type,
trace_address=trace.trace_address,
classification_type=ClassificationType.unknown,
protocol=spec.protocol,
function_name=call_data.function_name,
function_signature=call_data.function_signature,
intputs=call_data.inputs,
)
return None
@staticmethod
def _build_unknown_classification(trace):
return Classification(
transaction_hash=trace.transaction_hash,
block_number=trace.block_number,
trace_type=trace.type,
trace_address=trace.trace_address,
classification_type=ClassificationType.unknown,
protocol=None,
function_name=None,
function_signature=None,
intputs=None,
)

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import List from typing import Any, Dict, List, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -8,6 +8,12 @@ from .blocks import TraceType
class ClassificationType(Enum): class ClassificationType(Enum):
unknown = "unknown" unknown = "unknown"
swap = "swap"
class Protocol(Enum):
uniswap_v2 = "uniswap_v2"
sushiswap = "sushiswap"
class Classification(BaseModel): class Classification(BaseModel):
@ -16,3 +22,13 @@ class Classification(BaseModel):
trace_type: TraceType trace_type: TraceType
trace_address: List[int] trace_address: List[int]
classification_type: ClassificationType classification_type: ClassificationType
protocol: Optional[Protocol]
function_name: Optional[str]
function_signature: Optional[str]
inputs: Optional[Dict[str, Any]]
class DecodeSpec(BaseModel):
abi_name: str
protocol: Optional[Protocol] = None
valid_contract_addresses: Optional[List[str]] = None

View File

@ -1,9 +1,30 @@
import argparse import argparse
import json
from web3 import Web3 from web3 import Web3
from mev_inspect import block from mev_inspect import block
from mev_inspect.processor import Processor from mev_inspect.processor import Processor
from mev_inspect.schemas.classifications import DecodeSpec, Protocol
SUSHISWAP_ROUTER_ADDRESS = "0xd9e1cE17f2641f24aE83637ab66a2cca9C378B9F"
UNISWAP_V2_ROUTER_ADDRESS = "0x7a250d5630B4cF539739dF2C5dAcb4c659F2488D"
DECODE_SPECS = [
DecodeSpec(
abi_name="UniswapV2Router",
protocol=Protocol.uniswap_v2,
valid_contract_addresses=[UNISWAP_V2_ROUTER_ADDRESS],
),
DecodeSpec(
abi_name="UniswapV2Router",
protocol=Protocol.sushiswap,
valid_contract_addresses=[SUSHISWAP_ROUTER_ADDRESS],
),
DecodeSpec(abi_name="UniswapV2Pair"),
]
def inspect_block(base_provider, block_number): def inspect_block(base_provider, block_number):
@ -19,11 +40,24 @@ def inspect_block(base_provider, block_number):
) )
print(f"Total transactions: {total_transactions}") print(f"Total transactions: {total_transactions}")
processor = Processor() processor = Processor(DECODE_SPECS)
classifications = processor.process(block_data) classifications = processor.process(block_data)
print(f"Returned {len(classifications)} classifications") print(f"Returned {len(classifications)} classifications")
stats = {}
for classification in classifications:
protocol = classification.protocol
signature = classification.function_signature
protocol_stats = stats.get(protocol, {})
signature_count = protocol_stats.get(signature, 0)
protocol_stats[signature] = signature_count + 1
stats[protocol] = protocol_stats
print(json.dumps(dict(stats.items()), indent=4))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Inspect some blocks.") parser = argparse.ArgumentParser(description="Inspect some blocks.")