Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for all modules #138

Merged
merged 51 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
27a921a
test: Add unit tests for LighterSystem functionality and behavior
ibro45 Nov 28, 2024
e3a784f
test: Add comprehensive unit tests for all modules in the lighter pac…
ibro45 Nov 28, 2024
e272ce8
fix: Import torch in test files to resolve undefined name errors
ibro45 Nov 28, 2024
2eeab00
fix: Resolve optimizer empty parameter list and update test assertions
ibro45 Nov 28, 2024
ab4e153
fix: Import Path from pathlib to resolve undefined name errors in tests
ibro45 Nov 28, 2024
9459cdf
fix: Ensure layers are frozen correctly in LighterFreezer test and fi…
ibro45 Nov 28, 2024
59d7df6
test: Remove unnecessary frozen state assignment in freezer test
ibro45 Nov 29, 2024
e751458
test: Include dummy Trainer in freezer and system tests for validation
ibro45 Nov 29, 2024
8958764
fix: Import missing classes in test_freezer and test_system files
ibro45 Nov 29, 2024
85ba211
fix: Import missing Dataset, DataLoader, and Accuracy in test files
ibro45 Nov 29, 2024
ce26668
chore: Update imports in test files for consistency and clarity
ibro45 Nov 29, 2024
e236d47
fix: Resolve TypeError in trainer setup and import DataLoader correctly
ibro45 Nov 29, 2024
2abb159
feat: Wrap dummy models into a minimal PyTorch Lightning system
ibro45 Nov 29, 2024
214e7b1
fix: Correct test case for replacing layer with identity in unit tests
ibro45 Nov 29, 2024
c9a89fc
chore: Import Module from torch.nn in test_freezer.py
ibro45 Nov 29, 2024
4fafbe9
test: Fix NameError and improve tests for LighterFreezer functionality
ibro45 Nov 29, 2024
431a18f
refactor: Simplify dataset return structure in test_freezer.py
ibro45 Nov 29, 2024
5097be7
refactor: Replace DummySystem with LighterSystem and add model tests
ibro45 Nov 29, 2024
5f53545
feat: Update DummyModel architecture and enhance dummy_system setup
ibro45 Nov 29, 2024
dc50ff4
test: Ensure optimizer is correctly set up in freezer exception test
ibro45 Nov 29, 2024
1f2bd41
test: Update assertion to check for grad_fn in freezer test case
ibro45 Nov 29, 2024
0da59a7
test: Remove redundant optimizer setup assertions from test_freezer.py
ibro45 Nov 29, 2024
469af2b
fix: Correct freezing logic in freezer tests to match expected behavior
ibro45 Nov 29, 2024
af6363e
Fixes, reorganize
ibro45 Nov 29, 2024
dd4cb7c
test: Remove skip markers from training and validation step tests
ibro45 Nov 29, 2024
93fe1d4
fix: Attach DummySystem to Trainer and resolve batch size mismatches
ibro45 Nov 29, 2024
92fca36
fix: Ensure DummySystem is attached to Trainer in test_valid_batch_fo…
ibro45 Nov 29, 2024
db006e6
test: Refactor training, validation, and prediction steps in tests
ibro45 Nov 29, 2024
6157ebb
feat: Add predict dataset to DummySystem and update corresponding tests
ibro45 Nov 29, 2024
ad01f6a
test: Update unit tests to use training_step with mock logging
ibro45 Nov 29, 2024
946da7e
test: Add unit tests for empty datasets, model forward pass, and syst…
ibro45 Nov 29, 2024
7070c1c
Fix tests
ibro45 Nov 29, 2024
b2333b8
Update tests/unit/test_callbacks_writer_file.py
ibro45 Nov 30, 2024
776a531
Update tests/unit/test_callbacks_writer_table.py
ibro45 Nov 30, 2024
8ae2af0
Update tests/unit/test_callbacks_writer_file.py
ibro45 Nov 30, 2024
f42a089
Update tests/unit/test_utils_collate.py
ibro45 Nov 30, 2024
98af60b
Update tests/unit/test_callbacks_writer_base.py
ibro45 Nov 30, 2024
85913cd
test: Fix assertion in table writer test for correct key usage
ibro45 Nov 30, 2024
6bf6f76
test: Add comprehensive tests for LighterTableWriter functionality
ibro45 Nov 30, 2024
e581031
test: Replace mocked Trainer with a real Trainer instance in tests
ibro45 Nov 30, 2024
fa87780
test: Update Trainer instantiation in table writer tests for simplicity
ibro45 Nov 30, 2024
48782a7
test: Update test_table_writer_write to validate all CSV entries with…
ibro45 Nov 30, 2024
4ff7260
test: Update unit test for LighterTableWriter with expected records
ibro45 Nov 30, 2024
9eced0f
test: Fix DataFrame ID type comparison in table writer tests
ibro45 Nov 30, 2024
6a7281c
test: Update test cases for LighterTableWriter initialization and val…
ibro45 Nov 30, 2024
0f2d396
fix: Mock world_size and is_global_zero methods in tests
ibro45 Nov 30, 2024
74567db
test: Update test directory name and improve error handling in tests
ibro45 Nov 30, 2024
51d4557
test: Fix assertion and error handling in file writer tests
ibro45 Nov 30, 2024
e9e97aa
test: Fix test assertions and mock properties in writer tests
ibro45 Nov 30, 2024
2047058
fix: Mock world_size and is_global_zero methods in tests
ibro45 Nov 30, 2024
7b7836b
Improve tests provided by coderabbit
ibro45 Nov 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,5 @@ projects/*
**/predictions/
*/.DS_Store
.DS_Store
.aider*
test_dir/
2 changes: 1 addition & 1 deletion lighter/utils/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def parse_config(**kwargs) -> ConfigParser:
raise ValueError("'--config' not specified. Please provide a valid configuration file.")

