From 3039f3eed2c5172488e0d9c713d5068e1077302a Mon Sep 17 00:00:00 2001 From: Luke Van Seters Date: Thu, 7 Oct 2021 18:00:13 -0400 Subject: [PATCH 1/2] Use SwapClassifier for Swap objects --- mev_inspect/classifiers/specs/__init__.py | 21 +++++++++++ mev_inspect/classifiers/specs/balancer.py | 11 ++++-- mev_inspect/classifiers/specs/uniswap.py | 23 ++++++++++-- mev_inspect/schemas/classifiers.py | 5 +++ mev_inspect/swaps.py | 43 ++++++++++------------ mev_inspect/transfers.py | 44 ++++++++++++----------- 6 files changed, 98 insertions(+), 49 deletions(-) diff --git a/mev_inspect/classifiers/specs/__init__.py b/mev_inspect/classifiers/specs/__init__.py index e6f3fda..b5dd98e 100644 --- a/mev_inspect/classifiers/specs/__init__.py +++ b/mev_inspect/classifiers/specs/__init__.py @@ -1,3 +1,8 @@ +from typing import Dict, Optional, Tuple, Type + +from mev_inspect.schemas.classified_traces import DecodedCallTrace, Protocol +from mev_inspect.schemas.classifiers import ClassifierSpec, Classifier + from .aave import AAVE_CLASSIFIER_SPECS from .curve import CURVE_CLASSIFIER_SPECS from .erc20 import ERC20_CLASSIFIER_SPECS @@ -16,3 +21,19 @@ ALL_CLASSIFIER_SPECS = ( + ZEROX_CLASSIFIER_SPECS + BALANCER_CLASSIFIER_SPECS ) + +_SPECS_BY_ABI_NAME_AND_PROTOCOL: Dict[ + Tuple[str, Optional[Protocol]], ClassifierSpec +] = {(spec.abi_name, spec.protocol): spec for spec in ALL_CLASSIFIER_SPECS} + + +def get_classifier( + trace: DecodedCallTrace, +) -> Optional[Type[Classifier]]: + abi_name_and_protocol = (trace.abi_name, trace.protocol) + spec = _SPECS_BY_ABI_NAME_AND_PROTOCOL.get(abi_name_and_protocol) + + if spec is not None: + return spec.classifiers.get(trace.function_signature) + + return None diff --git a/mev_inspect/classifiers/specs/balancer.py b/mev_inspect/classifiers/specs/balancer.py index 9b05dfd..3d52978 100644 --- a/mev_inspect/classifiers/specs/balancer.py +++ b/mev_inspect/classifiers/specs/balancer.py @@ -1,4 +1,5 @@ from mev_inspect.schemas.classified_traces import ( + DecodedCallTrace, Protocol, ) from mev_inspect.schemas.classifiers import ( @@ -7,13 +8,19 @@ from mev_inspect.schemas.classifiers import ( ) +class BalancerSwapClassifier(SwapClassifier): + @staticmethod + def get_swap_recipient(trace: DecodedCallTrace) -> str: + return trace.from_address + + BALANCER_V1_SPECS = [ ClassifierSpec( abi_name="BPool", protocol=Protocol.balancer_v1, classifiers={ - "swapExactAmountIn(address,uint256,address,uint256,uint256)": SwapClassifier, - "swapExactAmountOut(address,uint256,address,uint256,uint256)": SwapClassifier, + "swapExactAmountIn(address,uint256,address,uint256,uint256)": BalancerSwapClassifier, + "swapExactAmountOut(address,uint256,address,uint256,uint256)": BalancerSwapClassifier, }, ), ] diff --git a/mev_inspect/classifiers/specs/uniswap.py b/mev_inspect/classifiers/specs/uniswap.py index ad6d146..f136174 100644 --- a/mev_inspect/classifiers/specs/uniswap.py +++ b/mev_inspect/classifiers/specs/uniswap.py @@ -1,4 +1,5 @@ from mev_inspect.schemas.classified_traces import ( + DecodedCallTrace, Protocol, ) from mev_inspect.schemas.classifiers import ( @@ -7,6 +8,24 @@ from mev_inspect.schemas.classifiers import ( ) +class UniswapV3SwapClassifier(SwapClassifier): + @staticmethod + def get_swap_recipient(trace: DecodedCallTrace) -> str: + if trace.inputs is not None and "recipient" in trace.inputs: + return trace.inputs["recipient"] + else: + return trace.from_address + + +class UniswapV2SwapClassifier(SwapClassifier): + @staticmethod + def get_swap_recipient(trace: DecodedCallTrace) -> str: + if trace.inputs is not None and "to" in trace.inputs: + return trace.inputs["to"] + else: + return trace.from_address + + UNISWAP_V3_CONTRACT_SPECS = [ ClassifierSpec( abi_name="UniswapV3Factory", @@ -69,7 +88,7 @@ UNISWAP_V3_GENERAL_SPECS = [ ClassifierSpec( abi_name="UniswapV3Pool", classifiers={ - "swap(address,bool,int256,uint160,bytes)": SwapClassifier, + "swap(address,bool,int256,uint160,bytes)": UniswapV3SwapClassifier, }, ), ClassifierSpec( @@ -100,7 +119,7 @@ UNISWAPPY_V2_CONTRACT_SPECS = [ UNISWAPPY_V2_PAIR_SPEC = ClassifierSpec( abi_name="UniswapV2Pair", classifiers={ - "swap(uint256,uint256,address,bytes)": SwapClassifier, + "swap(uint256,uint256,address,bytes)": UniswapV2SwapClassifier, }, ) diff --git a/mev_inspect/schemas/classifiers.py b/mev_inspect/schemas/classifiers.py index 5cee8d0..0b8db14 100644 --- a/mev_inspect/schemas/classifiers.py +++ b/mev_inspect/schemas/classifiers.py @@ -30,6 +30,11 @@ class SwapClassifier(Classifier): def get_classification() -> Classification: return Classification.swap + @staticmethod + @abstractmethod + def get_swap_recipient(trace: DecodedCallTrace) -> str: + raise NotImplementedError() + class LiquidationClassifier(Classifier): @staticmethod diff --git a/mev_inspect/swaps.py b/mev_inspect/swaps.py index 9181f9d..906759b 100644 --- a/mev_inspect/swaps.py +++ b/mev_inspect/swaps.py @@ -1,24 +1,23 @@ from typing import List, Optional +from mev_inspect.classifiers.specs import get_classifier from mev_inspect.schemas.classified_traces import ( ClassifiedTrace, Classification, + DecodedCallTrace, ) +from mev_inspect.schemas.classifiers import SwapClassifier from mev_inspect.schemas.swaps import Swap from mev_inspect.schemas.transfers import ERC20Transfer from mev_inspect.traces import get_traces_by_transaction_hash from mev_inspect.transfers import ( + get_transfer, get_child_transfers, filter_transfers, remove_child_transfers_of_transfers, ) -UNISWAP_V2_PAIR_ABI_NAME = "UniswapV2Pair" -UNISWAP_V3_POOL_ABI_NAME = "UniswapV3Pool" -BALANCER_V1_POOL_ABI_NAME = "BPool" - - def get_swaps(traces: List[ClassifiedTrace]) -> List[Swap]: swaps = [] @@ -35,8 +34,13 @@ def _get_swaps_for_transaction(traces: List[ClassifiedTrace]) -> List[Swap]: prior_transfers: List[ERC20Transfer] = [] for trace in ordered_traces: - if trace.classification == Classification.transfer: - prior_transfers.append(ERC20Transfer.from_trace(trace)) + if not isinstance(trace, DecodedCallTrace): + continue + + elif trace.classification == Classification.transfer: + transfer = get_transfer(trace) + if transfer is not None: + prior_transfers.append(transfer) elif trace.classification == Classification.swap: child_transfers = get_child_transfers( @@ -58,7 +62,7 @@ def _get_swaps_for_transaction(traces: List[ClassifiedTrace]) -> List[Swap]: def _parse_swap( - trace: ClassifiedTrace, + trace: DecodedCallTrace, prior_transfers: List[ERC20Transfer], child_transfers: List[ERC20Transfer], ) -> Optional[Swap]: @@ -102,20 +106,9 @@ def _parse_swap( ) -def _get_recipient_address(trace: ClassifiedTrace) -> Optional[str]: - if trace.abi_name == UNISWAP_V3_POOL_ABI_NAME: - return ( - trace.inputs["recipient"] - if trace.inputs is not None and "recipient" in trace.inputs - else trace.from_address - ) - elif trace.abi_name == UNISWAP_V2_PAIR_ABI_NAME: - return ( - trace.inputs["to"] - if trace.inputs is not None and "to" in trace.inputs - else trace.from_address - ) - elif trace.abi_name == BALANCER_V1_POOL_ABI_NAME: - return trace.from_address - else: - return None +def _get_recipient_address(trace: DecodedCallTrace) -> Optional[str]: + classifier = get_classifier(trace) + if classifier is not None and issubclass(classifier, SwapClassifier): + return classifier.get_swap_recipient(trace) + + return None diff --git a/mev_inspect/transfers.py b/mev_inspect/transfers.py index 90959cd..b1dc11a 100644 --- a/mev_inspect/transfers.py +++ b/mev_inspect/transfers.py @@ -1,12 +1,11 @@ -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence -from mev_inspect.classifiers.specs import ALL_CLASSIFIER_SPECS -from mev_inspect.schemas.classifiers import ClassifierSpec, TransferClassifier +from mev_inspect.classifiers.specs import get_classifier +from mev_inspect.schemas.classifiers import TransferClassifier from mev_inspect.schemas.classified_traces import ( Classification, ClassifiedTrace, DecodedCallTrace, - Protocol, ) from mev_inspect.schemas.transfers import ERC20Transfer, EthTransfer, TransferGeneric from mev_inspect.traces import is_child_trace_address, get_child_traces @@ -25,25 +24,26 @@ def get_eth_transfers(traces: List[ClassifiedTrace]) -> List[EthTransfer]: def get_transfers(traces: List[ClassifiedTrace]) -> List[ERC20Transfer]: transfers = [] - specs_by_abi_name_and_protocol: Dict[ - Tuple[str, Optional[Protocol]], ClassifierSpec - ] = {(spec.abi_name, spec.protocol): spec for spec in ALL_CLASSIFIER_SPECS} - for trace in traces: - if not isinstance(trace, DecodedCallTrace): - continue - - abi_name_and_protocol = (trace.abi_name, trace.protocol) - spec = specs_by_abi_name_and_protocol.get(abi_name_and_protocol) - - if spec is not None: - classifier = spec.classifiers.get(trace.function_signature) - if classifier is not None and issubclass(classifier, TransferClassifier): - transfers.append(classifier.get_transfer(trace)) + if isinstance(trace, DecodedCallTrace): + transfer = get_transfer(trace) + if transfer is not None: + transfers.append(transfer) return transfers +def get_transfer(trace: DecodedCallTrace) -> Optional[ERC20Transfer]: + if not isinstance(trace, DecodedCallTrace): + return None + + classifier = get_classifier(trace) + if classifier is not None and issubclass(classifier, TransferClassifier): + return classifier.get_transfer(trace) + + return None + + def get_child_transfers( transaction_hash: str, parent_trace_address: List[int], @@ -52,8 +52,12 @@ def get_child_transfers( child_transfers = [] for child_trace in get_child_traces(transaction_hash, parent_trace_address, traces): - if child_trace.classification == Classification.transfer: - child_transfers.append(ERC20Transfer.from_trace(child_trace)) + if child_trace.classification == Classification.transfer and isinstance( + child_trace, DecodedCallTrace + ): + transfer = get_transfer(child_trace) + if transfer is not None: + child_transfers.append(transfer) return child_transfers From a1fd035de85765e23fa0fe132f8f7c25f2aa8501 Mon Sep 17 00:00:00 2001 From: Luke Van Seters Date: Thu, 7 Oct 2021 18:29:11 -0400 Subject: [PATCH 2/2] Update tests --- mev_inspect/classifiers/specs/balancer.py | 5 ++++- mev_inspect/classifiers/specs/uniswap.py | 8 ++++++-- tests/helpers.py | 14 ++++++++++---- tests/test_arbitrages.py | 4 ++-- tests/test_swaps.py | 13 ++++++++++--- 5 files changed, 32 insertions(+), 12 deletions(-) diff --git a/mev_inspect/classifiers/specs/balancer.py b/mev_inspect/classifiers/specs/balancer.py index 3d52978..01bc463 100644 --- a/mev_inspect/classifiers/specs/balancer.py +++ b/mev_inspect/classifiers/specs/balancer.py @@ -8,6 +8,9 @@ from mev_inspect.schemas.classifiers import ( ) +BALANCER_V1_POOL_ABI_NAME = "BPool" + + class BalancerSwapClassifier(SwapClassifier): @staticmethod def get_swap_recipient(trace: DecodedCallTrace) -> str: @@ -16,7 +19,7 @@ class BalancerSwapClassifier(SwapClassifier): BALANCER_V1_SPECS = [ ClassifierSpec( - abi_name="BPool", + abi_name=BALANCER_V1_POOL_ABI_NAME, protocol=Protocol.balancer_v1, classifiers={ "swapExactAmountIn(address,uint256,address,uint256,uint256)": BalancerSwapClassifier, diff --git a/mev_inspect/classifiers/specs/uniswap.py b/mev_inspect/classifiers/specs/uniswap.py index f136174..1761bee 100644 --- a/mev_inspect/classifiers/specs/uniswap.py +++ b/mev_inspect/classifiers/specs/uniswap.py @@ -8,6 +8,10 @@ from mev_inspect.schemas.classifiers import ( ) +UNISWAP_V2_PAIR_ABI_NAME = "UniswapV2Pair" +UNISWAP_V3_POOL_ABI_NAME = "UniswapV3Pool" + + class UniswapV3SwapClassifier(SwapClassifier): @staticmethod def get_swap_recipient(trace: DecodedCallTrace) -> str: @@ -86,7 +90,7 @@ UNISWAP_V3_CONTRACT_SPECS = [ UNISWAP_V3_GENERAL_SPECS = [ ClassifierSpec( - abi_name="UniswapV3Pool", + abi_name=UNISWAP_V3_POOL_ABI_NAME, classifiers={ "swap(address,bool,int256,uint160,bytes)": UniswapV3SwapClassifier, }, @@ -117,7 +121,7 @@ UNISWAPPY_V2_CONTRACT_SPECS = [ ] UNISWAPPY_V2_PAIR_SPEC = ClassifierSpec( - abi_name="UniswapV2Pair", + abi_name=UNISWAP_V2_PAIR_ABI_NAME, classifiers={ "swap(uint256,uint256,address,bytes)": UniswapV2SwapClassifier, }, diff --git a/tests/helpers.py b/tests/helpers.py index 21f3264..5c2f25d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,11 +1,11 @@ -from typing import List +from typing import List, Optional from mev_inspect.schemas.blocks import TraceType from mev_inspect.schemas.classified_traces import ( Classification, ClassifiedTrace, - CallTrace, DecodedCallTrace, + Protocol, ) @@ -18,7 +18,7 @@ def make_transfer_trace( token_address: str, amount: int, ): - return CallTrace( + return DecodedCallTrace( transaction_hash=transaction_hash, block_number=block_number, type=TraceType.call, @@ -26,6 +26,9 @@ def make_transfer_trace( classification=Classification.transfer, from_address=from_address, to_address=token_address, + abi_name="ERC20", + function_name="transfer", + function_signature="transfer(address,uint256)", inputs={ "recipient": to_address, "amount": amount, @@ -43,6 +46,8 @@ def make_swap_trace( from_address: str, pool_address: str, abi_name: str, + function_signature: str, + protocol: Optional[Protocol], recipient_address: str, recipient_input_key: str, ): @@ -57,9 +62,10 @@ def make_swap_trace( from_address=from_address, to_address=pool_address, function_name="swap", - function_signature="swap()", + function_signature=function_signature, inputs={recipient_input_key: recipient_address}, abi_name=abi_name, + protocol=protocol, block_hash=str(block_number), ) diff --git a/tests/test_arbitrages.py b/tests/test_arbitrages.py index 18af3bd..524aa7f 100644 --- a/tests/test_arbitrages.py +++ b/tests/test_arbitrages.py @@ -1,9 +1,9 @@ from mev_inspect.arbitrages import get_arbitrages -from mev_inspect.schemas.swaps import Swap -from mev_inspect.swaps import ( +from mev_inspect.classifiers.specs.uniswap import ( UNISWAP_V2_PAIR_ABI_NAME, UNISWAP_V3_POOL_ABI_NAME, ) +from mev_inspect.schemas.swaps import Swap def test_two_pool_arbitrage(get_transaction_hashes, get_addresses): diff --git a/tests/test_swaps.py b/tests/test_swaps.py index e4ec6c8..a2a0d29 100644 --- a/tests/test_swaps.py +++ b/tests/test_swaps.py @@ -1,9 +1,10 @@ -from mev_inspect.swaps import ( - get_swaps, +from mev_inspect.swaps import get_swaps +from mev_inspect.classifiers.specs.balancer import BALANCER_V1_POOL_ABI_NAME +from mev_inspect.classifiers.specs.uniswap import ( UNISWAP_V2_PAIR_ABI_NAME, UNISWAP_V3_POOL_ABI_NAME, - BALANCER_V1_POOL_ABI_NAME, ) +from mev_inspect.schemas.classified_traces import Protocol from .helpers import ( make_unknown_trace, @@ -64,6 +65,8 @@ def test_swaps( from_address=alice_address, pool_address=first_pool_address, abi_name=UNISWAP_V2_PAIR_ABI_NAME, + protocol=None, + function_signature="swap(uint256,uint256,address,bytes)", recipient_address=bob_address, recipient_input_key="to", ), @@ -83,6 +86,8 @@ def test_swaps( from_address=bob_address, pool_address=second_pool_address, abi_name=UNISWAP_V3_POOL_ABI_NAME, + protocol=None, + function_signature="swap(address,bool,int256,uint160,bytes)", recipient_address=carl_address, recipient_input_key="recipient", ), @@ -129,6 +134,8 @@ def test_swaps( from_address=bob_address, pool_address=third_pool_address, abi_name=BALANCER_V1_POOL_ABI_NAME, + protocol=Protocol.balancer_v1, + function_signature="swapExactAmountIn(address,uint256,address,uint256,uint256)", recipient_address=bob_address, recipient_input_key="recipient", ),