Skip to content

Commit

Permalink
Improved demo datasets (#325)
Browse files Browse the repository at this point in the history
Adds more tests, and fixes several typing issues. Now, all data types should be correctly stored and be reusable.
  • Loading branch information
qubixes authored Sep 9, 2024
1 parent 42ffed6 commit 93f7ab0
Show file tree
Hide file tree
Showing 8 changed files with 350 additions and 119 deletions.
253 changes: 171 additions & 82 deletions metasyn/demo/dataset.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,195 @@
"""Create and retrieve demo datasets."""

# import random
import string
from abc import ABC, abstractproperty
from datetime import date, datetime, time, timedelta
from pathlib import Path

import numpy as np
import polars as pl

from metasyn.varspec import VarSpec

try:
from importlib_resources import files
except ImportError:
from importlib.resources import files # type: ignore

# import numpy as np
# import wget
_AVAILABLE_DATASETS = {}


def register(*args):
"""Register a dataset so that it can be found by name."""
def _wrap(cls):
_AVAILABLE_DATASETS[cls().name] = cls()
return cls
return _wrap(*args)


# from metasyn.distribution.datetime import (
# DateTimeUniformDistribution,
# DateUniformDistribution,
# TimeUniformDistribution,
# )
class BaseDataset(ABC):
"""Base class for demo datasets."""

@abstractproperty
def name(self):
pass

# def create_titanic_demo(output_fp: Path) -> Path:
# """Create demo dataset for the titanic dataset.
@property
def file_location(self):
return files(__package__) / f"demo_{self.name}.csv"

# Arguments
# ---------
# output_fp:
# File to write the demonstration table to.
def get_dataframe(self):
return pl.read_csv(self.file_location, schema_overrides=self.schema,
try_parse_dates=True)

# Returns
# -------
# Output file location.
# """
# titanic_fp = Path("titanic.csv")
# if Path(output_fp).is_file():
# return output_fp
# if not titanic_fp.is_file():
# wget.download(
# "https://raw.githubusercontent.com/pandas-dev/pandas/main/doc/data/titanic.csv")
# dframe = pd.read_csv(titanic_fp)
# np.random.seed(1283742)
# random.seed(1928374)
@abstractproperty
def schema(self):
pass

# # Convert Age to a nullable integer.
# dframe["Age"] = dframe["Age"].round().astype("Int64")
@property
def var_specs(self):
return []

# # Add a date column.
# date_dist = DateUniformDistribution.default_distribution()
# dframe["Birthday"] = [date_dist.draw() if np.random.rand() < 0.9 else pd.NA
# for _ in range(len(dframe))]

# # Add a time column.
@register
class TitanicDataset(BaseDataset):
"""Included in pandas, but post-processed to contain more columns."""

# time_dist = TimeUniformDistribution.default_distribution()
# dframe["Board time"] = [time_dist.draw() if np.random.rand() < 0.9 else pd.NA
# for _ in range(len(dframe))]
@property
def name(self):
return "titanic"

# # Add a datetime column
# time_dist = DateTimeUniformDistribution.default_distribution()
# dframe["Married since"] = [time_dist.draw() if np.random.rand() < 0.9 else pd.NA
# for _ in range(len(dframe))]
@property
def schema(self):
return {"Sex": pl.Categorical, "Embarked": pl.Categorical}

# dframe["all_NA"] = [pd.NA for _ in range(len(dframe))]
# # Remove some columns for brevity and write to a file.
# dframe = dframe.drop(["SibSp", "Pclass", "Survived"], axis=1)
# dframe.to_csv(output_fp, index=False)
# return output_fp
@property
def var_specs(self):
return [VarSpec("PassengerId", unique=True)]

@register
class SpaceShipDataset(BaseDataset):
"""CC-BY from https://www.kaggle.com/competitions/spaceship-titanic."""

@property
def name(self):
return "spaceship"

@property
def schema(self):
return {
"HomePlanet": pl.Categorical,
"CryoSleep": pl.Categorical,
"VIP": pl.Categorical,
"Destination": pl.Categorical,
"Transported": pl.Categorical,
}


@register
class FruitDataset(BaseDataset):
"""Very basic example data from Polars."""

@property
def name(self):
return "fruit"

@property
def schema(self):
return {"fruits": pl.Categorical, "cars": pl.Categorical}

@property
def var_specs(self):
return [VarSpec("ID", unique=True), VarSpec("B", unique=False)]


@register
class SurveyDataset(BaseDataset):
"""Columns from ESS round 11 Human Values Scale questionnaire for the Netherlands."""

@property
def name(self):
return "survey"

@property
def schema(self):
return {}


@register
class TestDataset(BaseDataset):
"""Test dataset with all supported data types."""

@property
def name(self):
return "test"

@property
def schema(self):
columns = pl.read_csv(self.file_location).columns
return {col_name: (getattr(pl, col_name[3:]) if col_name != "NA" else pl.String)
for col_name in columns}

@classmethod
def create(cls, csv_file):
all_series = []
n_rows = 100

for int_val in [8, 16, 32, 64]:
all_series.append(pl.Series(f"pl.Int{int_val}",
[np.random.randint(-10, 10) for _ in range(n_rows)],
dtype=getattr(pl, f"Int{int_val}")))
all_series.append(pl.Series(f"pl.UInt{int_val}",
[np.random.randint(10) for _ in range(n_rows)],
dtype=getattr(pl, f"UInt{int_val}")))

for float_val in [32, 64]:
all_series.append(pl.Series(f"pl.Float{float_val}",
np.random.randn(n_rows),
dtype=getattr(pl, f"Float{float_val}")))

all_series.append(pl.Series("pl.Date", [date(2024, 9, 4) + timedelta(days=i)
for i in range(n_rows)],
dtype=pl.Date))
all_series.append(pl.Series("pl.Datetime",
[datetime(2024, 9, 4, 12, 30, 12)
+ timedelta(hours=i, minutes=i*2, seconds=i*3)
for i in range(n_rows)],
dtype=pl.Datetime))
all_series.append(pl.Series("pl.Time",
[time(3+i//20, 6+i//12, 12+i//35) for i in range(n_rows)],
dtype=pl.Time))
all_series.append(pl.Series("pl.String",
np.random.choice(list(string.printable), size=n_rows),
dtype=pl.String))
all_series.append(pl.Series("pl.Utf8",
np.random.choice(list(string.printable), size=n_rows),
dtype=pl.Utf8))
all_series.append(pl.Series("pl.Categorical",
np.random.choice(list(string.ascii_uppercase[:5]), size=n_rows),
dtype=pl.Categorical))
all_series.append(pl.Series("pl.Boolean",
np.random.choice([True, False], size=n_rows),
dtype=pl.Boolean))
all_series.append(pl.Series("NA", [None for _ in range(n_rows)], dtype=pl.String))

# Add NA's for all series except the categorical
for series in all_series:
if series.name != "pl.Categorical":
none_idx = np.random.choice(np.arange(n_rows), size=n_rows//10, replace=False)
none_idx.sort()
series[none_idx] = None

# Write to a csv file
pl.DataFrame(all_series).write_csv(csv_file)


def _get_demo_class(name):
if name in _AVAILABLE_DATASETS:
return _AVAILABLE_DATASETS[name]
raise ValueError(
f"No demonstration dataset with name '{name}'. Options: {list(_AVAILABLE_DATASETS)}."
)


def demo_file(name: str = "titanic") -> Path:
Expand All @@ -92,18 +216,7 @@ def demo_file(name: str = "titanic") -> Path:
file, edition 1.0 [Data set]. Sikt - Norwegian Agency for Shared Services in Education and
Research. https://doi.org/10.21338/ess11e01_0
"""
if name == "titanic":
return files(__package__) / "demo_titanic.csv"
if name == "spaceship":
return files(__package__) / "demo_spaceship.csv"
if name == "fruit":
return files(__package__) / "demo_fruit.csv"
if name == "survey":
return files(__package__) / "demo_survey.csv"

raise ValueError(
f"No demonstration dataset with name '{name}'. Options: titanic, spaceship, fruit, survey."
)
return _get_demo_class(name).file_location


def demo_dataframe(name: str = "titanic") -> pl.DataFrame:
Expand All @@ -130,28 +243,4 @@ def demo_dataframe(name: str = "titanic") -> pl.DataFrame:
file, edition 1.0 [Data set]. Sikt - Norwegian Agency for Shared Services in Education and
Research. https://doi.org/10.21338/ess11e01_0
"""
file_path = demo_file(name=name)
if name == "spaceship":
# the Kaggle spaceship data (CC-BY)
data_types = {
"HomePlanet": pl.Categorical,
"CryoSleep": pl.Categorical,
"VIP": pl.Categorical,
"Destination": pl.Categorical,
"Transported": pl.Categorical,
}
return pl.read_csv(file_path, schema_overrides=data_types, try_parse_dates=True)
if name == "titanic":
# our edited titanic data
data_types = {"Sex": pl.Categorical, "Embarked": pl.Categorical}
return pl.read_csv(file_path, schema_overrides=data_types, try_parse_dates=True)
if name == "fruit":
# basic fruit data from polars example
data_types = {"fruits": pl.Categorical, "cars": pl.Categorical}
return pl.read_csv(file_path, schema_overrides=data_types)
if name == "survey":
return pl.read_csv(file_path)

raise ValueError(
f"No demonstration dataset with name '{name}'. Options: titanic, spaceship, fruit."
)
return _get_demo_class(name).get_dataframe()
Loading

0 comments on commit 93f7ab0

Please sign in to comment.