diff --git a/tests/core/eth-module/test_poa.py b/tests/core/eth-module/test_poa.py index d3364ce79a..ce1be2abc2 100644 --- a/tests/core/eth-module/test_poa.py +++ b/tests/core/eth-module/test_poa.py @@ -1,5 +1,8 @@ import pytest +from web3.exceptions import ( + ValidationError, +) from web3.middleware import ( construct_fixture_middleware, geth_poa_middleware, @@ -12,7 +15,7 @@ def test_long_extra_data(web3): 'eth_getBlockByNumber': {'extraData': '0x' + 'ff' * 33}, }) web3.middleware_stack.inject(return_block_with_long_extra_data, layer=0) - with pytest.raises(ValueError): + with pytest.raises(ValidationError): web3.eth.getBlock('latest') diff --git a/web3/middleware/formatting.py b/web3/middleware/formatting.py index 79fcf35b35..989bb71ff3 100644 --- a/web3/middleware/formatting.py +++ b/web3/middleware/formatting.py @@ -1,44 +1,71 @@ +from cytoolz import ( + curry, + merge, +) from cytoolz.dicttoolz import ( assoc, ) -def construct_formatting_middleware(request_formatters=None, - result_formatters=None, - error_formatters=None): - if request_formatters is None: - request_formatters = {} - if result_formatters is None: - result_formatters = {} - if error_formatters is None: - error_formatters = {} - - def formatter_middleware(make_request, web3): - def middleware(method, params): - if method in request_formatters: - formatter = request_formatters[method] - formatted_params = formatter(params) - response = make_request(method, formatted_params) - else: - response = make_request(method, params) - - if 'result' in response and method in result_formatters: - formatter = result_formatters[method] - formatted_response = assoc( - response, - 'result', - formatter(response['result']), - ) - return formatted_response - elif 'error' in response and method in error_formatters: - formatter = error_formatters[method] - formatted_response = assoc( - response, - 'error', - formatter(response['error']), - ) - return formatted_response - else: - return response - return middleware +def construct_formatting_middleware( + request_formatters=None, + result_formatters=None, + error_formatters=None): + def ignore_web3_in_standard_formatters(w3): + return dict( + request_formatters=request_formatters or {}, + result_formatters=result_formatters or {}, + error_formatters=error_formatters or {}, + ) + + return construct_web3_formatting_middleware(ignore_web3_in_standard_formatters) + + +def construct_web3_formatting_middleware(web3_formatters_builder): + def formatter_middleware(make_request, w3): + formatters = merge( + { + 'request_formatters': {}, + 'result_formatters': {}, + 'error_formatters': {}, + }, + web3_formatters_builder(w3), + ) + return apply_formatters(make_request=make_request, **formatters) + return formatter_middleware + + +@curry +def apply_formatters( + method, + params, + make_request, + request_formatters, + result_formatters, + error_formatters): + if method in request_formatters: + formatter = request_formatters[method] + formatted_params = formatter(params) + response = make_request(method, formatted_params) + else: + response = make_request(method, params) + + if 'result' in response and method in result_formatters: + formatter = result_formatters[method] + formatted_response = assoc( + response, + 'result', + formatter(response['result']), + ) + return formatted_response + elif 'error' in response and method in error_formatters: + formatter = error_formatters[method] + formatted_response = assoc( + response, + 'error', + formatter(response['error']), + ) + return formatted_response + else: + return response diff --git a/web3/middleware/validation.py b/web3/middleware/validation.py index 9669ae9a70..c293a341a2 100644 --- a/web3/middleware/validation.py +++ b/web3/middleware/validation.py @@ -1,16 +1,30 @@ from cytoolz import ( + complement, compose, curry, dissoc, ) from eth_utils.curried import ( apply_formatter_at_index, + apply_formatter_if, apply_formatters_to_dict, + is_null, +) +from hexbytes import ( + HexBytes, ) from web3.exceptions import ( ValidationError, ) +from web3.middleware.formatting import ( + construct_web3_formatting_middleware, +) + +MAX_EXTRADATA_LENGTH = 32 + + +is_not_null = complement(is_null) @curry @@ -27,20 +41,72 @@ def validate_chain_id(web3, chain_id): ) +def check_extradata_length(val): + if not isinstance(val, (str, int, bytes)): + return val + result = HexBytes(val) + if len(result) > MAX_EXTRADATA_LENGTH: + raise ValidationError( + "The field extraData is %d bytes, but should be %d. " + "It is quite likely that you are connected to a POA chain. " + "Refer " + "http://web3py.readthedocs.io/en/stable/middleware.html#geth-style-proof-of-authority " + "for more details. The full extraData is: %r" % ( + len(result), MAX_EXTRADATA_LENGTH, result + ) + ) + return val + + def transaction_normalizer(transaction): return dissoc(transaction, 'chainId') -def validation_middleware(make_request, web3): - transaction_validator = apply_formatters_to_dict({ - }) +def transaction_param_validator(web3): + transactions_params_validators = { + 'chainId': apply_formatter_if( + # Bypass `validate_chain_id` if chainId can't be determined + lambda _: is_not_null(web3.net.chainId), + validate_chain_id(web3) + ), + } + return apply_formatter_at_index( + apply_formatters_to_dict(transactions_params_validators), + 0 + ) + + +BLOCK_VALIDATORS = { + 'extraData': check_extradata_length, +} + + +block_validator = apply_formatter_if( + is_not_null, + apply_formatters_to_dict(BLOCK_VALIDATORS) +) + + +@curry +def chain_id_validator(web3): + return compose( + apply_formatter_at_index(transaction_normalizer, 0), + transaction_param_validator(web3) + ) + + +def build_validators_with_web3(w3): + return dict( + request_formatters={ + 'eth_sendTransaction': chain_id_validator(w3), + 'eth_estimateGas': chain_id_validator(w3), + 'eth_call': chain_id_validator(w3), + }, + result_formatters={ + 'eth_getBlockByHash': block_validator, + 'eth_getBlockByNumber': block_validator, + }, + ) - transaction_sanitizer = compose(transaction_normalizer, transaction_validator) - def middleware(method, params): - if method in {'eth_sendTransaction', 'eth_estimateGas', 'eth_call'}: - post_validated_params = apply_formatter_at_index(transaction_sanitizer, 0, params) - return make_request(method, post_validated_params) - else: - return make_request(method, params) - return middleware +validation_middleware = construct_web3_formatting_middleware(build_validators_with_web3)