From cde6bcaf3ddbd37a499b9a21be011c35f79025e3 Mon Sep 17 00:00:00 2001 From: Luke Van Seters Date: Tue, 3 Aug 2021 09:43:38 -0400 Subject: [PATCH] Save progress before switching directions --- mev_inspect/strategies/arbitrage.py | 196 +++++++++++++++++++++++++++- mev_inspect/traces.py | 11 +- 2 files changed, 202 insertions(+), 5 deletions(-) diff --git a/mev_inspect/strategies/arbitrage.py b/mev_inspect/strategies/arbitrage.py index ccc7410..cf12418 100644 --- a/mev_inspect/strategies/arbitrage.py +++ b/mev_inspect/strategies/arbitrage.py @@ -1,8 +1,200 @@ -from typing import List +from typing import Dict, List, Optional -from mev_inspect.schemas.classified_traces import ClassifiedTrace +from pydantic import BaseModel + +from mev_inspect.schemas.classified_traces import ( + ClassifiedTrace, + Classification, + Protocol, +) from mev_inspect.schemas.strategies import Arbitrage +class EthTransfer(BaseModel): + from_address: str + to_address: str + amount: int + + +class Transfer(BaseModel): + protocol: Optional[Protocol] + from_address: str + to_address: str + token_address: str + amount: int + + +class BalanceTracker: + def __init__(self): + # account address => token address => balance + self._balances: Dict[str, Dict[str, int]] = {} + + def update(self, transfer: Transfer) -> None: + self._change_balance( + transfer.from_address, + transfer.token_address, + -transfer.amount, + ) + + self._change_balance( + transfer.to_address, + transfer.token_address, + transfer.amount, + ) + + def get_all_balances( + self, + account_address: str, + ) -> Dict[str, int]: + return self._balances.get(account_address, {}) + + def get_balance( + self, + account_address: str, + token_address: str, + ) -> int: + return self.get_all_balances(account_address).get(token_address, 0) + + def _change_balance( + self, + account_address: str, + token_address: str, + amount_change: int, + ) -> None: + if account_address not in self._balances: + self._balances[account_address] = { + token_address: amount_change, + } + else: + existing_balances = self._balances[account_address] + existing_token_balance = existing_balances.get( + token_address, + 0, + ) + updated_token_balance = existing_token_balance + amount_change + self._balances[account_address][token_address] = updated_token_balance + + +class Swap(BaseModel): + protocol: Optional[Protocol] + pool_address: str + from_address: str + to_address: str + token_in_address: str + token_in_amount: int + token_out_address: str + token_out_amount: int + + def get_arbitrages(traces: List[ClassifiedTrace]) -> List[Arbitrage]: + all_arbitrages = [] + traces_by_transaction = _group_traces_by_transaction(traces) + + for transaction_traces in traces_by_transaction.values(): + all_arbitrages += _get_arbitrages_for_transaction( + transaction_traces, + ) + + return all_arbitrages + + +def _group_traces_by_transaction( + traces: List[ClassifiedTrace], +) -> Dict[str, List[ClassifiedTrace]]: + grouped_traces: Dict[str, List[ClassifiedTrace]] = {} + + for trace in traces: + hash = trace.transaction_hash + existing_traces = grouped_traces.get(hash, []) + grouped_traces[hash] = existing_traces + [trace] + + return grouped_traces + + +def _get_arbitrages_for_transaction( + traces: List[ClassifiedTrace], +) -> List[Arbitrage]: + swaps = _get_swaps(traces) + print(swaps) return [] + + +def _get_swaps(traces: List[ClassifiedTrace]) -> List[Swap]: + ordered_traces = list(sorted(traces, key=lambda t: t.trace_address)) + + swaps: List[Swap] = [] + + prior_transfers = [] + + for trace in ordered_traces: + if trace.classification == Classification.transfer: + transfer = _as_transfer(trace) + prior_transfers.append(transfer) + + if trace.classification == Classification.swap: + immediate_child_traces = [ + t + for t in ordered_traces + if ( + len(t.trace_address) == len(trace.trace_address) + 1 + and ( + t.trace_address[: len(trace.trace_address)] + == trace.trace_address + ) + ) + ] + + if trace.abi_name == "UniswapV2Pair": + pool_address = trace.to_address + transfer_to_pool = [ + transfer + for transfer in prior_transfers + if transfer.to_address == pool_address + ][ + -1 + ] # todo + + internal_transfer = [ + _as_transfer(child) + for child in immediate_child_traces + if child.classification == Classification.transfer + ][ + 0 + ] # todo + + swap = Swap( + pool_address=pool_address, + from_address=transfer_to_pool.from_address, + to_address=internal_transfer.to_address, + token_in_address=transfer_to_pool.token_address, + token_in_amount=transfer_to_pool.amount, + token_out_address=internal_transfer.token_address, + token_out_amount=internal_transfer.amount, + ) + + swaps.append(swap) + + return swaps + + +def _as_transfer(trace: ClassifiedTrace) -> Transfer: + # todo - this should be enforced at the data level + if trace.inputs is None: + raise ValueError("Invalid transfer") + + if trace.protocol == Protocol.weth: + return Transfer( + protocol=trace.protocol, + amount=trace.inputs["wad"], + to_address=trace.inputs["dst"], + from_address=trace.from_address, + token_address=trace.to_address, + ) + else: + return Transfer( + protocol=trace.protocol, + amount=trace.inputs["amount"], + to_address=trace.inputs["recipient"], + from_address=trace.from_address, + token_address=trace.to_address, + ) diff --git a/mev_inspect/traces.py b/mev_inspect/traces.py index 052d8be..642b1b1 100644 --- a/mev_inspect/traces.py +++ b/mev_inspect/traces.py @@ -20,7 +20,9 @@ def _get_transaction_hash(trace) -> str: return trace.transaction_hash -def _as_nested_traces_by_transaction(traces: Iterable[ClassifiedTrace]) -> List[NestedTrace]: +def _as_nested_traces_by_transaction( + traces: Iterable[ClassifiedTrace], +) -> List[NestedTrace]: """ Turns a list of Traces into a a tree of NestedTraces using their trace addresses @@ -45,7 +47,7 @@ def _as_nested_traces_by_transaction(traces: Iterable[ClassifiedTrace]) -> List[ children = [] continue - elif not _is_subtrace(trace, parent): + elif not is_subtrace(trace, parent): nested_traces.append( NestedTrace( trace=parent, @@ -70,7 +72,10 @@ def _as_nested_traces_by_transaction(traces: Iterable[ClassifiedTrace]) -> List[ return nested_traces -def _is_subtrace(trace: ClassifiedTrace, parent: ClassifiedTrace): +def is_subtrace(trace: ClassifiedTrace, parent: ClassifiedTrace): + if trace.transaction_hash != parent.transaction_hash: + return False + parent_trace_length = len(parent.trace_address) if len(trace.trace_address) > parent_trace_length: