diff --git a/.gitignore b/.gitignore index 4c93a874bad0e..eb93aa3c20e7e 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,8 @@ dist caravel.egg-info/ app.db *.bak +.idea +*.sqllite # Node.js, webpack artifacts *.entry.js diff --git a/caravel/bin/caravel b/caravel/bin/caravel index 3ae2b68516da3..51d0f63c4e1fa 100755 --- a/caravel/bin/caravel +++ b/caravel/bin/caravel @@ -5,9 +5,10 @@ from __future__ import print_function from __future__ import unicode_literals import logging +import celery +from celery.bin import worker as celery_worker from datetime import datetime from subprocess import Popen -import textwrap from flask_migrate import MigrateCommand from flask_script import Manager @@ -127,5 +128,24 @@ def refresh_druid(): session.commit() +@manager.command +def worker(): + """Starts a Caravel worker for async SQL query execution.""" + # celery -A tasks worker --loglevel=info + print("Starting SQL Celery worker.") + if config.get('CELERY_CONFIG'): + print("Celery broker url: ") + print(config.get('CELERY_CONFIG').BROKER_URL) + + application = celery.current_app._get_current_object() + c_worker = celery_worker.worker(app=application) + options = { + 'broker': config.get('CELERY_CONFIG').BROKER_URL, + 'loglevel': 'INFO', + 'traceback': True, + } + c_worker.run(**options) + + if __name__ == "__main__": manager.run() diff --git a/caravel/config.py b/caravel/config.py index 79c87d0a9215a..f922b9db02aaa 100644 --- a/caravel/config.py +++ b/caravel/config.py @@ -179,7 +179,22 @@ # Set this API key to enable Mapbox visualizations MAPBOX_API_KEY = "" +# Maximum number of rows returned in the SQL editor +SQL_MAX_ROW = 1000 +# Default celery config is to use SQLA as a broker, in a production setting +# you'll want to use a proper broker as specified here: +# http://docs.celeryproject.org/en/latest/getting-started/brokers/index.html +""" +# Example: +class CeleryConfig(object): + BROKER_URL = 'sqla+sqlite:///celerydb.sqlite' + CELERY_IMPORTS = ('caravel.tasks', ) + CELERY_RESULT_BACKEND = 'db+sqlite:///celery_results.sqlite' + CELERY_ANNOTATIONS = {'tasks.add': {'rate_limit': '10/s'}} +CELERY_CONFIG = CeleryConfig +""" +CELERY_CONFIG = None try: from caravel_config import * # noqa @@ -188,3 +203,4 @@ if not CACHE_DEFAULT_TIMEOUT: CACHE_DEFAULT_TIMEOUT = CACHE_CONFIG.get('CACHE_DEFAULT_TIMEOUT') + diff --git a/caravel/migrations/versions/ad82a75afd82_add_query_model.py b/caravel/migrations/versions/ad82a75afd82_add_query_model.py new file mode 100644 index 0000000000000..4794f416de07f --- /dev/null +++ b/caravel/migrations/versions/ad82a75afd82_add_query_model.py @@ -0,0 +1,39 @@ +"""Update models to support storing the queries. + +Revision ID: ad82a75afd82 +Revises: f162a1dea4c4 +Create Date: 2016-07-25 17:48:12.771103 + +""" + +# revision identifiers, used by Alembic. +revision = 'ad82a75afd82' +down_revision = 'f162a1dea4c4' + +from alembic import op +import sqlalchemy as sa + +def upgrade(): + op.create_table('query', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('database_id', sa.Integer(), nullable=False), + sa.Column('tmp_table_name', sa.String(length=64), nullable=True), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.Column('status', sa.String(length=16), nullable=True), + sa.Column('name', sa.String(length=64), nullable=True), + sa.Column('sql', sa.Text, nullable=True), + sa.Column('limit', sa.Integer(), nullable=True), + sa.Column('progress', sa.Integer(), nullable=True), + sa.Column('start_time', sa.DateTime(), nullable=True), + sa.Column('end_time', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['database_id'], [u'dbs.id'], ), + sa.ForeignKeyConstraint(['user_id'], [u'ab_user.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.add_column('dbs', sa.Column('select_as_create_table_as', sa.Boolean(), + nullable=True)) + + +def downgrade(): + op.drop_table('query') + op.drop_column('dbs', 'select_as_create_table_as') diff --git a/caravel/models.py b/caravel/models.py index 5c46293d6b809..6dbeb53c6a313 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -32,8 +32,9 @@ from pydruid.utils.having import Aggregation from six import string_types from sqlalchemy import ( - Column, Integer, String, ForeignKey, Text, Boolean, DateTime, Date, - Table, create_engine, MetaData, desc, asc, select, and_, func) + Column, Integer, String, ForeignKey, Text, Boolean, DateTime, Date, Table, + create_engine, MetaData, desc, asc, select, and_, func +) from sqlalchemy.engine import reflection from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import relationship @@ -378,6 +379,7 @@ class Database(Model, AuditMixinNullable): sqlalchemy_uri = Column(String(1024)) password = Column(EncryptedType(String(1024), config.get('SECRET_KEY'))) cache_timeout = Column(Integer) + select_as_create_table_as = Column(Boolean, default=True) extra = Column(Text, default=textwrap.dedent("""\ { "metadata_params": {}, @@ -1701,3 +1703,39 @@ class FavStar(Model): class_name = Column(String(50)) obj_id = Column(Integer) dttm = Column(DateTime, default=func.now()) + + +class QueryStatus: + SCHEDULED = 'SCHEDULED' + CANCELLED = 'CANCELLED' + IN_PROGRESS = 'IN_PROGRESS' + FINISHED = 'FINISHED' + TIMED_OUT = 'TIMED_OUT' + FAILED = 'FAILED' + + +class Query(Model): + + """ORM model for SQL query""" + + __tablename__ = 'query' + id = Column(Integer, primary_key=True) + + database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False) + + # Store the tmp table into the DB only if the user asks for it. + tmp_table_name = Column(String(64)) + user_id = Column(Integer, ForeignKey('ab_user.id'), nullable=True) + + # models.QueryStatus + status = Column(String(16)) + + name = Column(String(64)) + sql = Column(Text) + # Could be configured in the caravel config + limit = Column(Integer) + + # 1..100 + progress = Column(Integer) + start_time = Column(DateTime) + end_time = Column(DateTime) diff --git a/caravel/tasks.py b/caravel/tasks.py new file mode 100644 index 0000000000000..c48e66997456a --- /dev/null +++ b/caravel/tasks.py @@ -0,0 +1,219 @@ +import celery +from caravel import models, app, utils +from datetime import datetime +import logging +from sqlalchemy import create_engine, select, text +from sqlalchemy.orm import scoped_session, sessionmaker +from sqlalchemy.sql.expression import TextAsFrom +import sqlparse +import pandas as pd + +celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG')) + + +def is_query_select(sql): + try: + return sqlparse.parse(sql)[0].get_type() == 'SELECT' + # Capture sqlparse exceptions, worker shouldn't fail here. + except Exception: + # TODO(bkyryliuk): add logging here. + return False + + +# if sqlparse provides the stream of tokens but don't provide the API +# to access the table names, more on it: +# https://groups.google.com/forum/#!topic/sqlparse/sL2aAi6dSJU +# https://github.com/andialbrecht/sqlparse/blob/master/examples/ +# extract_table_names.py +# +# Another approach would be to run the EXPLAIN on the sql statement: +# https://prestodb.io/docs/current/sql/explain.html +# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Explain +def get_tables(): + """Retrieves the query names from the query.""" + # TODO(bkyryliuk): implement parsing the sql statement. + pass + + +def add_limit_to_the_query(sql, limit, eng): + # Treat as single sql statement in case of failure. + sql_statements = [sql] + try: + sql_statements = [s for s in sqlparse.split(sql) if s] + except Exception as e: + logging.info( + "Statement " + sql + "failed to be transformed to have the limit " + "with the exception" + e.message) + return sql + if len(sql_statements) == 1 and is_query_select(sql): + qry = select('*').select_from( + TextAsFrom(text(sql_statements[0]), ['*']).alias( + 'inner_qry')).limit(limit) + sql_statement = str(qry.compile( + eng, compile_kwargs={"literal_binds": True})) + return sql_statement + return sql + + +# create table works only for the single statement. +def create_table_as(sql, table_name, override=False): + """Reformats the query into the create table as query. + + Works only for the single select SQL statements, in all other cases + the sql query is not modified. + :param sql: string, sql query that will be executed + :param table_name: string, will contain the results of the query execution + :param override, boolean, table table_name will be dropped if true + :return: string, create table as query + """ + # TODO(bkyryliuk): drop table if allowed, check the namespace and + # the permissions. + # Treat as single sql statement in case of failure. + sql_statements = [sql] + try: + # Filter out empty statements. + sql_statements = [s for s in sqlparse.split(sql) if s] + except Exception as e: + logging.info( + "Statement " + sql + "failed to be transformed as create table as " + "with the exception" + e.message) + return sql + if len(sql_statements) == 1 and is_query_select(sql): + updated_sql = '' + # TODO(bkyryliuk): use sqlalchemy statements for the + # the drop and create operations. + if override: + updated_sql = 'DROP TABLE IF EXISTS {};\n'.format(table_name) + updated_sql += "CREATE TABLE %s AS %s" % ( + table_name, sql_statements[0]) + return updated_sql + return sql + + +def get_session(): + """Creates new SQLAlchemy scoped_session.""" + engine = create_engine( + app.config.get('SQLALCHEMY_DATABASE_URI'), convert_unicode=True) + return scoped_session(sessionmaker( + autocommit=False, autoflush=False, bind=engine)) + + +@celery_app.task +def get_sql_results(database_id, sql, user_id, tmp_table_name="", schema=None): + """Executes the sql query returns the results. + + :param database_id: integer + :param sql: string, query that will be executed + :param user_id: integer + :param tmp_table_name: name of the table for CTA + :param schema: string, name of the schema (used in presto) + :return: dataframe, query result + """ + # Create a separate session, reusing the db.session leads to the + # concurrency issues. + session = get_session() + try: + db_to_query = ( + session.query(models.Database).filter_by(id=database_id).first() + ) + except Exception as e: + return { + 'error': utils.error_msg_from_exception(e), + 'success': False, + } + if not db_to_query: + return { + 'error': "Database with id {0} is missing.".format(database_id), + 'success': False, + } + + # TODO(bkyryliuk): provide a way for the user to name the query. + # TODO(bkyryliuk): run explain query to derive the tables and fill in the + # table_ids + # TODO(bkyryliuk): check the user permissions + # TODO(bkyryliuk): store the tab name in the query model + limit = app.config.get('SQL_MAX_ROW', None) + start_time = datetime.now() + if not tmp_table_name: + tmp_table_name = 'tmp.{}_table_{}'.format(user_id, start_time) + query = models.Query( + user_id=user_id, + database_id=database_id, + limit=limit, + name='{}'.format(start_time), + sql=sql, + start_time=start_time, + tmp_table_name=tmp_table_name, + status=models.QueryStatus.IN_PROGRESS, + ) + session.add(query) + session.commit() + query_result = get_sql_results_as_dict( + db_to_query, sql, query.tmp_table_name, schema=schema) + query.end_time = datetime.now() + if query_result['success']: + query.status = models.QueryStatus.FINISHED + else: + query.status = models.QueryStatus.FAILED + session.commit() + # TODO(bkyryliuk): return the tmp table / query_id + return query_result + + +# TODO(bkyryliuk): merge the changes made in the carapal first +# before merging this PR. +def get_sql_results_as_dict(db_to_query, sql, tmp_table_name, schema=None): + """Get the SQL query results from the give session and db connection. + + :param sql: string, query that will be executed + :param db_to_query: models.Database to query, cannot be None + :param tmp_table_name: name of the table for CTA + :param schema: string, name of the schema (used in presto) + :return: (dataframe, boolean), results and the status + """ + eng = db_to_query.get_sqla_engine(schema=schema) + sql = sql.strip().strip(';') + # TODO(bkyryliuk): fix this case for multiple statements + if app.config.get('SQL_MAX_ROW'): + sql = add_limit_to_the_query( + sql, app.config.get("SQL_MAX_ROW"), eng) + + cta_used = False + if (app.config.get('SQL_SELECT_AS_CTA') and + db_to_query.select_as_create_table_as and is_query_select(sql)): + # TODO(bkyryliuk): figure out if the query is select query. + sql = create_table_as(sql, tmp_table_name) + cta_used = True + + if cta_used: + try: + eng.execute(sql) + return { + 'tmp_table': tmp_table_name, + 'success': True, + } + except Exception as e: + return { + 'error': utils.error_msg_from_exception(e), + 'success': False, + } + + # otherwise run regular SQL query. + # TODO(bkyryliuk): rewrite into eng.execute as queries different from + # select should be permitted too. + try: + df = db_to_query.get_df(sql, schema) + df = df.fillna(0) + return { + 'columns': [c for c in df.columns], + 'data': df.to_dict(orient='records'), + 'success': True, + } + + except Exception as e: + return { + 'error': utils.error_msg_from_exception(e), + 'success': False, + } + + diff --git a/caravel/views.py b/caravel/views.py index da306510e64b3..b7a206b2c660c 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -32,7 +32,9 @@ from wtforms.validators import ValidationError import caravel -from caravel import appbuilder, db, models, viz, utils, app, sm, ascii_art +from caravel import ( + appbuilder, db, models, viz, utils, app, sm, ascii_art, tasks +) config = app.config log_this = models.Log.log_this @@ -1060,7 +1062,7 @@ def activity_per_day(self): @expose("/tables//") def tables(self, db_id, schema): """endpoint to power the calendar heatmap on the welcome page""" - schema = None if schema == 'null' else schema + schema = None if schema in ('null', 'undefined') else schema database = ( db.session .query(models.Database) @@ -1229,7 +1231,7 @@ def sql(self, database_id): @expose("/table////") @log_this def table(self, database_id, table_name, schema): - schema = None if schema == 'null' else schema + schema = None if schema in ('null', 'undefined') else schema mydb = db.session.query(models.Database).filter_by(id=database_id).one() cols = [] t = mydb.get_columns(table_name, schema) @@ -1325,12 +1327,9 @@ def theme(self): @log_this def sql_json(self): """Runs arbitrary sql and returns and json""" - session = db.session() - limit = 1000 sql = request.form.get('sql') database_id = request.form.get('database_id') schema = request.form.get('schema') - mydb = session.query(models.Database).filter_by(id=database_id).first() if ( not self.can_access( @@ -1338,41 +1337,19 @@ def sql_json(self): raise utils.CaravelSecurityException(_( "This view requires the `all_datasource_access` permission")) - error_msg = "" - if not mydb: - error_msg = "The database selected doesn't seem to exist" - else: - eng = mydb.get_sqla_engine() - if limit: - sql = sql.strip().strip(';') - qry = ( - select('*') - .select_from(TextAsFrom(text(sql), ['*']).alias('inner_qry')) - .limit(limit) - ) - sql = str(qry.compile(eng, compile_kwargs={"literal_binds": True})) - try: - df = mydb.get_df(sql, schema) - df = df.fillna(0) # TODO make sure NULL - except Exception as e: - error_msg = utils.error_msg_from_exception(e) - - session.commit() - if error_msg: - return Response( - json.dumps({ - 'error': error_msg, - }), - status=500, - mimetype="application/json") - else: - data = { - 'columns': [c for c in df.columns], - 'data': df.to_dict(orient='records'), - 'ydata_tpe.to_dict': { - k: '{}'.format(v) for k, v in df.dtypes.to_dict().items()}, - } - return json.dumps(data, default=utils.json_int_dttm_ser, allow_nan=False) + data = tasks.get_sql_results(database_id, sql, g.user.get_id(), + schema=schema) + if 'error' in data: + return Response( + json.dumps(data), + status=500, + mimetype="application/json") + if 'tmp_table' in data: + # TODO(bkyryliuk): add query id to the response and implement the + # endpoint to poll the status and results. + return None + return json.dumps( + data, default=utils.json_int_dttm_ser, allow_nan=False) @has_access @expose("/refresh_datasources/") diff --git a/caravel/viz.py b/caravel/viz.py index 13449b720dd3c..83df0c4729fd0 100755 --- a/caravel/viz.py +++ b/caravel/viz.py @@ -284,7 +284,8 @@ def get_json(self): cached_data = cached_data.decode('utf-8') payload = json.loads(cached_data) except Exception as e: - logging.error("Error reading cache") + logging.error("Error reading cache: " + + utils.error_msg_from_exception(e)) payload = None logging.info("Serving from cache") diff --git a/run_tests.sh b/run_tests.sh index 37ea9249bbbff..faf19251a9ba4 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -1,9 +1,12 @@ #!/usr/bin/env bash echo $DB rm /tmp/caravel_unittests.db +rm /tmp/celerydb.sqlite +rm /tmp/celery_results.sqlite rm -f .coverage export CARAVEL_CONFIG=tests.caravel_test_config set -e caravel/bin/caravel db upgrade caravel/bin/caravel version -v python setup.py nosetests + diff --git a/setup.py b/setup.py index d02d1499f1634..ceb266d9fef18 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ zip_safe=False, scripts=['caravel/bin/caravel'], install_requires=[ + 'celery==3.1.23', 'cryptography==1.4', 'flask-appbuilder==1.8.1', 'flask-cache==0.13.1', diff --git a/tests/caravel_test_config.py b/tests/caravel_test_config.py index a7de4569499e7..03f6b5aede723 100644 --- a/tests/caravel_test_config.py +++ b/tests/caravel_test_config.py @@ -2,11 +2,27 @@ AUTH_USER_REGISTRATION_ROLE = 'alpha' SQLALCHEMY_DATABASE_URI = 'sqlite:////tmp/caravel_unittests.db' +# MySQL connection string for unit tests: +# SQLALCHEMY_DATABASE_URI = 'mysql://root:@localhost/caravel_db' DEBUG = True CARAVEL_WEBSERVER_PORT = 8081 # Allowing SQLALCHEMY_DATABASE_URI to be defined as an env var for # continuous integration if 'CARAVEL__SQLALCHEMY_DATABASE_URI' in os.environ: - SQLALCHEMY_DATABASE_URI = os.environ.get('CARAVEL__SQLALCHEMY_DATABASE_URI') + SQLALCHEMY_DATABASE_URI = os.environ.get( + 'CARAVEL__SQLALCHEMY_DATABASE_URI') + +SQL_CELERY_DB_FILE_PATH = '/tmp/celerydb.sqlite' +SQL_CELERY_RESULTS_DB_FILE_PATH = '/tmp/celery_results.sqlite' +SQL_SELECT_AS_CTA = True + + +class CeleryConfig(object): + BROKER_URL = 'sqla+sqlite:///' + SQL_CELERY_DB_FILE_PATH + CELERY_IMPORTS = ('caravel.tasks', ) + CELERY_RESULT_BACKEND = 'db+sqlite:///' + SQL_CELERY_RESULTS_DB_FILE_PATH + CELERY_ANNOTATIONS = {'tasks.add': {'rate_limit': '10/s'}} + CONCURRENCY = 1 +CELERY_CONFIG = CeleryConfig diff --git a/tests/celery_tests.py b/tests/celery_tests.py new file mode 100644 index 0000000000000..e88ae0fca1c5b --- /dev/null +++ b/tests/celery_tests.py @@ -0,0 +1,399 @@ +"""Unit tests for Caravel Celery worker""" +import datetime +import imp +import subprocess +import os +import pandas as pd +import time +import unittest + +import caravel +from caravel import app, appbuilder, db, models, tasks, utils + + +class CeleryConfig(object): + BROKER_URL = 'sqla+sqlite:////tmp/celerydb.sqlite' + CELERY_IMPORTS = ('caravel.tasks',) + CELERY_RESULT_BACKEND = 'db+sqlite:////tmp/celery_results.sqlite' + CELERY_ANNOTATIONS = {'tasks.add': {'rate_limit': '10/s'}} +app.config['CELERY_CONFIG'] = CeleryConfig + +BASE_DIR = app.config.get('BASE_DIR') +cli = imp.load_source('cli', BASE_DIR + '/bin/caravel') + + +class UtilityFunctionTests(unittest.TestCase): + def test_create_table_as(self): + select_query = "SELECT * FROM outer_space;" + updated_select_query = tasks.create_table_as(select_query, "tmp") + self.assertEqual( + "CREATE TABLE tmp AS SELECT * FROM outer_space;", + updated_select_query) + + updated_select_query_with_drop = tasks.create_table_as( + select_query, "tmp", override=True) + self.assertEqual( + "DROP TABLE IF EXISTS tmp;\n" + "CREATE TABLE tmp AS SELECT * FROM outer_space;", + updated_select_query_with_drop) + + select_query_no_semicolon = "SELECT * FROM outer_space" + updated_select_query_no_semicolon = tasks.create_table_as( + select_query_no_semicolon, "tmp") + self.assertEqual( + "CREATE TABLE tmp AS SELECT * FROM outer_space", + updated_select_query_no_semicolon) + + incorrect_query = "SMTH WRONG SELECT * FROM outer_space" + updated_incorrect_query = tasks.create_table_as(incorrect_query, "tmp") + self.assertEqual(incorrect_query, updated_incorrect_query) + + insert_query = "INSERT INTO stomach VALUES (beer, chips);" + updated_insert_query = tasks.create_table_as(insert_query, "tmp") + self.assertEqual(insert_query, updated_insert_query) + + multi_line_query = ( + "SELECT * FROM planets WHERE\n" + "Luke_Father = 'Darth Vader';") + updated_multi_line_query = tasks.create_table_as( + multi_line_query, "tmp") + expected_updated_multi_line_query = ( + "CREATE TABLE tmp AS SELECT * FROM planets WHERE\n" + "Luke_Father = 'Darth Vader';") + self.assertEqual( + expected_updated_multi_line_query, + updated_multi_line_query) + + updated_multi_line_query_with_drop = tasks.create_table_as( + multi_line_query, "tmp", override=True) + expected_updated_multi_line_query_with_drop = ( + "DROP TABLE IF EXISTS tmp;\n" + "CREATE TABLE tmp AS SELECT * FROM planets WHERE\n" + "Luke_Father = 'Darth Vader';") + self.assertEqual( + expected_updated_multi_line_query_with_drop, + updated_multi_line_query_with_drop) + + delete_query = "DELETE FROM planet WHERE name = 'Earth'" + updated_delete_query = tasks.create_table_as(delete_query, "tmp") + self.assertEqual(delete_query, updated_delete_query) + + create_table_as = ( + "CREATE TABLE pleasure AS SELECT chocolate FROM lindt_store;\n") + updated_create_table_as = tasks.create_table_as( + create_table_as, "tmp") + self.assertEqual(create_table_as, updated_create_table_as) + + sql_procedure = ( + "CREATE PROCEDURE MyMarriage\n " + "BrideGroom Male (25) ,\n " + "Bride Female(20) AS\n " + "BEGIN\n " + "SELECT Bride FROM ukraine_ Brides\n " + "WHERE\n " + "FatherInLaw = 'Millionaire' AND Count(Car) > 20\n" + " AND HouseStatus ='ThreeStoreyed'\n" + " AND BrideEduStatus IN " + "(B.TECH ,BE ,Degree ,MCA ,MiBA)\n " + "AND Having Brothers= Null AND Sisters =Null" + ) + updated_sql_procedure = tasks.create_table_as(sql_procedure, "tmp") + self.assertEqual(sql_procedure, updated_sql_procedure) + + multiple_statements = """ + DROP HUSBAND; + SELECT * FROM politicians WHERE clue > 0; + INSERT INTO MyCarShed VALUES('BUGATTI'); + SELECT standard_disclaimer, witty_remark FROM company_requirements; + select count(*) from developer_brain; + """ + updated_multiple_statements = tasks.create_table_as( + multiple_statements, "tmp") + self.assertEqual(multiple_statements, updated_multiple_statements) + + +class CeleryTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(CeleryTestCase, self).__init__(*args, **kwargs) + self.client = app.test_client() + utils.init(caravel) + admin = appbuilder.sm.find_user('admin') + if not admin: + appbuilder.sm.add_user( + 'admin', 'admin', ' user', 'admin@fab.org', + appbuilder.sm.find_role('Admin'), + password='general') + utils.init(caravel) + + @classmethod + def setUpClass(cls): + try: + os.remove(app.config.get('SQL_CELERY_DB_FILE_PATH')) + except OSError: + pass + try: + os.remove(app.config.get('SQL_CELERY_RESULTS_DB_FILE_PATH')) + except OSError: + pass + + worker_command = BASE_DIR + '/bin/caravel worker' + subprocess.Popen( + worker_command, shell=True, stdout=subprocess.PIPE) + cli.load_examples(load_test_data=True) + + @classmethod + def tearDownClass(cls): + subprocess.call( + "ps auxww | grep 'celeryd' | awk '{print $2}' | " + "xargs kill -9", + shell=True + ) + subprocess.call( + "ps auxww | grep 'caravel worker' | awk '{print $2}' | " + "xargs kill -9", + shell=True + ) + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_add_limit_to_the_query(self): + query_session = tasks.get_session() + db_to_query = query_session.query(models.Database).filter_by( + id=1).first() + eng = db_to_query.get_sqla_engine() + + select_query = "SELECT * FROM outer_space;" + updated_select_query = tasks.add_limit_to_the_query( + select_query, 100, eng) + # Different DB engines have their own spacing while compiling + # the queries, that's why ' '.join(query.split()) is used. + # In addition some of the engines do not include OFFSET 0. + self.assertTrue( + "SELECT * FROM (SELECT * FROM outer_space;) AS inner_qry " + "LIMIT 100" in ' '.join(updated_select_query.split()) + ) + + select_query_no_semicolon = "SELECT * FROM outer_space" + updated_select_query_no_semicolon = tasks.add_limit_to_the_query( + select_query_no_semicolon, 100, eng) + self.assertTrue( + "SELECT * FROM (SELECT * FROM outer_space) AS inner_qry " + "LIMIT 100" in + ' '.join(updated_select_query_no_semicolon.split()) + ) + + incorrect_query = "SMTH WRONG SELECT * FROM outer_space" + updated_incorrect_query = tasks.add_limit_to_the_query( + incorrect_query, 100, eng) + self.assertEqual(incorrect_query, updated_incorrect_query) + + insert_query = "INSERT INTO stomach VALUES (beer, chips);" + updated_insert_query = tasks.add_limit_to_the_query( + insert_query, 100, eng) + self.assertEqual(insert_query, updated_insert_query) + + multi_line_query = ( + "SELECT * FROM planets WHERE\n Luke_Father = 'Darth Vader';" + ) + updated_multi_line_query = tasks.add_limit_to_the_query( + multi_line_query, 100, eng) + self.assertTrue( + "SELECT * FROM (SELECT * FROM planets WHERE " + "Luke_Father = 'Darth Vader';) AS inner_qry LIMIT 100" in + ' '.join(updated_multi_line_query.split()) + ) + + delete_query = "DELETE FROM planet WHERE name = 'Earth'" + updated_delete_query = tasks.add_limit_to_the_query( + delete_query, 100, eng) + self.assertEqual(delete_query, updated_delete_query) + + create_table_as = ( + "CREATE TABLE pleasure AS SELECT chocolate FROM lindt_store;\n") + updated_create_table_as = tasks.add_limit_to_the_query( + create_table_as, 100, eng) + self.assertEqual(create_table_as, updated_create_table_as) + + sql_procedure = ( + "CREATE PROCEDURE MyMarriage\n " + "BrideGroom Male (25) ,\n " + "Bride Female(20) AS\n " + "BEGIN\n " + "SELECT Bride FROM ukraine_ Brides\n " + "WHERE\n " + "FatherInLaw = 'Millionaire' AND Count(Car) > 20\n" + " AND HouseStatus ='ThreeStoreyed'\n" + " AND BrideEduStatus IN " + "(B.TECH ,BE ,Degree ,MCA ,MiBA)\n " + "AND Having Brothers= Null AND Sisters = Null" + ) + updated_sql_procedure = tasks.add_limit_to_the_query( + sql_procedure, 100, eng) + self.assertEqual(sql_procedure, updated_sql_procedure) + + def test_run_async_query_delay_get(self): + main_db = db.session.query(models.Database).filter_by( + database_name="main").first() + eng = main_db.get_sqla_engine() + + # Case 1. + # DB #0 doesn't exist. + result1 = tasks.get_sql_results.delay( + 0, 'SELECT * FROM dontexist', 1, tmp_table_name='tmp_1_1').get() + expected_result1 = { + 'error': 'Database with id 0 is missing.', + 'success': False + } + self.assertEqual( + sorted(expected_result1.items()), + sorted(result1.items()) + ) + session1 = db.create_scoped_session() + query1 = session1.query(models.Query).filter_by( + sql='SELECT * FROM dontexist').first() + session1.close() + self.assertIsNone(query1) + + # Case 2. + session2 = db.create_scoped_session() + query2 = session2.query(models.Query).filter_by( + sql='SELECT * FROM dontexist1').first() + self.assertEqual(models.QueryStatus.FAILED, query2.status) + session2.close() + + result2 = tasks.get_sql_results.delay( + 1, 'SELECT * FROM dontexist1', 1, tmp_table_name='tmp_2_1').get() + self.assertTrue('error' in result2) + session2 = db.create_scoped_session() + query2 = session2.query(models.Query).filter_by( + sql='SELECT * FROM dontexist1').first() + self.assertEqual(models.QueryStatus.FAILED, query2.status) + session2.close() + + # Case 3. + where_query = ( + "SELECT name FROM ab_permission WHERE name='can_select_star'") + result3 = tasks.get_sql_results.delay( + 1, where_query, 1, tmp_table_name='tmp_3_1').get() + expected_result3 = { + 'tmp_table': 'tmp_3_1', + 'success': True + } + self.assertEqual( + sorted(expected_result3.items()), + sorted(result3.items()) + ) + session3 = db.create_scoped_session() + query3 = session3.query(models.Query).filter_by( + sql=where_query).first() + session3.close() + df3 = pd.read_sql_query(sql="SELECT * FROM tmp_3_1", con=eng) + data3 = df3.to_dict(orient='records') + self.assertEqual(models.QueryStatus.FINISHED, query3.status) + self.assertEqual([{'name': 'can_select_star'}], data3) + + # Case 4. + result4 = tasks.get_sql_results.delay( + 1, 'SELECT * FROM ab_permission WHERE id=666', 1, + tmp_table_name='tmp_4_1').get() + expected_result4 = { + 'tmp_table': 'tmp_4_1', + 'success': True + } + self.assertEqual( + sorted(expected_result4.items()), + sorted(result4.items()) + ) + session4 = db.create_scoped_session() + query4 = session4.query(models.Query).filter_by( + sql='SELECT * FROM ab_permission WHERE id=666').first() + session4.close() + df4 = pd.read_sql_query(sql="SELECT * FROM tmp_4_1", con=eng) + data4 = df4.to_dict(orient='records') + self.assertEqual(models.QueryStatus.FINISHED, query4.status) + self.assertEqual([], data4) + + # Case 5. + # Return the data directly if DB select_as_create_table_as is False. + main_db.select_as_create_table_as = False + db.session.commit() + result5 = tasks.get_sql_results.delay( + 1, where_query, 1, tmp_table_name='tmp_5_1').get() + expected_result5 = { + 'columns': ['name'], + 'data': [{'name': 'can_select_star'}], + 'success': True + } + self.assertEqual( + sorted(expected_result5.items()), + sorted(result5.items()) + ) + + def test_run_async_query_delay(self): + celery_task1 = tasks.get_sql_results.delay( + 0, 'SELECT * FROM dontexist', 1, tmp_table_name='tmp_1_2') + celery_task2 = tasks.get_sql_results.delay( + 1, 'SELECT * FROM dontexist1', 1, tmp_table_name='tmp_2_2') + where_query = ( + "SELECT name FROM ab_permission WHERE name='can_select_star'") + celery_task3 = tasks.get_sql_results.delay( + 1, where_query, 1, tmp_table_name='tmp_3_2') + celery_task4 = tasks.get_sql_results.delay( + 1, 'SELECT * FROM ab_permission WHERE id=666', 1, + tmp_table_name='tmp_4_2') + + time.sleep(1) + + # DB #0 doesn't exist. + expected_result1 = { + 'error': 'Database with id 0 is missing.', + 'success': False + } + self.assertEqual( + sorted(expected_result1.items()), + sorted(celery_task1.get().items()) + ) + session2 = db.create_scoped_session() + query2 = session2.query(models.Query).filter_by( + sql='SELECT * FROM dontexist1').first() + self.assertEqual(models.QueryStatus.FAILED, query2.status) + self.assertTrue('error' in celery_task2.get()) + expected_result3 = { + 'tmp_table': 'tmp_3_2', + 'success': True + } + self.assertEqual( + sorted(expected_result3.items()), + sorted(celery_task3.get().items()) + ) + expected_result4 = { + 'tmp_table': 'tmp_4_2', + 'success': True + } + self.assertEqual( + sorted(expected_result4.items()), + sorted(celery_task4.get().items()) + ) + + session = db.create_scoped_session() + query1 = session.query(models.Query).filter_by( + sql='SELECT * FROM dontexist').first() + self.assertIsNone(query1) + query2 = session.query(models.Query).filter_by( + sql='SELECT * FROM dontexist1').first() + self.assertEqual(models.QueryStatus.FAILED, query2.status) + query3 = session.query(models.Query).filter_by( + sql=where_query).first() + self.assertEqual(models.QueryStatus.FINISHED, query3.status) + query4 = session.query(models.Query).filter_by( + sql='SELECT * FROM ab_permission WHERE id=666').first() + self.assertEqual(models.QueryStatus.FINISHED, query4.status) + session.close() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/core_tests.py b/tests/core_tests.py index d24d5f1207926..08623ab34bb21 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -10,6 +10,7 @@ import imp import os import unittest + from mock import Mock, patch from flask import escape @@ -21,11 +22,15 @@ os.environ['CARAVEL_CONFIG'] = 'tests.caravel_test_config' -app.config['TESTING'] = True +# Disable celery. +app.config['CELERY_CONFIG'] = None app.config['CSRF_ENABLED'] = False +app.config['PUBLIC_ROLE_LIKE_GAMMA'] = True app.config['SECRET_KEY'] = 'thisismyscretkey' +app.config['SQL_SELECT_AS_CTA'] = False +app.config['TESTING'] = True app.config['WTF_CSRF_ENABLED'] = False -app.config['PUBLIC_ROLE_LIKE_GAMMA'] = True + BASE_DIR = app.config.get("BASE_DIR") cli = imp.load_source('cli', BASE_DIR + "/bin/caravel") @@ -41,7 +46,7 @@ def __init__(self, *args, **kwargs): admin = appbuilder.sm.find_user('admin') if not admin: appbuilder.sm.add_user( - 'admin', 'admin',' user', 'admin@fab.org', + 'admin', 'admin', ' user', 'admin@fab.org', appbuilder.sm.find_role('Admin'), password='general') @@ -80,7 +85,7 @@ def setup_public_access_for_dashboard(self, table_name): public_role = appbuilder.sm.find_role('Public') perms = db.session.query(ab_models.PermissionView).all() for perm in perms: - if ( perm.permission.name == 'datasource_access' and + if (perm.permission.name == 'datasource_access' and perm.view_menu and table_name in perm.view_menu.name): appbuilder.sm.add_permission_role(public_role, perm) @@ -88,7 +93,7 @@ def revoke_public_access(self, table_name): public_role = appbuilder.sm.find_role('Public') perms = db.session.query(ab_models.PermissionView).all() for perm in perms: - if ( perm.permission.name == 'datasource_access' and + if (perm.permission.name == 'datasource_access' and perm.view_menu and table_name in perm.view_menu.name): appbuilder.sm.del_permission_role(public_role, perm) @@ -96,15 +101,15 @@ def revoke_public_access(self, table_name): class CoreTests(CaravelTestCase): def __init__(self, *args, **kwargs): - # Load examples first, so that we setup proper permission-view relations - # for all example data sources. + # Load examples first, so that we setup proper permission-view + # relations for all example data sources. super(CoreTests, self).__init__(*args, **kwargs) @classmethod def setUpClass(cls): cli.load_examples(load_test_data=True) utils.init(caravel) - cls.table_ids = {tbl.table_name: tbl.id for tbl in ( + cls.table_ids = {tbl.table_name: tbl.id for tbl in ( db.session .query(models.SqlaTable) .all() @@ -127,7 +132,12 @@ def test_save_slice(self): copy_name = "Test Sankey Save" tbl_id = self.table_ids.get('energy_usage') - url = "/caravel/explore/table/{}/?viz_type=sankey&groupby=source&groupby=target&metric=sum__value&row_limit=5000&where=&having=&flt_col_0=source&flt_op_0=in&flt_eq_0=&slice_id={}&slice_name={}&collapsed_fieldsets=&action={}&datasource_name=energy_usage&datasource_id=1&datasource_type=table&previous_viz_type=sankey" + url = ( + "/caravel/explore/table/{}/?viz_type=sankey&groupby=source&" + "groupby=target&metric=sum__value&row_limit=5000&where=&having=&" + "flt_col_0=source&flt_op_0=in&flt_eq_0=&slice_id={}&slice_name={}&" + "collapsed_fieldsets=&action={}&datasource_name=energy_usage&" + "datasource_id=1&datasource_type=table&previous_viz_type=sankey") db.session.commit() resp = self.client.get( @@ -147,7 +157,8 @@ def test_slices(self): for slc in db.session.query(Slc).all(): urls += [ (slc.slice_name, 'slice_url', slc.slice_url), - (slc.slice_name, 'slice_id_endpoint', '/caravel/slices/{}'.format(slc.id)), + (slc.slice_name, 'slice_id_endpoint', '/caravel/slices/{}'. + format(slc.id)), (slc.slice_name, 'json_endpoint', slc.viz.json_endpoint), (slc.slice_name, 'csv_endpoint', slc.viz.csv_endpoint), ] @@ -176,13 +187,20 @@ def test_misc(self): def test_shortner(self): self.login(username='admin') - data = "//caravel/explore/table/1/?viz_type=sankey&groupby=source&groupby=target&metric=sum__value&row_limit=5000&where=&having=&flt_col_0=source&flt_op_0=in&flt_eq_0=&slice_id=78&slice_name=Energy+Sankey&collapsed_fieldsets=&action=&datasource_name=energy_usage&datasource_id=1&datasource_type=table&previous_viz_type=sankey" + data = ( + "//caravel/explore/table/1/?viz_type=sankey&groupby=source&" + "groupby=target&metric=sum__value&row_limit=5000&where=&having=&" + "flt_col_0=source&flt_op_0=in&flt_eq_0=&slice_id=78&slice_name=" + "Energy+Sankey&collapsed_fieldsets=&action=&datasource_name=" + "energy_usage&datasource_id=1&datasource_type=table&" + "previous_viz_type=sankey") resp = self.client.post('/r/shortner/', data=data) assert '/r/' in resp.data.decode('utf-8') def test_save_dash(self, username='admin'): self.login(username=username) - dash = db.session.query(models.Dashboard).filter_by(slug="births").first() + dash = db.session.query(models.Dashboard).filter_by( + slug="births").first() positions = [] for i, slc in enumerate(dash.slices): d = { @@ -203,18 +221,24 @@ def test_save_dash(self, username='admin'): def test_add_slices(self, username='admin'): self.login(username=username) - dash = db.session.query(models.Dashboard).filter_by(slug="births").first() - new_slice = db.session.query(models.Slice).filter_by(slice_name="Mapbox Long/Lat").first() - existing_slice = db.session.query(models.Slice).filter_by(slice_name="Name Cloud").first() + dash = db.session.query(models.Dashboard).filter_by( + slug="births").first() + new_slice = db.session.query(models.Slice).filter_by( + slice_name="Mapbox Long/Lat").first() + existing_slice = db.session.query(models.Slice).filter_by( + slice_name="Name Cloud").first() data = { - "slice_ids": [new_slice.data["slice_id"], existing_slice.data["slice_id"]] + "slice_ids": [new_slice.data["slice_id"], + existing_slice.data["slice_id"]] } url = '/caravel/add_slices/{}/'.format(dash.id) resp = self.client.post(url, data=dict(data=json.dumps(data))) assert "SLICES ADDED" in resp.data.decode('utf-8') - dash = db.session.query(models.Dashboard).filter_by(slug="births").first() - new_slice = db.session.query(models.Slice).filter_by(slice_name="Mapbox Long/Lat").first() + dash = db.session.query(models.Dashboard).filter_by( + slug="births").first() + new_slice = db.session.query(models.Slice).filter_by( + slice_name="Mapbox Long/Lat").first() assert new_slice in dash.slices assert len(set(dash.slices)) == len(dash.slices) @@ -222,7 +246,10 @@ def test_add_slice_redirect_to_sqla(self, username='admin'): self.login(username=username) url = '/slicemodelview/add' resp = self.client.get(url, follow_redirects=True) - assert "Click on a table link to create a Slice" in resp.data.decode('utf-8') + assert ( + "Click on a table link to create a Slice" in + resp.data.decode('utf-8') + ) def test_add_slice_redirect_to_druid(self, username='admin'): datasource = DruidDatasource( @@ -234,7 +261,10 @@ def test_add_slice_redirect_to_druid(self, username='admin'): self.login(username=username) url = '/slicemodelview/add' resp = self.client.get(url, follow_redirects=True) - assert "Click on a datasource link to create a Slice" in resp.data.decode('utf-8') + assert ( + "Click on a datasource link to create a Slice" + in resp.data.decode('utf-8') + ) db.session.delete(datasource) db.session.commit() @@ -305,7 +335,6 @@ def test_public_user_dashboard_access(self): data = resp.data.decode('utf-8') assert "/caravel/dashboard/world_health/" not in data - def test_only_owners_can_save(self): dash = ( db.session @@ -337,26 +366,26 @@ def test_only_owners_can_save(self): SEGMENT_METADATA = [{ "id": "some_id", - "intervals": [ "2013-05-13T00:00:00.000Z/2013-05-14T00:00:00.000Z" ], + "intervals": ["2013-05-13T00:00:00.000Z/2013-05-14T00:00:00.000Z"], "columns": { "__time": { "type": "LONG", "hasMultipleValues": False, - "size": 407240380, "cardinality": None, "errorMessage": None }, + "size": 407240380, "cardinality": None, "errorMessage": None}, "dim1": { "type": "STRING", "hasMultipleValues": False, - "size": 100000, "cardinality": 1944, "errorMessage": None }, + "size": 100000, "cardinality": 1944, "errorMessage": None}, "dim2": { "type": "STRING", "hasMultipleValues": True, - "size": 100000, "cardinality": 1504, "errorMessage": None }, + "size": 100000, "cardinality": 1504, "errorMessage": None}, "metric1": { "type": "FLOAT", "hasMultipleValues": False, - "size": 100000, "cardinality": None, "errorMessage": None } + "size": 100000, "cardinality": None, "errorMessage": None} }, "aggregators": { "metric1": { "type": "longSum", "name": "metric1", - "fieldName": "metric1" } + "fieldName": "metric1"} }, "size": 300000, "numRows": 5000000 @@ -422,7 +451,8 @@ def test_client(self, PyDruid): datasource_id = cluster.datasources[0].id db.session.commit() - resp = self.client.get('/caravel/explore/druid/{}/'.format(datasource_id)) + resp = self.client.get('/caravel/explore/druid/{}/'.format( + datasource_id)) assert "[test_cluster].[test_datasource]" in resp.data.decode('utf-8') nres = [ @@ -434,9 +464,15 @@ def test_client(self, PyDruid): instance.export_pandas.return_value = df instance.query_dict = {} instance.query_builder.last_query.query_dict = {} - resp = self.client.get('/caravel/explore/druid/{}/?viz_type=table&granularity=one+day&druid_time_origin=&since=7+days+ago&until=now&row_limit=5000&include_search=false&metrics=count&groupby=name&flt_col_0=dim1&flt_op_0=in&flt_eq_0=&slice_id=&slice_name=&collapsed_fieldsets=&action=&datasource_name=test_datasource&datasource_id={}&datasource_type=druid&previous_viz_type=table&json=true&force=true'.format(datasource_id, datasource_id)) + resp = self.client.get( + '/caravel/explore/druid/{}/?viz_type=table&granularity=one+day&' + 'druid_time_origin=&since=7+days+ago&until=now&row_limit=5000&' + 'include_search=false&metrics=count&groupby=name&flt_col_0=dim1&' + 'flt_op_0=in&flt_eq_0=&slice_id=&slice_name=&collapsed_fieldsets=&' + 'action=&datasource_name=test_datasource&datasource_id={}&' + 'datasource_type=druid&previous_viz_type=table&json=true&' + 'force=true'.format(datasource_id, datasource_id)) assert "Canada" in resp.data.decode('utf-8') - if __name__ == '__main__': unittest.main()