Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for all modules #138

Merged
merged 51 commits into from
Nov 30, 2024
Merged

Add unit tests for all modules #138

merged 51 commits into from
Nov 30, 2024

Conversation

ibro45
Copy link
Collaborator

@ibro45 ibro45 commented Nov 29, 2024

Description

Mostly generated with Aider

The unit test filenames naming follows the modules path from lighter. E.g. lighter/callbacks/writer_file.py will be test_callbacks_writer_file.py

Related Issue

Type of Change

  • 📚 Examples / docs / tutorials / dependencies update
  • 🔧 Bug fix (non-breaking change which fixes an issue)
  • 🥂 Improvement (non-breaking change which improves an existing feature)
  • 🚀 New feature (non-breaking change which adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to change)
  • 🔐 Security fix

Checklist

  • I've read the CODE_OF_CONDUCT.md document.
  • I've read the CONTRIBUTING.md guide.
  • I've updated the code style using make codestyle.
  • I've written tests for all new methods and classes that I created.
  • I've written the docstring in Google format for all the methods and classes that I used.

Summary by CodeRabbit

  • Bug Fixes

    • Enhanced error handling in the LighterTableWriter to prevent crashes during DataFrame sorting.
  • New Features

    • Introduced comprehensive unit tests for various components, including LighterFreezer, LighterFileWriter, LighterTableWriter, and LighterSystem.
    • Added tests for utility functions like preprocess_image, collate_replace_corrupted, and ensure_list.
  • Tests

    • Expanded test coverage across multiple modules to ensure functionality and error handling.
    • New tests added for configuration parsing and handling of corrupted data.
  • Chores

    • Updated .gitignore to streamline version control by ignoring unnecessary files and directories.

Copy link
Contributor

coderabbitai bot commented Nov 29, 2024

Walkthrough

The pull request introduces several changes across multiple files. The .gitignore file is updated to ignore specific files and directories. The lighter/utils/runner.py file modifies the parse_config function to change how the configuration schema is initialized. A series of new tests are added across various test files, focusing on components like the LighterFreezer, LighterFileWriter, LighterTableWriter, and LighterSystem, among others. These tests cover initialization, functionality, and error handling for different components and utility functions.

Changes

File Change Summary
.gitignore Added patterns to ignore files starting with "aider" and the test_dir/ directory.
lighter/utils/runner.py Updated parse_config to use ConfigSchema().model_dump() instead of ConfigSchema().dict().
tests/unit/test_callbacks_freezer.py Introduced tests for LighterFreezer, including dummy model and dataset classes, and three test functions.
tests/unit/test_callbacks_utils.py Added tests for preprocess_image function handling 2D and 3D images.
tests/unit/test_callbacks_writer_base.py Introduced test_writer_initialization to check LighterBaseWriter behavior during initialization.
tests/unit/test_callbacks_writer_file.py Added tests for LighterFileWriter, including initialization and writing tensor functionality.
tests/unit/test_callbacks_writer_table.py Introduced tests for LighterTableWriter, verifying initialization and writing capabilities.
tests/unit/test_system.py Comprehensive tests for LighterSystem, including dummy classes and various test functions for training and validation.
tests/unit/test_utils_collate.py Added test_collate_replace_corrupted to check handling of corrupted data in collate_replace_corrupted.
tests/unit/test_utils_dynamic_imports.py Introduced test_import_module_from_path to verify error handling for non-existent modules.
tests/unit/test_utils_logging.py Added test_setup_logging to ensure logging setup executes without exceptions.
tests/unit/test_utils_misc.py Introduced test_ensure_list to verify behavior of ensure_list utility function.
tests/unit/test_utils_model.py Added tests for modifying neural network layers, including DummyModel and three utility functions.
tests/unit/test_utils_patches.py Introduced test_patched_module_dict_handles_reserved_names to validate behavior of PatchedModuleDict.
tests/unit/test_utils_runner.py Added test_parse_config_no_config to check error handling in parse_config when no configuration is provided.

Possibly related PRs

  • Fix #130: The changes in the LighterFileWriter class involve renaming a parameter from directory to path, which aligns with the updates in the .gitignore file that now ignores the test_dir directory, indicating a focus on file management and organization.
  • Switch from manual to pydantic validation of config and LighterSystem #135: The modifications in lighter/utils/runner.py include changes to the parse_config function, which is relevant as the main PR also deals with configuration management by updating the .gitignore to streamline version control, suggesting a broader context of managing project files and configurations.

Suggested reviewers

  • surajpaib

🐇 In the meadow, we hop and play,
New tests and changes brighten the day.
With configs parsed and writers too,
Our code is fresh, and so much to do!
Let's celebrate this joyful spree,
For every line brings harmony! 🐇✨


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@ibro45 ibro45 marked this pull request as ready for review November 30, 2024 00:02
@dosubot dosubot bot added the size:L This PR changes 100-499 lines, ignoring generated files. label Nov 30, 2024
@ibro45 ibro45 changed the title Unit tests Add unit tests for all modules Nov 30, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

🧹 Outside diff range and nitpick comments (23)
tests/unit/test_utils_logging.py (2)

1-1: Consider testing through public interfaces instead of private functions

The test imports _setup_logging which is marked as private (leading underscore). Testing private implementation details directly can make refactoring more difficult. Consider testing through public interfaces instead.


4-6: Add docstring to document test purpose

The test function lacks documentation explaining its purpose and expected behavior. Add a docstring to improve maintainability.

tests/unit/test_utils_runner.py (2)

6-8: Good test! Consider adding more test cases.

The test correctly verifies the error case when no config is provided. However, we could enhance the test coverage.

Consider adding these additional test cases:

  1. Test with empty config
  2. Test with invalid config
  3. Test with minimal valid config

Here's a suggested implementation:

def test_parse_config_empty_config():
    with pytest.raises(ValueError):
        parse_config({})

def test_parse_config_invalid_config():
    with pytest.raises(ValueError):
        parse_config({"invalid_key": "value"})

def test_parse_config_minimal_valid():
    config = {
        "name": "test",
        "version": "1.0"
    }
    result = parse_config(config)
    assert result.name == "test"
    assert result.version == "1.0"

5-5: Remove extra blank line.

There's an extra blank line that can be removed to maintain consistent spacing.

-

def test_parse_config_no_config():
tests/unit/test_utils_dynamic_imports.py (1)

6-8: Consider adding more test cases for comprehensive coverage.

While the error handling test is good, consider adding:

  1. Positive test case with a valid module
  2. Edge cases testing different path formats
  3. Test case with an existing path but invalid module
  4. Test case with a module containing syntax errors

Here's a suggested implementation:

import os
import tempfile
import textwrap

def test_import_module_from_path():
    # Test non-existent module (current test)
    with pytest.raises(FileNotFoundError):
        import_module_from_path("non_existent_module", "non_existent_path")
    
    # Test valid module
    with tempfile.NamedTemporaryFile(suffix='.py', mode='w', delete=False) as f:
        f.write(textwrap.dedent('''
            def sample_function():
                return 42
        '''))
    
    try:
        module = import_module_from_path("sample_module", f.name)
        assert module.sample_function() == 42
    finally:
        os.unlink(f.name)
    
    # Test invalid module content
    with tempfile.NamedTemporaryFile(suffix='.py', mode='w', delete=False) as f:
        f.write("invalid python syntax :")
    
    try:
        with pytest.raises(SyntaxError):
            import_module_from_path("invalid_module", f.name)
    finally:
        os.unlink(f.name)
tests/unit/test_utils_collate.py (1)

3-3: Remove unnecessary blank line

There's an extra blank line that can be removed to improve code readability.

from lighter.utils.collate import collate_replace_corrupted

-
def test_collate_replace_corrupted():
tests/unit/test_callbacks_writer_table.py (2)

1-7: Remove extra empty line after imports

There's an unnecessary double empty line after the imports.

 from lighter.callbacks.writer.table import LighterTableWriter

-

1-16: Consider improving test organization and coverage

The test file would benefit from several architectural improvements:

  1. Add pytest fixtures for common setup/teardown
  2. Use parametrize for testing multiple scenarios
  3. Add integration tests with other writers
  4. Consider adding a test class for better organization

Here's a suggested structure:

import pytest
from pathlib import Path
import torch
from lighter.callbacks.writer.table import LighterTableWriter

@pytest.fixture
def table_writer():
    writer = LighterTableWriter(path="test.csv", writer="tensor")
    yield writer
    Path("test.csv").unlink(missing_ok=True)

class TestLighterTableWriter:
    @pytest.mark.parametrize("tensor,id", [
        (torch.tensor([1, 2, 3]), 1),
        (torch.tensor([]), 2),
        (torch.randn(1000), 3),
    ])
    def test_write(self, table_writer, tensor, id):
        table_writer.write(tensor=tensor, id=id)
        assert len(table_writer.csv_records) > 0
        # Add more assertions
tests/unit/test_callbacks_writer_file.py (1)

1-19: Consider adopting a more structured testing approach.

Given this is part of a larger unit testing effort, consider:

  1. Using pytest fixtures for common setup/teardown
  2. Implementing parametrized tests for different tensor types and sizes
  3. Adding integration tests with other writers
  4. Adding property-based testing using hypothesis for tensor generation

Example fixture implementation:

@pytest.fixture
def temp_writer_dir():
    """Fixture providing a temporary directory for writer tests."""
    test_dir = Path("test_dir_fixture")
    test_dir.mkdir(exist_ok=True)
    yield test_dir
    shutil.rmtree(test_dir)

@pytest.fixture
def file_writer(temp_writer_dir):
    """Fixture providing a configured LighterFileWriter instance."""
    return LighterFileWriter(path=temp_writer_dir, writer="tensor")
tests/unit/test_callbacks_utils.py (1)

6-9: Enhance test coverage and documentation for 2D image preprocessing.

While the basic shape verification is good, consider these improvements:

  1. Add docstring explaining the test purpose and expected behavior
  2. Verify the content/values of the processed image, not just shape
  3. Add edge cases with different batch sizes and channel counts
  4. Document why these specific dimensions (64x64) were chosen

Here's a suggested enhancement:

 def test_preprocess_image_2d():
+    """Test 2D image preprocessing.
+    
+    Verifies that preprocess_image correctly handles 2D images by:
+    1. Removing batch dimension
+    2. Preserving channel and spatial dimensions
+    3. Maintaining pixel values
+    """
     image = torch.rand(1, 3, 64, 64)  # Batch of 2D images
     processed_image = preprocess_image(image)
     assert processed_image.shape == (3, 64, 64)
+    # Verify content preservation
+    assert torch.allclose(processed_image, image.squeeze(0))
+    
+    # Test with different batch and channel sizes
+    image_multi = torch.rand(4, 1, 32, 32)
+    processed_multi = preprocess_image(image_multi)
+    assert processed_multi.shape == (1, 32, 32)
tests/unit/test_utils_patches.py (3)

1-3: Remove unused import Module

The Module class from torch.nn is imported but never used in the test file.

-from torch.nn import Linear, Module
+from torch.nn import Linear
🧰 Tools
🪛 Ruff (0.8.0)

1-1: torch.nn.Module imported but unused

Remove unused import: torch.nn.Module

(F401)


6-25: Good test coverage, but could be enhanced

The test effectively covers basic functionality with reserved names. Here are some suggestions for improvement:

  1. Add a docstring explaining the test's purpose and importance
  2. Consider using @pytest.mark.parametrize for testing different reserved names
  3. Add error case testing (e.g., invalid values, None keys)

Here's a suggested enhancement:

+import pytest
+
 def test_patched_module_dict_handles_reserved_names():
+    """
+    Verify PatchedModuleDict properly handles Python reserved names.
+    
+    This test ensures that potentially problematic Python reserved names
+    can be used as keys without conflicts, maintaining proper dictionary
+    functionality for access and deletion operations.
+    """
     # Test with previously problematic reserved names

Would you like me to provide a complete example with parametrization and additional test cases?


11-12: Enhance assertions for Linear module verification

While the test verifies basic dictionary operations, it should also verify that the Linear modules maintain their properties and can perform their intended operations.

Add these assertions after line 21:

# Verify Linear modules maintain their properties
assert isinstance(patched_dict["forward"], Linear)
assert patched_dict["forward"].in_features == 10
assert patched_dict["forward"].out_features == 10

# Verify module operations still work
import torch
input_tensor = torch.randn(1, 10)
output = patched_dict["forward"](input_tensor)
assert output.shape == (1, 10)

Also applies to: 19-21

tests/unit/test_utils_model.py (2)

7-14: Add docstring and type hints to the test fixture.

While the implementation is correct, adding documentation would improve maintainability.

 class DummyModel(torch.nn.Module):
+    """Test fixture representing a simple neural network with two linear layers.
+    
+    Used for testing layer manipulation utilities.
+    """
     def __init__(self):
         super().__init__()
         self.layer1 = Linear(10, 10)
         self.layer2 = Linear(10, 10)
 
-    def forward(self, x):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
         return self.layer2(self.layer1(x))

30-37: Refactor duplicate model creation and add edge case tests.

  1. The Sequential model creation is duplicated. Consider using a fixture.
  2. Add test cases for edge cases like removing more layers than available.
