diff --git a/examples/uniswap_inspect.py b/examples/uniswap_inspect.py deleted file mode 100644 index 080284f..0000000 --- a/examples/uniswap_inspect.py +++ /dev/null @@ -1,41 +0,0 @@ -import argparse - -from web3 import Web3 - -from mev_inspect import block -from mev_inspect.inspectors.uniswap import UniswapInspector -from mev_inspect.processor import Processor - -parser = argparse.ArgumentParser(description="Inspect some blocks.") -parser.add_argument( - "-block_number", - metavar="b", - type=int, - nargs="+", - help="the block number you are targetting, eventually this will need to be changed", -) -parser.add_argument( - "-rpc", metavar="r", help="rpc endpoint, this needs to have parity style traces" -) -args = parser.parse_args() - -## Set up the base provider, but don't wrap it in web3 so we can make requests to it with make_request() -base_provider = Web3.HTTPProvider(args.rpc) - -## Get block data that we need -block_data = block.create_from_block_number(args.block_number[0], base_provider) -print(f"Total traces: {len(block_data.traces)}") - -total_transactions = len( - set(t.transaction_hash for t in block_data.traces if t.transaction_hash is not None) -) -print(f"Total transactions: {total_transactions}") - -## Build a Uniswap inspector -uniswap_inspector = UniswapInspector(base_provider) - -## Create a processor, pass in an ARRAY of inspects -processor = Processor([uniswap_inspector, uniswap_inspector]) - -classifications = processor.get_transaction_evaluations(block_data) -print(f"Returned {len(classifications)} classifications") diff --git a/mev_inspect/classifier_specs.py b/mev_inspect/classifier_specs.py new file mode 100644 index 0000000..aef1917 --- /dev/null +++ b/mev_inspect/classifier_specs.py @@ -0,0 +1,37 @@ +from mev_inspect.schemas.classified_traces import ( + Classification, + ClassifierSpec, + Protocol, +) + + +SUSHISWAP_ROUTER_ADDRESS = "0xd9e1cE17f2641f24aE83637ab66a2cca9C378B9F" +UNISWAP_V2_ROUTER_ADDRESS = "0x7a250d5630B4cF539739dF2C5dAcb4c659F2488D" + + +CLASSIFIER_SPECS = [ + ClassifierSpec( + abi_name="UniswapV2Router", + protocol=Protocol.uniswap_v2, + valid_contract_addresses=[UNISWAP_V2_ROUTER_ADDRESS], + ), + ClassifierSpec( + abi_name="UniswapV2Router", + protocol=Protocol.sushiswap, + valid_contract_addresses=[SUSHISWAP_ROUTER_ADDRESS], + ), + ClassifierSpec( + abi_name="ERC20", + classifications={ + "transferFrom(address,address,uint256)": Classification.transfer, + "transfer(address,uint256)": Classification.transfer, + "burn(address)": Classification.burn, + }, + ), + ClassifierSpec( + abi_name="UniswapV2Pair", + classifications={ + "swap(uint256,uint256,address,bytes)": Classification.swap, + }, + ), +] diff --git a/mev_inspect/crud/__init__.py b/mev_inspect/crud/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mev_inspect/crud/classified_traces.py b/mev_inspect/crud/classified_traces.py new file mode 100644 index 0000000..a175d82 --- /dev/null +++ b/mev_inspect/crud/classified_traces.py @@ -0,0 +1,17 @@ +import json +from typing import List + +from mev_inspect.models.classified_traces import ClassifiedTraceModel +from mev_inspect.schemas.classified_traces import ClassifiedTrace + + +def write_classified_traces( + db_connection, + classified_traces: List[ClassifiedTrace], +) -> None: + models = [ + ClassifiedTraceModel(**json.loads(trace.json())) for trace in classified_traces + ] + + db_connection.bulk_save_objects(models) + db_connection.commit() diff --git a/mev_inspect/db.py b/mev_inspect/db.py new file mode 100644 index 0000000..862fe96 --- /dev/null +++ b/mev_inspect/db.py @@ -0,0 +1,13 @@ +import os + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + + +def get_engine(): + return create_engine(os.getenv("SQLALCHEMY_DATABASE_URI")) + + +def get_session(): + Session = sessionmaker(bind=get_engine()) + return Session() diff --git a/mev_inspect/inspectors/__init__.py b/mev_inspect/inspectors/__init__.py deleted file mode 100644 index 4ece86a..0000000 --- a/mev_inspect/inspectors/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base import Inspector diff --git a/mev_inspect/inspectors/base.py b/mev_inspect/inspectors/base.py deleted file mode 100644 index 2bb315f..0000000 --- a/mev_inspect/inspectors/base.py +++ /dev/null @@ -1,11 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional - -from mev_inspect.schemas.blocks import NestedTrace -from mev_inspect.schemas.classifications import Classification - - -class Inspector(ABC): - @abstractmethod - def inspect(self, nested_trace: NestedTrace) -> Optional[Classification]: - pass diff --git a/mev_inspect/inspectors/uniswap.py b/mev_inspect/inspectors/uniswap.py deleted file mode 100644 index 20aacda..0000000 --- a/mev_inspect/inspectors/uniswap.py +++ /dev/null @@ -1,103 +0,0 @@ -import json -from typing import Optional - -from web3 import Web3 - -from mev_inspect import utils -from mev_inspect.config import load_config -from mev_inspect.schemas.blocks import NestedTrace, TraceType -from mev_inspect.schemas.classifications import Classification - -from .base import Inspector - -config = load_config() - -uniswap_router_abi = json.loads(config["ABI"]["UniswapV2Router"]) -uniswap_router_address = config["ADDRESSES"]["UniswapV2Router"] -sushiswap_router_address = config["ADDRESSES"]["SushiswapV2Router"] - -uniswap_pair_abi = json.loads(config["ABI"]["UniswapV2Pair"]) - - -class UniswapInspector(Inspector): - def __init__(self, base_provider) -> None: - self.w3 = Web3(base_provider) - - self.trading_functions = self.get_trading_functions() - self.uniswap_v2_router_contract = self.w3.eth.contract( - abi=uniswap_router_abi, address=uniswap_router_address - ) - self.uniswap_router_trade_signatures = self.get_router_signatures() - - self.uniswap_v2_pair_contract = self.w3.eth.contract(abi=uniswap_pair_abi) - self.uniswap_v2_pair_swap_signatures = ( - self.uniswap_v2_pair_contract.functions.swap( - 0, 0, uniswap_router_address, "" - ).selector - ) ## Note the address here doesn't matter, but it must be filled out - self.uniswap_v2_pair_reserves_signatures = ( - self.uniswap_v2_pair_contract.functions.getReserves().selector - ) ## Called "checksigs" in mev-inpsect.ts - - print("Built Uniswap inspector") - - def get_trading_functions(self): - ## Gets all functions used for swapping - result = [] - - ## For each entry in the ABI - for abi in uniswap_router_abi: - ## Check to see if the entry is a function and if it is if the function's name starts with swap - if abi["type"] == "function" and abi["name"].startswith("swap"): - ## If so add it to our array - result.append(abi["name"]) - - return result - - def get_router_signatures(self): - ## Gets the selector / function signatures of all the router swap functions - result = [] - - ## For each entry in the ABI - for abi in uniswap_router_abi: - ## Check to see if the entry is a function and if it is if the function's name starts with swap - if abi["type"] == "function" and abi["name"].startswith("swap"): - ## Add a parantheses - function = abi["name"] + "(" - - ## For each input in the function's input - for input in abi["inputs"]: - - ## Concat them into a string with commas - function = function + input["internalType"] + "," - - ## Take off the last comma, add a ')' to close the parentheses - function = function[:-1] + ")" - - ## The result looks like this: 'swapETHForExactTokens(uint256,address[],address,uint256)' - - ## Take the first 4 bytes of the sha3 hash of the above string. - selector = Web3.sha3(text=function)[0:4] - - ## Add that to an array - result.append(selector) - - return result - - def inspect(self, nested_trace: NestedTrace) -> Optional[Classification]: - trace = nested_trace.trace - - if ( - trace.type == TraceType.call - and ( - trace.action["to"] == uniswap_router_address.lower() - or trace.action["to"] == sushiswap_router_address.lower() - ) - and utils.check_trace_for_signature( - trace, self.uniswap_router_trade_signatures - ) - ): - # print("WIP, here is where there is a call that matches what we are looking for") - 1 == 1 - - return None diff --git a/mev_inspect/models/__init__.py b/mev_inspect/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mev_inspect/models/base.py b/mev_inspect/models/base.py new file mode 100644 index 0000000..860e542 --- /dev/null +++ b/mev_inspect/models/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() diff --git a/mev_inspect/models/classified_traces.py b/mev_inspect/models/classified_traces.py new file mode 100644 index 0000000..bf898e0 --- /dev/null +++ b/mev_inspect/models/classified_traces.py @@ -0,0 +1,24 @@ +from sqlalchemy import Column, JSON, Numeric, String + +from .base import Base + + +class ClassifiedTraceModel(Base): + __tablename__ = "classified_traces" + + transaction_hash = Column(String, primary_key=True) + block_number = Column(Numeric, nullable=False) + classification = Column(String, nullable=False) + trace_type = Column(String, nullable=False) + trace_address = Column(String, nullable=False) + protocol = Column(String, nullable=True) + abi_name = Column(String, nullable=True) + function_name = Column(String, nullable=True) + function_signature = Column(String, nullable=True) + inputs = Column(JSON, nullable=True) + from_address = Column(String, nullable=True) + to_address = Column(String, nullable=True) + gas = Column(Numeric, nullable=True) + value = Column(Numeric, nullable=True) + gas_used = Column(Numeric, nullable=True) + error = Column(String, nullable=True) diff --git a/mev_inspect/processor.py b/mev_inspect/processor.py deleted file mode 100644 index 6efa84d..0000000 --- a/mev_inspect/processor.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import List - -from mev_inspect.inspectors import Inspector -from mev_inspect.schemas.blocks import Block, NestedTrace, TraceType -from mev_inspect.schemas.classifications import ( - Classification, - UnknownClassification, -) -from mev_inspect.traces import as_nested_traces - - -class Processor: - def __init__(self, inspectors: List[Inspector]) -> None: - self._inspectors = inspectors - - def get_transaction_evaluations( - self, - block: Block, - ) -> List[Classification]: - transaction_traces = ( - trace for trace in block.traces if trace.type != TraceType.reward - ) - - return [ - self._run_inspectors(nested_trace) - for nested_trace in as_nested_traces(transaction_traces) - ] - - def _run_inspectors(self, nested_trace: NestedTrace) -> Classification: - for inspector in self._inspectors: - classification = inspector.inspect(nested_trace) - - if classification is not None: - return classification - - internal_classifications = [ - self._run_inspectors(subtrace) for subtrace in nested_trace.subtraces - ] - - return UnknownClassification( - trace=nested_trace.trace, - internal_classifications=internal_classifications, - ) diff --git a/mev_inspect/schemas/blocks.py b/mev_inspect/schemas/blocks.py index d26e8a6..b6e3635 100644 --- a/mev_inspect/schemas/blocks.py +++ b/mev_inspect/schemas/blocks.py @@ -1,11 +1,39 @@ from enum import Enum from typing import Dict, List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, validator +from mev_inspect.utils import hex_to_int from .utils import CamelModel, Web3Model +class CallResult(CamelModel): + gas_used: int + + @validator("gas_used", pre=True) + def maybe_hex_to_int(v): + if isinstance(v, str): + return hex_to_int(v) + return v + + +class CallAction(Web3Model): + to: str + from_: str + input: str + value: int + gas: int + + @validator("value", "gas", pre=True) + def maybe_hex_to_int(v): + if isinstance(v, str): + return hex_to_int(v) + return v + + class Config: + fields = {"from_": "from"} + + class TraceType(Enum): call = "call" create = "create" diff --git a/mev_inspect/schemas/classifications.py b/mev_inspect/schemas/classifications.py deleted file mode 100644 index 3e0875f..0000000 --- a/mev_inspect/schemas/classifications.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import List - -from pydantic import BaseModel - -from .blocks import Trace - - -class Classification(BaseModel): - pass - - -class UnknownClassification(Classification): - trace: Trace - internal_classifications: List[Classification] diff --git a/mev_inspect/schemas/classified_traces.py b/mev_inspect/schemas/classified_traces.py new file mode 100644 index 0000000..f867508 --- /dev/null +++ b/mev_inspect/schemas/classified_traces.py @@ -0,0 +1,51 @@ +from enum import Enum +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel + +from .blocks import TraceType + + +class Classification(Enum): + unknown = "unknown" + swap = "swap" + burn = "burn" + transfer = "transfer" + + +class Protocol(Enum): + uniswap_v2 = "uniswap_v2" + sushiswap = "sushiswap" + + +class ClassifiedTrace(BaseModel): + transaction_hash: str + block_number: int + trace_type: TraceType + trace_address: List[int] + classification: Classification + protocol: Optional[Protocol] + abi_name: Optional[str] + function_name: Optional[str] + function_signature: Optional[str] + inputs: Optional[Dict[str, Any]] + to_address: Optional[str] + from_address: Optional[str] + gas: Optional[int] + value: Optional[int] + gas_used: Optional[int] + error: Optional[str] + + class Config: + json_encoders = { + # a little lazy but fine for now + # this is used for bytes value inputs + bytes: lambda b: b.hex(), + } + + +class ClassifierSpec(BaseModel): + abi_name: str + protocol: Optional[Protocol] = None + valid_contract_addresses: Optional[List[str]] = None + classifications: Dict[str, Classification] = {} diff --git a/mev_inspect/trace_classifier.py b/mev_inspect/trace_classifier.py new file mode 100644 index 0000000..9dfd816 --- /dev/null +++ b/mev_inspect/trace_classifier.py @@ -0,0 +1,93 @@ +from typing import Dict, List, Optional + +from mev_inspect.abi import get_abi +from mev_inspect.decode import ABIDecoder +from mev_inspect.schemas.blocks import CallAction, CallResult, Trace, TraceType +from mev_inspect.schemas.classified_traces import ( + Classification, + ClassifiedTrace, + ClassifierSpec, +) + + +class TraceClassifier: + def __init__(self, classifier_specs: List[ClassifierSpec]) -> None: + # TODO - index by contract_addresses for speed + self._classifier_specs = classifier_specs + self._decoders_by_abi_name: Dict[str, ABIDecoder] = {} + + for spec in self._classifier_specs: + abi = get_abi(spec.abi_name) + + if abi is None: + raise ValueError(f"No ABI found for {spec.abi_name}") + + decoder = ABIDecoder(abi) + self._decoders_by_abi_name[spec.abi_name] = decoder + + def classify( + self, + traces: List[Trace], + ) -> List[ClassifiedTrace]: + return [ + self._classify_trace(trace) + for trace in traces + if trace.type != TraceType.reward + ] + + def _classify_trace(self, trace: Trace) -> ClassifiedTrace: + if trace.type == TraceType.call: + classified_trace = self._classify_call(trace) + if classified_trace is not None: + return classified_trace + + return ClassifiedTrace( + **trace.dict(), + trace_type=trace.type, + classification=Classification.unknown, + ) + + def _classify_call(self, trace) -> Optional[ClassifiedTrace]: + action = CallAction(**trace.action) + result = CallResult(**trace.result) if trace.result is not None else None + + for spec in self._classifier_specs: + if spec.valid_contract_addresses is not None: + if action.to not in spec.valid_contract_addresses: + continue + + decoder = self._decoders_by_abi_name[spec.abi_name] + call_data = decoder.decode(action.input) + + if call_data is not None: + signature = call_data.function_signature + classification = spec.classifications.get( + signature, Classification.unknown + ) + + return ClassifiedTrace( + **trace.dict(), + trace_type=trace.type, + classification=classification, + protocol=spec.protocol, + abi_name=spec.abi_name, + function_name=call_data.function_name, + function_signature=signature, + inputs=call_data.inputs, + to_address=action.to, + from_address=action.from_, + value=action.value, + gas=action.gas, + gas_used=result.gas_used if result is not None else None, + ) + + return ClassifiedTrace( + **trace.dict(), + trace_type=trace.type, + classification=Classification.unknown, + to_address=action.to, + from_address=action.from_, + value=action.value, + gas=action.gas, + gas_used=result.gas_used if result is not None else None, + ) diff --git a/mev_inspect/utils.py b/mev_inspect/utils.py index 0f7c094..10d331a 100644 --- a/mev_inspect/utils.py +++ b/mev_inspect/utils.py @@ -1,19 +1,5 @@ -from typing import List - from hexbytes.main import HexBytes -from mev_inspect.schemas.blocks import Trace - -def check_trace_for_signature(trace: Trace, signatures: List[str]): - if trace.action["input"] == None: - return False - - ## Iterate over all signatures, and if our trace matches any of them set it to True - for signature in signatures: - if HexBytes(trace.action["input"]).startswith(signature): - ## Note that we are turning the input into hex bytes here, which seems to be fine - ## Working with strings was doing weird things - return True - - return False +def hex_to_int(value: str) -> int: + return int.from_bytes(HexBytes(value), byteorder="big") diff --git a/run.py b/run.py new file mode 100644 index 0000000..065f7d2 --- /dev/null +++ b/run.py @@ -0,0 +1,74 @@ +import argparse +import json + +from web3 import Web3 + +from mev_inspect import block +from mev_inspect.crud.classified_traces import write_classified_traces +from mev_inspect.db import get_session +from mev_inspect.classifier_specs import CLASSIFIER_SPECS +from mev_inspect.trace_classifier import TraceClassifier + + +def inspect_block(base_provider, block_number): + block_data = block.create_from_block_number(block_number, base_provider) + print(f"Total traces: {len(block_data.traces)}") + + total_transactions = len( + set( + t.transaction_hash + for t in block_data.traces + if t.transaction_hash is not None + ) + ) + print(f"Total transactions: {total_transactions}") + + trace_clasifier = TraceClassifier(CLASSIFIER_SPECS) + classified_traces = trace_clasifier.classify(block_data.traces) + print(f"Returned {len(classified_traces)} classified traces") + + db_session = get_session() + write_classified_traces(db_session, classified_traces) + db_session.close() + + stats = get_stats(classified_traces) + print(json.dumps(stats, indent=4)) + + +def get_stats(classified_traces) -> dict: + stats: dict = {} + + for trace in classified_traces: + abi_name = trace.abi_name + classification = trace.classification.value + signature = trace.function_signature + + abi_name_stats = stats.get(abi_name, {}) + class_stats = abi_name_stats.get(classification, {}) + signature_count = class_stats.get(signature, 0) + class_stats[signature] = signature_count + 1 + abi_name_stats[classification] = class_stats + stats[abi_name] = abi_name_stats + + return stats + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Inspect some blocks.") + + parser.add_argument( + "-block_number", + metavar="b", + type=int, + nargs="+", + help="the block number you are targetting, eventually this will need to be changed", + ) + + parser.add_argument( + "-rpc", metavar="r", help="rpc endpoint, this needs to have parity style traces" + ) + + args = parser.parse_args() + + w3_base_provider = Web3.HTTPProvider(args.rpc) + inspect_block(w3_base_provider, args.block_number[0])