Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pd: support dpa1 #4414

Open
wants to merge 68 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 67 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
48f77f3
add core modules of paddle backend and water/se_e2_a example
HydrogenSulfate Nov 2, 2024
2082a59
add paddle code in consistent test
HydrogenSulfate Nov 2, 2024
2ae45b8
clean env and training
HydrogenSulfate Nov 2, 2024
7f03a04
add more test files
HydrogenSulfate Nov 2, 2024
4d1c44c
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 2, 2024
72c9b4e
fix pt->pd
HydrogenSulfate Nov 2, 2024
3b1c348
update test_python.yml
HydrogenSulfate Nov 2, 2024
a46dcb5
restore .pre-commit-config.yaml
HydrogenSulfate Nov 3, 2024
90f9ff9
remove redundant file
HydrogenSulfate Nov 3, 2024
0a6baa6
Skip bfloat16 for some cases
HydrogenSulfate Nov 3, 2024
4b77e55
enable prim by default in unitest
HydrogenSulfate Nov 3, 2024
6e139a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2024
9437957
fix env code
HydrogenSulfate Nov 5, 2024
f1d762f
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 5, 2024
8534597
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 6, 2024
c22b45d
update test_ener.py
HydrogenSulfate Nov 6, 2024
39842ff
add missing pd_class
HydrogenSulfate Nov 6, 2024
07cd98e
use paddle Tensor instead of numpy array in pd/test_auto_batch_size.p…
HydrogenSulfate Nov 6, 2024
bb2d547
add training test and remove ase_calc.py
HydrogenSulfate Nov 7, 2024
5fb6d8e
add training test and remove ase_calc.py
HydrogenSulfate Nov 7, 2024
91066f8
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 7, 2024
90c9c03
upload missing json
HydrogenSulfate Nov 7, 2024
eb7384e
restore pt/test_auto_batch_size.py
HydrogenSulfate Nov 7, 2024
9faf54f
rerun CI for network problem
HydrogenSulfate Nov 7, 2024
4e3a121
add multitask unitest
HydrogenSulfate Nov 7, 2024
18333ab
add more unitest
HydrogenSulfate Nov 7, 2024
f9c6da8
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 7, 2024
3fd979d
remove redundant file and fix typo
HydrogenSulfate Nov 7, 2024
5922e84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
f5a17a9
update unitest
HydrogenSulfate Nov 8, 2024
8bea1bf
delete record
HydrogenSulfate Nov 8, 2024
8a7875f
remove more unused code and files
HydrogenSulfate Nov 8, 2024
df9f887
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 8, 2024
67b81e1
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 8, 2024
b0bf733
Merge branch 'add_paddle_backend_core_and_water_se_e2_a' of https://g…
HydrogenSulfate Nov 8, 2024
71a3c0a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 8, 2024
ede5047
remove redundant annotations
HydrogenSulfate Nov 8, 2024
d11bf4d
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 8, 2024
b7a8cec
add nvtx profiler code in training, which is more accurate and detailed
HydrogenSulfate Nov 8, 2024
7567cf8
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 8, 2024
416fec8
update code as devel and fix typo
HydrogenSulfate Nov 9, 2024
1c0161c
fix pth -> json
HydrogenSulfate Nov 9, 2024
02a6f84
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 9, 2024
3354e5c
update unitest and training
HydrogenSulfate Nov 9, 2024
0d3f8cf
install paddle when test_cuda
HydrogenSulfate Nov 9, 2024
18215ff
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 9, 2024
859b94d
fix unitest
HydrogenSulfate Nov 9, 2024
74ee1c2
add eta in logging message for convenient
HydrogenSulfate Nov 9, 2024
f176309
remove hybrid code and enable one unitest
HydrogenSulfate Nov 9, 2024
4935e7b
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 9, 2024
fac51d3
add pd/__init__.py
HydrogenSulfate Nov 11, 2024
d3ca1f0
Merge branch 'devel' into add_paddle_backend_core_and_water_se_e2_a
HydrogenSulfate Nov 11, 2024
db1cd76
fix enable_prim
HydrogenSulfate Nov 11, 2024
36512fd
remove unused layernorm
HydrogenSulfate Nov 11, 2024
351bf7a
update dpa1 code
HydrogenSulfate Nov 25, 2024
01d4179
Merge branch 'devel' into add_dpa1
HydrogenSulfate Nov 25, 2024
bc1cb38
update code of dpa1
HydrogenSulfate Nov 28, 2024
b4bc9db
Merge branch 'devel' into add_dpa1
HydrogenSulfate Nov 28, 2024
c944b82
restore decomp to paddle function
HydrogenSulfate Nov 28, 2024
4c925f9
remove redundant files
HydrogenSulfate Nov 28, 2024
dd3191a
update unitest and codes
HydrogenSulfate Nov 29, 2024
7df0e2f
fix
HydrogenSulfate Nov 29, 2024
ac479ed
update code
HydrogenSulfate Nov 29, 2024
3e64196
Merge branch 'devel' into add_dpa1
HydrogenSulfate Nov 29, 2024
56e079c
update typos
HydrogenSulfate Nov 29, 2024
3d70e7c
update code
HydrogenSulfate Nov 29, 2024
8b5b4a8
fix coverage
HydrogenSulfate Nov 29, 2024
e74d272
update consistent check of dpa1
HydrogenSulfate Nov 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions deepmd/pd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@
use_pretrain_script: bool = False,
force_load: bool = False,
output: str = "out.json",
):
) -> None:
log.info("Configuration path: %s", input_file)
SummaryPrinter()()
with open(input_file) as fin:
Expand Down Expand Up @@ -321,18 +321,26 @@
# save min_nbor_dist
if min_nbor_dist is not None:
if not multi_task:
trainer.model.min_nbor_dist = min_nbor_dist
trainer.model.min_nbor_dist = paddle.to_tensor(

Check warning on line 324 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L324

Added line #L324 was not covered by tests
min_nbor_dist,
dtype=paddle.float64,
place=DEVICE,
)
else:
for model_item in min_nbor_dist:
trainer.model[model_item].min_nbor_dist = min_nbor_dist[model_item]
trainer.model[model_item].min_nbor_dist = paddle.to_tensor(

Check warning on line 331 in deepmd/pd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/entrypoints/main.py#L331

Added line #L331 was not covered by tests
min_nbor_dist[model_item],
dtype=paddle.float64,
place=DEVICE,
)
trainer.run()


