Skip to content

Commit

Permalink
Support for nested types non-key fields
Browse files Browse the repository at this point in the history
  • Loading branch information
mroz45 committed Nov 27, 2024
1 parent e1fa43d commit 105633e
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 17 deletions.
3 changes: 3 additions & 0 deletions cpp/src/arrow/acero/asof_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,9 @@ class AsofJoinNode : public ExecNode {
case Type::LARGE_STRING:
case Type::BINARY:
case Type::LARGE_BINARY:
case Type::LIST:
case Type::FIXED_SIZE_LIST:
case Type::STRUCT:
return Status::OK();
default:
return Status::Invalid("Unsupported type for data field ", field->name(), " : ",
Expand Down
105 changes: 97 additions & 8 deletions cpp/src/arrow/acero/asof_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1286,14 +1286,6 @@ TRACED_TEST(AsofJoinTest, TestUnsupportedByType, {
field("r0_v0", float32())}));
})

TRACED_TEST(AsofJoinTest, TestUnsupportedDatatype, {
// List is unsupported
DoRunInvalidTypeTest(
schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}),
schema({field("time", int64()), field("key", int32()),
field("r0_v0", list(int32()))}));
})

TRACED_TEST(AsofJoinTest, TestMissingKeys, {
DoRunMissingKeysTest(
schema({field("time1", int64()), field("key", int32()), field("l_v0", float64())}),
Expand Down Expand Up @@ -1732,5 +1724,102 @@ TEST(AsofJoinTest, RhsEmptinessRaceEmptyBy) {
AssertExecBatchesEqualIgnoringOrder(result.schema, {exp_batch}, result.batches);
}

// GH-44729: Testing nested data type for non-key fields
TEST(AsofJoinTest, FixedListDataType) {
const int32_t list_size = 3;
auto list_type = arrow::fixed_size_list(arrow::int32(), list_size);

auto left_batch = ExecBatchFromJSON({int64()}, R"([[1], [2], [3]])");
auto right_batch = ExecBatchFromJSON({list_type, int64()}, R"([
[[0, 1, 2], 2],
[[3, 4, 5], 3],
[[6, 7, 8], 4]
])");

Declaration left{"exec_batch_source",
ExecBatchSourceNodeOptions(schema({field("on", int64())}),
{std::move(left_batch)})};
Declaration right{"exec_batch_source",
ExecBatchSourceNodeOptions(
schema({field("colVals", list_type), field("on", int64())}),
{std::move(right_batch)})};

AsofJoinNodeOptions asof_join_opts({{{"on"}, {}}, {{"on"}, {}}}, 1);
Declaration asof_join{
"asofjoin", {std::move(left), std::move(right)}, std::move(asof_join_opts)};

ASSERT_OK_AND_ASSIGN(auto result, DeclarationToExecBatches(std::move(asof_join)));

auto exp_batch = ExecBatchFromJSON({int64(), list_type}, R"([
[1, [0, 1, 2]],
[2, [0, 1, 2]],
[3, [3, 4, 5]]
])");

AssertExecBatchesEqual(result.schema, {exp_batch}, result.batches);
}

TEST(AsofJoinTest, ListDataType) {
auto list_type = list(int32());

auto left_batch = ExecBatchFromJSON({int64()}, R"([[1], [2], [3]])");
auto right_batch = ExecBatchFromJSON({list_type, int64()}, R"([
[[0, 1, 2, 9], 2],
[[3, 4, 5, 7], 3],
[[6, 7, 8], 4]
])");

Declaration left{"exec_batch_source",
ExecBatchSourceNodeOptions(schema({field("on", int64())}),
{std::move(left_batch)})};
Declaration right{"exec_batch_source",
ExecBatchSourceNodeOptions(
schema({field("colVals", list_type), field("on", int64())}),
{std::move(right_batch)})};

AsofJoinNodeOptions asof_join_opts({{{"on"}, {}}, {{"on"}, {}}}, 1);
Declaration asof_join{
"asofjoin", {std::move(left), std::move(right)}, std::move(asof_join_opts)};

ASSERT_OK_AND_ASSIGN(auto result, DeclarationToExecBatches(std::move(asof_join)));
auto exp_batch = ExecBatchFromJSON({int64(), list_type}, R"([
[1, [0, 1, 2]],
[2, [0, 1, 2]],
[3, [3, 4, 5]]
])");

