Skip to content

Commit

Permalink
feat(pt): DPA-2 repinit compress
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 9, 2024
1 parent 0c5ab07 commit 22d0192
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 5 deletions.
86 changes: 86 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,14 @@
build_multiple_neighbor_list,
get_multiple_nlist_key,
)
from deepmd.pt.utils.tabulate import (
DPTabulate,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.pt.utils.utils import (
ActivationFn,
to_numpy_array,
)
from deepmd.utils.data_system import (
Expand Down Expand Up @@ -859,3 +863,85 @@ def update_sel(
)
local_jdata_cpy["repformer"]["nsel"] = repformer_sel[0]
return local_jdata_cpy, min_nbor_dist

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
# do some checks before the mocel compression process
if self.repinit.compress:
raise ValueError("Compression is already enabled.")

Check warning on line 892 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L892

Added line #L892 was not covered by tests
assert (
not self.repinit.resnet_dt
), "Model compression error: repinit resnet_dt must be false!"
for tt in self.repinit.exclude_types:
if (tt[0] not in range(self.repinit.ntypes)) or (

Check warning on line 897 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L897

Added line #L897 was not covered by tests
tt[1] not in range(self.repinit.ntypes)
):
raise RuntimeError(

Check warning on line 900 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L900

Added line #L900 was not covered by tests
"Repinit exclude types"
+ str(tt)
+ " must within the number of atomic types "
+ str(self.repinit.ntypes)
+ "!"
)
if (
self.repinit.ntypes * self.repinit.ntypes - len(self.repinit.exclude_types)
== 0
):
raise RuntimeError(

Check warning on line 911 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L911

Added line #L911 was not covered by tests
"Repinit empty embedding-nets are not supported in model compression!"
)

if self.repinit.attn_layer != 0:
raise RuntimeError(

Check warning on line 916 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L916

Added line #L916 was not covered by tests
"Cannot compress model when repinit attention layer is not 0."
)

if self.repinit.tebd_input_mode != "strip":
raise RuntimeError(

Check warning on line 921 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L921

Added line #L921 was not covered by tests
"Cannot compress model when repinit tebd_input_mode == 'concat'"
)

# repinit doesn't have a serialize method
data = self.serialize()
self.table = DPTabulate(
self,
data["repinit_args"]["neuron"],
data["repinit_args"]["type_one_side"],
data["exclude_types"],
ActivationFn(data["repinit_args"]["activation_function"]),
)
self.table_config = [
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
]
self.lower, self.upper = self.table.build(
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
)

self.repinit.enable_compression(
self.table.data, self.table_config, self.lower, self.upper
)
self.compress = True
19 changes: 14 additions & 5 deletions deepmd/pt/utils/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,14 @@ def __init__(
raise RuntimeError("Unknown activation function type!")

self.activation_fn = activation_fn
self.davg = self.descrpt.serialize()["@variables"]["davg"]
self.dstd = self.descrpt.serialize()["@variables"]["dstd"]
self.ntypes = self.descrpt.get_ntypes()
serialized = self.descrpt.serialize()
if isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptDPA2):
serialized = serialized["repinit_variable"]
self.davg = serialized["@variables"]["davg"]
self.dstd = serialized["@variables"]["dstd"]
self.embedding_net_nodes = serialized["embeddings"]["networks"]

self.embedding_net_nodes = self.descrpt.serialize()["embeddings"]["networks"]
self.ntypes = self.descrpt.get_ntypes()

self.layer_size = self._get_layer_size()
self.table_size = self._get_table_size()
Expand Down Expand Up @@ -291,7 +294,13 @@ def _layer_1(self, x, w, b):
return t, self.activation_fn(torch.matmul(x, w) + b) + t

def _get_descrpt_type(self):
if isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptDPA1):
if isinstance(
self.descrpt,
(
deepmd.pt.model.descriptor.DescrptDPA1,
deepmd.pt.model.descriptor.DescrptDPA2,
),
):
return "Atten"
elif isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptSeA):
return "A"
Expand Down
149 changes: 149 additions & 0 deletions source/tests/pt/model/test_compressed_descriptor_dpa2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest
from typing import (
Any,
)

import numpy as np
import torch

from deepmd.dpmodel.descriptor.dpa2 import (
RepformerArgs,
RepinitArgs,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.pt.model.descriptor.dpa2 import (
DescrptDPA2,
)
from deepmd.pt.utils.env import DEVICE as PT_DEVICE
from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt
from deepmd.pt.utils.nlist import (
extend_coord_with_ghosts as extend_coord_with_ghosts_pt,
)

from ...consistent.common import (
parameterized,
)


def eval_pt_descriptor(
pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False
) -> Any:
ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt(
torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3),
torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1),
torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3),
pt_obj.get_rcut(),
)
nlist = build_neighbor_list_pt(
ext_coords,
ext_atype,
natoms[0],
pt_obj.get_rcut(),
pt_obj.get_sel(),
distinguish_types=(not mixed_types),
)
result, _, _, _, _ = pt_obj(ext_coords, ext_atype, nlist, mapping=mapping)
return result


@parameterized(("float32", "float64"), (True, False))
class TestDescriptorDPA2(unittest.TestCase):
def setUp(self):
(self.dtype, self.type_one_side) = self.param
if self.dtype == "float32":
self.skipTest("FP32 has bugs:")
# ../../../../deepmd/pt/model/descriptor/repformer_layer.py:521: in forward
# torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni)
# E RuntimeError: expected scalar type Float but found Double
if self.dtype == "float32":
self.atol = 1e-5
elif self.dtype == "float64":
self.atol = 1e-10
self.seed = 21
self.sel = [10]
self.rcut_smth = 5.80
self.rcut = 6.00
self.neuron = [6, 12, 24]
self.axis_neuron = 3
self.ntypes = 2
self.coords = np.array(
[
12.83,
2.56,
2.18,
12.09,
2.87,
2.74,
00.25,
3.32,
1.68,
3.36,
3.00,
1.81,
3.51,
2.51,
2.60,
4.27,
3.22,
1.56,
],
dtype=GLOBAL_NP_FLOAT_PRECISION,
)
self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32)
self.box = np.array(
[13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0],
dtype=GLOBAL_NP_FLOAT_PRECISION,
)
self.natoms = np.array([6, 6, 2, 4], dtype=np.int32)

repinit = RepinitArgs(
rcut=self.rcut,
rcut_smth=self.rcut_smth,
nsel=10,
tebd_input_mode="strip",
type_one_side=self.type_one_side,
)
repformer = RepformerArgs(
rcut=self.rcut - 1,
rcut_smth=self.rcut_smth - 1,
nsel=9,
)

self.descriptor = DescrptDPA2(
ntypes=self.ntypes,
repinit=repinit,
repformer=repformer,
precision=self.dtype,
)

def test_compressed_forward(self):
result_pt = eval_pt_descriptor(
self.descriptor,
self.natoms,
self.coords,
self.atype,
self.box,
)
self.descriptor.enable_compression(0.5)
result_pt_compressed = eval_pt_descriptor(
self.descriptor,
self.natoms,
self.coords,
self.atype,
self.box,
)

self.assertEqual(result_pt.shape, result_pt_compressed.shape)
torch.testing.assert_close(
result_pt,
result_pt_compressed,
atol=self.atol,
rtol=self.atol,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 22d0192

Please sign in to comment.