Skip to content

Commit

Permalink
Fix attributes import in the token classification task (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Dec 15, 2022
1 parent 3f15405 commit c3a6e7e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lib/kino_bumblebee/task_cell.ex
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ defmodule KinoBumblebee.TaskCell do
field: "aggregation",
label: "Aggregation",
type: :select,
options: [%{value: nil, label: "None"}, %{value: :same, label: "Same"}],
options: [%{value: nil, label: "None"}, %{value: "same", label: "Same"}],
default: :same
},
%{field: "sequence_length", label: "Max input tokens", type: :number, default: 100}
Expand Down Expand Up @@ -690,7 +690,7 @@ defmodule KinoBumblebee.TaskCell do
defp to_quoted(%{"task_id" => "token_classification"} = attrs) do
opts =
if(aggregation = attrs["aggregation"],
do: [aggregation: aggregation],
do: [aggregation: String.to_atom(aggregation)],
else: []
) ++
[compile: [batch_size: 1, sequence_length: attrs["sequence_length"]]] ++
Expand Down
12 changes: 12 additions & 0 deletions test/kino_bumblebee/task_cell_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ defmodule KinoBumblebee.TaskCellTest do
Kino.Layout.grid([form, frame], boxed: true, gap: 16)\
"""

attrs = %{
"task_id" => "token_classification",
"variant_id" => "bert_base_cased_ner",
"aggregation" => "same",
"sequence_length" => 100,
"compiler" => "exla"
}

{_kino, source} = start_smart_cell!(TaskCell, attrs)

assert source =~ "aggregation: :same"
end
end

Expand Down

0 comments on commit c3a6e7e

Please sign in to comment.