-
Notifications
You must be signed in to change notification settings - Fork 526
Commit
Fix #4105. Fix #4429. Signed-off-by: Jinzhe Zeng <[email protected]>
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,352 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from typing import ( | ||
Optional, | ||
) | ||
|
||
import array_api_compat | ||
import numpy as np | ||
|
||
from deepmd.dpmodel.loss.loss import ( | ||
Loss, | ||
) | ||
from deepmd.utils.data import ( | ||
DataRequirementItem, | ||
) | ||
from deepmd.utils.version import ( | ||
check_version_compatibility, | ||
) | ||
|
||
|
||
class EnergyLoss(Loss): | ||
def __init__( | ||
self, | ||
starter_learning_rate: float, | ||
start_pref_e: float = 0.02, | ||
limit_pref_e: float = 1.00, | ||
start_pref_f: float = 1000, | ||
limit_pref_f: float = 1.00, | ||
start_pref_v: float = 0.0, | ||
limit_pref_v: float = 0.0, | ||
start_pref_ae: float = 0.0, | ||
limit_pref_ae: float = 0.0, | ||
start_pref_pf: float = 0.0, | ||
limit_pref_pf: float = 0.0, | ||
relative_f: Optional[float] = None, | ||
enable_atom_ener_coeff: bool = False, | ||
start_pref_gf: float = 0.0, | ||
limit_pref_gf: float = 0.0, | ||
numb_generalized_coord: int = 0, | ||
**kwargs, | ||
) -> None: | ||
self.starter_learning_rate = starter_learning_rate | ||
self.start_pref_e = start_pref_e | ||
self.limit_pref_e = limit_pref_e | ||
self.start_pref_f = start_pref_f | ||
self.limit_pref_f = limit_pref_f | ||
self.start_pref_v = start_pref_v | ||
self.limit_pref_v = limit_pref_v | ||
self.start_pref_ae = start_pref_ae | ||
self.limit_pref_ae = limit_pref_ae | ||
self.start_pref_pf = start_pref_pf | ||
self.limit_pref_pf = limit_pref_pf | ||
self.relative_f = relative_f | ||
self.enable_atom_ener_coeff = enable_atom_ener_coeff | ||
self.start_pref_gf = start_pref_gf | ||
self.limit_pref_gf = limit_pref_gf | ||
self.numb_generalized_coord = numb_generalized_coord | ||
self.has_e = self.start_pref_e != 0.0 or self.limit_pref_e != 0.0 | ||
self.has_f = self.start_pref_f != 0.0 or self.limit_pref_f != 0.0 | ||
self.has_v = self.start_pref_v != 0.0 or self.limit_pref_v != 0.0 | ||
self.has_ae = self.start_pref_ae != 0.0 or self.limit_pref_ae != 0.0 | ||
self.has_pf = self.start_pref_pf != 0.0 or self.limit_pref_pf != 0.0 | ||
self.has_gf = self.start_pref_gf != 0.0 or self.limit_pref_gf != 0.0 | ||
if self.has_gf and self.numb_generalized_coord < 1: | ||
raise RuntimeError( | ||
"When generalized force loss is used, the dimension of generalized coordinates should be larger than 0" | ||
) | ||
|
||
def call( | ||
self, | ||
learning_rate: float, | ||
natoms: int, | ||
model_dict: dict[str, np.ndarray], | ||
label_dict: dict[str, np.ndarray], | ||
) -> dict[str, np.ndarray]: | ||
"""Calculate loss from model results and labeled results.""" | ||
energy = model_dict["energy"] | ||
force = model_dict["force"] | ||
virial = model_dict["virial"] | ||
atom_ener = model_dict["atom_ener"] | ||
energy_hat = label_dict["energy"] | ||
force_hat = label_dict["force"] | ||
virial_hat = label_dict["virial"] | ||
atom_ener_hat = label_dict["atom_ener"] | ||
atom_pref = label_dict["atom_pref"] | ||
find_energy = label_dict["find_energy"] | ||
find_force = label_dict["find_force"] | ||
find_virial = label_dict["find_virial"] | ||
find_atom_ener = label_dict["find_atom_ener"] | ||
find_atom_pref = label_dict["find_atom_pref"] | ||
xp = array_api_compat.array_namespace( | ||
energy, | ||
force, | ||
virial, | ||
atom_ener, | ||
energy_hat, | ||
force_hat, | ||
virial_hat, | ||
atom_ener_hat, | ||
atom_pref, | ||
) | ||
if self.has_gf: | ||
drdq = label_dict["drdq"] | ||
Check warning Code scanning / CodeQL Variable defined multiple times Warning
This assignment to 'drdq' is unnecessary as it is
redefined Error loading related location Loading |
||
find_drdq = label_dict["find_drdq"] | ||
|
||
if self.enable_atom_ener_coeff: | ||
# when ener_coeff (\nu) is defined, the energy is defined as | ||
# E = \sum_i \nu_i E_i | ||
# instead of the sum of atomic energies. | ||
# | ||
# A case is that we want to train reaction energy | ||
# A + B -> C + D | ||
# E = - E(A) - E(B) + E(C) + E(D) | ||
# A, B, C, D could be put far away from each other | ||
atom_ener_coeff = label_dict["atom_ener_coeff"] | ||
atom_ener_coeff = xp.reshape(atom_ener_coeff, xp.shape(atom_ener)) | ||
energy = xp.sum(atom_ener_coeff * atom_ener, 1) | ||
if self.has_e: | ||
l2_ener_loss = xp.mean(xp.square(energy - energy_hat)) | ||
|
||
if self.has_f or self.has_pf or self.relative_f or self.has_gf: | ||
force_reshape = xp.reshape(force, [-1]) | ||
force_hat_reshape = xp.reshape(force_hat, [-1]) | ||
diff_f = force_hat_reshape - force_reshape | ||
|
||
if self.relative_f is not None: | ||
force_hat_3 = xp.reshape(force_hat, [-1, 3]) | ||
norm_f = xp.reshape(xp.norm(force_hat_3, axis=1), [-1, 1]) + self.relative_f | ||
diff_f_3 = xp.reshape(diff_f, [-1, 3]) | ||
Check failure Code scanning / CodeQL Potentially uninitialized local variable Error
Local variable 'diff_f' may be used before it is initialized.
|
||
diff_f_3 = diff_f_3 / norm_f | ||
diff_f = xp.reshape(diff_f_3, [-1]) | ||
|
||
if self.has_f: | ||
l2_force_loss = xp.mean(xp.square(diff_f)) | ||
|
||
if self.has_pf: | ||
atom_pref_reshape = xp.reshape(atom_pref, [-1]) | ||
l2_pref_force_loss = xp.mean( | ||
xp.multiply(xp.square(diff_f), atom_pref_reshape), | ||
) | ||
|
||
if self.has_gf: | ||
drdq = label_dict["drdq"] | ||
force_reshape_nframes = xp.reshape(force, [-1, natoms[0] * 3]) | ||
force_hat_reshape_nframes = xp.reshape(force_hat, [-1, natoms[0] * 3]) | ||
drdq_reshape = xp.reshape( | ||
drdq, [-1, natoms[0] * 3, self.numb_generalized_coord] | ||
) | ||
gen_force_hat = xp.einsum( | ||
"bij,bi->bj", drdq_reshape, force_hat_reshape_nframes | ||
) | ||
gen_force = xp.einsum("bij,bi->bj", drdq_reshape, force_reshape_nframes) | ||
diff_gen_force = gen_force_hat - gen_force | ||
l2_gen_force_loss = xp.mean(xp.square(diff_gen_force)) | ||
|
||
if self.has_v: | ||
virial_reshape = xp.reshape(virial, [-1]) | ||
virial_hat_reshape = xp.reshape(virial_hat, [-1]) | ||
l2_virial_loss = xp.mean( | ||
xp.square(virial_hat_reshape - virial_reshape), | ||
) | ||
|
||
if self.has_ae: | ||
atom_ener_reshape = xp.reshape(atom_ener, [-1]) | ||
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, [-1]) | ||
l2_atom_ener_loss = xp.mean( | ||
xp.square(atom_ener_hat_reshape - atom_ener_reshape), | ||
) | ||
|
||
atom_norm = 1.0 / natoms | ||
atom_norm_ener = 1.0 / natoms | ||
lr_ratio = learning_rate / self.starter_learning_rate | ||
pref_e = find_energy * ( | ||
self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * lr_ratio | ||
) | ||
pref_f = find_force * ( | ||
self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * lr_ratio | ||
) | ||
pref_v = find_virial * ( | ||
self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * lr_ratio | ||
) | ||
pref_ae = find_atom_ener * ( | ||
self.limit_pref_ae + (self.start_pref_ae - self.limit_pref_ae) * lr_ratio | ||
) | ||
pref_pf = find_atom_pref * ( | ||
self.limit_pref_pf + (self.start_pref_pf - self.limit_pref_pf) * lr_ratio | ||
) | ||
if self.has_gf: | ||
pref_gf = find_drdq * ( | ||
Check failure Code scanning / CodeQL Potentially uninitialized local variable Error
Local variable 'find_drdq' may be used before it is initialized.
|
||
self.limit_pref_gf | ||
+ (self.start_pref_gf - self.limit_pref_gf) * lr_ratio | ||
) | ||
|
||
l2_loss = 0 | ||
more_loss = {} | ||
if self.has_e: | ||
l2_loss += atom_norm_ener * (pref_e * l2_ener_loss) | ||
Check failure Code scanning / CodeQL Potentially uninitialized local variable Error
Local variable 'l2_ener_loss' may be used before it is initialized.
|
||
more_loss["l2_ener_loss"] = self.display_if_exist(l2_ener_loss, find_energy) | ||
if self.has_f: | ||
l2_loss += pref_f * l2_force_loss | ||
Check failure Code scanning / CodeQL Potentially uninitialized local variable Error
Local variable 'l2_force_loss' may be used before it is initialized.
|
||
more_loss["l2_force_loss"] = self.display_if_exist( | ||
l2_force_loss, find_force | ||
) | ||
if self.has_v: | ||
l2_loss += atom_norm * (pref_v * l2_virial_loss) | ||
Check failure Code scanning / CodeQL Potentially uninitialized local variable Error
Local variable 'l2_virial_loss' may be used before it is initialized.
|
||
more_loss["l2_virial_loss"] = self.display_if_exist( | ||
l2_virial_loss, find_virial | ||
) | ||
if self.has_ae: | ||
l2_loss += pref_ae * l2_atom_ener_loss | ||
Check failure Code scanning / CodeQL Potentially uninitialized local variable Error
Local variable 'l2_atom_ener_loss' may be used before it is initialized.
|
||
more_loss["l2_atom_ener_loss"] = self.display_if_exist( | ||
l2_atom_ener_loss, find_atom_ener | ||
) | ||
if self.has_pf: | ||
l2_loss += pref_pf * l2_pref_force_loss | ||
Check failure Code scanning / CodeQL Potentially uninitialized local variable Error
Local variable 'l2_pref_force_loss' may be used before it is initialized.
|
||
more_loss["l2_pref_force_loss"] = self.display_if_exist( | ||
l2_pref_force_loss, find_atom_pref | ||
) | ||
if self.has_gf: | ||
l2_loss += pref_gf * l2_gen_force_loss | ||
Check failure Code scanning / CodeQL Potentially uninitialized local variable Error
Local variable 'pref_gf' may be used before it is initialized.
Check failure Code scanning / CodeQL Potentially uninitialized local variable Error
Local variable 'l2_gen_force_loss' may be used before it is initialized.
|
||
more_loss["l2_gen_force_loss"] = self.display_if_exist( | ||
l2_gen_force_loss, find_drdq | ||
Check failure Code scanning / CodeQL Potentially uninitialized local variable Error
Local variable 'find_drdq' may be used before it is initialized.
|
||
) | ||
|
||
self.l2_l = l2_loss | ||
self.l2_more = more_loss | ||
return l2_loss, more_loss | ||
|
||
@property | ||
def label_requirement(self) -> list[DataRequirementItem]: | ||
"""Return data label requirements needed for this loss calculation.""" | ||
label_requirement = [] | ||
if self.has_e: | ||
label_requirement.append( | ||
DataRequirementItem( | ||
"energy", | ||
ndof=1, | ||
atomic=False, | ||
must=False, | ||
high_prec=True, | ||
) | ||
) | ||
if self.has_f: | ||
label_requirement.append( | ||
DataRequirementItem( | ||
"force", | ||
ndof=3, | ||
atomic=True, | ||
must=False, | ||
high_prec=False, | ||
) | ||
) | ||
if self.has_v: | ||
label_requirement.append( | ||
DataRequirementItem( | ||
"virial", | ||
ndof=9, | ||
atomic=False, | ||
must=False, | ||
high_prec=False, | ||
) | ||
) | ||
if self.has_ae: | ||
label_requirement.append( | ||
DataRequirementItem( | ||
"atom_ener", | ||
ndof=1, | ||
atomic=True, | ||
must=False, | ||
high_prec=False, | ||
) | ||
) | ||
if self.has_pf: | ||
label_requirement.append( | ||
DataRequirementItem( | ||
"atom_pref", | ||
ndof=1, | ||
atomic=True, | ||
must=False, | ||
high_prec=False, | ||
repeat=3, | ||
) | ||
) | ||
if self.has_gf > 0: | ||
label_requirement.append( | ||
DataRequirementItem( | ||
"drdq", | ||
ndof=self.numb_generalized_coord * 3, | ||
atomic=True, | ||
must=False, | ||
high_prec=False, | ||
) | ||
) | ||
if self.enable_atom_ener_coeff: | ||
label_requirement.append( | ||
DataRequirementItem( | ||
"atom_ener_coeff", | ||
ndof=1, | ||
atomic=True, | ||
must=False, | ||
high_prec=False, | ||
default=1.0, | ||
) | ||
) | ||
return label_requirement | ||
|
||
def serialize(self) -> dict: | ||
"""Serialize the loss module. | ||
Returns | ||
------- | ||
dict | ||
The serialized loss module | ||
""" | ||
return { | ||
"@class": "EnergyLoss", | ||
"@version": 1, | ||
"starter_learning_rate": self.starter_learning_rate, | ||
"start_pref_e": self.start_pref_e, | ||
"limit_pref_e": self.limit_pref_e, | ||
"start_pref_f": self.start_pref_f, | ||
"limit_pref_f": self.limit_pref_f, | ||
"start_pref_v": self.start_pref_v, | ||
"limit_pref_v": self.limit_pref_v, | ||
"start_pref_ae": self.start_pref_ae, | ||
"limit_pref_ae": self.limit_pref_ae, | ||
"start_pref_pf": self.start_pref_pf, | ||
"limit_pref_pf": self.limit_pref_pf, | ||
"relative_f": self.relative_f, | ||
"enable_atom_ener_coeff": self.enable_atom_ener_coeff, | ||
"start_pref_gf": self.start_pref_gf, | ||
"limit_pref_gf": self.limit_pref_gf, | ||
"numb_generalized_coord": self.numb_generalized_coord, | ||
} | ||
|
||
@classmethod | ||
def deserialize(cls, data: dict) -> "Loss": | ||
"""Deserialize the loss module. | ||
Parameters | ||
---------- | ||
data : dict | ||
The serialized loss module | ||
Returns | ||
------- | ||
Loss | ||
The deserialized loss module | ||
""" | ||
data = data.copy() | ||
check_version_compatibility(data.pop("@version"), 1, 1) | ||
data.pop("@class") | ||
return cls(**data) |