diff --git a/mev_inspect/arbitrages.py b/mev_inspect/arbitrages.py index 893c159..fe3491b 100644 --- a/mev_inspect/arbitrages.py +++ b/mev_inspect/arbitrages.py @@ -92,11 +92,21 @@ def _get_shortest_route( start_swap: Swap, end_swaps: List[Swap], all_swaps: List[Swap], + max_route_length: Optional[int] = None, ) -> Optional[List[Swap]]: + if len(end_swaps) == 0: + return None + + if max_route_length is not None and max_route_length < 2: + return None + for end_swap in end_swaps: if start_swap.token_out_address == end_swap.token_in_address: return [start_swap, end_swap] + if max_route_length is not None and max_route_length == 2: + return None + other_swaps = [ swap for swap in all_swaps if (swap is not start_swap and swap not in end_swaps) ] @@ -105,6 +115,9 @@ def _get_shortest_route( return None shortest_remaining_route = None + max_remaining_route_length = ( + None if max_route_length is None else max_route_length - 1 + ) for next_swap in other_swaps: if start_swap.token_out_address == next_swap.token_in_address and ( @@ -116,6 +129,7 @@ def _get_shortest_route( next_swap, end_swaps, other_swaps, + max_route_length=max_remaining_route_length, ) if shortest_from_next is not None and ( @@ -123,6 +137,7 @@ def _get_shortest_route( or len(shortest_from_next) < len(shortest_remaining_route) ): shortest_remaining_route = shortest_from_next + max_remaining_route_length = len(shortest_from_next) - 1 if shortest_remaining_route is None: return None