Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add efficient MPS sampling, including to Circuit MPS classes #244

Merged
merged 2 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion quimb/tensor/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4008,7 +4008,7 @@
@property
def psi(self):
# no squeeze so that bond dims of 1 preserved
return self._psi
return self._psi.copy()

Check warning on line 4011 in quimb/tensor/circuit.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/circuit.py#L4011

Added line #L4011 was not covered by tests

@property
def uni(self):
Expand All @@ -4029,6 +4029,23 @@
"""
return self.psi

def sample(
self,
C,
seed=None,
):
"""Sample the MPS circuit ``C`` times.

Parameters
----------
C : int
The number of samples to generate.
seed : None, int, or generator, optional
A random seed or generator to use for reproducibility.
"""
for config, _ in self._psi.sample(C, seed=seed):
yield "".join(map(str, config))

Check warning on line 4047 in quimb/tensor/circuit.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/circuit.py#L4046-L4047

Added lines #L4046 - L4047 were not covered by tests

def fidelity_estimate(self):
r"""Estimate the fidelity of the current state based on its norm, which
tracks how much the state has been truncated:
Expand Down Expand Up @@ -4122,6 +4139,26 @@
"""
return self._psi.copy()

def sample(self, C, seed=None):
"""Sample the PermMPS circuit ``C`` times.

Parameters
----------
C : int
The number of samples to generate.
seed : None, int, or generator, optional
A random seed or generator to use for reproducibility.

Yields
------
str
The next sample bitstring.
"""
# configuring is in physical order, so need to reorder for sampling
ordering = self.calc_qubit_ordering()
for config, _ in self._psi.sample(C, seed=seed):
yield "".join(str(config[i]) for i in ordering)

Check warning on line 4160 in quimb/tensor/circuit.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/circuit.py#L4158-L4160

Added lines #L4158 - L4160 were not covered by tests

@property
def psi(self):
# need to reindex and retag the MPS
Expand Down
77 changes: 77 additions & 0 deletions quimb/tensor/tensor_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3191,6 +3191,83 @@

measure_ = functools.partialmethod(measure, inplace=True)

def sample_configuration(self, seed=None, info=None):
"""Sample a configuration from this MPS.

Parameters
----------
seed : None, int, or np.random.Generator, optional
A random seed or generator to use.
info : dict, optional
If given, will be used to infer and store various extra
information. Currently the key "cur_orthog" is used to store the
current orthogonality center.
"""
import numpy as np

Check warning on line 3206 in quimb/tensor/tensor_1d.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_1d.py#L3206

Added line #L3206 was not covered by tests

# if seed is already a generator this simply returns it
rng = np.random.default_rng(seed)

Check warning on line 3209 in quimb/tensor/tensor_1d.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_1d.py#L3209

Added line #L3209 was not covered by tests

# right canonicalize
psi = self.canonicalize(0, info=info)

Check warning on line 3212 in quimb/tensor/tensor_1d.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_1d.py#L3212

Added line #L3212 was not covered by tests

config = []
omega = 1.0
for i in range(psi.L):

Check warning on line 3216 in quimb/tensor/tensor_1d.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_1d.py#L3214-L3216

Added lines #L3214 - L3216 were not covered by tests

# form local density matrix
ki = psi[i]
bi = ki.H
ix = psi.site_ind(i)

Check warning on line 3221 in quimb/tensor/tensor_1d.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_1d.py#L3219-L3221

Added lines #L3219 - L3221 were not covered by tests
# contract diagonal to get probabilities
pi = (ki & bi).contract(output_inds=[ix]).data

Check warning on line 3223 in quimb/tensor/tensor_1d.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_1d.py#L3223

Added line #L3223 was not covered by tests

# sample outcome using numpy
pi = do("to_numpy", pi).real
pi /= pi.sum()
xi = rng.choice(pi.size, p=pi)
config.append(xi)

Check warning on line 3229 in quimb/tensor/tensor_1d.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_1d.py#L3226-L3229

Added lines #L3226 - L3229 were not covered by tests
# track local probability
omega *= pi[xi]

Check warning on line 3231 in quimb/tensor/tensor_1d.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_1d.py#L3231

Added line #L3231 was not covered by tests

# project outcome
psi.isel_({ix: xi})
if i < psi.L - 1:

Check warning on line 3235 in quimb/tensor/tensor_1d.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_1d.py#L3234-L3235

Added lines #L3234 - L3235 were not covered by tests
# and absorb projected site into next site
psi.contract_tags_([psi.site_tag(i), psi.site_tag(i + 1)])

Check warning on line 3237 in quimb/tensor/tensor_1d.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_1d.py#L3237

Added line #L3237 was not covered by tests

return config, omega

Check warning on line 3239 in quimb/tensor/tensor_1d.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_1d.py#L3239

Added line #L3239 was not covered by tests

def sample(self, C, seed=None, info=None):
"""Generate ``C`` samples rom this MPS, along with their probabilities.

