Create trace classifier. Use in general run script. Write to the DB

This commit is contained in:
Luke Van Seters 2021-07-27 18:33:40 -04:00
parent 6ca00ff3ba
commit d7f2d120dd
18 changed files with 343 additions and 230 deletions

View File

@ -1,41 +0,0 @@
import argparse
from web3 import Web3
from mev_inspect import block
from mev_inspect.inspectors.uniswap import UniswapInspector
from mev_inspect.processor import Processor
parser = argparse.ArgumentParser(description="Inspect some blocks.")
parser.add_argument(
"-block_number",
metavar="b",
type=int,
nargs="+",
help="the block number you are targetting, eventually this will need to be changed",
)
parser.add_argument(
"-rpc", metavar="r", help="rpc endpoint, this needs to have parity style traces"
)
args = parser.parse_args()
## Set up the base provider, but don't wrap it in web3 so we can make requests to it with make_request()
base_provider = Web3.HTTPProvider(args.rpc)
## Get block data that we need
block_data = block.create_from_block_number(args.block_number[0], base_provider)
print(f"Total traces: {len(block_data.traces)}")
total_transactions = len(
set(t.transaction_hash for t in block_data.traces if t.transaction_hash is not None)
)
print(f"Total transactions: {total_transactions}")
## Build a Uniswap inspector
uniswap_inspector = UniswapInspector(base_provider)
## Create a processor, pass in an ARRAY of inspects
processor = Processor([uniswap_inspector, uniswap_inspector])
classifications = processor.get_transaction_evaluations(block_data)
print(f"Returned {len(classifications)} classifications")

View File

@ -0,0 +1,37 @@
from mev_inspect.schemas.classified_traces import (
Classification,
ClassifierSpec,
Protocol,
)
SUSHISWAP_ROUTER_ADDRESS = "0xd9e1cE17f2641f24aE83637ab66a2cca9C378B9F"
UNISWAP_V2_ROUTER_ADDRESS = "0x7a250d5630B4cF539739dF2C5dAcb4c659F2488D"
CLASSIFIER_SPECS = [
ClassifierSpec(
abi_name="UniswapV2Router",
protocol=Protocol.uniswap_v2,
valid_contract_addresses=[UNISWAP_V2_ROUTER_ADDRESS],
),
ClassifierSpec(
abi_name="UniswapV2Router",
protocol=Protocol.sushiswap,
valid_contract_addresses=[SUSHISWAP_ROUTER_ADDRESS],
),
ClassifierSpec(
abi_name="ERC20",
classifications={
"transferFrom(address,address,uint256)": Classification.transfer,
"transfer(address,uint256)": Classification.transfer,
"burn(address)": Classification.burn,
},
),
ClassifierSpec(
abi_name="UniswapV2Pair",
classifications={
"swap(uint256,uint256,address,bytes)": Classification.swap,
},
),
]

View File

View File

@ -0,0 +1,17 @@
import json
from typing import List
from mev_inspect.models.classified_traces import ClassifiedTraceModel
from mev_inspect.schemas.classified_traces import ClassifiedTrace
def write_classified_traces(
db_connection,
classified_traces: List[ClassifiedTrace],
) -> None:
models = [
ClassifiedTraceModel(**json.loads(trace.json())) for trace in classified_traces
]
db_connection.bulk_save_objects(models)
db_connection.commit()

13
mev_inspect/db.py Normal file
View File

@ -0,0 +1,13 @@
import os
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
def get_engine():
return create_engine(os.getenv("SQLALCHEMY_DATABASE_URI"))
def get_session():
Session = sessionmaker(bind=get_engine())
return Session()

View File

@ -1 +0,0 @@
from .base import Inspector

View File

@ -1,11 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional
from mev_inspect.schemas.blocks import NestedTrace
from mev_inspect.schemas.classifications import Classification
class Inspector(ABC):
@abstractmethod
def inspect(self, nested_trace: NestedTrace) -> Optional[Classification]:
pass

View File

