Skip to content

Commit

Permalink
Refactoring more in the connector base classes (#2431)
Browse files Browse the repository at this point in the history
  • Loading branch information
mistercrunch authored Mar 27, 2017
1 parent 398036d commit 121b1d0
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 82 deletions.
9 changes: 5 additions & 4 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import imp
import json
import os
from collections import OrderedDict

from dateutil import tz
from flask_appbuilder.security.manager import AUTH_DB
Expand Down Expand Up @@ -178,10 +179,10 @@
# --------------------------------------------------
# Modules, datasources and middleware to be registered
# --------------------------------------------------
DEFAULT_MODULE_DS_MAP = {
'superset.connectors.druid.models': ['DruidDatasource'],
'superset.connectors.sqla.models': ['SqlaTable'],
}
DEFAULT_MODULE_DS_MAP = OrderedDict([
('superset.connectors.sqla.models', ['SqlaTable']),
('superset.connectors.druid.models', ['DruidDatasource']),
])
ADDITIONAL_MODULE_DS_MAP = {}
ADDITIONAL_MIDDLEWARE = []

Expand Down
19 changes: 17 additions & 2 deletions superset/connectors/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json

from sqlalchemy import Column, Integer, String, Text, Boolean

from sqlalchemy import (
Column, Integer, String, Text, Boolean,
)
from superset import utils
from superset.models.helpers import AuditMixinNullable, ImportMixin

Expand All @@ -12,9 +13,23 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):

__tablename__ = None # {connector_name}_datasource

column_class = None # link to derivative of BaseColumn
metric_class = None # link to derivative of BaseMetric

# Used to do code highlighting when displaying the query in the UI
query_language = None

# Columns
id = Column(Integer, primary_key=True)
description = Column(Text)
default_endpoint = Column(Text)
is_featured = Column(Boolean, default=False)
filter_select_enabled = Column(Boolean, default=False)
offset = Column(Integer, default=0)
cache_timeout = Column(Integer)
params = Column(String(1000))
perm = Column(String(1000))

@property
def column_names(self):
return sorted([c.column_name for c in self.columns])
Expand Down
27 changes: 9 additions & 18 deletions superset/connectors/druid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,38 +312,29 @@ class DruidDatasource(Model, BaseDatasource):

"""ORM object referencing Druid datasources (tables)"""

__tablename__ = 'datasources'

type = "druid"
query_langtage = "json"
metric_class = DruidMetric
cluster_class = DruidCluster
metric_class = DruidMetric
column_class = DruidColumn

baselink = "druiddatasourcemodelview"

__tablename__ = 'datasources'
id = Column(Integer, primary_key=True)
# Columns
datasource_name = Column(String(255), unique=True)
is_featured = Column(Boolean, default=False)
is_hidden = Column(Boolean, default=False)
filter_select_enabled = Column(Boolean, default=False)
description = Column(Text)
fetch_values_from = Column(String(100))
default_endpoint = Column(Text)
cluster_name = Column(
String(250), ForeignKey('clusters.cluster_name'))
cluster = relationship(
'DruidCluster', backref='datasources', foreign_keys=[cluster_name])
user_id = Column(Integer, ForeignKey('ab_user.id'))
owner = relationship(
'User',
backref=backref('datasources', cascade='all, delete-orphan'),
foreign_keys=[user_id])
cluster_name = Column(
String(250), ForeignKey('clusters.cluster_name'))
cluster = relationship(
'DruidCluster', backref='datasources', foreign_keys=[cluster_name])
offset = Column(Integer, default=0)
cache_timeout = Column(Integer)
params = Column(String(1000))
perm = Column(String(1000))

metric_cls = DruidMetric
column_cls = DruidColumn

