Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Update stub generation helper function to handle classes followed by dataclasses #307

Merged
merged 2 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ Things to be included in the next release go here.
- _**SEMI-BREAKING CHANGE**_: Changed the `USB_MODEL_ID_LOOKUP` constant to use `SupportedModels` as keys instead of values to make the documentation clearer.
- _**SEMI-BREAKING CHANGE**_: Changed the `DEVICE_DRIVER_MODEL_MAPPING` constant to use `SupportedModels` as keys instead of values to make the documentation clearer.

### Fixed

- Fixed a bug in the stubgen helper code responsible for adding dynamically added methods to stub files that caused invalid stub files to be created if a dataclass immediately followed a class that was being dynamically updated.

---

## v2.3.0 (2024-08-23)
Expand Down
4 changes: 2 additions & 2 deletions src/tm_devices/helpers/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def add_info_to_stub(cls: Any, method: Any, is_property: bool = False) -> None:
if typing_imports:
contents = f"from typing import {', '.join(typing_imports)}\n" + contents
# Use a regular expression to find the end of the current class
pattern = r"(class\s+" + cls.__name__ + r"\b.*?)(\n(?=def|class)|\Z)"
pattern = r"(class\s+" + cls.__name__ + r"\b.*?)(\n(?=def|class|@)|\Z)"
# Insert the new code at the end of the current class
if match := re.search(pattern, contents, flags=re.DOTALL):
end_pos = match.end()
Expand All @@ -102,7 +102,7 @@ def add_info_to_stub(cls: Any, method: Any, is_property: bool = False) -> None:
second_half_contents = contents[end_pos:]
contents = first_half_contents + method_stub_content + second_half_contents
else: # pragma: no cover
msg = f"Could not find the end of the {cls.__class__.__name__} class."
msg = f"Could not find the end of the {cls.__name__} class."
raise ValueError(msg)

with open(method_filepath, "w", encoding="utf-8") as file_pointer:
Expand Down
6 changes: 6 additions & 0 deletions tests/samples/golden_stubs/drivers/device.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@ from typing import List
from tm_devices.helpers import DeviceConfigEntry

class Device(ABC, metaclass=abc.ABCMeta):
class NestedClass:
"""This is a nested class."""

def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""
@property
def existing_property(self) -> int:
"""Return an int."""
@property
def inc_cached_count(self) -> int:
"""Increment a local counter."""
@property
Expand Down
15 changes: 15 additions & 0 deletions tests/samples/golden_stubs/drivers/pi/pi_device.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import abc

from abc import ABC

from tm_devices.helpers import DeviceConfigEntry

class PIDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""
def added_method(self) -> None:
"""Return nothing."""

class OtherDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
18 changes: 18 additions & 0 deletions tests/samples/golden_stubs/drivers/pi/tsp_device.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import abc

from abc import ABC
from dataclasses import dataclass

from tm_devices.helpers import DeviceConfigEntry

class TSPDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""
def added_tsp_method(self) -> None:
"""Return nothing."""

@dataclass(frozen=True)
class CustomDataclass:
value1: str
value2: int = 1
128 changes: 97 additions & 31 deletions tests/test_extension_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,66 @@
from tm_devices import DeviceManager
from tm_devices.drivers import AFG3K, AFG3KC
from tm_devices.drivers.device import Device
from tm_devices.drivers.pi.pi_device import PIDevice
from tm_devices.drivers.pi.scopes.scope import Scope
from tm_devices.drivers.pi.signal_generators.afgs.afg import AFG
from tm_devices.drivers.pi.signal_generators.signal_generator import SignalGenerator
from tm_devices.drivers.pi.tsp_device import TSPDevice

INITIAL_DEVICE_INPUT = '''import abc
from abc import ABC

from tm_devices.helpers import DeviceConfigEntry

class Device(ABC, metaclass=abc.ABCMeta):
class NestedClass:
"""This is a nested class."""
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""
@property
def existing_property(self) -> int:
"""Return an int."""

def function_1(arg1: str, arg2: int = 1) -> bool: ...

class OtherDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...

def function_2(arg1: str, arg2: int = 2) -> bool: ...
'''
INITIAL_PI_DEVICE_INPUT = '''import abc

from abc import ABC

from tm_devices.helpers import DeviceConfigEntry

class PIDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""

class OtherDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
'''
INITIAL_TSP_DEVICE_INPUT = '''import abc

from abc import ABC
from dataclasses import dataclass
from tm_devices.helpers import DeviceConfigEntry

class TSPDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""

@dataclass(frozen=True)
class CustomDataclass:

value1: str
value2: int = 1

'''


