Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the stubgen helper function to attach stubs to the correct class in modules with multiple classes #276

Merged
merged 5 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ Valid subsections within a version are:

Things to be included in the next release go here.

### Fixed

- Fixed the stubgen helper to properly attach stubs to the correct class in modules that have multiple classes.

---

## v2.2.1 (2024-08-07)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ yamlfix = "^1.16.0"
[tool.poetry.group.docs.dependencies]
black = "^24.4.2"
codespell = "^2.2.6"
griffe = "^0.47.0"
mkdocs = "^1.6.0"
mkdocs-ezglossary-plugin = "^1.6.10"
mkdocs-gen-files = "^0.5.0"
Expand Down
18 changes: 16 additions & 2 deletions src/tm_devices/helpers/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _get_data_type(data_object: Any) -> str:


# pylint: disable=too-many-locals
def add_info_to_stub(cls: Any, method: Any, is_property: bool = False) -> None: # noqa: C901
def add_info_to_stub(cls: Any, method: Any, is_property: bool = False) -> None: # noqa: C901,PLR0912
"""Add information to a stub file.

This method requires that an environment variable named ``TM_DEVICES_STUB_DIR`` is defined that
Expand All @@ -42,6 +42,7 @@ def add_info_to_stub(cls: Any, method: Any, is_property: bool = False) -> None:

Raises:
AssertionError: Indicates that the file that needs to be updated does not exist.
ValueError: Indicates that the class could not be found in the stub file.
"""
if stub_dir := os.getenv("TM_DEVICES_STUB_DIR"):
method_filepath = inspect.getfile(cls)
Expand Down Expand Up @@ -88,8 +89,21 @@ def add_info_to_stub(cls: Any, method: Any, is_property: bool = False) -> None:
with open(method_filepath, encoding="utf-8") as file_pointer:
contents = file_pointer.read()
if f" def {method.__name__}(" not in contents:
contents += method_stub_content
if typing_imports:
contents = f"from typing import {', '.join(typing_imports)}\n" + contents
# Use a regular expression to find the end of the current class
pattern = r"(class\s+" + cls.__name__ + r"\b.*?)(\n(?=def|class)|\Z)"
# Insert the new code at the end of the current class
if match := re.search(pattern, contents, flags=re.DOTALL):
end_pos = match.end()
first_half_contents = contents[:end_pos]
if first_half_contents.endswith("\n\n"):
first_half_contents = first_half_contents[:-1]
second_half_contents = contents[end_pos:]
contents = first_half_contents + method_stub_content + second_half_contents
else: # pragma: no cover
msg = f"Could not find the end of the {cls.__class__.__name__} class."
raise ValueError(msg)

with open(method_filepath, "w", encoding="utf-8") as file_pointer:
file_pointer.write(contents)
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ def fixture_device_manager() -> Generator[DeviceManager, None, None]:
yield dev_manager


@pytest.fixture(autouse=True)
def _reset_dm(device_manager: DeviceManager) -> Generator[None, None, None]: # pyright: ignore[reportUnusedFunction]
"""Reset the device_manager settings after each test.

Args:
device_manager: The device manager fixture.
"""
saved_setup_enable = device_manager.setup_cleanup_enabled
saved_teardown_enable = device_manager.teardown_cleanup_enabled
yield
device_manager.setup_cleanup_enabled = saved_setup_enable
device_manager.teardown_cleanup_enabled = saved_teardown_enable


