Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch - Lazy initialization of models #11471

Merged
merged 29 commits into from
May 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6a1cc07
lazy_init_weights
patrickvonplaten Apr 27, 2021
4c92704
remove ipdb
patrickvonplaten Apr 27, 2021
2230ff5
save int
patrickvonplaten Apr 30, 2021
d43ed82
merge conflict
patrickvonplaten Apr 30, 2021
ac3710c
add necessary code
patrickvonplaten Apr 30, 2021
a5107a7
remove unnecessary utils
patrickvonplaten Apr 30, 2021
881a86d
Update src/transformers/models/t5/modeling_t5.py
patrickvonplaten Apr 30, 2021
cc546e7
Merge branch 'lazy_init' of https://github.com/patrickvonplaten/trans…
patrickvonplaten Apr 30, 2021
69f3fee
clean
patrickvonplaten Apr 30, 2021
33b0912
add tests
patrickvonplaten Apr 30, 2021
6b27fe8
correct
patrickvonplaten Apr 30, 2021
6d3f829
finish tests
patrickvonplaten Apr 30, 2021
c7ab932
finish tests
patrickvonplaten Apr 30, 2021
bdc03c1
fix some more tests
patrickvonplaten Apr 30, 2021
6314565
fix xlnet & transfo-xl
patrickvonplaten May 3, 2021
38629a2
fix merge conflict
patrickvonplaten May 3, 2021
df21f7b
fix more tests
patrickvonplaten May 3, 2021
5c8097b
make sure tests are independent
patrickvonplaten May 3, 2021
ca32874
fix tests more
patrickvonplaten May 3, 2021
d025196
finist tests
patrickvonplaten May 3, 2021
806e27a
final touches
patrickvonplaten May 3, 2021
7f134b7
Update src/transformers/modeling_utils.py
patrickvonplaten May 3, 2021
2356a32
Apply suggestions from code review
patrickvonplaten May 3, 2021
3d63d77
Update src/transformers/modeling_utils.py
patrickvonplaten May 5, 2021
4e051e8
Update src/transformers/modeling_utils.py
patrickvonplaten May 5, 2021
eeff9a6
clean tests
patrickvonplaten May 5, 2021
a63e0b1
Merge branch 'lazy_init' of https://github.com/patrickvonplaten/trans…
patrickvonplaten May 5, 2021
59a0020
give arg positive name
patrickvonplaten May 5, 2021
9acb97a
add more mock weights to xlnet
patrickvonplaten May 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/pytorch/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def test_run_ner(self):
--per_device_train_batch_size=2
--per_device_eval_batch_size=2
--num_train_epochs={epochs}
--seed 7
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
""".split()

if torch_device != "cuda":
Expand Down
313 changes: 194 additions & 119 deletions src/transformers/modeling_utils.py
100755 → 100644

Large diffs are not rendered by default.

97 changes: 97 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,103 @@ def test_save_load__keys_to_ignore_on_save(self):
for k in _keys_to_ignore_on_save:
self.assertNotIn(k, state_dict_saved)

def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)

def test_save_load_fast_init_from_base(self):
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
base_class = MODEL_MAPPING[config.__class__]

if isinstance(base_class, tuple):
base_class = base_class[0]

for model_class in self.all_model_classes:
if model_class == base_class:
continue

# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(model_class):
pass
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

model_class_copy = CopyClass

# make sure that all keys are expected for test
model_class_copy._keys_to_ignore_on_load_missing = []

# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
model_class_copy._init_weights = self._mock_init_weights

model = base_class(config)
state_dict = model.state_dict()

# this will often delete a single weight of a multi-weight module
# to test an edge case
random_key_to_del = random.choice(list(state_dict.keys()))
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
del state_dict[random_key_to_del]

# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))

model_fast_init = model_class_copy.from_pretrained(tmpdirname)
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)

for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
base_class = MODEL_MAPPING[config.__class__]

if isinstance(base_class, tuple):
base_class = base_class[0]

for model_class in self.all_model_classes:

if model_class == base_class:
continue

# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(base_class):
pass

base_class_copy = CopyClass

# make sure that all keys are expected for test
base_class_copy._keys_to_ignore_on_load_missing = []

# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
base_class_copy._init_weights = self._mock_init_weights

model = model_class(config)
state_dict = model.state_dict()

# this will often delete a single weight of a multi-weight module
# to test an edge case
random_key_to_del = random.choice(list(state_dict.keys()))
del state_dict[random_key_to_del]

# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.config.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))

model_fast_init = base_class_copy.from_pretrained(tmpdirname)
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)

for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down
24 changes: 24 additions & 0 deletions tests/test_modeling_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,18 @@ def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)

# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)

for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]:
if hasattr(module, param) and getattr(module, param) is not None:
weight = getattr(module, param)
weight.data.fill_(3)


@require_torch
class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase):
Expand Down Expand Up @@ -442,6 +454,18 @@ def test_training(self):
loss = model(**inputs).loss
loss.backward()

# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)

for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]:
if hasattr(module, param) and getattr(module, param) is not None:
weight = getattr(module, param)
weight.data.fill_(3)


@require_torch
@require_sentencepiece
Expand Down
25 changes: 25 additions & 0 deletions tests/test_modeling_transfo_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,31 @@ def _check_hidden_states_for_generate(
[expected_shape] * len(iter_hidden_states),
)

# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "cluster_weight") and module.cluster_weight is not None:
module.cluster_weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
if hasattr(module, "cluster_bias") and module.cluster_bias is not None:
module.cluster_bias.data.fill_(3)

if hasattr(module, "emb_projs"):
for i in range(len(module.emb_projs)):
if module.emb_projs[i] is not None:
torch.nn.init.constant_(module.emb_projs[i], 0.0003)
if hasattr(module, "out_projs"):
for i in range(len(module.out_projs)):
if module.out_projs[i] is not None:
torch.nn.init.constant_(module.out_projs[i], 0.0003)

for param in ["r_emb", "r_w_bias", "r_r_bias", "r_bias"]:
if hasattr(module, param) and getattr(module, param) is not None:
weight = getattr(module, param)
weight.data.fill_(3)


@require_torch
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,15 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)

# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight is not None:
module.weight_g.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)

@slow
def test_model_from_pretrained(self):
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
Expand Down Expand Up @@ -446,6 +455,15 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)

# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight is not None:
module.weight_g.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)

@slow
def test_model_from_pretrained(self):
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
Expand Down
12 changes: 12 additions & 0 deletions tests/test_modeling_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,18 @@ def test_retain_grad_hidden_states_attentions(self):
# xlnet cannot keep gradients in attentions or hidden states
return

# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)

for param in ["q", "k", "v", "o", "r", "r_r_bias", "r_s_bias", "r_w_bias", "seg_embed", "mask_emb"]:
if hasattr(module, param) and getattr(module, param) is not None:
weight = getattr(module, param)
weight.data.fill_(3)

def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
Expand Down