Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mean_tweedie_deviance and particular cases #193

Merged
merged 5 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 165 additions & 2 deletions lib/scholar/metrics/regression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ defmodule Scholar.Metrics.Regression do
>
"""
defn mean_square_error(y_true, y_pred) do
diff = y_true - y_pred
(diff * diff) |> Nx.mean()
mean_tweedie_deviance_n(y_true, y_pred, 0)
end

@doc ~S"""
Expand Down Expand Up @@ -133,6 +132,170 @@ defmodule Scholar.Metrics.Regression do
|> Nx.mean()
end

@doc """
Calculates the mean Tweedie deviance of predictions
with respect to targets. Includes the Gaussian, Poisson,
Gamma and inverse-Gaussian families as special cases.

#{~S'''
$$d(y,\mu) =
\begin{cases}
(y-\mu)^2, & \text{for }p=0\\\\
2(y \log(y/\mu) + \mu - y), & \text{for }p=1\\\\
2(\log(\mu/y) + y/\mu - 1), & \text{for }p=2\\\\
2\left(\frac{\max(y,0)^{2-p}}{(1-p)(2-p)}-\frac{y\mu^{1-p}}{1-p}+\frac{\mu^{2-p}}{2-p}\right), & \text{for }p<0 \vee p>2
\end{cases}$$
'''}

## Examples

iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32)
iex> y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32)
iex> Scholar.Metrics.Regression.mean_tweedie_deviance(y_true, y_pred, 1)
#Nx.Tensor<
f32
0.18411168456077576
>
"""
defn mean_tweedie_deviance(y_true, y_pred, power) do
mean_tweedie_deviance_n(y_true, y_pred, power)
end

@doc """
Similar to `mean_tweedie_deviance/3` but raises `RuntimeError` if the
inputs cannot be used with the given power argument.

Note: This function cannot be used in `defn`.

## Examples

iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32)
iex> y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32)
iex> Scholar.Metrics.Regression.mean_tweedie_deviance!(y_true, y_pred, 1)
#Nx.Tensor<
f32
0.18411168456077576
>
"""
def mean_tweedie_deviance!(y_true, y_pred, power) do
message = "mean Tweedie deviance with power=#{power} can only be used on "

case check_tweedie_deviance_power(y_true, y_pred, power) |> Nx.to_number() do
1 -> :ok
2 -> raise message <> "strictly positive y_pred"
4 -> raise message <> "non-negative y_true and strictly positive y_pred"
5 -> raise message <> "strictly positive y_true and strictly positive y_pred"
100 -> raise "something went wrong, branch should never appear"
end

mean_tweedie_deviance_n(y_true, y_pred, power)
end

defnp mean_tweedie_deviance_n(y_true, y_pred, power) do
deviance =
cond do
power < 0 ->
2 *
(
Nx.pow(max(y_true, 0), 2 - power) / ((1 - power) * (2 - power))
-y_true * Nx.pow(y_pred, 1 - power) / (1 - power)
+Nx.pow(y_pred, 2 - power) / (2 - power)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this clause as it is the same as the last?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, they are not the same. In the first one, there is max(y_true,0)


# Normal distribution
power == 0 ->
Nx.pow(y_true - y_pred, 2)

# Poisson distribution
power == 1 ->
2 * (y_true * Nx.log(y_true / y_pred) + y_pred - y_true)

# Gamma distribution
power == 2 ->
2 * (Nx.log(y_pred / y_true) + y_true / y_pred - 1)

# 1 < power < 2 -> Compound Poisson distribution, non-negative with mass at zero
# power == 3 -> Inverse-Gaussian distribution
# power > 2 -> Stable distribution, with support on the positive reals
true ->
2 *
(
Nx.pow(y_true, 2 - power) / ((1 - power) * (2 - power))
-y_true * Nx.pow(y_pred, 1 - power) / (1 - power)
+Nx.pow(y_pred, 2 - power) / (2 - power)
)
end

Nx.mean(deviance)
end

defnp check_tweedie_deviance_power(y_true, y_pred, power) do
cond do
power < 0 ->
if Nx.all(y_pred > 0) do
Nx.u8(1)
else
Nx.u8(2)
end

power == 0 ->
Nx.u8(1)

power >= 1 and power < 2 ->
if Nx.all(y_true >= 0) and Nx.all(y_pred > 0) do
Nx.u8(1)
else
Nx.u8(4)
end

power >= 2 ->
if Nx.all(y_true > 0) and Nx.all(y_pred > 0) do
Nx.u8(1)
else
Nx.u8(5)
end

true ->
Nx.u8(100)
end
end

@doc """
Calculates the mean Poisson deviance of predictions
with respect to targets.

## Examples

iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32)
iex> y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32)
iex> Scholar.Metrics.Regression.mean_poisson_deviance(y_true, y_pred)
#Nx.Tensor<
f32
0.18411168456077576
>
"""
defn mean_poisson_deviance(y_true, y_pred) do
mean_tweedie_deviance_n(y_true, y_pred, 1)
end

@doc """
Calculates the mean Gamma deviance of predictions
with respect to targets.

## Examples

iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32)
iex> y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32)
iex> Scholar.Metrics.Regression.mean_gamma_deviance(y_true, y_pred)
#Nx.Tensor<
f32
0.115888312458992
>
"""
defn mean_gamma_deviance(y_true, y_pred) do
mean_tweedie_deviance_n(y_true, y_pred, 2)
end

@doc """
Calculates the $R^2$ score of predictions with respect to targets.

Expand Down
52 changes: 52 additions & 0 deletions test/scholar/metrics/regression_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,56 @@ defmodule Scholar.Metrics.RegressionTest do

alias Scholar.Metrics.Regression
doctest Regression

describe "mean_tweedie_deviance!/3" do
test "raise when y_pred <= 0 and power < 0" do
power = -1
y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32)
y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32)

assert_raise RuntimeError, ~r/mean Tweedie deviance/, fn ->
Regression.mean_tweedie_deviance!(y_true, y_pred, power)
end
end

test "raise when y_pred <= 0 and 1 <= power < 2" do
power = 1
y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32)
y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32)

assert_raise RuntimeError, ~r/mean Tweedie deviance/, fn ->
Regression.mean_tweedie_deviance!(y_true, y_pred, power)
end
end

test "raise when y_pred <= 0 and power >= 2" do
power = 2
y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32)
y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32)

assert_raise RuntimeError, ~r/mean Tweedie deviance/, fn ->
Regression.mean_tweedie_deviance!(y_true, y_pred, power)
end
end

test "raise when y_true < 0 and 1 <= power < 2" do
power = 1
y_true = Nx.tensor([-1, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32)
y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32)

assert_raise RuntimeError, ~r/mean Tweedie deviance/, fn ->
Regression.mean_tweedie_deviance!(y_true, y_pred, power)
end
end

test "raise when y_true <= 0 and power >= 2" do
power = 2
y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32)
y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32)

assert_raise RuntimeError, ~r/mean Tweedie deviance/, fn ->
Regression.mean_tweedie_deviance!(y_true, y_pred, power)
end
end
end
end
Loading