Skip to content

Commit

Permalink
Fix schema validation (#175)
Browse files Browse the repository at this point in the history
* Fix schema validation

Stop using the global schema in pure functions. Also now properly allows
custom schema validators.

Next step for fixing schema validation is #166

* Remove unused "validate" function
  • Loading branch information
bcb authored Jul 3, 2021
1 parent cb89a45 commit 96e72a4
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 49 deletions.
58 changes: 27 additions & 31 deletions jsonrpcserver/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import Any, Callable, Dict, List, Union

from apply_defaults import apply_config # type: ignore
from jsonschema import ValidationError # type: ignore
from jsonschema.validators import validator_for # type: ignore
from pkg_resources import resource_string # type: ignore

Expand All @@ -28,11 +27,14 @@
)
from .result import InvalidParams, InternalError, Result

# Prepare the jsonschema validator
global_schema = json.loads(resource_string(__name__, "request-schema.json"))
klass = validator_for(global_schema)
klass.check_schema(global_schema)
validator = klass(global_schema)
default_deserializer = json.loads

# Prepare the jsonschema validator. This is global so it loads only once, not every
# time dispatch is called.
schema = json.loads(resource_string(__name__, "request-schema.json"))
klass = validator_for(schema)
klass.check_schema(schema)
default_schema_validator = klass(schema).validate

