diff --git a/mev_inspect/strategies/arbitrage.py b/mev_inspect/strategies/arbitrage.py index acf42dd..fdff45c 100644 --- a/mev_inspect/strategies/arbitrage.py +++ b/mev_inspect/strategies/arbitrage.py @@ -12,6 +12,8 @@ from mev_inspect.schemas.strategies import Arbitrage class Transfer(BaseModel): + transaction_hash: str + trace_address: List[int] from_address: str to_address: str amount: int @@ -19,6 +21,7 @@ class Transfer(BaseModel): class Swap(BaseModel): + abi_name: str transaction_hash: str protocol: Optional[Protocol] pool_address: str @@ -67,42 +70,96 @@ def _get_swaps(traces: List[ClassifiedTrace]) -> List[Swap]: ordered_traces = list(sorted(traces, key=lambda t: t.trace_address)) swaps: List[Swap] = [] - prior_transfers = [] + prior_traces: List[ClassifiedTrace] = [] for trace in ordered_traces: - if trace.classification == Classification.transfer: - transfer = _as_transfer(trace) - prior_transfers.append(transfer) - if trace.classification == Classification.swap: child_traces = _get_child_traces(trace.trace_address, traces) - swap = _build_swap(trace, prior_transfers, child_traces) + swap = _build_swap(trace, prior_traces, child_traces) if swap is not None: swaps.append(swap) + prior_traces.append(trace) + return swaps def _build_swap( trace: ClassifiedTrace, - prior_transfers: List[Transfer], + prior_traces: List[ClassifiedTrace], child_traces: List[ClassifiedTrace], ) -> Optional[Swap]: if trace.abi_name == "UniswapV2Pair": - return _parse_uniswap_v2_swap(trace, prior_transfers, child_traces) + return _parse_uniswap_v2_swap(trace, prior_traces, child_traces) + elif trace.abi_name == "UniswapV3Pool": + return _parse_uniswap_v3_swap(trace, child_traces) + + return None + + +def _parse_uniswap_v3_swap( + trace: ClassifiedTrace, + child_traces: List[ClassifiedTrace], +) -> Optional[Swap]: + if trace.inputs is None: + return None + + pool_address = trace.to_address + recipient_address = trace.inputs["recipient"] + + child_transfers = _remove_inner_transfers( + [ + _as_transfer(child_trace) + for child_trace in child_traces + if child_trace.classification == Classification.transfer + ] + ) + + transfers_to_pool = [ + child for child in child_transfers if child.to_address == pool_address + ] + + transfers_from_pool_to_recipient = [ + child + for child in child_transfers + if ( + child.from_address == pool_address and child.to_address == recipient_address + ) + ] + + if len(transfers_to_pool) == 1 and len(transfers_from_pool_to_recipient) == 1: + transfer_in = transfers_to_pool[0] + transfer_out = transfers_from_pool_to_recipient[0] + + return Swap( + abi_name="UniswapV3Pool", + transaction_hash=trace.transaction_hash, + pool_address=pool_address, + from_address=transfer_in.from_address, + to_address=transfer_out.to_address, + token_in_address=transfer_in.token_address, + token_in_amount=transfer_in.amount, + token_out_address=transfer_out.token_address, + token_out_amount=transfer_out.amount, + ) return None def _parse_uniswap_v2_swap( trace: ClassifiedTrace, - prior_transfers: List[Transfer], + prior_traces: List[ClassifiedTrace], child_traces: List[ClassifiedTrace], ) -> Optional[Swap]: pool_address = trace.to_address transfers_to_pool = [ - transfer for transfer in prior_transfers if transfer.to_address == pool_address + _as_transfer(prior_trace) + for prior_trace in prior_traces + if ( + prior_trace.classification == Classification.transfer + and prior_trace.to_address == pool_address + ) ] # expecting a prior transfer to the pool @@ -123,6 +180,7 @@ def _parse_uniswap_v2_swap( pool_internal_transfer = all_pool_internal_transfers[0] return Swap( + abi_name="UniswapV2Pair", transaction_hash=trace.transaction_hash, pool_address=pool_address, from_address=most_recent_transfer_to_pool.from_address, @@ -156,11 +214,36 @@ def _is_child_trace_address( maybe_child_trace_address: List[int], ) -> bool: parent_trace_length = len(parent_trace_address) - return (len(maybe_child_trace_address) == parent_trace_length + 1) and ( + return (len(maybe_child_trace_address) > parent_trace_length) and ( maybe_child_trace_address[:parent_trace_length] == parent_trace_address ) +def _remove_inner_transfers(transfers: List[Transfer]) -> List[Transfer]: + updated_transfers = [] + transfer_trace_addresses: List[List[int]] = [] + + sorted_transfers = sorted(transfers, key=lambda t: t.trace_address) + + for transfer in sorted_transfers: + if not any( + _is_subtrace(parent_address, transfer.trace_address) + for parent_address in transfer_trace_addresses + ): + updated_transfers.append(transfer) + + transfer_trace_addresses.append(transfer.trace_address) + + return updated_transfers + + +def _is_subtrace(parent_trace_address, child_trace_address) -> bool: + return ( + len(child_trace_address) > len(parent_trace_address) + and child_trace_address[: len(parent_trace_address)] == parent_trace_address + ) + + def _as_transfer(trace: ClassifiedTrace) -> Transfer: # todo - this should be enforced at the data level if trace.inputs is None: @@ -168,6 +251,8 @@ def _as_transfer(trace: ClassifiedTrace) -> Transfer: if trace.protocol == Protocol.weth: return Transfer( + transaction_hash=trace.transaction_hash, + trace_address=trace.trace_address, amount=trace.inputs["wad"], to_address=trace.inputs["dst"], from_address=trace.from_address, @@ -175,6 +260,8 @@ def _as_transfer(trace: ClassifiedTrace) -> Transfer: ) else: return Transfer( + transaction_hash=trace.transaction_hash, + trace_address=trace.trace_address, amount=trace.inputs["amount"], to_address=trace.inputs["recipient"], from_address=trace.inputs.get("sender", trace.from_address),