-
Notifications
You must be signed in to change notification settings - Fork 603
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
docs(blog): classification metrics on the backend (#10501)
## Description of changes Adding a blog post breaking down how to perform binary classification metrics with Ibis. I did a fair amount of background explanation on these models and these metrics because many Ibis users may not be as familiar with these topics, but we can scale that back if needed and get more to the point. --------- Co-authored-by: Deepyaman Datta <[email protected]> Co-authored-by: Gil Forsyth <[email protected]>
- Loading branch information
1 parent
6e693b7
commit aafb30f
Showing
3 changed files
with
232 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
232 changes: 232 additions & 0 deletions
232
docs/posts/classification-metrics-on-the-backend/index.qmd
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
--- | ||
title: "Classification metrics on the backend" | ||
author: "Tyler White" | ||
date: "2024-12-05" | ||
image: thumbnail.png | ||
categories: | ||
- blog | ||
- machine learning | ||
- portability | ||
--- | ||
|
||
A review of binary classification models, metrics used to evaluate them, and | ||
corresponding metric calculations with Ibis. | ||
|
||
We're going explore common classification metrics such as accuracy, precision, recall, | ||
and F1 score, demonstrating how to compute each one using Ibis. In this example, we'll | ||
use DuckDB, the default Ibis backend, but we could use this same code to execute | ||
against another backend such as Postgres or Snowflake. This capability is useful as it | ||
offers an easy and performant way to evaluate model performance without extracting data | ||
from the source system. | ||
|
||
## Classification models | ||
|
||
In machine learning, classification entails categorizing data into different groups. | ||
Binary classification, which is what we'll be covering in this post, specifically | ||
involves sorting data into only two distinct groups. For example, a model could | ||
differentiate between whether or not an email is spam. | ||
|
||
## Model evaluation | ||
|
||
It's important to validate the performance of the model to ensure it makes correct | ||
predictions consistently and doesn’t only perform well on the data it was trained on. | ||
These metrics help us understand not just the errors the model makes, but also the | ||
types of errors. For example, we might want to know if the model is more likely to | ||
predict a positive outcome when the actual outcome is negative. | ||
|
||
The easiest way to break down how this works is to look at a confusion matrix. | ||
|
||
### Confusion matrix | ||
|
||
A confusion matrix is a table used to describe the performance of a classification | ||
model on a set of data for which the true values are known. As binary classification | ||
only involves two categories, the confusion matrix only contains true positives, false | ||
positives, false negatives, and true negatives. | ||
|
||
![](confusion_matrix.png) | ||
|
||
Here's a breakdown of the terms with examples. | ||
|
||
True Positives (TP) | ||
: Correctly predicted positive examples. | ||
|
||
We guessed it was a spam email, and it was. This email is going straight to the junk | ||
folder. | ||
|
||
False Positives (FP) | ||
: Incorrectly predicted as positive. | ||
|
||
We guessed it was a spam email, but it actually wasn’t. Hopefully, the recipient | ||
doesn’t miss anything important as this email is going to the junk folder. | ||
|
||
False Negatives (FN) | ||
: Incorrectly predicted as negative. | ||
|
||
We didn't guess it was a spam email, but it really was. Hopefully, the recipient | ||
doesn’t click any links! | ||
|
||
True Negatives (TN) | ||
: Correctly predicted negative examples. | ||
|
||
We guessed it was not a spam email, and it actually was not. The recipient can read | ||
this email as intended. | ||
|
||
### Building a confusion matrix | ||
|
||
#### Sample dataset | ||
|
||
Let's create a sample dataset that includes twelve rows with two columns: `actual` and | ||
`prediction`. The `actual` column contains the true values, and the `prediction` column | ||
contains the model's predictions. | ||
|
||
```{python} | ||
from ibis.interactive import * | ||
t = ibis.memtable( | ||
{ | ||
"id": range(1, 13), | ||
"actual": [1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1], | ||
"prediction": [1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1], | ||
} | ||
) | ||
t | ||
``` | ||
|
||
We can use the `case` function to create a new column that categorizes the outcomes. | ||
|
||
```{python} | ||
case_expr = ( | ||
ibis.case() | ||
.when((_.actual == 0) & (_.prediction == 0), "TN") | ||
.when((_.actual == 0) & (_.prediction == 1), "FP") | ||
.when((_.actual == 1) & (_.prediction == 0), "FN") | ||
.when((_.actual == 1) & (_.prediction == 1), "TP") | ||
.end() | ||
) | ||
t = t.mutate(outcome=case_expr) | ||
t | ||
``` | ||
|
||
To create the confusion matrix, we'll group by the outcome, count the occurrences, and | ||
use `pivot_wider`. Widening our data makes it possible to perform column-wise | ||
operations on the table expression for metric calculations. | ||
|
||
```{python} | ||
cm = ( | ||
t.group_by("outcome") | ||
.agg(counted=_.count()) | ||
.pivot_wider(names_from="outcome", values_from="counted") | ||
.select("TP", "FP", "FN", "TN") | ||
) | ||
cm | ||
``` | ||
|
||
We can plot the confusion matrix to visualize the results. | ||
|
||
```{python} | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
data = cm.to_pyarrow().to_pydict() | ||
plt.figure(figsize=(6, 4)) | ||
sns.heatmap( | ||
[[data["TP"][0], data["FP"][0]], [data["FN"][0], data["TN"][0]]], | ||
annot=True, | ||
fmt="d", | ||
cmap="Blues", | ||
cbar=False, | ||
xticklabels=["Predicted Positive", "Predicted Negative"], | ||
yticklabels=["Actual Positive", "Actual Negative"], | ||
) | ||
plt.title("Confusion Matrix") | ||
plt.show() | ||
``` | ||
|
||
Now that we've built a confusion matrix, we're able to more easily calculate a few | ||
common classification metrics. | ||
|
||
### Metrics | ||
|
||
Here are the metrics we'll calculate as well as a brief description of each. | ||
|
||
Accuracy | ||
: The proportion of correct predictions out of all predictions made. This measures the | ||
overall effectiveness of the model across all classes. | ||
|
||
Precision | ||
: The proportion of true positive predictions out of all positive predictions made. | ||
This tells us how many of the predicted positives were actually correct. | ||
|
||
Recall | ||
: The proportion of true positive predictions out of all actual positive examples. This | ||
measures how well the model identifies all actual positives. | ||
|
||
F1 Score | ||
: A metric that combines precision and recall into a single score by taking their | ||
weighted average. This balances the trade-off between precision and recall, making it | ||
especially useful for imbalanced datasets. | ||
|
||
We can calculate these metrics using the columns from the confusion matrix we created | ||
earlier. | ||
|
||
```{python} | ||
accuracy_expr = (_.TP + _.TN) / (_.TP + _.TN + _.FP + _.FN) | ||
precision_expr = _.TP / (_.TP + _.FP) | ||
recall_expr = _.TP / (_.TP + _.FN) | ||
f1_score_expr = 2 * (precision_expr * recall_expr) / (precision_expr + recall_expr) | ||
metrics = cm.select( | ||
accuracy=accuracy_expr, | ||
precision=precision_expr, | ||
recall=recall_expr, | ||
f1_score=f1_score_expr, | ||
) | ||
metrics | ||
``` | ||
|
||
## A more efficient approach | ||
|
||
In the illustrative example above, we used a case expression and pivoted the data to | ||
demonstrate where the values would fall in the confusion matrix before performing our | ||
metric calculations using the pivoted data. We can actually skip this step using column | ||
aggregation. | ||
|
||
```{python} | ||
tp = (t.actual * t.prediction).sum() | ||
fp = t.prediction.sum() - tp | ||
fn = t.actual.sum() - tp | ||
tn = t.actual.count() - tp - fp - fn | ||
accuracy_expr = (t.actual == t.prediction).mean() | ||
precision_expr = tp / t.prediction.sum() | ||
recall_expr = tp / t.actual.sum() | ||
f1_score_expr = 2 * tp / (t.actual.sum() + t.prediction.sum()) | ||
print( | ||
f"accuracy={accuracy_expr.to_pyarrow().as_py()}", | ||
f"precision={precision_expr.to_pyarrow().as_py()}", | ||
f"precision={recall_expr.to_pyarrow().as_py()}", | ||
f"f1_score={f1_score_expr.to_pyarrow().as_py()}", | ||
sep="\n", | ||
) | ||
``` | ||
|
||
## Conclusion | ||
|
||
By pushing the computation down to the backend, the performance is as powerful as the | ||
backend we're connected to. This capability allows us to easily scale to different | ||
backends without modifying any code. | ||
|
||
We hope you give this a try and let us know how it goes. If you have any questions or | ||
feedback, please reach out to us on [GitHub](https://github.com/ibis-project) or | ||
[Zulip](https://ibis-project.zulipchat.com/). |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.