Skip to content

Commit

Permalink
Merge pull request #4975 from jenshnielsen/add_types_to_tests_2
Browse files Browse the repository at this point in the history
Add types to part of test module
  • Loading branch information
jenshnielsen authored Feb 6, 2023
2 parents 6d9330b + 068232a commit 12dbe72
Show file tree
Hide file tree
Showing 41 changed files with 1,127 additions and 700 deletions.
2 changes: 1 addition & 1 deletion qcodes/parameters/combined_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def combine(
label: str | None = None,
unit: str | None = None,
units: str | None = None,
aggregator: Callable[[Sequence[Any]], Any] | None = None,
aggregator: Callable[..., Any] | None = None,
) -> CombinedParameter:
"""
Combine parameters into one sweepable parameter
Expand Down
7 changes: 4 additions & 3 deletions qcodes/parameters/val_mapping.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

from collections import OrderedDict
from typing import Any
from typing import TypeVar

T = TypeVar("T")

def create_on_off_val_mapping(
on_val: Any = True, off_val: Any = False
) -> dict[str | bool, Any]:
on_val: T | bool = True, off_val: T | bool = False
) -> OrderedDict[str | bool, T | bool]:
"""
Returns a value mapping which maps inputs which reasonably mean "on"/"off"
to the specified ``on_val``/``off_val`` which are to be sent to the
Expand Down
90 changes: 49 additions & 41 deletions qcodes/tests/parameter/conftest.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,49 @@
from __future__ import annotations

from collections import namedtuple
from typing import Any, Generator
from typing import Any, Callable, Generator, Literal, TypeVar

import pytest

import qcodes.validators as vals
from qcodes.instrument import InstrumentBase
from qcodes.parameters import Parameter
from qcodes.parameters import ParamDataType, Parameter, ParamRawDataType
from qcodes.tests.instrument_mocks import DummyChannelInstrument

NOT_PASSED = 'NOT_PASSED'
T = TypeVar("T")

NOT_PASSED: Literal["NOT_PASSED"] = "NOT_PASSED"


@pytest.fixture(params=(True, False, NOT_PASSED))
def snapshot_get(request: pytest.FixtureRequest) -> Any:
def snapshot_get(request: pytest.FixtureRequest) -> bool | Literal["NOT_PASSED"]:
return request.param


@pytest.fixture(params=(True, False, NOT_PASSED))
def snapshot_value(request: pytest.FixtureRequest) -> Any:
def snapshot_value(request: pytest.FixtureRequest) -> bool | Literal["NOT_PASSED"]:
return request.param


@pytest.fixture(params=(None, False, NOT_PASSED))
def get_cmd(request: pytest.FixtureRequest) -> Any:
def get_cmd(
request: pytest.FixtureRequest,
) -> None | Literal[False] | Literal["NOT_PASSED"]:
return request.param


@pytest.fixture(params=(True, False, NOT_PASSED))
def get_if_invalid(request: pytest.FixtureRequest) -> Any:
def get_if_invalid(request: pytest.FixtureRequest) -> bool | Literal["NOT_PASSED"]:
return request.param


@pytest.fixture(params=(True, False, None, NOT_PASSED))
def update(request: pytest.FixtureRequest) -> Any:
def update(request: pytest.FixtureRequest) -> bool | None | Literal["NOT_PASSED"]:
return request.param


@pytest.fixture(params=(True, False))
def cache_is_valid(request: pytest.FixtureRequest) -> Any:
def cache_is_valid(request: pytest.FixtureRequest) -> bool:
return request.param


Expand All @@ -52,19 +56,19 @@ def _make_dummy_instrument() -> Generator[DummyChannelInstrument, None, None]:

class GettableParam(Parameter):
""" Parameter that keeps track of number of get operations"""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._get_count = 0

def get_raw(self):
def get_raw(self) -> int:
self._get_count += 1
return 42


class BetterGettableParam(Parameter):
""" Parameter that keeps track of number of get operations,
But can actually store values"""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._get_count = 0

Expand All @@ -75,77 +79,81 @@ def get_raw(self) -> Any:

class SettableParam(Parameter):
""" Parameter that keeps track of number of set operations"""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
self._set_count = 0
super().__init__(*args, **kwargs)

def set_raw(self, value):
def set_raw(self, value: Any) -> None:
self._set_count += 1


class OverwriteGetParam(Parameter):
""" Parameter that overwrites get."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._value = 42
self.set_count = 0
self.get_count = 0

def get(self):
def get(self) -> int:
self.get_count += 1
return self._value


class OverwriteSetParam(Parameter):
""" Parameter that overwrites set."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._value = 42
self.set_count = 0
self.get_count = 0

def set(self, value):
def set(self, value: Any) -> None:
self.set_count += 1
self._value = value


class GetSetRawParameter(Parameter):
""" Parameter that implements get and set raw"""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)

def get_raw(self):
def get_raw(self) -> ParamRawDataType:
return self.cache.raw_value

def set_raw(self, value):
def set_raw(self, value: ParamRawDataType) -> None:
pass


class BookkeepingValidator(vals.Validator[Any]):
class BookkeepingValidator(vals.Validator[T]):
"""
Validator that keeps track of what it validates
"""
def __init__(self, min_value=-float("inf"), max_value=float("inf")):
self.values_validated = []

def validate(self, value, context=''):
def __init__(
self, min_value: float = -float("inf"), max_value: float = float("inf")
):
self.values_validated: list[T] = []

def validate(self, value: T, context: str = "") -> None:
self.values_validated.append(value)

is_numeric = True


