Skip to content

Commit

Permalink
[Fix] Remove jax import in states module + Fix docs warnings (#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz authored Jun 17, 2024
1 parent 5021ab4 commit 236aee4
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ authors = [
]
requires-python = ">=3.9"
license = {text = "Apache 2.0"}
version = "1.6.1"
version = "1.6.2"
classifiers=[
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
Expand Down
5 changes: 1 addition & 4 deletions qadence/ml_tools/saveload.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,7 @@ def load_model(
except Exception as e:
msg = f"Unable to load state dict due to {e}.\
No corresponding pre-trained model found. Returning the un-trained model."
import warnings

warnings.warn(msg, UserWarning)
logger.warn(msg)
logger.warning(msg)
return model, iteration


Expand Down
2 changes: 1 addition & 1 deletion qadence/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from qadence.types import DifferentiableExpression, Engine, TNumber

# Modules to be automatically added to the qadence namespace
__all__ = ["FeatureParameter", "Parameter", "VariationalParameter"]
__all__ = ["FeatureParameter", "Parameter", "VariationalParameter", "ParamMap"]

logger = getLogger(__name__)

Expand Down
15 changes: 7 additions & 8 deletions qadence/states.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from __future__ import annotations

import random
import warnings
from functools import singledispatch
from typing import List

import torch
from jax.typing import ArrayLike
from numpy.typing import ArrayLike
from torch import Tensor, concat
from torch.distributions import Categorical, Distribution

from qadence.blocks import ChainBlock, KronBlock, PrimitiveBlock, chain, kron
from qadence.circuit import QuantumCircuit
from qadence.execution import run
from qadence.logger import get_script_logger
from qadence.operations import CNOT, RX, RY, RZ, H, I, X
from qadence.types import PI, BackendName, Endianness, StateGeneratorType
from qadence.utils import basis_to_int
Expand Down Expand Up @@ -45,6 +45,7 @@

parametric_single_qubit_gates: List = [RX, RY, RZ]

logger = get_script_logger(__name__)
# PRIVATE


Expand Down Expand Up @@ -189,15 +190,15 @@ def product_state(
bitstring: str,
batch_size: int = 1,
endianness: Endianness = Endianness.BIG,
backend: str = "pyqtorch",
backend: BackendName = BackendName.PYQTORCH,
) -> ArrayLike:
"""
Creates a product state from a bitstring.
Arguments:
bitstring (str): A bitstring.
batch_size (int) : Batch size.
backend (str): The backend to use. Default is "pyqtorch".
backend (BackendName): The backend to use. Default is "pyqtorch".
Returns:
A torch.Tensor.
Expand All @@ -211,11 +212,9 @@ def product_state(
```
"""
if batch_size:
warnings.warn(
logger.debug(
"The input `batch_size` is going to be deprecated. "
"For now, default batch_size is set to 1.",
DeprecationWarning,
stacklevel=2,
"For now, default batch_size is set to 1."
)
return run(product_block(bitstring), backend=backend, endianness=endianness)

Expand Down

0 comments on commit 236aee4

Please sign in to comment.