Skip to content

Commit

Permalink
✨Personalized resource limits: only allow specific groups to override…
Browse files Browse the repository at this point in the history
… resources (#4417)
  • Loading branch information
sanderegg authored Jun 29, 2023
1 parent fd95cd3 commit 52fb373
Show file tree
Hide file tree
Showing 32 changed files with 1,100 additions and 77 deletions.
37 changes: 37 additions & 0 deletions api/specs/webserver/openapi-users.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,19 @@ paths:
responses:
'204':
description: Successful Response
/me/permissions:
get:
tags:
- user
summary: List User Permissions
operationId: list_user_permissions
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/Envelope_list_simcore_service_webserver.users.schemas.PermissionGet__'
components:
schemas:
AllUsersGroups:
Expand Down Expand Up @@ -218,6 +231,17 @@ components:
$ref: '#/components/schemas/UserNotification'
error:
title: Error
Envelope_list_simcore_service_webserver.users.schemas.PermissionGet__:
title: Envelope[list[simcore_service_webserver.users.schemas.PermissionGet]]
type: object
properties:
data:
title: Data
type: array
items:
$ref: '#/components/schemas/PermissionGet'
error:
title: Error
Envelope_list_simcore_service_webserver.users.schemas.Token__:
title: Envelope[list[simcore_service_webserver.users.schemas.Token]]
type: object
Expand Down Expand Up @@ -256,6 +280,19 @@ components:
- ANNOTATION_NOTE
type: string
description: An enumeration.
PermissionGet:
title: PermissionGet
required:
- name
- allowed
type: object
properties:
name:
title: Name
type: string
allowed:
title: Allowed
type: boolean
ProfileGet:
title: ProfileGet
required:
Expand Down
3 changes: 3 additions & 0 deletions api/specs/webserver/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ paths:
/me/notifications/{notification_id}:
$ref: "./openapi-users.yaml#/paths/~1me~1notifications~1{notification_id}"

/me/permissions:
$ref: "./openapi-users.yaml#/paths/~1me~1permissions"

# GROUP SETTINGS ------------------------------------------------------------------

/groups:
Expand Down
11 changes: 11 additions & 0 deletions api/specs/webserver/scripts/openapi_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
UserNotificationPatch,
)
from simcore_service_webserver.users.schemas import (
PermissionGet,
ProfileGet,
ProfileUpdate,
Token,
Expand Down Expand Up @@ -131,6 +132,16 @@ async def mark_notification_as_read(
...


@app.get(
"/me/permissions",
response_model=Envelope[list[PermissionGet]],
tags=TAGS,
operation_id="list_user_permissions",
)
async def list_user_permissions():
...


if __name__ == "__main__":
from _common import CURRENT_DIR, create_openapi_specs

Expand Down
4 changes: 2 additions & 2 deletions packages/models-library/src/models_library/clusters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Final, Literal, Union
from typing import Final, Literal, TypeAlias, Union

from pydantic import AnyUrl, BaseModel, Extra, Field, HttpUrl, SecretStr, root_validator
from pydantic.types import NonNegativeInt
Expand Down Expand Up @@ -107,7 +107,7 @@ class Config:
use_enum_values = True


ClusterID = NonNegativeInt
ClusterID: TypeAlias = NonNegativeInt
DEFAULT_CLUSTER_ID: Final[NonNegativeInt] = 0


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def ensure_limits_are_equal_or_above_reservations(cls, values):

return values

def set_reservation_same_as_limit(self) -> None:
self.reservation = self.limit

class Config:
validate_assignment = True

Expand Down Expand Up @@ -82,6 +85,10 @@ class ImageResources(BaseModel):
description="describe how a service shall be booted, using CPU, MPI, openMP or GPU",
)

def set_reservation_same_as_limit(self) -> None:
for resource in self.resources.values():
resource.set_reservation_same_as_limit()

class Config:
schema_extra = {
"example": {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""add override_services_specifications
Revision ID: 38fe651b4196
Revises: 417f9eb848ce
Create Date: 2023-06-23 11:37:04.833354+00:00
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "38fe651b4196"
down_revision = "417f9eb848ce"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"groups_extra_properties",
sa.Column(
"override_services_specifications",
sa.Boolean(),
server_default=sa.text("false"),
nullable=True,
),
)
# ### end Alembic commands ###
groups_extra_properties_table = sa.table(
"groups_extra_properties",
sa.column("group_id"),
sa.column("node_id"),
sa.column("created"),
sa.column("modified"),
sa.column("override_services_specifications"),
)

# default to false
op.execute(
groups_extra_properties_table.update().values(
override_services_specifications=False
)
)
# # set to non nullable
op.alter_column(
"groups_extra_properties", "override_services_specifications", nullable=False
)


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("groups_extra_properties", "override_services_specifications")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@
"If a user is part of this group, it's "
"service can access the internet.",
),
sa.Column(
"override_services_specifications",
sa.Boolean(),
nullable=False,
server_default=sa.sql.expression.false(),
doc="allows group to override default service specifications.",
),
# TIME STAMPS ----
column_created_datetime(timezone=False),
column_modified_datetime(timezone=False),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import datetime
import logging
from dataclasses import dataclass, fields
from typing import Any

import sqlalchemy as sa
from aiopg.sa.connection import SAConnection
from aiopg.sa.result import RowProxy
from simcore_postgres_database.models.groups_extra_properties import (
groups_extra_properties,
)

from .models.groups import GroupType, groups, user_to_groups
from .utils_models import FromRowMixin

_logger = logging.getLogger(__name__)


class GroupExtraPropertiesError(Exception):
...


class GroupExtraPropertiesNotFound(GroupExtraPropertiesError):
...


@dataclass(frozen=True, slots=True, kw_only=True)
class GroupExtraProperties(FromRowMixin):
group_id: int
product_name: str
internet_access: bool
override_services_specifications: bool
created: datetime.datetime
modified: datetime.datetime


async def _list_table_entries_ordered_by_group_type(
connection: SAConnection, user_id: int, product_name: str
) -> list[RowProxy]:
list_stmt = (
sa.select(
groups_extra_properties,
groups.c.type,
sa.case(
# NOTE: the ordering is important for the aggregation afterwards
(groups.c.type == GroupType.EVERYONE, sa.literal(3)),
(groups.c.type == GroupType.STANDARD, sa.literal(2)),
(groups.c.type == GroupType.PRIMARY, sa.literal(1)),
else_=sa.literal(4),
).label("type_order"),
)
.select_from(
sa.join(
sa.join(
groups_extra_properties,
user_to_groups,
groups_extra_properties.c.group_id == user_to_groups.c.gid,
),
groups,
groups_extra_properties.c.group_id == groups.c.gid,
)
)
.where(
(groups_extra_properties.c.product_name == product_name)
& (user_to_groups.c.uid == user_id)
)
.alias()
)

result = await connection.execute(
sa.select(list_stmt).order_by(list_stmt.c.type_order)
)
assert result # nosec

rows: list[RowProxy] | None = await result.fetchall()
assert rows is not None # nosec
return rows


def _merge_extra_properties_booleans(
instance1: GroupExtraProperties, instance2: GroupExtraProperties
) -> GroupExtraProperties:
merged_properties: dict[str, Any] = {}
for field in fields(instance1):
value1 = getattr(instance1, field.name)
value2 = getattr(instance2, field.name)

if isinstance(value1, bool):
merged_properties[field.name] = value1 or value2
else:
merged_properties[field.name] = value1
return GroupExtraProperties(**merged_properties) # pylint: disable=missing-kwoa


@dataclass(frozen=True, slots=True, kw_only=True)
class GroupExtraPropertiesRepo:
@staticmethod
async def get(
connection: SAConnection, *, gid: int, product_name: str
) -> GroupExtraProperties:
get_stmt = sa.select(groups_extra_properties).where(
(groups_extra_properties.c.group_id == gid)
& (groups_extra_properties.c.product_name == product_name)
)
result = await connection.execute(get_stmt)
assert result # nosec
if row := await result.first():
return GroupExtraProperties.from_row(row)
raise GroupExtraPropertiesNotFound(f"Properties for group {gid} not found")

@staticmethod
async def get_aggregated_properties_for_user(
connection: SAConnection,
*,
user_id: int,
product_name: str,
) -> GroupExtraProperties:
rows = await _list_table_entries_ordered_by_group_type(
connection, user_id, product_name
)
merged_standard_extra_properties = None
for row in rows:
group_extra_properties = GroupExtraProperties.from_row(row)
match row.type:
case GroupType.PRIMARY:
# this always has highest priority
return group_extra_properties
case GroupType.STANDARD:
if merged_standard_extra_properties:
merged_standard_extra_properties = (
_merge_extra_properties_booleans(
merged_standard_extra_properties,
group_extra_properties,
)
)
else:
merged_standard_extra_properties = group_extra_properties
case GroupType.EVERYONE:
# if there are standard properties, they take precedence
return (
merged_standard_extra_properties
if merged_standard_extra_properties
else group_extra_properties
)
case _:
_logger.warning(
"Unexpected GroupType found in %s db table! Please adapt code here!",
groups_extra_properties.name,
)
if merged_standard_extra_properties:
return merged_standard_extra_properties
raise GroupExtraPropertiesNotFound(
f"Properties for user {user_id} in {product_name} not found"
)
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import fields, is_dataclass
from typing import TypeVar

from aiopg.sa.result import RowProxy
Expand All @@ -10,4 +11,6 @@ class FromRowMixin:

@classmethod
def from_row(cls: type[ModelType], row: RowProxy) -> ModelType:
return cls(**dict(row.items()))
assert is_dataclass(cls) # nosec
field_names = [f.name for f in fields(cls)]
return cls(**{k: v for k, v in row.items() if k in field_names}) # type: ignore[return-value]
2 changes: 1 addition & 1 deletion packages/postgres-database/tests/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ services:
# - net.ipv4.tcp_keepalive_probes=9
# - net.ipv4.tcp_keepalive_time=600
adminer:
image: adminer:4.8.0
image: adminer:4.8.1
init: true
environment:
- ADMINER_DEFAULT_SERVER=postgres
Expand Down
Loading

0 comments on commit 52fb373

Please sign in to comment.