async middleware - geth poa

This commit is contained in:
Supragya Raj 2021-12-02 15:09:13 +01:00
parent 0895a0f1cd
commit d1a1a53101
4 changed files with 136 additions and 32 deletions

View File

@ -80,19 +80,19 @@ async def _fetch_block(
else: else:
raise raise
else: else:
block_json = await w3.eth.get_block(block_number) # print(block_number)
print(block_json) block_json = await asyncio.gather(w3.eth.get_block(block_number))
traces = geth_get_tx_traces_parity_format(base_provider, block_json) traces = await geth_get_tx_traces_parity_format(base_provider, block_json[0])
geth_tx_receipts = geth_get_tx_receipts( geth_tx_receipts = await geth_get_tx_receipts_async(
base_provider, block_json["transactions"] base_provider.endpoint_uri, block_json[0]["transactions"]
) )
receipts = geth_receipts_translator(block_json, geth_tx_receipts) receipts = geth_receipts_translator(block_json[0], geth_tx_receipts)
base_fee_per_gas = 0 base_fee_per_gas = 0
return Block( return Block(
block_number=block_number, block_number=block_number,
block_timestamp=block_json["timestamp"], block_timestamp=block_json[0]["timestamp"],
miner=block_json["miner"], miner=block_json[0]["miner"],
base_fee_per_gas=base_fee_per_gas, base_fee_per_gas=base_fee_per_gas,
traces=traces, traces=traces,
receipts=receipts, receipts=receipts,
@ -220,9 +220,11 @@ def get_transaction_hashes(calls: List[Trace]) -> List[str]:
# Geth specific additions # Geth specific additions
def geth_get_tx_traces_parity_format(base_provider, block_json: dict): async def geth_get_tx_traces_parity_format(base_provider, block_json: dict):
# print(block_json['hash'].hex())
block_hash = block_json["hash"] block_hash = block_json["hash"]
block_trace = geth_get_tx_traces(base_provider, block_hash) block_trace = await geth_get_tx_traces(base_provider, block_hash)
# print(block_trace)
parity_traces = [] parity_traces = []
for idx, trace in enumerate(block_trace["result"]): for idx, trace in enumerate(block_trace["result"]):
if "result" in trace: if "result" in trace:
@ -232,8 +234,8 @@ def geth_get_tx_traces_parity_format(base_provider, block_json: dict):
return parity_traces return parity_traces
def geth_get_tx_traces(base_provider, block_hash): async def geth_get_tx_traces(base_provider, block_hash):
block_trace = base_provider.make_request( block_trace = await base_provider.make_request(
"debug_traceBlockByHash", [block_hash.hex(), {"tracer": "callTracer"}] "debug_traceBlockByHash", [block_hash.hex(), {"tracer": "callTracer"}]
) )
return block_trace return block_trace
@ -314,12 +316,6 @@ async def geth_get_tx_receipts_async(endpoint_uri, transactions):
return [json.loads(tx_receipts) for tx_receipts in geth_tx_receipts] return [json.loads(tx_receipts) for tx_receipts in geth_tx_receipts]
def geth_get_tx_receipts(base_provider, transactions):
return asyncio.run(
geth_get_tx_receipts_async(base_provider.endpoint_uri, transactions)
)
def geth_receipts_translator(block_json, geth_tx_receipts) -> List[Receipt]: def geth_receipts_translator(block_json, geth_tx_receipts) -> List[Receipt]:
json_decoded_receipts = [ json_decoded_receipts = [
tx_receipt["result"] tx_receipt["result"]

View File

@ -0,0 +1,99 @@
from typing import (
Any,
Callable,
)
from hexbytes import (
HexBytes,
)
from eth_utils.curried import (
apply_formatter_if,
apply_formatters_to_dict,
apply_key_map,
is_null,
)
from eth_utils.toolz import (
complement,
compose,
assoc,
)
from web3._utils.rpc_abi import (
RPC,
)
from web3.types import (
Formatters,
RPCEndpoint,
RPCResponse,
)
from web3 import Web3 # noqa: F401
async def get_geth_poa_middleware(
make_request: Callable[[RPCEndpoint, Any], RPCResponse],
request_formatters: Formatters = {},
result_formatters: Formatters = {},
error_formatters: Formatters = {},
) -> RPCResponse:
async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse:
if method in request_formatters:
formatter = request_formatters[method]
formatted_params = formatter(params)
response = await make_request(method, formatted_params)
else:
response = await make_request(method, params)
if "result" in response and method in result_formatters:
formatter = result_formatters[method]
formatted_response = assoc(
response,
"result",
formatter(response["result"]),
)
return formatted_response
elif "error" in response and method in error_formatters:
formatter = error_formatters[method]
formatted_response = assoc(
response,
"error",
formatter(response["error"]),
)
return formatted_response
else:
return response
return middleware
is_not_null = complement(is_null)
remap_geth_poa_fields = apply_key_map(
{
"extraData": "proofOfAuthorityData",
}
)
pythonic_geth_poa = apply_formatters_to_dict(
{
"proofOfAuthorityData": HexBytes,
}
)
geth_poa_cleanup = compose(pythonic_geth_poa, remap_geth_poa_fields)
async def geth_poa_middleware(
make_request: Callable[[RPCEndpoint, Any], Any], *_: Web3
):
return await get_geth_poa_middleware(
make_request=make_request,
request_formatters={},
result_formatters={
RPC.eth_getBlockByHash: apply_formatter_if(is_not_null, geth_poa_cleanup),
RPC.eth_getBlockByNumber: apply_formatter_if(is_not_null, geth_poa_cleanup),
},
error_formatters={},
)

View File

@ -7,7 +7,6 @@ from typing import Optional
from sqlalchemy import orm from sqlalchemy import orm
from web3 import Web3 from web3 import Web3
from web3.eth import AsyncEth from web3.eth import AsyncEth
from web3.middleware import geth_poa_middleware
from mev_inspect.block import create_from_block_number from mev_inspect.block import create_from_block_number
from mev_inspect.classifiers.trace import TraceClassifier from mev_inspect.classifiers.trace import TraceClassifier
@ -29,18 +28,19 @@ class MEVInspector:
): ):
self.inspect_db_session = inspect_db_session self.inspect_db_session = inspect_db_session
self.trace_db_session = trace_db_session self.trace_db_session = trace_db_session
self.base_provider = get_base_provider(rpc, request_timeout=request_timeout) self.base_provider = get_base_provider(rpc, request_timeout, geth)
self.geth = geth self.geth = geth
if geth: self.w3 = Web3(self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[])
self.w3 = Web3( # if geth:
self.base_provider, # self.w3 = Web3(
modules={"eth": (AsyncEth,)}, # self.base_provider,
middlewares=[geth_poa_middleware], # modules={"eth": (AsyncEth,)},
) # middlewares=[],
else: # )
self.w3 = Web3( # else:
self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[] # self.w3 = Web3(
) # self.base_provider, modules={"eth": (AsyncEth,)}, middlewares=[]
# )
self.trace_classifier = TraceClassifier() self.trace_classifier = TraceClassifier()
self.max_concurrency = asyncio.Semaphore(max_concurrency) self.max_concurrency = asyncio.Semaphore(max_concurrency)

View File

@ -1,9 +1,18 @@
from web3 import Web3, AsyncHTTPProvider from web3 import Web3, AsyncHTTPProvider
from mev_inspect.retry import http_retry_with_backoff_request_middleware from mev_inspect.retry import http_retry_with_backoff_request_middleware
from mev_inspect.geth_poa_middleware import geth_poa_middleware
def get_base_provider(rpc: str, request_timeout: int = 500) -> Web3.AsyncHTTPProvider: def get_base_provider(
rpc: str, request_timeout: int = 500, geth: bool = False
) -> Web3.AsyncHTTPProvider:
base_provider = AsyncHTTPProvider(rpc, request_kwargs={"timeout": request_timeout}) base_provider = AsyncHTTPProvider(rpc, request_kwargs={"timeout": request_timeout})
if geth:
base_provider.middlewares += (
geth_poa_middleware,
http_retry_with_backoff_request_middleware,
)
else:
base_provider.middlewares += (http_retry_with_backoff_request_middleware,) base_provider.middlewares += (http_retry_with_backoff_request_middleware,)
return base_provider return base_provider