diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 418d4bd0ff..ff4d6f8be9 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -83,6 +83,7 @@ DEFAULT_TIMEOUT, CalculatableTimeout, Callback, + LazyMock, NotConnected, Reference, StrictEnum, @@ -176,6 +177,7 @@ "DEFAULT_TIMEOUT", "CalculatableTimeout", "Callback", + "LazyMock", "CALCULATE_TIMEOUT", "NotConnected", "Reference", diff --git a/src/ophyd_async/core/_device.py b/src/ophyd_async/core/_device.py index 1fe7855f3d..eb43abff58 100644 --- a/src/ophyd_async/core/_device.py +++ b/src/ophyd_async/core/_device.py @@ -3,17 +3,15 @@ import asyncio import sys from collections.abc import Coroutine, Iterator, Mapping, MutableMapping +from functools import cached_property from logging import LoggerAdapter, getLogger from typing import Any, TypeVar -from unittest.mock import Mock from bluesky.protocols import HasName from bluesky.run_engine import call_in_bluesky_event_loop, in_bluesky_event_loop from ._protocol import Connectable -from ._utils import DEFAULT_TIMEOUT, NotConnected, wait_for_connection - -_device_mocks: dict[Device, Mock] = {} +from ._utils import DEFAULT_TIMEOUT, LazyMock, NotConnected, wait_for_connection class DeviceConnector: @@ -37,25 +35,23 @@ def create_children_from_annotations(self, device: Device): during ``__init__``. """ - async def connect( - self, - device: Device, - mock: bool | Mock, - timeout: float, - force_reconnect: bool, - ): + async def connect_mock(self, device: Device, mock: LazyMock): + # Connect serially, no errors to gather up as in mock mode + for name, child_device in device.children(): + await child_device.connect(mock=mock.child(name)) + + async def connect_real(self, device: Device, timeout: float, force_reconnect: bool): """Used during ``Device.connect``. This is called when a previous connect has not been done, or has been done in a different mock more. It should connect the Device and all its children. """ - coros = {} - for name, child_device in device.children(): - child_mock = getattr(mock, name) if mock else mock # Mock() or False - coros[name] = child_device.connect( - mock=child_mock, timeout=timeout, force_reconnect=force_reconnect - ) + # Connect in parallel, gathering up NotConnected errors + coros = { + name: child_device.connect(timeout=timeout, force_reconnect=force_reconnect) + for name, child_device in device.children() + } await wait_for_connection(**coros) @@ -67,9 +63,8 @@ class Device(HasName, Connectable): parent: Device | None = None # None if connect hasn't started, a Task if it has _connect_task: asyncio.Task | None = None - # If not None, then this is the mock arg of the previous connect - # to let us know if we can reuse an existing connection - _connect_mock_arg: bool | None = None + # The mock if we have connected in mock mode + _mock: LazyMock | None = None def __init__( self, name: str = "", connector: DeviceConnector | None = None @@ -83,10 +78,18 @@ def name(self) -> str: """Return the name of the Device""" return self._name + @cached_property + def _child_devices(self) -> dict[str, Device]: + return {} + def children(self) -> Iterator[tuple[str, Device]]: - for attr_name, attr in self.__dict__.items(): - if attr_name != "parent" and isinstance(attr, Device): - yield attr_name, attr + yield from self._child_devices.items() + + @cached_property + def log(self) -> LoggerAdapter: + return LoggerAdapter( + getLogger("ophyd_async.devices"), {"ophyd_async_device_name": self.name} + ) def set_name(self, name: str): """Set ``self.name=name`` and each ``self.child.name=name+"-child"``. @@ -97,28 +100,33 @@ def set_name(self, name: str): New name to set """ self._name = name - # Ensure self.log is recreated after a name change - self.log = LoggerAdapter( - getLogger("ophyd_async.devices"), {"ophyd_async_device_name": self.name} - ) + # Ensure logger is recreated after a name change + if "log" in self.__dict__: + del self.log for child_name, child in self.children(): child_name = f"{self.name}-{child_name.strip('_')}" if self.name else "" child.set_name(child_name) def __setattr__(self, name: str, value: Any) -> None: + # Bear in mind that this function is called *a lot*, so + # we need to make sure nothing expensive happens in it... if name == "parent": if self.parent not in (value, None): raise TypeError( f"Cannot set the parent of {self} to be {value}: " f"it is already a child of {self.parent}" ) - elif isinstance(value, Device): + # ...hence not doing an isinstance check for attributes we + # know not to be Devices + elif name not in _not_device_attrs and isinstance(value, Device): value.parent = self - return super().__setattr__(name, value) + self._child_devices[name] = value + # ...and avoiding the super call as we know it resolves to `object` + return object.__setattr__(self, name, value) async def connect( self, - mock: bool | Mock = False, + mock: bool | LazyMock = False, timeout: float = DEFAULT_TIMEOUT, force_reconnect: bool = False, ) -> None: @@ -133,26 +141,39 @@ async def connect( timeout: Time to wait before failing with a TimeoutError. """ - uses_mock = bool(mock) - can_use_previous_connect = ( - uses_mock is self._connect_mock_arg - and self._connect_task - and not (self._connect_task.done() and self._connect_task.exception()) - ) - if mock is True: - mock = Mock() # create a new Mock if one not provided - if force_reconnect or not can_use_previous_connect: - self._connect_mock_arg = uses_mock - if self._connect_mock_arg: - _device_mocks[self] = mock - coro = self._connector.connect( - device=self, mock=mock, timeout=timeout, force_reconnect=force_reconnect + if mock: + # Always connect in mock mode serially + if isinstance(mock, LazyMock): + # Use the provided mock + self._mock = mock + elif not self._mock: + # Make one + self._mock = LazyMock() + await self._connector.connect_mock(self, self._mock) + else: + # Try to cache the connect in real mode + can_use_previous_connect = ( + self._mock is None + and self._connect_task + and not (self._connect_task.done() and self._connect_task.exception()) ) - self._connect_task = asyncio.create_task(coro) - - assert self._connect_task, "Connect task not created, this shouldn't happen" - # Wait for it to complete - await self._connect_task + if force_reconnect or not can_use_previous_connect: + self._mock = None + coro = self._connector.connect_real(self, timeout, force_reconnect) + self._connect_task = asyncio.create_task(coro) + assert self._connect_task, "Connect task not created, this shouldn't happen" + # Wait for it to complete + await self._connect_task + + +_not_device_attrs = { + "_name", + "_children", + "_connector", + "_timeout", + "_mock", + "_connect_task", +} DeviceT = TypeVar("DeviceT", bound=Device) diff --git a/src/ophyd_async/core/_mock_signal_backend.py b/src/ophyd_async/core/_mock_signal_backend.py index 878313e051..43fb2ae7df 100644 --- a/src/ophyd_async/core/_mock_signal_backend.py +++ b/src/ophyd_async/core/_mock_signal_backend.py @@ -1,13 +1,13 @@ import asyncio from collections.abc import Callable from functools import cached_property -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock from bluesky.protocols import Descriptor, Reading from ._signal_backend import SignalBackend, SignalDatatypeT from ._soft_signal_backend import SoftSignalBackend -from ._utils import Callback +from ._utils import Callback, LazyMock class MockSignalBackend(SignalBackend[SignalDatatypeT]): @@ -16,7 +16,7 @@ class MockSignalBackend(SignalBackend[SignalDatatypeT]): def __init__( self, initial_backend: SignalBackend[SignalDatatypeT], - mock: Mock, + mock: LazyMock, ) -> None: if isinstance(initial_backend, MockSignalBackend): raise ValueError("Cannot make a MockSignalBackend for a MockSignalBackend") @@ -34,11 +34,14 @@ def __init__( # use existing Mock if provided self.mock = mock - self.put_mock = AsyncMock(name="put", spec=Callable) - self.mock.attach_mock(self.put_mock, "put") - super().__init__(datatype=self.initial_backend.datatype) + @cached_property + def put_mock(self) -> AsyncMock: + put_mock = AsyncMock(name="put", spec=Callable) + self.mock().attach_mock(put_mock, "put") + return put_mock + def set_value(self, value: SignalDatatypeT): self.soft_backend.set_value(value) @@ -46,7 +49,7 @@ def source(self, name: str, read: bool) -> str: return f"mock+{self.initial_backend.source(name, read)}" async def connect(self, timeout: float) -> None: - pass + raise RuntimeError("It is not possible to connect a MockSignalBackend") @cached_property def put_proceeds(self) -> asyncio.Event: diff --git a/src/ophyd_async/core/_mock_signal_utils.py b/src/ophyd_async/core/_mock_signal_utils.py index 30d48dbfe0..08976a0468 100644 --- a/src/ophyd_async/core/_mock_signal_utils.py +++ b/src/ophyd_async/core/_mock_signal_utils.py @@ -2,17 +2,26 @@ from contextlib import asynccontextmanager, contextmanager from unittest.mock import AsyncMock, Mock -from ._device import Device, _device_mocks +from ._device import Device from ._mock_signal_backend import MockSignalBackend -from ._signal import Signal, SignalR, _mock_signal_backends +from ._signal import Signal, SignalConnector, SignalR from ._soft_signal_backend import SignalDatatypeT +from ._utils import LazyMock + + +def get_mock(device: Device | Signal) -> Mock: + mock = device._mock # noqa: SLF001 + assert isinstance(mock, LazyMock), f"Device {device} not connected in mock mode" + return mock() def _get_mock_signal_backend(signal: Signal) -> MockSignalBackend: - assert ( - signal in _mock_signal_backends + connector = signal._connector # noqa: SLF001 + assert isinstance(connector, SignalConnector), f"Expected Signal, got {signal}" + assert isinstance( + connector.backend, MockSignalBackend ), f"Signal {signal} not connected in mock mode" - return _mock_signal_backends[signal] + return connector.backend def set_mock_value(signal: Signal[SignalDatatypeT], value: SignalDatatypeT): @@ -45,12 +54,6 @@ def get_mock_put(signal: Signal) -> AsyncMock: return _get_mock_signal_backend(signal).put_mock -def get_mock(device: Device | Signal) -> Mock: - if isinstance(device, Signal): - return _get_mock_signal_backend(device).mock - return _device_mocks[device] - - def reset_mock_put_calls(signal: Signal): backend = _get_mock_signal_backend(signal) backend.put_mock.reset_mock() diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 22f4ffdc57..d4d4d7ffe3 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -4,7 +4,6 @@ import functools from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping from typing import Any, Generic, cast -from unittest.mock import Mock from bluesky.protocols import ( Locatable, @@ -30,9 +29,14 @@ ) from ._soft_signal_backend import SoftSignalBackend from ._status import AsyncStatus -from ._utils import CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, Callback, T - -_mock_signal_backends: dict[Device, MockSignalBackend] = {} +from ._utils import ( + CALCULATE_TIMEOUT, + DEFAULT_TIMEOUT, + CalculatableTimeout, + Callback, + LazyMock, + T, +) async def _wait_for(coro: Awaitable[T], timeout: float | None, source: str) -> T: @@ -54,26 +58,28 @@ class SignalConnector(DeviceConnector): def __init__(self, backend: SignalBackend): self.backend = self._init_backend = backend - async def connect( - self, - device: Device, - mock: bool | Mock, - timeout: float, - force_reconnect: bool, - ): - if mock: - self.backend = MockSignalBackend(self._init_backend, mock) - _mock_signal_backends[device] = self.backend - else: - self.backend = self._init_backend + async def connect_mock(self, device: Device, mock: LazyMock): + self.backend = MockSignalBackend(self._init_backend, mock) + + async def connect_real(self, device: Device, timeout: float, force_reconnect: bool): + self.backend = self._init_backend device.log.debug(f"Connecting to {self.backend.source(device.name, read=True)}") await self.backend.connect(timeout) +class _ChildrenNotAllowed(dict[str, Device]): + def __setitem__(self, key: str, value: Device) -> None: + raise AttributeError( + f"Cannot add Device or Signal child {key}={value} of Signal, " + "make a subclass of Device instead" + ) + + class Signal(Device, Generic[SignalDatatypeT]): """A Device with the concept of a value, with R, RW, W and X flavours""" _connector: SignalConnector + _child_devices = _ChildrenNotAllowed() # type: ignore def __init__( self, @@ -89,14 +95,6 @@ def source(self) -> str: """Like ca://PV_PREFIX:SIGNAL, or "" if not set""" return self._connector.backend.source(self.name, read=True) - def __setattr__(self, name: str, value: Any) -> None: - if name != "parent" and isinstance(value, Device): - raise AttributeError( - f"Cannot add Device or Signal {value} as a child of Signal {self}, " - "make a subclass of Device instead" - ) - return super().__setattr__(name, value) - class _SignalCache(Generic[SignalDatatypeT]): def __init__(self, backend: SignalBackend[SignalDatatypeT], signal: Signal): diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index d0e48c7212..ba21c3ba9b 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -4,6 +4,7 @@ from abc import abstractmethod from collections.abc import Sequence from dataclasses import dataclass +from functools import lru_cache from typing import Any, Generic, get_origin import numpy as np @@ -90,6 +91,7 @@ def write_value(self, value: Any) -> TableT: raise TypeError(f"Cannot convert {value} to {self.datatype}") +@lru_cache def make_converter(datatype: type[SignalDatatype]) -> SoftConverter: enum_cls = get_enum_cls(datatype) if datatype == Sequence[str]: diff --git a/src/ophyd_async/core/_utils.py b/src/ophyd_async/core/_utils.py index db4afae04a..ca20d90a3c 100644 --- a/src/ophyd_async/core/_utils.py +++ b/src/ophyd_async/core/_utils.py @@ -14,6 +14,7 @@ get_args, get_origin, ) +from unittest.mock import Mock import numpy as np @@ -120,20 +121,29 @@ async def wait_for_connection(**coros: Awaitable[None]): Expected kwargs should be a mapping of names to coroutine tasks to execute. """ - results = await asyncio.gather(*coros.values(), return_exceptions=True) - exceptions = {} + exceptions: dict[str, Exception] = {} + if len(coros) == 1: + # Single device optimization + name, coro = coros.popitem() + try: + await coro + except Exception as e: + exceptions[name] = e + else: + # Use gather to connect in parallel + results = await asyncio.gather(*coros.values(), return_exceptions=True) + for name, result in zip(coros, results, strict=False): + if isinstance(result, Exception): + exceptions[name] = result - for name, result in zip(coros, results, strict=False): - if isinstance(result, Exception): - exceptions[name] = result - if not isinstance(result, NotConnected): + if exceptions: + for name, exception in exceptions.items(): + if not isinstance(exception, NotConnected): logging.exception( f"device `{name}` raised unexpected exception " - f"{type(result).__name__}", - exc_info=result, + f"{type(exception).__name__}", + exc_info=exception, ) - - if exceptions: raise NotConnected(exceptions) @@ -252,3 +262,38 @@ def __init__(self, obj: T): def __call__(self) -> T: return self._obj + + +class LazyMock: + """A lazily created Mock to be used when connecting in mock mode. + + Creating Mocks is reasonably expensive when each Device (and Signal) + requires its own, and the tree is only used when ``Signal.set()`` is + called. This class allows a tree of lazily connected Mocks to be + constructed so that when the leaf is created, so are its parents. + Any calls to the child are then accessible from the parent mock. + + >>> parent = LazyMock() + >>> child = parent.child("child") + >>> child_mock = child() + >>> child_mock() # doctest: +ELLIPSIS + + >>> parent_mock = parent() + >>> parent_mock.mock_calls + [call.child()] + """ + + def __init__(self, name: str = "", parent: LazyMock | None = None) -> None: + self.parent = parent + self.name = name + self._mock: Mock | None = None + + def child(self, name: str) -> LazyMock: + return LazyMock(name, self) + + def __call__(self) -> Mock: + if self._mock is None: + self._mock = Mock(spec=object) + if self.parent is not None: + self.parent().attach_mock(self._mock, self.name) + return self._mock diff --git a/src/ophyd_async/epics/adcore/_single_trigger.py b/src/ophyd_async/epics/adcore/_single_trigger.py index 9fd81b413d..165204d371 100644 --- a/src/ophyd_async/epics/adcore/_single_trigger.py +++ b/src/ophyd_async/epics/adcore/_single_trigger.py @@ -19,7 +19,8 @@ def __init__( **plugins: NDPluginBaseIO, ) -> None: self.drv = drv - self.__dict__.update(plugins) + for k, v in plugins.items(): + setattr(self, k, v) self.add_readables( [self.drv.array_counter, *read_uncached], diff --git a/src/ophyd_async/epics/core/_pvi_connector.py b/src/ophyd_async/epics/core/_pvi_connector.py index 812e4ec473..1c5c0eceb6 100644 --- a/src/ophyd_async/epics/core/_pvi_connector.py +++ b/src/ophyd_async/epics/core/_pvi_connector.py @@ -1,7 +1,5 @@ from __future__ import annotations -from unittest.mock import Mock - from ophyd_async.core import ( Device, DeviceConnector, @@ -11,6 +9,7 @@ SignalRW, SignalX, ) +from ophyd_async.core._utils import LazyMock from ._epics_connector import fill_backend_with_prefix from ._signal import PvaSignalBackend, pvget_with_timeout @@ -64,29 +63,29 @@ def _fill_child(self, name: str, entry: Entry, vector_index: int | None = None): backend.read_pv = read_pv backend.write_pv = write_pv - async def connect( - self, device: Device, mock: bool | Mock, timeout: float, force_reconnect: bool + async def connect_mock(self, device: Device, mock: LazyMock): + self.filler.create_device_vector_entries_to_mock(2) + # Set the name of the device to name all children + device.set_name(device.name) + return await super().connect_mock(device, mock) + + async def connect_real( + self, device: Device, timeout: float, force_reconnect: bool ) -> None: - if mock: - # Make 2 entries for each DeviceVector - self.filler.create_device_vector_entries_to_mock(2) - else: - pvi_structure = await pvget_with_timeout(self.pvi_pv, timeout) - entries: dict[str, Entry | list[Entry | None]] = pvi_structure[ - "value" - ].todict() - # Fill based on what PVI gives us - for name, entry in entries.items(): - if isinstance(entry, dict): - # This is a child - self._fill_child(name, entry) - else: - # This is a DeviceVector of children - for i, e in enumerate(entry): - if e: - self._fill_child(name, e, i) - # Check that all the requested children have been filled - self.filler.check_filled(f"{self.pvi_pv}: {entries}") + pvi_structure = await pvget_with_timeout(self.pvi_pv, timeout) + entries: dict[str, Entry | list[Entry | None]] = pvi_structure["value"].todict() + # Fill based on what PVI gives us + for name, entry in entries.items(): + if isinstance(entry, dict): + # This is a child + self._fill_child(name, entry) + else: + # This is a DeviceVector of children + for i, e in enumerate(entry): + if e: + self._fill_child(name, e, i) + # Check that all the requested children have been filled + self.filler.check_filled(f"{self.pvi_pv}: {entries}") # Set the name of the device to name all children device.set_name(device.name) - return await super().connect(device, mock, timeout, force_reconnect) + return await super().connect_real(device, timeout, force_reconnect) diff --git a/src/ophyd_async/plan_stubs/_ensure_connected.py b/src/ophyd_async/plan_stubs/_ensure_connected.py index d4835b710c..0ad5cff518 100644 --- a/src/ophyd_async/plan_stubs/_ensure_connected.py +++ b/src/ophyd_async/plan_stubs/_ensure_connected.py @@ -1,13 +1,11 @@ -from unittest.mock import Mock - import bluesky.plan_stubs as bps -from ophyd_async.core import DEFAULT_TIMEOUT, Device, wait_for_connection +from ophyd_async.core import DEFAULT_TIMEOUT, Device, LazyMock, wait_for_connection def ensure_connected( *devices: Device, - mock: bool | Mock = False, + mock: bool | LazyMock = False, timeout: float = DEFAULT_TIMEOUT, force_reconnect=False, ): diff --git a/src/ophyd_async/tango/base_devices/_base_device.py b/src/ophyd_async/tango/base_devices/_base_device.py index 2227d5ddbc..1f98c4bd4f 100644 --- a/src/ophyd_async/tango/base_devices/_base_device.py +++ b/src/ophyd_async/tango/base_devices/_base_device.py @@ -1,9 +1,9 @@ from __future__ import annotations from typing import TypeVar -from unittest.mock import Mock from ophyd_async.core import Device, DeviceConnector, DeviceFiller +from ophyd_async.core._utils import LazyMock from ophyd_async.tango.signal import ( TangoSignalBackend, infer_python_type, @@ -117,41 +117,42 @@ def create_children_from_annotations(self, device: Device): list(self.filler.create_signals_from_annotations(filled=False)) self.filler.check_created() - async def connect( - self, device: Device, mock: bool | Mock, timeout: float, force_reconnect: bool - ) -> None: - if mock: - # Make 2 entries for each DeviceVector - self.filler.create_device_vector_entries_to_mock(2) + async def connect_mock(self, device: Device, mock: LazyMock): + # Make 2 entries for each DeviceVector + self.filler.create_device_vector_entries_to_mock(2) + # Set the name of the device to name all children + device.set_name(device.name) + return await super().connect_mock(device, mock) + + async def connect_real(self, device: Device, timeout: float, force_reconnect: bool): + if self.trl and self.proxy is None: + self.proxy = await AsyncDeviceProxy(self.trl) + elif self.proxy and not self.trl: + self.trl = self.proxy.name() else: - if self.trl and self.proxy is None: - self.proxy = await AsyncDeviceProxy(self.trl) - elif self.proxy and not self.trl: - self.trl = self.proxy.name() - else: - raise TypeError("Neither proxy nor trl supplied") - - children = sorted( - set() - .union(self.proxy.get_attribute_list()) - .union(self.proxy.get_command_list()) - ) - for name in children: - # TODO: strip attribute name - full_trl = f"{self.trl}/{name}" - signal_type = await infer_signal_type(full_trl, self.proxy) - if signal_type: - backend = self.filler.fill_child_signal(name, signal_type) - backend.datatype = await infer_python_type(full_trl, self.proxy) - backend.set_trl(full_trl) - if polling := self._signal_polling.get(name, ()): - backend.set_polling(*polling) - backend.allow_events(False) - elif self._polling[0]: - backend.set_polling(*self._polling) - backend.allow_events(False) - # Check that all the requested children have been filled - self.filler.check_filled(f"{self.trl}: {children}") + raise TypeError("Neither proxy nor trl supplied") + + children = sorted( + set() + .union(self.proxy.get_attribute_list()) + .union(self.proxy.get_command_list()) + ) + for name in children: + # TODO: strip attribute name + full_trl = f"{self.trl}/{name}" + signal_type = await infer_signal_type(full_trl, self.proxy) + if signal_type: + backend = self.filler.fill_child_signal(name, signal_type) + backend.datatype = await infer_python_type(full_trl, self.proxy) + backend.set_trl(full_trl) + if polling := self._signal_polling.get(name, ()): + backend.set_polling(*polling) + backend.allow_events(False) + elif self._polling[0]: + backend.set_polling(*self._polling) + backend.allow_events(False) + # Check that all the requested children have been filled + self.filler.check_filled(f"{self.trl}: {children}") # Set the name of the device to name all children device.set_name(device.name) - return await super().connect(device, mock, timeout, force_reconnect) + return await super().connect_real(device, timeout, force_reconnect) diff --git a/tests/core/test_device.py b/tests/core/test_device.py index 2fc127f17b..39a9b70a5d 100644 --- a/tests/core/test_device.py +++ b/tests/core/test_device.py @@ -1,4 +1,5 @@ import asyncio +import time import traceback from unittest.mock import Mock @@ -174,43 +175,50 @@ def __init__(self, name: str) -> None: super().__init__(name) +@pytest.mark.parametrize("parallel", (False, True)) +async def test_many_individual_device_connects_not_slow(parallel): + start = time.time() + bundles = [MotorBundle(f"bundle{i}") for i in range(100)] + if parallel: + for bundle in bundles: + await bundle.connect(mock=True) + else: + coros = {bundle.name: bundle.connect(mock=True) for bundle in bundles} + await wait_for_connection(**coros) + duration = time.time() - start + assert duration < 1 + + async def test_device_with_children_lazily_connects(RE): parentMotor = MotorBundle("parentMotor") for device in [parentMotor, parentMotor.X, parentMotor.Y] + list( parentMotor.V.values() ): - assert device._connect_task is None + assert device._mock is None RE(ensure_connected(parentMotor, mock=True)) for device in [parentMotor, parentMotor.X, parentMotor.Y] + list( parentMotor.V.values() ): - assert ( - device._connect_task is not None - and device._connect_task.done() - and not device._connect_task.exception() - ) + assert device._mock is not None -@pytest.mark.parametrize("use_Mock", [False, True]) -async def test_no_reconnect_signals_if_not_forced(use_Mock): +async def test_no_reconnect_signals_if_not_forced(): parent = DummyDeviceGroup("parent") - connect_mock_arg = Mock() if use_Mock else True - - async def inner_connect(mock, timeout, force_reconnect): + async def inner_connect(mock=False, timeout=None, force_reconnect=False): parent.child1.connected = True parent.child1.connect = Mock(side_effect=inner_connect) - await parent.connect(mock=connect_mock_arg, timeout=0.01) + await parent.connect(mock=False, timeout=0.01) assert parent.child1.connected assert parent.child1.connect.call_count == 1 - await parent.connect(mock=connect_mock_arg, timeout=0.01) + await parent.connect(mock=False, timeout=0.01) assert parent.child1.connected assert parent.child1.connect.call_count == 1 for count in range(2, 10): - await parent.connect(mock=connect_mock_arg, timeout=0.01, force_reconnect=True) + await parent.connect(mock=False, timeout=0.01, force_reconnect=True) assert parent.child1.connected assert parent.child1.connect.call_count == count diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index b03f3b42c5..f80e21f084 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -9,11 +9,7 @@ from bluesky.protocols import Reading from ophyd_async.core import ( - DEFAULT_TIMEOUT, DeviceCollector, - MockSignalBackend, - NotConnected, - Signal, SignalR, SignalRW, SoftSignalBackend, @@ -32,80 +28,44 @@ ) from ophyd_async.core import StandardReadableFormat as Format from ophyd_async.epics.core import epics_signal_r, epics_signal_rw -from ophyd_async.plan_stubs import ensure_connected def num_occurrences(substring: str, string: str) -> int: return len(list(re.finditer(re.escape(substring), string))) +def test_cannot_add_child_to_signal(): + signal = soft_signal_rw(str) + with pytest.raises( + AttributeError, + match="Cannot add Device or Signal child foo=<.*> of Signal, " + "make a subclass of Device instead", + ): + signal.foo = signal + + async def test_signal_connects_to_previous_backend(caplog): caplog.set_level(logging.DEBUG) - int_mock_backend = MockSignalBackend(SoftSignalBackend(int), Mock()) - original_connect = int_mock_backend.connect - times_backend_connect_called = 0 - - async def new_connect(timeout=1): - nonlocal times_backend_connect_called - times_backend_connect_called += 1 - await asyncio.sleep(0.1) - await original_connect(timeout=timeout) - - int_mock_backend.connect = new_connect - signal = Signal(int_mock_backend) - await asyncio.gather(signal.connect(), signal.connect()) + signal = soft_signal_rw(int) + mock_connect = Mock(side_effect=signal._connector.backend.connect) + signal._connector.backend.connect = mock_connect + await signal.connect() + assert mock_connect.call_count == 1 + assert num_occurrences(f"Connecting to {signal.source}", caplog.text) == 1 + await asyncio.gather(signal.connect(), signal.connect(), signal.connect()) + assert mock_connect.call_count == 1 assert num_occurrences(f"Connecting to {signal.source}", caplog.text) == 1 - assert times_backend_connect_called == 1 async def test_signal_connects_with_force_reconnect(caplog): caplog.set_level(logging.DEBUG) - signal = Signal(MockSignalBackend(SoftSignalBackend(int), Mock())) + signal = soft_signal_rw(int) await signal.connect() assert num_occurrences(f"Connecting to {signal.source}", caplog.text) == 1 await signal.connect(force_reconnect=True) assert num_occurrences(f"Connecting to {signal.source}", caplog.text) == 2 -async def test_signal_lazily_connects(RE): - class MockSignalBackendFailingFirst(MockSignalBackend): - succeed_on_connect = False - - async def connect(self, timeout=DEFAULT_TIMEOUT): - if self.succeed_on_connect: - self.succeed_on_connect = False - await super().connect(timeout=timeout) - else: - self.succeed_on_connect = True - raise RuntimeError("connect fail") - - signal = SignalRW(MockSignalBackendFailingFirst(SoftSignalBackend(int), Mock())) - - with pytest.raises(RuntimeError, match="connect fail"): - await signal.connect(mock=False) - - assert ( - signal._connect_task - and signal._connect_task.done() - and signal._connect_task.exception() - ) - - RE(ensure_connected(signal, mock=False)) - assert ( - signal._connect_task - and signal._connect_task.done() - and not signal._connect_task.exception() - ) - - with pytest.raises(NotConnected, match="RuntimeError: connect fail"): - RE(ensure_connected(signal, mock=False, force_reconnect=True)) - assert ( - signal._connect_task - and signal._connect_task.done() - and signal._connect_task.exception() - ) - - async def time_taken_by(coro) -> float: start = time.monotonic() await coro diff --git a/tests/epics/demo/test_demo.py b/tests/epics/demo/test_demo.py index 838b9f5811..e29e13bf29 100644 --- a/tests/epics/demo/test_demo.py +++ b/tests/epics/demo/test_demo.py @@ -10,6 +10,7 @@ from ophyd_async.core import ( DeviceCollector, + LazyMock, NotConnected, assert_emitted, assert_reading, @@ -198,9 +199,10 @@ async def test_retrieve_mock_and_assert(mock_mover: demo.Mover): async def test_mocks_in_device_share_parent(): - mock = Mock() - async with DeviceCollector(mock=mock): - mock_mover = demo.Mover("BLxxI-MO-TABLE-01:Y:") + lm = LazyMock() + mock_mover = demo.Mover("BLxxI-MO-TABLE-01:Y:") + await mock_mover.connect(mock=lm) + mock = lm() assert get_mock(mock_mover) is mock assert get_mock(mock_mover.setpoint) is mock.setpoint diff --git a/tests/plan_stubs/test_ensure_connected.py b/tests/plan_stubs/test_ensure_connected.py index 62cc7d034b..4665ddcee9 100644 --- a/tests/plan_stubs/test_ensure_connected.py +++ b/tests/plan_stubs/test_ensure_connected.py @@ -31,8 +31,8 @@ def connect(): device2 = MyDevice("PREFIX2", name="device2") def connect_with_mocking(): - assert device2.signal._connect_task is None + assert device2.signal._mock is None yield from ensure_connected(device2, mock=True, timeout=0.1) - assert device2.signal._connect_task.done() + assert device2.signal._mock is not None RE(connect_with_mocking())