diff --git a/mev_inspect/schemas/strategies.py b/mev_inspect/schemas/strategies.py deleted file mode 100644 index f94dda2..0000000 --- a/mev_inspect/schemas/strategies.py +++ /dev/null @@ -1,5 +0,0 @@ -from pydantic import BaseModel - - -class Arbitrage(BaseModel): - pass diff --git a/mev_inspect/strategies/arbitrage.py b/mev_inspect/strategies/arbitrage.py index 49eeae0..605e3d3 100644 --- a/mev_inspect/strategies/arbitrage.py +++ b/mev_inspect/strategies/arbitrage.py @@ -7,7 +7,6 @@ from mev_inspect.schemas.classified_traces import ( Classification, Protocol, ) -from mev_inspect.schemas.strategies import Arbitrage UNISWAP_V2_PAIR_ABI_NAME = "UniswapV2Pair" @@ -26,6 +25,7 @@ class Transfer(BaseModel): class Swap(BaseModel): abi_name: str transaction_hash: str + trace_address: List[int] protocol: Optional[Protocol] pool_address: str from_address: str @@ -36,6 +36,15 @@ class Swap(BaseModel): token_out_amount: int +class Arbitrage(BaseModel): + swaps: List[Swap] + account_address: str + profit_token_address: str + start_amount: int + end_amount: int + profit_amount: int + + def get_arbitrages(traces: List[ClassifiedTrace]) -> List[Arbitrage]: all_arbitrages = [] traces_by_transaction = _group_traces_by_transaction(traces) @@ -65,10 +74,82 @@ def _get_arbitrages_for_transaction( traces: List[ClassifiedTrace], ) -> List[Arbitrage]: swaps = _get_swaps(traces) - print(f"Found {len(swaps)} swaps") + + if len(swaps) > 1: + return _get_arbitrages_from_swaps(swaps) + return [] +def _get_arbitrages_from_swaps(swaps: List[Swap]) -> List[Arbitrage]: + pool_addresses = {swap.pool_address for swap in swaps} + + all_arbitrages = [] + + for index, first_swap in enumerate(swaps): + other_swaps = swaps[:index] + swaps[index + 1 :] + + if first_swap.from_address not in pool_addresses: + arbitrage = _get_arbitrage_starting_with_swap(first_swap, other_swaps) + + if arbitrage is not None: + all_arbitrages.append(arbitrage) + + return all_arbitrages + + +def _get_arbitrage_starting_with_swap( + start_swap: Swap, + other_swaps: List[Swap], +) -> Optional[Arbitrage]: + swap_path = [start_swap] + + current_address = start_swap.to_address + current_token = start_swap.token_out_address + + while True: + swaps_from_current_address = [] + + for swap in other_swaps: + if ( + swap.pool_address == current_address + and swap.token_in_address == current_token + ): + swaps_from_current_address.append(swap) + + if len(swaps_from_current_address) == 0: + return None + + if len(swaps_from_current_address) > 1: + raise RuntimeError("todo") + + latest_swap = swaps_from_current_address[0] + swap_path.append(latest_swap) + + current_address = latest_swap.to_address + current_token = latest_swap.token_out_address + + if ( + current_address == start_swap.from_address + and current_token == start_swap.token_in_address + ): + + start_amount = start_swap.token_in_amount + end_amount = latest_swap.token_out_amount + profit_amount = end_amount - start_amount + + return Arbitrage( + swaps=swap_path, + account_address=start_swap.from_address, + profit_token_address=start_swap.token_in_address, + start_amount=start_amount, + end_amount=end_amount, + profit_amount=profit_amount, + ) + + return None + + def _get_swaps(traces: List[ClassifiedTrace]) -> List[Swap]: ordered_traces = list(sorted(traces, key=lambda t: t.trace_address)) @@ -81,7 +162,11 @@ def _get_swaps(traces: List[ClassifiedTrace]) -> List[Swap]: elif trace.classification == Classification.swap: child_transfers = _get_child_transfers(trace.trace_address, traces) - swap = _build_swap(trace, prior_transfers, child_transfers) + swap = _build_swap( + trace, + _remove_inner_transfers(prior_transfers), + _remove_inner_transfers(child_transfers), + ) if swap is not None: swaps.append(swap) @@ -130,6 +215,7 @@ def _parse_uniswap_v3_swap( return Swap( abi_name=UNISWAP_V3_POOL_ABI_NAME, transaction_hash=trace.transaction_hash, + trace_address=trace.trace_address, pool_address=pool_address, from_address=transfer_in.from_address, to_address=transfer_out.to_address, @@ -170,6 +256,7 @@ def _parse_uniswap_v2_swap( return Swap( abi_name=UNISWAP_V2_PAIR_ABI_NAME, transaction_hash=trace.transaction_hash, + trace_address=trace.trace_address, pool_address=pool_address, from_address=transfer_in.from_address, to_address=transfer_out.to_address,