diff --git a/mev_inspect/classifiers/specs/__init__.py b/mev_inspect/classifiers/specs/__init__.py index e6f3fda..b5dd98e 100644 --- a/mev_inspect/classifiers/specs/__init__.py +++ b/mev_inspect/classifiers/specs/__init__.py @@ -1,3 +1,8 @@ +from typing import Dict, Optional, Tuple, Type + +from mev_inspect.schemas.classified_traces import DecodedCallTrace, Protocol +from mev_inspect.schemas.classifiers import ClassifierSpec, Classifier + from .aave import AAVE_CLASSIFIER_SPECS from .curve import CURVE_CLASSIFIER_SPECS from .erc20 import ERC20_CLASSIFIER_SPECS @@ -16,3 +21,19 @@ ALL_CLASSIFIER_SPECS = ( + ZEROX_CLASSIFIER_SPECS + BALANCER_CLASSIFIER_SPECS ) + +_SPECS_BY_ABI_NAME_AND_PROTOCOL: Dict[ + Tuple[str, Optional[Protocol]], ClassifierSpec +] = {(spec.abi_name, spec.protocol): spec for spec in ALL_CLASSIFIER_SPECS} + + +def get_classifier( + trace: DecodedCallTrace, +) -> Optional[Type[Classifier]]: + abi_name_and_protocol = (trace.abi_name, trace.protocol) + spec = _SPECS_BY_ABI_NAME_AND_PROTOCOL.get(abi_name_and_protocol) + + if spec is not None: + return spec.classifiers.get(trace.function_signature) + + return None diff --git a/mev_inspect/classifiers/specs/balancer.py b/mev_inspect/classifiers/specs/balancer.py index 9b05dfd..3d52978 100644 --- a/mev_inspect/classifiers/specs/balancer.py +++ b/mev_inspect/classifiers/specs/balancer.py @@ -1,4 +1,5 @@ from mev_inspect.schemas.classified_traces import ( + DecodedCallTrace, Protocol, ) from mev_inspect.schemas.classifiers import ( @@ -7,13 +8,19 @@ from mev_inspect.schemas.classifiers import ( ) +class BalancerSwapClassifier(SwapClassifier): + @staticmethod + def get_swap_recipient(trace: DecodedCallTrace) -> str: + return trace.from_address + + BALANCER_V1_SPECS = [ ClassifierSpec( abi_name="BPool", protocol=Protocol.balancer_v1, classifiers={ - "swapExactAmountIn(address,uint256,address,uint256,uint256)": SwapClassifier, - "swapExactAmountOut(address,uint256,address,uint256,uint256)": SwapClassifier, + "swapExactAmountIn(address,uint256,address,uint256,uint256)": BalancerSwapClassifier, + "swapExactAmountOut(address,uint256,address,uint256,uint256)": BalancerSwapClassifier, }, ), ] diff --git a/mev_inspect/classifiers/specs/uniswap.py b/mev_inspect/classifiers/specs/uniswap.py index ad6d146..f136174 100644 --- a/mev_inspect/classifiers/specs/uniswap.py +++ b/mev_inspect/classifiers/specs/uniswap.py @@ -1,4 +1,5 @@ from mev_inspect.schemas.classified_traces import ( + DecodedCallTrace, Protocol, ) from mev_inspect.schemas.classifiers import ( @@ -7,6 +8,24 @@ from mev_inspect.schemas.classifiers import ( ) +class UniswapV3SwapClassifier(SwapClassifier): + @staticmethod + def get_swap_recipient(trace: DecodedCallTrace) -> str: + if trace.inputs is not None and "recipient" in trace.inputs: + return trace.inputs["recipient"] + else: + return trace.from_address + + +class UniswapV2SwapClassifier(SwapClassifier): + @staticmethod + def get_swap_recipient(trace: DecodedCallTrace) -> str: + if trace.inputs is not None and "to" in trace.inputs: + return trace.inputs["to"] + else: + return trace.from_address + + UNISWAP_V3_CONTRACT_SPECS = [ ClassifierSpec( abi_name="UniswapV3Factory", @@ -69,7 +88,7 @@ UNISWAP_V3_GENERAL_SPECS = [ ClassifierSpec( abi_name="UniswapV3Pool", classifiers={ - "swap(address,bool,int256,uint160,bytes)": SwapClassifier, + "swap(address,bool,int256,uint160,bytes)": UniswapV3SwapClassifier, }, ), ClassifierSpec( @@ -100,7 +119,7 @@ UNISWAPPY_V2_CONTRACT_SPECS = [ UNISWAPPY_V2_PAIR_SPEC = ClassifierSpec( abi_name="UniswapV2Pair", classifiers={ - "swap(uint256,uint256,address,bytes)": SwapClassifier, + "swap(uint256,uint256,address,bytes)": UniswapV2SwapClassifier, }, ) diff --git a/mev_inspect/schemas/classifiers.py b/mev_inspect/schemas/classifiers.py index 5cee8d0..0b8db14 100644 --- a/mev_inspect/schemas/classifiers.py +++ b/mev_inspect/schemas/classifiers.py @@ -30,6 +30,11 @@ class SwapClassifier(Classifier): def get_classification() -> Classification: return Classification.swap + @staticmethod + @abstractmethod + def get_swap_recipient(trace: DecodedCallTrace) -> str: + raise NotImplementedError() + class LiquidationClassifier(Classifier): @staticmethod diff --git a/mev_inspect/swaps.py b/mev_inspect/swaps.py index 9181f9d..906759b 100644 --- a/mev_inspect/swaps.py +++ b/mev_inspect/swaps.py @@ -1,24 +1,23 @@ from typing import List, Optional +from mev_inspect.classifiers.specs import get_classifier from mev_inspect.schemas.classified_traces import ( ClassifiedTrace, Classification, + DecodedCallTrace, ) +from mev_inspect.schemas.classifiers import SwapClassifier from mev_inspect.schemas.swaps import Swap from mev_inspect.schemas.transfers import ERC20Transfer from mev_inspect.traces import get_traces_by_transaction_hash from mev_inspect.transfers import ( + get_transfer, get_child_transfers, filter_transfers, remove_child_transfers_of_transfers, ) -UNISWAP_V2_PAIR_ABI_NAME = "UniswapV2Pair" -UNISWAP_V3_POOL_ABI_NAME = "UniswapV3Pool" -BALANCER_V1_POOL_ABI_NAME = "BPool" - - def get_swaps(traces: List[ClassifiedTrace]) -> List[Swap]: swaps = [] @@ -35,8 +34,13 @@ def _get_swaps_for_transaction(traces: List[ClassifiedTrace]) -> List[Swap]: prior_transfers: List[ERC20Transfer] = [] for trace in ordered_traces: - if trace.classification == Classification.transfer: - prior_transfers.append(ERC20Transfer.from_trace(trace)) + if not isinstance(trace, DecodedCallTrace): + continue + + elif trace.classification == Classification.transfer: + transfer = get_transfer(trace) + if transfer is not None: + prior_transfers.append(transfer) elif trace.classification == Classification.swap: child_transfers = get_child_transfers( @@ -58,7 +62,7 @@ def _get_swaps_for_transaction(traces: List[ClassifiedTrace]) -> List[Swap]: def _parse_swap( - trace: ClassifiedTrace, + trace: DecodedCallTrace, prior_transfers: List[ERC20Transfer], child_transfers: List[ERC20Transfer], ) -> Optional[Swap]: @@ -102,20 +106,9 @@ def _parse_swap( ) -def _get_recipient_address(trace: ClassifiedTrace) -> Optional[str]: - if trace.abi_name == UNISWAP_V3_POOL_ABI_NAME: - return ( - trace.inputs["recipient"] - if trace.inputs is not None and "recipient" in trace.inputs - else trace.from_address - ) - elif trace.abi_name == UNISWAP_V2_PAIR_ABI_NAME: - return ( - trace.inputs["to"] - if trace.inputs is not None and "to" in trace.inputs - else trace.from_address - ) - elif trace.abi_name == BALANCER_V1_POOL_ABI_NAME: - return trace.from_address - else: - return None +def _get_recipient_address(trace: DecodedCallTrace) -> Optional[str]: + classifier = get_classifier(trace) + if classifier is not None and issubclass(classifier, SwapClassifier): + return classifier.get_swap_recipient(trace) + + return None diff --git a/mev_inspect/transfers.py b/mev_inspect/transfers.py index 90959cd..b1dc11a 100644 --- a/mev_inspect/transfers.py +++ b/mev_inspect/transfers.py @@ -1,12 +1,11 @@ -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence -from mev_inspect.classifiers.specs import ALL_CLASSIFIER_SPECS -from mev_inspect.schemas.classifiers import ClassifierSpec, TransferClassifier +from mev_inspect.classifiers.specs import get_classifier +from mev_inspect.schemas.classifiers import TransferClassifier from mev_inspect.schemas.classified_traces import ( Classification, ClassifiedTrace, DecodedCallTrace, - Protocol, ) from mev_inspect.schemas.transfers import ERC20Transfer, EthTransfer, TransferGeneric from mev_inspect.traces import is_child_trace_address, get_child_traces @@ -25,25 +24,26 @@ def get_eth_transfers(traces: List[ClassifiedTrace]) -> List[EthTransfer]: def get_transfers(traces: List[ClassifiedTrace]) -> List[ERC20Transfer]: transfers = [] - specs_by_abi_name_and_protocol: Dict[ - Tuple[str, Optional[Protocol]], ClassifierSpec - ] = {(spec.abi_name, spec.protocol): spec for spec in ALL_CLASSIFIER_SPECS} - for trace in traces: - if not isinstance(trace, DecodedCallTrace): - continue - - abi_name_and_protocol = (trace.abi_name, trace.protocol) - spec = specs_by_abi_name_and_protocol.get(abi_name_and_protocol) - - if spec is not None: - classifier = spec.classifiers.get(trace.function_signature) - if classifier is not None and issubclass(classifier, TransferClassifier): - transfers.append(classifier.get_transfer(trace)) + if isinstance(trace, DecodedCallTrace): + transfer = get_transfer(trace) + if transfer is not None: + transfers.append(transfer) return transfers +def get_transfer(trace: DecodedCallTrace) -> Optional[ERC20Transfer]: + if not isinstance(trace, DecodedCallTrace): + return None + + classifier = get_classifier(trace) + if classifier is not None and issubclass(classifier, TransferClassifier): + return classifier.get_transfer(trace) + + return None + + def get_child_transfers( transaction_hash: str, parent_trace_address: List[int], @@ -52,8 +52,12 @@ def get_child_transfers( child_transfers = [] for child_trace in get_child_traces(transaction_hash, parent_trace_address, traces): - if child_trace.classification == Classification.transfer: - child_transfers.append(ERC20Transfer.from_trace(child_trace)) + if child_trace.classification == Classification.transfer and isinstance( + child_trace, DecodedCallTrace + ): + transfer = get_transfer(child_trace) + if transfer is not None: + child_transfers.append(transfer) return child_transfers