diff --git a/mev_inspect/traces.py b/mev_inspect/traces.py index 29c8acd..792897e 100644 --- a/mev_inspect/traces.py +++ b/mev_inspect/traces.py @@ -1,9 +1,26 @@ +from itertools import groupby from typing import Iterable, List from mev_inspect.schemas import Trace, NestedTrace def as_nested_traces(traces: Iterable[Trace]) -> List[NestedTrace]: + nested_traces = [] + + sorted_by_transaction_hash = sorted(traces, key=_get_transaction_hash) + for _, transaction_traces in groupby( + sorted_by_transaction_hash, _get_transaction_hash + ): + nested_traces += _as_nested_traces_by_transaction(transaction_traces) + + return nested_traces + + +def _get_transaction_hash(trace) -> str: + return trace.transaction_hash + + +def _as_nested_traces_by_transaction(traces: Iterable[Trace]) -> List[NestedTrace]: """ Turns a list of Traces into a a tree of NestedTraces using their trace addresses diff --git a/tests/trace_test.py b/tests/trace_test.py index 24af48f..dd415aa 100644 --- a/tests/trace_test.py +++ b/tests/trace_test.py @@ -10,29 +10,38 @@ DEFAULT_BLOCK_NUMBER = 123 class TestTraces(unittest.TestCase): def test_nested_traces(self): - trace_addresses = [ - [0, 2], - [], - [2], - [0], - [0, 0], - [0, 1], - [1], - [1, 0], - [0, 1, 0], + trace_hash_address_pairs = [ + ("abc", [0, 2]), + ("abc", []), + ("abc", [2]), + ("abc", [0]), + ("abc", [0, 0]), + ("abc", [0, 1]), + ("abc", [1]), + ("efg", []), + ("abc", [1, 0]), + ("abc", [0, 1, 0]), + ("efg", [0]), ] - traces = [build_trace_at_address(address) for address in trace_addresses] + traces = [ + build_trace_at_address(hash, address) + for (hash, address) in trace_hash_address_pairs + ] nested_traces = as_nested_traces(traces) - assert len(nested_traces) == 1 - root_trace = nested_traces[0] + assert len(nested_traces) == 2 - assert_trace_address(root_trace, []) - assert len(root_trace.subtraces) == 3 + abc_trace = nested_traces[0] + efg_trace = nested_traces[1] - [trace_0, trace_1, trace_2] = root_trace.subtraces + # abc + assert abc_trace.trace.transaction_hash == "abc" + assert_trace_address(abc_trace, []) + assert len(abc_trace.subtraces) == 3 + + [trace_0, trace_1, trace_2] = abc_trace.subtraces assert_trace_address(trace_0, [0]) assert_trace_address(trace_1, [1]) @@ -59,15 +68,26 @@ class TestTraces(unittest.TestCase): assert_trace_address(trace_0_1_0, [0, 1, 0]) assert len(trace_0_1_0.subtraces) == 0 + # efg + assert efg_trace.trace.transaction_hash == "efg" + assert_trace_address(efg_trace, []) + assert len(efg_trace.subtraces) == 1 + + [efg_subtrace] = efg_trace.subtraces + + assert_trace_address(efg_subtrace, [0]) + assert len(efg_subtrace.subtraces) == 0 + def build_trace_at_address( + transaction_hash: str, trace_address: List[int], ) -> Trace: return Trace( # real values + transaction_hash=transaction_hash, trace_address=trace_address, # placeholders - transaction_hash="", action={}, block_hash="", block_number=DEFAULT_BLOCK_NUMBER,