@ -1,103 +0,0 @@
import json
from typing import Optional
from web3 import Web3
from mev_inspect import utils
from mev_inspect.config import load_config
from mev_inspect.schemas.blocks import NestedTrace, TraceType
from mev_inspect.schemas.classifications import Classification
from .base import Inspector
config = load_config()
uniswap_router_abi = json.loads(config["ABI"]["UniswapV2Router"])
uniswap_router_address = config["ADDRESSES"]["UniswapV2Router"]
sushiswap_router_address = config["ADDRESSES"]["SushiswapV2Router"]
uniswap_pair_abi = json.loads(config["ABI"]["UniswapV2Pair"])
class UniswapInspector(Inspector):
def __init__(self, base_provider) -> None:
self.w3 = Web3(base_provider)
self.trading_functions = self.get_trading_functions()
self.uniswap_v2_router_contract = self.w3.eth.contract(
abi=uniswap_router_abi, address=uniswap_router_address
)
self.uniswap_router_trade_signatures = self.get_router_signatures()
self.uniswap_v2_pair_contract = self.w3.eth.contract(abi=uniswap_pair_abi)
self.uniswap_v2_pair_swap_signatures = (
self.uniswap_v2_pair_contract.functions.swap(
0, 0, uniswap_router_address, ""
).selector
) ## Note the address here doesn't matter, but it must be filled out
self.uniswap_v2_pair_reserves_signatures = (
self.uniswap_v2_pair_contract.functions.getReserves().selector
) ## Called "checksigs" in mev-inpsect.ts
print("Built Uniswap inspector")
def get_trading_functions(self):
## Gets all functions used for swapping
result = []
## For each entry in the ABI
for abi in uniswap_router_abi:
## Check to see if the entry is a function and if it is if the function's name starts with swap
if abi["type"] == "function" and abi["name"].startswith("swap"):
## If so add it to our array
result.append(abi["name"])
return result
def get_router_signatures(self):
## Gets the selector / function signatures of all the router swap functions
result = []
## For each entry in the ABI
for abi in uniswap_router_abi:
## Check to see if the entry is a function and if it is if the function's name starts with swap
if abi["type"] == "function" and abi["name"].startswith("swap"):
## Add a parantheses
function = abi["name"] + "("
## For each input in the function's input
for input in abi["inputs"]:
## Concat them into a string with commas
function = function + input["internalType"] + ","
## Take off the last comma, add a ')' to close the parentheses
function = function[:-1] + ")"
## The result looks like this: 'swapETHForExactTokens(uint256,address[],address,uint256)'
## Take the first 4 bytes of the sha3 hash of the above string.
selector = Web3.sha3(text=function)[0:4]
## Add that to an array
result.append(selector)
return result
def inspect(self, nested_trace: NestedTrace) -> Optional[Classification]:
trace = nested_trace.trace
if (
trace.type == TraceType.call
and (
trace.action["to"] == uniswap_router_address.lower()
or trace.action["to"] == sushiswap_router_address.lower()
)
and utils.check_trace_for_signature(
trace, self.uniswap_router_trade_signatures
)
):
# print("WIP, here is where there is a call that matches what we are looking for")
1 == 1
return None

View File

View File

@ -0,0 +1,3 @@
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()

View File

@ -0,0 +1,24 @@
from sqlalchemy import Column, JSON, Numeric, String
from .base import Base
class ClassifiedTraceModel(Base):
__tablename__ = "classified_traces"
transaction_hash = Column(String, primary_key=True)
block_number = Column(Numeric, nullable=False)
classification = Column(String, nullable=False)
trace_type = Column(String, nullable=False)
trace_address = Column(String, nullable=False)
protocol = Column(String, nullable=True)
abi_name = Column(String, nullable=True)
function_name = Column(String, nullable=True)
function_signature = Column(String, nullable=True)
inputs = Column(JSON, nullable=True)
from_address = Column(String, nullable=True)
to_address = Column(String, nullable=True)
gas = Column(Numeric, nullable=True)
value = Column(Numeric, nullable=True)
gas_used = Column(Numeric, nullable=True)
error = Column(String, nullable=True)

View File

