Skip to content

Commit

Permalink
TNOptimizer: fix torch for complex parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Dec 5, 2024
1 parent 4445bdd commit 2e13857
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
11 changes: 3 additions & 8 deletions quimb/tensor/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
tree_map,
tree_unflatten,
)
from .contraction import contract_backend
from .interface import get_jax
from .tensor_core import (
TensorNetwork,
Expand Down Expand Up @@ -266,7 +265,6 @@ def _parse_pytree_to_backend(x, to_constant):

def collect(x):
if hasattr(x, "get_params"):

if hasattr(x, "apply_to_arrays"):
x.apply_to_arrays(to_constant)

Expand Down Expand Up @@ -655,7 +653,7 @@ def value_and_grad(self, arrays):
def get_gradient_from_torch(t):
if t.grad is None:
return np.zeros(t.shape, dtype=get_dtype_name(t))
return to_numpy(t.grad).conj()
return to_numpy(t.grad)

result.backward()
grads = tree_map(get_gradient_from_torch, variables)
Expand Down Expand Up @@ -1108,10 +1106,7 @@ def __init__(self, tn_opt, loss_fn, norm_fn, autodiff_backend):

def __call__(self, arrays):
tn_compute = inject_variables(arrays, self.tn_opt)

# set backend explicitly as maybe mixing with numpy arrays
with contract_backend(self.autodiff_backend):
return self.loss_fn(self.norm_fn(tn_compute))
return self.loss_fn(self.norm_fn(tn_compute))


def identity_fn(x):
Expand Down Expand Up @@ -1795,7 +1790,7 @@ def plot(
ax.plot(xs, ys, ".-")
if xscale == "symlog":
ax.set_xscale(xscale, linthresh=xscale_linthresh)
ax.axvline(xscale_linthresh, color=(.5, .5, .5), ls="-", lw=0.5)
ax.axvline(xscale_linthresh, color=(0.5, 0.5, 0.5), ls="-", lw=0.5)
else:
ax.set_xscale(xscale)
ax.set_xlabel("Iteration")
Expand Down
13 changes: 8 additions & 5 deletions tests/test_tensor/test_optimizers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import functools
import importlib

import pytest
import numpy as np
from numpy.testing import assert_allclose
import pytest
from autoray import real
from numpy.testing import assert_allclose

import quimb as qu
import quimb.tensor as qtn
from quimb.tensor.optimize import Vectorizer, parse_network_to_backend


found_torch = importlib.util.find_spec("torch") is not None
found_autograd = importlib.util.find_spec("autograd") is not None
found_jax = importlib.util.find_spec("jax") is not None
Expand Down Expand Up @@ -206,7 +205,9 @@ def test_optimize_pbc_heis(heis_pbc, backend, method):
assert loss_fn(psi_opt, H) == pytest.approx(en_ex, rel=1e-2)


@pytest.mark.parametrize("backend", [jax_case, autograd_case, tensorflow_case])
@pytest.mark.parametrize(
"backend", [jax_case, autograd_case, tensorflow_case, pytorch_case]
)
@pytest.mark.parametrize("method", ["simple", "basin"])
def test_optimize_ham_mbl_complex(ham_mbl_pbc_complex, backend, method):
psi0, H, norm_fn, loss_fn, en_ex = ham_mbl_pbc_complex
Expand Down Expand Up @@ -269,7 +270,9 @@ def loss(psi, target):
assert tnopt.loss < f0


@pytest.mark.parametrize("backend", [jax_case, autograd_case, tensorflow_case])
@pytest.mark.parametrize(
"backend", [jax_case, autograd_case, tensorflow_case, pytorch_case]
)
def test_parametrized_circuit(backend):
H = qu.ham_mbl(4, dh=3.0, dh_dim=3)
gs = qu.groundstate(H)
Expand Down

0 comments on commit 2e13857

Please sign in to comment.