From 236aee455613957488e7dc91afa1809b48aa6133 Mon Sep 17 00:00:00 2001 From: seitzdom Date: Mon, 17 Jun 2024 19:05:33 +0200 Subject: [PATCH] [Fix] Remove jax import in `states` module + Fix docs warnings (#466) --- pyproject.toml | 2 +- qadence/ml_tools/saveload.py | 5 +---- qadence/parameters.py | 2 +- qadence/states.py | 15 +++++++-------- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c17daa918..18629d740 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ authors = [ ] requires-python = ">=3.9" license = {text = "Apache 2.0"} -version = "1.6.1" +version = "1.6.2" classifiers=[ "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", diff --git a/qadence/ml_tools/saveload.py b/qadence/ml_tools/saveload.py index a19cedfcd..7052db508 100644 --- a/qadence/ml_tools/saveload.py +++ b/qadence/ml_tools/saveload.py @@ -120,10 +120,7 @@ def load_model( except Exception as e: msg = f"Unable to load state dict due to {e}.\ No corresponding pre-trained model found. Returning the un-trained model." - import warnings - - warnings.warn(msg, UserWarning) - logger.warn(msg) + logger.warning(msg) return model, iteration diff --git a/qadence/parameters.py b/qadence/parameters.py index 675d36c7c..a92a5adbf 100644 --- a/qadence/parameters.py +++ b/qadence/parameters.py @@ -16,7 +16,7 @@ from qadence.types import DifferentiableExpression, Engine, TNumber # Modules to be automatically added to the qadence namespace -__all__ = ["FeatureParameter", "Parameter", "VariationalParameter"] +__all__ = ["FeatureParameter", "Parameter", "VariationalParameter", "ParamMap"] logger = getLogger(__name__) diff --git a/qadence/states.py b/qadence/states.py index d75176fff..7fca667df 100644 --- a/qadence/states.py +++ b/qadence/states.py @@ -1,18 +1,18 @@ from __future__ import annotations import random -import warnings from functools import singledispatch from typing import List import torch -from jax.typing import ArrayLike +from numpy.typing import ArrayLike from torch import Tensor, concat from torch.distributions import Categorical, Distribution from qadence.blocks import ChainBlock, KronBlock, PrimitiveBlock, chain, kron from qadence.circuit import QuantumCircuit from qadence.execution import run +from qadence.logger import get_script_logger from qadence.operations import CNOT, RX, RY, RZ, H, I, X from qadence.types import PI, BackendName, Endianness, StateGeneratorType from qadence.utils import basis_to_int @@ -45,6 +45,7 @@ parametric_single_qubit_gates: List = [RX, RY, RZ] +logger = get_script_logger(__name__) # PRIVATE @@ -189,7 +190,7 @@ def product_state( bitstring: str, batch_size: int = 1, endianness: Endianness = Endianness.BIG, - backend: str = "pyqtorch", + backend: BackendName = BackendName.PYQTORCH, ) -> ArrayLike: """ Creates a product state from a bitstring. @@ -197,7 +198,7 @@ def product_state( Arguments: bitstring (str): A bitstring. batch_size (int) : Batch size. - backend (str): The backend to use. Default is "pyqtorch". + backend (BackendName): The backend to use. Default is "pyqtorch". Returns: A torch.Tensor. @@ -211,11 +212,9 @@ def product_state( ``` """ if batch_size: - warnings.warn( + logger.debug( "The input `batch_size` is going to be deprecated. " - "For now, default batch_size is set to 1.", - DeprecationWarning, - stacklevel=2, + "For now, default batch_size is set to 1." ) return run(product_block(bitstring), backend=backend, endianness=endianness)