+@pytest.fixture
+def sequential_model():
+    return Sequential(Linear(10, 10), Linear(10, 10), Linear(10, 10))
+
 def test_remove_n_last_layers_sequentially():
-    model = Sequential(Linear(10, 10), Linear(10, 10), Linear(10, 10))
+    model = sequential_model()
     new_model = remove_n_last_layers_sequentially(model, num_layers=1)
     assert len(new_model) == 2

-    model = Sequential(Linear(10, 10), Linear(10, 10), Linear(10, 10))
+    model = sequential_model()
     new_model = remove_n_last_layers_sequentially(model, num_layers=2)
     assert len(new_model) == 1
+
+    # Test edge case
+    with pytest.raises(ValueError):
+        remove_n_last_layers_sequentially(model, num_layers=4)
tests/unit/test_callbacks_freezer.py (4)

11-23: Add docstring and consider adding activation functions.

The DummyModel could be improved by:

  1. Adding a docstring explaining its purpose in testing.
  2. Adding activation functions between layers for a more realistic test scenario.
  3. Documenting why this specific architecture was chosen.
 class DummyModel(Module):
+    """A simple neural network for testing the LighterFreezer callback.
+    
+    Architecture:
+        input (10) -> linear -> linear (4) -> linear (1)
+    """
     def __init__(self):
         super().__init__()
         self.layer1 = torch.nn.Linear(10, 10)
+        self.relu1 = torch.nn.ReLU()
         self.layer2 = torch.nn.Linear(10, 4)
+        self.relu2 = torch.nn.ReLU()
         self.layer3 = torch.nn.Linear(4, 1)

     def forward(self, x):
         x = self.layer1(x)
+        x = self.relu1(x)
         x = self.layer2(x)
+        x = self.relu2(x)
         x = self.layer3(x)
         return x

25-31: Enhance test coverage with more diverse data.

The DummyDataset could be improved by:

  1. Adding a docstring explaining its testing purpose.
  2. Increasing the dataset size for more robust testing.
  3. Generating varied targets instead of all zeros.
 class DummyDataset(Dataset):
+    """A mock dataset for testing the LighterFreezer callback.
+    
+    Generates random input tensors and corresponding target values.
+    """
+    def __init__(self, size=100):
+        self.size = size
+
     def __len__(self):
-        return 10
+        return self.size

     def __getitem__(self, idx):
-        return {"input": torch.randn(10), "target": torch.tensor(0)}
+        return {
+            "input": torch.randn(10),
+            "target": torch.randint(0, 2, (1,))  # Binary targets for variety
+        }

33-40: Document fixture and parameterize test configurations.

The fixture could be improved by:

  1. Adding a docstring explaining its purpose and usage.
  2. Making batch_size and optimizer configurations parameterizable.
  3. Using a more stable optimizer like Adam for testing.
 @pytest.fixture
-def dummy_system():
+def dummy_system(batch_size=32):
+    """Create a LighterSystem instance for testing.
+    
+    Args:
+        batch_size (int, optional): Batch size for training. Defaults to 32.
+    
+    Returns:
+        LighterSystem: A system configured with dummy model and data.
+    """
     model = DummyModel()
-    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
+    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
     dataset = DummyDataset()
     criterion = torch.nn.CrossEntropyLoss()
-    return LighterSystem(model=model, batch_size=32, criterion=criterion, optimizer=optimizer, datasets={"train": dataset})
+    return LighterSystem(
+        model=model,
+        batch_size=batch_size,
+        criterion=criterion,
+        optimizer=optimizer,
+        datasets={"train": dataset}
+    )

1-65: Consider restructuring tests using a proper test class.

The test suite would benefit from being organized into a proper test class with shared setup and teardown methods. This would:

  1. Reduce code duplication
  2. Provide better organization of test cases
  3. Allow for shared fixtures and helper methods
  4. Make it easier to add new test cases

Example structure:

@pytest.mark.freezer
class TestLighterFreezer:
    @pytest.fixture(autouse=True)
    def setup(self):
        self.model = DummyModel()
        self.dataset = DummyDataset()
        # ... other setup code

    def test_initialization(self):
        # ... initialization tests

    def test_functionality(self):
        # ... functionality tests

    def test_exceptions(self):
        # ... exception tests
tests/unit/test_system.py (4)

15-48: Add docstrings to test helper classes.

While the implementation is correct, adding docstrings would improve code maintainability by documenting:

  • Purpose of each test helper class
  • Expected shapes and formats of the data
  • Relationship to the system under test

Example docstring for DummyDataset:

class DummyDataset(Dataset):
    """Test dataset that generates random image-like tensors and class labels.
    
    Returns batches in the format:
        {'input': torch.Tensor(3, 32, 32), 'target': torch.Tensor()}
    """

50-79: Consider adding validation for dataset configuration.

The DummySystem initialization looks good, but consider adding validation to ensure:

  • All required dataset splits are provided
  • Metrics configuration matches the dataset splits
  • Dataset sizes are appropriate for batch size

Example validation:

def __init__(self):
    # ... existing code ...
    
    required_splits = {'train', 'val', 'test'}
    if not all(split in datasets for split in required_splits):
        raise ValueError(f"Missing required dataset splits: {required_splits - set(datasets.keys())}")
    
    if not all(split in metrics for split in required_splits):
        raise ValueError(f"Missing metrics for splits: {required_splits - set(metrics.keys())}")

110-125: Enhance step function tests with edge cases.

While the basic functionality is well tested, consider adding tests for:

  • Empty batches
  • Batches with different tensor types (float32, float64)
  • GPU/CPU tensor handling
  • Gradient calculation verification

Example edge case test:

def test_training_step_empty_batch(dummy_system):
    empty_batch = {"input": torch.empty(0, 3, 32, 32), "target": torch.empty(0)}
    with mock.patch.object(dummy_system, "log"):
        with pytest.raises(ValueError, match="Empty batch"):
            dummy_system.training_step(empty_batch, batch_idx=0)

Also applies to: 127-139, 141-152


162-188: Add more batch format test cases.

The parameterized tests are good, but consider adding test cases for:

  • Different input shapes
  • Multi-target scenarios
  • Missing optional keys
  • Type mismatches between input and target

Example additional test cases:

@pytest.mark.parametrize(
    "batch",
    [
        # Different input shapes
        {"input": torch.randn(1, 1, 64, 64), "target": torch.randint(0, 10, size=(1,))},
        # Multi-target
        {
            "input": torch.randn(1, 3, 32, 32),
            "target": {"class": torch.randint(0, 10, size=(1,)), "bbox": torch.randn(1, 4)}
        },
    ]
)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between c332a2b and 7070c1c.

