Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: consistent DPA-1 model #4320

Merged
merged 6 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class DescrptHybrid(BaseDescriptor, NativeOP):
def __init__(
self,
list: list[Union[BaseDescriptor, dict[str, Any]]],
type_map: Optional[list[str]] = None,
ntypes: Optional[int] = None, # to be compat with input
) -> None:
super().__init__()
# warning: list is conflict with built-in list
Expand All @@ -56,6 +58,10 @@ def __init__(
if isinstance(ii, BaseDescriptor):
formatted_descript_list.append(ii)
elif isinstance(ii, dict):
ii = ii.copy()
# only pass if not already set
ii.setdefault("type_map", type_map)
ii.setdefault("ntypes", ntypes)
formatted_descript_list.append(BaseDescriptor(**ii))
else:
raise NotImplementedError
Expand Down
14 changes: 4 additions & 10 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.descriptor.se_e2_a import (
DescrptSeA,
)
from deepmd.dpmodel.fitting.ener_fitting import (
EnergyFittingNet,
)
Expand Down Expand Up @@ -39,16 +36,13 @@ def get_standard_model(data: dict) -> EnergyModel:
data : dict
The data to construct the model.
"""
descriptor_type = data["descriptor"].pop("type")
data["descriptor"]["type_map"] = data["type_map"]
data["descriptor"]["ntypes"] = len(data["type_map"])
fitting_type = data["fitting_net"].pop("type")
data["fitting_net"]["type_map"] = data["type_map"]
if descriptor_type == "se_e2_a":
descriptor = DescrptSeA(
**data["descriptor"],
)
else:
raise ValueError(f"Unknown descriptor type {descriptor_type}")
descriptor = BaseDescriptor(
**data["descriptor"],
)
if fitting_type == "ener":
fitting = EnergyFittingNet(
ntypes=descriptor.get_ntypes(),
Expand Down
4 changes: 4 additions & 0 deletions deepmd/jax/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from deepmd.jax.descriptor.hybrid import (
DescrptHybrid,
)
from deepmd.jax.descriptor.se_atten_v2 import (
DescrptSeAttenV2,
)
from deepmd.jax.descriptor.se_e2_a import (
DescrptSeA,
)
Expand All @@ -27,6 +30,7 @@
"DescrptSeT",
"DescrptSeTTebd",
"DescrptDPA1",
"DescrptSeAttenV2",
"DescrptDPA2",
"DescrptHybrid",
]
1 change: 1 addition & 0 deletions deepmd/jax/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_standard_model(data: dict):
data = deepcopy(data)
descriptor_type = data["descriptor"].pop("type")
data["descriptor"]["type_map"] = data["type_map"]
data["descriptor"]["ntypes"] = len(data["type_map"])
fitting_type = data["fitting_net"].pop("type")
data["fitting_net"]["type_map"] = data["type_map"]
descriptor = BaseDescriptor.get_class_by_type(descriptor_type)(
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ def _forward_common(

if nd != self.dim_descrpt:
raise ValueError(
"get an input descriptor of dim {nd},"
"which is not consistent with {self.dim_descrpt}."
f"get an input descriptor of dim {nd},"
f"which is not consistent with {self.dim_descrpt}."
)
# check fparam dim, concate to input descriptor
if self.numb_fparam > 0:
Expand Down
6 changes: 6 additions & 0 deletions deepmd/tf/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@
return local_jdata_cpy, min_nbor_dist

def serialize(self, suffix: str = "") -> dict:
if hasattr(self, "type_embedding"):
raise NotImplementedError("hybrid + type embedding is not supported")

Check warning on line 465 in deepmd/tf/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/hybrid.py#L465

Added line #L465 was not covered by tests
return {
"@class": "Descriptor",
"type": "hybrid",
Expand All @@ -485,4 +487,8 @@
for idx, ii in enumerate(data["list"])
],
)
# search for type embedding
for ii in obj.descrpt_list:
if hasattr(ii, "type_embedding"):
raise NotImplementedError("hybrid + type embedding is not supported")

Check warning on line 493 in deepmd/tf/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/hybrid.py#L493

Added line #L493 was not covered by tests
njzjz marked this conversation as resolved.
Show resolved Hide resolved
return obj
123 changes: 96 additions & 27 deletions deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,11 @@
if scaling_factor != 1.0:
raise NotImplementedError("scaling_factor is not supported.")
if not normalize:
raise NotImplementedError("normalize is not supported.")
raise NotImplementedError("Disabling normalize is not supported.")

Check warning on line 222 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L222

Added line #L222 was not covered by tests
if temperature is not None:
raise NotImplementedError("temperature is not supported.")
if not concat_output_tebd:
raise NotImplementedError("concat_output_tebd is not supported.")
raise NotImplementedError("Disbaling concat_output_tebd is not supported.")

Check warning on line 226 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L226

Added line #L226 was not covered by tests
njzjz marked this conversation as resolved.
Show resolved Hide resolved
if env_protection != 0.0:
raise NotImplementedError("env_protection != 0.0 is not supported.")
# to keep consistent with default value in this backends
Expand Down Expand Up @@ -1866,7 +1866,11 @@
if cls is not DescrptSeAtten:
raise NotImplementedError(f"Not implemented in class {cls.__name__}")
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
if data["smooth_type_embedding"]:
raise RuntimeError(

Check warning on line 1870 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L1870

Added line #L1870 was not covered by tests
"The implementation for smooth_type_embedding is inconsistent with other backends"
)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("@class")
data.pop("type")
embedding_net_variables = cls.deserialize_network(
Expand All @@ -1878,10 +1882,13 @@
data.pop("env_mat")
variables = data.pop("@variables")
tebd_input_mode = data["tebd_input_mode"]
if tebd_input_mode in ["strip"]:
raise ValueError(
"Deserialization is unsupported for `tebd_input_mode='strip'` in the native model."
)
type_embedding = TypeEmbedNet.deserialize(
data.pop("type_embedding"), suffix=suffix
)
if "use_tebd_bias" not in data:
# v1 compatibility
data["use_tebd_bias"] = True

Check warning on line 1890 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L1890

Added line #L1890 was not covered by tests
type_embedding.use_tebd_bias = data.pop("use_tebd_bias")
descriptor = cls(**data)
descriptor.embedding_net_variables = embedding_net_variables
descriptor.attention_layer_variables = attention_layer_variables
Expand All @@ -1891,6 +1898,17 @@
descriptor.dstd = variables["dstd"].reshape(
descriptor.ntypes, descriptor.ndescrpt
)
descriptor.type_embedding = type_embedding
if tebd_input_mode in ["strip"]:
type_one_side = data["type_one_side"]
two_side_embeeding_net_variables = cls.deserialize_network_strip(
data.pop("embeddings_strip"),
suffix=suffix,
type_one_side=type_one_side,
)
descriptor.two_side_embeeding_net_variables = (
two_side_embeeding_net_variables
)
return descriptor

def serialize(self, suffix: str = "") -> dict:
Expand All @@ -1906,10 +1924,9 @@
dict
The serialized data
"""
if self.stripped_type_embedding and type(self) is DescrptSeAtten:
# only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip'
raise NotImplementedError(
"serialization is unsupported by the native model when tebd_input_mode=='strip'"
if self.smooth:
raise RuntimeError(

Check warning on line 1928 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L1928

Added line #L1928 was not covered by tests
"The implementation for smooth_type_embedding is inconsistent with other backends"
)
# todo support serialization when tebd_input_mode=='strip' and type_one_side is True
if self.stripped_type_embedding and self.type_one_side:
Expand All @@ -1927,10 +1944,18 @@
assert self.davg is not None
assert self.dstd is not None

tebd_dim = self.type_embedding.neuron[0]
if self.tebd_input_mode in ["concat"]:
if not self.type_one_side:
embd_input_dim = 1 + tebd_dim * 2
else:
embd_input_dim = 1 + tebd_dim

Check warning on line 1952 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L1952

Added line #L1952 was not covered by tests
else:
embd_input_dim = 1
data = {
"@class": "Descriptor",
"type": "se_atten",
"@version": 1,
"type": "dpa1",
"@version": 2,
"rcut": self.rcut_r,
"rcut_smth": self.rcut_r_smth,
"sel": self.sel_a,
Expand All @@ -1952,9 +1977,7 @@
"embeddings": self.serialize_network(
ntypes=self.ntypes,
ndim=0,
in_dim=1
if not hasattr(self, "embd_input_dim")
else self.embd_input_dim,
in_dim=embd_input_dim,
neuron=self.filter_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.filter_resnet_dt,
Expand Down Expand Up @@ -1986,17 +2009,23 @@
"type_one_side": self.type_one_side,
"spin": self.spin,
}
data["type_embedding"] = self.type_embedding.serialize(suffix=suffix)
data["use_tebd_bias"] = self.type_embedding.use_tebd_bias
data["tebd_dim"] = tebd_dim
if len(self.type_embedding.neuron) > 1:
raise NotImplementedError(

Check warning on line 2016 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L2016

Added line #L2016 was not covered by tests
"Only support single layer type embedding network"
)
if self.tebd_input_mode in ["strip"]:
assert (
type(self) is not DescrptSeAtten
), "only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip'"
# assert (
# type(self) is not DescrptSeAtten
# ), "only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip'"
data.update(
{
"embeddings_strip": self.serialize_network_strip(
ntypes=self.ntypes,
ndim=0,
in_dim=2
* self.tebd_dim, # only DescrptDPA1Compat has this attribute
in_dim=2 * tebd_dim,
neuron=self.filter_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.filter_resnet_dt,
Expand All @@ -2006,8 +2035,54 @@
)
}
)
# default values
data.update(
{
"scaling_factor": 1.0,
"normalize": True,
"temperature": None,
"concat_output_tebd": True,
"use_econf_tebd": False,
}
)
data["attention_layers"] = self.update_attention_layers_serialize(
data["attention_layers"]
)
return data

def update_attention_layers_serialize(self, data: dict):
"""Update the serialized data to be consistent with other backend references."""
new_dict = {
"@class": "NeighborGatedAttention",
"@version": 1,
"scaling_factor": 1.0,
"normalize": True,
"temperature": None,
}
new_dict.update(data)
update_info = {
"nnei": self.nnei_a,
"embed_dim": self.filter_neuron[-1],
"hidden_dim": self.att_n,
"dotr": self.attn_dotr,
"do_mask": self.attn_mask,
"scaling_factor": 1.0,
"normalize": True,
"temperature": None,
"precision": self.filter_precision.name,
}
for layer_idx in range(self.attn_layer):
new_dict["attention_layers"][layer_idx].update(update_info)
new_dict["attention_layers"][layer_idx]["attention_layer"].update(

Check warning on line 2076 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L2075-L2076

Added lines #L2075 - L2076 were not covered by tests
update_info
)
new_dict["attention_layers"][layer_idx]["attention_layer"].update(

Check warning on line 2079 in deepmd/tf/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten.py#L2079

Added line #L2079 was not covered by tests
{
"num_heads": 1,
}
)
return new_dict


class DescrptDPA1Compat(DescrptSeAtten):
r"""Consistent version of the model for testing with other backend references.
Expand Down Expand Up @@ -2433,17 +2508,11 @@
{
"type": "dpa1",
"@version": 2,
"tebd_dim": self.tebd_dim,
"scaling_factor": self.scaling_factor,
"normalize": self.normalize,
"temperature": self.temperature,
"concat_output_tebd": self.concat_output_tebd,
"use_econf_tebd": self.use_econf_tebd,
"use_tebd_bias": self.use_tebd_bias,
"type_embedding": self.type_embedding.serialize(suffix),
}
)
data["attention_layers"] = self.update_attention_layers_serialize(
data["attention_layers"]
)
return data
16 changes: 15 additions & 1 deletion deepmd/tf/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
Optional,
)

from deepmd.tf.utils.type_embed import (
TypeEmbedNet,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -127,10 +130,13 @@
Model
The deserialized model
"""
raise RuntimeError(

Check warning on line 133 in deepmd/tf/descriptor/se_atten_v2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten_v2.py#L133

Added line #L133 was not covered by tests
"The implementation for smooth_type_embedding is inconsistent with other backends"
)
if cls is not DescrptSeAttenV2:
raise NotImplementedError(f"Not implemented in class {cls.__name__}")
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
check_version_compatibility(data.pop("@version"), 2, 1)

Check warning on line 139 in deepmd/tf/descriptor/se_atten_v2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten_v2.py#L139

Added line #L139 was not covered by tests
data.pop("@class")
data.pop("type")
embedding_net_variables = cls.deserialize_network(
Expand All @@ -147,6 +153,13 @@
suffix=suffix,
type_one_side=type_one_side,
)
type_embedding = TypeEmbedNet.deserialize(

Check warning on line 156 in deepmd/tf/descriptor/se_atten_v2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten_v2.py#L156

Added line #L156 was not covered by tests
data.pop("type_embedding"), suffix=suffix
)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
if "use_tebd_bias" not in data:

Check warning on line 159 in deepmd/tf/descriptor/se_atten_v2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten_v2.py#L159

Added line #L159 was not covered by tests
# v1 compatibility
data["use_tebd_bias"] = True
type_embedding.use_tebd_bias = data.pop("use_tebd_bias")

Check warning on line 162 in deepmd/tf/descriptor/se_atten_v2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten_v2.py#L161-L162

Added lines #L161 - L162 were not covered by tests
descriptor = cls(**data)
descriptor.embedding_net_variables = embedding_net_variables
descriptor.attention_layer_variables = attention_layer_variables
Expand All @@ -157,6 +170,7 @@
descriptor.dstd = variables["dstd"].reshape(
descriptor.ntypes, descriptor.ndescrpt
)
descriptor.type_embedding = type_embedding

Check warning on line 173 in deepmd/tf/descriptor/se_atten_v2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_atten_v2.py#L173

Added line #L173 was not covered by tests
return descriptor

def serialize(self, suffix: str = "") -> dict:
Expand Down
6 changes: 5 additions & 1 deletion deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def __init__(
len(self.layer_name) == len(self.n_neuron) + 1
), "length of layer_name should be that of n_neuron + 1"
self.mixed_types = mixed_types
self.tebd_dim = 0

def get_numb_fparam(self) -> int:
"""Get the number of frame parameters."""
Expand Down Expand Up @@ -754,6 +755,8 @@ def build(
outs = tf.reshape(outs, [-1])

tf.summary.histogram("fitting_net_output", outs)
# recover original dim_descrpt, which needs to be serialized
self.dim_descrpt = original_dim_descrpt
return tf.reshape(outs, [-1])

def init_variables(
Expand Down Expand Up @@ -908,7 +911,7 @@ def serialize(self, suffix: str = "") -> dict:
"@version": 2,
"var_name": "energy",
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
"dim_descrpt": self.dim_descrpt + self.tebd_dim,
"mixed_types": self.mixed_types,
"dim_out": 1,
"neuron": self.n_neuron,
Expand All @@ -930,6 +933,7 @@ def serialize(self, suffix: str = "") -> dict:
ndim=0 if self.mixed_types else 1,
in_dim=(
self.dim_descrpt
+ self.tebd_dim
+ self.numb_fparam
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
),
Expand Down
Loading