Skip to content

Commit

Permalink
fix: Update stub generation helper function to handle classes followe…
Browse files Browse the repository at this point in the history
…d by dataclasses (#307)

* fix: Updated the stub generation helper function to properly detect dataclasses that immediately follow a class that is having methods added to it via the extension mechanism
  • Loading branch information
nfelt14 authored Sep 17, 2024
1 parent 4efdddc commit 74553a4
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 33 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ Things to be included in the next release go here.
- _**SEMI-BREAKING CHANGE**_: Changed the `USB_MODEL_ID_LOOKUP` constant to use `SupportedModels` as keys instead of values to make the documentation clearer.
- _**SEMI-BREAKING CHANGE**_: Changed the `DEVICE_DRIVER_MODEL_MAPPING` constant to use `SupportedModels` as keys instead of values to make the documentation clearer.

### Fixed

- Fixed a bug in the stubgen helper code responsible for adding dynamically added methods to stub files that caused invalid stub files to be created if a dataclass immediately followed a class that was being dynamically updated.

---

## v2.3.0 (2024-08-23)
Expand Down
4 changes: 2 additions & 2 deletions src/tm_devices/helpers/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def add_info_to_stub(cls: Any, method: Any, is_property: bool = False) -> None:
if typing_imports:
contents = f"from typing import {', '.join(typing_imports)}\n" + contents
# Use a regular expression to find the end of the current class
pattern = r"(class\s+" + cls.__name__ + r"\b.*?)(\n(?=def|class)|\Z)"
pattern = r"(class\s+" + cls.__name__ + r"\b.*?)(\n(?=def|class|@)|\Z)"
# Insert the new code at the end of the current class
if match := re.search(pattern, contents, flags=re.DOTALL):
end_pos = match.end()
Expand All @@ -102,7 +102,7 @@ def add_info_to_stub(cls: Any, method: Any, is_property: bool = False) -> None:
second_half_contents = contents[end_pos:]
contents = first_half_contents + method_stub_content + second_half_contents
else: # pragma: no cover
msg = f"Could not find the end of the {cls.__class__.__name__} class."
msg = f"Could not find the end of the {cls.__name__} class."
raise ValueError(msg)

with open(method_filepath, "w", encoding="utf-8") as file_pointer:
Expand Down
6 changes: 6 additions & 0 deletions tests/samples/golden_stubs/drivers/device.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@ from typing import List
from tm_devices.helpers import DeviceConfigEntry

class Device(ABC, metaclass=abc.ABCMeta):
class NestedClass:
"""This is a nested class."""

def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""
@property
def existing_property(self) -> int:
"""Return an int."""
@property
def inc_cached_count(self) -> int:
"""Increment a local counter."""
@property
Expand Down
15 changes: 15 additions & 0 deletions tests/samples/golden_stubs/drivers/pi/pi_device.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import abc

from abc import ABC

from tm_devices.helpers import DeviceConfigEntry

class PIDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""
def added_method(self) -> None:
"""Return nothing."""

class OtherDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
18 changes: 18 additions & 0 deletions tests/samples/golden_stubs/drivers/pi/tsp_device.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import abc

from abc import ABC
from dataclasses import dataclass

from tm_devices.helpers import DeviceConfigEntry

class TSPDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""
def added_tsp_method(self) -> None:
"""Return nothing."""

@dataclass(frozen=True)
class CustomDataclass:
value1: str
value2: int = 1
128 changes: 97 additions & 31 deletions tests/test_extension_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,66 @@
from tm_devices import DeviceManager
from tm_devices.drivers import AFG3K, AFG3KC
from tm_devices.drivers.device import Device
from tm_devices.drivers.pi.pi_device import PIDevice
from tm_devices.drivers.pi.scopes.scope import Scope
from tm_devices.drivers.pi.signal_generators.afgs.afg import AFG
from tm_devices.drivers.pi.signal_generators.signal_generator import SignalGenerator
from tm_devices.drivers.pi.tsp_device import TSPDevice

INITIAL_DEVICE_INPUT = '''import abc
from abc import ABC
from tm_devices.helpers import DeviceConfigEntry
class Device(ABC, metaclass=abc.ABCMeta):
class NestedClass:
"""This is a nested class."""
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""
@property
def existing_property(self) -> int:
"""Return an int."""
def function_1(arg1: str, arg2: int = 1) -> bool: ...
class OtherDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def function_2(arg1: str, arg2: int = 2) -> bool: ...
'''
INITIAL_PI_DEVICE_INPUT = '''import abc
from abc import ABC
from tm_devices.helpers import DeviceConfigEntry
class PIDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""
class OtherDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
'''
INITIAL_TSP_DEVICE_INPUT = '''import abc
from abc import ABC
from dataclasses import dataclass
from tm_devices.helpers import DeviceConfigEntry
class TSPDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""
@dataclass(frozen=True)
class CustomDataclass:
value1: str
value2: int = 1
'''


@pytest.fixture(scope="module", autouse=True)
Expand All @@ -42,6 +99,8 @@ def _remove_added_methods() -> Iterator[None]:
(AFG, "custom_model_getter_afg"),
(AFG3K, "custom_model_getter_afg3k"),
(AFG3KC, "custom_model_getter_afg3kc"),
(PIDevice, "added_method"),
(TSPDevice, "added_tsp_method"),
):
with contextlib.suppress(AttributeError):
delattr(obj, name)
Expand Down Expand Up @@ -77,34 +136,23 @@ def gen_count() -> Iterator[int]:

