diff --git a/lib/kino_bumblebee/task_cell.ex b/lib/kino_bumblebee/task_cell.ex index 3bdac6d..c9096f5 100644 --- a/lib/kino_bumblebee/task_cell.ex +++ b/lib/kino_bumblebee/task_cell.ex @@ -189,7 +189,7 @@ defmodule KinoBumblebee.TaskCell do field: "aggregation", label: "Aggregation", type: :select, - options: [%{value: nil, label: "None"}, %{value: :same, label: "Same"}], + options: [%{value: nil, label: "None"}, %{value: "same", label: "Same"}], default: :same }, %{field: "sequence_length", label: "Max input tokens", type: :number, default: 100} @@ -690,7 +690,7 @@ defmodule KinoBumblebee.TaskCell do defp to_quoted(%{"task_id" => "token_classification"} = attrs) do opts = if(aggregation = attrs["aggregation"], - do: [aggregation: aggregation], + do: [aggregation: String.to_atom(aggregation)], else: [] ) ++ [compile: [batch_size: 1, sequence_length: attrs["sequence_length"]]] ++ diff --git a/test/kino_bumblebee/task_cell_test.exs b/test/kino_bumblebee/task_cell_test.exs index 6d62fa2..1de9265 100644 --- a/test/kino_bumblebee/task_cell_test.exs +++ b/test/kino_bumblebee/task_cell_test.exs @@ -91,6 +91,18 @@ defmodule KinoBumblebee.TaskCellTest do Kino.Layout.grid([form, frame], boxed: true, gap: 16)\ """ + + attrs = %{ + "task_id" => "token_classification", + "variant_id" => "bert_base_cased_ner", + "aggregation" => "same", + "sequence_length" => 100, + "compiler" => "exla" + } + + {_kino, source} = start_smart_cell!(TaskCell, attrs) + + assert source =~ "aggregation: :same" end end