@pytest.fixture(name="mock_http_server", scope="session")
def _fixture_mock_http_server() -> ( # pyright: ignore [reportUnusedFunction]
Generator[None, None, None]
Expand Down
7 changes: 7 additions & 0 deletions tests/samples/golden_stubs/drivers/device.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,10 @@ class Device(ABC, metaclass=abc.ABCMeta):

This has a multi-line description.
"""

def function_1(arg1: str, arg2: int = 1) -> bool: ...

class OtherDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...

def function_2(arg1: str, arg2: int = 2) -> bool: ...
280 changes: 0 additions & 280 deletions tests/test_device_manager.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
# pyright: reportUnusedFunction=none
# pyright: reportUnknownMemberType=none
# pyright: reportAttributeAccessIssue=none
# pyright: reportUnknownVariableType=none
# pyright: reportArgumentType=none
"""Tests for the device_manager.py file."""

import contextlib
import os
import subprocess
import sys

from pathlib import Path
from typing import Generator, Iterator, List
from unittest import mock

import pytest
Expand All @@ -20,51 +13,9 @@

from conftest import SIMULATED_VISA_LIB
from tm_devices import DeviceManager
from tm_devices.drivers import AFG3K, AFG3KC
from tm_devices.drivers.device import Device
from tm_devices.drivers.pi.scopes.scope import Scope
from tm_devices.drivers.pi.signal_generators.afgs.afg import AFG
from tm_devices.drivers.pi.signal_generators.signal_generator import SignalGenerator
from tm_devices.helpers import ConnectionTypes, DeviceTypes, PYVISA_PY_BACKEND, SerialConfig


@pytest.fixture(scope="module", autouse=True)
def _remove_added_methods() -> Iterator[None]:
"""Remove custom added methods from devices."""
yield
for obj, name in (
(Device, "inc_cached_count"),
(Device, "inc_count"),
(Device, "class_name"),
(Device, "custom_model_getter"),
(Device, "custom_list"),
(Device, "custom_return_none"),
(Device, "already_exists"),
(Scope, "custom_model_getter_scope"),
(Scope, "custom_return"),
(SignalGenerator, "custom_model_getter_ss"),
(AFG, "custom_model_getter_afg"),
(AFG3K, "custom_model_getter_afg3k"),
(AFG3KC, "custom_model_getter_afg3kc"),
):
with contextlib.suppress(AttributeError):
delattr(obj, name)


@pytest.fixture(autouse=True)
def _reset_dm(device_manager: DeviceManager) -> Generator[None, None, None]:
"""Reset the device_manager settings after each test.

Args:
device_manager: The device manager fixture.
"""
saved_setup_enable = device_manager.setup_cleanup_enabled
saved_teardown_enable = device_manager.teardown_cleanup_enabled
yield
device_manager.setup_cleanup_enabled = saved_setup_enable
device_manager.teardown_cleanup_enabled = saved_teardown_enable


class TestDeviceManager: # pylint: disable=no-self-use
"""Device Manager test class."""

Expand Down Expand Up @@ -222,237 +173,6 @@ def test_dm_properties(self, device_manager: DeviceManager) -> None:
device_manager.verbose = saved_verbose
device_manager.visa_library = saved_visa_lib

# pylint: disable=too-many-locals
def test_visa_device_methods_and_method_adding( # noqa: C901,PLR0915
self, device_manager: DeviceManager, capsys: pytest.CaptureFixture[str]
) -> None:
"""Test methods pertaining to VISA devices.

Args:
device_manager: The DeviceManager object.
capsys: The captured stdout and stderr.
"""
# Remove all previous devices
device_manager.remove_all_devices()
# Read the captured stdout to clear it
_ = capsys.readouterr().out
saved_setup_enable = device_manager.setup_cleanup_enabled
saved_teardown_enable = device_manager.teardown_cleanup_enabled
device_manager.setup_cleanup_enabled = True
device_manager.teardown_cleanup_enabled = True

############################################################################################
# Make sure to add all methods to the remove_added_methods() fixture
# at the top of this test module.

def gen_count() -> Iterator[int]:
"""Local counter."""
count = 0
while True:
count += 1
yield count

local_count = gen_count()

initial_input = '''import abc
from abc import ABC

from tm_devices.helpers import DeviceConfigEntry

class Device(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""
'''
sub_filepath = Path("drivers/device.pyi")
generated_stub_dir = (
Path(__file__).parent
/ "samples/generated_stubs"
/ f"output_{sys.version_info.major}{sys.version_info.minor}/tm_devices"
)
generated_stub_file = generated_stub_dir / sub_filepath
golden_stub_dir = Path(__file__).parent / "samples" / "golden_stubs"
generated_stub_file.parent.mkdir(parents=True, exist_ok=True)
with open(generated_stub_file, "w", encoding="utf-8") as generated_file:
generated_file.write(initial_input)
with mock.patch.dict("os.environ", {"TM_DEVICES_STUB_DIR": str(generated_stub_dir)}):
# noinspection PyUnusedLocal,PyShadowingNames
@Device.add_property(is_cached=True)
def inc_cached_count(self: Device) -> int: # noqa: ARG001
"""Increment a local counter."""
return next(local_count)

# noinspection PyUnusedLocal,PyShadowingNames
@Device.add_property(is_cached=False)
def inc_count(self: Device) -> int: # noqa: ARG001
"""Increment a local counter."""
return next(local_count)

# noinspection PyShadowingNames
@Device.add_property
def class_name(self: Device) -> str:
"""Return the class name."""
return self.__class__.__name__

# noinspection PyShadowingNames
@Device.add_method
def custom_model_getter(
self: Device,
value1: str,
value2: str = "add",
value3: str = "",
value4: float = 0.1,
) -> str:
"""Return the model."""
return " ".join(["Device", self.model, value1, value2, value3, str(value4)])

# noinspection PyShadowingNames
@Device.add_method
def custom_list(self: Device) -> List[str]:
"""Return the model and serial in a list."""
return [self.model, self.serial]

@Device.add_method
def custom_return_none() -> None:
"""Return nothing.

This has a multi-line description.
"""

@Device.add_method
def already_exists() -> None:
"""Return nothing."""

with pytest.raises(AssertionError):

@Scope.add_method
def custom_return() -> None:
"""Return nothing."""

@Scope.add_method
def custom_model_getter_scope(device: Scope, value: str) -> str:
"""Return the model."""
return f"Scope {device.model} {value}"

@SignalGenerator.add_method
def custom_model_getter_ss(device: SignalGenerator, value: str) -> str:
"""Return the model."""
return f"SignalGenerator {device.model} {value}"

@AFG.add_method
def custom_model_getter_afg(device: AFG, value: str) -> str:
"""Return the model."""
return f"AFG {device.model} {value}"

@AFG3K.add_method
def custom_model_getter_afg3k(device: AFG3K, value: str) -> str:
"""Return the model."""
return f"AFG3K {device.model} {value}"

@AFG3KC.add_method
def custom_model_getter_afg3kc(device: AFG3KC, value: str) -> str:
"""Return the model."""
return f"AFG3KC {device.model} {value}"

############################################################################################
start_dir = os.getcwd()
try:
os.chdir(generated_stub_file.parent)
subprocess.check_call( # noqa: S603
[
sys.executable,
"-m",
"ruff",
"format",
"--quiet",
generated_stub_file.name,
]
)
subprocess.check_call( # noqa: S603
[
sys.executable,
"-m",
"ruff",
"check",
"--quiet",
"--select=I",
"--fix",
generated_stub_file.name,
]
)
finally:
os.chdir(start_dir)
with open(golden_stub_dir / sub_filepath, encoding="utf-8") as golden_file:
golden_contents = golden_file.read()
with open(generated_stub_file, encoding="utf-8") as generated_file:
generated_contents = generated_file.read()
assert generated_contents == golden_contents

# Test the custom added properties
afg = device_manager.add_afg("afg3252c-hostname", alias="testing")
# noinspection PyUnresolvedReferences
assert afg.class_name == "AFG3KC"
# noinspection PyUnresolvedReferences
_ = afg.inc_cached_count
# noinspection PyUnresolvedReferences
assert afg.inc_cached_count == 1, "cached property is not working"
# noinspection PyUnresolvedReferences
_ = afg.inc_count
# noinspection PyUnresolvedReferences
assert afg.inc_count == 3, "uncached property is not working"

# Test the custom added methods
# noinspection PyUnresolvedReferences
assert afg.custom_model_getter("a", "b", "c", 0.1) == "Device AFG3252C a b c 0.1"
# noinspection PyUnresolvedReferences
assert afg.custom_model_getter_ss("hello") == "SignalGenerator AFG3252C hello"
# noinspection PyUnresolvedReferences
assert afg.custom_model_getter_afg("hello") == "AFG AFG3252C hello"
# noinspection PyUnresolvedReferences
assert afg.custom_model_getter_afg3k("hello") == "AFG3K AFG3252C hello"
# noinspection PyUnresolvedReferences
assert afg.custom_model_getter_afg3kc("hello") == "AFG3KC AFG3252C hello"
with pytest.raises(AttributeError):
# noinspection PyUnresolvedReferences
afg.custom_model_getter_scope("hello")

# Test VISA methods
assert afg.set_and_check("OUTPUT1:STATE", "1", custom_message_prefix="Custom prefix") == "1"
device_manager.disable_device_command_checking()
assert afg.set_and_check("OUTPUT1:STATE", "0") == ""
device_manager.cleanup_all_devices()
console_output = capsys.readouterr()
assert "Beginning Device Cleanup on AFG " in console_output.out
assert "Response from 'OUTPUT1:STATE?' >> '1'" in console_output.out
assert "Response from 'OUTPUT1:STATE?' >> '0'" not in console_output.out
assert console_output.err == ""

assert len(device_manager.devices) == 1
device_manager.close()
assert "Beginning Device Cleanup" in capsys.readouterr().out
assert len(device_manager.devices) == 1

device_manager.setup_cleanup_enabled = False
device_manager.open()
device_manager.verbose_visa = True
afg = device_manager.get_afg(number_or_alias="testing")
afg.ieee_cmds.idn()
assert "pyvisa - DEBUG" in capsys.readouterr().err
device_manager.verbose_visa = False
assert not device_manager.verbose_visa
afg.ieee_cmds.idn()
assert "pyvisa - DEBUG" not in capsys.readouterr().err
device_manager.teardown_cleanup_enabled = False
assert len(device_manager.devices) == 1
device_manager.close()
assert "Beginning Device Cleanup" not in capsys.readouterr().out
assert len(device_manager.devices) == 1

device_manager.open()
device_manager.remove_device(alias="testing")
device_manager.setup_cleanup_enabled = saved_setup_enable
device_manager.teardown_cleanup_enabled = saved_teardown_enable

def test_failed_cleanup(self, device_manager: DeviceManager) -> None:
"""Test what happens when a device manager cleanup fails.

Expand Down
Loading
Loading