Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added some visualizations for the HF dataset #2851

Merged
merged 10 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) ZenML GmbH 2021. All Rights Reserved.
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,12 +16,21 @@
import os
from collections import defaultdict
from tempfile import TemporaryDirectory, mkdtemp
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type, Union
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Optional,
Tuple,
Type,
Union,
)

from datasets import Dataset, load_from_disk
from datasets.dataset_dict import DatasetDict

from zenml.enums import ArtifactType
from zenml.enums import ArtifactType, VisualizationType
from zenml.io import fileio
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.materializers.pandas_materializer import PandasMaterializer
Expand All @@ -33,13 +42,38 @@
DEFAULT_DATASET_DIR = "hf_datasets"


def extract_repo_url(checksum_str: str) -> Optional[str]:
"""Extracts the repo url from the checksum URL.
schustmi marked this conversation as resolved.
Show resolved Hide resolved

An example of a checksum_str is:
"hf://datasets/nyu-mll/glue@bcdcba79d07bc864c1c254ccfcedcce55bcc9a8c/mrpc/train-00000-of-00001.parquet"
and the expected output is "nyu-mll/glue".

Args:
checksum_str: The checksum_str to extract the dataset name from.

Returns:
Optional[str]: The extracted dataset name.
"""
schustmi marked this conversation as resolved.
Show resolved Hide resolved
dataset = None
try:
parts = checksum_str.split("/")
if len(parts) >= 4:
# Case: nyu-mll/glue
dataset = f"{parts[3]}/{parts[4].split('@')[0]}"
except Exception: # pylint: disable=broad-except
pass

return dataset


class HFDatasetMaterializer(BaseMaterializer):
"""Materializer to read data to and from huggingface datasets."""

ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Dataset, DatasetDict)
ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = (
ArtifactType.DATA_ANALYSIS
)
ASSOCIATED_ARTIFACT_TYPE: ClassVar[
ArtifactType
] = ArtifactType.DATA_ANALYSIS

def load(
self, data_type: Union[Type[Dataset], Type[DatasetDict]]
Expand Down Expand Up @@ -103,3 +137,53 @@ def extract_metadata(
metadata[key][dataset_name] = value
return dict(metadata)
raise ValueError(f"Unsupported type {type(ds)}")

def save_visualizations(
self, ds: Union[Dataset, DatasetDict]
) -> Dict[str, VisualizationType]:
"""Save visualizations for the dataset.

Args:
ds: The Dataset or DatasetDict to visualize.

Returns:
A dictionary mapping visualization paths to their types.

Raises:
ValueError: If the given object is not a `Dataset` or `DatasetDict`.
"""
visualizations = {}

if isinstance(ds, Dataset):
datasets = {"default": ds}
elif isinstance(ds, DatasetDict):
datasets = ds
else:
raise ValueError(f"Unsupported type {type(ds)}")

for name, dataset in datasets.items():
# Generate a unique identifier for the dataset
dataset_id = extract_repo_url(
[x for x in dataset.info.download_checksums.keys()][0]
)
if dataset_id:
# Create the iframe HTML
html = f"""
<iframe
src="https://huggingface.co/datasets/{dataset_id}/embed/viewer"
frameborder="0"
width="100%"
height="560px"
></iframe>
"""

# Save the HTML to a file
visualization_path = os.path.join(
self.uri, f"{name}_viewer.html"
)
with fileio.open(visualization_path, "w") as f:
f.write(html)

visualizations[visualization_path] = VisualizationType.HTML

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dashboard only works with a single visualization per artifact right now, but that shouldn't stop this PR. Just so you're aware of it.

return visualizations
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tests.unit.test_general import _test_materializer
from zenml.integrations.huggingface.materializers.huggingface_datasets_materializer import (
HFDatasetMaterializer,
extract_repo_url,
)


Expand All @@ -35,3 +36,63 @@ def test_huggingface_datasets_materializer(clean_client):
data = dataset.data.to_pydict()
assert "0" in data.keys()
assert [1, 2, 3] in data.values()


def test_extract_repo_url():
"""Tests whether the extract_repo_url function works correctly."""
# Test valid URL
url = "hf://datasets/nyu-mll/glue@bcdcba79d07bc864c1c254ccfcedcce55bcc9a8c/mrpc/train-00000-of-00001.parquet"
assert extract_repo_url(url) == "nyu-mll/glue"

# Test valid URL with different dataset
url = "hf://datasets/huggingface/dataset-name@123456/subset/file.parquet"
assert extract_repo_url(url) == "huggingface/dataset-name"

# Test URL without file
url = "hf://datasets/org/repo@commit"
assert extract_repo_url(url) == "org/repo"

# Test URL with extra parts
url = "hf://datasets/org/repo/extra/parts@commit/file.parquet"
assert extract_repo_url(url) == "org/repo"

# Test invalid URL (too short)
url = "hf://datasets/org"
assert extract_repo_url(url) is None

# Test invalid URL format
url = "https://huggingface.co/datasets/org/repo"
assert extract_repo_url(url) is None

# Test empty string
assert extract_repo_url("") is None

# Test None input
assert extract_repo_url(None) is None


def test_extract_repo_url_edge_cases():
"""Tests edge cases for the extract_repo_url function."""
# Test URL with no '@' symbol
url = "hf://datasets/org/repo/file.parquet"
assert extract_repo_url(url) == "org/repo"

# Test URL with multiple '@' symbols
url = "hf://datasets/org/repo@commit@extra/file.parquet"
assert extract_repo_url(url) == "org/repo"

# Test URL with special characters in repo name
url = "hf://datasets/org-name/repo_name-123@commit/file.parquet"
assert extract_repo_url(url) == "org-name/repo_name-123"


def test_extract_repo_url_exceptions():
"""Tests exception handling in the extract_repo_url function."""
# Test with non-string input
assert extract_repo_url(123) is None

# Test with list input
assert extract_repo_url(["not", "a", "string"]) is None

# Test with dict input
assert extract_repo_url({"not": "a string"}) is None
Loading