def freeze(
model: str,
output: str = "frozen_model.json",
head: Optional[str] = None,
):
) -> None:
paddle.set_flags(
{
"FLAGS_save_cf_stack_op": 1,
Expand Down Expand Up @@ -383,7 +391,7 @@
numb_batch: int = 0,
model_branch: Optional[str] = None,
output: Optional[str] = None,
):
) -> None:
if input_file.endswith(".pd"):
old_state_dict = paddle.load(input_file)
model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict))
Expand Down
6 changes: 1 addition & 5 deletions deepmd/pd/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
TaskLoss,
)
from deepmd.pd.utils import (
decomp,
env,
)
from deepmd.pd.utils.env import (
Expand Down Expand Up @@ -224,10 +223,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):

if self.relative_f is not None:
force_label_3 = force_label.reshape([-1, 3])
# norm_f = force_label_3.norm(axis=1, keepdim=True) + self.relative_f
norm_f = (
decomp.norm(force_label_3, axis=1, keepdim=True) + self.relative_f
)
norm_f = force_label_3.norm(axis=1, keepdim=True) + self.relative_f
diff_f_3 = diff_f.reshape([-1, 3])
diff_f_3 = diff_f_3 / norm_f
diff_f = diff_f_3.reshape([-1])
Expand Down
38 changes: 34 additions & 4 deletions deepmd/pd/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import functools
import logging
from typing import (
Expand Down Expand Up @@ -52,7 +51,7 @@
fitting,
type_map: list[str],
**kwargs,
):
) -> None:
super().__init__(type_map, **kwargs)
ntypes = len(type_map)
self.type_map = type_map
Expand Down Expand Up @@ -201,7 +200,7 @@

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class", None)
data.pop("type", None)
Expand All @@ -212,6 +211,37 @@
obj = super().deserialize(data)
return obj

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Call descriptor enable_compression().

Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
self.descriptor.enable_compression(

Check warning on line 237 in deepmd/pd/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/model/atomic_model/dp_atomic_model.py#L237

Added line #L237 was not covered by tests
min_nbor_dist,
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
)

def forward_atomic(
self,
extended_coord,
Expand Down Expand Up @@ -278,7 +308,7 @@
self,
sampled_func,
stat_file_path: Optional[DPPath] = None,
):
) -> None:
"""
Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pd/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from .descriptor import (
DescriptorBlock,
)
from .dpa1 import (
DescrptBlockSeAtten,
DescrptDPA1,
)
from .env_mat import (
prod_env_mat,
)
Expand All @@ -17,6 +21,8 @@
"BaseDescriptor",
"DescriptorBlock",
"DescrptBlockSeA",
"DescrptBlockSeAtten",
"DescrptDPA1",
"DescrptSeA",
"prod_env_mat",
]
Loading