From 611471fcfbefe32b610819b33a744da4c1f5849f Mon Sep 17 00:00:00 2001 From: Xiong Ding Date: Wed, 10 Apr 2024 09:57:51 -0700 Subject: [PATCH] Fix ssl connection (#178) * Fix ssl connection after wrap_ssl * test * refactor * remove global level * test * revert test * address comments --- kafka/client_async.py | 9 ++++- kafka/conn.py | 4 +-- test/fixtures.py | 9 +++-- test/test_ssl_integration.py | 67 ++++++++++++++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 5 deletions(-) create mode 100644 test/test_ssl_integration.py diff --git a/kafka/client_async.py b/kafka/client_async.py index b46b879f9..984cd81fb 100644 --- a/kafka/client_async.py +++ b/kafka/client_async.py @@ -266,7 +266,14 @@ def _conn_state_change(self, node_id, sock, conn): try: self._selector.register(sock, selectors.EVENT_WRITE, conn) except KeyError: - self._selector.modify(sock, selectors.EVENT_WRITE, conn) + # SSL detaches the original socket, and transfers the + # underlying file descriptor to a new SSLSocket. We should + # explicitly unregister the original socket. + if conn.state == ConnectionStates.HANDSHAKE: + self._selector.unregister(sock) + self._selector.register(sock, selectors.EVENT_WRITE, conn) + else: + self._selector.modify(sock, selectors.EVENT_WRITE, conn) if self.cluster.is_bootstrap(node_id): self._last_bootstrap = time.time() diff --git a/kafka/conn.py b/kafka/conn.py index 745e4bca6..b9ef0e2d9 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -378,10 +378,10 @@ def connect(self): if self.config['security_protocol'] in ('SSL', 'SASL_SSL'): log.debug('%s: initiating SSL handshake', self) - self.state = ConnectionStates.HANDSHAKE - self.config['state_change_callback'](self.node_id, self._sock, self) # _wrap_ssl can alter the connection state -- disconnects on failure self._wrap_ssl() + self.state = ConnectionStates.HANDSHAKE + self.config['state_change_callback'](self.node_id, self._sock, self) elif self.config['security_protocol'] == 'SASL_PLAINTEXT': log.debug('%s: initiating SASL authentication', self) diff --git a/test/fixtures.py b/test/fixtures.py index 4ed515da3..998dc429f 100644 --- a/test/fixtures.py +++ b/test/fixtures.py @@ -38,7 +38,7 @@ def gen_ssl_resources(directory): # Step 1 keytool -keystore kafka.server.keystore.jks -alias localhost -validity 1 \ - -genkey -storepass foobar -keypass foobar \ + -genkey -keyalg RSA -storepass foobar -keypass foobar \ -dname "CN=localhost, OU=kafka-python, O=kafka-python, L=SF, ST=CA, C=US" \ -ext SAN=dns:localhost @@ -289,7 +289,7 @@ def __init__(self, host, port, broker_id, zookeeper, zk_chroot, self.sasl_mechanism = sasl_mechanism.upper() else: self.sasl_mechanism = None - self.ssl_dir = self.test_resource('ssl') + self.ssl_dir = None # TODO: checking for port connection would be better than scanning logs # until then, we need the pattern to work across all supported broker versions @@ -410,6 +410,8 @@ def start(self): jaas_conf = self.tmp_dir.join("kafka_server_jaas.conf") properties_template = self.test_resource("kafka.properties") jaas_conf_template = self.test_resource("kafka_server_jaas.conf") + self.ssl_dir = self.tmp_dir + gen_ssl_resources(self.ssl_dir.strpath) args = self.kafka_run_class_args("kafka.Kafka", properties.strpath) env = self.kafka_run_class_env() @@ -641,6 +643,9 @@ def _enrich_client_params(self, params, **defaults): if self.sasl_mechanism in ('PLAIN', 'SCRAM-SHA-256', 'SCRAM-SHA-512'): params.setdefault('sasl_plain_username', self.broker_user) params.setdefault('sasl_plain_password', self.broker_password) + if self.transport in ["SASL_SSL", "SSL"]: + params.setdefault("ssl_cafile", self.ssl_dir.join('ca-cert').strpath) + params.setdefault("security_protocol", self.transport) return params @staticmethod diff --git a/test/test_ssl_integration.py b/test/test_ssl_integration.py new file mode 100644 index 000000000..8453e7831 --- /dev/null +++ b/test/test_ssl_integration.py @@ -0,0 +1,67 @@ +import logging +import uuid + +import pytest + +from kafka.admin import NewTopic +from kafka.protocol.metadata import MetadataRequest_v1 +from test.testutil import assert_message_count, env_kafka_version, random_string, special_to_underscore + + +@pytest.fixture(scope="module") +def ssl_kafka(request, kafka_broker_factory): + return kafka_broker_factory(transport="SSL")[0] + + +@pytest.mark.skipif(env_kafka_version() < (0, 10), reason="Inter broker SSL was implemented at version 0.9") +def test_admin(request, ssl_kafka): + topic_name = special_to_underscore(request.node.name + random_string(4)) + admin, = ssl_kafka.get_admin_clients(1) + admin.create_topics([NewTopic(topic_name, 1, 1)]) + assert topic_name in ssl_kafka.get_topic_names() + + +@pytest.mark.skipif(env_kafka_version() < (0, 10), reason="Inter broker SSL was implemented at version 0.9") +def test_produce_and_consume(request, ssl_kafka): + topic_name = special_to_underscore(request.node.name + random_string(4)) + ssl_kafka.create_topics([topic_name], num_partitions=2) + producer, = ssl_kafka.get_producers(1) + + messages_and_futures = [] # [(message, produce_future),] + for i in range(100): + encoded_msg = "{}-{}-{}".format(i, request.node.name, uuid.uuid4()).encode("utf-8") + future = producer.send(topic_name, value=encoded_msg, partition=i % 2) + messages_and_futures.append((encoded_msg, future)) + producer.flush() + + for (msg, f) in messages_and_futures: + assert f.succeeded() + + consumer, = ssl_kafka.get_consumers(1, [topic_name]) + messages = {0: [], 1: []} + for i, message in enumerate(consumer, 1): + logging.debug("Consumed message %s", repr(message)) + messages[message.partition].append(message) + if i >= 100: + break + + assert_message_count(messages[0], 50) + assert_message_count(messages[1], 50) + + +@pytest.mark.skipif(env_kafka_version() < (0, 10), reason="Inter broker SSL was implemented at version 0.9") +def test_client(request, ssl_kafka): + topic_name = special_to_underscore(request.node.name + random_string(4)) + ssl_kafka.create_topics([topic_name], num_partitions=1) + + client, = ssl_kafka.get_clients(1) + request = MetadataRequest_v1(None) + client.send(0, request) + for _ in range(10): + result = client.poll(timeout_ms=10000) + if len(result) > 0: + break + else: + raise RuntimeError("Couldn't fetch topic response from Broker.") + result = result[0] + assert topic_name in [t[1] for t in result.topics]