diff --git a/redash/query_runner/cass.py b/redash/query_runner/cass.py index 9f513eb5c3..4a49c0d9b1 100644 --- a/redash/query_runner/cass.py +++ b/redash/query_runner/cass.py @@ -1,4 +1,8 @@ import logging +import os +import ssl +from base64 import b64decode +from tempfile import NamedTemporaryFile from redash.query_runner import BaseQueryRunner, register from redash.utils import JSONEncoder, json_dumps, json_loads @@ -15,6 +19,16 @@ enabled = False +def generate_ssl_options_dict(protocol, cert_path=None): + ssl_options = { + 'ssl_version': getattr(ssl, protocol) + } + if cert_path is not None: + ssl_options['ca_certs'] = cert_path + ssl_options['cert_reqs'] = ssl.CERT_REQUIRED + return ssl_options + + class CassandraJSONEncoder(JSONEncoder): def default(self, o): if isinstance(o, sortedset): @@ -45,8 +59,27 @@ def configuration_schema(cls): "default": 3, }, "timeout": {"type": "number", "title": "Timeout", "default": 10}, + "useSsl": {"type": "boolean", "title": "Use SSL", "default": False}, + "sslCertificateFile": { + "type": "string", + "title": "SSL Certificate File" + }, + "sslProtocol": { + "type": "string", + "title": "SSL Protocol", + "enum": [ + "PROTOCOL_SSLv23", + "PROTOCOL_TLS", + "PROTOCOL_TLS_CLIENT", + "PROTOCOL_TLS_SERVER", + "PROTOCOL_TLSv1", + "PROTOCOL_TLSv1_1", + "PROTOCOL_TLSv1_2", + ], + }, }, - "required": ["keyspace", "host"], + "required": ["keyspace", "host", "useSsl"], + "secret": ["sslCertificateFile"], } @classmethod @@ -93,7 +126,7 @@ def get_schema(self, get_stats=False): def run_query(self, query, user): connection = None - + cert_path = self._generate_cert_file() if self.configuration.get("username", "") and self.configuration.get( "password", "" ): @@ -106,18 +139,21 @@ def run_query(self, query, user): auth_provider=auth_provider, port=self.configuration.get("port", ""), protocol_version=self.configuration.get("protocol", 3), + ssl_options=self._get_ssl_options(cert_path), ) else: connection = Cluster( [self.configuration.get("host", "")], port=self.configuration.get("port", ""), protocol_version=self.configuration.get("protocol", 3), + ssl_options=self._get_ssl_options(cert_path), ) session = connection.connect() session.set_keyspace(self.configuration["keyspace"]) session.default_timeout = self.configuration.get("timeout", 10) logger.debug("Cassandra running query: %s", query) result = session.execute(query) + self._cleanup_cert_file(cert_path) column_names = result.column_names @@ -130,6 +166,28 @@ def run_query(self, query, user): return json_data, None + def _generate_cert_file(self): + cert_encoded_bytes = self.configuration.get("sslCertificateFile", None) + if cert_encoded_bytes: + with NamedTemporaryFile(mode='w', delete=False) as cert_file: + cert_bytes = b64decode(cert_encoded_bytes) + cert_file.write(cert_bytes.decode("utf-8")) + return cert_file.name + return None + + def _cleanup_cert_file(self, cert_path): + if cert_path: + os.remove(cert_path) + + def _get_ssl_options(self, cert_path): + ssl_options = None + if self.configuration.get("useSsl", False): + ssl_options = generate_ssl_options_dict( + protocol=self.configuration["sslProtocol"], + cert_path=cert_path + ) + return ssl_options + class ScyllaDB(Cassandra): @classmethod diff --git a/tests/query_runner/test_cass.py b/tests/query_runner/test_cass.py new file mode 100644 index 0000000000..6f7fcca78d --- /dev/null +++ b/tests/query_runner/test_cass.py @@ -0,0 +1,22 @@ +import shutil +import ssl +from unittest import TestCase + +from redash.query_runner.cass import generate_ssl_options_dict + + +class TestCassandra(TestCase): + + def test_generate_ssl_options_dict_creates_plain_protocol_dict(self): + expected = {'ssl_version': ssl.PROTOCOL_TLSv1_2} + actual = generate_ssl_options_dict("PROTOCOL_TLSv1_2") + self.assertDictEqual(expected, actual) + + def test_generate_ssl_options_dict_creates_certificate_dict(self): + expected = { + 'ssl_version': ssl.PROTOCOL_TLSv1_2, + 'ca_certs': 'some/path', + 'cert_reqs': ssl.CERT_REQUIRED, + } + actual = generate_ssl_options_dict("PROTOCOL_TLSv1_2", "some/path") + self.assertDictEqual(expected, actual)