diff --git a/mev_inspect/transfers.py b/mev_inspect/transfers.py index 924a8d2..991e3fe 100644 --- a/mev_inspect/transfers.py +++ b/mev_inspect/transfers.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Dict, List, Optional from mev_inspect.schemas.classified_traces import Classification, ClassifiedTrace from mev_inspect.schemas.transfers import Transfer @@ -40,17 +40,23 @@ def filter_transfers( def remove_inner_transfers(transfers: List[Transfer]) -> List[Transfer]: updated_transfers = [] - transfer_trace_addresses: List[List[int]] = [] + transfer_addresses_by_transaction: Dict[str, List[List[int]]] = {} sorted_transfers = sorted(transfers, key=lambda t: t.trace_address) for transfer in sorted_transfers: + existing_addresses = transfer_addresses_by_transaction.get( + transfer.transaction_hash, [] + ) + if not any( is_child_trace_address(transfer.trace_address, parent_address) - for parent_address in transfer_trace_addresses + for parent_address in existing_addresses ): updated_transfers.append(transfer) - transfer_trace_addresses.append(transfer.trace_address) + transfer_addresses_by_transaction[ + transfer.transaction_hash + ] = existing_addresses + [transfer.trace_address] return updated_transfers diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..9391342 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,26 @@ +from hashlib import sha3_256 +from typing import List + +import pytest + + +@pytest.fixture(name="get_transaction_hashes") +def fixture_get_transaction_hashes(): + def _get_transaction_hashes(n: int): + return _hash_with_prefix(n, "transaction_hash") + + return _get_transaction_hashes + + +@pytest.fixture(name="get_addresses") +def fixture_get_addresses(): + def _get_addresses(n: int): + return [f"0x{hash_value[:40]}" for hash_value in _hash_with_prefix(n, "addr")] + + return _get_addresses + + +def _hash_with_prefix(n_hashes: int, prefix: str) -> List[str]: + return [ + sha3_256(f"{prefix}{i}".encode("utf-8")).hexdigest() for i in range(n_hashes) + ] diff --git a/tests/test_traces.py b/tests/test_traces.py index a11a253..489d677 100644 --- a/tests/test_traces.py +++ b/tests/test_traces.py @@ -1,21 +1,10 @@ -from hashlib import sha3_256 from typing import List -import pytest - from mev_inspect.schemas.blocks import TraceType from mev_inspect.schemas.classified_traces import Classification, ClassifiedTrace from mev_inspect.traces import is_child_trace_address, get_child_traces -@pytest.fixture(name="get_transaction_hashes") -def fixture_get_transaction_hashes(): - def _get_transaction_hashes(n: int): - return [sha3_256(str(i).encode("utf-8")).hexdigest() for i in range(n)] - - return _get_transaction_hashes - - def make_unknown_classified_trace( block_number, transaction_hash, diff --git a/tests/test_transfers.py b/tests/test_transfers.py new file mode 100644 index 0000000..d65a753 --- /dev/null +++ b/tests/test_transfers.py @@ -0,0 +1,71 @@ +from mev_inspect.schemas.transfers import Transfer +from mev_inspect.transfers import remove_inner_transfers + + +def test_remove_inner_transfers(get_transaction_hashes, get_addresses): + [transaction_hash, other_transaction_hash] = get_transaction_hashes(2) + + [ + alice_address, + bob_address, + first_token_address, + second_token_address, + third_token_address, + ] = get_addresses(5) + + outer_transfer = Transfer( + transaction_hash=transaction_hash, + trace_address=[0], + from_address=alice_address, + to_address=bob_address, + amount=10, + token_address=first_token_address, + ) + + inner_transfer = Transfer( + **{ + **outer_transfer.dict(), + **dict( + trace_address=[0, 0], + token_address=second_token_address, + ), + } + ) + + other_transfer = Transfer( + transaction_hash=transaction_hash, + trace_address=[1], + from_address=bob_address, + to_address=alice_address, + amount=10, + token_address=third_token_address, + ) + + separate_transaction_transfer = Transfer( + **{ + **inner_transfer.dict(), + **dict(transaction_hash=other_transaction_hash), + } + ) + + transfers = [ + outer_transfer, + inner_transfer, + other_transfer, + separate_transaction_transfer, + ] + + expected_transfers = [ + outer_transfer, + other_transfer, + separate_transaction_transfer, + ] + + removed_transfers = remove_inner_transfers(transfers) + assert _equal_ignoring_order(removed_transfers, expected_transfers) + + +def _equal_ignoring_order(first_list, second_list) -> bool: + return all(first in second_list for first in first_list) and all( + second in first_list for second in second_list + )