Skip to content

Commit

Permalink
run cargo fmt --all && cargo clippy --all-features
Browse files Browse the repository at this point in the history
  • Loading branch information
azmyrajab committed Mar 22, 2024
1 parent a4da788 commit 045f96f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 18 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/publish_to_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
strategy:
matrix:
target: [x86_64, x86]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand Down Expand Up @@ -59,7 +59,7 @@ jobs:
sudo apt-get install -y pkg-config libssl-dev libudev-dev libopenblas-dev
- uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: '3.11'
- name: Build wheels
uses: PyO3/maturin-action@v1
env:
Expand Down Expand Up @@ -92,7 +92,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: '3.11'
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
Expand Down
6 changes: 1 addition & 5 deletions src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,7 @@ pub struct RLSKwargs {
}

fn convert_option_vec_to_array1(opt_vec: Option<Vec<f32>>) -> Option<Array1<f32>> {
if let Some(vec) = opt_vec {
Some(Array1::from(vec))
} else {
None
}
opt_vec.map(Array1::from)
}

#[polars_expr(output_type_func=list_float_dtype)]
Expand Down
18 changes: 8 additions & 10 deletions src/least_squares.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ pub fn solve_ols_qr(y: &Array1<f32>, x: &Array2<f32>) -> Array1<f32> {
/// Solves the normal equations: (X^T X) coefficients = X^T Y
fn solve_normal_equations(xtx: &Array2<f32>, xty: &Array1<f32>) -> Array1<f32> {
// attempt to solve via cholesky making use of X.T X being SPD
match xtx.solvec(&xty) {
match xtx.solvec(xty) {
Ok(coefficients) => coefficients,
Err(_) => {
// else fallback to QR decomposition
solve_ols_qr(&xty, &xtx)
solve_ols_qr(xty, xtx)
}
}
}
Expand Down Expand Up @@ -201,9 +201,7 @@ fn outer_product(u: &ArrayView1<f32>, v: &ArrayView1<f32>) -> Array2<f32> {
let v_reshaped = v.insert_axis(Axis(0));

// Compute the outer product using broadcasting and dot product
let outer_product = u_reshaped.dot(&v_reshaped);

outer_product
u_reshaped.dot(&v_reshaped)
}

fn inv_diag(c: &Array2<f32>) -> Array2<f32> {
Expand Down Expand Up @@ -235,7 +233,7 @@ pub fn woodbury_update(
) -> Array2<f32> {
// Check if c_is_diag is Some(true)
let inv_c = if let Some(true) = c_is_diag {
inv_diag(&c)
inv_diag(c)
} else {
c.inv().unwrap()
};
Expand All @@ -250,7 +248,7 @@ pub fn woodbury_update(
/// Function to update inv(X^TX) by x_update array of rank r using Woodbury Identity.
fn update_xtx_inv(xtx_inv: &Array2<f32>, x_update: &Array2<f32>) -> Array2<f32> {
// Reshape x_new and x_old for Woodbury update
let u = (&x_update.t()).to_owned(); // K x r
let u = x_update.t().to_owned(); // K x r
let v = u.t().to_owned(); // r x K
let c = Array2::eye(u.shape()[1]); // Identity matrix r x r

Expand Down Expand Up @@ -330,13 +328,13 @@ pub fn solve_rolling_ols(
let x_new = x.row(i);

// Add new contributions
xtx = xtx + &outer_product(&x_new, &x_new);
xty = xty + &x_new * y[i];
xtx += &outer_product(&x_new, &x_new);
xty = xty - &x_new * y[i];

// Subtract the previous contribution
if i > window_size {
let x_prev = x.row(i_start);
xtx = xtx - &outer_product(&x_prev, &x_prev);
xtx -= &outer_product(&x_prev, &x_prev);
xty = xty - &x_prev * y[i_start];
}

Expand Down

0 comments on commit 045f96f

Please sign in to comment.