Skip to content

Commit

Permalink
feat: add ml.metrics.mean_squared_error (#559)
Browse files Browse the repository at this point in the history
* feat: add ml.metrics.mean_squared_error

* fix docs

* fix docs
  • Loading branch information
GarrettWu authored Apr 3, 2024
1 parent 90bcec5 commit 853c25e
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 69 deletions.
2 changes: 2 additions & 0 deletions bigframes/ml/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
auc,
confusion_matrix,
f1_score,
mean_squared_error,
precision_score,
r2_score,
recall_score,
Expand All @@ -35,5 +36,6 @@
"confusion_matrix",
"precision_score",
"f1_score",
"mean_squared_error",
"pairwise",
]
14 changes: 14 additions & 0 deletions bigframes/ml/metrics/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,17 @@ def f1_score(


f1_score.__doc__ = inspect.getdoc(vendored_metrics_classification.f1_score)


def mean_squared_error(
y_true: Union[bpd.DataFrame, bpd.Series],
y_pred: Union[bpd.DataFrame, bpd.Series],
) -> float:
y_true_series, y_pred_series = utils.convert_to_series(y_true, y_pred)

return (y_pred_series - y_true_series).pow(2).sum() / len(y_true_series)


mean_squared_error.__doc__ = inspect.getdoc(
vendored_metrics_regression.mean_squared_error
)
Loading

0 comments on commit 853c25e

Please sign in to comment.