Skip to content

Commit

Permalink
get append by iterable working
Browse files Browse the repository at this point in the history
  • Loading branch information
paleolimbot committed Nov 21, 2024
1 parent d0e7c94 commit ad09b2b
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 9 deletions.
10 changes: 9 additions & 1 deletion python/src/nanoarrow/_schema.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,8 @@ cdef class CSchemaView:
(_types.TIMESTAMP, _types.DATE64, _types.DURATION)
):
return 'q'
elif self.extension_name:
return self._get_buffer_format()
else:
return None

Expand All @@ -564,7 +566,13 @@ cdef class CSchemaView:
or None if there is no Python format string that can represent this
type without loosing information.
"""
if self.extension_name or self._schema_view.type != self._schema_view.storage_type:
if self.extension_name:
return None
else:
return self._get_buffer_format()

def _get_buffer_format(self):
if self._schema_view.type != self._schema_view.storage_type:
return None

# String/binary types do not have format strings as far as the Python
Expand Down
8 changes: 8 additions & 0 deletions python/src/nanoarrow/c_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nanoarrow._utils import obj_is_buffer, obj_is_capsule
from nanoarrow.c_buffer import c_buffer
from nanoarrow.c_schema import c_schema, c_schema_view
from nanoarrow.extension import resolve_extension

from nanoarrow import _types

Expand Down Expand Up @@ -462,6 +463,13 @@ def __init__(self, schema):

# Resolve the method name we are going to use to do the building from
# the provided schema.
ext = resolve_extension(self._schema_view)
if ext is not None:
maybe_appender = ext.get_iterable_appender(self._schema, self)
if maybe_appender:
self._append_impl = maybe_appender
return

type_id = self._schema_view.type_id
if type_id not in _ARRAY_BUILDER_FROM_ITERABLE_METHOD:
raise ValueError(
Expand Down
12 changes: 10 additions & 2 deletions python/src/nanoarrow/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Iterator, Mapping, Optional, Type
from typing import Any, Iterator, Mapping, Optional, Type, Callable, Iterable

from nanoarrow.c_schema import CSchema, CSchemaView, c_schema_view
from nanoarrow.c_array import CArrayBuilder


class Extension:
Expand All @@ -38,6 +39,11 @@ def get_pyiter(
def get_sequence_converter(self, c_schema: CSchema):
return None

def get_iterable_appender(
self, c_schema: CSchema, array_builder
) -> Optional[Callable[[Iterable], None]]:
return None


_global_extension_registry = {}

Expand All @@ -59,7 +65,9 @@ def register_extension(extension: Extension) -> Optional[Extension]:
else:
key = schema_view.type_id

prev = _global_extension_registry[key] if key in _global_extension_registry else None
prev = (
_global_extension_registry[key] if key in _global_extension_registry else None
)
_global_extension_registry[key] = extension
return prev

Expand Down
7 changes: 6 additions & 1 deletion python/src/nanoarrow/extension_canonical.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from typing import Any, Iterator, Mapping, Optional

from nanoarrow.c_array import CArrayBuilder
from nanoarrow.c_buffer import CBufferBuilder
from nanoarrow.c_schema import CSchema, c_schema_view
from nanoarrow.schema import extension_type, int8
Expand All @@ -33,7 +34,7 @@ def bool8(nullable: bool = True):
Use ``False`` to mark this field as non-nullable.
"""

return extension_type(int8(nullable=nullable), "arrow.bool8")
return extension_type(int8(), "arrow.bool8", nullable=nullable)


class Bool8SequenceConverter(ToPyBufferConverter):
Expand Down Expand Up @@ -72,3 +73,7 @@ def get_pyiter(
def get_sequence_converter(self, c_schema: CSchema):
self.get_params(c_schema)
return Bool8SequenceConverter

def get_sequence_appender(self, c_schema: CSchema, array_builder):
self.get_params(c_schema)
return None
10 changes: 5 additions & 5 deletions python/tests/test_canonical_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ def test_extension_bool8():
assert schema.extension.name == "arrow.bool8"
assert schema.extension.metadata == b""

buf = na.c_buffer([1, 0, 1, 1], na.int8())
bool8_array = na.Array(na.c_array_from_buffers(schema, 4, [None, buf]))
assert na.bool8(nullable=False).nullable is False

bool8_array = na.Array([True, False, True, True], na.bool8())
assert bool8_array.schema.type == na.Type.EXTENSION
assert bool8_array.schema.extension.name == "arrow.bool8"
assert bool8_array.to_pylist() == [True, False, True, True]

sequence = bool8_array.to_pysequence()
assert list(sequence) == [True, False, True, True]

validity = na.c_buffer([True, True, False, True], na.bool_())
bool8_array = na.Array(na.c_array_from_buffers(schema, 4, [validity, buf]))
bool8_array = na.Array([True, False, None, True], na.bool8())
assert bool8_array.to_pylist() == [True, False, None, True]

sequence = bool8_array.to_pysequence(handle_nulls=na.nulls_separate())
assert list(sequence[1]) == [True, False, True, True]
assert list(sequence[1]) == [True, False, False, True]
assert list(sequence[0]) == [True, True, False, True]

0 comments on commit ad09b2b

Please sign in to comment.