Skip to content

Commit

Permalink
Add Blenderbot (#177)
Browse files Browse the repository at this point in the history
Co-authored-by: Mateusz Sluszniak <[email protected]>
  • Loading branch information
jonatanklosko and msluszniak authored Mar 28, 2023
1 parent faa5474 commit d94737b
Show file tree
Hide file tree
Showing 7 changed files with 712 additions and 5 deletions.
4 changes: 4 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ defmodule Bumblebee do
"BertForTokenClassification" => {Bumblebee.Text.Bert, :for_token_classification},
"BertLMHeadModel" => {Bumblebee.Text.Bert, :for_causal_language_modeling},
"BertModel" => {Bumblebee.Text.Bert, :base},
"BlenderbotForConditionalGeneration" =>
{Bumblebee.Text.Blenderbot, :for_conditional_generation},
"BlenderbotModel" => {Bumblebee.Text.Blenderbot, :base},
"BlipForConditionalGeneration" => {Bumblebee.Multimodal.Blip, :for_conditional_generation},
# These models are just RoBERTa models, but the config will list them as CamemBERT
"CamembertModel" => {Bumblebee.Text.Roberta, :base},
Expand Down Expand Up @@ -155,6 +158,7 @@ defmodule Bumblebee do
"albert" => Bumblebee.Text.AlbertTokenizer,
"bart" => Bumblebee.Text.BartTokenizer,
"bert" => Bumblebee.Text.BertTokenizer,
"blenderbot" => Bumblebee.Text.BlenderbotTokenizer,
"blip" => Bumblebee.Text.BertTokenizer,
"distilbert" => Bumblebee.Text.DistilbertTokenizer,
"camembert" => Bumblebee.Text.CamembertTokenizer,
Expand Down
8 changes: 7 additions & 1 deletion lib/bumblebee/text/bart.ex
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,10 @@ defmodule Bumblebee.Text.Bart do
inputs = encoder_decoder_inputs(spec)
outputs = core(inputs, spec)

logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head")
logits =
outputs.hidden_state
|> language_modeling_head(spec, name: "language_modeling_head")
|> Axon.bias(name: "language_modeling_head.logits_bias", bias_initializer: :zeros)

Layers.output(%{
logits: logits,
Expand Down Expand Up @@ -685,6 +688,9 @@ defmodule Bumblebee.Text.Bart do
"decoder.blocks.{n}.ffn.output" => "model.decoder.layers.{n}.fc2",
"decoder.blocks.{n}.output_norm" => "model.decoder.layers.{n}.final_layer_norm",
"language_modeling_head.output" => "model.shared",
"language_modeling_head.logits_bias" => %{
"bias" => {[{"model", "final_logits_bias"}], fn [value] -> Nx.squeeze(value) end}
},
"sequence_classification_head.dense" => "classification_head.dense",
"sequence_classification_head.output" => "classification_head.out_proj",
"question_answering_head.output" => "qa_outputs"
Expand Down
Loading

0 comments on commit d94737b

Please sign in to comment.