Use SwapClassifier for Swap objects

This commit is contained in:
Luke Van Seters 2021-10-07 18:00:13 -04:00
parent 8c6d7ab889
commit 3039f3eed2
6 changed files with 98 additions and 49 deletions

View File

@ -1,3 +1,8 @@
from typing import Dict, Optional, Tuple, Type
from mev_inspect.schemas.classified_traces import DecodedCallTrace, Protocol
from mev_inspect.schemas.classifiers import ClassifierSpec, Classifier
from .aave import AAVE_CLASSIFIER_SPECS
from .curve import CURVE_CLASSIFIER_SPECS
from .erc20 import ERC20_CLASSIFIER_SPECS
@ -16,3 +21,19 @@ ALL_CLASSIFIER_SPECS = (
+ ZEROX_CLASSIFIER_SPECS
+ BALANCER_CLASSIFIER_SPECS
)
_SPECS_BY_ABI_NAME_AND_PROTOCOL: Dict[
Tuple[str, Optional[Protocol]], ClassifierSpec
] = {(spec.abi_name, spec.protocol): spec for spec in ALL_CLASSIFIER_SPECS}
def get_classifier(
trace: DecodedCallTrace,
) -> Optional[Type[Classifier]]:
abi_name_and_protocol = (trace.abi_name, trace.protocol)
spec = _SPECS_BY_ABI_NAME_AND_PROTOCOL.get(abi_name_and_protocol)
if spec is not None:
return spec.classifiers.get(trace.function_signature)
return None

View File

