Skip to content

Commit

Permalink
add test to check if polt work
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickleonardy committed Nov 3, 2023
1 parent a80291a commit 8025f7c
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 2 deletions.
36 changes: 34 additions & 2 deletions tests/evaluation/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def mock_data():
d = {'variable': ['education', 'education', 'education', 'education'],
'label': ['1st-4th', '5th-6th', '7th-8th', '9th'],
'pop_size': [0.002, 0.004, 0.009, 0.019],
'avg_incidence': [0.23, 0.23, 0.23, 0.23],
'incidence': [0.047, 0.0434, 0.054, 0.069]}
'global_avg_target': [0.23, 0.23, 0.23, 0.23],
'avg_target': [0.047, 0.0434, 0.054, 0.069]}
return pd.DataFrame(d)

def mock_preds(n, seed = 505):
Expand All @@ -22,6 +22,11 @@ def mock_preds(n, seed = 505):

return y_true, y_pred






class TestEvaluation:

def test_plot_incidence_with_unsupported_model_type(self):
Expand Down Expand Up @@ -127,3 +132,30 @@ def test_fit_regression(self):
for metric in ["R2", "MAE", "MSE", "RMSE"]:
assert evaluator.scalar_metrics[metric] is not None
assert evaluator.qq is not None

class TestClassificationEvaluator:
y_true, y_pred = mock_preds(50)
y_true = (y_true > 0.5).astype(int) # convert to 0-1 labels

evaluator = ClassificationEvaluator(n_bins=5)
evaluator.fit(y_true, y_pred)

def test_plot_roc_curve(self):
self.evaluator.plot_roc_curve()
def test_plot_confusion_matrix(self):
self.evaluator.plot_confusion_matrix()
def test_plot_cumulative_response_curve(self):
self.evaluator.plot_cumulative_response_curve()
def test_plot_lift_curve(self):
self.evaluator.plot_lift_curve()
def test_plot_cumulative_gains(self):
self.evaluator.plot_cumulative_gains()

class TestRegressionEvaluator:
y_true, y_pred = mock_preds(50)
evaluator = RegressionEvaluator()
evaluator.fit(y_true, y_pred)
def test_plot_predictions(self):
self.evaluator.plot_predictions()
def test_plot_qq(self):
self.evaluator.plot_qq()
19 changes: 19 additions & 0 deletions tests/evaluation/test_pig_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
import pandas as pd

from cobra.evaluation import plot_incidence

def mock_data():
d = {'variable': ['education', 'education', 'education', 'education'],
'label': ['1st-4th', '5th-6th', '7th-8th', '9th'],
'pop_size': [0.002, 0.004, 0.009, 0.019],
'global_avg_target': [0.23, 0.23, 0.23, 0.23],
'avg_target': [0.047, 0.0434, 0.054, 0.069]}
return pd.DataFrame(d)


def test_plot_incidence():
plot_incidence(pig_tables=mock_data(),
variable="education",
model_type="regression",)

57 changes: 57 additions & 0 deletions tests/evaluation/test_plotting_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from pandas import DataFrame
from cobra.evaluation import (plot_univariate_predictor_quality,
plot_correlation_matrix,
plot_performance_curves,
plot_variable_importance)

def mock_df_rmse() -> DataFrame:
return DataFrame(
{'predictor': {0: 'weight', 1: 'displacement', 2: 'horsepower',
3: 'cylinders', 4: 'origin', 5: 'model_year',
6: 'name', 7: 'acceleration'},
'RMSE train': {0: 4.225088318760745, 1: 4.403878881676005,
2: 4.3343326307873875, 3: 4.901531871261906,
4: 6.6435969708016955, 5: 6.318271823003904,
6: 1.4537996193882199, 7: 6.631180878197439},
'RMSE selection': {0: 4.006855931973032, 1: 4.146696570151399,
2: 4.321365764687869, 3: 4.466259266291863,
4: 5.833138420191894, 5: 5.979795941821068,
6: 6.99641113758452, 7: 7.449190759856361},
'preselection': {0: True, 1: True, 2: True, 3: True, 4: True,
5: True, 6: True, 7: True}}
)


def mock_df_corr() -> DataFrame:
return DataFrame({
'cylinders': {'cylinders': 1.0, 'weight': 0.8767772796304492, 'horsepower': 0.8124872187173973},
'weight': {'cylinders': 0.8767772796304492, 'weight': 1.0, 'horsepower': 0.8786843186591881},
'horsepower': {'cylinders': 0.8124872187173973, 'weight': 0.8786843186591881, 'horsepower': 1.0}})

def mock_performances() -> DataFrame:
return DataFrame({
'predictors': {0: ['weight_enc'], 1: ['weight_enc', 'horsepower_enc'], 2: ['horsepower_enc', 'weight_enc', 'cylinders_enc']},
'last_added_predictor': {0: 'weight_enc', 1: 'horsepower_enc', 2: 'cylinders_enc'},
'train_performance': {0: 4.225088318760745, 1: 3.92118718828259, 2: 3.8929681840552495},
'selection_performance': {0: 4.006855931973032, 1: 3.630079770314085, 2: 3.531305702221386},
'validation_performance': {0: 4.348180862267973, 1: 4.089638309577036, 2: 3.9989641017455995},
'model_type': {0: 'regression', 1: 'regression', 2: 'regression'}
})

def mock_variable_importance() -> DataFrame:
return DataFrame({
'predictor': {0: 'weight', 1: 'horsepower', 2: 'model_year', 3: 'origin'},
'importance': {0: 0.8921354566046729, 1: 0.864633073581914, 2: 0.694399044392948, 3: 0.6442243718390968}
})

def test_plot_univariate_predictor_quality():
plot_univariate_predictor_quality(mock_df_rmse())

def test_plot_correlation_matrix():
plot_correlation_matrix(mock_df_corr())

def test_plot_performance_curves():
plot_performance_curves(mock_performances())

def test_plot_variable_importance():
plot_variable_importance(mock_variable_importance())

0 comments on commit 8025f7c

Please sign in to comment.