Skip to content

Commit

Permalink
fixed failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vytautas-a committed Sep 24, 2024
1 parent 8439656 commit 866e125
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
30 changes: 17 additions & 13 deletions pyqtorch/composite/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.nn import Module, ModuleList, ParameterDict

from pyqtorch.apply import apply_operator
from pyqtorch.embed import Embedding
from pyqtorch.embed import ConcretizedCallable, Embedding
from pyqtorch.matrices import add_batch_dim
from pyqtorch.primitives import CNOT, RX, RY, Parametric, Primitive
from pyqtorch.utils import (
Expand Down Expand Up @@ -71,12 +71,14 @@ def forward(
if embedding is not None:
values = embedding(values)

scale = (
values[self.param_name]
if isinstance(self.param_name, str)
else self.param_name
)
return scale * self.operations[0].forward(state, values, embedding)
if isinstance(self.param_name, str):
scale = values[self.param_name]
elif isinstance(self.param_name, Tensor):
scale = self.param_name
elif isinstance(self.param_name, ConcretizedCallable):
scale = self.param_name(values)

return scale * self.operations[0].forward(state, values)

def tensor(
self,
Expand All @@ -99,12 +101,14 @@ def tensor(
if embedding is not None:
values = embedding(values)

scale = (
values[self.param_name]
if isinstance(self.param_name, str)
else self.param_name
)
return scale * self.operations[0].tensor(values, embedding, full_support)
if isinstance(self.param_name, str):
scale = values[self.param_name]
elif isinstance(self.param_name, (Tensor, int, float)):
scale = self.param_name
elif isinstance(self.param_name, ConcretizedCallable):
scale = self.param_name(values)

return scale * self.operations[0].tensor(values, full_support=full_support)

def flatten(self) -> list[Scale]:
"""This method should only be called in the AdjointExpectation,
Expand Down
4 changes: 1 addition & 3 deletions pyqtorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ def is_diag(H: Tensor, atol: Tensor = ATOL) -> bool:
Returns:
True if diagonal, else False.
"""
m = H.shape[0]
p, q = H.stride()
offdiag_view = torch.as_strided(H[:, 1:], (m - 1, m), (p + q, q))
offdiag_view = H - torch.diag(torch.diag(H))
return torch.count_nonzero(torch.abs(offdiag_view).gt(atol)) == 0


Expand Down
3 changes: 2 additions & 1 deletion tests/test_analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def test_timedependent(
)
hamiltonian_evolution = pyq.HamiltonianEvolution(
generator=hamevo_generator,
time=torch.as_tensor(duration),
time=tparam,
duration=duration,
steps=n_steps,
solver=ode_solver,
)
Expand Down

0 comments on commit 866e125

Please sign in to comment.