From 62a1fd0e774603c4cd1f88c464d4a5e4f78b0807 Mon Sep 17 00:00:00 2001 From: sean-hickey-wf <102522212+sean-hickey-wf@users.noreply.github.com> Date: Tue, 3 Dec 2024 09:29:31 +0100 Subject: [PATCH] [FEATURE]: Adding Functionality To Update Users (#5615) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description Argilla offers the ability to create and delete users but not the ability to update a User object after it has been created. For example, if we want to update the Role of a user after they have been created (from annotator to admin for example), this is not possible without deleting and recreating the User. This PR adds an update endpoint to the FastAPI server and also the convenience of doing this through the python sdk also Closes # **Type of change** - New feature (non-breaking change which adds functionality) - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** Tests have been added at both the server and SDK level to ensure that the update method is working as expected **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Paco Aranda Co-authored-by: Paco Aranda Co-authored-by: Francisco Aranda Co-authored-by: José Francisco Calvo --- argilla-server/CHANGELOG.md | 1 + .../argilla_server/api/handlers/v1/users.py | 17 +- .../api/policies/v1/user_policy.py | 4 + .../argilla_server/api/schemas/v1/users.py | 39 ++- .../src/argilla_server/contexts/accounts.py | 12 + .../api/handlers/v1/users/test_update_user.py | 273 ++++++++++++++++++ argilla/CHANGELOG.md | 1 + argilla/src/argilla/_api/_users.py | 9 + .../tests/integration/test_manage_users.py | 29 +- 9 files changed, 378 insertions(+), 7 deletions(-) create mode 100644 argilla-server/tests/unit/api/handlers/v1/users/test_update_user.py diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 8cb9f45512..c99887b1c4 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -29,6 +29,7 @@ These are the section headers that we use: ### Changed +- API endpoint added to the User router to allow updates to User objects ([#5615](https://github.com/argilla-io/argilla/pull/5615)) - Changed default python version to 3.13. ([#5649](https://github.com/argilla-io/argilla/pull/5649)) - Changed Pydantic version to v2. ([#5666](https://github.com/argilla-io/argilla/pull/5666)) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/users.py b/argilla-server/src/argilla_server/api/handlers/v1/users.py index 7548b54520..728d497eb6 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/users.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/users.py @@ -19,7 +19,7 @@ from argilla_server.api.policies.v1 import UserPolicy, authorize from argilla_server.api.schemas.v1.users import User as UserSchema -from argilla_server.api.schemas.v1.users import UserCreate, Users +from argilla_server.api.schemas.v1.users import UserCreate, Users, UserUpdate from argilla_server.api.schemas.v1.workspaces import Workspaces from argilla_server.contexts import accounts from argilla_server.database import get_async_db @@ -89,6 +89,21 @@ async def delete_user( return await accounts.delete_user(db, user) +@router.patch("/users/{user_id}", status_code=status.HTTP_200_OK, response_model=UserSchema) +async def update_user( + *, + db: AsyncSession = Depends(get_async_db), + user_id: UUID, + user_update: UserUpdate, + current_user: User = Security(auth.get_current_user), +): + user = await User.get_or_raise(db, user_id) + + await authorize(current_user, UserPolicy.update) + + return await accounts.update_user(db, user, user_update.model_dump(exclude_unset=True)) + + @router.get("/users/{user_id}/workspaces", response_model=Workspaces) async def list_user_workspaces( *, diff --git a/argilla-server/src/argilla_server/api/policies/v1/user_policy.py b/argilla-server/src/argilla_server/api/policies/v1/user_policy.py index cd97a843ad..3564ae5e09 100644 --- a/argilla-server/src/argilla_server/api/policies/v1/user_policy.py +++ b/argilla-server/src/argilla_server/api/policies/v1/user_policy.py @@ -28,6 +28,10 @@ async def list(cls, actor: User) -> bool: async def create(cls, actor: User) -> bool: return actor.is_owner + @classmethod + async def update(cls, actor: User) -> bool: + return actor.is_owner + @classmethod async def delete(cls, actor: User) -> bool: return actor.is_owner diff --git a/argilla-server/src/argilla_server/api/schemas/v1/users.py b/argilla-server/src/argilla_server/api/schemas/v1/users.py index f45dbda70c..56ec19bf33 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/users.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/users.py @@ -13,16 +13,35 @@ # limitations under the License. from datetime import datetime -from typing import List, Optional +from typing import Annotated, List, Optional from uuid import UUID from pydantic import BaseModel, Field, constr, ConfigDict +from argilla_server.api.schemas.v1.commons import UpdateSchema from argilla_server.enums import UserRole USER_PASSWORD_MIN_LENGTH = 8 USER_PASSWORD_MAX_LENGTH = 100 +UserFirstName = Annotated[ + constr(min_length=1, strip_whitespace=True), Field(..., description="The first name for the user") +] +UserLastName = Annotated[ + constr(min_length=1, strip_whitespace=True), Field(..., description="The last name for the user") +] +UserUsername = Annotated[str, Field(..., min_length=1, description="The username for the user")] + +UserPassword = Annotated[ + str, + Field( + ..., + min_length=USER_PASSWORD_MIN_LENGTH, + max_length=USER_PASSWORD_MAX_LENGTH, + description="The password for the user", + ), +] + class User(BaseModel): id: UUID @@ -40,11 +59,21 @@ class User(BaseModel): class UserCreate(BaseModel): - username: str = Field(..., min_length=1) - password: str = Field(min_length=USER_PASSWORD_MIN_LENGTH, max_length=USER_PASSWORD_MAX_LENGTH) - first_name: constr(min_length=1, strip_whitespace=True) - last_name: Optional[constr(min_length=1, strip_whitespace=True)] = None + first_name: UserFirstName + last_name: Optional[UserLastName] = None + username: UserUsername + role: Optional[UserRole] = None + password: UserPassword + + +class UserUpdate(UpdateSchema): + __non_nullable_fields__ = {"first_name", "username", "role", "password"} + + first_name: Optional[UserFirstName] = None + last_name: Optional[UserLastName] = None + username: Optional[UserUsername] = None role: Optional[UserRole] = None + password: Optional[UserPassword] = None class Users(BaseModel): diff --git a/argilla-server/src/argilla_server/contexts/accounts.py b/argilla-server/src/argilla_server/contexts/accounts.py index 0b65bdc815..f12fbadf0c 100644 --- a/argilla-server/src/argilla_server/contexts/accounts.py +++ b/argilla-server/src/argilla_server/contexts/accounts.py @@ -154,6 +154,18 @@ async def create_user_with_random_password( return await create_user(db, user_attrs, workspaces) +async def update_user(db: AsyncSession, user: User, user_attrs: dict) -> User: + username = user_attrs.get("username") + if username is not None and username != user.username: + if await get_user_by_username(db, username): + raise UnprocessableEntityError(f"Username {username!r} already exists") + + if "password" in user_attrs: + user_attrs["password_hash"] = hash_password(user_attrs.pop("password")) + + return await user.update(db, **user_attrs) + + async def delete_user(db: AsyncSession, user: User) -> User: return await user.delete(db) diff --git a/argilla-server/tests/unit/api/handlers/v1/users/test_update_user.py b/argilla-server/tests/unit/api/handlers/v1/users/test_update_user.py new file mode 100644 index 0000000000..08615cca4d --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/users/test_update_user.py @@ -0,0 +1,273 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from uuid import UUID, uuid4 +from argilla_server.constants import API_KEY_HEADER_NAME +from argilla_server.contexts import accounts +from argilla_server.enums import UserRole +from argilla_server.models import User +from httpx import AsyncClient +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from tests.factories import UserFactory + + +@pytest.mark.asyncio +class TestUpdateUser: + def url(self, user_id: UUID) -> str: + return f"/api/v1/users/{user_id}" + + async def test_update_user(self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict): + user = await UserFactory.create() + + user_password_hash = user.password_hash + + response = await async_client.patch( + self.url(user.id), + headers=owner_auth_header, + json={ + "first_name": "Updated First Name", + "last_name": "Updated Last Name", + "username": "updated_username", + "role": UserRole.admin, + "password": "new_password", + }, + ) + + assert response.status_code == 200 + + updated_user = (await db.execute(select(User).filter_by(id=user.id))).scalar_one() + assert updated_user.first_name == "Updated First Name" + assert updated_user.last_name == "Updated Last Name" + assert updated_user.username == "updated_username" + assert updated_user.role == UserRole.admin + assert updated_user.password_hash != user_password_hash + + async def test_update_user_password(self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict): + user = await UserFactory.create() + old_password_hash = user.password_hash + + response = await async_client.patch( + self.url(user.id), + headers=owner_auth_header, + json={ + "password": "new_password", + }, + ) + + assert response.status_code == 200 + + assert accounts.verify_password("new_password", user.password_hash) is True + assert accounts.verify_password("new_password", old_password_hash) is False + + async def test_update_user_without_authentication(self, db: AsyncSession, async_client: AsyncClient): + user = await UserFactory.create() + + response = await async_client.patch( + self.url(user.id), + json={ + "first_name": "Updated First Name", + "last_name": "Updated Last Name", + "username": "updated_username", + "role": UserRole.admin, + }, + ) + + assert response.status_code == 401 + + @pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator]) + async def test_update_user_with_unauthorized_role( + self, db: AsyncSession, async_client: AsyncClient, user_role: UserRole + ): + user = await UserFactory.create() + user_with_unauthorized_role = await UserFactory.create(role=user_role) + + response = await async_client.patch( + self.url(user.id), + headers={API_KEY_HEADER_NAME: user_with_unauthorized_role.api_key}, + json={ + "first_name": "Updated First Name", + "last_name": "Updated Last Name", + "username": "updated_username", + "role": UserRole.admin, + }, + ) + + assert response.status_code == 403 + + async def test_update_user_with_nonexistent_user_id(self, async_client: AsyncClient, owner_auth_header: dict): + user_id = uuid4() + + response = await async_client.patch( + self.url(user_id), + headers=owner_auth_header, + json={ + "first_name": "Updated First Name", + "last_name": "Updated Last Name", + "username": "updated_username", + "role": UserRole.admin, + }, + ) + + assert response.status_code == 404 + assert response.json() == {"detail": f"User with id `{user_id}` not found"} + + async def test_update_user_with_invalid_data( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + user = await UserFactory.create() + + response = await async_client.patch( + self.url(user.id), + headers=owner_auth_header, + json={ + "first_name": "", + "last_name": "Updated Last Name", + "username": "updated_username", + "role": "invalid_role", + }, + ) + + assert response.status_code == 422 + + async def test_update_user_with_duplicate_username( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + user1 = await UserFactory.create(username="user1") + user2 = await UserFactory.create(username="user2") + + response = await async_client.patch( + self.url(user2.id), + headers=owner_auth_header, + json={ + "username": user1.username, + }, + ) + + assert response.status_code == 422, response.json() + assert response.json() == {"detail": "Username 'user1' already exists"} + + async def test_update_user_with_none_first_name( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + user = await UserFactory.create() + + response = await async_client.patch( + self.url(user.id), + headers=owner_auth_header, + json={ + "first_name": None, + }, + ) + + assert response.status_code == 422 + assert response.json() == { + "detail": { + "code": "argilla.api.errors::ValidationError", + "params": { + "errors": [ + { + "loc": ["body"], + "msg": "Value error, The following keys must have non-null values: first_name", + "type": "value_error", + } + ] + }, + } + } + + async def test_update_user_with_none_last_name( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + user = await UserFactory.create() + + response = await async_client.patch( + self.url(user.id), + headers=owner_auth_header, + json={ + "last_name": None, + }, + ) + + assert response.status_code == 200 + assert response.json() == { + "id": str(user.id), + "api_key": user.api_key, + "first_name": user.first_name, + "last_name": None, + "username": user.username, + "role": user.role, + "inserted_at": user.inserted_at.isoformat(), + "updated_at": user.updated_at.isoformat(), + } + + async def test_update_user_with_none_username( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + user = await UserFactory.create() + + response = await async_client.patch( + self.url(user.id), + headers=owner_auth_header, + json={ + "username": None, + }, + ) + + assert response.status_code == 422 + assert response.json() == { + "detail": { + "code": "argilla.api.errors::ValidationError", + "params": { + "errors": [ + { + "loc": ["body"], + "msg": "Value error, The following keys must have non-null values: username", + "type": "value_error", + } + ] + }, + } + } + + async def test_update_user_with_none_password( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + user = await UserFactory.create() + + response = await async_client.patch( + self.url(user.id), + headers=owner_auth_header, + json={ + "password": None, + }, + ) + + assert response.status_code == 422 + assert response.json() == { + "detail": { + "code": "argilla.api.errors::ValidationError", + "params": { + "errors": [ + { + "loc": ["body"], + "msg": "Value error, The following keys must have non-null values: password", + "type": "value_error", + } + ] + }, + } + } diff --git a/argilla/CHANGELOG.md b/argilla/CHANGELOG.md index a35dded88a..b0a76c59ff 100644 --- a/argilla/CHANGELOG.md +++ b/argilla/CHANGELOG.md @@ -37,6 +37,7 @@ These are the section headers that we use: ### Changed +- User parameters can now be updated using the client ([#5614](https://github.com/argilla-io/argilla/issues/5614)) - Changed `Dataset.from_hub` method to open configure URL when `settings="ui"`. ([#5622](https://github.com/argilla-io/argilla/pull/5622)) - Terms metadata properties accept other values than `str`. ([#5594](https://github.com/argilla-io/argilla/pull/5594)) - Added support for `with_vectors` while fetching records along with a search query. ([#5638](https://github.com/argilla-io/argilla/pull/5638)) diff --git a/argilla/src/argilla/_api/_users.py b/argilla/src/argilla/_api/_users.py index 30d711fcbf..0063e15f92 100644 --- a/argilla/src/argilla/_api/_users.py +++ b/argilla/src/argilla/_api/_users.py @@ -43,6 +43,15 @@ def create(self, user: UserModel) -> UserModel: return user_created + @api_error_handler + def update(self, user: UserModel) -> UserModel: + json_body = user.model_dump() + response = self.http_client.patch(f"/api/v1/users/{user.id}", json=json_body).raise_for_status() + user_updated = self._model_from_json(response_json=response.json()) + self._log_message(message=f"Updated user {user_updated.username}") + + return user_updated + @api_error_handler def get(self, user_id: UUID) -> UserModel: # TODO: Implement this endpoint in the API diff --git a/argilla/tests/integration/test_manage_users.py b/argilla/tests/integration/test_manage_users.py index 58d716cef2..4671ce3ecd 100644 --- a/argilla/tests/integration/test_manage_users.py +++ b/argilla/tests/integration/test_manage_users.py @@ -16,7 +16,7 @@ import pytest from argilla import User, Argilla, Workspace -from argilla._exceptions import UnprocessableEntityError +from argilla._exceptions import UnprocessableEntityError, ConflictError class TestManageUsers: @@ -46,3 +46,30 @@ def test_add_user_to_workspace(self, client: Argilla, workspace: Workspace): user.add_to_workspace(workspace) assert user in workspace.users + + def test_update_user(self, client: Argilla): + user = User(username=f"test_update_user_{uuid.uuid4()}", password="test_password") + client.users.add(user) + + updated_username = f"updated_user_{uuid.uuid4()}" + user.username = updated_username + user.first_name = "Updated First Name" + user.last_name = "Updated Last Name" + user.role = "admin" + user.update() + + updated_user = client.users(id=user.id) + assert updated_user.username == updated_username + assert updated_user.first_name == "Updated First Name" + assert updated_user.last_name == "Updated Last Name" + assert updated_user.role == "admin" + + def test_update_user_with_duplicate_username(self, client: Argilla): + user1 = User(username=f"test_user1_{uuid.uuid4()}", password="test_password") + user2 = User(username=f"test_user2_{uuid.uuid4()}", password="test_password") + client.users.add(user1) + client.users.add(user2) + + user2.username = user1.username + with pytest.raises(expected_exception=UnprocessableEntityError): + user2.update()