Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Apr 30, 2021
1 parent 69f3fee commit 33b0912
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
18 changes: 9 additions & 9 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,16 +786,16 @@ def init_weights(self):
"""
Maybe initializes and prunes weights if needed.
"""
# Prune heads if needed
if self.config.pruned_heads:
self.prune_heads(self.config.pruned_heads)

if not _init_weights:
return

# Initialize weights
self.apply(self._init_weights)

# Prune heads if needed
if self.config.pruned_heads:
self.prune_heads(self.config.pruned_heads)

# Tie weights if needed
self.tie_weights()

Expand Down Expand Up @@ -1163,9 +1163,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
try:
from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model

model, missing_keys = load_tf2_checkpoint_in_pytorch_model(
model, resolved_archive_file, allow_missing_keys=True
)
model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
Expand Down Expand Up @@ -1241,9 +1239,11 @@ def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)

if not has_prefix_module and expects_prefix_module:
expected_keys = [s.split(prefix)[-1] for s in expected_keys if s.startswith(prefix)]
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
elif has_prefix_module and not expects_prefix_module:
expected_keys = [".".join([prefix, s]) for s in expected_keys if ".".join([prefix, s]) in set(loaded_keys)]
expected_keys = [
".".join([prefix, s]) if ".".join([prefix, s]) in set(loaded_keys) else s for s in expected_keys
]

missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
Expand Down
45 changes: 44 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,49 @@ def test_save_load__keys_to_ignore_on_save(self):
for k in _keys_to_ignore_on_save:
self.assertNotIn(k, state_dict_saved)

def test_save_load_fast_init_from_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
base_class = MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue

model = base_class(config)

# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

torch.manual_seed(0)
model_fast_init = model_class.from_pretrained(tmpdirname)
model_slow_init = model_class.from_pretrained(tmpdirname, _no_fast_init=True)

for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
base_class = MODEL_MAPPING[config.__class__]

for model_class in self.all_model_classes:
if model_class == base_class:
continue

model = model_class(config)

# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

torch.manual_seed(0)
model_fast_init = base_class.from_pretrained(tmpdirname, _no_fast_init=True)
model_slow_init = base_class.from_pretrained(tmpdirname, _no_fast_init=True)

for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down Expand Up @@ -635,7 +678,7 @@ def test_head_pruning_integration(self):
if not self.test_pruning:
return

for model_class in self.all_model_classes:
for model_class in self.all_model_classes[:1]:
(
config,
inputs_dict,
Expand Down

0 comments on commit 33b0912

Please sign in to comment.