Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update: language_abbr parameter to a string #220

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DOCS/gettingstarted.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
huggingface_token = " " # make sure token has write permissions
dataset_name = "mozilla-foundation/common_voice_16_1" # Also supports "google/fleurs" and "facebook/multilingual_librispeech".
# For custom datasets, ensure the text key is one of the following: "sentence", "transcript", or "transcription".
language_abbr= [ ] # Example `["af"]`. see specific dataset for language code.
language_abbr= " " # Example `"af"`. see specific dataset for language code.
model_id= "model-id" # Example openai/whisper-small, openai/whisper-medium
processing_task= "translate" # translate or transcribe
wandb_api_key = " "
Expand Down
4 changes: 2 additions & 2 deletions src/training/data_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .load_data import Dataset
from .whisper_model_prep import WhisperModelPrep
from .audio_data_processor import AudioDataProcessor
from typing import Tuple, List
from typing import Tuple
import warnings
warnings.filterwarnings("ignore")

Expand All @@ -39,7 +39,7 @@ def __init__(
self,
huggingface_token: str,
dataset_name: str,
language_abbr: List[str],
language_abbr: str,
model_id: str,
processing_task: str,
use_peft: bool,
Expand Down
76 changes: 40 additions & 36 deletions src/training/load_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from datasets import load_dataset, IterableDatasetDict, concatenate_datasets
from datasets import load_dataset, IterableDatasetDict
import warnings
from typing import List
from datasets import DatasetDict
from huggingface_hub import HfFolder
warnings.filterwarnings("ignore")
Expand All @@ -15,7 +14,7 @@ class Dataset:
language_abbr (str): Abbreviation of the language for the dataset.
"""

def __init__(self, huggingface_token: str, dataset_name: str, language_abbr: List[str]):
def __init__(self, huggingface_token: str, dataset_name: str, language_abbr: str):
"""
Initializes the DatasetManager with necessary details for dataset operations.

Expand Down Expand Up @@ -46,39 +45,44 @@ def load_dataset(self, streaming: bool = True, train_num_samples: int = None, te
dict: A dictionary containing concatenated train and test splits for each language.
"""
data = {}
for lang in self.language_abbr:
train_dataset = load_dataset(self.dataset_name,
lang,
split='train',
streaming=streaming,
token=self.huggingface_token,
trust_remote_code=True)
test_dataset = load_dataset(self.dataset_name,
lang,
split='test',
streaming=streaming,
token=self.huggingface_token,
trust_remote_code=True)
if streaming:
train_split = train_dataset.take(train_num_samples) if train_num_samples else train_dataset
test_split = test_dataset.take(test_num_samples) if test_num_samples else test_dataset

else:

train_split = train_dataset if not train_num_samples or len(train_dataset) < train_num_samples else \
train_dataset.select(range(train_num_samples))

test_split = test_dataset if not test_num_samples or len(test_dataset) < test_num_samples else \
test_dataset.select(range(test_num_samples))

if "train" in data:
data["train"] = concatenate_datasets([data["train"], train_split])
else:
data["train"] = train_split
if "test" in data:
data["test"] = concatenate_datasets([data["test"], test_split])
else:
data["test"] = test_split

train_dataset = load_dataset(
self.dataset_name,
self.language_abbr,
split="train",
streaming=streaming,
token=self.huggingface_token,
trust_remote_code=True,
)
test_dataset = load_dataset(
self.dataset_name,
self.language_abbr,
split="test",
streaming=streaming,
token=self.huggingface_token,
trust_remote_code=True,
)

if streaming:
train_split = train_dataset.take(train_num_samples) if train_num_samples else train_dataset
else:
train_split = (
train_dataset
if not train_num_samples or len(train_dataset) < train_num_samples
else train_dataset.select(range(train_num_samples))
)

if streaming:
test_split = test_dataset.take(test_num_samples) if test_num_samples else test_dataset
else:
test_split = (
test_dataset
if not test_num_samples or len(test_dataset) < test_num_samples
else test_dataset.select(range(test_num_samples))
)

data["train"] = train_split
data["test"] = test_split

return data

Expand Down
3 changes: 1 addition & 2 deletions src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ def parse_args():
)
parser.add_argument(
"--language_abbr",
nargs='+',
required=True,
help="Abbreviation(s) of the language(s) for the dataset.",
help="Abbreviation of the language(s) for the dataset.",
)
parser.add_argument(
"--model_id",
Expand Down
4 changes: 2 additions & 2 deletions src/training/whisper_model_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class WhisperModelPrep:
def __init__(
self,
model_id: str,
language: list,
language: str,
processing_task: str,
use_peft: bool,
):
Expand All @@ -45,7 +45,7 @@ def __init__(
"""

self.model_id = model_id
self.language = language[0]
self.language = language
self.processing_task = processing_task
self.use_peft = use_peft

Expand Down
4 changes: 2 additions & 2 deletions tests/test_audio_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setUp(self):
self.data_loader = Dataset(
huggingface_token = os.environ.get("HF_TOKEN"),
dataset_name="mozilla-foundation/common_voice_16_1",
language_abbr=["af"]
language_abbr="af"
)
self.dataset_streaming = self.data_loader.load_dataset(streaming=True, train_num_samples=10, test_num_samples=10)
self.dataset_batch = self.data_loader.load_dataset(streaming=False, train_num_samples=10, test_num_samples=10)
Expand All @@ -38,7 +38,7 @@ def setUp(self):

# Initialize model preparation
self.model_prep = WhisperModelPrep(
language = ["af"],
language = "af",
model_id="openai/whisper-tiny",
processing_task="transcribe",
use_peft=False
Expand Down
4 changes: 2 additions & 2 deletions tests/test_data_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ def setUp(self):
self.data_prep = DataPrep(
huggingface_token= os.environ.get("HF_TOKEN"),
dataset_name="mozilla-foundation/common_voice_16_1",
language_abbr=["af"],
language_abbr="af",
model_id="openai/whisper-small",
processing_task="transcribe",
use_peft=False,
)
self.model_prep=WhisperModelPrep(
language= ["af"],
language= "af",
model_id="openai/whisper-small",
processing_task="transcribe",
use_peft=False),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def setUp(self):
self.dataset_manager = Dataset(
huggingface_token=os.environ.get("HF_TOKEN"),
dataset_name="mozilla-foundation/common_voice_16_1",
language_abbr=["af"]
language_abbr="af"
)

def test_load_dataset(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class TestDatasetManager(unittest.TestCase):
def setUp(self):
"""Initialize the test setup with an instance of WhisperModelPrep."""
self.model_prep = WhisperModelPrep(
language=["af"],
language="af",
model_id="openai/whisper-small",
processing_task="transcribe",
use_peft=False
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model_trainer_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def setUp(self) -> None:
process = DataPrep(
huggingface_token=os.environ.get("HF_TOKEN"),
dataset_name="mozilla-foundation/common_voice_16_1",
language_abbr=["af"],
language_abbr="af",
model_id=self.model_id,
processing_task="transcribe",
use_peft=False,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model_trainer_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def setUp(self) -> None:
process = DataPrep(
huggingface_token=os.environ.get("HF_TOKEN"),
dataset_name="mozilla-foundation/common_voice_16_1",
language_abbr=["af"],
language_abbr="af",
model_id=self.model_id,
processing_task="transcribe",
use_peft=True,
Expand Down
Loading