Skip to content

Commit

Permalink
Refactor the explore view (#1252)
Browse files Browse the repository at this point in the history
* Refactor the explore view

* Fixing the tests

* Addressing comments
  • Loading branch information
mistercrunch authored Oct 7, 2016
1 parent b7d1f78 commit f70d301
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 138 deletions.
8 changes: 8 additions & 0 deletions caravel/source_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,11 @@ def register_sources(cls, datasource_config):
for class_name in class_names:
source_class = getattr(module_obj, class_name)
cls.sources[source_class.type] = source_class

@classmethod
def get_datasource(cls, datasource_type, datasource_id, session):
return (
session.query(cls.sources[datasource_type])
.filter_by(id=datasource_id)
.one()
)
2 changes: 1 addition & 1 deletion caravel/templates/caravel/standalone.html
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<html>
<head>
<title>{{viz.token}}</title>
<title>{{ viz.token }}</title>
<link rel="stylesheet" type="text/css" href="/static/assets/node_modules/font-awesome/css/font-awesome.min.css" />
<link rel="stylesheet" type="text/css" href="/static/assets/stylesheets/caravel.css" />
<link rel="stylesheet" type="text/css" href="/static/appbuilder/css/flags/flags16.css" />
Expand Down
143 changes: 76 additions & 67 deletions caravel/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from flask_appbuilder.models.sqla.filters import BaseFilter

from sqlalchemy import create_engine
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.routing import BaseConverter
from wtforms.validators import ValidationError

Expand Down Expand Up @@ -244,7 +243,8 @@ def apply(self, query, func): # noqa
druid_datasources = []
for perm in perms:
match = re.search(r'\(id:(\d+)\)', perm)
druid_datasources.append(match.group(1))
if match:
druid_datasources.append(match.group(1))
qry = query.filter(self.model.id.in_(druid_datasources))
return qry

Expand Down Expand Up @@ -672,6 +672,7 @@ class DruidClusterModelView(CaravelModelView, DeleteMixin): # noqa
'broker_port': _("Broker Port"),
'broker_endpoint': _("Broker Endpoint"),
}

def pre_add(self, db):
utils.merge_perm(sm, 'database_access', db.perm)

