Skip to content

Commit

Permalink
Merge branch 'docs/ui-import-from-hub' of github.com:argilla-io/argil…
Browse files Browse the repository at this point in the history
…la into docs/ui-import-from-hub
  • Loading branch information
nataliaElv committed Oct 28, 2024
2 parents 1171445 + 5f4ddf9 commit 7f25c0f
Show file tree
Hide file tree
Showing 65 changed files with 2,648 additions and 1,059 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/argilla.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
build:
services:
argilla-server:
image: argilladev/argilla-hf-spaces:pr-5573
image: argilladev/argilla-hf-spaces:pr-5572
ports:
- 6900:6900
env:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
width="18"
height="18"
color="#F6C000"
aria-hidden="true"
></svgicon
>{{ $t("home.importFromHub") }}</BaseButton
>
Expand All @@ -18,7 +19,8 @@
class="import-from-hub__close-button"
@click="$emit('on-close')"
>
<svgicon name="close" width="8"></svgicon>Close</BaseButton
<svgicon name="close" width="8" aria-hidden="true"></svgicon
>Close</BaseButton
>
<form @submit.prevent="$emit('on-import-dataset', repositoryId)">
<transition name="slide-right" appear>
Expand All @@ -31,6 +33,7 @@
name="link"
width="20"
height="20"
aria-hidden="true"
></svgicon>
<BaseInput
v-model="repositoryId"
Expand Down
12 changes: 1 addition & 11 deletions argilla-frontend/pages/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
@on-import-dataset="importDataset"
:error="error"
/>
<ImportFromPython v-if="!showImportDatasetInput" />
</div>
<div class="home__sidebar__content">
<p
Expand Down Expand Up @@ -85,16 +84,7 @@
link="https://docs.argilla.io/dev/how_to_guides/query/"
/>
</div>
<p class="home__sidebar__link">
Log to
<a
href="https://huggingface.co/spaces/argilla/argilla-template-space"
target="_blank"
>
Argilla_template_space</a
>
to try it out
</p>
<p class="home__sidebar__link" v-html="$t('home.demoLink')" />
</div>
</template>
</template>
Expand Down
2 changes: 2 additions & 0 deletions argilla-frontend/translation/en.js
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ export default {
guidesTitle: "First time in Argilla?",
guidesText: "Take a look at these guides:",
pasteRepoIdPlaceholder: "Paste a repo id",
demoLink:
"Log into this <a href='https://huggingface.co/spaces/argilla/argilla-template-space' target='_blank'>demo</a> to try Argilla out",
},
datasetCreation: {
questions: {
Expand Down
2 changes: 2 additions & 0 deletions argilla-frontend/translation/es.js
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ export default {
guidesTitle: "¿Primera vez en Argilla?",
guidesText: "Echa un vistazo a estas guías:",
pasteRepoIdPlaceholder: "Pega un repo id",
demoLink:
"Entra en esta <a href='https://huggingface.co/spaces/argilla/argilla-template-space' target='_blank'>demo</a> para probar Argilla",
},
datasetCreation: {
questions: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,12 @@ describe("DatasetCreation", () => {
expect(labelQuestion.required).toBeTruthy();
expect(labelQuestion.options).toEqual([
{
id: "positive",
text: "positive",
value: "positive",
},
{
id: "negative",
text: "negative",
value: "negative",
},
Expand Down
4 changes: 4 additions & 0 deletions argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ These are the section headers that we use:
- Removed name pattern restrictions for Vector Settings. ([#5573](https://github.com/argilla-io/argilla/pull/5573))
- Removed name pattern validation for Workspaces, Datasets, and Users. ([#5575](https://github.com/argilla-io/argilla/pull/5575))

### Fixed

- Fixed wrong field content conversion for empty text and partial chat fields. ([#5600](https://github.com/argilla-io/argilla/pull/5600))

## [2.3.1](https://github.com/argilla-io/argilla/compare/v2.3.0...v2.3.1)

### Fixed
Expand Down
504 changes: 268 additions & 236 deletions argilla-server/pdm.lock

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions argilla-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ dependencies = [
"backoff>=1.11.1",
# Database dependencies
"alembic ~= 1.9.0",
"SQLAlchemy == 2.0.31",
"greenlet >= 2.0.0",
"SQLAlchemy == 2.0.35",
"greenlet ~= 3.1.0",
# Async SQLite
"aiosqlite == 0.20.0",
# metrics
Expand Down Expand Up @@ -59,6 +59,9 @@ dependencies = [
"typer >= 0.6.0, < 0.10.0", # spaCy only supports typer<0.10.0
"packaging>=23.2",
"psycopg2-binary>=2.9.9",
# For HF dataset import
"datasets>=3.0.1",
"pillow>=10.4.0",
# For Telemetry
"huggingface_hub>=0.13,<1",
]
Expand Down Expand Up @@ -100,7 +103,6 @@ test = [
"factory-boy~=3.2.1",
"httpx>=0.26.0",
# Required by tests/unit/utils/test_dependency.py but we should take a look a probably removed them
"datasets > 1.17.0,!= 2.3.2",
"spacy>=3.5.0,<3.7.0",
"pytest-randomly>=3.15.0",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
DatasetProgress,
Datasets,
DatasetUpdate,
HubDataset,
UsersProgress,
)
from argilla_server.api.schemas.v1.fields import Field, FieldCreate, Fields
Expand All @@ -38,9 +39,11 @@
MetadataPropertyCreate,
)
from argilla_server.api.schemas.v1.vector_settings import VectorSettings, VectorSettingsCreate, VectorsSettings
from argilla_server.api.schemas.v1.jobs import Job as JobSchema
from argilla_server.contexts import datasets
from argilla_server.database import get_async_db
from argilla_server.enums import DatasetStatus
from argilla_server.jobs import hub_jobs
from argilla_server.models import Dataset, User
from argilla_server.search_engine import (
SearchEngine,
Expand Down Expand Up @@ -143,37 +146,44 @@ async def get_dataset(
@router.get("/me/datasets/{dataset_id}/metrics", response_model=DatasetMetrics)
async def get_current_user_dataset_metrics(
*,
db: AsyncSession = Depends(get_async_db),
dataset_id: UUID,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
current_user: User = Security(auth.get_current_user),
):
dataset = await Dataset.get_or_raise(db, dataset_id)

await authorize(current_user, DatasetPolicy.get(dataset))

return await datasets.get_user_dataset_metrics(db, current_user.id, dataset.id)
result = await datasets.get_user_dataset_metrics(db, search_engine, current_user, dataset)

return DatasetMetrics(responses=result)


@router.get("/datasets/{dataset_id}/progress", response_model=DatasetProgress)
async def get_dataset_progress(
*,
db: AsyncSession = Depends(get_async_db),
dataset_id: UUID,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
current_user: User = Security(auth.get_current_user),
):
dataset = await Dataset.get_or_raise(db, dataset_id)

await authorize(current_user, DatasetPolicy.get(dataset))

return await datasets.get_dataset_progress(db, dataset.id)
result = await datasets.get_dataset_progress(search_engine, dataset)

return DatasetProgress(**result)


@router.get("/datasets/{dataset_id}/users/progress", response_model=UsersProgress)
async def get_dataset_users_progress(
*,
current_user: User = Security(auth.get_current_user),
dataset_id: UUID,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
current_user: User = Security(auth.get_current_user),
):
dataset = await Dataset.get_or_raise(db, dataset_id)

Expand Down Expand Up @@ -301,3 +311,27 @@ async def update_dataset(
await authorize(current_user, DatasetPolicy.update(dataset))

return await datasets.update_dataset(db, dataset, dataset_update.dict(exclude_unset=True))


# TODO: Maybe change /import to /import-from-hub?
@router.post("/datasets/{dataset_id}/import", status_code=status.HTTP_202_ACCEPTED, response_model=JobSchema)
async def import_dataset_from_hub(
*,
db: AsyncSession = Depends(get_async_db),
dataset_id: UUID,
hub_dataset: HubDataset,
current_user: User = Security(auth.get_current_user),
):
dataset = await Dataset.get_or_raise(db, dataset_id)

await authorize(current_user, DatasetPolicy.import_from_hub(dataset))

job = hub_jobs.import_dataset_from_hub_job.delay(
name=hub_dataset.name,
subset=hub_dataset.subset,
split=hub_dataset.split,
dataset_id=dataset.id,
mapping=hub_dataset.mapping.dict(),
)

return JobSchema(id=job.id, status=job.get_status())
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def create_dataset_records_bulk(
async def upsert_dataset_records_bulk(
*,
dataset_id: UUID,
records_bulk_create: RecordsBulkUpsert,
records_bulk_upsert: RecordsBulkUpsert,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
current_user: User = Security(auth.get_current_user),
Expand All @@ -86,7 +86,7 @@ async def upsert_dataset_records_bulk(

await authorize(current_user, DatasetPolicy.upsert_records(dataset))

records_bulk = await UpsertRecordsBulk(db, search_engine).upsert_records_bulk(dataset, records_bulk_create)
records_bulk = await UpsertRecordsBulk(db, search_engine).upsert_records_bulk(dataset, records_bulk_upsert)

updated = len(records_bulk.updated_item_ids)
created = len(records_bulk.items) - updated
Expand Down
52 changes: 52 additions & 0 deletions argilla-server/src/argilla_server/api/handlers/v1/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from fastapi import APIRouter, Depends, HTTPException, Security, status
from sqlalchemy.ext.asyncio import AsyncSession

from rq.job import Job
from rq.exceptions import NoSuchJobError

from argilla_server.database import get_async_db
from argilla_server.jobs.queues import REDIS_CONNECTION
from argilla_server.models import User
from argilla_server.api.policies.v1 import JobPolicy, authorize
from argilla_server.api.schemas.v1.jobs import Job as JobSchema
from argilla_server.security import auth

router = APIRouter(tags=["jobs"])


def _get_job(job_id: str) -> Job:
try:
return Job.fetch(job_id, connection=REDIS_CONNECTION)
except NoSuchJobError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Job with id `{job_id}` not found",
)


@router.get("/jobs/{job_id}", response_model=JobSchema)
async def get_job(
*,
db: AsyncSession = Depends(get_async_db),
job_id: str,
current_user: User = Security(auth.get_current_user),
):
job = _get_job(job_id)

await authorize(current_user, JobPolicy.get)

return JobSchema(id=job.id, status=job.get_status(refresh=True))
2 changes: 2 additions & 0 deletions argilla-server/src/argilla_server/api/policies/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from argilla_server.api.policies.v1.vector_settings_policy import VectorSettingsPolicy
from argilla_server.api.policies.v1.workspace_policy import WorkspacePolicy
from argilla_server.api.policies.v1.workspace_user_policy import WorkspaceUserPolicy
from argilla_server.api.policies.v1.job_policy import JobPolicy

__all__ = [
"DatasetPolicy",
Expand All @@ -37,6 +38,7 @@
"VectorSettingsPolicy",
"WorkspacePolicy",
"WorkspaceUserPolicy",
"JobPolicy",
"authorize",
"is_authorized",
]
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,10 @@ async def is_allowed(actor: User) -> bool:
return actor.is_owner or (actor.is_admin and await actor.is_member(dataset.workspace_id))

return is_allowed

@classmethod
def import_from_hub(cls, dataset: Dataset) -> PolicyAction:
async def is_allowed(actor: User) -> bool:
return actor.is_owner or (actor.is_admin and await actor.is_member(dataset.workspace_id))

return is_allowed
21 changes: 21 additions & 0 deletions argilla-server/src/argilla_server/api/policies/v1/job_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla_server.models import User


class JobPolicy:
@classmethod
async def get(cls, actor: User) -> bool:
return actor.is_owner or actor.is_admin
2 changes: 2 additions & 0 deletions argilla-server/src/argilla_server/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from argilla_server.api.handlers.v1 import (
workspaces as workspaces_v1,
)
from argilla_server.api.handlers.v1 import jobs as jobs_v1
from argilla_server.errors.base_errors import __ALL__
from argilla_server.errors.error_handler import APIErrorHandler

Expand Down Expand Up @@ -92,6 +93,7 @@ def create_api_v1():
users_v1.router,
vectors_settings_v1.router,
workspaces_v1.router,
jobs_v1.router,
oauth2_v1.router,
settings_v1.router,
]:
Expand Down
28 changes: 28 additions & 0 deletions argilla-server/src/argilla_server/api/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,31 @@ class DatasetUpdate(UpdateSchema):
metadata_: Optional[Dict[str, Any]] = Field(None, alias="metadata")

__non_nullable_fields__ = {"name", "allow_extra_metadata", "distribution"}


class HubDatasetMappingItem(BaseModel):
source: str = Field(..., description="The name of the column in the Hub Dataset")
target: str = Field(..., description="The name of the target resource in the Argilla Dataset")


class HubDatasetMapping(BaseModel):
fields: List[HubDatasetMappingItem] = Field(..., min_items=1)
metadata: Optional[List[HubDatasetMappingItem]] = []
suggestions: Optional[List[HubDatasetMappingItem]] = []
external_id: Optional[str] = None

@property
def sources(self) -> List[str]:
fields_sources = [field.source for field in self.fields]
metadata_sources = [metadata.source for metadata in self.metadata]
suggestions_sources = [suggestion.source for suggestion in self.suggestions]
external_id_source = [self.external_id] if self.external_id else []

return list(set(fields_sources + metadata_sources + suggestions_sources + external_id_source))


class HubDataset(BaseModel):
name: str
subset: str
split: str
mapping: HubDatasetMapping
Loading

0 comments on commit 7f25c0f

Please sign in to comment.