-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add unit tests for all modules #138
Merged
Changes from 37 commits
Commits
Show all changes
51 commits
Select commit
Hold shift + click to select a range
27a921a
test: Add unit tests for LighterSystem functionality and behavior
ibro45 e3a784f
test: Add comprehensive unit tests for all modules in the lighter pac…
ibro45 e272ce8
fix: Import torch in test files to resolve undefined name errors
ibro45 2eeab00
fix: Resolve optimizer empty parameter list and update test assertions
ibro45 ab4e153
fix: Import Path from pathlib to resolve undefined name errors in tests
ibro45 9459cdf
fix: Ensure layers are frozen correctly in LighterFreezer test and fi…
ibro45 59d7df6
test: Remove unnecessary frozen state assignment in freezer test
ibro45 e751458
test: Include dummy Trainer in freezer and system tests for validation
ibro45 8958764
fix: Import missing classes in test_freezer and test_system files
ibro45 85ba211
fix: Import missing Dataset, DataLoader, and Accuracy in test files
ibro45 ce26668
chore: Update imports in test files for consistency and clarity
ibro45 e236d47
fix: Resolve TypeError in trainer setup and import DataLoader correctly
ibro45 2abb159
feat: Wrap dummy models into a minimal PyTorch Lightning system
ibro45 214e7b1
fix: Correct test case for replacing layer with identity in unit tests
ibro45 c9a89fc
chore: Import Module from torch.nn in test_freezer.py
ibro45 4fafbe9
test: Fix NameError and improve tests for LighterFreezer functionality
ibro45 431a18f
refactor: Simplify dataset return structure in test_freezer.py
ibro45 5097be7
refactor: Replace DummySystem with LighterSystem and add model tests
ibro45 5f53545
feat: Update DummyModel architecture and enhance dummy_system setup
ibro45 dc50ff4
test: Ensure optimizer is correctly set up in freezer exception test
ibro45 1f2bd41
test: Update assertion to check for grad_fn in freezer test case
ibro45 0da59a7
test: Remove redundant optimizer setup assertions from test_freezer.py
ibro45 469af2b
fix: Correct freezing logic in freezer tests to match expected behavior
ibro45 af6363e
Fixes, reorganize
ibro45 dd4cb7c
test: Remove skip markers from training and validation step tests
ibro45 93fe1d4
fix: Attach DummySystem to Trainer and resolve batch size mismatches
ibro45 92fca36
fix: Ensure DummySystem is attached to Trainer in test_valid_batch_fo…
ibro45 db006e6
test: Refactor training, validation, and prediction steps in tests
ibro45 6157ebb
feat: Add predict dataset to DummySystem and update corresponding tests
ibro45 ad01f6a
test: Update unit tests to use training_step with mock logging
ibro45 946da7e
test: Add unit tests for empty datasets, model forward pass, and syst…
ibro45 7070c1c
Fix tests
ibro45 b2333b8
Update tests/unit/test_callbacks_writer_file.py
ibro45 776a531
Update tests/unit/test_callbacks_writer_table.py
ibro45 8ae2af0
Update tests/unit/test_callbacks_writer_file.py
ibro45 f42a089
Update tests/unit/test_utils_collate.py
ibro45 98af60b
Update tests/unit/test_callbacks_writer_base.py
ibro45 85913cd
test: Fix assertion in table writer test for correct key usage
ibro45 6bf6f76
test: Add comprehensive tests for LighterTableWriter functionality
ibro45 e581031
test: Replace mocked Trainer with a real Trainer instance in tests
ibro45 fa87780
test: Update Trainer instantiation in table writer tests for simplicity
ibro45 48782a7
test: Update test_table_writer_write to validate all CSV entries with…
ibro45 4ff7260
test: Update unit test for LighterTableWriter with expected records
ibro45 9eced0f
test: Fix DataFrame ID type comparison in table writer tests
ibro45 6a7281c
test: Update test cases for LighterTableWriter initialization and val…
ibro45 0f2d396
fix: Mock world_size and is_global_zero methods in tests
ibro45 74567db
test: Update test directory name and improve error handling in tests
ibro45 51d4557
test: Fix assertion and error handling in file writer tests
ibro45 e9e97aa
test: Fix test assertions and mock properties in writer tests
ibro45 2047058
fix: Mock world_size and is_global_zero methods in tests
ibro45 7b7836b
Improve tests provided by coderabbit
ibro45 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -152,3 +152,5 @@ projects/* | |
**/predictions/ | ||
*/.DS_Store | ||
.DS_Store | ||
.aider* | ||
test_dir/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import pytest | ||
import torch | ||
from pytorch_lightning import Trainer | ||
from torch.nn import Module | ||
from torch.utils.data import Dataset | ||
|
||
from lighter.callbacks.freezer import LighterFreezer | ||
from lighter.system import LighterSystem | ||
|
||
|
||
class DummyModel(Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.layer1 = torch.nn.Linear(10, 10) | ||
self.layer2 = torch.nn.Linear(10, 4) | ||
self.layer3 = torch.nn.Linear(4, 1) | ||
|
||
def forward(self, x): | ||
x = self.layer1(x) | ||
x = self.layer2(x) | ||
x = self.layer3(x) | ||
return x | ||
|
||
|
||
class DummyDataset(Dataset): | ||
def __len__(self): | ||
return 10 | ||
|
||
def __getitem__(self, idx): | ||
return {"input": torch.randn(10), "target": torch.tensor(0)} | ||
|
||
|
||
@pytest.fixture | ||
def dummy_system(): | ||
model = DummyModel() | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | ||
dataset = DummyDataset() | ||
criterion = torch.nn.CrossEntropyLoss() | ||
return LighterSystem(model=model, batch_size=32, criterion=criterion, optimizer=optimizer, datasets={"train": dataset}) | ||
|
||
|
||
def test_freezer_initialization(): | ||
freezer = LighterFreezer(names=["layer1"]) | ||
assert freezer.names == ["layer1"] | ||
|
||
|
||
def test_freezer_functionality(dummy_system): | ||
freezer = LighterFreezer(names=["layer1.weight", "layer1.bias"]) | ||
trainer = Trainer(callbacks=[freezer], max_epochs=1) | ||
trainer.fit(dummy_system) | ||
assert not dummy_system.model.layer1.weight.requires_grad | ||
assert not dummy_system.model.layer1.bias.requires_grad | ||
assert dummy_system.model.layer2.weight.requires_grad | ||
|
||
|
||
def test_freezer_with_exceptions(dummy_system): | ||
freezer = LighterFreezer(name_starts_with=["layer"], except_names=["layer2.weight", "layer2.bias"]) | ||
trainer = Trainer(callbacks=[freezer], max_epochs=1) | ||
trainer.fit(dummy_system) | ||
assert not dummy_system.model.layer1.weight.requires_grad | ||
assert not dummy_system.model.layer1.bias.requires_grad | ||
assert dummy_system.model.layer2.weight.requires_grad | ||
assert dummy_system.model.layer2.bias.requires_grad | ||
assert not dummy_system.model.layer3.weight.requires_grad | ||
assert not dummy_system.model.layer3.bias.requires_grad |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch | ||
|
||
from lighter.callbacks.utils import preprocess_image | ||
|
||
|
||
def test_preprocess_image_2d(): | ||
image = torch.rand(1, 3, 64, 64) # Batch of 2D images | ||
processed_image = preprocess_image(image) | ||
assert processed_image.shape == (3, 64, 64) | ||
|
||
|
||
def test_preprocess_image_3d(): | ||
batch_size = 8 | ||
depth = 20 | ||
height = 64 | ||
width = 64 | ||
image = torch.rand(batch_size, 1, depth, height, width) # Batch of 3D images | ||
processed_image = preprocess_image(image) | ||
assert processed_image.shape == (1, depth * height, batch_size * width) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import pytest | ||
|
||
from lighter.callbacks.writer.base import LighterBaseWriter | ||
|
||
|
||
def test_writer_initialization(): | ||
with pytest.raises(TypeError): | ||
LighterBaseWriter(path="test", writer="tensor") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
from lighter.callbacks.writer.file import LighterFileWriter | ||
|
||
|
||
import shutil | ||
|
||
def test_file_writer_initialization(): | ||
"""Test LighterFileWriter initialization with proper attributes.""" | ||
path = Path("test_dir") | ||
path.mkdir(exist_ok=True) # Ensure the directory exists | ||
try: | ||
writer = LighterFileWriter(path=path, writer="tensor") | ||
assert writer.path == Path("test_dir") | ||
assert writer.writer == "tensor" # Verify writer type | ||
finally: | ||
shutil.rmtree(path) # Clean up after test | ||
|
||
import pytest | ||
|
||
def test_file_writer_write_tensor(): | ||
"""Test LighterFileWriter's ability to write and persist tensors correctly.""" | ||
test_dir = Path("test_dir_tensor") | ||
test_dir.mkdir(exist_ok=True) | ||
try: | ||
writer = LighterFileWriter(path=test_dir, writer="tensor") | ||
tensor = torch.tensor([1, 2, 3]) | ||
writer.write(tensor, id=1) | ||
|
||
# Verify file exists | ||
saved_path = writer.path / "1.pt" | ||
assert saved_path.exists() | ||
|
||
# Verify tensor contents | ||
loaded_tensor = torch.load(saved_path) | ||
assert torch.equal(loaded_tensor, tensor) | ||
finally: | ||
shutil.rmtree(test_dir) | ||
|
||
def test_file_writer_write_tensor_errors(): | ||
"""Test error handling in LighterFileWriter.""" | ||
writer = LighterFileWriter(path="test_dir_errors", writer="tensor") | ||
|
||
# Test invalid tensor | ||
with pytest.raises(TypeError): | ||
writer.write("not a tensor", id=1) | ||
|
||
# Test invalid ID | ||
with pytest.raises(ValueError): | ||
writer.write(torch.tensor([1]), id=-1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
from lighter.callbacks.writer.table import LighterTableWriter | ||
|
||
|
||
def test_table_writer_initialization(): | ||
writer = LighterTableWriter(path="test.csv", writer="tensor") | ||
assert writer.path == Path("test.csv") | ||
|
||
|
||
def test_table_writer_write(): | ||
"""Test LighterTableWriter write functionality with various inputs.""" | ||
test_file = Path("test.csv") | ||
writer = LighterTableWriter(path="test.csv", writer="tensor") | ||
|
||
# Test basic write | ||
test_tensor = torch.tensor([1, 2, 3]) | ||
writer.write(tensor=test_tensor, id=1) | ||
assert len(writer.csv_records) == 1 | ||
assert writer.csv_records[0]["tensor"] == test_tensor.tolist() | ||
assert writer.csv_records[0]["id"] == 1 | ||
|
||
# Test edge cases | ||
writer.write(tensor=torch.tensor([]), id=2) # empty tensor | ||
writer.write(tensor=torch.randn(1000), id=3) # large tensor | ||
writer.write(tensor=torch.tensor([1.5, 2.5]), id=4) # float tensor | ||
|
||
# Verify file creation and content | ||
assert test_file.exists() | ||
with open(test_file) as f: | ||
content = f.read() | ||
assert "1,2,3" in content # verify first tensor | ||
|
||
# Cleanup | ||
test_file.unlink() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
from unittest import mock | ||
|
||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from pytorch_lightning import Trainer | ||
from torch.optim import Adam | ||
from torch.optim.lr_scheduler import StepLR | ||
from torch.utils.data import DataLoader, Dataset | ||
from torchmetrics import Accuracy | ||
|
||
from lighter.system import LighterSystem | ||
|
||
|
||
class DummyDataset(Dataset): | ||
def __init__(self, size=100): | ||
self.size = size | ||
|
||
def __len__(self): | ||
return self.size | ||
|
||
def __getitem__(self, idx): | ||
x = torch.randn(3, 32, 32) | ||
|
||
y = torch.randint(0, 10, size=()).long() # Changed to return scalar tensor | ||
return {"input": x, "target": y} | ||
|
||
|
||
class DummyPredictDataset(Dataset): | ||
def __init__(self, size=20): | ||
self.size = size | ||
|
||
def __len__(self): | ||
return self.size | ||
|
||
def __getitem__(self, idx): | ||
x = torch.randn(3, 32, 32) | ||
return {"input": x} | ||
|
||
|
||
class DummyModel(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.net = nn.Sequential(nn.Flatten(), nn.Linear(3072, 10)) | ||
|
||
def forward(self, x): | ||
return self.net(x) | ||
|
||
|
||
class DummySystem(LighterSystem): | ||
def __init__(self): | ||
model = DummyModel() | ||
optimizer = Adam(model.parameters(), lr=0.001) | ||
scheduler = StepLR(optimizer, step_size=1) | ||
criterion = nn.CrossEntropyLoss() | ||
|
||
datasets = { | ||
"train": DummyDataset(), | ||
"val": DummyDataset(50), | ||
"test": DummyDataset(20), | ||
"predict": DummyPredictDataset(20), | ||
} | ||
|
||
metrics = { | ||
"train": Accuracy(task="multiclass", num_classes=10), | ||
"val": Accuracy(task="multiclass", num_classes=10), | ||
"test": Accuracy(task="multiclass", num_classes=10), | ||
} | ||
|
||
super().__init__( | ||
model=model, | ||
batch_size=32, | ||
optimizer=optimizer, | ||
scheduler=scheduler, | ||
criterion=criterion, | ||
datasets=datasets, | ||
metrics=metrics, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def dummy_system(): | ||
return DummySystem() | ||
|
||
|
||
def test_system_with_trainer(dummy_system): | ||
trainer = Trainer(max_epochs=1) | ||
trainer.fit(dummy_system) | ||
assert dummy_system.batch_size == 32 | ||
assert isinstance(dummy_system.model, DummyModel) | ||
assert isinstance(dummy_system.optimizer, Adam) | ||
assert isinstance(dummy_system.scheduler, StepLR) | ||
|
||
|
||
def test_configure_optimizers(dummy_system): | ||
config = dummy_system.configure_optimizers() | ||
assert "optimizer" in config | ||
assert "lr_scheduler" in config | ||
assert isinstance(config["optimizer"], Adam) | ||
assert isinstance(config["lr_scheduler"], StepLR) | ||
|
||
|
||
def test_dataloader_creation(dummy_system): | ||
dummy_system.setup("fit") | ||
train_loader = dummy_system.train_dataloader() | ||
assert isinstance(train_loader, DataLoader) | ||
assert train_loader.batch_size == 32 | ||
|
||
|
||
def test_training_step(dummy_system): | ||
trainer = Trainer(max_epochs=1) | ||
trainer.fit(dummy_system) # Only to attach the system to the trainer | ||
batch = next(iter(dummy_system.train_dataloader())) | ||
|
||
# https://github.com/Lightning-AI/pytorch-lightning/issues/9674#issuecomment-926243063 | ||
with mock.patch.object(dummy_system, "log"): | ||
result = dummy_system.training_step(batch, batch_idx=0) | ||
|
||
assert "loss" in result | ||
assert "metrics" in result | ||
assert "input" in result | ||
assert "target" in result | ||
assert "pred" in result | ||
assert torch.is_tensor(result["loss"]) | ||
|
||
|
||
def test_validation_step(dummy_system): | ||
trainer = Trainer(max_epochs=1) | ||
trainer.validate(dummy_system) # Only to attach the system to the trainer | ||
batch = next(iter(dummy_system.val_dataloader())) | ||
|
||
# https://github.com/Lightning-AI/pytorch-lightning/issues/9674#issuecomment-926243063 | ||
with mock.patch.object(dummy_system, "log"): | ||
result = dummy_system.validation_step(batch, batch_idx=0) | ||
|
||
assert "loss" in result | ||
assert "metrics" in result | ||
assert torch.is_tensor(result["loss"]) | ||
|
||
|
||
def test_predict_step(dummy_system): | ||
trainer = Trainer(max_epochs=1) | ||
trainer.predict(dummy_system) # Only to attach the system to the trainer | ||
batch = next(iter(dummy_system.predict_dataloader())) | ||
|
||
# https://github.com/Lightning-AI/pytorch-lightning/issues/9674#issuecomment-926243063 | ||
with mock.patch.object(dummy_system, "log"): | ||
result = dummy_system.predict_step(batch, batch_idx=0) | ||
|
||
assert "pred" in result | ||
assert torch.is_tensor(result["pred"]) | ||
|
||
|
||
def test_learning_rate_property(dummy_system): | ||
initial_lr = dummy_system.learning_rate | ||
assert initial_lr == 0.001 | ||
|
||
dummy_system.learning_rate = 0.01 | ||
assert dummy_system.learning_rate == 0.01 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch", | ||
[ | ||
{"input": torch.randn(1, 3, 32, 32), "target": torch.randint(0, 10, size=(1,)).long()}, | ||
{"input": torch.randn(2, 3, 32, 32), "target": torch.randint(0, 10, size=(2,)).long()}, | ||
{"input": torch.randn(4, 3, 32, 32), "target": torch.randint(0, 10, size=(4,)).long(), "id": "test_id"}, | ||
], | ||
) | ||
def test_valid_batch_formats(dummy_system, batch): | ||
trainer = Trainer(max_epochs=1) | ||
trainer.fit(dummy_system) # Only to attach the system to the trainer | ||
|
||
# https://github.com/Lightning-AI/pytorch-lightning/issues/9674#issuecomment-926243063 | ||
with mock.patch.object(dummy_system, "log"): | ||
result = dummy_system.training_step(batch, batch_idx=0) | ||
assert isinstance(result, dict) | ||
|
||
|
||
@pytest.mark.xfail(raises=ValueError) | ||
def test_invalid_batch_format(dummy_system): | ||
invalid_batch = {"wrong_key": torch.randn(1, 3, 32, 32)} | ||
trainer = Trainer(max_epochs=1) | ||
trainer.fit(dummy_system) # Only to attach the system to the trainer | ||
|
||
# https://github.com/Lightning-AI/pytorch-lightning/issues/9674#issuecomment-926243063 | ||
with mock.patch.object(dummy_system, "log"): | ||
dummy_system.training_step(invalid_batch, batch_idx=0) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Enhance initialization test coverage
The current test only verifies the path attribute. Consider adding:
writer
attribute