Add tests for get_child_traces. Add some helpful fixtures
This commit is contained in:
parent
f59b2b2b82
commit
f73a34a5ba
@ -44,7 +44,12 @@ def _get_swaps_for_transaction(traces: List[ClassifiedTrace]) -> List[Swap]:
|
||||
prior_transfers.append(Transfer.from_trace(trace))
|
||||
|
||||
elif trace.classification == Classification.swap:
|
||||
child_transfers = get_child_transfers(trace.trace_address, traces)
|
||||
child_transfers = get_child_transfers(
|
||||
trace.transaction_hash,
|
||||
trace.trace_address,
|
||||
traces,
|
||||
)
|
||||
|
||||
swap = _parse_swap(
|
||||
trace,
|
||||
remove_inner_transfers(prior_transfers),
|
||||
|
@ -16,6 +16,7 @@ def is_child_trace_address(
|
||||
|
||||
|
||||
def get_child_traces(
|
||||
transaction_hash: str,
|
||||
parent_trace_address: List[int],
|
||||
traces: List[ClassifiedTrace],
|
||||
) -> List[ClassifiedTrace]:
|
||||
@ -23,7 +24,7 @@ def get_child_traces(
|
||||
child_traces = []
|
||||
|
||||
for trace in ordered_traces:
|
||||
if is_child_trace_address(
|
||||
if trace.transaction_hash == transaction_hash and is_child_trace_address(
|
||||
trace.trace_address,
|
||||
parent_trace_address,
|
||||
):
|
||||
|
@ -6,12 +6,13 @@ from mev_inspect.traces import is_child_trace_address, get_child_traces
|
||||
|
||||
|
||||
def get_child_transfers(
|
||||
transaction_hash: str,
|
||||
parent_trace_address: List[int],
|
||||
traces: List[ClassifiedTrace],
|
||||
) -> List[Transfer]:
|
||||
child_transfers = []
|
||||
|
||||
for child_trace in get_child_traces(parent_trace_address, traces):
|
||||
for child_trace in get_child_traces(transaction_hash, parent_trace_address, traces):
|
||||
if child_trace.classification == Classification.transfer:
|
||||
child_transfers.append(Transfer.from_trace(child_trace))
|
||||
|
||||
|
@ -1,4 +1,49 @@
|
||||
from mev_inspect.traces import is_child_trace_address
|
||||
from hashlib import sha3_256
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from mev_inspect.schemas.blocks import TraceType
|
||||
from mev_inspect.schemas.classified_traces import Classification, ClassifiedTrace
|
||||
from mev_inspect.traces import is_child_trace_address, get_child_traces
|
||||
|
||||
|
||||
@pytest.fixture(name="get_transaction_hashes")
|
||||
def fixture_get_transaction_hashes():
|
||||
def _get_transaction_hashes(n: int):
|
||||
return [sha3_256(str(i).encode("utf-8")).hexdigest() for i in range(n)]
|
||||
|
||||
return _get_transaction_hashes
|
||||
|
||||
|
||||
def make_unknown_classified_trace(
|
||||
block_number,
|
||||
transaction_hash,
|
||||
trace_address,
|
||||
):
|
||||
return ClassifiedTrace(
|
||||
transaction_hash=transaction_hash,
|
||||
block_number=block_number,
|
||||
trace_type=TraceType.call,
|
||||
trace_address=trace_address,
|
||||
classification=Classification.unknown,
|
||||
)
|
||||
|
||||
|
||||
def make_traces(
|
||||
block_number,
|
||||
transaction_hash,
|
||||
trace_addresses,
|
||||
) -> List[ClassifiedTrace]:
|
||||
|
||||
return [
|
||||
make_unknown_classified_trace(
|
||||
block_number,
|
||||
transaction_hash,
|
||||
trace_address,
|
||||
)
|
||||
for trace_address in trace_addresses
|
||||
]
|
||||
|
||||
|
||||
def test_is_child_trace_address():
|
||||
@ -12,3 +57,80 @@ def test_is_child_trace_address():
|
||||
assert not is_child_trace_address([1], [0])
|
||||
assert not is_child_trace_address([1, 0], [0])
|
||||
assert not is_child_trace_address([100, 2, 10], [100, 1])
|
||||
|
||||
|
||||
def test_get_child_traces(get_transaction_hashes):
|
||||
block_number = 123
|
||||
[first_hash, second_hash] = get_transaction_hashes(2)
|
||||
|
||||
traces = []
|
||||
|
||||
first_hash_trace_addresses = [
|
||||
[],
|
||||
[0],
|
||||
[0, 0],
|
||||
[1],
|
||||
[1, 0],
|
||||
[1, 0, 0],
|
||||
[1, 0, 1],
|
||||
[1, 1],
|
||||
[1, 2],
|
||||
]
|
||||
|
||||
second_hash_trace_addresses = [[], [0], [1], [1, 0], [2]]
|
||||
|
||||
traces += make_traces(
|
||||
block_number,
|
||||
first_hash,
|
||||
first_hash_trace_addresses,
|
||||
)
|
||||
|
||||
traces += make_traces(
|
||||
block_number,
|
||||
second_hash,
|
||||
second_hash_trace_addresses,
|
||||
)
|
||||
|
||||
assert has_expected_child_traces(
|
||||
first_hash,
|
||||
[],
|
||||
traces,
|
||||
first_hash_trace_addresses[1:],
|
||||
)
|
||||
|
||||
|
||||
def has_expected_child_traces(
|
||||
transaction_hash: str,
|
||||
parent_trace_address: List[int],
|
||||
traces: List[ClassifiedTrace],
|
||||
expected_trace_addresses: List[List[int]],
|
||||
):
|
||||
child_traces = get_child_traces(
|
||||
transaction_hash,
|
||||
parent_trace_address,
|
||||
traces,
|
||||
)
|
||||
|
||||
distinct_trace_addresses = distinct_lists(expected_trace_addresses)
|
||||
|
||||
if len(child_traces) != len(distinct_trace_addresses):
|
||||
return False
|
||||
|
||||
for trace in child_traces:
|
||||
if trace.transaction_hash != transaction_hash:
|
||||
return False
|
||||
|
||||
if trace.trace_address not in distinct_trace_addresses:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def distinct_lists(list_of_lists: List[List[int]]) -> List[List[int]]:
|
||||
distinct_so_far = []
|
||||
|
||||
for list_of_values in list_of_lists:
|
||||
if list_of_values not in distinct_so_far:
|
||||
distinct_so_far.append(list_of_values)
|
||||
|
||||
return distinct_so_far
|
||||
|
Loading…
x
Reference in New Issue
Block a user