Skip to content

Commit

Permalink
datasource access to allow more granular access to tables
Browse files Browse the repository at this point in the history
  • Loading branch information
painyjames committed Jan 20, 2022
1 parent 14b9298 commit 1fa370c
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 7 deletions.
15 changes: 13 additions & 2 deletions superset/databases/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,30 @@ class DatabaseFilter(BaseFilter):
# TODO(bogdan): consider caching.
def schema_access_databases(self) -> Set[str]: # noqa pylint: disable=no-self-use
return {
security_manager.unpack_schema_perm(vm)[0]
security_manager.unpack_perm(vm)[0]
for vm in security_manager.user_view_menu_names("schema_access")
}

def datasource_access_databases( # noqa pylint: disable=no-self-use
self,
) -> Set[str]:
return {
security_manager.unpack_perm(vm)[0]
for vm in security_manager.user_view_menu_names("datasource_access")
}

def apply(self, query: Query, value: Any) -> Query:
if security_manager.can_access_all_databases():
return query
database_perms = security_manager.user_view_menu_names("database_access")
# TODO(bogdan): consider adding datasource access here as well.
schema_access_databases = self.schema_access_databases()

datasource_access_databases = self.datasource_access_databases()

return query.filter(
or_(
self.model.perm.in_(database_perms),
self.model.database_name.in_(schema_access_databases),
self.model.database_name.in_(datasource_access_databases),
)
)
6 changes: 3 additions & 3 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def get_schema_perm( # pylint: disable=no-self-use

return None

def unpack_schema_perm( # pylint: disable=no-self-use
def unpack_perm( # pylint: disable=no-self-use
self, schema_permission: str
) -> Tuple[str, str]:
# [database_name].[schema_name]
Expand Down Expand Up @@ -532,7 +532,7 @@ def get_schemas_accessible_by_user(

# schema_access
accessible_schemas = {
self.unpack_schema_perm(s)[1]
self.unpack_perm(s)[1]
for s in self.user_view_menu_names("schema_access")
if s.startswith(f"[{database}].")
}
Expand Down Expand Up @@ -582,7 +582,7 @@ def get_datasources_accessible_by_user( # pylint: disable=invalid-name
)
if schema:
names = {d.table_name for d in user_datasources if d.schema == schema}
return [d for d in datasource_names if d in names]
return [d for d in datasource_names if d.table in names]

full_names = {d.full_name for d in user_datasources}
return [d for d in datasource_names if f"[{database}].[{d}]" in full_names]
Expand Down
34 changes: 34 additions & 0 deletions tests/integration_tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,40 @@ def test_get_superset_tables_not_allowed(self):
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)

@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_get_superset_tables_allowed(self):
session = db.session
table_name = "energy_usage"
role_name = "dummy_role"
self.logout()
self.login(username="gamma")
gamma_user = security_manager.find_user(username="gamma")
security_manager.add_role(role_name)
dummy_role = security_manager.find_role(role_name)
gamma_user.roles.append(dummy_role)

tbl_id = self.table_ids.get(table_name)
table = db.session.query(SqlaTable).filter(SqlaTable.id == tbl_id).first()
table_perm = table.perm

security_manager.add_permission_role(
dummy_role,
security_manager.find_permission_view_menu("datasource_access", table_perm),
)

session.commit()

example_db = utils.get_example_database()
schema_name = self.default_schema_backend_map[example_db.backend]
uri = f"superset/tables/{example_db.id}/{schema_name}/{table_name}/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)

# cleanup
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.remove(security_manager.find_role(role_name))
session.commit()

def test_get_superset_tables_substr(self):
example_db = utils.get_example_database()
if example_db.backend in {"presto", "hive"}:
Expand Down
6 changes: 4 additions & 2 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,10 @@ def test_get_dataset_related_database_gamma(self):
rv = self.client.get(uri)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
assert response["count"] == 0
assert response["result"] == []

assert response["count"] == 1
main_db = get_main_database()
assert filter(lambda x: x.text == main_db, response["result"]) != []

@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_get_dataset_item(self):
Expand Down

0 comments on commit 1fa370c

Please sign in to comment.