Merge pull request #27 from lukevs/add-inspectors-and-classifications

Add inspectors and classifications
This commit is contained in:
Robert Miller 2021-07-22 13:05:20 -04:00 committed by GitHub
commit 674e4a1c6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 148 additions and 32 deletions

36
mev_inspect/decode.py Normal file
View File

@ -0,0 +1,36 @@
from typing import Dict, Optional
from hexbytes import HexBytes
from eth_abi import decode_abi
from mev_inspect.schemas.abi import ABI, ABIFunctionDescription
from mev_inspect.schemas.call_data import CallData
class ABIDecoder:
def __init__(self, abi: ABI):
self._functions_by_selector: Dict[str, ABIFunctionDescription] = {
description.get_selector(): description
for description in abi
if isinstance(description, ABIFunctionDescription)
}
def decode(self, data: str) -> Optional[CallData]:
hex_data = HexBytes(data)
selector, params = hex_data[:4], hex_data[4:]
func = self._functions_by_selector.get(selector)
if func is None:
return None
names = [input.name for input in func.inputs]
types = [input.type for input in func.inputs]
decoded = decode_abi(types, params)
return CallData(
function_name=func.name,
function_signature=func.get_signature(),
inputs={name: value for name, value in zip(names, decoded)},
)

View File

@ -0,0 +1 @@
from .base import Inspector

View File

@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
from typing import Optional
from mev_inspect.schemas.blocks import NestedTrace
from mev_inspect.schemas.classifications import Classification
class Inspector(ABC):
@abstractmethod
def inspect(self, nested_trace: NestedTrace) -> Optional[Classification]:
pass

View File

@ -1,9 +1,14 @@
import json import json
from typing import Optional
from web3 import Web3 from web3 import Web3
from mev_inspect import utils from mev_inspect import utils
from mev_inspect.config import load_config from mev_inspect.config import load_config
from mev_inspect.schemas.blocks import NestedTrace, TraceType
from mev_inspect.schemas.classifications import Classification
from .base import Inspector
config = load_config() config = load_config()
@ -14,7 +19,7 @@ sushiswap_router_address = config["ADDRESSES"]["SushiswapV2Router"]
uniswap_pair_abi = json.loads(config["ABI"]["UniswapV2Pair"]) uniswap_pair_abi = json.loads(config["ABI"]["UniswapV2Pair"])
class UniswapInspector: class UniswapInspector(Inspector):
def __init__(self, base_provider) -> None: def __init__(self, base_provider) -> None:
self.w3 = Web3(base_provider) self.w3 = Web3(base_provider)
@ -79,18 +84,20 @@ class UniswapInspector:
return result return result
def inspect(self, calls): def inspect(self, nested_trace: NestedTrace) -> Optional[Classification]:
for call in calls: trace = nested_trace.trace
print("\n", call)
if ( if (
call["type"] == "call" trace.type == TraceType.call
and ( and (
call["action"]["to"] == uniswap_router_address.lower() trace.action["to"] == uniswap_router_address.lower()
or call["action"]["to"] == sushiswap_router_address.lower() or trace.action["to"] == sushiswap_router_address.lower()
) )
and utils.check_trace_for_signature( and utils.check_trace_for_signature(
call, self.uniswap_router_trade_signatures trace, self.uniswap_router_trade_signatures
) )
): ):
# print("WIP, here is where there is a call that matches what we are looking for") # print("WIP, here is where there is a call that matches what we are looking for")
1 == 1 1 == 1
return None

View File

