-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add KNNImputer #303
base: main
Are you sure you want to change the base?
Add KNNImputer #303
Changes from all commits
d6c7a55
47b4a65
eb8f245
642b15e
520633a
926a1c7
a3e0eba
4757dd1
108475d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,247 @@ | ||||||||||||
defmodule Scholar.Impute.KNNImputter do | ||||||||||||
@moduledoc """ | ||||||||||||
Imputer for completing missing values using k-Nearest Neighbors. | ||||||||||||
|
||||||||||||
Each sample's missing values are imputed using the mean value from | ||||||||||||
`n_neighbors` nearest neighbors found in the training set. Two samples are | ||||||||||||
close if the features that neither is missing are close. | ||||||||||||
""" | ||||||||||||
import Nx.Defn | ||||||||||||
import Scholar.Metrics.Distance | ||||||||||||
|
||||||||||||
@derive {Nx.Container, keep: [:missing_values], containers: [:statistics]} | ||||||||||||
defstruct [:statistics, :missing_values] | ||||||||||||
|
||||||||||||
opts_schema = [ | ||||||||||||
missing_values: [ | ||||||||||||
type: {:or, [:float, :integer, {:in, [:nan]}]}, | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I believe this should allow :infinity and :neg_infinity too for completeness |
||||||||||||
default: :nan, | ||||||||||||
doc: ~S""" | ||||||||||||
The placeholder for the missing values. All occurrences of `:missing_values` will be imputed. | ||||||||||||
|
||||||||||||
The default value expects there are no NaNs in the input tensor. | ||||||||||||
""" | ||||||||||||
], | ||||||||||||
number_of_neighbors: [ | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest changing this to |
||||||||||||
type: :pos_integer, | ||||||||||||
default: 2, | ||||||||||||
doc: "The number of nearest neighbors." | ||||||||||||
] | ||||||||||||
] | ||||||||||||
|
||||||||||||
@opts_schema NimbleOptions.new!(opts_schema) | ||||||||||||
|
||||||||||||
@doc """ | ||||||||||||
Imputter for completing missing values using k-Nearest Neighbors. | ||||||||||||
|
||||||||||||
Preconditions: | ||||||||||||
* `number_of_neighbors` is a positive integer. | ||||||||||||
* number of neighbors must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value) otherwise it is better to use simple imputter | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please try to break this long line :)
Comment on lines
+38
to
+39
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
* when you set a value different than :nan in `missing_values` there should be no NaNs in the input tensor | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
||||||||||||
## Options | ||||||||||||
|
||||||||||||
#{NimbleOptions.docs(@opts_schema)} | ||||||||||||
|
||||||||||||
## Return Values | ||||||||||||
|
||||||||||||
The function returns a struct with the following parameters: | ||||||||||||
|
||||||||||||
* `:missing_values` - the same value as in `:missing_values` | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
||||||||||||
* `:statistics` - The imputation fill value for each feature. Computing statistics can result in | ||||||||||||
[`Nx.Constant.nan/0`](https://hexdocs.pm/nx/Nx.Constants.html#nan/0) values. | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you need the explicit linking in hexdoc? |
||||||||||||
|
||||||||||||
## Examples | ||||||||||||
|
||||||||||||
iex> x = Nx.tensor([[40.0, 2.0],[4.0, 5.0],[7.0, :nan],[:nan, 8.0],[11.0, 11.0]]) | ||||||||||||
iex> Scholar.Impute.KNNImputter.fit(x, number_of_neighbors: 2) | ||||||||||||
%Scholar.Impute.KNNImputter{ | ||||||||||||
statistics: Nx.tensor( | ||||||||||||
[ | ||||||||||||
[:nan, :nan], | ||||||||||||
[:nan, :nan], | ||||||||||||
[:nan, 8.0], | ||||||||||||
[7.5, :nan], | ||||||||||||
[:nan, :nan] | ||||||||||||
] | ||||||||||||
), | ||||||||||||
missing_values: :nan | ||||||||||||
} | ||||||||||||
|
||||||||||||
""" | ||||||||||||
|
||||||||||||
deftransform fit(x, opts \\ []) do | ||||||||||||
opts = NimbleOptions.validate!(opts, @opts_schema) | ||||||||||||
|
||||||||||||
input_rank = Nx.rank(x) | ||||||||||||
|
||||||||||||
if input_rank != 2 do | ||||||||||||
raise ArgumentError, "Wrong input rank. Expected: 2, got: #{inspect(input_rank)}" | ||||||||||||
end | ||||||||||||
|
||||||||||||
x = | ||||||||||||
if opts[:missing_values] != :nan, | ||||||||||||
do: Nx.select(Nx.equal(x, opts[:missing_values]), Nx.Constants.nan(), x), | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should be able to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a deftransform, so |
||||||||||||
else: x | ||||||||||||
|
||||||||||||
num_neighbors = opts[:number_of_neighbors] | ||||||||||||
|
||||||||||||
placeholder_value = Nx.Constants.nan() |> Nx.tensor() | ||||||||||||
josevalim marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you probably want to pass the input type here to avoid upcasts |
||||||||||||
|
||||||||||||
statistics = knn_impute(x, placeholder_value, num_neighbors: num_neighbors) | ||||||||||||
missing_values = opts[:missing_values] | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would move this line above so that you don't access |
||||||||||||
%__MODULE__{statistics: statistics, missing_values: missing_values} | ||||||||||||
end | ||||||||||||
|
||||||||||||
@doc """ | ||||||||||||
Impute all missing values in `x` using fitted imputer. | ||||||||||||
|
||||||||||||
## Return Values | ||||||||||||
|
||||||||||||
The function returns input tensor with NaN replaced with values saved in fitted imputer. | ||||||||||||
|
||||||||||||
## Examples | ||||||||||||
|
||||||||||||
iex> x = Nx.tensor([[40.0, 2.0],[4.0, 5.0],[7.0, :nan],[:nan, 8.0],[11.0, 11.0]]) | ||||||||||||
iex> imputer = Scholar.Impute.KNNImputter.fit(x, number_of_neighbors: 2) | ||||||||||||
iex> Scholar.Impute.KNNImputter.transform(imputer, x) | ||||||||||||
Nx.tensor( | ||||||||||||
[ | ||||||||||||
[40.0, 2.0], | ||||||||||||
[4.0, 5.0], | ||||||||||||
[7.0, 8.0], | ||||||||||||
[7.5, 8.0], | ||||||||||||
[11.0, 11.0] | ||||||||||||
] | ||||||||||||
) | ||||||||||||
""" | ||||||||||||
deftransform transform(%__MODULE__{statistics: statistics, missing_values: missing_values}, x) do | ||||||||||||
mask = if missing_values == :nan, do: Nx.is_nan(x), else: Nx.equal(x, missing_values) | ||||||||||||
Nx.select(mask, statistics, x) | ||||||||||||
end | ||||||||||||
|
||||||||||||
defnp knn_impute(x, placeholder_value, opts \\ []) do | ||||||||||||
mask = Nx.is_nan(x) | ||||||||||||
{num_rows, num_cols} = Nx.shape(x) | ||||||||||||
num_neighbors = opts[:num_neighbors] | ||||||||||||
|
||||||||||||
values_to_impute = Nx.broadcast(placeholder_value, x) | ||||||||||||
|
||||||||||||
{_, values_to_impute} = | ||||||||||||
while {{row = 0, mask, num_neighbors, num_rows, x}, values_to_impute}, | ||||||||||||
Nx.less(row, num_rows) do | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use |
||||||||||||
{_, values_to_impute} = | ||||||||||||
while {{col = 0, mask, num_neighbors, num_cols, row, x}, values_to_impute}, | ||||||||||||
Nx.less(col, num_cols) do | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. |
||||||||||||
if mask[row][col] > 0 do | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||||||||||||
{rows, cols} = Nx.shape(x) | ||||||||||||
|
||||||||||||
neighbor_avg = | ||||||||||||
calculate_knn(x, row, col, rows: rows, num_neighbors: opts[:num_neighbors]) | ||||||||||||
|
||||||||||||
indices = | ||||||||||||
[Nx.stack(row), Nx.stack(col)] | ||||||||||||
|> Nx.concatenate() | ||||||||||||
|> Nx.stack() | ||||||||||||
Comment on lines
+143
to
+146
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
If I read the code correctly, row and col are scalars and this should yield the same result |
||||||||||||
|
||||||||||||
values_to_impute = Nx.indexed_put(values_to_impute, indices, Nx.stack(neighbor_avg)) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think this is even simpler |
||||||||||||
{{col + 1, mask, num_neighbors, cols, row, x}, values_to_impute} | ||||||||||||
else | ||||||||||||
{{col + 1, mask, num_neighbors, num_cols, row, x}, values_to_impute} | ||||||||||||
end | ||||||||||||
end | ||||||||||||
|
||||||||||||
{{row + 1, mask, num_neighbors, num_rows, x}, values_to_impute} | ||||||||||||
end | ||||||||||||
|
||||||||||||
values_to_impute | ||||||||||||
end | ||||||||||||
|
||||||||||||
defnp calculate_knn(x, nan_row, nan_col, opts \\ []) do | ||||||||||||
opts = keyword!(opts, rows: 1, num_neighbors: 2) | ||||||||||||
rows = opts[:rows] | ||||||||||||
num_neighbors = opts[:num_neighbors] | ||||||||||||
|
||||||||||||
row_distances = Nx.iota({rows}, type: {:f, 32}) | ||||||||||||
|
||||||||||||
row_with_value_to_fill = x[nan_row] | ||||||||||||
|
||||||||||||
# calculate distance between row with nan to fill and all other rows where distance | ||||||||||||
# to the row is under its index in the tensor | ||||||||||||
{_, row_distances} = | ||||||||||||
while {{i = 0, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances}, | ||||||||||||
Nx.less(i, rows) do | ||||||||||||
potential_donor = x[i] | ||||||||||||
|
||||||||||||
distance = | ||||||||||||
if i == nan_row do | ||||||||||||
Nx.Constants.infinity(Nx.type(row_with_value_to_fill)) | ||||||||||||
else | ||||||||||||
nan_euclidian(row_with_value_to_fill, nan_col, potential_donor) | ||||||||||||
end | ||||||||||||
|
||||||||||||
row_distances = Nx.indexed_put(row_distances, Nx.new_axis(i, 0), distance) | ||||||||||||
{{i + 1, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances} | ||||||||||||
end | ||||||||||||
Comment on lines
+172
to
+186
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. try this: potential_donors = Nx.vectorize(x, :rows)
distances = nan_euclidean(row_with_value_to_fill, nan_col, potential_donors) |> Nx.devectorize()
row_distances = Nx.indexed_put(distances, [i], Nx.Constants.infinity()) |
||||||||||||
|
||||||||||||
{_, indices} = Nx.top_k(-row_distances, k: num_neighbors) | ||||||||||||
|
||||||||||||
gather_indices = Nx.stack([indices, Nx.broadcast(nan_col, indices)], axis: 1) | ||||||||||||
values = Nx.gather(x, gather_indices) | ||||||||||||
Nx.sum(values) / num_neighbors | ||||||||||||
end | ||||||||||||
|
||||||||||||
# nan_col is the column of the value to impute | ||||||||||||
defnp nan_euclidian(row, nan_col, potential_neighbor) do | ||||||||||||
{coordinates} = Nx.shape(row) | ||||||||||||
|
||||||||||||
# minus nan column | ||||||||||||
coordinates = coordinates - 1 | ||||||||||||
|
||||||||||||
# inputes zeros in nan_col to calculate distance with squared_euclidean | ||||||||||||
new_row = Nx.indexed_put(row, Nx.new_axis(nan_col, 0), Nx.tensor(0)) | ||||||||||||
|
||||||||||||
# if potential neighbor has nan in nan_col, we don't want to calculate distance and the case if potential_neighbour is the row to impute | ||||||||||||
{potential_neighbor} = | ||||||||||||
if Nx.is_nan(potential_neighbor[nan_col]) do | ||||||||||||
potential_neighbor = | ||||||||||||
Nx.broadcast(Nx.Constants.infinity(Nx.type(potential_neighbor)), potential_neighbor) | ||||||||||||
|
||||||||||||
{potential_neighbor} | ||||||||||||
else | ||||||||||||
# inputes zeros in nan_col to calculate distance with squared_euclidean - distance will be 0 so no change to the distance value | ||||||||||||
potential_neighbor = | ||||||||||||
Nx.indexed_put( | ||||||||||||
potential_neighbor, | ||||||||||||
Nx.new_axis(nan_col, 0), | ||||||||||||
Nx.tensor(0, type: Nx.type(row)) | ||||||||||||
) | ||||||||||||
|
||||||||||||
{potential_neighbor} | ||||||||||||
end | ||||||||||||
|
||||||||||||
# calculates how many values are present in the row without nan_col to calculate weight for the distance | ||||||||||||
present_coordinates = Nx.sum(Nx.logical_not(Nx.is_nan(potential_neighbor))) - 1 | ||||||||||||
|
||||||||||||
# if row has all nans we skip it | ||||||||||||
{weight, potential_neighbor} = | ||||||||||||
if present_coordinates == 0 do | ||||||||||||
potential_neighbor = | ||||||||||||
Nx.broadcast(Nx.Constants.infinity(Nx.type(potential_neighbor)), potential_neighbor) | ||||||||||||
|
||||||||||||
weight = 0 | ||||||||||||
{weight, potential_neighbor} | ||||||||||||
else | ||||||||||||
potential_neighbor = Nx.select(Nx.is_nan(potential_neighbor), new_row, potential_neighbor) | ||||||||||||
weight = coordinates / present_coordinates | ||||||||||||
{weight, potential_neighbor} | ||||||||||||
end | ||||||||||||
|
||||||||||||
# calculating weighted euclidian distance | ||||||||||||
distance = Nx.sqrt(weight * squared_euclidean(new_row, potential_neighbor)) | ||||||||||||
|
||||||||||||
# return inf if potential_row is row to impute | ||||||||||||
Nx.select(Nx.is_nan(distance), Nx.Constants.infinity(Nx.type(distance)), distance) | ||||||||||||
end | ||||||||||||
end |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,128 @@ | ||||||||||||||||||||||||||||||||||||||||||||||
defmodule KNNImputterTest do | ||||||||||||||||||||||||||||||||||||||||||||||
use Scholar.Case, async: true | ||||||||||||||||||||||||||||||||||||||||||||||
alias Scholar.Impute.KNNImputter | ||||||||||||||||||||||||||||||||||||||||||||||
doctest KNNImputter | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
describe "general cases" do | ||||||||||||||||||||||||||||||||||||||||||||||
def generate_data() do | ||||||||||||||||||||||||||||||||||||||||||||||
x = Nx.iota({5, 4}) | ||||||||||||||||||||||||||||||||||||||||||||||
x = Nx.select(Nx.equal(Nx.quotient(x, 5), 2), Nx.Constants.nan(), x) | ||||||||||||||||||||||||||||||||||||||||||||||
Nx.indexed_put(x, Nx.tensor([[4, 2]]), Nx.tensor([6.0])) | ||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
test "general KNN imputer" do | ||||||||||||||||||||||||||||||||||||||||||||||
x = generate_data() | ||||||||||||||||||||||||||||||||||||||||||||||
jit_fit = Nx.Defn.jit(&KNNImputter.fit/2) | ||||||||||||||||||||||||||||||||||||||||||||||
jit_transform = Nx.Defn.jit(&KNNImputter.transform/2) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
knn_imputer = | ||||||||||||||||||||||||||||||||||||||||||||||
%KNNImputter{statistics: statistics, missing_values: missing_values} = | ||||||||||||||||||||||||||||||||||||||||||||||
jit_fit.(x, missing_values: :nan, number_of_neighbors: 2) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
assert missing_values == :nan | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
assert statistics == | ||||||||||||||||||||||||||||||||||||||||||||||
Nx.tensor([ | ||||||||||||||||||||||||||||||||||||||||||||||
[:nan, :nan, :nan, :nan], | ||||||||||||||||||||||||||||||||||||||||||||||
[:nan, :nan, :nan, :nan], | ||||||||||||||||||||||||||||||||||||||||||||||
[:nan, :nan, 4.0, 5.0], | ||||||||||||||||||||||||||||||||||||||||||||||
[2.0, 3.0, 4.0, :nan], | ||||||||||||||||||||||||||||||||||||||||||||||
[:nan, :nan, :nan, :nan] | ||||||||||||||||||||||||||||||||||||||||||||||
]) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
assert jit_transform.(knn_imputer, x) == | ||||||||||||||||||||||||||||||||||||||||||||||
Nx.tensor([ | ||||||||||||||||||||||||||||||||||||||||||||||
[0.0, 1.0, 2.0, 3.0], | ||||||||||||||||||||||||||||||||||||||||||||||
[4.0, 5.0, 6.0, 7.0], | ||||||||||||||||||||||||||||||||||||||||||||||
[8.0, 9.0, 4.0, 5.0], | ||||||||||||||||||||||||||||||||||||||||||||||
[2.0, 3.0, 4.0, 15.0], | ||||||||||||||||||||||||||||||||||||||||||||||
[16.0, 17.0, 6.0, 19.0] | ||||||||||||||||||||||||||||||||||||||||||||||
]) | ||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
test "general KNN imputer with different number of neighbors" do | ||||||||||||||||||||||||||||||||||||||||||||||
x = generate_data() | ||||||||||||||||||||||||||||||||||||||||||||||
jit_fit = Nx.Defn.jit(&KNNImputter.fit/2) | ||||||||||||||||||||||||||||||||||||||||||||||
jit_transform = Nx.Defn.jit(&KNNImputter.transform/2) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
knn_imputter = | ||||||||||||||||||||||||||||||||||||||||||||||
%KNNImputter{statistics: statistics, missing_values: missing_values} = | ||||||||||||||||||||||||||||||||||||||||||||||
jit_fit.(x, missing_values: :nan, number_of_neighbors: 1) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
assert missing_values == :nan | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
assert statistics == | ||||||||||||||||||||||||||||||||||||||||||||||
Nx.tensor([ | ||||||||||||||||||||||||||||||||||||||||||||||
[:nan, :nan, :nan, :nan], | ||||||||||||||||||||||||||||||||||||||||||||||
[:nan, :nan, :nan, :nan], | ||||||||||||||||||||||||||||||||||||||||||||||
[:nan, :nan, 2.0, 3.0], | ||||||||||||||||||||||||||||||||||||||||||||||
[0.0, 1.0, 2.0, :nan], | ||||||||||||||||||||||||||||||||||||||||||||||
[:nan, :nan, :nan, :nan] | ||||||||||||||||||||||||||||||||||||||||||||||
]) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
assert jit_transform.(knn_imputter, x) == | ||||||||||||||||||||||||||||||||||||||||||||||
Nx.tensor([ | ||||||||||||||||||||||||||||||||||||||||||||||
[0.0, 1.0, 2.0, 3.0], | ||||||||||||||||||||||||||||||||||||||||||||||
[4.0, 5.0, 6.0, 7.0], | ||||||||||||||||||||||||||||||||||||||||||||||
[8.0, 9.0, 2.0, 3.0], | ||||||||||||||||||||||||||||||||||||||||||||||
[0.0, 1.0, 2.0, 15.0], | ||||||||||||||||||||||||||||||||||||||||||||||
[16.0, 17.0, 6.0, 19.0] | ||||||||||||||||||||||||||||||||||||||||||||||
]) | ||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
test "missing values different than :nan" do | ||||||||||||||||||||||||||||||||||||||||||||||
x = generate_data() | ||||||||||||||||||||||||||||||||||||||||||||||
x = Nx.select(Nx.is_nan(x), Nx.tensor(19.0), x) | ||||||||||||||||||||||||||||||||||||||||||||||
jit_fit = Nx.Defn.jit(&KNNImputter.fit/2) | ||||||||||||||||||||||||||||||||||||||||||||||
jit_transform = Nx.Defn.jit(&KNNImputter.transform/2) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
knn_imputter = | ||||||||||||||||||||||||||||||||||||||||||||||
%KNNImputter{statistics: statistics, missing_values: missing_values} = | ||||||||||||||||||||||||||||||||||||||||||||||
jit_fit.(x, missing_values: 19.0, number_of_neighbors: 2) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
assert missing_values == 19.0 | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
assert statistics == | ||||||||||||||||||||||||||||||||||||||||||||||
Nx.tensor([ | ||||||||||||||||||||||||||||||||||||||||||||||
[:nan, :nan, :nan, :nan], | ||||||||||||||||||||||||||||||||||||||||||||||
[:nan, :nan, :nan, :nan], | ||||||||||||||||||||||||||||||||||||||||||||||
[:nan, :nan, 4.0, 5.0], | ||||||||||||||||||||||||||||||||||||||||||||||
[2.0, 3.0, 4.0, :nan], | ||||||||||||||||||||||||||||||||||||||||||||||
[:nan, :nan, :nan, 5.0] | ||||||||||||||||||||||||||||||||||||||||||||||
]) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
assert jit_transform.(knn_imputter, x) == | ||||||||||||||||||||||||||||||||||||||||||||||
Nx.tensor([ | ||||||||||||||||||||||||||||||||||||||||||||||
[0.0, 1.0, 2.0, 3.0], | ||||||||||||||||||||||||||||||||||||||||||||||
[4.0, 5.0, 6.0, 7.0], | ||||||||||||||||||||||||||||||||||||||||||||||
[8.0, 9.0, 4.0, 5.0], | ||||||||||||||||||||||||||||||||||||||||||||||
[2.0, 3.0, 4.0, 15.0], | ||||||||||||||||||||||||||||||||||||||||||||||
[16.0, 17.0, 6.0, 5.0] | ||||||||||||||||||||||||||||||||||||||||||||||
]) | ||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
describe "errors" do | ||||||||||||||||||||||||||||||||||||||||||||||
test "Wrong impute rank" do | ||||||||||||||||||||||||||||||||||||||||||||||
x = Nx.tensor([1, 2, 2, 3]) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
assert_raise ArgumentError, | ||||||||||||||||||||||||||||||||||||||||||||||
"Wrong input rank. Expected: 2, got: 1", | ||||||||||||||||||||||||||||||||||||||||||||||
fn -> | ||||||||||||||||||||||||||||||||||||||||||||||
KNNImputter.fit(x, missing_values: 1, number_of_neighbors: 2) | ||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
test "Invalid n_neighbors value" do | ||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+106
to
+116
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test names start in lowercase :)
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||
x = generate_data() | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
jit_fit = Nx.Defn.jit(&KNNImputter.fit/2) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
assert_raise NimbleOptions.ValidationError, | ||||||||||||||||||||||||||||||||||||||||||||||
"invalid value for :number_of_neighbors option: expected positive integer, got: -1", | ||||||||||||||||||||||||||||||||||||||||||||||
fn -> | ||||||||||||||||||||||||||||||||||||||||||||||
jit_fit.(x, missing_values: 1.0, number_of_neighbors: -1) | ||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.