Skip to content

Commit

Permalink
Merge pull request #44 from tma15/feature/load_data_from_hf
Browse files Browse the repository at this point in the history
Feature/load data from hf
  • Loading branch information
tma15 authored Feb 23, 2024
2 parents e157b71 + 6c8f2bd commit 154e4f7
Show file tree
Hide file tree
Showing 23 changed files with 201 additions and 173 deletions.
46 changes: 28 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ Example of `sklearn.svm.SVC`

```yaml
data:
train: train.jsonl
dev: dev.jsonl
test: test.jsonl
label_column: category
text_column: title
args:
path: data/jsonl

output_dir: models/svm-model

Expand Down Expand Up @@ -53,9 +54,10 @@ Example of BERT
```yaml
data:
train: train.jsonl
dev: dev.jsonl
test: test.jsonl
label_column: category
text_column: title
args:
path: data/jsonl

output_dir: models/transformer-model

Expand Down Expand Up @@ -96,36 +98,44 @@ You can set data-related settings in `data`.

```yaml
data:
train: train.jsonl # training data
dev: dev.jsonl # development data
test: test.jsonl # test data
label_column: label
text_column: text
label_column: category
text_column: title
args:
# Use local data in `data/jsonl`. In this path is assumed to contain data files such as train.jsonl, validation.jsonl and test.jsonl
path: data/jsonl

# If you want to use data on Hugging Face Hub, use the following args instead.
# Data is from https://huggingface.co/datasets/shunk031/livedoor-news-corpus
# path: shunk031/livedoor-news-corpus
# random_state: 0
# shuffle: true

```

