Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

normalization: add normalization of the options #848

Merged
merged 4 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion karapace/kafka_rest_apis/consumer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,6 @@ async def fetch(self, internal_name: Tuple[str, str], content_type: str, formats
)
# we get to be more in line with the confluent proxy by doing a bunch of fetches each time and
# respecting the max fetch request size
# pylint: disable=protected-access
max_bytes = (
int(query_params["max_bytes"])
if "max_bytes" in query_params
Expand Down
2 changes: 2 additions & 0 deletions karapace/protobuf/option_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@


class OptionElement:
name: str

class Kind(Enum):
STRING = 1
BOOLEAN = 2
Expand Down
220 changes: 220 additions & 0 deletions karapace/protobuf/proto_normalizations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""
Copyright (c) 2024 Aiven Ltd
See LICENSE for details
"""
from karapace.protobuf.enum_constant_element import EnumConstantElement
from karapace.protobuf.enum_element import EnumElement
from karapace.protobuf.extend_element import ExtendElement
from karapace.protobuf.field_element import FieldElement
from karapace.protobuf.group_element import GroupElement
from karapace.protobuf.message_element import MessageElement
from karapace.protobuf.one_of_element import OneOfElement
from karapace.protobuf.option_element import OptionElement
from karapace.protobuf.proto_file_element import ProtoFileElement
from karapace.protobuf.rpc_element import RpcElement
from karapace.protobuf.service_element import ServiceElement
from karapace.protobuf.type_element import TypeElement
from karapace.typing import StrEnum
from typing import List


class ProtobufNormalisationOptions(StrEnum):
eliax1996 marked this conversation as resolved.
Show resolved Hide resolved
sort_options = "sort_options"


def sort_by_name(element: OptionElement) -> str:
return element.name


def type_field_element_with_sorted_options(type_field: FieldElement) -> FieldElement:
sorted_options = None if type_field.options is None else list(sorted(type_field.options, key=sort_by_name))
return FieldElement(
location=type_field.location,
label=type_field.label,
element_type=type_field.element_type,
name=type_field.name,
default_value=type_field.default_value,
json_name=type_field.json_name,
tag=type_field.tag,
documentation=type_field.documentation,
options=sorted_options,
)


def enum_constant_element_with_sorted_options(enum_constant: EnumConstantElement) -> EnumConstantElement:
sorted_options = None if enum_constant.options is None else list(sorted(enum_constant.options, key=sort_by_name))
return EnumConstantElement(
location=enum_constant.location,
name=enum_constant.name,
tag=enum_constant.tag,
documentation=enum_constant.documentation,
options=sorted_options,
)


def enum_element_with_sorted_options(enum_element: EnumElement) -> EnumElement:
aiven-anton marked this conversation as resolved.
Show resolved Hide resolved
sorted_options = None if enum_element.options is None else list(sorted(enum_element.options, key=sort_by_name))
constants_with_sorted_options = (
None
if enum_element.constants is None
else [enum_constant_element_with_sorted_options(constant) for constant in enum_element.constants]
)
return EnumElement(
location=enum_element.location,
name=enum_element.name,
documentation=enum_element.documentation,
options=sorted_options,
constants=constants_with_sorted_options,
)


def groups_with_sorted_options(group: GroupElement) -> GroupElement:
sorted_fields = (
None if group.fields is None else [type_field_element_with_sorted_options(field) for field in group.fields]
)
return GroupElement(
label=group.label,
location=group.location,
name=group.name,
tag=group.tag,
documentation=group.documentation,
fields=sorted_fields,
)


def one_ofs_with_sorted_options(one_ofs: OneOfElement) -> OneOfElement:
sorted_options = None if one_ofs.options is None else list(sorted(one_ofs.options, key=sort_by_name))
sorted_fields = [type_field_element_with_sorted_options(field) for field in one_ofs.fields]
sorted_groups = [groups_with_sorted_options(group) for group in one_ofs.groups]

return OneOfElement(
name=one_ofs.name,
documentation=one_ofs.documentation,
fields=sorted_fields,
groups=sorted_groups,
options=sorted_options,
)


def message_element_with_sorted_options(message_element: MessageElement) -> MessageElement:
sorted_options = None if message_element.options is None else list(sorted(message_element.options, key=sort_by_name))
sorted_neasted_types = [type_element_with_sorted_options(nested_type) for nested_type in message_element.nested_types]
eliax1996 marked this conversation as resolved.
Show resolved Hide resolved
sorted_fields = [type_field_element_with_sorted_options(field) for field in message_element.fields]
sorted_one_ofs = [one_ofs_with_sorted_options(one_of) for one_of in message_element.one_ofs]

return MessageElement(
location=message_element.location,
name=message_element.name,
documentation=message_element.documentation,
nested_types=sorted_neasted_types,
options=sorted_options,
reserveds=message_element.reserveds,
fields=sorted_fields,
one_ofs=sorted_one_ofs,
extensions=message_element.extensions,
groups=message_element.groups,
)


def type_element_with_sorted_options(type_element: TypeElement) -> TypeElement:
sorted_neasted_types: List[TypeElement] = []
eliax1996 marked this conversation as resolved.
Show resolved Hide resolved

for nested_type in type_element.nested_types:
if isinstance(nested_type, EnumElement):
sorted_neasted_types.append(enum_element_with_sorted_options(nested_type))
elif isinstance(nested_type, MessageElement):
sorted_neasted_types.append(message_element_with_sorted_options(nested_type))
else:
raise ValueError("Unknown type element") # tried with assert_never but it did not work
eliax1996 marked this conversation as resolved.
Show resolved Hide resolved

# doing it here since the subtypes do not declare the nested_types property
type_element.nested_types = sorted_neasted_types

if isinstance(type_element, EnumElement):
return enum_element_with_sorted_options(type_element)

if isinstance(type_element, MessageElement):
return message_element_with_sorted_options(type_element)

raise ValueError("Unknown type element") # tried with assert_never but it did not work
eliax1996 marked this conversation as resolved.
Show resolved Hide resolved


def extends_element_with_sorted_options(extend_element: ExtendElement) -> ExtendElement:
sorted_fields = (
None
if extend_element.fields is None
else [type_field_element_with_sorted_options(field) for field in extend_element.fields]
)
return ExtendElement(
location=extend_element.location,
name=extend_element.name,
documentation=extend_element.documentation,
fields=sorted_fields,
)


def rpc_element_with_sorted_options(rpc: RpcElement) -> RpcElement:
sorted_options = None if rpc.options is None else list(sorted(rpc.options, key=sort_by_name))
return RpcElement(
location=rpc.location,
name=rpc.name,
documentation=rpc.documentation,
request_type=rpc.request_type,
response_type=rpc.response_type,
request_streaming=rpc.request_streaming,
response_streaming=rpc.response_streaming,
options=sorted_options,
)


def service_element_with_sorted_options(service_element: ServiceElement) -> ServiceElement:
sorted_options = None if service_element.options is None else list(sorted(service_element.options, key=sort_by_name))
sorted_rpc = (
None if service_element.rpcs is None else [rpc_element_with_sorted_options(rpc) for rpc in service_element.rpcs]
)

return ServiceElement(
location=service_element.location,
name=service_element.name,
documentation=service_element.documentation,
rpcs=sorted_rpc,
options=sorted_options,
)


def normalize_options_ordered(proto_file_element: ProtoFileElement) -> ProtoFileElement:
eliax1996 marked this conversation as resolved.
Show resolved Hide resolved
sorted_types = [type_element_with_sorted_options(type_element) for type_element in proto_file_element.types]
sorted_options = (
None if proto_file_element.options is None else list(sorted(proto_file_element.options, key=sort_by_name))
)
sorted_services = (
None
if proto_file_element.services is None
else [service_element_with_sorted_options(service) for service in proto_file_element.services]
)
sorted_extend_declarations = (
None
if proto_file_element.extend_declarations is None
else [extends_element_with_sorted_options(extend) for extend in proto_file_element.extend_declarations]
)

return ProtoFileElement(
location=proto_file_element.location,
package_name=proto_file_element.package_name,
syntax=proto_file_element.syntax,
imports=proto_file_element.imports,
public_imports=proto_file_element.public_imports,
types=sorted_types,
services=sorted_services,
extend_declarations=sorted_extend_declarations,
options=sorted_options,
)


# if other normalizations are added we will switch to a more generic approach:
# def normalize_parsed_file(proto_file_element: ProtoFileElement,
# normalization: ProtobufNormalisationOptions) -> ProtoFileElement:
# if normalization == ProtobufNormalisationOptions.sort_options:
# return normalize_options_ordered(proto_file_element)
# else:
# assert_never(normalization)
13 changes: 12 additions & 1 deletion karapace/schema_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ProtobufUnresolvedDependencyException,
SchemaParseException as ProtobufSchemaParseException,
)
from karapace.protobuf.proto_normalizations import normalize_options_ordered
from karapace.protobuf.schema import ProtobufSchema
from karapace.schema_references import Reference
from karapace.schema_type import SchemaType
Expand Down Expand Up @@ -62,6 +63,7 @@ def parse_protobuf_schema_definition(
references: Sequence[Reference] | None = None,
dependencies: Mapping[str, Dependency] | None = None,
validate_references: bool = True,
normalize: bool = False,
) -> ProtobufSchema:
"""Parses and validates `schema_definition`.

Expand All @@ -74,6 +76,10 @@ def parse_protobuf_schema_definition(
result = protobuf_schema.verify_schema_dependencies()
if not result.result:
raise ProtobufUnresolvedDependencyException(f"{result.message}")

if protobuf_schema.proto_file_element is not None and normalize:
protobuf_schema.proto_file_element = normalize_options_ordered(protobuf_schema.proto_file_element)

return protobuf_schema


Expand Down Expand Up @@ -179,6 +185,7 @@ def parse(
validate_avro_names: bool,
references: Sequence[Reference] | None = None,
dependencies: Mapping[str, Dependency] | None = None,
normalize: bool = False,
) -> ParsedTypedSchema:
if schema_type not in [SchemaType.AVRO, SchemaType.JSONSCHEMA, SchemaType.PROTOBUF]:
raise InvalidSchema(f"Unknown parser {schema_type} for {schema_str}")
Expand All @@ -203,7 +210,7 @@ def parse(

elif schema_type is SchemaType.PROTOBUF:
try:
parsed_schema = parse_protobuf_schema_definition(schema_str, references, dependencies)
parsed_schema = parse_protobuf_schema_definition(schema_str, references, dependencies, normalize=normalize)
except (
TypeError,
SchemaError,
Expand Down Expand Up @@ -270,6 +277,7 @@ def parse(
schema_str: str,
references: Sequence[Reference] | None = None,
dependencies: Mapping[str, Dependency] | None = None,
normalize: bool = False,
) -> ParsedTypedSchema:
return parse(
schema_type=schema_type,
Expand All @@ -278,6 +286,7 @@ def parse(
validate_avro_names=False,
references=references,
dependencies=dependencies,
normalize=normalize,
)

def __str__(self) -> str:
Expand Down Expand Up @@ -352,6 +361,7 @@ def parse(
schema_str: str,
references: Sequence[Reference] | None = None,
dependencies: Mapping[str, Dependency] | None = None,
normalize: bool = False,
) -> ValidatedTypedSchema:
parsed_schema = parse(
schema_type=schema_type,
Expand All @@ -360,6 +370,7 @@ def parse(
validate_avro_names=True,
references=references,
dependencies=dependencies,
normalize=normalize,
)

return cast(ValidatedTypedSchema, parsed_schema)
Expand Down
10 changes: 9 additions & 1 deletion karapace/schema_registry_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,11 +1097,16 @@ async def subjects_schema_post(
schema_type = self._validate_schema_type(content_type=content_type, data=body)
references = self._validate_references(content_type, schema_type, body)
references, new_schema_dependencies = self.schema_registry.resolve_references(references)
normalize = request.query.get("normalize", "false").lower() == "true"
try:
# When checking if schema is already registered, allow unvalidated schema in as
# there might be stored schemas that are non-compliant from the past.
new_schema = ParsedTypedSchema.parse(
schema_type=schema_type, schema_str=schema_str, references=references, dependencies=new_schema_dependencies
schema_type=schema_type,
schema_str=schema_str,
references=references,
dependencies=new_schema_dependencies,
normalize=normalize,
)
except InvalidSchema:
self.log.warning("Invalid schema: %r", schema_str)
Expand Down Expand Up @@ -1133,6 +1138,7 @@ async def subjects_schema_post(
schema_version.schema.schema_str,
references=other_references,
dependencies=other_dependencies,
normalize=normalize,
)
except InvalidSchema as e:
failed_schema_id = schema_version.schema_id
Expand Down Expand Up @@ -1191,6 +1197,7 @@ async def subject_post(
self._validate_schema_request_body(content_type, body)
schema_type = self._validate_schema_type(content_type, body)
self._validate_schema_key(content_type, body)
normalize = request.query.get("normalize", "false").lower() == "true"
references = self._validate_references(content_type, schema_type, body)

try:
Expand All @@ -1200,6 +1207,7 @@ async def subject_post(
schema_str=body["schema"],
references=references,
dependencies=resolved_dependencies,
normalize=normalize,
)
except (InvalidReferences, InvalidSchema, InvalidSchemaType) as e:
self.log.warning("Invalid schema: %r", body["schema"], exc_info=True)
Expand Down
Loading
Loading