-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mean_tweedie_deviance and particular cases #193
Add mean_tweedie_deviance and particular cases #193
Conversation
Particular cases: - mean_poisson_deviance - mean_gamma_deviance And update mean_square_error to use mean_tweedie_deviance as well
lib/scholar/metrics/regression.ex
Outdated
case check_tweedie_deviance_power(y_true, y_pred, power) |> Nx.to_number() do | ||
2 -> raise message <> "strictly positive y_pred." | ||
4 -> raise message <> "non-negative y_true and strictly positive y_pred." | ||
5 -> raise message <> "strictly positive y_true and strictly positive y_pred." | ||
100 -> raise "Something went wrong, branch should never appear." | ||
1 -> :ok | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We cannot use runtime checks, because they would fail when function would be called in another defn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can test this by calling Nx.Defn.jit(&mean_tweedie_deviance/3).(y_true, y_pred, 1)
. You will see that jit returns an expression and you can't convert it to number.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean this:
iex> Nx.Defn.jit(&mean_tweedie_deviance/3).(y_true, y_pred, 1)
** (Protocol.UndefinedError) protocol String.Chars not implemented for #Nx.Tensor<
s64
Nx.Defn.Expr
parameter a:2 s64
> of type Nx.Tensor (a struct). This protocol is implemented for the following type(s): Atom, BitString, Complex, Date, DateTime, Float, Hex.Solver.Assignment, Hex.Solver.Constraints.Empty, Hex.Solver.Constraints.Range, Hex.Solver.Constraints.Union, Hex.Solver.Incompatibility, Hex.Solver.PackageRange, Hex.Solver.Term, Integer, List, NaiveDateTime, Time, URI, Version, Version.Requirement
(elixir 1.14.5) lib/string/chars.ex:3: String.Chars.impl_for!/1
(elixir 1.14.5) lib/string/chars.ex:22: String.Chars.to_string/1
(scholar 0.2.1) lib/scholar/metrics/regression.ex:161: Scholar.Metrics.Regression.mean_tweedie_deviance/3
(nx 0.7.0-dev) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
(nx 0.7.0-dev) lib/nx/defn/evaluator.ex:82: Nx.Defn.Evaluator.precompile/3
(nx 0.7.0-dev) lib/nx/defn/evaluator.ex:60: Nx.Defn.Evaluator.__compile__/4
(nx 0.7.0-dev) lib/nx/defn/evaluator.ex:53: Nx.Defn.Evaluator.__jit__/5
iex:4: (file)
I thought that by using the deftransform
, it could be done. That assumption from reading the docs, was the ideia behing the check_tweedie_deviance_power
returning a tensor "error value". But from what I understand now, what you mean is that this happens when a defn
function calls the mean_tweedie_deviance
defined with deftransform
.
If so, do you propose any idea/solution for it. In this case, the values need to be checked in order to guarantee that the results "can be calculated" as to have probabilistic meaning.
PS: For the alternative, would it be feasible to return tensor of different shapes depending on the error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to imagine that defn
is running inside the GPU. Inside the GPU you can't have different shapes and you can't have exceptions. The best you can do is return NaN to signal that something went wrong OR have a separate function, that is a regular def
, that receives a tensor to check it. In this case, the best is to assume you receive valid inputs and, if not, the value is unspecified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having the checks might be useful, at least for sanity check. I've added an option to allow for that check to happen in a def
function, which by default is false (no check is made). Which means that by default, if the incorrect inputs for the given power value are passed to these functions, there is no guarantee of a result (with some NaNs appearing). Let me know if this can be done like this, or if you have any other concerns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was following the opts approach, but if possible I agree. In that case would mean_tweedie_deviance!
be a def
function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it would be a def
!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should a note be added, for mean_tweedie_deviance!
, as in Nx.to_number/1?
Note: This function cannot be used in defn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be great, yeah!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, added it as well.
Nx.pow(max(y_true, 0), 2 - power) / ((1 - power) * (2 - power)) | ||
-y_true * Nx.pow(y_pred, 1 - power) / (1 - power) | ||
+Nx.pow(y_pred, 2 - power) / (2 - power) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove this clause as it is the same as the last?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, they are not the same. In the first one, there is max(y_true,0)
lib/scholar/metrics/regression.ex
Outdated
Nx.mean(deviance) | ||
end | ||
|
||
defp check_tweedie_deviance_power(y_true, y_pred, power) when is_number(power) do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I liked your previous implementation more, where you did the checking in Nx, and returned an integer. :)
💚 💙 💜 💛 ❤️ |
This PR also adds the particular cases:
mean_poisson_deviance
mean_gamma_deviance
And updates
mean_square_error
to usemean_tweedie_deviance
as well.