diff --git a/src/stdio_mgr/stdio_mgr.py b/src/stdio_mgr/stdio_mgr.py index 844fd94..40af42a 100644 --- a/src/stdio_mgr/stdio_mgr.py +++ b/src/stdio_mgr/stdio_mgr.py @@ -27,12 +27,57 @@ """ import sys -from contextlib import contextmanager -from io import BufferedReader, BytesIO, StringIO, TextIOBase, TextIOWrapper +from contextlib import contextmanager, ExitStack, suppress +from io import BufferedRandom, BufferedReader, BytesIO, TextIOBase, TextIOWrapper import attr +class _PersistedBytesIO(BytesIO): + """Class to persist the stream after close. + + The persisted stream is available at _closed_buf. + """ + + def close(self): + self._closed_buf = self.getvalue() + super().close() + + +class RandomTextIO(TextIOWrapper): + """Class to capture writes to a buffer even when detached. + + Subclass of :cls:`~io.TextIOWrapper` that utilises an internal + buffer defaulting to utf-8 encoding. + + All writes are flushed to the buffer. + + This class provides :meth:`~RandomTextIO.getvalue` which emulates the + behavior of :meth:`~io.StringIO.getvalue`, decoding the buffer + using the :attr:`~io.TextIOWrapper.encoding`. The value is available + even if the stream is detached or closed. + """ + + def __init__(self): + """Initialise buffer with utf-8 encoding.""" + self._stream = _PersistedBytesIO() + self._encoding = "utf-8" + self._buf = BufferedRandom(self._stream) + super().__init__(self._buf, encoding=self._encoding) + + def write(self, *args, **kwargs): + """Flush after each write.""" + super().write(*args, **kwargs) + self.flush() + + def getvalue(self): + """Obtain buffer of text sent to the stream.""" + if self._stream.closed: + return self._stream._closed_buf.decode(self.encoding) + else: + return self._stream.getvalue().decode(self.encoding) + + @attr.s(slots=False) class TeeStdin(TextIOWrapper): """Class to tee contents to a side buffer on read. @@ -152,13 +197,44 @@ def getvalue(self): return self.buffer.peek().decode(self.encoding) +class _SafeCloseIOBase(TextIOBase): + """Class to ignore ValueError when exiting the context. + + Subclass of :cls:`~io.TextIOBase` that disregards ValueError which can + occur if the file has already been closed. + """ + + def __exit__(self, exc_type, exc_value, traceback): + """Suppress ValueError while exiting context. + + ValueError may occur when the underlying + buffer is detached or the file was closed. + """ + with suppress(ValueError): + super().__exit__(exc_type, exc_value, traceback) + + +class SafeCloseRandomTextIO(_SafeCloseIOBase, RandomTextIO): + """Class to capture writes to a buffer even when detached and safely close. + + Subclass of :cls:`~_SafeCloseIOBase` and :cls:`~TeeStdin`. + """ + + +class SafeCloseTeeStdin(_SafeCloseIOBase, TeeStdin): + """Class to tee contents to a side buffer on read and safely close. + + Subclass of :cls:`~_SafeCloseIOBase` and :cls:`~TeeStdin`. + """ + + @contextmanager -def stdio_mgr(in_str=""): +def stdio_mgr(in_str="", close=True): r"""Subsitute temporary text buffers for `stdio` in a managed context. Context manager. - Substitutes empty :cls:`~io.StringIO`\ s for + Substitutes empty :cls:`~io.RandomTextIO`\ s for :cls:`sys.stdout` and :cls:`sys.stderr`, and a :cls:`TeeStdin` for :cls:`sys.stdin` within the managed context. @@ -181,22 +257,32 @@ def stdio_mgr(in_str=""): out_ - :cls:`~io.StringIO` -- Temporary stream for `stdout`, + :cls:`~io.RandomTextIO` -- Temporary stream for `stdout`, initially empty. err_ - :cls:`~io.StringIO` -- Temporary stream for `stderr`, + :cls:`~io.RandomTextIO` -- Temporary stream for `stderr`, initially empty. """ + if close: + out_cls = SafeCloseRandomTextIO + in_cls = SafeCloseTeeStdin + else: + out_cls = RandomTextIO + in_cls = TeeStdin + old_stdin = sys.stdin old_stdout = sys.stdout old_stderr = sys.stderr - new_stdout = StringIO() - new_stderr = StringIO() - new_stdin = TeeStdin(new_stdout, in_str) + with ExitStack() as stack: + new_stdout = stack.enter_context(out_cls()) + new_stderr = stack.enter_context(out_cls()) + new_stdin = stack.enter_context(in_cls(new_stdout, in_str)) + + close_files = stack.pop_all().close sys.stdin = new_stdin sys.stdout = new_stdout @@ -208,14 +294,5 @@ def stdio_mgr(in_str=""): sys.stdout = old_stdout sys.stderr = old_stderr - try: - closed = new_stdin.closed - except ValueError: - # ValueError occurs when the underlying buffer is detached - pass - else: - if not closed: - new_stdin.close() - - new_stdout.close() - new_stderr.close() + if close: + close_files() diff --git a/tests/test_stdiomgr_base.py b/tests/test_stdiomgr_base.py index 5cd3167..931a9e8 100644 --- a/tests/test_stdiomgr_base.py +++ b/tests/test_stdiomgr_base.py @@ -26,9 +26,10 @@ """ - +import io import warnings +import pytest from stdio_mgr import stdio_mgr @@ -114,8 +115,108 @@ def test_repeated_use(): test_capture_stderr() +def test_manual_close(): + """Confirm files remain open if close=False after the context has exited.""" + with stdio_mgr(close=False) as (i, o, e): + test_default_stdin() + test_capture_stderr() + assert not i.closed + assert not o.closed + assert not e.closed + + i.close() + o.close() + e.close() + + +def test_manual_close_detached_fails(): + """Confirm files remain open if close=False after the context has exited.""" + with stdio_mgr(close=False) as (i, o, e): + test_default_stdin() + test_capture_stderr() + i.detach() + o.detach() + e.detach() + + with pytest.raises(ValueError): + i.close() + with pytest.raises(ValueError): + i.closed + with pytest.raises(ValueError): + o.close() + with pytest.raises(ValueError): + o.closed + with pytest.raises(ValueError): + e.close() + + def test_stdin_detached(): """Confirm stdin's buffer can be detached within the context.""" with stdio_mgr() as (i, o, e): + print("test str") + f = i.detach() + + assert "test str\n" == o.getvalue() + + print("second test str") + + assert "test str\nsecond test str\n" == o.getvalue() + assert not f.closed + assert o.closed + assert e.closed + + +def test_stdout_detached(): + """Confirm stdout's buffer can be detached within the context. + + Like the real sys.stdout, writes after detach should fail, however + writes to the detached stream should be captured. + """ + with stdio_mgr() as (i, o, e): + print("test str") + + f = o.detach() + + assert f is o._buf + assert f is i.tee._buf + + assert "test str\n" == o.getvalue() + + with pytest.raises(ValueError): + print("anything") + + f.write("second test str\n".encode("utf8")) + f.flush() + + assert "test str\nsecond test str\n" == o.getvalue() + + assert not f.closed + assert i.closed + assert e.closed + + +def test_stdout_detached_and_closed(): + """Confirm stdout's buffer can be detached within the context. + + Like the real sys.stdout, writes after detach should fail, however + writes to the detached stream should be captured. + """ + with stdio_mgr() as (i, o, e): + print("test str") + + f = o.detach() + + assert isinstance(f, io.BufferedRandom) + assert f is o._buf + + assert "test str\n" == o.getvalue() + + with pytest.raises(ValueError): + print("anything") + + f.write("second test str\n".encode("utf8")) + f.close() + + assert "test str\nsecond test str\n" == o.getvalue()