From bbf365b552bdf43d074354e98818ce3f04d35195 Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Fri, 14 Feb 2020 11:49:02 -0800 Subject: [PATCH] kafka: 2.4.0 support - add support for new message types added in 2.4 (#10000) Signed-off-by: Adam Kotwasinski --- bazel/repository_locations.bzl | 18 +- .../network_filters/kafka_broker_filter.rst | 2 +- source/extensions/filters/network/kafka/BUILD | 2 + .../filters/network/kafka/kafka_request.h | 57 +++- .../network/kafka/kafka_request_parser.cc | 30 +++ .../network/kafka/kafka_request_parser.h | 40 ++- .../filters/network/kafka/kafka_response.h | 54 +++- .../network/kafka/kafka_response_parser.cc | 33 ++- .../network/kafka/kafka_response_parser.h | 18 +- .../kafka/protocol/complex_type_template.j2 | 44 ++-- .../network/kafka/protocol/generator.py | 246 ++++++++++++++---- .../protocol/kafka_request_resolver_cc.j2 | 19 ++ .../protocol/kafka_response_resolver_cc.j2 | 22 ++ .../network/kafka/serialization/generator.py | 2 +- .../filters/network/kafka/tagged_fields.h | 4 +- test/extensions/filters/network/kafka/BUILD | 1 + .../integration_test/zookeeper_properties.j2 | 2 + .../kafka/kafka_request_parser_test.cc | 5 +- .../protocol/request_codec_request_test_cc.j2 | 15 +- .../response_codec_response_test_cc.j2 | 14 +- .../network/kafka/request_codec_unit_test.cc | 4 +- .../network/kafka/serialization_test.cc | 19 ++ 22 files changed, 526 insertions(+), 125 deletions(-) diff --git a/bazel/repository_locations.bzl b/bazel/repository_locations.bzl index 7fab9e774ee3..bff4e023bd96 100644 --- a/bazel/repository_locations.bzl +++ b/bazel/repository_locations.bzl @@ -303,18 +303,18 @@ REPOSITORY_LOCATIONS = dict( urls = ["https://github.com/protocolbuffers/upb/archive/8a3ae1ef3e3e3f26b45dec735c5776737fc7247f.tar.gz"], ), kafka_source = dict( - sha256 = "feaa32e5c42acf42bd587f8f0b1ccce679db227620da97eed013f4c44a44f64d", - strip_prefix = "kafka-2.3.1/clients/src/main/resources/common/message", - urls = ["https://github.com/apache/kafka/archive/2.3.1.zip"], + sha256 = "e7b748a62e432b5770db6dbb3b034c68c0ea212812cb51603ee7f3a8a35f06be", + strip_prefix = "kafka-2.4.0/clients/src/main/resources/common/message", + urls = ["https://github.com/apache/kafka/archive/2.4.0.zip"], ), kafka_server_binary = dict( - sha256 = "5a3ddd4148371284693370d56f6f66c7a86d86dd96c533447d2a94d176768d2e", - strip_prefix = "kafka_2.12-2.3.1", - urls = ["http://us.mirrors.quenda.co/apache/kafka/2.3.1/kafka_2.12-2.3.1.tgz"], + sha256 = "b9582bab0c3e8d131953b1afa72d6885ca1caae0061c2623071e7f396f2ccfee", + strip_prefix = "kafka_2.12-2.4.0", + urls = ["http://us.mirrors.quenda.co/apache/kafka/2.4.0/kafka_2.12-2.4.0.tgz"], ), kafka_python_client = dict( - sha256 = "81f24a5d297531495e0ccb931fbd6c4d1ec96583cf5a730579a3726e63f59c47", - strip_prefix = "kafka-python-1.4.7", - urls = ["https://github.com/dpkp/kafka-python/archive/1.4.7.tar.gz"], + sha256 = "454bf3aafef9348017192417b7f0828a347ec2eaf3efba59336f3a3b68f10094", + strip_prefix = "kafka-python-2.0.0", + urls = ["https://github.com/dpkp/kafka-python/archive/2.0.0.tar.gz"], ), ) diff --git a/docs/root/configuration/listeners/network_filters/kafka_broker_filter.rst b/docs/root/configuration/listeners/network_filters/kafka_broker_filter.rst index 2443136748d1..8f7ef37427fe 100644 --- a/docs/root/configuration/listeners/network_filters/kafka_broker_filter.rst +++ b/docs/root/configuration/listeners/network_filters/kafka_broker_filter.rst @@ -5,7 +5,7 @@ Kafka Broker filter The Apache Kafka broker filter decodes the client protocol for `Apache Kafka `_, both the requests and responses in the payload. -The message versions in `Kafka 2.3.1 `_ +The message versions in `Kafka 2.4.0 `_ are supported. The filter attempts not to influence the communication between client and brokers, so the messages that could not be decoded (due to Kafka client or broker running a newer version than supported by diff --git a/source/extensions/filters/network/kafka/BUILD b/source/extensions/filters/network/kafka/BUILD index 8564c02edae8..f4588076cfae 100644 --- a/source/extensions/filters/network/kafka/BUILD +++ b/source/extensions/filters/network/kafka/BUILD @@ -81,6 +81,7 @@ envoy_cc_library( deps = [ ":kafka_request_lib", ":parser_lib", + ":tagged_fields_lib", "//source/common/common:assert_lib", "//source/common/common:minimal_logger_lib", ], @@ -143,6 +144,7 @@ envoy_cc_library( deps = [ ":kafka_response_lib", ":parser_lib", + ":tagged_fields_lib", "//source/common/common:assert_lib", "//source/common/common:minimal_logger_lib", ], diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h index 0c21543cc214..ff9f57282af3 100644 --- a/source/extensions/filters/network/kafka/kafka_request.h +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -4,12 +4,22 @@ #include "extensions/filters/network/kafka/external/serialization_composite.h" #include "extensions/filters/network/kafka/serialization.h" +#include "extensions/filters/network/kafka/tagged_fields.h" namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace Kafka { +/** + * Decides if request with given api key & version should have tagged fields in header. + * This method gets implemented in generated code through 'kafka_request_resolver_cc.j2'. + * @param api_key Kafka request key. + * @param api_version Kafka request's version. + * @return Whether tagged fields should be used for this request. + */ +bool requestUsesTaggedFieldsInHeader(const uint16_t api_key, const uint16_t api_version); + /** * Represents fields that are present in every Kafka request message. * @see http://kafka.apache.org/protocol.html#protocol_messages @@ -19,10 +29,45 @@ struct RequestHeader { int16_t api_version_; int32_t correlation_id_; NullableString client_id_; + TaggedFields tagged_fields_; + + RequestHeader(const int16_t api_key, const int16_t api_version, const int32_t correlation_id, + const NullableString& client_id) + : RequestHeader{api_key, api_version, correlation_id, client_id, TaggedFields{}} {}; + + RequestHeader(const int16_t api_key, const int16_t api_version, const int32_t correlation_id, + const NullableString& client_id, const TaggedFields& tagged_fields) + : api_key_{api_key}, api_version_{api_version}, correlation_id_{correlation_id}, + client_id_{client_id}, tagged_fields_{tagged_fields} {}; + + uint32_t computeSize(const EncodingContext& context) const { + uint32_t result{0}; + result += context.computeSize(api_key_); + result += context.computeSize(api_version_); + result += context.computeSize(correlation_id_); + result += context.computeSize(client_id_); + if (requestUsesTaggedFieldsInHeader(api_key_, api_version_)) { + result += context.computeCompactSize(tagged_fields_); + } + return result; + } + + uint32_t encode(Buffer::Instance& dst, EncodingContext& context) const { + uint32_t written{0}; + written += context.encode(api_key_, dst); + written += context.encode(api_version_, dst); + written += context.encode(correlation_id_, dst); + written += context.encode(client_id_, dst); + if (requestUsesTaggedFieldsInHeader(api_key_, api_version_)) { + written += context.encodeCompact(tagged_fields_, dst); + } + return written; + } bool operator==(const RequestHeader& rhs) const { return api_key_ == rhs.api_key_ && api_version_ == rhs.api_version_ && - correlation_id_ == rhs.correlation_id_ && client_id_ == rhs.client_id_; + correlation_id_ == rhs.correlation_id_ && client_id_ == rhs.client_id_ && + tagged_fields_ == rhs.tagged_fields_; }; }; @@ -95,10 +140,7 @@ template class Request : public AbstractRequest { const EncodingContext context{request_header_.api_version_}; uint32_t result{0}; // Compute size of header. - result += context.computeSize(request_header_.api_key_); - result += context.computeSize(request_header_.api_version_); - result += context.computeSize(request_header_.correlation_id_); - result += context.computeSize(request_header_.client_id_); + result += context.computeSize(request_header_); // Compute size of request data. result += context.computeSize(data_); return result; @@ -111,10 +153,7 @@ template class Request : public AbstractRequest { EncodingContext context{request_header_.api_version_}; uint32_t written{0}; // Encode request header. - written += context.encode(request_header_.api_key_, dst); - written += context.encode(request_header_.api_version_, dst); - written += context.encode(request_header_.correlation_id_, dst); - written += context.encode(request_header_.client_id_, dst); + written += context.encode(request_header_, dst); // Encode request-specific data. written += context.encode(data_, dst); return written; diff --git a/source/extensions/filters/network/kafka/kafka_request_parser.cc b/source/extensions/filters/network/kafka/kafka_request_parser.cc index 2c6be950f08e..811c051182c3 100644 --- a/source/extensions/filters/network/kafka/kafka_request_parser.cc +++ b/source/extensions/filters/network/kafka/kafka_request_parser.cc @@ -20,6 +20,36 @@ RequestParseResponse RequestStartParser::parse(absl::string_view& data) { } } +uint32_t RequestHeaderDeserializer::feed(absl::string_view& data) { + uint32_t consumed = 0; + + consumed += common_part_deserializer_.feed(data); + if (common_part_deserializer_.ready()) { + const auto request_header = common_part_deserializer_.get(); + if (requestUsesTaggedFieldsInHeader(request_header.api_key_, request_header.api_version_)) { + tagged_fields_present_ = true; + consumed += tagged_fields_deserializer_.feed(data); + } + } + + return consumed; +} + +bool RequestHeaderDeserializer::ready() const { + // Header is only fully parsed after we have processed everything, including tagged fields (if + // they are present). + return common_part_deserializer_.ready() && + (tagged_fields_present_ ? tagged_fields_deserializer_.ready() : true); +} + +RequestHeader RequestHeaderDeserializer::get() const { + auto result = common_part_deserializer_.get(); + if (tagged_fields_present_) { + result.tagged_fields_ = tagged_fields_deserializer_.get(); + } + return result; +} + RequestParseResponse RequestHeaderParser::parse(absl::string_view& data) { context_->remaining_request_size_ -= deserializer_->feed(data); // One of the two needs must have happened when feeding finishes: diff --git a/source/extensions/filters/network/kafka/kafka_request_parser.h b/source/extensions/filters/network/kafka/kafka_request_parser.h index b588b9aff454..15d304b8d187 100644 --- a/source/extensions/filters/network/kafka/kafka_request_parser.h +++ b/source/extensions/filters/network/kafka/kafka_request_parser.h @@ -8,6 +8,7 @@ #include "extensions/filters/network/kafka/kafka_request.h" #include "extensions/filters/network/kafka/parser.h" +#include "extensions/filters/network/kafka/tagged_fields.h" namespace Envoy { namespace Extensions { @@ -22,8 +23,16 @@ using RequestParserSharedPtr = std::shared_ptr; * Context that is shared between parsers that are handling the same single message. */ struct RequestContext { + + /** + * Bytes left to consume. + */ uint32_t remaining_request_size_{0}; - RequestHeader request_header_{}; + + /** + * Request header that gets filled in during the parse. + */ + RequestHeader request_header_{-1, -1, -1, absl::nullopt}; /** * Bytes left to consume. @@ -91,10 +100,31 @@ class RequestStartParser : public RequestParser { * Can throw, as one of the fields (client-id) can throw (nullable string with invalid length). * @see http://kafka.apache.org/protocol.html#protocol_messages */ -class RequestHeaderDeserializer - : public CompositeDeserializerWith4Delegates {}; +class RequestHeaderDeserializer : public Deserializer, + private Logger::Loggable { + + // Request header, no matter what, has at least 4 fields. They are extracted here. + using CommonPartDeserializer = + CompositeDeserializerWith4Delegates; + +public: + RequestHeaderDeserializer() = default; + + uint32_t feed(absl::string_view& data) override; + bool ready() const override; + RequestHeader get() const override; + +private: + // Deserializer for the first 4 fields, that are present in every request header. + CommonPartDeserializer common_part_deserializer_; + + // Tagged fields are used only in request header v2. + // This flag will be set depending on common part's result (api key & version), and will decide + // whether we want to feed data to tagged fields deserializer. + bool tagged_fields_present_; + TaggedFieldsDeserializer tagged_fields_deserializer_; +}; using RequestHeaderDeserializerPtr = std::unique_ptr; diff --git a/source/extensions/filters/network/kafka/kafka_response.h b/source/extensions/filters/network/kafka/kafka_response.h index 8e6a8e5b35c1..e53ae70de0da 100644 --- a/source/extensions/filters/network/kafka/kafka_response.h +++ b/source/extensions/filters/network/kafka/kafka_response.h @@ -2,28 +2,67 @@ #include "extensions/filters/network/kafka/external/serialization_composite.h" #include "extensions/filters/network/kafka/serialization.h" +#include "extensions/filters/network/kafka/tagged_fields.h" namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace Kafka { +/** + * Decides if response with given api key & version should have tagged fields in header. + * Bear in mind, that ApiVersions responses DO NOT contain tagged fields in header (despite having + * flexible versions) as per + * https://github.com/apache/kafka/blob/2.4.0/clients/src/main/resources/common/message/ApiVersionsResponse.json#L24 + * This method gets implemented in generated code through 'kafka_response_resolver_cc.j2'. + * + * @param api_key Kafka request key. + * @param api_version Kafka request's version. + * @return Whether tagged fields should be used for this request. + */ +bool responseUsesTaggedFieldsInHeader(const uint16_t api_key, const uint16_t api_version); + /** * Represents Kafka response metadata: expected api key, version and correlation id. * @see http://kafka.apache.org/protocol.html#protocol_messages */ struct ResponseMetadata { ResponseMetadata(const int16_t api_key, const int16_t api_version, const int32_t correlation_id) - : api_key_{api_key}, api_version_{api_version}, correlation_id_{correlation_id} {}; + : ResponseMetadata{api_key, api_version, correlation_id, TaggedFields{}} {}; + + ResponseMetadata(const int16_t api_key, const int16_t api_version, const int32_t correlation_id, + const TaggedFields& tagged_fields) + : api_key_{api_key}, api_version_{api_version}, correlation_id_{correlation_id}, + tagged_fields_{tagged_fields} {}; + + uint32_t computeSize(const EncodingContext& context) const { + uint32_t result{0}; + result += context.computeSize(correlation_id_); + if (responseUsesTaggedFieldsInHeader(api_key_, api_version_)) { + result += context.computeCompactSize(tagged_fields_); + } + return result; + } + + uint32_t encode(Buffer::Instance& dst, EncodingContext& context) const { + uint32_t written{0}; + // Encode correlation id (api key / version are not present in responses). + written += context.encode(correlation_id_, dst); + if (responseUsesTaggedFieldsInHeader(api_key_, api_version_)) { + written += context.encodeCompact(tagged_fields_, dst); + } + return written; + } bool operator==(const ResponseMetadata& rhs) const { return api_key_ == rhs.api_key_ && api_version_ == rhs.api_version_ && - correlation_id_ == rhs.correlation_id_; + correlation_id_ == rhs.correlation_id_ && tagged_fields_ == rhs.tagged_fields_; }; const int16_t api_key_; const int16_t api_version_; const int32_t correlation_id_; + const TaggedFields tagged_fields_; }; using ResponseMetadataSharedPtr = std::shared_ptr; @@ -77,7 +116,12 @@ template class Response : public AbstractResponse { */ uint32_t computeSize() const override { const EncodingContext context{metadata_.api_version_}; - return context.computeSize(metadata_.correlation_id_) + context.computeSize(data_); + uint32_t result{0}; + // Compute size of header. + result += context.computeSize(metadata_); + // Compute size of response data. + result += context.computeSize(data_); + return result; } /** @@ -86,8 +130,8 @@ template class Response : public AbstractResponse { uint32_t encode(Buffer::Instance& dst) const override { EncodingContext context{metadata_.api_version_}; uint32_t written{0}; - // Encode correlation id (api key / version are not present in responses). - written += context.encode(metadata_.correlation_id_, dst); + // Encode response header. + written += context.encode(metadata_, dst); // Encode response-specific data. written += context.encode(data_, dst); return written; diff --git a/source/extensions/filters/network/kafka/kafka_response_parser.cc b/source/extensions/filters/network/kafka/kafka_response_parser.cc index fb1781446495..3adbb1b28777 100644 --- a/source/extensions/filters/network/kafka/kafka_response_parser.cc +++ b/source/extensions/filters/network/kafka/kafka_response_parser.cc @@ -22,16 +22,33 @@ ResponseParseResponse ResponseHeaderParser::parse(absl::string_view& data) { return ResponseParseResponse::stillWaiting(); } - context_->remaining_response_size_ = length_deserializer_.get(); - context_->remaining_response_size_ -= sizeof(context_->correlation_id_); - context_->correlation_id_ = correlation_id_deserializer_.get(); + if (!context_->api_info_set_) { + // We have consumed first two response header fields: payload length and correlation id. + context_->remaining_response_size_ = length_deserializer_.get(); + context_->remaining_response_size_ -= sizeof(context_->correlation_id_); + context_->correlation_id_ = correlation_id_deserializer_.get(); - const ExpectedResponseSpec spec = getResponseSpec(context_->correlation_id_); - context_->api_key_ = spec.first; - context_->api_version_ = spec.second; - // At this stage, we have setup the context - we know the response's api key & version, so we can - // safely create the payload parser. + // We have correlation id now, so we can see what is the expected response api key & version. + const ExpectedResponseSpec spec = getResponseSpec(context_->correlation_id_); + context_->api_key_ = spec.first; + context_->api_version_ = spec.second; + // Mark that version data has been set, so we do not attempt to re-initialize again. + context_->api_info_set_ = true; + } + + // Depending on response's api key & version, we might need to parse tagged fields element. + if (responseUsesTaggedFieldsInHeader(context_->api_key_, context_->api_version_)) { + context_->remaining_response_size_ -= tagged_fields_deserializer_.feed(data); + if (tagged_fields_deserializer_.ready()) { + context_->tagged_fields_ = tagged_fields_deserializer_.get(); + } else { + return ResponseParseResponse::stillWaiting(); + } + } + + // At this stage, we have fully setup the context - we know the response's api key & version, + // so we can safely create the payload parser. auto next_parser = parser_resolver_.createParser(context_); return ResponseParseResponse::nextParser(next_parser); } diff --git a/source/extensions/filters/network/kafka/kafka_response_parser.h b/source/extensions/filters/network/kafka/kafka_response_parser.h index 335da7f5d9ea..428511bac592 100644 --- a/source/extensions/filters/network/kafka/kafka_response_parser.h +++ b/source/extensions/filters/network/kafka/kafka_response_parser.h @@ -5,6 +5,7 @@ #include "extensions/filters/network/kafka/kafka_response.h" #include "extensions/filters/network/kafka/parser.h" +#include "extensions/filters/network/kafka/tagged_fields.h" namespace Envoy { namespace Extensions { @@ -20,6 +21,11 @@ using ResponseParserSharedPtr = std::shared_ptr; */ struct ResponseContext { + /** + * Whether the 'api_key_' & 'api_version_' fields have been initialized. + */ + bool api_info_set_ = false; + /** * Api key of response that's being parsed. */ @@ -40,6 +46,11 @@ struct ResponseContext { */ int32_t correlation_id_; + /** + * Response's tagged fields. + */ + TaggedFields tagged_fields_; + /** * Bytes left to consume. */ @@ -48,7 +59,9 @@ struct ResponseContext { /** * Returns data needed for construction of parse failure message. */ - const ResponseMetadata asFailureData() const { return {api_key_, api_version_, correlation_id_}; } + const ResponseMetadata asFailureData() const { + return {api_key_, api_version_, correlation_id_, tagged_fields_}; + } }; using ResponseContextSharedPtr = std::shared_ptr; @@ -119,6 +132,7 @@ class ResponseHeaderParser : public ResponseParser { Int32Deserializer length_deserializer_; Int32Deserializer correlation_id_deserializer_; + TaggedFieldsDeserializer tagged_fields_deserializer_; }; /** @@ -166,7 +180,7 @@ class ResponseDataParser : public ResponseParser { if (0 == context_->remaining_response_size_) { // After a successful parse, there should be nothing left - we have consumed all the bytes. const ResponseMetadata metadata = {context_->api_key_, context_->api_version_, - context_->correlation_id_}; + context_->correlation_id_, context_->tagged_fields_}; const AbstractResponseSharedPtr response = std::make_shared>(metadata, deserializer_.get()); return ResponseParseResponse::parsedMessage(response); diff --git a/source/extensions/filters/network/kafka/protocol/complex_type_template.j2 b/source/extensions/filters/network/kafka/protocol/complex_type_template.j2 index a60533777970..744fd8749889 100644 --- a/source/extensions/filters/network/kafka/protocol/complex_type_template.j2 +++ b/source/extensions/filters/network/kafka/protocol/complex_type_template.j2 @@ -25,38 +25,40 @@ struct {{ complex_type.name }} { {{ constructor['full_declaration'] }}{% endfor %} {# For every field that's used in version, just compute its size using an encoder. #} - {% if complex_type.fields|length > 0 %} uint32_t computeSize(const EncodingContext& encoder) const { const int16_t api_version = encoder.apiVersion(); - uint32_t written{0};{% for field in complex_type.fields %} - if (api_version >= {{ field.version_usage[0] }} - && api_version < {{ field.version_usage[-1] + 1 }}) { - written += encoder.computeSize({{ field.name }}_); - }{% endfor %} + uint32_t written{0}; + + {% for spec in complex_type.compute_serialization_specs() %} + if (api_version >= {{ spec.versions[0] }} && api_version < {{ spec.versions[-1] + 1 }}) { + written += encoder.{{ spec.compute_size_method_name }}({{ spec.field.name }}_); + } + {% endfor %} + return written; } - {% else %} - uint32_t computeSize(const EncodingContext&) const { - return 0; + + uint32_t computeCompactSize(const EncodingContext& encoder) const { + return computeSize(encoder); } - {% endif %} {# For every field that's used in version, just serialize it. #} - {% if complex_type.fields|length > 0 %} uint32_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { const int16_t api_version = encoder.apiVersion(); - uint32_t written{0};{% for field in complex_type.fields %} - if (api_version >= {{ field.version_usage[0] }} - && api_version < {{ field.version_usage[-1] + 1 }}) { - written += encoder.encode({{ field.name }}_, dst); - }{% endfor %} + uint32_t written{0}; + + {% for spec in complex_type.compute_serialization_specs() %} + if (api_version >= {{ spec.versions[0] }} && api_version < {{ spec.versions[-1] + 1 }}) { + written += encoder.{{ spec.encode_method_name }}({{ spec.field.name }}_, dst); + } + {% endfor %} + return written; } - {% else %} - uint32_t encode(Buffer::Instance&, EncodingContext&) const { - return 0; + + uint32_t encodeCompact(Buffer::Instance& dst, EncodingContext& encoder) const { + return encode(dst, encoder); } - {% endif %} {% if complex_type.fields|length > 0 %} bool operator==(const {{ complex_type.name }}& rhs) const { @@ -77,7 +79,7 @@ class {{ complex_type.name }}V{{ field_list.version }}Deserializer: public CompositeDeserializerWith{{ field_list.field_count() }}Delegates< {{ complex_type.name }} {% for field in field_list.used_fields() %}, - {{ field.deserializer_name_in_version(field_list.version) }} + {{ field.deserializer_name_in_version(field_list.version, field_list.uses_compact_fields) }} {% endfor %}>{}; {% endfor %} diff --git a/source/extensions/filters/network/kafka/protocol/generator.py b/source/extensions/filters/network/kafka/protocol/generator.py index c683821b0ff0..87cd7b1fe8ef 100755 --- a/source/extensions/filters/network/kafka/protocol/generator.py +++ b/source/extensions/filters/network/kafka/protocol/generator.py @@ -26,7 +26,8 @@ def generate_main_code(type, main_header_file, resolver_cc_file, metrics_header_ for message in messages: # For each child structure that is used by request/response, render its matching C++ code. - for dependency in message.declaration_chain: + dependencies = message.compute_declaration_chain() + for dependency in dependencies: main_header_contents += complex_type_template.render(complex_type=dependency) # Each top-level structure (e.g. FetchRequest/FetchResponse) needs corresponding parsers. main_header_contents += parsers_template.render(complex_type=message) @@ -101,6 +102,8 @@ def __init__(self): self.known_types = set() # Name of parent message type that's being processed right now. self.currently_processed_message_type = None + # Common structs declared in this message type. + self.common_structs = {} def parse_messages(self, input_files): """ @@ -114,12 +117,17 @@ def parse_messages(self, input_files): input_files.sort() # For each specification file, remove comments, and parse the remains. for input_file in input_files: - with open(input_file, 'r') as fd: - raw_contents = fd.read() - without_comments = re.sub(r'//.*\n', '', raw_contents) - message_spec = json.loads(without_comments) - message = self.parse_top_level_element(message_spec) - messages.append(message) + try: + with open(input_file, 'r') as fd: + raw_contents = fd.read() + without_comments = re.sub(r'\s*//.*\n', '\n', raw_contents) + without_empty_newlines = re.sub(r'^\s*$', '', without_comments, flags=re.MULTILINE) + message_spec = json.loads(without_empty_newlines) + message = self.parse_top_level_element(message_spec) + messages.append(message) + except Exception as e: + print('could not process %s' % input_file) + raise # Sort messages by api_key. messages.sort(key=lambda x: x.get_extra('api_key')) @@ -132,33 +140,75 @@ def parse_top_level_element(self, spec): named fields, compared to sub-structures in a message. """ self.currently_processed_message_type = spec['name'] + + # Figure out all versions of this message type. versions = Statics.parse_version_string(spec['validVersions'], 2 << 16 - 1) - complex_type = self.parse_complex_type(self.currently_processed_message_type, spec, versions) - # Request / response types need to carry api key version. - return complex_type.with_extra('api_key', spec['apiKey']) + + # Figure out the flexible versions. + flexible_versions_string = spec.get('flexibleVersions', 'none') + if 'none' != flexible_versions_string: + flexible_versions = Statics.parse_version_string(flexible_versions_string, versions[-1]) + else: + flexible_versions = [] + + # Sanity check - all flexible versions need to be versioned. + if [x for x in flexible_versions if x not in versions]: + raise ValueError('invalid flexible versions') + + try: + # In 2.4 some types are declared at top level, and only referenced inside. + # So let's parse them and store them in state. + common_structs = spec.get('commonStructs') + if common_structs is not None: + for common_struct in common_structs: + common_struct_name = common_struct['name'] + common_struct_versions = Statics.parse_version_string(common_struct['versions'], + versions[-1]) + parsed_complex = self.parse_complex_type(common_struct_name, common_struct, + common_struct_versions) + self.common_structs[parsed_complex.name] = parsed_complex + + # Parse the type itself. + complex_type = self.parse_complex_type(self.currently_processed_message_type, spec, versions) + complex_type.register_flexible_versions(flexible_versions) + + # Request / response types need to carry api key version. + result = complex_type.with_extra('api_key', spec['apiKey']) + return result + + finally: + self.common_structs = {} + self.currently_processed_message_type = None def parse_complex_type(self, type_name, field_spec, versions): """ Parse given complex type, returning a structure that holds its name, field specification and allowed versions. """ - fields = [] - for child_field in field_spec['fields']: - child = self.parse_field(child_field, versions[-1]) - fields.append(child) - - # Some of the types repeat multiple times (e.g. AlterableConfig). - # In such a case, every second or later occurrence of the same name is going to be prefixed - # with parent type, e.g. we have AlterableConfig (for AlterConfigsRequest) and then - # IncrementalAlterConfigsRequestAlterableConfig (for IncrementalAlterConfigsRequest). - # This keeps names unique, while keeping non-duplicate ones short. - if type_name not in self.known_types: - self.known_types.add(type_name) - else: - type_name = self.currently_processed_message_type + type_name - self.known_types.add(type_name) + fields_el = field_spec.get('fields') + + if fields_el is not None: + fields = [] + for child_field in field_spec['fields']: + child = self.parse_field(child_field, versions[-1]) + if child is not None: + fields.append(child) + + # Some of the types repeat multiple times (e.g. AlterableConfig). + # In such a case, every second or later occurrence of the same name is going to be prefixed + # with parent type, e.g. we have AlterableConfig (for AlterConfigsRequest) and then + # IncrementalAlterConfigsRequestAlterableConfig (for IncrementalAlterConfigsRequest). + # This keeps names unique, while keeping non-duplicate ones short. + if type_name not in self.known_types: + self.known_types.add(type_name) + else: + type_name = self.currently_processed_message_type + type_name + self.known_types.add(type_name) - return Complex(type_name, fields, versions) + return Complex(type_name, fields, versions) + + else: + return self.common_structs[type_name] def parse_field(self, field_spec, highest_possible_version): """ @@ -166,6 +216,9 @@ def parse_field(self, field_spec, highest_possible_version): actually used (nullable or not). Obviously, field cannot be used in version higher than its type's usage. """ + if field_spec.get('tag') is not None: + return None + version_usage = Statics.parse_version_string(field_spec['versions'], highest_possible_version) version_usage_as_nullable = Statics.parse_version_string( field_spec['nullableVersions'], @@ -183,7 +236,7 @@ def parse_type(self, type_name, field_spec, highest_possible_version): underlying_type = self.parse_type(type_name[2:], field_spec, highest_possible_version) return Array(underlying_type) else: - if (type_name in Primitive.PRIMITIVE_TYPE_NAMES): + if (type_name in Primitive.USABLE_PRIMITIVE_TYPE_NAMES): return Primitive(type_name, field_spec.get('default')) else: versions = Statics.parse_version_string(field_spec['versions'], highest_possible_version) @@ -211,11 +264,12 @@ def parse_version_string(raw_versions, highest_possible_version): class FieldList: """ List of fields used by given entity (request or child structure) in given message version - (as fields get added or removed across versions). + (as fields get added or removed across versions and/or they change compaction level). """ - def __init__(self, version, fields): + def __init__(self, version, uses_compact_fields, fields): self.version = version + self.uses_compact_fields = uses_compact_fields self.fields = fields def used_fields(self): @@ -327,11 +381,11 @@ def example_value_for_test(self, version): else: return str(self.type.example_value_for_test(version)) - def deserializer_name_in_version(self, version): + def deserializer_name_in_version(self, version, compact): if self.is_nullable_in_version(version): - return 'Nullable%s' % self.type.deserializer_name_in_version(version) + return 'Nullable%s' % self.type.deserializer_name_in_version(version, compact) else: - return self.type.deserializer_name_in_version(version) + return self.type.deserializer_name_in_version(version, compact) def is_printable(self): return self.type.is_printable() @@ -339,7 +393,13 @@ def is_printable(self): class TypeSpecification: - def deserializer_name_in_version(self, version): + def compute_declaration_chain(self): + """ + Computes types that need to be declared before this type can be declared, in C++ sense. + """ + raise NotImplementedError() + + def deserializer_name_in_version(self, version, compact): """ Renders the deserializer name of given type, in message with given version. """ @@ -351,6 +411,12 @@ def default_value(self): """ raise NotImplementedError() + def has_flexible_handling(self): + """ + Whether the given type has special encoding when carrying message is using flexible encoding. + """ + raise NotImplementedError() + def example_value_for_test(self, version): raise NotImplementedError() @@ -367,19 +433,26 @@ class Array(TypeSpecification): def __init__(self, underlying): self.underlying = underlying - self.declaration_chain = self.underlying.declaration_chain @property def name(self): return 'std::vector<%s>' % self.underlying.name - def deserializer_name_in_version(self, version): - return 'ArrayDeserializer<%s, %s>' % (self.underlying.name, - self.underlying.deserializer_name_in_version(version)) + def compute_declaration_chain(self): + # To use an array of type T, we just need to be capable of using type T. + return self.underlying.compute_declaration_chain() + + def deserializer_name_in_version(self, version, compact): + return '%sArrayDeserializer<%s, %s>' % ("Compact" if compact else "", self.underlying.name, + self.underlying.deserializer_name_in_version( + version, compact)) def default_value(self): return 'std::vector<%s>{}' % (self.underlying.name) + def has_flexible_handling(self): + return True + def example_value_for_test(self, version): return 'std::vector<%s>{ %s }' % (self.underlying.name, self.underlying.example_value_for_test(version)) @@ -393,7 +466,7 @@ class Primitive(TypeSpecification): Represents a Kafka primitive value. """ - PRIMITIVE_TYPE_NAMES = ['bool', 'int8', 'int16', 'int32', 'int64', 'string', 'bytes'] + USABLE_PRIMITIVE_TYPE_NAMES = ['bool', 'int8', 'int16', 'int32', 'int64', 'string', 'bytes'] KAFKA_TYPE_TO_ENVOY_TYPE = { 'string': 'std::string', @@ -403,6 +476,7 @@ class Primitive(TypeSpecification): 'int32': 'int32_t', 'int64': 'int64_t', 'bytes': 'Bytes', + 'tagged_fields': 'TaggedFields', } KAFKA_TYPE_TO_DESERIALIZER = { @@ -413,6 +487,12 @@ class Primitive(TypeSpecification): 'int32': 'Int32Deserializer', 'int64': 'Int64Deserializer', 'bytes': 'BytesDeserializer', + 'tagged_fields': 'TaggedFieldsDeserializer', + } + + KAFKA_TYPE_TO_COMPACT_DESERIALIZER = { + 'string': 'CompactStringDeserializer', + 'bytes': 'CompactBytesDeserializer' } # See https://github.com/apache/kafka/tree/trunk/clients/src/main/resources/common/message#deserializing-messages @@ -424,25 +504,33 @@ class Primitive(TypeSpecification): 'int32': '0', 'int64': '0', 'bytes': '{}', + 'tagged_fields': 'TaggedFields({})', } # Custom values that make test code more readable. KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST = { - 'string': '"string"', - 'bool': 'false', - 'int8': 'static_cast(8)', - 'int16': 'static_cast(16)', - 'int32': 'static_cast(32)', - 'int64': 'static_cast(64)', - 'bytes': 'Bytes({0, 1, 2, 3})', + 'string': + '"string"', + 'bool': + 'false', + 'int8': + 'static_cast(8)', + 'int16': + 'static_cast(16)', + 'int32': + 'static_cast(32)', + 'int64': + 'static_cast(64)', + 'bytes': + 'Bytes({0, 1, 2, 3})', + 'tagged_fields': + 'TaggedFields{std::vector{{10, Bytes({1, 2, 3})}, {20, Bytes({4, 5, 6})}}}', } def __init__(self, name, custom_default_value): self.original_name = name self.name = Primitive.compute(name, Primitive.KAFKA_TYPE_TO_ENVOY_TYPE) self.custom_default_value = custom_default_value - self.declaration_chain = [] - self.deserializer_name = Primitive.compute(name, Primitive.KAFKA_TYPE_TO_DESERIALIZER) @staticmethod def compute(name, map): @@ -451,8 +539,15 @@ def compute(name, map): else: raise ValueError(name) - def deserializer_name_in_version(self, version): - return self.deserializer_name + def compute_declaration_chain(self): + # Primitives need no declarations. + return [] + + def deserializer_name_in_version(self, version, compact): + if compact and self.original_name in Primitive.KAFKA_TYPE_TO_COMPACT_DESERIALIZER.keys(): + return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_COMPACT_DESERIALIZER) + else: + return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_DESERIALIZER) def default_value(self): if self.custom_default_value is not None: @@ -460,6 +555,9 @@ def default_value(self): else: return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_DEFAULT_VALUE) + def has_flexible_handling(self): + return self.original_name in ['string', 'bytes', 'tagged_fields'] + def example_value_for_test(self, version): return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST) @@ -467,6 +565,15 @@ def is_printable(self): return self.name not in ['Bytes'] +class FieldSerializationSpec(): + + def __init__(self, field, versions, compute_size_method_name, encode_method_name): + self.field = field + self.versions = versions + self.compute_size_method_name = compute_size_method_name + self.encode_method_name = encode_method_name + + class Complex(TypeSpecification): """ Represents a complex type (multiple types aggregated into one). @@ -477,17 +584,30 @@ def __init__(self, name, fields, versions): self.name = name self.fields = fields self.versions = versions - self.declaration_chain = self.__compute_declaration_chain() + self.flexible_versions = None # Will be set in 'register_flexible_versions'. self.attributes = {} - def __compute_declaration_chain(self): + def register_flexible_versions(self, flexible_versions): + # If flexible versions are present, so we need to add placeholder 'tagged_fields' field to + # *every* type that's used in by this message type. + for type in self.compute_declaration_chain(): + type.flexible_versions = flexible_versions + if len(flexible_versions) > 0: + tagged_fields_field = FieldSpec('tagged_fields', Primitive('tagged_fields', None), + flexible_versions, []) + type.fields.append(tagged_fields_field) + + def compute_declaration_chain(self): """ Computes all dependencies, what means all non-primitive types used by this type. They need to be declared before this struct is declared. """ result = [] for field in self.fields: - result.extend(field.type.declaration_chain) + field_dependencies = field.type.compute_declaration_chain() + for field_dependency in field_dependencies: + if field_dependency not in result: + result.append(field_dependency) result.append(self) return result @@ -528,11 +648,26 @@ def compute_field_lists(self): """ field_lists = [] for version in self.versions: - field_list = FieldList(version, self.fields) + field_list = FieldList(version, version in self.flexible_versions, self.fields) field_lists.append(field_list) return field_lists - def deserializer_name_in_version(self, version): + def compute_serialization_specs(self): + result = [] + for field in self.fields: + if field.type.has_flexible_handling(): + flexible = [x for x in field.version_usage if x in self.flexible_versions] + non_flexible = [x for x in field.version_usage if x not in flexible] + if non_flexible: + result.append(FieldSerializationSpec(field, non_flexible, 'computeSize', 'encode')) + if flexible: + result.append( + FieldSerializationSpec(field, flexible, 'computeCompactSize', 'encodeCompact')) + else: + result.append(FieldSerializationSpec(field, field.version_usage, 'computeSize', 'encode')) + return result + + def deserializer_name_in_version(self, version, compact): return '%sV%dDeserializer' % (self.name, version) def name_in_c_case(self): @@ -543,6 +678,9 @@ def name_in_c_case(self): def default_value(self): raise NotImplementedError('unable to create default value of complex type') + def has_flexible_handling(self): + return False + def example_value_for_test(self, version): field_list = next(fl for fl in self.compute_field_lists() if fl.version == version) example_values = map(lambda x: x.example_value_for_test(version), field_list.used_fields()) diff --git a/source/extensions/filters/network/kafka/protocol/kafka_request_resolver_cc.j2 b/source/extensions/filters/network/kafka/protocol/kafka_request_resolver_cc.j2 index e754248fc113..d82c23864fd7 100644 --- a/source/extensions/filters/network/kafka/protocol/kafka_request_resolver_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol/kafka_request_resolver_cc.j2 @@ -12,6 +12,25 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { +// Implements declaration from 'kafka_request.h'. +bool requestUsesTaggedFieldsInHeader(const uint16_t api_key, const uint16_t api_version) { + switch (api_key) { + {% for message_type in message_types %} + case {{ message_type.get_extra('api_key') }}: + switch (api_version) { + {% for flexible_version in message_type.flexible_versions %} + case {{ flexible_version }}: + return true; + {% endfor %} + default: + return false; + } + {% endfor %} + default: + return false; + } +} + /** * Creates a parser that corresponds to provided key and version. * If corresponding parser cannot be found (what means a newer version of Kafka protocol), diff --git a/source/extensions/filters/network/kafka/protocol/kafka_response_resolver_cc.j2 b/source/extensions/filters/network/kafka/protocol/kafka_response_resolver_cc.j2 index bc107a18dd41..b89df698a5d7 100644 --- a/source/extensions/filters/network/kafka/protocol/kafka_response_resolver_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol/kafka_response_resolver_cc.j2 @@ -11,6 +11,28 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { +// Implements declaration from 'kafka_response.h'. +bool responseUsesTaggedFieldsInHeader(const uint16_t api_key, const uint16_t api_version) { + switch (api_key) { + {% for message_type in message_types %} + case {{ message_type.get_extra('api_key') }}: + switch (api_version) { + {# ApiVersions responses require special handling. #} + {% if message_type.get_extra('api_key') != 18 %} + {% for flexible_version in message_type.flexible_versions %} + case {{ flexible_version }}: + return true; + {% endfor %} + {% endif %} + default: + return false; + } + {% endfor %} + default: + return false; + } +} + /** * Creates a parser that is going to process data specific for given response. * If corresponding parser cannot be found (what means a newer version of Kafka protocol), diff --git a/source/extensions/filters/network/kafka/serialization/generator.py b/source/extensions/filters/network/kafka/serialization/generator.py index 857f39c0eeea..a05012e4837d 100755 --- a/source/extensions/filters/network/kafka/serialization/generator.py +++ b/source/extensions/filters/network/kafka/serialization/generator.py @@ -37,7 +37,7 @@ def get_field_counts(): """ Generate argument counts that should be processed by composite deserializers. """ - return range(1, 11) + return range(1, 12) class RenderingHelper: diff --git a/source/extensions/filters/network/kafka/tagged_fields.h b/source/extensions/filters/network/kafka/tagged_fields.h index cda6a7162e4f..7e60952c03ce 100644 --- a/source/extensions/filters/network/kafka/tagged_fields.h +++ b/source/extensions/filters/network/kafka/tagged_fields.h @@ -79,7 +79,7 @@ class TaggedFieldDeserializer : public Deserializer { ready_ = true; } - return consumed; + return consumed + data_consumed; }; bool ready() const override { return ready_; }; @@ -100,7 +100,7 @@ class TaggedFieldDeserializer : public Deserializer { */ struct TaggedFields { - const std::vector fields_; + std::vector fields_; uint32_t computeCompactSize(const EncodingContext& encoder) const { uint32_t result{0}; diff --git a/test/extensions/filters/network/kafka/BUILD b/test/extensions/filters/network/kafka/BUILD index cf59d94584eb..da6612efa710 100644 --- a/test/extensions/filters/network/kafka/BUILD +++ b/test/extensions/filters/network/kafka/BUILD @@ -39,6 +39,7 @@ envoy_extension_cc_test( deps = [ ":serialization_utilities_lib", "//source/extensions/filters/network/kafka:serialization_lib", + "//source/extensions/filters/network/kafka:tagged_fields_lib", "//test/mocks/server:server_mocks", ], ) diff --git a/test/extensions/filters/network/kafka/broker/integration_test/zookeeper_properties.j2 b/test/extensions/filters/network/kafka/broker/integration_test/zookeeper_properties.j2 index 5a563d1f4d2a..be524bea342b 100644 --- a/test/extensions/filters/network/kafka/broker/integration_test/zookeeper_properties.j2 +++ b/test/extensions/filters/network/kafka/broker/integration_test/zookeeper_properties.j2 @@ -1,3 +1,5 @@ clientPort={{ data['zk_port'] }} dataDir={{ data['data_dir'] }} maxClientCnxns=0 +# ZK 3.5 tries to bind 8080 for introspection capacility - we do not need that. +admin.enableServer=false diff --git a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc index e9c72ad3e9d1..7125e5075071 100644 --- a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc +++ b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc @@ -112,7 +112,7 @@ TEST_F(KafkaRequestParserTest, RequestDataParserShouldHandleDeserializerExceptio int32_t get() const override { throw std::runtime_error("should not be invoked at all"); }; }; - RequestContextSharedPtr request_context{new RequestContext{1024, {}}}; + RequestContextSharedPtr request_context{new RequestContext{1024, {0, 0, 0, absl::nullopt}}}; RequestDataParser testee{request_context}; absl::string_view data = putGarbageIntoBuffer(); @@ -146,7 +146,8 @@ TEST_F(KafkaRequestParserTest, RequestDataParserShouldHandleDeserializerReturningReadyButLeavingData) { // given const int32_t request_size = 1024; // There are still 1024 bytes to read to complete the request. - RequestContextSharedPtr request_context{new RequestContext{request_size, {}}}; + RequestContextSharedPtr request_context{ + new RequestContext{request_size, {0, 0, 0, absl::nullopt}}}; RequestDataParser testee{request_context}; diff --git a/test/extensions/filters/network/kafka/protocol/request_codec_request_test_cc.j2 b/test/extensions/filters/network/kafka/protocol/request_codec_request_test_cc.j2 index c4f872e253d5..744fe0ac9ebe 100644 --- a/test/extensions/filters/network/kafka/protocol/request_codec_request_test_cc.j2 +++ b/test/extensions/filters/network/kafka/protocol/request_codec_request_test_cc.j2 @@ -42,8 +42,18 @@ TEST_F(RequestCodecRequestTest, shouldHandle{{ message_type.name }}Messages) { {% for field_list in message_type.compute_field_lists() %} for (int i = 0; i < 100; ++i ) { - const RequestHeader header = - { {{ message_type.get_extra('api_key') }}, {{ field_list.version }}, correlation++, "id" }; + {# Request header cannot contain tagged fields if request does not support them. #} + const TaggedFields tagged_fields = requestUsesTaggedFieldsInHeader( + {{ message_type.get_extra('api_key') }}, {{ field_list.version }}) ? + TaggedFields{ { TaggedField{ 10, Bytes{1, 2, 3, 4} } } }: + TaggedFields({}); + const RequestHeader header = { + {{ message_type.get_extra('api_key') }}, + {{ field_list.version }}, + correlation++, + "id", + tagged_fields + }; const {{ message_type.name }} data = { {{ field_list.example_value() }} }; const RequestUnderTest request = {header, data}; putMessageIntoBuffer(request); @@ -64,6 +74,7 @@ TEST_F(RequestCodecRequestTest, shouldHandle{{ message_type.name }}Messages) { // then const std::vector& received = callback->getCapturedMessages(); ASSERT_EQ(received.size(), sent.size()); + ASSERT_EQ(received.size(), correlation); for (size_t i = 0; i < received.size(); ++i) { const std::shared_ptr request = diff --git a/test/extensions/filters/network/kafka/protocol/response_codec_response_test_cc.j2 b/test/extensions/filters/network/kafka/protocol/response_codec_response_test_cc.j2 index 87b977f64e49..c85b9d5044ee 100644 --- a/test/extensions/filters/network/kafka/protocol/response_codec_response_test_cc.j2 +++ b/test/extensions/filters/network/kafka/protocol/response_codec_response_test_cc.j2 @@ -45,8 +45,17 @@ TEST_F(ResponseCodecResponseTest, shouldHandle{{ message_type.name }}Messages) { {% for field_list in message_type.compute_field_lists() %} for (int i = 0; i < 100; ++i ) { - const ResponseMetadata metadata = - { {{ message_type.get_extra('api_key') }}, {{ field_list.version }}, ++correlation_id }; + {# Response header cannot contain tagged fields if response does not support them. #} + const TaggedFields tagged_fields = responseUsesTaggedFieldsInHeader( + {{ message_type.get_extra('api_key') }}, {{ field_list.version }}) ? + TaggedFields{ { TaggedField{ 10, Bytes{1, 2, 3, 4} } } }: + TaggedFields({}); + const ResponseMetadata metadata = { + {{ message_type.get_extra('api_key') }}, + {{ field_list.version }}, + ++correlation_id, + tagged_fields, + }; const {{ message_type.name }} data = { {{ field_list.example_value() }} }; const ResponseUnderTest response = {metadata, data}; putMessageIntoBuffer(response); @@ -62,6 +71,7 @@ TEST_F(ResponseCodecResponseTest, shouldHandle{{ message_type.name }}Messages) { // then const std::vector& received = callback->getCapturedMessages(); ASSERT_EQ(received.size(), sent.size()); + ASSERT_EQ(received.size(), correlation_id); for (size_t i = 0; i < received.size(); ++i) { const std::shared_ptr response = diff --git a/test/extensions/filters/network/kafka/request_codec_unit_test.cc b/test/extensions/filters/network/kafka/request_codec_unit_test.cc index c1f6e94ed792..edbeb57c47d0 100644 --- a/test/extensions/filters/network/kafka/request_codec_unit_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_unit_test.cc @@ -139,7 +139,7 @@ TEST_F(RequestCodecUnitTest, shouldPassParsedMessageToCallbackAndInitializeNextP putGarbageIntoBuffer(); const AbstractRequestSharedPtr parsed_message = - std::make_shared>(RequestHeader(), 0); + std::make_shared>(RequestHeader{0, 0, 0, absl::nullopt}, 0); MockParserSharedPtr parser1 = std::make_shared(); EXPECT_CALL(*parser1, parse(_)) @@ -170,7 +170,7 @@ TEST_F(RequestCodecUnitTest, shouldPassParseFailureDataToCallback) { putGarbageIntoBuffer(); const RequestParseFailureSharedPtr failure_data = - std::make_shared(RequestHeader()); + std::make_shared(RequestHeader{0, 0, 0, absl::nullopt}); MockParserSharedPtr parser = std::make_shared(); auto consume_and_return = [&failure_data](absl::string_view& data) -> RequestParseResponse { diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc index 250d35121f17..903a66470c40 100644 --- a/test/extensions/filters/network/kafka/serialization_test.cc +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -1,3 +1,5 @@ +#include "extensions/filters/network/kafka/tagged_fields.h" + #include "test/extensions/filters/network/kafka/serialization_utilities.h" namespace Envoy { @@ -434,6 +436,23 @@ TEST(NullableCompactArrayDeserializer, ShouldConsumeNullArray) { NullableCompactArrayDeserializer>(value); } +// Tagged fields. + +TEST(TaggedFieldDeserializer, ShouldConsumeCorrectAmountOfData) { + const TaggedField value{200, Bytes{1, 2, 3, 4, 5, 6}}; + serializeCompactThenDeserializeAndCheckEquality(value); +} + +TEST(TaggedFieldsDeserializer, ShouldConsumeCorrectAmountOfData) { + std::vector fields; + for (uint32_t i = 0; i < 200; ++i) { + const TaggedField tagged_field = {i, Bytes{1, 2, 3, 4}}; + fields.push_back(tagged_field); + } + const TaggedFields value{fields}; + serializeCompactThenDeserializeAndCheckEquality(value); +} + } // namespace SerializationTest } // namespace Kafka } // namespace NetworkFilters