Skip to content

Commit

Permalink
Let SharedMemory and NDArray support with blocks
Browse files Browse the repository at this point in the history
To do this, we subclass Python's SharedMemory class,
rather than simply passing it along verbatim anymore.

And we implement the __exit__ method to call the Appose SharedMemory
subclass's dispose() method, which calls either close() or unlink()
depending on whether the unlink_on_dispose flag is set.

Note that this implementation cannot quite align with the Java one,
because in Java, the AutoCloseable interface always calls close().
As such, the close() method must be overridden and taught to sometimes
call unlink(), and sometimes not, depending on the unlinkOnClose flag.
  • Loading branch information
ctrueden committed Aug 10, 2024
1 parent 522f0e9 commit 84d45e6
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 74 deletions.
85 changes: 64 additions & 21 deletions src/appose/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,65 @@
import json
import re
from math import ceil, prod
from multiprocessing import resource_tracker
from multiprocessing.shared_memory import SharedMemory
from multiprocessing import resource_tracker, shared_memory
from typing import Any, Dict, Sequence, Union

Args = Dict[str, Any]


class SharedMemory(shared_memory.SharedMemory):
"""
An enhanced version of Python's multiprocessing.shared_memory.SharedMemory
class which can be used with a `with` statement. When the program flow
exits the `with` block, this class's `dispose()` method will be invoked,
which might call `close()` or `unlink()` depending on the value of its
`unlink_on_dispose` flag.
"""

def __init__(self, name: str = None, create: bool = False, size: int = 0):
super().__init__(name=name, create=create, size=size)
self._unlink_on_dispose = create
if _is_worker:
# HACK: Remove this shared memory block from the resource_tracker,
# which wants to clean up shared memory blocks after all known
# references are done using them.
#
# There is one resource_tracker per Python process, and they will
# each try to delete shared memory blocks known to them when they
# are shutting down, even when other processes still need them.
#
# As such, the rule Appose follows is: let the service process
# always handle cleanup of shared memory blocks, regardless of
# which process initially allocated it.
resource_tracker.unregister(self._name, "shared_memory")

def unlink_on_dispose(self, value: bool) -> None:
"""
Set whether the `unlink()` method should be invoked to destroy
the shared memory block when the `dispose()` method is called.
Note: dispose() is the method called when exiting a `with` block.
By default, shared memory objects constructed with `create=True`
will behave this way, whereas shared memory objects constructed
with `create=False` will not. But this method allows to override
the behavior.
"""
self._unlink_on_dispose = value

def dispose(self) -> None:
if self._unlink_on_dispose:
self.unlink()
else:
self.close()

def __enter__(self) -> "SharedMemory":
return self

def __exit__(self, exc_type, exc_value, exc_tb) -> None:
self.dispose()


def encode(data: Args) -> str:
return json.dumps(data, cls=_ApposeJSONEncoder, separators=(",", ":"))

Expand All @@ -63,7 +115,9 @@ def __init__(self, dtype: str, shape: Sequence[int], shm: SharedMemory = None):
self.dtype = dtype
self.shape = shape
self.shm = (
_create_shm(create=True, size=ceil(prod(shape) * _bytes_per_element(dtype)))
SharedMemory(
create=True, size=ceil(prod(shape) * _bytes_per_element(dtype))
)
if shm is None
else shm
)
Expand Down Expand Up @@ -91,6 +145,12 @@ def ndarray(self):
except ModuleNotFoundError:
raise ImportError("NumPy is not available.")

def __enter__(self) -> "NDArray":
return self

def __exit__(self, exc_type, exc_value, exc_tb) -> None:
self.shm.dispose()


class _ApposeJSONEncoder(json.JSONEncoder):
def default(self, obj):
Expand All @@ -114,7 +174,7 @@ def _appose_object_hook(obj: Dict):
atype = obj.get("appose_type")
if atype == "shm":
# Attach to existing shared memory block.
return _create_shm(name=(obj["name"]), size=(obj["size"]))
return SharedMemory(name=(obj["name"]), size=(obj["size"]))
elif atype == "ndarray":
return NDArray(obj["dtype"], obj["shape"], obj["shm"])
else:
Expand All @@ -129,23 +189,6 @@ def _bytes_per_element(dtype: str) -> Union[int, float]:
return bits / 8


def _create_shm(name: str = None, create: bool = False, size: int = 0):
shm = SharedMemory(name=name, create=create, size=size)
if _is_worker:
# HACK: Disable this process's resource_tracker, which wants to clean up
# shared memory blocks after all known references are done using them.
#
# There is one resource_tracker per Python process, and they will each
# try to delete shared memory blocks known to them when they are
# shutting down, even when other processes still need them.
#
# As such, the rule Appose follows is: let the service process always
# do the cleanup of shared memory blocks, regardless of which process
# initially allocated it.
resource_tracker.unregister(shm._name, "shared_memory")
return shm


_is_worker = False


