Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): support complete form energy loss #3782

Merged
merged 3 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 174 additions & 51 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
Optional,
)

import torch
Expand Down Expand Up @@ -34,6 +35,11 @@
limit_pref_ae: float = 0.0,
start_pref_pf: float = 0.0,
limit_pref_pf: float = 0.0,
relative_f: Optional[float] = None,
enable_atom_ener_coeff: bool = False,
start_pref_gf: float = 0.0,
limit_pref_gf: float = 0.0,
numb_generalized_coord: int = 0,
use_l1_all: bool = False,
inference=False,
**kwargs,
Expand Down Expand Up @@ -64,6 +70,18 @@
The prefactor of atomic prefactor force loss at the start of the training.
limit_pref_pf : float
The prefactor of atomic prefactor force loss at the end of the training.
relative_f : float
If provided, relative force error will be used in the loss. The difference
of force will be normalized by the magnitude of the force in the label with
a shift given by relative_f
enable_atom_ener_coeff : bool
if true, the energy will be computed as \sum_i c_i E_i
start_pref_gf : float
The prefactor of generalized force loss at the start of the training.
limit_pref_gf : float
The prefactor of generalized force loss at the end of the training.
numb_generalized_coord : int
The dimension of generalized coordinates.
use_l1_all : bool
Whether to use L1 loss, if False (default), it will use L2 loss.
inference : bool
Expand All @@ -76,17 +94,29 @@
self.has_e = (start_pref_e != 0.0 and limit_pref_e != 0.0) or inference
self.has_f = (start_pref_f != 0.0 and limit_pref_f != 0.0) or inference
self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference

# TODO EnergyStdLoss need support for atomic energy and atomic pref
self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference
self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference
self.has_gf = (start_pref_gf != 0.0 and limit_pref_gf != 0.0) or inference

self.start_pref_e = start_pref_e
self.limit_pref_e = limit_pref_e
self.start_pref_f = start_pref_f
self.limit_pref_f = limit_pref_f
self.start_pref_v = start_pref_v
self.limit_pref_v = limit_pref_v
self.start_pref_ae = start_pref_ae
self.limit_pref_ae = limit_pref_ae
self.start_pref_pf = start_pref_pf
self.limit_pref_pf = limit_pref_pf
self.start_pref_gf = start_pref_gf
self.limit_pref_gf = limit_pref_gf
self.relative_f = relative_f
self.enable_atom_ener_coeff = enable_atom_ener_coeff
self.numb_generalized_coord = numb_generalized_coord
if self.has_gf and self.numb_generalized_coord < 1:
raise RuntimeError(

Check warning on line 117 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L117

Added line #L117 was not covered by tests
"When generalized force loss is used, the dimension of generalized coordinates should be larger than 0"
)
self.use_l1_all = use_l1_all
self.inference = inference

Expand Down Expand Up @@ -118,18 +148,35 @@
pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef
pref_f = self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * coef
pref_v = self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * coef
pref_ae = self.limit_pref_ae + (self.start_pref_ae - self.limit_pref_ae) * coef
pref_pf = self.limit_pref_pf + (self.start_pref_pf - self.limit_pref_pf) * coef
pref_gf = self.limit_pref_gf + (self.start_pref_gf - self.limit_pref_gf) * coef

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
# more_loss['log_keys'] = [] # showed when validation on the fly
# more_loss['test_keys'] = [] # showed when doing dp test
atom_norm = 1.0 / natoms
if self.has_e and "energy" in model_pred and "energy" in label:
energy_pred = model_pred["energy"]
energy_label = label["energy"]
if self.enable_atom_ener_coeff and "atom_energy" in model_pred:
atom_ener_pred = model_pred["atom_energy"]
# when ener_coeff (\nu) is defined, the energy is defined as
# E = \sum_i \nu_i E_i
# instead of the sum of atomic energies.
#
# A case is that we want to train reaction energy
# A + B -> C + D
# E = - E(A) - E(B) + E(C) + E(D)
# A, B, C, D could be put far away from each other
atom_ener_coeff = label["atom_ener_coeff"]
atom_ener_coeff = atom_ener_coeff.reshape(atom_ener_pred.shape)
energy_pred = torch.sum(atom_ener_coeff * atom_ener_pred, dim=1)
find_energy = label.get("find_energy", 0.0)
pref_e = pref_e * find_energy
if not self.use_l1_all:
l2_ener_loss = torch.mean(
torch.square(model_pred["energy"] - label["energy"])
)
l2_ener_loss = torch.mean(torch.square(energy_pred - energy_label))
if not self.inference:
more_loss["l2_ener_loss"] = self.display_if_exist(
l2_ener_loss.detach(), find_energy
Expand All @@ -142,77 +189,111 @@
# more_loss['log_keys'].append('rmse_e')
else: # use l1 and for all atoms
l1_ener_loss = F.l1_loss(
model_pred["energy"].reshape(-1),
label["energy"].reshape(-1),
energy_pred.reshape(-1),
energy_label.reshape(-1),
reduction="sum",
)
loss += pref_e * l1_ener_loss
more_loss["mae_e"] = self.display_if_exist(
F.l1_loss(
model_pred["energy"].reshape(-1),
label["energy"].reshape(-1),
energy_pred.reshape(-1),
energy_label.reshape(-1),
reduction="mean",
).detach(),
find_energy,
)
# more_loss['log_keys'].append('rmse_e')
if mae:
mae_e = (
torch.mean(torch.abs(model_pred["energy"] - label["energy"]))
* atom_norm
)
mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm

Check warning on line 207 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L207

Added line #L207 was not covered by tests
more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
mae_e_all = torch.mean(
torch.abs(model_pred["energy"] - label["energy"])
)
mae_e_all = torch.mean(torch.abs(energy_pred - energy_label))

Check warning on line 209 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L209

Added line #L209 was not covered by tests
more_loss["mae_e_all"] = self.display_if_exist(
mae_e_all.detach(), find_energy
)

if self.has_f and "force" in model_pred and "force" in label:
if (
(self.has_f or self.has_pf or self.relative_f or self.has_gf)
and "force" in model_pred
and "force" in label
):
find_force = label.get("find_force", 0.0)
pref_f = pref_f * find_force
if "force_target_mask" in model_pred:
force_target_mask = model_pred["force_target_mask"]
else:
force_target_mask = None
if not self.use_l1_all:
if force_target_mask is not None:
diff_f = (label["force"] - model_pred["force"]) * force_target_mask
force_cnt = force_target_mask.squeeze(-1).sum(-1)
l2_force_loss = torch.mean(
torch.square(diff_f).mean(-1).sum(-1) / force_cnt
)
else:
diff_f = label["force"] - model_pred["force"]
force_pred = model_pred["force"]
force_label = label["force"]
diff_f = (force_label - force_pred).reshape(-1)

if self.relative_f is not None:
force_label_3 = force_label.reshape(-1, 3)
norm_f = force_label_3.norm(dim=1, keepdim=True) + self.relative_f
diff_f_3 = diff_f.reshape(-1, 3)
diff_f_3 = diff_f_3 / norm_f
diff_f = diff_f_3.reshape(-1)

if self.has_f:
if not self.use_l1_all:
l2_force_loss = torch.mean(torch.square(diff_f))
if not self.inference:
more_loss["l2_force_loss"] = self.display_if_exist(
l2_force_loss.detach(), find_force
if not self.inference:
more_loss["l2_force_loss"] = self.display_if_exist(
l2_force_loss.detach(), find_force
)
loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
rmse_f = l2_force_loss.sqrt()
more_loss["rmse_f"] = self.display_if_exist(
rmse_f.detach(), find_force
)
loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
rmse_f = l2_force_loss.sqrt()
more_loss["rmse_f"] = self.display_if_exist(rmse_f.detach(), find_force)
else:
l1_force_loss = F.l1_loss(
label["force"], model_pred["force"], reduction="none"
)
if force_target_mask is not None:
l1_force_loss *= force_target_mask
force_cnt = force_target_mask.squeeze(-1).sum(-1)
more_loss["mae_f"] = self.display_if_exist(
(l1_force_loss.mean(-1).sum(-1) / force_cnt).mean(), find_force
)
l1_force_loss = (l1_force_loss.sum(-1).sum(-1) / force_cnt).sum()
else:
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="none")

Check warning on line 245 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L245

Added line #L245 was not covered by tests
more_loss["mae_f"] = self.display_if_exist(
l1_force_loss.mean().detach(), find_force
)
l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
if mae:
mae_f = torch.mean(torch.abs(diff_f))
more_loss["mae_f"] = self.display_if_exist(mae_f.detach(), find_force)
loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)

