Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ssl options for Cassandra data source #4665

Merged
merged 9 commits into from
Apr 3, 2020
62 changes: 60 additions & 2 deletions redash/query_runner/cass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import logging
import os
import ssl
from base64 import b64decode
from tempfile import NamedTemporaryFile

from redash.query_runner import BaseQueryRunner, register
from redash.utils import JSONEncoder, json_dumps, json_loads
Expand All @@ -15,6 +19,16 @@
enabled = False


def generate_ssl_options_dict(protocol, cert_path=None):
ssl_options = {
'ssl_version': getattr(ssl, protocol)
}
if cert_path is not None:
ssl_options['ca_certs'] = cert_path
ssl_options['cert_reqs'] = ssl.CERT_REQUIRED
return ssl_options


class CassandraJSONEncoder(JSONEncoder):
def default(self, o):
if isinstance(o, sortedset):
Expand Down Expand Up @@ -45,8 +59,27 @@ def configuration_schema(cls):
"default": 3,
},
"timeout": {"type": "number", "title": "Timeout", "default": 10},
"useSsl": {"type": "boolean", "title": "Use SSL", "default": False},
"sslCertificateFile": {
"type": "string",
"title": "SSL Certificate File"
},
"sslProtocol": {
"type": "string",
arihantsurana marked this conversation as resolved.
Show resolved Hide resolved
"title": "SSL Protocol",
"enum": [
"PROTOCOL_SSLv23",
"PROTOCOL_TLS",
"PROTOCOL_TLS_CLIENT",
"PROTOCOL_TLS_SERVER",
"PROTOCOL_TLSv1",
"PROTOCOL_TLSv1_1",
"PROTOCOL_TLSv1_2",
],
},
},
"required": ["keyspace", "host"],
"required": ["keyspace", "host", "useSsl"],
"secret": ["sslCertificateFile"],
}

@classmethod
Expand Down Expand Up @@ -93,7 +126,7 @@ def get_schema(self, get_stats=False):

def run_query(self, query, user):
connection = None

cert_path = self._generate_cert_file()
if self.configuration.get("username", "") and self.configuration.get(
"password", ""
):
Expand All @@ -106,18 +139,21 @@ def run_query(self, query, user):
auth_provider=auth_provider,
port=self.configuration.get("port", ""),
protocol_version=self.configuration.get("protocol", 3),
ssl_options=self._get_ssl_options(cert_path),
)
else:
connection = Cluster(
[self.configuration.get("host", "")],
port=self.configuration.get("port", ""),
protocol_version=self.configuration.get("protocol", 3),
ssl_options=self._get_ssl_options(cert_path),
)
session = connection.connect()
session.set_keyspace(self.configuration["keyspace"])
session.default_timeout = self.configuration.get("timeout", 10)
logger.debug("Cassandra running query: %s", query)
result = session.execute(query)
self._cleanup_cert_file(cert_path)

column_names = result.column_names

Expand All @@ -130,6 +166,28 @@ def run_query(self, query, user):

return json_data, None

def _generate_cert_file(self):
cert_encoded_bytes = self.configuration.get("sslCertificateFile", None)
if cert_encoded_bytes:
with NamedTemporaryFile(mode='w', delete=False) as cert_file:
cert_bytes = b64decode(cert_encoded_bytes)
cert_file.write(cert_bytes.decode("utf-8"))
return cert_file.name
return None

def _cleanup_cert_file(self, cert_path):
if cert_path:
os.remove(cert_path)

def _get_ssl_options(self, cert_path):
ssl_options = None
if self.configuration.get("useSsl", False):
ssl_options = generate_ssl_options_dict(
protocol=self.configuration["sslProtocol"],
cert_path=cert_path
)
return ssl_options


class ScyllaDB(Cassandra):
@classmethod
Expand Down
22 changes: 22 additions & 0 deletions tests/query_runner/test_cass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import shutil
import ssl
from unittest import TestCase

from redash.query_runner.cass import generate_ssl_options_dict


class TestCassandra(TestCase):

def test_generate_ssl_options_dict_creates_plain_protocol_dict(self):
expected = {'ssl_version': ssl.PROTOCOL_TLSv1_2}
actual = generate_ssl_options_dict("PROTOCOL_TLSv1_2")
self.assertDictEqual(expected, actual)

def test_generate_ssl_options_dict_creates_certificate_dict(self):
expected = {
'ssl_version': ssl.PROTOCOL_TLSv1_2,
'ca_certs': 'some/path',
'cert_reqs': ssl.CERT_REQUIRED,
}
actual = generate_ssl_options_dict("PROTOCOL_TLSv1_2", "some/path")
self.assertDictEqual(expected, actual)