Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR] argilla server: using pydantic v2 #5666

Merged
merged 32 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
50b0e82
config: Update to pydantic v2
frascuchon Nov 5, 2024
2f53019
refactor: Update API schemas
frascuchon Nov 5, 2024
1a8539d
refactor: Import pydantic module
frascuchon Nov 5, 2024
34d5e30
chore: Remove unused files
frascuchon Nov 5, 2024
798db75
refactor: Convert to settingschema
frascuchon Nov 5, 2024
18e721e
chore: Remove v2 logic for APP
frascuchon Nov 5, 2024
43330fe
refactor: Review and adapt error handler
frascuchon Nov 5, 2024
a3050fa
tests: Adapt tests messages
frascuchon Nov 5, 2024
fb16594
fix: Add coerce 2 string for terms filter
frascuchon Nov 5, 2024
ff30ee1
chore: Redefine TextQuery as BaseModel with coerce
frascuchon Nov 5, 2024
79cfc93
refactor: review validator code
frascuchon Nov 5, 2024
6aec507
Merge branch 'develop' into refactor/argilla-server/using-pydanticV2
frascuchon Nov 6, 2024
d305ae4
Merge branch 'develop' into refactor/argilla-server/using-pydanticV2
frascuchon Nov 6, 2024
145ea38
Merge branch 'develop' into refactor/argilla-server/using-pydanticV2
frascuchon Nov 6, 2024
f346a84
[BUGFIX] `argilla frontend`: redirect after login (#5635)
frascuchon Nov 6, 2024
489c265
Merge branch 'develop' into refactor/argilla-server/using-pydanticV2
frascuchon Nov 6, 2024
06a666e
Merge branch 'develop' into refactor/argilla-server/using-pydanticV2
frascuchon Nov 8, 2024
9d758ab
Merge branch 'develop' into refactor/argilla-server/using-pydanticV2
frascuchon Nov 8, 2024
2c0ae2e
refactor: Migrate validatorsand deprecated methods/functions
frascuchon Nov 8, 2024
3d5f745
chore: Same with tests
frascuchon Nov 8, 2024
72975ff
update pdm.lock
frascuchon Nov 8, 2024
eb4234c
Merge branch 'develop' into refactor/argilla-server/using-pydanticV2
frascuchon Nov 11, 2024
612c721
Merge branch 'develop' into refactor/argilla-server/using-pydanticV2
frascuchon Nov 18, 2024
c92b407
Merge branch 'develop' into refactor/argilla-server/using-pydanticV2
frascuchon Nov 19, 2024
a29d185
chore: update pdm.lock hash
frascuchon Nov 19, 2024
764393d
suggestion: deb DB name
frascuchon Nov 19, 2024
e0e9538
chore: Add missing tests with empty chat values
frascuchon Nov 19, 2024
2862020
chore: Remove extra imports
frascuchon Nov 19, 2024
dd96edd
Merge branch 'develop' into refactor/argilla-server/using-pydanticV2
jfcalvo Nov 19, 2024
59d2122
fix: webhook models updated to use Pydantic v2
jfcalvo Nov 19, 2024
56d6830
fix: revert mistake
jfcalvo Nov 19, 2024
4bd1522
fix: return back model_dump for metadata properties settings
jfcalvo Nov 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion argilla-server/.env.dev
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES # Needed by RQ to work with forked processes on MacOS
ALEMBIC_CONFIG=src/argilla_server/alembic.ini
ARGILLA_AUTH_SECRET_KEY=8VO7na5N/jQx+yP/N+HlE8q51vPdrxqlh6OzoebIyko= # With this we avoid using a different key every time the server is reloaded
ARGILLA_DATABASE_URL=sqlite+aiosqlite:///${HOME}/.argilla/argilla.db?check_same_thread=False
ARGILLA_DATABASE_URL=sqlite+aiosqlite:///${HOME}/.argilla/argilla-dev.db?check_same_thread=False
171 changes: 132 additions & 39 deletions argilla-server/pdm.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion argilla-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ maintainers = [{ name = "argilla", email = "[email protected]" }]
dependencies = [
# Basic dependencies
"fastapi ~= 0.115.0",
"pydantic ~= 1.10.18",
"pydantic ~= 2.9.0",
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
"pydantic-settings ~= 2.6.0",
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
"uvicorn[standard] ~= 0.32.0",
"opensearch-py ~= 2.0.0",
"elasticsearch8[async] ~= 8.7.0",
Expand Down
9 changes: 1 addition & 8 deletions argilla-server/src/argilla_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Remove me
import warnings

from argilla_server.pydantic_v1 import PYDANTIC_MAJOR_VERSION

if PYDANTIC_MAJOR_VERSION >= 2:
warnings.warn("The argilla_server package is not compatible with Pydantic 2. " "Please use Pydantic 1.x instead.")
else:
from argilla_server._app import app # noqa
from argilla_server._app import app # noqa
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
get_search_engine,
)
from argilla_server.security import auth
from argilla_server.telemetry import TelemetryClient, get_telemetry_client

router = APIRouter()

Expand Down Expand Up @@ -203,7 +202,7 @@ async def create_dataset(
):
await authorize(current_user, DatasetPolicy.create(dataset_create.workspace_id))

return await datasets.create_dataset(db, dataset_create.dict())
return await datasets.create_dataset(db, dataset_create.model_dump())


@router.post("/datasets/{dataset_id}/fields", status_code=status.HTTP_201_CREATED, response_model=Field)
Expand Down Expand Up @@ -310,7 +309,7 @@ async def update_dataset(

await authorize(current_user, DatasetPolicy.update(dataset))

return await datasets.update_dataset(db, dataset, dataset_update.dict(exclude_unset=True))
return await datasets.update_dataset(db, dataset, dataset_update.model_dump(exclude_unset=True))


@router.post("/datasets/{dataset_id}/import", status_code=status.HTTP_202_ACCEPTED, response_model=JobSchema)
Expand All @@ -330,7 +329,7 @@ async def import_dataset_from_hub(
subset=hub_dataset.subset,
split=hub_dataset.split,
dataset_id=dataset.id,
mapping=hub_dataset.mapping.dict(),
mapping=hub_dataset.mapping.model_dump(),
)

return JobSchema(id=job.id, status=job.get_status())
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ async def search_current_user_dataset_records(
record.metadata_ = await _filter_record_metadata_for_user(record, current_user)

record_id_score_map[record.id]["search_record"] = SearchRecord(
record=RecordSchema.from_orm(record),
record=RecordSchema.model_validate(record),
query_score=record_id_score_map[record.id]["query_score"],
)

Expand Down Expand Up @@ -382,7 +382,7 @@ async def search_dataset_records(

for record in records:
record_id_score_map[record.id]["search_record"] = SearchRecord(
record=RecordSchema.from_orm(record),
record=RecordSchema.model_validate(record),
query_score=record_id_score_map[record.id]["query_score"],
)

Expand Down
5 changes: 2 additions & 3 deletions argilla-server/src/argilla_server/api/handlers/v1/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
from fastapi.responses import RedirectResponse
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server import telemetry
from argilla_server.api.schemas.v1.oauth2 import Provider, Providers, Token
from argilla_server.api.schemas.v1.users import UserCreate
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.enums import UserRole
from argilla_server.errors.future import NotFoundError
from argilla_server.models import User
from argilla_server.pydantic_v1 import Field
from pydantic import Field
from argilla_server.security.authentication.oauth2 import OAuth2ClientProvider
from argilla_server.security.authentication.userinfo import UserInfo
from argilla_server.security.settings import settings
Expand Down Expand Up @@ -86,7 +85,7 @@ async def get_access_token(
username=userinfo.username,
first_name=userinfo.first_name,
role=userinfo.role,
).dict(exclude_unset=True),
).model_dump(exclude_unset=True),
workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces],
)

Expand Down
2 changes: 1 addition & 1 deletion argilla-server/src/argilla_server/api/handlers/v1/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def create_user(
):
await authorize(current_user, UserPolicy.create)

user = await accounts.create_user(db, user_create.dict())
user = await accounts.create_user(db, user_create.model_dump())

return user

Expand Down
4 changes: 2 additions & 2 deletions argilla-server/src/argilla_server/api/handlers/v1/webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def create_webhook(
):
await authorize(current_user, WebhookPolicy.create)

return await webhooks.create_webhook(db, webhook_create.dict())
return await webhooks.create_webhook(db, webhook_create.model_dump())


@router.patch("/webhooks/{webhook_id}", response_model=WebhookSchema)
Expand All @@ -68,7 +68,7 @@ async def update_webhook(

await authorize(current_user, WebhookPolicy.update)

return await webhooks.update_webhook(db, webhook, webhook_update.dict(exclude_unset=True))
return await webhooks.update_webhook(db, webhook, webhook_update.model_dump(exclude_unset=True))


@router.delete("/webhooks/{webhook_id}", response_model=WebhookSchema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def create_workspace(
):
await authorize(current_user, WorkspacePolicy.create)

return await accounts.create_workspace(db, workspace_create.dict())
return await accounts.create_workspace(db, workspace_create.model_dump())


@router.delete("/workspaces/{workspace_id}", response_model=WorkspaceSchema)
Expand Down
7 changes: 3 additions & 4 deletions argilla-server/src/argilla_server/api/schemas/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla_server.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field

MIN_MESSAGE_LENGTH = 1
MAX_MESSAGE_LENGTH = 20000

MIN_ROLE_LENGTH = 1
MAX_ROLE_LENGTH = 20
MAX_ROLE_REGEX = r"^\S+$"


class ChatFieldValue(BaseModel):
role: str = Field(..., min_role_length=MIN_ROLE_LENGTH, max_length=MAX_ROLE_LENGTH, regex=MAX_ROLE_REGEX)
content: str = Field(..., min_message_length=MIN_MESSAGE_LENGTH, max_length=MAX_MESSAGE_LENGTH)
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
role: str = Field(..., min_length=MIN_ROLE_LENGTH, max_length=MAX_ROLE_LENGTH)
content: str = Field(..., min_length=MIN_MESSAGE_LENGTH, max_length=MAX_MESSAGE_LENGTH)
13 changes: 7 additions & 6 deletions argilla-server/src/argilla_server/api/schemas/v1/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import Any, Dict, Set, Union

from argilla_server.pydantic_v1 import BaseModel, root_validator
from pydantic import BaseModel, model_validator


class UpdateSchema(BaseModel):
Expand All @@ -25,17 +25,18 @@ class UpdateSchema(BaseModel):

__non_nullable_fields__: Union[Set[str], None] = None

@root_validator(pre=True)
def validate_non_nullable_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def validate_non_nullable_fields(cls, data: Dict[str, Any]) -> Dict[str, Any]:
if cls.__non_nullable_fields__ is None:
return values
return data

invalid_keys = []
for key in cls.__non_nullable_fields__:
if key in values and values[key] is None:
if key in data and data[key] is None:
invalid_keys.append(key)

if invalid_keys:
raise ValueError(f"The following keys must have non-null values: {', '.join(invalid_keys)}")

return values
return data
40 changes: 25 additions & 15 deletions argilla-server/src/argilla_server/api/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.

from datetime import datetime
from typing import List, Literal, Optional, Union, Dict, Any
from typing import List, Literal, Optional, Dict, Any
from uuid import UUID

from pydantic.v1.utils import GetterDict

from argilla_server.api.schemas.v1.commons import UpdateSchema
from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus
from argilla_server.pydantic_v1 import BaseModel, Field, constr
from argilla_server.pydantic_v1.utils import GetterDict
from pydantic import BaseModel, Field, constr, ConfigDict, model_validator

try:
from typing import Annotated
Expand Down Expand Up @@ -106,7 +107,7 @@ class UsersProgress(BaseModel):


class DatasetGetterDict(GetterDict):
def get(self, key: str, default: Any) -> Any:
def get(self, key: Any, default: Any = None) -> Any:
if key == "metadata":
return getattr(self._obj, "metadata_", None)

Expand All @@ -116,19 +117,28 @@ def get(self, key: str, default: Any) -> Any:
class Dataset(BaseModel):
id: UUID
name: str
guidelines: Optional[str]
guidelines: Optional[str] = None
allow_extra_metadata: bool
status: DatasetStatus
distribution: DatasetDistribution
metadata: Optional[Dict[str, Any]]
metadata: Optional[Dict[str, Any]] = None
workspace_id: UUID
last_activity_at: datetime
inserted_at: datetime
updated_at: datetime

class Config:
orm_mode = True
getter_dict = DatasetGetterDict
model_config = ConfigDict(from_attributes=True)

@model_validator(mode="before")
@classmethod
def validate(cls, value) -> dict:
getter = DatasetGetterDict(value)

data = {}
for field in cls.model_fields:
data[field] = getter.get(field)

return data


class Datasets(BaseModel):
Expand All @@ -137,7 +147,7 @@ class Datasets(BaseModel):

class DatasetCreate(BaseModel):
name: DatasetName
guidelines: Optional[DatasetGuidelines]
guidelines: Optional[DatasetGuidelines] = None
allow_extra_metadata: bool = True
distribution: DatasetDistributionCreate = DatasetOverlapDistributionCreate(
strategy=DatasetDistributionStrategy.overlap,
Expand All @@ -148,10 +158,10 @@ class DatasetCreate(BaseModel):


class DatasetUpdate(UpdateSchema):
name: Optional[DatasetName]
guidelines: Optional[DatasetGuidelines]
allow_extra_metadata: Optional[bool]
distribution: Optional[DatasetDistributionUpdate]
name: Optional[DatasetName] = None
guidelines: Optional[DatasetGuidelines] = None
allow_extra_metadata: Optional[bool] = None
distribution: Optional[DatasetDistributionUpdate] = None
metadata_: Optional[Dict[str, Any]] = Field(None, alias="metadata")

__non_nullable_fields__ = {"name", "allow_extra_metadata", "distribution"}
Expand All @@ -163,7 +173,7 @@ class HubDatasetMappingItem(BaseModel):


class HubDatasetMapping(BaseModel):
fields: List[HubDatasetMappingItem] = Field(..., min_items=1)
fields: List[HubDatasetMappingItem] = Field(..., min_length=1)
metadata: Optional[List[HubDatasetMappingItem]] = []
suggestions: Optional[List[HubDatasetMappingItem]] = []
external_id: Optional[str] = None
Expand Down
12 changes: 5 additions & 7 deletions argilla-server/src/argilla_server/api/schemas/v1/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@

from argilla_server.api.schemas.v1.commons import UpdateSchema
from argilla_server.enums import FieldType
from argilla_server.pydantic_v1 import BaseModel, constr
from argilla_server.pydantic_v1 import Field as PydanticField
from pydantic import BaseModel, constr, Field as PydanticField, ConfigDict

FIELD_CREATE_NAME_MIN_LENGTH = 1
FIELD_CREATE_NAME_MAX_LENGTH = 200
Expand Down Expand Up @@ -145,8 +144,7 @@ class Field(BaseModel):
inserted_at: datetime
updated_at: datetime

class Config:
orm_mode = True
model_config = ConfigDict(from_attributes=True)


class Fields(BaseModel):
Expand All @@ -156,12 +154,12 @@ class Fields(BaseModel):
class FieldCreate(BaseModel):
name: FieldName
title: FieldTitle
required: Optional[bool]
required: Optional[bool] = None
settings: FieldSettingsCreate


class FieldUpdate(UpdateSchema):
title: Optional[FieldTitle]
settings: Optional[FieldSettingsUpdate]
title: Optional[FieldTitle] = None
settings: Optional[FieldSettingsUpdate] = None

__non_nullable_fields__ = {"title", "settings"}
2 changes: 1 addition & 1 deletion argilla-server/src/argilla_server/api/schemas/v1/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla_server.pydantic_v1 import BaseModel
from pydantic import BaseModel


class Version(BaseModel):
Expand Down
Loading
Loading