Skip to content

Commit

Permalink
Return only new text from text generation
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Dec 12, 2023
1 parent d1783dc commit 2aebec2
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 80 deletions.
20 changes: 1 addition & 19 deletions lib/bumblebee/text/conversation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ defmodule Bumblebee.Text.Conversation do
batch_size = compile[:batch_size]
sequence_length = compile[:sequence_length]

encoder_decoder? = encoder_decoder?(model)

generate_fun =
Text.Generation.build_generate(model, spec, generation_config, Keyword.take(opts, [:seed]))

Expand All @@ -66,18 +64,7 @@ defmodule Bumblebee.Text.Conversation do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
sequences = generate_fun.(params, inputs)
inputs = Nx.Defn.jit_apply(&Function.identity/1, [inputs])

start_idx =
if encoder_decoder? do
1
else
Nx.axis_size(inputs["input_ids"], 1)
end

sequences[[.., start_idx..-1//1]]
|> Shared.serving_post_computation()
generate_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down Expand Up @@ -125,9 +112,4 @@ defmodule Bumblebee.Text.Conversation do
defp validate_input(input) do
{:error, "expected input to be a map with :text and :history, got: #{inspect(input)}"}
end

defp encoder_decoder?(model) do
inputs = Axon.get_inputs(model)
Map.has_key?(inputs, "input_ids") and Map.has_key?(inputs, "decoder_input_ids")
end
end
101 changes: 53 additions & 48 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,16 @@ defmodule Bumblebee.Text.Generation do
end
end

deftransformp generate_impl(
inputs,
predict_fun,
params,
logits_processor_fun,
prepare_inputs_fun,
update_inputs_fun,
traverse_cache_fun,
opts \\ []
) do
defnp generate_impl(
inputs,
predict_fun,
params,
logits_processor_fun,
prepare_inputs_fun,
update_inputs_fun,
traverse_cache_fun,
opts \\ []
) do
{decoder_inputs, decoder_input_ids, max_length} = prepare_inputs_fun.(inputs, params)

length = Nx.axis_size(decoder_input_ids, 1)
Expand All @@ -352,45 +352,54 @@ defmodule Bumblebee.Text.Generation do
strategy = opts[:strategy]
seed = opts[:seed]

case strategy.type do
:greedy_search ->
greedy(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
[max_length: max_length] ++ opts
)
sequences =
case strategy.type do
:greedy_search ->
greedy(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
merge_options([max_length: max_length], opts)
)

:contrastive_search ->
contrastive(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
traverse_cache_fun,
[max_length: max_length, top_k: strategy.top_k, penalty_alpha: strategy.alpha] ++ opts
)
:contrastive_search ->
contrastive(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
traverse_cache_fun,
merge_options(
[max_length: max_length, top_k: strategy.top_k, penalty_alpha: strategy.alpha],
opts
)
)

:multinomial_sampling ->
prng_key = Nx.Random.key(seed)
:multinomial_sampling ->
prng_key = Nx.Random.key(seed)

sampling(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
[max_length: max_length, prng_key: prng_key] ++ opts
)
end
sampling(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
merge_options([max_length: max_length, prng_key: prng_key], opts)
)
end

# Output only the newly generated tokens
sequences[[.., length..-1//1]]
end

deftransformp merge_options(left, right), do: left ++ right

# Greedy search

defnp greedy(
Expand Down Expand Up @@ -457,10 +466,6 @@ defmodule Bumblebee.Text.Generation do
defnp init_sequences(decoder_input_ids, max_length, pad_token_id) do
{batch_size, length} = Nx.shape(decoder_input_ids)

if length > max_length do
raise ArgumentError, "expected the input to be at most #{max_length} tokens, got: #{length}"
end

sequences = Nx.broadcast(pad_token_id, {batch_size, max_length})
sequences = Nx.put_slice(sequences, [0, 0], decoder_input_ids)

Expand Down
2 changes: 1 addition & 1 deletion test/bumblebee/text/bart_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,6 @@ defmodule Bumblebee.Text.BartTest do
generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
token_ids = generate.(params, inputs)

assert_equal(token_ids, Nx.tensor([[2, 988, 988, 988]]))
assert_equal(token_ids, Nx.tensor([[988, 988, 988]]))
end
end
2 changes: 1 addition & 1 deletion test/bumblebee/text/blenderbot_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,6 @@ defmodule Bumblebee.Text.BlenderbotTest do
generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
token_ids = generate.(params, inputs)

assert_equal(token_ids, Nx.tensor([[1, 382, 382, 382]]))
assert_equal(token_ids, Nx.tensor([[382, 382, 382]]))
end
end
13 changes: 5 additions & 8 deletions test/bumblebee/text/generation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ defmodule Bumblebee.Text.GenerationTest do
serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config)

# Without :no_repeat_ngram_length we get
# %{results: [%{text: "I was going to say, 'Well, I'm going to say,"}]}
# %{results: [%{text: " to say, 'Well, I'm going to say,"}]}

assert %{results: [%{text: "I was going to say, 'Well, I'm going back to the"}]} =
assert %{results: [%{text: " to say, 'Well, I'm going back to the"}]} =
Nx.Serving.run(serving, "I was going")
end

Expand All @@ -60,11 +60,8 @@ defmodule Bumblebee.Text.GenerationTest do
# Note that this is just a snapshot test, we do not use any
# reference value, because of PRNG difference

assert %{
results: [
%{text: "I was going to fall asleep.\"\n\nThis is not Wallace's fifth"}
]
} = Nx.Serving.run(serving, "I was going")
assert %{results: [%{text: " to fall asleep.\"\n\nThis is not Wallace's fifth"}]} =
Nx.Serving.run(serving, "I was going")
end

test "contrastive search" do
Expand All @@ -80,7 +77,7 @@ defmodule Bumblebee.Text.GenerationTest do

serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config)

assert %{results: [%{text: "I was going to say, 'Well, I don't know what you"}]} =
assert %{results: [%{text: " to say, 'Well, I don't know what you"}]} =
Nx.Serving.run(serving, "I was going")
end

Expand Down
2 changes: 1 addition & 1 deletion test/bumblebee/text/mbart_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,6 @@ defmodule Bumblebee.Text.MbartTest do
generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
token_ids = generate.(params, inputs)

assert_equal(token_ids, Nx.tensor([[0, 230_521, 20386, 20386]]))
assert_equal(token_ids, Nx.tensor([[230_521, 20386, 20386]]))
end
end
4 changes: 2 additions & 2 deletions test/bumblebee/text/t5_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ defmodule Bumblebee.Text.T5Test do
generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
token_ids = generate.(params, inputs)

assert_equal(token_ids, Nx.tensor([[0, 0, 0, 0]]))
assert_equal(token_ids, Nx.tensor([[0, 0, 0]]))
end

test "generation with :for_conditional_generation without tied embeddings" do
Expand Down Expand Up @@ -195,6 +195,6 @@ defmodule Bumblebee.Text.T5Test do
generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
token_ids = generate.(params, inputs)

assert_equal(token_ids, Nx.tensor([[0, 6161, 29516, 9788]]))
assert_equal(token_ids, Nx.tensor([[6161, 29516, 9788]]))
end
end

0 comments on commit 2aebec2

Please sign in to comment.