diff --git a/lib/bumblebee/text/conversation.ex b/lib/bumblebee/text/conversation.ex index c5ce4d64..d649e054 100644 --- a/lib/bumblebee/text/conversation.ex +++ b/lib/bumblebee/text/conversation.ex @@ -41,8 +41,6 @@ defmodule Bumblebee.Text.Conversation do batch_size = compile[:batch_size] sequence_length = compile[:sequence_length] - encoder_decoder? = encoder_decoder?(model) - generate_fun = Text.Generation.build_generate(model, spec, generation_config, Keyword.take(opts, [:seed])) @@ -66,18 +64,7 @@ defmodule Bumblebee.Text.Conversation do fn inputs -> inputs = Shared.maybe_pad(inputs, batch_size) - sequences = generate_fun.(params, inputs) - inputs = Nx.Defn.jit_apply(&Function.identity/1, [inputs]) - - start_idx = - if encoder_decoder? do - 1 - else - Nx.axis_size(inputs["input_ids"], 1) - end - - sequences[[.., start_idx..-1//1]] - |> Shared.serving_post_computation() + generate_fun.(params, inputs) |> Shared.serving_post_computation() end end, defn_options @@ -125,9 +112,4 @@ defmodule Bumblebee.Text.Conversation do defp validate_input(input) do {:error, "expected input to be a map with :text and :history, got: #{inspect(input)}"} end - - defp encoder_decoder?(model) do - inputs = Axon.get_inputs(model) - Map.has_key?(inputs, "input_ids") and Map.has_key?(inputs, "decoder_input_ids") - end end diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index fff77bba..4c71dc32 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -329,16 +329,16 @@ defmodule Bumblebee.Text.Generation do end end - deftransformp generate_impl( - inputs, - predict_fun, - params, - logits_processor_fun, - prepare_inputs_fun, - update_inputs_fun, - traverse_cache_fun, - opts \\ [] - ) do + defnp generate_impl( + inputs, + predict_fun, + params, + logits_processor_fun, + prepare_inputs_fun, + update_inputs_fun, + traverse_cache_fun, + opts \\ [] + ) do {decoder_inputs, decoder_input_ids, max_length} = prepare_inputs_fun.(inputs, params) length = Nx.axis_size(decoder_input_ids, 1) @@ -352,45 +352,54 @@ defmodule Bumblebee.Text.Generation do strategy = opts[:strategy] seed = opts[:seed] - case strategy.type do - :greedy_search -> - greedy( - decoder_inputs, - decoder_input_ids, - predict_fun, - params, - logits_processor_fun, - update_inputs_fun, - [max_length: max_length] ++ opts - ) + sequences = + case strategy.type do + :greedy_search -> + greedy( + decoder_inputs, + decoder_input_ids, + predict_fun, + params, + logits_processor_fun, + update_inputs_fun, + merge_options([max_length: max_length], opts) + ) - :contrastive_search -> - contrastive( - decoder_inputs, - decoder_input_ids, - predict_fun, - params, - logits_processor_fun, - update_inputs_fun, - traverse_cache_fun, - [max_length: max_length, top_k: strategy.top_k, penalty_alpha: strategy.alpha] ++ opts - ) + :contrastive_search -> + contrastive( + decoder_inputs, + decoder_input_ids, + predict_fun, + params, + logits_processor_fun, + update_inputs_fun, + traverse_cache_fun, + merge_options( + [max_length: max_length, top_k: strategy.top_k, penalty_alpha: strategy.alpha], + opts + ) + ) - :multinomial_sampling -> - prng_key = Nx.Random.key(seed) + :multinomial_sampling -> + prng_key = Nx.Random.key(seed) - sampling( - decoder_inputs, - decoder_input_ids, - predict_fun, - params, - logits_processor_fun, - update_inputs_fun, - [max_length: max_length, prng_key: prng_key] ++ opts - ) - end + sampling( + decoder_inputs, + decoder_input_ids, + predict_fun, + params, + logits_processor_fun, + update_inputs_fun, + merge_options([max_length: max_length, prng_key: prng_key], opts) + ) + end + + # Output only the newly generated tokens + sequences[[.., length..-1//1]] end + deftransformp merge_options(left, right), do: left ++ right + # Greedy search defnp greedy( @@ -457,10 +466,6 @@ defmodule Bumblebee.Text.Generation do defnp init_sequences(decoder_input_ids, max_length, pad_token_id) do {batch_size, length} = Nx.shape(decoder_input_ids) - if length > max_length do - raise ArgumentError, "expected the input to be at most #{max_length} tokens, got: #{length}" - end - sequences = Nx.broadcast(pad_token_id, {batch_size, max_length}) sequences = Nx.put_slice(sequences, [0, 0], decoder_input_ids) diff --git a/test/bumblebee/text/bart_test.exs b/test/bumblebee/text/bart_test.exs index b492ae6d..73d1d9de 100644 --- a/test/bumblebee/text/bart_test.exs +++ b/test/bumblebee/text/bart_test.exs @@ -157,6 +157,6 @@ defmodule Bumblebee.Text.BartTest do generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config) token_ids = generate.(params, inputs) - assert_equal(token_ids, Nx.tensor([[2, 988, 988, 988]])) + assert_equal(token_ids, Nx.tensor([[988, 988, 988]])) end end diff --git a/test/bumblebee/text/blenderbot_test.exs b/test/bumblebee/text/blenderbot_test.exs index 860da566..ddc7e6e2 100644 --- a/test/bumblebee/text/blenderbot_test.exs +++ b/test/bumblebee/text/blenderbot_test.exs @@ -82,6 +82,6 @@ defmodule Bumblebee.Text.BlenderbotTest do generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config) token_ids = generate.(params, inputs) - assert_equal(token_ids, Nx.tensor([[1, 382, 382, 382]])) + assert_equal(token_ids, Nx.tensor([[382, 382, 382]])) end end diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 23692e95..0d3d87c5 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -38,9 +38,9 @@ defmodule Bumblebee.Text.GenerationTest do serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config) # Without :no_repeat_ngram_length we get - # %{results: [%{text: "I was going to say, 'Well, I'm going to say,"}]} + # %{results: [%{text: " to say, 'Well, I'm going to say,"}]} - assert %{results: [%{text: "I was going to say, 'Well, I'm going back to the"}]} = + assert %{results: [%{text: " to say, 'Well, I'm going back to the"}]} = Nx.Serving.run(serving, "I was going") end @@ -60,11 +60,8 @@ defmodule Bumblebee.Text.GenerationTest do # Note that this is just a snapshot test, we do not use any # reference value, because of PRNG difference - assert %{ - results: [ - %{text: "I was going to fall asleep.\"\n\nThis is not Wallace's fifth"} - ] - } = Nx.Serving.run(serving, "I was going") + assert %{results: [%{text: " to fall asleep.\"\n\nThis is not Wallace's fifth"}]} = + Nx.Serving.run(serving, "I was going") end test "contrastive search" do @@ -80,7 +77,7 @@ defmodule Bumblebee.Text.GenerationTest do serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config) - assert %{results: [%{text: "I was going to say, 'Well, I don't know what you"}]} = + assert %{results: [%{text: " to say, 'Well, I don't know what you"}]} = Nx.Serving.run(serving, "I was going") end diff --git a/test/bumblebee/text/mbart_test.exs b/test/bumblebee/text/mbart_test.exs index a0e5123b..4aa0e825 100644 --- a/test/bumblebee/text/mbart_test.exs +++ b/test/bumblebee/text/mbart_test.exs @@ -155,6 +155,6 @@ defmodule Bumblebee.Text.MbartTest do generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config) token_ids = generate.(params, inputs) - assert_equal(token_ids, Nx.tensor([[0, 230_521, 20386, 20386]])) + assert_equal(token_ids, Nx.tensor([[230_521, 20386, 20386]])) end end diff --git a/test/bumblebee/text/t5_test.exs b/test/bumblebee/text/t5_test.exs index 98775ce0..5d4a2840 100644 --- a/test/bumblebee/text/t5_test.exs +++ b/test/bumblebee/text/t5_test.exs @@ -167,7 +167,7 @@ defmodule Bumblebee.Text.T5Test do generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config) token_ids = generate.(params, inputs) - assert_equal(token_ids, Nx.tensor([[0, 0, 0, 0]])) + assert_equal(token_ids, Nx.tensor([[0, 0, 0]])) end test "generation with :for_conditional_generation without tied embeddings" do @@ -195,6 +195,6 @@ defmodule Bumblebee.Text.T5Test do generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config) token_ids = generate.(params, inputs) - assert_equal(token_ids, Nx.tensor([[0, 6161, 29516, 9788]])) + assert_equal(token_ids, Nx.tensor([[6161, 29516, 9788]])) end end