Skip to content

Commit

Permalink
Fix schema validation
Browse files Browse the repository at this point in the history
Stop using the global schema in pure functions. Also now properly allows
custom schema validators.

Next step for fixing schema validation is #166
  • Loading branch information
bcb committed Jul 3, 2021
1 parent cb89a45 commit f8f5352
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 33 deletions.
45 changes: 30 additions & 15 deletions jsonrpcserver/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import Any, Callable, Dict, List, Union

from apply_defaults import apply_config # type: ignore
from jsonschema import ValidationError # type: ignore
from jsonschema.validators import validator_for # type: ignore
from pkg_resources import resource_string # type: ignore

Expand All @@ -28,11 +27,14 @@
)
from .result import InvalidParams, InternalError, Result

# Prepare the jsonschema validator
global_schema = json.loads(resource_string(__name__, "request-schema.json"))
klass = validator_for(global_schema)
klass.check_schema(global_schema)
validator = klass(global_schema)
default_deserializer = json.loads

# Prepare the jsonschema validator. This is global so it loads only once, not every
# time dispatch is called.
schema = json.loads(resource_string(__name__, "request-schema.json"))
klass = validator_for(schema)
klass.check_schema(schema)
default_schema_validator = klass(schema).validate

# Read configuration file
config = ConfigParser(default_section="dispatch")
Expand Down Expand Up @@ -140,7 +142,7 @@ def create_requests(requests: Union[Dict, List[Dict]]) -> Union[Request, List[Re
)


def validate(request: Union[Dict, List]) -> Union[Dict, List]:
def validate(validator: Callable, request: Union[Dict, List]) -> Union[Dict, List]:
"""
Wraps jsonschema.validate, returning the same object passed in if successful.
Expand All @@ -153,14 +155,19 @@ def validate(request: Union[Dict, List]) -> Union[Dict, List]:
The same object passed in.
Raises:
jsonschema.ValidationError
An exception,
"""
validator.validate(request)
validator(request)
return request


def dispatch_to_response_pure(
*, methods: Methods, context: Any, deserializer: Callable, request: str
*,
methods: Methods,
context: Any,
schema_validator: Callable,
deserializer: Callable,
request: str,
) -> Union[Response, List[Response], None]:
"""
Dispatch a JSON-serialized request string to methods.
Expand All @@ -186,9 +193,12 @@ def dispatch_to_response_pure(
# will be raised is unknown. Any exception is a parse error.
except Exception as exc:
return ParseErrorResponse(str(exc))
# As above, we don't know which validator will be used, so the specific
# exception that will be raised is unknown. Any exception is an invalid request
# error.
try:
validate(deserialized)
except ValidationError as exc:
schema_validator(deserialized)
except Exception as exc:
return InvalidRequestResponse("The request failed schema validation")
return dispatch_requests(
methods=methods, context=context, requests=create_requests(deserialized)
Expand All @@ -204,7 +214,8 @@ def dispatch_to_response(
methods: Methods = None,
*,
context: Any = None,
deserializer: Callable = json.loads,
schema_validator: Callable = default_schema_validator,
deserializer: Callable = default_deserializer,
) -> Union[Response, List[Response], None]:
"""
Dispatch a JSON-serialized request to methods.
Expand All @@ -218,9 +229,10 @@ def dispatch_to_response(
request: The JSON-RPC request string.
methods: Collection of methods that can be called. If not passed, uses the
internal methods object.
request: The incoming request string.
context: Will be passed to methods as the first param if not None.
schema_validator:
deserialize: Function that is used to deserialize data.
request: The incoming request string.
Returns:
A Response, list of Responses or None.
Expand All @@ -231,13 +243,16 @@ def dispatch_to_response(
return dispatch_to_response_pure(
methods=global_methods if methods is None else methods,
context=context,
schema_validator=schema_validator,
deserializer=deserializer,
request=request,
)


def dispatch_to_json(
*args: Any, serializer: Callable = json.dumps, **kwargs: Any
*args: Any,
serializer: Callable = json.dumps,
**kwargs: Any,
) -> str:
"""
This is the main public method, it goes through the entire JSON-RPC process
Expand Down
68 changes: 50 additions & 18 deletions tests/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from jsonrpcserver import status
from jsonrpcserver.dispatcher import (
create_requests,
default_deserializer,
default_schema_validator,
dispatch_request,
dispatch_to_response,
dispatch_to_response_pure,
Expand Down Expand Up @@ -117,7 +119,8 @@ def test_dispatch_to_response_pure():
response = dispatch_to_response_pure(
methods=Methods(ping),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "ping", "id": 1}',
)
assert isinstance(response, SuccessResponse)
Expand All @@ -129,7 +132,8 @@ def test_dispatch_to_response_pure_notification():
response = dispatch_to_response_pure(
methods=Methods(ping),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "ping"}',
)
assert response is None
Expand All @@ -139,7 +143,8 @@ def test_dispatch_to_response_pure_notification_invalid_jsonrpc():
response = dispatch_to_response_pure(
methods=Methods(ping),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "0", "method": "notify"}',
)
assert isinstance(response, ErrorResponse)
Expand All @@ -148,15 +153,23 @@ def test_dispatch_to_response_pure_notification_invalid_jsonrpc():
def test_dispatch_to_response_pure_invalid_json():
"""Unable to parse, must return an error"""
response = dispatch_to_response_pure(
methods=Methods(ping), context=None, deserializer=json.loads, request="{"
methods=Methods(ping),
context=None,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request="{",
)
assert isinstance(response, ErrorResponse)


def test_dispatch_to_response_pure_invalid_jsonrpc():
"""Invalid JSON-RPC, must return an error. (impossible to determine if notification)"""
response = dispatch_to_response_pure(
methods=Methods(ping), context=None, deserializer=json.loads, request="{}"
methods=Methods(ping),
context=None,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request="{}",
)
assert isinstance(response, ErrorResponse)

Expand All @@ -169,7 +182,8 @@ def foo(colour: str) -> Result:
response = dispatch_to_response_pure(
methods=Methods(foo),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "foo", "params": ["blue"], "id": 1}',
)
assert isinstance(response, ErrorResponse)
Expand All @@ -182,7 +196,8 @@ def foo(colour: str, size: str):
response = dispatch_to_response_pure(
methods=Methods(foo),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "foo", "params": {"colour":"blue"}, "id": 1}',
)
assert isinstance(response, ErrorResponse)
Expand Down Expand Up @@ -216,7 +231,8 @@ def subtract(minuend, subtrahend):
response = dispatch_to_response_pure(
methods=Methods(subtract),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "subtract", "params": [42, 23], "id": 1}',
)
assert isinstance(response, SuccessResponse)
Expand All @@ -226,7 +242,8 @@ def subtract(minuend, subtrahend):
response = dispatch_to_response_pure(
methods=Methods(subtract),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "subtract", "params": [23, 42], "id": 2}',
)
assert isinstance(response, SuccessResponse)
Expand All @@ -240,7 +257,8 @@ def subtract(**kwargs):
response = dispatch_to_response_pure(
methods=Methods(subtract),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "subtract", "params": {"subtrahend": 23, "minuend": 42}, "id": 3}',
)
assert isinstance(response, SuccessResponse)
Expand All @@ -250,7 +268,8 @@ def subtract(**kwargs):
response = dispatch_to_response_pure(
methods=Methods(subtract),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "subtract", "params": {"minuend": 42, "subtrahend": 23}, "id": 4}',
)
assert isinstance(response, SuccessResponse)
Expand All @@ -261,7 +280,8 @@ def test_examples_notification():
response = dispatch_to_response_pure(
methods=Methods(update=lambda: None, foobar=lambda: None),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "update", "params": [1, 2, 3, 4, 5]}',
)
assert response is None
Expand All @@ -270,7 +290,8 @@ def test_examples_notification():
response = dispatch_to_response_pure(
methods=Methods(update=lambda: None, foobar=lambda: None),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='{"jsonrpc": "2.0", "method": "foobar"}',
)
assert response is None
Expand All @@ -280,7 +301,8 @@ def test_examples_invalid_json():
response = dispatch_to_response_pure(
methods=Methods(ping),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request='[{"jsonrpc": "2.0", "method": "sum", "params": [1,2,4], "id": "1"}, {"jsonrpc": "2.0", "method"]',
)
assert isinstance(response, ErrorResponse)
Expand All @@ -293,7 +315,8 @@ def test_examples_empty_array():
request="[]",
methods=Methods(ping),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
)
assert isinstance(response, ErrorResponse)
assert response.code == status.JSONRPC_INVALID_REQUEST_CODE
Expand All @@ -305,7 +328,11 @@ def test_examples_invalid_jsonrpc_batch():
The examples are expecting a batch response full of error responses.
"""
response = dispatch_to_response_pure(
methods=Methods(ping), context=None, deserializer=json.loads, request="[1]"
methods=Methods(ping),
context=None,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request="[1]",
)
assert isinstance(response, ErrorResponse)
assert response.code == status.JSONRPC_INVALID_REQUEST_CODE
Expand All @@ -319,7 +346,8 @@ def test_examples_multiple_invalid_jsonrpc():
response = dispatch_to_response_pure(
methods=Methods(ping),
context=None,
deserializer=json.loads,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request="[1, 2, 3]",
)
assert isinstance(response, ErrorResponse)
Expand Down Expand Up @@ -357,7 +385,11 @@ def test_examples_mixed_requests_and_notifications():
]
)
response = dispatch_to_response_pure(
methods=methods, context=None, deserializer=json.loads, request=requests
methods=methods,
context=None,
schema_validator=default_schema_validator,
deserializer=default_deserializer,
request=requests,
)
expected = [
SuccessResponse(
Expand Down

0 comments on commit f8f5352

Please sign in to comment.