Skip to content

Commit

Permalink
Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexscordellis committed Aug 31, 2024
1 parent 0204761 commit 6762e03
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pub fn compute_feature_metrics(
// Convert xtx_reg to a faer matrix
let xtx_faer = xtx_reg.view().into_faer();

let mut nans = Array::zeros(features.len());
let mut nans = Array::zeros(features.ncols());
nans.fill(f64::NAN);

// Compute X^T y
Expand Down
25 changes: 25 additions & 0 deletions tests/test_ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,31 @@ def test_least_squares_statistics():
assert np.allclose(df_stats["p_values"], res.pvalues)


def test_least_squares_statistics_fit_fails():
df = _make_data()
statistics = df.with_columns(pl.lit(0.0).alias("x0")).select(
pl.col("y").least_squares.ols(cs.starts_with("x"), mode="statistics", add_intercept=True)
)
statistics = statistics.unnest("statistics")

assert np.isnan(statistics["r2"].item())
assert np.isnan(statistics["mse"].item())

df_stats = (
statistics.explode(
["feature_names", "coefficients", "standard_errors", "t_values", "p_values"]
)
.to_pandas()
.set_index("feature_names")
.rename(index={"const": "Intercept"})
)

assert np.all(~np.isfinite(df_stats["coefficients"]))
assert np.all(~np.isfinite(df_stats["standard_errors"]))
assert np.all(~np.isfinite(df_stats["t_values"]))
assert np.all(~np.isfinite(df_stats["p_values"]))


def test_predict_formula():
df = _make_data()
df = (
Expand Down

0 comments on commit 6762e03

Please sign in to comment.