Skip to content

Commit

Permalink
add test time init + better finite_diff writing
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles MOUSSA committed Aug 19, 2024
1 parent 9574759 commit 3ede788
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
7 changes: 6 additions & 1 deletion pyqtorch/hamiltonians/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,12 @@ def __init__(
)
super().__init__(generator)
self._qubit_support = qubit_support # type: ignore
self.time = time

if isinstance(time, str) or isinstance(time, Tensor):
self.time = time
else:
raise ValueError("time should be passed as str or tensor.")

logger.debug("Hamiltonian Evolution initialized")
if logger.isEnabledFor(logging.DEBUG):
# When Debugging let's add logging and NVTX markers
Expand Down
25 changes: 14 additions & 11 deletions pyqtorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,29 +247,32 @@ def finitediff(
eps = torch.finfo(x.dtype).eps ** (1 / (2 + order))

# compute derivative direction vector(s)
eps = torch.as_tensor(eps, dtype=x.dtype)
_eps = 1 / eps # type: ignore[operator]
ev = torch.zeros_like(x)
delta = torch.zeros_like(x)
i = derivative_indices[0]
ev[:, i] += eps
delta[:, i] += torch.as_tensor(eps, dtype=x.dtype)
denominator = 1 / delta[:, i]

# recursive finite differencing for higher order than 3 / mixed derivatives
if len(derivative_indices) > 3 or len(set(derivative_indices)) > 1:
di = derivative_indices[1:]
return (finitediff(f, x + ev, di) - finitediff(f, x - ev, di)) * _eps / 2
return (
(finitediff(f, x + delta, di) - finitediff(f, x - delta, di))
* denominator
/ 2
)
if len(derivative_indices) == 3:
return (
(f(x + 2 * ev) - 2 * f(x + ev) + 2 * f(x - ev) - f(x - 2 * ev))
* _eps**3
(f(x + 2 * delta) - 2 * f(x + delta) + 2 * f(x - delta) - f(x - 2 * delta))
* denominator**3
/ 2
)
if len(derivative_indices) == 2:
return (f(x + ev) + f(x - ev) - 2 * f(x)) * _eps**2
return (f(x + delta) + f(x - delta) - 2 * f(x)) * denominator**2
if len(derivative_indices) == 1:
return (f(x + ev) - f(x - ev)) * _eps / 2
return (f(x + delta) - f(x - delta)) * denominator / 2
raise ValueError(
"If you see this error there is a bug in the `finitediff` function."
)
"If you see this error there is a bug in the `finitediff` function."
)


def product_state(
Expand Down

0 comments on commit 3ede788

Please sign in to comment.