Skip to content

Commit

Permalink
Uniquely constrained users table (#2483)
Browse files Browse the repository at this point in the history
* use db constrains for users

* lint a bit

* coderabbitai

* set rate limits for login api

* Revert "set rate limits for login api"

This reverts commit 78ce40e.

* bring checks back

* Auto-update of Starter template

* Auto-update of NLP template

* review suggestions

* resolve branching

* Auto-update of Starter template

* Auto-update of NLP template

* lint

---------

Co-authored-by: GitHub Actions <[email protected]>
  • Loading branch information
avishniakov and actions-user authored Mar 4, 2024
1 parent 12d68dd commit afcaf74
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""unique users [72675226b2de].
Revision ID: 72675226b2de
Revises: 0.55.4
Create Date: 2024-02-29 14:58:25.584731
"""

from alembic import op

# revision identifiers, used by Alembic.
revision = "72675226b2de"
down_revision = "0.55.4"
branch_labels = None
depends_on = None


def upgrade() -> None:
"""Upgrade database schema and/or data, creating a new revision."""
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.create_unique_constraint(
"uq_user_name_is_service_account", ["name", "is_service_account"]
)


def downgrade() -> None:
"""Downgrade database schema and/or data back to the previous revision."""
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.drop_constraint(
"uq_user_name_is_service_account", type_="unique"
)
3 changes: 2 additions & 1 deletion src/zenml/zen_stores/schemas/user_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, Any, List, Optional
from uuid import UUID

from sqlalchemy import TEXT, Column
from sqlalchemy import TEXT, Column, UniqueConstraint
from sqlmodel import Field, Relationship

from zenml.models import (
Expand Down Expand Up @@ -65,6 +65,7 @@ class UserSchema(NamedSchema, table=True):
"""SQL Model for users."""

__tablename__ = "user"
__table_args__ = (UniqueConstraint("name", "is_service_account"),)

is_service_account: bool = Field(default=False)
full_name: str
Expand Down
66 changes: 35 additions & 31 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5146,28 +5146,29 @@ def create_service_account(
already exists.
"""
with Session(self.engine) as session:
# Check if a service account with the given name already
# exists
err_msg = (
f"Unable to create service account with name "
f"'{service_account.name}': Found existing service "
"account with this name."
)
try:
self._get_account_schema(
service_account.name, session=session, service_account=True
)
raise EntityExistsError(err_msg)
except KeyError:
pass

# Create the service account
new_account = UserSchema.from_service_account_request(
service_account
)
session.add(new_account)
# on commit an IntegrityError may arise we let it bubble up
session.commit()

# Check if a service account with the given name already
# exists
service_accounts = session.execute(
select(UserSchema).where(
UserSchema.name == service_account.name,
UserSchema.is_service_account.is_(True), # type: ignore[attr-defined]
)
).fetchall()
if len(service_accounts) == 1:
session.commit()
else:
raise EntityExistsError(
f"Unable to create service account with name "
f"'{service_account.name}': Found existing service "
"account with this name."
)
return new_account.to_service_account_model(include_metadata=True)

def get_service_account(
Expand Down Expand Up @@ -7357,24 +7358,27 @@ def create_user(self, user: UserRequest) -> UserResponse:
already exists.
"""
with Session(self.engine) as session:
# Check if a user account with the given name already exists
err_msg = (
f"Unable to create user with name '{user.name}': "
f"Found an existing user account with this name."
)
try:
self._get_account_schema(
user.name,
session=session,
# Filter out service accounts
service_account=False,
)
raise EntityExistsError(err_msg)
except KeyError:
pass

# Create the user
new_user = UserSchema.from_user_request(user)
session.add(new_user)

# Check if a user account with the given name already exists
users = session.execute(
select(UserSchema).where(
UserSchema.name == user.name,
UserSchema.is_service_account.is_(False), # type: ignore[attr-defined]
)
).fetchall()
if len(users) == 1:
session.commit()
else:
raise EntityExistsError(
f"Unable to create user with name '{user.name}': "
f"Found an existing user account with this name."
)
# on commit an IntegrityError may arise we let it bubble up
session.commit()
return new_user.to_model(include_metadata=True)

def get_user(
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/functional/zen_stores/test_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pytest
from pydantic import SecretStr
from sqlalchemy.exc import IntegrityError

from tests.integration.functional.utils import sample_name
from tests.integration.functional.zen_stores.utils import (
Expand Down Expand Up @@ -422,7 +423,7 @@ def silent_create_user(user_request: UserRequest):
"""
try:
clean_client.zen_store.create_user(user_request)
except EntityExistsError:
except (EntityExistsError, IntegrityError):
pass

user_name = "test_user"
Expand Down Expand Up @@ -463,7 +464,7 @@ def silent_create_service_account(
clean_client.zen_store.create_service_account(
service_account_request
)
except EntityExistsError:
except (EntityExistsError, IntegrityError):
pass

user_name = "test_user"
Expand Down

0 comments on commit afcaf74

Please sign in to comment.