Simplify smallest logic. Fix tests

This commit is contained in:
Luke Van Seters 2021-12-22 10:26:26 -05:00
parent 46b768c147
commit 17c9b835ac
2 changed files with 70 additions and 80 deletions

View File

@ -48,7 +48,11 @@ def _get_arbitrages_from_swaps(swaps: List[Swap]) -> List[Arbitrage]:
used_swaps: List[Swap] = []
for (start, ends) in start_ends:
route = _get_shortest_route_from_start(start, ends, swaps, used_swaps)
if start in used_swaps:
continue
unused_ends = [end for end in ends if end not in used_swaps]
route = _get_shortest_route(start, unused_ends, swaps)
if route is not None:
start_amount = route[0].token_in_amount
@ -84,42 +88,47 @@ def _get_arbitrages_from_swaps(swaps: List[Swap]) -> List[Arbitrage]:
]
def _get_shortest_route_from_start(
def _get_shortest_route(
start_swap: Swap,
end_swaps: List[Swap],
all_swaps: List[Swap],
used_swaps: List[Swap],
) -> Optional[List[Swap]]:
if start_swap in used_swaps:
# TODO - add max length
for end_swap in end_swaps:
if start_swap.token_out_address == end_swap.token_in_address:
return [start_swap, end_swap]
other_swaps = [
swap for swap in all_swaps if (swap is not start_swap and swap not in end_swaps)
]
if len(other_swaps) == 0:
return None
shortest_route = None
shortest_route_rest = None
for end_swap in end_swaps:
if end_swap in used_swaps:
continue
potential_intermediate_swaps = [
swap
for swap in all_swaps
if (
swap is not start_swap
and swap is not end_swap
and swap not in used_swaps
for next_swap in other_swaps:
if start_swap.token_out_address == next_swap.token_in_address and (
start_swap.contract_address == next_swap.from_address
or start_swap.to_address == next_swap.contract_address
or start_swap.to_address == next_swap.from_address
):
shortest_from_next = _get_shortest_route(
next_swap,
end_swaps,
other_swaps,
)
]
routes = _get_all_routes(
start_swap,
end_swap,
potential_intermediate_swaps,
)
if shortest_from_next is not None and (
shortest_route_rest is None
or len(shortest_from_next) < len(shortest_route_rest)
):
shortest_route_rest = shortest_from_next
for route in routes:
if shortest_route is None or len(route) < len(shortest_route):
shortest_route = route
return shortest_route
if shortest_route_rest is None:
return None
else:
return [start_swap] + shortest_route_rest
def _get_all_start_end_swaps(swaps: List[Swap]) -> List[Tuple[Swap, List[Swap]]]:
@ -134,7 +143,7 @@ def _get_all_start_end_swaps(swaps: List[Swap]) -> List[Tuple[Swap, List[Swap]]]
valid_start_ends: List[Tuple[Swap, List[Swap]]] = []
for index, potential_start_swap in enumerate(swaps):
ends_for_start = []
ends_for_start: List[Swap] = []
remaining_swaps = swaps[:index] + swaps[index + 1 :]
for potential_end_swap in remaining_swaps:
@ -151,36 +160,3 @@ def _get_all_start_end_swaps(swaps: List[Swap]) -> List[Tuple[Swap, List[Swap]]]
valid_start_ends.append((potential_start_swap, ends_for_start))
return valid_start_ends
def _get_all_routes(
start_swap: Swap, end_swap: Swap, other_swaps: List[Swap]
) -> List[List[Swap]]:
"""
Returns all routes (List[Swap]) from start to finish between a start_swap and an end_swap only accounting for token_address_in and token_address_out.
"""
# If the path is complete, return
if start_swap.token_out_address == end_swap.token_in_address:
return [[start_swap, end_swap]]
elif len(other_swaps) == 0:
return []
# Collect all potential next steps, check if valid, recursively find routes from next_step to end_swap
routes: List[List[Swap]] = []
for potential_next_swap in other_swaps:
if start_swap.token_out_address == potential_next_swap.token_in_address and (
start_swap.contract_address == potential_next_swap.from_address
or start_swap.to_address == potential_next_swap.contract_address
or start_swap.to_address == potential_next_swap.from_address
):
remaining_swaps = [
swap for swap in other_swaps if swap != potential_next_swap
]
next_swap_routes = _get_all_routes(
potential_next_swap, end_swap, remaining_swaps
)
if len(next_swap_routes) > 0:
for next_swap_route in next_swap_routes:
next_swap_route.insert(0, start_swap)
routes.append(next_swap_route)
return routes

View File

@ -1,6 +1,6 @@
from typing import List
from mev_inspect.arbitrages import _get_all_routes, get_arbitrages
from mev_inspect.arbitrages import _get_shortest_route, get_arbitrages
from mev_inspect.classifiers.specs.uniswap import (
UNISWAP_V2_PAIR_ABI_NAME,
UNISWAP_V3_POOL_ABI_NAME,
@ -171,39 +171,46 @@ def test_three_pool_arbitrage(get_transaction_hashes, get_addresses):
assert arbitrage.profit_amount == first_token_out_amount - first_token_in_amount
def test_get_all_routes():
def test_get_shortest_route():
# A -> B, B -> A
start_swap = create_generic_swap("0xa", "0xb")
end_swap = create_generic_swap("0xb", "0xa")
routes = _get_all_routes(start_swap, end_swap, [])
assert len(routes) == 1
route = _get_shortest_route(start_swap, [end_swap], [])
assert route is not None
assert len(route) == 2
# A->B, B->C, C->A
start_swap = create_generic_swap("0xa", "0xb")
other_swaps = [create_generic_swap("0xb", "0xc")]
end_swap = create_generic_swap("0xc", "0xa")
routes = _get_all_routes(start_swap, end_swap, other_swaps)
assert len(routes) == 1
route = _get_shortest_route(start_swap, [end_swap], other_swaps)
assert route is not None
assert len(route) == 3
# A->B, B->C, C->A + A->D
other_swaps.append(create_generic_swap("0xa", "0xd"))
routes = _get_all_routes(start_swap, end_swap, other_swaps)
assert len(routes) == 1
route = _get_shortest_route(start_swap, [end_swap], other_swaps)
assert route is not None
assert len(route) == 3
# A->B, B->C, C->A + A->D B->E
other_swaps.append(create_generic_swap("0xb", "0xe"))
routes = _get_all_routes(start_swap, end_swap, other_swaps)
assert len(routes) == 1
route = _get_shortest_route(start_swap, [end_swap], other_swaps)
assert route is not None
assert len(route) == 3
# A->B, B->A, B->C, C->A
other_swaps = [create_generic_swap("0xb", "0xa"), create_generic_swap("0xb", "0xc")]
routes = _get_all_routes(start_swap, end_swap, other_swaps)
assert len(routes) == 1
expect_simple_route = [["0xa", "0xb"], ["0xb", "0xc"], ["0xc", "0xa"]]
assert len(routes[0]) == len(expect_simple_route)
for i in range(len(expect_simple_route)):
assert expect_simple_route[i][0] == routes[0][i].token_in_address
assert expect_simple_route[i][1] == routes[0][i].token_out_address
route = _get_shortest_route(start_swap, [end_swap], other_swaps)
expected_smallest_route = [["0xa", "0xb"], ["0xb", "0xc"], ["0xc", "0xa"]]
assert route is not None
assert len(route) == len(expected_smallest_route)
for i, [expected_token_in, expected_token_out] in enumerate(
expected_smallest_route
):
assert expected_token_in == route[i].token_in_address
assert expected_token_out == route[i].token_out_address
# A->B, B->C, C->D, D->A, B->D
end_swap = create_generic_swap("0xd", "0xa")
@ -212,8 +219,15 @@ def test_get_all_routes():
create_generic_swap("0xc", "0xd"),
create_generic_swap("0xb", "0xd"),
]
routes = _get_all_routes(start_swap, end_swap, other_swaps)
assert len(routes) == 2
expected_smallest_route = [["0xa", "0xb"], ["0xb", "0xd"], ["0xd", "0xa"]]
route = _get_shortest_route(start_swap, [end_swap], other_swaps)
assert len(route) == 3
for i, [expected_token_in, expected_token_out] in enumerate(
expected_smallest_route
):
assert expected_token_in == route[i].token_in_address
assert expected_token_out == route[i].token_out_address
def create_generic_swap(