Merge pull request #92 from flashbots/swaps-classifier

Use SwapClassifier for decoding Swap objects
This commit is contained in:
Luke Van Seters 2021-10-08 11:47:27 -04:00 committed by GitHub
commit afcff7c845
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 130 additions and 61 deletions

View File

@ -1,3 +1,8 @@
from typing import Dict, Optional, Tuple, Type
from mev_inspect.schemas.classified_traces import DecodedCallTrace, Protocol
from mev_inspect.schemas.classifiers import ClassifierSpec, Classifier
from .aave import AAVE_CLASSIFIER_SPECS
from .curve import CURVE_CLASSIFIER_SPECS
from .erc20 import ERC20_CLASSIFIER_SPECS
@ -16,3 +21,19 @@ ALL_CLASSIFIER_SPECS = (
+ ZEROX_CLASSIFIER_SPECS
+ BALANCER_CLASSIFIER_SPECS
)
_SPECS_BY_ABI_NAME_AND_PROTOCOL: Dict[
Tuple[str, Optional[Protocol]], ClassifierSpec
] = {(spec.abi_name, spec.protocol): spec for spec in ALL_CLASSIFIER_SPECS}
def get_classifier(
trace: DecodedCallTrace,
) -> Optional[Type[Classifier]]:
abi_name_and_protocol = (trace.abi_name, trace.protocol)
spec = _SPECS_BY_ABI_NAME_AND_PROTOCOL.get(abi_name_and_protocol)
if spec is not None:
return spec.classifiers.get(trace.function_signature)
return None

View File

@ -1,4 +1,5 @@
from mev_inspect.schemas.classified_traces import (
DecodedCallTrace,
Protocol,
)
from mev_inspect.schemas.classifiers import (
@ -7,13 +8,22 @@ from mev_inspect.schemas.classifiers import (
)
BALANCER_V1_POOL_ABI_NAME = "BPool"
class BalancerSwapClassifier(SwapClassifier):
@staticmethod
def get_swap_recipient(trace: DecodedCallTrace) -> str:
return trace.from_address
BALANCER_V1_SPECS = [
ClassifierSpec(
abi_name="BPool",
abi_name=BALANCER_V1_POOL_ABI_NAME,
protocol=Protocol.balancer_v1,
classifiers={
"swapExactAmountIn(address,uint256,address,uint256,uint256)": SwapClassifier,
"swapExactAmountOut(address,uint256,address,uint256,uint256)": SwapClassifier,
"swapExactAmountIn(address,uint256,address,uint256,uint256)": BalancerSwapClassifier,
"swapExactAmountOut(address,uint256,address,uint256,uint256)": BalancerSwapClassifier,
},
),
]

View File

@ -1,4 +1,5 @@
from mev_inspect.schemas.classified_traces import (
DecodedCallTrace,
Protocol,
)
from mev_inspect.schemas.classifiers import (
@ -7,6 +8,28 @@ from mev_inspect.schemas.classifiers import (
)
UNISWAP_V2_PAIR_ABI_NAME = "UniswapV2Pair"
UNISWAP_V3_POOL_ABI_NAME = "UniswapV3Pool"
class UniswapV3SwapClassifier(SwapClassifier):
@staticmethod
def get_swap_recipient(trace: DecodedCallTrace) -> str:
if trace.inputs is not None and "recipient" in trace.inputs:
return trace.inputs["recipient"]
else:
return trace.from_address
class UniswapV2SwapClassifier(SwapClassifier):
@staticmethod
def get_swap_recipient(trace: DecodedCallTrace) -> str:
if trace.inputs is not None and "to" in trace.inputs:
return trace.inputs["to"]
else:
return trace.from_address
UNISWAP_V3_CONTRACT_SPECS = [
ClassifierSpec(
abi_name="UniswapV3Factory",
@ -67,9 +90,9 @@ UNISWAP_V3_CONTRACT_SPECS = [
UNISWAP_V3_GENERAL_SPECS = [
ClassifierSpec(
abi_name="UniswapV3Pool",
abi_name=UNISWAP_V3_POOL_ABI_NAME,
classifiers={
"swap(address,bool,int256,uint160,bytes)": SwapClassifier,
"swap(address,bool,int256,uint160,bytes)": UniswapV3SwapClassifier,
},
),
ClassifierSpec(
@ -98,9 +121,9 @@ UNISWAPPY_V2_CONTRACT_SPECS = [
]
UNISWAPPY_V2_PAIR_SPEC = ClassifierSpec(
abi_name="UniswapV2Pair",
abi_name=UNISWAP_V2_PAIR_ABI_NAME,
classifiers={
"swap(uint256,uint256,address,bytes)": SwapClassifier,
"swap(uint256,uint256,address,bytes)": UniswapV2SwapClassifier,
},
)

View File

@ -30,6 +30,11 @@ class SwapClassifier(Classifier):
def get_classification() -> Classification:
return Classification.swap
@staticmethod
@abstractmethod
def get_swap_recipient(trace: DecodedCallTrace) -> str:
raise NotImplementedError()
class LiquidationClassifier(Classifier):
@staticmethod

View File

@ -1,24 +1,23 @@
from typing import List, Optional
from mev_inspect.classifiers.specs import get_classifier
from mev_inspect.schemas.classified_traces import (
ClassifiedTrace,
Classification,
DecodedCallTrace,
)
from mev_inspect.schemas.classifiers import SwapClassifier
from mev_inspect.schemas.swaps import Swap
from mev_inspect.schemas.transfers import ERC20Transfer
from mev_inspect.traces import get_traces_by_transaction_hash
from mev_inspect.transfers import (
get_transfer,
get_child_transfers,
filter_transfers,
remove_child_transfers_of_transfers,
)
UNISWAP_V2_PAIR_ABI_NAME = "UniswapV2Pair"
UNISWAP_V3_POOL_ABI_NAME = "UniswapV3Pool"
BALANCER_V1_POOL_ABI_NAME = "BPool"
def get_swaps(traces: List[ClassifiedTrace]) -> List[Swap]:
swaps = []
@ -35,8 +34,13 @@ def _get_swaps_for_transaction(traces: List[ClassifiedTrace]) -> List[Swap]:
prior_transfers: List[ERC20Transfer] = []
for trace in ordered_traces:
if trace.classification == Classification.transfer:
prior_transfers.append(ERC20Transfer.from_trace(trace))
if not isinstance(trace, DecodedCallTrace):
continue
elif trace.classification == Classification.transfer:
transfer = get_transfer(trace)
if transfer is not None:
prior_transfers.append(transfer)
elif trace.classification == Classification.swap:
child_transfers = get_child_transfers(
@ -58,7 +62,7 @@ def _get_swaps_for_transaction(traces: List[ClassifiedTrace]) -> List[Swap]:
def _parse_swap(
trace: ClassifiedTrace,
trace: DecodedCallTrace,
prior_transfers: List[ERC20Transfer],
child_transfers: List[ERC20Transfer],
) -> Optional[Swap]:
@ -102,20 +106,9 @@ def _parse_swap(
)
def _get_recipient_address(trace: ClassifiedTrace) -> Optional[str]:
if trace.abi_name == UNISWAP_V3_POOL_ABI_NAME:
return (
trace.inputs["recipient"]
if trace.inputs is not None and "recipient" in trace.inputs
else trace.from_address
)
elif trace.abi_name == UNISWAP_V2_PAIR_ABI_NAME:
return (
trace.inputs["to"]
if trace.inputs is not None and "to" in trace.inputs
else trace.from_address
)
elif trace.abi_name == BALANCER_V1_POOL_ABI_NAME:
return trace.from_address
else:
return None
def _get_recipient_address(trace: DecodedCallTrace) -> Optional[str]:
classifier = get_classifier(trace)
if classifier is not None and issubclass(classifier, SwapClassifier):
return classifier.get_swap_recipient(trace)
return None

View File

@ -1,12 +1,11 @@
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Dict, List, Optional, Sequence
from mev_inspect.classifiers.specs import ALL_CLASSIFIER_SPECS
from mev_inspect.schemas.classifiers import ClassifierSpec, TransferClassifier
from mev_inspect.classifiers.specs import get_classifier
from mev_inspect.schemas.classifiers import TransferClassifier
from mev_inspect.schemas.classified_traces import (
Classification,
ClassifiedTrace,
DecodedCallTrace,
Protocol,
)
from mev_inspect.schemas.transfers import ERC20Transfer, EthTransfer, TransferGeneric
from mev_inspect.traces import is_child_trace_address, get_child_traces
@ -25,25 +24,26 @@ def get_eth_transfers(traces: List[ClassifiedTrace]) -> List[EthTransfer]:
def get_transfers(traces: List[ClassifiedTrace]) -> List[ERC20Transfer]:
transfers = []
specs_by_abi_name_and_protocol: Dict[
Tuple[str, Optional[Protocol]], ClassifierSpec
] = {(spec.abi_name, spec.protocol): spec for spec in ALL_CLASSIFIER_SPECS}
for trace in traces:
if not isinstance(trace, DecodedCallTrace):
continue
abi_name_and_protocol = (trace.abi_name, trace.protocol)
spec = specs_by_abi_name_and_protocol.get(abi_name_and_protocol)
if spec is not None:
classifier = spec.classifiers.get(trace.function_signature)
if classifier is not None and issubclass(classifier, TransferClassifier):
transfers.append(classifier.get_transfer(trace))
if isinstance(trace, DecodedCallTrace):
transfer = get_transfer(trace)
if transfer is not None:
transfers.append(transfer)
return transfers
def get_transfer(trace: DecodedCallTrace) -> Optional[ERC20Transfer]:
if not isinstance(trace, DecodedCallTrace):
return None
classifier = get_classifier(trace)
if classifier is not None and issubclass(classifier, TransferClassifier):
return classifier.get_transfer(trace)
return None
def get_child_transfers(
transaction_hash: str,
parent_trace_address: List[int],
@ -52,8 +52,12 @@ def get_child_transfers(
child_transfers = []
for child_trace in get_child_traces(transaction_hash, parent_trace_address, traces):
if child_trace.classification == Classification.transfer:
child_transfers.append(ERC20Transfer.from_trace(child_trace))
if child_trace.classification == Classification.transfer and isinstance(
child_trace, DecodedCallTrace
):
transfer = get_transfer(child_trace)
if transfer is not None:
child_transfers.append(transfer)
return child_transfers

View File

@ -1,11 +1,11 @@
from typing import List
from typing import List, Optional
from mev_inspect.schemas.blocks import TraceType
from mev_inspect.schemas.classified_traces import (
Classification,
ClassifiedTrace,
CallTrace,
DecodedCallTrace,
Protocol,
)
@ -18,7 +18,7 @@ def make_transfer_trace(
token_address: str,
amount: int,
):
return CallTrace(
return DecodedCallTrace(
transaction_hash=transaction_hash,
block_number=block_number,
type=TraceType.call,
@ -26,6 +26,9 @@ def make_transfer_trace(
classification=Classification.transfer,
from_address=from_address,
to_address=token_address,
abi_name="ERC20",
function_name="transfer",
function_signature="transfer(address,uint256)",
inputs={
"recipient": to_address,
"amount": amount,
@ -43,6 +46,8 @@ def make_swap_trace(
from_address: str,
pool_address: str,
abi_name: str,
function_signature: str,
protocol: Optional[Protocol],
recipient_address: str,
recipient_input_key: str,
):
@ -57,9 +62,10 @@ def make_swap_trace(
from_address=from_address,
to_address=pool_address,
function_name="swap",
function_signature="swap()",
function_signature=function_signature,
inputs={recipient_input_key: recipient_address},
abi_name=abi_name,
protocol=protocol,
block_hash=str(block_number),
)

View File

@ -1,9 +1,9 @@
from mev_inspect.arbitrages import get_arbitrages
from mev_inspect.schemas.swaps import Swap
from mev_inspect.swaps import (
from mev_inspect.classifiers.specs.uniswap import (
UNISWAP_V2_PAIR_ABI_NAME,
UNISWAP_V3_POOL_ABI_NAME,
)
from mev_inspect.schemas.swaps import Swap
def test_two_pool_arbitrage(get_transaction_hashes, get_addresses):

View File

@ -1,9 +1,10 @@
from mev_inspect.swaps import (
get_swaps,
from mev_inspect.swaps import get_swaps
from mev_inspect.classifiers.specs.balancer import BALANCER_V1_POOL_ABI_NAME
from mev_inspect.classifiers.specs.uniswap import (
UNISWAP_V2_PAIR_ABI_NAME,
UNISWAP_V3_POOL_ABI_NAME,
BALANCER_V1_POOL_ABI_NAME,
)
from mev_inspect.schemas.classified_traces import Protocol
from .helpers import (
make_unknown_trace,
@ -64,6 +65,8 @@ def test_swaps(
from_address=alice_address,
pool_address=first_pool_address,
abi_name=UNISWAP_V2_PAIR_ABI_NAME,
protocol=None,
function_signature="swap(uint256,uint256,address,bytes)",
recipient_address=bob_address,
recipient_input_key="to",
),
@ -83,6 +86,8 @@ def test_swaps(
from_address=bob_address,
pool_address=second_pool_address,
abi_name=UNISWAP_V3_POOL_ABI_NAME,
protocol=None,
function_signature="swap(address,bool,int256,uint160,bytes)",
recipient_address=carl_address,
recipient_input_key="recipient",
),
@ -129,6 +134,8 @@ def test_swaps(
from_address=bob_address,
pool_address=third_pool_address,
abi_name=BALANCER_V1_POOL_ABI_NAME,
protocol=Protocol.balancer_v1,
function_signature="swapExactAmountIn(address,uint256,address,uint256,uint256)",
recipient_address=bob_address,
recipient_input_key="recipient",
),