Check warning on line 250 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L250

Added line #L250 was not covered by tests
if mae:
mae_f = torch.mean(torch.abs(diff_f))
more_loss["mae_f"] = self.display_if_exist(

Check warning on line 253 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L252-L253

Added lines #L252 - L253 were not covered by tests
mae_f.detach(), find_force
)

if self.has_pf and "atom_pref" in label:
atom_pref = label["atom_pref"]
find_atom_pref = label.get("find_atom_pref", 0.0)
pref_pf = pref_pf * find_atom_pref
atom_pref_reshape = atom_pref.reshape(-1)
l2_pref_force_loss = (torch.square(diff_f) * atom_pref_reshape).mean()
if not self.inference:
more_loss["l2_pref_force_loss"] = self.display_if_exist(
l2_pref_force_loss.detach(), find_atom_pref
)
loss += (pref_pf * l2_pref_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
rmse_pf = l2_pref_force_loss.sqrt()
more_loss["rmse_pf"] = self.display_if_exist(
rmse_pf.detach(), find_atom_pref
)

if self.has_gf and "drdq" in label:
drdq = label["drdq"]
find_drdq = label.get("find_drdq", 0.0)
pref_gf = pref_gf * find_drdq
force_reshape_nframes = force_pred.reshape(-1, natoms * 3)
force_label_reshape_nframes = force_label.reshape(-1, natoms * 3)
drdq_reshape = drdq.reshape(-1, natoms * 3, self.numb_generalized_coord)
gen_force_label = torch.einsum(
"bij,bi->bj", drdq_reshape, force_label_reshape_nframes
)
gen_force = torch.einsum(
"bij,bi->bj", drdq_reshape, force_reshape_nframes
)
diff_gen_force = gen_force_label - gen_force
l2_gen_force_loss = torch.square(diff_gen_force).mean()
if not self.inference:
more_loss["l2_gen_force_loss"] = self.display_if_exist(
l2_gen_force_loss.detach(), find_drdq
)
loss += (pref_gf * l2_gen_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
rmse_gf = l2_gen_force_loss.sqrt()
more_loss["rmse_gf"] = self.display_if_exist(
rmse_gf.detach(), find_drdq
)

if self.has_v and "virial" in model_pred and "virial" in label:
find_virial = label.get("find_virial", 0.0)
Expand All @@ -229,6 +310,27 @@
if mae:
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)

if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label:
atom_ener = model_pred["atom_energy"]
atom_ener_label = label["atom_ener"]
find_atom_ener = label.get("find_atom_ener", 0.0)
pref_ae = pref_ae * find_atom_ener
atom_ener_reshape = atom_ener.reshape(-1)
atom_ener_label_reshape = atom_ener_label.reshape(-1)
l2_atom_ener_loss = torch.square(
atom_ener_label_reshape - atom_ener_reshape
).mean()
if not self.inference:
more_loss["l2_atom_ener_loss"] = self.display_if_exist(
l2_atom_ener_loss.detach(), find_atom_ener
)
loss += (pref_ae * l2_atom_ener_loss).to(GLOBAL_PT_FLOAT_PRECISION)
rmse_ae = l2_atom_ener_loss.sqrt()
more_loss["rmse_ae"] = self.display_if_exist(
rmse_ae.detach(), find_atom_ener
)

if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return model_pred, loss, more_loss
Expand Down Expand Up @@ -288,4 +390,25 @@
repeat=3,
)
)
if self.has_gf > 0:
label_requirement.append(

Check warning on line 394 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L394

Added line #L394 was not covered by tests
DataRequirementItem(
"drdq",
ndof=self.numb_generalized_coord * 3,
atomic=True,
must=False,
high_prec=False,
)
)
if self.enable_atom_ener_coeff:
label_requirement.append(

Check warning on line 404 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L404

Added line #L404 was not covered by tests
DataRequirementItem(
"atom_ener_coeff",
ndof=1,
atomic=True,
must=False,
high_prec=False,
default=1.0,
)
)
return label_requirement
Loading