@ -1,43 +0,0 @@
from typing import List
from mev_inspect.inspectors import Inspector
from mev_inspect.schemas.blocks import Block, NestedTrace, TraceType
from mev_inspect.schemas.classifications import (
Classification,
UnknownClassification,
)
from mev_inspect.traces import as_nested_traces
class Processor:
def __init__(self, inspectors: List[Inspector]) -> None:
self._inspectors = inspectors
def get_transaction_evaluations(
self,
block: Block,
) -> List[Classification]:
transaction_traces = (
trace for trace in block.traces if trace.type != TraceType.reward
)
return [
self._run_inspectors(nested_trace)
for nested_trace in as_nested_traces(transaction_traces)
]
def _run_inspectors(self, nested_trace: NestedTrace) -> Classification:
for inspector in self._inspectors:
classification = inspector.inspect(nested_trace)
if classification is not None:
return classification
internal_classifications = [
self._run_inspectors(subtrace) for subtrace in nested_trace.subtraces
]
return UnknownClassification(
trace=nested_trace.trace,
internal_classifications=internal_classifications,
)

View File

@ -1,11 +1,39 @@
from enum import Enum from enum import Enum
from typing import Dict, List, Optional from typing import Dict, List, Optional
from pydantic import BaseModel from pydantic import BaseModel, validator
from mev_inspect.utils import hex_to_int
from .utils import CamelModel, Web3Model from .utils import CamelModel, Web3Model
class CallResult(CamelModel):
gas_used: int
@validator("gas_used", pre=True)
def maybe_hex_to_int(v):
if isinstance(v, str):
return hex_to_int(v)
return v
class CallAction(Web3Model):
to: str
from_: str
input: str
value: int
gas: int
@validator("value", "gas", pre=True)
def maybe_hex_to_int(v):
if isinstance(v, str):
return hex_to_int(v)
return v
class Config:
fields = {"from_": "from"}
class TraceType(Enum): class TraceType(Enum):
call = "call" call = "call"
create = "create" create = "create"

View File

@ -1,14 +0,0 @@
from typing import List
from pydantic import BaseModel
from .blocks import Trace
class Classification(BaseModel):
pass
class UnknownClassification(Classification):
trace: Trace
internal_classifications: List[Classification]

View File

@ -0,0 +1,51 @@
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from .blocks import TraceType
class Classification(Enum):
unknown = "unknown"
swap = "swap"
burn = "burn"
transfer = "transfer"
class Protocol(Enum):
uniswap_v2 = "uniswap_v2"
sushiswap = "sushiswap"
class ClassifiedTrace(BaseModel):
transaction_hash: str
block_number: int
trace_type: TraceType
trace_address: List[int]
classification: Classification
protocol: Optional[Protocol]
abi_name: Optional[str]
function_name: Optional[str]
function_signature: Optional[str]
inputs: Optional[Dict[str, Any]]
to_address: Optional[str]
from_address: Optional[str]
gas: Optional[int]
value: Optional[int]
gas_used: Optional[int]
error: Optional[str]
class Config:
json_encoders = {
# a little lazy but fine for now
# this is used for bytes value inputs
bytes: lambda b: b.hex(),
}
class ClassifierSpec(BaseModel):
abi_name: str
protocol: Optional[Protocol] = None
valid_contract_addresses: Optional[List[str]] = None
classifications: Dict[str, Classification] = {}

View File

