From d4635165d28caa8d1c197ea01075be7956bd07e0 Mon Sep 17 00:00:00 2001 From: kasaur <1703543+ksaur@users.noreply.github.com> Date: Thu, 30 May 2024 18:28:23 +0000 Subject: [PATCH 1/2] Fixing the decomposition bug based on skl PR #28612 --- .../ml/operator_converters/_decomposition_implementations.py | 1 - 1 file changed, 1 deletion(-) diff --git a/hummingbird/ml/operator_converters/_decomposition_implementations.py b/hummingbird/ml/operator_converters/_decomposition_implementations.py index 6a03d9316..9325a43f3 100644 --- a/hummingbird/ml/operator_converters/_decomposition_implementations.py +++ b/hummingbird/ml/operator_converters/_decomposition_implementations.py @@ -96,6 +96,5 @@ def __init__(self, logical_operator, x_mean, x_std, y_mean, coefficients, device def forward(self, x): x -= self.x_mean - x /= self.x_std y_pred = torch.mm(x, self.coefficients).float() return y_pred + self.y_mean From 2d3d40202a3c7aa9d2bdbd937209846a2b84f2c8 Mon Sep 17 00:00:00 2001 From: kasaur <1703543+ksaur@users.noreply.github.com> Date: Thu, 30 May 2024 18:49:09 +0000 Subject: [PATCH 2/2] now using skl>=1.5.0 --- .github/workflows/pythonapp.yml | 6 ------ setup.py | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 63eec1e4f..2142d5c80 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -54,12 +54,6 @@ jobs: - name: Install basic dependencies run: | python -m pip install -e .[tests] -f https://download.pytorch.org/whl/torch_stable.html - - name: Test with older SKLearn on Linux with py3.9 to check backward compatibility - if: ${{ matrix.python-version == '3.9' && startsWith(matrix.os, 'ubuntu') == true }} - run: python -m pip install scikit-learn==1.2.1 - - name: Pin SKLearn<1.5.0 on not Linux with py3.9 - if: ${{ !(matrix.python-version == '3.9' && startsWith(matrix.os, 'ubuntu') == true) }} - run: python -m pip install "scikit-learn<1.5.0" - name: Run basic tests without extra run: pytest - name: Coverage on basic tests without extra diff --git a/setup.py b/setup.py index dca5a7a58..5d832d017 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ "numpy>=1.15", "onnxconverter-common>=1.6.0", "scipy", - "scikit-learn", + "scikit-learn>=1.5.0", "torch>1.7.0", "psutil", "dill",