Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return only new text from text generation #302

Merged
merged 1 commit into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 1 addition & 19 deletions lib/bumblebee/text/conversation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ defmodule Bumblebee.Text.Conversation do
batch_size = compile[:batch_size]
sequence_length = compile[:sequence_length]

encoder_decoder? = encoder_decoder?(model)

generate_fun =
Text.Generation.build_generate(model, spec, generation_config, Keyword.take(opts, [:seed]))

Expand All @@ -66,18 +64,7 @@ defmodule Bumblebee.Text.Conversation do

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
sequences = generate_fun.(params, inputs)
inputs = Nx.Defn.jit_apply(&Function.identity/1, [inputs])

start_idx =
if encoder_decoder? do
1
else
Nx.axis_size(inputs["input_ids"], 1)
end

sequences[[.., start_idx..-1//1]]
|> Shared.serving_post_computation()
generate_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
Expand Down Expand Up @@ -125,9 +112,4 @@ defmodule Bumblebee.Text.Conversation do
defp validate_input(input) do
{:error, "expected input to be a map with :text and :history, got: #{inspect(input)}"}
end

defp encoder_decoder?(model) do
inputs = Axon.get_inputs(model)
Map.has_key?(inputs, "input_ids") and Map.has_key?(inputs, "decoder_input_ids")
end
end
101 changes: 53 additions & 48 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,16 @@ defmodule Bumblebee.Text.Generation do
end
end

deftransformp generate_impl(
inputs,
predict_fun,
params,
logits_processor_fun,
prepare_inputs_fun,
update_inputs_fun,
traverse_cache_fun,
opts \\ []
) do
defnp generate_impl(
inputs,
predict_fun,
params,
logits_processor_fun,
prepare_inputs_fun,
update_inputs_fun,
traverse_cache_fun,
opts \\ []
) do
{decoder_inputs, decoder_input_ids, max_length} = prepare_inputs_fun.(inputs, params)

length = Nx.axis_size(decoder_input_ids, 1)
Expand All @@ -352,45 +352,54 @@ defmodule Bumblebee.Text.Generation do
strategy = opts[:strategy]
seed = opts[:seed]

case strategy.type do
:greedy_search ->
greedy(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
[max_length: max_length] ++ opts
)
sequences =
case strategy.type do
:greedy_search ->
greedy(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
merge_options([max_length: max_length], opts)
)

:contrastive_search ->
contrastive(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
traverse_cache_fun,
[max_length: max_length, top_k: strategy.top_k, penalty_alpha: strategy.alpha] ++ opts
)
:contrastive_search ->
contrastive(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
traverse_cache_fun,
merge_options(
[max_length: max_length, top_k: strategy.top_k, penalty_alpha: strategy.alpha],
opts
)
)

:multinomial_sampling ->
prng_key = Nx.Random.key(seed)
:multinomial_sampling ->
prng_key = Nx.Random.key(seed)

sampling(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
[max_length: max_length, prng_key: prng_key] ++ opts
)
end
sampling(
decoder_inputs,
decoder_input_ids,
predict_fun,
params,
logits_processor_fun,
update_inputs_fun,
merge_options([max_length: max_length, prng_key: prng_key], opts)
)
end

# Output only the newly generated tokens
sequences[[.., length..-1//1]]
end

deftransformp merge_options(left, right), do: left ++ right

# Greedy search

defnp greedy(
Expand Down Expand Up @@ -457,10 +466,6 @@ defmodule Bumblebee.Text.Generation do
defnp init_sequences(decoder_input_ids, max_length, pad_token_id) do
{batch_size, length} = Nx.shape(decoder_input_ids)

if length > max_length do
raise ArgumentError, "expected the input to be at most #{max_length} tokens, got: #{length}"
end

sequences = Nx.broadcast(pad_token_id, {batch_size, max_length})
sequences = Nx.put_slice(sequences, [0, 0], decoder_input_ids)

Expand Down
2 changes: 1 addition & 1 deletion test/bumblebee/text/bart_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,6 @@ defmodule Bumblebee.Text.BartTest do
generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
token_ids = generate.(params, inputs)

assert_equal(token_ids, Nx.tensor([[2, 988, 988, 988]]))
assert_equal(token_ids, Nx.tensor([[988, 988, 988]]))
end
end
2 changes: 1 addition & 1 deletion test/bumblebee/text/blenderbot_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,6 @@ defmodule Bumblebee.Text.BlenderbotTest do
generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
token_ids = generate.(params, inputs)

assert_equal(token_ids, Nx.tensor([[1, 382, 382, 382]]))
assert_equal(token_ids, Nx.tensor([[382, 382, 382]]))
end
end
13 changes: 5 additions & 8 deletions test/bumblebee/text/generation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ defmodule Bumblebee.Text.GenerationTest do
serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config)

# Without :no_repeat_ngram_length we get
# %{results: [%{text: "I was going to say, 'Well, I'm going to say,"}]}
# %{results: [%{text: " to say, 'Well, I'm going to say,"}]}

assert %{results: [%{text: "I was going to say, 'Well, I'm going back to the"}]} =
assert %{results: [%{text: " to say, 'Well, I'm going back to the"}]} =
Nx.Serving.run(serving, "I was going")
end

Expand All @@ -60,11 +60,8 @@ defmodule Bumblebee.Text.GenerationTest do
# Note that this is just a snapshot test, we do not use any
# reference value, because of PRNG difference

assert %{
results: [
%{text: "I was going to fall asleep.\"\n\nThis is not Wallace's fifth"}
]
} = Nx.Serving.run(serving, "I was going")
assert %{results: [%{text: " to fall asleep.\"\n\nThis is not Wallace's fifth"}]} =
Nx.Serving.run(serving, "I was going")
end

test "contrastive search" do
Expand All @@ -80,7 +77,7 @@ defmodule Bumblebee.Text.GenerationTest do

serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config)

assert %{results: [%{text: "I was going to say, 'Well, I don't know what you"}]} =
assert %{results: [%{text: " to say, 'Well, I don't know what you"}]} =
Nx.Serving.run(serving, "I was going")
end

Expand Down
2 changes: 1 addition & 1 deletion test/bumblebee/text/mbart_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,6 @@ defmodule Bumblebee.Text.MbartTest do
generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
token_ids = generate.(params, inputs)

assert_equal(token_ids, Nx.tensor([[0, 230_521, 20386, 20386]]))
assert_equal(token_ids, Nx.tensor([[230_521, 20386, 20386]]))
end
end
4 changes: 2 additions & 2 deletions test/bumblebee/text/t5_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ defmodule Bumblebee.Text.T5Test do
generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
token_ids = generate.(params, inputs)

assert_equal(token_ids, Nx.tensor([[0, 0, 0, 0]]))
assert_equal(token_ids, Nx.tensor([[0, 0, 0]]))
end

test "generation with :for_conditional_generation without tied embeddings" do
Expand Down Expand Up @@ -195,6 +195,6 @@ defmodule Bumblebee.Text.T5Test do
generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
token_ids = generate.(params, inputs)

assert_equal(token_ids, Nx.tensor([[0, 6161, 29516, 9788]]))
assert_equal(token_ids, Nx.tensor([[6161, 29516, 9788]]))
end
end
Loading