diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 749fe6bbf9..2fa072cc78 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -65,6 +65,13 @@ def get_sel(self) -> list[int]: """Get the neighbor selection.""" return self.descriptor.get_sel() + def set_case_embd(self, case_idx: int): + """ + Set the case embedding of this atomic model by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + self.fitting.set_case_embd(case_idx) + def mixed_types(self) -> bool: """If true, the model 1. assumes total number of atoms aligned across frames; diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 9676b34bfd..8108292bd2 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -134,6 +134,14 @@ def get_model_rcuts(self) -> list[float]: def get_sel(self) -> list[int]: return [max([model.get_nsel() for model in self.models])] + def set_case_embd(self, case_idx: int): + """ + Set the case embedding of this atomic model by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + for model in self.models: + model.set_case_embd(case_idx) + def get_model_nsels(self) -> list[int]: """Get the processed sels for each individual models. Not distinguishing types.""" return [model.get_nsel() for model in self.models] @@ -428,6 +436,14 @@ def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel": data.pop("type", None) return super().deserialize(data) + def set_case_embd(self, case_idx: int): + """ + Set the case embedding of this atomic model by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + # only set case_idx for dpmodel + self.models[0].set_case_embd(case_idx) + def _compute_weight( self, extended_coord: np.ndarray, diff --git a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py index a4c38518a3..01caa7cd64 100644 --- a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py @@ -68,6 +68,14 @@ def get_sel(self) -> list[int]: """Returns the number of selected atoms for each type.""" pass + @abstractmethod + def set_case_embd(self, case_idx: int) -> None: + """ + Set the case embedding of this atomic model by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + pass + def get_nsel(self) -> int: """Returns the total number of selected neighboring atoms in the cut-off radius.""" return sum(self.get_sel()) diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index a4bffe508d..0c35320e7f 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -120,6 +120,15 @@ def get_type_map(self) -> list[str]: def get_sel(self) -> list[int]: return [self.sel] + def set_case_embd(self, case_idx: int): + """ + Set the case embedding of this atomic model by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + raise NotImplementedError( + "Case identification not supported for PairTabAtomicModel!" + ) + def get_nsel(self) -> int: return self.sel diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index c872ef0555..fcaea43338 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -95,6 +95,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, rcond: Optional[float] = None, tot_ener_zero: bool = False, trainable: Optional[list[bool]] = None, @@ -130,6 +131,7 @@ def __init__( resnet_dt=resnet_dt, numb_fparam=numb_fparam, numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, rcond=rcond, tot_ener_zero=tot_ener_zero, trainable=trainable, @@ -159,7 +161,7 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = data.copy() - check_version_compatibility(data.pop("@version", 1), 2, 1) + check_version_compatibility(data.pop("@version", 1), 3, 1) var_name = data.pop("var_name", None) assert var_name == "dipole" return super().deserialize(data) diff --git a/deepmd/dpmodel/fitting/dos_fitting.py b/deepmd/dpmodel/fitting/dos_fitting.py index b4b1ee3cb2..2f6df77eac 100644 --- a/deepmd/dpmodel/fitting/dos_fitting.py +++ b/deepmd/dpmodel/fitting/dos_fitting.py @@ -36,6 +36,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, bias_dos: Optional[np.ndarray] = None, rcond: Optional[float] = None, trainable: Union[bool, list[bool]] = True, @@ -60,6 +61,7 @@ def __init__( bias_atom=bias_dos, numb_fparam=numb_fparam, numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, rcond=rcond, trainable=trainable, activation_function=activation_function, @@ -73,7 +75,7 @@ def __init__( @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = data.copy() - check_version_compatibility(data.pop("@version", 1), 2, 1) + check_version_compatibility(data.pop("@version", 1), 3, 1) data["numb_dos"] = data.pop("dim_out") data.pop("tot_ener_zero", None) data.pop("var_name", None) diff --git a/deepmd/dpmodel/fitting/ener_fitting.py b/deepmd/dpmodel/fitting/ener_fitting.py index 53bedb4cec..6435b6468f 100644 --- a/deepmd/dpmodel/fitting/ener_fitting.py +++ b/deepmd/dpmodel/fitting/ener_fitting.py @@ -32,6 +32,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, rcond: Optional[float] = None, tot_ener_zero: bool = False, trainable: Optional[list[bool]] = None, @@ -55,6 +56,7 @@ def __init__( resnet_dt=resnet_dt, numb_fparam=numb_fparam, numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, rcond=rcond, tot_ener_zero=tot_ener_zero, trainable=trainable, @@ -73,7 +75,7 @@ def __init__( @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = data.copy() - check_version_compatibility(data.pop("@version", 1), 2, 1) + check_version_compatibility(data.pop("@version", 1), 3, 1) data.pop("var_name") data.pop("dim_out") return super().deserialize(data) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 2958a7d18d..c05d84c4a1 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -105,6 +105,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, bias_atom_e: Optional[np.ndarray] = None, rcond: Optional[float] = None, tot_ener_zero: bool = False, @@ -127,6 +128,7 @@ def __init__( self.resnet_dt = resnet_dt self.numb_fparam = numb_fparam self.numb_aparam = numb_aparam + self.dim_case_embd = dim_case_embd self.rcond = rcond self.tot_ener_zero = tot_ener_zero self.trainable = trainable @@ -171,11 +173,16 @@ def __init__( self.aparam_inv_std = np.ones(self.numb_aparam, dtype=self.prec) else: self.aparam_avg, self.aparam_inv_std = None, None + if self.dim_case_embd > 0: + self.case_embd = np.zeros(self.dim_case_embd, dtype=self.prec) + else: + self.case_embd = None # init networks in_dim = ( self.dim_descrpt + self.numb_fparam + (0 if self.use_aparam_as_mask else self.numb_aparam) + + self.dim_case_embd ) self.nets = NetworkCollection( 1 if not self.mixed_types else 0, @@ -222,6 +229,13 @@ def get_type_map(self) -> list[str]: """Get the name to each type of atoms.""" return self.type_map + def set_case_embd(self, case_idx: int): + """ + Set the case embedding of this fitting net by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + self.case_embd = np.eye(self.dim_case_embd, dtype=self.prec)[case_idx] + def change_type_map( self, type_map: list[str], model_with_new_type_stat=None ) -> None: @@ -255,6 +269,8 @@ def __setitem__(self, key, value) -> None: self.aparam_avg = value elif key in ["aparam_inv_std"]: self.aparam_inv_std = value + elif key in ["case_embd"]: + self.case_embd = value elif key in ["scale"]: self.scale = value else: @@ -271,6 +287,8 @@ def __getitem__(self, key): return self.aparam_avg elif key in ["aparam_inv_std"]: return self.aparam_inv_std + elif key in ["case_embd"]: + return self.case_embd elif key in ["scale"]: return self.scale else: @@ -287,7 +305,7 @@ def serialize(self) -> dict: """Serialize the fitting to dict.""" return { "@class": "Fitting", - "@version": 2, + "@version": 3, "var_name": self.var_name, "ntypes": self.ntypes, "dim_descrpt": self.dim_descrpt, @@ -295,6 +313,7 @@ def serialize(self) -> dict: "resnet_dt": self.resnet_dt, "numb_fparam": self.numb_fparam, "numb_aparam": self.numb_aparam, + "dim_case_embd": self.dim_case_embd, "rcond": self.rcond, "activation_function": self.activation_function, "precision": self.precision, @@ -303,6 +322,7 @@ def serialize(self) -> dict: "nets": self.nets.serialize(), "@variables": { "bias_atom_e": to_numpy_array(self.bias_atom_e), + "case_embd": to_numpy_array(self.case_embd), "fparam_avg": to_numpy_array(self.fparam_avg), "fparam_inv_std": to_numpy_array(self.fparam_inv_std), "aparam_avg": to_numpy_array(self.aparam_avg), @@ -423,6 +443,19 @@ def _call_common( axis=-1, ) + if self.dim_case_embd > 0: + assert self.case_embd is not None + case_embd = xp.tile(xp.reshape(self.case_embd, [1, 1, -1]), [nf, nloc, 1]) + xx = xp.concat( + [xx, case_embd], + axis=-1, + ) + if xx_zeros is not None: + xx_zeros = xp.concat( + [xx_zeros, case_embd], + axis=-1, + ) + # calculate the prediction if not self.mixed_types: outs = xp.zeros( diff --git a/deepmd/dpmodel/fitting/invar_fitting.py b/deepmd/dpmodel/fitting/invar_fitting.py index 219589d9ee..b5d3a02d86 100644 --- a/deepmd/dpmodel/fitting/invar_fitting.py +++ b/deepmd/dpmodel/fitting/invar_fitting.py @@ -123,6 +123,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, bias_atom: Optional[np.ndarray] = None, rcond: Optional[float] = None, tot_ener_zero: bool = False, @@ -155,6 +156,7 @@ def __init__( resnet_dt=resnet_dt, numb_fparam=numb_fparam, numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, rcond=rcond, bias_atom_e=bias_atom, tot_ener_zero=tot_ener_zero, @@ -183,7 +185,7 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = data.copy() - check_version_compatibility(data.pop("@version", 1), 2, 1) + check_version_compatibility(data.pop("@version", 1), 3, 1) return super().deserialize(data) def _net_out_dim(self): diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 021359a96e..0db6a23377 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -101,6 +101,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, rcond: Optional[float] = None, tot_ener_zero: bool = False, trainable: Optional[list[bool]] = None, @@ -150,6 +151,7 @@ def __init__( resnet_dt=resnet_dt, numb_fparam=numb_fparam, numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, rcond=rcond, tot_ener_zero=tot_ener_zero, trainable=trainable, @@ -187,7 +189,7 @@ def __getitem__(self, key): def serialize(self) -> dict: data = super().serialize() data["type"] = "polar" - data["@version"] = 3 + data["@version"] = 4 data["embedding_width"] = self.embedding_width data["fit_diag"] = self.fit_diag data["shift_diag"] = self.shift_diag @@ -198,7 +200,7 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = data.copy() - check_version_compatibility(data.pop("@version", 1), 3, 1) + check_version_compatibility(data.pop("@version", 1), 4, 1) var_name = data.pop("var_name", None) assert var_name == "polar" return super().deserialize(data) diff --git a/deepmd/dpmodel/fitting/property_fitting.py b/deepmd/dpmodel/fitting/property_fitting.py index 18a56e3bf9..8b903af00e 100644 --- a/deepmd/dpmodel/fitting/property_fitting.py +++ b/deepmd/dpmodel/fitting/property_fitting.py @@ -78,6 +78,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, @@ -99,6 +100,7 @@ def __init__( resnet_dt=resnet_dt, numb_fparam=numb_fparam, numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, rcond=rcond, trainable=trainable, activation_function=activation_function, @@ -111,7 +113,7 @@ def __init__( @classmethod def deserialize(cls, data: dict) -> "PropertyFittingNet": data = data.copy() - check_version_compatibility(data.pop("@version"), 2, 1) + check_version_compatibility(data.pop("@version"), 3, 1) data.pop("dim_out") data.pop("var_name") data.pop("tot_ener_zero") diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 70ddbe09b8..ccad72c6a5 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -552,6 +552,9 @@ def serialize(self) -> dict: def deserialize(cls, data) -> "CM": return cls(atomic_model_=T_AtomicModel.deserialize(data)) + def set_case_embd(self, case_idx: int): + self.atomic_model.set_case_embd(case_idx) + def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" return self.atomic_model.get_dim_fparam() diff --git a/deepmd/pd/model/atomic_model/dp_atomic_model.py b/deepmd/pd/model/atomic_model/dp_atomic_model.py index 47b881e0cc..25a0f89d77 100644 --- a/deepmd/pd/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pd/model/atomic_model/dp_atomic_model.py @@ -139,6 +139,13 @@ def get_sel(self) -> list[int]: """Get the neighbor selection.""" return self.sel + def set_case_embd(self, case_idx: int): + """ + Set the case embedding of this atomic model by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + self.fitting_net.set_case_embd(case_idx) + def mixed_types(self) -> bool: """If true, the model 1. assumes total number of atoms aligned across frames; diff --git a/deepmd/pd/model/model/make_model.py b/deepmd/pd/model/model/make_model.py index 67b46d4d87..d5c5c6bd41 100644 --- a/deepmd/pd/model/model/make_model.py +++ b/deepmd/pd/model/model/make_model.py @@ -516,6 +516,9 @@ def serialize(self) -> dict: def deserialize(cls, data) -> "CM": return cls(atomic_model_=T_AtomicModel.deserialize(data)) + def set_case_embd(self, case_idx: int): + self.atomic_model.set_case_embd(case_idx) + def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" return self.atomic_model.get_dim_fparam() diff --git a/deepmd/pd/model/task/ener.py b/deepmd/pd/model/task/ener.py index ed0cfac69d..789ef75066 100644 --- a/deepmd/pd/model/task/ener.py +++ b/deepmd/pd/model/task/ener.py @@ -42,6 +42,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, @@ -59,6 +60,7 @@ def __init__( resnet_dt=resnet_dt, numb_fparam=numb_fparam, numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, activation_function=activation_function, precision=precision, mixed_types=mixed_types, @@ -70,7 +72,7 @@ def __init__( @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 2, 1) + check_version_compatibility(data.pop("@version", 1), 3, 1) data.pop("var_name") data.pop("dim_out") return super().deserialize(data) diff --git a/deepmd/pd/model/task/fitting.py b/deepmd/pd/model/task/fitting.py index 9008ef8af3..375cf834cc 100644 --- a/deepmd/pd/model/task/fitting.py +++ b/deepmd/pd/model/task/fitting.py @@ -103,6 +103,9 @@ class GeneralFitting(Fitting): Number of frame parameters. numb_aparam : int Number of atomic parameters. + dim_case_embd : int + (Not supported yet) + Dimension of case specific embedding. activation_function : str Activation function. precision : str @@ -140,6 +143,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, @@ -161,6 +165,10 @@ def __init__( self.resnet_dt = resnet_dt self.numb_fparam = numb_fparam self.numb_aparam = numb_aparam + self.dim_case_embd = dim_case_embd + if dim_case_embd > 0: + raise ValueError("dim_case_embd is not supported yet in PaddlePaddle.") + self.case_embd = None self.activation_function = activation_function self.precision = precision self.prec = PRECISION_DICT[self.precision] @@ -274,7 +282,7 @@ def serialize(self) -> dict: """Serialize the fitting to dict.""" return { "@class": "Fitting", - "@version": 2, + "@version": 3, "var_name": self.var_name, "ntypes": self.ntypes, "dim_descrpt": self.dim_descrpt, @@ -282,6 +290,7 @@ def serialize(self) -> dict: "resnet_dt": self.resnet_dt, "numb_fparam": self.numb_fparam, "numb_aparam": self.numb_aparam, + "dim_case_embd": self.dim_case_embd, "activation_function": self.activation_function, "precision": self.precision, "mixed_types": self.mixed_types, @@ -290,6 +299,7 @@ def serialize(self) -> dict: "exclude_types": self.exclude_types, "@variables": { "bias_atom_e": to_numpy_array(self.bias_atom_e), + "case_embd": None, "fparam_avg": to_numpy_array(self.fparam_avg), "fparam_inv_std": to_numpy_array(self.fparam_inv_std), "aparam_avg": to_numpy_array(self.aparam_avg), @@ -349,6 +359,13 @@ def get_type_map(self) -> list[str]: """Get the name to each type of atoms.""" return self.type_map + def set_case_embd(self, case_idx: int): + """ + Set the case embedding of this fitting net by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + raise NotImplementedError("set_case_embd is not supported yet in PaddlePaddle.") + def __setitem__(self, key, value): if key in ["bias_atom_e"]: value = value.reshape([self.ntypes, self._net_out_dim()]) @@ -361,6 +378,8 @@ def __setitem__(self, key, value): self.aparam_avg = value elif key in ["aparam_inv_std"]: self.aparam_inv_std = value + elif key in ["case_embd"]: + self.case_embd = value elif key in ["scale"]: self.scale = value else: @@ -377,6 +396,8 @@ def __getitem__(self, key): return self.aparam_avg elif key in ["aparam_inv_std"]: return self.aparam_inv_std + elif key in ["case_embd"]: + return self.case_embd elif key in ["scale"]: return self.scale else: diff --git a/deepmd/pd/model/task/invar_fitting.py b/deepmd/pd/model/task/invar_fitting.py index b366fc1d2e..b92c862dc8 100644 --- a/deepmd/pd/model/task/invar_fitting.py +++ b/deepmd/pd/model/task/invar_fitting.py @@ -57,6 +57,9 @@ class InvarFitting(GeneralFitting): Number of frame parameters. numb_aparam : int Number of atomic parameters. + dim_case_embd : int + (Not supported yet) + Dimension of case specific embedding. activation_function : str Activation function. precision : str @@ -92,6 +95,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, @@ -114,6 +118,7 @@ def __init__( resnet_dt=resnet_dt, numb_fparam=numb_fparam, numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, activation_function=activation_function, precision=precision, mixed_types=mixed_types, @@ -142,7 +147,7 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = copy.deepcopy(data) - check_version_compatibility(data.pop("@version", 1), 2, 1) + check_version_compatibility(data.pop("@version", 1), 3, 1) return super().deserialize(data) def output_def(self) -> FittingOutputDef: diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 2cdc97f934..c988d63213 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -93,6 +93,13 @@ def get_sel(self) -> list[int]: """Get the neighbor selection.""" return self.sel + def set_case_embd(self, case_idx: int): + """ + Set the case embedding of this atomic model by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + self.fitting_net.set_case_embd(case_idx) + def mixed_types(self) -> bool: """If true, the model 1. assumes total number of atoms aligned across frames; diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 3a6abccdf6..36c636ddfb 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -158,6 +158,14 @@ def get_model_rcuts(self) -> list[float]: def get_sel(self) -> list[int]: return [max([model.get_nsel() for model in self.models])] + def set_case_embd(self, case_idx: int): + """ + Set the case embedding of this atomic model by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + for model in self.models: + model.set_case_embd(case_idx) + def get_model_nsels(self) -> list[int]: """Get the processed sels for each individual models. Not distinguishing types.""" return [model.get_nsel() for model in self.models] @@ -561,6 +569,14 @@ def serialize(self) -> dict: ) return dd + def set_case_embd(self, case_idx: int): + """ + Set the case embedding of this atomic model by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + # only set case_idx for dpmodel + self.models[0].set_case_embd(case_idx) + @classmethod def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel": data = data.copy() diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index 0d3b2c0c41..62b47afb32 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -141,6 +141,15 @@ def get_type_map(self) -> list[str]: def get_sel(self) -> list[int]: return [self.sel] + def set_case_embd(self, case_idx: int): + """ + Set the case embedding of this atomic model by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + raise NotImplementedError( + "Case identification not supported for PairTabAtomicModel!" + ) + def get_nsel(self) -> int: return self.sel diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index ebad588e32..c8e430960b 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -403,25 +403,8 @@ def share_params(self, base_class, shared_level, resume=False) -> None: ] self.repformers.share_params(base_class.repformers, 0, resume=resume) # shared_level: 1 - # share all parameters in type_embedding and repinit - elif shared_level == 1: - self._modules["type_embedding"] = base_class._modules["type_embedding"] - self.repinit.share_params(base_class.repinit, 0, resume=resume) - if self.use_three_body: - self.repinit_three_body.share_params( - base_class.repinit_three_body, 0, resume=resume - ) - # shared_level: 2 - # share all parameters in type_embedding and repformers - elif shared_level == 2: - self._modules["type_embedding"] = base_class._modules["type_embedding"] - self._modules["g1_shape_tranform"] = base_class._modules[ - "g1_shape_tranform" - ] - self.repformers.share_params(base_class.repformers, 0, resume=resume) - # shared_level: 3 # share all parameters in type_embedding - elif shared_level == 3: + elif shared_level == 1: self._modules["type_embedding"] = base_class._modules["type_embedding"] # Other shared levels else: diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 83abf9ee4a..472eae5329 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -514,6 +514,9 @@ def serialize(self) -> dict: def deserialize(cls, data) -> "CM": return cls(atomic_model_=T_AtomicModel.deserialize(data)) + def set_case_embd(self, case_idx: int): + self.atomic_model.set_case_embd(case_idx) + @torch.jit.export def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index c2db53288a..65b64220ae 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -51,6 +51,8 @@ class DipoleFittingNet(GeneralFitting): Number of frame parameters. numb_aparam : int Number of atomic parameters. + dim_case_embd : int + Dimension of case specific embedding. activation_function : str Activation function. precision : str @@ -81,6 +83,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, @@ -103,6 +106,7 @@ def __init__( resnet_dt=resnet_dt, numb_fparam=numb_fparam, numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, activation_function=activation_function, precision=precision, mixed_types=mixed_types, @@ -128,7 +132,7 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = data.copy() - check_version_compatibility(data.pop("@version", 1), 2, 1) + check_version_compatibility(data.pop("@version", 1), 3, 1) data.pop("var_name", None) return super().deserialize(data) diff --git a/deepmd/pt/model/task/dos.py b/deepmd/pt/model/task/dos.py index a71117e587..568ef81c92 100644 --- a/deepmd/pt/model/task/dos.py +++ b/deepmd/pt/model/task/dos.py @@ -47,6 +47,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, rcond: Optional[float] = None, bias_dos: Optional[torch.Tensor] = None, trainable: Union[bool, list[bool]] = True, @@ -73,6 +74,7 @@ def __init__( resnet_dt=resnet_dt, numb_fparam=numb_fparam, numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, activation_function=activation_function, precision=precision, mixed_types=mixed_types, @@ -99,7 +101,7 @@ def output_def(self) -> FittingOutputDef: @classmethod def deserialize(cls, data: dict) -> "DOSFittingNet": data = data.copy() - check_version_compatibility(data.pop("@version", 1), 2, 1) + check_version_compatibility(data.pop("@version", 1), 3, 1) data.pop("@class", None) data.pop("var_name", None) data.pop("tot_ener_zero", None) diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index 543d987e31..07351b33f6 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -50,6 +50,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, @@ -67,6 +68,7 @@ def __init__( resnet_dt=resnet_dt, numb_fparam=numb_fparam, numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, activation_function=activation_function, precision=precision, mixed_types=mixed_types, @@ -78,7 +80,7 @@ def __init__( @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = data.copy() - check_version_compatibility(data.pop("@version", 1), 2, 1) + check_version_compatibility(data.pop("@version", 1), 3, 1) data.pop("var_name") data.pop("dim_out") return super().deserialize(data) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index fb0954979e..2486ab576f 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -64,14 +64,7 @@ def share_params(self, base_class, shared_level, resume=False) -> None: self.__class__ == base_class.__class__ ), "Only fitting nets of the same type can share params!" if shared_level == 0: - # link buffers - if hasattr(self, "bias_atom_e"): - self.bias_atom_e = base_class.bias_atom_e - # the following will successfully link all the params except buffers, which need manually link. - for item in self._modules: - self._modules[item] = base_class._modules[item] - elif shared_level == 1: - # only not share the bias_atom_e + # only not share the bias_atom_e and the case_embd # the following will successfully link all the params except buffers, which need manually link. for item in self._modules: self._modules[item] = base_class._modules[item] @@ -102,6 +95,8 @@ class GeneralFitting(Fitting): Number of frame parameters. numb_aparam : int Number of atomic parameters. + dim_case_embd : int + Dimension of case specific embedding. activation_function : str Activation function. precision : str @@ -139,6 +134,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, @@ -160,6 +156,7 @@ def __init__( self.resnet_dt = resnet_dt self.numb_fparam = numb_fparam self.numb_aparam = numb_aparam + self.dim_case_embd = dim_case_embd self.activation_function = activation_function self.precision = precision self.prec = PRECISION_DICT[self.precision] @@ -211,10 +208,20 @@ def __init__( else: self.aparam_avg, self.aparam_inv_std = None, None + if self.dim_case_embd > 0: + self.register_buffer( + "case_embd", + torch.zeros(self.dim_case_embd, dtype=self.prec, device=device), + # torch.eye(self.dim_case_embd, dtype=self.prec, device=device)[0], + ) + else: + self.case_embd = None + in_dim = ( self.dim_descrpt + self.numb_fparam + (0 if self.use_aparam_as_mask else self.numb_aparam) + + self.dim_case_embd ) self.filter_layers = NetworkCollection( @@ -274,7 +281,7 @@ def serialize(self) -> dict: """Serialize the fitting to dict.""" return { "@class": "Fitting", - "@version": 2, + "@version": 3, "var_name": self.var_name, "ntypes": self.ntypes, "dim_descrpt": self.dim_descrpt, @@ -282,6 +289,7 @@ def serialize(self) -> dict: "resnet_dt": self.resnet_dt, "numb_fparam": self.numb_fparam, "numb_aparam": self.numb_aparam, + "dim_case_embd": self.dim_case_embd, "activation_function": self.activation_function, "precision": self.precision, "mixed_types": self.mixed_types, @@ -290,6 +298,7 @@ def serialize(self) -> dict: "exclude_types": self.exclude_types, "@variables": { "bias_atom_e": to_numpy_array(self.bias_atom_e), + "case_embd": to_numpy_array(self.case_embd), "fparam_avg": to_numpy_array(self.fparam_avg), "fparam_inv_std": to_numpy_array(self.fparam_inv_std), "aparam_avg": to_numpy_array(self.aparam_avg), @@ -349,6 +358,15 @@ def get_type_map(self) -> list[str]: """Get the name to each type of atoms.""" return self.type_map + def set_case_embd(self, case_idx: int): + """ + Set the case embedding of this fitting net by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + self.case_embd = torch.eye(self.dim_case_embd, dtype=self.prec, device=device)[ + case_idx + ] + def __setitem__(self, key, value) -> None: if key in ["bias_atom_e"]: value = value.view([self.ntypes, self._net_out_dim()]) @@ -361,6 +379,8 @@ def __setitem__(self, key, value) -> None: self.aparam_avg = value elif key in ["aparam_inv_std"]: self.aparam_inv_std = value + elif key in ["case_embd"]: + self.case_embd = value elif key in ["scale"]: self.scale = value else: @@ -377,6 +397,8 @@ def __getitem__(self, key): return self.aparam_avg elif key in ["aparam_inv_std"]: return self.aparam_inv_std + elif key in ["case_embd"]: + return self.case_embd elif key in ["scale"]: return self.scale else: @@ -475,6 +497,19 @@ def _forward_common( dim=-1, ) + if self.dim_case_embd > 0: + assert self.case_embd is not None + case_embd = torch.tile(self.case_embd.reshape([1, 1, -1]), [nf, nloc, 1]) + xx = torch.cat( + [xx, case_embd], + dim=-1, + ) + if xx_zeros is not None: + xx_zeros = torch.cat( + [xx_zeros, case_embd], + dim=-1, + ) + outs = torch.zeros( (nf, nloc, net_dim_out), dtype=self.prec, diff --git a/deepmd/pt/model/task/invar_fitting.py b/deepmd/pt/model/task/invar_fitting.py index 2579f5b9da..b1599eac60 100644 --- a/deepmd/pt/model/task/invar_fitting.py +++ b/deepmd/pt/model/task/invar_fitting.py @@ -56,6 +56,8 @@ class InvarFitting(GeneralFitting): Number of frame parameters. numb_aparam : int Number of atomic parameters. + dim_case_embd : int + Dimension of case specific embedding. activation_function : str Activation function. precision : str @@ -91,6 +93,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, @@ -113,6 +116,7 @@ def __init__( resnet_dt=resnet_dt, numb_fparam=numb_fparam, numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, activation_function=activation_function, precision=precision, mixed_types=mixed_types, @@ -141,7 +145,7 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = data.copy() - check_version_compatibility(data.pop("@version", 1), 2, 1) + check_version_compatibility(data.pop("@version", 1), 3, 1) return super().deserialize(data) def output_def(self) -> FittingOutputDef: diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 8e07896e38..d9a421d635 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -53,6 +53,8 @@ class PolarFittingNet(GeneralFitting): Number of frame parameters. numb_aparam : int Number of atomic parameters. + dim_case_embd : int + Dimension of case specific embedding. activation_function : str Activation function. precision : str @@ -85,6 +87,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, @@ -128,6 +131,7 @@ def __init__( resnet_dt=resnet_dt, numb_fparam=numb_fparam, numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, activation_function=activation_function, precision=precision, mixed_types=mixed_types, @@ -191,7 +195,7 @@ def change_type_map( def serialize(self) -> dict: data = super().serialize() data["type"] = "polar" - data["@version"] = 3 + data["@version"] = 4 data["embedding_width"] = self.embedding_width data["fit_diag"] = self.fit_diag data["shift_diag"] = self.shift_diag @@ -202,7 +206,7 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": data = data.copy() - check_version_compatibility(data.pop("@version", 1), 3, 1) + check_version_compatibility(data.pop("@version", 1), 4, 1) data.pop("var_name", None) return super().deserialize(data) diff --git a/deepmd/pt/model/task/property.py b/deepmd/pt/model/task/property.py index 1c2b9e7c9c..dec0f1447b 100644 --- a/deepmd/pt/model/task/property.py +++ b/deepmd/pt/model/task/property.py @@ -60,6 +60,8 @@ class PropertyFittingNet(InvarFitting): Number of frame parameters. numb_aparam : int Number of atomic parameters. + dim_case_embd : int + Dimension of case specific embedding. activation_function : str Activation function. precision : str @@ -83,6 +85,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, @@ -102,6 +105,7 @@ def __init__( resnet_dt=resnet_dt, numb_fparam=numb_fparam, numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, activation_function=activation_function, precision=precision, mixed_types=mixed_types, @@ -129,7 +133,7 @@ def output_def(self) -> FittingOutputDef: @classmethod def deserialize(cls, data: dict) -> "PropertyFittingNet": data = data.copy() - check_version_compatibility(data.pop("@version", 1), 2, 1) + check_version_compatibility(data.pop("@version", 1), 3, 1) data.pop("dim_out") data.pop("var_name") obj = super().deserialize(data) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index af6e48191d..61683fd857 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -265,7 +265,7 @@ def get_lr(lr_params): self.opt_type, self.opt_param = get_opt_param(training_params) # Model - self.model = get_model_for_wrapper(model_params) + self.model = get_model_for_wrapper(model_params, resuming=resuming) # Loss if not self.multi_task: @@ -1267,7 +1267,7 @@ def get_single_model( return model -def get_model_for_wrapper(_model_params): +def get_model_for_wrapper(_model_params, resuming=False): if "model_dict" not in _model_params: _model = get_single_model( _model_params, @@ -1275,13 +1275,41 @@ def get_model_for_wrapper(_model_params): else: _model = {} model_keys = list(_model_params["model_dict"]) + do_case_embd, case_embd_index = get_case_embd_config(_model_params) for _model_key in model_keys: _model[_model_key] = get_single_model( _model_params["model_dict"][_model_key], ) + if do_case_embd and not resuming: + # only set case_embd when from scratch multitask training + _model[_model_key].set_case_embd(case_embd_index[_model_key]) return _model +def get_case_embd_config(_model_params): + assert ( + "model_dict" in _model_params + ), "Only support setting case embedding for multi-task model!" + model_keys = list(_model_params["model_dict"]) + sorted_model_keys = sorted(model_keys) + numb_case_embd_list = [ + _model_params["model_dict"][model_key] + .get("fitting_net", {}) + .get("dim_case_embd", 0) + for model_key in sorted_model_keys + ] + if not all(item == numb_case_embd_list[0] for item in numb_case_embd_list): + raise ValueError( + f"All models must have the same dimension of case embedding, while the settings are: {numb_case_embd_list}" + ) + if numb_case_embd_list[0] == 0: + return False, {} + case_embd_index = { + model_key: idx for idx, model_key in enumerate(sorted_model_keys) + } + return True, case_embd_index + + def model_change_out_bias( _model, _sample_func, diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index 48119caf19..f0253c283e 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -112,8 +112,10 @@ def share_params(self, shared_links, resume=False) -> None: f"Shared params of {model_key_base}.{class_type_base} and {model_key_link}.{class_type_link}!" ) else: - if hasattr(self.model[model_key_base], class_type_base): - base_class = self.model[model_key_base].__getattr__(class_type_base) + if hasattr(self.model[model_key_base].atomic_model, class_type_base): + base_class = self.model[model_key_base].atomic_model.__getattr__( + class_type_base + ) for link_item in shared_links[shared_item]["links"][1:]: class_type_link = link_item["shared_type"] model_key_link = link_item["model_key"] @@ -124,9 +126,9 @@ def share_params(self, shared_links, resume=False) -> None: assert ( class_type_base == class_type_link ), f"Class type mismatched: {class_type_base} vs {class_type_link}!" - link_class = self.model[model_key_link].__getattr__( - class_type_link - ) + link_class = self.model[ + model_key_link + ].atomic_model.__getattr__(class_type_link) link_class.share_params( base_class, shared_level_link, resume=resume ) diff --git a/deepmd/tf/fit/dipole.py b/deepmd/tf/fit/dipole.py index c05fa4b525..4428d06536 100644 --- a/deepmd/tf/fit/dipole.py +++ b/deepmd/tf/fit/dipole.py @@ -58,6 +58,8 @@ class DipoleFittingSeA(Fitting): Number of frame parameters numb_aparam Number of atomic parameters + dim_case_embd + Dimension of case specific embedding. sel_type : list[int] The atom types selected to have an atomic dipole prediction. If is None, all atoms are selected. seed : int @@ -84,6 +86,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, sel_type: Optional[list[int]] = None, seed: Optional[int] = None, activation_function: str = "tanh", @@ -119,10 +122,13 @@ def __init__( self.type_map = type_map self.numb_fparam = numb_fparam self.numb_aparam = numb_aparam + self.dim_case_embd = dim_case_embd if numb_fparam > 0: raise ValueError("numb_fparam is not supported in the dipole fitting") if numb_aparam > 0: raise ValueError("numb_aparam is not supported in the dipole fitting") + if dim_case_embd > 0: + raise ValueError("dim_case_embd is not supported in TensorFlow.") self.fparam_avg = None self.fparam_std = None self.fparam_inv_std = None @@ -385,7 +391,7 @@ def serialize(self, suffix: str) -> dict: data = { "@class": "Fitting", "type": "dipole", - "@version": 2, + "@version": 3, "ntypes": self.ntypes, "dim_descrpt": self.dim_descrpt, "embedding_width": self.dim_rot_mat_1, @@ -395,6 +401,7 @@ def serialize(self, suffix: str) -> dict: "resnet_dt": self.resnet_dt, "numb_fparam": self.numb_fparam, "numb_aparam": self.numb_aparam, + "dim_case_embd": self.dim_case_embd, "activation_function": self.activation_function_name, "precision": self.fitting_precision.name, "exclude_types": [], @@ -428,7 +435,7 @@ def deserialize(cls, data: dict, suffix: str): The deserialized model """ data = data.copy() - check_version_compatibility(data.pop("@version", 1), 2, 1) + check_version_compatibility(data.pop("@version", 1), 3, 1) fitting = cls(**data) fitting.fitting_net_variables = cls.deserialize_network( data["nets"], diff --git a/deepmd/tf/fit/dos.py b/deepmd/tf/fit/dos.py index 099cba0d12..1da0e55a92 100644 --- a/deepmd/tf/fit/dos.py +++ b/deepmd/tf/fit/dos.py @@ -74,6 +74,8 @@ class DOSFitting(Fitting): Number of frame parameter numb_aparam Number of atomic parameter + dim_case_embd + Dimension of case specific embedding. ! numb_dos (added) Number of gridpoints on which the DOS is evaluated (NEDOS in VASP) rcond @@ -111,6 +113,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, numb_dos: int = 300, rcond: Optional[float] = None, trainable: Optional[list[bool]] = None, @@ -132,6 +135,9 @@ def __init__( self.numb_fparam = numb_fparam self.numb_aparam = numb_aparam + self.dim_case_embd = dim_case_embd + if dim_case_embd > 0: + raise ValueError("dim_case_embd is not supported in TensorFlow.") self.numb_dos = numb_dos @@ -672,7 +678,7 @@ def deserialize(cls, data: dict, suffix: str = ""): The deserialized model """ data = data.copy() - check_version_compatibility(data.pop("@version", 1), 2, 1) + check_version_compatibility(data.pop("@version", 1), 3, 1) data["numb_dos"] = data.pop("dim_out") fitting = cls(**data) fitting.fitting_net_variables = cls.deserialize_network( @@ -699,7 +705,7 @@ def serialize(self, suffix: str = "") -> dict: data = { "@class": "Fitting", "type": "dos", - "@version": 2, + "@version": 3, "var_name": "dos", "ntypes": self.ntypes, "dim_descrpt": self.dim_descrpt, @@ -709,6 +715,7 @@ def serialize(self, suffix: str = "") -> dict: "resnet_dt": self.resnet_dt, "numb_fparam": self.numb_fparam, "numb_aparam": self.numb_aparam, + "dim_case_embd": self.dim_case_embd, "rcond": self.rcond, "trainable": self.trainable, "activation_function": self.activation_function, @@ -731,6 +738,7 @@ def serialize(self, suffix: str = "") -> dict: "fparam_inv_std": self.fparam_inv_std, "aparam_avg": self.aparam_avg, "aparam_inv_std": self.aparam_inv_std, + "case_embd": None, }, "type_map": self.type_map, } diff --git a/deepmd/tf/fit/ener.py b/deepmd/tf/fit/ener.py index 7a3ee8eade..068d3d8e35 100644 --- a/deepmd/tf/fit/ener.py +++ b/deepmd/tf/fit/ener.py @@ -117,6 +117,8 @@ class EnerFitting(Fitting): Number of frame parameter numb_aparam Number of atomic parameter + dim_case_embd + Dimension of case specific embedding. rcond The condition number for the regression of atomic energy. tot_ener_zero @@ -156,6 +158,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, rcond: Optional[float] = None, tot_ener_zero: bool = False, trainable: Optional[list[bool]] = None, @@ -190,6 +193,9 @@ def __init__( # .add("trainable", [list, bool], default = True) self.numb_fparam = numb_fparam self.numb_aparam = numb_aparam + self.dim_case_embd = dim_case_embd + if dim_case_embd > 0: + raise ValueError("dim_case_embd is not supported in TensorFlow.") self.n_neuron = neuron self.resnet_dt = resnet_dt self.rcond = rcond @@ -878,7 +884,7 @@ def deserialize(cls, data: dict, suffix: str = ""): The deserialized model """ data = data.copy() - check_version_compatibility(data.pop("@version", 1), 2, 1) + check_version_compatibility(data.pop("@version", 1), 3, 1) fitting = cls(**data) fitting.fitting_net_variables = cls.deserialize_network( data["nets"], @@ -904,7 +910,7 @@ def serialize(self, suffix: str = "") -> dict: data = { "@class": "Fitting", "type": "ener", - "@version": 2, + "@version": 3, "var_name": "energy", "ntypes": self.ntypes, "dim_descrpt": self.dim_descrpt + self.tebd_dim, @@ -914,6 +920,7 @@ def serialize(self, suffix: str = "") -> dict: "resnet_dt": self.resnet_dt, "numb_fparam": self.numb_fparam, "numb_aparam": self.numb_aparam, + "dim_case_embd": self.dim_case_embd, "rcond": self.rcond, "tot_ener_zero": self.tot_ener_zero, "trainable": self.trainable, @@ -945,6 +952,7 @@ def serialize(self, suffix: str = "") -> dict: "fparam_inv_std": self.fparam_inv_std, "aparam_avg": self.aparam_avg, "aparam_inv_std": self.aparam_inv_std, + "case_embd": None, }, "type_map": self.type_map, } diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index 2f1400e697..14dd6ee092 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -63,6 +63,8 @@ class PolarFittingSeA(Fitting): Number of frame parameters numb_aparam Number of atomic parameters + dim_case_embd + Dimension of case specific embedding. sel_type : list[int] The atom types selected to have an atomic polarizability prediction. If is None, all atoms are selected. fit_diag : bool @@ -95,6 +97,7 @@ def __init__( resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, + dim_case_embd: int = 0, sel_type: Optional[list[int]] = None, fit_diag: bool = True, scale: Optional[list[float]] = None, @@ -162,10 +165,13 @@ def __init__( self.type_map = type_map self.numb_fparam = numb_fparam self.numb_aparam = numb_aparam + self.dim_case_embd = dim_case_embd if numb_fparam > 0: raise ValueError("numb_fparam is not supported in the dipole fitting") if numb_aparam > 0: raise ValueError("numb_aparam is not supported in the dipole fitting") + if dim_case_embd > 0: + raise ValueError("dim_case_embd is not supported in TensorFlow.") self.fparam_avg = None self.fparam_std = None self.fparam_inv_std = None @@ -578,7 +584,7 @@ def serialize(self, suffix: str) -> dict: data = { "@class": "Fitting", "type": "polar", - "@version": 3, + "@version": 4, "ntypes": self.ntypes, "dim_descrpt": self.dim_descrpt, "embedding_width": self.dim_rot_mat_1, @@ -588,6 +594,7 @@ def serialize(self, suffix: str) -> dict: "resnet_dt": self.resnet_dt, "numb_fparam": self.numb_fparam, "numb_aparam": self.numb_aparam, + "dim_case_embd": self.dim_case_embd, "activation_function": self.activation_function_name, "precision": self.fitting_precision.name, "exclude_types": [], @@ -625,7 +632,7 @@ def deserialize(cls, data: dict, suffix: str): """ data = data.copy() check_version_compatibility( - data.pop("@version", 1), 3, 1 + data.pop("@version", 1), 4, 1 ) # to allow PT version. fitting = cls(**data) fitting.fitting_net_variables = cls.deserialize_network( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index d5419a38cd..5b57f15979 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1433,6 +1433,7 @@ def descrpt_variant_type_args(exclude_hybrid: bool = False) -> Variant: def fitting_ener(): doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams." + doc_dim_case_embd = "The dimension of the case embedding embedding. When training or fine-tuning a multitask model with case embedding embeddings, this number should be set to the number of model branches." doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built." doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.' doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." @@ -1459,6 +1460,13 @@ def fitting_ener(): return [ Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam), Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam), + Argument( + "dim_case_embd", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_dim_case_embd, + ), Argument( "neuron", list[int], @@ -1509,6 +1517,7 @@ def fitting_ener(): def fitting_dos(): doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams." + doc_dim_case_embd = "The dimension of the case embedding embedding. When training or fine-tuning a multitask model with case embedding embeddings, this number should be set to the number of model branches." doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built." doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.' doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." @@ -1525,6 +1534,13 @@ def fitting_dos(): return [ Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam), Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam), + Argument( + "dim_case_embd", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_dim_case_embd, + ), Argument( "neuron", list[int], optional=True, default=[120, 120, 120], doc=doc_neuron ), @@ -1556,6 +1572,7 @@ def fitting_dos(): def fitting_property(): doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams." + doc_dim_case_embd = "The dimension of the case embedding embedding. When training or fine-tuning a multitask model with case embedding embeddings, this number should be set to the number of model branches." doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built" doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.' doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' @@ -1567,6 +1584,13 @@ def fitting_property(): return [ Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam), Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam), + Argument( + "dim_case_embd", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_dim_case_embd, + ), Argument( "neuron", list[int], @@ -1597,6 +1621,7 @@ def fitting_property(): def fitting_polar(): doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams." + doc_dim_case_embd = "The dimension of the case embedding embedding. When training or fine-tuning a multitask model with case embedding embeddings, this number should be set to the number of model branches." doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built." doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.' doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' @@ -1625,6 +1650,13 @@ def fitting_polar(): default=0, doc=doc_only_pt_supported + doc_numb_aparam, ), + Argument( + "dim_case_embd", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_dim_case_embd, + ), Argument( "neuron", list[int], @@ -1667,6 +1699,7 @@ def fitting_polar(): def fitting_dipole(): doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams." + doc_dim_case_embd = "The dimension of the case embedding embedding. When training or fine-tuning a multitask model with case embedding embeddings, this number should be set to the number of model branches." doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built." doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.' doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' @@ -1688,6 +1721,13 @@ def fitting_dipole(): default=0, doc=doc_only_pt_supported + doc_numb_aparam, ), + Argument( + "dim_case_embd", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_dim_case_embd, + ), Argument( "neuron", list[int], diff --git a/doc/train/multi-task-training.md b/doc/train/multi-task-training.md index 9d5b71592e..51dffcc5f5 100644 --- a/doc/train/multi-task-training.md +++ b/doc/train/multi-task-training.md @@ -48,14 +48,27 @@ Specifically, there are several parts that need to be modified: - {ref}`model/model_dict `: The core definition of the model part and the explanation of sharing rules, starting with user-defined model name keys `model_key`, such as `my_model_1`. Each model part needs to align with the components of the single-task training {ref}`model `, but with the following sharing rules: -- - If you want to share the current model component with other tasks, which should be part of the {ref}`model/shared_dict `, + + - If you want to share the current model component with other tasks, which should be part of the {ref}`model/shared_dict `, you can directly fill in the corresponding `part_key`, such as `"descriptor": "my_descriptor", ` to replace the previous detailed parameters. Here, you can also specify the shared_level, such as `"descriptor": "my_descriptor:shared_level", ` - and use the user-defined integer `shared_level` in the code to share the corresponding module to varying degrees - (default is to share all parameters, i.e., `shared_level`=0). - The parts that are exclusive to each model can be written following the previous definition. + and use the user-defined integer `shared_level` in the code to share the corresponding module to varying degrees. + - For descriptors, `shared_level` can be set as follows: + - Valid `shared_level` values are 0-1, depending on the descriptor type + - Each level enables different sharing behaviors: + - Level 0: Shares all parameters (default) + - Level 1: Shares type embedding only + - Not all descriptors support all levels (e.g., se_a only supports level 0) + - For fitting nets, we only support the default `shared_level`=0, where all parameters will be shared except for `bias_atom_e` and `case_embd`. + - To conduct multitask training, there are two typical approaches: + 1. **Descriptor sharing only**: Share the descriptor with `shared_level`=0. See [here](../../examples/water_multi_task/pytorch_example/input_torch.json) for an example. + 2. **Descriptor and fitting network sharing with data identification**: + - Share the descriptor and the fitting network with `shared_level`=0. + - {ref}`dim_case_embd ` must be set to the number of model branches, which will distinguish different data tasks using a one-hot embedding. + - See [here](../../examples/water_multi_task/pytorch_example/input_torch_sharefit.json) for an example. + - The parts that are exclusive to each model can be written following the previous definition. - {ref}`loss_dict `: The loss settings corresponding to each task model, specified by the `model_key`. Each {ref}`loss_dict/model_key ` contains the corresponding loss settings, diff --git a/examples/water_multi_task/pytorch_example/input_torch_sharefit.json b/examples/water_multi_task/pytorch_example/input_torch_sharefit.json new file mode 100644 index 0000000000..2fc23007c6 --- /dev/null +++ b/examples/water_multi_task/pytorch_example/input_torch_sharefit.json @@ -0,0 +1,155 @@ +{ + "_comment": "that's all", + "model": { + "shared_dict": { + "type_map_all": [ + "O", + "H" + ], + "dpa2_descriptor": { + "type": "dpa2", + "repinit": { + "tebd_dim": 8, + "rcut": 6.0, + "rcut_smth": 0.5, + "nsel": 120, + "neuron": [ + 25, + 50, + 100 + ], + "axis_neuron": 12, + "activation_function": "tanh", + "three_body_sel": 48, + "three_body_rcut": 4.0, + "three_body_rcut_smth": 3.5, + "use_three_body": true + }, + "repformer": { + "rcut": 4.0, + "rcut_smth": 3.5, + "nsel": 48, + "nlayers": 6, + "g1_dim": 128, + "g2_dim": 32, + "attn2_hidden": 32, + "attn2_nhead": 4, + "attn1_hidden": 128, + "attn1_nhead": 4, + "axis_neuron": 4, + "update_h2": false, + "update_g1_has_conv": true, + "update_g1_has_grrg": true, + "update_g1_has_drrd": true, + "update_g1_has_attn": false, + "update_g2_has_g1g1": false, + "update_g2_has_attn": true, + "update_style": "res_residual", + "update_residual": 0.01, + "update_residual_init": "norm", + "attn2_has_gate": true, + "use_sqrt_nnei": true, + "g1_out_conv": true, + "g1_out_mlp": true + }, + "precision": "float64", + "add_tebd_to_repinit_out": false, + "_comment": " that's all" + }, + "shared_fit_with_id": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "dim_case_embd": 2, + "_comment": " that's all" + }, + "_comment": "that's all" + }, + "model_dict": { + "water_1": { + "type_map": "type_map_all", + "descriptor": "dpa2_descriptor", + "fitting_net": "shared_fit_with_id" + }, + "water_2": { + "type_map": "type_map_all", + "descriptor": "dpa2_descriptor", + "fitting_net": "shared_fit_with_id" + } + } + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-08, + "_comment": "that's all" + }, + "loss_dict": { + "water_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0 + }, + "water_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0 + } + }, + "training": { + "model_prob": { + "water_1": 0.5, + "water_2": 0.5 + }, + "data_dict": { + "water_1": { + "training_data": { + "systems": [ + "../../water/data/data_0/", + "../../water/data/data_1/", + "../../water/data/data_2/" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../../water/data/data_3/" + ], + "batch_size": 1, + "_comment": "that's all" + } + }, + "water_2": { + "training_data": { + "systems": [ + "../../water/data/data_0/", + "../../water/data/data_1/", + "../../water/data/data_2/" + ], + "batch_size": 1, + "_comment": "that's all" + } + } + }, + "numb_steps": 100000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 100, + "_comment": "that's all" + } +} diff --git a/source/tests/common/test_examples.py b/source/tests/common/test_examples.py index 068a91709c..1ddbb50db9 100644 --- a/source/tests/common/test_examples.py +++ b/source/tests/common/test_examples.py @@ -64,6 +64,7 @@ input_files_multi = ( p_examples / "water_multi_task" / "pytorch_example" / "input_torch.json", + p_examples / "water_multi_task" / "pytorch_example" / "input_torch_sharefit.json", ) diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index cb4dbed391..a08e849c6c 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -101,7 +101,7 @@ class CommonTest(ABC): # we may usually skip jax before jax is fully supported skip_jax: ClassVar[bool] = True """Whether to skip the JAX model.""" - skip_pd: ClassVar[bool] = not INSTALLED_PD + skip_pd: ClassVar[bool] = True """Whether to skip the Paddle model.""" skip_array_api_strict: ClassVar[bool] = True """Whether to skip the array_api_strict model.""" @@ -185,7 +185,6 @@ def eval_jax(self, jax_obj: Any) -> Any: """ raise NotImplementedError("Not implemented") - @abstractmethod def eval_pd(self, pd_obj: Any) -> Any: """Evaluate the return value of PD. @@ -194,6 +193,7 @@ def eval_pd(self, pd_obj: Any) -> Any: pd_obj : Any The object of PD """ + raise NotImplementedError("Not implemented") def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: """Evaluate the return value of array_api_strict. diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index a463960fb7..8838696108 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -136,7 +136,7 @@ def skip_pd(self) -> bool: precision, env_protection, ) = self.param - return CommonTest.skip_pd + return not INSTALLED_PD @property def skip_array_api_strict(self) -> bool: diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index 12fafa7ba8..f5a79acabe 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -135,7 +135,7 @@ def skip_pd(self) -> bool: ) = self.param # Paddle do not support "bfloat16" in some kernels, # so skip this in CI test - return CommonTest.skip_pd or precision == "bfloat16" + return not INSTALLED_PD or precision == "bfloat16" tf_class = EnerFittingTF dp_class = EnerFittingDP diff --git a/source/tests/pt/model/water/multitask.json b/source/tests/pt/model/water/multitask.json index 06a4f88e55..e8d998e6f1 100644 --- a/source/tests/pt/model/water/multitask.json +++ b/source/tests/pt/model/water/multitask.json @@ -10,7 +10,8 @@ "type": "se_e2_a", "sel": [ 46, - 92 + 92, + 4 ], "rcut_smth": 0.50, "rcut": 6.00, diff --git a/source/tests/pt/model/water/multitask_sharefit.json b/source/tests/pt/model/water/multitask_sharefit.json new file mode 100644 index 0000000000..246b5992f7 --- /dev/null +++ b/source/tests/pt/model/water/multitask_sharefit.json @@ -0,0 +1,134 @@ +{ + "model": { + "shared_dict": { + "my_type_map": [ + "O", + "H", + "B" + ], + "my_descriptor": { + "type": "se_e2_a", + "sel": [ + 46, + 92, + 4 + ], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 16, + "seed": 1, + "_comment": " that's all" + }, + "my_fitting": { + "dim_case_embd": 2, + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + }, + "_comment": "that's all" + }, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": "my_fitting", + "data_stat_nbatch": 1 + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": "my_fitting", + "data_stat_nbatch": 1 + } + } + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.0002, + "decay_rate": 0.98, + "stop_lr": 3.51e-08, + "_comment": "that's all" + }, + "loss_dict": { + "model_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0 + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0 + } + }, + "training": { + "model_prob": { + "model_1": 0.5, + "model_2": 0.5 + }, + "data_dict": { + "model_1": { + "stat_file": "./stat_files/model_1.hdf5", + "training_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "_comment": "that's all" + } + }, + "model_2": { + "stat_file": "./stat_files/model_2.hdf5", + "training_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "_comment": "that's all" + } + } + }, + "numb_steps": 100000, + "warmup_steps": 0, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 100, + "_comment": "that's all" + } +} diff --git a/source/tests/pt/test_multitask.py b/source/tests/pt/test_multitask.py index a59d6f8e54..62964abad3 100644 --- a/source/tests/pt/test_multitask.py +++ b/source/tests/pt/test_multitask.py @@ -42,12 +42,20 @@ def setUpModule() -> None: with open(multitask_template_json) as f: multitask_template = json.load(f) + global multitask_sharefit_template + multitask_sharefit_template_json = str( + Path(__file__).parent / "water/multitask_sharefit.json" + ) + with open(multitask_sharefit_template_json) as f: + multitask_sharefit_template = json.load(f) + class MultiTaskTrainTest: def test_multitask_train(self) -> None: # test multitask training self.config = update_deepmd_input(self.config, warning=True) self.config = normalize(self.config, multi_task=True) + self.share_fitting = getattr(self, "share_fitting", False) trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links) trainer.run() # check model keys @@ -62,7 +70,12 @@ def test_multitask_train(self) -> None: self.assertIn(state_key.replace("model_1", "model_2"), multi_state_dict) if "model_2" in state_key: self.assertIn(state_key.replace("model_2", "model_1"), multi_state_dict) - if "model_1.descriptor" in state_key: + if ("model_1.atomic_model.descriptor" in state_key) or ( + self.share_fitting + and "model_1.atomic_model.fitting_net" in state_key + and "fitting_net.bias_atom_e" not in state_key + and "fitting_net.case_embd" not in state_key + ): torch.testing.assert_close( multi_state_dict[state_key], multi_state_dict[state_key.replace("model_1", "model_2")], @@ -223,6 +236,46 @@ def tearDown(self) -> None: MultiTaskTrainTest.tearDown(self) +class TestMultiTaskSeASharefit(unittest.TestCase, MultiTaskTrainTest): + def setUp(self) -> None: + multitask_se_e2_a = deepcopy(multitask_sharefit_template) + multitask_se_e2_a["model"]["shared_dict"]["my_descriptor"] = model_se_e2_a[ + "descriptor" + ] + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.stat_files = "se_e2_a_share_fit" + os.makedirs(self.stat_files, exist_ok=True) + self.config = multitask_se_e2_a + self.config["training"]["data_dict"]["model_1"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_1"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_1"]["stat_file"] = ( + f"{self.stat_files}/model_1" + ) + self.config["training"]["data_dict"]["model_2"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_2"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_2"]["stat_file"] = ( + f"{self.stat_files}/model_2" + ) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.origin_config = deepcopy(self.config) + self.config["model"], self.shared_links = preprocess_shared_params( + self.config["model"] + ) + self.share_fitting = True + + def tearDown(self) -> None: + MultiTaskTrainTest.tearDown(self) + + class TestMultiTaskDPA1(unittest.TestCase, MultiTaskTrainTest): def setUp(self) -> None: multitask_DPA1 = deepcopy(multitask_template) diff --git a/source/tests/universal/dpmodel/fitting/test_fitting.py b/source/tests/universal/dpmodel/fitting/test_fitting.py index fe6ffd2e09..db199c02a3 100644 --- a/source/tests/universal/dpmodel/fitting/test_fitting.py +++ b/source/tests/universal/dpmodel/fitting/test_fitting.py @@ -39,7 +39,7 @@ def FittingParamEnergy( exclude_types=[], precision="float64", embedding_width=None, - numb_param=0, # test numb_fparam and numb_aparam together + numb_param=0, # test numb_fparam, numb_aparam and dim_case_embd together ): input_dict = { "ntypes": ntypes, @@ -51,6 +51,7 @@ def FittingParamEnergy( "precision": precision, "numb_fparam": numb_param, "numb_aparam": numb_param, + "dim_case_embd": numb_param, } return input_dict @@ -77,7 +78,7 @@ def FittingParamDos( exclude_types=[], precision="float64", embedding_width=None, - numb_param=0, # test numb_fparam and numb_aparam together + numb_param=0, # test numb_fparam, numb_aparam and dim_case_embd together ): input_dict = { "ntypes": ntypes, @@ -89,6 +90,7 @@ def FittingParamDos( "precision": precision, "numb_fparam": numb_param, "numb_aparam": numb_param, + "dim_case_embd": numb_param, } return input_dict @@ -115,7 +117,7 @@ def FittingParamDipole( exclude_types=[], precision="float64", embedding_width=None, - numb_param=0, # test numb_fparam and numb_aparam together + numb_param=0, # test numb_fparam, numb_aparam and dim_case_embd together ): assert ( embedding_width is not None @@ -131,6 +133,7 @@ def FittingParamDipole( "precision": precision, "numb_fparam": numb_param, "numb_aparam": numb_param, + "dim_case_embd": numb_param, } return input_dict @@ -157,7 +160,7 @@ def FittingParamPolar( exclude_types=[], precision="float64", embedding_width=None, - numb_param=0, # test numb_fparam and numb_aparam together + numb_param=0, # test numb_fparam, numb_aparam and dim_case_embd together ): assert embedding_width is not None, "embedding_width for polar fitting is required." input_dict = { @@ -171,6 +174,7 @@ def FittingParamPolar( "precision": precision, "numb_fparam": numb_param, "numb_aparam": numb_param, + "dim_case_embd": numb_param, } return input_dict @@ -197,7 +201,7 @@ def FittingParamProperty( exclude_types=[], precision="float64", embedding_width=None, - numb_param=0, # test numb_fparam and numb_aparam together + numb_param=0, # test numb_fparam, numb_aparam and dim_case_embd together ): input_dict = { "ntypes": ntypes, @@ -209,6 +213,7 @@ def FittingParamProperty( "precision": precision, "numb_fparam": numb_param, "numb_aparam": numb_param, + "dim_case_embd": numb_param, } return input_dict