Expand Down Expand Up @@ -699,7 +700,8 @@ class SliceModelView(CaravelModelView, DeleteMixin): # noqa
list_columns = [
'slice_link', 'viz_type', 'datasource_link', 'creator', 'modified']
edit_columns = [
'slice_name', 'description', 'viz_type', 'owners', 'dashboards', 'params', 'cache_timeout']
'slice_name', 'description', 'viz_type', 'owners', 'dashboards',
'params', 'cache_timeout']
base_order = ('changed_on', 'desc')
description_columns = {
'description': Markup(
Expand Down Expand Up @@ -1099,107 +1101,110 @@ def approve(self):
session.commit()
return redirect('/accessrequestsmodelview/list/')

def get_viz(
self,
slice_id=None,
args=None,
datasource_type=None,
datasource_id=None):
if slice_id:
slc = db.session.query(models.Slice).filter_by(id=slice_id).one()
return slc.get_viz()
else:
viz_type = args.get('viz_type', 'table')
datasource = SourceRegistry.get_datasource(
datasource_type, datasource_id, db.session)
viz_obj = viz.viz_types[viz_type](datasource, request.args)
return viz_obj

@has_access
@expose("/explore/<datasource_type>/<datasource_id>/<slice_id>/")
@expose("/explore/<datasource_type>/<datasource_id>/")
@expose("/datasource/<datasource_type>/<datasource_id>/") # Legacy url
@expose("/slice/<slice_id>/")
def slice(self, slice_id):
viz_obj = self.get_viz(slice_id)
return redirect(viz_obj.get_url(**request.args))

@has_access_api
@expose("/explore_json/<datasource_type>/<datasource_id>/")
def explore_json(self, datasource_type, datasource_id):
viz_obj = self.get_viz(
datasource_type=datasource_type,
datasource_id=datasource_id,
args=request.args)
if not self.datasource_access(viz_obj.datasource):
return Response(
json.dumps(
{'error': _("You don't have access to this datasource")}),
status=404,
mimetype="application/json")
return Response(
viz_obj.get_json(),
status=200,
mimetype="application/json")

@log_this
def explore(self, datasource_type, datasource_id, slice_id=None):
@has_access
@expose("/explore/<datasource_type>/<datasource_id>/")
def explore(self, datasource_type, datasource_id):
viz_type = request.args.get("viz_type")
slice_id = request.args.get('slice_id')
slc = db.session.query(models.Slice).filter_by(id=slice_id).first()

error_redirect = '/slicemodelview/list/'
datasource_class = SourceRegistry.sources[datasource_type]
datasources = db.session.query(datasource_class).all()
datasources = sorted(datasources, key=lambda ds: ds.full_name)
datasource = [ds for ds in datasources if int(datasource_id) == ds.id]
datasource = datasource[0] if datasource else None

if not datasource:
viz_obj = self.get_viz(
datasource_type=datasource_type,
datasource_id=datasource_id,
args=request.args)

if not viz_obj.datasource:
flash(DATASOURCE_MISSING_ERR, "alert")
return redirect(error_redirect)

if not self.datasource_access(datasource):
if not self.datasource_access(viz_obj.datasource):
flash(
__(get_datasource_access_error_msg(datasource.name)), "danger")
__(get_datasource_access_error_msg(viz_obj.datasource.name)),
"danger")
return redirect(
'caravel/request_access/?'
'datasource_type={datasource_type}&'
'datasource_id={datasource_id}&'
''.format(**locals()))

request_args_multi_dict = request.args # MultiDict

slice_id = slice_id or request_args_multi_dict.get("slice_id")
slc = None
# build viz_obj and get it's params
if slice_id:
slc = db.session.query(models.Slice).filter_by(id=slice_id).first()
try:
viz_obj = slc.get_viz(
url_params_multidict=request_args_multi_dict)
except Exception as e:
logging.exception(e)
flash(utils.error_msg_from_exception(e), "danger")
return redirect(error_redirect)
else:
viz_type = request_args_multi_dict.get("viz_type")
if not viz_type and datasource.default_endpoint:
return redirect(datasource.default_endpoint)
# default to table if no default endpoint and no viz_type
viz_type = viz_type or "table"
# validate viz params
try:
viz_obj = viz.viz_types[viz_type](
datasource, request_args_multi_dict)
except Exception as e:
logging.exception(e)
flash(utils.error_msg_from_exception(e), "danger")
return redirect(error_redirect)
slice_params_multi_dict = ImmutableMultiDict(viz_obj.orig_form_data)
if not viz_type and viz_obj.datasource.default_endpoint:
return redirect(viz_obj.datasource.default_endpoint)

# slc perms
slice_add_perm = self.can_access('can_add', 'SliceModelView')
slice_edit_perm = check_ownership(slc, raise_if_false=False)
slice_download_perm = self.can_access('can_download', 'SliceModelView')

# handle save or overwrite
action = slice_params_multi_dict.get('action')
action = request.args.get('action')
if action in ('saveas', 'overwrite'):
return self.save_or_overwrite_slice(
slice_params_multi_dict, slc, slice_add_perm, slice_edit_perm)
request.args, slc, slice_add_perm, slice_edit_perm)

# handle different endpoints
if slice_params_multi_dict.get("json") == "true":
if config.get("DEBUG"):
# Allows for nice debugger stack traces in debug mode
return Response(
viz_obj.get_json(),
status=200,
mimetype="application/json")
try:
return Response(
viz_obj.get_json(),
status=200,
mimetype="application/json")
except Exception as e:
logging.exception(e)
return json_error_response(utils.error_msg_from_exception(e))

elif slice_params_multi_dict.get("csv") == "true":
if request.args.get("csv") == "true":
payload = viz_obj.get_csv()
return Response(
payload,
status=200,
headers=generate_download_headers("csv"),
mimetype="application/csv")
elif request.args.get("standalone") == "true":
return self.render_template("caravel/standalone.html", viz=viz_obj)
else:
if slice_params_multi_dict.get("standalone") == "true":
template = "caravel/standalone.html"
else:
template = "caravel/explore.html"
return self.render_template(
template, viz=viz_obj, slice=slc, datasources=datasources,
"caravel/explore.html",
viz=viz_obj, slice=slc, datasources=datasources,
can_add=slice_add_perm, can_edit=slice_edit_perm,
can_download=slice_download_perm,
userid=g.user.get_id() if g.user else '')
userid=g.user.get_id() if g.user else ''
)

@has_access
@expose("/exploreV2/<datasource_type>/<datasource_id>/<slice_id>/")
Expand Down Expand Up @@ -1705,7 +1710,11 @@ def sqllab_viz(self):
data = json.loads(request.args.get('data'))
table_name = data.get('datasourceName')
viz_type = data.get('chartType')
table = db.session.query(models.SqlaTable).filter_by(table_name=table_name).first()
table = (
db.session.query(models.SqlaTable)
.filter_by(table_name=table_name)
.first()
)
if not table:
table = models.SqlaTable(table_name=table_name)
table.database_id = data.get('dbId')
Expand Down
13 changes: 7 additions & 6 deletions caravel/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def flat_form_fields(cls):
def reassignments(self):
pass

