diff --git a/mev_inspect/aave_liquidations.py b/mev_inspect/aave_liquidations.py index 49e4c3c..213bd1c 100644 --- a/mev_inspect/aave_liquidations.py +++ b/mev_inspect/aave_liquidations.py @@ -11,8 +11,8 @@ from mev_inspect.schemas.classified_traces import ( Protocol, ) -from mev_inspect.schemas.transfers import ERC20Transfer from mev_inspect.schemas.liquidations import Liquidation +from mev_inspect.transfers import get_transfer AAVE_CONTRACT_ADDRESSES: List[str] = [ # AAVE Proxy @@ -76,10 +76,12 @@ def _get_liquidator_payback( for child in child_traces: if child.classification == Classification.transfer: - child_transfer = ERC20Transfer.from_trace(child) + child_transfer = get_transfer(child) - if (child_transfer.to_address == liquidator) and ( - child.from_address in AAVE_CONTRACT_ADDRESSES + if ( + child_transfer is not None + and child_transfer.to_address == liquidator + and child.from_address in AAVE_CONTRACT_ADDRESSES ): return child_transfer.amount diff --git a/mev_inspect/classifiers/specs/erc20.py b/mev_inspect/classifiers/specs/erc20.py index aee3179..a84445c 100644 --- a/mev_inspect/classifiers/specs/erc20.py +++ b/mev_inspect/classifiers/specs/erc20.py @@ -3,13 +3,13 @@ from mev_inspect.schemas.classifiers import ( ClassifierSpec, TransferClassifier, ) -from mev_inspect.schemas.transfers import ERC20Transfer +from mev_inspect.schemas.transfers import Transfer class ERC20TransferClassifier(TransferClassifier): @staticmethod - def get_transfer(trace: DecodedCallTrace) -> ERC20Transfer: - return ERC20Transfer( + def get_transfer(trace: DecodedCallTrace) -> Transfer: + return Transfer( block_number=trace.block_number, transaction_hash=trace.transaction_hash, trace_address=trace.trace_address, diff --git a/mev_inspect/classifiers/specs/weth.py b/mev_inspect/classifiers/specs/weth.py index e84f2f3..2184b7b 100644 --- a/mev_inspect/classifiers/specs/weth.py +++ b/mev_inspect/classifiers/specs/weth.py @@ -6,13 +6,13 @@ from mev_inspect.schemas.classifiers import ( DecodedCallTrace, TransferClassifier, ) -from mev_inspect.schemas.transfers import ERC20Transfer +from mev_inspect.schemas.transfers import Transfer class WethTransferClassifier(TransferClassifier): @staticmethod - def get_transfer(trace: DecodedCallTrace) -> ERC20Transfer: - return ERC20Transfer( + def get_transfer(trace: DecodedCallTrace) -> Transfer: + return Transfer( block_number=trace.block_number, transaction_hash=trace.transaction_hash, trace_address=trace.trace_address, diff --git a/mev_inspect/crud/transfers.py b/mev_inspect/crud/transfers.py index eb57ad9..7aa5adb 100644 --- a/mev_inspect/crud/transfers.py +++ b/mev_inspect/crud/transfers.py @@ -2,7 +2,7 @@ import json from typing import List from mev_inspect.models.transfers import TransferModel -from mev_inspect.schemas.transfers import ERC20Transfer +from mev_inspect.schemas.transfers import Transfer def delete_transfers_for_block( @@ -20,7 +20,7 @@ def delete_transfers_for_block( def write_transfers( db_session, - transfers: List[ERC20Transfer], + transfers: List[Transfer], ) -> None: models = [TransferModel(**json.loads(transfer.json())) for transfer in transfers] diff --git a/mev_inspect/schemas/classifiers.py b/mev_inspect/schemas/classifiers.py index 0b8db14..caf4b9f 100644 --- a/mev_inspect/schemas/classifiers.py +++ b/mev_inspect/schemas/classifiers.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Type from pydantic import BaseModel from .classified_traces import Classification, DecodedCallTrace, Protocol -from .transfers import ERC20Transfer +from .transfers import Transfer class Classifier(ABC): @@ -21,7 +21,7 @@ class TransferClassifier(Classifier): @staticmethod @abstractmethod - def get_transfer(trace: DecodedCallTrace) -> ERC20Transfer: + def get_transfer(trace: DecodedCallTrace) -> Transfer: raise NotImplementedError() diff --git a/mev_inspect/schemas/transfers.py b/mev_inspect/schemas/transfers.py index e62d185..9719c2b 100644 --- a/mev_inspect/schemas/transfers.py +++ b/mev_inspect/schemas/transfers.py @@ -1,8 +1,9 @@ -from typing import List, TypeVar +from typing import List from pydantic import BaseModel -from .classified_traces import Classification, ClassifiedTrace, Protocol + +ETH_TOKEN_ADDRESS = "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE" class Transfer(BaseModel): @@ -12,50 +13,4 @@ class Transfer(BaseModel): from_address: str to_address: str amount: int - - -# To preserve the specific Transfer type -TransferGeneric = TypeVar("TransferGeneric", bound="Transfer") - - -class EthTransfer(Transfer): - @classmethod - def from_trace(cls, trace: ClassifiedTrace) -> "EthTransfer": - return cls( - block_number=trace.block_number, - transaction_hash=trace.transaction_hash, - trace_address=trace.trace_address, - amount=trace.value, - to_address=trace.to_address, - from_address=trace.from_address, - ) - - -class ERC20Transfer(Transfer): token_address: str - - @classmethod - def from_trace(cls, trace: ClassifiedTrace) -> "ERC20Transfer": - if trace.classification != Classification.transfer or trace.inputs is None: - raise ValueError("Invalid transfer") - - if trace.protocol == Protocol.weth: - return cls( - block_number=trace.block_number, - transaction_hash=trace.transaction_hash, - trace_address=trace.trace_address, - amount=trace.inputs["wad"], - to_address=trace.inputs["dst"], - from_address=trace.from_address, - token_address=trace.to_address, - ) - else: - return cls( - block_number=trace.block_number, - transaction_hash=trace.transaction_hash, - trace_address=trace.trace_address, - amount=trace.inputs["amount"], - to_address=trace.inputs["recipient"], - from_address=trace.inputs.get("sender", trace.from_address), - token_address=trace.to_address, - ) diff --git a/mev_inspect/swaps.py b/mev_inspect/swaps.py index f6645db..54640e7 100644 --- a/mev_inspect/swaps.py +++ b/mev_inspect/swaps.py @@ -8,11 +8,11 @@ from mev_inspect.schemas.classified_traces import ( ) from mev_inspect.schemas.classifiers import SwapClassifier from mev_inspect.schemas.swaps import Swap -from mev_inspect.schemas.transfers import ERC20Transfer +from mev_inspect.schemas.transfers import Transfer from mev_inspect.traces import get_traces_by_transaction_hash from mev_inspect.transfers import ( get_child_transfers, - get_erc20_transfer, + get_transfer, filter_transfers, remove_child_transfers_of_transfers, ) @@ -31,14 +31,14 @@ def _get_swaps_for_transaction(traces: List[ClassifiedTrace]) -> List[Swap]: ordered_traces = list(sorted(traces, key=lambda t: t.trace_address)) swaps: List[Swap] = [] - prior_transfers: List[ERC20Transfer] = [] + prior_transfers: List[Transfer] = [] for trace in ordered_traces: if not isinstance(trace, DecodedCallTrace): continue elif trace.classification == Classification.transfer: - transfer = get_erc20_transfer(trace) + transfer = get_transfer(trace) if transfer is not None: prior_transfers.append(transfer) @@ -63,8 +63,8 @@ def _get_swaps_for_transaction(traces: List[ClassifiedTrace]) -> List[Swap]: def _parse_swap( trace: DecodedCallTrace, - prior_transfers: List[ERC20Transfer], - child_transfers: List[ERC20Transfer], + prior_transfers: List[Transfer], + child_transfers: List[Transfer], ) -> Optional[Swap]: pool_address = trace.to_address recipient_address = _get_recipient_address(trace) diff --git a/mev_inspect/transfers.py b/mev_inspect/transfers.py index aef588f..865a54e 100644 --- a/mev_inspect/transfers.py +++ b/mev_inspect/transfers.py @@ -3,40 +3,57 @@ from typing import Dict, List, Optional, Sequence from mev_inspect.classifiers.specs import get_classifier from mev_inspect.schemas.classifiers import TransferClassifier from mev_inspect.schemas.classified_traces import ( - Classification, ClassifiedTrace, DecodedCallTrace, ) -from mev_inspect.schemas.transfers import ERC20Transfer, EthTransfer, TransferGeneric +from mev_inspect.schemas.transfers import ETH_TOKEN_ADDRESS, Transfer from mev_inspect.traces import is_child_trace_address, get_child_traces -def get_eth_transfers(traces: List[ClassifiedTrace]) -> List[EthTransfer]: +def get_transfers(traces: List[ClassifiedTrace]) -> List[Transfer]: transfers = [] for trace in traces: - if trace.value is not None and trace.value > 0: - transfers.append(EthTransfer.from_trace(trace)) + transfer = get_transfer(trace) + if transfer is not None: + transfers.append(transfer) return transfers -def get_transfers(traces: List[ClassifiedTrace]) -> List[ERC20Transfer]: - transfers = [] +def get_eth_transfers(traces: List[ClassifiedTrace]) -> List[Transfer]: + transfers = get_transfers(traces) - for trace in traces: - if isinstance(trace, DecodedCallTrace): - transfer = get_erc20_transfer(trace) - if transfer is not None: - transfers.append(transfer) - - return transfers + return [ + transfer + for transfer in transfers + if transfer.token_address == ETH_TOKEN_ADDRESS + ] -def get_erc20_transfer(trace: DecodedCallTrace) -> Optional[ERC20Transfer]: - if not isinstance(trace, DecodedCallTrace): - return None +def get_transfer(trace: ClassifiedTrace) -> Optional[Transfer]: + if trace.value is not None and trace.value > 0: + return _build_eth_transfer(trace) + if isinstance(trace, DecodedCallTrace): + return _build_erc20_transfer(trace) + + return None + + +def _build_eth_transfer(trace: ClassifiedTrace) -> Transfer: + return Transfer( + block_number=trace.block_number, + transaction_hash=trace.transaction_hash, + trace_address=trace.trace_address, + amount=trace.value, + to_address=trace.to_address, + from_address=trace.from_address, + token_address=ETH_TOKEN_ADDRESS, + ) + + +def _build_erc20_transfer(trace: DecodedCallTrace) -> Optional[Transfer]: classifier = get_classifier(trace) if classifier is not None and issubclass(classifier, TransferClassifier): return classifier.get_transfer(trace) @@ -48,25 +65,22 @@ def get_child_transfers( transaction_hash: str, parent_trace_address: List[int], traces: List[ClassifiedTrace], -) -> List[ERC20Transfer]: +) -> List[Transfer]: child_transfers = [] for child_trace in get_child_traces(transaction_hash, parent_trace_address, traces): - if child_trace.classification == Classification.transfer and isinstance( - child_trace, DecodedCallTrace - ): - transfer = get_erc20_transfer(child_trace) - if transfer is not None: - child_transfers.append(transfer) + transfer = get_transfer(child_trace) + if transfer is not None: + child_transfers.append(transfer) return child_transfers def filter_transfers( - transfers: Sequence[TransferGeneric], + transfers: Sequence[Transfer], to_address: Optional[str] = None, from_address: Optional[str] = None, -) -> List[TransferGeneric]: +) -> List[Transfer]: filtered_transfers = [] for transfer in transfers: @@ -82,8 +96,8 @@ def filter_transfers( def remove_child_transfers_of_transfers( - transfers: List[ERC20Transfer], -) -> List[ERC20Transfer]: + transfers: List[Transfer], +) -> List[Transfer]: updated_transfers = [] transfer_addresses_by_transaction: Dict[str, List[List[int]]] = {} diff --git a/tests/test_transfers.py b/tests/test_transfers.py index 6abd889..9f7fc16 100644 --- a/tests/test_transfers.py +++ b/tests/test_transfers.py @@ -1,4 +1,4 @@ -from mev_inspect.schemas.transfers import ERC20Transfer +from mev_inspect.schemas.transfers import Transfer from mev_inspect.transfers import remove_child_transfers_of_transfers @@ -13,7 +13,7 @@ def test_remove_child_transfers_of_transfers(get_transaction_hashes, get_address third_token_address, ] = get_addresses(5) - outer_transfer = ERC20Transfer( + outer_transfer = Transfer( block_number=123, transaction_hash=transaction_hash, trace_address=[0], @@ -23,7 +23,7 @@ def test_remove_child_transfers_of_transfers(get_transaction_hashes, get_address token_address=first_token_address, ) - inner_transfer = ERC20Transfer( + inner_transfer = Transfer( **{ **outer_transfer.dict(), **dict( @@ -33,7 +33,7 @@ def test_remove_child_transfers_of_transfers(get_transaction_hashes, get_address } ) - other_transfer = ERC20Transfer( + other_transfer = Transfer( block_number=123, transaction_hash=transaction_hash, trace_address=[1], @@ -43,7 +43,7 @@ def test_remove_child_transfers_of_transfers(get_transaction_hashes, get_address token_address=third_token_address, ) - separate_transaction_transfer = ERC20Transfer( + separate_transaction_transfer = Transfer( **{ **inner_transfer.dict(), **dict(transaction_hash=other_transaction_hash),