diff --git a/mev_inspect/strategies/arbitrage.py b/mev_inspect/strategies/arbitrage.py index 40b68c0..07f8551 100644 --- a/mev_inspect/strategies/arbitrage.py +++ b/mev_inspect/strategies/arbitrage.py @@ -1,4 +1,3 @@ -import json from typing import Dict, List, Optional from pydantic import BaseModel @@ -11,6 +10,10 @@ from mev_inspect.schemas.classified_traces import ( from mev_inspect.schemas.strategies import Arbitrage +UNISWAP_V2_PAIR_ABI_NAME = "UniswapV2Pair" +UNISWAP_V3_POOL_ABI_NAME = "UniswapV3Pool" + + class Transfer(BaseModel): transaction_hash: str trace_address: List[int] @@ -62,7 +65,7 @@ def _get_arbitrages_for_transaction( traces: List[ClassifiedTrace], ) -> List[Arbitrage]: swaps = _get_swaps(traces) - print(json.dumps([swap.dict() for swap in swaps], indent=4)) + print(f"Found {len(swaps)} swaps") return [] @@ -70,41 +73,39 @@ def _get_swaps(traces: List[ClassifiedTrace]) -> List[Swap]: ordered_traces = list(sorted(traces, key=lambda t: t.trace_address)) swaps: List[Swap] = [] - prior_traces: List[ClassifiedTrace] = [] + prior_transfers: List[Transfer] = [] for trace in ordered_traces: - if trace.classification == Classification.swap: - child_traces = _get_child_traces(trace.trace_address, traces) - swap = _build_swap(trace, prior_traces, child_traces) + if trace.classification == Classification.transfer: + prior_transfers.append(_as_transfer(trace)) + + elif trace.classification == Classification.swap: + child_transfers = _get_child_transfers(trace.trace_address, traces) + swap = _build_swap(trace, prior_transfers, child_transfers) if swap is not None: swaps.append(swap) - prior_traces.append(trace) - return swaps def _build_swap( trace: ClassifiedTrace, - prior_traces: List[ClassifiedTrace], - child_traces: List[ClassifiedTrace], + prior_transfers: List[Transfer], + child_transfers: List[Transfer], ) -> Optional[Swap]: - if trace.abi_name == "UniswapV2Pair": - return _parse_uniswap_v2_swap(trace, prior_traces, child_traces) - elif trace.abi_name == "UniswapV3Pool": - return _parse_uniswap_v3_swap(trace, child_traces) + if trace.abi_name == UNISWAP_V2_PAIR_ABI_NAME: + return _parse_uniswap_v2_swap(trace, prior_transfers, child_transfers) + elif trace.abi_name == UNISWAP_V3_POOL_ABI_NAME: + return _parse_uniswap_v3_swap(trace, child_transfers) return None def _parse_uniswap_v3_swap( trace: ClassifiedTrace, - child_traces: List[ClassifiedTrace], + child_transfers: List[Transfer], ) -> Optional[Swap]: - if trace.inputs is None: - return None - pool_address = trace.to_address recipient_address = ( trace.inputs["recipient"] @@ -112,88 +113,14 @@ def _parse_uniswap_v3_swap( else trace.from_address ) - child_transfers = _remove_inner_transfers( - [ - _as_transfer(child_trace) - for child_trace in child_traces - if child_trace.classification == Classification.transfer - ] + transfers_to_pool = _filter_transfers(child_transfers, to_address=pool_address) + transfers_from_pool_to_recipient = _filter_transfers( + child_transfers, to_address=recipient_address, from_address=pool_address ) - transfers_to_pool = [ - child for child in child_transfers if child.to_address == pool_address - ] - - transfers_from_pool_to_recipient = [ - child - for child in child_transfers - if ( - child.from_address == pool_address and child.to_address == recipient_address - ) - ] - - if len(transfers_to_pool) == 1 and len(transfers_from_pool_to_recipient) == 1: - transfer_in = transfers_to_pool[0] - transfer_out = transfers_from_pool_to_recipient[0] - - return Swap( - abi_name="UniswapV3Pool", - transaction_hash=trace.transaction_hash, - pool_address=pool_address, - from_address=transfer_in.from_address, - to_address=transfer_out.to_address, - token_in_address=transfer_in.token_address, - token_in_amount=transfer_in.amount, - token_out_address=transfer_out.token_address, - token_out_amount=transfer_out.amount, - ) - - return None - - -def _parse_uniswap_v2_swap( - trace: ClassifiedTrace, - prior_traces: List[ClassifiedTrace], - child_traces: List[ClassifiedTrace], -) -> Optional[Swap]: - - pool_address = trace.to_address - recipient_address = ( - trace.inputs["to"] - if trace.inputs is not None and "to" in trace.inputs - else trace.from_address - ) - - prior_transfers = [ - _as_transfer(prior_trace) - for prior_trace in prior_traces - if prior_trace.classification == Classification.transfer - ] - - child_transfers = [ - _as_transfer(child) - for child in child_traces - if child.classification == Classification.transfer - ] - - transfers_to_pool = [ - transfer for transfer in prior_transfers if transfer.to_address == pool_address - ] - - transfers_from_pool_to_recipient = [ - transfer - for transfer in child_transfers - if ( - transfer.to_address == recipient_address - and transfer.from_address == pool_address - ) - ] - - # expecting a prior transfer to the pool - if len(transfers_to_pool) == 0: + if len(transfers_to_pool) == 1: return None - # expecting exactly one transfer inside the pool if len(transfers_from_pool_to_recipient) != 1: return None @@ -201,7 +128,7 @@ def _parse_uniswap_v2_swap( transfer_out = transfers_from_pool_to_recipient[0] return Swap( - abi_name="UniswapV2Pair", + abi_name=UNISWAP_V3_POOL_ABI_NAME, transaction_hash=trace.transaction_hash, pool_address=pool_address, from_address=transfer_in.from_address, @@ -213,6 +140,59 @@ def _parse_uniswap_v2_swap( ) +def _parse_uniswap_v2_swap( + trace: ClassifiedTrace, + prior_transfers: List[Transfer], + child_transfers: List[Transfer], +) -> Optional[Swap]: + + pool_address = trace.to_address + recipient_address = ( + trace.inputs["to"] + if trace.inputs is not None and "to" in trace.inputs + else trace.from_address + ) + + transfers_to_pool = _filter_transfers(prior_transfers, to_address=pool_address) + transfers_from_pool_to_recipient = _filter_transfers( + child_transfers, to_address=recipient_address, from_address=pool_address + ) + + if len(transfers_to_pool) == 0: + return None + + if len(transfers_from_pool_to_recipient) != 1: + return None + + transfer_in = transfers_to_pool[-1] + transfer_out = transfers_from_pool_to_recipient[0] + + return Swap( + abi_name=UNISWAP_V2_PAIR_ABI_NAME, + transaction_hash=trace.transaction_hash, + pool_address=pool_address, + from_address=transfer_in.from_address, + to_address=transfer_out.to_address, + token_in_address=transfer_in.token_address, + token_in_amount=transfer_in.amount, + token_out_address=transfer_out.token_address, + token_out_amount=transfer_out.amount, + ) + + +def _get_child_transfers( + parent_trace_address: List[int], + traces: List[ClassifiedTrace], +) -> List[Transfer]: + child_transfers = [] + + for child_trace in _get_child_traces(parent_trace_address, traces): + if child_trace.classification == Classification.transfer: + child_transfers.append(_as_transfer(child_trace)) + + return child_transfers + + def _get_child_traces( parent_trace_address: List[int], traces: List[ClassifiedTrace], @@ -265,6 +245,25 @@ def _is_subtrace(parent_trace_address, child_trace_address) -> bool: ) +def _filter_transfers( + transfers: List[Transfer], + to_address: Optional[str] = None, + from_address: Optional[str] = None, +) -> List[Transfer]: + filtered_transfers = [] + + for transfer in transfers: + if to_address is not None and transfer.to_address != to_address: + continue + + if from_address is not None and transfer.from_address != from_address: + continue + + filtered_transfers.append(transfer) + + return filtered_transfers + + def _as_transfer(trace: ClassifiedTrace) -> Transfer: # todo - this should be enforced at the data level if trace.inputs is None: