From ea07eb3a8f0d153d5051972037ffe1117a7bbc37 Mon Sep 17 00:00:00 2001 From: Luke Van Seters Date: Tue, 3 Aug 2021 18:38:01 -0400 Subject: [PATCH] More clean up. Support ERC transfers with UniswapV2 --- mev_inspect/strategies/arbitrage.py | 177 ++++++++++++---------------- 1 file changed, 76 insertions(+), 101 deletions(-) diff --git a/mev_inspect/strategies/arbitrage.py b/mev_inspect/strategies/arbitrage.py index cf12418..aa742f5 100644 --- a/mev_inspect/strategies/arbitrage.py +++ b/mev_inspect/strategies/arbitrage.py @@ -10,72 +10,15 @@ from mev_inspect.schemas.classified_traces import ( from mev_inspect.schemas.strategies import Arbitrage -class EthTransfer(BaseModel): - from_address: str - to_address: str - amount: int - - class Transfer(BaseModel): - protocol: Optional[Protocol] from_address: str to_address: str - token_address: str amount: int - - -class BalanceTracker: - def __init__(self): - # account address => token address => balance - self._balances: Dict[str, Dict[str, int]] = {} - - def update(self, transfer: Transfer) -> None: - self._change_balance( - transfer.from_address, - transfer.token_address, - -transfer.amount, - ) - - self._change_balance( - transfer.to_address, - transfer.token_address, - transfer.amount, - ) - - def get_all_balances( - self, - account_address: str, - ) -> Dict[str, int]: - return self._balances.get(account_address, {}) - - def get_balance( - self, - account_address: str, - token_address: str, - ) -> int: - return self.get_all_balances(account_address).get(token_address, 0) - - def _change_balance( - self, - account_address: str, - token_address: str, - amount_change: int, - ) -> None: - if account_address not in self._balances: - self._balances[account_address] = { - token_address: amount_change, - } - else: - existing_balances = self._balances[account_address] - existing_token_balance = existing_balances.get( - token_address, - 0, - ) - updated_token_balance = existing_token_balance + amount_change - self._balances[account_address][token_address] = updated_token_balance + token_address: str class Swap(BaseModel): + transaction_hash: str protocol: Optional[Protocol] pool_address: str from_address: str @@ -123,7 +66,6 @@ def _get_swaps(traces: List[ClassifiedTrace]) -> List[Swap]: ordered_traces = list(sorted(traces, key=lambda t: t.trace_address)) swaps: List[Swap] = [] - prior_transfers = [] for trace in ordered_traces: @@ -132,51 +74,86 @@ def _get_swaps(traces: List[ClassifiedTrace]) -> List[Swap]: prior_transfers.append(transfer) if trace.classification == Classification.swap: - immediate_child_traces = [ - t - for t in ordered_traces - if ( - len(t.trace_address) == len(trace.trace_address) + 1 - and ( - t.trace_address[: len(trace.trace_address)] - == trace.trace_address - ) - ) - ] - - if trace.abi_name == "UniswapV2Pair": - pool_address = trace.to_address - transfer_to_pool = [ - transfer - for transfer in prior_transfers - if transfer.to_address == pool_address - ][ - -1 - ] # todo - - internal_transfer = [ - _as_transfer(child) - for child in immediate_child_traces - if child.classification == Classification.transfer - ][ - 0 - ] # todo - - swap = Swap( - pool_address=pool_address, - from_address=transfer_to_pool.from_address, - to_address=internal_transfer.to_address, - token_in_address=transfer_to_pool.token_address, - token_in_amount=transfer_to_pool.amount, - token_out_address=internal_transfer.token_address, - token_out_amount=internal_transfer.amount, - ) + child_traces = _get_child_traces(trace.trace_address, traces) + swap = _build_swap(trace, prior_transfers, child_traces) + if swap is not None: swaps.append(swap) return swaps +def _build_swap( + trace: ClassifiedTrace, + prior_transfers: List[Transfer], + child_traces: List[ClassifiedTrace], +) -> Optional[Swap]: + if trace.abi_name == "UniswapV2Pair": + pool_address = trace.to_address + transfers_to_pool = [ + transfer + for transfer in prior_transfers + if transfer.to_address == pool_address + ] + + # expecting a prior transfer to the pool + if len(transfers_to_pool) == 0: + return None + + most_recent_transfer_to_pool = transfers_to_pool[-1] + all_pool_internal_transfers = [ + _as_transfer(child) + for child in child_traces + if child.classification == Classification.transfer + ] + + # expecting exactly one transfer inside the pool + if len(all_pool_internal_transfers) != 1: + return None + + pool_internal_transfer = all_pool_internal_transfers[0] + + return Swap( + transaction_hash=trace.transaction_hash, + pool_address=pool_address, + from_address=most_recent_transfer_to_pool.from_address, + to_address=pool_internal_transfer.to_address, + token_in_address=most_recent_transfer_to_pool.token_address, + token_in_amount=most_recent_transfer_to_pool.amount, + token_out_address=pool_internal_transfer.token_address, + token_out_amount=pool_internal_transfer.amount, + ) + + return None + + +def _get_child_traces( + parent_trace_address: List[int], + traces: List[ClassifiedTrace], +) -> List[ClassifiedTrace]: + ordered_traces = sorted(traces, key=lambda t: t.trace_address) + child_traces = [] + + for trace in ordered_traces: + if _is_child_trace_address( + parent_trace_address, + trace.trace_address, + ): + child_traces.append(trace) + + return child_traces + + +def _is_child_trace_address( + parent_trace_address: List[int], + maybe_child_trace_address: List[int], +) -> bool: + parent_trace_length = len(parent_trace_address) + return (len(maybe_child_trace_address) == parent_trace_length + 1) and ( + maybe_child_trace_address[:parent_trace_length] == parent_trace_address + ) + + def _as_transfer(trace: ClassifiedTrace) -> Transfer: # todo - this should be enforced at the data level if trace.inputs is None: @@ -184,7 +161,6 @@ def _as_transfer(trace: ClassifiedTrace) -> Transfer: if trace.protocol == Protocol.weth: return Transfer( - protocol=trace.protocol, amount=trace.inputs["wad"], to_address=trace.inputs["dst"], from_address=trace.from_address, @@ -192,7 +168,6 @@ def _as_transfer(trace: ClassifiedTrace) -> Transfer: ) else: return Transfer( - protocol=trace.protocol, amount=trace.inputs["amount"], to_address=trace.inputs["recipient"], from_address=trace.from_address,