local_count = gen_count()

initial_input = '''import abc
from abc import ABC
from tm_devices.helpers import DeviceConfigEntry
class Device(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""
def function_1(arg1: str, arg2: int = 1) -> bool: ...
class OtherDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def function_2(arg1: str, arg2: int = 2) -> bool: ...
'''
sub_filepath = Path("drivers/device.pyi")
golden_stub_dir = Path(__file__).parent / "samples" / "golden_stubs"
stub_device_filepath = Path("drivers/device.pyi")
stub_pi_device_filepath = Path("drivers/pi/pi_device.pyi")
stub_tsp_device_filepath = Path("drivers/pi/tsp_device.pyi")
generated_stub_dir = (
Path(__file__).parent
/ "samples/generated_stubs"
/ f"output_{sys.version_info.major}{sys.version_info.minor}/tm_devices"
)
generated_stub_file = generated_stub_dir / sub_filepath
golden_stub_dir = Path(__file__).parent / "samples" / "golden_stubs"
generated_stub_file.parent.mkdir(parents=True, exist_ok=True)
with open(generated_stub_file, "w", encoding="utf-8") as generated_file:
generated_file.write(initial_input)
generated_device_stub_file = generated_stub_dir / stub_device_filepath
generated_device_stub_file.parent.mkdir(parents=True, exist_ok=True)
generated_pi_device_stub_file = generated_stub_dir / stub_pi_device_filepath
generated_tsp_device_stub_file = generated_stub_dir / stub_tsp_device_filepath
generated_pi_device_stub_file.parent.mkdir(parents=True, exist_ok=True)
generated_device_stub_file.write_text(INITIAL_DEVICE_INPUT, encoding="utf-8")
generated_pi_device_stub_file.write_text(INITIAL_PI_DEVICE_INPUT, encoding="utf-8")
generated_tsp_device_stub_file.write_text(INITIAL_TSP_DEVICE_INPUT, encoding="utf-8")
with mock.patch.dict("os.environ", {"TM_DEVICES_STUB_DIR": str(generated_stub_dir)}):
# noinspection PyUnusedLocal,PyShadowingNames
@Device.add_property(is_cached=True)
Expand Down Expand Up @@ -153,6 +201,14 @@ def custom_return_none() -> None:
def already_exists() -> None:
"""Return nothing."""

@PIDevice.add_method
def added_method() -> None:
"""Return nothing."""

@TSPDevice.add_method
def added_tsp_method() -> None:
"""Return nothing."""

with pytest.raises(AssertionError):

@Scope.add_method
Expand Down Expand Up @@ -187,15 +243,15 @@ def custom_model_getter_afg3kc(device: AFG3KC, value: str) -> str:
############################################################################################
start_dir = os.getcwd()
try:
os.chdir(generated_stub_file.parent)
os.chdir(generated_stub_dir)
subprocess.check_call( # noqa: S603
[
sys.executable,
"-m",
"ruff",
"format",
"--quiet",
generated_stub_file.name,
generated_stub_dir,
]
)
subprocess.check_call( # noqa: S603
Expand All @@ -207,16 +263,26 @@ def custom_model_getter_afg3kc(device: AFG3KC, value: str) -> str:
"--quiet",
"--select=I",
"--fix",
generated_stub_file.name,
generated_stub_dir,
]
)
finally:
os.chdir(start_dir)
with open(golden_stub_dir / sub_filepath, encoding="utf-8") as golden_file:
golden_contents = golden_file.read()
with open(generated_stub_file, encoding="utf-8") as generated_file:
generated_contents = generated_file.read()
assert generated_contents == golden_contents

# Compare the file contents
golden_device_contents = (golden_stub_dir / stub_device_filepath).read_text(encoding="utf-8")
generated_device_contents = generated_device_stub_file.read_text(encoding="utf-8")
assert generated_device_contents == golden_device_contents
golden_pi_device_contents = (golden_stub_dir / stub_pi_device_filepath).read_text(
encoding="utf-8"
)
generated_pi_device_contents = generated_pi_device_stub_file.read_text(encoding="utf-8")
assert generated_pi_device_contents == golden_pi_device_contents
golden_tsp_device_contents = (golden_stub_dir / stub_tsp_device_filepath).read_text(
encoding="utf-8"
)
generated_tsp_device_contents = generated_tsp_device_stub_file.read_text(encoding="utf-8")
assert generated_tsp_device_contents == golden_tsp_device_contents

# Test the custom added properties
afg = device_manager.add_afg("afg3252c-hostname", alias="testing")
Expand Down

0 comments on commit 74553a4

Please sign in to comment.