From 40e7057bcec55421275d0121d365a2e4f2d38a29 Mon Sep 17 00:00:00 2001 From: Bogdan Date: Thu, 13 Oct 2016 18:18:03 -0700 Subject: [PATCH] Override the role with perms for give datasources. (#1335) * Override the role with perms for give datasources. * Address comments. --- caravel/config.py | 1 - caravel/models.py | 10 +- caravel/source_registry.py | 8 ++ caravel/utils.py | 7 + caravel/views.py | 47 ++++++ tests/{access_requests.py => access_tests.py} | 134 +++++++++++++++++- tests/base_tests.py | 13 +- tests/import_export_tests.py | 6 +- 8 files changed, 212 insertions(+), 14 deletions(-) rename tests/{access_requests.py => access_tests.py} (66%) diff --git a/caravel/config.py b/caravel/config.py index 89bcaa76f4cd0..71b754234ccee 100644 --- a/caravel/config.py +++ b/caravel/config.py @@ -173,7 +173,6 @@ DEFAULT_MODULE_DS_MAP = {'caravel.models': ['DruidDatasource', 'SqlaTable']} ADDITIONAL_MODULE_DS_MAP = {} - """ 1) http://docs.python-guide.org/en/latest/writing/logging/ 2) https://docs.python.org/2/library/logging.config.html diff --git a/caravel/models.py b/caravel/models.py index 00a819989f4f4..4d0ab5a7e2e6b 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -582,6 +582,7 @@ class Database(Model, AuditMixinNullable): """An ORM object that stores Database related information""" __tablename__ = 'dbs' + id = Column(Integer, primary_key=True) database_name = Column(String(250), unique=True) sqlalchemy_uri = Column(String(1024)) @@ -806,7 +807,8 @@ def name(self): @property def full_name(self): - return "[{obj.database}].[{obj.table_name}]".format(obj=self) + return utils.get_datasource_full_name( + self.database, self.table_name, schema=self.schema) @property def dttm_cols(self): @@ -1372,6 +1374,7 @@ class DruidCluster(Model, AuditMixinNullable): """ORM object referencing the Druid clusters""" __tablename__ = 'clusters' + id = Column(Integer, primary_key=True) cluster_name = Column(String(250), unique=True) coordinator_host = Column(String(255)) @@ -1484,9 +1487,8 @@ def link(self): @property def full_name(self): - return ( - "[{obj.cluster_name}]." - "[{obj.datasource_name}]").format(obj=self) + return utils.get_datasource_full_name( + self.cluster_name, self.datasource_name) @property def time_column_grains(self): diff --git a/caravel/source_registry.py b/caravel/source_registry.py index 6176c9c0c96b1..669ca176fe243 100644 --- a/caravel/source_registry.py +++ b/caravel/source_registry.py @@ -22,6 +22,14 @@ def get_datasource(cls, datasource_type, datasource_id, session): .one() ) + @classmethod + def get_all_datasources(cls, session): + datasources = [] + for source_type in SourceRegistry.sources: + datasources.extend( + session.query(SourceRegistry.sources[source_type]).all()) + return datasources + @classmethod def get_datasource_by_name(cls, session, datasource_type, datasource_name, schema, database_name): diff --git a/caravel/utils.py b/caravel/utils.py index f3f19673047cc..c453a6e4263be 100644 --- a/caravel/utils.py +++ b/caravel/utils.py @@ -229,6 +229,7 @@ def init(caravel): ADMIN_ONLY_PERMISSIONS = set([ 'can_sync_druid_source', + 'can_override_role_permissions', 'can_approve', ]) @@ -443,6 +444,12 @@ def generic_find_constraint_name(table, columns, referenced, db): return fk.name +def get_datasource_full_name(database_name, datasource_name, schema=None): + if not schema: + return "[{}].[{}]".format(database_name, datasource_name) + return "[{}].[{}].[{}]".format(database_name, schema, datasource_name) + + def validate_json(obj): if obj: try: diff --git a/caravel/views.py b/caravel/views.py index 19307a05036d6..fd4a38721b93e 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -1062,6 +1062,53 @@ def msg(self): class Caravel(BaseCaravelView): """The base views for Caravel!""" + @has_access + @expose("/override_role_permissions/", methods=['POST']) + def override_role_permissions(self): + """Updates the role with the give datasource permissions. + + Permissions not in the request will be revoked. This endpoint should + be available to admins only. Expects JSON in the format: + { + 'role_name': '{role_name}', + 'database': [{ + 'datasource_type': '{table|druid}', + 'name': '{database_name}', + 'schema': [{ + 'name': '{schema_name}', + 'datasources': ['{datasource name}, {datasource name}'] + }] + }] + } + """ + data = request.get_json(force=True) + role_name = data['role_name'] + databases = data['database'] + + db_ds_names = set() + for dbs in databases: + for schema in dbs['schema']: + for ds_name in schema['datasources']: + db_ds_names.add(utils.get_datasource_full_name( + dbs['name'], ds_name, schema=schema['name'])) + + existing_datasources = SourceRegistry.get_all_datasources(db.session) + datasources = [ + d for d in existing_datasources if d.full_name in db_ds_names] + + role = sm.find_role(role_name) + # remove all permissions + role.permissions = [] + # grant permissions to the list of datasources + for ds_name in datasources: + role.permissions.append( + sm.find_permission_view_menu( + view_menu_name=ds_name.perm, + permission_name='datasource_access') + ) + db.session.commit() + return Response(status=201) + @log_this @has_access @expose("/request_access/") diff --git a/tests/access_requests.py b/tests/access_tests.py similarity index 66% rename from tests/access_requests.py rename to tests/access_tests.py index f443c273eeccb..65e6a9a6670d6 100644 --- a/tests/access_requests.py +++ b/tests/access_tests.py @@ -4,6 +4,7 @@ from __future__ import print_function from __future__ import unicode_literals +import json import unittest from caravel import db, models, sm @@ -11,16 +12,144 @@ from .base_tests import CaravelTestCase +ROLE_TABLES_PERM_DATA = { + 'role_name': 'override_me', + 'database': [{ + 'datasource_type': 'table', + 'name': 'main', + 'schema': [{ + 'name': '', + 'datasources': ['birth_names'] + }] + }] +} + +ROLE_ALL_PERM_DATA = { + 'role_name': 'override_me', + 'database': [{ + 'datasource_type': 'table', + 'name': 'main', + 'schema': [{ + 'name': '', + 'datasources': ['birth_names'] + }] + }, { + 'datasource_type': 'druid', + 'name': 'druid_test', + 'schema': [{ + 'name': '', + 'datasources': ['druid_ds_1', 'druid_ds_2'] + }] + } + ] +} class RequestAccessTests(CaravelTestCase): - requires_examples = True + requires_examples = False + + @classmethod + def setUpClass(cls): + sm.add_role('override_me') + db.session.commit() + + @classmethod + def tearDownClass(cls): + override_me = sm.find_role('override_me') + db.session.delete(override_me) + db.session.commit() + + def setUp(self): + self.login('admin') + + def tearDown(self): + self.logout() + override_me = sm.find_role('override_me') + override_me.permissions = [] + db.session.commit() + db.session.close() + + def test_override_role_permissions_is_admin_only(self): + self.logout() + self.login('alpha') + response = self.client.post( + '/caravel/override_role_permissions/', + data=json.dumps(ROLE_TABLES_PERM_DATA), + content_type='application/json', + follow_redirects=True) + self.assertNotEquals(405, response.status_code) + + def test_override_role_permissions_1_table(self): + response = self.client.post( + '/caravel/override_role_permissions/', + data=json.dumps(ROLE_TABLES_PERM_DATA), + content_type='application/json') + self.assertEquals(201, response.status_code) + + updated_override_me = sm.find_role('override_me') + self.assertEquals(1, len(updated_override_me.permissions)) + birth_names = self.get_table_by_name('birth_names') + self.assertEquals( + birth_names.perm, + updated_override_me.permissions[0].view_menu.name) + self.assertEquals( + 'datasource_access', + updated_override_me.permissions[0].permission.name) + + def test_override_role_permissions_druid_and_table(self): + response = self.client.post( + '/caravel/override_role_permissions/', + data=json.dumps(ROLE_ALL_PERM_DATA), + content_type='application/json') + self.assertEquals(201, response.status_code) + + updated_role = sm.find_role('override_me') + perms = sorted( + updated_role.permissions, key=lambda p: p.view_menu.name) + self.assertEquals(3, len(perms)) + druid_ds_1 = self.get_druid_ds_by_name('druid_ds_1') + self.assertEquals(druid_ds_1.perm, perms[0].view_menu.name) + self.assertEquals('datasource_access', perms[0].permission.name) + + druid_ds_2 = self.get_druid_ds_by_name('druid_ds_2') + self.assertEquals(druid_ds_2.perm, perms[1].view_menu.name) + self.assertEquals( + 'datasource_access', updated_role.permissions[1].permission.name) + + birth_names = self.get_table_by_name('birth_names') + self.assertEquals(birth_names.perm, perms[2].view_menu.name) + self.assertEquals( + 'datasource_access', updated_role.permissions[2].permission.name) + + def test_override_role_permissions_drops_absent_perms(self): + override_me = sm.find_role('override_me') + override_me.permissions.append( + sm.find_permission_view_menu( + view_menu_name=self.get_table_by_name('long_lat').perm, + permission_name='datasource_access') + ) + db.session.flush() + + response = self.client.post( + '/caravel/override_role_permissions/', + data=json.dumps(ROLE_TABLES_PERM_DATA), + content_type='application/json') + self.assertEquals(201, response.status_code) + updated_override_me = sm.find_role('override_me') + self.assertEquals(1, len(updated_override_me.permissions)) + birth_names = self.get_table_by_name('birth_names') + self.assertEquals( + birth_names.perm, + updated_override_me.permissions[0].view_menu.name) + self.assertEquals( + 'datasource_access', + updated_override_me.permissions[0].permission.name) + def test_approve(self): session = db.session TEST_ROLE_NAME = 'table_role' sm.add_role(TEST_ROLE_NAME) - self.login('admin') def create_access_request(ds_type, ds_name, role_name): ds_class = SourceRegistry.sources[ds_type] @@ -116,6 +245,7 @@ def create_access_request(ds_type, ds_name, role_name): def test_request_access(self): session = db.session + self.logout() self.login(username='gamma') gamma_user = sm.find_user(username='gamma') sm.add_role('dummy_role') diff --git a/tests/base_tests.py b/tests/base_tests.py index 1e22934e27bf4..c841c77cdd583 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -20,8 +20,8 @@ class CaravelTestCase(unittest.TestCase): - requires_examples = False - examples_loaded = False + requires_examples = True + examples_loaded = True def __init__(self, *args, **kwargs): if ( @@ -119,6 +119,15 @@ def get_slice(self, slice_name, session): session.expunge_all() return slc + def get_table_by_name(self, name): + return db.session.query(models.SqlaTable).filter_by( + table_name=name).first() + + def get_druid_ds_by_name(self, name): + return db.session.query(models.DruidDatasource).filter_by( + datasource_name=name).first() + + def get_resp(self, url): """Shortcut to get the parsed results while following redirects""" resp = self.client.get(url, follow_redirects=True) diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index 6a929a6c4292a..33f9bd506dacf 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -20,6 +20,7 @@ class ImportExportTests(CaravelTestCase): def __init__(self, *args, **kwargs): super(ImportExportTests, self).__init__(*args, **kwargs) + db.session.commit() @classmethod def delete_imports(cls): @@ -109,10 +110,6 @@ def get_table(self, table_id): return db.session.query(models.SqlaTable).filter_by( id=table_id).first() - def get_table_by_name(self, name): - return db.session.query(models.SqlaTable).filter_by( - table_name=name).first() - def assert_dash_equals(self, expected_dash, actual_dash): self.assertEquals(expected_dash.slug, actual_dash.slug) self.assertEquals( @@ -221,7 +218,6 @@ def test_import_2_slices_for_same_table(self): self.assert_slice_equals(slc_1, imported_slc_1) self.assertEquals(imported_slc_1.datasource.perm, imported_slc_1.perm) - self.assertEquals(table_id, imported_slc_2.datasource_id) self.assert_slice_equals(slc_2, imported_slc_2) self.assertEquals(imported_slc_2.datasource.perm, imported_slc_2.perm)