Parameters
----------
C : int
The number of samples to generate.
seed : None, int, or np.random.Generator, optional
A random seed or generator to use.
info : dict, optional
If given, will be used to infer and store various extra
information. Currently the key "cur_orthog" is used to store the
current orthogonality center.

Yields
------
config : sequence of int
The sample configuration.
omega : float
The probability of this configuration.
"""

if info is None:
info = {}

Check warning on line 3264 in quimb/tensor/tensor_1d.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_1d.py#L3263-L3264

Added lines #L3263 - L3264 were not covered by tests

# do right canonicalization once (supplying info avoids re-performing)
psi0 = self.canonicalize(0, info=info)

Check warning on line 3267 in quimb/tensor/tensor_1d.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_1d.py#L3267

Added line #L3267 was not covered by tests

for _ in range(C):
yield psi0.sample_configuration(seed=seed, info=info)

Check warning on line 3270 in quimb/tensor/tensor_1d.py

View check run for this annotation

Codecov / codecov/patch

quimb/tensor/tensor_1d.py#L3269-L3270

Added lines #L3269 - L3270 were not covered by tests

class MatrixProductOperator(TensorNetwork1DOperator, TensorNetwork1DFlat):
"""Initialise a matrix product operator, with auto labelling and tagging.
Expand Down
43 changes: 36 additions & 7 deletions tests/test_tensor/test_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,6 @@ def test_from_qsim(self):
qc = qtn.Circuit.from_qsim_str(qsim)
assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0)

def test_from_qsim_mps_swapsplit(self):
G = rand_reg_graph(reg=3, n=18, seed=42)
qsim = graph_to_qsim(G)
qc = qtn.CircuitMPS.from_qsim_str(qsim)
assert len(qc.psi.tensors) == 18
assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0)

def test_from_openqasm2(self):
qc = qtn.Circuit.from_openqasm2_str(example_openqasm2_qft())
assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0)
Expand Down Expand Up @@ -641,6 +634,15 @@ def test_multi_controlled_circuit(self):
(b,) = circ.sample(1, group_size=3)
assert b[N - 2] == "0"


class TestCircuitMPS:
def test_from_qsim_mps_swapsplit(self):
G = rand_reg_graph(reg=3, n=18, seed=42)
qsim = graph_to_qsim(G)
qc = qtn.CircuitMPS.from_qsim_str(qsim)
assert len(qc.psi.tensors) == 18
assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0)

def test_multi_controlled_mps_circuit(self):
N = 10
rng = np.random.default_rng(42)
Expand Down Expand Up @@ -681,6 +683,33 @@ def test_multi_controlled_mps_circuit(self):
assert mps.norm() == pytest.approx(1.0)
assert mps.distance_normalized(psi_lazy) < 1e-6

def test_mps_sampling(self):
N = 6
circ = qtn.CircuitMPS(N)
circ.h(3)
circ.cx(3, 2)
circ.cx(2, 1)
circ.cx(1, 0)
circ.cx(0, 5)
circ.cx(5, 4)
circ.x(4)
for x in circ.sample(10):
assert x in {"000010", "111101"}

def test_permmps_sampling(self):
N = 6
circ = qtn.CircuitPermMPS(N)
circ.h(3)
circ.cx(3, 2)
circ.cx(2, 1)
circ.cx(1, 0)
circ.cx(0, 5)
circ.cx(5, 4)
circ.x(4)
assert circ.qubits != tuple(range(N))
for x in circ.sample(10):
assert x in {"000010", "111101"}


class TestCircuitGen:
@pytest.mark.parametrize(
Expand Down
10 changes: 10 additions & 0 deletions tests/test_tensor/test_tensor_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,16 @@ def test_gate_non_local(self, where, phys_dim):
psi.gate(G, where, contract=False)
) == pytest.approx(0.0, abs=1e-6)

def test_sample_configuration(self):
psi = qtn.MPS_rand_state(10, 7)
config, omega = psi.sample_configuration()
assert len(config) == 10
assert abs(
psi.isel(
{psi.site_ind(i): xi for i, xi in enumerate(config)}
).contract()
) ** 2 == pytest.approx(omega)


class TestMatrixProductOperator:
@pytest.mark.parametrize("cyclic", [False, True])
Expand Down