Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 4 pt descriptor compression #4227

Merged
merged 102 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 96 commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
6bb8795
yan devel
cherryWangY Jun 9, 2024
7ecd122
tabulate_fusion_se_t
cherryWangY Jun 10, 2024
ab670ed
tabulate_fusion_all_op_basic_verion
cherryWangY Jun 10, 2024
9fc3fb0
compile safe version
cherryWangY Jun 10, 2024
cab50c9
compile safe version
cherryWangY Jun 10, 2024
e9ccb98
Merge branch 'devel' of https://github.com/cherryWangY/deepmd-kit int…
cherryWangY Jun 13, 2024
87909e2
se_a & se_atten
cherryWangY Jun 13, 2024
2225975
se_r
cherryWangY Jun 13, 2024
c09a7a7
remove print
cherryWangY Jun 13, 2024
ee5b64e
move pt op test
cherryWangY Jun 13, 2024
c7efbce
Merge remote-tracking branch 'upstream/devel' into devel
cherryWangY Jun 13, 2024
763c7b4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2024
3fe4b64
remove print
cherryWangY Jun 13, 2024
6f76ccf
fixed for commit
cherryWangY Jun 13, 2024
b89ceed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2024
5a8b77e
fix pull request warning
cherryWangY Jun 13, 2024
25ca8c1
fix pr warning
cherryWangY Jun 13, 2024
f1c43f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2024
179e175
gpu test debug
cherryWangY Jun 16, 2024
34c664c
merge
cherryWangY Jun 16, 2024
4cc1478
merge
cherryWangY Jun 16, 2024
5921a60
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 16, 2024
b63209c
table_info set cpu
cherryWangY Jun 17, 2024
a824a80
remove print
cherryWangY Jun 17, 2024
8527819
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2024
e47dcba
add dtype=float64
cherryWangY Jun 17, 2024
95a9566
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2024
114f7a6
add dtype=float64
cherryWangY Jun 17, 2024
9e677f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2024
22ae3b7
test both float64 and float32
njzjz Jun 17, 2024
9920e57
skip tests if customized ops are not enables
njzjz Jun 17, 2024
7dd7f6a
Merge branch 'devel' into devel
cherryWangY Jun 18, 2024
82a5035
Merge branch 'devel' into devel
cherryWangY Jun 20, 2024
e6bc120
Merge branch 'devel' into tabulate_op
njzjz Jun 21, 2024
1523196
reduce test size from 192 atoms to 4 atoms
njzjz Jun 21, 2024
4c6f69b
Merge branch 'deepmodeling:devel' into devel
cherryWangY Jun 29, 2024
ff57db2
basic descriptor se_a
cherryWangY Jul 7, 2024
fdb13bb
Merge branch 'devel' of https://github.com/cherryWangY/deepmd-kit int…
cherryWangY Jul 7, 2024
882297c
basic descriptor se_a
cherryWangY Jul 7, 2024
4afd634
basic descriptor se_a
cherryWangY Jul 7, 2024
6880e8b
basic torch version
cherryWangY Jul 27, 2024
48a7508
compressed se_a test
cherryWangY Aug 6, 2024
1af34ce
se_a debug
cherryWangY Aug 28, 2024
e9915e3
four descriptors compression
cherryWangY Oct 17, 2024
322ad02
remove redundant code for pt descriptor compression
cherryWangY Oct 21, 2024
d7813d9
align to latest version
cherryWangY Oct 21, 2024
0331c43
align to latest version
cherryWangY Oct 21, 2024
d6cea9b
Merge branch 'devel' into devel
cherryWangY Oct 21, 2024
8984eb9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2024
c38472a
remove extra indent
cherryWangY Oct 22, 2024
7d9d1e2
fix pre-commit
cherryWangY Oct 22, 2024
ed1d598
Merge branch 'devel' into devel
cherryWangY Oct 22, 2024
f12fa1b
enhance code robustness
cherryWangY Oct 23, 2024
95e4f9f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2024
776816f
fix pre-commit
cherryWangY Oct 23, 2024
a0d7403
solve some problems
cherryWangY Oct 24, 2024
0d46b21
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2024
2617b6f
fix annotation error
cherryWangY Oct 24, 2024
7b9332b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2024
e2106af
fix self.table and remove locals
cherryWangY Oct 25, 2024
6e1f787
merge latest
cherryWangY Oct 25, 2024
2802947
fix se_atten
cherryWangY Oct 25, 2024
b572034
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2024
8e86382
fix gg undefine
cherryWangY Oct 25, 2024
3eb8d3f
Merge remote-tracking branch 'new-fork/devel' into devel
cherryWangY Oct 25, 2024
7d84da6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2024
37b1644
make torchscript happy
njzjz Oct 25, 2024
c0ab006
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2024
c924750
add fake op for freezing
cherryWangY Oct 26, 2024
78f20fd
add explicit device
cherryWangY Oct 26, 2024
755401e
add parametized tests
cherryWangY Oct 26, 2024
976c8df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2024
126a170
fix parameterized
cherryWangY Oct 26, 2024
706d614
remove useless exception
cherryWangY Oct 26, 2024
b92c1f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2024
4d3c1b8
fix prec bug
cherryWangY Oct 26, 2024
619e5fb
Merge branch 'devel' of https://github.com/cherryWangY/deepmd-kit int…
cherryWangY Oct 26, 2024
1b1f0e9
set se_t precesion as 1e-6 when float64
cherryWangY Oct 27, 2024
b3737b9
avoid using env.DEVICE in the forward
cherryWangY Oct 27, 2024
90c5ac3
change enable_compression
cherryWangY Oct 27, 2024
5d8c96a
solve coderabbitai conversations
cherryWangY Oct 27, 2024
feda81b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 27, 2024
7231776
change error info
cherryWangY Oct 29, 2024
05c9145
avoid calling serialize() multiple times
cherryWangY Oct 29, 2024
dbe92ef
add enable_compression() to BaseBaseModel
cherryWangY Oct 29, 2024
a7b9f66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
f5d1757
remove enable_compression() in base_model
cherryWangY Oct 30, 2024
3fb094a
add enable_compression() in base_descriptor
cherryWangY Oct 30, 2024
621b45f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2024
99f9f6b
simplified code
cherryWangY Oct 30, 2024
4df6df0
update branch
cherryWangY Oct 30, 2024
30c722c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2024
7c86385
merge pt and tf similar implementation of tabulate
cherryWangY Oct 31, 2024
2ce4356
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
39d6d17
Merge branch 'devel' into devel
cherryWangY Oct 31, 2024
cb15335
set is_pt
cherryWangY Oct 31, 2024
15bf3e2
fix for loop; fix codeql warnings
njzjz Oct 31, 2024
689748e
add comment at descrpt SeT build()
cherryWangY Nov 1, 2024
9c7534e
Refactor duplicate code in _get_bias and _get_matrix methods
cherryWangY Nov 1, 2024
5794248
fix device inconsistency
cherryWangY Nov 1, 2024
c06f54a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2024
006c159
Merge branch 'devel' into devel
cherryWangY Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,31 @@ def compute_input_stats(
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.

Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
raise NotImplementedError("This descriptor doesn't support compression!")

cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
@abstractmethod
def fwd(
self,
Expand Down
87 changes: 87 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@
from deepmd.pt.utils.env import (
RESERVED_PRECISON_DICT,
)
from deepmd.pt.utils.tabulate import (
DPTabulate,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.pt.utils.utils import (
ActivationFn,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -261,6 +267,8 @@ def __init__(
if ln_eps is None:
ln_eps = 1e-5

self.tebd_input_mode = tebd_input_mode

del type, spin, attn_mask
self.se_atten = DescrptBlockSeAtten(
rcut,
Expand Down Expand Up @@ -293,6 +301,7 @@ def __init__(
self.use_econf_tebd = use_econf_tebd
self.use_tebd_bias = use_tebd_bias
self.type_map = type_map
self.compress = False
self.type_embedding = TypeEmbedNet(
ntypes,
tebd_dim,
Expand Down Expand Up @@ -551,6 +560,84 @@ def t_cvt(xx):
)
return obj

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.

Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
# do some checks before the mocel compression process
if self.compress:
raise ValueError("Compression is already enabled.")
assert (
not self.se_atten.resnet_dt
), "Model compression error: descriptor resnet_dt must be false!"
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
for tt in self.se_atten.exclude_types:
if (tt[0] not in range(self.se_atten.ntypes)) or (
tt[1] not in range(self.se_atten.ntypes)
):
raise RuntimeError(
"exclude types"
+ str(tt)
+ " must within the number of atomic types "
+ str(self.se_atten.ntypes)
+ "!"
)
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
if (
self.se_atten.ntypes * self.se_atten.ntypes
- len(self.se_atten.exclude_types)
== 0
):
raise RuntimeError(
"Empty embedding-nets are not supported in model compression!"
)

if self.se_atten.attn_layer != 0:
raise RuntimeError("Cannot compress model when attention layer is not 0.")

if self.tebd_input_mode != "strip":
raise RuntimeError("Cannot compress model when tebd_input_mode == 'concat'")

cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
data = self.serialize()
self.table = DPTabulate(
self,
data["neuron"],
data["type_one_side"],
data["exclude_types"],
ActivationFn(data["activation_function"]),
)
self.table_config = [
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
]
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
self.lower, self.upper = self.table.build(
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
)

self.se_atten.enable_compression(
self.table.data, self.table_config, self.lower, self.upper
)
self.compress = True

def forward(
self,
extended_coord: torch.Tensor,
Expand Down
134 changes: 130 additions & 4 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,34 @@
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.tabulate import (
DPTabulate,
)
from deepmd.pt.utils.utils import (
ActivationFn,
)

from .base_descriptor import (
BaseDescriptor,
)

if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_a"):

def tabulate_fusion_se_a(
argument0,
argument1,
argument2,
argument3,
argument4,
) -> list[torch.Tensor]:
raise NotImplementedError(
"tabulate_fusion_se_a is not available since customized PyTorch OP library is not built when freezing the model. "
"See documentation for model compression for details."
)

# Note: this hack cannot actually save a model that can be runned using LAMMPS.
torch.ops.deepmd.tabulate_fusion_se_a = tabulate_fusion_se_a


@BaseDescriptor.register("se_e2_a")
@BaseDescriptor.register("se_a")
Expand Down Expand Up @@ -93,6 +116,7 @@ def __init__(
raise NotImplementedError("old implementation of spin is not supported.")
super().__init__()
self.type_map = type_map
self.compress = False
self.sea = DescrptBlockSeA(
rcut,
rcut_smth,
Expand Down Expand Up @@ -225,6 +249,53 @@ def reinit_exclude(
"""Update the type exclusions."""
self.sea.reinit_exclude(exclude_types)

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.

Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
if self.compress:
raise ValueError("Compression is already enabled.")
data = self.serialize()
self.table = DPTabulate(
self,
data["neuron"],
data["type_one_side"],
data["exclude_types"],
ActivationFn(data["activation_function"]),
)
self.table_config = [
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
]
self.lower, self.upper = self.table.build(
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
)
self.sea.enable_compression(
self.table.data, self.table_config, self.lower, self.upper
)
self.compress = True

def forward(
self,
coord_ext: torch.Tensor,
Expand Down Expand Up @@ -366,6 +437,10 @@ def update_sel(
class DescrptBlockSeA(DescriptorBlock):
ndescrpt: Final[int]
__constants__: ClassVar[list] = ["ndescrpt"]
lower: dict[str, int]
upper: dict[str, int]
table_data: dict[str, torch.Tensor]
table_config: list[Union[int, float]]

def __init__(
self,
Expand Down Expand Up @@ -425,6 +500,13 @@ def __init__(
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)

# add for compression
self.compress = False
self.lower = {}
self.upper = {}
self.table_data = {}
self.table_config = []

ndim = 1 if self.type_one_side else 2
filter_layers = NetworkCollection(
ndim=ndim, ntypes=len(sel), network_type="embedding_network"
Expand All @@ -443,6 +525,7 @@ def __init__(
self.filter_layers = filter_layers
self.stats = None
# set trainable
self.trainable = trainable
for param in self.parameters():
param.requires_grad = trainable

Expand Down Expand Up @@ -470,6 +553,10 @@ def get_dim_out(self) -> int:
"""Returns the output dimension."""
return self.dim_out

def get_dim_rot_mat_1(self) -> int:
"""Returns the first dimension of the rotation matrix. The rotation is of shape dim_1 x 3."""
return self.filter_neuron[-1]

def get_dim_emb(self) -> int:
"""Returns the output dimension."""
return self.neuron[-1]
Expand Down Expand Up @@ -578,6 +665,19 @@ def reinit_exclude(
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

def enable_compression(
self,
table_data,
table_config,
lower,
upper,
) -> None:
self.compress = True
self.table_data = table_data
self.table_config = table_config
self.lower = lower
self.upper = upper

cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
def forward(
self,
nlist: torch.Tensor,
Expand Down Expand Up @@ -627,6 +727,7 @@ def forward(
for embedding_idx, ll in enumerate(self.filter_layers.networks):
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
if self.type_one_side:
ii = embedding_idx
ti = -1
# torch.jit is not happy with slice(None)
# ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device)
# applying a mask seems to cause performance degradation
Expand All @@ -648,10 +749,35 @@ def forward(
rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :]
rr = rr * mm[:, :, None]
ss = rr[:, :, :1]
# nfnl x nt x ng
gg = ll.forward(ss)
# nfnl x 4 x ng
gr = torch.matmul(rr.permute(0, 2, 1), gg)

if self.compress:
if self.type_one_side:
net = "filter_-1_net_" + str(ii)
else:
net = "filter_" + str(ti) + "_net_" + str(ii)
info = [
self.lower[net],
self.upper[net],
self.upper[net] * self.table_config[0],
self.table_config[1],
self.table_config[2],
self.table_config[3],
]
ss = ss.reshape(-1, 1) # xyz_scatter_tensor in tf
tensor_data = self.table_data[net].to(ss.device).to(dtype=self.prec)
gr = torch.ops.deepmd.tabulate_fusion_se_a(
tensor_data.contiguous(),
torch.tensor(info, dtype=self.prec, device="cpu").contiguous(),
ss.contiguous(),
rr.contiguous(),
self.filter_neuron[-1],
)[0]
else:
# nfnl x nt x ng
gg = ll.forward(ss)
# nfnl x 4 x ng
gr = torch.matmul(rr.permute(0, 2, 1), gg)

cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
if ti_mask is not None:
xyz_scatter[ti_mask] += gr
else:
Expand Down
Loading
Loading