Skip to content

Commit

Permalink
fix(tests): fix tests and dataprep module (#182)
Browse files Browse the repository at this point in the history
* fix(tests): fix tests and dataprep module

* fix(tests): temporarily remove test run on Windows OS
  • Loading branch information
KevKibe authored Sep 10, 2024
1 parent 31ce34f commit b34b78d
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 61 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/deployment.speech_inference_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, macOS-latest, windows-latest]
os: [ubuntu-latest, macOS-latest]

runs-on: ${{ matrix.os }}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
name: Test deployment.peft_speech_inference Module.
name: Test training.model_trainer Module.

on: [pull_request]

jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, macOS-latest, windows-latest]
os: [ubuntu-latest, macOS-latest]

runs-on: ${{ matrix.os }}

Expand Down Expand Up @@ -45,4 +45,4 @@ jobs:
HF_READ_TOKEN: ${{ secrets.HF_READ_TOKEN }}
HF_WRITE_TOKEN: ${{ secrets.HF_WRITE_TOKEN }}
WANDB_TOKEN: ${{ secrets.WANDB_TOKEN }}
run: pytest src/tests/test_peft_speech_inference.py
run: pytest src/tests/test_model_prep.py
2 changes: 1 addition & 1 deletion .github/workflows/training.model_trainer_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, macOS-latest, windows-latest]
os: [ubuntu-latest, macOS-latest]

runs-on: ${{ matrix.os }}

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/training_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, macOS-latest, windows-latest]
os: [ubuntu-latest, macOS-latest]

runs-on: ${{ matrix.os }}

Expand Down
7 changes: 6 additions & 1 deletion src/tests/test_audio_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ def setUp(self):
dataset_name="mozilla-foundation/common_voice_16_1",
language_abbr=["yi", "ti"]
)
self.dataset = self.data_loader.load_dataset()
self.dataset = self.data_loader.load_dataset(train_num_samples=10, test_num_samples=10)
has_train_sample = any(True for _ in self.dataset["train"])
assert has_train_sample, "Train dataset is empty!"

has_test_sample = any(True for _ in self.dataset["test"])
assert has_test_sample, "Test dataset is empty!"

# Initialize model preparation
self.model_prep = WhisperModelPrep(
Expand Down
5 changes: 5 additions & 0 deletions src/tests/test_data_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ def test_load_dataset(self):
"""Test the load_dataset method."""
tokenizer, feature_extractor, processor, model = self.data_prep.prepare_model()
dataset = self.data_prep.load_dataset(feature_extractor, tokenizer, processor, train_num_samples = 10, test_num_samples=10)
has_train_sample = any(True for _ in dataset["train"])
assert has_train_sample, "Train dataset is empty!"

has_test_sample = any(True for _ in dataset["test"])
assert has_test_sample, "Test dataset is empty!"
self.assertIsInstance(dataset, dict)
self.assertIsInstance(dataset["train"], IterableDataset)
self.assertIsInstance(dataset["test"], IterableDataset)
Expand Down
5 changes: 4 additions & 1 deletion src/tests/test_load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ def test_load_dataset(self):
"""Test loading the dataset and verifying its contents."""
# Act
data = self.dataset_manager.load_dataset(train_num_samples=10, test_num_samples = 10)

has_train_sample = any(True for _ in data["train"])
assert has_train_sample, "Train dataset is empty!"
has_test_sample = any(True for _ in data["test"])
assert has_test_sample, "Test dataset is empty!"
# Assert
self.assertIsNotNone(data, "The loaded dataset should not be None.")
self.assertIn("train", data, "The dataset should contain a 'train' split.")
Expand Down
16 changes: 14 additions & 2 deletions src/tests/test_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ def setUp(self) -> None:
use_peft=False,
)
tokenizer, feature_extractor, feature_processor, model = process.prepare_model()
dataset = process.load_dataset(feature_extractor, tokenizer, feature_processor)
dataset = process.load_dataset(feature_extractor, tokenizer, feature_processor, train_num_samples=10, test_num_samples=10)

has_train_sample = any(True for _ in dataset["train"])
assert has_train_sample, "Train dataset is empty!"

has_test_sample = any(True for _ in dataset["test"])
assert has_test_sample, "Test dataset is empty!"

self.trainer = Trainer(
huggingface_write_token= os.environ.get("HF_WRITE_TOKEN"),
model_id=self.model_id,
Expand All @@ -34,6 +41,11 @@ def setUp(self) -> None:
return super().setUp()

def test_train(self):
# print(self.trainer.dataset['train'])
# data_loader = self.trainer.get_train_dataloader()
# for batch in data_loader:
# print(batch)
# assert batch is not None, "Empty batch found!"
self.trainer.train(
max_steps = 10,
learning_rate = 1e-5,
Expand All @@ -48,7 +60,7 @@ def test_train(self):
)
assert os.path.exists(f"../{self.model_id}-finetuned/preprocessor_config.json")
assert os.path.exists(f"../{self.model_id}-finetuned/tokenizer_config.json")


if __name__ == '__main__':
unittest.main()
49 changes: 0 additions & 49 deletions src/tests/test_peft_speech_inference.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/training/data_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
Initializes the Trainer with the necessary configuration and loads the evaluation metric.
Parameters:
huggingface_token (str): Hugging Face API token for authenticated access.
huggingface_read_token (str): Hugging Face API token for authenticated access.
dataset_name (str): Name of the dataset to be downloaded from Hugging Face.
language_abbr (str): Language abbreviation for the dataset.
model_id (str): Model ID for the model to be used in training.
Expand All @@ -56,8 +56,8 @@ def __init__(
def prepare_model(
self,
) -> Tuple[
WhisperFeatureExtractor,
WhisperTokenizer,
WhisperFeatureExtractor,
WhisperProcessor,
WhisperForConditionalGeneration,
]:
Expand Down

0 comments on commit b34b78d

Please sign in to comment.