Skip to content

Commit

Permalink
improve lazy load
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Apr 30, 2021
1 parent 33b0912 commit 15409fb
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 22 deletions.
1 change: 0 additions & 1 deletion datasets
Submodule datasets deleted from 8afd0b
30 changes: 24 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,10 +1237,12 @@ def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or

has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
remove_prefix = not has_prefix_module and expects_prefix_module
add_prefix = has_prefix_module and not expects_prefix_module

if not has_prefix_module and expects_prefix_module:
if remove_prefix:
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
elif has_prefix_module and not expects_prefix_module:
elif add_prefix:
expected_keys = [
".".join([prefix, s]) if ".".join([prefix, s]) in set(loaded_keys) else s for s in expected_keys
]
Expand All @@ -1259,7 +1261,9 @@ def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

# tie unintialized modules
unintialized_modules = model.retrieve_modules_from_names(missing_keys)
unintialized_modules = model.retrieve_modules_from_names(
missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix
)
for module in unintialized_modules:
model._init_weights(module)
# copy state_dict so _load_from_state_dict can modify it
Expand All @@ -1274,7 +1278,7 @@ def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or
# so we need to apply the function recursively.
def load(module: nn.Module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if is_deepspeed_zero3_enabled():
import deepspeed

Expand Down Expand Up @@ -1330,9 +1334,23 @@ def load(module: nn.Module, prefix=""):

return model, missing_keys, unexpected_keys, error_msgs

def retrieve_modules_from_names(self, names):
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
module_keys = [".".join(key.split(".")[:-1]) for key in names]
return [module for name, module in self.named_modules() if name in module_keys]

retrieved_modules = []
# retrieve all modules that has at least one missing weight name
for name, module in self.named_modules():
if remove_prefix:
name = ".".join(name.split(".")[1:]) if name.startswith(self.base_model_prefix) else name
elif add_prefix:
name = (
".".join([self.base_model_prefix, name]) if not name.startswith(self.base_model_prefix) else name
)
#
if name in module_keys:
retrieved_modules.append(module)

return retrieved_modules


class Conv1D(nn.Module):
Expand Down
7 changes: 0 additions & 7 deletions tests/fixtures/tests_samples/MRPC/dev.csv

This file was deleted.

7 changes: 0 additions & 7 deletions tests/fixtures/tests_samples/MRPC/train.csv

This file was deleted.

36 changes: 35 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,34 @@ def test_save_load__keys_to_ignore_on_save(self):
def test_save_load_fast_init_from_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
base_class = MODEL_MAPPING[config.__class__]

def _mock_init_weights(self, module):
if hasattr(module, "weight"):
module.weight.data.fill_(3)
if hasattr(module, "bias"):
module.bias.data.fill_(3)

for model_class in self.all_model_classes:
if model_class == base_class:
continue

# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
model_class._init_weights = _mock_init_weights

model = base_class(config)
state_dict = model.state_dict()

# this will often delete a single weight of a multi-weight module
# to test an edge case
random_key_to_del = random.choice(list(state_dict.keys()))
del state_dict[random_key_to_del]

# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))

torch.manual_seed(0)
model_fast_init = model_class.from_pretrained(tmpdirname)
model_slow_init = model_class.from_pretrained(tmpdirname, _no_fast_init=True)

Expand All @@ -201,15 +218,32 @@ def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
base_class = MODEL_MAPPING[config.__class__]

def _mock_init_weights(self, module):
if hasattr(module, "weight"):
module.weight.data.fill_(3)
if hasattr(module, "bias"):
module.bias.data.fill_(3)

for model_class in self.all_model_classes:
if model_class == base_class:
continue

# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
model_class._init_weights = _mock_init_weights

model = model_class(config)
state_dict = model.state_dict()

# this will often delete a single weight of a multi-weight module
# to test an edge case
random_key_to_del = random.choice(list(state_dict.keys()))
del state_dict[random_key_to_del]

# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))

torch.manual_seed(0)
model_fast_init = base_class.from_pretrained(tmpdirname, _no_fast_init=True)
Expand Down

0 comments on commit 15409fb

Please sign in to comment.