From 538d62bd4bc0aef26364ab7e49a4a9e24a0b247b Mon Sep 17 00:00:00 2001 From: Emmanuel Evbuomwan Date: Wed, 29 May 2024 14:41:55 +0200 Subject: [PATCH] refactor: consolidate version resolution and validation logic --- karapace/schema_models.py | 38 ++++++++++- tests/unit/test_schema_models.py | 111 +++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 3 deletions(-) create mode 100644 tests/unit/test_schema_models.py diff --git a/karapace/schema_models.py b/karapace/schema_models.py index 46e3832d5..b2c7b6b94 100644 --- a/karapace/schema_models.py +++ b/karapace/schema_models.py @@ -10,7 +10,7 @@ from jsonschema import Draft7Validator from jsonschema.exceptions import SchemaError from karapace.dependency import Dependency -from karapace.errors import InvalidSchema +from karapace.errors import InvalidSchema, InvalidVersion, VersionNotFoundException from karapace.protobuf.exception import ( Error as ProtobufError, IllegalArgumentException, @@ -23,8 +23,8 @@ from karapace.protobuf.schema import ProtobufSchema from karapace.schema_references import Reference from karapace.schema_type import SchemaType -from karapace.typing import JsonObject, ResolvedVersion, SchemaId, Subject -from karapace.utils import assert_never, json_decode, json_encode, JSONDecodeError +from karapace.typing import JsonObject, ResolvedVersion, SchemaId, Subject, Version +from karapace.utils import assert_never, catch_and_raise_error, json_decode, json_encode, JSONDecodeError from typing import Any, cast, Dict, Final, final, Mapping, Sequence import hashlib @@ -388,3 +388,35 @@ class SchemaVersion: schema_id: SchemaId schema: TypedSchema references: Sequence[Reference] | None + + +class SchemaVersionManager: + LATEST_SCHEMA_VERSION_TAG: Final = "latest" + MINUS_1_SCHEMA_VERSION_TAG: Final = "-1" + + @classmethod + def latest_schema_tag_condition(cls, version: Version): + return (str(version) == cls.LATEST_SCHEMA_VERSION_TAG) or (str(version) == cls.MINUS_1_SCHEMA_VERSION_TAG) + + @classmethod + @catch_and_raise_error(to_catch=(ValueError,), to_raise=VersionNotFoundException) + def resolve_version( + cls, + schema_versions: Mapping[ResolvedVersion, SchemaVersion], + version: Version, + ) -> ResolvedVersion | None: + max_version = max(schema_versions) + if cls.latest_schema_tag_condition(version): + return max_version + if (int(version) <= max_version) and (int(version) >= int(cls.MINUS_1_SCHEMA_VERSION_TAG)): + return ResolvedVersion(version) + return None + + @classmethod + @catch_and_raise_error(to_catch=(ValueError,), to_raise=InvalidVersion) + def validate_version(cls, version: Version) -> Version | str | None: + if cls.latest_schema_tag_condition(version): + return cls.LATEST_SCHEMA_VERSION_TAG + if int(version) > 0: + return version + return None diff --git a/tests/unit/test_schema_models.py b/tests/unit/test_schema_models.py new file mode 100644 index 000000000..391f12630 --- /dev/null +++ b/tests/unit/test_schema_models.py @@ -0,0 +1,111 @@ +""" +karapace - Test schema models + +Copyright (c) 2024 Aiven Ltd +See LICENSE for details +""" + +from avro.schema import Schema as AvroSchema +from karapace.errors import InvalidVersion, VersionNotFoundException +from karapace.schema_models import parse_avro_schema_definition, SchemaVersion, SchemaVersionManager, TypedSchema +from karapace.schema_type import SchemaType +from karapace.typing import ResolvedVersion, Version +from typing import Callable + +import pytest + +# Schema versions factory fixture type +SVFCallable = Callable[[None], Callable[[ResolvedVersion, dict[str]], dict[ResolvedVersion, SchemaVersion]]] + + +class TestSchemaVersionManager: + @pytest.fixture + def avro_schema(self) -> str: + return '{"type":"record","name":"testRecord","fields":[{"type":"string","name":"test"}]}' + + @pytest.fixture + def avro_schema_parsed(self, avro_schema: str) -> AvroSchema: + return parse_avro_schema_definition(avro_schema) + + @pytest.fixture + def schema_versions_factory( + self, + avro_schema: str, + avro_schema_parsed: AvroSchema, + ) -> Callable[[ResolvedVersion, dict[str]], dict[ResolvedVersion, SchemaVersion]]: + def schema_versions(resolved_version: int, schema_version_data: dict[str] | None = None): + if schema_version_data is None: + schema_version_data = dict() + base_schema_version_data = dict( + subject="test-topic", + version=resolved_version, + deleted=False, + schema_id=1, + schema=TypedSchema( + schema_type=SchemaType.AVRO, + schema_str=avro_schema, + schema=avro_schema_parsed, + ), + references=None, + ) + return {ResolvedVersion(resolved_version): SchemaVersion(**{**base_schema_version_data, **schema_version_data})} + + return schema_versions + + def test_schema_version_manager_tags(self): + assert SchemaVersionManager.LATEST_SCHEMA_VERSION_TAG == "latest" + assert SchemaVersionManager.MINUS_1_SCHEMA_VERSION_TAG == "-1" + + @pytest.mark.parametrize( + "version, is_latest", + [("latest", True), ("-1", True), ("-20", False), (10, False)], + ) + def test_schema_version_manager_latest_schema_tag_condition( + self, + version: Version, + is_latest: bool, + ): + assert SchemaVersionManager.latest_schema_tag_condition(version) is is_latest + + @pytest.mark.parametrize("invalid_version", ["invalid_version", 0]) + def test_schema_version_manager_validate_version_invalid(self, invalid_version: str | int): + with pytest.raises(InvalidVersion): + SchemaVersionManager.validate_version(invalid_version) + + @pytest.mark.parametrize( + "version, validated_version", + [("latest", "latest"), (-1, "latest"), ("-1", "latest"), (10, 10)], + ) + def test_schema_version_manager_validate_version( + self, + version: Version, + validated_version: Version, + ): + assert SchemaVersionManager.validate_version(version) == validated_version + + @pytest.mark.parametrize( + "version, resolved_version", + [("-1", 10), (-1, 10), (1, 1), (10, 10), ("latest", 10)], + ) + def test_schema_version_manager_resolve_version( + self, + version: Version, + resolved_version: ResolvedVersion, + schema_versions_factory: SVFCallable, + ): + schema_versions = dict() + schema_versions.update(schema_versions_factory(1)) + schema_versions.update(schema_versions_factory(2)) + schema_versions.update(schema_versions_factory(10)) + assert SchemaVersionManager.resolve_version(schema_versions, version) == resolved_version + + @pytest.mark.parametrize("invalid_version", ["invalid_version", 0, -20, "-10", "100", 2000]) + def test_schema_version_manager_resolve_version_invalid( + self, + invalid_version: str | int, + schema_versions_factory: SVFCallable, + ): + schema_versions = dict() + schema_versions.update(schema_versions_factory(1)) + with pytest.raises(VersionNotFoundException): + SchemaVersionManager.resolve_version(schema_versions, invalid_version)