@ -1,4 +1,5 @@
from mev_inspect.schemas.classified_traces import (
DecodedCallTrace,
Protocol,
)
from mev_inspect.schemas.classifiers import (
@ -7,13 +8,19 @@ from mev_inspect.schemas.classifiers import (
)
class BalancerSwapClassifier(SwapClassifier):
@staticmethod
def get_swap_recipient(trace: DecodedCallTrace) -> str:
return trace.from_address
BALANCER_V1_SPECS = [
ClassifierSpec(
abi_name="BPool",
protocol=Protocol.balancer_v1,
classifiers={
"swapExactAmountIn(address,uint256,address,uint256,uint256)": SwapClassifier,
"swapExactAmountOut(address,uint256,address,uint256,uint256)": SwapClassifier,
"swapExactAmountIn(address,uint256,address,uint256,uint256)": BalancerSwapClassifier,
"swapExactAmountOut(address,uint256,address,uint256,uint256)": BalancerSwapClassifier,
},
),
]

View File

@ -1,4 +1,5 @@
from mev_inspect.schemas.classified_traces import (
DecodedCallTrace,
Protocol,
)
from mev_inspect.schemas.classifiers import (
@ -7,6 +8,24 @@ from mev_inspect.schemas.classifiers import (
)
class UniswapV3SwapClassifier(SwapClassifier):
@staticmethod
def get_swap_recipient(trace: DecodedCallTrace) -> str:
if trace.inputs is not None and "recipient" in trace.inputs:
return trace.inputs["recipient"]
else:
return trace.from_address
class UniswapV2SwapClassifier(SwapClassifier):
@staticmethod
def get_swap_recipient(trace: DecodedCallTrace) -> str:
if trace.inputs is not None and "to" in trace.inputs:
return trace.inputs["to"]
else:
return trace.from_address
UNISWAP_V3_CONTRACT_SPECS = [
ClassifierSpec(
abi_name="UniswapV3Factory",
@ -69,7 +88,7 @@ UNISWAP_V3_GENERAL_SPECS = [
ClassifierSpec(
abi_name="UniswapV3Pool",
classifiers={
"swap(address,bool,int256,uint160,bytes)": SwapClassifier,
"swap(address,bool,int256,uint160,bytes)": UniswapV3SwapClassifier,
},
),
ClassifierSpec(
@ -100,7 +119,7 @@ UNISWAPPY_V2_CONTRACT_SPECS = [
UNISWAPPY_V2_PAIR_SPEC = ClassifierSpec(
abi_name="UniswapV2Pair",
classifiers={
"swap(uint256,uint256,address,bytes)": SwapClassifier,
"swap(uint256,uint256,address,bytes)": UniswapV2SwapClassifier,
},
)

View File

@ -30,6 +30,11 @@ class SwapClassifier(Classifier):
def get_classification() -> Classification:
return Classification.swap
@staticmethod
@abstractmethod
def get_swap_recipient(trace: DecodedCallTrace) -> str:
raise NotImplementedError()
class LiquidationClassifier(Classifier):
@staticmethod

View File

@ -1,24 +1,23 @@
from typing import List, Optional
from mev_inspect.classifiers.specs import get_classifier
from mev_inspect.schemas.classified_traces import (
ClassifiedTrace,
Classification,
DecodedCallTrace,
)
from mev_inspect.schemas.classifiers import SwapClassifier
from mev_inspect.schemas.swaps import Swap
from mev_inspect.schemas.transfers import ERC20Transfer
from mev_inspect.traces import get_traces_by_transaction_hash
from mev_inspect.transfers import (
get_transfer,
get_child_transfers,
filter_transfers,
remove_child_transfers_of_transfers,
)
UNISWAP_V2_PAIR_ABI_NAME = "UniswapV2Pair"
UNISWAP_V3_POOL_ABI_NAME = "UniswapV3Pool"
BALANCER_V1_POOL_ABI_NAME = "BPool"
def get_swaps(traces: List[ClassifiedTrace]) -> List[Swap]:
swaps = []
@ -35,8 +34,13 @@ def _get_swaps_for_transaction(traces: List[ClassifiedTrace]) -> List[Swap]:
prior_transfers: List[ERC20Transfer] = []
for trace in ordered_traces:
if trace.classification == Classification.transfer:
prior_transfers.append(ERC20Transfer.from_trace(trace))
if not isinstance(trace, DecodedCallTrace):
continue
elif trace.classification == Classification.transfer:
transfer = get_transfer(trace)
if transfer is not None:
prior_transfers.append(transfer)
elif trace.classification == Classification.swap:
child_transfers = get_child_transfers(
@ -58,7 +62,7 @@ def _get_swaps_for_transaction(traces: List[ClassifiedTrace]) -> List[Swap]:
def _parse_swap(
trace: ClassifiedTrace,
trace: DecodedCallTrace,
prior_transfers: List[ERC20Transfer],
child_transfers: List[ERC20Transfer],
) -> Optional[Swap]:
@ -102,20 +106,9 @@ def _parse_swap(
)
def _get_recipient_address(trace: ClassifiedTrace) -> Optional[str]:
if trace.abi_name == UNISWAP_V3_POOL_ABI_NAME:
return (
trace.inputs["recipient"]
if trace.inputs is not None and "recipient" in trace.inputs
else trace.from_address
)
elif trace.abi_name == UNISWAP_V2_PAIR_ABI_NAME:
return (
trace.inputs["to"]
if trace.inputs is not None and "to" in trace.inputs
else trace.from_address
)
elif trace.abi_name == BALANCER_V1_POOL_ABI_NAME:
return trace.from_address
else:
return None
def _get_recipient_address(trace: DecodedCallTrace) -> Optional[str]:
classifier = get_classifier(trace)
if classifier is not None and issubclass(classifier, SwapClassifier):
return classifier.get_swap_recipient(trace)
return None

View File

@ -1,12 +1,11 @@
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Dict, List, Optional, Sequence
from mev_inspect.classifiers.specs import ALL_CLASSIFIER_SPECS
from mev_inspect.schemas.classifiers import ClassifierSpec, TransferClassifier
from mev_inspect.classifiers.specs import get_classifier
from mev_inspect.schemas.classifiers import TransferClassifier
from mev_inspect.schemas.classified_traces import (
Classification,
ClassifiedTrace,
DecodedCallTrace,
Protocol,
)
from mev_inspect.schemas.transfers import ERC20Transfer, EthTransfer, TransferGeneric
from mev_inspect.traces import is_child_trace_address, get_child_traces
@ -25,25 +24,26 @@ def get_eth_transfers(traces: List[ClassifiedTrace]) -> List[EthTransfer]:
def get_transfers(traces: List[ClassifiedTrace]) -> List[ERC20Transfer]:
transfers = []
specs_by_abi_name_and_protocol: Dict[
Tuple[str, Optional[Protocol]], ClassifierSpec
] = {(spec.abi_name, spec.protocol): spec for spec in ALL_CLASSIFIER_SPECS}
for trace in traces:
if not isinstance(trace, DecodedCallTrace):
continue
abi_name_and_protocol = (trace.abi_name, trace.protocol)
spec = specs_by_abi_name_and_protocol.get(abi_name_and_protocol)
if spec is not None:
classifier = spec.classifiers.get(trace.function_signature)
if classifier is not None and issubclass(classifier, TransferClassifier):
transfers.append(classifier.get_transfer(trace))
if isinstance(trace, DecodedCallTrace):
transfer = get_transfer(trace)
if transfer is not None:
transfers.append(transfer)
return transfers
def get_transfer(trace: DecodedCallTrace) -> Optional[ERC20Transfer]:
if not isinstance(trace, DecodedCallTrace):
return None
classifier = get_classifier(trace)
if classifier is not None and issubclass(classifier, TransferClassifier):
return classifier.get_transfer(trace)
return None
def get_child_transfers(
transaction_hash: str,
parent_trace_address: List[int],
@ -52,8 +52,12 @@ def get_child_transfers(
child_transfers = []
for child_trace in get_child_traces(transaction_hash, parent_trace_address, traces):
if child_trace.classification == Classification.transfer:
child_transfers.append(ERC20Transfer.from_trace(child_trace))
if child_trace.classification == Classification.transfer and isinstance(
child_trace, DecodedCallTrace
):
transfer = get_transfer(child_trace)
if transfer is not None:
child_transfers.append(transfer)
return child_transfers