Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: consistent DPA-1 model #4320

Merged
merged 6 commits into from
Nov 8, 2024

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Nov 7, 2024

Fix #4022.
Note that smooth_type_embedding==True is not consistent between TF and others.
Also, fix several issues.

Summary by CodeRabbit

Release Notes

  • New Features

    • Enhanced configurability of descriptors with new optional parameters for type mapping and type count.
    • Introduction of a new class DescrptSeAttenV2 for advanced attention mechanisms.
    • Added a new unit test framework for validating energy models across multiple backends.
  • Bug Fixes

    • Improved error handling in descriptor serialization methods to prevent unsupported operations.
  • Documentation

    • Updated backend documentation to include JAX support and clarify file extensions for various backends.
  • Style

    • Enhanced readability of error messages in fitting classes.
  • Tests

    • Comprehensive unit tests added for energy models across different machine learning frameworks.

Fix deepmodeling#4022.
Note that `smooth_type_embedding==True` is not consistent between TF and others.
Also fix several issues.

Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

coderabbitai bot commented Nov 7, 2024

📝 Walkthrough
📝 Walkthrough

Walkthrough

The pull request introduces several modifications across multiple files, primarily focusing on enhancing descriptor handling in the DescrptHybrid class and related model functions. Key changes include the addition of optional parameters for type mapping and type counts, the removal of specific descriptor types to simplify model creation, and updates to serialization and deserialization processes for various descriptor classes. Additionally, new unit tests for energy models across different backends have been added, along with documentation updates for backend support.

Changes

File Path Change Summary
deepmd/dpmodel/descriptor/hybrid.py Updated DescrptHybrid constructor to include type_map and ntypes. Added get_type_map method.
deepmd/dpmodel/model/model.py Removed DescrptSeA handling from get_standard_model, get_zbl_model, and get_spin_model.
deepmd/jax/descriptor/__init__.py Imported DescrptSeAttenV2 and updated __all__ list.
deepmd/jax/model/model.py Added ntypes assignment in get_standard_model and get_zbl_model.
deepmd/pt/model/task/fitting.py Updated error handling in _forward_common with f-strings. Initialized bias_atom_e tensor with zeros.
deepmd/tf/descriptor/hybrid.py Added checks for type_embedding in serialize and deserialize methods.
deepmd/tf/descriptor/se_atten.py Updated error messages and version compatibility checks in DescrptSeAtten.
deepmd/tf/descriptor/se_atten_v2.py Introduced DescrptSeAttenV2 class with updated initialization and deserialization.
deepmd/tf/fit/ener.py Added tebd_dim attribute to EnerFitting and updated serialization.
deepmd/tf/model/model.py Enhanced type embedding handling in deserialize and serialize methods of StandardModel.
deepmd/tf/utils/type_embed.py Updated serialize method to skip weights set to 0.0.
doc/backend.md Updated documentation for JAX backend and expanded DP backend section.
source/tests/consistent/model/test_dpa1.py Added unit tests for EnergyModel classes across multiple backends.

Assessment against linked issues

