Skip to content

Commit

Permalink
Fix the stubgen helper function to attach stubs to the correct class …
Browse files Browse the repository at this point in the history
…in modules with multiple classes (#276)

* test: Moved the extension mixin test into its own, separate file for easier testing in the future

* fix: Updated the stubgen code to properly detect the location to add the stub code to

* test: Ignore some code for coverage reporting that should be unreachable code

* docs: Update the changelog with information on the stubgen fix
  • Loading branch information
nfelt14 authored Aug 14, 2024
1 parent cbd2c41 commit e8cb8a9
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 282 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ Valid subsections within a version are:

Things to be included in the next release go here.

### Fixed

- Fixed the stubgen helper to properly attach stubs to the correct class in modules that have multiple classes.

---

## v2.2.1 (2024-08-07)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ yamlfix = "^1.16.0"
[tool.poetry.group.docs.dependencies]
black = "^24.4.2"
codespell = "^2.2.6"
griffe = "^0.47.0"
mkdocs = "^1.6.0"
mkdocs-ezglossary-plugin = "^1.6.10"
mkdocs-gen-files = "^0.5.0"
Expand Down
18 changes: 16 additions & 2 deletions src/tm_devices/helpers/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _get_data_type(data_object: Any) -> str:


# pylint: disable=too-many-locals
def add_info_to_stub(cls: Any, method: Any, is_property: bool = False) -> None: # noqa: C901
def add_info_to_stub(cls: Any, method: Any, is_property: bool = False) -> None: # noqa: C901,PLR0912
"""Add information to a stub file.
This method requires that an environment variable named ``TM_DEVICES_STUB_DIR`` is defined that
Expand All @@ -42,6 +42,7 @@ def add_info_to_stub(cls: Any, method: Any, is_property: bool = False) -> None:
Raises:
AssertionError: Indicates that the file that needs to be updated does not exist.
ValueError: Indicates that the class could not be found in the stub file.
"""
if stub_dir := os.getenv("TM_DEVICES_STUB_DIR"):
method_filepath = inspect.getfile(cls)
Expand Down Expand Up @@ -88,8 +89,21 @@ def add_info_to_stub(cls: Any, method: Any, is_property: bool = False) -> None:
with open(method_filepath, encoding="utf-8") as file_pointer:
contents = file_pointer.read()
if f" def {method.__name__}(" not in contents:
contents += method_stub_content
if typing_imports:
contents = f"from typing import {', '.join(typing_imports)}\n" + contents
# Use a regular expression to find the end of the current class
pattern = r"(class\s+" + cls.__name__ + r"\b.*?)(\n(?=def|class)|\Z)"
# Insert the new code at the end of the current class
if match := re.search(pattern, contents, flags=re.DOTALL):
end_pos = match.end()
first_half_contents = contents[:end_pos]
if first_half_contents.endswith("\n\n"):
first_half_contents = first_half_contents[:-1]
second_half_contents = contents[end_pos:]
contents = first_half_contents + method_stub_content + second_half_contents
else: # pragma: no cover
msg = f"Could not find the end of the {cls.__class__.__name__} class."
raise ValueError(msg)

with open(method_filepath, "w", encoding="utf-8") as file_pointer:
file_pointer.write(contents)
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ def fixture_device_manager() -> Generator[DeviceManager, None, None]:
yield dev_manager


@pytest.fixture(autouse=True)
def _reset_dm(device_manager: DeviceManager) -> Generator[None, None, None]: # pyright: ignore[reportUnusedFunction]
"""Reset the device_manager settings after each test.
Args:
device_manager: The device manager fixture.
"""
saved_setup_enable = device_manager.setup_cleanup_enabled
saved_teardown_enable = device_manager.teardown_cleanup_enabled
yield
device_manager.setup_cleanup_enabled = saved_setup_enable
device_manager.teardown_cleanup_enabled = saved_teardown_enable


@pytest.fixture(name="mock_http_server", scope="session")
def _fixture_mock_http_server() -> ( # pyright: ignore [reportUnusedFunction]
Generator[None, None, None]
Expand Down
7 changes: 7 additions & 0 deletions tests/samples/golden_stubs/drivers/device.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,10 @@ class Device(ABC, metaclass=abc.ABCMeta):
This has a multi-line description.
"""

def function_1(arg1: str, arg2: int = 1) -> bool: ...

class OtherDevice(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...

def function_2(arg1: str, arg2: int = 2) -> bool: ...
280 changes: 0 additions & 280 deletions tests/test_device_manager.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
# pyright: reportUnusedFunction=none
# pyright: reportUnknownMemberType=none
# pyright: reportAttributeAccessIssue=none
# pyright: reportUnknownVariableType=none
# pyright: reportArgumentType=none
"""Tests for the device_manager.py file."""

import contextlib
import os
import subprocess
import sys

from pathlib import Path
from typing import Generator, Iterator, List
from unittest import mock

import pytest
Expand All @@ -20,51 +13,9 @@

from conftest import SIMULATED_VISA_LIB
from tm_devices import DeviceManager
from tm_devices.drivers import AFG3K, AFG3KC
from tm_devices.drivers.device import Device
from tm_devices.drivers.pi.scopes.scope import Scope
from tm_devices.drivers.pi.signal_generators.afgs.afg import AFG
from tm_devices.drivers.pi.signal_generators.signal_generator import SignalGenerator
from tm_devices.helpers import ConnectionTypes, DeviceTypes, PYVISA_PY_BACKEND, SerialConfig


@pytest.fixture(scope="module", autouse=True)
def _remove_added_methods() -> Iterator[None]:
"""Remove custom added methods from devices."""
yield
for obj, name in (
(Device, "inc_cached_count"),
(Device, "inc_count"),
(Device, "class_name"),
(Device, "custom_model_getter"),
(Device, "custom_list"),
(Device, "custom_return_none"),
(Device, "already_exists"),
(Scope, "custom_model_getter_scope"),
(Scope, "custom_return"),
(SignalGenerator, "custom_model_getter_ss"),
(AFG, "custom_model_getter_afg"),
(AFG3K, "custom_model_getter_afg3k"),
(AFG3KC, "custom_model_getter_afg3kc"),
):
with contextlib.suppress(AttributeError):
delattr(obj, name)


@pytest.fixture(autouse=True)
def _reset_dm(device_manager: DeviceManager) -> Generator[None, None, None]:
"""Reset the device_manager settings after each test.
Args:
device_manager: The device manager fixture.
"""
saved_setup_enable = device_manager.setup_cleanup_enabled
saved_teardown_enable = device_manager.teardown_cleanup_enabled
yield
device_manager.setup_cleanup_enabled = saved_setup_enable
device_manager.teardown_cleanup_enabled = saved_teardown_enable


class TestDeviceManager: # pylint: disable=no-self-use
"""Device Manager test class."""

Expand Down Expand Up @@ -222,237 +173,6 @@ def test_dm_properties(self, device_manager: DeviceManager) -> None:
device_manager.verbose = saved_verbose
device_manager.visa_library = saved_visa_lib

# pylint: disable=too-many-locals
def test_visa_device_methods_and_method_adding( # noqa: C901,PLR0915
self, device_manager: DeviceManager, capsys: pytest.CaptureFixture[str]
) -> None:
"""Test methods pertaining to VISA devices.
Args:
device_manager: The DeviceManager object.
capsys: The captured stdout and stderr.
"""
# Remove all previous devices
device_manager.remove_all_devices()
# Read the captured stdout to clear it
_ = capsys.readouterr().out
saved_setup_enable = device_manager.setup_cleanup_enabled
saved_teardown_enable = device_manager.teardown_cleanup_enabled
device_manager.setup_cleanup_enabled = True
device_manager.teardown_cleanup_enabled = True

############################################################################################
# Make sure to add all methods to the remove_added_methods() fixture
# at the top of this test module.

def gen_count() -> Iterator[int]:
"""Local counter."""
count = 0
while True:
count += 1
yield count

local_count = gen_count()

initial_input = '''import abc
from abc import ABC
from tm_devices.helpers import DeviceConfigEntry
class Device(ABC, metaclass=abc.ABCMeta):
def __init__(self, config_entry: DeviceConfigEntry, verbose: bool) -> None: ...
def already_exists(self) -> None:
"""Return nothing."""
'''
sub_filepath = Path("drivers/device.pyi")
generated_stub_dir = (
Path(__file__).parent
/ "samples/generated_stubs"
/ f"output_{sys.version_info.major}{sys.version_info.minor}/tm_devices"
)
generated_stub_file = generated_stub_dir / sub_filepath
golden_stub_dir = Path(__file__).parent / "samples" / "golden_stubs"
generated_stub_file.parent.mkdir(parents=True, exist_ok=True)
with open(generated_stub_file, "w", encoding="utf-8") as generated_file:
generated_file.write(initial_input)
with mock.patch.dict("os.environ", {"TM_DEVICES_STUB_DIR": str(generated_stub_dir)}):
# noinspection PyUnusedLocal,PyShadowingNames
@Device.add_property(is_cached=True)
def inc_cached_count(self: Device) -> int: # noqa: ARG001
"""Increment a local counter."""
return next(local_count)

# noinspection PyUnusedLocal,PyShadowingNames
@Device.add_property(is_cached=False)
def inc_count(self: Device) -> int: # noqa: ARG001
"""Increment a local counter."""
return next(local_count)

# noinspection PyShadowingNames
@Device.add_property
def class_name(self: Device) -> str:
"""Return the class name."""
return self.__class__.__name__

# noinspection PyShadowingNames
@Device.add_method
def custom_model_getter(
self: Device,
value1: str,
value2: str = "add",
value3: str = "",
value4: float = 0.1,
) -> str:
"""Return the model."""
return " ".join(["Device", self.model, value1, value2, value3, str(value4)])

# noinspection PyShadowingNames
@Device.add_method
def custom_list(self: Device) -> List[str]:
"""Return the model and serial in a list."""
return [self.model, self.serial]

@Device.add_method
def custom_return_none() -> None:
"""Return nothing.
This has a multi-line description.
"""

@Device.add_method
def already_exists() -> None:
"""Return nothing."""

with pytest.raises(AssertionError):

@Scope.add_method
def custom_return() -> None:
"""Return nothing."""

@Scope.add_method
def custom_model_getter_scope(device: Scope, value: str) -> str:
"""Return the model."""
return f"Scope {device.model} {value}"

@SignalGenerator.add_method
def custom_model_getter_ss(device: SignalGenerator, value: str) -> str:
"""Return the model."""
return f"SignalGenerator {device.model} {value}"

@AFG.add_method
def custom_model_getter_afg(device: AFG, value: str) -> str:
"""Return the model."""
return f"AFG {device.model} {value}"

@AFG3K.add_method
def custom_model_getter_afg3k(device: AFG3K, value: str) -> str:
"""Return the model."""
return f"AFG3K {device.model} {value}"

@AFG3KC.add_method
def custom_model_getter_afg3kc(device: AFG3KC, value: str) -> str:
"""Return the model."""
return f"AFG3KC {device.model} {value}"

############################################################################################
start_dir = os.getcwd()
try:
os.chdir(generated_stub_file.parent)
subprocess.check_call( # noqa: S603
[
sys.executable,
"-m",
"ruff",
"format",
"--quiet",
generated_stub_file.name,
]
)
subprocess.check_call( # noqa: S603
[
sys.executable,
"-m",
"ruff",
"check",
"--quiet",
"--select=I",
"--fix",
generated_stub_file.name,
]
)
finally:
os.chdir(start_dir)
with open(golden_stub_dir / sub_filepath, encoding="utf-8") as golden_file:
golden_contents = golden_file.read()
with open(generated_stub_file, encoding="utf-8") as generated_file:
generated_contents = generated_file.read()
assert generated_contents == golden_contents

# Test the custom added properties
afg = device_manager.add_afg("afg3252c-hostname", alias="testing")
# noinspection PyUnresolvedReferences
assert afg.class_name == "AFG3KC"
# noinspection PyUnresolvedReferences
_ = afg.inc_cached_count
# noinspection PyUnresolvedReferences
assert afg.inc_cached_count == 1, "cached property is not working"
# noinspection PyUnresolvedReferences
_ = afg.inc_count
# noinspection PyUnresolvedReferences
assert afg.inc_count == 3, "uncached property is not working"

# Test the custom added methods
# noinspection PyUnresolvedReferences
assert afg.custom_model_getter("a", "b", "c", 0.1) == "Device AFG3252C a b c 0.1"
# noinspection PyUnresolvedReferences
assert afg.custom_model_getter_ss("hello") == "SignalGenerator AFG3252C hello"
# noinspection PyUnresolvedReferences
assert afg.custom_model_getter_afg("hello") == "AFG AFG3252C hello"
# noinspection PyUnresolvedReferences
assert afg.custom_model_getter_afg3k("hello") == "AFG3K AFG3252C hello"
# noinspection PyUnresolvedReferences
assert afg.custom_model_getter_afg3kc("hello") == "AFG3KC AFG3252C hello"
with pytest.raises(AttributeError):
# noinspection PyUnresolvedReferences
afg.custom_model_getter_scope("hello")

# Test VISA methods
assert afg.set_and_check("OUTPUT1:STATE", "1", custom_message_prefix="Custom prefix") == "1"
device_manager.disable_device_command_checking()
assert afg.set_and_check("OUTPUT1:STATE", "0") == ""
device_manager.cleanup_all_devices()
console_output = capsys.readouterr()
assert "Beginning Device Cleanup on AFG " in console_output.out
assert "Response from 'OUTPUT1:STATE?' >> '1'" in console_output.out
assert "Response from 'OUTPUT1:STATE?' >> '0'" not in console_output.out
assert console_output.err == ""

assert len(device_manager.devices) == 1
device_manager.close()
assert "Beginning Device Cleanup" in capsys.readouterr().out
assert len(device_manager.devices) == 1

device_manager.setup_cleanup_enabled = False
device_manager.open()
device_manager.verbose_visa = True
afg = device_manager.get_afg(number_or_alias="testing")
afg.ieee_cmds.idn()
assert "pyvisa - DEBUG" in capsys.readouterr().err
device_manager.verbose_visa = False
assert not device_manager.verbose_visa
afg.ieee_cmds.idn()
assert "pyvisa - DEBUG" not in capsys.readouterr().err
device_manager.teardown_cleanup_enabled = False
assert len(device_manager.devices) == 1
device_manager.close()
assert "Beginning Device Cleanup" not in capsys.readouterr().out
assert len(device_manager.devices) == 1

device_manager.open()
device_manager.remove_device(alias="testing")
device_manager.setup_cleanup_enabled = saved_setup_enable
device_manager.teardown_cleanup_enabled = saved_teardown_enable

def test_failed_cleanup(self, device_manager: DeviceManager) -> None:
"""Test what happens when a device manager cleanup fails.
Expand Down
Loading

0 comments on commit e8cb8a9

Please sign in to comment.