📒 Files selected for processing (15)
  • .gitignore (1 hunks)
  • lighter/utils/runner.py (1 hunks)
  • tests/unit/test_callbacks_freezer.py (1 hunks)
  • tests/unit/test_callbacks_utils.py (1 hunks)
  • tests/unit/test_callbacks_writer_base.py (1 hunks)
  • tests/unit/test_callbacks_writer_file.py (1 hunks)
  • tests/unit/test_callbacks_writer_table.py (1 hunks)
  • tests/unit/test_system.py (1 hunks)
  • tests/unit/test_utils_collate.py (1 hunks)
  • tests/unit/test_utils_dynamic_imports.py (1 hunks)
  • tests/unit/test_utils_logging.py (1 hunks)
  • tests/unit/test_utils_misc.py (1 hunks)
  • tests/unit/test_utils_model.py (1 hunks)
  • tests/unit/test_utils_patches.py (1 hunks)
  • tests/unit/test_utils_runner.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • .gitignore
🧰 Additional context used
🪛 Ruff (0.8.0)
tests/unit/test_callbacks_writer_base.py

8-8: Local variable writer is assigned to but never used

Remove assignment to unused variable writer

(F841)

tests/unit/test_utils_patches.py

1-1: torch.nn.Module imported but unused

Remove unused import: torch.nn.Module

(F401)

🔇 Additional comments (16)
tests/unit/test_utils_misc.py (1)

1-2: LGTM! Clean and specific import.

The import statement follows good practices by importing only the specific function being tested.

tests/unit/test_utils_logging.py (1)

1-6: Verify logging configuration persistence

Let's verify that the logging configuration persists across the codebase and that there are no conflicting configurations.

✅ Verification successful

Let me check the implementation of the logging setup to ensure it's properly configured.


Logging configuration is properly centralized and safe

The logging setup is well-implemented with a single configuration point in lighter/utils/logging.py that:

  • Properly intercepts and redirects all Python logging to Loguru
  • Is imported and initialized early in lighter/__init__.py
  • Has no competing or conflicting configurations across the codebase
  • Uses a custom formatter and handler to ensure consistent logging behavior
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for multiple logging configuration points that might conflict

# Search for logging configuration in Python files
echo "Searching for logging configuration points..."
rg -l "logging\.(basicConfig|config)" --type py

# Search for logger instantiation
echo "\nSearching for logger instantiation points..."
rg "logging\.getLogger\(" --type py -A 2

# Search for custom logging setup
echo "\nSearching for custom logging setup..."
rg "setup_logging" --type py

Length of output: 919


Script:

#!/bin/bash
# Get the implementation of _setup_logging
ast-grep --pattern 'def _setup_logging():
  $$$'

Length of output: 4259

tests/unit/test_utils_runner.py (2)

1-4: LGTM! Clean and minimal imports.

The imports are appropriate and follow best practices by importing only what's needed for the test.


1-8: Verify test implementation and coverage.

Let's ensure the test structure aligns with the implementation.

✅ Verification successful

Let me gather more information about the implementation of parse_config to ensure proper test coverage.


Let me try one more verification to see the complete implementation and usage:


Test implementation aligns with the function's behavior

The test correctly verifies the error case when no config is provided to parse_config(). The implementation shows that parse_config() explicitly raises a ValueError when "config" is not in the kwargs, which is exactly what the test is checking. The test is also properly located in the unit test directory and is the only test file for the runner module.

  • ✓ Test location is correct (tests/unit/test_utils_runner.py)
  • ✓ Implementation matches test expectations
  • ✓ No duplicate tests exist
  • ⚠️ Consider adding more test cases for other scenarios (config validation, overrides)
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify test location, coverage, and implementation

# Check if there are other test files for runner.py
echo "Checking for other runner tests..."
fd test_.*runner\.py tests/

# Verify the parse_config implementation
echo "Checking parse_config implementation..."
ast-grep --pattern 'def parse_config($$$)' 

# Check for similar existing tests
echo "Checking for similar tests..."
rg -A 5 'test_parse_config' tests/

Length of output: 569


Script:

#!/bin/bash
# Let's check the actual implementation of parse_config
echo "Checking parse_config implementation in lighter/utils/runner.py..."
rg -A 10 "def parse_config" lighter/utils/runner.py

# Also check if there are any other related functions or usages
echo "Checking related functions and usages..."
rg "parse_config" --type python

Length of output: 756


Script:

#!/bin/bash
# Check the complete implementation file
echo "Checking complete runner.py implementation..."
cat lighter/utils/runner.py

# Check for imports and usage without file type restriction
echo "Checking all parse_config references..."
rg "parse_config"

Length of output: 4592

tests/unit/test_callbacks_writer_base.py (1)

6-8: Enhance test coverage with docstring and positive cases.

While testing the error case is important, consider:

  1. Adding a docstring to describe the test's purpose
  2. Including positive test cases to verify successful initialization scenarios
 def test_writer_initialization():
+    """
+    Test LighterBaseWriter initialization:
+    - Raises TypeError for invalid parameters
+    - Successfully initializes with valid parameters
+    """
     with pytest.raises(TypeError):
-        writer = LighterBaseWriter(path="test", writer="tensor")
+        LighterBaseWriter(path="test", writer="tensor")
+    
+    # Test successful initialization
+    writer = LighterBaseWriter(path="test", writer="valid_type")  # Replace with actual valid type
+    assert isinstance(writer, LighterBaseWriter)
🧰 Tools
🪛 Ruff (0.8.0)

8-8: Local variable writer is assigned to but never used

Remove assignment to unused variable writer

(F841)

tests/unit/test_utils_dynamic_imports.py (1)

6-8: LGTM! The error handling test is well-implemented.

The test correctly verifies that attempting to import a non-existent module raises a FileNotFoundError.

tests/unit/test_utils_collate.py (1)

1-2: Verify import usage across codebase

Let's verify that the imported collate_replace_corrupted function exists and is properly exported.

✅ Verification successful

Import verified and properly implemented

The collate_replace_corrupted function is correctly implemented in lighter/utils/collate.py and is being properly imported in both the test file and lighter/system.py. The function exists, is exported, and is being used appropriately across the codebase.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the existence and export of collate_replace_corrupted function

# Check if the function exists in the source file
ast-grep --pattern 'def collate_replace_corrupted($$$)'

# Check if the function is properly exported
rg -l "^from.*collate_replace_corrupted import" || rg -l "^from.*collate import.*collate_replace_corrupted"

Length of output: 217


Script:

#!/bin/bash
# Let's verify the actual implementation and location of the function
rg -p "def collate_replace_corrupted" -g '*.py'

# Also check the content of the collate module to ensure proper export
fd collate.py --exec cat {}

Length of output: 3027

tests/unit/test_callbacks_writer_file.py (1)

1-7: LGTM! Clean and minimal imports.

