Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the explore view #1252

Merged
merged 3 commits into from
Oct 7, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions caravel/source_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,11 @@ def register_sources(cls, datasource_config):
for class_name in class_names:
source_class = getattr(module_obj, class_name)
cls.sources[source_class.type] = source_class

@classmethod
def get_datasource(cls, datasource_type, datasource_id, session):
return (
session.query(cls.sources[datasource_type])
.filter_by(id=datasource_id)
.one()
)
2 changes: 1 addition & 1 deletion caravel/templates/caravel/standalone.html
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<html>
<head>
<title>{{viz.token}}</title>
<title>{{ viz.token }}</title>
<link rel="stylesheet" type="text/css" href="/static/assets/node_modules/font-awesome/css/font-awesome.min.css" />
<link rel="stylesheet" type="text/css" href="/static/assets/stylesheets/caravel.css" />
<link rel="stylesheet" type="text/css" href="/static/appbuilder/css/flags/flags16.css" />
Expand Down
143 changes: 76 additions & 67 deletions caravel/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from flask_appbuilder.models.sqla.filters import BaseFilter

from sqlalchemy import create_engine
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.routing import BaseConverter
from wtforms.validators import ValidationError

Expand Down Expand Up @@ -244,7 +243,8 @@ def apply(self, query, func): # noqa
druid_datasources = []
for perm in perms:
match = re.search(r'\(id:(\d+)\)', perm)
druid_datasources.append(match.group(1))
if match:
druid_datasources.append(match.group(1))
qry = query.filter(self.model.id.in_(druid_datasources))
return qry

Expand Down Expand Up @@ -672,6 +672,7 @@ class DruidClusterModelView(CaravelModelView, DeleteMixin): # noqa
'broker_port': _("Broker Port"),
'broker_endpoint': _("Broker Endpoint"),
}

def pre_add(self, db):
utils.merge_perm(sm, 'database_access', db.perm)

Expand Down Expand Up @@ -699,7 +700,8 @@ class SliceModelView(CaravelModelView, DeleteMixin): # noqa
list_columns = [
'slice_link', 'viz_type', 'datasource_link', 'creator', 'modified']
edit_columns = [
'slice_name', 'description', 'viz_type', 'owners', 'dashboards', 'params', 'cache_timeout']
'slice_name', 'description', 'viz_type', 'owners', 'dashboards',
'params', 'cache_timeout']
base_order = ('changed_on', 'desc')
description_columns = {
'description': Markup(
Expand Down Expand Up @@ -1099,107 +1101,110 @@ def approve(self):
session.commit()
return redirect('/accessrequestsmodelview/list/')

def get_viz(
self,
slice_id=None,
args=None,
datasource_type=None,
datasource_id=None):
if slice_id:
slc = db.session.query(models.Slice).filter_by(id=slice_id).one()
return slc.get_viz()
else:
viz_type = args.get('viz_type', 'table')
datasource = SourceRegistry.get_datasource(
datasource_type, datasource_id, db.session)
viz_obj = viz.viz_types[viz_type](datasource, request.args)
return viz_obj

@has_access
@expose("/explore/<datasource_type>/<datasource_id>/<slice_id>/")
@expose("/explore/<datasource_type>/<datasource_id>/")
@expose("/datasource/<datasource_type>/<datasource_id>/") # Legacy url
@expose("/slice/<slice_id>/")
def slice(self, slice_id):
viz_obj = self.get_viz(slice_id)
return redirect(viz_obj.get_url(**request.args))

@has_access_api
@expose("/explore_json/<datasource_type>/<datasource_id>/")
def explore_json(self, datasource_type, datasource_id):
viz_obj = self.get_viz(
datasource_type=datasource_type,
datasource_id=datasource_id,
args=request.args)
if not self.datasource_access(viz_obj.datasource):
return Response(
json.dumps(
{'error': _("You don't have access to this datasource")}),
status=404,
mimetype="application/json")
return Response(
viz_obj.get_json(),
status=200,
mimetype="application/json")

@log_this
def explore(self, datasource_type, datasource_id, slice_id=None):
@has_access
@expose("/explore/<datasource_type>/<datasource_id>/")
def explore(self, datasource_type, datasource_id):
viz_type = request.args.get("viz_type")
slice_id = request.args.get('slice_id')
slc = db.session.query(models.Slice).filter_by(id=slice_id).first()

error_redirect = '/slicemodelview/list/'
datasource_class = SourceRegistry.sources[datasource_type]
datasources = db.session.query(datasource_class).all()
datasources = sorted(datasources, key=lambda ds: ds.full_name)
datasource = [ds for ds in datasources if int(datasource_id) == ds.id]
datasource = datasource[0] if datasource else None

if not datasource:
viz_obj = self.get_viz(
datasource_type=datasource_type,
datasource_id=datasource_id,
args=request.args)

if not viz_obj.datasource:
flash(DATASOURCE_MISSING_ERR, "alert")
return redirect(error_redirect)

if not self.datasource_access(datasource):
if not self.datasource_access(viz_obj.datasource):
flash(
__(get_datasource_access_error_msg(datasource.name)), "danger")
__(get_datasource_access_error_msg(viz_obj.datasource.name)),
"danger")
return redirect(
'caravel/request_access/?'
'datasource_type={datasource_type}&'
'datasource_id={datasource_id}&'
''.format(**locals()))

request_args_multi_dict = request.args # MultiDict

slice_id = slice_id or request_args_multi_dict.get("slice_id")
slc = None
# build viz_obj and get it's params
if slice_id:
slc = db.session.query(models.Slice).filter_by(id=slice_id).first()
try:
viz_obj = slc.get_viz(
url_params_multidict=request_args_multi_dict)
except Exception as e:
logging.exception(e)
flash(utils.error_msg_from_exception(e), "danger")
return redirect(error_redirect)
else:
viz_type = request_args_multi_dict.get("viz_type")
if not viz_type and datasource.default_endpoint:
return redirect(datasource.default_endpoint)
# default to table if no default endpoint and no viz_type
viz_type = viz_type or "table"
# validate viz params
try:
viz_obj = viz.viz_types[viz_type](
datasource, request_args_multi_dict)
except Exception as e:
logging.exception(e)
flash(utils.error_msg_from_exception(e), "danger")
return redirect(error_redirect)
slice_params_multi_dict = ImmutableMultiDict(viz_obj.orig_form_data)
if not viz_type and viz_obj.datasource.default_endpoint:
return redirect(viz_obj.datasource.default_endpoint)

# slc perms
slice_add_perm = self.can_access('can_add', 'SliceModelView')
slice_edit_perm = check_ownership(slc, raise_if_false=False)
slice_download_perm = self.can_access('can_download', 'SliceModelView')

# handle save or overwrite
action = slice_params_multi_dict.get('action')
action = request.args.get('action')
if action in ('saveas', 'overwrite'):
return self.save_or_overwrite_slice(
slice_params_multi_dict, slc, slice_add_perm, slice_edit_perm)
request.args, slc, slice_add_perm, slice_edit_perm)

