diff --git a/tests/test.py b/tests/test_basic.py similarity index 71% rename from tests/test.py rename to tests/test_basic.py index 873d8c4..f5c3650 100644 --- a/tests/test.py +++ b/tests/test_basic.py @@ -3,7 +3,7 @@ import numpy as np from numpy.random import default_rng -from sortedl1 import slope +from sortedl1 import Slope class TestBasicUse(unittest.TestCase): @@ -20,9 +20,11 @@ def test_simple_problem(self): lam = np.array([2, 1, 0.2]) alph = np.array([1.0]) - res = slope(x, y, lam, alph) + model = Slope(alph, lam) - np.testing.assert_array_almost_equal(res, 0) + res = model.fit(x, y).predict(x) + + np.testing.assert_array_almost_equal(res, np.array([0.0, 1.0])) if __name__ == "__main__":