@ -0,0 +1,93 @@
from typing import Dict, List, Optional
from mev_inspect.abi import get_abi
from mev_inspect.decode import ABIDecoder
from mev_inspect.schemas.blocks import CallAction, CallResult, Trace, TraceType
from mev_inspect.schemas.classified_traces import (
Classification,
ClassifiedTrace,
ClassifierSpec,
)
class TraceClassifier:
def __init__(self, classifier_specs: List[ClassifierSpec]) -> None:
# TODO - index by contract_addresses for speed
self._classifier_specs = classifier_specs
self._decoders_by_abi_name: Dict[str, ABIDecoder] = {}
for spec in self._classifier_specs:
abi = get_abi(spec.abi_name)
if abi is None:
raise ValueError(f"No ABI found for {spec.abi_name}")
decoder = ABIDecoder(abi)
self._decoders_by_abi_name[spec.abi_name] = decoder
def classify(
self,
traces: List[Trace],
) -> List[ClassifiedTrace]:
return [
self._classify_trace(trace)
for trace in traces
if trace.type != TraceType.reward
]
def _classify_trace(self, trace: Trace) -> ClassifiedTrace:
if trace.type == TraceType.call:
classified_trace = self._classify_call(trace)
if classified_trace is not None:
return classified_trace
return ClassifiedTrace(
**trace.dict(),
trace_type=trace.type,
classification=Classification.unknown,
)
def _classify_call(self, trace) -> Optional[ClassifiedTrace]:
action = CallAction(**trace.action)
result = CallResult(**trace.result) if trace.result is not None else None
for spec in self._classifier_specs:
if spec.valid_contract_addresses is not None:
if action.to not in spec.valid_contract_addresses:
continue
decoder = self._decoders_by_abi_name[spec.abi_name]
call_data = decoder.decode(action.input)
if call_data is not None:
signature = call_data.function_signature
classification = spec.classifications.get(
signature, Classification.unknown
)
return ClassifiedTrace(
**trace.dict(),
trace_type=trace.type,
classification=classification,
protocol=spec.protocol,
abi_name=spec.abi_name,
function_name=call_data.function_name,
function_signature=signature,
inputs=call_data.inputs,
to_address=action.to,
from_address=action.from_,
value=action.value,
gas=action.gas,
gas_used=result.gas_used if result is not None else None,
)
return ClassifiedTrace(
**trace.dict(),
trace_type=trace.type,
classification=Classification.unknown,
to_address=action.to,
from_address=action.from_,
value=action.value,
gas=action.gas,
gas_used=result.gas_used if result is not None else None,
)

View File

@ -1,19 +1,5 @@
from typing import List
from hexbytes.main import HexBytes from hexbytes.main import HexBytes
from mev_inspect.schemas.blocks import Trace
def hex_to_int(value: str) -> int:
def check_trace_for_signature(trace: Trace, signatures: List[str]): return int.from_bytes(HexBytes(value), byteorder="big")
if trace.action["input"] == None:
return False
## Iterate over all signatures, and if our trace matches any of them set it to True
for signature in signatures:
if HexBytes(trace.action["input"]).startswith(signature):
## Note that we are turning the input into hex bytes here, which seems to be fine
## Working with strings was doing weird things
return True
return False

74
run.py Normal file
View File

@ -0,0 +1,74 @@
import argparse
import json
from web3 import Web3
from mev_inspect import block
from mev_inspect.crud.classified_traces import write_classified_traces
from mev_inspect.db import get_session
from mev_inspect.classifier_specs import CLASSIFIER_SPECS
from mev_inspect.trace_classifier import TraceClassifier
def inspect_block(base_provider, block_number):
block_data = block.create_from_block_number(block_number, base_provider)
print(f"Total traces: {len(block_data.traces)}")
total_transactions = len(
set(
t.transaction_hash
for t in block_data.traces
if t.transaction_hash is not None
)
)
print(f"Total transactions: {total_transactions}")
trace_clasifier = TraceClassifier(CLASSIFIER_SPECS)
classified_traces = trace_clasifier.classify(block_data.traces)
print(f"Returned {len(classified_traces)} classified traces")
db_session = get_session()
write_classified_traces(db_session, classified_traces)
db_session.close()
stats = get_stats(classified_traces)
print(json.dumps(stats, indent=4))
def get_stats(classified_traces) -> dict:
stats: dict = {}
for trace in classified_traces:
abi_name = trace.abi_name
classification = trace.classification.value
signature = trace.function_signature
abi_name_stats = stats.get(abi_name, {})
class_stats = abi_name_stats.get(classification, {})
signature_count = class_stats.get(signature, 0)
class_stats[signature] = signature_count + 1
abi_name_stats[classification] = class_stats
stats[abi_name] = abi_name_stats
return stats
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Inspect some blocks.")
parser.add_argument(
"-block_number",
metavar="b",
type=int,
nargs="+",
help="the block number you are targetting, eventually this will need to be changed",
)
parser.add_argument(
"-rpc", metavar="r", help="rpc endpoint, this needs to have parity style traces"
)
args = parser.parse_args()
w3_base_provider = Web3.HTTPProvider(args.rpc)
inspect_block(w3_base_provider, args.block_number[0])