From 84d45e6c5a89d69c1dc64c469b93370c1bd33f3e Mon Sep 17 00:00:00 2001 From: Curtis Rueden Date: Sat, 10 Aug 2024 16:22:41 -0500 Subject: [PATCH] Let SharedMemory and NDArray support `with` blocks To do this, we subclass Python's SharedMemory class, rather than simply passing it along verbatim anymore. And we implement the __exit__ method to call the Appose SharedMemory subclass's dispose() method, which calls either close() or unlink() depending on whether the unlink_on_dispose flag is set. Note that this implementation cannot quite align with the Java one, because in Java, the AutoCloseable interface always calls close(). As such, the close() method must be overridden and taught to sometimes call unlink(), and sometimes not, depending on the unlinkOnClose flag. --- src/appose/types.py | 85 ++++++++++++++++++++++++++++++++++----------- tests/test_shm.py | 33 ++++++++---------- tests/test_types.py | 68 ++++++++++++++++++------------------ 3 files changed, 112 insertions(+), 74 deletions(-) diff --git a/src/appose/types.py b/src/appose/types.py index dbd877e..363ea7e 100644 --- a/src/appose/types.py +++ b/src/appose/types.py @@ -30,13 +30,65 @@ import json import re from math import ceil, prod -from multiprocessing import resource_tracker -from multiprocessing.shared_memory import SharedMemory +from multiprocessing import resource_tracker, shared_memory from typing import Any, Dict, Sequence, Union Args = Dict[str, Any] +class SharedMemory(shared_memory.SharedMemory): + """ + An enhanced version of Python's multiprocessing.shared_memory.SharedMemory + class which can be used with a `with` statement. When the program flow + exits the `with` block, this class's `dispose()` method will be invoked, + which might call `close()` or `unlink()` depending on the value of its + `unlink_on_dispose` flag. + """ + + def __init__(self, name: str = None, create: bool = False, size: int = 0): + super().__init__(name=name, create=create, size=size) + self._unlink_on_dispose = create + if _is_worker: + # HACK: Remove this shared memory block from the resource_tracker, + # which wants to clean up shared memory blocks after all known + # references are done using them. + # + # There is one resource_tracker per Python process, and they will + # each try to delete shared memory blocks known to them when they + # are shutting down, even when other processes still need them. + # + # As such, the rule Appose follows is: let the service process + # always handle cleanup of shared memory blocks, regardless of + # which process initially allocated it. + resource_tracker.unregister(self._name, "shared_memory") + + def unlink_on_dispose(self, value: bool) -> None: + """ + Set whether the `unlink()` method should be invoked to destroy + the shared memory block when the `dispose()` method is called. + + Note: dispose() is the method called when exiting a `with` block. + + By default, shared memory objects constructed with `create=True` + will behave this way, whereas shared memory objects constructed + with `create=False` will not. But this method allows to override + the behavior. + """ + self._unlink_on_dispose = value + + def dispose(self) -> None: + if self._unlink_on_dispose: + self.unlink() + else: + self.close() + + def __enter__(self) -> "SharedMemory": + return self + + def __exit__(self, exc_type, exc_value, exc_tb) -> None: + self.dispose() + + def encode(data: Args) -> str: return json.dumps(data, cls=_ApposeJSONEncoder, separators=(",", ":")) @@ -63,7 +115,9 @@ def __init__(self, dtype: str, shape: Sequence[int], shm: SharedMemory = None): self.dtype = dtype self.shape = shape self.shm = ( - _create_shm(create=True, size=ceil(prod(shape) * _bytes_per_element(dtype))) + SharedMemory( + create=True, size=ceil(prod(shape) * _bytes_per_element(dtype)) + ) if shm is None else shm ) @@ -91,6 +145,12 @@ def ndarray(self): except ModuleNotFoundError: raise ImportError("NumPy is not available.") + def __enter__(self) -> "NDArray": + return self + + def __exit__(self, exc_type, exc_value, exc_tb) -> None: + self.shm.dispose() + class _ApposeJSONEncoder(json.JSONEncoder): def default(self, obj): @@ -114,7 +174,7 @@ def _appose_object_hook(obj: Dict): atype = obj.get("appose_type") if atype == "shm": # Attach to existing shared memory block. - return _create_shm(name=(obj["name"]), size=(obj["size"])) + return SharedMemory(name=(obj["name"]), size=(obj["size"])) elif atype == "ndarray": return NDArray(obj["dtype"], obj["shape"], obj["shm"]) else: @@ -129,23 +189,6 @@ def _bytes_per_element(dtype: str) -> Union[int, float]: return bits / 8 -def _create_shm(name: str = None, create: bool = False, size: int = 0): - shm = SharedMemory(name=name, create=create, size=size) - if _is_worker: - # HACK: Disable this process's resource_tracker, which wants to clean up - # shared memory blocks after all known references are done using them. - # - # There is one resource_tracker per Python process, and they will each - # try to delete shared memory blocks known to them when they are - # shutting down, even when other processes still need them. - # - # As such, the rule Appose follows is: let the service process always - # do the cleanup of shared memory blocks, regardless of which process - # initially allocated it. - resource_tracker.unregister(shm._name, "shared_memory") - return shm - - _is_worker = False diff --git a/tests/test_shm.py b/tests/test_shm.py index 965e398..68dcf8b 100644 --- a/tests/test_shm.py +++ b/tests/test_shm.py @@ -41,23 +41,20 @@ def test_ndarray(): env = appose.system() with env.python() as service: - # Construct the data. - shm = appose.SharedMemory(create=True, size=2 * 2 * 20 * 25) - shm.buf[0] = 123 - shm.buf[456] = 78 - shm.buf[1999] = 210 - data = appose.NDArray("uint16", [2, 20, 25], shm) + with appose.SharedMemory(create=True, size=2 * 2 * 20 * 25) as shm: + # Construct the data. + shm.buf[0] = 123 + shm.buf[456] = 78 + shm.buf[1999] = 210 + data = appose.NDArray("uint16", [2, 20, 25], shm) - # Run the task. - task = service.task(ndarray_inspect, {"data": data}) - task.wait_for() + # Run the task. + task = service.task(ndarray_inspect, {"data": data}) + task.wait_for() - # Validate the execution result. - assert TaskStatus.COMPLETE == task.status - assert 2 * 20 * 25 * 2 == task.outputs["size"] - assert "uint16" == task.outputs["dtype"] - assert [2, 20, 25] == task.outputs["shape"] - assert 123 + 78 + 210 == task.outputs["sum"] - - # Clean up. - shm.unlink() + # Validate the execution result. + assert TaskStatus.COMPLETE == task.status + assert 2 * 20 * 25 * 2 == task.outputs["size"] + assert "uint16" == task.outputs["dtype"] + assert [2, 20, 25] == task.outputs["shape"] + assert 123 + 78 + 210 == task.outputs["sum"] diff --git a/tests/test_types.py b/tests/test_types.py index 445c749..dc38d23 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -65,40 +65,38 @@ def test_encode(self): "numbers": self.NUMBERS, "words": self.WORDS, } - ndarray = appose.NDArray("float32", [2, 20, 25]) - shm_name = ndarray.shm.name - data["ndArray"] = ndarray - json_str = appose.types.encode(data) - self.assertIsNotNone(json_str) - expected = self.JSON.replace("SHM_NAME", shm_name) - self.assertEqual(expected, json_str) - ndarray.shm.unlink() + with appose.NDArray("float32", [2, 20, 25]) as ndarray: + shm_name = ndarray.shm.name + data["ndArray"] = ndarray + json_str = appose.types.encode(data) + self.assertIsNotNone(json_str) + expected = self.JSON.replace("SHM_NAME", shm_name) + self.assertEqual(expected, json_str) def test_decode(self): - shm = appose.SharedMemory(create=True, size=4000) - shm_name = shm.name - data = appose.types.decode(self.JSON.replace("SHM_NAME", shm_name)) - self.assertIsNotNone(data) - self.assertEqual(19, len(data)) - self.assertEqual(123, data["posByte"]) - self.assertEqual(-98, data["negByte"]) - self.assertEqual(9.876543210123456, data["posDouble"]) - self.assertEqual(-1.234567890987654e302, data["negDouble"]) - self.assertEqual(9.876543, data["posFloat"]) - self.assertEqual(-1.2345678, data["negFloat"]) - self.assertEqual(1234567890, data["posInt"]) - self.assertEqual(-987654321, data["negInt"]) - self.assertEqual(12345678987654321, data["posLong"]) - self.assertEqual(-98765432123456789, data["negLong"]) - self.assertEqual(32109, data["posShort"]) - self.assertEqual(-23456, data["negShort"]) - self.assertTrue(data["trueBoolean"]) - self.assertFalse(data["falseBoolean"]) - self.assertEqual("\0", data["nullChar"]) - self.assertEqual(self.STRING, data["aString"]) - self.assertEqual(self.NUMBERS, data["numbers"]) - self.assertEqual(self.WORDS, data["words"]) - ndArray = data["ndArray"] - self.assertEqual("float32", ndArray.dtype) - self.assertEqual([2, 20, 25], ndArray.shape) - shm.unlink() + with appose.SharedMemory(create=True, size=4000) as shm: + shm_name = shm.name + data = appose.types.decode(self.JSON.replace("SHM_NAME", shm_name)) + self.assertIsNotNone(data) + self.assertEqual(19, len(data)) + self.assertEqual(123, data["posByte"]) + self.assertEqual(-98, data["negByte"]) + self.assertEqual(9.876543210123456, data["posDouble"]) + self.assertEqual(-1.234567890987654e302, data["negDouble"]) + self.assertEqual(9.876543, data["posFloat"]) + self.assertEqual(-1.2345678, data["negFloat"]) + self.assertEqual(1234567890, data["posInt"]) + self.assertEqual(-987654321, data["negInt"]) + self.assertEqual(12345678987654321, data["posLong"]) + self.assertEqual(-98765432123456789, data["negLong"]) + self.assertEqual(32109, data["posShort"]) + self.assertEqual(-23456, data["negShort"]) + self.assertTrue(data["trueBoolean"]) + self.assertFalse(data["falseBoolean"]) + self.assertEqual("\0", data["nullChar"]) + self.assertEqual(self.STRING, data["aString"]) + self.assertEqual(self.NUMBERS, data["numbers"]) + self.assertEqual(self.WORDS, data["words"]) + ndArray = data["ndArray"] + self.assertEqual("float32", ndArray.dtype) + self.assertEqual([2, 20, 25], ndArray.shape)