Handle multiple transaction hashes

This commit is contained in:
Luke Van Seters 2021-07-20 18:50:43 -04:00
parent 311f265d1b
commit 59eb9ef514
2 changed files with 54 additions and 17 deletions

View File

@ -1,9 +1,26 @@
from itertools import groupby
from typing import Iterable, List
from mev_inspect.schemas import Trace, NestedTrace
def as_nested_traces(traces: Iterable[Trace]) -> List[NestedTrace]:
nested_traces = []
sorted_by_transaction_hash = sorted(traces, key=_get_transaction_hash)
for _, transaction_traces in groupby(
sorted_by_transaction_hash, _get_transaction_hash
):
nested_traces += _as_nested_traces_by_transaction(transaction_traces)
return nested_traces
def _get_transaction_hash(trace) -> str:
return trace.transaction_hash
def _as_nested_traces_by_transaction(traces: Iterable[Trace]) -> List[NestedTrace]:
"""
Turns a list of Traces into a a tree of NestedTraces
using their trace addresses

View File

@ -10,29 +10,38 @@ DEFAULT_BLOCK_NUMBER = 123
class TestTraces(unittest.TestCase):
def test_nested_traces(self):
trace_addresses = [
[0, 2],
[],
[2],
[0],
[0, 0],
[0, 1],
[1],
[1, 0],
[0, 1, 0],
trace_hash_address_pairs = [
("abc", [0, 2]),
("abc", []),
("abc", [2]),
("abc", [0]),
("abc", [0, 0]),
("abc", [0, 1]),
("abc", [1]),
("efg", []),
("abc", [1, 0]),
("abc", [0, 1, 0]),
("efg", [0]),
]
traces = [build_trace_at_address(address) for address in trace_addresses]
traces = [
build_trace_at_address(hash, address)
for (hash, address) in trace_hash_address_pairs
]
nested_traces = as_nested_traces(traces)
assert len(nested_traces) == 1
root_trace = nested_traces[0]
assert len(nested_traces) == 2
assert_trace_address(root_trace, [])
assert len(root_trace.subtraces) == 3
abc_trace = nested_traces[0]
efg_trace = nested_traces[1]
[trace_0, trace_1, trace_2] = root_trace.subtraces
# abc
assert abc_trace.trace.transaction_hash == "abc"
assert_trace_address(abc_trace, [])
assert len(abc_trace.subtraces) == 3
[trace_0, trace_1, trace_2] = abc_trace.subtraces
assert_trace_address(trace_0, [0])
assert_trace_address(trace_1, [1])
@ -59,15 +68,26 @@ class TestTraces(unittest.TestCase):
assert_trace_address(trace_0_1_0, [0, 1, 0])
assert len(trace_0_1_0.subtraces) == 0
# efg
assert efg_trace.trace.transaction_hash == "efg"
assert_trace_address(efg_trace, [])
assert len(efg_trace.subtraces) == 1
[efg_subtrace] = efg_trace.subtraces
assert_trace_address(efg_subtrace, [0])
assert len(efg_subtrace.subtraces) == 0
def build_trace_at_address(
transaction_hash: str,
trace_address: List[int],
) -> Trace:
return Trace(
# real values
transaction_hash=transaction_hash,
trace_address=trace_address,
# placeholders
transaction_hash="",
action={},
block_hash="",
block_number=DEFAULT_BLOCK_NUMBER,