@pytest.fixture(scope="module", autouse=True)
Expand All @@ -42,6 +99,8 @@ def _remove_added_methods() -> Iterator[None]:
(AFG, "custom_model_getter_afg"),
(AFG3K, "custom_model_getter_afg3k"),
(AFG3KC, "custom_model_getter_afg3kc"),
(PIDevice, "added_method"),
(TSPDevice, "added_tsp_method"),
):
with contextlib.suppress(AttributeError):
delattr(obj, name)
Expand Down Expand Up @@ -77,34 +136,23 @@ def gen_count() -> Iterator[int]:

local_count = gen_count()

initial_input = '''import abc
from abc import ABC

from tm_devices.helpers import DeviceConfigEntry

class Device(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""

def function_1(arg1: str, arg2: int = 1) -> bool: ...

class OtherDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...

def function_2(arg1: str, arg2: int = 2) -> bool: ...
'''
sub_filepath = Path("drivers/device.pyi")
golden_stub_dir = Path(__file__).parent / "samples" / "golden_stubs"
stub_device_filepath = Path("drivers/device.pyi")
stub_pi_device_filepath = Path("drivers/pi/pi_device.pyi")
stub_tsp_device_filepath = Path("drivers/pi/tsp_device.pyi")
generated_stub_dir = (
Path(__file__).parent
/ "samples/generated_stubs"
/ f"output_{sys.version_info.major}{sys.version_info.minor}/tm_devices"
)
generated_stub_file = generated_stub_dir / sub_filepath
golden_stub_dir = Path(__file__).parent / "samples" / "golden_stubs"
generated_stub_file.parent.mkdir(parents=True, exist_ok=True)
with open(generated_stub_file, "w", encoding="utf-8") as generated_file:
generated_file.write(initial_input)
generated_device_stub_file = generated_stub_dir / stub_device_filepath
generated_device_stub_file.parent.mkdir(parents=True, exist_ok=True)
generated_pi_device_stub_file = generated_stub_dir / stub_pi_device_filepath
generated_tsp_device_stub_file = generated_stub_dir / stub_tsp_device_filepath
generated_pi_device_stub_file.parent.mkdir(parents=True, exist_ok=True)
generated_device_stub_file.write_text(INITIAL_DEVICE_INPUT, encoding="utf-8")
generated_pi_device_stub_file.write_text(INITIAL_PI_DEVICE_INPUT, encoding="utf-8")
generated_tsp_device_stub_file.write_text(INITIAL_TSP_DEVICE_INPUT, encoding="utf-8")
with mock.patch.dict("os.environ", {"TM_DEVICES_STUB_DIR": str(generated_stub_dir)}):
# noinspection PyUnusedLocal,PyShadowingNames
@Device.add_property(is_cached=True)
Expand Down Expand Up @@ -153,6 +201,14 @@ def custom_return_none() -> None:
def already_exists() -> None:
"""Return nothing."""

@PIDevice.add_method
def added_method() -> None:
"""Return nothing."""

@TSPDevice.add_method
def added_tsp_method() -> None:
"""Return nothing."""

with pytest.raises(AssertionError):

@Scope.add_method
Expand Down Expand Up @@ -187,15 +243,15 @@ def custom_model_getter_afg3kc(device: AFG3KC, value: str) -> str:
############################################################################################
start_dir = os.getcwd()
try:
os.chdir(generated_stub_file.parent)
os.chdir(generated_stub_dir)
subprocess.check_call( # noqa: S603
[
sys.executable,
"-m",
"ruff",
"format",
"--quiet",
generated_stub_file.name,
generated_stub_dir,
]
)
subprocess.check_call( # noqa: S603
Expand All @@ -207,16 +263,26 @@ def custom_model_getter_afg3kc(device: AFG3KC, value: str) -> str:
"--quiet",
"--select=I",
"--fix",
generated_stub_file.name,
generated_stub_dir,
]
)
finally:
os.chdir(start_dir)
with open(golden_stub_dir / sub_filepath, encoding="utf-8") as golden_file:
golden_contents = golden_file.read()
with open(generated_stub_file, encoding="utf-8") as generated_file:
generated_contents = generated_file.read()
assert generated_contents == golden_contents

# Compare the file contents
golden_device_contents = (golden_stub_dir / stub_device_filepath).read_text(encoding="utf-8")
generated_device_contents = generated_device_stub_file.read_text(encoding="utf-8")
assert generated_device_contents == golden_device_contents
golden_pi_device_contents = (golden_stub_dir / stub_pi_device_filepath).read_text(
encoding="utf-8"
)
generated_pi_device_contents = generated_pi_device_stub_file.read_text(encoding="utf-8")
assert generated_pi_device_contents == golden_pi_device_contents
golden_tsp_device_contents = (golden_stub_dir / stub_tsp_device_filepath).read_text(
encoding="utf-8"
)
generated_tsp_device_contents = generated_tsp_device_stub_file.read_text(encoding="utf-8")
assert generated_tsp_device_contents == golden_tsp_device_contents

# Test the custom added properties
afg = device_manager.add_afg("afg3252c-hostname", alias="testing")
Expand Down
Loading