diff --git a/.github/workflows/publish_to_pypi.yml b/.github/workflows/publish_to_pypi.yml index db96ef5..47a11df 100644 --- a/.github/workflows/publish_to_pypi.yml +++ b/.github/workflows/publish_to_pypi.yml @@ -31,7 +31,7 @@ jobs: strategy: matrix: target: [x86_64, x86] - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11"] steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 @@ -59,7 +59,7 @@ jobs: sudo apt-get install -y pkg-config libssl-dev libudev-dev libopenblas-dev - uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' - name: Build wheels uses: PyO3/maturin-action@v1 env: @@ -92,7 +92,7 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' - name: Build wheels uses: PyO3/maturin-action@v1 with: diff --git a/src/expressions.rs b/src/expressions.rs index 46cd189..6f62d8a 100644 --- a/src/expressions.rs +++ b/src/expressions.rs @@ -112,11 +112,7 @@ pub struct RLSKwargs { } fn convert_option_vec_to_array1(opt_vec: Option>) -> Option> { - if let Some(vec) = opt_vec { - Some(Array1::from(vec)) - } else { - None - } + opt_vec.map(Array1::from) } #[polars_expr(output_type_func=list_float_dtype)] diff --git a/src/least_squares.rs b/src/least_squares.rs index a300cb0..db1b734 100644 --- a/src/least_squares.rs +++ b/src/least_squares.rs @@ -21,11 +21,11 @@ pub fn solve_ols_qr(y: &Array1, x: &Array2) -> Array1 { /// Solves the normal equations: (X^T X) coefficients = X^T Y fn solve_normal_equations(xtx: &Array2, xty: &Array1) -> Array1 { // attempt to solve via cholesky making use of X.T X being SPD - match xtx.solvec(&xty) { + match xtx.solvec(xty) { Ok(coefficients) => coefficients, Err(_) => { // else fallback to QR decomposition - solve_ols_qr(&xty, &xtx) + solve_ols_qr(xty, xtx) } } } @@ -201,9 +201,7 @@ fn outer_product(u: &ArrayView1, v: &ArrayView1) -> Array2 { let v_reshaped = v.insert_axis(Axis(0)); // Compute the outer product using broadcasting and dot product - let outer_product = u_reshaped.dot(&v_reshaped); - - outer_product + u_reshaped.dot(&v_reshaped) } fn inv_diag(c: &Array2) -> Array2 { @@ -235,7 +233,7 @@ pub fn woodbury_update( ) -> Array2 { // Check if c_is_diag is Some(true) let inv_c = if let Some(true) = c_is_diag { - inv_diag(&c) + inv_diag(c) } else { c.inv().unwrap() }; @@ -250,7 +248,7 @@ pub fn woodbury_update( /// Function to update inv(X^TX) by x_update array of rank r using Woodbury Identity. fn update_xtx_inv(xtx_inv: &Array2, x_update: &Array2) -> Array2 { // Reshape x_new and x_old for Woodbury update - let u = (&x_update.t()).to_owned(); // K x r + let u = x_update.t().to_owned(); // K x r let v = u.t().to_owned(); // r x K let c = Array2::eye(u.shape()[1]); // Identity matrix r x r @@ -330,13 +328,13 @@ pub fn solve_rolling_ols( let x_new = x.row(i); // Add new contributions - xtx = xtx + &outer_product(&x_new, &x_new); - xty = xty + &x_new * y[i]; + xtx += &outer_product(&x_new, &x_new); + xty = xty - &x_new * y[i]; // Subtract the previous contribution if i > window_size { let x_prev = x.row(i_start); - xtx = xtx - &outer_product(&x_prev, &x_prev); + xtx -= &outer_product(&x_prev, &x_prev); xty = xty - &x_prev * y[i_start]; }