Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add server dataset hub import #5591

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b2594b9
feat: first iteration of background job to import datasets from hub
jfcalvo Oct 10, 2024
6108ffe
feat: improve import_dataset_from_hub_job to get dataset before insta…
jfcalvo Oct 11, 2024
3a875ee
feat: improve HubDataset batch processing
jfcalvo Oct 11, 2024
b10a92e
feat: use UpsertRecordsBulk of CreateRecordsBulk for importing datase…
jfcalvo Oct 11, 2024
2b47522
feat: transform dataset importing value columns with PIL images to da…
jfcalvo Oct 11, 2024
7f93e0b
feat: add support to map suggestions importing datasets from hub
jfcalvo Oct 14, 2024
07aeec6
Merge branch 'feat/argilla-direct-feature-branch' into feat/add-hub-d…
frascuchon Oct 14, 2024
83237a7
Merge branch 'feat/argilla-direct-feature-branch' into feat/add-hub-d…
frascuchon Oct 14, 2024
a38fdda
feat: add support for hub dataset mapping
jfcalvo Oct 14, 2024
15d694f
feat: set metadata and suggestions as optional for HubDatasetMapping
jfcalvo Oct 14, 2024
c3752e8
feat: when no external_id is mapped row_idx is used
jfcalvo Oct 14, 2024
469ebc4
feat: use streaming when loading the dataset
jfcalvo Oct 15, 2024
b665019
feat: refactor UpsertRecordsBulk to validate records individually
jfcalvo Oct 15, 2024
a2fbc10
feat: ignore invalid records when importing datasets from hub
jfcalvo Oct 15, 2024
f3e33bd
Merge branch 'feat/argilla-direct-feature-branch' into feat/add-hub-d…
jfcalvo Oct 15, 2024
fde2603
Merge branch 'feat/argilla-direct-feature-branch' into feat/add-hub-d…
frascuchon Oct 16, 2024
3d1f04f
feat: add a fixed number of rows to take importing dataset from Hub (…
jfcalvo Oct 17, 2024
e84fb45
Merge branch 'feat/argilla-direct-feature-branch' into feat/add-hub-d…
jfcalvo Oct 17, 2024
08ebe28
feat: add support for class labels and casting rows (#5601)
jfcalvo Oct 18, 2024
312551a
feat: improve `HubDataset` image processing support (#5606)
jfcalvo Oct 18, 2024
f064267
feat: add support to `-1` no label values for `ClassLabel` dataset fe…
jfcalvo Oct 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 154 additions & 176 deletions argilla-server/pdm.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion argilla-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ dependencies = [
"typer >= 0.6.0, < 0.10.0", # spaCy only supports typer<0.10.0
"packaging>=23.2",
"psycopg2-binary>=2.9.9",
"datasets>=3.0.1",
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
# For Telemetry
"huggingface_hub>=0.13,<1",

]

[project.optional-dependencies]
Expand Down Expand Up @@ -100,7 +102,6 @@ test = [
"factory-boy~=3.2.1",
"httpx>=0.26.0",
# Required by tests/unit/utils/test_dependency.py but we should take a look a probably removed them
"datasets > 1.17.0,!= 2.3.2",
"spacy>=3.5.0,<3.7.0",
"pytest-randomly>=3.15.0",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
DatasetProgress,
Datasets,
DatasetUpdate,
HubDataset,
UsersProgress,
)
from argilla_server.api.schemas.v1.fields import Field, FieldCreate, Fields
Expand All @@ -38,9 +39,11 @@
MetadataPropertyCreate,
)
from argilla_server.api.schemas.v1.vector_settings import VectorSettings, VectorSettingsCreate, VectorsSettings
from argilla_server.api.schemas.v1.jobs import Job as JobSchema
from argilla_server.contexts import datasets
from argilla_server.database import get_async_db
from argilla_server.enums import DatasetStatus
from argilla_server.jobs import hub_jobs
from argilla_server.models import Dataset, User
from argilla_server.search_engine import (
SearchEngine,
Expand Down Expand Up @@ -301,3 +304,26 @@
await authorize(current_user, DatasetPolicy.update(dataset))

return await datasets.update_dataset(db, dataset, dataset_update.dict(exclude_unset=True))


# TODO: Maybe change /import to /import-from-hub?
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
@router.post("/datasets/{dataset_id}/import", status_code=status.HTTP_202_ACCEPTED, response_model=JobSchema)
async def import_dataset_from_hub(
*,
db: AsyncSession = Depends(get_async_db),
dataset_id: UUID,
hub_dataset: HubDataset,
current_user: User = Security(auth.get_current_user),
):
dataset = await Dataset.get_or_raise(db, dataset_id)

Check warning on line 318 in argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py#L318

Added line #L318 was not covered by tests

await authorize(current_user, DatasetPolicy.import_from_hub(dataset))

Check warning on line 320 in argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py#L320

Added line #L320 was not covered by tests

job = hub_jobs.import_dataset_from_hub_job.delay(

Check warning on line 322 in argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py#L322

Added line #L322 was not covered by tests
name=hub_dataset.name,
subset=hub_dataset.subset,
split=hub_dataset.split,
dataset_id=dataset.id,
)

return JobSchema(id=job.id, status=job.get_status())

Check warning on line 329 in argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py#L329

Added line #L329 was not covered by tests
52 changes: 52 additions & 0 deletions argilla-server/src/argilla_server/api/handlers/v1/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from fastapi import APIRouter, Depends, HTTPException, Security, status
from sqlalchemy.ext.asyncio import AsyncSession

from rq.job import Job
from rq.exceptions import NoSuchJobError

from argilla_server.database import get_async_db
from argilla_server.jobs.queues import REDIS_CONNECTION
from argilla_server.models import User
from argilla_server.api.policies.v1 import JobPolicy, authorize
from argilla_server.api.schemas.v1.jobs import Job as JobSchema
from argilla_server.security import auth

router = APIRouter(tags=["jobs"])


def _get_job(job_id: str) -> Job:
try:
return Job.fetch(job_id, connection=REDIS_CONNECTION)
except NoSuchJobError:
raise HTTPException(

Check warning on line 35 in argilla-server/src/argilla_server/api/handlers/v1/jobs.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/jobs.py#L32-L35

Added lines #L32 - L35 were not covered by tests
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Job with id `{job_id}` not found",
)


@router.get("/jobs/{job_id}", response_model=JobSchema)
async def get_job(
*,
db: AsyncSession = Depends(get_async_db),
job_id: str,
current_user: User = Security(auth.get_current_user),
):
job = _get_job(job_id)

Check warning on line 48 in argilla-server/src/argilla_server/api/handlers/v1/jobs.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/jobs.py#L48

Added line #L48 was not covered by tests

await authorize(current_user, JobPolicy.get)

Check warning on line 50 in argilla-server/src/argilla_server/api/handlers/v1/jobs.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/jobs.py#L50

Added line #L50 was not covered by tests

return JobSchema(id=job.id, status=job.get_status(refresh=True))

Check warning on line 52 in argilla-server/src/argilla_server/api/handlers/v1/jobs.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/handlers/v1/jobs.py#L52

Added line #L52 was not covered by tests
2 changes: 2 additions & 0 deletions argilla-server/src/argilla_server/api/policies/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from argilla_server.api.policies.v1.vector_settings_policy import VectorSettingsPolicy
from argilla_server.api.policies.v1.workspace_policy import WorkspacePolicy
from argilla_server.api.policies.v1.workspace_user_policy import WorkspaceUserPolicy
from argilla_server.api.policies.v1.job_policy import JobPolicy

__all__ = [
"DatasetPolicy",
Expand All @@ -37,6 +38,7 @@
"VectorSettingsPolicy",
"WorkspacePolicy",
"WorkspaceUserPolicy",
"JobPolicy",
"authorize",
"is_authorized",
]
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,10 @@
return actor.is_owner or (actor.is_admin and await actor.is_member(dataset.workspace_id))

return is_allowed

@classmethod
def import_from_hub(cls, dataset: Dataset) -> PolicyAction:
async def is_allowed(actor: User) -> bool:
return actor.is_owner or (actor.is_admin and await actor.is_member(dataset.workspace_id))

Check warning on line 148 in argilla-server/src/argilla_server/api/policies/v1/dataset_policy.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/policies/v1/dataset_policy.py#L147-L148

Added lines #L147 - L148 were not covered by tests

return is_allowed

Check warning on line 150 in argilla-server/src/argilla_server/api/policies/v1/dataset_policy.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/policies/v1/dataset_policy.py#L150

Added line #L150 was not covered by tests
21 changes: 21 additions & 0 deletions argilla-server/src/argilla_server/api/policies/v1/job_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla_server.models import User


class JobPolicy:
@classmethod
async def get(cls, actor: User) -> bool:
return actor.is_owner or actor.is_admin

Check warning on line 21 in argilla-server/src/argilla_server/api/policies/v1/job_policy.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/api/policies/v1/job_policy.py#L21

Added line #L21 was not covered by tests
2 changes: 2 additions & 0 deletions argilla-server/src/argilla_server/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from argilla_server.api.handlers.v1 import (
workspaces as workspaces_v1,
)
from argilla_server.api.handlers.v1 import jobs as jobs_v1
from argilla_server.errors.base_errors import __ALL__
from argilla_server.errors.error_handler import APIErrorHandler

Expand Down Expand Up @@ -92,6 +93,7 @@ def create_api_v1():
users_v1.router,
vectors_settings_v1.router,
workspaces_v1.router,
jobs_v1.router,
oauth2_v1.router,
settings_v1.router,
]:
Expand Down
6 changes: 6 additions & 0 deletions argilla-server/src/argilla_server/api/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,9 @@ class DatasetUpdate(UpdateSchema):
distribution: Optional[DatasetDistributionUpdate]

__non_nullable_fields__ = {"name", "allow_extra_metadata", "distribution"}


class HubDataset(BaseModel):
name: str
subset: str
split: str
21 changes: 21 additions & 0 deletions argilla-server/src/argilla_server/api/schemas/v1/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from rq.job import JobStatus
from pydantic import BaseModel


class Job(BaseModel):
id: str
status: JobStatus
90 changes: 90 additions & 0 deletions argilla-server/src/argilla_server/contexts/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing_extensions import Self

from datasets import load_dataset
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.models.database import Dataset
from argilla_server.search_engine import SearchEngine
from argilla_server.bulk.records_bulk import CreateRecordsBulk
from argilla_server.api.schemas.v1.records import RecordCreate as RecordCreateSchema
from argilla_server.api.schemas.v1.records_bulk import RecordsBulkCreate as RecordsBulkCreateSchema

BATCH_SIZE = 100


class HubDataset:
# TODO: (Ben feedback) rename `name` to `repository_id` or `repo_id`
# TODO: (Ben feedback) check subset and split and see if we should support None
def __init__(self, name: str, subset: str, split: str):
self.dataset = load_dataset(path=name, name=subset, split=split)
self.iterable_dataset = self.dataset.to_iterable_dataset()

@property
def num_rows(self) -> int:
return self.dataset.num_rows

def take(self, n: int) -> Self:
self.iterable_dataset = self.iterable_dataset.take(n)

return self

# TODO: We can change things so we get the database and search engine here instead of receiving them as parameters
async def import_to(self, db: AsyncSession, search_engine: SearchEngine, dataset: Dataset) -> None:
if not dataset.is_ready:
raise Exception("it's not possible to import records to a non published dataset")

Check warning on line 48 in argilla-server/src/argilla_server/contexts/hub.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/contexts/hub.py#L48

Added line #L48 was not covered by tests

batched_dataset = self.iterable_dataset.batch(batch_size=BATCH_SIZE)
for batch in batched_dataset:
await self._import_batch_to(db, search_engine, batch, dataset)

async def _import_batch_to(
self, db: AsyncSession, search_engine: SearchEngine, batch: dict, dataset: Dataset
) -> None:
batch_size = len(next(iter(batch.values())))

items = []
for i in range(batch_size):
# NOTE: if there is a value with key "id" in the batch, we will use it as external_id
external_id = None
if "id" in batch:
external_id = batch["id"][i]

fields = {}
for field in dataset.fields:
# TODO: Should we cast to string or change the schema to use not strict string?
value = batch[field.name][i]
if field.is_text:
value = str(value)

fields[field.name] = value

metadata = {}
for metadata_property in dataset.metadata_properties:
metadata[metadata_property.name] = batch[metadata_property.name][i]

items.append(
RecordCreateSchema(
fields=fields,
metadata=metadata,
external_id=external_id,
responses=None,
suggestions=None,
vectors=None,
),
)

await CreateRecordsBulk(db, search_engine).create_records_bulk(dataset, RecordsBulkCreateSchema(items=items))
2 changes: 1 addition & 1 deletion argilla-server/src/argilla_server/jobs/dataset_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


@job(DEFAULT_QUEUE, timeout=JOB_TIMEOUT_DISABLED, retry=Retry(max=3))
async def update_dataset_records_status_job(dataset_id: UUID):
async def update_dataset_records_status_job(dataset_id: UUID) -> None:
"""This Job updates the status of all the records in the dataset when the distribution strategy changes."""

record_ids = []
Expand Down
48 changes: 48 additions & 0 deletions argilla-server/src/argilla_server/jobs/hub_jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from uuid import UUID

from rq import Retry
from rq.decorators import job
from sqlalchemy.orm import selectinload

from argilla_server.models import Dataset
from argilla_server.settings import settings
from argilla_server.contexts.hub import HubDataset
from argilla_server.database import AsyncSessionLocal
from argilla_server.search_engine.base import SearchEngine
from argilla_server.jobs.queues import DEFAULT_QUEUE

# TODO: Move this to be defined on jobs queues as a shared constant
JOB_TIMEOUT_DISABLED = -1


# TODO: Once we merge webhooks we should change the queue to use a different one (default queue is deleted there)
@job(DEFAULT_QUEUE, timeout=JOB_TIMEOUT_DISABLED, retry=Retry(max=3))
async def import_dataset_from_hub_job(name: str, subset: str, split: str, dataset_id: UUID) -> None:
hub_dataset = HubDataset(name, subset, split)

Check warning on line 35 in argilla-server/src/argilla_server/jobs/hub_jobs.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/jobs/hub_jobs.py#L35

Added line #L35 was not covered by tests
jfcalvo marked this conversation as resolved.
Show resolved Hide resolved

async with AsyncSessionLocal() as db:
async with SearchEngine.get_by_name(settings.search_engine) as search_engine:
dataset = await Dataset.get_or_raise(

Check warning on line 39 in argilla-server/src/argilla_server/jobs/hub_jobs.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/jobs/hub_jobs.py#L37-L39

Added lines #L37 - L39 were not covered by tests
db,
dataset_id,
options=[
selectinload(Dataset.fields),
selectinload(Dataset.metadata_properties),
],
)

await hub_dataset.import_to(db, search_engine, dataset)

Check warning on line 48 in argilla-server/src/argilla_server/jobs/hub_jobs.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/jobs/hub_jobs.py#L48

Added line #L48 was not covered by tests
Loading