# Initialize the parser with the predefined structure.
parser = ConfigParser(ConfigSchema().dict(), globals=False)
parser = ConfigParser(ConfigSchema().model_dump(), globals=False)
# Update the parser with the configuration file.
parser.update(parser.load_config_files(kwargs.pop("config")))
# Update the parser with the provided cli arguments.
Expand Down
65 changes: 65 additions & 0 deletions tests/unit/test_callbacks_freezer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest
import torch
from pytorch_lightning import Trainer
from torch.nn import Module
from torch.utils.data import Dataset

from lighter.callbacks.freezer import LighterFreezer
from lighter.system import LighterSystem


class DummyModel(Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(10, 10)
self.layer2 = torch.nn.Linear(10, 4)
self.layer3 = torch.nn.Linear(4, 1)

def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x


class DummyDataset(Dataset):
def __len__(self):
return 10

def __getitem__(self, idx):
return {"input": torch.randn(10), "target": torch.tensor(0)}


@pytest.fixture
def dummy_system():
model = DummyModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
dataset = DummyDataset()
criterion = torch.nn.CrossEntropyLoss()
return LighterSystem(model=model, batch_size=32, criterion=criterion, optimizer=optimizer, datasets={"train": dataset})


def test_freezer_initialization():
freezer = LighterFreezer(names=["layer1"])
assert freezer.names == ["layer1"]


def test_freezer_functionality(dummy_system):
freezer = LighterFreezer(names=["layer1.weight", "layer1.bias"])
trainer = Trainer(callbacks=[freezer], max_epochs=1)
trainer.fit(dummy_system)
assert not dummy_system.model.layer1.weight.requires_grad
assert not dummy_system.model.layer1.bias.requires_grad
assert dummy_system.model.layer2.weight.requires_grad


def test_freezer_with_exceptions(dummy_system):
freezer = LighterFreezer(name_starts_with=["layer"], except_names=["layer2.weight", "layer2.bias"])
trainer = Trainer(callbacks=[freezer], max_epochs=1)
trainer.fit(dummy_system)
assert not dummy_system.model.layer1.weight.requires_grad
assert not dummy_system.model.layer1.bias.requires_grad
assert dummy_system.model.layer2.weight.requires_grad
assert dummy_system.model.layer2.bias.requires_grad
assert not dummy_system.model.layer3.weight.requires_grad
assert not dummy_system.model.layer3.bias.requires_grad
19 changes: 19 additions & 0 deletions tests/unit/test_callbacks_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch

from lighter.callbacks.utils import preprocess_image


def test_preprocess_image_2d():
image = torch.rand(1, 3, 64, 64) # Batch of 2D images
processed_image = preprocess_image(image)
assert processed_image.shape == (3, 64, 64)


def test_preprocess_image_3d():
batch_size = 8
depth = 20
height = 64
width = 64
image = torch.rand(batch_size, 1, depth, height, width) # Batch of 3D images
processed_image = preprocess_image(image)
assert processed_image.shape == (1, depth * height, batch_size * width)
8 changes: 8 additions & 0 deletions tests/unit/test_callbacks_writer_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest

from lighter.callbacks.writer.base import LighterBaseWriter


def test_writer_initialization():
with pytest.raises(TypeError):
LighterBaseWriter(path="test", writer="tensor")
52 changes: 52 additions & 0 deletions tests/unit/test_callbacks_writer_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from pathlib import Path

import torch

from lighter.callbacks.writer.file import LighterFileWriter


import shutil

def test_file_writer_initialization():
"""Test LighterFileWriter initialization with proper attributes."""
path = Path("test_dir")
path.mkdir(exist_ok=True) # Ensure the directory exists
try:
writer = LighterFileWriter(path=path, writer="tensor")
assert writer.path == Path("test_dir")
assert writer.writer == "tensor" # Verify writer type
finally:
shutil.rmtree(path) # Clean up after test

import pytest

def test_file_writer_write_tensor():
"""Test LighterFileWriter's ability to write and persist tensors correctly."""
test_dir = Path("test_dir_tensor")
test_dir.mkdir(exist_ok=True)
try:
writer = LighterFileWriter(path=test_dir, writer="tensor")
tensor = torch.tensor([1, 2, 3])
writer.write(tensor, id=1)

# Verify file exists
saved_path = writer.path / "1.pt"
assert saved_path.exists()

# Verify tensor contents
loaded_tensor = torch.load(saved_path)
assert torch.equal(loaded_tensor, tensor)
finally:
shutil.rmtree(test_dir)

def test_file_writer_write_tensor_errors():
"""Test error handling in LighterFileWriter."""
writer = LighterFileWriter(path="test_dir_errors", writer="tensor")

# Test invalid tensor
with pytest.raises(TypeError):
writer.write("not a tensor", id=1)

# Test invalid ID
with pytest.raises(ValueError):
writer.write(torch.tensor([1]), id=-1)
37 changes: 37 additions & 0 deletions tests/unit/test_callbacks_writer_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from pathlib import Path

import torch

from lighter.callbacks.writer.table import LighterTableWriter


def test_table_writer_initialization():
writer = LighterTableWriter(path="test.csv", writer="tensor")
assert writer.path == Path("test.csv")
Comment on lines +17 to +19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance initialization test coverage

The current test only verifies the path attribute. Consider adding:

  1. Verification of the writer attribute
  2. Negative test cases (invalid paths, missing arguments)
  3. Cleanup of any created files
  4. Docstring explaining test purpose and assertions
 def test_table_writer_initialization():
+    """Test LighterTableWriter initialization with valid and invalid inputs."""
     writer = LighterTableWriter(path="test.csv", writer="tensor")
     assert writer.path == Path("test.csv")
+    assert writer.writer == "tensor"
+    
+    # Test invalid writer type
+    with pytest.raises(ValueError):
+        LighterTableWriter(path="test.csv", writer="invalid_type")
+    
+    # Test invalid path
+    with pytest.raises(ValueError):
+        LighterTableWriter(path="", writer="tensor")

Committable suggestion skipped: line range outside the PR's diff.



def test_table_writer_write():
"""Test LighterTableWriter write functionality with various inputs."""
test_file = Path("test.csv")
writer = LighterTableWriter(path="test.csv", writer="tensor")

# Test basic write
test_tensor = torch.tensor([1, 2, 3])
writer.write(tensor=test_tensor, id=1)
assert len(writer.csv_records) == 1
assert writer.csv_records[0]["tensor"] == test_tensor.tolist()
assert writer.csv_records[0]["id"] == 1

# Test edge cases
writer.write(tensor=torch.tensor([]), id=2) # empty tensor
writer.write(tensor=torch.randn(1000), id=3) # large tensor
writer.write(tensor=torch.tensor([1.5, 2.5]), id=4) # float tensor

# Verify file creation and content
assert test_file.exists()
with open(test_file) as f:
content = f.read()
assert "1,2,3" in content # verify first tensor

# Cleanup
test_file.unlink()
188 changes: 188 additions & 0 deletions tests/unit/test_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from unittest import mock

import pytest
import torch
import torch.nn as nn
from pytorch_lightning import Trainer
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Accuracy

from lighter.system import LighterSystem


class DummyDataset(Dataset):
def __init__(self, size=100):
self.size = size

def __len__(self):
return self.size

def __getitem__(self, idx):
x = torch.randn(3, 32, 32)

y = torch.randint(0, 10, size=()).long() # Changed to return scalar tensor
return {"input": x, "target": y}


class DummyPredictDataset(Dataset):
def __init__(self, size=20):
self.size = size

def __len__(self):
return self.size

def __getitem__(self, idx):
x = torch.randn(3, 32, 32)
return {"input": x}


class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(nn.Flatten(), nn.Linear(3072, 10))

def forward(self, x):
return self.net(x)


class DummySystem(LighterSystem):
def __init__(self):
model = DummyModel()
optimizer = Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=1)
criterion = nn.CrossEntropyLoss()

