diff --git a/mev_inspect/swaps.py b/mev_inspect/swaps.py index dda85e4..6e754b7 100644 --- a/mev_inspect/swaps.py +++ b/mev_inspect/swaps.py @@ -29,7 +29,7 @@ def get_swaps(traces: List[ClassifiedTrace]) -> List[Swap]: elif trace.classification == Classification.swap: child_transfers = get_child_transfers(trace.trace_address, traces) - swap = _build_swap( + swap = _parse_swap( trace, remove_inner_transfers(prior_transfers), remove_inner_transfers(child_transfers), @@ -41,79 +41,29 @@ def get_swaps(traces: List[ClassifiedTrace]) -> List[Swap]: return swaps -def _build_swap( +def _parse_swap( trace: ClassifiedTrace, prior_transfers: List[Transfer], child_transfers: List[Transfer], -) -> Optional[Swap]: - if trace.abi_name == UNISWAP_V2_PAIR_ABI_NAME: - return _parse_uniswap_v2_swap(trace, prior_transfers, child_transfers) - elif trace.abi_name == UNISWAP_V3_POOL_ABI_NAME: - return _parse_uniswap_v3_swap(trace, child_transfers) - - return None - - -def _parse_uniswap_v3_swap( - trace: ClassifiedTrace, - child_transfers: List[Transfer], ) -> Optional[Swap]: pool_address = trace.to_address - recipient_address = ( - trace.inputs["recipient"] - if trace.inputs is not None and "recipient" in trace.inputs - else trace.from_address - ) + recipient_address = _get_recipient_address(trace) - transfers_to_pool = filter_transfers(child_transfers, to_address=pool_address) - transfers_from_pool_to_recipient = filter_transfers( - child_transfers, to_address=recipient_address, from_address=pool_address - ) - - if len(transfers_to_pool) == 0: + if recipient_address is None: return None - if len(transfers_from_pool_to_recipient) != 1: - return None - - transfer_in = transfers_to_pool[-1] - transfer_out = transfers_from_pool_to_recipient[0] - - return Swap( - abi_name=UNISWAP_V3_POOL_ABI_NAME, - transaction_hash=trace.transaction_hash, - trace_address=trace.trace_address, - pool_address=pool_address, - from_address=transfer_in.from_address, - to_address=transfer_out.to_address, - token_in_address=transfer_in.token_address, - token_in_amount=transfer_in.amount, - token_out_address=transfer_out.token_address, - token_out_amount=transfer_out.amount, - ) - - -def _parse_uniswap_v2_swap( - trace: ClassifiedTrace, - prior_transfers: List[Transfer], - child_transfers: List[Transfer], -) -> Optional[Swap]: - - pool_address = trace.to_address - recipient_address = ( - trace.inputs["to"] - if trace.inputs is not None and "to" in trace.inputs - else trace.from_address - ) - transfers_to_pool = filter_transfers(prior_transfers, to_address=pool_address) - transfers_from_pool_to_recipient = filter_transfers( - child_transfers, to_address=recipient_address, from_address=pool_address - ) + + if len(transfers_to_pool) == 0: + transfers_to_pool = filter_transfers(child_transfers, to_address=pool_address) if len(transfers_to_pool) == 0: return None + transfers_from_pool_to_recipient = filter_transfers( + child_transfers, to_address=recipient_address, from_address=pool_address + ) + if len(transfers_from_pool_to_recipient) != 1: return None @@ -121,7 +71,7 @@ def _parse_uniswap_v2_swap( transfer_out = transfers_from_pool_to_recipient[0] return Swap( - abi_name=UNISWAP_V2_PAIR_ABI_NAME, + abi_name=trace.abi_name, transaction_hash=trace.transaction_hash, trace_address=trace.trace_address, pool_address=pool_address, @@ -132,3 +82,20 @@ def _parse_uniswap_v2_swap( token_out_address=transfer_out.token_address, token_out_amount=transfer_out.amount, ) + + +def _get_recipient_address(trace: ClassifiedTrace) -> Optional[str]: + if trace.abi_name == UNISWAP_V3_POOL_ABI_NAME: + return ( + trace.inputs["recipient"] + if trace.inputs is not None and "recipient" in trace.inputs + else trace.from_address + ) + elif trace.abi_name == UNISWAP_V2_PAIR_ABI_NAME: + return ( + trace.inputs["to"] + if trace.inputs is not None and "to" in trace.inputs + else trace.from_address + ) + else: + return None