Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC-Breaking] Move fine-tune specific module out of wav2vec2 encoder #1782

Merged
merged 2 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _test_import_finetune(self, original, imported, config):
# Readout
x = torch.randn(3, 10, config["hidden_size"])
ref = original.lm_head(x)
hyp = imported.encoder.readout(x)
hyp = imported.aux(x)
self.assertEqual(ref, hyp)
# The whole model without mask
x = torch.randn(3, 1024)
Expand Down Expand Up @@ -195,8 +195,8 @@ def _test_recreate(self, imported, reloaded, config):
self.assertEqual(ref, hyp)
# Readout
x = torch.randn(3, 10, config["hidden_size"])
ref = imported.encoder.readout(x)
hyp = reloaded.encoder.readout(x)
ref = imported.aux(x)
hyp = reloaded.aux(x)
self.assertEqual(ref, hyp)
# The whole model
x = torch.randn(3, 1024)
Expand All @@ -208,7 +208,7 @@ def _test_recreate(self, imported, reloaded, config):
def test_recreate_pretrain(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.encoder.readout.out_features)
reloaded = factory_func(num_out=imported.aux.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
self._test_recreate(imported, reloaded, config)
Expand All @@ -217,7 +217,7 @@ def test_recreate_pretrain(self, config, factory_func):
def test_recreate_finetune(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.encoder.readout.out_features)
reloaded = factory_func(num_out=imported.aux.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
self._test_recreate(imported, reloaded, config)
10 changes: 1 addition & 9 deletions torchaudio/models/wav2vec2/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,10 @@ def __init__(
self,
feature_projection: Module,
transformer: Module,
readout: Module,
):
super().__init__()
self.feature_projection = feature_projection
self.transformer = transformer
self.readout = readout

def _preprocess(
self,
Expand All @@ -458,7 +456,6 @@ def forward(
) -> Tensor:
x, mask = self._preprocess(features, lengths)
x = self.transformer(x, attention_mask=mask)
x = self.readout(x)
return x

def extract_features(
Expand Down Expand Up @@ -561,7 +558,6 @@ def _get_encoder(
dropout: float,
layer_norm_first: bool,
layer_drop: float,
num_out: int,
) -> Encoder:
"""
Args:
Expand Down Expand Up @@ -720,8 +716,4 @@ def _get_encoder(
layer_norm_first=not layer_norm_first,
layer_drop=layer_drop,
)
readout = nn.Linear(
in_features=embed_dim,
out_features=num_out,
)
return Encoder(feature_projection, transformer, readout)
return Encoder(feature_projection, transformer)
25 changes: 18 additions & 7 deletions torchaudio/models/wav2vec2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,20 @@ class Wav2Vec2Model(Module):
encoder (torch.nn.Module):
Encoder that converts the audio features into the sequence of probability
distribution (in negative log-likelihood) over labels.

aux (Optional[torch.nn.Module]):
mthrok marked this conversation as resolved.
Show resolved Hide resolved
Auxiliary module. If provided, the output from encoder is passed to this module.
"""
def __init__(
self,
feature_extractor: Module,
encoder: Module,
aux: Optional[Module] = None,
):
super().__init__()
self.feature_extractor = feature_extractor
self.encoder = encoder
self.aux = aux

@torch.jit.export
def extract_features(
Expand Down Expand Up @@ -89,7 +94,10 @@ def forward(
Shape: ``(batch, )``.
"""
x, lengths = self.feature_extractor(waveforms, lengths)
return self.encoder(x, lengths), lengths
x = self.encoder(x, lengths)
if self.aux is not None:
x = self.aux(x)
return x, lengths


def _get_model(
Expand All @@ -108,7 +116,7 @@ def _get_model(
encoder_dropout: float,
encoder_layer_norm_first: bool,
encoder_layer_drop: float,
encoder_num_out: int,
aux_num_out: int,
) -> Wav2Vec2Model:
if extractor_conv_layer_config is None:
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
Expand All @@ -129,9 +137,12 @@ def _get_model(
dropout=encoder_dropout,
layer_norm_first=encoder_layer_norm_first,
layer_drop=encoder_layer_drop,
num_out=encoder_num_out,
)
return Wav2Vec2Model(feature_extractor, encoder)
aux = torch.nn.Linear(
in_features=encoder_embed_dim,
out_features=aux_num_out,
)
return Wav2Vec2Model(feature_extractor, encoder, aux)


def wav2vec2_base(num_out: int) -> Wav2Vec2Model:
Expand Down Expand Up @@ -172,7 +183,7 @@ def wav2vec2_base(num_out: int) -> Wav2Vec2Model:
encoder_dropout=0.1,
encoder_layer_norm_first=False,
encoder_layer_drop=0.1,
encoder_num_out=num_out,
aux_num_out=num_out,
)


Expand Down Expand Up @@ -214,7 +225,7 @@ def wav2vec2_large(num_out: int) -> Wav2Vec2Model:
encoder_dropout=0.1,
encoder_layer_norm_first=False,
encoder_layer_drop=0.1,
encoder_num_out=num_out,
aux_num_out=num_out,
)


Expand Down Expand Up @@ -256,5 +267,5 @@ def wav2vec2_large_lv60k(num_out: int) -> Wav2Vec2Model:
encoder_dropout=0.0,
encoder_layer_norm_first=True,
encoder_layer_drop=0.1,
encoder_num_out=num_out,
aux_num_out=num_out,
)
4 changes: 2 additions & 2 deletions torchaudio/models/wav2vec2/utils/import_fairseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _parse_config(w2v_model, num_out):
'encoder_dropout': encoder.layers[0].dropout3.p,
'encoder_layer_norm_first': encoder.layer_norm_first,
'encoder_layer_drop': encoder.layerdrop,
'encoder_num_out': num_out,
'aux_num_out': num_out,
}
return config

Expand Down Expand Up @@ -110,7 +110,7 @@ def _map_key(key):
match = re.match(r"proj\.(weight|bias)", key)
# Encoder - Readout layer
if match:
return f"encoder.readout.{match.group(1)}"
return f"aux.{match.group(1)}"
raise ValueError(f'Unexpected key: {key_}')


Expand Down
4 changes: 2 additions & 2 deletions torchaudio/models/wav2vec2/utils/import_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _get_config(cfg):
'encoder_dropout': cfg.hidden_dropout,
'encoder_layer_norm_first': cfg.do_stable_layer_norm,
'encoder_layer_drop': cfg.layerdrop,
'encoder_num_out': cfg.vocab_size,
'aux_num_out': cfg.vocab_size,
}
return config

Expand All @@ -42,7 +42,7 @@ def _build(config, original):
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
if original.__class__.__name__ == 'Wav2Vec2ForCTC':
imported.encoder.readout.load_state_dict(original.lm_head.state_dict())
imported.aux.load_state_dict(original.lm_head.state_dict())
return imported


Expand Down