datasets = {
"train": DummyDataset(),
"val": DummyDataset(50),
"test": DummyDataset(20),
"predict": DummyPredictDataset(20),
}

metrics = {
"train": Accuracy(task="multiclass", num_classes=10),
"val": Accuracy(task="multiclass", num_classes=10),
"test": Accuracy(task="multiclass", num_classes=10),
}

super().__init__(
model=model,
batch_size=32,
optimizer=optimizer,
scheduler=scheduler,
criterion=criterion,
datasets=datasets,
metrics=metrics,
)


@pytest.fixture
def dummy_system():
return DummySystem()


def test_system_with_trainer(dummy_system):
trainer = Trainer(max_epochs=1)
trainer.fit(dummy_system)
assert dummy_system.batch_size == 32
assert isinstance(dummy_system.model, DummyModel)
assert isinstance(dummy_system.optimizer, Adam)
assert isinstance(dummy_system.scheduler, StepLR)


def test_configure_optimizers(dummy_system):
config = dummy_system.configure_optimizers()
assert "optimizer" in config
assert "lr_scheduler" in config
assert isinstance(config["optimizer"], Adam)
assert isinstance(config["lr_scheduler"], StepLR)


def test_dataloader_creation(dummy_system):
dummy_system.setup("fit")
train_loader = dummy_system.train_dataloader()
assert isinstance(train_loader, DataLoader)
assert train_loader.batch_size == 32


