Merge pull request #26 from lukevs/nested-traces

Add NestedTrace
This commit is contained in:
Robert Miller 2021-07-22 10:41:52 -04:00 committed by GitHub
commit 6e8a7b58e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 23890 additions and 208352 deletions

34422
cache/10803840.json vendored

File diff suppressed because it is too large Load Diff

34396
cache/11930296-new.json vendored

File diff suppressed because one or more lines are too long

23494
cache/11935012-new.json vendored

File diff suppressed because one or more lines are too long

32009
cache/12051659.json vendored

File diff suppressed because it is too large Load Diff

40552
cache/12412732.json vendored

File diff suppressed because it is too large Load Diff

View File

@ -3,7 +3,7 @@ from typing import List
from web3 import Web3 from web3 import Web3
from mev_inspect.schemas import Block, BlockCall, BlockCallType from mev_inspect.schemas import Block, Trace, TraceType
cache_directory = "./cache" cache_directory = "./cache"
@ -12,7 +12,7 @@ cache_directory = "./cache"
## Creates a block object, either from the cache or from the chain itself ## Creates a block object, either from the cache or from the chain itself
## Note that you need to pass in the provider, not the web3 wrapped provider object! ## Note that you need to pass in the provider, not the web3 wrapped provider object!
## This is because only the provider allows you to make json rpc requests ## This is because only the provider allows you to make json rpc requests
def createFromBlockNumber(block_number: int, base_provider) -> Block: def create_from_block_number(block_number: int, base_provider) -> Block:
cache_path = _get_cache_path(block_number) cache_path = _get_cache_path(block_number)
if cache_path.is_file(): if cache_path.is_file():
@ -42,8 +42,8 @@ def fetch_block(w3, base_provider, block_number: int) -> Block:
) )
## Trace the whole block, return those calls ## Trace the whole block, return those calls
block_calls_json = w3.parity.trace_block(block_number) traces_json = w3.parity.trace_block(block_number)
block_calls = [BlockCall(**call_json) for call_json in block_calls_json] traces = [Trace(**trace_json) for trace_json in traces_json]
## Get the logs ## Get the logs
block_hash = (block_data.hash).hex() block_hash = (block_data.hash).hex()
@ -64,25 +64,25 @@ def fetch_block(w3, base_provider, block_number: int) -> Block:
"netFeePaid": tx_data["gasPrice"] * tx_receipt["gasUsed"], "netFeePaid": tx_data["gasPrice"] * tx_receipt["gasUsed"],
} }
transaction_hashes = get_transaction_hashes(block_calls) transaction_hashes = get_transaction_hashes(traces)
## Create a new object ## Create a new object
return Block( return Block(
block_number=block_number, block_number=block_number,
data=block_data, data=block_data,
receipts=block_receipts_raw, receipts=block_receipts_raw,
calls=block_calls, traces=traces,
logs=block_logs, logs=block_logs,
transaction_hashes=transaction_hashes, transaction_hashes=transaction_hashes,
txs_gas_data=txs_gas_data, txs_gas_data=txs_gas_data,
) )
def get_transaction_hashes(calls: List[BlockCall]) -> List[str]: def get_transaction_hashes(calls: List[Trace]) -> List[str]:
result = [] result = []
for call in calls: for call in calls:
if call.type != BlockCallType.reward: if call.type != TraceType.reward:
if ( if (
call.transaction_hash is not None call.transaction_hash is not None
and call.transaction_hash not in result and call.transaction_hash not in result

View File

@ -83,10 +83,14 @@ class UniswapInspector:
for call in calls: for call in calls:
print("\n", call) print("\n", call)
if ( if (
call["type"] == "call"
and (
call["action"]["to"] == uniswap_router_address.lower() call["action"]["to"] == uniswap_router_address.lower()
or call["action"]["to"] == sushiswap_router_address.lower() or call["action"]["to"] == sushiswap_router_address.lower()
) and utils.check_call_for_signature( )
and utils.check_trace_for_signature(
call, self.uniswap_router_trade_signatures call, self.uniswap_router_trade_signatures
)
): ):
# print("WIP, here is where there is a call that matches what we are looking for") # print("WIP, here is where there is a call that matches what we are looking for")
1 == 1 1 == 1

View File

@ -8,8 +8,8 @@ class Processor:
def get_transaction_evaluations(self, block_data): def get_transaction_evaluations(self, block_data):
for transaction_hash in block_data.transaction_hashes: for transaction_hash in block_data.transaction_hashes:
calls = block_data.get_filtered_calls(transaction_hash) traces = block_data.get_filtered_traces(transaction_hash)
calls_json = [to_original_json_dict(call) for call in calls] traces_json = [to_original_json_dict(trace) for trace in traces]
for inspector in self.inspectors: for inspector in self.inspectors:
inspector.inspect(calls_json) inspector.inspect(traces_json)

View File

@ -1,2 +1,2 @@
from .abi import ABI from .abi import ABI
from .blocks import Block, BlockCall, BlockCallType from .blocks import Block, NestedTrace, Trace, TraceType

View File

@ -1,10 +1,12 @@
from enum import Enum from enum import Enum
from typing import Dict, List, Optional from typing import Dict, List, Optional
from pydantic import BaseModel
from .utils import CamelModel, Web3Model from .utils import CamelModel, Web3Model
class BlockCallType(Enum): class TraceType(Enum):
call = "call" call = "call"
create = "create" create = "create"
delegate_call = "delegateCall" delegate_call = "delegateCall"
@ -12,7 +14,7 @@ class BlockCallType(Enum):
suicide = "suicide" suicide = "suicide"
class BlockCall(CamelModel): class Trace(CamelModel):
action: dict action: dict
block_hash: str block_hash: str
block_number: int block_number: int
@ -21,18 +23,26 @@ class BlockCall(CamelModel):
trace_address: List[int] trace_address: List[int]
transaction_hash: Optional[str] transaction_hash: Optional[str]
transaction_position: Optional[int] transaction_position: Optional[int]
type: BlockCallType type: TraceType
error: Optional[str] error: Optional[str]
class Block(Web3Model): class Block(Web3Model):
block_number: int block_number: int
calls: List[BlockCall] traces: List[Trace]
data: dict data: dict
logs: List[dict] logs: List[dict]
receipts: dict receipts: dict
transaction_hashes: List[str] transaction_hashes: List[str]
txs_gas_data: Dict[str, dict] txs_gas_data: Dict[str, dict]
def get_filtered_calls(self, hash: str) -> List[BlockCall]: def get_filtered_traces(self, hash: str) -> List[Trace]:
return [call for call in self.calls if call.transaction_hash == hash] return [trace for trace in self.traces if trace.transaction_hash == hash]
class NestedTrace(BaseModel):
trace: Trace
subtraces: List["NestedTrace"]
NestedTrace.update_forward_refs()

View File

@ -1,7 +1,7 @@
import json from typing import List, Optional
from pathlib import Path
from mev_inspect.config import load_config from mev_inspect.config import load_config
from mev_inspect.schemas import Block, Trace, TraceType
config = load_config() config = load_config()
@ -12,30 +12,6 @@ weth_address = config["ADDRESSES"]["WETH"]
cache_directory = "./cache" cache_directory = "./cache"
def get_tx_traces(txHash, blockNo):
# block_calls = w3.parity.trace_block(10803840)
cache_file = "{cacheDirectory}/{blockNumber}-new.json".format(
cacheDirectory=cache_directory, blockNumber=blockNo
)
file_exists = Path(cache_file).is_file()
tx_traces = []
# if have the traces cached
if file_exists:
block_file = open(cache_file)
block_json = json.load(block_file)
for call in block_json["calls"]:
if call["transactionHash"] == txHash:
tx_traces.append(call)
block_file.close()
else:
# todo, fetch and cache traces that don't exist
# depending on the best way to call block.py from here
print("traces do not exist")
return tx_traces
def is_stablecoin_address(address): def is_stablecoin_address(address):
# to look for stablecoin inflow/outflows # to look for stablecoin inflow/outflows
stablecoin_addresses = [ stablecoin_addresses = [
@ -85,41 +61,34 @@ def is_known_router_address(address):
# we're interested in the to address to run token flow on it as well # we're interested in the to address to run token flow on it as well
def get_tx_to_address(txHash, blockNo): def get_tx_to_address(tx_hash, block) -> Optional[str]:
cache_file = "{cacheDirectory}/{blockNumber}-new.json".format( for receipt in block.receipts["result"]:
cacheDirectory=cache_directory, blockNumber=blockNo if receipt["transactionHash"] == tx_hash:
)
block_file = open(cache_file)
block_json = json.load(block_file)
for receipt in block_json["receipts"]["result"]:
if receipt["transactionHash"] == txHash:
block_file.close()
return receipt["to"] return receipt["to"]
return None
def get_tx_proxies(tx_traces, to_address):
def get_tx_proxies(tx_traces: List[Trace], to_address: Optional[str]):
proxies = [] proxies = []
for trace in tx_traces: for trace in tx_traces:
if ( if (
trace["type"] == "call" trace.type == TraceType.call
and trace["action"]["callType"] == "delegatecall" and trace.action["callType"] == "delegatecall"
and trace["action"]["from"] == to_address and trace.action["from"] == to_address
): ):
proxies.append(trace["action"]["to"]) proxies.append(trace.action["to"])
return proxies return proxies
def get_net_gas_used(txHash, blockNo): def get_net_gas_used(tx_hash, block):
cache_file = "{cacheDirectory}/{blockNumber}.json".format( for trace in block.traces:
cacheDirectory=cache_directory, blockNumber=blockNo if trace.transaction_hash == tx_hash:
) gas_used += int(trace.result["gasUsed"], 16)
block_file = open(cache_file)
block_json = json.load(block_file) return gas_used
gas_used = 0
for trace in block_json["calls"]:
if trace["transactionHash"] == txHash:
gas_used = gas_used + int(trace["result"]["gasUsed"], 16)
print(gas_used)
def get_ether_flows(tx_traces, addresses_to_check): def get_ether_flows(tx_traces, addresses_to_check):
@ -127,60 +96,56 @@ def get_ether_flows(tx_traces, addresses_to_check):
eth_outflow = 0 eth_outflow = 0
for trace in tx_traces: for trace in tx_traces:
if trace["type"] == "call": if trace.type == TraceType.call:
value = int( value = int(
trace["action"]["value"], 16 trace.action["value"], 16
) # converting from 0x prefix to decimal ) # converting from 0x prefix to decimal
# ETH_GET # ETH_GET
if ( if (
trace["action"]["callType"] != "delegatecall" trace.action["callType"] != "delegatecall"
and trace["action"]["from"] != weth_address and trace.action["from"] != weth_address
and value > 0 and value > 0
and trace["action"]["to"] in addresses_to_check and trace.action["to"] in addresses_to_check
): ):
eth_inflow = eth_inflow + value eth_inflow = eth_inflow + value
# ETH_GIVE # ETH_GIVE
if ( if (
trace["action"]["callType"] != "delegatecall" trace.action["callType"] != "delegatecall"
and trace["action"]["to"] != weth_address and trace.action["to"] != weth_address
and value > 0 and value > 0
and trace["action"]["from"] in addresses_to_check and trace.action["from"] in addresses_to_check
): ):
eth_outflow = eth_outflow + value eth_outflow = eth_outflow + value
if trace["action"]["to"] == weth_address: if trace.action["to"] == weth_address:
# WETH_GET1 & WETH_GET2 (to account for both 'transfer' and 'transferFrom' methods) # WETH_GET1 & WETH_GET2 (to account for both 'transfer' and 'transferFrom' methods)
# WETH_GIVE1 & WETH_GIVE2 # WETH_GIVE1 & WETH_GIVE2
# transfer(address to,uint256 value) with args # transfer(address to,uint256 value) with args
if len(trace["action"]["input"]) == 138: if len(trace.action["input"]) == 138:
if trace["action"]["input"][2:10] == "a9059cbb": if trace.action["input"][2:10] == "a9059cbb":
transfer_to = "0x" + trace["action"]["input"][34:74] transfer_to = "0x" + trace.action["input"][34:74]
transfer_value = int( transfer_value = int("0x" + trace.action["input"][74:138], 16)
"0x" + trace["action"]["input"][74:138], 16
)
if transfer_to in addresses_to_check: if transfer_to in addresses_to_check:
eth_inflow = eth_inflow + transfer_value eth_inflow = eth_inflow + transfer_value
elif trace["action"]["from"] in addresses_to_check: elif trace.action["from"] in addresses_to_check:
eth_outflow = eth_outflow + transfer_value eth_outflow = eth_outflow + transfer_value
# transferFrom(address from,address to,uint256 value ) # transferFrom(address from,address to,uint256 value )
if len(trace["action"]["input"]) == 202: if len(trace.action["input"]) == 202:
if trace["action"]["input"][2:10] == "23b872dd": if trace.action["input"][2:10] == "23b872dd":
transfer_from = "0x" + trace["action"]["input"][34:74] transfer_from = "0x" + trace.action["input"][34:74]
transfer_to = "0x" + trace["action"]["input"][98:138] transfer_to = "0x" + trace.action["input"][98:138]
transfer_value = int( transfer_value = int("0x" + trace.action["input"][138:202], 16)
"0x" + trace["action"]["input"][138:202], 16
)
if transfer_to in addresses_to_check: if transfer_to in addresses_to_check:
eth_inflow = eth_inflow + transfer_value eth_inflow = eth_inflow + transfer_value
elif transfer_from in addresses_to_check: elif transfer_from in addresses_to_check:
eth_outflow = eth_outflow + transfer_value eth_outflow = eth_outflow + transfer_value
if trace["type"] == "suicide": if trace.type == TraceType.suicide:
if trace["action"]["refundAddress"] in addresses_to_check: if trace.action["refundAddress"] in addresses_to_check:
refund_value = int("0x" + trace["action"]["balance"], 16) refund_value = int("0x" + trace.action["balance"], 16)
eth_inflow = eth_inflow + refund_value eth_inflow = eth_inflow + refund_value
return [eth_inflow, eth_outflow] return [eth_inflow, eth_outflow]
@ -190,30 +155,28 @@ def get_dollar_flows(tx_traces, addresses_to_check):
dollar_inflow = 0 dollar_inflow = 0
dollar_outflow = 0 dollar_outflow = 0
for trace in tx_traces: for trace in tx_traces:
if trace["type"] == "call" and is_stablecoin_address(trace["action"]["to"]): if trace.type == TraceType.call and is_stablecoin_address(trace.action["to"]):
_ = int( _ = int(trace.action["value"], 16) # converting from 0x prefix to decimal
trace["action"]["value"], 16
) # converting from 0x prefix to decimal
# USD_GET1 & USD_GET2 (to account for both 'transfer' and 'transferFrom' methods) # USD_GET1 & USD_GET2 (to account for both 'transfer' and 'transferFrom' methods)
# USD_GIVE1 & USD_GIVE2 # USD_GIVE1 & USD_GIVE2
# transfer(address to,uint256 value) with args # transfer(address to,uint256 value) with args
if len(trace["action"]["input"]) == 138: if len(trace.action["input"]) == 138:
if trace["action"]["input"][2:10] == "a9059cbb": if trace.action["input"][2:10] == "a9059cbb":
transfer_to = "0x" + trace["action"]["input"][34:74] transfer_to = "0x" + trace.action["input"][34:74]
transfer_value = int("0x" + trace["action"]["input"][74:138], 16) transfer_value = int("0x" + trace.action["input"][74:138], 16)
if transfer_to in addresses_to_check: if transfer_to in addresses_to_check:
dollar_inflow = dollar_inflow + transfer_value dollar_inflow = dollar_inflow + transfer_value
elif trace["action"]["from"] in addresses_to_check: elif trace.action["from"] in addresses_to_check:
dollar_outflow = dollar_outflow + transfer_value dollar_outflow = dollar_outflow + transfer_value
# transferFrom(address from,address to,uint256 value ) # transferFrom(address from,address to,uint256 value )
if len(trace["action"]["input"]) == 202: if len(trace.action["input"]) == 202:
if trace["action"]["input"][2:10] == "23b872dd": if trace.action["input"][2:10] == "23b872dd":
transfer_from = "0x" + trace["action"]["input"][34:74] transfer_from = "0x" + trace.action["input"][34:74]
transfer_to = "0x" + trace["action"]["input"][98:138] transfer_to = "0x" + trace.action["input"][98:138]
transfer_value = int("0x" + trace["action"]["input"][138:202], 16) transfer_value = int("0x" + trace.action["input"][138:202], 16)
if transfer_to in addresses_to_check: if transfer_to in addresses_to_check:
dollar_inflow = dollar_inflow + transfer_value dollar_inflow = dollar_inflow + transfer_value
elif transfer_from in addresses_to_check: elif transfer_from in addresses_to_check:
@ -221,13 +184,18 @@ def get_dollar_flows(tx_traces, addresses_to_check):
return [dollar_inflow, dollar_outflow] return [dollar_inflow, dollar_outflow]
def run_tokenflow(txHash, blockNo): def run_tokenflow(tx_hash: str, block: Block):
tx_traces = get_tx_traces(txHash, blockNo) tx_traces = block.get_filtered_traces(tx_hash)
to_address = get_tx_to_address(txHash, blockNo) to_address = get_tx_to_address(tx_hash, block)
if to_address is None:
raise ValueError("No to address found")
addresses_to_check = [] addresses_to_check = []
# check for proxies, add them to addresses to check # check for proxies, add them to addresses to check
proxies = get_tx_proxies(tx_traces, to_address) proxies = get_tx_proxies(tx_traces, to_address)
for proxy in proxies: for proxy in proxies:
addresses_to_check.append(proxy.lower()) addresses_to_check.append(proxy.lower())
@ -247,8 +215,8 @@ def run_tokenflow(txHash, blockNo):
# note: not the gas set by user, only gas consumed upon execution # note: not the gas set by user, only gas consumed upon execution
# def get_gas_used_by_tx(txHash): # def get_gas_used_by_tx(tx_hash):
# # tx_receipt = w3.eth.getTransactionReceipt(txHash) # # tx_receipt = w3.eth.getTransactionReceipt(tx_hash)
# return tx_receipt["gasUsed"] # return tx_receipt["gasUsed"]

80
mev_inspect/traces.py Normal file
View File

@ -0,0 +1,80 @@
from itertools import groupby
from typing import Iterable, List
from mev_inspect.schemas import Trace, NestedTrace
def as_nested_traces(traces: Iterable[Trace]) -> List[NestedTrace]:
nested_traces = []
sorted_by_transaction_hash = sorted(traces, key=_get_transaction_hash)
for _, transaction_traces in groupby(
sorted_by_transaction_hash, _get_transaction_hash
):
nested_traces += _as_nested_traces_by_transaction(transaction_traces)
return nested_traces
def _get_transaction_hash(trace) -> str:
return trace.transaction_hash
def _as_nested_traces_by_transaction(traces: Iterable[Trace]) -> List[NestedTrace]:
"""
Turns a list of Traces into a a tree of NestedTraces
using their trace addresses
Right now this has an exponential (?) runtime because we rescan
most traces at each level of tree depth
TODO to write a better implementation if it becomes a bottleneck
Should be doable in linear time
"""
nested_traces = []
parent = None
children: List[Trace] = []
sorted_traces = sorted(traces, key=lambda t: t.trace_address)
for trace in sorted_traces:
if parent is None:
parent = trace
children = []
continue
elif not _is_subtrace(trace, parent):
nested_traces.append(
NestedTrace(
trace=parent,
subtraces=as_nested_traces(children),
)
)
parent = trace
children = []
else:
children.append(trace)
if parent is not None:
nested_traces.append(
NestedTrace(
trace=parent,
subtraces=as_nested_traces(children),
)
)
return nested_traces
def _is_subtrace(trace: Trace, parent: Trace):
parent_trace_length = len(parent.trace_address)
if len(trace.trace_address) > parent_trace_length:
prefix = trace.trace_address[:parent_trace_length]
return prefix == parent.trace_address
return False

View File

@ -1,22 +1,17 @@
from typing import List
from hexbytes.main import HexBytes from hexbytes.main import HexBytes
def check_call_for_signature(call, signatures): def check_trace_for_signature(trace: dict, signatures: List[str]):
if call["action"]["input"] == None: if trace["action"]["input"] == None:
return False return False
## By default set this to False ## Iterate over all signatures, and if our trace matches any of them set it to True
signature_present_boolean = False
## Iterate over all signatures, and if our call matches any of them set it to True
for signature in signatures: for signature in signatures:
# print("Desired signature:", str(signature)) if HexBytes(trace["action"]["input"]).startswith(signature):
# print("Actual", HexBytes(call['action']['input']))
if HexBytes(call["action"]["input"]).startswith((signature)):
## Note that we are turning the input into hex bytes here, which seems to be fine ## Note that we are turning the input into hex bytes here, which seems to be fine
## Working with strings was doing weird things ## Working with strings was doing weird things
print("hit") return True
signature_present_boolean = True
return signature_present_boolean return False

View File

@ -23,7 +23,7 @@ args = parser.parse_args()
base_provider = Web3.HTTPProvider(args.rpc) base_provider = Web3.HTTPProvider(args.rpc)
## Get block data that we need ## Get block data that we need
block_data = block.createFromBlockNumber(args.block_number[0], base_provider) block_data = block.create_from_block_number(args.block_number[0], base_provider)
## Build a Uniswap inspector ## Build a Uniswap inspector
uniswap_inspector = UniswapInspector(base_provider) uniswap_inspector = UniswapInspector(base_provider)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

10832
tests/blocks/11935012.json Normal file

File diff suppressed because one or more lines are too long

View File

@ -1,30 +1,49 @@
import json
import os
import unittest import unittest
from mev_inspect import tokenflow from mev_inspect import tokenflow
from mev_inspect.schemas.blocks import Block
THIS_FILE_DIRECTORY = os.path.dirname(__file__)
TEST_BLOCKS_DIRECTORY = os.path.join(THIS_FILE_DIRECTORY, "blocks")
class TestTokenFlow(unittest.TestCase): class TestTokenFlow(unittest.TestCase):
def test_simple_arb(self): def test_simple_arb(self):
tx_hash = "0x4121ce805d33e952b2e6103a5024f70c118432fd0370128d6d7845f9b2987922" tx_hash = "0x4121ce805d33e952b2e6103a5024f70c118432fd0370128d6d7845f9b2987922"
block_no = 11930296 block_no = 11930296
res = tokenflow.run_tokenflow(tx_hash, block_no)
block = load_test_block(block_no)
res = tokenflow.run_tokenflow(tx_hash, block)
self.assertEqual(res["ether_flows"], [3547869861992962562, 3499859860420296704]) self.assertEqual(res["ether_flows"], [3547869861992962562, 3499859860420296704])
self.assertEqual(res["dollar_flows"], [0, 0]) self.assertEqual(res["dollar_flows"], [0, 0])
def test_arb_with_stable_flow(self): def test_arb_with_stable_flow(self):
tx_hash = "0x496836e0bd1520388e36c79d587a31d4b3306e4f25352164178ca0667c7f9c29" tx_hash = "0x496836e0bd1520388e36c79d587a31d4b3306e4f25352164178ca0667c7f9c29"
block_no = 11935012 block_no = 11935012
res = tokenflow.run_tokenflow(tx_hash, block_no)
block = load_test_block(block_no)
res = tokenflow.run_tokenflow(tx_hash, block)
self.assertEqual(res["ether_flows"], [597044987302243493, 562445964778930176]) self.assertEqual(res["ether_flows"], [597044987302243493, 562445964778930176])
self.assertEqual(res["dollar_flows"], [871839781, 871839781]) self.assertEqual(res["dollar_flows"], [871839781, 871839781])
def test_complex_cross_arb(self): def test_complex_cross_arb(self):
tx_hash = "0x5ab21bfba50ad3993528c2828c63e311aafe93b40ee934790e545e150cb6ca73" tx_hash = "0x5ab21bfba50ad3993528c2828c63e311aafe93b40ee934790e545e150cb6ca73"
block_no = 11931272 block_no = 11931272
res = tokenflow.run_tokenflow(tx_hash, block_no) block = load_test_block(block_no)
res = tokenflow.run_tokenflow(tx_hash, block)
self.assertEqual(res["ether_flows"], [3636400213125714803, 3559576672903063566]) self.assertEqual(res["ether_flows"], [3636400213125714803, 3559576672903063566])
self.assertEqual(res["dollar_flows"], [0, 0]) self.assertEqual(res["dollar_flows"], [0, 0])
def load_test_block(block_number):
block_path = f"{TEST_BLOCKS_DIRECTORY}/{block_number}.json"
with open(block_path, "r") as block_file:
block_json = json.load(block_file)
return Block(**block_json)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

103
tests/trace_test.py Normal file
View File

@ -0,0 +1,103 @@
import unittest
from typing import List
from mev_inspect.schemas import Trace, TraceType, NestedTrace
from mev_inspect.traces import as_nested_traces
DEFAULT_BLOCK_NUMBER = 123
class TestTraces(unittest.TestCase):
def test_nested_traces(self):
trace_hash_address_pairs = [
("abc", [0, 2]),
("abc", []),
("abc", [2]),
("abc", [0]),
("abc", [0, 0]),
("abc", [0, 1]),
("abc", [1]),
("efg", []),
("abc", [1, 0]),
("abc", [0, 1, 0]),
("efg", [0]),
]
traces = [
build_trace_at_address(hash, address)
for (hash, address) in trace_hash_address_pairs
]
nested_traces = as_nested_traces(traces)
assert len(nested_traces) == 2
abc_trace = nested_traces[0]
efg_trace = nested_traces[1]
# abc
assert abc_trace.trace.transaction_hash == "abc"
assert_trace_address(abc_trace, [])
assert len(abc_trace.subtraces) == 3
[trace_0, trace_1, trace_2] = abc_trace.subtraces
assert_trace_address(trace_0, [0])
assert_trace_address(trace_1, [1])
assert_trace_address(trace_2, [2])
assert len(trace_0.subtraces) == 3
assert len(trace_1.subtraces) == 1
assert len(trace_2.subtraces) == 0
[trace_0_0, trace_0_1, trace_0_2] = trace_0.subtraces
[trace_1_0] = trace_1.subtraces
assert_trace_address(trace_0_0, [0, 0])
assert_trace_address(trace_0_1, [0, 1])
assert_trace_address(trace_0_2, [0, 2])
assert_trace_address(trace_1_0, [1, 0])
assert len(trace_0_0.subtraces) == 0
assert len(trace_0_1.subtraces) == 1
assert len(trace_0_2.subtraces) == 0
assert len(trace_1_0.subtraces) == 0
[trace_0_1_0] = trace_0_1.subtraces
assert_trace_address(trace_0_1_0, [0, 1, 0])
assert len(trace_0_1_0.subtraces) == 0
# efg
assert efg_trace.trace.transaction_hash == "efg"
assert_trace_address(efg_trace, [])
assert len(efg_trace.subtraces) == 1
[efg_subtrace] = efg_trace.subtraces
assert_trace_address(efg_subtrace, [0])
assert len(efg_subtrace.subtraces) == 0
def build_trace_at_address(
transaction_hash: str,
trace_address: List[int],
) -> Trace:
return Trace(
# real values
transaction_hash=transaction_hash,
trace_address=trace_address,
# placeholders
action={},
block_hash="",
block_number=DEFAULT_BLOCK_NUMBER,
result=None,
subtraces=0,
transaction_position=None,
type=TraceType.call,
error=None,
)
def assert_trace_address(nested_trace: NestedTrace, trace_address: List[int]):
assert nested_trace.trace.trace_address == trace_address