def get_url(self, for_cache_key=False, **kwargs):
def get_url(self, for_cache_key=False, json_endpoint=False, **kwargs):
"""Returns the URL for the viz
:param for_cache_key: when getting the url as the identifier to hash
Expand Down Expand Up @@ -140,8 +140,12 @@ def get_url(self, for_cache_key=False, **kwargs):
for item in v:
od.add(key, item)

base_endpoint = '/caravel/explore'
if json_endpoint:
base_endpoint = '/caravel/explore_json'

href = Href(
'/caravel/explore/{self.datasource.type}/'
'{base_endpoint}/{self.datasource.type}/'
'{self.datasource.id}/'.format(**locals()))
if for_cache_key and 'force' in od:
del od['force']
Expand Down Expand Up @@ -373,7 +377,7 @@ def get_data(self):

@property
def json_endpoint(self):
return self.get_url(json="true")
return self.get_url(json_endpoint=True)

@property
def cache_key(self):
Expand Down Expand Up @@ -1261,7 +1265,6 @@ class HistogramViz(BaseViz):
}
}


def query_obj(self):
"""Returns the query object for this visualization"""
d = super(HistogramViz, self).query_obj()
Expand All @@ -1272,7 +1275,6 @@ def query_obj(self):
d['columns'] = [numeric_column]
return d


def get_df(self, query_obj=None):
"""Returns a pandas dataframe based on the query object"""
if not query_obj:
Expand All @@ -1289,7 +1291,6 @@ def get_df(self, query_obj=None):
df = df.fillna(0)
return df


def get_data(self):
"""Returns the chart data"""
df = self.get_df()
Expand Down
2 changes: 1 addition & 1 deletion run_specific_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ export CARAVEL_CONFIG=tests.caravel_test_config
set -e
caravel/bin/caravel version -v
export SOLO_TEST=1
nosetests tests.core_tests:CoreTests.test_public_user_dashboard_access
nosetests tests.core_tests:CoreTests.test_slice_endpoint
20 changes: 15 additions & 5 deletions tests/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def __init__(self, *args, **kwargs):

utils.init(caravel)

def get_or_create(self, cls, criteria, session):
obj = session.query(cls).filter_by(**criteria).first()
if not obj:
obj = cls(**criteria)
return obj

def login(self, username='admin', password='general'):
resp = self.client.post(
'/login/',
Expand All @@ -104,6 +110,15 @@ def get_latest_query(self, sql):
session.close()
return query

def get_slice(self, slice_name, session):
slc = (
session.query(models.Slice)
.filter_by(slice_name=slice_name)
.one()
)
session.expunge_all()
return slc

def get_resp(self, url):
"""Shortcut to get the parsed results while following redirects"""
resp = self.client.get(url, follow_redirects=True)
Expand All @@ -124,11 +139,6 @@ def get_access_requests(self, username, ds_type, ds_id):
def logout(self):
self.client.get('/logout/', follow_redirects=True)

def test_welcome(self):
self.login()
resp = self.client.get('/caravel/welcome')
assert 'Welcome' in resp.data.decode('utf-8')

def setup_public_access_for_dashboard(self, table_name):
public_role = appbuilder.sm.find_role('Public')
perms = db.session.query(ab_models.PermissionView).all()
Expand Down
35 changes: 28 additions & 7 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,33 @@ def setUp(self):
def tearDown(self):
pass

def test_welcome(self):
self.login()
resp = self.client.get('/caravel/welcome')
assert 'Welcome' in resp.data.decode('utf-8')

def test_slice_endpoint(self):
self.login(username='admin')
slc = self.get_slice("Girls", db.session)
resp = self.get_resp('/caravel/slice/{}/'.format(slc.id))
assert 'Time Column' in resp
assert 'List Roles' in resp

# Testing overrides
resp = self.get_resp(
'/caravel/slice/{}/?standalone=true'.format(slc.id))
assert 'List Roles' not in resp

def test_endpoints_for_a_slice(self):
self.login(username='admin')
slc = self.get_slice("Girls", db.session)

resp = self.get_resp(slc.viz.csv_endpoint)
assert 'Jennifer,' in resp

resp = self.get_resp(slc.viz.json_endpoint)
assert '"Jennifer"' in resp

def test_admin_only_permissions(self):
def assert_admin_permission_in(role_name, assert_func):
role = sm.find_role(role_name)
Expand Down Expand Up @@ -73,13 +100,7 @@ def assert_admin_view_menus_in(role_name, assert_func):

def test_save_slice(self):
self.login(username='admin')

slc = (
db.session.query(models.Slice.id)
.filter_by(slice_name="Energy Sankey")
.first())
slice_id = slc.id

slice_id = self.get_slice("Energy Sankey", db.session).id
copy_name = "Test Sankey Save"
tbl_id = self.table_ids.get('energy_usage')
url = (
Expand Down
Loading

0 comments on commit f70d301

Please sign in to comment.