Objective Addressed Explanation
Support dp convert-backend for models with type embedding for tf (#[4022]) Changes do not implement support for type embedding in conversion.

Possibly related PRs

Suggested labels

Python, Docs

Suggested reviewers

  • iProzd
  • wanghan-iapcm

📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 6ae7169 and cc67fdd.

📒 Files selected for processing (1)
  • deepmd/tf/utils/type_embed.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/tf/utils/type_embed.py

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Outside diff range and nitpick comments (9)
deepmd/jax/model/model.py (1)

40-40: Consider adding type validation.

While setting ntypes is correct, consider adding validation to ensure type_map is not empty and contains unique types.

+    if not data["type_map"]:
+        raise ValueError("type_map cannot be empty")
+    if len(set(data["type_map"])) != len(data["type_map"]):
+        raise ValueError("type_map contains duplicate types")
     data["descriptor"]["ntypes"] = len(data["type_map"])
deepmd/tf/utils/type_embed.py (1)

Line range hint 1-394: Consider modernizing the TensorFlow implementation.

While not directly related to the current changes, there are some architectural improvements that could be considered for future updates:

  1. The code uses TensorFlow 1.x style APIs (tf.variable_scope, etc.). Consider modernizing to TF 2.x patterns for better maintainability and performance.
  2. The versioning system could be enhanced to better handle backward compatibility, especially around serialization changes.
  3. The error handling for version mismatches could be more informative.

These improvements would make the codebase more maintainable and future-proof. Consider creating separate issues to track these modernization efforts.

deepmd/dpmodel/descriptor/hybrid.py (1)

46-47: LGTM! Consider enhancing parameter documentation.

The addition of type_map and ntypes parameters aligns well with the PR objective of fixing type embedding consistency across frameworks.

Consider updating the class docstring to document these new parameters:

 Parameters
 ----------
 list : list : list[Union[BaseDescriptor, dict[str, Any]]]
     Build a descriptor from the concatenation of the list of descriptors.
     The descriptor can be either an object or a dictionary.
+type_map : Optional[list[str]], optional
+    List of atom type names, by default None
+ntypes : Optional[int], optional
+    Number of atom types, by default None
deepmd/tf/descriptor/hybrid.py (1)

464-465: Enhance error message with more context and alternatives.

While the error check is correct, the error message could be more informative to help users understand why this combination is not supported and what alternatives they might consider.

-            raise NotImplementedError("hybrid + type embedding is not supported")
+            raise NotImplementedError(
+                "Hybrid descriptors with type embedding are not supported due to framework compatibility issues. "
+                "Consider using individual descriptors with type embedding instead."
+            )
deepmd/tf/model/model.py (1)

843-844: LGTM! Consider adding error handling for missing neuron dimensions.

The type embedding handling during serialization looks correct. The code properly assigns the type embedding to the descriptor and updates the fitting dimensions.

Consider adding a validation check for self.typeebd.neuron to handle cases where the neuron dimensions might not be initialized:

 if self.typeebd is not None:
+    if not hasattr(self.typeebd, 'neuron') or not self.typeebd.neuron:
+        raise ValueError("Type embedding neuron dimensions not initialized")
     self.descrpt.type_embedding = self.typeebd
     self.fitting.tebd_dim = self.typeebd.neuron[-1]
source/tests/consistent/model/test_dpa1.py (1)

160-160: Correct typographical error in comment

There's a minor typographical error in the comment on line 160:

-     # TF requires the atype to be sort
+     # TF requires the atype to be sorted

Updating the comment improves code readability and clarity.

deepmd/tf/descriptor/se_atten.py (3)

226-226: Typo in Exception Message: Correct 'Disbaling' to 'Disabling'

There's a typographical error in the exception message at line 226. The word 'Disbaling' should be corrected to 'Disabling'.

Apply this diff to fix the typo:

-                raise NotImplementedError("Disbaling concat_output_tebd is not supported.")
+                raise NotImplementedError("Disabling concat_output_tebd is not supported.")

2008-2010: Clarify Exception Message for Unsupported Feature

The exception message at lines 2008-2010 could be improved for clarity. Consider rephrasing it to "Only single-layer type embedding networks are supported." for better understanding.

Apply this diff:

                 raise NotImplementedError(
-                    "Only support single layer type embedding network"
+                    "Only single-layer type embedding networks are supported"
                 )

1884-1887: Avoid Modifying Input Parameter data to Prevent Side Effects

At lines 1884-1887, the code modifies the input dictionary data by setting a default value for "use_tebd_bias". Modifying input parameters can lead to unintended side effects if data is used elsewhere. Consider creating a copy of data or setting a local variable to avoid altering the original input.

Apply this diff to prevent modifying data directly:

             if "use_tebd_bias" not in data:
                 # v1 compatibility
-                data["use_tebd_bias"] = True
+                use_tebd_bias = True
             else:
+                use_tebd_bias = data.pop("use_tebd_bias")
-            type_embedding.use_tebd_bias = data.pop("use_tebd_bias")
+            type_embedding.use_tebd_bias = use_tebd_bias
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 430dfa9 and 4de5511.

📒 Files selected for processing (13)
  • deepmd/dpmodel/descriptor/hybrid.py (2 hunks)
  • deepmd/dpmodel/model/model.py (1 hunks)
  • deepmd/jax/descriptor/__init__.py (2 hunks)
  • deepmd/jax/model/model.py (1 hunks)
  • deepmd/pt/model/task/fitting.py (1 hunks)
  • deepmd/tf/descriptor/hybrid.py (2 hunks)
  • deepmd/tf/descriptor/se_atten.py (8 hunks)
  • deepmd/tf/descriptor/se_atten_v2.py (4 hunks)
  • deepmd/tf/fit/ener.py (4 hunks)
  • deepmd/tf/model/model.py (3 hunks)
  • deepmd/tf/utils/type_embed.py (1 hunks)
  • doc/backend.md (0 hunks)
  • source/tests/consistent/model/test_dpa1.py (1 hunks)
💤 Files with no reviewable changes (1)
  • doc/backend.md
🔇 Additional comments (20)
deepmd/jax/descriptor/__init__.py (2)

11-13: LGTM! Import follows consistent style.

The import of DescrptSeAttenV2 follows the established pattern and maintains consistency with other descriptor imports.


33-33: LGTM! Proper exposure of the new descriptor.

The addition of DescrptSeAttenV2 to __all__ correctly exposes the new descriptor class while maintaining alphabetical ordering.

Let's verify the module structure and implementation:

✅ Verification successful

Let me gather more information to verify the class implementation since the previous ast-grep command didn't show the actual implementation.


Proper exposure and implementation of DescrptSeAttenV2 confirmed

The verification shows that:

  • The class is properly implemented in se_atten_v2.py, inheriting from DescrptDPA1 and DescrptSeAttenV2DP
  • It's correctly registered with the @BaseDescriptor.register("se_atten_v2") decorator
  • The class is properly imported and exposed in __init__.py
  • The __all__ list maintains alphabetical ordering
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the se_atten_v2 module exists and contains the DescrptSeAttenV2 class

# Test 1: Check if the module file exists
fd "se_atten_v2.py" deepmd/jax/descriptor/

# Test 2: Verify class implementation
ast-grep --pattern 'class DescrptSeAttenV2' deepmd/jax/descriptor/se_atten_v2.py

Length of output: 162


Script:

#!/bin/bash
# Let's check the actual content of the se_atten_v2.py file
cat deepmd/jax/descriptor/se_atten_v2.py

# Also verify the import in __init__.py to ensure proper exposure
cat deepmd/jax/descriptor/__init__.py

Length of output: 1184

deepmd/jax/model/model.py (2)

40-40: Consistent implementation across model types.

The same ntypes assignment logic is correctly applied to both standard and ZBL models, maintaining consistency in type handling.

Also applies to: 63-63


40-40: LGTM: Explicit type count enhances model consistency.

Setting ntypes explicitly from type_map length ensures consistent type handling across different backends, which is crucial for type embedding support.

Let's verify this change is consistent with other backend implementations:

✅ Verification successful

Based on the search results, I can now generate the final response as we have sufficient evidence to verify the consistency of ntypes handling across different backend implementations:

✓ Setting ntypes from type_map length is consistent across backends

The change is correctly implemented and matches the pattern used in other backend implementations:

  • JAX backend: data["descriptor"]["ntypes"] = len(data["type_map"]) in both get_standard_model and get_zbl_model
  • DPModel backend: Same pattern in deepmd/dpmodel/model/model.py
  • PyTorch backend: Uses ntypes = len(type_map) consistently in model initialization
  • TensorFlow backend: Sets ntypes from len(self.get_type_map()) in model classes
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if ntypes is consistently set across different backend implementations
# Expected: Similar ntypes assignment in other backend model files

# Search for model.py files in different backend directories
fd -t f "model.py" -x rg -A 5 "descriptor\[.\"'\]ntypes.*=.*len.*type_map"

Length of output: 76


Script:

#!/bin/bash
# Let's try a broader search to understand how ntypes is handled across backends
# First, let's find all model.py files
echo "=== Model files ==="
fd "model.py$"

echo -e "\n=== Searching for ntypes handling ==="
# Search for ntypes assignments in these files
rg -l "ntypes" $(fd "model.py$")

echo -e "\n=== Detailed ntypes context ==="
# Get context around ntypes usage
rg -B 2 -A 2 "ntypes" $(fd "model.py$")

Length of output: 58352

deepmd/dpmodel/model/model.py (2)

40-40: LGTM! Improved type handling consistency.

The explicit setting of ntypes and unified descriptor creation through BaseDescriptor aligns well with the PR's objective to fix type embedding inconsistencies between frameworks.

Also applies to: 43-45


43-45: Verify backward compatibility with existing models.

The switch to using BaseDescriptor uniformly needs verification to ensure:

  1. All existing descriptor types work correctly
  2. Type embedding is properly handled in BaseDescriptor
  3. Models can still be converted between frameworks
✅ Verification successful

Based on the search results, I can now generate the final response:

Backward compatibility verification completed - no issues found

The codebase shows strong evidence of backward compatibility:

  1. All descriptor types are properly handled through the BaseDescriptor interface, with existing descriptors (se_a, se_atten, etc.) still maintained in tests
  2. Type embedding is well supported with smooth_type_embedding parameter properly implemented across TF/PyTorch/JAX backends
  3. Model conversion between frameworks is supported through the convert-backend command with extensive test coverage
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for potential impacts on existing models and type embedding

# Test 1: Check for any remaining references to removed descriptor types
echo "Checking for references to removed descriptor types..."
rg "DescrptSeA|se_atten|se_atten_v2" --type py

# Test 2: Check for type embedding related configurations in tests
echo "Checking for type embedding tests..."
rg "smooth_type_embedding.*True" --type py

# Test 3: Look for model conversion tests
echo "Checking for model conversion tests..."
rg "convert.*backend|dp\s+convert" --type py

Length of output: 62838

deepmd/tf/utils/type_embed.py (1)

330-332: Optimization: Skip serialization of default identity weights.

This change improves the consistency of type embedding serialization across different frameworks by explicitly skipping the serialization of identity weights that are set to 0.0. This is a good optimization as:

  1. It prevents unnecessary data in the serialized model
  2. Aligns with the PR objective of making type embedding handling consistent between TensorFlow and other frameworks
  3. The comment clearly explains the purpose of this check

However, there are a few considerations:

Let's verify if this change maintains backward compatibility with existing models:

Consider adding a version check to ensure older versions of the framework can still load models serialized with this change. This could be done by:

  1. Incrementing the @version field when this behavior changes
  2. Adding appropriate version handling in the deserialize method
✅ Verification successful

Based on the gathered information, I can now provide the final response:

The optimization to skip serialization of zero identity weights is safe and consistent.

The change is well-aligned with the framework's design and implementation because:

  1. Version compatibility is properly maintained through check_version_compatibility checks across frameworks (TF, JAX, and dpmodel)
  2. The deserialize method in array_api_strict confirms consistent handling across implementations
  3. Zero weights are explicitly handled as default values during deserialization
  4. No existing tests or models were found that rely on serialized zero identity weights

The change improves consistency across frameworks while maintaining backward compatibility through proper version checks (supporting versions 1 and 2).

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for any existing models that might rely on the serialized 0.0 idt weights
rg -g '*.json' -g '*.pb' '"idt":\s*0\.0' 

Length of output: 42


Script:

#!/bin/bash
# Let's verify the serialization and deserialization logic more comprehensively

# 1. Check the deserialize method to understand how 0.0 values are handled
ast-grep --pattern 'def deserialize(self, data):
  $$$
'

# 2. Look for any tests that verify serialization/deserialization
rg -l "test.*serialize" 

# 3. Check if there are any version checks in the type embedding
rg "version|@version" deepmd/tf/utils/type_embed.py

# 4. Look for any documentation about the serialization format
fd "README|CHANGELOG" --type f --exec grep -l "type.*embed|serialize" {} \;

Length of output: 753


Script:

#!/bin/bash
# Let's examine the deserialize method and version handling more closely

# 1. Check the full context of deserialize method in type_embed.py
rg -B 5 -A 15 "def deserialize" deepmd/tf/utils/type_embed.py

# 2. Look for any test cases specifically for TypeEmbedNet
rg -l "TypeEmbedNet.*test|test.*TypeEmbedNet"

# 3. Check how idt weights are handled in tests
rg -C 5 "idt.*=.*0\.0|idt.*:.*0\.0" source/tests/

# 4. Look for any related changes in other frameworks (pytorch)
fd "type_embed" --type f --exec rg -l "serialize" {} \;

Length of output: 903


Script:

#!/bin/bash
# Let's check the serialization implementations across different frameworks

# 1. Check how other frameworks handle idt weights in serialization
rg -C 5 "def serialize" deepmd/jax/utils/type_embed.py deepmd/dpmodel/utils/type_embed.py

# 2. Check version handling in deserialization across frameworks
rg -C 5 "check_version_compatibility" deepmd/jax/utils/type_embed.py deepmd/dpmodel/utils/type_embed.py

# 3. Look for any related test cases in array_api_strict
rg -C 5 "serialize|deserialize" source/tests/array_api_strict/utils/type_embed.py

# 4. Check if there are any model compatibility tests
rg -l "test.*compatibility|compatibility.*test"

Length of output: 2765

deepmd/dpmodel/descriptor/hybrid.py (1)

61-63: LGTM! Verify backward compatibility.

The changes correctly propagate type information to child descriptors, which is essential for consistent type embedding support.

Let's verify that existing models without type information still work:

✅ Verification successful

Based on the search results from previous scripts, I can now make a determination:

Type information propagation is safe and backward compatible

The code changes are safe because:

  1. The hybrid descriptor implementation properly handles both cases - with and without type information
  2. Test cases show successful initialization with and without type parameters
  3. The type information is only propagated when provided, maintaining compatibility with existing models
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for existing model usage without type_map/ntypes
rg -l "DescrptHybrid.*list\s*=.*\}" --type py

Length of output: 47


Script:

#!/bin/bash
# Let's try a broader search for DescrptHybrid usage patterns
rg "DescrptHybrid" --type py -B 2 -A 2

# Also search for test files that might show usage examples
fd "test.*\.py" --exec rg "DescrptHybrid" {} -B 2 -A 2

# Look for descriptor list initialization patterns
rg "descript_list.*=.*\[" --type py -B 2 -A 2

Length of output: 52603


Script:

#!/bin/bash
# Let's look for test cases that verify backward compatibility
rg "test.*backward.*compatibility" --type py -B 2 -A 2

# Also search for test cases that use DescrptHybrid without type_map/ntypes
ast-grep --pattern 'DescrptHybrid\(\s*list\s*=\s*\[.*\]\s*\)'

Length of output: 120

deepmd/tf/descriptor/hybrid.py (1)

Line range hint 464-493: Verify impact on existing models.

The addition of type embedding checks could affect existing models that use hybrid descriptors.

Let's check for potential usage patterns in the codebase:

✅ Verification successful

Based on the search results, I can now generate the final response since I have enough information to verify the impact of type embedding checks on existing models:

Type embedding check is safe and won't affect existing models

The added type embedding check is consistent with existing behavior since:

  1. All example configurations and test files show hybrid descriptors are used without type embedding:
  • examples/water/hybrid/input.json
  • examples/dprc/generalized_force/input.json
  • examples/dprc/normal/input.json
  • source/tests/tf/water_hybrid.json
  1. The error message was already effectively enforced through the implementation, as evidenced by test files explicitly checking this limitation:
  • Test files contain comments like "unsupported for SpinModel to hybrid both mixed_types and no-mixed_types descriptor"
  • Multiple test classes verify this behavior across different frameworks (TF, PT, JAX)
  1. The change only makes the limitation more explicit by failing early with a clear error message rather than failing later during execution.
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Search for hybrid descriptor usage with type embedding
# Expected results: No instances of hybrid descriptors with type embedding

# Test 1: Look for hybrid descriptor configurations in test files
echo "Searching for hybrid descriptor test configurations..."
rg -A 10 '"type":\s*"hybrid"' --type json

# Test 2: Look for hybrid descriptor instantiations in Python files
echo "Searching for hybrid descriptor usage in Python files..."
rg -B 5 -A 5 'DescrptHybrid|"type":\s*"hybrid"' --type py

Length of output: 75747

deepmd/pt/model/task/fitting.py (2)

422-423: LGTM: Improved error message formatting

The change to f-strings improves readability while maintaining the same error information.


422-423: Verify alignment with PR objectives

While the error message improvements are good, they seem tangential to the main PR objectives of fixing DPA-1 model consistency issues. Please confirm if there are other changes needed in this file to address the core objectives.

Let's check for any DPA-1 model related code:

deepmd/tf/model/model.py (2)

825-825: LGTM! Type embedding parameter properly added.

The type embedding parameter is correctly added to the class constructor, maintaining consistency with the deserialization logic.


811-814: LGTM! Verify type embedding dimension adjustment.

The type embedding handling during deserialization looks correct. The code properly passes the descriptor's type embedding to the model and adjusts the descriptor dimension to avoid double-counting.

Let's verify the dimension adjustment logic:

✅ Verification successful

Let me gather more context about the type embedding dimension handling.


Let me gather more information about the type embedding and descriptor dimensions.


Let me try a different search pattern to understand the type embedding and dimension handling.


Based on the code analysis, I can now provide a definitive response about the type embedding dimension adjustment.

Type embedding dimension adjustment is correct and necessary

The dimension adjustment in the deserialization code is correct because:

  1. The descriptor's output dimension (get_dim_out()) includes the type embedding dimension when explicit_ntypes is true
  2. The fitting network expects the raw descriptor dimension without the type embedding dimension
  3. The adjustment fitting.dim_descrpt -= type_embedding.neuron[-1] correctly subtracts the type embedding dimension from the fitting network's input dimension to avoid double-counting

The code shows consistent handling across different descriptor implementations (se_a, se_t, se_atten, etc.) where:

  • Descriptors with explicit_ntypes=True concatenate type embedding to their output
  • The fitting network needs the original descriptor dimension without the type embedding
  • The type embedding is passed separately to the fitting network via the model
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for any other places where descriptor dimension is adjusted
rg -A 5 "dim_descrpt.*=.*type_embedding" 

Length of output: 441


Script:

#!/bin/bash
# Search for type_embedding related code
ast-grep --pattern 'type_embedding' 

# Search for dim_descrpt usage
rg -B 3 -A 3 'dim_descrpt'

Length of output: 116453


Script:

#!/bin/bash
# Search for type embedding neuron dimension
ast-grep --pattern 'type_embedding.*neuron'

# Search for get_dim_out in descriptor classes
ast-grep --pattern 'def get_dim_out'

Length of output: 84


Script:

#!/bin/bash
# Search for TypeEmbedNet class definition and usage
ast-grep --pattern 'class TypeEmbedNet'

# Search for explicit_ntypes usage
rg -B 3 -A 3 'explicit_ntypes'

# Search for get_dim_out method calls
rg -B 3 -A 3 'get_dim_out\('

Length of output: 91537

deepmd/tf/fit/ener.py (3)

245-245: LGTM: New attribute for type embedding dimension.

The addition of tebd_dim attribute with default value 0 is appropriate for tracking the type embedding dimension.


758-759: LGTM: Proper dimension recovery.

Correctly restores the original descriptor dimension after processing, which is essential for proper serialization.


914-914: LGTM: Consistent dimension handling in serialization.

The changes properly account for the type embedding dimension in both the descriptor dimension and network input dimension calculations during serialization.

Let's verify the consistency of type embedding dimension handling across the codebase:

Also applies to: 936-936

✅ Verification successful

Let me gather more information about the type embedding dimension handling by searching for related patterns.


Based on the extensive search results, I can now provide a final response:

Verified: The dimension handling is consistent across the codebase

The type embedding dimension (tebd_dim) is correctly added to the descriptor dimension (dim_descrpt) in both the serialization and network input dimension calculations. This pattern is consistently implemented across:

  • Energy fitting (ener.py): Adds tebd_dim to dim_descrpt for both serialization and network input
  • DOS fitting (dos.py): Properly handles combined dimensions for network input
  • General fitting base class: Consistently handles dimension combinations
  • PyTorch implementation: Follows the same dimension handling patterns

The change ensures proper dimensionality when using type embeddings while maintaining consistency across different model components and backends.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for consistent type embedding dimension handling across different backends

# Test: Look for similar dimension handling patterns in other backend implementations
rg -A 2 "dim_descrpt.*tebd" 

Length of output: 227


Script:

#!/bin/bash
# Search for type embedding dimension (tebd) related code
rg "tebd" -A 3

# Search for dimension descriptor handling
rg "dim_descrpt" -A 3

# Look for network input dimension calculations
ast-grep --pattern 'self.dim_descrpt + self.$_'

Length of output: 314480

deepmd/tf/descriptor/se_atten_v2.py (4)

7-9: Import of TypeEmbedNet is appropriate

The import statement correctly brings in TypeEmbedNet, which is necessary for handling type embeddings in the descriptor.


136-136: Version compatibility check is correctly implemented

The check_version_compatibility function is used appropriately to ensure that the serialized data is compatible with version 2.1, which helps maintain consistency and prevent errors during deserialization.


156-159: Backward compatibility for use_tebd_bias is well-handled

The code ensures backward compatibility by setting use_tebd_bias to True when it's not present in the data. This approach maintains expected behavior for older versions that may not include this key.


170-170: Assignment of type_embedding to descriptor is appropriate

The type_embedding instance is correctly assigned to descriptor.type_embedding, ensuring that the deserialized type embedding is incorporated into the descriptor object.

deepmd/tf/descriptor/hybrid.py Show resolved Hide resolved
deepmd/tf/descriptor/se_atten_v2.py Show resolved Hide resolved
source/tests/consistent/model/test_dpa1.py Show resolved Hide resolved
source/tests/consistent/model/test_dpa1.py Show resolved Hide resolved
source/tests/consistent/model/test_dpa1.py Show resolved Hide resolved
deepmd/tf/descriptor/se_atten.py Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (3)
deepmd/tf/model/model.py (1)

845-846: LGTM! Consider adding docstring for type embedding fields.

The changes correctly handle type embedding during serialization by updating both the descriptor and fitting components. This ensures consistency between frameworks.

Consider adding docstring comments to document these fields:

 if self.typeebd is not None:
+    # Pass type embedding to descriptor for consistent serialization
     self.descrpt.type_embedding = self.typeebd
+    # Update fitting dimension to include type embedding
     self.fitting.tebd_dim = self.typeebd.neuron[-1]
deepmd/tf/descriptor/se_atten.py (2)

1947-1954: Robust type embedding dimension handling

The code properly calculates the embedding input dimension based on the type embedding mode and configuration. However, the logic could be more maintainable.

Consider extracting this logic into a separate method for better maintainability:

-        tebd_dim = self.type_embedding.neuron[0]
-        if self.tebd_input_mode in ["concat"]:
-            if not self.type_one_side:
-                embd_input_dim = 1 + tebd_dim * 2
-            else:
-                embd_input_dim = 1 + tebd_dim
-        else:
-            embd_input_dim = 1
+        def calculate_embedding_input_dim(tebd_dim: int) -> int:
+            if self.tebd_input_mode in ["concat"]:
+                return 1 + tebd_dim * (2 if not self.type_one_side else 1)
+            return 1
+            
+        tebd_dim = self.type_embedding.neuron[0]
+        embd_input_dim = calculate_embedding_input_dim(tebd_dim)

2012-2018: Validation of type embedding network structure

Good validation check for single layer type embedding networks. However, consider providing more guidance in the error message.

Enhance the error message:

-            raise NotImplementedError(
-                "Only support single layer type embedding network"
-            )
+            raise NotImplementedError(
+                "Multi-layer type embedding networks are not supported. Please use a single layer network with neuron=[tebd_dim]"
+            )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 4de5511 and 0fcd127.

📒 Files selected for processing (4)
  • deepmd/dpmodel/descriptor/hybrid.py (2 hunks)
  • deepmd/tf/descriptor/se_atten.py (9 hunks)
  • deepmd/tf/descriptor/se_atten_v2.py (4 hunks)
  • deepmd/tf/model/model.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/dpmodel/descriptor/hybrid.py
🔇 Additional comments (10)
deepmd/tf/descriptor/se_atten_v2.py (5)

7-9: LGTM: Clean import and well-documented class.

The import of TypeEmbedNet and class documentation are well-structured and complete.


Line range hint 89-117: LGTM: Constructor properly enforces consistent configuration.

The constructor correctly sets smooth_type_embedding=True and tebd_input_mode="strip" to ensure consistency across backends, which aligns with the PR objectives.


133-135: LGTM: Clear error handling and version requirements.

The RuntimeError and version compatibility check (2.1) effectively enforce the intended backend consistency requirements.

Also applies to: 139-139


Line range hint 176-192: LGTM: Clean serialization implementation.

The serialization method correctly removes the fixed configuration values and updates the type identifier.


156-162: ⚠️ Potential issue

Previous review comment still applies: Handle potential KeyError.

The code is still vulnerable to KeyError when accessing 'type_embedding'. Consider implementing the suggested error handling from the previous review.

Apply this diff as suggested in the previous review:

+            if "type_embedding" not in data:
+                raise KeyError("'type_embedding' key not found in data during deserialization.")
             type_embedding = TypeEmbedNet.deserialize(
                 data.pop("type_embedding"), suffix=suffix
             )
deepmd/tf/model/model.py (1)

811-816: LGTM! Consider adding error handling for dimension validation.

The changes correctly handle type embedding during deserialization by passing it from the descriptor to the model when explicit_ntypes is true. The dimension adjustment prevents double-counting of the embedding dimension.

Let's verify the type embedding dimension handling:

Also applies to: 827-827

✅ Verification successful

Based on the search results, I can now generate the final response:

The dimension adjustment is safe and correctly implemented.

The code correctly subtracts the type embedding dimension (type_embedding.neuron[-1]) from fitting.dim_descrpt when explicit_ntypes is true. This adjustment is necessary because:

  1. The type embedding dimension is properly defined in the descriptor (see deepmd/tf/descriptor/se_atten.py and deepmd/tf/descriptor/se_a.py)
  2. The fitting networks consistently use dim_descrpt as input dimension across all implementations
  3. The dimension adjustment only happens when explicit_ntypes is true, preventing any unintended modifications
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for potential dimension mismatches in type embedding usage
rg -A 5 "dim_descrpt\s*[-+]=\s*" 

Length of output: 407


Script:

#!/bin/bash
# Let's check how type_embedding.neuron is used and defined
rg -A 5 "type_embedding.*neuron" 

# Also check where dim_descrpt is used to ensure dimension consistency
rg -A 5 "dim_descrpt\s*=\s*" 

# Look for any validation of dimensions
ast-grep --pattern 'if $_.dim_descrpt $_'

Length of output: 35719

deepmd/tf/descriptor/se_atten.py (4)

1885-1891: Improved version compatibility handling for type embedding

Good addition of version compatibility handling for use_tebd_bias. The code properly handles both v1 and v2 formats by setting a default value of True for v1 compatibility.


1902-1911: Proper handling of stripped type embedding

The code correctly handles the stripped type embedding mode by deserializing the embeddings with appropriate type_one_side parameter. This ensures compatibility with different model configurations.


1927-1929: Consistent error handling for smooth type embedding

The error handling for smooth type embedding is consistent with the deserialization logic, ensuring that incompatible configurations are caught early.


2038-2084: Comprehensive attention layer serialization update

The code properly updates the attention layer serialization to maintain consistency with other backend references. The default values and attention layer parameters are well-structured.

However, consider adding validation for the default values.

✅ Verification successful

Based on the search results, I can now provide a final response:

Default values are consistently defined across backends

The verification shows that the default values in the serialization code are consistent with their usage across different backend implementations (TensorFlow, PyTorch, and Universal):

  • scaling_factor=1.0
  • normalize=True
  • temperature=None
  • concat_output_tebd=True
  • use_econf_tebd=False

These values are properly synchronized in:

  • TensorFlow: deepmd/tf/descriptor/se_atten.py
  • PyTorch: deepmd/pt/model/descriptor/se_atten.py
  • Universal: deepmd/dpmodel/descriptor/se_atten_v2.py, deepmd/dpmodel/descriptor/dpa1.py

The attention layer serialization implementation maintains consistency across all backend references.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Verify that the default values match across all backends
rg --type python "scaling_factor.*=.*1\.0|normalize.*=.*True|temperature.*=.*None|concat_output_tebd.*=.*True|use_econf_tebd.*=.*False"

Length of output: 168


Script:

#!/bin/bash
# Search for default values in Python files without specifying file type
rg "scaling_factor.*=.*1\.0|normalize.*=.*True|temperature.*=.*None|concat_output_tebd.*=.*True|use_econf_tebd.*=.*False" --type-add 'py:*.py' -t py

Length of output: 6221


Script:

#!/bin/bash
# Search for attention layer configuration in Python files
rg "attention_layers.*=|attention_layer.*=" --type-add 'py:*.py' -t py -C 2

Length of output: 14584


Script:

#!/bin/bash
# Find all Python files that might contain attention-related configurations
fd -e py | xargs grep -l "attention"

Length of output: 792

deepmd/tf/descriptor/se_atten.py Show resolved Hide resolved
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link

codecov bot commented Nov 7, 2024

Codecov Report

Attention: Patch coverage is 72.46377% with 19 lines in your changes missing coverage. Please review.

Project coverage is 84.47%. Comparing base (430dfa9) to head (cc67fdd).
Report is 5 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/tf/descriptor/se_atten.py 73.68% 10 Missing ⚠️
deepmd/tf/descriptor/se_atten_v2.py 12.50% 7 Missing ⚠️
deepmd/tf/descriptor/hybrid.py 60.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4320      +/-   ##
==========================================
+ Coverage   84.42%   84.47%   +0.05%     
==========================================
  Files         570      570              
  Lines       53080    53160      +80     
  Branches     3054     3059       +5     
==========================================
+ Hits        44813    44909      +96     
+ Misses       7307     7289      -18     
- Partials      960      962       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@coderabbitai coderabbitai bot mentioned this pull request Nov 30, 2024
9 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Support dp convert-backend for models with type embedding for tf
3 participants