-
Notifications
You must be signed in to change notification settings - Fork 519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: consistent DPA-1 model #4320
Conversation
Fix deepmodeling#4022. Note that `smooth_type_embedding==True` is not consistent between TF and others. Also fix several issues. Signed-off-by: Jinzhe Zeng <[email protected]>
📝 Walkthrough📝 WalkthroughWalkthroughThe pull request introduces several modifications across multiple files, primarily focusing on enhancing descriptor handling in the Changes
Assessment against linked issues
Possibly related PRs
Suggested labels
Suggested reviewers
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Outside diff range and nitpick comments (9)
deepmd/jax/model/model.py (1)
40-40
: Consider adding type validation.While setting
ntypes
is correct, consider adding validation to ensuretype_map
is not empty and contains unique types.+ if not data["type_map"]: + raise ValueError("type_map cannot be empty") + if len(set(data["type_map"])) != len(data["type_map"]): + raise ValueError("type_map contains duplicate types") data["descriptor"]["ntypes"] = len(data["type_map"])deepmd/tf/utils/type_embed.py (1)
Line range hint
1-394
: Consider modernizing the TensorFlow implementation.While not directly related to the current changes, there are some architectural improvements that could be considered for future updates:
- The code uses TensorFlow 1.x style APIs (tf.variable_scope, etc.). Consider modernizing to TF 2.x patterns for better maintainability and performance.
- The versioning system could be enhanced to better handle backward compatibility, especially around serialization changes.
- The error handling for version mismatches could be more informative.
These improvements would make the codebase more maintainable and future-proof. Consider creating separate issues to track these modernization efforts.
deepmd/dpmodel/descriptor/hybrid.py (1)
46-47
: LGTM! Consider enhancing parameter documentation.The addition of
type_map
andntypes
parameters aligns well with the PR objective of fixing type embedding consistency across frameworks.Consider updating the class docstring to document these new parameters:
Parameters ---------- list : list : list[Union[BaseDescriptor, dict[str, Any]]] Build a descriptor from the concatenation of the list of descriptors. The descriptor can be either an object or a dictionary. +type_map : Optional[list[str]], optional + List of atom type names, by default None +ntypes : Optional[int], optional + Number of atom types, by default Nonedeepmd/tf/descriptor/hybrid.py (1)
464-465
: Enhance error message with more context and alternatives.While the error check is correct, the error message could be more informative to help users understand why this combination is not supported and what alternatives they might consider.
- raise NotImplementedError("hybrid + type embedding is not supported") + raise NotImplementedError( + "Hybrid descriptors with type embedding are not supported due to framework compatibility issues. " + "Consider using individual descriptors with type embedding instead." + )deepmd/tf/model/model.py (1)
843-844
: LGTM! Consider adding error handling for missing neuron dimensions.The type embedding handling during serialization looks correct. The code properly assigns the type embedding to the descriptor and updates the fitting dimensions.
Consider adding a validation check for
self.typeebd.neuron
to handle cases where the neuron dimensions might not be initialized:if self.typeebd is not None: + if not hasattr(self.typeebd, 'neuron') or not self.typeebd.neuron: + raise ValueError("Type embedding neuron dimensions not initialized") self.descrpt.type_embedding = self.typeebd self.fitting.tebd_dim = self.typeebd.neuron[-1]source/tests/consistent/model/test_dpa1.py (1)
160-160
: Correct typographical error in commentThere's a minor typographical error in the comment on line 160:
- # TF requires the atype to be sort + # TF requires the atype to be sortedUpdating the comment improves code readability and clarity.
deepmd/tf/descriptor/se_atten.py (3)
226-226
: Typo in Exception Message: Correct 'Disbaling' to 'Disabling'There's a typographical error in the exception message at line 226. The word 'Disbaling' should be corrected to 'Disabling'.
Apply this diff to fix the typo:
- raise NotImplementedError("Disbaling concat_output_tebd is not supported.") + raise NotImplementedError("Disabling concat_output_tebd is not supported.")
2008-2010
: Clarify Exception Message for Unsupported FeatureThe exception message at lines 2008-2010 could be improved for clarity. Consider rephrasing it to "Only single-layer type embedding networks are supported." for better understanding.
Apply this diff:
raise NotImplementedError( - "Only support single layer type embedding network" + "Only single-layer type embedding networks are supported" )
1884-1887
: Avoid Modifying Input Parameterdata
to Prevent Side EffectsAt lines 1884-1887, the code modifies the input dictionary
data
by setting a default value for"use_tebd_bias"
. Modifying input parameters can lead to unintended side effects ifdata
is used elsewhere. Consider creating a copy ofdata
or setting a local variable to avoid altering the original input.Apply this diff to prevent modifying
data
directly:if "use_tebd_bias" not in data: # v1 compatibility - data["use_tebd_bias"] = True + use_tebd_bias = True else: + use_tebd_bias = data.pop("use_tebd_bias") - type_embedding.use_tebd_bias = data.pop("use_tebd_bias") + type_embedding.use_tebd_bias = use_tebd_bias
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (13)
deepmd/dpmodel/descriptor/hybrid.py
(2 hunks)deepmd/dpmodel/model/model.py
(1 hunks)deepmd/jax/descriptor/__init__.py
(2 hunks)deepmd/jax/model/model.py
(1 hunks)deepmd/pt/model/task/fitting.py
(1 hunks)deepmd/tf/descriptor/hybrid.py
(2 hunks)deepmd/tf/descriptor/se_atten.py
(8 hunks)deepmd/tf/descriptor/se_atten_v2.py
(4 hunks)deepmd/tf/fit/ener.py
(4 hunks)deepmd/tf/model/model.py
(3 hunks)deepmd/tf/utils/type_embed.py
(1 hunks)doc/backend.md
(0 hunks)source/tests/consistent/model/test_dpa1.py
(1 hunks)
💤 Files with no reviewable changes (1)
- doc/backend.md
🔇 Additional comments (20)
deepmd/jax/descriptor/__init__.py (2)
11-13
: LGTM! Import follows consistent style.
The import of DescrptSeAttenV2
follows the established pattern and maintains consistency with other descriptor imports.
33-33
: LGTM! Proper exposure of the new descriptor.
The addition of DescrptSeAttenV2
to __all__
correctly exposes the new descriptor class while maintaining alphabetical ordering.
Let's verify the module structure and implementation:
✅ Verification successful
Let me gather more information to verify the class implementation since the previous ast-grep command didn't show the actual implementation.
Proper exposure and implementation of DescrptSeAttenV2 confirmed
The verification shows that:
- The class is properly implemented in
se_atten_v2.py
, inheriting fromDescrptDPA1
andDescrptSeAttenV2DP
- It's correctly registered with the
@BaseDescriptor.register("se_atten_v2")
decorator - The class is properly imported and exposed in
__init__.py
- The
__all__
list maintains alphabetical ordering
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify the se_atten_v2 module exists and contains the DescrptSeAttenV2 class
# Test 1: Check if the module file exists
fd "se_atten_v2.py" deepmd/jax/descriptor/
# Test 2: Verify class implementation
ast-grep --pattern 'class DescrptSeAttenV2' deepmd/jax/descriptor/se_atten_v2.py
Length of output: 162
Script:
#!/bin/bash
# Let's check the actual content of the se_atten_v2.py file
cat deepmd/jax/descriptor/se_atten_v2.py
# Also verify the import in __init__.py to ensure proper exposure
cat deepmd/jax/descriptor/__init__.py
Length of output: 1184
deepmd/jax/model/model.py (2)
40-40
: Consistent implementation across model types.
The same ntypes
assignment logic is correctly applied to both standard and ZBL models, maintaining consistency in type handling.
Also applies to: 63-63
40-40
: LGTM: Explicit type count enhances model consistency.
Setting ntypes
explicitly from type_map
length ensures consistent type handling across different backends, which is crucial for type embedding support.
Let's verify this change is consistent with other backend implementations:
✅ Verification successful
Based on the search results, I can now generate the final response as we have sufficient evidence to verify the consistency of ntypes
handling across different backend implementations:
✓ Setting ntypes from type_map length is consistent across backends
The change is correctly implemented and matches the pattern used in other backend implementations:
- JAX backend:
data["descriptor"]["ntypes"] = len(data["type_map"])
in bothget_standard_model
andget_zbl_model
- DPModel backend: Same pattern in
deepmd/dpmodel/model/model.py
- PyTorch backend: Uses
ntypes = len(type_map)
consistently in model initialization - TensorFlow backend: Sets
ntypes
fromlen(self.get_type_map())
in model classes
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check if ntypes is consistently set across different backend implementations
# Expected: Similar ntypes assignment in other backend model files
# Search for model.py files in different backend directories
fd -t f "model.py" -x rg -A 5 "descriptor\[.\"'\]ntypes.*=.*len.*type_map"
Length of output: 76
Script:
#!/bin/bash
# Let's try a broader search to understand how ntypes is handled across backends
# First, let's find all model.py files
echo "=== Model files ==="
fd "model.py$"
echo -e "\n=== Searching for ntypes handling ==="
# Search for ntypes assignments in these files
rg -l "ntypes" $(fd "model.py$")
echo -e "\n=== Detailed ntypes context ==="
# Get context around ntypes usage
rg -B 2 -A 2 "ntypes" $(fd "model.py$")
Length of output: 58352
deepmd/dpmodel/model/model.py (2)
40-40
: LGTM! Improved type handling consistency.
The explicit setting of ntypes
and unified descriptor creation through BaseDescriptor
aligns well with the PR's objective to fix type embedding inconsistencies between frameworks.
Also applies to: 43-45
43-45
: Verify backward compatibility with existing models.
The switch to using BaseDescriptor
uniformly needs verification to ensure:
- All existing descriptor types work correctly
- Type embedding is properly handled in
BaseDescriptor
- Models can still be converted between frameworks
✅ Verification successful
Based on the search results, I can now generate the final response:
Backward compatibility verification completed - no issues found
The codebase shows strong evidence of backward compatibility:
- All descriptor types are properly handled through the BaseDescriptor interface, with existing descriptors (se_a, se_atten, etc.) still maintained in tests
- Type embedding is well supported with smooth_type_embedding parameter properly implemented across TF/PyTorch/JAX backends
- Model conversion between frameworks is supported through the convert-backend command with extensive test coverage
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for potential impacts on existing models and type embedding
# Test 1: Check for any remaining references to removed descriptor types
echo "Checking for references to removed descriptor types..."
rg "DescrptSeA|se_atten|se_atten_v2" --type py
# Test 2: Check for type embedding related configurations in tests
echo "Checking for type embedding tests..."
rg "smooth_type_embedding.*True" --type py
# Test 3: Look for model conversion tests
echo "Checking for model conversion tests..."
rg "convert.*backend|dp\s+convert" --type py
Length of output: 62838
deepmd/tf/utils/type_embed.py (1)
330-332
: Optimization: Skip serialization of default identity weights.
This change improves the consistency of type embedding serialization across different frameworks by explicitly skipping the serialization of identity weights that are set to 0.0. This is a good optimization as:
- It prevents unnecessary data in the serialized model
- Aligns with the PR objective of making type embedding handling consistent between TensorFlow and other frameworks
- The comment clearly explains the purpose of this check
However, there are a few considerations:
Let's verify if this change maintains backward compatibility with existing models:
Consider adding a version check to ensure older versions of the framework can still load models serialized with this change. This could be done by:
- Incrementing the
@version
field when this behavior changes - Adding appropriate version handling in the
deserialize
method
✅ Verification successful
Based on the gathered information, I can now provide the final response:
The optimization to skip serialization of zero identity weights is safe and consistent.
The change is well-aligned with the framework's design and implementation because:
- Version compatibility is properly maintained through
check_version_compatibility
checks across frameworks (TF, JAX, and dpmodel) - The deserialize method in array_api_strict confirms consistent handling across implementations
- Zero weights are explicitly handled as default values during deserialization
- No existing tests or models were found that rely on serialized zero identity weights
The change improves consistency across frameworks while maintaining backward compatibility through proper version checks (supporting versions 1 and 2).
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for any existing models that might rely on the serialized 0.0 idt weights
rg -g '*.json' -g '*.pb' '"idt":\s*0\.0'
Length of output: 42
Script:
#!/bin/bash
# Let's verify the serialization and deserialization logic more comprehensively
# 1. Check the deserialize method to understand how 0.0 values are handled
ast-grep --pattern 'def deserialize(self, data):
$$$
'
# 2. Look for any tests that verify serialization/deserialization
rg -l "test.*serialize"
# 3. Check if there are any version checks in the type embedding
rg "version|@version" deepmd/tf/utils/type_embed.py
# 4. Look for any documentation about the serialization format
fd "README|CHANGELOG" --type f --exec grep -l "type.*embed|serialize" {} \;
Length of output: 753
Script:
#!/bin/bash
# Let's examine the deserialize method and version handling more closely
# 1. Check the full context of deserialize method in type_embed.py
rg -B 5 -A 15 "def deserialize" deepmd/tf/utils/type_embed.py
# 2. Look for any test cases specifically for TypeEmbedNet
rg -l "TypeEmbedNet.*test|test.*TypeEmbedNet"
# 3. Check how idt weights are handled in tests
rg -C 5 "idt.*=.*0\.0|idt.*:.*0\.0" source/tests/
# 4. Look for any related changes in other frameworks (pytorch)
fd "type_embed" --type f --exec rg -l "serialize" {} \;
Length of output: 903
Script:
#!/bin/bash
# Let's check the serialization implementations across different frameworks
# 1. Check how other frameworks handle idt weights in serialization
rg -C 5 "def serialize" deepmd/jax/utils/type_embed.py deepmd/dpmodel/utils/type_embed.py
# 2. Check version handling in deserialization across frameworks
rg -C 5 "check_version_compatibility" deepmd/jax/utils/type_embed.py deepmd/dpmodel/utils/type_embed.py
# 3. Look for any related test cases in array_api_strict
rg -C 5 "serialize|deserialize" source/tests/array_api_strict/utils/type_embed.py
# 4. Check if there are any model compatibility tests
rg -l "test.*compatibility|compatibility.*test"
Length of output: 2765
deepmd/dpmodel/descriptor/hybrid.py (1)
61-63
: LGTM! Verify backward compatibility.
The changes correctly propagate type information to child descriptors, which is essential for consistent type embedding support.
Let's verify that existing models without type information still work:
✅ Verification successful
Based on the search results from previous scripts, I can now make a determination:
Type information propagation is safe and backward compatible
The code changes are safe because:
- The hybrid descriptor implementation properly handles both cases - with and without type information
- Test cases show successful initialization with and without type parameters
- The type information is only propagated when provided, maintaining compatibility with existing models
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for existing model usage without type_map/ntypes
rg -l "DescrptHybrid.*list\s*=.*\}" --type py
Length of output: 47
Script:
#!/bin/bash
# Let's try a broader search for DescrptHybrid usage patterns
rg "DescrptHybrid" --type py -B 2 -A 2
# Also search for test files that might show usage examples
fd "test.*\.py" --exec rg "DescrptHybrid" {} -B 2 -A 2
# Look for descriptor list initialization patterns
rg "descript_list.*=.*\[" --type py -B 2 -A 2
Length of output: 52603
Script:
#!/bin/bash
# Let's look for test cases that verify backward compatibility
rg "test.*backward.*compatibility" --type py -B 2 -A 2
# Also search for test cases that use DescrptHybrid without type_map/ntypes
ast-grep --pattern 'DescrptHybrid\(\s*list\s*=\s*\[.*\]\s*\)'
Length of output: 120
deepmd/tf/descriptor/hybrid.py (1)
Line range hint 464-493
: Verify impact on existing models.
The addition of type embedding checks could affect existing models that use hybrid descriptors.
Let's check for potential usage patterns in the codebase:
✅ Verification successful
Based on the search results, I can now generate the final response since I have enough information to verify the impact of type embedding checks on existing models:
Type embedding check is safe and won't affect existing models
The added type embedding check is consistent with existing behavior since:
- All example configurations and test files show hybrid descriptors are used without type embedding:
examples/water/hybrid/input.json
examples/dprc/generalized_force/input.json
examples/dprc/normal/input.json
source/tests/tf/water_hybrid.json
- The error message was already effectively enforced through the implementation, as evidenced by test files explicitly checking this limitation:
- Test files contain comments like "unsupported for SpinModel to hybrid both mixed_types and no-mixed_types descriptor"
- Multiple test classes verify this behavior across different frameworks (TF, PT, JAX)
- The change only makes the limitation more explicit by failing early with a clear error message rather than failing later during execution.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Search for hybrid descriptor usage with type embedding
# Expected results: No instances of hybrid descriptors with type embedding
# Test 1: Look for hybrid descriptor configurations in test files
echo "Searching for hybrid descriptor test configurations..."
rg -A 10 '"type":\s*"hybrid"' --type json
# Test 2: Look for hybrid descriptor instantiations in Python files
echo "Searching for hybrid descriptor usage in Python files..."
rg -B 5 -A 5 'DescrptHybrid|"type":\s*"hybrid"' --type py
Length of output: 75747
deepmd/pt/model/task/fitting.py (2)
422-423
: LGTM: Improved error message formatting
The change to f-strings improves readability while maintaining the same error information.
422-423
: Verify alignment with PR objectives
While the error message improvements are good, they seem tangential to the main PR objectives of fixing DPA-1 model consistency issues. Please confirm if there are other changes needed in this file to address the core objectives.
Let's check for any DPA-1 model related code:
deepmd/tf/model/model.py (2)
825-825
: LGTM! Type embedding parameter properly added.
The type embedding parameter is correctly added to the class constructor, maintaining consistency with the deserialization logic.
811-814
: LGTM! Verify type embedding dimension adjustment.
The type embedding handling during deserialization looks correct. The code properly passes the descriptor's type embedding to the model and adjusts the descriptor dimension to avoid double-counting.
Let's verify the dimension adjustment logic:
✅ Verification successful
Let me gather more context about the type embedding dimension handling.
Let me gather more information about the type embedding and descriptor dimensions.
Let me try a different search pattern to understand the type embedding and dimension handling.
Based on the code analysis, I can now provide a definitive response about the type embedding dimension adjustment.
Type embedding dimension adjustment is correct and necessary
The dimension adjustment in the deserialization code is correct because:
- The descriptor's output dimension (
get_dim_out()
) includes the type embedding dimension whenexplicit_ntypes
is true - The fitting network expects the raw descriptor dimension without the type embedding dimension
- The adjustment
fitting.dim_descrpt -= type_embedding.neuron[-1]
correctly subtracts the type embedding dimension from the fitting network's input dimension to avoid double-counting
The code shows consistent handling across different descriptor implementations (se_a, se_t, se_atten, etc.) where:
- Descriptors with
explicit_ntypes=True
concatenate type embedding to their output - The fitting network needs the original descriptor dimension without the type embedding
- The type embedding is passed separately to the fitting network via the model
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for any other places where descriptor dimension is adjusted
rg -A 5 "dim_descrpt.*=.*type_embedding"
Length of output: 441
Script:
#!/bin/bash
# Search for type_embedding related code
ast-grep --pattern 'type_embedding'
# Search for dim_descrpt usage
rg -B 3 -A 3 'dim_descrpt'
Length of output: 116453
Script:
#!/bin/bash
# Search for type embedding neuron dimension
ast-grep --pattern 'type_embedding.*neuron'
# Search for get_dim_out in descriptor classes
ast-grep --pattern 'def get_dim_out'
Length of output: 84
Script:
#!/bin/bash
# Search for TypeEmbedNet class definition and usage
ast-grep --pattern 'class TypeEmbedNet'
# Search for explicit_ntypes usage
rg -B 3 -A 3 'explicit_ntypes'
# Search for get_dim_out method calls
rg -B 3 -A 3 'get_dim_out\('
Length of output: 91537
deepmd/tf/fit/ener.py (3)
245-245
: LGTM: New attribute for type embedding dimension.
The addition of tebd_dim
attribute with default value 0 is appropriate for tracking the type embedding dimension.
758-759
: LGTM: Proper dimension recovery.
Correctly restores the original descriptor dimension after processing, which is essential for proper serialization.
914-914
: LGTM: Consistent dimension handling in serialization.
The changes properly account for the type embedding dimension in both the descriptor dimension and network input dimension calculations during serialization.
Let's verify the consistency of type embedding dimension handling across the codebase:
Also applies to: 936-936
✅ Verification successful
Let me gather more information about the type embedding dimension handling by searching for related patterns.
Based on the extensive search results, I can now provide a final response:
Verified: The dimension handling is consistent across the codebase
The type embedding dimension (tebd_dim
) is correctly added to the descriptor dimension (dim_descrpt
) in both the serialization and network input dimension calculations. This pattern is consistently implemented across:
- Energy fitting (
ener.py
): Addstebd_dim
todim_descrpt
for both serialization and network input - DOS fitting (
dos.py
): Properly handles combined dimensions for network input - General fitting base class: Consistently handles dimension combinations
- PyTorch implementation: Follows the same dimension handling patterns
The change ensures proper dimensionality when using type embeddings while maintaining consistency across different model components and backends.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for consistent type embedding dimension handling across different backends
# Test: Look for similar dimension handling patterns in other backend implementations
rg -A 2 "dim_descrpt.*tebd"
Length of output: 227
Script:
#!/bin/bash
# Search for type embedding dimension (tebd) related code
rg "tebd" -A 3
# Search for dimension descriptor handling
rg "dim_descrpt" -A 3
# Look for network input dimension calculations
ast-grep --pattern 'self.dim_descrpt + self.$_'
Length of output: 314480
deepmd/tf/descriptor/se_atten_v2.py (4)
7-9
: Import of TypeEmbedNet
is appropriate
The import statement correctly brings in TypeEmbedNet
, which is necessary for handling type embeddings in the descriptor.
136-136
: Version compatibility check is correctly implemented
The check_version_compatibility
function is used appropriately to ensure that the serialized data is compatible with version 2.1, which helps maintain consistency and prevent errors during deserialization.
156-159
: Backward compatibility for use_tebd_bias
is well-handled
The code ensures backward compatibility by setting use_tebd_bias
to True
when it's not present in the data. This approach maintains expected behavior for older versions that may not include this key.
170-170
: Assignment of type_embedding
to descriptor is appropriate
The type_embedding
instance is correctly assigned to descriptor.type_embedding
, ensuring that the deserialized type embedding is incorporated into the descriptor object.
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (3)
deepmd/tf/model/model.py (1)
845-846
: LGTM! Consider adding docstring for type embedding fields.The changes correctly handle type embedding during serialization by updating both the descriptor and fitting components. This ensures consistency between frameworks.
Consider adding docstring comments to document these fields:
if self.typeebd is not None: + # Pass type embedding to descriptor for consistent serialization self.descrpt.type_embedding = self.typeebd + # Update fitting dimension to include type embedding self.fitting.tebd_dim = self.typeebd.neuron[-1]deepmd/tf/descriptor/se_atten.py (2)
1947-1954
: Robust type embedding dimension handlingThe code properly calculates the embedding input dimension based on the type embedding mode and configuration. However, the logic could be more maintainable.
Consider extracting this logic into a separate method for better maintainability:
- tebd_dim = self.type_embedding.neuron[0] - if self.tebd_input_mode in ["concat"]: - if not self.type_one_side: - embd_input_dim = 1 + tebd_dim * 2 - else: - embd_input_dim = 1 + tebd_dim - else: - embd_input_dim = 1 + def calculate_embedding_input_dim(tebd_dim: int) -> int: + if self.tebd_input_mode in ["concat"]: + return 1 + tebd_dim * (2 if not self.type_one_side else 1) + return 1 + + tebd_dim = self.type_embedding.neuron[0] + embd_input_dim = calculate_embedding_input_dim(tebd_dim)
2012-2018
: Validation of type embedding network structureGood validation check for single layer type embedding networks. However, consider providing more guidance in the error message.
Enhance the error message:
- raise NotImplementedError( - "Only support single layer type embedding network" - ) + raise NotImplementedError( + "Multi-layer type embedding networks are not supported. Please use a single layer network with neuron=[tebd_dim]" + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (4)
deepmd/dpmodel/descriptor/hybrid.py
(2 hunks)deepmd/tf/descriptor/se_atten.py
(9 hunks)deepmd/tf/descriptor/se_atten_v2.py
(4 hunks)deepmd/tf/model/model.py
(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/dpmodel/descriptor/hybrid.py
🔇 Additional comments (10)
deepmd/tf/descriptor/se_atten_v2.py (5)
7-9
: LGTM: Clean import and well-documented class.
The import of TypeEmbedNet
and class documentation are well-structured and complete.
Line range hint 89-117
: LGTM: Constructor properly enforces consistent configuration.
The constructor correctly sets smooth_type_embedding=True
and tebd_input_mode="strip"
to ensure consistency across backends, which aligns with the PR objectives.
133-135
: LGTM: Clear error handling and version requirements.
The RuntimeError and version compatibility check (2.1) effectively enforce the intended backend consistency requirements.
Also applies to: 139-139
Line range hint 176-192
: LGTM: Clean serialization implementation.
The serialization method correctly removes the fixed configuration values and updates the type identifier.
156-162
:
Previous review comment still applies: Handle potential KeyError.
The code is still vulnerable to KeyError when accessing 'type_embedding'. Consider implementing the suggested error handling from the previous review.
Apply this diff as suggested in the previous review:
+ if "type_embedding" not in data:
+ raise KeyError("'type_embedding' key not found in data during deserialization.")
type_embedding = TypeEmbedNet.deserialize(
data.pop("type_embedding"), suffix=suffix
)
deepmd/tf/model/model.py (1)
811-816
: LGTM! Consider adding error handling for dimension validation.
The changes correctly handle type embedding during deserialization by passing it from the descriptor to the model when explicit_ntypes
is true. The dimension adjustment prevents double-counting of the embedding dimension.
Let's verify the type embedding dimension handling:
Also applies to: 827-827
✅ Verification successful
Based on the search results, I can now generate the final response:
The dimension adjustment is safe and correctly implemented.
The code correctly subtracts the type embedding dimension (type_embedding.neuron[-1]
) from fitting.dim_descrpt
when explicit_ntypes
is true. This adjustment is necessary because:
- The type embedding dimension is properly defined in the descriptor (see
deepmd/tf/descriptor/se_atten.py
anddeepmd/tf/descriptor/se_a.py
) - The fitting networks consistently use
dim_descrpt
as input dimension across all implementations - The dimension adjustment only happens when
explicit_ntypes
is true, preventing any unintended modifications
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for potential dimension mismatches in type embedding usage
rg -A 5 "dim_descrpt\s*[-+]=\s*"
Length of output: 407
Script:
#!/bin/bash
# Let's check how type_embedding.neuron is used and defined
rg -A 5 "type_embedding.*neuron"
# Also check where dim_descrpt is used to ensure dimension consistency
rg -A 5 "dim_descrpt\s*=\s*"
# Look for any validation of dimensions
ast-grep --pattern 'if $_.dim_descrpt $_'
Length of output: 35719
deepmd/tf/descriptor/se_atten.py (4)
1885-1891
: Improved version compatibility handling for type embedding
Good addition of version compatibility handling for use_tebd_bias
. The code properly handles both v1 and v2 formats by setting a default value of True for v1 compatibility.
1902-1911
: Proper handling of stripped type embedding
The code correctly handles the stripped type embedding mode by deserializing the embeddings with appropriate type_one_side parameter. This ensures compatibility with different model configurations.
1927-1929
: Consistent error handling for smooth type embedding
The error handling for smooth type embedding is consistent with the deserialization logic, ensuring that incompatible configurations are caught early.
2038-2084
: Comprehensive attention layer serialization update
The code properly updates the attention layer serialization to maintain consistency with other backend references. The default values and attention layer parameters are well-structured.
However, consider adding validation for the default values.
✅ Verification successful
Based on the search results, I can now provide a final response:
Default values are consistently defined across backends
The verification shows that the default values in the serialization code are consistent with their usage across different backend implementations (TensorFlow, PyTorch, and Universal):
scaling_factor=1.0
normalize=True
temperature=None
concat_output_tebd=True
use_econf_tebd=False
These values are properly synchronized in:
- TensorFlow:
deepmd/tf/descriptor/se_atten.py
- PyTorch:
deepmd/pt/model/descriptor/se_atten.py
- Universal:
deepmd/dpmodel/descriptor/se_atten_v2.py
,deepmd/dpmodel/descriptor/dpa1.py
The attention layer serialization implementation maintains consistency across all backend references.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Verify that the default values match across all backends
rg --type python "scaling_factor.*=.*1\.0|normalize.*=.*True|temperature.*=.*None|concat_output_tebd.*=.*True|use_econf_tebd.*=.*False"
Length of output: 168
Script:
#!/bin/bash
# Search for default values in Python files without specifying file type
rg "scaling_factor.*=.*1\.0|normalize.*=.*True|temperature.*=.*None|concat_output_tebd.*=.*True|use_econf_tebd.*=.*False" --type-add 'py:*.py' -t py
Length of output: 6221
Script:
#!/bin/bash
# Search for attention layer configuration in Python files
rg "attention_layers.*=|attention_layer.*=" --type-add 'py:*.py' -t py -C 2
Length of output: 14584
Script:
#!/bin/bash
# Find all Python files that might contain attention-related configurations
fd -e py | xargs grep -l "attention"
Length of output: 792
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4320 +/- ##
==========================================
+ Coverage 84.42% 84.47% +0.05%
==========================================
Files 570 570
Lines 53080 53160 +80
Branches 3054 3059 +5
==========================================
+ Hits 44813 44909 +96
+ Misses 7307 7289 -18
- Partials 960 962 +2 ☔ View full report in Codecov by Sentry. |
Fix #4022.
Note that
smooth_type_embedding==True
is not consistent between TF and others.Also, fix several issues.
Summary by CodeRabbit
Release Notes
New Features
DescrptSeAttenV2
for advanced attention mechanisms.Bug Fixes
Documentation
Style
Tests