From 0c2764bafba3663a1ee054e63e2f307e8b7ffe21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bojan=20Karla=C5=A1?= Date: Thu, 17 Aug 2023 02:41:16 -0400 Subject: [PATCH] Changes to Shapley tests. --- tests/importance/test_shapley.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/importance/test_shapley.py b/tests/importance/test_shapley.py index f61910f..ee81981 100644 --- a/tests/importance/test_shapley.py +++ b/tests/importance/test_shapley.py @@ -98,9 +98,9 @@ def test_comparative_1(trainsize: int, testsize: int, method: ImportanceMethod): # @pytest.mark.parametrize("n_samples_train", [100, 500, 1000, 5000, 10000]) -@pytest.mark.parametrize("n_samples_train", [500, 1000, 5000, 10000]) +@pytest.mark.parametrize("n_samples_train", [1024, 2048, 4096, 8192, 16384, 32768, 65536]) # @pytest.mark.parametrize("n_samples_test", [10, 50, 100]) -@pytest.mark.parametrize("n_samples_test", [1000]) +@pytest.mark.parametrize("n_samples_test", [1024]) def test_neighbor_benchmark_1(n_samples_train: int, n_samples_test: int, benchmark): X, y = make_classification( n_samples=n_samples_train + n_samples_test, @@ -122,8 +122,8 @@ def test_neighbor_benchmark_1(n_samples_train: int, n_samples_test: int, benchma if __name__ == "__main__": - n_samples_train = 10000 - n_samples_test = 500 + n_samples_train = 100000 + n_samples_test = 1000 X, y = make_classification( n_samples=n_samples_train + n_samples_test, @@ -137,10 +137,10 @@ def test_neighbor_benchmark_1(n_samples_train: int, n_samples_test: int, benchma ) X, X_test, y, y_test = train_test_split(X, y, train_size=n_samples_train, test_size=n_samples_test, random_state=7) - provenance = Provenance(units=1000) - provenance = provenance.fork(size=10) + provenance = Provenance(units=n_samples_train) + # provenance = provenance.fork(size=10) utility = SklearnModelAccuracy(KNeighborsClassifier(n_neighbors=1)) importance = ShapleyImportance(method=ImportanceMethod.NEIGHBOR, utility=utility) importance.fit(X, y, provenance=provenance) - importance.score(X_test, y_test) + print(importance.score(X_test, y_test))