The imports are well-organized and include only the necessary dependencies.

tests/unit/test_callbacks_utils.py (2)

1-4: LGTM! Imports are clean and minimal.

The imports are well-organized and include only the necessary dependencies.


12-19: 🛠️ Refactor suggestion

Verify the expected output shape transformation for 3D images.

The expected output shape (1, depth*height, batch_size*width) seems unusual. This transformation:

  1. Combines spatial dimensions (depth*height)
  2. Multiplies width by batch_size
  3. Sets channels to 1

Please verify if this is the intended behavior of preprocess_image.

Additionally, consider these improvements:

  1. Add docstring explaining the transformation logic
  2. Verify pixel values are preserved correctly
  3. Add error handling tests (e.g., invalid dimensions)

Here's a suggested enhancement:

 def test_preprocess_image_3d():
+    """Test 3D image preprocessing.
+    
+    Verifies that preprocess_image correctly handles 3D images by:
+    1. Reshaping spatial dimensions as per model requirements
+    2. Preserving total pixel count and values
+    3. Handling batch dimension appropriately
+    """
     batch_size = 8
     depth = 20
     height = 64
     width = 64
     image = torch.rand(batch_size, 1, depth, height, width)  # Batch of 3D images
     processed_image = preprocess_image(image)
     assert processed_image.shape == (1, depth * height, batch_size * width)
+    
+    # Verify total element count is preserved
+    assert image.numel() == processed_image.numel()
+    
+    # Test error handling
+    with pytest.raises(ValueError, match="Invalid dimensions"):
+        invalid_image = torch.rand(8, 1, 20)  # Missing spatial dimensions
+        preprocess_image(invalid_image)
✅ Verification successful

Let me search for more context about the preprocess_image function implementation.


The test correctly verifies the intended shape transformation

The test is correct and matches the implementation's documented behavior. From the source code in lighter/callbacks/utils.py, we can see that:

  1. The function is specifically designed to handle 3D images by:

    • Combining depth and height dimensions vertically (shape[2] * shape[3])
    • Concatenating batch images horizontally (B*W)
    • This matches exactly with the test's expected shape (1, depth*height, batch_size*width)
  2. The transformation is intentional for visualization purposes, as documented:

    "In case of 3D, a single image is displayed with slices stacked vertically, while a batch of 3D images as a grid where each column is a different 3D image"

The test's assertions correctly verify this visualization-oriented transformation.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for the preprocess_image implementation and its usage
ast-grep --pattern 'def preprocess_image($$$)' -A 10
rg -A 5 'preprocess_image.*3[dD]'

Length of output: 479


Script:

#!/bin/bash
# Search for preprocess_image implementation in the codebase
rg -l "def preprocess_image"

# Once we find the file, let's look at its content
rg -B 5 -A 15 "def preprocess_image"

# Let's also search for any related documentation or comments
rg -i "preprocess.*image"

Length of output: 2409

tests/unit/test_utils_model.py (1)

1-4: LGTM! Imports are well-organized.

The imports are appropriately separated between external (PyTorch) and internal (lighter) dependencies, and all imported components are utilized in the tests.

tests/unit/test_callbacks_freezer.py (2)

1-9: LGTM! Imports are well-organized and complete.


42-65: 🛠️ Refactor suggestion

Enhance test coverage with docstrings and edge cases.

The test suite needs improvement in several areas:

  1. Missing docstrings for all test functions
  2. No negative test cases
  3. Limited testing of error conditions
  4. Single epoch might not reveal training-related issues

Here's how to enhance the test coverage:

+def test_freezer_invalid_initialization():
+    """Test that LighterFreezer raises ValueError for invalid layer names."""
+    with pytest.raises(ValueError):
+        LighterFreezer(names=["non_existent_layer"])

 def test_freezer_initialization():
+    """Test that LighterFreezer correctly initializes with valid layer names."""
     freezer = LighterFreezer(names=["layer1"])
     assert freezer.names == ["layer1"]
+    assert isinstance(freezer.names, list)

 def test_freezer_functionality(dummy_system):
+    """Test that LighterFreezer correctly freezes specified layers during training."""
     freezer = LighterFreezer(names=["layer1.weight", "layer1.bias"])
-    trainer = Trainer(callbacks=[freezer], max_epochs=1)
+    trainer = Trainer(callbacks=[freezer], max_epochs=3)  # More epochs for stability
     trainer.fit(dummy_system)
+    
+    # Test initial parameter states
     assert not dummy_system.model.layer1.weight.requires_grad
     assert not dummy_system.model.layer1.bias.requires_grad
     assert dummy_system.model.layer2.weight.requires_grad
+    
+    # Store and verify frozen parameters don't change
+    initial_weights = dummy_system.model.layer1.weight.clone()
+    trainer.fit(dummy_system)
+    assert torch.allclose(initial_weights, dummy_system.model.layer1.weight)

 def test_freezer_with_exceptions(dummy_system):
+    """Test that LighterFreezer correctly handles exception patterns during training."""
     freezer = LighterFreezer(name_starts_with=["layer"], except_names=["layer2.weight", "layer2.bias"])
-    trainer = Trainer(callbacks=[freezer], max_epochs=1)
+    trainer = Trainer(callbacks=[freezer], max_epochs=3)
     trainer.fit(dummy_system)
+    
+    # Verify frozen state of all layers
     assert not dummy_system.model.layer1.weight.requires_grad
     assert not dummy_system.model.layer1.bias.requires_grad
     assert dummy_system.model.layer2.weight.requires_grad
     assert dummy_system.model.layer2.bias.requires_grad
     assert not dummy_system.model.layer3.weight.requires_grad
     assert not dummy_system.model.layer3.bias.requires_grad
+    
+    # Verify training affects only unfrozen layers
+    initial_l1_weights = dummy_system.model.layer1.weight.clone()
+    initial_l2_weights = dummy_system.model.layer2.weight.clone()
+    trainer.fit(dummy_system)
+    assert torch.allclose(initial_l1_weights, dummy_system.model.layer1.weight)
+    assert not torch.allclose(initial_l2_weights, dummy_system.model.layer2.weight)

Let's verify the test coverage:

lighter/utils/runner.py (1)

39-39: LGTM! Consider adding version compatibility handling.

The change from .dict() to .model_dump() appears to be an update for compatibility with newer versions of Pydantic (v2.x). While the change is correct, we should ensure compatibility across different Pydantic versions.

Let's verify the Pydantic version and its compatibility:

Consider adding version compatibility handling:

try:
    parser = ConfigParser(ConfigSchema().model_dump(), globals=False)
except AttributeError:  # Fallback for older Pydantic versions
    parser = ConfigParser(ConfigSchema().dict(), globals=False)
✅ Verification successful

