-
Notifications
You must be signed in to change notification settings - Fork 519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(pt): DPA-2 repinit compress #4329
Conversation
Signed-off-by: Jinzhe Zeng <[email protected]>
📝 Walkthrough📝 WalkthroughWalkthroughThe changes in this pull request introduce a new method Changes
Possibly related PRs
Suggested labels
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🧹 Outside diff range and nitpick comments (3)
source/tests/pt/model/test_compressed_descriptor_dpa2.py (2)
72-100
: Document the test data structure.The hardcoded coordinates and box values would benefit from documentation explaining their significance and structure (e.g., what molecular system they represent).
Add a comment explaining the test system:
+ # Test system configuration: + # - 6 atoms in total (2 of type 0, 4 of type 1) + # - Cubic box of size 13.0 self.coords = np.array([
122-146
: Consider adding more test cases for compression.While the current test verifies basic compression functionality, consider adding tests for:
- Different compression ratios
- Edge cases (e.g., very small or large compression ratios)
- Error cases (e.g., invalid compression values)
Example additional test:
def test_compression_invalid_ratio(self): with self.assertRaises(ValueError): self.descriptor.enable_compression(-1.0) with self.assertRaises(ValueError): self.descriptor.enable_compression(1.5)deepmd/pt/model/descriptor/dpa2.py (1)
890-890
: Typo in comment: 'mocel' should be 'model'There's a typographical error in the inline comment. The word 'mocel' should be corrected to 'model'.
Apply this diff to fix the typo:
- # do some checks before the mocel compression process + # do some checks before the model compression process
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
deepmd/pt/model/descriptor/dpa2.py
(2 hunks)deepmd/pt/utils/tabulate.py
(2 hunks)source/tests/pt/model/test_compressed_descriptor_dpa2.py
(1 hunks)
🔇 Additional comments (3)
source/tests/pt/model/test_compressed_descriptor_dpa2.py (1)
1-50
: LGTM! Well-structured helper function with proper type hints.
The helper function eval_pt_descriptor
is well-implemented with:
- Clear type hints and parameters
- Proper device handling for PyTorch tensors
- Logical workflow for descriptor evaluation
deepmd/pt/utils/tabulate.py (2)
297-303
: LGTM! Clean implementation of descriptor type handling.
The changes elegantly extend support for DescrptDPA2
while maintaining consistency with DescrptDPA1
. The implementation follows best practices by using isinstance
with a tuple for multiple type checks.
98-103
: Verify compression functionality for DescrptDPA2.
Since these changes enable compression for DescrptDPA2
, it would be valuable to verify:
- The compression process works as expected
- The compressed model produces consistent results
✅ Verification successful
Compression functionality for DescrptDPA2 is properly tested
The test file test_compressed_descriptor_dpa2.py
contains comprehensive tests that:
- Verify the model's output shape remains consistent after compression
- Validate that compressed model results match the uncompressed results within acceptable tolerance
- Test both float32 and float64 precision levels
- Include proper test setup with realistic input data
The implementation includes proper error handling and precision-specific tolerances, ensuring reliable compression functionality.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Search for tests that verify DescrptDPA2 compression functionality
# Look for test files related to DescrptDPA2
echo "Searching for DescrptDPA2 test files..."
fd -e py "test.*dpa2.*" source/tests/
# Look for compression-related test cases
echo "Searching for compression test cases..."
rg -l "compress.*dpa2|dpa2.*compress" source/tests/
Length of output: 509
Script:
#!/bin/bash
# Let's examine the content of the compression test file
echo "Checking compression test content..."
cat source/tests/pt/model/test_compressed_descriptor_dpa2.py
# Let's also check if there are any compression-related tests in the descriptor test file
echo -e "\nChecking descriptor test content..."
rg "compress" source/tests/pt/model/test_descriptor_dpa2.py -A 5
# Look for actual compression implementation
echo -e "\nChecking compression implementation..."
rg "compress.*dpa2|dpa2.*compress" -A 5 --type py
Length of output: 4803
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4329 +/- ##
==========================================
- Coverage 84.60% 84.45% -0.16%
==========================================
Files 571 571
Lines 53163 53187 +24
Branches 3059 3059
==========================================
- Hits 44981 44918 -63
- Misses 7218 7306 +88
+ Partials 964 963 -1 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (3)
deepmd/pt/model/descriptor/dpa2.py (3)
876-876
: Fix typo in docstring: 'statisitcs' → 'statistics'The docstring contains a spelling error.
- """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data. + """Receive the statistics (distance, max_nbor_size and env_mat_range) of the training data.
868-875
: Add return type hint to method signatureThe method signature should include a return type hint for better type safety.
- def enable_compression( + def enable_compression( self, min_nbor_dist: float, table_extrapolate: float = 5, table_stride_1: float = 0.01, table_stride_2: float = 0.1, check_frequency: int = -1, - ) -> None: + ) -> None:
935-940
: Consider using a dictionary for table configurationUsing a list for configuration makes the code less maintainable and more prone to errors. Consider using a dictionary instead.
- self.table_config = [ - table_extrapolate, - table_stride_1, - table_stride_2, - check_frequency, - ] + self.table_config = { + "extrapolate": table_extrapolate, + "stride_1": table_stride_1, + "stride_2": table_stride_2, + "check_frequency": check_frequency, + }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/pt/model/descriptor/dpa2.py
(3 hunks)
🔇 Additional comments (3)
deepmd/pt/model/descriptor/dpa2.py (3)
34-36
: LGTM: New imports are properly organized.
The additions of DPTabulate
and ActivationFn
imports are well-organized and follow the project's import style.
Also applies to: 41-42
313-313
: LGTM: Proper initialization of compression flag.
The compress
flag is correctly initialized to False
in __init__
, addressing potential AttributeError
concerns.
922-924
:
Fix inconsistent error message
The error message doesn't match the condition being checked. The condition checks if tebd_input_mode != "strip"
, but the error message refers to "concat"
.
- raise RuntimeError(
- "Cannot compress model when repinit tebd_input_mode == 'concat'"
- )
+ raise RuntimeError(
+ "Cannot compress model when repinit tebd_input_mode != 'strip'"
+ )
Likely invalid or redundant comment.
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (2)
deepmd/pt/model/descriptor/dpa2.py (2)
868-875
: Add return type hint to method signature.The method signature should include a return type hint for better type safety.
def enable_compression( self, min_nbor_dist: float, table_extrapolate: float = 5, table_stride_1: float = 0.01, table_stride_2: float = 0.1, check_frequency: int = -1, - ) -> None: + ) -> None: # noqa: D202
891-925
: Consider refactoring validation checks into separate methods.The pre-compression validation logic is complex and could be modularized for better maintainability. Consider extracting the checks into separate validation methods.
+ def _validate_compression_state(self) -> None: + if self.compress: + raise ValueError("Compression is already enabled.") + def _validate_repinit_settings(self) -> None: + if self.repinit.resnet_dt: + raise RuntimeError("Model compression error: repinit resnet_dt must be false!") + if self.repinit.attn_layer != 0: + raise RuntimeError("Cannot compress model when repinit attention layer is not 0.") + if self.repinit.tebd_input_mode != "strip": + raise RuntimeError("Cannot compress model when repinit tebd_input_mode != 'strip'") + def _validate_exclude_types(self) -> None: + for tt in self.repinit.exclude_types: + if (tt[0] not in range(self.repinit.ntypes)) or (tt[1] not in range(self.repinit.ntypes)): + raise RuntimeError( + f"Repinit exclude types {tt} must within the number of atomic types {self.repinit.ntypes}!" + ) + if self.repinit.ntypes * self.repinit.ntypes - len(self.repinit.exclude_types) == 0: + raise RuntimeError("Repinit empty embedding-nets are not supported in model compression!") def enable_compression( self, min_nbor_dist: float, table_extrapolate: float = 5, table_stride_1: float = 0.01, table_stride_2: float = 0.1, check_frequency: int = -1, ) -> None: - # do some checks before the mocel compression process - if self.compress: - raise ValueError("Compression is already enabled.") - assert ( - not self.repinit.resnet_dt - ), "Model compression error: repinit resnet_dt must be false!" - for tt in self.repinit.exclude_types: - if (tt[0] not in range(self.repinit.ntypes)) or ( - tt[1] not in range(self.repinit.ntypes) - ): - raise RuntimeError( - "Repinit exclude types" - + str(tt) - + " must within the number of atomic types " - + str(self.repinit.ntypes) - + "!" - ) - if ( - self.repinit.ntypes * self.repinit.ntypes - len(self.repinit.exclude_types) - == 0 - ): - raise RuntimeError( - "Repinit empty embedding-nets are not supported in model compression!" - ) - - if self.repinit.attn_layer != 0: - raise RuntimeError( - "Cannot compress model when repinit attention layer is not 0." - ) - - if self.repinit.tebd_input_mode != "strip": - raise RuntimeError( - "Cannot compress model when repinit tebd_input_mode == 'concat'" - ) + self._validate_compression_state() + self._validate_repinit_settings() + self._validate_exclude_types()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/pt/model/descriptor/dpa2.py
(3 hunks)
🔇 Additional comments (5)
deepmd/pt/model/descriptor/dpa2.py (5)
34-36
: LGTM: New imports are properly organized.
The added imports for DPTabulate
and ActivationFn
are correctly placed and necessary for the new compression functionality.
Also applies to: 41-41
313-313
: LGTM: Properly initialized compression flag.
The compress
attribute is correctly initialized with a default value of False
, addressing the potential AttributeError
issue mentioned in past review comments.
876-876
: Fix typo in docstring.
The word "statistics" is misspelled.
926-948
: LGTM: Compression setup looks correct.
The compression setup logic is well-structured:
- Serializes the model data
- Initializes the tabulation with proper parameters
- Builds the tables with provided configuration
- Enables compression on the repinit module
921-924
:
Fix inconsistent error message.
The error message doesn't match the condition being checked. When tebd_input_mode != "strip"
, the error message incorrectly states it's "concat"
.
Summary by CodeRabbit
New Features
Tests