From ad09b2ba4a20af04fba71aaa05fa17d08b05c5a6 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 21 Nov 2024 14:57:52 -0600 Subject: [PATCH] get append by iterable working --- python/src/nanoarrow/_schema.pyx | 10 +++++++++- python/src/nanoarrow/c_array.py | 8 ++++++++ python/src/nanoarrow/extension.py | 12 ++++++++++-- python/src/nanoarrow/extension_canonical.py | 7 ++++++- python/tests/test_canonical_extension.py | 10 +++++----- 5 files changed, 38 insertions(+), 9 deletions(-) diff --git a/python/src/nanoarrow/_schema.pyx b/python/src/nanoarrow/_schema.pyx index 3e82c0659..c1deaec72 100644 --- a/python/src/nanoarrow/_schema.pyx +++ b/python/src/nanoarrow/_schema.pyx @@ -555,6 +555,8 @@ cdef class CSchemaView: (_types.TIMESTAMP, _types.DATE64, _types.DURATION) ): return 'q' + elif self.extension_name: + return self._get_buffer_format() else: return None @@ -564,7 +566,13 @@ cdef class CSchemaView: or None if there is no Python format string that can represent this type without loosing information. """ - if self.extension_name or self._schema_view.type != self._schema_view.storage_type: + if self.extension_name: + return None + else: + return self._get_buffer_format() + + def _get_buffer_format(self): + if self._schema_view.type != self._schema_view.storage_type: return None # String/binary types do not have format strings as far as the Python diff --git a/python/src/nanoarrow/c_array.py b/python/src/nanoarrow/c_array.py index 0c71bda45..932aa174d 100644 --- a/python/src/nanoarrow/c_array.py +++ b/python/src/nanoarrow/c_array.py @@ -24,6 +24,7 @@ from nanoarrow._utils import obj_is_buffer, obj_is_capsule from nanoarrow.c_buffer import c_buffer from nanoarrow.c_schema import c_schema, c_schema_view +from nanoarrow.extension import resolve_extension from nanoarrow import _types @@ -462,6 +463,13 @@ def __init__(self, schema): # Resolve the method name we are going to use to do the building from # the provided schema. + ext = resolve_extension(self._schema_view) + if ext is not None: + maybe_appender = ext.get_iterable_appender(self._schema, self) + if maybe_appender: + self._append_impl = maybe_appender + return + type_id = self._schema_view.type_id if type_id not in _ARRAY_BUILDER_FROM_ITERABLE_METHOD: raise ValueError( diff --git a/python/src/nanoarrow/extension.py b/python/src/nanoarrow/extension.py index 88e545434..fa13f5aec 100644 --- a/python/src/nanoarrow/extension.py +++ b/python/src/nanoarrow/extension.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Iterator, Mapping, Optional, Type +from typing import Any, Iterator, Mapping, Optional, Type, Callable, Iterable from nanoarrow.c_schema import CSchema, CSchemaView, c_schema_view +from nanoarrow.c_array import CArrayBuilder class Extension: @@ -38,6 +39,11 @@ def get_pyiter( def get_sequence_converter(self, c_schema: CSchema): return None + def get_iterable_appender( + self, c_schema: CSchema, array_builder + ) -> Optional[Callable[[Iterable], None]]: + return None + _global_extension_registry = {} @@ -59,7 +65,9 @@ def register_extension(extension: Extension) -> Optional[Extension]: else: key = schema_view.type_id - prev = _global_extension_registry[key] if key in _global_extension_registry else None + prev = ( + _global_extension_registry[key] if key in _global_extension_registry else None + ) _global_extension_registry[key] = extension return prev diff --git a/python/src/nanoarrow/extension_canonical.py b/python/src/nanoarrow/extension_canonical.py index 62a2b35bc..f2813913c 100644 --- a/python/src/nanoarrow/extension_canonical.py +++ b/python/src/nanoarrow/extension_canonical.py @@ -17,6 +17,7 @@ from typing import Any, Iterator, Mapping, Optional +from nanoarrow.c_array import CArrayBuilder from nanoarrow.c_buffer import CBufferBuilder from nanoarrow.c_schema import CSchema, c_schema_view from nanoarrow.schema import extension_type, int8 @@ -33,7 +34,7 @@ def bool8(nullable: bool = True): Use ``False`` to mark this field as non-nullable. """ - return extension_type(int8(nullable=nullable), "arrow.bool8") + return extension_type(int8(), "arrow.bool8", nullable=nullable) class Bool8SequenceConverter(ToPyBufferConverter): @@ -72,3 +73,7 @@ def get_pyiter( def get_sequence_converter(self, c_schema: CSchema): self.get_params(c_schema) return Bool8SequenceConverter + + def get_sequence_appender(self, c_schema: CSchema, array_builder): + self.get_params(c_schema) + return None diff --git a/python/tests/test_canonical_extension.py b/python/tests/test_canonical_extension.py index a98bd96f1..461d06fec 100644 --- a/python/tests/test_canonical_extension.py +++ b/python/tests/test_canonical_extension.py @@ -25,8 +25,9 @@ def test_extension_bool8(): assert schema.extension.name == "arrow.bool8" assert schema.extension.metadata == b"" - buf = na.c_buffer([1, 0, 1, 1], na.int8()) - bool8_array = na.Array(na.c_array_from_buffers(schema, 4, [None, buf])) + assert na.bool8(nullable=False).nullable is False + + bool8_array = na.Array([True, False, True, True], na.bool8()) assert bool8_array.schema.type == na.Type.EXTENSION assert bool8_array.schema.extension.name == "arrow.bool8" assert bool8_array.to_pylist() == [True, False, True, True] @@ -34,10 +35,9 @@ def test_extension_bool8(): sequence = bool8_array.to_pysequence() assert list(sequence) == [True, False, True, True] - validity = na.c_buffer([True, True, False, True], na.bool_()) - bool8_array = na.Array(na.c_array_from_buffers(schema, 4, [validity, buf])) + bool8_array = na.Array([True, False, None, True], na.bool8()) assert bool8_array.to_pylist() == [True, False, None, True] sequence = bool8_array.to_pysequence(handle_nulls=na.nulls_separate()) - assert list(sequence[1]) == [True, False, True, True] + assert list(sequence[1]) == [True, False, False, True] assert list(sequence[0]) == [True, True, False, True]