diff --git a/mev_inspect/arbitrages.py b/mev_inspect/arbitrages.py index 786fbd8..893c159 100644 --- a/mev_inspect/arbitrages.py +++ b/mev_inspect/arbitrages.py @@ -1,5 +1,5 @@ from itertools import groupby -from typing import List, Tuple +from typing import List, Optional, Tuple from mev_inspect.schemas.arbitrages import Arbitrage from mev_inspect.schemas.swaps import Swap @@ -45,14 +45,16 @@ def _get_arbitrages_from_swaps(swaps: List[Swap]) -> List[Arbitrage]: if len(start_ends) == 0: return [] - # for (start, end) in filtered_start_ends: - for (start, end) in start_ends: - potential_intermediate_swaps = [ - swap for swap in swaps if swap is not start and swap is not end - ] - routes = _get_all_routes(start, end, potential_intermediate_swaps) + used_swaps: List[Swap] = [] - for route in routes: + for (start, ends) in start_ends: + if start in used_swaps: + continue + + unused_ends = [end for end in ends if end not in used_swaps] + route = _get_shortest_route(start, unused_ends, swaps) + + if route is not None: start_amount = route[0].token_in_amount end_amount = route[-1].token_out_amount profit_amount = end_amount - start_amount @@ -72,7 +74,10 @@ def _get_arbitrages_from_swaps(swaps: List[Swap]) -> List[Arbitrage]: profit_amount=profit_amount, error=error, ) + all_arbitrages.append(arb) + used_swaps.extend(route) + if len(all_arbitrages) == 1: return all_arbitrages else: @@ -83,18 +88,63 @@ def _get_arbitrages_from_swaps(swaps: List[Swap]) -> List[Arbitrage]: ] -def _get_all_start_end_swaps(swaps: List[Swap]) -> List[Tuple[Swap, Swap]]: +def _get_shortest_route( + start_swap: Swap, + end_swaps: List[Swap], + all_swaps: List[Swap], +) -> Optional[List[Swap]]: + for end_swap in end_swaps: + if start_swap.token_out_address == end_swap.token_in_address: + return [start_swap, end_swap] + + other_swaps = [ + swap for swap in all_swaps if (swap is not start_swap and swap not in end_swaps) + ] + + if len(other_swaps) == 0: + return None + + shortest_remaining_route = None + + for next_swap in other_swaps: + if start_swap.token_out_address == next_swap.token_in_address and ( + start_swap.contract_address == next_swap.from_address + or start_swap.to_address == next_swap.contract_address + or start_swap.to_address == next_swap.from_address + ): + shortest_from_next = _get_shortest_route( + next_swap, + end_swaps, + other_swaps, + ) + + if shortest_from_next is not None and ( + shortest_remaining_route is None + or len(shortest_from_next) < len(shortest_remaining_route) + ): + shortest_remaining_route = shortest_from_next + + if shortest_remaining_route is None: + return None + else: + return [start_swap] + shortest_remaining_route + + +def _get_all_start_end_swaps(swaps: List[Swap]) -> List[Tuple[Swap, List[Swap]]]: """ - Gets the set of all possible opening and closing swap pairs in an arbitrage via + Gets the set of all possible openings and corresponding closing swaps for an arbitrage via - swap[start].token_in == swap[end].token_out - swap[start].from_address == swap[end].to_address - not swap[start].from_address in all_pool_addresses - not swap[end].to_address in all_pool_addresses """ pool_addrs = [swap.contract_address for swap in swaps] - valid_start_ends: List[Tuple[Swap, Swap]] = [] + valid_start_ends: List[Tuple[Swap, List[Swap]]] = [] + for index, potential_start_swap in enumerate(swaps): + ends_for_start: List[Swap] = [] remaining_swaps = swaps[:index] + swaps[index + 1 :] + for potential_end_swap in remaining_swaps: if ( potential_start_swap.token_in_address @@ -102,38 +152,10 @@ def _get_all_start_end_swaps(swaps: List[Swap]) -> List[Tuple[Swap, Swap]]: and potential_start_swap.from_address == potential_end_swap.to_address and not potential_start_swap.from_address in pool_addrs ): - valid_start_ends.append((potential_start_swap, potential_end_swap)) + + ends_for_start.append(potential_end_swap) + + if len(ends_for_start) > 0: + valid_start_ends.append((potential_start_swap, ends_for_start)) + return valid_start_ends - - -def _get_all_routes( - start_swap: Swap, end_swap: Swap, other_swaps: List[Swap] -) -> List[List[Swap]]: - """ - Returns all routes (List[Swap]) from start to finish between a start_swap and an end_swap only accounting for token_address_in and token_address_out. - """ - # If the path is complete, return - if start_swap.token_out_address == end_swap.token_in_address: - return [[start_swap, end_swap]] - elif len(other_swaps) == 0: - return [] - - # Collect all potential next steps, check if valid, recursively find routes from next_step to end_swap - routes: List[List[Swap]] = [] - for potential_next_swap in other_swaps: - if start_swap.token_out_address == potential_next_swap.token_in_address and ( - start_swap.contract_address == potential_next_swap.from_address - or start_swap.to_address == potential_next_swap.contract_address - or start_swap.to_address == potential_next_swap.from_address - ): - remaining_swaps = [ - swap for swap in other_swaps if swap != potential_next_swap - ] - next_swap_routes = _get_all_routes( - potential_next_swap, end_swap, remaining_swaps - ) - if len(next_swap_routes) > 0: - for next_swap_route in next_swap_routes: - next_swap_route.insert(0, start_swap) - routes.append(next_swap_route) - return routes diff --git a/tests/test_arbitrage_integration.py b/tests/test_arbitrage_integration.py index 8437a3c..9dd8b3a 100644 --- a/tests/test_arbitrage_integration.py +++ b/tests/test_arbitrage_integration.py @@ -67,7 +67,7 @@ def test_reverting_arbitrage(trace_classifier: TraceClassifier): assert len(swaps) == 38 arbitrages = get_arbitrages(list(swaps)) - assert len(arbitrages) == 21 + assert len(arbitrages) == 4 arbitrage_1 = [ arb diff --git a/tests/test_arbitrages.py b/tests/test_arbitrages.py index 71bbc44..1a321fa 100644 --- a/tests/test_arbitrages.py +++ b/tests/test_arbitrages.py @@ -1,6 +1,6 @@ -from typing import List +from typing import List, Tuple -from mev_inspect.arbitrages import _get_all_routes, get_arbitrages +from mev_inspect.arbitrages import _get_shortest_route, get_arbitrages from mev_inspect.classifiers.specs.uniswap import ( UNISWAP_V2_PAIR_ABI_NAME, UNISWAP_V3_POOL_ABI_NAME, @@ -171,39 +171,41 @@ def test_three_pool_arbitrage(get_transaction_hashes, get_addresses): assert arbitrage.profit_amount == first_token_out_amount - first_token_in_amount -def test_get_all_routes(): +def test_get_shortest_route(): # A -> B, B -> A start_swap = create_generic_swap("0xa", "0xb") end_swap = create_generic_swap("0xb", "0xa") - routes = _get_all_routes(start_swap, end_swap, []) - assert len(routes) == 1 + shortest_route = _get_shortest_route(start_swap, [end_swap], []) + assert shortest_route is not None + assert len(shortest_route) == 2 # A->B, B->C, C->A start_swap = create_generic_swap("0xa", "0xb") other_swaps = [create_generic_swap("0xb", "0xc")] end_swap = create_generic_swap("0xc", "0xa") - routes = _get_all_routes(start_swap, end_swap, other_swaps) - assert len(routes) == 1 + shortest_route = _get_shortest_route(start_swap, [end_swap], other_swaps) + assert shortest_route is not None + assert len(shortest_route) == 3 # A->B, B->C, C->A + A->D other_swaps.append(create_generic_swap("0xa", "0xd")) - routes = _get_all_routes(start_swap, end_swap, other_swaps) - assert len(routes) == 1 + shortest_route = _get_shortest_route(start_swap, [end_swap], other_swaps) + assert shortest_route is not None + assert len(shortest_route) == 3 # A->B, B->C, C->A + A->D B->E other_swaps.append(create_generic_swap("0xb", "0xe")) - routes = _get_all_routes(start_swap, end_swap, other_swaps) - assert len(routes) == 1 + shortest_route = _get_shortest_route(start_swap, [end_swap], other_swaps) + assert shortest_route is not None + assert len(shortest_route) == 3 # A->B, B->A, B->C, C->A other_swaps = [create_generic_swap("0xb", "0xa"), create_generic_swap("0xb", "0xc")] - routes = _get_all_routes(start_swap, end_swap, other_swaps) - assert len(routes) == 1 - expect_simple_route = [["0xa", "0xb"], ["0xb", "0xc"], ["0xc", "0xa"]] - assert len(routes[0]) == len(expect_simple_route) - for i in range(len(expect_simple_route)): - assert expect_simple_route[i][0] == routes[0][i].token_in_address - assert expect_simple_route[i][1] == routes[0][i].token_out_address + actual_shortest_route = _get_shortest_route(start_swap, [end_swap], other_swaps) + expected_shortest_route = [("0xa", "0xb"), ("0xb", "0xc"), ("0xc", "0xa")] + + assert actual_shortest_route is not None + _assert_route_tokens_equal(actual_shortest_route, expected_shortest_route) # A->B, B->C, C->D, D->A, B->D end_swap = create_generic_swap("0xd", "0xa") @@ -212,8 +214,24 @@ def test_get_all_routes(): create_generic_swap("0xc", "0xd"), create_generic_swap("0xb", "0xd"), ] - routes = _get_all_routes(start_swap, end_swap, other_swaps) - assert len(routes) == 2 + expected_shortest_route = [("0xa", "0xb"), ("0xb", "0xd"), ("0xd", "0xa")] + actual_shortest_route = _get_shortest_route(start_swap, [end_swap], other_swaps) + + assert actual_shortest_route is not None + _assert_route_tokens_equal(actual_shortest_route, expected_shortest_route) + + +def _assert_route_tokens_equal( + route: List[Swap], + expected_token_in_out_pairs: List[Tuple[str, str]], +) -> None: + assert len(route) == len(expected_token_in_out_pairs) + + for i, [expected_token_in, expected_token_out] in enumerate( + expected_token_in_out_pairs + ): + assert expected_token_in == route[i].token_in_address + assert expected_token_out == route[i].token_out_address def create_generic_swap(