diff --git a/client/app/pages/queries/hooks/useDataSourceSchema.js b/client/app/pages/queries/hooks/useDataSourceSchema.js index 0e7588ab0e..2c1b0ea4d4 100644 --- a/client/app/pages/queries/hooks/useDataSourceSchema.js +++ b/client/app/pages/queries/hooks/useDataSourceSchema.js @@ -1,22 +1,36 @@ import { reduce } from "lodash"; import { useCallback, useEffect, useRef, useState, useMemo } from "react"; +import { axios } from "@/services/axios"; import DataSource, { SCHEMA_NOT_SUPPORTED } from "@/services/data-source"; import notification from "@/services/notification"; +function sleep(ms) { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + function getSchema(dataSource, refresh = undefined) { if (!dataSource) { return Promise.resolve([]); } + const fetchSchemaFromJob = (data) => { + return sleep(1000).then(() => { + return axios.get(`api/jobs/${data.job.id}`).then((data) => { + if (data.job.status < 3) { + return fetchSchemaFromJob(data); + } else if (data.job.status === 3) { + return data.job.result; + } else if (data.job.status === 4 && data.job.error.code === SCHEMA_NOT_SUPPORTED) { + return []; + } else { + return Promise.reject(new Error(data.job.error)); + } + }); + }); + }; + return DataSource.fetchSchema(dataSource, refresh) - .then(data => { - if (data.schema) { - return data.schema; - } else if (data.error.code === SCHEMA_NOT_SUPPORTED) { - return []; - } - return Promise.reject(new Error("Schema refresh failed.")); - }) + .then(fetchSchemaFromJob) .catch(() => { notification.error("Schema refresh failed.", "Please try again later."); return Promise.resolve([]); @@ -34,11 +48,9 @@ export default function useDataSourceSchema(dataSource) { const reloadSchema = useCallback( (refresh = undefined) => { - const refreshToken = Math.random() - .toString(36) - .substr(2); + const refreshToken = Math.random().toString(36).substr(2); refreshSchemaTokenRef.current = refreshToken; - getSchema(dataSource, refresh).then(data => { + getSchema(dataSource, refresh).then((data) => { if (refreshSchemaTokenRef.current === refreshToken) { setSchema(prepareSchema(data)); } diff --git a/client/app/services/data-source.js b/client/app/services/data-source.js index c7de5d086d..80497baf13 100644 --- a/client/app/services/data-source.js +++ b/client/app/services/data-source.js @@ -8,9 +8,9 @@ const DataSource = { query: () => axios.get("api/data_sources"), get: ({ id }) => axios.get(`api/data_sources/${id}`), types: () => axios.get("api/data_sources/types"), - create: data => axios.post(`api/data_sources`, data), - save: data => axios.post(`api/data_sources/${data.id}`, data), - test: data => axios.post(`api/data_sources/${data.id}/test`), + create: (data) => axios.post(`api/data_sources`, data), + save: (data) => axios.post(`api/data_sources/${data.id}`, data), + test: (data) => axios.post(`api/data_sources/${data.id}/test`), delete: ({ id }) => axios.delete(`api/data_sources/${id}`), fetchSchema: (data, refresh = false) => { const params = {}; diff --git a/redash/handlers/data_sources.py b/redash/handlers/data_sources.py index 94ad0abb11..b60c204b25 100644 --- a/redash/handlers/data_sources.py +++ b/redash/handlers/data_sources.py @@ -1,4 +1,5 @@ import logging +import time from flask import make_response, request from flask_restful import abort @@ -20,12 +21,16 @@ ) from redash.utils import filter_none from redash.utils.configuration import ConfigurationContainer, ValidationError +from redash.tasks.general import test_connection, get_schema +from redash.serializers import serialize_job class DataSourceTypeListResource(BaseResource): @require_admin def get(self): - return [q.to_dict() for q in sorted(query_runners.values(), key=lambda q: q.name())] + return [ + q.to_dict() for q in sorted(query_runners.values(), key=lambda q: q.name()) + ] class DataSourceResource(BaseResource): @@ -182,19 +187,9 @@ def get(self, data_source_id): require_access(data_source, self.current_user, view_only) refresh = request.args.get("refresh") is not None - response = {} - - try: - response["schema"] = data_source.get_schema(refresh) - except NotSupported: - response["error"] = { - "code": 1, - "message": "Data source type does not support retrieving schema", - } - except Exception: - response["error"] = {"code": 2, "message": "Error retrieving schema."} + job = get_schema.delay(data_source.id, refresh) - return response + return serialize_job(job) class DataSourcePauseResource(BaseResource): @@ -245,10 +240,14 @@ def post(self, data_source_id): ) response = {} - try: - data_source.query_runner.test_connection() - except Exception as e: - response = {"message": str(e), "ok": False} + + job = test_connection.delay(data_source.id) + while not (job.is_finished or job.is_failed): + time.sleep(1) + job.refresh() + + if isinstance(job.result, Exception): + response = {"message": str(job.result), "ok": False} else: response = {"message": "success", "ok": True} diff --git a/redash/models/__init__.py b/redash/models/__init__.py index 6304afc44d..1af9661771 100644 --- a/redash/models/__init__.py +++ b/redash/models/__init__.py @@ -24,6 +24,7 @@ ) from redash.metrics import database # noqa: F401 from redash.query_runner import ( + with_ssh_tunnel, get_configuration_schema_for_query_runner_type, get_query_runner, TYPE_BOOLEAN, @@ -251,9 +252,18 @@ def update_group_permission(self, group, view_only): db.session.add(dsg) return dsg + @property + def uses_ssh_tunnel(self): + return "ssh_tunnel" in self.options + @property def query_runner(self): - return get_query_runner(self.type, self.options) + query_runner = get_query_runner(self.type, self.options) + + if self.uses_ssh_tunnel: + query_runner = with_ssh_tunnel(query_runner, self.options.get("ssh_tunnel")) + + return query_runner @classmethod def get_by_name(cls, name): diff --git a/redash/query_runner/__init__.py b/redash/query_runner/__init__.py index 086a05f680..8235be5842 100644 --- a/redash/query_runner/__init__.py +++ b/redash/query_runner/__init__.py @@ -1,8 +1,11 @@ import logging +from contextlib import ExitStack from dateutil import parser +from functools import wraps import requests +from sshtunnel import open_tunnel from redash import settings from redash.utils import json_loads from rq.timeouts import JobTimeoutException @@ -70,6 +73,58 @@ def type(cls): def enabled(cls): return True + @property + def host(self): + """Returns this query runner's configured host. + This is used primarily for temporarily swapping endpoints when using SSH tunnels to connect to a data source. + + `BaseQueryRunner`'s naïve implementation supports query runner implementations that store endpoints using `host` and `port` + configuration values. If your query runner uses a different schema (e.g. a web address), you should override this function. + """ + if "host" in self.configuration: + return self.configuration["host"] + else: + raise NotImplementedError() + + @host.setter + def host(self, host): + """Sets this query runner's configured host. + This is used primarily for temporarily swapping endpoints when using SSH tunnels to connect to a data source. + + `BaseQueryRunner`'s naïve implementation supports query runner implementations that store endpoints using `host` and `port` + configuration values. If your query runner uses a different schema (e.g. a web address), you should override this function. + """ + if "host" in self.configuration: + self.configuration["host"] = host + else: + raise NotImplementedError() + + @property + def port(self): + """Returns this query runner's configured port. + This is used primarily for temporarily swapping endpoints when using SSH tunnels to connect to a data source. + + `BaseQueryRunner`'s naïve implementation supports query runner implementations that store endpoints using `host` and `port` + configuration values. If your query runner uses a different schema (e.g. a web address), you should override this function. + """ + if "port" in self.configuration: + return self.configuration["port"] + else: + raise NotImplementedError() + + @port.setter + def port(self, port): + """Sets this query runner's configured port. + This is used primarily for temporarily swapping endpoints when using SSH tunnels to connect to a data source. + + `BaseQueryRunner`'s naïve implementation supports query runner implementations that store endpoints using `host` and `port` + configuration values. If your query runner uses a different schema (e.g. a web address), you should override this function. + """ + if "port" in self.configuration: + self.configuration["port"] = port + else: + raise NotImplementedError() + @classmethod def configuration_schema(cls): return {} @@ -127,7 +182,7 @@ def to_dict(cls): "name": cls.name(), "type": cls.type(), "configuration_schema": cls.configuration_schema(), - **({ "deprecated": True } if cls.deprecated else {}) + **({"deprecated": True} if cls.deprecated else {}), } @@ -303,3 +358,46 @@ def guess_type_from_string(string_value): pass return TYPE_STRING + + +def with_ssh_tunnel(query_runner, details): + def tunnel(f): + @wraps(f) + def wrapper(*args, **kwargs): + try: + remote_host, remote_port = query_runner.host, query_runner.port + except NotImplementedError: + raise NotImplementedError( + "SSH tunneling is not implemented for this query runner yet." + ) + + stack = ExitStack() + try: + bastion_address = (details["ssh_host"], details.get("ssh_port", 22)) + remote_address = (remote_host, remote_port) + auth = { + "ssh_username": details["ssh_username"], + **settings.dynamic_settings.ssh_tunnel_auth(), + } + server = stack.enter_context( + open_tunnel( + bastion_address, remote_bind_address=remote_address, **auth + ) + ) + except Exception as error: + raise type(error)("SSH tunnel: {}".format(str(error))) + + with stack: + try: + query_runner.host, query_runner.port = server.local_bind_address + result = f(*args, **kwargs) + finally: + query_runner.host, query_runner.port = remote_host, remote_port + + return result + + return wrapper + + query_runner.run_query = tunnel(query_runner.run_query) + + return query_runner diff --git a/redash/query_runner/clickhouse.py b/redash/query_runner/clickhouse.py index b76f812b92..c2a1c6ebb5 100644 --- a/redash/query_runner/clickhouse.py +++ b/redash/query_runner/clickhouse.py @@ -1,5 +1,6 @@ import logging import re +from urllib.parse import urlparse import requests @@ -42,6 +43,30 @@ def configuration_schema(cls): def type(cls): return "clickhouse" + @property + def _url(self): + return urlparse(self.configuration["url"]) + + @_url.setter + def _url(self, url): + self.configuration["url"] = url.geturl() + + @property + def host(self): + return self._url.hostname + + @host.setter + def host(self, host): + self._url = self._url._replace(netloc="{}:{}".format(host, self._url.port)) + + @property + def port(self): + return self._url.port + + @port.setter + def port(self, port): + self._url = self._url._replace(netloc="{}:{}".format(self._url.hostname, port)) + def _get_tables(self, schema): query = "SELECT database, table, name FROM system.columns WHERE database NOT IN ('system')" diff --git a/redash/serializers/__init__.py b/redash/serializers/__init__.py index 992d7d6b56..b0782ac319 100644 --- a/redash/serializers/__init__.py +++ b/redash/serializers/__init__.py @@ -284,7 +284,7 @@ def serialize_job(job): updated_at = 0 status = STATUSES[job_status] - query_result_id = None + result = query_result_id = None if job.is_cancelled: error = "Query cancelled by user." @@ -292,9 +292,12 @@ def serialize_job(job): elif isinstance(job.result, Exception): error = str(job.result) status = 4 + elif isinstance(job.result, dict) and "error" in job.result: + error = job.result["error"] + status = 4 else: error = "" - query_result_id = job.result + result = query_result_id = job.result return { "job": { @@ -302,6 +305,7 @@ def serialize_job(job): "updated_at": updated_at, "status": status, "error": error, + "result": result, "query_result_id": query_result_id, } } diff --git a/redash/settings/dynamic_settings.py b/redash/settings/dynamic_settings.py index 33f3f40277..145308d356 100644 --- a/redash/settings/dynamic_settings.py +++ b/redash/settings/dynamic_settings.py @@ -25,3 +25,15 @@ def periodic_jobs(): # This provides the ability to override the way we store QueryResult's data column. # Reference implementation: redash.models.DBPersistence QueryResultPersistence = None + + +def ssh_tunnel_auth(): + """ + To enable data source connections via SSH tunnels, provide your SSH authentication + pkey here. Return a string pointing at your **private** key's path (which will be used + to extract the public key), or a `paramiko.pkey.PKey` instance holding your **public** key. + """ + return { + # 'ssh_pkey': 'path_to_private_key', # or instance of `paramiko.pkey.PKey` + # 'ssh_private_key_password': 'optional_passphrase_of_private_key', + } \ No newline at end of file diff --git a/redash/tasks/general.py b/redash/tasks/general.py index d120576a4f..7b2c7287dd 100644 --- a/redash/tasks/general.py +++ b/redash/tasks/general.py @@ -9,6 +9,8 @@ from redash.models import users from redash.version_check import run_version_check from redash.worker import job, get_job_logger +from redash.tasks.worker import Queue +from redash.query_runner import NotSupported logger = get_job_logger(__name__) @@ -63,6 +65,33 @@ def send_mail(to, subject, html, text): logger.exception("Failed sending message: %s", message.subject) +@job("queries", timeout=30, ttl=90) +def test_connection(data_source_id): + try: + data_source = models.DataSource.get_by_id(data_source_id) + data_source.query_runner.test_connection() + except Exception as e: + return e + else: + return True + + +@job("schemas", queue_class=Queue, at_front=True, timeout=300, ttl=90) +def get_schema(data_source_id, refresh): + try: + data_source = models.DataSource.get_by_id(data_source_id) + return data_source.get_schema(refresh) + except NotSupported: + return { + "error": { + "code": 1, + "message": "Data source type does not support retrieving schema", + } + } + except Exception: + return {"error": {"code": 2, "message": "Error retrieving schema."}} + + def sync_user_details(): users.sync_last_active_at() diff --git a/requirements.txt b/requirements.txt index a5df9e404d..77ffbe2039 100644 --- a/requirements.txt +++ b/requirements.txt @@ -55,6 +55,7 @@ maxminddb-geolite2==2018.703 pypd==1.1.0 disposable-email-domains>=0.0.52 gevent==1.4.0 +sshtunnel==0.1.5 supervisor==4.1.0 supervisor_checks==0.8.1 werkzeug==0.16.1