-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add unit tests for all modules #138
Conversation
…x writer path error
WalkthroughThe pull request introduces several changes across multiple files. The Changes
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 11
🧹 Outside diff range and nitpick comments (23)
tests/unit/test_utils_logging.py (2)
1-1
: Consider testing through public interfaces instead of private functionsThe test imports
_setup_logging
which is marked as private (leading underscore). Testing private implementation details directly can make refactoring more difficult. Consider testing through public interfaces instead.
4-6
: Add docstring to document test purposeThe test function lacks documentation explaining its purpose and expected behavior. Add a docstring to improve maintainability.
tests/unit/test_utils_runner.py (2)
6-8
: Good test! Consider adding more test cases.The test correctly verifies the error case when no config is provided. However, we could enhance the test coverage.
Consider adding these additional test cases:
- Test with empty config
- Test with invalid config
- Test with minimal valid config
Here's a suggested implementation:
def test_parse_config_empty_config(): with pytest.raises(ValueError): parse_config({}) def test_parse_config_invalid_config(): with pytest.raises(ValueError): parse_config({"invalid_key": "value"}) def test_parse_config_minimal_valid(): config = { "name": "test", "version": "1.0" } result = parse_config(config) assert result.name == "test" assert result.version == "1.0"
5-5
: Remove extra blank line.There's an extra blank line that can be removed to maintain consistent spacing.
- def test_parse_config_no_config():
tests/unit/test_utils_dynamic_imports.py (1)
6-8
: Consider adding more test cases for comprehensive coverage.While the error handling test is good, consider adding:
- Positive test case with a valid module
- Edge cases testing different path formats
- Test case with an existing path but invalid module
- Test case with a module containing syntax errors
Here's a suggested implementation:
import os import tempfile import textwrap def test_import_module_from_path(): # Test non-existent module (current test) with pytest.raises(FileNotFoundError): import_module_from_path("non_existent_module", "non_existent_path") # Test valid module with tempfile.NamedTemporaryFile(suffix='.py', mode='w', delete=False) as f: f.write(textwrap.dedent(''' def sample_function(): return 42 ''')) try: module = import_module_from_path("sample_module", f.name) assert module.sample_function() == 42 finally: os.unlink(f.name) # Test invalid module content with tempfile.NamedTemporaryFile(suffix='.py', mode='w', delete=False) as f: f.write("invalid python syntax :") try: with pytest.raises(SyntaxError): import_module_from_path("invalid_module", f.name) finally: os.unlink(f.name)tests/unit/test_utils_collate.py (1)
3-3
: Remove unnecessary blank lineThere's an extra blank line that can be removed to improve code readability.
from lighter.utils.collate import collate_replace_corrupted - def test_collate_replace_corrupted():
tests/unit/test_callbacks_writer_table.py (2)
1-7
: Remove extra empty line after importsThere's an unnecessary double empty line after the imports.
from lighter.callbacks.writer.table import LighterTableWriter -
1-16
: Consider improving test organization and coverageThe test file would benefit from several architectural improvements:
- Add pytest fixtures for common setup/teardown
- Use parametrize for testing multiple scenarios
- Add integration tests with other writers
- Consider adding a test class for better organization
Here's a suggested structure:
import pytest from pathlib import Path import torch from lighter.callbacks.writer.table import LighterTableWriter @pytest.fixture def table_writer(): writer = LighterTableWriter(path="test.csv", writer="tensor") yield writer Path("test.csv").unlink(missing_ok=True) class TestLighterTableWriter: @pytest.mark.parametrize("tensor,id", [ (torch.tensor([1, 2, 3]), 1), (torch.tensor([]), 2), (torch.randn(1000), 3), ]) def test_write(self, table_writer, tensor, id): table_writer.write(tensor=tensor, id=id) assert len(table_writer.csv_records) > 0 # Add more assertionstests/unit/test_callbacks_writer_file.py (1)
1-19
: Consider adopting a more structured testing approach.Given this is part of a larger unit testing effort, consider:
- Using pytest fixtures for common setup/teardown
- Implementing parametrized tests for different tensor types and sizes
- Adding integration tests with other writers
- Adding property-based testing using hypothesis for tensor generation
Example fixture implementation:
@pytest.fixture def temp_writer_dir(): """Fixture providing a temporary directory for writer tests.""" test_dir = Path("test_dir_fixture") test_dir.mkdir(exist_ok=True) yield test_dir shutil.rmtree(test_dir) @pytest.fixture def file_writer(temp_writer_dir): """Fixture providing a configured LighterFileWriter instance.""" return LighterFileWriter(path=temp_writer_dir, writer="tensor")tests/unit/test_callbacks_utils.py (1)
6-9
: Enhance test coverage and documentation for 2D image preprocessing.While the basic shape verification is good, consider these improvements:
- Add docstring explaining the test purpose and expected behavior
- Verify the content/values of the processed image, not just shape
- Add edge cases with different batch sizes and channel counts
- Document why these specific dimensions (64x64) were chosen
Here's a suggested enhancement:
def test_preprocess_image_2d(): + """Test 2D image preprocessing. + + Verifies that preprocess_image correctly handles 2D images by: + 1. Removing batch dimension + 2. Preserving channel and spatial dimensions + 3. Maintaining pixel values + """ image = torch.rand(1, 3, 64, 64) # Batch of 2D images processed_image = preprocess_image(image) assert processed_image.shape == (3, 64, 64) + # Verify content preservation + assert torch.allclose(processed_image, image.squeeze(0)) + + # Test with different batch and channel sizes + image_multi = torch.rand(4, 1, 32, 32) + processed_multi = preprocess_image(image_multi) + assert processed_multi.shape == (1, 32, 32)tests/unit/test_utils_patches.py (3)
1-3
: Remove unused importModule
The
Module
class fromtorch.nn
is imported but never used in the test file.-from torch.nn import Linear, Module +from torch.nn import Linear🧰 Tools
🪛 Ruff (0.8.0)
1-1:
torch.nn.Module
imported but unusedRemove unused import:
torch.nn.Module
(F401)
6-25
: Good test coverage, but could be enhancedThe test effectively covers basic functionality with reserved names. Here are some suggestions for improvement:
- Add a docstring explaining the test's purpose and importance
- Consider using
@pytest.mark.parametrize
for testing different reserved names- Add error case testing (e.g., invalid values, None keys)
Here's a suggested enhancement:
+import pytest + def test_patched_module_dict_handles_reserved_names(): + """ + Verify PatchedModuleDict properly handles Python reserved names. + + This test ensures that potentially problematic Python reserved names + can be used as keys without conflicts, maintaining proper dictionary + functionality for access and deletion operations. + """ # Test with previously problematic reserved namesWould you like me to provide a complete example with parametrization and additional test cases?
11-12
: Enhance assertions for Linear module verificationWhile the test verifies basic dictionary operations, it should also verify that the Linear modules maintain their properties and can perform their intended operations.
Add these assertions after line 21:
# Verify Linear modules maintain their properties assert isinstance(patched_dict["forward"], Linear) assert patched_dict["forward"].in_features == 10 assert patched_dict["forward"].out_features == 10 # Verify module operations still work import torch input_tensor = torch.randn(1, 10) output = patched_dict["forward"](input_tensor) assert output.shape == (1, 10)Also applies to: 19-21
tests/unit/test_utils_model.py (2)
7-14
: Add docstring and type hints to the test fixture.While the implementation is correct, adding documentation would improve maintainability.
class DummyModel(torch.nn.Module): + """Test fixture representing a simple neural network with two linear layers. + + Used for testing layer manipulation utilities. + """ def __init__(self): super().__init__() self.layer1 = Linear(10, 10) self.layer2 = Linear(10, 10) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.layer2(self.layer1(x))
30-37
: Refactor duplicate model creation and add edge case tests.
- The Sequential model creation is duplicated. Consider using a fixture.
- Add test cases for edge cases like removing more layers than available.
+@pytest.fixture +def sequential_model(): + return Sequential(Linear(10, 10), Linear(10, 10), Linear(10, 10)) + def test_remove_n_last_layers_sequentially(): - model = Sequential(Linear(10, 10), Linear(10, 10), Linear(10, 10)) + model = sequential_model() new_model = remove_n_last_layers_sequentially(model, num_layers=1) assert len(new_model) == 2 - model = Sequential(Linear(10, 10), Linear(10, 10), Linear(10, 10)) + model = sequential_model() new_model = remove_n_last_layers_sequentially(model, num_layers=2) assert len(new_model) == 1 + + # Test edge case + with pytest.raises(ValueError): + remove_n_last_layers_sequentially(model, num_layers=4)tests/unit/test_callbacks_freezer.py (4)
11-23
: Add docstring and consider adding activation functions.The DummyModel could be improved by:
- Adding a docstring explaining its purpose in testing.
- Adding activation functions between layers for a more realistic test scenario.
- Documenting why this specific architecture was chosen.
class DummyModel(Module): + """A simple neural network for testing the LighterFreezer callback. + + Architecture: + input (10) -> linear -> linear (4) -> linear (1) + """ def __init__(self): super().__init__() self.layer1 = torch.nn.Linear(10, 10) + self.relu1 = torch.nn.ReLU() self.layer2 = torch.nn.Linear(10, 4) + self.relu2 = torch.nn.ReLU() self.layer3 = torch.nn.Linear(4, 1) def forward(self, x): x = self.layer1(x) + x = self.relu1(x) x = self.layer2(x) + x = self.relu2(x) x = self.layer3(x) return x
25-31
: Enhance test coverage with more diverse data.The DummyDataset could be improved by:
- Adding a docstring explaining its testing purpose.
- Increasing the dataset size for more robust testing.
- Generating varied targets instead of all zeros.
class DummyDataset(Dataset): + """A mock dataset for testing the LighterFreezer callback. + + Generates random input tensors and corresponding target values. + """ + def __init__(self, size=100): + self.size = size + def __len__(self): - return 10 + return self.size def __getitem__(self, idx): - return {"input": torch.randn(10), "target": torch.tensor(0)} + return { + "input": torch.randn(10), + "target": torch.randint(0, 2, (1,)) # Binary targets for variety + }
33-40
: Document fixture and parameterize test configurations.The fixture could be improved by:
- Adding a docstring explaining its purpose and usage.
- Making batch_size and optimizer configurations parameterizable.
- Using a more stable optimizer like Adam for testing.
@pytest.fixture -def dummy_system(): +def dummy_system(batch_size=32): + """Create a LighterSystem instance for testing. + + Args: + batch_size (int, optional): Batch size for training. Defaults to 32. + + Returns: + LighterSystem: A system configured with dummy model and data. + """ model = DummyModel() - optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) dataset = DummyDataset() criterion = torch.nn.CrossEntropyLoss() - return LighterSystem(model=model, batch_size=32, criterion=criterion, optimizer=optimizer, datasets={"train": dataset}) + return LighterSystem( + model=model, + batch_size=batch_size, + criterion=criterion, + optimizer=optimizer, + datasets={"train": dataset} + )
1-65
: Consider restructuring tests using a proper test class.The test suite would benefit from being organized into a proper test class with shared setup and teardown methods. This would:
- Reduce code duplication
- Provide better organization of test cases
- Allow for shared fixtures and helper methods
- Make it easier to add new test cases
Example structure:
@pytest.mark.freezer class TestLighterFreezer: @pytest.fixture(autouse=True) def setup(self): self.model = DummyModel() self.dataset = DummyDataset() # ... other setup code def test_initialization(self): # ... initialization tests def test_functionality(self): # ... functionality tests def test_exceptions(self): # ... exception teststests/unit/test_system.py (4)
15-48
: Add docstrings to test helper classes.While the implementation is correct, adding docstrings would improve code maintainability by documenting:
- Purpose of each test helper class
- Expected shapes and formats of the data
- Relationship to the system under test
Example docstring for DummyDataset:
class DummyDataset(Dataset): """Test dataset that generates random image-like tensors and class labels. Returns batches in the format: {'input': torch.Tensor(3, 32, 32), 'target': torch.Tensor()} """
50-79
: Consider adding validation for dataset configuration.The DummySystem initialization looks good, but consider adding validation to ensure:
- All required dataset splits are provided
- Metrics configuration matches the dataset splits
- Dataset sizes are appropriate for batch size
Example validation:
def __init__(self): # ... existing code ... required_splits = {'train', 'val', 'test'} if not all(split in datasets for split in required_splits): raise ValueError(f"Missing required dataset splits: {required_splits - set(datasets.keys())}") if not all(split in metrics for split in required_splits): raise ValueError(f"Missing metrics for splits: {required_splits - set(metrics.keys())}")
110-125
: Enhance step function tests with edge cases.While the basic functionality is well tested, consider adding tests for:
- Empty batches
- Batches with different tensor types (float32, float64)
- GPU/CPU tensor handling
- Gradient calculation verification
Example edge case test:
def test_training_step_empty_batch(dummy_system): empty_batch = {"input": torch.empty(0, 3, 32, 32), "target": torch.empty(0)} with mock.patch.object(dummy_system, "log"): with pytest.raises(ValueError, match="Empty batch"): dummy_system.training_step(empty_batch, batch_idx=0)Also applies to: 127-139, 141-152
162-188
: Add more batch format test cases.The parameterized tests are good, but consider adding test cases for:
- Different input shapes
- Multi-target scenarios
- Missing optional keys
- Type mismatches between input and target
Example additional test cases:
@pytest.mark.parametrize( "batch", [ # Different input shapes {"input": torch.randn(1, 1, 64, 64), "target": torch.randint(0, 10, size=(1,))}, # Multi-target { "input": torch.randn(1, 3, 32, 32), "target": {"class": torch.randint(0, 10, size=(1,)), "bbox": torch.randn(1, 4)} }, ] )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (15)
.gitignore
(1 hunks)lighter/utils/runner.py
(1 hunks)tests/unit/test_callbacks_freezer.py
(1 hunks)tests/unit/test_callbacks_utils.py
(1 hunks)tests/unit/test_callbacks_writer_base.py
(1 hunks)tests/unit/test_callbacks_writer_file.py
(1 hunks)tests/unit/test_callbacks_writer_table.py
(1 hunks)tests/unit/test_system.py
(1 hunks)tests/unit/test_utils_collate.py
(1 hunks)tests/unit/test_utils_dynamic_imports.py
(1 hunks)tests/unit/test_utils_logging.py
(1 hunks)tests/unit/test_utils_misc.py
(1 hunks)tests/unit/test_utils_model.py
(1 hunks)tests/unit/test_utils_patches.py
(1 hunks)tests/unit/test_utils_runner.py
(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- .gitignore
🧰 Additional context used
🪛 Ruff (0.8.0)
tests/unit/test_callbacks_writer_base.py
8-8: Local variable writer
is assigned to but never used
Remove assignment to unused variable writer
(F841)
tests/unit/test_utils_patches.py
1-1: torch.nn.Module
imported but unused
Remove unused import: torch.nn.Module
(F401)
🔇 Additional comments (16)
tests/unit/test_utils_misc.py (1)
1-2
: LGTM! Clean and specific import.
The import statement follows good practices by importing only the specific function being tested.
tests/unit/test_utils_logging.py (1)
1-6
: Verify logging configuration persistence
Let's verify that the logging configuration persists across the codebase and that there are no conflicting configurations.
✅ Verification successful
Let me check the implementation of the logging setup to ensure it's properly configured.
Logging configuration is properly centralized and safe
The logging setup is well-implemented with a single configuration point in lighter/utils/logging.py
that:
- Properly intercepts and redirects all Python logging to Loguru
- Is imported and initialized early in
lighter/__init__.py
- Has no competing or conflicting configurations across the codebase
- Uses a custom formatter and handler to ensure consistent logging behavior
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for multiple logging configuration points that might conflict
# Search for logging configuration in Python files
echo "Searching for logging configuration points..."
rg -l "logging\.(basicConfig|config)" --type py
# Search for logger instantiation
echo "\nSearching for logger instantiation points..."
rg "logging\.getLogger\(" --type py -A 2
# Search for custom logging setup
echo "\nSearching for custom logging setup..."
rg "setup_logging" --type py
Length of output: 919
Script:
#!/bin/bash
# Get the implementation of _setup_logging
ast-grep --pattern 'def _setup_logging():
$$$'
Length of output: 4259
tests/unit/test_utils_runner.py (2)
1-4
: LGTM! Clean and minimal imports.
The imports are appropriate and follow best practices by importing only what's needed for the test.
1-8
: Verify test implementation and coverage.
Let's ensure the test structure aligns with the implementation.
✅ Verification successful
Let me gather more information about the implementation of parse_config
to ensure proper test coverage.
Let me try one more verification to see the complete implementation and usage:
Test implementation aligns with the function's behavior
The test correctly verifies the error case when no config is provided to parse_config()
. The implementation shows that parse_config()
explicitly raises a ValueError
when "config"
is not in the kwargs, which is exactly what the test is checking. The test is also properly located in the unit test directory and is the only test file for the runner module.
- ✓ Test location is correct (tests/unit/test_utils_runner.py)
- ✓ Implementation matches test expectations
- ✓ No duplicate tests exist
⚠️ Consider adding more test cases for other scenarios (config validation, overrides)
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify test location, coverage, and implementation
# Check if there are other test files for runner.py
echo "Checking for other runner tests..."
fd test_.*runner\.py tests/
# Verify the parse_config implementation
echo "Checking parse_config implementation..."
ast-grep --pattern 'def parse_config($$$)'
# Check for similar existing tests
echo "Checking for similar tests..."
rg -A 5 'test_parse_config' tests/
Length of output: 569
Script:
#!/bin/bash
# Let's check the actual implementation of parse_config
echo "Checking parse_config implementation in lighter/utils/runner.py..."
rg -A 10 "def parse_config" lighter/utils/runner.py
# Also check if there are any other related functions or usages
echo "Checking related functions and usages..."
rg "parse_config" --type python
Length of output: 756
Script:
#!/bin/bash
# Check the complete implementation file
echo "Checking complete runner.py implementation..."
cat lighter/utils/runner.py
# Check for imports and usage without file type restriction
echo "Checking all parse_config references..."
rg "parse_config"
Length of output: 4592
tests/unit/test_callbacks_writer_base.py (1)
6-8
: Enhance test coverage with docstring and positive cases.
While testing the error case is important, consider:
- Adding a docstring to describe the test's purpose
- Including positive test cases to verify successful initialization scenarios
def test_writer_initialization():
+ """
+ Test LighterBaseWriter initialization:
+ - Raises TypeError for invalid parameters
+ - Successfully initializes with valid parameters
+ """
with pytest.raises(TypeError):
- writer = LighterBaseWriter(path="test", writer="tensor")
+ LighterBaseWriter(path="test", writer="tensor")
+
+ # Test successful initialization
+ writer = LighterBaseWriter(path="test", writer="valid_type") # Replace with actual valid type
+ assert isinstance(writer, LighterBaseWriter)
🧰 Tools
🪛 Ruff (0.8.0)
8-8: Local variable writer
is assigned to but never used
Remove assignment to unused variable writer
(F841)
tests/unit/test_utils_dynamic_imports.py (1)
6-8
: LGTM! The error handling test is well-implemented.
The test correctly verifies that attempting to import a non-existent module raises a FileNotFoundError.
tests/unit/test_utils_collate.py (1)
1-2
: Verify import usage across codebase
Let's verify that the imported collate_replace_corrupted
function exists and is properly exported.
✅ Verification successful
Import verified and properly implemented
The collate_replace_corrupted
function is correctly implemented in lighter/utils/collate.py
and is being properly imported in both the test file and lighter/system.py
. The function exists, is exported, and is being used appropriately across the codebase.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify the existence and export of collate_replace_corrupted function
# Check if the function exists in the source file
ast-grep --pattern 'def collate_replace_corrupted($$$)'
# Check if the function is properly exported
rg -l "^from.*collate_replace_corrupted import" || rg -l "^from.*collate import.*collate_replace_corrupted"
Length of output: 217
Script:
#!/bin/bash
# Let's verify the actual implementation and location of the function
rg -p "def collate_replace_corrupted" -g '*.py'
# Also check the content of the collate module to ensure proper export
fd collate.py --exec cat {}
Length of output: 3027
tests/unit/test_callbacks_writer_file.py (1)
1-7
: LGTM! Clean and minimal imports.
The imports are well-organized and include only the necessary dependencies.
tests/unit/test_callbacks_utils.py (2)
1-4
: LGTM! Imports are clean and minimal.
The imports are well-organized and include only the necessary dependencies.
12-19
: 🛠️ Refactor suggestion
Verify the expected output shape transformation for 3D images.
The expected output shape (1, depth*height, batch_size*width)
seems unusual. This transformation:
- Combines spatial dimensions (depth*height)
- Multiplies width by batch_size
- Sets channels to 1
Please verify if this is the intended behavior of preprocess_image
.
Additionally, consider these improvements:
- Add docstring explaining the transformation logic
- Verify pixel values are preserved correctly
- Add error handling tests (e.g., invalid dimensions)
Here's a suggested enhancement:
def test_preprocess_image_3d():
+ """Test 3D image preprocessing.
+
+ Verifies that preprocess_image correctly handles 3D images by:
+ 1. Reshaping spatial dimensions as per model requirements
+ 2. Preserving total pixel count and values
+ 3. Handling batch dimension appropriately
+ """
batch_size = 8
depth = 20
height = 64
width = 64
image = torch.rand(batch_size, 1, depth, height, width) # Batch of 3D images
processed_image = preprocess_image(image)
assert processed_image.shape == (1, depth * height, batch_size * width)
+
+ # Verify total element count is preserved
+ assert image.numel() == processed_image.numel()
+
+ # Test error handling
+ with pytest.raises(ValueError, match="Invalid dimensions"):
+ invalid_image = torch.rand(8, 1, 20) # Missing spatial dimensions
+ preprocess_image(invalid_image)
✅ Verification successful
Let me search for more context about the preprocess_image function implementation.
The test correctly verifies the intended shape transformation
The test is correct and matches the implementation's documented behavior. From the source code in lighter/callbacks/utils.py
, we can see that:
-
The function is specifically designed to handle 3D images by:
- Combining depth and height dimensions vertically (
shape[2] * shape[3]
) - Concatenating batch images horizontally (
B*W
) - This matches exactly with the test's expected shape
(1, depth*height, batch_size*width)
- Combining depth and height dimensions vertically (
-
The transformation is intentional for visualization purposes, as documented:
"In case of 3D, a single image is displayed with slices stacked vertically, while a batch of 3D images as a grid where each column is a different 3D image"
The test's assertions correctly verify this visualization-oriented transformation.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for the preprocess_image implementation and its usage
ast-grep --pattern 'def preprocess_image($$$)' -A 10
rg -A 5 'preprocess_image.*3[dD]'
Length of output: 479
Script:
#!/bin/bash
# Search for preprocess_image implementation in the codebase
rg -l "def preprocess_image"
# Once we find the file, let's look at its content
rg -B 5 -A 15 "def preprocess_image"
# Let's also search for any related documentation or comments
rg -i "preprocess.*image"
Length of output: 2409
tests/unit/test_utils_model.py (1)
1-4
: LGTM! Imports are well-organized.
The imports are appropriately separated between external (PyTorch) and internal (lighter) dependencies, and all imported components are utilized in the tests.
tests/unit/test_callbacks_freezer.py (2)
1-9
: LGTM! Imports are well-organized and complete.
42-65
: 🛠️ Refactor suggestion
Enhance test coverage with docstrings and edge cases.
The test suite needs improvement in several areas:
- Missing docstrings for all test functions
- No negative test cases
- Limited testing of error conditions
- Single epoch might not reveal training-related issues
Here's how to enhance the test coverage:
+def test_freezer_invalid_initialization():
+ """Test that LighterFreezer raises ValueError for invalid layer names."""
+ with pytest.raises(ValueError):
+ LighterFreezer(names=["non_existent_layer"])
def test_freezer_initialization():
+ """Test that LighterFreezer correctly initializes with valid layer names."""
freezer = LighterFreezer(names=["layer1"])
assert freezer.names == ["layer1"]
+ assert isinstance(freezer.names, list)
def test_freezer_functionality(dummy_system):
+ """Test that LighterFreezer correctly freezes specified layers during training."""
freezer = LighterFreezer(names=["layer1.weight", "layer1.bias"])
- trainer = Trainer(callbacks=[freezer], max_epochs=1)
+ trainer = Trainer(callbacks=[freezer], max_epochs=3) # More epochs for stability
trainer.fit(dummy_system)
+
+ # Test initial parameter states
assert not dummy_system.model.layer1.weight.requires_grad
assert not dummy_system.model.layer1.bias.requires_grad
assert dummy_system.model.layer2.weight.requires_grad
+
+ # Store and verify frozen parameters don't change
+ initial_weights = dummy_system.model.layer1.weight.clone()
+ trainer.fit(dummy_system)
+ assert torch.allclose(initial_weights, dummy_system.model.layer1.weight)
def test_freezer_with_exceptions(dummy_system):
+ """Test that LighterFreezer correctly handles exception patterns during training."""
freezer = LighterFreezer(name_starts_with=["layer"], except_names=["layer2.weight", "layer2.bias"])
- trainer = Trainer(callbacks=[freezer], max_epochs=1)
+ trainer = Trainer(callbacks=[freezer], max_epochs=3)
trainer.fit(dummy_system)
+
+ # Verify frozen state of all layers
assert not dummy_system.model.layer1.weight.requires_grad
assert not dummy_system.model.layer1.bias.requires_grad
assert dummy_system.model.layer2.weight.requires_grad
assert dummy_system.model.layer2.bias.requires_grad
assert not dummy_system.model.layer3.weight.requires_grad
assert not dummy_system.model.layer3.bias.requires_grad
+
+ # Verify training affects only unfrozen layers
+ initial_l1_weights = dummy_system.model.layer1.weight.clone()
+ initial_l2_weights = dummy_system.model.layer2.weight.clone()
+ trainer.fit(dummy_system)
+ assert torch.allclose(initial_l1_weights, dummy_system.model.layer1.weight)
+ assert not torch.allclose(initial_l2_weights, dummy_system.model.layer2.weight)
Let's verify the test coverage:
lighter/utils/runner.py (1)
39-39
: LGTM! Consider adding version compatibility handling.
The change from .dict()
to .model_dump()
appears to be an update for compatibility with newer versions of Pydantic (v2.x). While the change is correct, we should ensure compatibility across different Pydantic versions.
Let's verify the Pydantic version and its compatibility:
Consider adding version compatibility handling:
try:
parser = ConfigParser(ConfigSchema().model_dump(), globals=False)
except AttributeError: # Fallback for older Pydantic versions
parser = ConfigParser(ConfigSchema().dict(), globals=False)
✅ Verification successful
Change is correct and consistent with dependencies
The change from .dict()
to .model_dump()
is appropriate as the project explicitly depends on Pydantic v2 (pydantic = "^2.8.2"
in pyproject.toml). The codebase consistently uses model_dump()
across multiple files, and the schema implementation in schema.py
follows Pydantic v2 patterns (e.g., using model_validator
, model_fields
, etc.). No compatibility layer is needed since the project has standardized on Pydantic v2.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check Pydantic version requirements and usage
# Check for Pydantic version in requirements
rg -l "pydantic[>=<~]"
# Check for other potential .dict() usage that might need updates
rg "\.dict\(\)"
# Check for existing version handling
rg "pydantic.*version"
Length of output: 67
Script:
#!/bin/bash
# Let's try a different approach to check dependencies and usage
# Check for pyproject.toml, setup.py, or requirements files
fd "pyproject.toml|setup.py|requirements.*\.txt"
# Check for pydantic imports and usage
rg -l "from pydantic|import pydantic"
# Check for model_dump usage to see if other files were updated
rg "model_dump"
# Look for ConfigSchema usage across the codebase
rg "ConfigSchema"
Length of output: 1679
Script:
#!/bin/bash
# Let's check the pyproject.toml for pydantic version and examine schema.py for BaseModel usage
# Check pyproject.toml content for dependencies
cat pyproject.toml
# Check the schema implementation
cat lighter/utils/schema.py
Length of output: 8683
tests/unit/test_system.py (2)
1-12
: LGTM! Well-organized imports.
The imports are properly organized following Python's import style guidelines, with standard library imports first, followed by third-party and local imports.
1-188
: Overall well-structured test suite with good coverage.
The test implementation demonstrates good practices:
- Proper use of pytest fixtures and parameterization
- Clear separation of test helper classes
- Comprehensive coverage of core functionality
Consider the suggested improvements for documentation and edge cases, but the current implementation is solid and ready for use.
def test_ensure_list(): | ||
assert ensure_list(1) == [1] | ||
assert ensure_list([1, 2]) == [1, 2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
The current test cases look good, but coverage could be improved.
While the existing test cases correctly verify basic functionality, consider adding more test cases to ensure comprehensive coverage:
def test_ensure_list():
+ """Test ensure_list utility function handles various input types correctly."""
assert ensure_list(1) == [1]
assert ensure_list([1, 2]) == [1, 2]
+ # Test None input
+ assert ensure_list(None) == [], "Should handle None input"
+ # Test empty list
+ assert ensure_list([]) == [], "Should preserve empty list"
+ # Test string input
+ assert ensure_list("test") == ["test"], "Should convert string to list"
+ # Test tuple input
+ assert ensure_list((1, 2)) == [1, 2], "Should convert tuple to list"
+ # Test nested structures
+ assert ensure_list([[1, 2]]) == [[1, 2]], "Should preserve nested lists"
Also consider:
- Adding descriptive messages to assertions for better error reporting
- Using pytest's parametrize decorator for cleaner test organization
- Adding docstring to explain the test's purpose
Committable suggestion skipped: line range outside the PR's diff.
from lighter.utils.misc import ensure_list | ||
|
||
|
||
def test_ensure_list(): | ||
assert ensure_list(1) == [1] | ||
assert ensure_list([1, 2]) == [1, 2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Test cases need to be expanded to cover all implementation cases
The current test cases only cover two scenarios (ensure_list(1)
and ensure_list([1, 2])
), but the implementation handles four distinct cases:
- List input (returns as-is)
- Tuple input (converts to list)
- None input (returns empty list)
- Other types (wraps in list)
Missing test cases:
ensure_list(None)
should return[]
ensure_list((1, 2))
should return[1, 2]
🔗 Analysis chain
Verify ensure_list implementation details
Let's verify the actual implementation to ensure our suggested test cases align with the intended behavior.
Let me try another approach to find the implementation details.
Let me get the complete implementation to verify the test cases.
Let me get a few more lines of the implementation as it seems truncated.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Find and display the implementation of ensure_list
echo "Looking for ensure_list implementation:"
ast-grep --pattern 'def ensure_list($_):
$$$
'
# Find any other existing tests for ensure_list
echo -e "\nLooking for other existing tests:"
rg -l "test.*ensure_list" tests/
Length of output: 287
Script:
#!/bin/bash
# Search for ensure_list implementation using ripgrep for better context
echo "Looking for ensure_list implementation in source files:"
rg "def ensure_list" -A 5 lighter/
Length of output: 480
Script:
#!/bin/bash
# Get the complete implementation of ensure_list
echo "Complete ensure_list implementation:"
rg "def ensure_list" -A 10 lighter/utils/misc.py
Length of output: 419
Script:
#!/bin/bash
# Get the complete implementation with more context
echo "Complete ensure_list implementation:"
rg "def ensure_list" -A 15 lighter/utils/misc.py
Length of output: 538
def test_replace_layer_with(): | ||
model = DummyModel() | ||
new_layer = Linear(10, 4) | ||
replace_layer_with(model, "layer1", new_layer) | ||
assert model.layer1 == new_layer | ||
|
||
|
||
def test_replace_layer_with_identity(): | ||
model = DummyModel() | ||
replace_layer_with_identity(model, "layer1") | ||
assert isinstance(model.layer1, torch.nn.Identity) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add test cases for error conditions and shape validation.
The current tests verify successful operations but don't cover error cases. Consider adding tests for:
- Invalid layer names in
replace_layer_with
- Shape mismatches when replacing layers
- Input/output tensor shape validation
Example test case to add:
def test_replace_layer_with_invalid_name():
model = DummyModel()
new_layer = Linear(10, 4)
with pytest.raises(AttributeError):
replace_layer_with(model, "non_existent_layer", new_layer)
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (6)
tests/unit/test_callbacks_writer_table.py (2)
13-37
: Add more comprehensive write test scenariosWhile the current test covers basic functionality and edge cases, consider adding:
- CSV header verification
- Test for concurrent writes
- Error cases (file permission issues, disk full scenarios)
- Verification of CSV format compliance
def test_table_writer_write(): """Test LighterTableWriter write functionality with various inputs.""" test_file = Path("test.csv") writer = LighterTableWriter(path="test.csv", writer="tensor") # Test basic write test_tensor = torch.tensor([1, 2, 3]) writer.write(tensor=test_tensor, id=1) assert len(writer.csv_records) == 1 assert writer.csv_records[0]["tensor"] == test_tensor.tolist() assert writer.csv_records[0]["id"] == 1 # Test edge cases writer.write(tensor=torch.tensor([]), id=2) # empty tensor writer.write(tensor=torch.randn(1000), id=3) # large tensor writer.write(tensor=torch.tensor([1.5, 2.5]), id=4) # float tensor # Verify file creation and content assert test_file.exists() with open(test_file) as f: + # Verify CSV header + header = f.readline().strip() + assert header == "id,tensor" + + # Verify content format content = f.read() assert "1,2,3" in content # verify first tensor + # Verify CSV format compliance + import csv + f.seek(0) + reader = csv.DictReader(f) + rows = list(reader) + assert len(rows) == 4 # All records present + # Test concurrent writes + import threading + def concurrent_write(): + writer.write(tensor=torch.tensor([5,6,7]), id=5) + thread = threading.Thread(target=concurrent_write) + thread.start() + thread.join() + assert len(writer.csv_records) == 5 + + # Test error cases + import os + # Test file permission error + test_file.chmod(0o000) # Remove all permissions + with pytest.raises(PermissionError): + writer.write(tensor=torch.tensor([8,9,10]), id=6) + test_file.chmod(0o666) # Restore permissions + # Cleanup test_file.unlink()
1-37
: Consider architectural improvements for better test organizationTo improve test maintainability and reduce duplication:
- Use pytest fixtures for common setup (writer initialization, test file creation)
- Mock file operations to avoid actual file I/O in tests
- Create test data fixtures for different tensor scenarios
Example fixture structure:
import pytest from unittest.mock import patch, mock_open @pytest.fixture def table_writer(): writer = LighterTableWriter(path="test.csv", writer="tensor") yield writer @pytest.fixture def test_tensors(): return { 'basic': torch.tensor([1, 2, 3]), 'empty': torch.tensor([]), 'large': torch.randn(1000), 'float': torch.tensor([1.5, 2.5]) } # Then use in tests: def test_table_writer_write(table_writer, test_tensors): with patch('builtins.open', mock_open()) as mock_file: table_writer.write(tensor=test_tensors['basic'], id=1) # Test assertions...tests/unit/test_utils_collate.py (4)
4-12
: Consider enhancing the docstring with function behavior details.While the docstring effectively lists what is being tested, it could be more helpful to include:
- Brief description of what
collate_replace_corrupted
does- Expected behavior for each test case
def test_collate_replace_corrupted(): """Test collate_replace_corrupted function handles corrupted data correctly. + The collate_replace_corrupted function replaces None values in a batch + with random values from a provided dataset while preserving non-None values. + Tests: - - Corrupted values (None) are replaced with valid dataset values - - Non-corrupted values remain unchanged - - Output maintains correct length - - Edge cases: empty batch, all corrupted values + - Corrupted values (None) are replaced with valid dataset values -> Should pick random values from dataset + - Non-corrupted values remain unchanged -> Original values should be preserved + - Output maintains correct length -> Output batch should match input batch length + - Edge cases: + - empty batch -> Should return empty list + - all corrupted values -> Should replace all with valid dataset values """
13-26
: Consider using test fixtures or setup data.The test data could be better organized using test fixtures or setup variables at the module level, especially if these test cases will be reused in other test functions.
+# At module level +SAMPLE_BATCH = [1, None, 2, None, 3] +SAMPLE_DATASET = [1, 2, 3, 4, 5] + def test_collate_replace_corrupted(): """...""" - batch = [1, None, 2, None, 3] - dataset = [1, 2, 3, 4, 5] + batch = SAMPLE_BATCH + dataset = SAMPLE_DATASET
28-35
: Consider adding an empty dataset edge case.The edge cases are well covered, but it might be valuable to test the behavior when the dataset is empty, as this could lead to potential issues.
# Test edge cases empty_batch = [] assert len(collate_replace_corrupted(empty_batch, dataset)) == 0 + # Test empty dataset + empty_dataset = [] + with pytest.raises(ValueError, match="Dataset cannot be empty"): + collate_replace_corrupted(batch, empty_dataset) + all_corrupted = [None, None, None] collated_all_corrupted = collate_replace_corrupted(all_corrupted, dataset) assert len(collated_all_corrupted) == len(all_corrupted) assert all(val in dataset for val in collated_all_corrupted)Note: This assumes the function should raise a ValueError for empty datasets. If the behavior should be different, adjust the test accordingly.
1-1
: Add pytest import for the suggested test.If you implement the empty dataset test, you'll need to add:
+import pytest from lighter.utils.collate import collate_replace_corrupted
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (4)
tests/unit/test_callbacks_writer_base.py
(1 hunks)tests/unit/test_callbacks_writer_file.py
(1 hunks)tests/unit/test_callbacks_writer_table.py
(1 hunks)tests/unit/test_utils_collate.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/unit/test_callbacks_writer_base.py
- tests/unit/test_callbacks_writer_file.py
🔇 Additional comments (3)
tests/unit/test_callbacks_writer_table.py (2)
1-7
: LGTM!
The imports are well-organized and include all necessary dependencies.
8-11
: 🛠️ Refactor suggestion
Enhance test coverage for initialization
The test needs improvement in several areas:
- Add docstring explaining test purpose
- Verify the
writer
attribute - Add negative test cases (invalid writer types, empty paths)
- Add type hints for better code documentation
+from typing import Any
+
def test_table_writer_initialization():
+ """Test LighterTableWriter initialization with valid and invalid inputs.
+
+ Tests:
+ - Valid initialization with correct path and writer
+ - Invalid writer type raises ValueError
+ - Empty path raises ValueError
+ """
writer = LighterTableWriter(path="test.csv", writer="tensor")
assert writer.path == Path("test.csv")
+ assert writer.writer == "tensor"
+
+ # Test invalid writer type
+ with pytest.raises(ValueError, match="Invalid writer type"):
+ LighterTableWriter(path="test.csv", writer="invalid")
+
+ # Test empty path
+ with pytest.raises(ValueError, match="Path cannot be empty"):
+ LighterTableWriter(path="", writer="tensor")
tests/unit/test_utils_collate.py (1)
1-3
: LGTM! Clean and focused import.
The import statement is properly focused on the specific function being tested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (5)
tests/unit/test_utils_collate.py (2)
1-4
: Add a comment explaining the torch dependencyWhile the
torch
import is necessary (the collated batch contains tensor items), its usage isn't immediately obvious. Consider adding a comment to explain why it's needed.+# torch is required as collate_replace_corrupted returns tensor items import torch
🧰 Tools
🪛 Ruff (0.8.0)
1-1:
torch
imported but unusedRemove unused import:
torch
(F401)
15-37
: Enhance test implementationWhile the test coverage is good, there are a few improvements to consider:
- The empty batch test case mentioned in the docstring is missing
- The filtered batch comparison could be more robust using sets
- Consider using parametrized tests for better maintainability
Here's a suggested implementation:
+import pytest + +@pytest.mark.parametrize("batch, dataset, expected_length", [ + ([1, None, 2, None, 3], [1, 2, 3, 4, 5], 5), # Normal case + ([None, None, None], [1, 2, 3, 4, 5], 3), # All corrupted + ([], [1, 2, 3, 4, 5], 0), # Empty batch +]) def test_collate_replace_corrupted(): """Test collate_replace_corrupted function handles corrupted data correctly. ... """ - batch = [1, None, 2, None, 3] - dataset = [1, 2, 3, 4, 5] collated_batch = collate_replace_corrupted(batch, dataset) # Test length - assert len(collated_batch) == len(batch) + assert len(collated_batch) == expected_length # Test non-corrupted values remain unchanged. filtered_batch = list(filter(lambda x: x is not None, batch)) - assert collated_batch[0].item() == filtered_batch[0] - assert collated_batch[1].item() == filtered_batch[1] - assert collated_batch[2].item() == filtered_batch[2] + if filtered_batch: # Skip for all-corrupted or empty batch + filtered_values = set(val.item() for val in collated_batch[:len(filtered_batch)]) + assert filtered_values == set(filtered_batch) # Test corrupted values are replaced with valid dataset values - assert collated_batch[3].item() in dataset - assert collated_batch[4].item() in dataset + if len(collated_batch) > len(filtered_batch): + replaced_values = [val.item() for val in collated_batch[len(filtered_batch):]] + assert all(val in dataset for val in replaced_values) - all_corrupted = [None, None, None] - collated_all_corrupted = collate_replace_corrupted(all_corrupted, dataset) - assert len(collated_all_corrupted) == len(all_corrupted) - assert all(val in dataset for val in collated_all_corrupted)lighter/callbacks/writer/table.py (1)
Line range hint
41-42
: Add type validation in the write methodTo prevent sorting issues downstream, add type validation for the
id
parameter in thewrite
method.def write(self, tensor: Any, id: Union[int, str]) -> None: + if not isinstance(id, (int, str)): + raise TypeError(f"'id' must be int or str, got {type(id)}") self.csv_records.append({"id": id, "pred": self.writer(tensor)})🧰 Tools
🪛 Ruff (0.8.0)
63-66: Use
contextlib.suppress(TypeError)
instead oftry
-except
-pass
Replace with
contextlib.suppress(TypeError)
(SIM105)
tests/unit/test_callbacks_writer_table.py (2)
10-10
: Remove unused importThe
LighterSystem
import is not used in any of the tests.-from lighter.system import LighterSystem
🧰 Tools
🪛 Ruff (0.8.0)
10-10:
lighter.system.LighterSystem
imported but unusedRemove unused import:
lighter.system.LighterSystem
(F401)
67-109
: Consider enhancing multi-process test documentation and coverageThe test effectively validates the multi-process writing scenario with good mocking and assertions. Consider adding:
- Docstring explaining the test's purpose and setup
- Error cases (e.g., gather failures, rank mismatches)
def test_table_writer_write_multi_process(tmp_path, monkeypatch): + """Test LighterTableWriter's behavior in a distributed training setup. + + Tests: + - Record gathering from multiple processes + - Correct CSV file creation with combined records + - Proper handling of distributed environment + """ test_file = tmp_path / "test.csv"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (4)
lighter/callbacks/writer/table.py
(1 hunks)tests/unit/test_callbacks_writer_file.py
(1 hunks)tests/unit/test_callbacks_writer_table.py
(1 hunks)tests/unit/test_utils_collate.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/unit/test_callbacks_writer_file.py
🧰 Additional context used
🪛 Ruff (0.8.0)
lighter/callbacks/writer/table.py
63-66: Use contextlib.suppress(TypeError)
instead of try
-except
-pass
Replace with contextlib.suppress(TypeError)
(SIM105)
tests/unit/test_callbacks_writer_table.py
5-5: pytest
imported but unused
Remove unused import: pytest
(F401)
10-10: lighter.system.LighterSystem
imported but unused
Remove unused import: lighter.system.LighterSystem
(F401)
tests/unit/test_utils_collate.py
1-1: torch
imported but unused
Remove unused import: torch
(F401)
🔇 Additional comments (5)
tests/unit/test_utils_collate.py (1)
6-14
: LGTM! Well-documented test function
The docstring clearly describes the test's purpose and comprehensively lists all test cases.
tests/unit/test_callbacks_writer_table.py (4)
13-14
: LGTM!
The helper function is well-defined and correctly implements a custom writer for testing purposes.
17-19
: Enhance initialization test coverage
The test only verifies the path attribute. Consider adding validation of the writer attribute and negative test cases.
22-26
: LGTM!
The test effectively validates the custom writer functionality with proper assertions.
29-64
: LGTM!
The test is well-structured with:
- Clear test cases including edge cases
- Comprehensive data validation
- Proper file cleanup
Description
Mostly generated with Aider
The unit test filenames naming follows the modules path from lighter. E.g.
lighter/callbacks/writer_file.py
will betest_callbacks_writer_file.py
Related Issue
Type of Change
Checklist
CODE_OF_CONDUCT.md
document.CONTRIBUTING.md
guide.make codestyle
.Summary by CodeRabbit
Bug Fixes
LighterTableWriter
to prevent crashes during DataFrame sorting.New Features
LighterFreezer
,LighterFileWriter
,LighterTableWriter
, andLighterSystem
.preprocess_image
,collate_replace_corrupted
, andensure_list
.Tests
Chores
.gitignore
to streamline version control by ignoring unnecessary files and directories.