-
Notifications
You must be signed in to change notification settings - Fork 0
/
concatenate_records.go
66 lines (59 loc) · 1.95 KB
/
concatenate_records.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
package arrowops
import (
"fmt"
"github.com/alekLukanen/errs"
"github.com/apache/arrow/go/v17/arrow"
"github.com/apache/arrow/go/v17/arrow/array"
"github.com/apache/arrow/go/v17/arrow/memory"
)
/*
Concatenate the records together. The records must all have the same schema. The resulting
record will have all data from each record in the order they were supplied and will be
a new record with all data copied.
*/
func ConcatenateRecords(mem *memory.GoAllocator, records ...arrow.Record) (arrow.Record, error) {
for _, record := range records {
record.Retain()
}
defer func() {
for _, record := range records {
record.Release()
}
}()
// validate the records
if len(records) == 0 {
return nil, errs.NewStackError(fmt.Errorf("%w| records not provided", ErrNoDataSupplied))
}
schema := records[0].Schema()
for _, record := range records {
if !schema.Equal(record.Schema()) {
return nil, errs.NewStackError(fmt.Errorf("%w| records schemas not all equal", ErrSchemasNotEqual))
}
}
// group all of the columns from each record together
// so that we can concatenate them together
fields := make([][]arrow.Array, schema.NumFields())
for i := 0; i < schema.NumFields(); i++ {
fields[i] = make([]arrow.Array, len(records))
}
for recordIdx, record := range records {
for i := 0; i < schema.NumFields(); i++ {
fields[i][recordIdx] = record.Column(i)
}
}
// concatenate the columns of the same index together
concatenatedFields := make([]arrow.Array, schema.NumFields())
for i := 0; i < schema.NumFields(); i++ {
concatenatedField, err := array.Concatenate(fields[i], mem)
if err != nil {
return nil, errs.NewStackError(fmt.Errorf("%w| concatenation failed", err))
}
concatenatedFields[i] = concatenatedField
}
// get the total number of rows in the concatenated record
var numRows uint32
for _, record := range records {
numRows += uint32(record.NumRows())
}
return array.NewRecord(schema, concatenatedFields, int64(numRows)), nil
}