Skip to content

Commit

Permalink
Add unparse_data() function returning bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
trossi committed May 16, 2024
1 parent a6af5fd commit 1c1cbfb
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
9 changes: 2 additions & 7 deletions rdata/tests/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import io
import tempfile
from contextlib import contextmanager
from pathlib import Path
Expand Down Expand Up @@ -62,14 +61,11 @@ def test_unparse(fname: str) -> None:

r_data = rdata.parser.parse_data(data, expand_altrep=False)

fd = io.BytesIO()
try:
unparse_data(fd, r_data, file_format=fmt, rds=rds)
out_data = unparse_data(r_data, file_format=fmt, rds=rds)
except NotImplementedError as e:
pytest.xfail(str(e))

out_data = fd.getvalue()

if fmt == "ascii":
data = data.replace(b"\r\n", b"\n")

Expand Down Expand Up @@ -137,9 +133,8 @@ def test_unparse_big_int() -> None:
"""Test checking too large integers."""
big_int = 2**32
r_data = rdata.conversion.convert_to_r_data(big_int)
fd = io.BytesIO()
with pytest.raises(ValueError, match="(?i)not castable"):
unparse_data(fd, r_data, file_format="xdr")
unparse_data(r_data, file_format="xdr")


@pytest.mark.parametrize("compression", [*valid_compressions, None, "fail"])
Expand Down
35 changes: 33 additions & 2 deletions rdata/unparser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import io
from typing import IO, TYPE_CHECKING, Any

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,10 +53,10 @@ def unparse_file(
raise ValueError(msg)

with open(path, "wb") as f:
unparse_data(f, r_data, file_format=file_format, rds=rds)
unparse_fileobj(f, r_data, file_format=file_format, rds=rds)


def unparse_data(
def unparse_fileobj(
fileobj: IO[Any],
r_data: RData,
*,
Expand All @@ -73,6 +74,8 @@ def unparse_data(
RData object
file_format:
File format (ascii or xdr)
rds:
Whether to create RDS or RDA file
"""
Unparser: type[UnparserXDR | UnparserASCII] # noqa: N806

Expand All @@ -86,3 +89,31 @@ def unparse_data(

unparser = Unparser(fileobj) # type: ignore [arg-type]
unparser.unparse_r_data(r_data, rds=rds)


def unparse_data(
r_data: RData,
*,
file_format: str = "xdr",
rds: bool = True,
) -> bytes:
"""
Unparse RData object to a bytestring.
Parameters
----------
r_data:
RData object
file_format:
File format (ascii or xdr)
rds:
Whether to create RDS or RDA file
Returns:
-------
data:
Bytestring of data
"""
fd = io.BytesIO()
unparse_fileobj(fd, r_data, file_format=file_format, rds=rds)
return fd.getvalue()

0 comments on commit 1c1cbfb

Please sign in to comment.