# Read configuration file
config = ConfigParser(default_section="dispatch")
Expand Down Expand Up @@ -140,27 +142,13 @@ def create_requests(requests: Union[Dict, List[Dict]]) -> Union[Request, List[Re
)


def validate(request: Union[Dict, List]) -> Union[Dict, List]:
"""
Wraps jsonschema.validate, returning the same object passed in if successful.
Raises an exception if invalid.
Args:
request: The deserialized-from-json request.
Returns:
The same object passed in.
Raises:
jsonschema.ValidationError
"""
validator.validate(request)
return request


def dispatch_to_response_pure(
*, methods: Methods, context: Any, deserializer: Callable, request: str
*,
methods: Methods,
context: Any,
schema_validator: Callable,
deserializer: Callable,
request: str,
) -> Union[Response, List[Response], None]:
"""
Dispatch a JSON-serialized request string to methods.
Expand All @@ -186,9 +174,12 @@ def dispatch_to_response_pure(
# will be raised is unknown. Any exception is a parse error.
except Exception as exc:
return ParseErrorResponse(str(exc))
# As above, we don't know which validator will be used, so the specific
# exception that will be raised is unknown. Any exception is an invalid request
# error.
try:
validate(deserialized)
except ValidationError as exc:
schema_validator(deserialized)
except Exception as exc:
return InvalidRequestResponse("The request failed schema validation")
return dispatch_requests(
methods=methods, context=context, requests=create_requests(deserialized)
Expand All @@ -204,7 +195,8 @@ def dispatch_to_response(
methods: Methods = None,
*,
context: Any = None,
deserializer: Callable = json.loads,
schema_validator: Callable = default_schema_validator,
deserializer: Callable = default_deserializer,
) -> Union[Response, List[Response], None]:
"""
Dispatch a JSON-serialized request to methods.
Expand All @@ -218,9 +210,10 @@ def dispatch_to_response(
request: The JSON-RPC request string.
methods: Collection of methods that can be called. If not passed, uses the
internal methods object.
request: The incoming request string.
context: Will be passed to methods as the first param if not None.
schema_validator:
deserialize: Function that is used to deserialize data.
request: The incoming request string.
Returns:
A Response, list of Responses or None.
Expand All @@ -231,13 +224,16 @@ def dispatch_to_response(
return dispatch_to_response_pure(
methods=global_methods if methods is None else methods,
context=context,
schema_validator=schema_validator,
deserializer=deserializer,
request=request,
)


def dispatch_to_json(
*args: Any, serializer: Callable = json.dumps, **kwargs: Any
*args: Any,
serializer: Callable = json.dumps,
**kwargs: Any,
) -> str:
"""
This is the main public method, it goes through the entire JSON-RPC process
Expand Down
68 changes: 50 additions & 18 deletions tests/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from jsonrpcserver import status
from jsonrpcserver.dispatcher import (
create_requests,
default_deserializer,
default_schema_validator,
dispatch_request,
dispatch_to_response,
dispatch_to_response_pure,
Expand Down Expand Up @@ -117,7 +119,8 @@ def test_dispatch_to_response_pure():
response = dispatch_to_response_pure(
methods=Methods(ping),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "ping", "id": 1}',
)
assert isinstance(response, SuccessResponse)
Expand All @@ -129,7 +132,8 @@ def test_dispatch_to_response_pure_notification():
response = dispatch_to_response_pure(
methods=Methods(ping),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "ping"}',
)
assert response is None
Expand All @@ -139,7 +143,8 @@ def test_dispatch_to_response_pure_notification_invalid_jsonrpc():
response = dispatch_to_response_pure(
methods=Methods(ping),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "0", "method": "notify"}',
)
assert isinstance(response, ErrorResponse)
Expand All @@ -148,15 +153,23 @@ def test_dispatch_to_response_pure_notification_invalid_jsonrpc():
def test_dispatch_to_response_pure_invalid_json():
"""Unable to parse, must return an error"""
response = dispatch_to_response_pure(
methods=Methods(ping), context=None, deserializer=json.loads, request="{"
methods=Methods(ping),
context=None,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request="{",
)
assert isinstance(response, ErrorResponse)


def test_dispatch_to_response_pure_invalid_jsonrpc():
"""Invalid JSON-RPC, must return an error. (impossible to determine if notification)"""
response = dispatch_to_response_pure(
methods=Methods(ping), context=None, deserializer=json.loads, request="{}"
methods=Methods(ping),
context=None,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request="{}",
)
assert isinstance(response, ErrorResponse)

Expand All @@ -169,7 +182,8 @@ def foo(colour: str) -> Result:
response = dispatch_to_response_pure(
methods=Methods(foo),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "foo", "params": ["blue"], "id": 1}',
)
assert isinstance(response, ErrorResponse)
Expand All @@ -182,7 +196,8 @@ def foo(colour: str, size: str):
response = dispatch_to_response_pure(
methods=Methods(foo),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "foo", "params": {"colour":"blue"}, "id": 1}',
)
assert isinstance(response, ErrorResponse)
Expand Down Expand Up @@ -216,7 +231,8 @@ def subtract(minuend, subtrahend):
response = dispatch_to_response_pure(
methods=Methods(subtract),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "subtract", "params": [42, 23], "id": 1}',
)
assert isinstance(response, SuccessResponse)
Expand All @@ -226,7 +242,8 @@ def subtract(minuend, subtrahend):
response = dispatch_to_response_pure(
methods=Methods(subtract),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "subtract", "params": [23, 42], "id": 2}',
)
assert isinstance(response, SuccessResponse)
Expand All @@ -240,7 +257,8 @@ def subtract(**kwargs):
response = dispatch_to_response_pure(
methods=Methods(subtract),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "subtract", "params": {"subtrahend": 23, "minuend": 42}, "id": 3}',
)
assert isinstance(response, SuccessResponse)
Expand All @@ -250,7 +268,8 @@ def subtract(**kwargs):
response = dispatch_to_response_pure(
methods=Methods(subtract),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "subtract", "params": {"minuend": 42, "subtrahend": 23}, "id": 4}',
)
assert isinstance(response, SuccessResponse)
Expand All @@ -261,7 +280,8 @@ def test_examples_notification():
response = dispatch_to_response_pure(
methods=Methods(update=lambda: None, foobar=lambda: None),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "update", "params": [1, 2, 3, 4, 5]}',
)
assert response is None
Expand All @@ -270,7 +290,8 @@ def test_examples_notification():
response = dispatch_to_response_pure(
methods=Methods(update=lambda: None, foobar=lambda: None),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "foobar"}',
)
assert response is None
Expand All @@ -280,7 +301,8 @@ def test_examples_invalid_json():
response = dispatch_to_response_pure(
methods=Methods(ping),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='[{"jsonrpc": "2.0", "method": "sum", "params": [1,2,4], "id": "1"}, {"jsonrpc": "2.0", "method"]',
)
assert isinstance(response, ErrorResponse)
Expand All @@ -293,7 +315,8 @@ def test_examples_empty_array():
request="[]",
methods=Methods(ping),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
)
assert isinstance(response, ErrorResponse)
assert response.code == status.JSONRPC_INVALID_REQUEST_CODE
Expand All @@ -305,7 +328,11 @@ def test_examples_invalid_jsonrpc_batch():
The examples are expecting a batch response full of error responses.
"""
response = dispatch_to_response_pure(
methods=Methods(ping), context=None, deserializer=json.loads, request="[1]"
methods=Methods(ping),
context=None,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request="[1]",
)
assert isinstance(response, ErrorResponse)
assert response.code == status.JSONRPC_INVALID_REQUEST_CODE
Expand All @@ -319,7 +346,8 @@ def test_examples_multiple_invalid_jsonrpc():
response = dispatch_to_response_pure(
methods=Methods(ping),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request="[1, 2, 3]",
)
assert isinstance(response, ErrorResponse)
Expand Down Expand Up @@ -357,7 +385,11 @@ def test_examples_mixed_requests_and_notifications():
]
)
response = dispatch_to_response_pure(
methods=methods, context=None, deserializer=json.loads, request=requests
methods=methods,
context=None,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request=requests,
)
expected = [
SuccessResponse(
Expand Down

0 comments on commit 96e72a4

Please sign in to comment.