Skip to content

Commit

Permalink
add n_update_has_a
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jan 23, 2025
1 parent 7c9a5c2 commit e5e6d49
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 3 deletions.
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(
n_attn_hidden: int = 64,
n_attn_head: int = 4,
pre_ln: bool = False,
n_update_has_a: bool = False,
n_update_has_a_first_sum: bool = False,
) -> None:
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
Expand Down Expand Up @@ -123,6 +125,8 @@ def __init__(
self.a_norm_use_max_v = a_norm_use_max_v
self.e_norm_use_max_v = e_norm_use_max_v
self.e_a_reduce_use_sqrt = e_a_reduce_use_sqrt
self.n_update_has_a = n_update_has_a
self.n_update_has_a_first_sum = n_update_has_a_first_sum

def __getitem__(self, key):
if hasattr(self, key):
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def init_subclass_params(sub_data, sub_class):
a_norm_use_max_v=self.repflow_args.a_norm_use_max_v,
e_norm_use_max_v=self.repflow_args.e_norm_use_max_v,
e_a_reduce_use_sqrt=self.repflow_args.e_a_reduce_use_sqrt,
n_update_has_a=self.repflow_args.n_update_has_a,
n_update_has_a_first_sum=self.repflow_args.n_update_has_a_first_sum,
h1_dim=self.repflow_args.h1_dim,
pre_ln=self.repflow_args.pre_ln,
skip_stat=self.repflow_args.skip_stat,
Expand Down
63 changes: 60 additions & 3 deletions deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def __init__(
a_norm_use_max_v: bool = False,
e_norm_use_max_v: bool = False,
e_a_reduce_use_sqrt: bool = True,
n_update_has_a: bool = False,
n_update_has_a_first_sum: bool = False,
pre_ln: bool = False,
activation_function: str = "silu",
update_style: str = "res_residual",
Expand Down Expand Up @@ -129,6 +131,8 @@ def __init__(
self.a_norm_use_max_v = a_norm_use_max_v
self.e_norm_use_max_v = e_norm_use_max_v
self.e_a_reduce_use_sqrt = e_a_reduce_use_sqrt
self.n_update_has_a = n_update_has_a
self.n_update_has_a_first_sum = n_update_has_a_first_sum

assert update_residual_init in [
"norm",
Expand Down Expand Up @@ -376,6 +380,27 @@ def __init__(
self.a_compress_n_linear = None
self.a_compress_e_linear = None

# node angle message
if self.n_update_has_a:
self.node_angle_linear = MLPLayer(
self.angle_dim,
self.n_dim,
precision=precision,
seed=child_seed(seed, 15),
)
if self.update_style == "res_residual":
self.n_residual.append(
get_residual(
self.n_dim,
self.update_residual,
self.update_residual_init,
precision=precision,
seed=child_seed(seed, 16),
)
)
else:
self.node_angle_linear = None

# edge angle message
self.edge_angle_linear1 = MLPLayer(
self.angle_dim,
Expand Down Expand Up @@ -781,9 +806,6 @@ def forward(
else:
h1_update = None

# update node_ebd
n_updated = self.list_update(n_update_list, "node")

# edge self message
edge_self_update = self.act(self.edge_self_linear(edge_info))
e_update_list.append(edge_self_update)
Expand Down Expand Up @@ -845,6 +867,37 @@ def forward(
# nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c)
angle_info = torch.cat(angle_info_list, dim=-1)

if self.n_update_has_a:
# node angle message
assert self.node_angle_linear is not None
if not self.n_update_has_a_first_sum:
node_angle_update = self.act(self.node_angle_linear(angle_info))
# nb x nloc x a_nnei x a_nnei x n_dim
weighted_node_angle_update = (
node_angle_update
* a_sw[:, :, :, None, None]
* a_sw[:, :, None, :, None]
)
# nb x nloc x n_dim
reduced_node_angle_update = torch.sum(
torch.sum(weighted_node_angle_update, dim=-2), dim=-2
) / (self.a_sel**2)
else:
reduced_angle_info = (
angle_info
* a_sw[:, :, :, None, None]
* a_sw[:, :, None, :, None]
)
# nb x nloc x angle_dim
reduced_angle_info = torch.sum(
torch.sum(reduced_angle_info, dim=-2), dim=-2
) / (self.a_sel**2)
# nb x nloc x n_dim
reduced_node_angle_update = self.act(
self.node_angle_linear(reduced_angle_info)
)
n_update_list.append(reduced_node_angle_update)

# edge angle message
# nb x nloc x a_nnei x a_nnei x e_dim
edge_angle_update = self.act(self.edge_angle_linear1(angle_info))
Expand Down Expand Up @@ -892,6 +945,8 @@ def forward(
e_update_list.append(
self.act(self.edge_angle_linear2(padding_edge_angle_update))
)
# update node_ebd
n_updated = self.list_update(n_update_list, "node")
# update edge_ebd
e_updated = self.list_update(e_update_list, "edge")

Expand All @@ -900,6 +955,8 @@ def forward(
angle_self_update = self.act(self.angle_self_linear(angle_info))
a_update_list.append(angle_self_update)
else:
# update node_ebd
n_updated = self.list_update(n_update_list, "node")
# update edge_ebd
e_updated = self.list_update(e_update_list, "edge")

Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def __init__(
a_norm_use_max_v: bool = False,
e_norm_use_max_v: bool = False,
e_a_reduce_use_sqrt: bool = True,
n_update_has_a: bool = False,
n_update_has_a_first_sum: bool = False,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
r"""
Expand Down Expand Up @@ -217,6 +219,8 @@ def __init__(
self.h1_message_only_nei = h1_message_only_nei
self.h1_dim = h1_dim
self.update_n_has_attn = update_n_has_attn
self.n_update_has_a = n_update_has_a
self.n_update_has_a_first_sum = n_update_has_a_first_sum
self.n_attn_hidden = n_attn_hidden
self.n_attn_head = n_attn_head

Expand Down Expand Up @@ -323,6 +327,8 @@ def __init__(
update_style=self.update_style,
update_residual=self.update_residual,
update_residual_init=self.update_residual_init,
n_update_has_a=self.n_update_has_a,
n_update_has_a_first_sum=self.n_update_has_a_first_sum,
precision=precision,
pre_ln=self.pre_ln,
seed=child_seed(child_seed(seed, 1), ii),
Expand Down
12 changes: 12 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,6 +1637,18 @@ def dpa3_repflow_args():
optional=True,
default=True,
),
Argument(
"n_update_has_a",
bool,
optional=True,
default=False,
),
Argument(
"n_update_has_a_first_sum",
bool,
optional=True,
default=False,
),
]


Expand Down

0 comments on commit e5e6d49

Please sign in to comment.