Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): consistent fine-tuning with init-model #3803

Merged
merged 38 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
a08ee08
feat(pt): consistent fine-tuning with init-model
iProzd May 22, 2024
da92310
Merge branch 'devel' into rf_finetune
iProzd May 22, 2024
70531cc
FIx uts
iProzd May 22, 2024
591de3e
Update test_finetune.py
iProzd May 22, 2024
7c909cb
Update test_finetune.py
iProzd May 22, 2024
21b77d6
Merge branch 'devel' into rf_finetune
iProzd May 22, 2024
5850a2f
Merge branch 'devel' into rf_finetune
iProzd May 23, 2024
bc8bdf8
Merge branch 'devel' into rf_finetune
iProzd May 29, 2024
297b5d6
Merge branch 'devel' into rf_finetune
iProzd May 30, 2024
638c369
Merge branch 'devel' into rf_finetune
iProzd May 31, 2024
a67ef2c
Update slim type
iProzd Jun 3, 2024
8270305
Merge branch 'devel' into rf_finetune
iProzd Jun 3, 2024
915707b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
951fb1e
Update main.py
iProzd Jun 3, 2024
c03e90a
rm extra doc
iProzd Jun 3, 2024
aa08e30
update uts
iProzd Jun 3, 2024
ddaa38d
fix uts
iProzd Jun 3, 2024
be1a18e
mv change_energy_bias_lower to tf
iProzd Jun 3, 2024
a394d97
resolve conversation
iProzd Jun 4, 2024
4d09586
update version
iProzd Jun 4, 2024
dfbf01f
Update test_cuda.yml
iProzd Jun 4, 2024
f5ee0ab
Revert "Update test_cuda.yml"
iProzd Jun 4, 2024
5664240
Merge branch 'devel' into rf_finetune
iProzd Jun 4, 2024
9f1d473
Merge branch 'devel' into rf_finetune
iProzd Jun 6, 2024
12788ab
Update deepmd/dpmodel/atomic_model/base_atomic_model.py
iProzd Jun 6, 2024
25909aa
Merge branch 'devel' into rf_finetune
iProzd Jun 6, 2024
fbe8396
Add uts for slim_type_map
iProzd Jun 6, 2024
7316f32
Merge branch 'devel' into rf_finetune
iProzd Jun 7, 2024
7c30b47
support extend type map in finetune
iProzd Jun 8, 2024
ab04399
Merge branch 'devel' into rf_finetune
iProzd Jun 8, 2024
4599213
resolve conversations
iProzd Jun 10, 2024
bf20853
add doc for use-pretrain-script in tf
iProzd Jun 10, 2024
af6c8b2
fix tebd
iProzd Jun 10, 2024
ad838c4
add ut for extend stat
iProzd Jun 11, 2024
9fac36e
Merge branch 'devel' into rf_finetune
iProzd Jun 11, 2024
911b043
Update deepmd/main.py
iProzd Jun 11, 2024
c0d57e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 11, 2024
fd64ee5
Update deepmd/utils/finetune.py
iProzd Jun 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@
]
)

def update_type_params(
self,
state_dict: Dict[str, np.ndarray],
mapping_index: List[int],
prefix: str = "",
) -> Dict[str, np.ndarray]:
"""Update the type related params when loading from pretrained model with redundant types."""
raise NotImplementedError

Check warning on line 123 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L123

Added line #L123 was not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved

def forward_common_atomic(
self,
extended_coord: np.ndarray,
Expand Down
10 changes: 10 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,16 @@
def deserialize(cls, data: dict):
pass

@abstractmethod
def update_type_params(
self,
state_dict: Dict[str, t_tensor],
mapping_index: List[int],
prefix: str = "",
) -> Dict[str, t_tensor]:
iProzd marked this conversation as resolved.
Show resolved Hide resolved
"""Update the type related params when loading from pretrained model with redundant types."""
pass

Check warning on line 147 in deepmd/dpmodel/atomic_model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/make_base_atomic_model.py#L147

Added line #L147 was not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved

def make_atom_mask(
self,
atype: t_tensor,
Expand Down
10 changes: 10 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -364,6 +365,15 @@
"""
raise NotImplementedError

def update_type_params(
self,
state_dict: Dict[str, np.ndarray],
mapping_index: List[int],
prefix: str = "",
) -> Dict[str, np.ndarray]:
"""Update the type related params when loading from pretrained model with redundant types."""
raise NotImplementedError

Check warning on line 375 in deepmd/dpmodel/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa1.py#L375

Added line #L375 was not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved
iProzd marked this conversation as resolved.
Show resolved Hide resolved

@property
def dim_out(self):
return self.get_dim_out()
Expand Down
10 changes: 10 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
__version__ = "unknown"

from typing import (
Dict,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -539,6 +540,15 @@
"""
raise NotImplementedError

def update_type_params(
self,
state_dict: Dict[str, np.ndarray],
mapping_index: List[int],
prefix: str = "",
) -> Dict[str, np.ndarray]:
"""Update the type related params when loading from pretrained model with redundant types."""
raise NotImplementedError

Check warning on line 550 in deepmd/dpmodel/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa2.py#L550

Added line #L550 was not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved

@property
def dim_out(self):
return self.get_dim_out()
Expand Down
9 changes: 9 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@
"""
raise NotImplementedError

def update_type_params(
self,
state_dict: Dict[str, np.ndarray],
mapping_index: List[int],
prefix: str = "",
) -> Dict[str, np.ndarray]:
"""Update the type related params when loading from pretrained model with redundant types."""
raise NotImplementedError

Check warning on line 162 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L162

Added line #L162 was not covered by tests

iProzd marked this conversation as resolved.
Show resolved Hide resolved
def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
"""Update mean and stddev for descriptor elements."""
for descrpt in self.descrpt_list:
Expand Down
11 changes: 11 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
)
from typing import (
Callable,
Dict,
List,
Optional,
Union,
Expand Down Expand Up @@ -105,6 +106,16 @@
"""
pass

@abstractmethod
def update_type_params(
self,
state_dict: Dict[str, t_tensor],
mapping_index: List[int],
prefix: str = "",
) -> Dict[str, t_tensor]:
iProzd marked this conversation as resolved.
Show resolved Hide resolved
"""Update the type related params when loading from pretrained model with redundant types."""
pass

Check warning on line 117 in deepmd/dpmodel/descriptor/make_base_descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/make_base_descriptor.py#L117

Added line #L117 was not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved

def compute_input_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
Expand Down
10 changes: 10 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import copy
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -263,6 +264,15 @@
"""
raise NotImplementedError

def update_type_params(
self,
state_dict: Dict[str, np.ndarray],
mapping_index: List[int],
prefix: str = "",
) -> Dict[str, np.ndarray]:
"""Update the type related params when loading from pretrained model with redundant types."""
raise NotImplementedError

Check warning on line 274 in deepmd/dpmodel/descriptor/se_e2_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_e2_a.py#L274

Added line #L274 was not covered by tests

iProzd marked this conversation as resolved.
Show resolved Hide resolved
def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes
Expand Down
10 changes: 10 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import copy
from typing import (
Any,
Dict,
List,
Optional,
)
Expand Down Expand Up @@ -219,6 +220,15 @@
"""
raise NotImplementedError

def update_type_params(
self,
state_dict: Dict[str, np.ndarray],
mapping_index: List[int],
prefix: str = "",
) -> Dict[str, np.ndarray]:
"""Update the type related params when loading from pretrained model with redundant types."""
raise NotImplementedError

Check warning on line 230 in deepmd/dpmodel/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_r.py#L230

Added line #L230 was not covered by tests

iProzd marked this conversation as resolved.
Show resolved Hide resolved
def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes
Expand Down
9 changes: 9 additions & 0 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@
]
)

def update_type_params(
self,
state_dict: Dict[str, np.ndarray],
mapping_index: List[int],
prefix: str = "",
) -> Dict[str, np.ndarray]:
"""Update the type related params when loading from pretrained model with redundant types."""
raise NotImplementedError

Check warning on line 185 in deepmd/dpmodel/fitting/dipole_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/dipole_fitting.py#L185

Added line #L185 was not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved

def call(
self,
descriptor: np.ndarray,
Expand Down
9 changes: 9 additions & 0 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,15 @@
obj.nets = NetworkCollection.deserialize(nets)
return obj

def update_type_params(
self,
state_dict: Dict[str, np.ndarray],
mapping_index: List[int],
prefix: str = "",
) -> Dict[str, np.ndarray]:
"""Update the type related params when loading from pretrained model with redundant types."""
raise NotImplementedError

Check warning on line 280 in deepmd/dpmodel/fitting/general_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/general_fitting.py#L280

Added line #L280 was not covered by tests

iProzd marked this conversation as resolved.
Show resolved Hide resolved
def _call_common(
self,
descriptor: np.ndarray,
Expand Down
11 changes: 11 additions & 0 deletions deepmd/dpmodel/fitting/make_base_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
)
from typing import (
Dict,
List,
Optional,
)

Expand Down Expand Up @@ -63,6 +64,16 @@
"""Calculate fitting."""
pass

@abstractmethod
def update_type_params(
self,
state_dict: Dict[str, t_tensor],
mapping_index: List[int],
prefix: str = "",
) -> Dict[str, t_tensor]:
"""Update the type related params when loading from pretrained model with redundant types."""
pass

Check warning on line 75 in deepmd/dpmodel/fitting/make_base_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/make_base_fitting.py#L75

Added line #L75 was not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved

def compute_output_stats(self, merged):
"""Update the output bias for fitting net."""
raise NotImplementedError
Expand Down
9 changes: 9 additions & 0 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,15 @@
]
)

