Skip to content

Commit

Permalink
refactor: use SchemaVersionManager to handle version logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nosahama committed May 29, 2024
1 parent 538d62b commit 59046e7
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 61 deletions.
10 changes: 5 additions & 5 deletions karapace/schema_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from karapace.schema_references import Reference
from karapace.schema_type import SchemaType
from karapace.typing import JsonObject, ResolvedVersion, SchemaId, Subject, Version
from karapace.utils import assert_never, catch_and_raise_error, json_decode, json_encode, JSONDecodeError
from karapace.utils import assert_never, intstr_conversion_guard, json_decode, json_encode, JSONDecodeError
from typing import Any, cast, Dict, Final, final, Mapping, Sequence

import hashlib
Expand Down Expand Up @@ -395,11 +395,11 @@ class SchemaVersionManager:
MINUS_1_SCHEMA_VERSION_TAG: Final = "-1"

@classmethod
def latest_schema_tag_condition(cls, version: Version):
def latest_schema_tag_condition(cls, version: Version) -> bool:
return (str(version) == cls.LATEST_SCHEMA_VERSION_TAG) or (str(version) == cls.MINUS_1_SCHEMA_VERSION_TAG)

@classmethod
@catch_and_raise_error(to_catch=(ValueError,), to_raise=VersionNotFoundException)
@intstr_conversion_guard(to_raise=VersionNotFoundException())
def resolve_version(
cls,
schema_versions: Mapping[ResolvedVersion, SchemaVersion],
Expand All @@ -409,11 +409,11 @@ def resolve_version(
if cls.latest_schema_tag_condition(version):
return max_version
if (int(version) <= max_version) and (int(version) >= int(cls.MINUS_1_SCHEMA_VERSION_TAG)):
return ResolvedVersion(version)
return ResolvedVersion(int(version))
return None

@classmethod
@catch_and_raise_error(to_catch=(ValueError,), to_raise=InvalidVersion)
@intstr_conversion_guard(to_raise=InvalidVersion())
def validate_version(cls, version: Version) -> Version | str | None:
if cls.latest_schema_tag_condition(version):
return cls.LATEST_SCHEMA_VERSION_TAG
Expand Down
48 changes: 15 additions & 33 deletions karapace/schema_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from karapace.dependency import Dependency
from karapace.errors import (
IncompatibleSchema,
InvalidVersion,
ReferenceExistsException,
SchemasNotFoundException,
SchemaVersionNotSoftDeletedException,
Expand All @@ -26,43 +25,25 @@
from karapace.master_coordinator import MasterCoordinator
from karapace.messaging import KarapaceProducer
from karapace.offset_watcher import OffsetWatcher
from karapace.schema_models import ParsedTypedSchema, SchemaType, SchemaVersion, TypedSchema, ValidatedTypedSchema
from karapace.schema_models import (
ParsedTypedSchema,
SchemaType,
SchemaVersion,
SchemaVersionManager,
TypedSchema,
ValidatedTypedSchema,
)
from karapace.schema_reader import KafkaSchemaReader
from karapace.schema_references import LatestVersionReference, Reference
from karapace.typing import JsonObject, Mode, ResolvedVersion, SchemaId, Subject, Version
from typing import Mapping, Sequence
from typing import Sequence

import asyncio
import logging

LOG = logging.getLogger(__name__)


def _resolve_version(
schema_versions: Mapping[ResolvedVersion, SchemaVersion],
version: Version,
) -> ResolvedVersion:
max_version = max(schema_versions)
if isinstance(version, str) and version == "latest":
return max_version
resolved_version = ResolvedVersion(int(version))
if resolved_version <= max_version:
return resolved_version
raise VersionNotFoundException()


def validate_version(version: Version) -> Version:
try:
version_number = int(version)
if version_number > 0:
return version
raise InvalidVersion(f"Invalid version {version_number}")
except ValueError as ex:
if version == "latest":
return version
raise InvalidVersion(f"Invalid version {version}") from ex


class KarapaceSchemaRegistry:
def __init__(self, config: Config) -> None:
self.config = config
Expand All @@ -82,6 +63,7 @@ def __init__(self, config: Config) -> None:
master_coordinator=self.mc,
database=self.database,
)
self.schema_version_manager = SchemaVersionManager()

self.schema_lock = asyncio.Lock()
self._master_lock = asyncio.Lock()
Expand Down Expand Up @@ -222,7 +204,7 @@ async def subject_version_delete_local(self, subject: Subject, version: Version,
for version_id, schema_version in schema_versions.items()
if schema_version.deleted is False
}
resolved_version = _resolve_version(schema_versions=schema_versions, version=version)
resolved_version = self.schema_version_manager.resolve_version(schema_versions=schema_versions, version=version)
schema_version = schema_versions.get(resolved_version, None)

if not schema_version:
Expand Down Expand Up @@ -261,11 +243,11 @@ def subject_get(self, subject: Subject, include_deleted: bool = False) -> dict[R
return schemas

def subject_version_get(self, subject: Subject, version: Version, *, include_deleted: bool = False) -> JsonObject:
validate_version(version)
self.schema_version_manager.validate_version(version)
schema_versions = self.subject_get(subject, include_deleted=include_deleted)
if not schema_versions:
raise SubjectNotFoundException()
resolved_version = _resolve_version(schema_versions=schema_versions, version=version)
resolved_version = self.schema_version_manager.resolve_version(schema_versions=schema_versions, version=version)
schema_data: SchemaVersion | None = schema_versions.get(resolved_version, None)

if not schema_data:
Expand Down Expand Up @@ -293,11 +275,11 @@ def subject_version_get(self, subject: Subject, version: Version, *, include_del
async def subject_version_referencedby_get(
self, subject: Subject, version: Version, *, include_deleted: bool = False
) -> list:
validate_version(version)
self.schema_version_manager.validate_version(version)
schema_versions = self.subject_get(subject, include_deleted=include_deleted)
if not schema_versions:
raise SubjectNotFoundException()
resolved_version = _resolve_version(schema_versions=schema_versions, version=version)
resolved_version = self.schema_version_manager.resolve_version(schema_versions=schema_versions, version=version)
schema_data: SchemaVersion | None = schema_versions.get(resolved_version, None)
if not schema_data:
raise VersionNotFoundException()
Expand Down
13 changes: 10 additions & 3 deletions karapace/schema_registry_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,16 @@
from karapace.karapace import KarapaceBase
from karapace.protobuf.exception import ProtobufUnresolvedDependencyException
from karapace.rapu import HTTPRequest, JSON_CONTENT_TYPE, SERVER_NAME
from karapace.schema_models import ParsedTypedSchema, SchemaType, SchemaVersion, TypedSchema, ValidatedTypedSchema
from karapace.schema_models import (
ParsedTypedSchema,
SchemaType,
SchemaVersion,
SchemaVersionManager,
TypedSchema,
ValidatedTypedSchema,
)
from karapace.schema_references import LatestVersionReference, Reference, reference_from_mapping
from karapace.schema_registry import KarapaceSchemaRegistry, validate_version
from karapace.schema_registry import KarapaceSchemaRegistry
from karapace.typing import JsonData, JsonObject, ResolvedVersion, SchemaId, Subject
from karapace.utils import JSONDecodeError
from typing import Any
Expand Down Expand Up @@ -814,7 +821,7 @@ async def subject_version_delete(
self, content_type: str, *, subject: str, version: str, request: HTTPRequest, user: User | None = None
) -> None:
self._check_authorization(user, Operation.Write, f"Subject:{subject}")
version = validate_version(version)
version = SchemaVersionManager.validate_version(version)
permanent = request.query.get("permanent", "false").lower() == "true"

are_we_master, master_url = await self.schema_registry.get_master()
Expand Down
4 changes: 2 additions & 2 deletions karapace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def convert_to_int(object_: dict, key: str, content_type: str) -> None:
)


def catch_and_raise_error(to_catch: tuple[Exception], to_raise: Exception):
def intstr_conversion_guard(to_raise: BaseException):
def wrapper(f):
@functools.wraps(f)
def catcher(*args, **kwargs):
Expand All @@ -224,7 +224,7 @@ def catcher(*args, **kwargs):
if not value:
raise to_raise
return value
except to_catch as exc:
except ValueError as exc:
raise to_raise from exc

return catcher
Expand Down
11 changes: 5 additions & 6 deletions tests/unit/test_schema_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from karapace.schema_models import parse_avro_schema_definition, SchemaVersion, SchemaVersionManager, TypedSchema
from karapace.schema_type import SchemaType
from karapace.typing import ResolvedVersion, Version
from typing import Callable
from typing import Any, Callable, Dict

import pytest

# Schema versions factory fixture type
SVFCallable = Callable[[None], Callable[[ResolvedVersion, dict[str]], dict[ResolvedVersion, SchemaVersion]]]
SVFCallable = Callable[[None], Callable[[ResolvedVersion, Dict[str, Any]], Dict[ResolvedVersion, SchemaVersion]]]


class TestSchemaVersionManager:
Expand All @@ -32,10 +32,9 @@ def schema_versions_factory(
self,
avro_schema: str,
avro_schema_parsed: AvroSchema,
) -> Callable[[ResolvedVersion, dict[str]], dict[ResolvedVersion, SchemaVersion]]:
def schema_versions(resolved_version: int, schema_version_data: dict[str] | None = None):
if schema_version_data is None:
schema_version_data = dict()
) -> Callable[[ResolvedVersion, Dict[str, Any]], Dict[ResolvedVersion, SchemaVersion]]:
def schema_versions(resolved_version: int, schema_version_data: Dict[str, Any] | None = None):
schema_version_data = schema_version_data or dict()
base_schema_version_data = dict(
subject="test-topic",
version=resolved_version,
Expand Down
17 changes: 5 additions & 12 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,18 @@
See LICENSE for details
"""

from karapace.utils import catch_and_raise_error
from karapace.utils import intstr_conversion_guard

import pytest


def test_catch_and_raise_error():
def test_intstr_conversion_guard():
class RaiseMe(Exception):
pass

@catch_and_raise_error(to_catch=(ValueError,), to_raise=RaiseMe)
def v():
@intstr_conversion_guard(to_raise=RaiseMe)
def raise_value_error():
int("not a number")

with pytest.raises(RaiseMe):
v()

@catch_and_raise_error(to_catch=(ZeroDivisionError,), to_raise=RaiseMe)
def z():
_ = 100 / 0

with pytest.raises(RaiseMe):
z()
raise_value_error()

0 comments on commit 59046e7

Please sign in to comment.