AssertExecBatchesEqual(result.schema, {exp_batch}, result.batches);
}

TEST(AsofJoinTest, StructTestDataType) {
auto struct_type = struct_({field("key", utf8()), field("value", int64())});

auto left_batch = ExecBatchFromJSON({int64()}, R"([[1], [2], [3]])");
auto right_batch = ExecBatchFromJSON({struct_type, int64()}, R"([
[{"key": "a", "value": 1}, 2],
[{"key": "b", "value": 3}, 3],
[{"key": "c", "value": 5}, 4]
])");

Declaration left{"exec_batch_source",
ExecBatchSourceNodeOptions(schema({field("on", int64())}),
{std::move(left_batch)})};
Declaration right{"exec_batch_source",
ExecBatchSourceNodeOptions(
schema({field("col", struct_type), field("on", int64())}),
{std::move(right_batch)})};
AsofJoinNodeOptions asof_join_opts({{{"on"}, {}}, {{"on"}, {}}}, 1);
Declaration asof_join{
"asofjoin", {std::move(left), std::move(right)}, std::move(asof_join_opts)};
ASSERT_OK_AND_ASSIGN(auto result, DeclarationToExecBatches(std::move(asof_join)));

auto exp_batch = ExecBatchFromJSON({int64(), struct_type}, R"([
[1, {"key": "a", "value": 1}],
[2, {"key": "a", "value": 1}],
[3, {"key": "b", "value": 3}]
])");
AssertExecBatchesEqual(result.schema, {exp_batch}, result.batches);
}

} // namespace acero
} // namespace arrow
13 changes: 4 additions & 9 deletions cpp/src/arrow/acero/unmaterialized_table_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ class UnmaterializedCompositeTable {
MATERIALIZE_CASE(BINARY)
MATERIALIZE_CASE(LARGE_BINARY)
MATERIALIZE_CASE(FIXED_SIZE_LIST)
MATERIALIZE_CASE(LIST)
MATERIALIZE_CASE(STRUCT)
default:
return arrow::Status::Invalid("Unsupported data type ",
field->type()->ToString(), " for field ",
Expand Down Expand Up @@ -167,8 +169,6 @@ class UnmaterializedCompositeTable {
num_rows += slice.Size();
}



template <class Type, class Builder = typename arrow::TypeTraits<Type>::BuilderType>
arrow::Result<std::shared_ptr<arrow::Array>> materializeColumn(
const std::shared_ptr<arrow::DataType>& type, int i_col) {
Expand All @@ -181,13 +181,8 @@ class UnmaterializedCompositeTable {
for (const auto& unmaterialized_slice : slices) {
const auto& [batch, start, end] = unmaterialized_slice.components[table_index];
if (batch) {
ARROW_RETURN_NOT_OK(builder.AppendArraySlice(*batch->column_data(column_index),start,end-start));

// for (uint64_t rowNum = start; rowNum < end; ++rowNum) {
// arrow::Status st = BuilderAppend<Type, Builder>(
// builder, batch->column_data(column_index), rowNum);
// ARROW_RETURN_NOT_OK(st);
// }
ARROW_RETURN_NOT_OK(builder.AppendArraySlice(*batch->column_data(column_index),
start, end - start));
} else {
for (uint64_t rowNum = start; rowNum < end; ++rowNum) {
ARROW_RETURN_NOT_OK(builder.AppendNull());
Expand Down

0 comments on commit 105633e

Please sign in to comment.