From 318b71629b038edd72d087403867f3fb0a3f8746 Mon Sep 17 00:00:00 2001 From: Rose Yemelyanova Date: Thu, 31 Aug 2023 16:12:20 +0100 Subject: [PATCH] Split up Ophyd tests to conform to new directory structure --- pyproject.toml | 3 +- tests/core/backends/test_sim.py | 142 +++++++++ tests/core/conftest.py | 42 +++ tests/core/devices/test_device.py | 99 ++++++ tests/core/signals/test_signal.py | 120 ++++++++ tests/core/test_async_status.py | 133 +++++++++ tests/core/test_core.py | 480 ------------------------------ tests/core/test_epicsdemo.py | 6 +- 8 files changed, 539 insertions(+), 486 deletions(-) create mode 100644 tests/core/backends/test_sim.py create mode 100644 tests/core/conftest.py create mode 100644 tests/core/devices/test_device.py create mode 100644 tests/core/signals/test_signal.py create mode 100644 tests/core/test_async_status.py delete mode 100644 tests/core/test_core.py diff --git a/pyproject.toml b/pyproject.toml index 70384e2bac..48804ca713 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,10 +115,9 @@ addopts = """ --cov=src/ophyd_async --cov-report term --cov-report xml:cov.xml """ # https://iscinumpy.gitlab.io/post/bound-version-constraints/#watch-for-warnings -filterwarnings = "error" +filterwarnings = ["error", "ignore::DeprecationWarning:pkg_resources"] # Doctest python code in docs, python code in src docstrings, test functions in tests testpaths = "docs src tests" -timeout = 20 log_format = "%(asctime)s,%(msecs)03d %(levelname)s (%(threadName)s) %(message)s" log_date_format = "%H:%M:%S" diff --git a/tests/core/backends/test_sim.py b/tests/core/backends/test_sim.py new file mode 100644 index 0000000000..2285bdcf08 --- /dev/null +++ b/tests/core/backends/test_sim.py @@ -0,0 +1,142 @@ +import asyncio +import time +from enum import Enum +from typing import Any, Callable, Sequence, Tuple, Type + +import numpy as np +import numpy.typing as npt +import pytest +from bluesky.protocols import Reading + +from ophyd_async.core.backends import SignalBackend, SimSignalBackend +from ophyd_async.core.signals import Signal +from ophyd_async.core.utils import T + + +class MyEnum(str, Enum): + a = "Aaa" + b = "Bbb" + c = "Ccc" + + +def integer_d(value): + return dict(dtype="integer", shape=[]) + + +def number_d(value): + return dict(dtype="number", shape=[]) + + +def string_d(value): + return dict(dtype="string", shape=[]) + + +def enum_d(value): + return dict(dtype="string", shape=[], choices=["Aaa", "Bbb", "Ccc"]) + + +def waveform_d(value): + return dict(dtype="array", shape=[len(value)]) + + +class MonitorQueue: + def __init__(self, backend: SignalBackend): + self.backend = backend + self.updates: asyncio.Queue[Tuple[Reading, Any]] = asyncio.Queue() + backend.set_callback(self.add_reading_value) + + def add_reading_value(self, reading: Reading, value): + self.updates.put_nowait((reading, value)) + + async def assert_updates(self, expected_value): + expected_reading = { + "value": expected_value, + "timestamp": pytest.approx(time.monotonic(), rel=0.1), + "alarm_severity": 0, + } + reading, value = await self.updates.get() + + backend_value = await self.backend.get_value() + backend_reading = await self.backend.get_reading() + + assert value == expected_value == backend_value + assert reading == expected_reading == backend_reading + + def close(self): + self.backend.set_callback(None) + + +@pytest.mark.parametrize( + "datatype, initial_value, put_value, descriptor", + [ + (int, 0, 43, integer_d), + (float, 0.0, 43.5, number_d), + (str, "", "goodbye", string_d), + (MyEnum, MyEnum.a, MyEnum.c, enum_d), + (npt.NDArray[np.int8], [], [-8, 3, 44], waveform_d), + (npt.NDArray[np.uint8], [], [218], waveform_d), + (npt.NDArray[np.int16], [], [-855], waveform_d), + (npt.NDArray[np.uint16], [], [5666], waveform_d), + (npt.NDArray[np.int32], [], [-2], waveform_d), + (npt.NDArray[np.uint32], [], [1022233], waveform_d), + (npt.NDArray[np.int64], [], [-3], waveform_d), + (npt.NDArray[np.uint64], [], [995444], waveform_d), + (npt.NDArray[np.float32], [], [1.0], waveform_d), + (npt.NDArray[np.float64], [], [0.2], waveform_d), + (Sequence[str], [], ["nine", "ten"], waveform_d), + # Can't do long strings until https://github.com/epics-base/pva2pva/issues/17 + # (str, "longstr", ls1, ls2, string_d), + # (str, "longstr2.VAL$", ls1, ls2, string_d), + ], +) +async def test_backend_get_put_monitor( + datatype: Type[T], + initial_value: T, + put_value: T, + descriptor: Callable[[Any], dict], +): + backend = SimSignalBackend(datatype, "") + + await backend.connect() + q = MonitorQueue(backend) + try: + # Check descriptor + assert ( + dict(source="sim://", **descriptor(initial_value)) + == await backend.get_descriptor() + ) + # Check initial value + await q.assert_updates( + pytest.approx(initial_value) if initial_value != "" else initial_value + ) + # Put to new value and check that + await backend.put(put_value) + await q.assert_updates(pytest.approx(put_value)) + finally: + q.close() + + +async def test_sim_backend_if_disconnected(): + sim_backend = SimSignalBackend(npt.NDArray[np.float64], "SOME-IOC:PV") + with pytest.raises(NotImplementedError): + await sim_backend.get_value() + + +async def test_sim_backend_with_numpy_typing(): + sim_backend = SimSignalBackend(npt.NDArray[np.float64], "SOME-IOC:PV") + await sim_backend.connect() + + array = await sim_backend.get_value() + assert array.shape == (0,) + + +async def test_sim_backend_descriptor_fails_for_invalid_class(): + class myClass: + def __init__(self) -> None: + pass + + sim_signal = Signal(SimSignalBackend(myClass, "test")) + await sim_signal.connect(sim=True) + + with pytest.raises(AssertionError): + await sim_signal._backend.get_descriptor() diff --git a/tests/core/conftest.py b/tests/core/conftest.py new file mode 100644 index 0000000000..9228b3ed5c --- /dev/null +++ b/tests/core/conftest.py @@ -0,0 +1,42 @@ +import asyncio +from typing import Callable, Coroutine + +import pytest +from bluesky.run_engine import RunEngine, TransitionError + + +@pytest.fixture(scope="function") +def RE(request): + loop = asyncio.new_event_loop() + loop.set_debug(True) + RE = RunEngine({}, call_returns_result=True, loop=loop) + + def clean_event_loop(): + if RE.state not in ("idle", "panicked"): + try: + RE.halt() + except TransitionError: + pass + loop.call_soon_threadsafe(loop.stop) + RE._th.join() + loop.close() + + request.addfinalizer(clean_event_loop) + return RE + + +@pytest.fixture +async def normal_coroutine() -> Callable[[None], Coroutine]: + async def inner_coroutine(): + await asyncio.sleep(0.01) + + return inner_coroutine + + +@pytest.fixture +async def failing_coroutine() -> Callable[[None], Coroutine]: + async def inner_coroutine(): + await asyncio.sleep(0.01) + raise ValueError() + + return inner_coroutine diff --git a/tests/core/devices/test_device.py b/tests/core/devices/test_device.py new file mode 100644 index 0000000000..b8c62b886c --- /dev/null +++ b/tests/core/devices/test_device.py @@ -0,0 +1,99 @@ +import asyncio +import traceback + +import pytest + +from ophyd_async.core.device_collector import DeviceCollector +from ophyd_async.core.devices import Device, DeviceVector, get_device_children +from ophyd_async.core.utils import wait_for_connection + + +class DummyBaseDevice(Device): + def __init__(self) -> None: + self.connected = False + + async def connect(self, sim=False): + self.connected = True + + +class DummyDeviceGroup(Device): + def __init__(self, name: str) -> None: + self.child1 = DummyBaseDevice() + self.child2 = DummyBaseDevice() + self.dict_with_children: DeviceVector[DummyBaseDevice] = DeviceVector( + {123: DummyBaseDevice()} + ) + self.set_name(name) + + +@pytest.fixture +def parent() -> DummyDeviceGroup: + return DummyDeviceGroup("parent") + + +def test_get_device_children(parent: DummyDeviceGroup): + names = ["child1", "child2", "dict_with_children"] + for idx, (name, child) in enumerate(get_device_children(parent)): + assert name == names[idx] + assert ( + type(child) is DummyBaseDevice + if name.startswith("child") + else type(child) is DeviceVector + ) + + +async def test_children_of_device_have_set_names_and_get_connected( + parent: DummyDeviceGroup, +): + assert parent.name == "parent" + assert parent.child1.name == "parent-child1" + assert parent.child2.name == "parent-child2" + assert parent.dict_with_children.name == "parent-dict_with_children" + assert parent.dict_with_children[123].name == "parent-dict_with_children-123" + + await parent.connect() + + assert parent.child1.connected + assert parent.dict_with_children[123].connected + + +async def test_device_with_device_collector(): + async with DeviceCollector(sim=True): + parent = DummyDeviceGroup("parent") + + assert parent.name == "parent" + assert parent.child1.name == "parent-child1" + assert parent.child2.name == "parent-child2" + assert parent.dict_with_children.name == "parent-dict_with_children" + assert parent.dict_with_children[123].name == "parent-dict_with_children-123" + assert parent.child1.connected + assert parent.dict_with_children[123].connected + + +async def test_wait_for_connection(): + class DummyDeviceWithSleep(DummyBaseDevice): + def __init__(self, name) -> None: + self.set_name(name) + + async def connect(self, sim=False): + await asyncio.sleep(0.01) + self.connected = True + + device1, device2 = DummyDeviceWithSleep("device1"), DummyDeviceWithSleep("device2") + + normal_coros = {"device1": device1.connect(), "device2": device2.connect()} + + await wait_for_connection(**normal_coros) + + assert device1.connected + assert device2.connected + + +async def test_wait_for_connection_propagates_error( + normal_coroutine, failing_coroutine +): + failing_coros = {"test": normal_coroutine(), "failing": failing_coroutine()} + + with pytest.raises(ValueError) as e: + await wait_for_connection(**failing_coros) + assert traceback.extract_tb(e.__traceback__)[-1].name == "failing_coroutine" diff --git a/tests/core/signals/test_signal.py b/tests/core/signals/test_signal.py new file mode 100644 index 0000000000..da59937546 --- /dev/null +++ b/tests/core/signals/test_signal.py @@ -0,0 +1,120 @@ +import asyncio +import re +import time + +import pytest + +from ophyd_async.core.backends import SimSignalBackend +from ophyd_async.core.signals import ( + Signal, + SignalRW, + set_and_wait_for_value, + set_sim_put_proceeds, + set_sim_value, + wait_for_value, +) + + +class MySignal(Signal): + @property + def source(self) -> str: + return "me" + + async def connect(self, sim=False): + pass + + +def test_signals_equality_raises(): + sim_backend = SimSignalBackend(str, "test") + + s1 = MySignal(sim_backend) + s2 = MySignal(sim_backend) + with pytest.raises( + TypeError, + match=re.escape( + "Can't compare two Signals, did you mean await signal.get_value() instead?" + ), + ): + s1 == s2 + with pytest.raises( + TypeError, + match=re.escape("'>' not supported between instances of 'MySignal' and 'int'"), + ): + s1 > 4 + + +async def test_set_sim_put_proceeds(): + sim_signal = Signal(SimSignalBackend(str, "test")) + await sim_signal.connect(sim=True) + + assert sim_signal._backend.put_proceeds.is_set() is True + + set_sim_put_proceeds(sim_signal, False) + assert sim_signal._backend.put_proceeds.is_set() is False + set_sim_put_proceeds(sim_signal, True) + assert sim_signal._backend.put_proceeds.is_set() is True + + +async def time_taken_by(coro) -> float: + start = time.monotonic() + await coro + return time.monotonic() - start + + +async def test_wait_for_value_with_value(): + sim_signal = SignalRW(SimSignalBackend(str, "test")) + sim_signal.set_name("sim_signal") + await sim_signal.connect(sim=True) + set_sim_value(sim_signal, "blah") + + with pytest.raises( + TimeoutError, + match="sim_signal didn't match 'something' in 0.1s, last value 'blah'", + ): + await wait_for_value(sim_signal, "something", timeout=0.1) + assert await time_taken_by(wait_for_value(sim_signal, "blah", timeout=2)) < 0.1 + t = asyncio.create_task( + time_taken_by(wait_for_value(sim_signal, "something else", timeout=2)) + ) + await asyncio.sleep(0.2) + assert not t.done() + set_sim_value(sim_signal, "something else") + assert 0.2 < await t < 1.0 + + +async def test_wait_for_value_with_funcion(): + sim_signal = SignalRW(SimSignalBackend(float, "test")) + sim_signal.set_name("sim_signal") + await sim_signal.connect(sim=True) + set_sim_value(sim_signal, 45.8) + + def less_than_42(v): + return v < 42 + + with pytest.raises( + TimeoutError, + match="sim_signal didn't match less_than_42 in 0.1s, last value 45.8", + ): + await wait_for_value(sim_signal, less_than_42, timeout=0.1) + t = asyncio.create_task( + time_taken_by(wait_for_value(sim_signal, less_than_42, timeout=2)) + ) + await asyncio.sleep(0.2) + assert not t.done() + set_sim_value(sim_signal, 41) + assert 0.2 < await t < 1.0 + assert ( + await time_taken_by(wait_for_value(sim_signal, less_than_42, timeout=2)) < 0.1 + ) + + +async def test_set_and_wait_for_value(): + sim_signal = SignalRW(SimSignalBackend(int, "test")) + sim_signal.set_name("sim_signal") + await sim_signal.connect(sim=True) + set_sim_value(sim_signal, 0) + set_sim_put_proceeds(sim_signal, False) + st = await set_and_wait_for_value(sim_signal, 1) + assert not st.done + set_sim_put_proceeds(sim_signal, True) + assert await time_taken_by(st) < 0.1 diff --git a/tests/core/test_async_status.py b/tests/core/test_async_status.py new file mode 100644 index 0000000000..cfd3c15707 --- /dev/null +++ b/tests/core/test_async_status.py @@ -0,0 +1,133 @@ +import asyncio +import traceback +from unittest.mock import Mock + +import bluesky.plan_stubs as bps +import pytest +from bluesky.protocols import Movable, Status +from bluesky.utils import FailedStatus + +from ophyd_async.core.async_status import AsyncStatus +from ophyd_async.core.devices import Device + + +async def test_async_status_success(): + st = AsyncStatus(asyncio.sleep(0.1)) + assert isinstance(st, Status) + assert not st.done + assert not st.success + await st + assert st.done + assert st.success + + +async def test_async_status_propagates_exception(failing_coroutine): + status = AsyncStatus(failing_coroutine()) + assert status.exception() is None + + with pytest.raises(ValueError): + await status + + assert type(status.exception()) == ValueError + + +async def test_async_status_propagates_cancelled_error(normal_coroutine): + status = AsyncStatus(normal_coroutine()) + assert status.exception() is None + + status.task.exception = Mock(side_effect=asyncio.CancelledError("")) + await status + + assert type(status.exception()) == asyncio.CancelledError + + +async def test_async_status_has_no_exception_if_coroutine_successful(normal_coroutine): + status = AsyncStatus(normal_coroutine()) + assert status.exception() is None + + await status + + assert status.exception() is None + + +async def test_async_status_success_if_cancelled(normal_coroutine): + status = AsyncStatus(normal_coroutine()) + assert status.exception() is None + status.task.cancel() + with pytest.raises(asyncio.CancelledError): + await status + assert status.success is False + assert isinstance(status.exception(), asyncio.CancelledError) + + +async def coroutine_to_wrap(time: float): + await asyncio.sleep(time) + + +async def test_async_status_wrap(): + wrapped_coroutine = AsyncStatus.wrap(coroutine_to_wrap) + status: AsyncStatus = wrapped_coroutine(0.01) + + await status + assert status.success is True + + +async def test_async_status_initialised_with_a_task(normal_coroutine): + normal_task = asyncio.Task(normal_coroutine()) + status = AsyncStatus(normal_task) + + await status + assert status.success is True + + +async def test_async_status_str_for_normal_coroutine(normal_coroutine): + normal_task = asyncio.Task(normal_coroutine()) + status = AsyncStatus(normal_task) + + assert str(status) == "" + await status + + assert str(status) == "" + + +async def test_async_status_str_for_failing_coroutine(failing_coroutine): + failing_task = asyncio.Task(failing_coroutine()) + status = AsyncStatus(failing_task) + + assert str(status) == "" + with pytest.raises(ValueError): + await status + + assert str(status) == "" + + +class FailingMovable(Movable, Device): + def _fail(self): + raise ValueError("This doesn't work") + + async def _set(self, value): + if value: + self._fail() + + def set(self, value) -> AsyncStatus: + return AsyncStatus(self._set(value)) + + +async def test_status_propogates_traceback_under_RE(RE) -> None: + expected_call_stack = ["_set", "_fail"] + d = FailingMovable() + with pytest.raises(FailedStatus) as ctx: + RE(bps.mv(d, 3)) + # We get "The above exception was the direct cause of the following exception:", + # so extract that first exception traceback and check + assert ctx.value.__cause__ + assert expected_call_stack == [ + x.name for x in traceback.extract_tb(ctx.value.__cause__.__traceback__) + ] + # Check we get the same from the status.exception + status: AsyncStatus = ctx.value.args[0] + exception = status.exception() + assert exception + assert expected_call_stack == [ + x.name for x in traceback.extract_tb(exception.__traceback__) + ] diff --git a/tests/core/test_core.py b/tests/core/test_core.py deleted file mode 100644 index d202ea32f6..0000000000 --- a/tests/core/test_core.py +++ /dev/null @@ -1,480 +0,0 @@ -import asyncio -import re -import time -import traceback -from enum import Enum -from typing import Any, Callable, Sequence, Tuple, Type -from unittest.mock import Mock - -import bluesky.plan_stubs as bps -import numpy as np -import numpy.typing as npt -import pytest -from bluesky import FailedStatus, RunEngine -from bluesky.protocols import Movable, Reading, Status - -from ophyd_async.core import ( - AsyncStatus, - Device, - DeviceCollector, - DeviceVector, - Signal, - SignalBackend, - SignalRW, - SimSignalBackend, - T, - get_device_children, - set_and_wait_for_value, - set_sim_put_proceeds, - set_sim_value, - wait_for_connection, - wait_for_value, -) - - -class MySignal(Signal): - @property - def source(self) -> str: - return "me" - - async def connect(self, sim=False): - pass - - -def test_signals_equality_raises(): - sim_backend = SimSignalBackend(str, "test") - - s1 = MySignal(sim_backend) - s2 = MySignal(sim_backend) - with pytest.raises( - TypeError, - match=re.escape( - "Can't compare two Signals, did you mean await signal.get_value() instead?" - ), - ): - s1 == s2 - with pytest.raises( - TypeError, - match=re.escape("'>' not supported between instances of 'MySignal' and 'int'"), - ): - s1 > 4 - - -class MyEnum(str, Enum): - a = "Aaa" - b = "Bbb" - c = "Ccc" - - -def integer_d(value): - return dict(dtype="integer", shape=[]) - - -def number_d(value): - return dict(dtype="number", shape=[]) - - -def string_d(value): - return dict(dtype="string", shape=[]) - - -def enum_d(value): - return dict(dtype="string", shape=[], choices=["Aaa", "Bbb", "Ccc"]) - - -def waveform_d(value): - return dict(dtype="array", shape=[len(value)]) - - -class MonitorQueue: - def __init__(self, backend: SignalBackend): - self.backend = backend - self.updates: asyncio.Queue[Tuple[Reading, Any]] = asyncio.Queue() - backend.set_callback(self.add_reading_value) - - def add_reading_value(self, reading: Reading, value): - self.updates.put_nowait((reading, value)) - - async def assert_updates(self, expected_value): - expected_reading = { - "value": expected_value, - "timestamp": pytest.approx(time.monotonic(), rel=0.1), - "alarm_severity": 0, - } - reading, value = await self.updates.get() - - backend_value = await self.backend.get_value() - backend_reading = await self.backend.get_reading() - - assert value == expected_value == backend_value - assert reading == expected_reading == backend_reading - - def close(self): - self.backend.set_callback(None) - - -@pytest.mark.parametrize( - "datatype, initial_value, put_value, descriptor", - [ - (int, 0, 43, integer_d), - (float, 0.0, 43.5, number_d), - (str, "", "goodbye", string_d), - (MyEnum, MyEnum.a, MyEnum.c, enum_d), - (npt.NDArray[np.int8], [], [-8, 3, 44], waveform_d), - (npt.NDArray[np.uint8], [], [218], waveform_d), - (npt.NDArray[np.int16], [], [-855], waveform_d), - (npt.NDArray[np.uint16], [], [5666], waveform_d), - (npt.NDArray[np.int32], [], [-2], waveform_d), - (npt.NDArray[np.uint32], [], [1022233], waveform_d), - (npt.NDArray[np.int64], [], [-3], waveform_d), - (npt.NDArray[np.uint64], [], [995444], waveform_d), - (npt.NDArray[np.float32], [], [1.0], waveform_d), - (npt.NDArray[np.float64], [], [0.2], waveform_d), - (Sequence[str], [], ["nine", "ten"], waveform_d), - # Can't do long strings until https://github.com/epics-base/pva2pva/issues/17 - # (str, "longstr", ls1, ls2, string_d), - # (str, "longstr2.VAL$", ls1, ls2, string_d), - ], -) -async def test_backend_get_put_monitor( - datatype: Type[T], - initial_value: T, - put_value: T, - descriptor: Callable[[Any], dict], -): - backend = SimSignalBackend(datatype, "") - - await backend.connect() - q = MonitorQueue(backend) - try: - # Check descriptor - assert ( - dict(source="sim://", **descriptor(initial_value)) - == await backend.get_descriptor() - ) - # Check initial value - await q.assert_updates( - pytest.approx(initial_value) if initial_value != "" else initial_value - ) - # Put to new value and check that - await backend.put(put_value) - await q.assert_updates(pytest.approx(put_value)) - finally: - q.close() - - -async def test_sim_backend_if_disconnected(): - sim_backend = SimSignalBackend(npt.NDArray[np.float64], "SOME-IOC:PV") - with pytest.raises(NotImplementedError): - await sim_backend.get_value() - - -async def test_sim_backend_with_numpy_typing(): - sim_backend = SimSignalBackend(npt.NDArray[np.float64], "SOME-IOC:PV") - await sim_backend.connect() - - array = await sim_backend.get_value() - assert array.shape == (0,) - - -async def test_async_status_success(): - st = AsyncStatus(asyncio.sleep(0.1)) - assert isinstance(st, Status) - assert not st.done - assert not st.success - await st - assert st.done - assert st.success - - -class DummyBaseDevice(Device): - def __init__(self) -> None: - self.connected = False - - async def connect(self, sim=False): - self.connected = True - - -class DummyDeviceGroup(Device): - def __init__(self, name: str) -> None: - self.child1 = DummyBaseDevice() - self.child2 = DummyBaseDevice() - self.dict_with_children: DeviceVector[DummyBaseDevice] = DeviceVector( - {123: DummyBaseDevice()} - ) - self.set_name(name) - - -def test_get_device_children(): - parent = DummyDeviceGroup("parent") - - names = ["child1", "child2", "dict_with_children"] - for idx, (name, child) in enumerate(get_device_children(parent)): - assert name == names[idx] - assert ( - type(child) is DummyBaseDevice - if name.startswith("child") - else type(child) is DeviceVector - ) - - -async def test_children_of_device_have_set_names_and_get_connected(): - parent = DummyDeviceGroup("parent") - - assert parent.name == "parent" - assert parent.child1.name == "parent-child1" - assert parent.child2.name == "parent-child2" - assert parent.dict_with_children.name == "parent-dict_with_children" - assert parent.dict_with_children[123].name == "parent-dict_with_children-123" - - await parent.connect() - - assert parent.child1.connected - assert parent.dict_with_children[123].connected - - -async def test_device_with_device_collector(): - async with DeviceCollector(sim=True): - parent = DummyDeviceGroup("parent") - - assert parent.name == "parent" - assert parent.child1.name == "parent-child1" - assert parent.child2.name == "parent-child2" - assert parent.dict_with_children.name == "parent-dict_with_children" - assert parent.dict_with_children[123].name == "parent-dict_with_children-123" - assert parent.child1.connected - assert parent.dict_with_children[123].connected - - -async def normal_coroutine(time: float): - await asyncio.sleep(time) - - -async def failing_coroutine(time: float): - await normal_coroutine(time) - raise ValueError() - - -async def test_async_status_propagates_exception(): - status = AsyncStatus(failing_coroutine(0.1)) - assert status.exception() is None - - with pytest.raises(ValueError): - await status - - assert type(status.exception()) == ValueError - - -async def test_async_status_propagates_cancelled_error(): - status = AsyncStatus(normal_coroutine(0.1)) - assert status.exception() is None - - status.task.exception = Mock(side_effect=asyncio.CancelledError("")) - await status - - assert type(status.exception()) == asyncio.CancelledError - - -async def test_async_status_has_no_exception_if_coroutine_successful(): - status = AsyncStatus(normal_coroutine(0.1)) - assert status.exception() is None - - await status - - assert status.exception() is None - - -async def test_async_status_success_if_cancelled(): - status = AsyncStatus(normal_coroutine(0.1)) - assert status.exception() is None - status.task.cancel() - with pytest.raises(asyncio.CancelledError): - await status - assert status.success is False - assert isinstance(status.exception(), asyncio.CancelledError) - - -async def test_async_status_wrap(): - wrapped_coroutine = AsyncStatus.wrap(normal_coroutine) - status = wrapped_coroutine(0.1) - - await status - assert status.success is True - - -async def test_async_status_initialised_with_a_task(): - normal_task = asyncio.Task(normal_coroutine(0.1)) - status = AsyncStatus(normal_task) - - await status - assert status.success is True - - -async def test_async_status_str_for_normal_coroutine(): - normal_task = asyncio.Task(normal_coroutine(0.01)) - status = AsyncStatus(normal_task) - - assert str(status) == "" - await status - - assert str(status) == "" - - -async def test_async_status_str_for_failing_coroutine(): - failing_task = asyncio.Task(failing_coroutine(0.01)) - status = AsyncStatus(failing_task) - - assert str(status) == "" - with pytest.raises(ValueError): - await status - - assert str(status) == "" - - -async def test_wait_for_connection(): - class DummyDeviceWithSleep(DummyBaseDevice): - def __init__(self, name) -> None: - self.set_name(name) - - async def connect(self, sim=False): - await asyncio.sleep(0.01) - self.connected = True - - device1, device2 = DummyDeviceWithSleep("device1"), DummyDeviceWithSleep("device2") - - normal_coros = {"device1": device1.connect(), "device2": device2.connect()} - - await wait_for_connection(**normal_coros) - - assert device1.connected - assert device2.connected - - -async def test_wait_for_connection_propagates_error(): - failing_coros = {"test": normal_coroutine(0.01), "failing": failing_coroutine(0.01)} - - with pytest.raises(ValueError) as e: - await wait_for_connection(**failing_coros) - assert traceback.extract_tb(e.__traceback__)[-1].name == "failing_coroutine" - - -class FailingMovable(Movable, Device): - def _fail(self): - raise ValueError("This doesn't work") - - async def _set(self, value): - if value: - self._fail() - - def set(self, value) -> AsyncStatus: - return AsyncStatus(self._set(value)) - - -async def test_status_propogates_traceback_under_RE() -> None: - expected_call_stack = ["_set", "_fail"] - RE = RunEngine() - d = FailingMovable() - with pytest.raises(FailedStatus) as ctx: - RE(bps.mv(d, 3)) - # We get "The above exception was the direct cause of the following exception:", - # so extract that first exception traceback and check - assert ctx.value.__cause__ - assert expected_call_stack == [ - x.name for x in traceback.extract_tb(ctx.value.__cause__.__traceback__) - ] - # Check we get the same from the status.exception - status: AsyncStatus = ctx.value.args[0] - exception = status.exception() - assert exception - assert expected_call_stack == [ - x.name for x in traceback.extract_tb(exception.__traceback__) - ] - - -async def test_set_sim_put_proceeds(): - sim_signal = Signal(SimSignalBackend(str, "test")) - await sim_signal.connect(sim=True) - - assert sim_signal._backend.put_proceeds.is_set() is True - - set_sim_put_proceeds(sim_signal, False) - assert sim_signal._backend.put_proceeds.is_set() is False - set_sim_put_proceeds(sim_signal, True) - assert sim_signal._backend.put_proceeds.is_set() is True - - -async def test_sim_backend_descriptor_fails_for_invalid_class(): - class myClass: - def __init__(self) -> None: - pass - - sim_signal = Signal(SimSignalBackend(myClass, "test")) - await sim_signal.connect(sim=True) - - with pytest.raises(AssertionError): - await sim_signal._backend.get_descriptor() - - -async def time_taken_by(coro) -> float: - start = time.monotonic() - await coro - return time.monotonic() - start - - -async def test_wait_for_value_with_value(): - sim_signal = SignalRW(SimSignalBackend(str, "test")) - sim_signal.set_name("sim_signal") - await sim_signal.connect(sim=True) - set_sim_value(sim_signal, "blah") - - with pytest.raises( - TimeoutError, - match="sim_signal didn't match 'something' in 0.1s, last value 'blah'", - ): - await wait_for_value(sim_signal, "something", timeout=0.1) - assert await time_taken_by(wait_for_value(sim_signal, "blah", timeout=2)) < 0.1 - t = asyncio.create_task( - time_taken_by(wait_for_value(sim_signal, "something else", timeout=2)) - ) - await asyncio.sleep(0.2) - assert not t.done() - set_sim_value(sim_signal, "something else") - assert 0.2 < await t < 1.0 - - -async def test_wait_for_value_with_funcion(): - sim_signal = SignalRW(SimSignalBackend(float, "test")) - sim_signal.set_name("sim_signal") - await sim_signal.connect(sim=True) - set_sim_value(sim_signal, 45.8) - - def less_than_42(v): - return v < 42 - - with pytest.raises( - TimeoutError, - match="sim_signal didn't match less_than_42 in 0.1s, last value 45.8", - ): - await wait_for_value(sim_signal, less_than_42, timeout=0.1) - t = asyncio.create_task( - time_taken_by(wait_for_value(sim_signal, less_than_42, timeout=2)) - ) - await asyncio.sleep(0.2) - assert not t.done() - set_sim_value(sim_signal, 41) - assert 0.2 < await t < 1.0 - assert ( - await time_taken_by(wait_for_value(sim_signal, less_than_42, timeout=2)) < 0.1 - ) - - -async def test_set_and_wait_for_value(): - sim_signal = SignalRW(SimSignalBackend(int, "test")) - sim_signal.set_name("sim_signal") - await sim_signal.connect(sim=True) - set_sim_value(sim_signal, 0) - set_sim_put_proceeds(sim_signal, False) - st = await set_and_wait_for_value(sim_signal, 1) - assert not st.done - set_sim_put_proceeds(sim_signal, True) - assert await time_taken_by(st) < 0.1 diff --git a/tests/core/test_epicsdemo.py b/tests/core/test_epicsdemo.py index 0c6c875ea3..b6d6504ffe 100644 --- a/tests/core/test_epicsdemo.py +++ b/tests/core/test_epicsdemo.py @@ -4,12 +4,11 @@ import pytest from bluesky.protocols import Reading -from bluesky.run_engine import RunEngine -from ophyd_async.core import epicsdemo from ophyd_async.core import ( DeviceCollector, NotConnected, + epicsdemo, set_sim_callback, set_sim_value, ) @@ -183,8 +182,7 @@ async def test_assembly_renaming() -> None: assert thing.x.stop_.name == "foo-x-stop" -def test_mover_in_re(sim_mover: epicsdemo.Mover) -> None: - RE = RunEngine() +def test_mover_in_re(sim_mover: epicsdemo.Mover, RE) -> None: sim_mover.move(0) def my_plan():