# handle different endpoints
if slice_params_multi_dict.get("json") == "true":
if config.get("DEBUG"):
# Allows for nice debugger stack traces in debug mode
return Response(
viz_obj.get_json(),
status=200,
mimetype="application/json")
try:
return Response(
viz_obj.get_json(),
status=200,
mimetype="application/json")
except Exception as e:
logging.exception(e)
return json_error_response(utils.error_msg_from_exception(e))

elif slice_params_multi_dict.get("csv") == "true":
if request.args.get("csv") == "true":
payload = viz_obj.get_csv()
return Response(
payload,
status=200,
headers=generate_download_headers("csv"),
mimetype="application/csv")
elif request.args.get("standalone") == "true":
return self.render_template("caravel/standalone.html", viz=viz_obj)
else:
if slice_params_multi_dict.get("standalone") == "true":
template = "caravel/standalone.html"
else:
template = "caravel/explore.html"
return self.render_template(
template, viz=viz_obj, slice=slc, datasources=datasources,
"caravel/explore.html",
viz=viz_obj, slice=slc, datasources=datasources,
can_add=slice_add_perm, can_edit=slice_edit_perm,
can_download=slice_download_perm,
userid=g.user.get_id() if g.user else '')
userid=g.user.get_id() if g.user else ''
)

def save_or_overwrite_slice(
self, args, slc, slice_add_perm, slice_edit_perm):
Expand Down Expand Up @@ -1598,7 +1603,11 @@ def sqllab_viz(self):
data = json.loads(request.args.get('data'))
table_name = data.get('datasourceName')
viz_type = data.get('chartType')
table = db.session.query(models.SqlaTable).filter_by(table_name=table_name).first()
table = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a schema here to the filter?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

risk of collisions in the context of unit tests is close to zero

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code belongs to the views.py @mistercrunch

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right, I was just linting here, but we should tackle this

db.session.query(models.SqlaTable)
.filter_by(table_name=table_name)
.first()
)
if not table:
table = models.SqlaTable(table_name=table_name)
table.database_id = data.get('dbId')
Expand Down
13 changes: 7 additions & 6 deletions caravel/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def flat_form_fields(cls):
def reassignments(self):
pass

