Skip to content

Commit

Permalink
Update pyqtorch/utils.py if finite_diff
Browse files Browse the repository at this point in the history
Co-authored-by: Doomsk <[email protected]>
  • Loading branch information
chMoussa and Doomsk authored Aug 19, 2024
1 parent efc762e commit 9574759
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions pyqtorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,17 @@ def finitediff(
if len(derivative_indices) > 3 or len(set(derivative_indices)) > 1:
di = derivative_indices[1:]
return (finitediff(f, x + ev, di) - finitediff(f, x - ev, di)) * _eps / 2
elif len(derivative_indices) == 3:
if len(derivative_indices) == 3:
return (
(f(x + 2 * ev) - 2 * f(x + ev) + 2 * f(x - ev) - f(x - 2 * ev))
* _eps**3
/ 2
)
elif len(derivative_indices) == 2:
if len(derivative_indices) == 2:
return (f(x + ev) + f(x - ev) - 2 * f(x)) * _eps**2
elif len(derivative_indices) == 1:
if len(derivative_indices) == 1:
return (f(x + ev) - f(x - ev)) * _eps / 2
else:
raise ValueError(
raise ValueError(
"If you see this error there is a bug in the `finitediff` function."
)

Expand Down

0 comments on commit 9574759

Please sign in to comment.