export_fields = (
'datasource_name', 'is_hidden', 'description', 'default_endpoint',
Expand Down
53 changes: 51 additions & 2 deletions superset/connectors/druid/views.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from datetime import datetime
import logging

import sqlalchemy as sqla

from flask import Markup
from flask_appbuilder import CompactCRUDMixin
from flask import Markup, flash, redirect
from flask_appbuilder import CompactCRUDMixin, expose
from flask_appbuilder.models.sqla.interface import SQLAInterface

from flask_babel import lazy_gettext as _
from flask_babel import gettext as __

import superset
from superset import db, utils, appbuilder, sm, security
from superset.connectors.connector_registry import ConnectorRegistry
from superset.utils import has_access
from superset.views.base import BaseSupersetView
from superset.views.base import (
SupersetModelView, validate_json, DeleteMixin, ListWidgetWithCheckboxes,
DatasourceFilter, get_datasource_exist_error_mgs)
Expand Down Expand Up @@ -205,3 +211,46 @@ def post_update(self, datasource):
category="Sources",
category_label=__("Sources"),
icon="fa-cube")


class Druid(BaseSupersetView):
"""The base views for Superset!"""

@has_access
@expose("/refresh_datasources/")
def refresh_datasources(self):
"""endpoint that refreshes druid datasources metadata"""
session = db.session()
DruidCluster = ConnectorRegistry.sources['druid'].cluster_class
for cluster in session.query(DruidCluster).all():
cluster_name = cluster.cluster_name
try:
cluster.refresh_datasources()
except Exception as e:
flash(
"Error while processing cluster '{}'\n{}".format(
cluster_name, utils.error_msg_from_exception(e)),
"danger")
logging.exception(e)
return redirect('/druidclustermodelview/list/')
cluster.metadata_last_refreshed = datetime.now()
flash(
"Refreshed metadata from cluster "
"[" + cluster.cluster_name + "]",
'info')
session.commit()
return redirect("/druiddatasourcemodelview/list/")

appbuilder.add_view_no_menu(Druid)

appbuilder.add_link(
"Refresh Druid Metadata",
label=__("Refresh Druid Metadata"),
href='/druid/refresh_datasources/',
category='Sources',
category_label=__("Sources"),
category_icon='fa-database',
icon="fa-cog")


appbuilder.add_separator("Sources", )
17 changes: 5 additions & 12 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,33 +162,26 @@ class SqlaTable(Model, BaseDatasource):
type = "table"
query_language = 'sql'
metric_class = SqlMetric
column_class = TableColumn

__tablename__ = 'tables'
id = Column(Integer, primary_key=True)
table_name = Column(String(250))
main_dttm_col = Column(String(250))
description = Column(Text)
default_endpoint = Column(Text)
database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
is_featured = Column(Boolean, default=False)
filter_select_enabled = Column(Boolean, default=False)
fetch_values_predicate = Column(String(1000))
user_id = Column(Integer, ForeignKey('ab_user.id'))
owner = relationship('User', backref='tables', foreign_keys=[user_id])
owner = relationship(
'User',
backref='tables',
foreign_keys=[user_id])
database = relationship(
'Database',
backref=backref('tables', cascade='all, delete-orphan'),
foreign_keys=[database_id])
offset = Column(Integer, default=0)
cache_timeout = Column(Integer)
schema = Column(String(255))
sql = Column(Text)
params = Column(Text)
perm = Column(String(1000))