def get_url(self, for_cache_key=False, **kwargs):
def get_url(self, for_cache_key=False, json_endpoint=False, **kwargs):
"""Returns the URL for the viz

:param for_cache_key: when getting the url as the identifier to hash
Expand Down Expand Up @@ -140,8 +140,12 @@ def get_url(self, for_cache_key=False, **kwargs):
for item in v:
od.add(key, item)

base_endpoint = '/caravel/explore'
if json_endpoint:
base_endpoint = '/caravel/explore_json'

href = Href(
'/caravel/explore/{self.datasource.type}/'
'{base_endpoint}/{self.datasource.type}/'
'{self.datasource.id}/'.format(**locals()))
if for_cache_key and 'force' in od:
del od['force']
Expand Down Expand Up @@ -373,7 +377,7 @@ def get_data(self):

@property
def json_endpoint(self):
return self.get_url(json="true")
return self.get_url(json_endpoint=True)

@property
def cache_key(self):
Expand Down Expand Up @@ -1261,7 +1265,6 @@ class HistogramViz(BaseViz):
}
}


def query_obj(self):
"""Returns the query object for this visualization"""
d = super(HistogramViz, self).query_obj()
Expand All @@ -1272,7 +1275,6 @@ def query_obj(self):
d['columns'] = [numeric_column]
return d


def get_df(self, query_obj=None):
"""Returns a pandas dataframe based on the query object"""
if not query_obj:
Expand All @@ -1289,7 +1291,6 @@ def get_df(self, query_obj=None):
df = df.fillna(0)
return df


def get_data(self):
"""Returns the chart data"""
df = self.get_df()
Expand Down
2 changes: 1 addition & 1 deletion run_specific_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ export CARAVEL_CONFIG=tests.caravel_test_config
set -e
caravel/bin/caravel version -v
export SOLO_TEST=1
nosetests tests.core_tests:CoreTests.test_public_user_dashboard_access
nosetests tests.core_tests:CoreTests.test_slice_endpoint
20 changes: 15 additions & 5 deletions tests/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def __init__(self, *args, **kwargs):

utils.init(caravel)

def get_or_create(self, cls, criteria, session):
obj = session.query(cls).filter_by(**criteria).first()
if not obj:
obj = cls(**criteria)
return obj

def login(self, username='admin', password='general'):
resp = self.client.post(
'/login/',
Expand All @@ -104,6 +110,15 @@ def get_latest_query(self, sql):
session.close()
return query

def get_slice(self, slice_name, session):
slc = (
session.query(models.Slice)
.filter_by(slice_name=slice_name)
.one()
)
session.expunge_all()
return slc

def get_resp(self, url):
"""Shortcut to get the parsed results while following redirects"""
resp = self.client.get(url, follow_redirects=True)
Expand All @@ -124,11 +139,6 @@ def get_access_requests(self, username, ds_type, ds_id):
def logout(self):
self.client.get('/logout/', follow_redirects=True)

def test_welcome(self):
self.login()
resp = self.client.get('/caravel/welcome')
assert 'Welcome' in resp.data.decode('utf-8')

def setup_public_access_for_dashboard(self, table_name):
public_role = appbuilder.sm.find_role('Public')
perms = db.session.query(ab_models.PermissionView).all()
Expand Down
35 changes: 28 additions & 7 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,33 @@ def setUp(self):
def tearDown(self):
pass

def test_welcome(self):
self.login()
resp = self.client.get('/caravel/welcome')
assert 'Welcome' in resp.data.decode('utf-8')

def test_slice_endpoint(self):
self.login(username='admin')
slc = self.get_slice("Girls", db.session)
resp = self.get_resp('/caravel/slice/{}/'.format(slc.id))
assert 'Time Column' in resp
assert 'List Roles' in resp

# Testing overrides
resp = self.get_resp(
'/caravel/slice/{}/?standalone=true'.format(slc.id))
assert 'List Roles' not in resp

def test_endpoints_for_a_slice(self):
self.login(username='admin')
slc = self.get_slice("Girls", db.session)

resp = self.get_resp(slc.viz.csv_endpoint)
assert 'Jennifer,' in resp

resp = self.get_resp(slc.viz.json_endpoint)
assert '"Jennifer"' in resp

def test_admin_only_permissions(self):
def assert_admin_permission_in(role_name, assert_func):
role = sm.find_role(role_name)
Expand Down Expand Up @@ -73,13 +100,7 @@ def assert_admin_view_menus_in(role_name, assert_func):

def test_save_slice(self):
self.login(username='admin')

slc = (
db.session.query(models.Slice.id)
.filter_by(slice_name="Energy Sankey")
.first())
slice_id = slc.id

slice_id = self.get_slice("Energy Sankey", db.session).id
copy_name = "Test Sankey Save"
tbl_id = self.table_ids.get('energy_usage')
url = (
Expand Down
Loading