Skip to content

Commit

Permalink
Add support for Blenderbot and BlenderbotSmall (huggingface#292)
Browse files Browse the repository at this point in the history
* Add support for `Blenderbot` models

Closes huggingface#37
References huggingface#29

* Add support for `BlenderbotTokenizer`

* Add blenderbot to supported models

* Add support for `BlenderbotSmallTokenizer`

* Add custom tests for blenderbot-small

* Add support for `BlenderbotSmall` models

* Update list of supported models

* Improve `addPastKeyValues` function

* Allow skipping of adding encoder past key values
  • Loading branch information
xenova authored Sep 19, 2023
1 parent c453e6b commit c367f9d
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 49 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer.
1. **[BEiT](https://huggingface.co/docs/transformers/model_doc/beit)** (from Microsoft) released with the paper [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254) by Hangbo Bao, Li Dong, Furu Wei.
1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
1. **[Blenderbot](https://huggingface.co/docs/transformers/model_doc/blenderbot)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
1. **[BlenderbotSmall](https://huggingface.co/docs/transformers/model_doc/blenderbot-small)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
1. **[BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom)** (from BigScience workshop) released by the [BigScience Workshop](https://bigscience.huggingface.co/).
1. **[CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot.
1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever.
Expand Down
2 changes: 2 additions & 0 deletions docs/snippets/6_supported-models.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
1. **[BART](https://huggingface.co/docs/transformers/model_doc/bart)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer.
1. **[BEiT](https://huggingface.co/docs/transformers/model_doc/beit)** (from Microsoft) released with the paper [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254) by Hangbo Bao, Li Dong, Furu Wei.
1. **[BERT](https://huggingface.co/docs/transformers/model_doc/bert)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
1. **[Blenderbot](https://huggingface.co/docs/transformers/model_doc/blenderbot)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
1. **[BlenderbotSmall](https://huggingface.co/docs/transformers/model_doc/blenderbot-small)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
1. **[BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom)** (from BigScience workshop) released by the [BigScience Workshop](https://bigscience.huggingface.co/).
1. **[CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot.
1. **[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)** (from OpenAI) released with the paper [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever.
Expand Down
20 changes: 10 additions & 10 deletions scripts/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,16 @@
'bert-base-chinese',
'emilyalsentzer/Bio_ClinicalBERT',
],
# 'blenderbot': [
# # Text2text generation (TODO add conversational)
# 'facebook/blenderbot-400M-distill',
# 'facebook/blenderbot-1B-distill',
# ],
# 'blenderbot-small': [
# # Text2text generation (TODO add conversational)
# 'facebook/blenderbot-90M', # DEPRECATED
# 'facebook/blenderbot_small-90M',
# ],
'blenderbot': [
# Text2text generation (TODO add conversational)
'facebook/blenderbot-400M-distill',
# 'facebook/blenderbot-1B-distill',
],
'blenderbot-small': [
# Text2text generation (TODO add conversational)
# 'facebook/blenderbot-90M', # DEPRECATED
'facebook/blenderbot_small-90M',
],
'bloom': [
# Text generation
'bigscience/bloom-560m',
Expand Down
152 changes: 113 additions & 39 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ function boolTensor(value) {
* @private
*/
async function seq2seqForward(self, model_inputs) {
const add_decoder_pkv = self.add_decoder_pkv ?? true;

let { encoder_outputs, past_key_values } = model_inputs;

Expand All @@ -327,7 +326,7 @@ async function seq2seqForward(self, model_inputs) {
if (self.decoder_merged_session.inputNames.includes('encoder_attention_mask')) {
decoderFeeds.encoder_attention_mask = model_inputs.attention_mask
}
self.addPastKeyValues(decoderFeeds, past_key_values, add_decoder_pkv);
self.addPastKeyValues(decoderFeeds, past_key_values);

const decoderResults = await sessionRun(self.decoder_merged_session, decoderFeeds);
let logits = decoderResults.logits;
Expand Down Expand Up @@ -1199,57 +1198,51 @@ export class PreTrainedModel extends Callable {
*
* @param {Object} decoderFeeds The decoder feeds object to add past key values to.
* @param {Object} pastKeyValues An object containing past key values.
* @param {boolean} [hasDecoder=false] Whether the model has a decoder.
*/
addPastKeyValues(decoderFeeds, pastKeyValues, hasDecoder = false) {
addPastKeyValues(decoderFeeds, pastKeyValues) {
if (pastKeyValues) {
Object.assign(decoderFeeds, pastKeyValues)
} else {
// TODO support batches (i.e., batch_size > 1)
if (hasDecoder) {
// @ts-ignore
if (this.config.is_encoder_decoder && (this.add_encoder_pkv ?? true)) {
// @ts-ignore
let encoder_dims = [1, this.num_encoder_heads, 0, this.encoder_dim_kv];
// @ts-ignore
for (let i = 0; i < this.num_encoder_layers; ++i) {
decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor('float32', [], encoder_dims)
decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor('float32', [], encoder_dims)
}

// @ts-ignore
let decoder_dims = [1, this.num_decoder_heads, 0, this.decoder_dim_kv];
// @ts-ignore
for (let i = 0; i < this.num_decoder_layers; ++i) {
decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor('float32', [], encoder_dims)
decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor('float32', [], encoder_dims)
decoderFeeds[`past_key_values.${i}.decoder.key`] = new Tensor('float32', [], decoder_dims)
decoderFeeds[`past_key_values.${i}.decoder.value`] = new Tensor('float32', [], decoder_dims)
}
} else if (this.config.multi_query) { // e.g., for `gpt_bigcode`
// @ts-ignore
let dims = [1, 0, 2 * this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims)
}
} else if (this.config.model_type === 'bloom') {
// NOTE: Custom implementation for Bloom

} else {
if (this.config.multi_query) {
// @ts-ignore
let dims = [1, 0, 2 * this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims)
}
} else if (this.config.model_type === 'bloom') {
// Custom implementation for Bloom
// @ts-ignore
let keyDims = [1 * this.num_heads, this.dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
// @ts-ignore
let valueDims = [1 * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims)
}
} else {
// @ts-ignore
let dims = [1, this.num_heads, 0, this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims)
}
// @ts-ignore
let keyDims = [1 * this.num_heads, this.dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
// @ts-ignore
let valueDims = [1 * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims)
}
} else { // Decoder-only
// @ts-ignore
let dims = [1, this.num_heads, 0, this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims)
}
}
}
Expand Down Expand Up @@ -2033,6 +2026,83 @@ export class MBartForSequenceClassification extends MBartPreTrainedModel {

//////////////////////////////////////////////////


//////////////////////////////////////////////////
// Blenderbot models
export class BlenderbotPreTrainedModel extends PreTrainedModel { };

/**
* The bare Blenderbot Model outputting raw hidden-states without any specific head on top.
*/
export class BlenderbotModel extends BlenderbotPreTrainedModel { }

/**
* The Blenderbot Model with a language modeling head. Can be used for summarization.
*/
export class BlenderbotForConditionalGeneration extends BlenderbotPreTrainedModel {

/**
* Creates a new instance of the `BlenderbotForConditionalGeneration` class.
* @param {any} config The model configuration.
* @param {any} session The ONNX session containing the encoder weights.
* @param {any} decoder_merged_session The ONNX session containing the merged decoder weights.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session, decoder_merged_session, generation_config) {
super(config, session);
this.decoder_merged_session = decoder_merged_session;
this.generation_config = generation_config;

this.num_decoder_layers = this.config.decoder_layers;
this.num_decoder_heads = this.config.decoder_attention_heads;
this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;

this.num_encoder_layers = this.config.encoder_layers;
this.num_encoder_heads = this.config.encoder_attention_heads;
this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
}
}
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// Blenderbot models
export class BlenderbotSmallPreTrainedModel extends PreTrainedModel { };

/**
* The bare BlenderbotSmall Model outputting raw hidden-states without any specific head on top.
*/
export class BlenderbotSmallModel extends BlenderbotSmallPreTrainedModel { }

/**
* The BlenderbotSmall Model with a language modeling head. Can be used for summarization.
*/
export class BlenderbotSmallForConditionalGeneration extends BlenderbotSmallPreTrainedModel {

/**
* Creates a new instance of the `BlenderbotForConditionalGeneration` class.
* @param {any} config The model configuration.
* @param {any} session The ONNX session containing the encoder weights.
* @param {any} decoder_merged_session The ONNX session containing the merged decoder weights.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session, decoder_merged_session, generation_config) {
super(config, session);
this.decoder_merged_session = decoder_merged_session;
this.generation_config = generation_config;

this.num_decoder_layers = this.config.decoder_layers;
this.num_decoder_heads = this.config.decoder_attention_heads;
this.decoder_dim_kv = this.config.d_model / this.num_decoder_heads;

this.num_encoder_layers = this.config.encoder_layers;
this.num_encoder_heads = this.config.encoder_attention_heads;
this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
}
}
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// Roberta models
export class RobertaPreTrainedModel extends PreTrainedModel { }
Expand Down Expand Up @@ -2458,7 +2528,7 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
*/
export class VisionEncoderDecoderModel extends PreTrainedModel {
main_input_name = 'pixel_values';
add_decoder_pkv = false;
add_encoder_pkv = false;

/**
* Creates a new instance of the `VisionEncoderDecoderModel` class.
Expand Down Expand Up @@ -3422,6 +3492,8 @@ const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
['marian', ['MarianModel', MarianModel]],
['whisper', ['WhisperModel', WhisperModel]],
['m2m_100', ['M2M100Model', M2M100Model]],
['blenderbot', ['BlenderbotModel', BlenderbotModel]],
['blenderbot-small', ['BlenderbotSmallModel', BlenderbotSmallModel]],
]);


Expand Down Expand Up @@ -3475,6 +3547,8 @@ const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([
['whisper', ['WhisperForConditionalGeneration', WhisperForConditionalGeneration]],
['marian', ['MarianMTModel', MarianMTModel]],
['m2m_100', ['M2M100ForConditionalGeneration', M2M100ForConditionalGeneration]],
['blenderbot', ['BlenderbotForConditionalGeneration', BlenderbotForConditionalGeneration]],
['blenderbot-small', ['BlenderbotSmallForConditionalGeneration', BlenderbotSmallForConditionalGeneration]],
]);

const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([
Expand Down
Loading

0 comments on commit c367f9d

Please sign in to comment.