diff --git a/mev_inspect/swaps.py b/mev_inspect/swaps.py index 9181f9d..b447687 100644 --- a/mev_inspect/swaps.py +++ b/mev_inspect/swaps.py @@ -14,11 +14,6 @@ from mev_inspect.transfers import ( ) -UNISWAP_V2_PAIR_ABI_NAME = "UniswapV2Pair" -UNISWAP_V3_POOL_ABI_NAME = "UniswapV3Pool" -BALANCER_V1_POOL_ABI_NAME = "BPool" - - def get_swaps(traces: List[ClassifiedTrace]) -> List[Swap]: swaps = [] @@ -63,10 +58,6 @@ def _parse_swap( child_transfers: List[ERC20Transfer], ) -> Optional[Swap]: pool_address = trace.to_address - recipient_address = _get_recipient_address(trace) - - if recipient_address is None: - return None transfers_to_pool = filter_transfers(prior_transfers, to_address=pool_address) @@ -74,17 +65,15 @@ def _parse_swap( transfers_to_pool = filter_transfers(child_transfers, to_address=pool_address) if len(transfers_to_pool) == 0: - return None + raise RuntimeError("Expected at least one transfer to pool") - transfers_from_pool_to_recipient = filter_transfers( - child_transfers, to_address=recipient_address, from_address=pool_address - ) + transfers_from_pool = filter_transfers(child_transfers, from_address=pool_address) - if len(transfers_from_pool_to_recipient) != 1: - return None + if len(transfers_from_pool) != 1: + raise RuntimeError("Expected exactly one transfer from pool") transfer_in = transfers_to_pool[-1] - transfer_out = transfers_from_pool_to_recipient[0] + transfer_out = transfers_from_pool[0] return Swap( abi_name=trace.abi_name, @@ -100,22 +89,3 @@ def _parse_swap( token_out_amount=transfer_out.amount, error=trace.error, ) - - -def _get_recipient_address(trace: ClassifiedTrace) -> Optional[str]: - if trace.abi_name == UNISWAP_V3_POOL_ABI_NAME: - return ( - trace.inputs["recipient"] - if trace.inputs is not None and "recipient" in trace.inputs - else trace.from_address - ) - elif trace.abi_name == UNISWAP_V2_PAIR_ABI_NAME: - return ( - trace.inputs["to"] - if trace.inputs is not None and "to" in trace.inputs - else trace.from_address - ) - elif trace.abi_name == BALANCER_V1_POOL_ABI_NAME: - return trace.from_address - else: - return None diff --git a/tests/test_arbitrages.py b/tests/test_arbitrages.py index 18af3bd..fdb2200 100644 --- a/tests/test_arbitrages.py +++ b/tests/test_arbitrages.py @@ -1,9 +1,9 @@ from mev_inspect.arbitrages import get_arbitrages from mev_inspect.schemas.swaps import Swap -from mev_inspect.swaps import ( - UNISWAP_V2_PAIR_ABI_NAME, - UNISWAP_V3_POOL_ABI_NAME, -) + + +UNISWAP_V2_PAIR_ABI_NAME = "UniswapV2Pair" +UNISWAP_V3_POOL_ABI_NAME = "UniswapV3Pool" def test_two_pool_arbitrage(get_transaction_hashes, get_addresses): diff --git a/tests/test_swaps.py b/tests/test_swaps.py index e4ec6c8..9ca81f1 100644 --- a/tests/test_swaps.py +++ b/tests/test_swaps.py @@ -1,9 +1,4 @@ -from mev_inspect.swaps import ( - get_swaps, - UNISWAP_V2_PAIR_ABI_NAME, - UNISWAP_V3_POOL_ABI_NAME, - BALANCER_V1_POOL_ABI_NAME, -) +from mev_inspect.swaps import get_swaps from .helpers import ( make_unknown_trace, @@ -12,6 +7,11 @@ from .helpers import ( ) +UNISWAP_V2_PAIR_ABI_NAME = "UniswapV2Pair" +UNISWAP_V3_POOL_ABI_NAME = "UniswapV3Pool" +BALANCER_V1_POOL_ABI_NAME = "BPool" + + def test_swaps( get_transaction_hashes, get_addresses,