Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve HubDataset image processing support #5606

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions argilla-server/src/argilla_server/api/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ class HubDatasetMapping(BaseModel):
suggestions: Optional[List[HubDatasetMappingItem]] = []
external_id: Optional[str] = None

@property
def sources(self) -> List[str]:
fields_sources = [field.source for field in self.fields]
metadata_sources = [metadata.source for metadata in self.metadata]
suggestions_sources = [suggestion.source for suggestion in self.suggestions]
external_id_source = [self.external_id] if self.external_id else []

return list(set(fields_sources + metadata_sources + suggestions_sources + external_id_source))


class HubDataset(BaseModel):
name: str
Expand Down
68 changes: 48 additions & 20 deletions argilla-server/src/argilla_server/contexts/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from datasets import load_dataset
from sqlalchemy.ext.asyncio import AsyncSession


from argilla_server.models.database import Dataset
from argilla_server.search_engine import SearchEngine
from argilla_server.bulk.records_bulk import UpsertRecordsBulk
Expand All @@ -34,13 +33,24 @@
BATCH_SIZE = 100
RESET_ROW_IDX = -1

FEATURE_TYPE_IMAGE = "Image"
FEATURE_TYPE_CLASS_LABEL = "ClassLabel"

DATA_URL_DEFAULT_IMAGE_FORMAT = "png"
DATA_URL_DEFAULT_IMAGE_MIMETYPE = "image/png"


class HubDataset:
def __init__(self, name: str, subset: str, split: str, mapping: HubDatasetMapping):
self.dataset = load_dataset(path=name, name=subset, split=split, streaming=True)
self.mapping = mapping
self.mapping_feature_names = mapping.sources
self.row_idx = RESET_ROW_IDX

@property
def features(self) -> dict:
return self.dataset.features

def take(self, n: int) -> Self:
self.dataset = self.dataset.take(n)

