async middleware - geth poa

This commit is contained in:
Supragya Raj 2021-12-02 15:09:13 +01:00
parent 0895a0f1cd
commit d1a1a53101
4 changed files with 136 additions and 32 deletions

View File

@ -80,19 +80,19 @@ async def _fetch_block(
else:
raise
else:
block_json = await w3.eth.get_block(block_number)
print(block_json)
traces = geth_get_tx_traces_parity_format(base_provider, block_json)
geth_tx_receipts = geth_get_tx_receipts(
base_provider, block_json["transactions"]
# print(block_number)
block_json = await asyncio.gather(w3.eth.get_block(block_number))
traces = await geth_get_tx_traces_parity_format(base_provider, block_json[0])
geth_tx_receipts = await geth_get_tx_receipts_async(
base_provider.endpoint_uri, block_json[0]["transactions"]
)
receipts = geth_receipts_translator(block_json, geth_tx_receipts)
receipts = geth_receipts_translator(block_json[0], geth_tx_receipts)
base_fee_per_gas = 0
return Block(
block_number=block_number,
block_timestamp=block_json["timestamp"],
miner=block_json["miner"],
block_timestamp=block_json[0]["timestamp"],
miner=block_json[0]["miner"],
base_fee_per_gas=base_fee_per_gas,
traces=traces,
receipts=receipts,
@ -220,9 +220,11 @@ def get_transaction_hashes(calls: List[Trace]) -> List[str]:
# Geth specific additions
def geth_get_tx_traces_parity_format(base_provider, block_json: dict):
async def geth_get_tx_traces_parity_format(base_provider, block_json: dict):
# print(block_json['hash'].hex())
block_hash = block_json["hash"]
block_trace = geth_get_tx_traces(base_provider, block_hash)
block_trace = await geth_get_tx_traces(base_provider, block_hash)
# print(block_trace)
parity_traces = []
for idx, trace in enumerate(block_trace["result"]):
if "result" in trace:
@ -232,8 +234,8 @@ def geth_get_tx_traces_parity_format(base_provider, block_json: dict):
return parity_traces
def geth_get_tx_traces(base_provider, block_hash):
block_trace = base_provider.make_request(
async def geth_get_tx_traces(base_provider, block_hash):
block_trace = await base_provider.make_request(
"debug_traceBlockByHash", [block_hash.hex(), {"tracer": "callTracer"}]
)
return block_trace
@ -314,12 +316,6 @@ async def geth_get_tx_receipts_async(endpoint_uri, transactions):
return [json.loads(tx_receipts) for tx_receipts in geth_tx_receipts]
def geth_get_tx_receipts(base_provider, transactions):
return asyncio.run(
geth_get_tx_receipts_async(base_provider.endpoint_uri, transactions)
)
def geth_receipts_translator(block_json, geth_tx_receipts) -> List[Receipt]:
json_decoded_receipts = [
tx_receipt["result"]

View File

@ -0,0 +1,99 @@
from typing import (
Any,
Callable,
)
from hexbytes import (
HexBytes,
)
from eth_utils.curried import (
apply_formatter_if,
apply_formatters_to_dict,
apply_key_map,
is_null,
)
from eth_utils.toolz import (
complement,
compose,
assoc,
)
from web3._utils.rpc_abi import (
RPC,
)
from web3.types import (
Formatters,
RPCEndpoint,
RPCResponse,
)
from web3 import Web3 # noqa: F401
async def get_geth_poa_middleware(
make_request: Callable[[RPCEndpoint, Any], RPCResponse],
request_formatters: Formatters = {},
result_formatters: Formatters = {},
error_formatters: Formatters = {},
) -> RPCResponse:
async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse:
if method in request_formatters:
formatter = request_formatters[method]
formatted_params = formatter(params)
response = await make_request(method, formatted_params)
else:
response = await make_request(method, params)
if "result" in response and method in result_formatters:
formatter = result_formatters[method]
formatted_response = assoc(
response,
"result",
formatter(response["result"]),
)
return formatted_response
elif "error" in response and method in error_formatters:
formatter = error_formatters[method]
formatted_response = assoc(
response,
"error",
formatter(response["error"]),
)
return formatted_response
else:
return response
return middleware
is_not_null = complement(is_null)
remap_geth_poa_fields = apply_key_map(
{
"extraData": "proofOfAuthorityData",
}
)
pythonic_geth_poa = apply_formatters_to_dict(
{
"proofOfAuthorityData": HexBytes,
}
)
geth_poa_cleanup = compose(pythonic_geth_poa, remap_geth_poa_fields)
async def geth_poa_middleware(
make_request: Callable[[RPCEndpoint, Any], Any], *_: Web3
):
return await get_geth_poa_middleware(
make_request=make_request,
request_formatters={},
result_formatters={
RPC.eth_getBlockByHash: apply_formatter_if(is_not_null, geth_poa_cleanup),
RPC.eth_getBlockByNumber: apply_formatter_if(is_not_null, geth_poa_cleanup),
},
error_formatters={},
)

View File

@ -7,7 +7,6 @@ from typing import Optional
from sqlalchemy import orm
from web3 import Web3
from web3.eth import AsyncEth
from web3.middleware import geth_poa_middleware
from mev_inspect.block import create_from_block_number
from mev_inspect.classifiers.trace import TraceClassifier
@ -29,18 +28,19 @@ class MEVInspector:
):
self.inspect_db_session = inspect_db_session
self.trace_db_session = trace_db_session
self.base_provider = get_base_provider(rpc, request_timeout=request_timeout)
self.base_provider = get_base_provider(rpc, request_timeout, geth)
self.geth = geth
if geth:
self.w3 = Web3(
self.base_provider,
modules={"eth": (AsyncEth,)},
middlewares=[geth_poa_middleware],
)
else:
self.w3 = Web3(
self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[]
)
self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
# if geth:
# self.w3 = Web3(
# self.base_provider,
# modules={"eth": (AsyncEth,)},
# middlewares=[],
# )
# else:
# self.w3 = Web3(
# self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[]
# )
self.trace_classifier = TraceClassifier()
self.max_concurrency = asyncio.Semaphore(max_concurrency)

View File

@ -1,9 +1,18 @@
from web3 import Web3, AsyncHTTPProvider
from mev_inspect.retry import http_retry_with_backoff_request_middleware
from mev_inspect.geth_poa_middleware import geth_poa_middleware
def get_base_provider(rpc: str, request_timeout: int = 500) -> Web3.AsyncHTTPProvider:
def get_base_provider(
rpc: str, request_timeout: int = 500, geth: bool = False
) -> Web3.AsyncHTTPProvider:
base_provider = AsyncHTTPProvider(rpc, request_kwargs={"timeout": request_timeout})
base_provider.middlewares += (http_retry_with_backoff_request_middleware,)
if geth:
base_provider.middlewares += (
geth_poa_middleware,
http_retry_with_backoff_request_middleware,
)
else:
base_provider.middlewares += (http_retry_with_backoff_request_middleware,)
return base_provider