Refactoring to remove while loops, improve logic, and improve error-handling

This commit is contained in:
elicb 2022-04-24 18:41:45 -07:00 committed by Eli Barbieri
parent 5abfc38e12
commit e717585f23
2 changed files with 170 additions and 103 deletions

View File

@ -1,4 +1,6 @@
from typing import List, Tuple, Union from typing import List, Tuple
from pydantic import BaseModel
from mev_inspect.schemas.jit_liquidity import JITLiquidity from mev_inspect.schemas.jit_liquidity import JITLiquidity
from mev_inspect.schemas.swaps import Swap from mev_inspect.schemas.swaps import Swap
@ -8,7 +10,6 @@ from mev_inspect.schemas.traces import (
DecodedCallTrace, DecodedCallTrace,
Protocol, Protocol,
) )
from mev_inspect.schemas.transfers import Transfer
from mev_inspect.traces import is_child_trace_address from mev_inspect.traces import is_child_trace_address
from mev_inspect.transfers import get_net_transfers from mev_inspect.transfers import get_net_transfers
@ -17,6 +18,16 @@ LIQUIDITY_MINT_ROUTERS = [
] ]
class JITTransferInfo(BaseModel):
token0_address: str
token1_address: str
mint_token0: int
mint_token1: int
burn_token0: int
burn_token1: int
error: bool
def get_jit_liquidity( def get_jit_liquidity(
classified_traces: List[ClassifiedTrace], swaps: List[Swap] classified_traces: List[ClassifiedTrace], swaps: List[Swap]
) -> List[JITLiquidity]: ) -> List[JITLiquidity]:
@ -31,74 +42,73 @@ def get_jit_liquidity(
trace.classification == Classification.liquidity_mint trace.classification == Classification.liquidity_mint
and trace.protocol == Protocol.uniswap_v3 and trace.protocol == Protocol.uniswap_v3
): ):
i = index + 1 for search_trace in classified_traces[index:]:
while i < len(classified_traces): if (
forward_search_trace = classified_traces[i] search_trace.classification == Classification.liquidity_burn
if forward_search_trace.classification == Classification.liquidity_burn: and search_trace.to_address == trace.to_address
if forward_search_trace.to_address == trace.to_address: ):
jit_liquidity = _parse_jit_liquidity_instance(
trace, forward_search_trace, classified_traces, swaps bot_address = _get_bot_address(trace, classified_traces)
transfer_info: JITTransferInfo = _get_transfer_info(
classified_traces,
trace,
search_trace,
)
jit_swaps, token0_volume, token1_volume = _get_swap_info(
swaps, trace, search_trace, transfer_info.token0_address
)
# -- Error Checking Section --
if transfer_info.error or len(jit_swaps) == 0:
continue
jit_liquidity_instances.append(
JITLiquidity(
block_number=trace.block_number,
bot_address=bot_address,
pool_address=trace.to_address,
mint_transaction_hash=trace.transaction_hash,
mint_trace=trace.trace_address,
burn_transaction_hash=search_trace.transaction_hash,
burn_trace=search_trace.trace_address,
swaps=jit_swaps,
token0_address=transfer_info.token0_address,
token1_address=transfer_info.token1_address,
mint_token0_amount=transfer_info.mint_token0,
mint_token1_amount=transfer_info.mint_token1,
burn_token0_amount=transfer_info.burn_token0,
burn_token1_amount=transfer_info.burn_token1,
token0_swap_volume=token0_volume,
token1_swap_volume=token1_volume,
) )
if jit_liquidity is None: )
continue
jit_liquidity_instances.append(jit_liquidity)
i += 1
return jit_liquidity_instances return jit_liquidity_instances
def _parse_jit_liquidity_instance( def _get_token_order(token_a: str, token_b: str) -> Tuple[str, str]:
token_order = True if int(token_a, 16) < int(token_b, 16) else False
return (token_a, token_b) if token_order else (token_b, token_a)
def _get_swap_info(
swaps: List[Swap],
mint_trace: ClassifiedTrace, mint_trace: ClassifiedTrace,
burn_trace: ClassifiedTrace, burn_trace: ClassifiedTrace,
classified_traces: List[ClassifiedTrace], token0_address: str,
swaps: List[Swap], ) -> Tuple[List[Swap], int, int]:
) -> Union[JITLiquidity, None]:
valid_swaps = list(
filter(
lambda t: mint_trace.transaction_position
< t.transaction_position
< burn_trace.transaction_position,
swaps,
)
)
net_transfers = get_net_transfers(
list(
filter(
lambda t: t.transaction_hash
in [mint_trace.transaction_hash, burn_trace.transaction_hash],
classified_traces,
)
)
)
jit_swaps: List[Swap] = [] jit_swaps: List[Swap] = []
token0_swap_volume, token1_swap_volume = 0, 0 token0_swap_volume, token1_swap_volume = 0, 0
mint_transfers: List[Transfer] = list( ordered_swaps = sorted(
filter( swaps, key=lambda s: (s.transaction_position, s.trace_address)
lambda t: t.transaction_hash == mint_trace.transaction_hash
and t.to_address == mint_trace.to_address,
net_transfers,
)
)
burn_transfers: List[Transfer] = list(
filter(
lambda t: t.transaction_hash == burn_trace.transaction_hash
and t.from_address == burn_trace.to_address,
net_transfers,
)
) )
if len(mint_transfers) == 2 and len(burn_transfers) == 2: for swap in ordered_swaps:
token0_address, token1_address = _get_token_order( if swap.transaction_position <= mint_trace.transaction_position:
mint_transfers[0].token_address, mint_transfers[1].token_address continue
) if swap.transaction_position >= burn_trace.transaction_position:
else: break
# This is a failing/skipping case, super weird
return None
bot_address = _get_bot_address(mint_trace, classified_traces)
for swap in valid_swaps:
if swap.contract_address == mint_trace.to_address: if swap.contract_address == mint_trace.to_address:
jit_swaps.append(swap) jit_swaps.append(swap)
token0_swap_volume += ( token0_swap_volume += (
@ -108,44 +118,99 @@ def _parse_jit_liquidity_instance(
0 if swap.token_in_address == token0_address else swap.token_in_amount 0 if swap.token_in_address == token0_address else swap.token_in_amount
) )
token_order = mint_transfers[0].token_address == token0_address return jit_swaps, token0_swap_volume, token1_swap_volume
return JITLiquidity(
block_number=mint_trace.block_number, def _get_transfer_info(
bot_address=bot_address, classified_traces: List[ClassifiedTrace],
pool_address=mint_trace.to_address, mint_trace: ClassifiedTrace,
mint_transaction_hash=mint_trace.transaction_hash, burn_trace: ClassifiedTrace,
mint_trace=mint_trace.trace_address, ) -> JITTransferInfo:
burn_transaction_hash=burn_trace.transaction_hash,
burn_trace=burn_trace.trace_address, error_found = False
swaps=jit_swaps, mint_slice_start, mint_slice_end, burn_slice_start, burn_slice_end = (
None,
None,
None,
None,
)
# This would be cleaner with bisect(), but creates 3.10 dependency
for index, trace in enumerate(classified_traces):
if (
mint_slice_start is None
and trace.transaction_hash == mint_trace.transaction_hash
):
mint_slice_start = index
if (
mint_slice_end is None
and trace.transaction_position > mint_trace.transaction_position
):
mint_slice_end = index
if (
burn_slice_start is None
and trace.transaction_hash == burn_trace.transaction_hash
):
burn_slice_start = index
if (
burn_slice_end is None
and trace.transaction_position > burn_trace.transaction_position
):
burn_slice_end = index
break
mint_net_transfers_full = get_net_transfers(
classified_traces[mint_slice_start:mint_slice_end]
)
burn_net_transfers_full = get_net_transfers(
classified_traces[burn_slice_start:burn_slice_end]
)
mint_net_transfers, burn_net_transfers = [], []
pool_address = mint_trace.to_address
for transfer in mint_net_transfers_full:
if transfer.to_address == pool_address:
mint_net_transfers.append(transfer)
for transfer in burn_net_transfers_full:
if transfer.from_address == pool_address:
burn_net_transfers.append(transfer)
if len(mint_net_transfers) > 2 or len(burn_net_transfers) > 2:
error_found = True
token0_address, token1_address = _get_token_order(
mint_net_transfers[0].token_address, mint_net_transfers[1].token_address
)
if mint_net_transfers[0].token_address == token0_address:
mint_token0 = mint_net_transfers[0].amount
mint_token1 = mint_net_transfers[1].amount
else:
mint_token0 = mint_net_transfers[1].amount
mint_token1 = mint_net_transfers[0].amount
if burn_net_transfers[0].token_address == token0_address:
burn_token0 = burn_net_transfers[0].amount
burn_token1 = burn_net_transfers[1].amount
else:
burn_token0 = burn_net_transfers[1].amount
burn_token1 = burn_net_transfers[0].amount
return JITTransferInfo(
token0_address=token0_address, token0_address=token0_address,
token1_address=token1_address, token1_address=token1_address,
mint_token0_amount=mint_transfers[0].amount mint_token0=mint_token0,
if token_order mint_token1=mint_token1,
else mint_transfers[1].amount, burn_token0=burn_token0,
mint_token1_amount=mint_transfers[1].amount burn_token1=burn_token1,
if token_order error=error_found,
else mint_transfers[0].amount,
burn_token0_amount=burn_transfers[0].amount
if token_order
else burn_transfers[1].amount,
burn_token1_amount=burn_transfers[1].amount
if token_order
else burn_transfers[0].amount,
token0_swap_volume=token0_swap_volume,
token1_swap_volume=token1_swap_volume,
) )
def _get_token_order(token_a: str, token_b: str) -> Tuple[str, str]: def _get_bot_address(
token_order = True if int(token_a, 16) < int(token_b, 16) else False
return (token_a, token_b) if token_order else (token_b, token_a)
def _get_bot_address( # Janky and a half...
mint_trace: ClassifiedTrace, classified_traces: List[ClassifiedTrace] mint_trace: ClassifiedTrace, classified_traces: List[ClassifiedTrace]
) -> Union[str, None]: ) -> str:
if mint_trace.from_address in LIQUIDITY_MINT_ROUTERS: if mint_trace.from_address in LIQUIDITY_MINT_ROUTERS:
bot_trace = list( bot_trace = list(
filter( filter(
@ -154,13 +219,15 @@ def _get_bot_address( # Janky and a half...
classified_traces, classified_traces,
) )
) )
if len(bot_trace) == 1: if len(bot_trace) == 1 or is_child_trace_address(
return _get_bot_address(bot_trace[0], classified_traces)
elif is_child_trace_address(
bot_trace[1].trace_address, bot_trace[0].trace_address bot_trace[1].trace_address, bot_trace[0].trace_address
): ):
return _get_bot_address(bot_trace[0], classified_traces) return _get_bot_address(bot_trace[0], classified_traces)
else: else:
return None return "0x0000000000000000000000000000000000000000"
return mint_trace.from_address # This case is here because from_address is optional in ClassifiedTrace
if type(mint_trace.from_address) == str:
return mint_trace.from_address
else:
return "0x0000000000000000000000000000000000000000"

View File

@ -134,7 +134,7 @@ def get_net_transfers(
""" """
Super Jank... Super Jank...
Returns the net transfers per transaction from a list of Classified Traces. Returns the net transfers per transaction from a list of Classified Traces.
Ex. if a bot transfers 200 WETH to a contract, and the contract transfers the excess WETH back to the bot, Ex. if a bot transfers 200 WETH to a contract, and the contract transfers the excess 50 WETH back to the bot,
the following transfer would be returned (from_address=bot, to_address=contract, amount=150) the following transfer would be returned (from_address=bot, to_address=contract, amount=150)
if the contract transferred 300 WETH back to the bot, the following would be returned if the contract transferred 300 WETH back to the bot, the following would be returned
(from_address=contract, to_address=bot, amount=100). if the contract transferred back 200 WETH, (from_address=contract, to_address=bot, amount=100). if the contract transferred back 200 WETH,
@ -204,19 +204,19 @@ def get_net_transfers(
) )
found_transfers.append(sorted(net_search_info)) found_transfers.append(sorted(net_search_info))
i = 0 process_index = -1
while True: while True:
process_index += 1
try: try:
transfer = return_transfers[i] transfer = return_transfers[process_index]
except IndexError: except IndexError:
break break
if transfer.amount < 0: if transfer.amount < 0:
return_transfers[i].from_address = transfer.to_address return_transfers[process_index].from_address = transfer.to_address
return_transfers[i].to_address = transfer.from_address return_transfers[process_index].to_address = transfer.from_address
return_transfers[i].amount = transfer.amount * -1 return_transfers[process_index].amount = transfer.amount * -1
if transfer.amount == 0: if transfer.amount == 0:
return_transfers.pop(i) return_transfers.pop(process_index)
i -= 1 process_index -= 1
i += 1
return return_transfers return return_transfers