Skip to content

Commit

Permalink
Merge branch 'devel' into add_dpa1
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Nov 28, 2024
2 parents bc1cb38 + a6b61b9 commit 48b7146
Show file tree
Hide file tree
Showing 46 changed files with 736 additions and 75 deletions.
7 changes: 7 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def get_sel(self) -> list[int]:
"""Get the neighbor selection."""
return self.descriptor.get_sel()

def set_case_embd(self, case_idx: int):
"""
Set the case embedding of this atomic model by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
self.fitting.set_case_embd(case_idx)

def mixed_types(self) -> bool:
"""If true, the model
1. assumes total number of atoms aligned across frames;
Expand Down
16 changes: 16 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ def get_model_rcuts(self) -> list[float]:
def get_sel(self) -> list[int]:
return [max([model.get_nsel() for model in self.models])]

def set_case_embd(self, case_idx: int):
"""
Set the case embedding of this atomic model by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
for model in self.models:
model.set_case_embd(case_idx)

def get_model_nsels(self) -> list[int]:
"""Get the processed sels for each individual models. Not distinguishing types."""
return [model.get_nsel() for model in self.models]
Expand Down Expand Up @@ -428,6 +436,14 @@ def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel":
data.pop("type", None)
return super().deserialize(data)

def set_case_embd(self, case_idx: int):
"""
Set the case embedding of this atomic model by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
# only set case_idx for dpmodel
self.models[0].set_case_embd(case_idx)

