diff --git a/mev_inspect/aave_liquidations.py b/mev_inspect/aave_liquidations.py index a637c06..ee49048 100644 --- a/mev_inspect/aave_liquidations.py +++ b/mev_inspect/aave_liquidations.py @@ -11,7 +11,7 @@ from mev_inspect.schemas.classified_traces import ( Protocol, ) -from mev_inspect.schemas.transfers import ERC20Transfer +from mev_inspect.schemas.transfers import ERC20Transfer, aTokenTransfer, Transfer from mev_inspect.schemas.liquidations import Liquidation AAVE_CONTRACT_ADDRESSES: List[str] = [ @@ -83,14 +83,30 @@ def _get_payback_token_and_amount( ) -> Tuple[str, int]: """Look for and return liquidator payback from liquidation""" + child: ClassifiedTrace + child_transfer: Transfer + for child in child_traces: - if child.classification == Classification.transfer: + if child.classification == Classification.transfer and isinstance( + child, DecodedCallTrace + ): + if liquidation.inputs["_receiveAToken"]: - child_transfer = ERC20Transfer.from_trace(child) - if ( - child_transfer.to_address == liquidator - ) and child.from_address in AAVE_CONTRACT_ADDRESSES: - return child_transfer.token_address, child_transfer.amount + child_transfer = aTokenTransfer.from_trace(child) + + if ( + child_transfer.to_address == liquidator + ) and child.from_address in AAVE_CONTRACT_ADDRESSES: + return child_transfer.token_address, child_transfer.amount + + else: + + child_transfer = ERC20Transfer.from_trace(child) + + if ( + child_transfer.to_address == liquidator + ) and child.from_address in AAVE_CONTRACT_ADDRESSES: + return child_transfer.token_address, child_transfer.amount return liquidation.inputs["_collateral"], 0 diff --git a/mev_inspect/schemas/transfers.py b/mev_inspect/schemas/transfers.py index e62d185..832b846 100644 --- a/mev_inspect/schemas/transfers.py +++ b/mev_inspect/schemas/transfers.py @@ -2,7 +2,12 @@ from typing import List, TypeVar from pydantic import BaseModel -from .classified_traces import Classification, ClassifiedTrace, Protocol +from .classified_traces import ( + Classification, + ClassifiedTrace, + DecodedCallTrace, + Protocol, +) class Transfer(BaseModel): @@ -31,6 +36,22 @@ class EthTransfer(Transfer): ) +class aTokenTransfer(Transfer): + token_address: str + + @classmethod + def from_trace(cls, trace: DecodedCallTrace) -> "aTokenTransfer": + return cls( + block_number=trace.block_number, + transaction_hash=trace.transaction_hash, + trace_address=trace.trace_address, + amount=trace.inputs["value"], + to_address=trace.inputs["to"], + from_address=trace.inputs["from"], + token_address=trace.to_address, + ) + + class ERC20Transfer(Transfer): token_address: str @@ -39,6 +60,17 @@ class ERC20Transfer(Transfer): if trace.classification != Classification.transfer or trace.inputs is None: raise ValueError("Invalid transfer") + if trace.protocol == Protocol.aave: + return cls( + block_number=trace.block_number, + transaction_hash=trace.transaction_hash, + trace_address=trace.trace_address, + amount=trace.inputs["value"], + to_address=trace.inputs["to"], + from_address=trace.inputs["from"], + token_address=trace.to_address, + ) + if trace.protocol == Protocol.weth: return cls( block_number=trace.block_number, @@ -49,6 +81,7 @@ class ERC20Transfer(Transfer): from_address=trace.from_address, token_address=trace.to_address, ) + else: return cls( block_number=trace.block_number,