Skip to content

Commit

Permalink
Fix python defined residuals
Browse files Browse the repository at this point in the history
  • Loading branch information
edyounis committed Nov 18, 2023
1 parent 3aaacf7 commit 7bc7dec
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/python/minimizers/residual_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,14 @@ impl ResidualFn for PyResidualFn {
let py = gil.python();
let parameters = PyArray1::from_slice(py, params);
let args = PyTuple::new(py, &[parameters]);
match self.cost_fn.call_method1(py, "get_cost", args) {
match self.cost_fn.call_method1(py, "get_residuals", args) {
Ok(val) => val
.extract::<Vec<f64>>(py)
.expect("Return type of get_cost was not a sequence of floats."),
Err(..) => panic!("Failed to call 'get_cost' on passed ResidualFunction."), // TODO: make a Python exception?
.expect("Return type of get_residuals was not a sequence of floats."),
Err(e) => {
println!("{:?}, {:?}, {:?}", e.get_type(py), e.value(py), e.traceback(py));
panic!("Failed to call 'get_residuals' on passed ResidualFunction."); // TODO: make a Python exception?
},
}
}
}
Expand Down

0 comments on commit 7bc7dec

Please sign in to comment.