def test_training_step(dummy_system):
trainer = Trainer(max_epochs=1)
trainer.fit(dummy_system) # Only to attach the system to the trainer
batch = next(iter(dummy_system.train_dataloader()))

# https://github.com/Lightning-AI/pytorch-lightning/issues/9674#issuecomment-926243063
with mock.patch.object(dummy_system, "log"):
result = dummy_system.training_step(batch, batch_idx=0)

assert "loss" in result
assert "metrics" in result
assert "input" in result
assert "target" in result
assert "pred" in result
assert torch.is_tensor(result["loss"])


def test_validation_step(dummy_system):
trainer = Trainer(max_epochs=1)
trainer.validate(dummy_system) # Only to attach the system to the trainer
batch = next(iter(dummy_system.val_dataloader()))

# https://github.com/Lightning-AI/pytorch-lightning/issues/9674#issuecomment-926243063
with mock.patch.object(dummy_system, "log"):
result = dummy_system.validation_step(batch, batch_idx=0)

assert "loss" in result
assert "metrics" in result
assert torch.is_tensor(result["loss"])


def test_predict_step(dummy_system):
trainer = Trainer(max_epochs=1)
trainer.predict(dummy_system) # Only to attach the system to the trainer
batch = next(iter(dummy_system.predict_dataloader()))

# https://github.com/Lightning-AI/pytorch-lightning/issues/9674#issuecomment-926243063
with mock.patch.object(dummy_system, "log"):
result = dummy_system.predict_step(batch, batch_idx=0)

assert "pred" in result
assert torch.is_tensor(result["pred"])


def test_learning_rate_property(dummy_system):
initial_lr = dummy_system.learning_rate
assert initial_lr == 0.001

dummy_system.learning_rate = 0.01
assert dummy_system.learning_rate == 0.01


@pytest.mark.parametrize(
"batch",
[
{"input": torch.randn(1, 3, 32, 32), "target": torch.randint(0, 10, size=(1,)).long()},
{"input": torch.randn(2, 3, 32, 32), "target": torch.randint(0, 10, size=(2,)).long()},
{"input": torch.randn(4, 3, 32, 32), "target": torch.randint(0, 10, size=(4,)).long(), "id": "test_id"},
],
)
def test_valid_batch_formats(dummy_system, batch):
trainer = Trainer(max_epochs=1)
trainer.fit(dummy_system) # Only to attach the system to the trainer

# https://github.com/Lightning-AI/pytorch-lightning/issues/9674#issuecomment-926243063
with mock.patch.object(dummy_system, "log"):
result = dummy_system.training_step(batch, batch_idx=0)
assert isinstance(result, dict)


@pytest.mark.xfail(raises=ValueError)
def test_invalid_batch_format(dummy_system):
invalid_batch = {"wrong_key": torch.randn(1, 3, 32, 32)}
trainer = Trainer(max_epochs=1)
trainer.fit(dummy_system) # Only to attach the system to the trainer

# https://github.com/Lightning-AI/pytorch-lightning/issues/9674#issuecomment-926243063
with mock.patch.object(dummy_system, "log"):
dummy_system.training_step(invalid_batch, batch_idx=0)
Loading
Loading