def _compute_weight(
self,
extended_coord: np.ndarray,
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def get_sel(self) -> list[int]:
"""Returns the number of selected atoms for each type."""
pass

@abstractmethod
def set_case_embd(self, case_idx: int) -> None:
"""
Set the case embedding of this atomic model by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
pass

def get_nsel(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return sum(self.get_sel())
Expand Down
9 changes: 9 additions & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ def get_type_map(self) -> list[str]:
def get_sel(self) -> list[int]:
return [self.sel]

def set_case_embd(self, case_idx: int):
"""
Set the case embedding of this atomic model by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
raise NotImplementedError(
"Case identification not supported for PairTabAtomicModel!"
)

def get_nsel(self) -> int:
return self.sel

Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
dim_case_embd: int = 0,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[list[bool]] = None,
Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
dim_case_embd=dim_case_embd,
rcond=rcond,
tot_ener_zero=tot_ener_zero,
trainable=trainable,
Expand Down Expand Up @@ -159,7 +161,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
check_version_compatibility(data.pop("@version", 1), 3, 1)
var_name = data.pop("var_name", None)
assert var_name == "dipole"
return super().deserialize(data)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/dos_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
dim_case_embd: int = 0,
bias_dos: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
trainable: Union[bool, list[bool]] = True,
Expand All @@ -60,6 +61,7 @@ def __init__(
bias_atom=bias_dos,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
dim_case_embd=dim_case_embd,
rcond=rcond,
trainable=trainable,
activation_function=activation_function,
Expand All @@ -73,7 +75,7 @@ def __init__(
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
check_version_compatibility(data.pop("@version", 1), 3, 1)
data["numb_dos"] = data.pop("dim_out")
data.pop("tot_ener_zero", None)
data.pop("var_name", None)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
dim_case_embd: int = 0,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[list[bool]] = None,
Expand All @@ -55,6 +56,7 @@ def __init__(
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
dim_case_embd=dim_case_embd,
rcond=rcond,
tot_ener_zero=tot_ener_zero,
trainable=trainable,
Expand All @@ -73,7 +75,7 @@ def __init__(
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
check_version_compatibility(data.pop("@version", 1), 3, 1)
data.pop("var_name")
data.pop("dim_out")
return super().deserialize(data)
Expand Down
35 changes: 34 additions & 1 deletion deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
dim_case_embd: int = 0,
bias_atom_e: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
Expand All @@ -127,6 +128,7 @@ def __init__(
self.resnet_dt = resnet_dt
self.numb_fparam = numb_fparam
self.numb_aparam = numb_aparam
self.dim_case_embd = dim_case_embd
self.rcond = rcond
self.tot_ener_zero = tot_ener_zero
self.trainable = trainable
Expand Down Expand Up @@ -171,11 +173,16 @@ def __init__(
self.aparam_inv_std = np.ones(self.numb_aparam, dtype=self.prec)
else:
self.aparam_avg, self.aparam_inv_std = None, None
if self.dim_case_embd > 0:
self.case_embd = np.zeros(self.dim_case_embd, dtype=self.prec)
else:
self.case_embd = None
# init networks
in_dim = (
self.dim_descrpt
+ self.numb_fparam
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
+ self.dim_case_embd
)
self.nets = NetworkCollection(
1 if not self.mixed_types else 0,
Expand Down Expand Up @@ -222,6 +229,13 @@ def get_type_map(self) -> list[str]:
"""Get the name to each type of atoms."""
return self.type_map

def set_case_embd(self, case_idx: int):
"""
Set the case embedding of this fitting net by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
self.case_embd = np.eye(self.dim_case_embd, dtype=self.prec)[case_idx]

def change_type_map(
self, type_map: list[str], model_with_new_type_stat=None
) -> None:
Expand Down Expand Up @@ -255,6 +269,8 @@ def __setitem__(self, key, value) -> None:
self.aparam_avg = value
elif key in ["aparam_inv_std"]:
self.aparam_inv_std = value
elif key in ["case_embd"]:
self.case_embd = value
elif key in ["scale"]:
self.scale = value
else:
Expand All @@ -271,6 +287,8 @@ def __getitem__(self, key):
return self.aparam_avg
elif key in ["aparam_inv_std"]:
return self.aparam_inv_std
elif key in ["case_embd"]:
return self.case_embd
elif key in ["scale"]:
return self.scale
else:
Expand All @@ -287,14 +305,15 @@ def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
"@class": "Fitting",
"@version": 2,
"@version": 3,
"var_name": self.var_name,
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
"neuron": self.neuron,
"resnet_dt": self.resnet_dt,
"numb_fparam": self.numb_fparam,
"numb_aparam": self.numb_aparam,
"dim_case_embd": self.dim_case_embd,
"rcond": self.rcond,
"activation_function": self.activation_function,
"precision": self.precision,
Expand All @@ -303,6 +322,7 @@ def serialize(self) -> dict:
"nets": self.nets.serialize(),
"@variables": {
"bias_atom_e": to_numpy_array(self.bias_atom_e),
"case_embd": to_numpy_array(self.case_embd),
"fparam_avg": to_numpy_array(self.fparam_avg),
"fparam_inv_std": to_numpy_array(self.fparam_inv_std),
"aparam_avg": to_numpy_array(self.aparam_avg),
Expand Down Expand Up @@ -423,6 +443,19 @@ def _call_common(
axis=-1,
)

if self.dim_case_embd > 0:
assert self.case_embd is not None
case_embd = xp.tile(xp.reshape(self.case_embd, [1, 1, -1]), [nf, nloc, 1])
xx = xp.concat(
[xx, case_embd],
axis=-1,
)
if xx_zeros is not None:
xx_zeros = xp.concat(
[xx_zeros, case_embd],
axis=-1,
)

# calculate the prediction
if not self.mixed_types:
outs = xp.zeros(
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
dim_case_embd: int = 0,
bias_atom: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
Expand Down Expand Up @@ -155,6 +156,7 @@ def __init__(
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
dim_case_embd=dim_case_embd,
rcond=rcond,
bias_atom_e=bias_atom,
tot_ener_zero=tot_ener_zero,
Expand Down Expand Up @@ -183,7 +185,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
check_version_compatibility(data.pop("@version", 1), 3, 1)
return super().deserialize(data)

def _net_out_dim(self):
Expand Down
6 changes: 4 additions & 2 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
dim_case_embd: int = 0,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[list[bool]] = None,
Expand Down Expand Up @@ -150,6 +151,7 @@ def __init__(
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
dim_case_embd=dim_case_embd,
rcond=rcond,
tot_ener_zero=tot_ener_zero,
trainable=trainable,
Expand Down Expand Up @@ -187,7 +189,7 @@ def __getitem__(self, key):
def serialize(self) -> dict:
data = super().serialize()
data["type"] = "polar"
data["@version"] = 3
data["@version"] = 4
data["embedding_width"] = self.embedding_width
data["fit_diag"] = self.fit_diag
data["shift_diag"] = self.shift_diag
Expand All @@ -198,7 +200,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 3, 1)
check_version_compatibility(data.pop("@version", 1), 4, 1)
var_name = data.pop("var_name", None)
assert var_name == "polar"
return super().deserialize(data)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/property_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
dim_case_embd: int = 0,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
mixed_types: bool = True,
Expand All @@ -99,6 +100,7 @@ def __init__(
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
dim_case_embd=dim_case_embd,
rcond=rcond,
trainable=trainable,
activation_function=activation_function,
Expand All @@ -111,7 +113,7 @@ def __init__(
@classmethod
def deserialize(cls, data: dict) -> "PropertyFittingNet":
data = data.copy()
check_version_compatibility(data.pop("@version"), 2, 1)
check_version_compatibility(data.pop("@version"), 3, 1)
data.pop("dim_out")
data.pop("var_name")
data.pop("tot_ener_zero")
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,9 @@ def serialize(self) -> dict:
def deserialize(cls, data) -> "CM":
return cls(atomic_model_=T_AtomicModel.deserialize(data))

def set_case_embd(self, case_idx: int):
self.atomic_model.set_case_embd(case_idx)

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return self.atomic_model.get_dim_fparam()
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pd/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ def get_sel(self) -> list[int]:
"""Get the neighbor selection."""
return self.sel

def set_case_embd(self, case_idx: int):
"""
Set the case embedding of this atomic model by the given case_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
self.fitting_net.set_case_embd(case_idx)

def mixed_types(self) -> bool:
"""If true, the model
1. assumes total number of atoms aligned across frames;
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pd/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,9 @@ def serialize(self) -> dict:
def deserialize(cls, data) -> "CM":
return cls(atomic_model_=T_AtomicModel.deserialize(data))

def set_case_embd(self, case_idx: int):
self.atomic_model.set_case_embd(case_idx)

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return self.atomic_model.get_dim_fparam()
Expand Down
Loading

0 comments on commit 48b7146

Please sign in to comment.