Skip to content

Commit

Permalink
Merge pull request #244 from jcmgray/mps-sampling
Browse files Browse the repository at this point in the history
add efficient MPS sampling, including to Circuit MPS classes
  • Loading branch information
jcmgray authored Jul 4, 2024
2 parents e05b77e + cb6a4e0 commit 8a22deb
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 8 deletions.
39 changes: 38 additions & 1 deletion quimb/tensor/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4008,7 +4008,7 @@ def apply_gates(self, gates, progbar=False, **gate_opts):
@property
def psi(self):
# no squeeze so that bond dims of 1 preserved
return self._psi
return self._psi.copy()

@property
def uni(self):
Expand All @@ -4029,6 +4029,23 @@ def get_psi_reverse_lightcone(self, where, keep_psi0=False):
"""
return self.psi

def sample(
self,
C,
seed=None,
):
"""Sample the MPS circuit ``C`` times.
Parameters
----------
C : int
The number of samples to generate.
seed : None, int, or generator, optional
A random seed or generator to use for reproducibility.
"""
for config, _ in self._psi.sample(C, seed=seed):
yield "".join(map(str, config))

def fidelity_estimate(self):
r"""Estimate the fidelity of the current state based on its norm, which
tracks how much the state has been truncated:
Expand Down Expand Up @@ -4122,6 +4139,26 @@ def get_psi_unordered(self):
"""
return self._psi.copy()

def sample(self, C, seed=None):
"""Sample the PermMPS circuit ``C`` times.
Parameters
----------
C : int
The number of samples to generate.
seed : None, int, or generator, optional
A random seed or generator to use for reproducibility.
Yields
------
str
The next sample bitstring.
"""
# configuring is in physical order, so need to reorder for sampling
ordering = self.calc_qubit_ordering()
for config, _ in self._psi.sample(C, seed=seed):
yield "".join(str(config[i]) for i in ordering)

@property
def psi(self):
# need to reindex and retag the MPS
Expand Down
77 changes: 77 additions & 0 deletions quimb/tensor/tensor_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3191,6 +3191,83 @@ def measure(

measure_ = functools.partialmethod(measure, inplace=True)

def sample_configuration(self, seed=None, info=None):
"""Sample a configuration from this MPS.
Parameters
----------
seed : None, int, or np.random.Generator, optional
A random seed or generator to use.
info : dict, optional
If given, will be used to infer and store various extra
information. Currently the key "cur_orthog" is used to store the
current orthogonality center.
"""
import numpy as np

# if seed is already a generator this simply returns it
rng = np.random.default_rng(seed)

# right canonicalize
psi = self.canonicalize(0, info=info)

config = []
omega = 1.0
for i in range(psi.L):

# form local density matrix
ki = psi[i]
bi = ki.H
ix = psi.site_ind(i)
# contract diagonal to get probabilities
pi = (ki & bi).contract(output_inds=[ix]).data

# sample outcome using numpy
pi = do("to_numpy", pi).real
pi /= pi.sum()
xi = rng.choice(pi.size, p=pi)
config.append(xi)
# track local probability
omega *= pi[xi]

# project outcome
psi.isel_({ix: xi})
if i < psi.L - 1:
# and absorb projected site into next site
psi.contract_tags_([psi.site_tag(i), psi.site_tag(i + 1)])

return config, omega

def sample(self, C, seed=None, info=None):
"""Generate ``C`` samples rom this MPS, along with their probabilities.
Parameters
----------
C : int
The number of samples to generate.
seed : None, int, or np.random.Generator, optional
A random seed or generator to use.
info : dict, optional
If given, will be used to infer and store various extra
information. Currently the key "cur_orthog" is used to store the
current orthogonality center.
Yields
------
config : sequence of int
The sample configuration.
omega : float
The probability of this configuration.
"""

if info is None:
info = {}

# do right canonicalization once (supplying info avoids re-performing)
psi0 = self.canonicalize(0, info=info)

for _ in range(C):
yield psi0.sample_configuration(seed=seed, info=info)

class MatrixProductOperator(TensorNetwork1DOperator, TensorNetwork1DFlat):
"""Initialise a matrix product operator, with auto labelling and tagging.
Expand Down
43 changes: 36 additions & 7 deletions tests/test_tensor/test_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,6 @@ def test_from_qsim(self):
qc = qtn.Circuit.from_qsim_str(qsim)
assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0)

def test_from_qsim_mps_swapsplit(self):
G = rand_reg_graph(reg=3, n=18, seed=42)
qsim = graph_to_qsim(G)
qc = qtn.CircuitMPS.from_qsim_str(qsim)
assert len(qc.psi.tensors) == 18
assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0)

def test_from_openqasm2(self):
qc = qtn.Circuit.from_openqasm2_str(example_openqasm2_qft())
assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0)
Expand Down Expand Up @@ -641,6 +634,15 @@ def test_multi_controlled_circuit(self):
(b,) = circ.sample(1, group_size=3)
assert b[N - 2] == "0"


class TestCircuitMPS:
def test_from_qsim_mps_swapsplit(self):
G = rand_reg_graph(reg=3, n=18, seed=42)
qsim = graph_to_qsim(G)
qc = qtn.CircuitMPS.from_qsim_str(qsim)
assert len(qc.psi.tensors) == 18
assert (qc.psi.H & qc.psi) ^ all == pytest.approx(1.0)

def test_multi_controlled_mps_circuit(self):
N = 10
rng = np.random.default_rng(42)
Expand Down Expand Up @@ -681,6 +683,33 @@ def test_multi_controlled_mps_circuit(self):
assert mps.norm() == pytest.approx(1.0)
assert mps.distance_normalized(psi_lazy) < 1e-6

def test_mps_sampling(self):
N = 6
circ = qtn.CircuitMPS(N)
circ.h(3)
circ.cx(3, 2)
circ.cx(2, 1)
circ.cx(1, 0)
circ.cx(0, 5)
circ.cx(5, 4)
circ.x(4)
for x in circ.sample(10):
assert x in {"000010", "111101"}

def test_permmps_sampling(self):
N = 6
circ = qtn.CircuitPermMPS(N)
circ.h(3)
circ.cx(3, 2)
circ.cx(2, 1)
circ.cx(1, 0)
circ.cx(0, 5)
circ.cx(5, 4)
circ.x(4)
assert circ.qubits != tuple(range(N))
for x in circ.sample(10):
assert x in {"000010", "111101"}


class TestCircuitGen:
@pytest.mark.parametrize(
Expand Down
10 changes: 10 additions & 0 deletions tests/test_tensor/test_tensor_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,16 @@ def test_gate_non_local(self, where, phys_dim):
psi.gate(G, where, contract=False)
) == pytest.approx(0.0, abs=1e-6)

def test_sample_configuration(self):
psi = qtn.MPS_rand_state(10, 7)
config, omega = psi.sample_configuration()
assert len(config) == 10
assert abs(
psi.isel(
{psi.site_ind(i): xi for i, xi in enumerate(config)}
).contract()
) ** 2 == pytest.approx(omega)


class TestMatrixProductOperator:
@pytest.mark.parametrize("cyclic", [False, True])
Expand Down

0 comments on commit 8a22deb

Please sign in to comment.