Change is correct and consistent with dependencies

The change from .dict() to .model_dump() is appropriate as the project explicitly depends on Pydantic v2 (pydantic = "^2.8.2" in pyproject.toml). The codebase consistently uses model_dump() across multiple files, and the schema implementation in schema.py follows Pydantic v2 patterns (e.g., using model_validator, model_fields, etc.). No compatibility layer is needed since the project has standardized on Pydantic v2.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check Pydantic version requirements and usage
# Check for Pydantic version in requirements
rg -l "pydantic[>=<~]" 

# Check for other potential .dict() usage that might need updates
rg "\.dict\(\)" 

# Check for existing version handling
rg "pydantic.*version"

Length of output: 67


Script:

#!/bin/bash
# Let's try a different approach to check dependencies and usage

# Check for pyproject.toml, setup.py, or requirements files
fd "pyproject.toml|setup.py|requirements.*\.txt"

# Check for pydantic imports and usage
rg -l "from pydantic|import pydantic"

# Check for model_dump usage to see if other files were updated
rg "model_dump"

# Look for ConfigSchema usage across the codebase
rg "ConfigSchema"

Length of output: 1679


Script:

#!/bin/bash
# Let's check the pyproject.toml for pydantic version and examine schema.py for BaseModel usage

# Check pyproject.toml content for dependencies
cat pyproject.toml

# Check the schema implementation
cat lighter/utils/schema.py

Length of output: 8683

tests/unit/test_system.py (2)

1-12: LGTM! Well-organized imports.

The imports are properly organized following Python's import style guidelines, with standard library imports first, followed by third-party and local imports.


1-188: Overall well-structured test suite with good coverage.

The test implementation demonstrates good practices:

  • Proper use of pytest fixtures and parameterization
  • Clear separation of test helper classes
  • Comprehensive coverage of core functionality

Consider the suggested improvements for documentation and edge cases, but the current implementation is solid and ready for use.

