Skip to content

Commit

Permalink
feat(pt): train with energy Hessian (#4169)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced support for Hessian calculations across various components,
enhancing the model's capabilities.
- Added a new loss function for Hessian, allowing for more comprehensive
training scenarios.
- New JSON configuration files for multi-task and single-task learning
models.
- Enhanced output handling to include Hessian data in model evaluations.
- Added new methods and properties to support Hessian in several classes
and modules.

- **Bug Fixes**
- Improved handling of output shapes and results related to Hessian
data.

- **Documentation**
- Updated documentation to include new Hessian properties and training
guidelines.
- Added sections detailing Hessian configurations and requirements in
the training documentation.

- **Tests**
- Added unit tests for the new Hessian-related functionalities to ensure
consistency and correctness.
- Enhanced existing test cases to incorporate Hessian data handling and
validation.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Anchor Yu <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: anyangml <[email protected]>
Co-authored-by: Han Wang <[email protected]>
Co-authored-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
5 people authored Dec 27, 2024
1 parent bf79cc6 commit c5ad841
Show file tree
Hide file tree
Showing 50 changed files with 1,083 additions and 45 deletions.
2 changes: 1 addition & 1 deletion deepmd/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def calculate(
cell = None
symbols = self.atoms.get_chemical_symbols()
atype = [self.type_dict[k] for k in symbols]
e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype)
e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype)[:3]
self.results["energy"] = e[0][0]
# see https://gitlab.com/ase/ase/-/merge_requests/2485
self.results["free_energy"] = e[0][0]
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,9 @@ def _get_output_shape(self, odef, nframes, natoms):
# Something wrong here?
# return [nframes, *shape, natoms, 1]
return [nframes, natoms, *odef.shape, 1]
elif odef.category == OutputVariableCategory.DERV_R_DERV_R:
# hessian
return [nframes, 3 * natoms, 3 * natoms]
else:
raise RuntimeError("unknown category")