Expand Down
33 changes: 15 additions & 18 deletions tests/test_shm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,20 @@
def test_ndarray():
env = appose.system()
with env.python() as service:
# Construct the data.
shm = appose.SharedMemory(create=True, size=2 * 2 * 20 * 25)
shm.buf[0] = 123
shm.buf[456] = 78
shm.buf[1999] = 210
data = appose.NDArray("uint16", [2, 20, 25], shm)
with appose.SharedMemory(create=True, size=2 * 2 * 20 * 25) as shm:
# Construct the data.
shm.buf[0] = 123
shm.buf[456] = 78
shm.buf[1999] = 210
data = appose.NDArray("uint16", [2, 20, 25], shm)

# Run the task.
task = service.task(ndarray_inspect, {"data": data})
task.wait_for()
# Run the task.
task = service.task(ndarray_inspect, {"data": data})
task.wait_for()

# Validate the execution result.
assert TaskStatus.COMPLETE == task.status
assert 2 * 20 * 25 * 2 == task.outputs["size"]
assert "uint16" == task.outputs["dtype"]
assert [2, 20, 25] == task.outputs["shape"]
assert 123 + 78 + 210 == task.outputs["sum"]

# Clean up.
shm.unlink()
# Validate the execution result.
assert TaskStatus.COMPLETE == task.status
assert 2 * 20 * 25 * 2 == task.outputs["size"]
assert "uint16" == task.outputs["dtype"]
assert [2, 20, 25] == task.outputs["shape"]
assert 123 + 78 + 210 == task.outputs["sum"]
68 changes: 33 additions & 35 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,40 +65,38 @@ def test_encode(self):
"numbers": self.NUMBERS,
"words": self.WORDS,
}
ndarray = appose.NDArray("float32", [2, 20, 25])
shm_name = ndarray.shm.name
data["ndArray"] = ndarray
json_str = appose.types.encode(data)
self.assertIsNotNone(json_str)
expected = self.JSON.replace("SHM_NAME", shm_name)
self.assertEqual(expected, json_str)
ndarray.shm.unlink()
with appose.NDArray("float32", [2, 20, 25]) as ndarray:
shm_name = ndarray.shm.name
data["ndArray"] = ndarray
json_str = appose.types.encode(data)
self.assertIsNotNone(json_str)
expected = self.JSON.replace("SHM_NAME", shm_name)
self.assertEqual(expected, json_str)

def test_decode(self):
shm = appose.SharedMemory(create=True, size=4000)
shm_name = shm.name
data = appose.types.decode(self.JSON.replace("SHM_NAME", shm_name))
self.assertIsNotNone(data)
self.assertEqual(19, len(data))
self.assertEqual(123, data["posByte"])
self.assertEqual(-98, data["negByte"])
self.assertEqual(9.876543210123456, data["posDouble"])
self.assertEqual(-1.234567890987654e302, data["negDouble"])
self.assertEqual(9.876543, data["posFloat"])
self.assertEqual(-1.2345678, data["negFloat"])
self.assertEqual(1234567890, data["posInt"])
self.assertEqual(-987654321, data["negInt"])
self.assertEqual(12345678987654321, data["posLong"])
self.assertEqual(-98765432123456789, data["negLong"])
self.assertEqual(32109, data["posShort"])
self.assertEqual(-23456, data["negShort"])
self.assertTrue(data["trueBoolean"])
self.assertFalse(data["falseBoolean"])
self.assertEqual("\0", data["nullChar"])
self.assertEqual(self.STRING, data["aString"])
self.assertEqual(self.NUMBERS, data["numbers"])
self.assertEqual(self.WORDS, data["words"])
ndArray = data["ndArray"]
self.assertEqual("float32", ndArray.dtype)
self.assertEqual([2, 20, 25], ndArray.shape)
shm.unlink()
with appose.SharedMemory(create=True, size=4000) as shm:
shm_name = shm.name
data = appose.types.decode(self.JSON.replace("SHM_NAME", shm_name))
self.assertIsNotNone(data)
self.assertEqual(19, len(data))
self.assertEqual(123, data["posByte"])
self.assertEqual(-98, data["negByte"])
self.assertEqual(9.876543210123456, data["posDouble"])
self.assertEqual(-1.234567890987654e302, data["negDouble"])
self.assertEqual(9.876543, data["posFloat"])
self.assertEqual(-1.2345678, data["negFloat"])
self.assertEqual(1234567890, data["posInt"])
self.assertEqual(-987654321, data["negInt"])
self.assertEqual(12345678987654321, data["posLong"])
self.assertEqual(-98765432123456789, data["negLong"])
self.assertEqual(32109, data["posShort"])
self.assertEqual(-23456, data["negShort"])
self.assertTrue(data["trueBoolean"])
self.assertFalse(data["falseBoolean"])
self.assertEqual("\0", data["nullChar"])
self.assertEqual(self.STRING, data["aString"])
self.assertEqual(self.NUMBERS, data["numbers"])
self.assertEqual(self.WORDS, data["words"])
ndArray = data["ndArray"]
self.assertEqual("float32", ndArray.dtype)
self.assertEqual([2, 20, 25], ndArray.shape)

0 comments on commit 84d45e6

Please sign in to comment.