class MemoryParameter(Parameter):
def __init__(self, get_cmd=None, **kwargs):
self.set_values = []
self.get_values = []
def __init__(self, get_cmd: None | Callable[[], Any] = None, **kwargs: Any):
self.set_values: list[Any] = []
self.get_values: list[Any] = []
super().__init__(set_cmd=self.add_set_value,
get_cmd=self.create_get_func(get_cmd), **kwargs)

def add_set_value(self, value):
def add_set_value(self, value: ParamDataType) -> None:
self.set_values.append(value)

def create_get_func(self, func):
def get_func():
def create_get_func(
self, func: None | Callable[[], ParamDataType]
) -> Callable[[], ParamDataType]:
def get_func() -> ParamDataType:
if func is not None:
val = func()
else:
Expand All @@ -156,15 +164,15 @@ def get_func():


class VirtualParameter(Parameter):
def __init__(self, name: str, param: Parameter, **kwargs):
def __init__(self, name: str, param: Parameter, **kwargs: Any):
self._param = param
super().__init__(name=name, **kwargs)

@property
def underlying_instrument(self) -> InstrumentBase | None:
return self._param.instrument

def get_raw(self):
def get_raw(self) -> ParamRawDataType:
return self._param.get()


Expand All @@ -178,22 +186,22 @@ def get_raw(self):

class ParameterMemory:

def __init__(self):
self._value = None
def __init__(self) -> None:
self._value: Any | None = None

def get(self):
def get(self) -> ParamDataType:
return self._value

def set(self, value):
def set(self, value: ParamDataType) -> None:
self._value = value

def set_p_prefixed(self, val):
def set_p_prefixed(self, val: int) -> None:
self._value = f'PVAL: {val:d}'

@staticmethod
def parse_set_p(val):
def parse_set_p(val: int) -> str:
return f'{val:d}'

@staticmethod
def strip_prefix(val):
def strip_prefix(val: str) -> int:
return int(val[6:])
45 changes: 26 additions & 19 deletions qcodes/tests/parameter/test_array_parameter.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
from typing import Any

import pytest

from qcodes.parameters import ArrayParameter
from qcodes.parameters import ArrayParameter, ParamRawDataType

from .conftest import blank_instruments, named_instrument


class SimpleArrayParam(ArrayParameter):
def __init__(self, return_val, *args, **kwargs):
def __init__(self, return_val: ParamRawDataType, *args: Any, **kwargs: Any):
self._return_val = return_val
self._get_count = 0
super().__init__(*args, **kwargs)

def get_raw(self):
def get_raw(self) -> ParamRawDataType:
self._get_count += 1
return self._return_val


class SettableArray(SimpleArrayParam):
# this is not allowed - just created to raise an error in the test below
def set_raw(self, value):
def set_raw(self, value: Any) -> None:
self.v = value


def test_default_attributes():
def test_default_attributes() -> None:
name = 'array_param'
shape = (2, 3)
p = SimpleArrayParam([[1, 2, 3], [4, 5, 6]], name, shape)
Expand Down Expand Up @@ -52,10 +54,11 @@ def test_default_attributes():
assert 'raw_value' not in snap
assert snap['ts'] is None

assert p.__doc__ is not None
assert name in p.__doc__


def test_explicit_attributes():
def test_explicit_attributes() -> None:
name = 'tiny_array'
shape = (2,)
label = 'it takes two to tango'
Expand Down Expand Up @@ -98,11 +101,12 @@ def test_explicit_attributes():
assert snap[k] == v
assert snap['ts'] is not None

assert p.__doc__ is not None
assert name in p.__doc__
assert docstring in p.__doc__


def test_has_set_get():
def test_has_set_get() -> None:
name = 'array_param'
shape = (3,)
with pytest.raises(AttributeError):
Expand All @@ -128,27 +132,30 @@ def test_has_set_get():
SettableArray([1, 2, 3], name, shape)


def test_full_name():
def test_full_name() -> None:
# three cases where only name gets used for full_name
for instrument in blank_instruments:
p = SimpleArrayParam([6, 7], 'fred', (2,),
setpoint_names=('barney',))
p._instrument = instrument
# this is not allowed since instrument
# here is not actually an instrument
# but useful for testing
p._instrument = instrument # type: ignore[assignment]
assert str(p) == 'fred'
assert p.setpoint_full_names == ('barney',)

# and then an instrument that really has a name
p = SimpleArrayParam([6, 7], 'wilma', (2,),
setpoint_names=('betty',))
p._instrument = named_instrument
assert str(p) == 'astro_wilma'
assert p.setpoint_full_names == ('astro_betty',)
p = SimpleArrayParam([6, 7], "wilma", (2,), setpoint_names=("betty",))
p._instrument = named_instrument # type: ignore[assignment]
assert str(p) == "astro_wilma"
assert p.setpoint_full_names == ("astro_betty",)

# and with a 2d parameter to test mixed setpoint_names
p = SimpleArrayParam([[6, 7, 8], [1, 2, 3]], 'wilma', (3, 2),
setpoint_names=('betty', None))
p._instrument = named_instrument
assert p.setpoint_full_names == ('astro_betty', None)
p = SimpleArrayParam(
[[6, 7, 8], [1, 2, 3]], "wilma", (3, 2), setpoint_names=("betty", None)
)
p._instrument = named_instrument # type: ignore[assignment]
assert p.setpoint_full_names == ("astro_betty", None)


@pytest.mark.parametrize("constructor", [
Expand All @@ -158,6 +165,6 @@ def test_full_name():
{'shape': [3], 'setpoint_labels': 'the index'}, # ['the index']
{'shape': [3], 'setpoint_names': [None, 'index2']}
])
def test_constructor_errors(constructor):
def test_constructor_errors(constructor: dict) -> None:
with pytest.raises(ValueError):
SimpleArrayParam([1, 2, 3], 'p', **constructor)
Loading

0 comments on commit 12dbe72

Please sign in to comment.