Skip to content

Commit

Permalink
Add LazyLoader for delayed imports of slow modules (#4653)
Browse files Browse the repository at this point in the history
This allows for delayed import for modules, which is useful when the modules load time is slow and we don't want these imports on a global level.
  • Loading branch information
dabacon authored Nov 10, 2021
1 parent 83c440f commit a4f5561
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 24 deletions.
52 changes: 52 additions & 0 deletions cirq-core/cirq/_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,55 @@ def wrap_func(module: ModuleType) -> Optional[ModuleType]:
for module in execute_list:
if module.__loader__ is not None and hasattr(module.__loader__, 'exec_module'):
cast(Loader, module.__loader__).exec_module(module) # Calls back into wrap_func


class LazyLoader(ModuleType):
"""Lazily import a module, mainly to avoid pulling in large dependencies.
This class is a modified version of a similar class in TensorFlow.
To use, instead of importing the module normally
```
import heavy_module
```
define the module
```
heavy_module = LazyLoader("heavy_module", globals(), "mypackage.heavy_module")
```
"""

def __init__(self, local_name, parent_module_globals, name):
"""Create the LazyLoader module.
Args:
local_name: The local name that the module will be refered to as.
parent_module_globals: The globals of the module where this should be imported.
Typically this will be globals().
name: The full qualified name of the module.
"""
self._local_name = local_name
self._parent_module_globals = parent_module_globals
self._module = None
super().__init__(name)

def _load(self):
"""Load the module and insert it into the parent's globals."""
# Import the target module and insert it into the parent's namespace
if self._module:
return self._module
self._module = importlib.import_module(self.__name__)
self._parent_module_globals[self._local_name] = self._module

# Update this object's dict so that if someone keeps a reference to the LazyLoader,
# lookups are efficient (__getattr__ is only called on lookups that fail).
self.__dict__.update(self._module.__dict__)

return self._module

def __getattr__(self, item):
module = self._load()
return getattr(module, item)

def __dir__(self):
module = self._load()
return dir(module)
30 changes: 30 additions & 0 deletions cirq-core/cirq/_import_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from cirq import _import


def test_lazy_loader():
linalg = _import.LazyLoader("linalg", globals(), "scipy.linalg")
linalg.fun = 1
assert linalg._module is None
assert "linalg" not in linalg.__dict__

linalg.det([[1]])

assert linalg._module is not None
assert globals()["linalg"] == linalg._module
assert "fun" in linalg.__dict__
assert "LinAlgError" in dir(linalg)
assert linalg.fun == 1
10 changes: 5 additions & 5 deletions cirq-core/cirq/experiments/t1_decay_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
import numpy as np


from cirq import circuits, ops, study, value
from cirq import circuits, ops, study, value, _import
from cirq._compat import proper_repr

if TYPE_CHECKING:
import cirq

# We initialize optimize lazily, otherwise it slows global import speed.
optimize = _import.LazyLoader("optimize", globals(), "scipy.optimize")


# TODO(#3388) Add documentation for Raises.
# pylint: disable=missing-raises-doc
Expand Down Expand Up @@ -130,10 +133,7 @@ def exp_decay(x, t1):

# Fit to exponential decay to find the t1 constant
try:
# Import scipy.optimize here to avoid costly module level import.
import scipy.optimize

popt, _ = scipy.optimize.curve_fit(exp_decay, xs, probs, p0=[t1_guess])
popt, _ = optimize.curve_fit(exp_decay, xs, probs, p0=[t1_guess])
t1 = popt[0]
return t1
except RuntimeError:
Expand Down
18 changes: 8 additions & 10 deletions cirq-core/cirq/experiments/xeb_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np
import pandas as pd
import sympy
from cirq import ops, protocols
from cirq import ops, protocols, _import
from cirq.circuits import Circuit
from cirq.experiments.xeb_simulation import simulate_2q_xeb_circuits

Expand All @@ -36,6 +36,10 @@
import multiprocessing
import scipy.optimize

# We initialize these lazily, otherwise they slow global import speed.
optimize = _import.LazyLoader("optimize", globals(), "scipy.optimize")
stats = _import.LazyLoader("stats", globals(), "scipy.stats")

THETA_SYMBOL, ZETA_SYMBOL, CHI_SYMBOL, GAMMA_SYMBOL, PHI_SYMBOL = sympy.symbols(
'theta zeta chi gamma phi'
)
Expand Down Expand Up @@ -410,10 +414,7 @@ def _mean_infidelity(angles):
print(f"Loss: {loss:7.3g}", flush=True)
return loss

# Import scipy.optimize here to avoid costly top level moule import.
import scipy.optimize

optimization_result = scipy.optimize.minimize(
optimization_result = optimize.minimize(
_mean_infidelity,
x0=x0,
options={
Expand Down Expand Up @@ -574,15 +575,12 @@ def _fit_exponential_decay(
cycle_depths_pos = cycle_depths[positives]
log_fidelities = np.log(fidelities[positives])

# We import here to avoid costly module level load time dependency on scipy.stats.
import scipy.stats

slope, intercept, _, _, _ = scipy.stats.linregress(cycle_depths_pos, log_fidelities)
slope, intercept, _, _, _ = stats.linregress(cycle_depths_pos, log_fidelities)
layer_fid_0 = np.clip(np.exp(slope), 0, 1)
a_0 = np.clip(np.exp(intercept), 0, 1)

try:
(a, layer_fid), pcov = scipy.optimize.curve_fit(
(a, layer_fid), pcov = optimize.curve_fit(
exponential_decay,
cycle_depths,
fidelities,
Expand Down
19 changes: 10 additions & 9 deletions cirq-core/cirq/qis/measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.
"""Measures on and between quantum states and operations."""


from typing import Optional, TYPE_CHECKING, Tuple

import numpy as np
import scipy

from cirq import protocols, value
from cirq import protocols, value, _import
from cirq.qis.states import (
QuantumState,
infer_qid_shape,
Expand All @@ -27,13 +27,18 @@
validate_normalized_state_vector,
)

# We initialize these lazily, otherwise they slow global import speed.
stats = _import.LazyLoader("stats", globals(), "scipy.stats")
linalg = _import.LazyLoader("linalg", globals(), "scipy.linalg")


if TYPE_CHECKING:
import cirq


def _sqrt_positive_semidefinite_matrix(mat: np.ndarray) -> np.ndarray:
"""Square root of a positive semidefinite matrix."""
eigs, vecs = scipy.linalg.eigh(mat)
eigs, vecs = linalg.eigh(mat)
return vecs @ (np.sqrt(np.abs(eigs)) * vecs).T.conj()


Expand Down Expand Up @@ -237,7 +242,7 @@ def _fidelity_state_vectors_or_density_matrices(state1: np.ndarray, state2: np.n
elif state1.ndim == 2 and state2.ndim == 2:
# Both density matrices
state1_sqrt = _sqrt_positive_semidefinite_matrix(state1)
eigs = scipy.linalg.eigvalsh(state1_sqrt @ state2 @ state1_sqrt)
eigs = linalg.eigvalsh(state1_sqrt @ state2 @ state1_sqrt)
trace = np.sum(np.sqrt(np.abs(eigs)))
return trace ** 2
raise ValueError(
Expand Down Expand Up @@ -277,11 +282,7 @@ def von_neumann_entropy(
qid_shape = (state.shape[0],)
validate_density_matrix(state, qid_shape=qid_shape, dtype=state.dtype, atol=atol)
eigenvalues = np.linalg.eigvalsh(state)

# We import here to avoid a costly module level load time dependency on scipy.stats.
import scipy.stats

return scipy.stats.entropy(np.abs(eigenvalues), base=2)
return stats.entropy(np.abs(eigenvalues), base=2)
if validate:
_ = quantum_state(state, qid_shape=qid_shape, copy=False, validate=True, atol=atol)
return 0.0
Expand Down

0 comments on commit a4f5561

Please sign in to comment.