From a4f556181952a7d44fac9c01bc9c3052f1535453 Mon Sep 17 00:00:00 2001 From: Dave Bacon Date: Wed, 10 Nov 2021 16:53:30 +0000 Subject: [PATCH] Add LazyLoader for delayed imports of slow modules (#4653) This allows for delayed import for modules, which is useful when the modules load time is slow and we don't want these imports on a global level. --- cirq-core/cirq/_import.py | 52 +++++++++++++++++++ cirq-core/cirq/_import_test.py | 30 +++++++++++ .../cirq/experiments/t1_decay_experiment.py | 10 ++-- cirq-core/cirq/experiments/xeb_fitting.py | 18 +++---- cirq-core/cirq/qis/measures.py | 19 +++---- 5 files changed, 105 insertions(+), 24 deletions(-) create mode 100644 cirq-core/cirq/_import_test.py diff --git a/cirq-core/cirq/_import.py b/cirq-core/cirq/_import.py index 36c7d5d35ad..45362b242cd 100644 --- a/cirq-core/cirq/_import.py +++ b/cirq-core/cirq/_import.py @@ -171,3 +171,55 @@ def wrap_func(module: ModuleType) -> Optional[ModuleType]: for module in execute_list: if module.__loader__ is not None and hasattr(module.__loader__, 'exec_module'): cast(Loader, module.__loader__).exec_module(module) # Calls back into wrap_func + + +class LazyLoader(ModuleType): + """Lazily import a module, mainly to avoid pulling in large dependencies. + + This class is a modified version of a similar class in TensorFlow. + + To use, instead of importing the module normally + ``` + import heavy_module + ``` + define the module + ``` + heavy_module = LazyLoader("heavy_module", globals(), "mypackage.heavy_module") + ``` + """ + + def __init__(self, local_name, parent_module_globals, name): + """Create the LazyLoader module. + + Args: + local_name: The local name that the module will be refered to as. + parent_module_globals: The globals of the module where this should be imported. + Typically this will be globals(). + name: The full qualified name of the module. + """ + self._local_name = local_name + self._parent_module_globals = parent_module_globals + self._module = None + super().__init__(name) + + def _load(self): + """Load the module and insert it into the parent's globals.""" + # Import the target module and insert it into the parent's namespace + if self._module: + return self._module + self._module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = self._module + + # Update this object's dict so that if someone keeps a reference to the LazyLoader, + # lookups are efficient (__getattr__ is only called on lookups that fail). + self.__dict__.update(self._module.__dict__) + + return self._module + + def __getattr__(self, item): + module = self._load() + return getattr(module, item) + + def __dir__(self): + module = self._load() + return dir(module) diff --git a/cirq-core/cirq/_import_test.py b/cirq-core/cirq/_import_test.py new file mode 100644 index 00000000000..2fa20246fd1 --- /dev/null +++ b/cirq-core/cirq/_import_test.py @@ -0,0 +1,30 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cirq import _import + + +def test_lazy_loader(): + linalg = _import.LazyLoader("linalg", globals(), "scipy.linalg") + linalg.fun = 1 + assert linalg._module is None + assert "linalg" not in linalg.__dict__ + + linalg.det([[1]]) + + assert linalg._module is not None + assert globals()["linalg"] == linalg._module + assert "fun" in linalg.__dict__ + assert "LinAlgError" in dir(linalg) + assert linalg.fun == 1 diff --git a/cirq-core/cirq/experiments/t1_decay_experiment.py b/cirq-core/cirq/experiments/t1_decay_experiment.py index fd1cb6c37e9..4c785539f9c 100644 --- a/cirq-core/cirq/experiments/t1_decay_experiment.py +++ b/cirq-core/cirq/experiments/t1_decay_experiment.py @@ -21,12 +21,15 @@ import numpy as np -from cirq import circuits, ops, study, value +from cirq import circuits, ops, study, value, _import from cirq._compat import proper_repr if TYPE_CHECKING: import cirq +# We initialize optimize lazily, otherwise it slows global import speed. +optimize = _import.LazyLoader("optimize", globals(), "scipy.optimize") + # TODO(#3388) Add documentation for Raises. # pylint: disable=missing-raises-doc @@ -130,10 +133,7 @@ def exp_decay(x, t1): # Fit to exponential decay to find the t1 constant try: - # Import scipy.optimize here to avoid costly module level import. - import scipy.optimize - - popt, _ = scipy.optimize.curve_fit(exp_decay, xs, probs, p0=[t1_guess]) + popt, _ = optimize.curve_fit(exp_decay, xs, probs, p0=[t1_guess]) t1 = popt[0] return t1 except RuntimeError: diff --git a/cirq-core/cirq/experiments/xeb_fitting.py b/cirq-core/cirq/experiments/xeb_fitting.py index bbe11264c2c..32a26e1f59f 100644 --- a/cirq-core/cirq/experiments/xeb_fitting.py +++ b/cirq-core/cirq/experiments/xeb_fitting.py @@ -27,7 +27,7 @@ import numpy as np import pandas as pd import sympy -from cirq import ops, protocols +from cirq import ops, protocols, _import from cirq.circuits import Circuit from cirq.experiments.xeb_simulation import simulate_2q_xeb_circuits @@ -36,6 +36,10 @@ import multiprocessing import scipy.optimize +# We initialize these lazily, otherwise they slow global import speed. +optimize = _import.LazyLoader("optimize", globals(), "scipy.optimize") +stats = _import.LazyLoader("stats", globals(), "scipy.stats") + THETA_SYMBOL, ZETA_SYMBOL, CHI_SYMBOL, GAMMA_SYMBOL, PHI_SYMBOL = sympy.symbols( 'theta zeta chi gamma phi' ) @@ -410,10 +414,7 @@ def _mean_infidelity(angles): print(f"Loss: {loss:7.3g}", flush=True) return loss - # Import scipy.optimize here to avoid costly top level moule import. - import scipy.optimize - - optimization_result = scipy.optimize.minimize( + optimization_result = optimize.minimize( _mean_infidelity, x0=x0, options={ @@ -574,15 +575,12 @@ def _fit_exponential_decay( cycle_depths_pos = cycle_depths[positives] log_fidelities = np.log(fidelities[positives]) - # We import here to avoid costly module level load time dependency on scipy.stats. - import scipy.stats - - slope, intercept, _, _, _ = scipy.stats.linregress(cycle_depths_pos, log_fidelities) + slope, intercept, _, _, _ = stats.linregress(cycle_depths_pos, log_fidelities) layer_fid_0 = np.clip(np.exp(slope), 0, 1) a_0 = np.clip(np.exp(intercept), 0, 1) try: - (a, layer_fid), pcov = scipy.optimize.curve_fit( + (a, layer_fid), pcov = optimize.curve_fit( exponential_decay, cycle_depths, fidelities, diff --git a/cirq-core/cirq/qis/measures.py b/cirq-core/cirq/qis/measures.py index 518b8ec2de5..38602bca319 100644 --- a/cirq-core/cirq/qis/measures.py +++ b/cirq-core/cirq/qis/measures.py @@ -13,12 +13,12 @@ # limitations under the License. """Measures on and between quantum states and operations.""" + from typing import Optional, TYPE_CHECKING, Tuple import numpy as np -import scipy -from cirq import protocols, value +from cirq import protocols, value, _import from cirq.qis.states import ( QuantumState, infer_qid_shape, @@ -27,13 +27,18 @@ validate_normalized_state_vector, ) +# We initialize these lazily, otherwise they slow global import speed. +stats = _import.LazyLoader("stats", globals(), "scipy.stats") +linalg = _import.LazyLoader("linalg", globals(), "scipy.linalg") + + if TYPE_CHECKING: import cirq def _sqrt_positive_semidefinite_matrix(mat: np.ndarray) -> np.ndarray: """Square root of a positive semidefinite matrix.""" - eigs, vecs = scipy.linalg.eigh(mat) + eigs, vecs = linalg.eigh(mat) return vecs @ (np.sqrt(np.abs(eigs)) * vecs).T.conj() @@ -237,7 +242,7 @@ def _fidelity_state_vectors_or_density_matrices(state1: np.ndarray, state2: np.n elif state1.ndim == 2 and state2.ndim == 2: # Both density matrices state1_sqrt = _sqrt_positive_semidefinite_matrix(state1) - eigs = scipy.linalg.eigvalsh(state1_sqrt @ state2 @ state1_sqrt) + eigs = linalg.eigvalsh(state1_sqrt @ state2 @ state1_sqrt) trace = np.sum(np.sqrt(np.abs(eigs))) return trace ** 2 raise ValueError( @@ -277,11 +282,7 @@ def von_neumann_entropy( qid_shape = (state.shape[0],) validate_density_matrix(state, qid_shape=qid_shape, dtype=state.dtype, atol=atol) eigenvalues = np.linalg.eigvalsh(state) - - # We import here to avoid a costly module level load time dependency on scipy.stats. - import scipy.stats - - return scipy.stats.entropy(np.abs(eigenvalues), base=2) + return stats.entropy(np.abs(eigenvalues), base=2) if validate: _ = quantum_state(state, qid_shape=qid_shape, copy=False, validate=True, atol=atol) return 0.0