Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mean_tweedie_deviance and particular cases #193

Merged
merged 5 commits into from
Oct 26, 2023

Conversation

0urobor0s
Copy link
Contributor

This PR also adds the particular cases:

  • mean_poisson_deviance
  • mean_gamma_deviance

And updates mean_square_error to use mean_tweedie_deviance as well.

Particular cases:
- mean_poisson_deviance
- mean_gamma_deviance

And update mean_square_error to use mean_tweedie_deviance as well
Comment on lines 163 to 169
case check_tweedie_deviance_power(y_true, y_pred, power) |> Nx.to_number() do
2 -> raise message <> "strictly positive y_pred."
4 -> raise message <> "non-negative y_true and strictly positive y_pred."
5 -> raise message <> "strictly positive y_true and strictly positive y_pred."
100 -> raise "Something went wrong, branch should never appear."
1 -> :ok
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot use runtime checks, because they would fail when function would be called in another defn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can test this by calling Nx.Defn.jit(&mean_tweedie_deviance/3).(y_true, y_pred, 1). You will see that jit returns an expression and you can't convert it to number.

Copy link
Contributor Author

@0urobor0s 0urobor0s Oct 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean this:

iex> Nx.Defn.jit(&mean_tweedie_deviance/3).(y_true, y_pred, 1)
** (Protocol.UndefinedError) protocol String.Chars not implemented for #Nx.Tensor<
  s64

  Nx.Defn.Expr
  parameter a:2   s64
> of type Nx.Tensor (a struct). This protocol is implemented for the following type(s): Atom, BitString, Complex, Date, DateTime, Float, Hex.Solver.Assignment, Hex.Solver.Constraints.Empty, Hex.Solver.Constraints.Range, Hex.Solver.Constraints.Union, Hex.Solver.Incompatibility, Hex.Solver.PackageRange, Hex.Solver.Term, Integer, List, NaiveDateTime, Time, URI, Version, Version.Requirement
    (elixir 1.14.5) lib/string/chars.ex:3: String.Chars.impl_for!/1
    (elixir 1.14.5) lib/string/chars.ex:22: String.Chars.to_string/1
    (scholar 0.2.1) lib/scholar/metrics/regression.ex:161: Scholar.Metrics.Regression.mean_tweedie_deviance/3
    (nx 0.7.0-dev) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
    (nx 0.7.0-dev) lib/nx/defn/evaluator.ex:82: Nx.Defn.Evaluator.precompile/3
    (nx 0.7.0-dev) lib/nx/defn/evaluator.ex:60: Nx.Defn.Evaluator.__compile__/4
    (nx 0.7.0-dev) lib/nx/defn/evaluator.ex:53: Nx.Defn.Evaluator.__jit__/5
    iex:4: (file)

I thought that by using the deftransform, it could be done. That assumption from reading the docs, was the ideia behing the check_tweedie_deviance_power returning a tensor "error value". But from what I understand now, what you mean is that this happens when a defn function calls the mean_tweedie_deviance defined with deftransform.

If so, do you propose any idea/solution for it. In this case, the values need to be checked in order to guarantee that the results "can be calculated" as to have probabilistic meaning.

PS: For the alternative, would it be feasible to return tensor of different shapes depending on the error?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to imagine that defn is running inside the GPU. Inside the GPU you can't have different shapes and you can't have exceptions. The best you can do is return NaN to signal that something went wrong OR have a separate function, that is a regular def, that receives a tensor to check it. In this case, the best is to assume you receive valid inputs and, if not, the value is unspecified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having the checks might be useful, at least for sanity check. I've added an option to allow for that check to happen in a def function, which by default is false (no check is made). Which means that by default, if the incorrect inputs for the given power value are passed to these functions, there is no guarantee of a result (with some NaNs appearing). Let me know if this can be done like this, or if you have any other concerns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following the opts approach, but if possible I agree. In that case would mean_tweedie_deviance! be a def function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it would be a def!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should a note be added, for mean_tweedie_deviance!, as in Nx.to_number/1?
Note: This function cannot be used in defn.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be great, yeah!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, added it as well.

Nx.pow(max(y_true, 0), 2 - power) / ((1 - power) * (2 - power))
-y_true * Nx.pow(y_pred, 1 - power) / (1 - power)
+Nx.pow(y_pred, 2 - power) / (2 - power)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this clause as it is the same as the last?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, they are not the same. In the first one, there is max(y_true,0)

Nx.mean(deviance)
end

defp check_tweedie_deviance_power(y_true, y_pred, power) when is_number(power) do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I liked your previous implementation more, where you did the checking in Nx, and returned an integer. :)

@josevalim josevalim merged commit 33d3bd6 into elixir-nx:main Oct 26, 2023
2 checks passed
@josevalim
Copy link
Contributor

💚 💙 💜 💛 ❤️

@0urobor0s 0urobor0s mentioned this pull request Oct 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants