Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RBAC resource sharing #2320

Merged
merged 11 commits into from
Jan 22, 2024
2 changes: 1 addition & 1 deletion examples/e2e/.copier-answers.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Changes here will be overwritten by Copier
_commit: 2024.01.17-2-g9c82435
_commit: 2024.01.18
_src_path: gh:zenml-io/template-e2e-batch
data_quality_checks: true
email: ''
Expand Down
2 changes: 1 addition & 1 deletion examples/e2e_nlp/.copier-answers.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Changes here will be overwritten by Copier
_commit: 0.45.0-2-gdb7862c
_commit: 2024.01.12
_src_path: gh:zenml-io/template-nlp
accelerator: cpu
cloud_of_choice: aws
Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart/.copier-answers.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Changes here will be overwritten by Copier
_commit: 2023.12.18-3-g5b0d7c9
_commit: 2024.01.12
_src_path: gh:zenml-io/template-starter
email: ''
full_name: ZenML GmbH
Expand Down
2 changes: 2 additions & 0 deletions src/zenml/zen_server/rbac/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class Action(StrEnum):
# Secrets
BACKUP_RESTORE = "backup_restore"

SHARE = "share"


class ResourceType(StrEnum):
"""Resource types of the server API."""
Expand Down
13 changes: 13 additions & 0 deletions src/zenml/zen_server/rbac/rbac_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,16 @@ def list_allowed_resource_ids(
will contain the list of instance IDs that the user can perform
the action on.
"""

@abstractmethod
def update_resource_membership(
self, user: "UserResponse", resource: Resource, actions: List[Action]
) -> None:
"""Update the resource membership of a user.

Args:
user: User for which the resource membership should be updated.
resource: The resource.
actions: The actions that the user should be able to perform on the
resource.
"""
87 changes: 82 additions & 5 deletions src/zenml/zen_server/rbac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,29 @@
"""RBAC utility functions."""

from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Sequence,
Set,
Type,
TypeVar,
)
from uuid import UUID

from pydantic import BaseModel

from zenml.exceptions import IllegalOperationError
from zenml.models import (
BaseResponse,
Page,
UserScopedResponse,
)
from zenml.models import BaseResponse, Page, UserResponse, UserScopedResponse
from zenml.zen_server.auth import get_auth_context
from zenml.zen_server.rbac.models import Action, Resource, ResourceType
from zenml.zen_server.utils import rbac, server_config

if TYPE_CHECKING:
from zenml.zen_stores.schemas import BaseSchema

AnyResponse = TypeVar("AnyResponse", bound=BaseResponse) # type: ignore[type-arg]
AnyModel = TypeVar("AnyModel", bound=BaseModel)

Expand Down Expand Up @@ -500,3 +501,79 @@ def _get_subresources_for_value(value: Any) -> Set[Resource]:
return set.union(*resources_list) if resources_list else set()
else:
return set()


def get_schema_for_resource_type(
resource_type: ResourceType,
) -> Type["BaseSchema"]:
"""Get the database schema for a resource type.

Args:
resource_type: The resource type for which to get the database schema.

Returns:
The database schema.
"""
from zenml.zen_stores.schemas import (
ArtifactSchema,
ArtifactVersionSchema,
CodeRepositorySchema,
FlavorSchema,
ModelSchema,
ModelVersionSchema,
PipelineBuildSchema,
PipelineDeploymentSchema,
PipelineRunSchema,
PipelineSchema,
RunMetadataSchema,
SecretSchema,
ServiceConnectorSchema,
StackComponentSchema,
StackSchema,
TagSchema,
UserSchema,
WorkspaceSchema,
)

mapping: Dict[ResourceType, Type["BaseSchema"]] = {
ResourceType.STACK: StackSchema,
ResourceType.FLAVOR: FlavorSchema,
ResourceType.STACK_COMPONENT: StackComponentSchema,
ResourceType.PIPELINE: PipelineSchema,
ResourceType.CODE_REPOSITORY: CodeRepositorySchema,
ResourceType.MODEL: ModelSchema,
ResourceType.MODEL_VERSION: ModelVersionSchema,
ResourceType.SERVICE_CONNECTOR: ServiceConnectorSchema,
ResourceType.ARTIFACT: ArtifactSchema,
ResourceType.ARTIFACT_VERSION: ArtifactVersionSchema,
ResourceType.SECRET: SecretSchema,
ResourceType.TAG: TagSchema,
ResourceType.SERVICE_ACCOUNT: UserSchema,
ResourceType.WORKSPACE: WorkspaceSchema,
ResourceType.PIPELINE_RUN: PipelineRunSchema,
ResourceType.PIPELINE_DEPLOYMENT: PipelineDeploymentSchema,
ResourceType.PIPELINE_BUILD: PipelineBuildSchema,
ResourceType.RUN_METADATA: RunMetadataSchema,
ResourceType.USER: UserSchema,
}

return mapping[resource_type]


def update_resource_membership(
user: UserResponse, resource: Resource, actions: List[Action]
) -> None:
"""Update the resource membership of a user.

Args:
user: User for which the resource membership should be updated.
resource: The resource.
actions: The actions that the user should be able to perform on the
resource.
"""
if not server_config().rbac_enabled:
return

rbac().update_resource_membership(
user=user, resource=resource, actions=actions
)
64 changes: 64 additions & 0 deletions src/zenml/zen_server/rbac/zenml_cloud_rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ZENML_CLOUD_RBAC_ENV_PREFIX = "ZENML_CLOUD_"
PERMISSIONS_ENDPOINT = "/rbac/check_permissions"
ALLOWED_RESOURCE_IDS_ENDPOINT = "/rbac/allowed_resource_ids"
RESOURCE_MEMBERSHIP_ENDPOINT = "/rbac/resource_members"

SERVER_SCOPE_IDENTIFIER = "server"

Expand Down Expand Up @@ -211,6 +212,28 @@ def list_allowed_resource_ids(

return full_resource_access, allowed_ids

def update_resource_membership(
self, user: "UserResponse", resource: Resource, actions: List[Action]
AlexejPenner marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Update the resource membership of a user.

Args:
user: User for which the resource membership should be updated.
resource: The resource.
actions: The actions that the user should be able to perform on the
resource.
"""
if user.is_service_account:
# Service accounts have full permissions for now
return

data = {
"user_id": str(user.external_user_id),
"resource": _convert_to_cloud_resource(resource),
"actions": [str(action) for action in actions],
}
self._post(endpoint=RESOURCE_MEMBERSHIP_ENDPOINT, data=data)

def _get(self, endpoint: str, params: Dict[str, Any]) -> requests.Response:
"""Send a GET request using the active session.

Expand Down Expand Up @@ -242,6 +265,47 @@ def _get(self, endpoint: str, params: Dict[str, Any]) -> requests.Response:

return response

def _post(
self,
endpoint: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
) -> requests.Response:
"""Send a POST request using the active session.

Args:
endpoint: The endpoint to send the request to. This will be appended
to the base URL.
params: Parameters to include in the request.
data: Data to include in the request.

Raises:
RuntimeError: If the request failed.

Returns:
The response.
"""
url = self._config.api_url + endpoint

response = self.session.post(
url=url, params=params, json=data, timeout=7
)
if response.status_code == 401:
# Refresh the auth token and try again
self._clear_session()
response = self.session.post(
url=url, params=params, json=data, timeout=7
)

try:
response.raise_for_status()
except requests.HTTPError as e:
raise RuntimeError(
f"Failed while trying to contact RBAC service: {e}"
)

return response

@property
def session(self) -> requests.Session:
"""Authenticate to the ZenML Cloud API.
Expand Down
83 changes: 81 additions & 2 deletions src/zenml/zen_server/routers/users_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# permissions and limitations under the License.
"""Endpoint definitions for users."""

from typing import Optional, Union
from typing import List, Optional, Union
from uuid import UUID

from fastapi import APIRouter, Depends, Security
Expand Down Expand Up @@ -46,11 +46,14 @@
from zenml.zen_server.rbac.endpoint_utils import (
verify_permissions_and_create_entity,
)
from zenml.zen_server.rbac.models import Action, ResourceType
from zenml.zen_server.rbac.models import Action, Resource, ResourceType
from zenml.zen_server.rbac.utils import (
dehydrate_page,
dehydrate_response_model,
get_allowed_resource_ids,
get_schema_for_resource_type,
update_resource_membership,
verify_permission,
verify_permission_for_model,
)
from zenml.zen_server.utils import (
Expand Down Expand Up @@ -462,3 +465,79 @@ def update_myself(
user_id=auth_context.user.id, user_update=user
)
return dehydrate_response_model(updated_user)


if server_config().rbac_enabled:

@router.post(
"/{user_name_or_id}/resource_membership",
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def update_user_resource_membership(
Copy link
Contributor

@AlexejPenner AlexejPenner Jan 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong. This handles adding a Permission to a User for a Resource. How would I remove a permission?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You would remove a permission by sending an empty list of actions. I think I might have to document that in the docstring though, or do you think having two separate endpoints makes more sense. It's definitely a little confusing with the RBACInterface.share_resource(...) method name though, maybe that should have two separate methods or it should also be called RBACInterface.update_resource_membership(...)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for the renaming

user_name_or_id: Union[str, UUID],
resource_type: str,
resource_id: UUID,
actions: List[str],
auth_context: AuthContext = Security(authorize),
) -> None:
"""Updates resource memberships of a user.

Args:
user_name_or_id: Name or ID of the user.
resource_type: Type of the resource for which to update the
membership.
resource_id: ID of the resource for which to update the membership.
actions: List of actions that the user should be able to perform on
the resource. If the user currently has permissions to perform
actions which are not passed in this list, the permissions will
be removed.
auth_context: Authentication context.

Raises:
ValueError: If a user tries to update their own membership.
KeyError: If no resource with the given type and ID exists.
"""
user = zen_store().get_user(user_name_or_id)
verify_permission_for_model(user, action=Action.READ)

if user.id == auth_context.user.id:
raise ValueError(
"Not allowed to call endpoint with the authenticated user."
)

resource_type = ResourceType(resource_type)
resource = Resource(type=resource_type, id=resource_id)

schema_class = get_schema_for_resource_type(resource_type)
if not zen_store().object_exists(
object_id=resource_id, schema_class=schema_class
):
raise KeyError(
f"Resource of type {resource_type} with ID {resource_id} does "
"not exist."
)

verify_permission(
resource_type=resource_type,
action=Action.SHARE,
resource_id=resource_id,
)
for action in actions:
# Make sure users aren't able to share permissions they don't have
# themselves
verify_permission(
resource_type=resource_type,
action=Action(action),
resource_id=resource_id,
)

update_resource_membership(
user=user,
resource=resource,
actions=[Action(action) for action in actions],
)
19 changes: 19 additions & 0 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6884,6 +6884,25 @@ def _count_entity(

return int(entity_count)

def object_exists(
self, object_id: UUID, schema_class: Type[AnySchema]
) -> bool:
"""Check whether an object exists in the database.

Args:
object_id: The ID of the object to check.
schema_class: The schema class.

Returns:
If the object exists.
"""
with Session(self.engine) as session:
schema = session.exec(
select(schema_class.id).where(schema_class.id == object_id)
).first()

return False if schema is None else True

@staticmethod
def _get_schema_by_name_or_id(
object_name_or_id: Union[str, UUID],
Expand Down
Loading