diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 8cb9f45512..c99887b1c4 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -29,6 +29,7 @@ These are the section headers that we use: ### Changed +- API endpoint added to the User router to allow updates to User objects ([#5615](https://github.com/argilla-io/argilla/pull/5615)) - Changed default python version to 3.13. ([#5649](https://github.com/argilla-io/argilla/pull/5649)) - Changed Pydantic version to v2. ([#5666](https://github.com/argilla-io/argilla/pull/5666)) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/users.py b/argilla-server/src/argilla_server/api/handlers/v1/users.py index 7548b54520..728d497eb6 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/users.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/users.py @@ -19,7 +19,7 @@ from argilla_server.api.policies.v1 import UserPolicy, authorize from argilla_server.api.schemas.v1.users import User as UserSchema -from argilla_server.api.schemas.v1.users import UserCreate, Users +from argilla_server.api.schemas.v1.users import UserCreate, Users, UserUpdate from argilla_server.api.schemas.v1.workspaces import Workspaces from argilla_server.contexts import accounts from argilla_server.database import get_async_db @@ -89,6 +89,21 @@ async def delete_user( return await accounts.delete_user(db, user) +@router.patch("/users/{user_id}", status_code=status.HTTP_200_OK, response_model=UserSchema) +async def update_user( + *, + db: AsyncSession = Depends(get_async_db), + user_id: UUID, + user_update: UserUpdate, + current_user: User = Security(auth.get_current_user), +): + user = await User.get_or_raise(db, user_id) + + await authorize(current_user, UserPolicy.update) + + return await accounts.update_user(db, user, user_update.model_dump(exclude_unset=True)) + + @router.get("/users/{user_id}/workspaces", response_model=Workspaces) async def list_user_workspaces( *, diff --git a/argilla-server/src/argilla_server/api/policies/v1/user_policy.py b/argilla-server/src/argilla_server/api/policies/v1/user_policy.py index cd97a843ad..3564ae5e09 100644 --- a/argilla-server/src/argilla_server/api/policies/v1/user_policy.py +++ b/argilla-server/src/argilla_server/api/policies/v1/user_policy.py @@ -28,6 +28,10 @@ async def list(cls, actor: User) -> bool: async def create(cls, actor: User) -> bool: return actor.is_owner + @classmethod + async def update(cls, actor: User) -> bool: + return actor.is_owner + @classmethod async def delete(cls, actor: User) -> bool: return actor.is_owner diff --git a/argilla-server/src/argilla_server/api/schemas/v1/users.py b/argilla-server/src/argilla_server/api/schemas/v1/users.py index f45dbda70c..56ec19bf33 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/users.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/users.py @@ -13,16 +13,35 @@ # limitations under the License. from datetime import datetime -from typing import List, Optional +from typing import Annotated, List, Optional from uuid import UUID from pydantic import BaseModel, Field, constr, ConfigDict +from argilla_server.api.schemas.v1.commons import UpdateSchema from argilla_server.enums import UserRole USER_PASSWORD_MIN_LENGTH = 8 USER_PASSWORD_MAX_LENGTH = 100 +UserFirstName = Annotated[ + constr(min_length=1, strip_whitespace=True), Field(..., description="The first name for the user") +] +UserLastName = Annotated[ + constr(min_length=1, strip_whitespace=True), Field(..., description="The last name for the user") +] +UserUsername = Annotated[str, Field(..., min_length=1, description="The username for the user")] + +UserPassword = Annotated[ + str, + Field( + ..., + min_length=USER_PASSWORD_MIN_LENGTH, + max_length=USER_PASSWORD_MAX_LENGTH, + description="The password for the user", + ), +] + class User(BaseModel): id: UUID @@ -40,11 +59,21 @@ class User(BaseModel): class UserCreate(BaseModel): - username: str = Field(..., min_length=1) - password: str = Field(min_length=USER_PASSWORD_MIN_LENGTH, max_length=USER_PASSWORD_MAX_LENGTH) - first_name: constr(min_length=1, strip_whitespace=True) - last_name: Optional[constr(min_length=1, strip_whitespace=True)] = None + first_name: UserFirstName + last_name: Optional[UserLastName] = None + username: UserUsername + role: Optional[UserRole] = None + password: UserPassword + + +class UserUpdate(UpdateSchema): + __non_nullable_fields__ = {"first_name", "username", "role", "password"} + + first_name: Optional[UserFirstName] = None + last_name: Optional[UserLastName] = None + username: Optional[UserUsername] = None role: Optional[UserRole] = None + password: Optional[UserPassword] = None class Users(BaseModel): diff --git a/argilla-server/src/argilla_server/contexts/accounts.py b/argilla-server/src/argilla_server/contexts/accounts.py index 0b65bdc815..f12fbadf0c 100644 --- a/argilla-server/src/argilla_server/contexts/accounts.py +++ b/argilla-server/src/argilla_server/contexts/accounts.py @@ -154,6 +154,18 @@ async def create_user_with_random_password( return await create_user(db, user_attrs, workspaces) +async def update_user(db: AsyncSession, user: User, user_attrs: dict) -> User: + username = user_attrs.get("username") + if username is not None and username != user.username: + if await get_user_by_username(db, username): + raise UnprocessableEntityError(f"Username {username!r} already exists") + + if "password" in user_attrs: + user_attrs["password_hash"] = hash_password(user_attrs.pop("password")) + + return await user.update(db, **user_attrs) + + async def delete_user(db: AsyncSession, user: User) -> User: return await user.delete(db) diff --git a/argilla-server/tests/unit/api/handlers/v1/users/test_update_user.py b/argilla-server/tests/unit/api/handlers/v1/users/test_update_user.py new file mode 100644 index 0000000000..08615cca4d --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/users/test_update_user.py @@ -0,0 +1,273 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from uuid import UUID, uuid4 +from argilla_server.constants import API_KEY_HEADER_NAME +from argilla_server.contexts import accounts +from argilla_server.enums import UserRole +from argilla_server.models import User +from httpx import AsyncClient +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from tests.factories import UserFactory + + +@pytest.mark.asyncio +class TestUpdateUser: + def url(self, user_id: UUID) -> str: + return f"/api/v1/users/{user_id}" + + async def test_update_user(self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict): + user = await UserFactory.create() + + user_password_hash = user.password_hash + + response = await async_client.patch( + self.url(user.id), + headers=owner_auth_header, + json={ + "first_name": "Updated First Name", + "last_name": "Updated Last Name", + "username": "updated_username", + "role": UserRole.admin, + "password": "new_password", + }, + ) + + assert response.status_code == 200 + + updated_user = (await db.execute(select(User).filter_by(id=user.id))).scalar_one() + assert updated_user.first_name == "Updated First Name" + assert updated_user.last_name == "Updated Last Name" + assert updated_user.username == "updated_username" + assert updated_user.role == UserRole.admin + assert updated_user.password_hash != user_password_hash + + async def test_update_user_password(self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict): + user = await UserFactory.create() + old_password_hash = user.password_hash + + response = await async_client.patch( + self.url(user.id), + headers=owner_auth_header, + json={ + "password": "new_password", + }, + ) + + assert response.status_code == 200 + + assert accounts.verify_password("new_password", user.password_hash) is True + assert accounts.verify_password("new_password", old_password_hash) is False + + async def test_update_user_without_authentication(self, db: AsyncSession, async_client: AsyncClient): + user = await UserFactory.create() + + response = await async_client.patch( + self.url(user.id), + json={ + "first_name": "Updated First Name", + "last_name": "Updated Last Name", + "username": "updated_username", + "role": UserRole.admin, + }, + ) + + assert response.status_code == 401 + + @pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator]) + async def test_update_user_with_unauthorized_role( + self, db: AsyncSession, async_client: AsyncClient, user_role: UserRole + ): + user = await UserFactory.create() + user_with_unauthorized_role = await UserFactory.create(role=user_role) + + response = await async_client.patch( + self.url(user.id), + headers={API_KEY_HEADER_NAME: user_with_unauthorized_role.api_key}, + json={ + "first_name": "Updated First Name", + "last_name": "Updated Last Name", + "username": "updated_username", + "role": UserRole.admin, + }, + ) + + assert response.status_code == 403 + + async def test_update_user_with_nonexistent_user_id(self, async_client: AsyncClient, owner_auth_header: dict): + user_id = uuid4() + + response = await async_client.patch( + self.url(user_id), + headers=owner_auth_header, + json={ + "first_name": "Updated First Name", + "last_name": "Updated Last Name", + "username": "updated_username", + "role": UserRole.admin, + }, + ) + + assert response.status_code == 404 + assert response.json() == {"detail": f"User with id `{user_id}` not found"} + + async def test_update_user_with_invalid_data( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + user = await UserFactory.create() + + response = await async_client.patch( + self.url(user.id), + headers=owner_auth_header, + json={ + "first_name": "", + "last_name": "Updated Last Name", + "username": "updated_username", + "role": "invalid_role", + }, + ) + + assert response.status_code == 422 + + async def test_update_user_with_duplicate_username( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + user1 = await UserFactory.create(username="user1") + user2 = await UserFactory.create(username="user2") + + response = await async_client.patch( + self.url(user2.id), + headers=owner_auth_header, + json={ + "username": user1.username, + }, + ) + + assert response.status_code == 422, response.json() + assert response.json() == {"detail": "Username 'user1' already exists"} + + async def test_update_user_with_none_first_name( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + user = await UserFactory.create() + + response = await async_client.patch( + self.url(user.id), + headers=owner_auth_header, + json={ + "first_name": None, + }, + ) + + assert response.status_code == 422 + assert response.json() == { + "detail": { + "code": "argilla.api.errors::ValidationError", + "params": { + "errors": [ + { + "loc": ["body"], + "msg": "Value error, The following keys must have non-null values: first_name", + "type": "value_error", + } + ] + }, + } + } + + async def test_update_user_with_none_last_name( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + user = await UserFactory.create() + + response = await async_client.patch( + self.url(user.id), + headers=owner_auth_header, + json={ + "last_name": None, + }, + ) + + assert response.status_code == 200 + assert response.json() == { + "id": str(user.id), + "api_key": user.api_key, + "first_name": user.first_name, + "last_name": None, + "username": user.username, + "role": user.role, + "inserted_at": user.inserted_at.isoformat(), + "updated_at": user.updated_at.isoformat(), + } + + async def test_update_user_with_none_username( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + user = await UserFactory.create() + + response = await async_client.patch( + self.url(user.id), + headers=owner_auth_header, + json={ + "username": None, + }, + ) + + assert response.status_code == 422 + assert response.json() == { + "detail": { + "code": "argilla.api.errors::ValidationError", + "params": { + "errors": [ + { + "loc": ["body"], + "msg": "Value error, The following keys must have non-null values: username", + "type": "value_error", + } + ] + }, + } + } + + async def test_update_user_with_none_password( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + user = await UserFactory.create() + + response = await async_client.patch( + self.url(user.id), + headers=owner_auth_header, + json={ + "password": None, + }, + ) + + assert response.status_code == 422 + assert response.json() == { + "detail": { + "code": "argilla.api.errors::ValidationError", + "params": { + "errors": [ + { + "loc": ["body"], + "msg": "Value error, The following keys must have non-null values: password", + "type": "value_error", + } + ] + }, + } + } diff --git a/argilla/CHANGELOG.md b/argilla/CHANGELOG.md index a35dded88a..b0a76c59ff 100644 --- a/argilla/CHANGELOG.md +++ b/argilla/CHANGELOG.md @@ -37,6 +37,7 @@ These are the section headers that we use: ### Changed +- User parameters can now be updated using the client ([#5614](https://github.com/argilla-io/argilla/issues/5614)) - Changed `Dataset.from_hub` method to open configure URL when `settings="ui"`. ([#5622](https://github.com/argilla-io/argilla/pull/5622)) - Terms metadata properties accept other values than `str`. ([#5594](https://github.com/argilla-io/argilla/pull/5594)) - Added support for `with_vectors` while fetching records along with a search query. ([#5638](https://github.com/argilla-io/argilla/pull/5638)) diff --git a/argilla/src/argilla/_api/_users.py b/argilla/src/argilla/_api/_users.py index 30d711fcbf..0063e15f92 100644 --- a/argilla/src/argilla/_api/_users.py +++ b/argilla/src/argilla/_api/_users.py @@ -43,6 +43,15 @@ def create(self, user: UserModel) -> UserModel: return user_created + @api_error_handler + def update(self, user: UserModel) -> UserModel: + json_body = user.model_dump() + response = self.http_client.patch(f"/api/v1/users/{user.id}", json=json_body).raise_for_status() + user_updated = self._model_from_json(response_json=response.json()) + self._log_message(message=f"Updated user {user_updated.username}") + + return user_updated + @api_error_handler def get(self, user_id: UUID) -> UserModel: # TODO: Implement this endpoint in the API diff --git a/argilla/tests/integration/test_manage_users.py b/argilla/tests/integration/test_manage_users.py index 58d716cef2..4671ce3ecd 100644 --- a/argilla/tests/integration/test_manage_users.py +++ b/argilla/tests/integration/test_manage_users.py @@ -16,7 +16,7 @@ import pytest from argilla import User, Argilla, Workspace -from argilla._exceptions import UnprocessableEntityError +from argilla._exceptions import UnprocessableEntityError, ConflictError class TestManageUsers: @@ -46,3 +46,30 @@ def test_add_user_to_workspace(self, client: Argilla, workspace: Workspace): user.add_to_workspace(workspace) assert user in workspace.users + + def test_update_user(self, client: Argilla): + user = User(username=f"test_update_user_{uuid.uuid4()}", password="test_password") + client.users.add(user) + + updated_username = f"updated_user_{uuid.uuid4()}" + user.username = updated_username + user.first_name = "Updated First Name" + user.last_name = "Updated Last Name" + user.role = "admin" + user.update() + + updated_user = client.users(id=user.id) + assert updated_user.username == updated_username + assert updated_user.first_name == "Updated First Name" + assert updated_user.last_name == "Updated Last Name" + assert updated_user.role == "admin" + + def test_update_user_with_duplicate_username(self, client: Argilla): + user1 = User(username=f"test_user1_{uuid.uuid4()}", password="test_password") + user2 = User(username=f"test_user2_{uuid.uuid4()}", password="test_password") + client.users.add(user1) + client.users.add(user2) + + user2.username = user1.username + with pytest.raises(expected_exception=UnprocessableEntityError): + user2.update()