You can set local files in `train`, `dev`, and `test`.
Supported types are `csv`, `json` and `jsonl`.
data is loaded via [datasets.load_dataset](https://huggingface.co/docs/datasets/main/en/package_reference/loading_methods#datasets.load_dataset).
So, you can load local data as well as data on [Hugging Face Hub](https://huggingface.co/datasets).
When loading data, `args` are passed to `load_dataset`.

`label_column` and `text_column` are field names of label and text.
When you set `label_column` to `label` and `text_column` to `text`, which are the default values, actual data must be as follows:

Format of `csv`:

```csv
label,text
label_name,sentence
category,sentence
sports,I like sports!
```

Format of `json`:

```json
[{"label", "label_name", "text": "sentence"}]
[{"category", "sports", "text": "I like sports!"}]
```

Format of `jsonl`:

```json
{"label", "label_name", "text": "sentence"}
{"category", "sports", "text": "I like suports!"}
```

### pipeline
Expand Down
54 changes: 25 additions & 29 deletions bunruija/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
from collections import UserDict

import datasets
from datasets import (
Dataset,
DatasetDict,
Expand All @@ -10,43 +11,38 @@


def load_data(
data_path: str | Path,
dataset_args: UserDict,
split: datasets.Split,
label_column: str = "label",
text_column: str | list[str] = "text",
) -> tuple[list[str], list[str] | list[list[str]]]:
if isinstance(data_path, str):
data_path = Path(data_path)
dataset: Dataset | DatasetDict | IterableDataset | IterableDatasetDict = (
load_dataset(split=split, **dataset_args)
)
assert isinstance(dataset, Dataset)

labels: list[str] = []
texts: list[str] | list[list[str]]
texts = [] # type: ignore

if data_path.suffix in [".csv", ".json", ".jsonl"]:
suffix: str = data_path.suffix[1:]
for idx, sample in enumerate(dataset):
label: str

# Because datasets does not support jsonl suffix, convert it to json
if suffix == "jsonl":
suffix = "json"
# If feature of label has names attribute, convert label to actual label strings
if hasattr(dataset.features[label_column], "names"):
label = dataset.features[label_column].names[sample[label_column]]
else:
label = sample[label_column]

# When data_files is only a single data_path, data split is "train"
dataset: DatasetDict | Dataset | IterableDataset | IterableDatasetDict = (
load_dataset(suffix, data_files=str(data_path), split="train")
)
assert isinstance(dataset, Dataset)
labels.append(label)

for idx, sample in enumerate(dataset):
labels.append(sample[label_column])
if isinstance(text_column, str):
input_example = sample[text_column]
texts.append(input_example)
elif isinstance(text_column, list):
if len(text_column) != 2:
raise ValueError(f"{len(text_column)=}")

if isinstance(text_column, str):
input_example = sample[text_column]
texts.append(input_example)
elif isinstance(text_column, list):
if len(text_column) != 2:
raise ValueError(f"{len(text_column)=}")

input_example = [sample[text_column[0]], sample[text_column[1]]]
texts.append(input_example)
return labels, texts

else:
raise ValueError(data_path.suffix)
input_example = [sample[text_column[0]], sample[text_column[1]]]
texts.append(input_example)
return labels, texts
20 changes: 9 additions & 11 deletions bunruija/dataclass.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import UserDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
Expand All @@ -13,32 +14,29 @@ class PipelineUnit:

@dataclass
class DataConfig:
train: Path = field(default_factory=Path)
dev: Path = field(default_factory=Path)
test: Path = field(default_factory=Path)
label_column: str = "label"
text_column: str = "text"

def __post_init__(self):
self.train = Path(self.train)
self.dev = Path(self.dev)
self.test = Path(self.test)
text_column: str | list[str] = "text"


@dataclass
class BunruijaConfig:
data: DataConfig
pipeline: list[PipelineUnit]
output_dir: Path
data: DataConfig | None = None
dataset_args: UserDict | None = None

@classmethod
def from_yaml(cls, config_file):
with open(config_file) as f:
yaml = ruamel.yaml.YAML()
config = yaml.load(f)

label_column: str = config["data"].pop("label_column", "label")
text_column: str | list[str] = config["data"].pop("text_column", "text")

return cls(
data=DataConfig(**config["data"]),
data=DataConfig(label_column=label_column, text_column=text_column),
pipeline=[PipelineUnit(**unit) for unit in config["pipeline"]],
output_dir=Path(config.get("output_dir", "output")),
dataset_args=UserDict(config["data"]["args"]),
)
5 changes: 4 additions & 1 deletion bunruija/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from argparse import Namespace

import datasets
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix

Expand All @@ -18,10 +19,12 @@ def __init__(self, args: Namespace):

def evaluate(self):
labels_test, X_test = load_data(
self.config.data.test,
self.config.dataset_args,
split=datasets.Split.TEST,
label_column=self.config.data.label_column,
text_column=self.config.data.text_column,
)

y_test: np.ndarray = self.predictor.label_encoder.transform(labels_test)
y_pred: np.ndarray = self.predictor(X_test)

Expand Down
8 changes: 5 additions & 3 deletions bunruija/gen_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@ def main(args):

setting = {
"data": {
"train": "train.csv",
"dev": "dev.csv",
"test": "test.csv",
"label_column": "label",
"text_coliumn": "text",
"args": {
"path": "",
},
},
"pipeline": [
infer_vectorizer(model_cls),
Expand Down
34 changes: 18 additions & 16 deletions bunruija/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datasets
import numpy as np
import sklearn # type: ignore
from sklearn.preprocessing import LabelEncoder # type: ignore
Expand All @@ -17,7 +18,8 @@ def __init__(self, config_file: str):

def train(self):
labels_train, X_train = load_data(
self.config.data.train,
self.config.dataset_args,
split=datasets.Split.TRAIN,
label_column=self.config.data.label_column,
text_column=self.config.data.text_column,
)
Expand All @@ -29,21 +31,21 @@ def train(self):

self.saver(self.model, label_encoder)

if self.config.data.dev.exists():
labels_dev, X_dev = load_data(
self.config.data.dev,
label_column=self.config.data.label_column,
text_column=self.config.data.text_column,
)
labels_dev, X_dev = load_data(
self.config.dataset_args,
split=datasets.Split.VALIDATION,
label_column=self.config.data.label_column,
text_column=self.config.data.text_column,
)

y_dev: np.ndarray = label_encoder.transform(labels_dev)
y_dev: np.ndarray = label_encoder.transform(labels_dev)

y_pred = self.model.predict(X_dev)
y_pred = self.model.predict(X_dev)

fscore = sklearn.metrics.f1_score(y_dev, y_pred, average="micro")
print(f"F-score on dev: {fscore}")
target_names: list[str] = list(label_encoder.classes_)
report = sklearn.metrics.classification_report(
y_pred, y_dev, target_names=target_names
)
print(report)
fscore = sklearn.metrics.f1_score(y_dev, y_pred, average="micro")
print(f"F-score on dev: {fscore}")
target_names: list[str] = list(label_encoder.classes_)
report = sklearn.metrics.classification_report(
y_pred, y_dev, target_names=target_names
)
print(report)
8 changes: 5 additions & 3 deletions example/jglue/jcola/create_jcola_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def write_json(ds: Dataset, name: Path):
for sample in ds:
category: str = ds.features["label"].names[sample["label"]]
sample_ = {
"text": sample["sentence"],
"sentence": sample["sentence"],
"label": category,
}
print(json.dumps(sample_), file=f)
Expand All @@ -20,7 +20,9 @@ def write_json(ds: Dataset, name: Path):

def main():
parser = ArgumentParser()
parser.add_argument("--output_dir", default="example/jglue/jcola/data", type=Path)
parser.add_argument(
"--output_dir", default="example/jglue/jcola/data/jsonl", type=Path
)
args = parser.parse_args()

if not args.output_dir.exists():
Expand All @@ -29,7 +31,7 @@ def main():
dataset = load_dataset("shunk031/JGLUE", name="JCoLA")

write_json(dataset["train"], args.output_dir / "train.jsonl")
write_json(dataset["validation"], args.output_dir / "dev.jsonl")
write_json(dataset["validation"], args.output_dir / "validation.jsonl")


if __name__ == "__main__":
Expand Down
10 changes: 6 additions & 4 deletions example/jglue/jnli/create_jnli_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def write_json(ds: Dataset, name: Path):
for sample in ds:
category: str = ds.features["label"].names[sample["label"]]
sample_ = {
"text1": sample["sentence1"],
"text2": sample["sentence2"],
"sentence1": sample["sentence1"],
"sentence2": sample["sentence2"],
"label": category,
}
print(json.dumps(sample_), file=f)
Expand All @@ -21,7 +21,9 @@ def write_json(ds: Dataset, name: Path):

def main():
parser = ArgumentParser()
parser.add_argument("--output_dir", default="example/jglue/jnli/data", type=Path)
parser.add_argument(
"--output_dir", default="example/jglue/jnli/data/jsonl", type=Path
)
args = parser.parse_args()

if not args.output_dir.exists():
Expand All @@ -31,7 +33,7 @@ def main():
print(dataset)

write_json(dataset["train"], args.output_dir / "train.jsonl")
write_json(dataset["validation"], args.output_dir / "dev.jsonl")
write_json(dataset["validation"], args.output_dir / "validation.jsonl")


if __name__ == "__main__":
Expand Down
8 changes: 5 additions & 3 deletions example/jglue/marc_ja/create_marc_ja_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def write_json(ds: Dataset, name: Path):
for sample in ds:
category: str = ds.features["label"].names[sample["label"]]
sample_ = {
"text": sample["sentence"],
"sentence": sample["sentence"],
"label": category,
}
print(json.dumps(sample_), file=f)
Expand All @@ -20,7 +20,9 @@ def write_json(ds: Dataset, name: Path):

def main():
parser = ArgumentParser()
parser.add_argument("--output_dir", default="example/jglue/jcola/data", type=Path)
parser.add_argument(
"--output_dir", default="example/jglue/marc_ja/data/jsonl", type=Path
)
args = parser.parse_args()

if not args.output_dir.exists():
Expand All @@ -29,7 +31,7 @@ def main():
dataset = load_dataset("shunk031/JGLUE", name="MARC-ja")

write_json(dataset["train"], args.output_dir / "train.jsonl")
write_json(dataset["validation"], args.output_dir / "dev.jsonl")
write_json(dataset["validation"], args.output_dir / "validation.jsonl")


if __name__ == "__main__":
Expand Down
8 changes: 6 additions & 2 deletions example/jglue/settings/classification/svm.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
data:
train: data/train.jsonl
dev: data/dev.jsonl
label_column: label
text_column: sentence
args:
path: shunk031/JGLUE
name: JCoLA
# path: data/jsonl

output_dir: models/svm-model

Expand Down
Loading

0 comments on commit 154e4f7

Please sign in to comment.