Skip to content

Commit

Permalink
[WB-6749] Add yea test demonstrating sklearn fix (#2826)
Browse files Browse the repository at this point in the history
  • Loading branch information
raubitsj authored Oct 26, 2021
1 parent 391d414 commit 6530fa5
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
2 changes: 2 additions & 0 deletions functional_tests/lightning/train-ddp.yea
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
id: 0.lightning.ddp
plugin:
- wandb
tag:
shard: service
command:
program: train-ddp.py
depend:
Expand Down
66 changes: 66 additions & 0 deletions functional_tests/sklearn/01-plot-calibration-curve-nonbinary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python
"""Demonstrate non-binary plot calibration curve failure
Reproduction for WB-6749.
---
id: 0.sklearn.01-plot-calibration-curve-nonbinary
plugin:
- wandb
depend:
files:
- file: wine.csv
source: https://raw.githubusercontent.com/wandb/examples/master/examples/data/wine.csv
assert:
- :wandb:runs_len: 1
- :wandb:runs[0][exitcode]: 1
- :yea:exit: 1
- :op:contains_regex:
- :wandb:runs[0][output][stderr]
- This function only supports binary classification at the moment and therefore expects labels to be binary
- :op:contains:
- :wandb:runs[0][telemetry][1] # imports before
- 5 # sklearn
- :op:contains:
- :wandb:runs[0][telemetry][2] # imports after
- 5 # sklearn
"""

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import wandb

# yea test will grab this
# data_url = "https://raw.githubusercontent.com/wandb/examples/master/examples/data/wine.csv"
# !wget {data_url} -O "wine.csv"

# Load data
wine_quality = pd.read_csv("wine.csv")
y = wine_quality["quality"]
y = y.values
X = wine_quality.drop(["quality"], axis=1)
X = X.values
feature_names = wine_quality.columns

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
labels = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten']

# Train model, get predictions
model = RandomForestClassifier()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
y_probas = model.predict_proba(X_test)
importances = model.feature_importances_
indices = np.argsort(importances)[::-1]

print(model.n_features_)

run = wandb.init(project='my-scikit-integration')

wandb.sklearn.plot_calibration_curve(model, X_train, y_train, 'RandomForestClassifier')

print(model.n_features_)

outs = model.predict(X_train)

0 comments on commit 6530fa5

Please sign in to comment.