diff --git a/examples/pytorch/test_examples.py b/examples/pytorch/test_examples.py index 1547fc84d714a3..0981bfadb25fa0 100644 --- a/examples/pytorch/test_examples.py +++ b/examples/pytorch/test_examples.py @@ -195,6 +195,7 @@ def test_run_ner(self): --per_device_train_batch_size=2 --per_device_eval_batch_size=2 --num_train_epochs={epochs} + --seed 7 """.split() if torch_device != "cuda": diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py old mode 100755 new mode 100644 index 66875a02829797..33e66697baa070 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -18,6 +18,7 @@ import os import re import warnings +from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -50,6 +51,26 @@ logger = logging.get_logger(__name__) + +_init_weights = True + + +@contextmanager +def no_init_weights(_enable=True): + """ + Context manager to globally disable weight initialization to speed up loading large models. + + TODO(Patrick): Delete safety argument `_enable=True` at next major version. . + """ + global _init_weights + if _enable: + _init_weights = False + try: + yield + finally: + _init_weights = True + + try: from torch.nn import Identity except ImportError: @@ -766,17 +787,19 @@ def _get_resized_lm_head( def init_weights(self): """ - Initializes and prunes weights if needed. + If needed prunes and maybe initializes weights. """ - # Initialize weights - self.apply(self._init_weights) - # Prune heads if needed if self.config.pruned_heads: self.prune_heads(self.config.pruned_heads) - # Tie weights if needed - self.tie_weights() + if _init_weights: + # Initialize weights + self.apply(self._init_weights) + + # Tie weights should be skipped when not initializing all weights + # since from_pretrained(...) calls tie weights anyways + self.tie_weights() def prune_heads(self, heads_to_prune: Dict[int, List[int]]): """ @@ -954,6 +977,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information. + _fast_init(:obj:`bool`, `optional`, defaults to `:obj:`True`): + Whether or not to disable fast initialization. + + .. warning:: + + One should only disable `_fast_init` to ensure backwards compatibility with + ``transformers.__version__ < 4.6.0`` for seeded model initialization. This argument will be removed + at the next major version. See `pull request 11471 + `__ for more information. + kwargs (remaining dictionary of keyword arguments, `optional`): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or @@ -1010,6 +1043,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P mirror = kwargs.pop("mirror", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) + _fast_init = kwargs.pop("_fast_init", True) user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} if from_pipeline is not None: @@ -1117,7 +1151,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P config.name_or_path = pretrained_model_name_or_path # Instantiate model. - if is_deepspeed_zero3_enabled(): import deepspeed @@ -1125,23 +1158,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # this immediately partitions the model across all gpus, to avoid the overhead in time # and memory copying it on CPU or each GPU first with deepspeed.zero.Init(config=deepspeed_config()): - model = cls(config, *model_args, **model_kwargs) + with no_init_weights(_enable=_fast_init): + model = cls(config, *model_args, **model_kwargs) else: - model = cls(config, *model_args, **model_kwargs) - - if state_dict is None and not (from_tf or from_flax): - try: - state_dict = torch.load(resolved_archive_file, map_location="cpu") - except Exception: - raise OSError( - f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' " - f"at '{resolved_archive_file}'" - "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " - ) - - missing_keys = [] - unexpected_keys = [] - error_msgs = [] + with no_init_weights(_enable=_fast_init): + model = cls(config, *model_args, **model_kwargs) if from_tf: if resolved_archive_file.endswith(".index"): @@ -1171,102 +1192,20 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) raise else: - # Convert old format to new format if needed from a PyTorch state_dict - old_keys = [] - new_keys = [] - for key in state_dict.keys(): - new_key = None - if "gamma" in key: - new_key = key.replace("gamma", "weight") - if "beta" in key: - new_key = key.replace("beta", "bias") - if new_key: - old_keys.append(key) - new_keys.append(new_key) - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, "_metadata", None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants - # so we need to apply the function recursively. - def load(module: nn.Module, prefix=""): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - args = (state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) - if is_deepspeed_zero3_enabled(): - import deepspeed - - # because zero3 puts placeholders in model params, this context - # manager gathers (unpartitions) the params of the current layer, then loads from - # the state dict and then re-partitions them again - with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): - if torch.distributed.get_rank() == 0: - module._load_from_state_dict(*args) - else: - module._load_from_state_dict(*args) - - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + ".") - - # Make sure we are able to load base models as well as derived models (with heads) - start_prefix = "" - model_to_load = model - has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()) - if not hasattr(model, cls.base_model_prefix) and has_prefix_module: - start_prefix = cls.base_model_prefix + "." - if hasattr(model, cls.base_model_prefix) and not has_prefix_module: - model_to_load = getattr(model, cls.base_model_prefix) - - load(model_to_load, prefix=start_prefix) - - if model.__class__.__name__ != model_to_load.__class__.__name__: - base_model_state_dict = model_to_load.state_dict().keys() - head_model_state_dict_without_base_prefix = [ - key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys() - ] - missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) - - # Some models may have keys that are not in the state by design, removing them before needlessly warning - # the user. - if cls._keys_to_ignore_on_load_missing is not None: - for pat in cls._keys_to_ignore_on_load_missing: - missing_keys = [k for k in missing_keys if re.search(pat, k) is None] - - if cls._keys_to_ignore_on_load_unexpected is not None: - for pat in cls._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " - f"initializing {model.__class__.__name__}: {unexpected_keys}\n" - f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " - f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" - f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " - f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." - ) - else: - logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") - if len(missing_keys) > 0: - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " - f"and are newly initialized: {missing_keys}\n" - f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." - ) - else: - logger.info( - f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" - f"If your task is similar to the task the model of the checkpoint was trained on, " - f"you can already use {model.__class__.__name__} for predictions without further training." - ) - if len(error_msgs) > 0: - error_msg = "\n\t".join(error_msgs) - raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + if state_dict is None: + try: + state_dict = torch.load(resolved_archive_file, map_location="cpu") + except Exception: + raise OSError( + f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' " + f"at '{resolved_archive_file}'" + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " + ) + + model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model( + model, state_dict, pretrained_model_name_or_path + ) + # make sure token embedding weights are still tied if needed model.tie_weights() @@ -1283,6 +1222,142 @@ def load(module: nn.Module, prefix=""): return model + @classmethod + def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path): + + # Convert old format to new format if needed from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + # Retrieve missing & unexpected_keys + expected_keys = list(model.state_dict().keys()) + loaded_keys = list(state_dict.keys()) + prefix = model.base_model_prefix + + has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) + expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) + remove_prefix = not has_prefix_module and expects_prefix_module + add_prefix = has_prefix_module and not expects_prefix_module + + if remove_prefix: + expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys] + elif add_prefix: + expected_keys = [".".join([prefix, s]) for s in expected_keys] + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Some models may have keys that are not in the state by design, removing them before needlessly warning + # the user. + if cls._keys_to_ignore_on_load_missing is not None: + for pat in cls._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + # tie unintialized modules + unintialized_modules = model.retrieve_modules_from_names( + missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix + ) + for module in unintialized_modules: + model._init_weights(module) + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: nn.Module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + if is_deepspeed_zero3_enabled(): + import deepspeed + + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(*args) + else: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + # Make sure we are able to load base models as well as derived models (with heads) + start_prefix = "" + model_to_load = model + if not hasattr(model, cls.base_model_prefix) and has_prefix_module: + start_prefix = cls.base_model_prefix + "." + if hasattr(model, cls.base_model_prefix) and not has_prefix_module: + model_to_load = getattr(model, cls.base_model_prefix) + + load(model_to_load, prefix=start_prefix) + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " + f"initializing {model.__class__.__name__}: {unexpected_keys}\n" + f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" + f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " + f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " + f"and are newly initialized: {missing_keys}\n" + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + else: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" + f"If your task is similar to the task the model of the checkpoint was trained on, " + f"you can already use {model.__class__.__name__} for predictions without further training." + ) + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + return model, missing_keys, unexpected_keys, error_msgs + + def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): + module_keys = set([".".join(key.split(".")[:-1]) for key in names]) + + retrieved_modules = [] + # retrieve all modules that has at least one missing weight name + for name, module in self.named_modules(): + if remove_prefix: + name = ".".join(name.split(".")[1:]) if name.startswith(self.base_model_prefix) else name + elif add_prefix: + name = ".".join([self.base_model_prefix, name]) + + if name in module_keys: + retrieved_modules.append(module) + + return retrieved_modules + class Conv1D(nn.Module): """ diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d193a9e7a47862..a782ce378ef159 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -176,6 +176,103 @@ def test_save_load__keys_to_ignore_on_save(self): for k in _keys_to_ignore_on_save: self.assertNotIn(k, state_dict_saved) + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + + def test_save_load_fast_init_from_base(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + base_class = MODEL_MAPPING[config.__class__] + + if isinstance(base_class, tuple): + base_class = base_class[0] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + # make a copy of model class to not break future tests + # from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class + class CopyClass(model_class): + pass + + model_class_copy = CopyClass + + # make sure that all keys are expected for test + model_class_copy._keys_to_ignore_on_load_missing = [] + + # make init deterministic, but make sure that + # non-initialized weights throw errors nevertheless + model_class_copy._init_weights = self._mock_init_weights + + model = base_class(config) + state_dict = model.state_dict() + + # this will often delete a single weight of a multi-weight module + # to test an edge case + random_key_to_del = random.choice(list(state_dict.keys())) + del state_dict[random_key_to_del] + + # check that certain keys didn't get saved with the model + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin")) + + model_fast_init = model_class_copy.from_pretrained(tmpdirname) + model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False) + + for key in model_fast_init.state_dict().keys(): + max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + def test_save_load_fast_init_to_base(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + base_class = MODEL_MAPPING[config.__class__] + + if isinstance(base_class, tuple): + base_class = base_class[0] + + for model_class in self.all_model_classes: + + if model_class == base_class: + continue + + # make a copy of model class to not break future tests + # from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class + class CopyClass(base_class): + pass + + base_class_copy = CopyClass + + # make sure that all keys are expected for test + base_class_copy._keys_to_ignore_on_load_missing = [] + + # make init deterministic, but make sure that + # non-initialized weights throw errors nevertheless + base_class_copy._init_weights = self._mock_init_weights + + model = model_class(config) + state_dict = model.state_dict() + + # this will often delete a single weight of a multi-weight module + # to test an edge case + random_key_to_del = random.choice(list(state_dict.keys())) + del state_dict[random_key_to_del] + + # check that certain keys didn't get saved with the model + with tempfile.TemporaryDirectory() as tmpdirname: + model.config.save_pretrained(tmpdirname) + torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin")) + + model_fast_init = base_class_copy.from_pretrained(tmpdirname) + model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False) + + for key in model_fast_init.state_dict().keys(): + max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_funnel.py b/tests/test_modeling_funnel.py index 2d59e9f4e4100d..7e8190f00b97d2 100644 --- a/tests/test_modeling_funnel.py +++ b/tests/test_modeling_funnel.py @@ -399,6 +399,18 @@ def test_for_question_answering(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + + for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]: + if hasattr(module, param) and getattr(module, param) is not None: + weight = getattr(module, param) + weight.data.fill_(3) + @require_torch class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase): @@ -442,6 +454,18 @@ def test_training(self): loss = model(**inputs).loss loss.backward() + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + + for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]: + if hasattr(module, param) and getattr(module, param) is not None: + weight = getattr(module, param) + weight.data.fill_(3) + @require_torch @require_sentencepiece diff --git a/tests/test_modeling_transfo_xl.py b/tests/test_modeling_transfo_xl.py index 6f771ece01dfeb..adbaf3642e8b3b 100644 --- a/tests/test_modeling_transfo_xl.py +++ b/tests/test_modeling_transfo_xl.py @@ -348,6 +348,31 @@ def _check_hidden_states_for_generate( [expected_shape] * len(iter_hidden_states), ) + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "cluster_weight") and module.cluster_weight is not None: + module.cluster_weight.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + if hasattr(module, "cluster_bias") and module.cluster_bias is not None: + module.cluster_bias.data.fill_(3) + + if hasattr(module, "emb_projs"): + for i in range(len(module.emb_projs)): + if module.emb_projs[i] is not None: + torch.nn.init.constant_(module.emb_projs[i], 0.0003) + if hasattr(module, "out_projs"): + for i in range(len(module.out_projs)): + if module.out_projs[i] is not None: + torch.nn.init.constant_(module.out_projs[i], 0.0003) + + for param in ["r_emb", "r_w_bias", "r_r_bias", "r_bias"]: + if hasattr(module, param) and getattr(module, param) is not None: + weight = getattr(module, param) + weight.data.fill_(3) + @require_torch class TransfoXLModelLanguageGenerationTest(unittest.TestCase): diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index abb57eb9af3053..f2bb897e55129d 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -329,6 +329,15 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "weight_g") and module.weight is not None: + module.weight_g.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + @slow def test_model_from_pretrained(self): model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") @@ -446,6 +455,15 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "weight_g") and module.weight is not None: + module.weight_g.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + @slow def test_model_from_pretrained(self): model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index 1423ef6980f2eb..7544334028e477 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -593,6 +593,18 @@ def test_retain_grad_hidden_states_attentions(self): # xlnet cannot keep gradients in attentions or hidden states return + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + + for param in ["q", "k", "v", "o", "r", "r_r_bias", "r_s_bias", "r_w_bias", "seg_embed", "mask_emb"]: + if hasattr(module, param) and getattr(module, param) is not None: + weight = getattr(module, param) + weight.data.fill_(3) + def _check_hidden_states_for_generate( self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 ):