Comment on lines +4 to +6
def test_ensure_list():
assert ensure_list(1) == [1]
assert ensure_list([1, 2]) == [1, 2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

The current test cases look good, but coverage could be improved.

While the existing test cases correctly verify basic functionality, consider adding more test cases to ensure comprehensive coverage:

 def test_ensure_list():
+    """Test ensure_list utility function handles various input types correctly."""
     assert ensure_list(1) == [1]
     assert ensure_list([1, 2]) == [1, 2]
+    # Test None input
+    assert ensure_list(None) == [], "Should handle None input"
+    # Test empty list
+    assert ensure_list([]) == [], "Should preserve empty list"
+    # Test string input
+    assert ensure_list("test") == ["test"], "Should convert string to list"
+    # Test tuple input
+    assert ensure_list((1, 2)) == [1, 2], "Should convert tuple to list"
+    # Test nested structures
+    assert ensure_list([[1, 2]]) == [[1, 2]], "Should preserve nested lists"

Also consider:

  1. Adding descriptive messages to assertions for better error reporting
  2. Using pytest's parametrize decorator for cleaner test organization
  3. Adding docstring to explain the test's purpose

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +1 to +6
from lighter.utils.misc import ensure_list


def test_ensure_list():
assert ensure_list(1) == [1]
assert ensure_list([1, 2]) == [1, 2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Test cases need to be expanded to cover all implementation cases

The current test cases only cover two scenarios (ensure_list(1) and ensure_list([1, 2])), but the implementation handles four distinct cases:

  • List input (returns as-is)
  • Tuple input (converts to list)
  • None input (returns empty list)
  • Other types (wraps in list)

Missing test cases:

  • ensure_list(None) should return []
  • ensure_list((1, 2)) should return [1, 2]
🔗 Analysis chain

Verify ensure_list implementation details

Let's verify the actual implementation to ensure our suggested test cases align with the intended behavior.


Let me try another approach to find the implementation details.


Let me get the complete implementation to verify the test cases.


Let me get a few more lines of the implementation as it seems truncated.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Find and display the implementation of ensure_list
echo "Looking for ensure_list implementation:"
ast-grep --pattern 'def ensure_list($_):
  $$$
'

# Find any other existing tests for ensure_list
echo -e "\nLooking for other existing tests:"
rg -l "test.*ensure_list" tests/

Length of output: 287


Script:

#!/bin/bash
# Search for ensure_list implementation using ripgrep for better context
echo "Looking for ensure_list implementation in source files:"
rg "def ensure_list" -A 5 lighter/

Length of output: 480


Script:

#!/bin/bash
# Get the complete implementation of ensure_list
echo "Complete ensure_list implementation:"
rg "def ensure_list" -A 10 lighter/utils/misc.py

Length of output: 419


Script:

#!/bin/bash
# Get the complete implementation with more context
echo "Complete ensure_list implementation:"
rg "def ensure_list" -A 15 lighter/utils/misc.py

Length of output: 538

tests/unit/test_utils_logging.py Show resolved Hide resolved
tests/unit/test_callbacks_writer_base.py Outdated Show resolved Hide resolved
tests/unit/test_utils_collate.py Outdated Show resolved Hide resolved
tests/unit/test_callbacks_writer_table.py Outdated Show resolved Hide resolved
tests/unit/test_callbacks_writer_file.py Outdated Show resolved Hide resolved
tests/unit/test_callbacks_writer_file.py Outdated Show resolved Hide resolved
Comment on lines +17 to +28
def test_replace_layer_with():
model = DummyModel()
new_layer = Linear(10, 4)
replace_layer_with(model, "layer1", new_layer)
assert model.layer1 == new_layer


def test_replace_layer_with_identity():
model = DummyModel()
replace_layer_with_identity(model, "layer1")
assert isinstance(model.layer1, torch.nn.Identity)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add test cases for error conditions and shape validation.

The current tests verify successful operations but don't cover error cases. Consider adding tests for:

  1. Invalid layer names in replace_layer_with
  2. Shape mismatches when replacing layers
  3. Input/output tensor shape validation

Example test case to add:

def test_replace_layer_with_invalid_name():
    model = DummyModel()
    new_layer = Linear(10, 4)
    with pytest.raises(AttributeError):
        replace_layer_with(model, "non_existent_layer", new_layer)

tests/unit/test_utils_model.py Show resolved Hide resolved
ibro45 and others added 5 commits November 29, 2024 19:07
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (6)
tests/unit/test_callbacks_writer_table.py (2)

13-37: Add more comprehensive write test scenarios

While the current test covers basic functionality and edge cases, consider adding:

  1. CSV header verification
  2. Test for concurrent writes
  3. Error cases (file permission issues, disk full scenarios)
  4. Verification of CSV format compliance
 def test_table_writer_write():
     """Test LighterTableWriter write functionality with various inputs."""
     test_file = Path("test.csv")
     writer = LighterTableWriter(path="test.csv", writer="tensor")
     
     # Test basic write
     test_tensor = torch.tensor([1, 2, 3])
     writer.write(tensor=test_tensor, id=1)
     assert len(writer.csv_records) == 1
     assert writer.csv_records[0]["tensor"] == test_tensor.tolist()
     assert writer.csv_records[0]["id"] == 1
     
     # Test edge cases
     writer.write(tensor=torch.tensor([]), id=2)  # empty tensor
     writer.write(tensor=torch.randn(1000), id=3)  # large tensor
     writer.write(tensor=torch.tensor([1.5, 2.5]), id=4)  # float tensor
     
     # Verify file creation and content
     assert test_file.exists()
     with open(test_file) as f:
+        # Verify CSV header
+        header = f.readline().strip()
+        assert header == "id,tensor"
+        
+        # Verify content format
         content = f.read()
         assert "1,2,3" in content  # verify first tensor
+        # Verify CSV format compliance
+        import csv
+        f.seek(0)
+        reader = csv.DictReader(f)
+        rows = list(reader)
+        assert len(rows) == 4  # All records present
     
+    # Test concurrent writes
+    import threading
+    def concurrent_write():
+        writer.write(tensor=torch.tensor([5,6,7]), id=5)
+    thread = threading.Thread(target=concurrent_write)
+    thread.start()
+    thread.join()
+    assert len(writer.csv_records) == 5
+    
+    # Test error cases
+    import os
+    # Test file permission error
+    test_file.chmod(0o000)  # Remove all permissions
+    with pytest.raises(PermissionError):
+        writer.write(tensor=torch.tensor([8,9,10]), id=6)
+    test_file.chmod(0o666)  # Restore permissions
+    
     # Cleanup
     test_file.unlink()

1-37: Consider architectural improvements for better test organization

To improve test maintainability and reduce duplication:

  1. Use pytest fixtures for common setup (writer initialization, test file creation)
  2. Mock file operations to avoid actual file I/O in tests
  3. Create test data fixtures for different tensor scenarios

Example fixture structure:

import pytest
from unittest.mock import patch, mock_open

@pytest.fixture
def table_writer():
    writer = LighterTableWriter(path="test.csv", writer="tensor")
    yield writer

@pytest.fixture
def test_tensors():
    return {
        'basic': torch.tensor([1, 2, 3]),
        'empty': torch.tensor([]),
        'large': torch.randn(1000),
        'float': torch.tensor([1.5, 2.5])
    }

# Then use in tests:
def test_table_writer_write(table_writer, test_tensors):
    with patch('builtins.open', mock_open()) as mock_file:
        table_writer.write(tensor=test_tensors['basic'], id=1)
        # Test assertions...
tests/unit/test_utils_collate.py (4)

4-12: Consider enhancing the docstring with function behavior details.

While the docstring effectively lists what is being tested, it could be more helpful to include:

  • Brief description of what collate_replace_corrupted does
  • Expected behavior for each test case
 def test_collate_replace_corrupted():
     """Test collate_replace_corrupted function handles corrupted data correctly.
     
+    The collate_replace_corrupted function replaces None values in a batch
+    with random values from a provided dataset while preserving non-None values.
+    
     Tests:
-        - Corrupted values (None) are replaced with valid dataset values
-        - Non-corrupted values remain unchanged
-        - Output maintains correct length
-        - Edge cases: empty batch, all corrupted values
+        - Corrupted values (None) are replaced with valid dataset values -> Should pick random values from dataset
+        - Non-corrupted values remain unchanged -> Original values should be preserved
+        - Output maintains correct length -> Output batch should match input batch length
+        - Edge cases:
+            - empty batch -> Should return empty list
+            - all corrupted values -> Should replace all with valid dataset values
     """

13-26: Consider using test fixtures or setup data.

The test data could be better organized using test fixtures or setup variables at the module level, especially if these test cases will be reused in other test functions.

+# At module level
+SAMPLE_BATCH = [1, None, 2, None, 3]
+SAMPLE_DATASET = [1, 2, 3, 4, 5]
+
 def test_collate_replace_corrupted():
     """..."""
-    batch = [1, None, 2, None, 3]
-    dataset = [1, 2, 3, 4, 5]
+    batch = SAMPLE_BATCH
+    dataset = SAMPLE_DATASET

28-35: Consider adding an empty dataset edge case.

The edge cases are well covered, but it might be valuable to test the behavior when the dataset is empty, as this could lead to potential issues.

     # Test edge cases
     empty_batch = []
     assert len(collate_replace_corrupted(empty_batch, dataset)) == 0
     
+    # Test empty dataset
+    empty_dataset = []
+    with pytest.raises(ValueError, match="Dataset cannot be empty"):
+        collate_replace_corrupted(batch, empty_dataset)
+
     all_corrupted = [None, None, None]
     collated_all_corrupted = collate_replace_corrupted(all_corrupted, dataset)
     assert len(collated_all_corrupted) == len(all_corrupted)
     assert all(val in dataset for val in collated_all_corrupted)

Note: This assumes the function should raise a ValueError for empty datasets. If the behavior should be different, adjust the test accordingly.


1-1: Add pytest import for the suggested test.

If you implement the empty dataset test, you'll need to add:

+import pytest
 from lighter.utils.collate import collate_replace_corrupted
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 7070c1c and 98af60b.

📒 Files selected for processing (4)
  • tests/unit/test_callbacks_writer_base.py (1 hunks)
  • tests/unit/test_callbacks_writer_file.py (1 hunks)
  • tests/unit/test_callbacks_writer_table.py (1 hunks)
  • tests/unit/test_utils_collate.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/unit/test_callbacks_writer_base.py
  • tests/unit/test_callbacks_writer_file.py
🔇 Additional comments (3)
tests/unit/test_callbacks_writer_table.py (2)

1-7: LGTM!

The imports are well-organized and include all necessary dependencies.


8-11: 🛠️ Refactor suggestion

Enhance test coverage for initialization

The test needs improvement in several areas:

  1. Add docstring explaining test purpose
  2. Verify the writer attribute
  3. Add negative test cases (invalid writer types, empty paths)
  4. Add type hints for better code documentation
+from typing import Any
+
 def test_table_writer_initialization():
+    """Test LighterTableWriter initialization with valid and invalid inputs.
+    
+    Tests:
+        - Valid initialization with correct path and writer
+        - Invalid writer type raises ValueError
+        - Empty path raises ValueError
+    """
     writer = LighterTableWriter(path="test.csv", writer="tensor")
     assert writer.path == Path("test.csv")
+    assert writer.writer == "tensor"
+    
+    # Test invalid writer type
+    with pytest.raises(ValueError, match="Invalid writer type"):
+        LighterTableWriter(path="test.csv", writer="invalid")
+    
+    # Test empty path
+    with pytest.raises(ValueError, match="Path cannot be empty"):
+        LighterTableWriter(path="", writer="tensor")
tests/unit/test_utils_collate.py (1)

1-3: LGTM! Clean and focused import.

The import statement is properly focused on the specific function being tested.

@dosubot dosubot bot added size:XL This PR changes 500-999 lines, ignoring generated files. and removed size:L This PR changes 100-499 lines, ignoring generated files. labels Nov 30, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (5)
tests/unit/test_utils_collate.py (2)

1-4: Add a comment explaining the torch dependency

While the torch import is necessary (the collated batch contains tensor items), its usage isn't immediately obvious. Consider adding a comment to explain why it's needed.

+# torch is required as collate_replace_corrupted returns tensor items
 import torch
🧰 Tools
🪛 Ruff (0.8.0)

1-1: torch imported but unused

Remove unused import: torch

(F401)


15-37: Enhance test implementation

While the test coverage is good, there are a few improvements to consider:

  1. The empty batch test case mentioned in the docstring is missing
  2. The filtered batch comparison could be more robust using sets
  3. Consider using parametrized tests for better maintainability

Here's a suggested implementation:

+import pytest
+
+@pytest.mark.parametrize("batch, dataset, expected_length", [
+    ([1, None, 2, None, 3], [1, 2, 3, 4, 5], 5),  # Normal case
+    ([None, None, None], [1, 2, 3, 4, 5], 3),     # All corrupted
+    ([], [1, 2, 3, 4, 5], 0),                     # Empty batch
+])
 def test_collate_replace_corrupted():
     """Test collate_replace_corrupted function handles corrupted data correctly.
     ...
     """
-    batch = [1, None, 2, None, 3]
-    dataset = [1, 2, 3, 4, 5]
     collated_batch = collate_replace_corrupted(batch, dataset)

     # Test length
-    assert len(collated_batch) == len(batch)
+    assert len(collated_batch) == expected_length

     # Test non-corrupted values remain unchanged.
     filtered_batch = list(filter(lambda x: x is not None, batch))
-    assert collated_batch[0].item() == filtered_batch[0]
-    assert collated_batch[1].item() == filtered_batch[1]
-    assert collated_batch[2].item() == filtered_batch[2]
+    if filtered_batch:  # Skip for all-corrupted or empty batch
+        filtered_values = set(val.item() for val in collated_batch[:len(filtered_batch)])
+        assert filtered_values == set(filtered_batch)

     # Test corrupted values are replaced with valid dataset values
-    assert collated_batch[3].item() in dataset
-    assert collated_batch[4].item() in dataset
+    if len(collated_batch) > len(filtered_batch):
+        replaced_values = [val.item() for val in collated_batch[len(filtered_batch):]]
+        assert all(val in dataset for val in replaced_values)

-    all_corrupted = [None, None, None]
-    collated_all_corrupted = collate_replace_corrupted(all_corrupted, dataset)
-    assert len(collated_all_corrupted) == len(all_corrupted)
-    assert all(val in dataset for val in collated_all_corrupted)
lighter/callbacks/writer/table.py (1)

Line range hint 41-42: Add type validation in the write method

To prevent sorting issues downstream, add type validation for the id parameter in the write method.

     def write(self, tensor: Any, id: Union[int, str]) -> None:
+        if not isinstance(id, (int, str)):
+            raise TypeError(f"'id' must be int or str, got {type(id)}")
         self.csv_records.append({"id": id, "pred": self.writer(tensor)})
🧰 Tools
🪛 Ruff (0.8.0)

63-66: Use contextlib.suppress(TypeError) instead of try-except-pass

Replace with contextlib.suppress(TypeError)

(SIM105)

tests/unit/test_callbacks_writer_table.py (2)

10-10: Remove unused import

The LighterSystem import is not used in any of the tests.

-from lighter.system import LighterSystem
🧰 Tools
🪛 Ruff (0.8.0)

10-10: lighter.system.LighterSystem imported but unused

Remove unused import: lighter.system.LighterSystem

(F401)


67-109: Consider enhancing multi-process test documentation and coverage

The test effectively validates the multi-process writing scenario with good mocking and assertions. Consider adding:

  1. Docstring explaining the test's purpose and setup
  2. Error cases (e.g., gather failures, rank mismatches)
 def test_table_writer_write_multi_process(tmp_path, monkeypatch):
+    """Test LighterTableWriter's behavior in a distributed training setup.
+    
+    Tests:
+    - Record gathering from multiple processes
+    - Correct CSV file creation with combined records
+    - Proper handling of distributed environment
+    """
     test_file = tmp_path / "test.csv"
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 98af60b and 7b7836b.

📒 Files selected for processing (4)
  • lighter/callbacks/writer/table.py (1 hunks)
  • tests/unit/test_callbacks_writer_file.py (1 hunks)
  • tests/unit/test_callbacks_writer_table.py (1 hunks)
  • tests/unit/test_utils_collate.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/unit/test_callbacks_writer_file.py
🧰 Additional context used
🪛 Ruff (0.8.0)
lighter/callbacks/writer/table.py

63-66: Use contextlib.suppress(TypeError) instead of try-except-pass

Replace with contextlib.suppress(TypeError)

(SIM105)

tests/unit/test_callbacks_writer_table.py

5-5: pytest imported but unused

Remove unused import: pytest

(F401)


10-10: lighter.system.LighterSystem imported but unused

Remove unused import: lighter.system.LighterSystem

(F401)

tests/unit/test_utils_collate.py

1-1: torch imported but unused

Remove unused import: torch

(F401)

🔇 Additional comments (5)
tests/unit/test_utils_collate.py (1)

6-14: LGTM! Well-documented test function

The docstring clearly describes the test's purpose and comprehensively lists all test cases.

tests/unit/test_callbacks_writer_table.py (4)

13-14: LGTM!

The helper function is well-defined and correctly implements a custom writer for testing purposes.


17-19: Enhance initialization test coverage

The test only verifies the path attribute. Consider adding validation of the writer attribute and negative test cases.


22-26: LGTM!

The test effectively validates the custom writer functionality with proper assertions.


29-64: LGTM!

The test is well-structured with:

  • Clear test cases including edge cases
  • Comprehensive data validation
  • Proper file cleanup

lighter/callbacks/writer/table.py Show resolved Hide resolved
@ibro45 ibro45 merged commit 4d0dd74 into main Nov 30, 2024
3 checks passed
@ibro45 ibro45 deleted the tests branch November 30, 2024 03:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
size:XL This PR changes 500-999 lines, ignoring generated files.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant