Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniquely constrained users table #2483

Merged
merged 15 commits into from
Mar 4, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""unique users [72675226b2de].
Revision ID: 72675226b2de
Revises: 0.55.4
Create Date: 2024-02-29 14:58:25.584731
"""

from alembic import op

# revision identifiers, used by Alembic.
revision = "72675226b2de"
down_revision = "0.55.4"
branch_labels = None
depends_on = None


def upgrade() -> None:
"""Upgrade database schema and/or data, creating a new revision."""
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.create_unique_constraint(
"uq_user_name_is_service_account", ["name", "is_service_account"]
)


def downgrade() -> None:
"""Downgrade database schema and/or data back to the previous revision."""
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.drop_constraint(
"uq_user_name_is_service_account", type_="unique"
)
3 changes: 2 additions & 1 deletion src/zenml/zen_stores/schemas/user_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, Any, List, Optional
from uuid import UUID

from sqlalchemy import TEXT, Column
from sqlalchemy import TEXT, Column, UniqueConstraint
from sqlmodel import Field, Relationship

from zenml.models import (
Expand Down Expand Up @@ -65,6 +65,7 @@ class UserSchema(NamedSchema, table=True):
"""SQL Model for users."""

__tablename__ = "user"
__table_args__ = (UniqueConstraint("name", "is_service_account"),)

is_service_account: bool = Field(default=False)
full_name: str
Expand Down
66 changes: 35 additions & 31 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5146,28 +5146,29 @@ def create_service_account(
already exists.
"""
with Session(self.engine) as session:
# Check if a service account with the given name already
# exists
err_msg = (
f"Unable to create service account with name "
f"'{service_account.name}': Found existing service "
"account with this name."
)
try:
self._get_account_schema(
service_account.name, session=session, service_account=True
)
raise EntityExistsError(err_msg)
except KeyError:
pass

# Create the service account
new_account = UserSchema.from_service_account_request(
service_account
)
session.add(new_account)
# on commit an IntegrityError may arise we let it bubble up
session.commit()

# Check if a service account with the given name already
# exists
service_accounts = session.execute(
select(UserSchema).where(
UserSchema.name == service_account.name,
UserSchema.is_service_account.is_(True), # type: ignore[attr-defined]
)
).fetchall()
if len(service_accounts) == 1:
session.commit()
else:
raise EntityExistsError(
f"Unable to create service account with name "
f"'{service_account.name}': Found existing service "
"account with this name."
)
return new_account.to_service_account_model(include_metadata=True)

def get_service_account(
Expand Down Expand Up @@ -7357,24 +7358,27 @@ def create_user(self, user: UserRequest) -> UserResponse:
already exists.
"""
with Session(self.engine) as session:
# Check if a user account with the given name already exists
err_msg = (
f"Unable to create user with name '{user.name}': "
f"Found an existing user account with this name."
)
try:
self._get_account_schema(
user.name,
session=session,
# Filter out service accounts
service_account=False,
)
raise EntityExistsError(err_msg)
except KeyError:
pass

# Create the user
new_user = UserSchema.from_user_request(user)
session.add(new_user)

# Check if a user account with the given name already exists
users = session.execute(
select(UserSchema).where(
UserSchema.name == user.name,
UserSchema.is_service_account.is_(False), # type: ignore[attr-defined]
)
).fetchall()
if len(users) == 1:
session.commit()
else:
raise EntityExistsError(
f"Unable to create user with name '{user.name}': "
f"Found an existing user account with this name."
)
# on commit an IntegrityError may arise we let it bubble up
session.commit()
return new_user.to_model(include_metadata=True)

def get_user(
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/functional/zen_stores/test_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pytest
from pydantic import SecretStr
from sqlalchemy.exc import IntegrityError

from tests.integration.functional.utils import sample_name
from tests.integration.functional.zen_stores.utils import (
Expand Down Expand Up @@ -422,7 +423,7 @@ def silent_create_user(user_request: UserRequest):
"""
try:
clean_client.zen_store.create_user(user_request)
except EntityExistsError:
except (EntityExistsError, IntegrityError):
pass

user_name = "test_user"
Expand Down Expand Up @@ -463,7 +464,7 @@ def silent_create_service_account(
clean_client.zen_store.create_service_account(
service_account_request
)
except EntityExistsError:
except (EntityExistsError, IntegrityError):
pass

user_name = "test_user"
Expand Down
Loading