Skip to content

Commit

Permalink
Fix HF model integration (pytorch#1781)
Browse files Browse the repository at this point in the history
* Fix HF model integration

Previously, when testing wav2vec models from HF transformers, all the model were
instantiated as `Wav2Vec2ForCTC` class, while some of them were supposed to be
`Wav2Vec2Model`.

Fixing this revealed that model importer cannot correctly handle `Wav2Vec2Model` import.

This PR fixes these issues.
  • Loading branch information
mthrok authored Sep 22, 2021
1 parent 1b4b82e commit e9cab8f
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ def _name_func(testcase_func, i, param):
HF_LARGE_XLSR_DE = _load_config('facebook', 'wav2vec2-large-xlsr-53-german')

# Config and corresponding factory functions
HF_CONFIGS = parameterized.expand([
# pretrained
PRETRAIN_CONFIGS = parameterized.expand([
(HF_BASE, wav2vec2_base),
(HF_LARGE, wav2vec2_large),
(HF_LARGE_LV60, wav2vec2_large_lv60k),
(HF_LARGE_XLSR_53, wav2vec2_large_lv60k),
(HF_BASE_10K_VOXPOPULI, wav2vec2_base),
# finetuned
], name_func=_name_func)
FINETUNE_CONFIGS = parameterized.expand([
(HF_BASE_960H, wav2vec2_base),
(HF_LARGE_960H, wav2vec2_large),
(HF_LARGE_LV60_960H, wav2vec2_large_lv60k),
Expand All @@ -72,34 +72,34 @@ def _get_model(self, config):
# the actual tests are started.
from transformers.models.wav2vec2 import (
Wav2Vec2Config,
Wav2Vec2Model,
Wav2Vec2ForCTC,
)
return Wav2Vec2ForCTC(Wav2Vec2Config(**config))

@HF_CONFIGS
def test_import(self, config, _):
"""wav2vec2 models from HF transformers can be imported and yields the same results"""
original = self._get_model(config).eval()
imported = import_huggingface_model(original).eval()
if config['architectures'] == ['Wav2Vec2Model']:
return Wav2Vec2Model(Wav2Vec2Config(**config))
if config['architectures'] == ['Wav2Vec2ForCTC']:
return Wav2Vec2ForCTC(Wav2Vec2Config(**config))
raise ValueError(f'Unexpected arch: {config["architectures"]}')

def _test_import_pretrain(self, original, imported, config, ):
torch.manual_seed(0)
# FeatureExtractor
x = torch.randn(3, 1024)
ref = original.wav2vec2.feature_extractor(x).transpose(1, 2)
ref = original.feature_extractor(x).transpose(1, 2)
hyp, _ = imported.feature_extractor(x, None)
self.assertEqual(ref, hyp)
# Feature projection
x = torch.randn(3, 10, config['conv_dim'][-1])
ref = original.wav2vec2.feature_projection(x)[0]
ref = original.feature_projection(x)[0]
hyp = imported.encoder.feature_projection(x)
self.assertEqual(ref, hyp)
# Convolutional Positional Encoder
x = torch.randn(3, 256, config['hidden_size'])
ref = original.wav2vec2.encoder.pos_conv_embed(x)
ref = original.encoder.pos_conv_embed(x)
hyp = imported.encoder.transformer.pos_conv_embed(x)
self.assertEqual(ref, hyp)
# Encoder Transformer Layer
for original_, imported_ in zip(original.wav2vec2.encoder.layers, imported.encoder.transformer.layers):
for original_, imported_ in zip(original.encoder.layers, imported.encoder.transformer.layers):
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
mask = torch.randn(b, 1, l, l)
Expand All @@ -110,9 +110,11 @@ def test_import(self, config, _):
# The whole Encoder Transformer
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
ref = original.wav2vec2.encoder(x).last_hidden_state
ref = original.encoder(x).last_hidden_state
hyp = imported.encoder.transformer(x)
self.assertEqual(ref, hyp)

def _test_import_finetune(self, original, imported, config):
# Readout
x = torch.randn(3, 10, config["hidden_size"])
ref = original.lm_head(x)
Expand Down Expand Up @@ -142,15 +144,22 @@ def test_import(self, config, _):
for i, l in enumerate(output_lengths):
self.assertEqual(ref[i, :l, ...], hyp[i, :l, ...])

@HF_CONFIGS
def test_recreate(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
@PRETRAIN_CONFIGS
def test_import_pretrain(self, config, _):
"""wav2vec2 models from HF transformers can be imported and yields the same results"""
original = self._get_model(config).eval()
imported = import_huggingface_model(original).eval()
self._test_import_pretrain(original, imported, config)

reloaded = factory_func(num_out=imported.encoder.readout.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
@FINETUNE_CONFIGS
def test_import_finetune(self, config, _):
"""wav2vec2 models from HF transformers can be imported and yields the same results"""
original = self._get_model(config).eval()
imported = import_huggingface_model(original).eval()
self._test_import_pretrain(original.wav2vec2, imported, config)
self._test_import_finetune(original, imported, config)

def _test_recreate(self, imported, reloaded, config):
torch.manual_seed(0)
# FeatureExtractor
x = torch.randn(3, 1024)
Expand Down Expand Up @@ -194,3 +203,21 @@ def test_recreate(self, config, factory_func):
ref, _ = imported(x)
hyp, _ = reloaded(x)
self.assertEqual(ref, hyp)

@PRETRAIN_CONFIGS
def test_recreate_pretrain(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.encoder.readout.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
self._test_recreate(imported, reloaded, config)

@FINETUNE_CONFIGS
def test_recreate_finetune(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.encoder.readout.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
self._test_recreate(imported, reloaded, config)
14 changes: 10 additions & 4 deletions torchaudio/models/wav2vec2/utils/import_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,17 @@ def _get_config(cfg):


def _build(config, original):
if original.__class__.__name__ == 'Wav2Vec2ForCTC':
wav2vec2 = original.wav2vec2
else:
wav2vec2 = original

imported = _get_model(**config)
imported.feature_extractor.load_state_dict(original.wav2vec2.feature_extractor.state_dict())
imported.encoder.feature_projection.load_state_dict(original.wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(original.wav2vec2.encoder.state_dict())
imported.encoder.readout.load_state_dict(original.lm_head.state_dict())
imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict())
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
if original.__class__.__name__ == 'Wav2Vec2ForCTC':
imported.encoder.readout.load_state_dict(original.lm_head.state_dict())
return imported


Expand Down

0 comments on commit e9cab8f

Please sign in to comment.