Skip to content

Commit

Permalink
[BC-Breaking] Move fine-tune specific module out of wav2vec2 encoder (p…
Browse files Browse the repository at this point in the history
…ytorch#1782)

Previously, the Linear module (called `readout`, which is used only for an ASR fine-tuning
task) was placed in encoder module. Conceptually, the encoder has nothing to
do with a module specific to fine-tuning / downstream task.

The problems here are that;
1. encoder can be also used in pre-training phase, in which such a module should
not present
2. The choice of Linear module is arbitral, and it is inconvenient for users
to have hard-coded module structure in encoder.

Therefore, this commit moves the Linear module out the encoder, and places it
as `aux` attribute of `Wav2Vec2Model`. (as a result `Wav2Vec2Model` has
`feature_extractor`, `encoder` and `aux` attributes.)

An alternative approach is to define another module and place `Wav2Vec2Model`
and aux module along each other. But that will introduce a new class we need
to maintain.
The expected use of `aux` is only  for 1. loading the pre-trained parameters 
published by `fairseq` (and it's variations from HF) and 2. creating the same model 
architectures for comparison experiment.
The newly introduced class will not be general enough for downstream adaptations, 
where there will be a bunch of different more complicated models. (i.e. s3prl)

Therefore, based on the minimalistic approach, we put them inside of `Wav2Vec2Model`.
  • Loading branch information
mthrok authored Sep 22, 2021
1 parent e9cab8f commit 40f2a08
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _test_import_finetune(self, original, imported, config):
# Readout
x = torch.randn(3, 10, config["hidden_size"])
ref = original.lm_head(x)
hyp = imported.encoder.readout(x)
hyp = imported.aux(x)
self.assertEqual(ref, hyp)
# The whole model without mask
x = torch.randn(3, 1024)
Expand Down Expand Up @@ -195,8 +195,8 @@ def _test_recreate(self, imported, reloaded, config):
self.assertEqual(ref, hyp)
# Readout
x = torch.randn(3, 10, config["hidden_size"])
ref = imported.encoder.readout(x)
hyp = reloaded.encoder.readout(x)
ref = imported.aux(x)
hyp = reloaded.aux(x)
self.assertEqual(ref, hyp)
# The whole model
x = torch.randn(3, 1024)
Expand All @@ -208,7 +208,7 @@ def _test_recreate(self, imported, reloaded, config):
def test_recreate_pretrain(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.encoder.readout.out_features)
reloaded = factory_func(num_out=imported.aux.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
self._test_recreate(imported, reloaded, config)
Expand All @@ -217,7 +217,7 @@ def test_recreate_pretrain(self, config, factory_func):
def test_recreate_finetune(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.encoder.readout.out_features)
reloaded = factory_func(num_out=imported.aux.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
self._test_recreate(imported, reloaded, config)
10 changes: 1 addition & 9 deletions torchaudio/models/wav2vec2/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,10 @@ def __init__(
self,
feature_projection: Module,
transformer: Module,
readout: Module,
):
super().__init__()
self.feature_projection = feature_projection
self.transformer = transformer
self.readout = readout

def _preprocess(
self,
Expand All @@ -458,7 +456,6 @@ def forward(
) -> Tensor:
x, mask = self._preprocess(features, lengths)
x = self.transformer(x, attention_mask=mask)
x = self.readout(x)
return x

def extract_features(
Expand Down Expand Up @@ -561,7 +558,6 @@ def _get_encoder(
dropout: float,
layer_norm_first: bool,
layer_drop: float,
num_out: int,
) -> Encoder:
"""
Args:
Expand Down Expand Up @@ -720,8 +716,4 @@ def _get_encoder(
layer_norm_first=not layer_norm_first,
layer_drop=layer_drop,
)
readout = nn.Linear(
in_features=embed_dim,
out_features=num_out,
)
return Encoder(feature_projection, transformer, readout)
return Encoder(feature_projection, transformer)
25 changes: 18 additions & 7 deletions torchaudio/models/wav2vec2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,20 @@ class Wav2Vec2Model(Module):
encoder (torch.nn.Module):
Encoder that converts the audio features into the sequence of probability
distribution (in negative log-likelihood) over labels.
aux (torch.nn.Module or None, optional):
Auxiliary module. If provided, the output from encoder is passed to this module.
"""
def __init__(
self,
feature_extractor: Module,
encoder: Module,
aux: Optional[Module] = None,
):
super().__init__()
self.feature_extractor = feature_extractor
self.encoder = encoder
self.aux = aux

@torch.jit.export
def extract_features(
Expand Down Expand Up @@ -89,7 +94,10 @@ def forward(
Shape: ``(batch, )``.
"""
x, lengths = self.feature_extractor(waveforms, lengths)
return self.encoder(x, lengths), lengths
x = self.encoder(x, lengths)
if self.aux is not None:
x = self.aux(x)
return x, lengths


def _get_model(
Expand All @@ -108,7 +116,7 @@ def _get_model(
encoder_dropout: float,
encoder_layer_norm_first: bool,
encoder_layer_drop: float,
encoder_num_out: int,
aux_num_out: int,
) -> Wav2Vec2Model:
if extractor_conv_layer_config is None:
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
Expand All @@ -129,9 +137,12 @@ def _get_model(
dropout=encoder_dropout,
layer_norm_first=encoder_layer_norm_first,
layer_drop=encoder_layer_drop,
num_out=encoder_num_out,
)
return Wav2Vec2Model(feature_extractor, encoder)
aux = torch.nn.Linear(
in_features=encoder_embed_dim,
out_features=aux_num_out,
)
return Wav2Vec2Model(feature_extractor, encoder, aux)


def wav2vec2_base(num_out: int) -> Wav2Vec2Model:
Expand Down Expand Up @@ -172,7 +183,7 @@ def wav2vec2_base(num_out: int) -> Wav2Vec2Model:
encoder_dropout=0.1,
encoder_layer_norm_first=False,
encoder_layer_drop=0.1,
encoder_num_out=num_out,
aux_num_out=num_out,
)


Expand Down Expand Up @@ -214,7 +225,7 @@ def wav2vec2_large(num_out: int) -> Wav2Vec2Model:
encoder_dropout=0.1,
encoder_layer_norm_first=False,
encoder_layer_drop=0.1,
encoder_num_out=num_out,
aux_num_out=num_out,
)


Expand Down Expand Up @@ -256,5 +267,5 @@ def wav2vec2_large_lv60k(num_out: int) -> Wav2Vec2Model:
encoder_dropout=0.0,
encoder_layer_norm_first=True,
encoder_layer_drop=0.1,
encoder_num_out=num_out,
aux_num_out=num_out,
)
4 changes: 2 additions & 2 deletions torchaudio/models/wav2vec2/utils/import_fairseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _parse_config(w2v_model, num_out):
'encoder_dropout': encoder.layers[0].dropout3.p,
'encoder_layer_norm_first': encoder.layer_norm_first,
'encoder_layer_drop': encoder.layerdrop,
'encoder_num_out': num_out,
'aux_num_out': num_out,
}
return config

Expand Down Expand Up @@ -110,7 +110,7 @@ def _map_key(key):
match = re.match(r"proj\.(weight|bias)", key)
# Encoder - Readout layer
if match:
return f"encoder.readout.{match.group(1)}"
return f"aux.{match.group(1)}"
raise ValueError(f'Unexpected key: {key_}')


Expand Down
4 changes: 2 additions & 2 deletions torchaudio/models/wav2vec2/utils/import_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _get_config(cfg):
'encoder_dropout': cfg.hidden_dropout,
'encoder_layer_norm_first': cfg.do_stable_layer_norm,
'encoder_layer_drop': cfg.layerdrop,
'encoder_num_out': cfg.vocab_size,
'aux_num_out': cfg.vocab_size,
}
return config

Expand All @@ -42,7 +42,7 @@ def _build(config, original):
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
if original.__class__.__name__ == 'Wav2Vec2ForCTC':
imported.encoder.readout.load_state_dict(original.lm_head.state_dict())
imported.aux.load_state_dict(original.lm_head.state_dict())
return imported


Expand Down

0 comments on commit 40f2a08

Please sign in to comment.