From aeca7389d317a2a7395b3c4e2699e93adfbbbab5 Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Mon, 22 Jul 2024 14:47:13 -0700 Subject: [PATCH] TN: add psi.normalize_simple --- quimb/tensor/tensor_arbgeom.py | 26 +++++++++++++++++++++++- tests/test_tensor/test_tensor_arbgeom.py | 18 ++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/quimb/tensor/tensor_arbgeom.py b/quimb/tensor/tensor_arbgeom.py index 700674d6..0e61f70c 100644 --- a/quimb/tensor/tensor_arbgeom.py +++ b/quimb/tensor/tensor_arbgeom.py @@ -631,6 +631,30 @@ def __iadd__(self, other): def __isub__(self, other): return tensor_network_ag_sum(self, other, negate=True, inplace=True) + def normalize_simple(self, gauges, **contract_opts): + """Normalize this network using simple local gauges. After calling + this, any tree-like sub network gauged with ``gauges`` will have + 2-norm 1. Inplace operation on both the tensor network and ``gauges``. + + Parameters + ---------- + gauges : dict[str, array_like] + The gauges to normalize with. + """ + # normalize gauges + for ix, g in gauges.items(): + gauges[ix] = g / do("linalg.norm", g) + + # normalize sites + for site in self.sites: + tn_site = self.select(site) + tn_site_gauged = tn_site.copy() + tn_site_gauged.gauge_simple_insert(gauges) + lnorm = (tn_site_gauged.H | tn_site_gauged).contract( + all, **contract_opts + ) ** 0.5 + tn_site /= lnorm + def gauge_product_boundary_vector( tn, @@ -894,7 +918,7 @@ def gate_with_op_lazy(self, A, transpose=False, inplace=False, **kwargs): which_A="upper" if transpose else "lower", contract=False, inplace=inplace, - **kwargs + **kwargs, ) gate_with_op_lazy_ = functools.partialmethod( diff --git a/tests/test_tensor/test_tensor_arbgeom.py b/tests/test_tensor/test_tensor_arbgeom.py index 43946045..96320f6e 100644 --- a/tests/test_tensor/test_tensor_arbgeom.py +++ b/tests/test_tensor/test_tensor_arbgeom.py @@ -116,3 +116,21 @@ def test_gate_sandwich_with_op(): y = A.to_dense() @ B.to_dense() @ A.to_dense().conj().T B.gate_sandwich_with_op_lazy_(A) assert_allclose(B.to_dense(), y) + + +def test_normalize_simple(): + psi = qtn.PEPS.rand(3, 3, 2, dtype=complex) + gauges = {} + psi.gauge_all_simple_(100, 5e-6, gauges=gauges) + psi.normalize_simple(gauges) + + for where in [ + [(0, 0)], + [(1, 1), (1, 2)], + [(0, 0), (0, 1), (0, 2), (1, 2), (2, 2), (2, 1)], + ]: + tags = [psi.site_tag(w) for w in where] + k = psi.select_any(tags, virtual=False) + k.gauge_simple_insert(gauges) + + assert k.H @ k == pytest.approx(1.0)