Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ibro45 committed Nov 29, 2024
1 parent 946da7e commit 7070c1c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 55 deletions.
2 changes: 1 addition & 1 deletion lighter/utils/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def parse_config(**kwargs) -> ConfigParser:
raise ValueError("'--config' not specified. Please provide a valid configuration file.")

# Initialize the parser with the predefined structure.
parser = ConfigParser(ConfigSchema().dict(), globals=False)
parser = ConfigParser(ConfigSchema().model_dump(), globals=False)
# Update the parser with the configuration file.
parser.update(parser.load_config_files(kwargs.pop("config")))
# Update the parser with the provided cli arguments.
Expand Down
43 changes: 9 additions & 34 deletions tests/unit/test_system.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest import mock

import pytest
import torch
import torch.nn as nn
Expand All @@ -8,7 +10,6 @@
from torchmetrics import Accuracy

from lighter.system import LighterSystem
from unittest import mock


class DummyDataset(Dataset):
Expand Down Expand Up @@ -82,34 +83,7 @@ def dummy_system():
return DummySystem()


def test_empty_dataset():
empty_dataset = DummyDataset(size=0)
assert len(empty_dataset) == 0
with pytest.raises(IndexError):
_ = empty_dataset[0]

def test_empty_predict_dataset():
empty_predict_dataset = DummyPredictDataset(size=0)
assert len(empty_predict_dataset) == 0
with pytest.raises(IndexError):
_ = empty_predict_dataset[0]

def test_model_forward_pass():
model = DummyModel()
input_tensor = torch.randn(1, 3, 32, 32)
output = model(input_tensor)
assert output.shape == (1, 10)

def test_system_initialization():
system = DummySystem()
assert isinstance(system.model, DummyModel)
assert isinstance(system.optimizer, Adam)
assert isinstance(system.scheduler, StepLR)
assert isinstance(system.criterion, nn.CrossEntropyLoss)
assert "train" in system.datasets
assert "val" in system.datasets
assert "test" in system.datasets
assert "predict" in system.datasets
def test_system_with_trainer(dummy_system):
trainer = Trainer(max_epochs=1)
trainer.fit(dummy_system)
assert dummy_system.batch_size == 32
Expand Down Expand Up @@ -139,7 +113,7 @@ def test_training_step(dummy_system):
batch = next(iter(dummy_system.train_dataloader()))

# https://github.com/Lightning-AI/pytorch-lightning/issues/9674#issuecomment-926243063
with mock.patch.object(dummy_system, 'log'):
with mock.patch.object(dummy_system, "log"):
result = dummy_system.training_step(batch, batch_idx=0)

assert "loss" in result
Expand All @@ -156,7 +130,7 @@ def test_validation_step(dummy_system):
batch = next(iter(dummy_system.val_dataloader()))

# https://github.com/Lightning-AI/pytorch-lightning/issues/9674#issuecomment-926243063
with mock.patch.object(dummy_system, 'log'):
with mock.patch.object(dummy_system, "log"):
result = dummy_system.validation_step(batch, batch_idx=0)

assert "loss" in result
Expand All @@ -170,7 +144,7 @@ def test_predict_step(dummy_system):
batch = next(iter(dummy_system.predict_dataloader()))

# https://github.com/Lightning-AI/pytorch-lightning/issues/9674#issuecomment-926243063
with mock.patch.object(dummy_system, 'log'):
with mock.patch.object(dummy_system, "log"):
result = dummy_system.predict_step(batch, batch_idx=0)

assert "pred" in result
Expand All @@ -184,6 +158,7 @@ def test_learning_rate_property(dummy_system):
dummy_system.learning_rate = 0.01
assert dummy_system.learning_rate == 0.01


@pytest.mark.parametrize(
"batch",
[
Expand All @@ -197,7 +172,7 @@ def test_valid_batch_formats(dummy_system, batch):
trainer.fit(dummy_system) # Only to attach the system to the trainer

# https://github.com/Lightning-AI/pytorch-lightning/issues/9674#issuecomment-926243063
with mock.patch.object(dummy_system, 'log'):
with mock.patch.object(dummy_system, "log"):
result = dummy_system.training_step(batch, batch_idx=0)
assert isinstance(result, dict)

Expand All @@ -209,5 +184,5 @@ def test_invalid_batch_format(dummy_system):
trainer.fit(dummy_system) # Only to attach the system to the trainer

# https://github.com/Lightning-AI/pytorch-lightning/issues/9674#issuecomment-926243063
with mock.patch.object(dummy_system, 'log'):
with mock.patch.object(dummy_system, "log"):
dummy_system.training_step(invalid_batch, batch_idx=0)
12 changes: 8 additions & 4 deletions tests/unit/test_utils_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from lighter.utils.model import remove_n_last_layers_sequentially, replace_layer_with, replace_layer_with_identity


class SimpleModel(torch.nn.Module):
class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = Linear(10, 10)
Expand All @@ -15,14 +15,14 @@ def forward(self, x):


def test_replace_layer_with():
model = SimpleModel()
new_layer = Linear(10, 10)
model = DummyModel()
new_layer = Linear(10, 4)
replace_layer_with(model, "layer1", new_layer)
assert model.layer1 == new_layer


def test_replace_layer_with_identity():
model = SimpleModel()
model = DummyModel()
replace_layer_with_identity(model, "layer1")
assert isinstance(model.layer1, torch.nn.Identity)

Expand All @@ -31,3 +31,7 @@ def test_remove_n_last_layers_sequentially():
model = Sequential(Linear(10, 10), Linear(10, 10), Linear(10, 10))
new_model = remove_n_last_layers_sequentially(model, num_layers=1)
assert len(new_model) == 2

model = Sequential(Linear(10, 10), Linear(10, 10), Linear(10, 10))
new_model = remove_n_last_layers_sequentially(model, num_layers=2)
assert len(new_model) == 1
27 changes: 11 additions & 16 deletions tests/unit/test_utils_patches.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,25 @@
from torch.nn import Linear, Module
import pytest

from lighter.utils.patches import PatchedModuleDict


def test_patched_module_dict_handles_reserved_names():
# Test with previously problematic reserved names
reserved_names = {
'type': None,
'to': None,
'forward': Linear(10, 10),
'training': Linear(10, 10),
'modules': [Linear(10, 10)]
"type": None,
"to": None,
"forward": Linear(10, 10),
"training": Linear(10, 10),
}

# Should work without raising exceptions
patched_dict = PatchedModuleDict(reserved_names)

# Verify all keys are accessible
for key in reserved_names:
assert key in patched_dict
assert isinstance(patched_dict[key], Linear)

# Test dictionary operations
assert set(patched_dict.keys()) == set(reserved_names.keys())
assert len(patched_dict.values()) == len(reserved_names)

assert patched_dict[key] == reserved_names[key]

# Test deletion
del patched_dict['type']
assert 'type' not in patched_dict
del patched_dict["type"]
assert "type" not in patched_dict

0 comments on commit 7070c1c

Please sign in to comment.