From 17c9b835ac87265bf9a9365e6bcea6ac4dcb32dd Mon Sep 17 00:00:00 2001 From: Luke Van Seters Date: Wed, 22 Dec 2021 10:26:26 -0500 Subject: [PATCH] Simplify smallest logic. Fix tests --- mev_inspect/arbitrages.py | 98 +++++++++++++++------------------------ tests/test_arbitrages.py | 52 +++++++++++++-------- 2 files changed, 70 insertions(+), 80 deletions(-) diff --git a/mev_inspect/arbitrages.py b/mev_inspect/arbitrages.py index 5cedb7f..2c7f158 100644 --- a/mev_inspect/arbitrages.py +++ b/mev_inspect/arbitrages.py @@ -48,7 +48,11 @@ def _get_arbitrages_from_swaps(swaps: List[Swap]) -> List[Arbitrage]: used_swaps: List[Swap] = [] for (start, ends) in start_ends: - route = _get_shortest_route_from_start(start, ends, swaps, used_swaps) + if start in used_swaps: + continue + + unused_ends = [end for end in ends if end not in used_swaps] + route = _get_shortest_route(start, unused_ends, swaps) if route is not None: start_amount = route[0].token_in_amount @@ -84,42 +88,47 @@ def _get_arbitrages_from_swaps(swaps: List[Swap]) -> List[Arbitrage]: ] -def _get_shortest_route_from_start( +def _get_shortest_route( start_swap: Swap, end_swaps: List[Swap], all_swaps: List[Swap], - used_swaps: List[Swap], ) -> Optional[List[Swap]]: - if start_swap in used_swaps: + # TODO - add max length + for end_swap in end_swaps: + if start_swap.token_out_address == end_swap.token_in_address: + return [start_swap, end_swap] + + other_swaps = [ + swap for swap in all_swaps if (swap is not start_swap and swap not in end_swaps) + ] + + if len(other_swaps) == 0: return None - shortest_route = None + shortest_route_rest = None - for end_swap in end_swaps: - if end_swap in used_swaps: - continue - - potential_intermediate_swaps = [ - swap - for swap in all_swaps - if ( - swap is not start_swap - and swap is not end_swap - and swap not in used_swaps + for next_swap in other_swaps: + if start_swap.token_out_address == next_swap.token_in_address and ( + start_swap.contract_address == next_swap.from_address + or start_swap.to_address == next_swap.contract_address + or start_swap.to_address == next_swap.from_address + ): + shortest_from_next = _get_shortest_route( + next_swap, + end_swaps, + other_swaps, ) - ] - routes = _get_all_routes( - start_swap, - end_swap, - potential_intermediate_swaps, - ) + if shortest_from_next is not None and ( + shortest_route_rest is None + or len(shortest_from_next) < len(shortest_route_rest) + ): + shortest_route_rest = shortest_from_next - for route in routes: - if shortest_route is None or len(route) < len(shortest_route): - shortest_route = route - - return shortest_route + if shortest_route_rest is None: + return None + else: + return [start_swap] + shortest_route_rest def _get_all_start_end_swaps(swaps: List[Swap]) -> List[Tuple[Swap, List[Swap]]]: @@ -134,7 +143,7 @@ def _get_all_start_end_swaps(swaps: List[Swap]) -> List[Tuple[Swap, List[Swap]]] valid_start_ends: List[Tuple[Swap, List[Swap]]] = [] for index, potential_start_swap in enumerate(swaps): - ends_for_start = [] + ends_for_start: List[Swap] = [] remaining_swaps = swaps[:index] + swaps[index + 1 :] for potential_end_swap in remaining_swaps: @@ -151,36 +160,3 @@ def _get_all_start_end_swaps(swaps: List[Swap]) -> List[Tuple[Swap, List[Swap]]] valid_start_ends.append((potential_start_swap, ends_for_start)) return valid_start_ends - - -def _get_all_routes( - start_swap: Swap, end_swap: Swap, other_swaps: List[Swap] -) -> List[List[Swap]]: - """ - Returns all routes (List[Swap]) from start to finish between a start_swap and an end_swap only accounting for token_address_in and token_address_out. - """ - # If the path is complete, return - if start_swap.token_out_address == end_swap.token_in_address: - return [[start_swap, end_swap]] - elif len(other_swaps) == 0: - return [] - - # Collect all potential next steps, check if valid, recursively find routes from next_step to end_swap - routes: List[List[Swap]] = [] - for potential_next_swap in other_swaps: - if start_swap.token_out_address == potential_next_swap.token_in_address and ( - start_swap.contract_address == potential_next_swap.from_address - or start_swap.to_address == potential_next_swap.contract_address - or start_swap.to_address == potential_next_swap.from_address - ): - remaining_swaps = [ - swap for swap in other_swaps if swap != potential_next_swap - ] - next_swap_routes = _get_all_routes( - potential_next_swap, end_swap, remaining_swaps - ) - if len(next_swap_routes) > 0: - for next_swap_route in next_swap_routes: - next_swap_route.insert(0, start_swap) - routes.append(next_swap_route) - return routes diff --git a/tests/test_arbitrages.py b/tests/test_arbitrages.py index 71bbc44..7be7b0f 100644 --- a/tests/test_arbitrages.py +++ b/tests/test_arbitrages.py @@ -1,6 +1,6 @@ from typing import List -from mev_inspect.arbitrages import _get_all_routes, get_arbitrages +from mev_inspect.arbitrages import _get_shortest_route, get_arbitrages from mev_inspect.classifiers.specs.uniswap import ( UNISWAP_V2_PAIR_ABI_NAME, UNISWAP_V3_POOL_ABI_NAME, @@ -171,39 +171,46 @@ def test_three_pool_arbitrage(get_transaction_hashes, get_addresses): assert arbitrage.profit_amount == first_token_out_amount - first_token_in_amount -def test_get_all_routes(): +def test_get_shortest_route(): # A -> B, B -> A start_swap = create_generic_swap("0xa", "0xb") end_swap = create_generic_swap("0xb", "0xa") - routes = _get_all_routes(start_swap, end_swap, []) - assert len(routes) == 1 + route = _get_shortest_route(start_swap, [end_swap], []) + assert route is not None + assert len(route) == 2 # A->B, B->C, C->A start_swap = create_generic_swap("0xa", "0xb") other_swaps = [create_generic_swap("0xb", "0xc")] end_swap = create_generic_swap("0xc", "0xa") - routes = _get_all_routes(start_swap, end_swap, other_swaps) - assert len(routes) == 1 + route = _get_shortest_route(start_swap, [end_swap], other_swaps) + assert route is not None + assert len(route) == 3 # A->B, B->C, C->A + A->D other_swaps.append(create_generic_swap("0xa", "0xd")) - routes = _get_all_routes(start_swap, end_swap, other_swaps) - assert len(routes) == 1 + route = _get_shortest_route(start_swap, [end_swap], other_swaps) + assert route is not None + assert len(route) == 3 # A->B, B->C, C->A + A->D B->E other_swaps.append(create_generic_swap("0xb", "0xe")) - routes = _get_all_routes(start_swap, end_swap, other_swaps) - assert len(routes) == 1 + route = _get_shortest_route(start_swap, [end_swap], other_swaps) + assert route is not None + assert len(route) == 3 # A->B, B->A, B->C, C->A other_swaps = [create_generic_swap("0xb", "0xa"), create_generic_swap("0xb", "0xc")] - routes = _get_all_routes(start_swap, end_swap, other_swaps) - assert len(routes) == 1 - expect_simple_route = [["0xa", "0xb"], ["0xb", "0xc"], ["0xc", "0xa"]] - assert len(routes[0]) == len(expect_simple_route) - for i in range(len(expect_simple_route)): - assert expect_simple_route[i][0] == routes[0][i].token_in_address - assert expect_simple_route[i][1] == routes[0][i].token_out_address + route = _get_shortest_route(start_swap, [end_swap], other_swaps) + expected_smallest_route = [["0xa", "0xb"], ["0xb", "0xc"], ["0xc", "0xa"]] + + assert route is not None + assert len(route) == len(expected_smallest_route) + for i, [expected_token_in, expected_token_out] in enumerate( + expected_smallest_route + ): + assert expected_token_in == route[i].token_in_address + assert expected_token_out == route[i].token_out_address # A->B, B->C, C->D, D->A, B->D end_swap = create_generic_swap("0xd", "0xa") @@ -212,8 +219,15 @@ def test_get_all_routes(): create_generic_swap("0xc", "0xd"), create_generic_swap("0xb", "0xd"), ] - routes = _get_all_routes(start_swap, end_swap, other_swaps) - assert len(routes) == 2 + expected_smallest_route = [["0xa", "0xb"], ["0xb", "0xd"], ["0xd", "0xa"]] + route = _get_shortest_route(start_swap, [end_swap], other_swaps) + assert len(route) == 3 + + for i, [expected_token_in, expected_token_out] in enumerate( + expected_smallest_route + ): + assert expected_token_in == route[i].token_in_address + assert expected_token_out == route[i].token_out_address def create_generic_swap(