diff --git a/mev_inspect/classifiers/specs/balancer.py b/mev_inspect/classifiers/specs/balancer.py index 3d52978..01bc463 100644 --- a/mev_inspect/classifiers/specs/balancer.py +++ b/mev_inspect/classifiers/specs/balancer.py @@ -8,6 +8,9 @@ from mev_inspect.schemas.classifiers import ( ) +BALANCER_V1_POOL_ABI_NAME = "BPool" + + class BalancerSwapClassifier(SwapClassifier): @staticmethod def get_swap_recipient(trace: DecodedCallTrace) -> str: @@ -16,7 +19,7 @@ class BalancerSwapClassifier(SwapClassifier): BALANCER_V1_SPECS = [ ClassifierSpec( - abi_name="BPool", + abi_name=BALANCER_V1_POOL_ABI_NAME, protocol=Protocol.balancer_v1, classifiers={ "swapExactAmountIn(address,uint256,address,uint256,uint256)": BalancerSwapClassifier, diff --git a/mev_inspect/classifiers/specs/uniswap.py b/mev_inspect/classifiers/specs/uniswap.py index f136174..1761bee 100644 --- a/mev_inspect/classifiers/specs/uniswap.py +++ b/mev_inspect/classifiers/specs/uniswap.py @@ -8,6 +8,10 @@ from mev_inspect.schemas.classifiers import ( ) +UNISWAP_V2_PAIR_ABI_NAME = "UniswapV2Pair" +UNISWAP_V3_POOL_ABI_NAME = "UniswapV3Pool" + + class UniswapV3SwapClassifier(SwapClassifier): @staticmethod def get_swap_recipient(trace: DecodedCallTrace) -> str: @@ -86,7 +90,7 @@ UNISWAP_V3_CONTRACT_SPECS = [ UNISWAP_V3_GENERAL_SPECS = [ ClassifierSpec( - abi_name="UniswapV3Pool", + abi_name=UNISWAP_V3_POOL_ABI_NAME, classifiers={ "swap(address,bool,int256,uint160,bytes)": UniswapV3SwapClassifier, }, @@ -117,7 +121,7 @@ UNISWAPPY_V2_CONTRACT_SPECS = [ ] UNISWAPPY_V2_PAIR_SPEC = ClassifierSpec( - abi_name="UniswapV2Pair", + abi_name=UNISWAP_V2_PAIR_ABI_NAME, classifiers={ "swap(uint256,uint256,address,bytes)": UniswapV2SwapClassifier, }, diff --git a/tests/helpers.py b/tests/helpers.py index 21f3264..5c2f25d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,11 +1,11 @@ -from typing import List +from typing import List, Optional from mev_inspect.schemas.blocks import TraceType from mev_inspect.schemas.classified_traces import ( Classification, ClassifiedTrace, - CallTrace, DecodedCallTrace, + Protocol, ) @@ -18,7 +18,7 @@ def make_transfer_trace( token_address: str, amount: int, ): - return CallTrace( + return DecodedCallTrace( transaction_hash=transaction_hash, block_number=block_number, type=TraceType.call, @@ -26,6 +26,9 @@ def make_transfer_trace( classification=Classification.transfer, from_address=from_address, to_address=token_address, + abi_name="ERC20", + function_name="transfer", + function_signature="transfer(address,uint256)", inputs={ "recipient": to_address, "amount": amount, @@ -43,6 +46,8 @@ def make_swap_trace( from_address: str, pool_address: str, abi_name: str, + function_signature: str, + protocol: Optional[Protocol], recipient_address: str, recipient_input_key: str, ): @@ -57,9 +62,10 @@ def make_swap_trace( from_address=from_address, to_address=pool_address, function_name="swap", - function_signature="swap()", + function_signature=function_signature, inputs={recipient_input_key: recipient_address}, abi_name=abi_name, + protocol=protocol, block_hash=str(block_number), ) diff --git a/tests/test_arbitrages.py b/tests/test_arbitrages.py index 18af3bd..524aa7f 100644 --- a/tests/test_arbitrages.py +++ b/tests/test_arbitrages.py @@ -1,9 +1,9 @@ from mev_inspect.arbitrages import get_arbitrages -from mev_inspect.schemas.swaps import Swap -from mev_inspect.swaps import ( +from mev_inspect.classifiers.specs.uniswap import ( UNISWAP_V2_PAIR_ABI_NAME, UNISWAP_V3_POOL_ABI_NAME, ) +from mev_inspect.schemas.swaps import Swap def test_two_pool_arbitrage(get_transaction_hashes, get_addresses): diff --git a/tests/test_swaps.py b/tests/test_swaps.py index e4ec6c8..a2a0d29 100644 --- a/tests/test_swaps.py +++ b/tests/test_swaps.py @@ -1,9 +1,10 @@ -from mev_inspect.swaps import ( - get_swaps, +from mev_inspect.swaps import get_swaps +from mev_inspect.classifiers.specs.balancer import BALANCER_V1_POOL_ABI_NAME +from mev_inspect.classifiers.specs.uniswap import ( UNISWAP_V2_PAIR_ABI_NAME, UNISWAP_V3_POOL_ABI_NAME, - BALANCER_V1_POOL_ABI_NAME, ) +from mev_inspect.schemas.classified_traces import Protocol from .helpers import ( make_unknown_trace, @@ -64,6 +65,8 @@ def test_swaps( from_address=alice_address, pool_address=first_pool_address, abi_name=UNISWAP_V2_PAIR_ABI_NAME, + protocol=None, + function_signature="swap(uint256,uint256,address,bytes)", recipient_address=bob_address, recipient_input_key="to", ), @@ -83,6 +86,8 @@ def test_swaps( from_address=bob_address, pool_address=second_pool_address, abi_name=UNISWAP_V3_POOL_ABI_NAME, + protocol=None, + function_signature="swap(address,bool,int256,uint160,bytes)", recipient_address=carl_address, recipient_input_key="recipient", ), @@ -129,6 +134,8 @@ def test_swaps( from_address=bob_address, pool_address=third_pool_address, abi_name=BALANCER_V1_POOL_ABI_NAME, + protocol=Protocol.balancer_v1, + function_signature="swapExactAmountIn(address,uint256,address,uint256,uint256)", recipient_address=bob_address, recipient_input_key="recipient", ),