From f73a34a5ba28e528298286152a79ca46bbc838fc Mon Sep 17 00:00:00 2001 From: Luke Van Seters Date: Mon, 9 Aug 2021 14:21:08 -0400 Subject: [PATCH] Add tests for get_child_traces. Add some helpful fixtures --- mev_inspect/swaps.py | 7 ++- mev_inspect/traces.py | 3 +- mev_inspect/transfers.py | 3 +- tests/test_traces.py | 124 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 133 insertions(+), 4 deletions(-) diff --git a/mev_inspect/swaps.py b/mev_inspect/swaps.py index 6e64f1a..94bf9c5 100644 --- a/mev_inspect/swaps.py +++ b/mev_inspect/swaps.py @@ -44,7 +44,12 @@ def _get_swaps_for_transaction(traces: List[ClassifiedTrace]) -> List[Swap]: prior_transfers.append(Transfer.from_trace(trace)) elif trace.classification == Classification.swap: - child_transfers = get_child_transfers(trace.trace_address, traces) + child_transfers = get_child_transfers( + trace.transaction_hash, + trace.trace_address, + traces, + ) + swap = _parse_swap( trace, remove_inner_transfers(prior_transfers), diff --git a/mev_inspect/traces.py b/mev_inspect/traces.py index 41d2907..459dddc 100644 --- a/mev_inspect/traces.py +++ b/mev_inspect/traces.py @@ -16,6 +16,7 @@ def is_child_trace_address( def get_child_traces( + transaction_hash: str, parent_trace_address: List[int], traces: List[ClassifiedTrace], ) -> List[ClassifiedTrace]: @@ -23,7 +24,7 @@ def get_child_traces( child_traces = [] for trace in ordered_traces: - if is_child_trace_address( + if trace.transaction_hash == transaction_hash and is_child_trace_address( trace.trace_address, parent_trace_address, ): diff --git a/mev_inspect/transfers.py b/mev_inspect/transfers.py index b150e90..924a8d2 100644 --- a/mev_inspect/transfers.py +++ b/mev_inspect/transfers.py @@ -6,12 +6,13 @@ from mev_inspect.traces import is_child_trace_address, get_child_traces def get_child_transfers( + transaction_hash: str, parent_trace_address: List[int], traces: List[ClassifiedTrace], ) -> List[Transfer]: child_transfers = [] - for child_trace in get_child_traces(parent_trace_address, traces): + for child_trace in get_child_traces(transaction_hash, parent_trace_address, traces): if child_trace.classification == Classification.transfer: child_transfers.append(Transfer.from_trace(child_trace)) diff --git a/tests/test_traces.py b/tests/test_traces.py index c0b7978..138120c 100644 --- a/tests/test_traces.py +++ b/tests/test_traces.py @@ -1,4 +1,49 @@ -from mev_inspect.traces import is_child_trace_address +from hashlib import sha3_256 +from typing import List + +import pytest + +from mev_inspect.schemas.blocks import TraceType +from mev_inspect.schemas.classified_traces import Classification, ClassifiedTrace +from mev_inspect.traces import is_child_trace_address, get_child_traces + + +@pytest.fixture(name="get_transaction_hashes") +def fixture_get_transaction_hashes(): + def _get_transaction_hashes(n: int): + return [sha3_256(str(i).encode("utf-8")).hexdigest() for i in range(n)] + + return _get_transaction_hashes + + +def make_unknown_classified_trace( + block_number, + transaction_hash, + trace_address, +): + return ClassifiedTrace( + transaction_hash=transaction_hash, + block_number=block_number, + trace_type=TraceType.call, + trace_address=trace_address, + classification=Classification.unknown, + ) + + +def make_traces( + block_number, + transaction_hash, + trace_addresses, +) -> List[ClassifiedTrace]: + + return [ + make_unknown_classified_trace( + block_number, + transaction_hash, + trace_address, + ) + for trace_address in trace_addresses + ] def test_is_child_trace_address(): @@ -12,3 +57,80 @@ def test_is_child_trace_address(): assert not is_child_trace_address([1], [0]) assert not is_child_trace_address([1, 0], [0]) assert not is_child_trace_address([100, 2, 10], [100, 1]) + + +def test_get_child_traces(get_transaction_hashes): + block_number = 123 + [first_hash, second_hash] = get_transaction_hashes(2) + + traces = [] + + first_hash_trace_addresses = [ + [], + [0], + [0, 0], + [1], + [1, 0], + [1, 0, 0], + [1, 0, 1], + [1, 1], + [1, 2], + ] + + second_hash_trace_addresses = [[], [0], [1], [1, 0], [2]] + + traces += make_traces( + block_number, + first_hash, + first_hash_trace_addresses, + ) + + traces += make_traces( + block_number, + second_hash, + second_hash_trace_addresses, + ) + + assert has_expected_child_traces( + first_hash, + [], + traces, + first_hash_trace_addresses[1:], + ) + + +def has_expected_child_traces( + transaction_hash: str, + parent_trace_address: List[int], + traces: List[ClassifiedTrace], + expected_trace_addresses: List[List[int]], +): + child_traces = get_child_traces( + transaction_hash, + parent_trace_address, + traces, + ) + + distinct_trace_addresses = distinct_lists(expected_trace_addresses) + + if len(child_traces) != len(distinct_trace_addresses): + return False + + for trace in child_traces: + if trace.transaction_hash != transaction_hash: + return False + + if trace.trace_address not in distinct_trace_addresses: + return False + + return True + + +def distinct_lists(list_of_lists: List[List[int]]) -> List[List[int]]: + distinct_so_far = [] + + for list_of_values in list_of_lists: + if list_of_values not in distinct_so_far: + distinct_so_far.append(list_of_values) + + return distinct_so_far