-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(tf): support tensor fitting with hybrid descriptor #4542
base: devel
Are you sure you want to change the base?
Conversation
Fix deepmodeling#4527. Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
Comments suppressed due to low confidence (2)
deepmd/tf/model/model.py:672
- The error message should be more specific. Suggestion: "The descriptor's embedding width is zero, indicating it cannot provide a rotation matrix for tensorial fitting. Please check the descriptor configuration."
raise ValueError("This descriptor cannot provide a rotation matrix for a tensorial fitting.")
deepmd/tf/descriptor/descriptor.py:543
- [nitpick] Returning a tensor with a zero dimension might lead to confusion or unexpected behavior. Consider revising the implementation to handle cases where no rotation matrix is provided more explicitly.
return tf.zeros([nframes, natoms, 0], dtype=GLOBAL_TF_FLOAT_PRECISION)
📝 WalkthroughWalkthroughThe pull request introduces modifications to the Changes
Assessment against linked issues
Possibly related PRs
Suggested labels
Suggested reviewers
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms (7)
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (6)
source/tests/tf/test_dipole_hybrid_descrpt.py (4)
23-26
: Consider using type hints for global variables.The global precision variables should have type hints for better code clarity and IDE support.
-GLOBAL_ENER_FLOAT_PRECISION = tf.float64 -GLOBAL_TF_FLOAT_PRECISION = tf.float64 -GLOBAL_NP_FLOAT_PRECISION = np.float64 +GLOBAL_ENER_FLOAT_PRECISION: tf.DType = tf.float64 +GLOBAL_TF_FLOAT_PRECISION: tf.DType = tf.float64 +GLOBAL_NP_FLOAT_PRECISION: np.dtype = np.float64
49-86
: Consider extracting descriptor configuration to a separate file.The descriptor configuration is quite large and contains repeated sections. Consider moving it to a separate JSON configuration file for better maintainability.
Also, the descriptor configuration has duplicate sections for the first two descriptors. Consider using a helper function to create these configurations:
def create_se_e2_a_config(seed: int = 1) -> dict: return { "type": "se_e2_a", "sel": [20, 20], "rcut_smth": 1.8, "rcut": 6.0, "neuron": [2, 4, 8], "resnet_dt": False, "axis_neuron": 8, "precision": "float64", "type_one_side": True, "seed": seed, }
95-95
: Remove commented code.The commented line appears to be test code. Remove it if it's not needed.
-# model._compute_dstats([test_data['coord']], [test_data['box']], [test_data['type']], [test_data['natoms_vec']], [test_data['default_mesh']])
113-114
: Initialize t_fparam with None type hint.Add type hint to t_fparam for better code clarity.
-t_fparam = None +t_fparam: Optional[tf.Tensor] = Nonedeepmd/tf/descriptor/hybrid.py (1)
507-512
: Add docstring parameters and return type.The get_rot_mat method needs more detailed documentation.
def get_rot_mat(self) -> tf.Tensor: - """Get rotational matrix.""" + """Get rotational matrix by concatenating rotation matrices from all descriptors. + + Returns + ------- + tf.Tensor + The concatenated rotation matrix tensor with shape [nframes, natoms, sum(rot_mat_dim)] + where sum(rot_mat_dim) is the sum of rotation matrix dimensions from all descriptors + """deepmd/tf/descriptor/descriptor.py (1)
108-109
: Add explanation for the default return value.Document why 0 is returned as the default value.
- # by default, no rotation matrix + # Return 0 to indicate that this descriptor doesn't support rotation matrices by default. + # Derived classes should override this method if they support rotation matrices.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
deepmd/tf/descriptor/descriptor.py
(2 hunks)deepmd/tf/descriptor/hybrid.py
(1 hunks)deepmd/tf/model/model.py
(1 hunks)source/tests/tf/test_dipole_hybrid_descrpt.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (21)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Analyze (python)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test C++ (false)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (cpu, cpu)
🔇 Additional comments (2)
deepmd/tf/descriptor/hybrid.py (1)
496-505
: LGTM! Well-documented method.The get_dim_rot_mat_1 method is well-documented with clear docstrings and return type hints.
deepmd/tf/model/model.py (1)
671-675
: LGTM! Good validation check.The validation ensures that tensorial fitting networks have the required rotation matrix support from their descriptors.
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4542 +/- ##
==========================================
+ Coverage 84.55% 84.58% +0.02%
==========================================
Files 677 677
Lines 63903 63915 +12
Branches 3488 3487 -1
==========================================
+ Hits 54035 54060 +25
+ Misses 8743 8729 -14
- Partials 1125 1126 +1 ☔ View full report in Codecov by Sentry. |
Fix #4527.
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Refactor