Expand Down Expand Up @@ -71,53 +81,68 @@ async def _import_batch_to(

items = []
for i in range(batch_size):
items.append(self._batch_row_to_record_schema(batch, i, dataset))
items.append(self._row_to_record_schema(self._batch_index_to_row(batch, i), dataset))

await UpsertRecordsBulk(db, search_engine).upsert_records_bulk(
dataset,
RecordsBulkUpsertSchema(items=items),
raise_on_error=False,
)

def _batch_row_to_record_schema(self, batch: dict, index: int, dataset: Dataset) -> RecordUpsertSchema:
def _batch_index_to_row(self, batch: dict, index: int) -> dict:
row = {}
for feature_name, values in batch.items():
if not feature_name in self.mapping_feature_names:
continue

value = values[index]
feature = self.features[feature_name]

if feature._type == FEATURE_TYPE_CLASS_LABEL:
row[feature_name] = feature.int2str(value)
elif feature._type == FEATURE_TYPE_IMAGE and isinstance(value, Image.Image):
row[feature_name] = pil_image_to_data_url(value)
else:
row[feature_name] = value

return row

def _row_to_record_schema(self, row: dict, dataset: Dataset) -> RecordUpsertSchema:
return RecordUpsertSchema(
id=None,
external_id=self._batch_row_external_id(batch, index),
fields=self._batch_row_fields(batch, index, dataset),
metadata=self._batch_row_metadata(batch, index, dataset),
suggestions=self._batch_row_suggestions(batch, index, dataset),
external_id=self._row_external_id(row),
fields=self._row_fields(row, dataset),
metadata=self._row_metadata(row, dataset),
suggestions=self._row_suggestions(row, dataset),
responses=None,
vectors=None,
)

def _batch_row_external_id(self, batch: dict, index: int) -> str:
def _row_external_id(self, row: dict) -> str:
if not self.mapping.external_id:
return str(self._next_row_idx())

return batch[self.mapping.external_id][index]
return row[self.mapping.external_id]

def _batch_row_fields(self, batch: dict, index: int, dataset: Dataset) -> dict:
def _row_fields(self, row: dict, dataset: Dataset) -> dict:
fields = {}
for mapping_field in self.mapping.fields:
value = batch[mapping_field.source][index]
value = row[mapping_field.source]
field = dataset.field_by_name(mapping_field.target)
if not field:
continue

if field.is_text and value is not None:
value = str(value)

if field.is_image and isinstance(value, Image.Image):
value = pil_image_to_data_url(value)

fields[field.name] = value

return fields

def _batch_row_metadata(self, batch: dict, index: int, dataset: Dataset) -> dict:
def _row_metadata(self, row: dict, dataset: Dataset) -> dict:
metadata = {}
for mapping_metadata in self.mapping.metadata:
value = batch[mapping_metadata.source][index]
value = row[mapping_metadata.source]
metadata_property = dataset.metadata_property_by_name(mapping_metadata.target)
if not metadata_property:
continue
Expand All @@ -126,10 +151,10 @@ def _batch_row_metadata(self, batch: dict, index: int, dataset: Dataset) -> dict

return metadata

def _batch_row_suggestions(self, batch: dict, index: int, dataset: Dataset) -> list:
def _row_suggestions(self, row: dict, dataset: Dataset) -> list:
suggestions = []
for mapping_suggestion in self.mapping.suggestions:
value = batch[mapping_suggestion.source][index]
value = row[mapping_suggestion.source]
question = dataset.question_by_name(mapping_suggestion.target)
if not question:
continue
Expand All @@ -156,8 +181,11 @@ def _batch_row_suggestions(self, batch: dict, index: int, dataset: Dataset) -> l
def pil_image_to_data_url(image: Image.Image):
buffer = io.BytesIO()

image.save(buffer, format=image.format)
image_format = image.format or DATA_URL_DEFAULT_IMAGE_FORMAT
image_mimetype = image.get_format_mimetype() if image.format else DATA_URL_DEFAULT_IMAGE_MIMETYPE

image.convert("RGB").save(buffer, format=image_format)

base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")

return f"data:{image.get_format_mimetype()};base64,{base64_image}"
return f"data:{image_mimetype};base64,{base64_image}"
76 changes: 74 additions & 2 deletions argilla-server/tests/unit/contexts/hub/test_hub_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tests.factories import (
DatasetFactory,
ImageFieldFactory,
RatingQuestionFactory,
QuestionFactory,
TextFieldFactory,
IntegerMetadataPropertyFactory,
)
Expand Down Expand Up @@ -86,7 +86,7 @@ async def test_hub_dataset_import_to_with_suggestions(self, db: AsyncSession, mo
await TextFieldFactory.create(name="package_name", required=True, dataset=dataset)
await TextFieldFactory.create(name="review", required=True, dataset=dataset)

question = await RatingQuestionFactory.create(
question = await QuestionFactory.create(
name="star",
required=True,
dataset=dataset,
Expand Down Expand Up @@ -127,6 +127,78 @@ async def test_hub_dataset_import_to_with_suggestions(self, db: AsyncSession, mo
assert record.suggestions[0].value == 4
assert record.suggestions[0].question_id == question.id

async def test_hub_dataset_import_to_with_class_label_suggestions(
self, db: AsyncSession, mock_search_engine: SearchEngine
):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)

await TextFieldFactory.create(name="text", required=True, dataset=dataset)

question = await QuestionFactory.create(
name="label",
settings={
"type": QuestionType.label_selection,
"options": [
{"value": "neg", "text": "Negative"},
{"value": "pos", "text": "Positive"},
],
},
dataset=dataset,
)

await dataset.awaitable_attrs.fields
await dataset.awaitable_attrs.questions
await dataset.awaitable_attrs.metadata_properties

hub_dataset = HubDataset(
name="stanfordnlp/imdb",
subset="plain_text",
split="train",
mapping=HubDatasetMapping(
fields=[
HubDatasetMappingItem(source="text", target="text"),
],
suggestions=[
HubDatasetMappingItem(source="label", target="label"),
],
),
)

await hub_dataset.take(1).import_to(db, mock_search_engine, dataset)

record = (await db.execute(select(Record))).scalar_one()
assert record.suggestions[0].value == "neg"
assert record.suggestions[0].question_id == question.id

async def test_hub_dataset_import_to_with_class_label_fields(
self, db: AsyncSession, mock_search_engine: SearchEngine
):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)

await TextFieldFactory.create(name="text", required=True, dataset=dataset)
await TextFieldFactory.create(name="label", required=True, dataset=dataset)

await dataset.awaitable_attrs.fields
await dataset.awaitable_attrs.questions
await dataset.awaitable_attrs.metadata_properties

hub_dataset = HubDataset(
name="stanfordnlp/imdb",
subset="plain_text",
split="train",
mapping=HubDatasetMapping(
fields=[
HubDatasetMappingItem(source="text", target="text"),
HubDatasetMappingItem(source="label", target="label"),
],
),
)

await hub_dataset.take(1).import_to(db, mock_search_engine, dataset)

record = (await db.execute(select(Record))).scalar_one()
assert record.fields["label"] == "neg"

async def test_hub_dataset_import_to_with_image_fields(self, db: AsyncSession, mock_search_engine: SearchEngine):
dataset = await DatasetFactory.create(status=DatasetStatus.ready)

Expand Down