def update_type_params(
self,
state_dict: Dict[str, np.ndarray],
mapping_index: List[int],
prefix: str = "",
) -> Dict[str, np.ndarray]:
"""Update the type related params when loading from pretrained model with redundant types."""
raise NotImplementedError

Check warning on line 226 in deepmd/dpmodel/fitting/polarizability_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/polarizability_fitting.py#L225-L226

Added lines #L225 - L226 were not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved
def call(
self,
descriptor: np.ndarray,
Expand Down
5 changes: 5 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ def main_parser() -> argparse.ArgumentParser:
default=None,
help="Finetune the frozen pretrained model.",
)
parser_train.add_argument(
"--use-pretrain-script",
action="store_true",
help="(Supported Backend: PyTorch) Use model params in the script of the pretrained model instead of user input.",
iProzd marked this conversation as resolved.
Show resolved Hide resolved
)
iProzd marked this conversation as resolved.
Show resolved Hide resolved
parser_train.add_argument(
"-o",
"--output",
Expand Down
53 changes: 19 additions & 34 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
DEVICE,
)
from deepmd.pt.utils.finetune import (
change_finetune_model_params,
get_finetune_rules,
)
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
Expand All @@ -81,10 +81,10 @@ def get_trainer(
init_model=None,
restart_model=None,
finetune_model=None,
model_branch="",
force_load=False,
init_frz_model=None,
shared_links=None,
finetune_links=None,
):
multi_task = "model_dict" in config.get("model", {})

Expand All @@ -95,23 +95,8 @@ def get_trainer(
assert dist.is_nccl_available()
dist.init_process_group(backend="nccl")

ckpt = init_model if init_model is not None else restart_model
finetune_links = None
if finetune_model is not None:
config["model"], finetune_links = change_finetune_model_params(
finetune_model,
config["model"],
model_branch=model_branch,
)
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)

def prepare_trainer_input_single(
model_params_single, data_dict_single, loss_dict_single, suffix="", rank=0
):
def prepare_trainer_input_single(model_params_single, data_dict_single, rank=0):
training_dataset_params = data_dict_single["training_data"]
type_split = False
if model_params_single["descriptor"]["type"] in ["se_e2_a"]:
type_split = True
validation_dataset_params = data_dict_single.get("validation_data", None)
validation_systems = (
validation_dataset_params["systems"] if validation_dataset_params else None
Expand Down Expand Up @@ -144,18 +129,11 @@ def prepare_trainer_input_single(
if validation_systems
else None
)
if ckpt or finetune_model:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single["type_map"],
)
else:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single["type_map"],
)
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single["type_map"],
)
return (
train_data_single,
validation_data_single,
Expand All @@ -171,7 +149,6 @@ def prepare_trainer_input_single(
) = prepare_trainer_input_single(
config["model"],
config["training"],
config["loss"],
rank=rank,
)
else:
Expand All @@ -184,8 +161,6 @@ def prepare_trainer_input_single(
) = prepare_trainer_input_single(
config["model"]["model_dict"][model_key],
config["training"]["data_dict"][model_key],
config["loss_dict"][model_key],
suffix=f"_{model_key}",
rank=rank,
)

Expand Down Expand Up @@ -245,6 +220,16 @@ def train(FLAGS):
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])

# update fine-tuning config
finetune_links = None
if FLAGS.finetune is not None:
config["model"], finetune_links = get_finetune_rules(
FLAGS.finetune,
config["model"],
model_branch=FLAGS.model_branch,
change_model_params=FLAGS.use_pretrain_script,
)

# argcheck
if not multi_task:
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
Expand Down Expand Up @@ -279,10 +264,10 @@ def train(FLAGS):
FLAGS.init_model,
FLAGS.restart,
FLAGS.finetune,
FLAGS.model_branch,
FLAGS.force_load,
FLAGS.init_frz_model,
shared_links=shared_links,
finetune_links=finetune_links,
)
trainer.run()

Expand Down
1 change: 0 additions & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def __init__(
item.replace(f"model.{head}.", "model.Default.")
] = state_dict[item].clone()
state_dict = state_dict_head
self.input_param["resuming"] = True
model = get_model(self.input_param).to(DEVICE)
model = torch.jit.script(model)
self.dp = ModelWrapper(model)
Expand Down
1 change: 0 additions & 1 deletion deepmd/pt/infer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def __init__(
state_dict = state_dict_head

self.model_params = deepcopy(model_params)
model_params["resuming"] = True
self.model = get_model(model_params).to(DEVICE)

# Model Wrapper
Expand Down
Loading
Loading