baselink = "tablemodelview"
column_cls = TableColumn
metric_cls = SqlMetric
export_fields = (
'table_name', 'main_dttm_col', 'description', 'default_endpoint',
'database_id', 'is_featured', 'offset', 'cache_timeout', 'schema',
Expand Down
4 changes: 2 additions & 2 deletions superset/import_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def import_datasource(
new_m.table_id = datasource.id
logging.info('Importing metric {} from the datasource: {}'.format(
new_m.to_json(), i_datasource.full_name))
imported_m = i_datasource.metric_cls.import_obj(new_m)
imported_m = i_datasource.metric_class.import_obj(new_m)
if (imported_m.metric_name not in
[m.metric_name for m in datasource.metrics]):
datasource.metrics.append(imported_m)
Expand All @@ -48,7 +48,7 @@ def import_datasource(
new_c.table_id = datasource.id
logging.info('Importing column {} from the datasource: {}'.format(
new_c.to_json(), i_datasource.full_name))
imported_c = i_datasource.column_cls.import_obj(new_c)
imported_c = i_datasource.column_class.import_obj(new_c)
if (imported_c.column_name not in
[c.column_name for c in datasource.columns]):
datasource.columns.append(imported_c)
Expand Down
52 changes: 10 additions & 42 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pandas as pd
import pickle
import re
import sys
import time
import traceback
import zlib
Expand Down Expand Up @@ -130,7 +129,7 @@ def check_ownership(obj, raise_if_false=True):
return False

security_exception = utils.SupersetSecurityException(
"You don't have the rights to alter [{}]".format(obj))
"You don't have the rights to alter [{}]".format(obj))

if g.user.is_anonymous():
if raise_if_false:
Expand Down Expand Up @@ -762,8 +761,8 @@ def override_role_permissions(self):
granted_perms = []
for datasource in datasources:
view_menu_perm = sm.find_permission_view_menu(
view_menu_name=datasource.perm,
permission_name='datasource_access')
view_menu_name=datasource.perm,
permission_name='datasource_access')
# prevent creating empty permissions
if view_menu_perm and view_menu_perm.view_menu:
role.permissions.append(view_menu_perm)
Expand Down Expand Up @@ -1214,8 +1213,12 @@ def overwrite_slice(self, slc):
@expose("/checkbox/<model_view>/<id_>/<attr>/<value>", methods=['GET'])
def checkbox(self, model_view, id_, attr, value):
"""endpoint for checking/unchecking any boolean in a sqla model"""
Col = ConnectorRegistry.sources['table'].column_cls
obj = db.session.query(Col).filter_by(id=id_).first()
modelview_to_model = {
'TableColumnInlineView':
ConnectorRegistry.sources['table'].column_class,
}
model = modelview_to_model[model_view]
obj = db.session.query(model).filter_by(id=id_).first()
if obj:
setattr(obj, attr, value == 'true')
db.session.commit()
Expand Down Expand Up @@ -1750,6 +1753,7 @@ def sync_druid_source(self):
@expose("/sqllab_viz/", methods=['POST'])
@log_this
def sqllab_viz(self):
SqlaTable = ConnectorRegistry.sources['table']
data = json.loads(request.form.get('data'))
table_name = data.get('datasourceName')
viz_type = data.get('chartType')
Expand Down Expand Up @@ -2158,32 +2162,6 @@ def search_queries(self):
status=200,
mimetype="application/json")

@has_access
@expose("/refresh_datasources/")
def refresh_datasources(self):
"""endpoint that refreshes druid datasources metadata"""
session = db.session()
DruidDatasource = ConnectorRegistry.sources['druid']
DruidCluster = DruidDatasource.cluster_class
for cluster in session.query(DruidCluster).all():
cluster_name = cluster.cluster_name
try:
cluster.refresh_datasources()
except Exception as e:
flash(
"Error while processing cluster '{}'\n{}".format(
cluster_name, utils.error_msg_from_exception(e)),
"danger")
logging.exception(e)
return redirect('/druidclustermodelview/list/')
cluster.metadata_last_refreshed = datetime.now()
flash(
"Refreshed metadata from cluster "
"[" + cluster.cluster_name + "]",
'info')
session.commit()
return redirect("/druiddatasourcemodelview/list/")

@app.errorhandler(500)
def show_traceback(self):
return render_template(
Expand Down Expand Up @@ -2257,16 +2235,6 @@ def sqllab(self):
)
appbuilder.add_view_no_menu(Superset)

if config['DRUID_IS_ACTIVE']:
appbuilder.add_link(
"Refresh Druid Metadata",
label=__("Refresh Druid Metadata"),
href='/superset/refresh_datasources/',
category='Sources',
category_label=__("Sources"),
category_icon='fa-database',
icon="fa-cog")


class CssTemplateModelView(SupersetModelView, DeleteMixin):
datamodel = SQLAInterface(models.CssTemplate)
Expand Down

0 comments on commit 121b1d0

Please sign in to comment.