@ -1,15 +1,43 @@
from mev_inspect.schemas.utils import to_original_json_dict from typing import List
from mev_inspect.inspectors import Inspector
from mev_inspect.schemas.blocks import Block, NestedTrace, TraceType
from mev_inspect.schemas.classifications import (
Classification,
UnknownClassification,
)
from mev_inspect.traces import as_nested_traces
class Processor: class Processor:
def __init__(self, base_provider, inspectors) -> None: def __init__(self, inspectors: List[Inspector]) -> None:
self.base_provider = base_provider self._inspectors = inspectors
self.inspectors = inspectors
def get_transaction_evaluations(self, block_data): def get_transaction_evaluations(
for transaction_hash in block_data.transaction_hashes: self,
traces = block_data.get_filtered_traces(transaction_hash) block: Block,
traces_json = [to_original_json_dict(trace) for trace in traces] ) -> List[Classification]:
transaction_traces = (
trace for trace in block.traces if trace.type != TraceType.reward
)
for inspector in self.inspectors: return [
inspector.inspect(traces_json) self._run_inspectors(nested_trace)
for nested_trace in as_nested_traces(transaction_traces)
]
def _run_inspectors(self, nested_trace: NestedTrace) -> Classification:
for inspector in self._inspectors:
classification = inspector.inspect(nested_trace)
if classification is not None:
return classification
internal_classifications = [
self._run_inspectors(subtrace) for subtrace in nested_trace.subtraces
]
return UnknownClassification(
trace=nested_trace.trace,
internal_classifications=internal_classifications,
)

View File

@ -24,6 +24,7 @@ NON_FUNCTION_DESCRIPTION_TYPES = Union[
class ABIDescriptionInput(BaseModel): class ABIDescriptionInput(BaseModel):
name: str
type: str type: str

View File

@ -0,0 +1,9 @@
from typing import Any, Dict
from pydantic import BaseModel
class CallData(BaseModel):
function_name: str
function_signature: str
inputs: Dict[str, Any]

View File

@ -0,0 +1,14 @@
from typing import List
from pydantic import BaseModel
from .blocks import Trace
class Classification(BaseModel):
pass
class UnknownClassification(Classification):
trace: Trace
internal_classifications: List[Classification]

View File

@ -2,14 +2,16 @@ from typing import List
from hexbytes.main import HexBytes from hexbytes.main import HexBytes
from mev_inspect.schemas.blocks import Trace
def check_trace_for_signature(trace: dict, signatures: List[str]):
if trace["action"]["input"] == None: def check_trace_for_signature(trace: Trace, signatures: List[str]):
if trace.action["input"] == None:
return False return False
## Iterate over all signatures, and if our trace matches any of them set it to True ## Iterate over all signatures, and if our trace matches any of them set it to True
for signature in signatures: for signature in signatures:
if HexBytes(trace["action"]["input"]).startswith(signature): if HexBytes(trace.action["input"]).startswith(signature):
## Note that we are turning the input into hex bytes here, which seems to be fine ## Note that we are turning the input into hex bytes here, which seems to be fine
## Working with strings was doing weird things ## Working with strings was doing weird things
return True return True

View File

@ -3,7 +3,7 @@ import argparse
from web3 import Web3 from web3 import Web3
from mev_inspect import block from mev_inspect import block
from mev_inspect.inspector_uniswap import UniswapInspector from mev_inspect.inspectors.uniswap import UniswapInspector
from mev_inspect.processor import Processor from mev_inspect.processor import Processor
parser = argparse.ArgumentParser(description="Inspect some blocks.") parser = argparse.ArgumentParser(description="Inspect some blocks.")
@ -24,11 +24,18 @@ base_provider = Web3.HTTPProvider(args.rpc)
## Get block data that we need ## Get block data that we need
block_data = block.create_from_block_number(args.block_number[0], base_provider) block_data = block.create_from_block_number(args.block_number[0], base_provider)
print(f"Total traces: {len(block_data.traces)}")
total_transactions = len(
set(t.transaction_hash for t in block_data.traces if t.transaction_hash is not None)
)
print(f"Total transactions: {total_transactions}")
## Build a Uniswap inspector ## Build a Uniswap inspector
uniswap_inspector = UniswapInspector(base_provider) uniswap_inspector = UniswapInspector(base_provider)
## Create a processor, pass in an ARRAY of inspects ## Create a processor, pass in an ARRAY of inspects
processor = Processor(base_provider, [uniswap_inspector, uniswap_inspector]) processor = Processor([uniswap_inspector, uniswap_inspector])
processor.get_transaction_evaluations(block_data) classifications = processor.get_transaction_evaluations(block_data)
print(f"Returned {len(classifications)} classifications")