Expand Down
2 changes: 1 addition & 1 deletion deepmd/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def label(self, data: dict) -> dict:
cell = data["cells"].reshape((nframes, 9))
else:
cell = None
e, f, v = self.dp.eval(coord, cell, atype)
e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype)[:3]
data = data.copy()
data["energies"] = e.reshape((nframes,))
data["forces"] = f.reshape((nframes, natoms, 3))
Expand Down
39 changes: 37 additions & 2 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def test_ener(
if dp.has_spin:
data.add("spin", 3, atomic=True, must=True, high_prec=False)
data.add("force_mag", 3, atomic=True, must=False, high_prec=False)
if dp.has_hessian:
data.add("hessian", 1, atomic=True, must=True, high_prec=False)

test_data = data.get_test()
mixed_type = data.mixed_type
Expand Down Expand Up @@ -352,6 +354,9 @@ def test_ener(
energy = energy.reshape([numb_test, 1])
force = force.reshape([numb_test, -1])
virial = virial.reshape([numb_test, 9])
if dp.has_hessian:
hessian = ret[3]
hessian = hessian.reshape([numb_test, -1])
if has_atom_ener:
ae = ret[3]
av = ret[4]
Expand Down Expand Up @@ -415,6 +420,10 @@ def test_ener(
rmse_ea = rmse_e / natoms
mae_va = mae_v / natoms
rmse_va = rmse_v / natoms
if dp.has_hessian:
diff_h = hessian - test_data["hessian"][:numb_test]
mae_h = mae(diff_h)
rmse_h = rmse(diff_h)
if has_atom_ener:
diff_ae = test_data["atom_ener"][:numb_test].reshape([-1]) - ae.reshape([-1])
mae_ae = mae(diff_ae)
Expand Down Expand Up @@ -447,6 +456,9 @@ def test_ener(
if has_atom_ener:
log.info(f"Atomic ener MAE : {mae_ae:e} eV")
log.info(f"Atomic ener RMSE : {rmse_ae:e} eV")
if dp.has_hessian:
log.info(f"Hessian MAE : {mae_h:e} eV/A^2")
log.info(f"Hessian RMSE : {rmse_h:e} eV/A^2")

if detail_file is not None:
detail_path = Path(detail_file)
Expand Down Expand Up @@ -530,8 +542,24 @@ def test_ener(
"pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz",
append=append_detail,
)
if dp.has_hessian:
data_h = test_data["hessian"][:numb_test].reshape(-1, 1)
pred_h = hessian.reshape(-1, 1)
h = np.concatenate(
(
data_h,
pred_h,
),
axis=1,
)
save_txt_file(
detail_path.with_suffix(".h.out"),
h,
header=f"{system}: data_h pred_h (3Na*3Na matrix in row-major order)",
append=append_detail,
)
if not out_put_spin:
return {
dict_to_return = {
"mae_e": (mae_e, energy.size),
"mae_ea": (mae_ea, energy.size),
"mae_f": (mae_f, force.size),
Expand All @@ -544,7 +572,7 @@ def test_ener(
"rmse_va": (rmse_va, virial.size),
}
else:
return {
dict_to_return = {
"mae_e": (mae_e, energy.size),
"mae_ea": (mae_ea, energy.size),
"mae_fr": (mae_fr, force_r.size),
Expand All @@ -558,6 +586,10 @@ def test_ener(
"rmse_v": (rmse_v, virial.size),
"rmse_va": (rmse_va, virial.size),
}
if dp.has_hessian:
dict_to_return["mae_h"] = (mae_h, hessian.size)
dict_to_return["rmse_h"] = (rmse_h, hessian.size)
return dict_to_return


def print_ener_sys_avg(avg: dict[str, float]) -> None:
Expand All @@ -584,6 +616,9 @@ def print_ener_sys_avg(avg: dict[str, float]) -> None:
log.info(f"Virial RMSE : {avg['rmse_v']:e} eV")
log.info(f"Virial MAE/Natoms : {avg['mae_va']:e} eV")
log.info(f"Virial RMSE/Natoms : {avg['rmse_va']:e} eV")
if "rmse_h" in avg.keys():
log.info(f"Hessian MAE : {avg['mae_h']:e} eV/A^2")
log.info(f"Hessian RMSE : {avg['rmse_h']:e} eV/A^2")


def test_dos(
Expand Down
10 changes: 10 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class DeepEvalBackend(ABC):
# old models in v1
"global_polar": "global_polar",
"wfc": "wfc",
"energy_derv_r_derv_r": "hessian",
}

@abstractmethod
Expand Down Expand Up @@ -274,6 +275,10 @@ def get_has_spin(self) -> bool:
"""Check if the model has spin atom types."""
return False

def get_has_hessian(self):
"""Check if the model has hessian."""
return False

def get_var_name(self) -> str:
"""Get the name of the fitting property."""
raise NotImplementedError
Expand Down Expand Up @@ -543,6 +548,11 @@ def has_spin(self) -> bool:
"""Check if the model has spin."""
return self.deep_eval.get_has_spin()

@property
def has_hessian(self) -> bool:
"""Check if the model has hessian."""
return self.deep_eval.get_has_hessian()

def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model. Only used in old implement."""
return self.deep_eval.get_ntypes_spin()
Expand Down
18 changes: 16 additions & 2 deletions deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def output_def(self) -> ModelOutputDef:
r_differentiable=True,
c_differentiable=True,
atomic=True,
r_hessian=True,
),
]
)
Expand Down Expand Up @@ -99,7 +100,10 @@ def eval(
aparam: Optional[np.ndarray],
mixed_type: bool,
**kwargs: Any,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
) -> Union[
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
]:
pass

@overload
Expand All @@ -113,7 +117,10 @@ def eval(
aparam: Optional[np.ndarray],
mixed_type: bool,
**kwargs: Any,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> Union[
tuple[np.ndarray, np.ndarray, np.ndarray],
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
]:
pass

@overload
Expand Down Expand Up @@ -179,6 +186,8 @@ def eval(
atomic_virial
The atomic virial of the system, in shape (nframes, natoms, 9). Only returned
when atomic is True.
hessian
The Hessian matrix of the system, in shape (nframes, 3 * natoms, 3 * natoms). Returned when available.
"""
# This method has been used by:
# documentation python.md
Expand Down Expand Up @@ -239,6 +248,11 @@ def eval(
force_mag = results["energy_derv_r_mag"].reshape(nframes, natoms, 3)
mask_mag = results["mask_mag"].reshape(nframes, natoms, 1)
result = (*list(result), force_mag, mask_mag)
if self.deep_eval.get_has_hessian():
hessian = results["energy_derv_r_derv_r"].reshape(
nframes, 3 * natoms, 3 * natoms
)
result = (*list(result), hessian)
return result


Expand Down
3 changes: 3 additions & 0 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,9 @@ def _get_output_shape(self, odef, nframes, natoms):
elif odef.category == OutputVariableCategory.OUT:
# atom_energy, atom_tensor
return [nframes, natoms, *odef.shape, 1]
elif odef.category == OutputVariableCategory.DERV_R_DERV_R:
# hessian
return [nframes, 3 * natoms, 3 * natoms]
else:
raise RuntimeError("unknown category")

Expand Down
12 changes: 11 additions & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def __init__(
] = state_dict[item].clone()
state_dict = state_dict_head
model = get_model(self.input_param).to(DEVICE)
model = torch.jit.script(model)
if not self.input_param.get("hessian_mode"):
model = torch.jit.script(model)
self.dp = ModelWrapper(model)
self.dp.load_state_dict(state_dict)
elif str(self.model_path).endswith(".pth"):
Expand Down Expand Up @@ -160,6 +161,7 @@ def __init__(
self._has_spin = getattr(self.dp.model["Default"], "has_spin", False)
if callable(self._has_spin):
self._has_spin = self._has_spin()
self._has_hessian = self.model_def_script.get("hessian_mode", False)

def get_rcut(self) -> float:
"""Get the cutoff radius of this model."""
Expand Down Expand Up @@ -243,6 +245,10 @@ def get_has_spin(self):
"""Check if the model has spin atom types."""
return self._has_spin

def get_has_hessian(self):
"""Check if the model has hessian."""
return self._has_hessian

def eval(
self,
coords: np.ndarray,
Expand Down Expand Up @@ -348,6 +354,7 @@ def _get_request_defs(self, atomic: bool) -> list[OutputVariableDef]:
OutputVariableCategory.REDU,
OutputVariableCategory.DERV_R,
OutputVariableCategory.DERV_C_REDU,
OutputVariableCategory.DERV_R_DERV_R,
)
]

Expand Down Expand Up @@ -577,6 +584,9 @@ def _get_output_shape(self, odef, nframes, natoms):
# Something wrong here?
# return [nframes, *shape, natoms, 1]
return [nframes, natoms, *odef.shape, 1]
elif odef.category == OutputVariableCategory.DERV_R_DERV_R:
return [nframes, 3 * natoms, 3 * natoms]
# return [nframes, *odef.shape, 3 * natoms, 3 * natoms]
else:
raise RuntimeError("unknown category")

Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/infer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def __init__(
] = state_dict[item].clone()
state_dict = state_dict_head

model_params.pop(
"hessian_mode", None
) # wrapper Hessian to Energy model due to JIT limit
self.model_params = deepcopy(model_params)
self.model = get_model(model_params).to(DEVICE)

Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DOSLoss,
)
from .ener import (
EnergyHessianStdLoss,
EnergyStdLoss,
)
from .ener_spin import (
Expand All @@ -24,6 +25,7 @@
__all__ = [
"DOSLoss",
"DenoiseLoss",
"EnergyHessianStdLoss",
"EnergySpinLoss",
"EnergyStdLoss",
"PropertyLoss",
Expand Down
72 changes: 72 additions & 0 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,75 @@ def label_requirement(self) -> list[DataRequirementItem]:
)
)
return label_requirement


class EnergyHessianStdLoss(EnergyStdLoss):
def __init__(
self,
start_pref_h=0.0,
limit_pref_h=0.0,
**kwargs,
):
r"""Enable the layer to compute loss on hessian.
Parameters
----------
start_pref_h : float
The prefactor of hessian loss at the start of the training.
limit_pref_h : float
The prefactor of hessian loss at the end of the training.
**kwargs
Other keyword arguments.
"""
super().__init__(**kwargs)
self.has_h = (start_pref_h != 0.0 and limit_pref_h != 0.0) or self.inference

self.start_pref_h = start_pref_h
self.limit_pref_h = limit_pref_h

def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
model_pred, loss, more_loss = super().forward(
input_dict, model, label, natoms, learning_rate, mae=mae
)
coef = learning_rate / self.starter_learning_rate
pref_h = self.limit_pref_h + (self.start_pref_h - self.limit_pref_h) * coef

if self.has_h and "hessian" in model_pred and "hessian" in label:
find_hessian = label.get("find_hessian", 0.0)
pref_h = pref_h * find_hessian
diff_h = label["hessian"].reshape(
-1,
) - model_pred["hessian"].reshape(
-1,
)
l2_hessian_loss = torch.mean(torch.square(diff_h))
if not self.inference:
more_loss["l2_hessian_loss"] = self.display_if_exist(
l2_hessian_loss.detach(), find_hessian
)
loss += pref_h * l2_hessian_loss
rmse_h = l2_hessian_loss.sqrt()
more_loss["rmse_h"] = self.display_if_exist(rmse_h.detach(), find_hessian)
if mae:
mae_h = torch.mean(torch.abs(diff_h))
more_loss["mae_h"] = self.display_if_exist(mae_h.detach(), find_hessian)

if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return model_pred, loss, more_loss

@property
def label_requirement(self) -> list[DataRequirementItem]:
"""Add hessian label requirement needed for this loss calculation."""
label_requirement = super().label_requirement
if self.has_h:
label_requirement.append(
DataRequirementItem(
"hessian",
ndof=1, # 9=3*3 --> 3N*3N=ndof*natoms*natoms
atomic=True,
must=False,
high_prec=False,
)
)
return label_requirement
5 changes: 3 additions & 2 deletions deepmd/pt/model/descriptor/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ def _make_env_mat(
nall = coord.shape[1]
mask = nlist >= 0
# nlist = nlist * mask ## this impl will contribute nans in Hessian calculation.
nlist = torch.where(mask, nlist, nall - 1)
nlist = torch.where(mask, nlist, nall)
coord_l = coord[:, :natoms].view(bsz, -1, 1, 3)
index = nlist.view(bsz, -1).unsqueeze(-1).expand(-1, -1, 3)
coord_r = torch.gather(coord, 1, index)
coord_pad = torch.concat([coord, coord[:, -1:, :] + rcut], dim=1)
coord_r = torch.gather(coord_pad, 1, index)
coord_r = coord_r.view(bsz, natoms, nnei, 3)
diff = coord_r - coord_l
length = torch.linalg.norm(diff, dim=-1, keepdim=True)
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ def get_standard_model(model_params):
pair_exclude_types=pair_exclude_types,
preset_out_bias=preset_out_bias,
)
if model_params.get("hessian_mode"):
model.enable_hessian()
model.model_def_script = json.dumps(model_params_old)
return model

Expand Down
Loading

0 comments on commit c5ad841

Please sign in to comment.