-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Split up Ophyd tests to conform to new directory structure
- Loading branch information
Rose Yemelyanova
committed
Aug 31, 2023
1 parent
fc00894
commit 318b716
Showing
8 changed files
with
539 additions
and
486 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import asyncio | ||
import time | ||
from enum import Enum | ||
from typing import Any, Callable, Sequence, Tuple, Type | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
import pytest | ||
from bluesky.protocols import Reading | ||
|
||
from ophyd_async.core.backends import SignalBackend, SimSignalBackend | ||
from ophyd_async.core.signals import Signal | ||
from ophyd_async.core.utils import T | ||
|
||
|
||
class MyEnum(str, Enum): | ||
a = "Aaa" | ||
b = "Bbb" | ||
c = "Ccc" | ||
|
||
|
||
def integer_d(value): | ||
return dict(dtype="integer", shape=[]) | ||
|
||
|
||
def number_d(value): | ||
return dict(dtype="number", shape=[]) | ||
|
||
|
||
def string_d(value): | ||
return dict(dtype="string", shape=[]) | ||
|
||
|
||
def enum_d(value): | ||
return dict(dtype="string", shape=[], choices=["Aaa", "Bbb", "Ccc"]) | ||
|
||
|
||
def waveform_d(value): | ||
return dict(dtype="array", shape=[len(value)]) | ||
|
||
|
||
class MonitorQueue: | ||
def __init__(self, backend: SignalBackend): | ||
self.backend = backend | ||
self.updates: asyncio.Queue[Tuple[Reading, Any]] = asyncio.Queue() | ||
backend.set_callback(self.add_reading_value) | ||
|
||
def add_reading_value(self, reading: Reading, value): | ||
self.updates.put_nowait((reading, value)) | ||
|
||
async def assert_updates(self, expected_value): | ||
expected_reading = { | ||
"value": expected_value, | ||
"timestamp": pytest.approx(time.monotonic(), rel=0.1), | ||
"alarm_severity": 0, | ||
} | ||
reading, value = await self.updates.get() | ||
|
||
backend_value = await self.backend.get_value() | ||
backend_reading = await self.backend.get_reading() | ||
|
||
assert value == expected_value == backend_value | ||
assert reading == expected_reading == backend_reading | ||
|
||
def close(self): | ||
self.backend.set_callback(None) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"datatype, initial_value, put_value, descriptor", | ||
[ | ||
(int, 0, 43, integer_d), | ||
(float, 0.0, 43.5, number_d), | ||
(str, "", "goodbye", string_d), | ||
(MyEnum, MyEnum.a, MyEnum.c, enum_d), | ||
(npt.NDArray[np.int8], [], [-8, 3, 44], waveform_d), | ||
(npt.NDArray[np.uint8], [], [218], waveform_d), | ||
(npt.NDArray[np.int16], [], [-855], waveform_d), | ||
(npt.NDArray[np.uint16], [], [5666], waveform_d), | ||
(npt.NDArray[np.int32], [], [-2], waveform_d), | ||
(npt.NDArray[np.uint32], [], [1022233], waveform_d), | ||
(npt.NDArray[np.int64], [], [-3], waveform_d), | ||
(npt.NDArray[np.uint64], [], [995444], waveform_d), | ||
(npt.NDArray[np.float32], [], [1.0], waveform_d), | ||
(npt.NDArray[np.float64], [], [0.2], waveform_d), | ||
(Sequence[str], [], ["nine", "ten"], waveform_d), | ||
# Can't do long strings until https://github.com/epics-base/pva2pva/issues/17 | ||
# (str, "longstr", ls1, ls2, string_d), | ||
# (str, "longstr2.VAL$", ls1, ls2, string_d), | ||
], | ||
) | ||
async def test_backend_get_put_monitor( | ||
datatype: Type[T], | ||
initial_value: T, | ||
put_value: T, | ||
descriptor: Callable[[Any], dict], | ||
): | ||
backend = SimSignalBackend(datatype, "") | ||
|
||
await backend.connect() | ||
q = MonitorQueue(backend) | ||
try: | ||
# Check descriptor | ||
assert ( | ||
dict(source="sim://", **descriptor(initial_value)) | ||
== await backend.get_descriptor() | ||
) | ||
# Check initial value | ||
await q.assert_updates( | ||
pytest.approx(initial_value) if initial_value != "" else initial_value | ||
) | ||
# Put to new value and check that | ||
await backend.put(put_value) | ||
await q.assert_updates(pytest.approx(put_value)) | ||
finally: | ||
q.close() | ||
|
||
|
||
async def test_sim_backend_if_disconnected(): | ||
sim_backend = SimSignalBackend(npt.NDArray[np.float64], "SOME-IOC:PV") | ||
with pytest.raises(NotImplementedError): | ||
await sim_backend.get_value() | ||
|
||
|
||
async def test_sim_backend_with_numpy_typing(): | ||
sim_backend = SimSignalBackend(npt.NDArray[np.float64], "SOME-IOC:PV") | ||
await sim_backend.connect() | ||
|
||
array = await sim_backend.get_value() | ||
assert array.shape == (0,) | ||
|
||
|
||
async def test_sim_backend_descriptor_fails_for_invalid_class(): | ||
class myClass: | ||
def __init__(self) -> None: | ||
pass | ||
|
||
sim_signal = Signal(SimSignalBackend(myClass, "test")) | ||
await sim_signal.connect(sim=True) | ||
|
||
with pytest.raises(AssertionError): | ||
await sim_signal._backend.get_descriptor() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import asyncio | ||
from typing import Callable, Coroutine | ||
|
||
import pytest | ||
from bluesky.run_engine import RunEngine, TransitionError | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def RE(request): | ||
loop = asyncio.new_event_loop() | ||
loop.set_debug(True) | ||
RE = RunEngine({}, call_returns_result=True, loop=loop) | ||
|
||
def clean_event_loop(): | ||
if RE.state not in ("idle", "panicked"): | ||
try: | ||
RE.halt() | ||
except TransitionError: | ||
pass | ||
loop.call_soon_threadsafe(loop.stop) | ||
RE._th.join() | ||
loop.close() | ||
|
||
request.addfinalizer(clean_event_loop) | ||
return RE | ||
|
||
|
||
@pytest.fixture | ||
async def normal_coroutine() -> Callable[[None], Coroutine]: | ||
async def inner_coroutine(): | ||
await asyncio.sleep(0.01) | ||
|
||
return inner_coroutine | ||
|
||
|
||
@pytest.fixture | ||
async def failing_coroutine() -> Callable[[None], Coroutine]: | ||
async def inner_coroutine(): | ||
await asyncio.sleep(0.01) | ||
raise ValueError() | ||
|
||
return inner_coroutine |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import asyncio | ||
import traceback | ||
|
||
import pytest | ||
|
||
from ophyd_async.core.device_collector import DeviceCollector | ||
from ophyd_async.core.devices import Device, DeviceVector, get_device_children | ||
from ophyd_async.core.utils import wait_for_connection | ||
|
||
|
||
class DummyBaseDevice(Device): | ||
def __init__(self) -> None: | ||
self.connected = False | ||
|
||
async def connect(self, sim=False): | ||
self.connected = True | ||
|
||
|
||
class DummyDeviceGroup(Device): | ||
def __init__(self, name: str) -> None: | ||
self.child1 = DummyBaseDevice() | ||
self.child2 = DummyBaseDevice() | ||
self.dict_with_children: DeviceVector[DummyBaseDevice] = DeviceVector( | ||
{123: DummyBaseDevice()} | ||
) | ||
self.set_name(name) | ||
|
||
|
||
@pytest.fixture | ||
def parent() -> DummyDeviceGroup: | ||
return DummyDeviceGroup("parent") | ||
|
||
|
||
def test_get_device_children(parent: DummyDeviceGroup): | ||
names = ["child1", "child2", "dict_with_children"] | ||
for idx, (name, child) in enumerate(get_device_children(parent)): | ||
assert name == names[idx] | ||
assert ( | ||
type(child) is DummyBaseDevice | ||
if name.startswith("child") | ||
else type(child) is DeviceVector | ||
) | ||
|
||
|
||
async def test_children_of_device_have_set_names_and_get_connected( | ||
parent: DummyDeviceGroup, | ||
): | ||
assert parent.name == "parent" | ||
assert parent.child1.name == "parent-child1" | ||
assert parent.child2.name == "parent-child2" | ||
assert parent.dict_with_children.name == "parent-dict_with_children" | ||
assert parent.dict_with_children[123].name == "parent-dict_with_children-123" | ||
|
||
await parent.connect() | ||
|
||
assert parent.child1.connected | ||
assert parent.dict_with_children[123].connected | ||
|
||
|
||
async def test_device_with_device_collector(): | ||
async with DeviceCollector(sim=True): | ||
parent = DummyDeviceGroup("parent") | ||
|
||
assert parent.name == "parent" | ||
assert parent.child1.name == "parent-child1" | ||
assert parent.child2.name == "parent-child2" | ||
assert parent.dict_with_children.name == "parent-dict_with_children" | ||
assert parent.dict_with_children[123].name == "parent-dict_with_children-123" | ||
assert parent.child1.connected | ||
assert parent.dict_with_children[123].connected | ||
|
||
|
||
async def test_wait_for_connection(): | ||
class DummyDeviceWithSleep(DummyBaseDevice): | ||
def __init__(self, name) -> None: | ||
self.set_name(name) | ||
|
||
async def connect(self, sim=False): | ||
await asyncio.sleep(0.01) | ||
self.connected = True | ||
|
||
device1, device2 = DummyDeviceWithSleep("device1"), DummyDeviceWithSleep("device2") | ||
|
||
normal_coros = {"device1": device1.connect(), "device2": device2.connect()} | ||
|
||
await wait_for_connection(**normal_coros) | ||
|
||
assert device1.connected | ||
assert device2.connected | ||
|
||
|
||
async def test_wait_for_connection_propagates_error( | ||
normal_coroutine, failing_coroutine | ||
): | ||
failing_coros = {"test": normal_coroutine(), "failing": failing_coroutine()} | ||
|
||
with pytest.raises(ValueError) as e: | ||
await wait_for_connection(**failing_coros) | ||
assert traceback.extract_tb(e.__traceback__)[-1].name == "failing_coroutine" |
Oops, something went wrong.