diff --git a/arrow/array.go b/arrow/array.go index d42ca6d05..891697e9b 100644 --- a/arrow/array.go +++ b/arrow/array.go @@ -111,7 +111,7 @@ type Array interface { ValueStr(i int) string // Get single value to be marshalled with `json.Marshal` - GetOneForMarshal(i int) interface{} + GetOneForMarshal(i int, nullable bool) interface{} Data() ArrayData diff --git a/arrow/array/binary.go b/arrow/array/binary.go index a8e77ae9c..ba9ce287b 100644 --- a/arrow/array/binary.go +++ b/arrow/array/binary.go @@ -152,8 +152,8 @@ func (a *Binary) setData(data *Data) { } } -func (a *Binary) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { +func (a *Binary) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && a.IsNull(i) { return nil } return a.Value(i) @@ -162,7 +162,7 @@ func (a *Binary) GetOneForMarshal(i int) interface{} { func (a *Binary) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { - vals[i] = a.GetOneForMarshal(i) + vals[i] = a.GetOneForMarshal(i, true) } // golang marshal standard says that []byte will be marshalled // as a base64-encoded string @@ -223,9 +223,9 @@ func (a *Binary) ValidateFull() error { return nil } -func arrayEqualBinary(left, right *Binary) bool { +func arrayEqualBinary(left, right *Binary, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if !bytes.Equal(left.Value(i), right.Value(i)) { @@ -346,8 +346,8 @@ func (a *LargeBinary) setData(data *Data) { } } -func (a *LargeBinary) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { +func (a *LargeBinary) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && a.IsNull(i) { return nil } return a.Value(i) @@ -356,7 +356,7 @@ func (a *LargeBinary) GetOneForMarshal(i int) interface{} { func (a *LargeBinary) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { - vals[i] = a.GetOneForMarshal(i) + vals[i] = a.GetOneForMarshal(i, true) } // golang marshal standard says that []byte will be marshalled // as a base64-encoded string @@ -417,9 +417,9 @@ func (a *LargeBinary) ValidateFull() error { return nil } -func arrayEqualLargeBinary(left, right *LargeBinary) bool { +func arrayEqualLargeBinary(left, right *LargeBinary, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if !bytes.Equal(left.Value(i), right.Value(i)) { @@ -522,8 +522,8 @@ func (a *BinaryView) ValueStr(i int) string { return base64.StdEncoding.EncodeToString(a.Value(i)) } -func (a *BinaryView) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { +func (a *BinaryView) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && a.IsNull(i) { return nil } return a.Value(i) @@ -532,17 +532,17 @@ func (a *BinaryView) GetOneForMarshal(i int) interface{} { func (a *BinaryView) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { - vals[i] = a.GetOneForMarshal(i) + vals[i] = a.GetOneForMarshal(i, true) } // golang marshal standard says that []byte will be marshalled // as a base64-encoded string return json.Marshal(vals) } -func arrayEqualBinaryView(left, right *BinaryView) bool { +func arrayEqualBinaryView(left, right *BinaryView, opt equalOption) bool { leftBufs, rightBufs := left.dataBuffers, right.dataBuffers for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if !left.ValueHeader(i).Equals(leftBufs, right.ValueHeader(i), rightBufs) { diff --git a/arrow/array/boolean.go b/arrow/array/boolean.go index d579fa0c8..c555536f5 100644 --- a/arrow/array/boolean.go +++ b/arrow/array/boolean.go @@ -90,8 +90,8 @@ func (a *Boolean) setData(data *Data) { } } -func (a *Boolean) GetOneForMarshal(i int) interface{} { - if a.IsValid(i) { +func (a *Boolean) GetOneForMarshal(i int, nullable bool) interface{} { + if !nullable || a.IsValid(i) { return a.Value(i) } return nil @@ -109,9 +109,9 @@ func (a *Boolean) MarshalJSON() ([]byte, error) { return json.Marshal(vals) } -func arrayEqualBoolean(left, right *Boolean) bool { +func arrayEqualBoolean(left, right *Boolean, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if left.Value(i) != right.Value(i) { diff --git a/arrow/array/compare.go b/arrow/array/compare.go index 3f1dad177..41247eaa5 100644 --- a/arrow/array/compare.go +++ b/arrow/array/compare.go @@ -26,7 +26,11 @@ import ( ) // RecordEqual reports whether the two provided records are equal. -func RecordEqual(left, right arrow.RecordBatch) bool { +func RecordEqual(left, right arrow.RecordBatch, opts ...EqualOption) bool { + return recordEqual(left, right, newEqualOption(opts...)) +} + +func recordEqual(left, right arrow.RecordBatch, opt equalOption) bool { switch { case left.NumCols() != right.NumCols(): return false @@ -35,9 +39,17 @@ func RecordEqual(left, right arrow.RecordBatch) bool { } for i := range left.Columns() { + lf := left.Schema().Field(i) + rf := left.Schema().Field(i) + if !lf.Equal(rf) { + return false + } + + opt.nullable = lf.Nullable + lc := left.Column(i) rc := right.Column(i) - if !Equal(lc, rc) { + if !equal(lc, rc, opt) { return false } } @@ -47,6 +59,10 @@ func RecordEqual(left, right arrow.RecordBatch) bool { // RecordApproxEqual reports whether the two provided records are approximately equal. // For non-floating point columns, it is equivalent to RecordEqual. func RecordApproxEqual(left, right arrow.RecordBatch, opts ...EqualOption) bool { + return recordApproxEqual(left, right, newEqualOption(opts...)) +} + +func recordApproxEqual(left, right arrow.RecordBatch, opt equalOption) bool { switch { case left.NumCols() != right.NumCols(): return false @@ -54,9 +70,15 @@ func RecordApproxEqual(left, right arrow.RecordBatch, opts ...EqualOption) bool return false } - opt := newEqualOption(opts...) - for i := range left.Columns() { + lf := left.Schema().Field(i) + rf := left.Schema().Field(i) + if !lf.Equal(rf) { + return false + } + + opt.nullable = lf.Nullable + lc := left.Column(i) rc := right.Column(i) if !arrayApproxEqual(lc, rc, opt) { @@ -106,13 +128,17 @@ func chunkedBinaryApply(left, right *arrow.Chunked, fn func(left arrow.Array, lb } // ChunkedEqual reports whether two chunked arrays are equal regardless of their chunkings -func ChunkedEqual(left, right *arrow.Chunked) bool { +func ChunkedEqual(left, right *arrow.Chunked, opts ...EqualOption) bool { + return chunkedEqual(left, right, newEqualOption(opts...)) +} + +func chunkedEqual(left, right *arrow.Chunked, opt equalOption) bool { switch { case left == right: return true case left.Len() != right.Len(): return false - case left.NullN() != right.NullN(): + case opt.nullable && left.NullN() != right.NullN(): return false case !arrow.TypeEqual(left.DataType(), right.DataType()): return false @@ -130,12 +156,16 @@ func ChunkedEqual(left, right *arrow.Chunked) bool { // ChunkedApproxEqual reports whether two chunked arrays are approximately equal regardless of their chunkings // for non-floating point arrays, this is equivalent to ChunkedEqual func ChunkedApproxEqual(left, right *arrow.Chunked, opts ...EqualOption) bool { + return chunkedApproxEqual(left, right, newEqualOption(opts...)) +} + +func chunkedApproxEqual(left, right *arrow.Chunked, opt equalOption) bool { switch { case left == right: return true case left.Len() != right.Len(): return false - case left.NullN() != right.NullN(): + case opt.nullable && left.NullN() != right.NullN(): return false case !arrow.TypeEqual(left.DataType(), right.DataType()): return false @@ -143,7 +173,7 @@ func ChunkedApproxEqual(left, right *arrow.Chunked, opts ...EqualOption) bool { var isequal bool chunkedBinaryApply(left, right, func(left arrow.Array, lbeg, lend int64, right arrow.Array, rbeg, rend int64) bool { - isequal = SliceApproxEqual(left, lbeg, lend, right, rbeg, rend, opts...) + isequal = sliceApproxEqual(left, lbeg, lend, right, rbeg, rend, opt) return isequal }) @@ -151,7 +181,11 @@ func ChunkedApproxEqual(left, right *arrow.Chunked, opts ...EqualOption) bool { } // TableEqual returns if the two tables have the same data in the same schema -func TableEqual(left, right arrow.Table) bool { +func TableEqual(left, right arrow.Table, opts ...EqualOption) bool { + return tableEqual(left, right, newEqualOption(opts...)) +} + +func tableEqual(left, right arrow.Table, opt equalOption) bool { switch { case left.NumCols() != right.NumCols(): return false @@ -160,21 +194,29 @@ func TableEqual(left, right arrow.Table) bool { } for i := 0; int64(i) < left.NumCols(); i++ { - lc := left.Column(i) - rc := right.Column(i) - if !lc.Field().Equal(rc.Field()) { + lf := left.Schema().Field(i) + rf := left.Schema().Field(i) + if !lf.Equal(rf) { return false } - if !ChunkedEqual(lc.Data(), rc.Data()) { + opt.nullable = lf.Nullable + + lc := left.Column(i) + rc := right.Column(i) + if !chunkedEqual(lc.Data(), rc.Data(), opt) { return false } } return true } -// TableEqual returns if the two tables have the approximately equal data in the same schema +// TableApproxEqual returns if the two tables have the approximately equal data in the same schema func TableApproxEqual(left, right arrow.Table, opts ...EqualOption) bool { + return tableApproxEqual(left, right, newEqualOption(opts...)) +} + +func tableApproxEqual(left, right arrow.Table, opt equalOption) bool { switch { case left.NumCols() != right.NumCols(): return false @@ -183,13 +225,17 @@ func TableApproxEqual(left, right arrow.Table, opts ...EqualOption) bool { } for i := 0; int64(i) < left.NumCols(); i++ { - lc := left.Column(i) - rc := right.Column(i) - if !lc.Field().Equal(rc.Field()) { + lf := left.Schema().Field(i) + rf := left.Schema().Field(i) + if !lf.Equal(rf) { return false } - if !ChunkedApproxEqual(lc.Data(), rc.Data(), opts...) { + opt.nullable = lf.Nullable + + lc := left.Column(i) + rc := right.Column(i) + if !chunkedApproxEqual(lc.Data(), rc.Data(), opt) { return false } } @@ -197,13 +243,17 @@ func TableApproxEqual(left, right arrow.Table, opts ...EqualOption) bool { } // Equal reports whether the two provided arrays are equal. -func Equal(left, right arrow.Array) bool { +func Equal(left, right arrow.Array, opts ...EqualOption) bool { + return equal(left, right, newEqualOption(opts...)) +} + +func equal(left, right arrow.Array, opt equalOption) bool { switch { - case !baseArrayEqual(left, right): + case !baseArrayEqual(left, right, opt): return false case left.Len() == 0: return true - case left.NullN() == left.Len(): + case opt.nullable && left.NullN() == left.Len(): return true } @@ -216,127 +266,127 @@ func Equal(left, right arrow.Array) bool { return true case *Boolean: r := right.(*Boolean) - return arrayEqualBoolean(l, r) + return arrayEqualBoolean(l, r, opt) case *FixedSizeBinary: r := right.(*FixedSizeBinary) - return arrayEqualFixedSizeBinary(l, r) + return arrayEqualFixedSizeBinary(l, r, opt) case *Binary: r := right.(*Binary) - return arrayEqualBinary(l, r) + return arrayEqualBinary(l, r, opt) case *String: r := right.(*String) - return arrayEqualString(l, r) + return arrayEqualString(l, r, opt) case *LargeBinary: r := right.(*LargeBinary) - return arrayEqualLargeBinary(l, r) + return arrayEqualLargeBinary(l, r, opt) case *LargeString: r := right.(*LargeString) - return arrayEqualLargeString(l, r) + return arrayEqualLargeString(l, r, opt) case *BinaryView: r := right.(*BinaryView) - return arrayEqualBinaryView(l, r) + return arrayEqualBinaryView(l, r, opt) case *StringView: r := right.(*StringView) - return arrayEqualStringView(l, r) + return arrayEqualStringView(l, r, opt) case *Int8: r := right.(*Int8) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Int16: r := right.(*Int16) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Int32: r := right.(*Int32) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Int64: r := right.(*Int64) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Uint8: r := right.(*Uint8) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Uint16: r := right.(*Uint16) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Uint32: r := right.(*Uint32) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Uint64: r := right.(*Uint64) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Float16: r := right.(*Float16) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Float32: r := right.(*Float32) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Float64: r := right.(*Float64) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Decimal32: r := right.(*Decimal32) - return arrayEqualDecimal(l, r) + return arrayEqualDecimal(l, r, opt) case *Decimal64: r := right.(*Decimal64) - return arrayEqualDecimal(l, r) + return arrayEqualDecimal(l, r, opt) case *Decimal128: r := right.(*Decimal128) - return arrayEqualDecimal(l, r) + return arrayEqualDecimal(l, r, opt) case *Decimal256: r := right.(*Decimal256) - return arrayEqualDecimal(l, r) + return arrayEqualDecimal(l, r, opt) case *Date32: r := right.(*Date32) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Date64: r := right.(*Date64) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Time32: r := right.(*Time32) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Time64: r := right.(*Time64) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Timestamp: r := right.(*Timestamp) - return arrayEqualTimestamp(l, r) + return arrayEqualTimestamp(l, r, opt) case *List: r := right.(*List) - return arrayEqualList(l, r) + return arrayEqualList(l, r, opt) case *LargeList: r := right.(*LargeList) - return arrayEqualLargeList(l, r) + return arrayEqualLargeList(l, r, opt) case *ListView: r := right.(*ListView) - return arrayEqualListView(l, r) + return arrayEqualListView(l, r, opt) case *LargeListView: r := right.(*LargeListView) - return arrayEqualLargeListView(l, r) + return arrayEqualLargeListView(l, r, opt) case *FixedSizeList: r := right.(*FixedSizeList) - return arrayEqualFixedSizeList(l, r) + return arrayEqualFixedSizeList(l, r, opt) case *Struct: r := right.(*Struct) - return arrayEqualStruct(l, r) + return arrayEqualStruct(l, r, opt) case *MonthInterval: r := right.(*MonthInterval) - return arrayEqualMonthInterval(l, r) + return arrayEqualMonthInterval(l, r, opt) case *DayTimeInterval: r := right.(*DayTimeInterval) - return arrayEqualDayTimeInterval(l, r) + return arrayEqualDayTimeInterval(l, r, opt) case *MonthDayNanoInterval: r := right.(*MonthDayNanoInterval) - return arrayEqualMonthDayNanoInterval(l, r) + return arrayEqualMonthDayNanoInterval(l, r, opt) case *Duration: r := right.(*Duration) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Map: r := right.(*Map) - return arrayEqualMap(l, r) + return arrayEqualMap(l, r, opt) case ExtensionArray: r := right.(ExtensionArray) - return arrayEqualExtension(l, r) + return arrayEqualExtension(l, r, opt) case *Dictionary: r := right.(*Dictionary) - return arrayEqualDict(l, r) + return arrayEqualDict(l, r, opt) case *SparseUnion: r := right.(*SparseUnion) return arraySparseUnionEqual(l, r) @@ -345,26 +395,29 @@ func Equal(left, right arrow.Array) bool { return arrayDenseUnionEqual(l, r) case *RunEndEncoded: r := right.(*RunEndEncoded) - return arrayRunEndEncodedEqual(l, r) + return arrayRunEndEncodedEqual(l, r, opt) default: panic(fmt.Errorf("arrow/array: unknown array type %T", l)) } } // SliceEqual reports whether slices left[lbeg:lend] and right[rbeg:rend] are equal. -func SliceEqual(left arrow.Array, lbeg, lend int64, right arrow.Array, rbeg, rend int64) bool { +func SliceEqual(left arrow.Array, lbeg, lend int64, right arrow.Array, rbeg, rend int64, opts ...EqualOption) bool { + return sliceEqual(left, lbeg, lend, right, rbeg, rend, newEqualOption(opts...)) +} + +func sliceEqual(left arrow.Array, lbeg, lend int64, right arrow.Array, rbeg, rend int64, opt equalOption) bool { l := NewSlice(left, lbeg, lend) defer l.Release() r := NewSlice(right, rbeg, rend) defer r.Release() - return Equal(l, r) + return equal(l, r, opt) } // SliceApproxEqual reports whether slices left[lbeg:lend] and right[rbeg:rend] are approximately equal. func SliceApproxEqual(left arrow.Array, lbeg, lend int64, right arrow.Array, rbeg, rend int64, opts ...EqualOption) bool { - opt := newEqualOption(opts...) - return sliceApproxEqual(left, lbeg, lend, right, rbeg, rend, opt) + return sliceApproxEqual(left, lbeg, lend, right, rbeg, rend, newEqualOption(opts...)) } func sliceApproxEqual(left arrow.Array, lbeg, lend int64, right arrow.Array, rbeg, rend int64, opt equalOption) bool { @@ -382,6 +435,7 @@ type equalOption struct { atol float64 // absolute tolerance nansEq bool // whether NaNs are considered equal. unorderedMapKeys bool // whether maps are allowed to have different entries order + nullable bool // whether the fields being compared are considered nullable } func (eq equalOption) f16(f1, f2 float16.Num) bool { @@ -417,8 +471,9 @@ func (eq equalOption) f64(v1, v2 float64) bool { func newEqualOption(opts ...EqualOption) equalOption { eq := equalOption{ - atol: defaultAbsoluteTolerance, - nansEq: false, + atol: defaultAbsoluteTolerance, + nansEq: false, + nullable: true, } for _, opt := range opts { opt(&eq) @@ -452,20 +507,27 @@ func WithUnorderedMapKeys(v bool) EqualOption { } } +// WithNullable sets whether the comparison function will consider both fields as nullable. If they're non-nullable, their +// valids buffer will be ignored for comparison and the underlying values will be used instead +func WithNullable(v bool) EqualOption { + return func(o *equalOption) { + o.nullable = v + } +} + // ApproxEqual reports whether the two provided arrays are approximately equal. // For non-floating point arrays, it is equivalent to Equal. func ApproxEqual(left, right arrow.Array, opts ...EqualOption) bool { - opt := newEqualOption(opts...) - return arrayApproxEqual(left, right, opt) + return arrayApproxEqual(left, right, newEqualOption(opts...)) } func arrayApproxEqual(left, right arrow.Array, opt equalOption) bool { switch { - case !baseArrayEqual(left, right): + case !baseArrayEqual(left, right, opt): return false case left.Len() == 0: return true - case left.NullN() == left.Len(): + case opt.nullable && left.NullN() == left.Len(): return true } @@ -478,52 +540,52 @@ func arrayApproxEqual(left, right arrow.Array, opt equalOption) bool { return true case *Boolean: r := right.(*Boolean) - return arrayEqualBoolean(l, r) + return arrayEqualBoolean(l, r, opt) case *FixedSizeBinary: r := right.(*FixedSizeBinary) - return arrayEqualFixedSizeBinary(l, r) + return arrayEqualFixedSizeBinary(l, r, opt) case *Binary: r := right.(*Binary) - return arrayEqualBinary(l, r) + return arrayEqualBinary(l, r, opt) case *String: r := right.(*String) - return arrayApproxEqualString(l, r) + return arrayApproxEqualString(l, r, opt) case *LargeBinary: r := right.(*LargeBinary) - return arrayEqualLargeBinary(l, r) + return arrayEqualLargeBinary(l, r, opt) case *LargeString: r := right.(*LargeString) - return arrayApproxEqualLargeString(l, r) + return arrayApproxEqualLargeString(l, r, opt) case *BinaryView: r := right.(*BinaryView) - return arrayEqualBinaryView(l, r) + return arrayEqualBinaryView(l, r, opt) case *StringView: r := right.(*StringView) - return arrayApproxEqualStringView(l, r) + return arrayApproxEqualStringView(l, r, opt) case *Int8: r := right.(*Int8) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Int16: r := right.(*Int16) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Int32: r := right.(*Int32) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Int64: r := right.(*Int64) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Uint8: r := right.(*Uint8) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Uint16: r := right.(*Uint16) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Uint32: r := right.(*Uint32) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Uint64: r := right.(*Uint64) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Float16: r := right.(*Float16) return arrayApproxEqualFloat16(l, r, opt) @@ -535,31 +597,31 @@ func arrayApproxEqual(left, right arrow.Array, opt equalOption) bool { return arrayApproxEqualFloat64(l, r, opt) case *Decimal32: r := right.(*Decimal32) - return arrayEqualDecimal(l, r) + return arrayEqualDecimal(l, r, opt) case *Decimal64: r := right.(*Decimal64) - return arrayEqualDecimal(l, r) + return arrayEqualDecimal(l, r, opt) case *Decimal128: r := right.(*Decimal128) - return arrayEqualDecimal(l, r) + return arrayEqualDecimal(l, r, opt) case *Decimal256: r := right.(*Decimal256) - return arrayEqualDecimal(l, r) + return arrayEqualDecimal(l, r, opt) case *Date32: r := right.(*Date32) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Date64: r := right.(*Date64) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Time32: r := right.(*Time32) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Time64: r := right.(*Time64) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Timestamp: r := right.(*Timestamp) - return arrayEqualTimestamp(l, r) + return arrayEqualTimestamp(l, r, opt) case *List: r := right.(*List) return arrayApproxEqualList(l, r, opt) @@ -580,16 +642,16 @@ func arrayApproxEqual(left, right arrow.Array, opt equalOption) bool { return arrayApproxEqualStruct(l, r, opt) case *MonthInterval: r := right.(*MonthInterval) - return arrayEqualMonthInterval(l, r) + return arrayEqualMonthInterval(l, r, opt) case *DayTimeInterval: r := right.(*DayTimeInterval) - return arrayEqualDayTimeInterval(l, r) + return arrayEqualDayTimeInterval(l, r, opt) case *MonthDayNanoInterval: r := right.(*MonthDayNanoInterval) - return arrayEqualMonthDayNanoInterval(l, r) + return arrayEqualMonthDayNanoInterval(l, r, opt) case *Duration: r := right.(*Duration) - return arrayEqualFixedWidth(l, r) + return arrayEqualFixedWidth(l, r, opt) case *Map: r := right.(*Map) if opt.unorderedMapKeys { @@ -616,15 +678,15 @@ func arrayApproxEqual(left, right arrow.Array, opt equalOption) bool { } } -func baseArrayEqual(left, right arrow.Array) bool { +func baseArrayEqual(left, right arrow.Array, opt equalOption) bool { switch { case left.Len() != right.Len(): return false - case left.NullN() != right.NullN(): + case opt.nullable && left.NullN() != right.NullN(): return false case !arrow.TypeEqual(left.DataType(), right.DataType()): // We do not check for metadata as in the C++ implementation. return false - case !validityBitmapEqual(left, right): + case opt.nullable && !validityBitmapEqual(left, right): return false } return true @@ -644,9 +706,9 @@ func validityBitmapEqual(left, right arrow.Array) bool { return true } -func arrayApproxEqualString(left, right *String) bool { +func arrayApproxEqualString(left, right *String, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if stripNulls(left.Value(i)) != stripNulls(right.Value(i)) { @@ -656,9 +718,9 @@ func arrayApproxEqualString(left, right *String) bool { return true } -func arrayApproxEqualLargeString(left, right *LargeString) bool { +func arrayApproxEqualLargeString(left, right *LargeString, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if stripNulls(left.Value(i)) != stripNulls(right.Value(i)) { @@ -668,9 +730,9 @@ func arrayApproxEqualLargeString(left, right *LargeString) bool { return true } -func arrayApproxEqualStringView(left, right *StringView) bool { +func arrayApproxEqualStringView(left, right *StringView, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if stripNulls(left.Value(i)) != stripNulls(right.Value(i)) { @@ -682,7 +744,7 @@ func arrayApproxEqualStringView(left, right *StringView) bool { func arrayApproxEqualFloat16(left, right *Float16, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if !opt.f16(left.Value(i), right.Value(i)) { @@ -694,7 +756,7 @@ func arrayApproxEqualFloat16(left, right *Float16, opt equalOption) bool { func arrayApproxEqualFloat32(left, right *Float32, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if !opt.f32(left.Value(i), right.Value(i)) { @@ -706,7 +768,7 @@ func arrayApproxEqualFloat32(left, right *Float32, opt equalOption) bool { func arrayApproxEqualFloat64(left, right *Float64, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if !opt.f64(left.Value(i), right.Value(i)) { @@ -718,7 +780,7 @@ func arrayApproxEqualFloat64(left, right *Float64, opt equalOption) bool { func arrayApproxEqualList(left, right *List, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } o := func() bool { @@ -737,7 +799,7 @@ func arrayApproxEqualList(left, right *List, opt equalOption) bool { func arrayApproxEqualLargeList(left, right *LargeList, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } o := func() bool { @@ -756,7 +818,7 @@ func arrayApproxEqualLargeList(left, right *LargeList, opt equalOption) bool { func arrayApproxEqualListView(left, right *ListView, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } o := func() bool { @@ -775,7 +837,7 @@ func arrayApproxEqualListView(left, right *ListView, opt equalOption) bool { func arrayApproxEqualLargeListView(left, right *LargeListView, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } o := func() bool { @@ -794,7 +856,7 @@ func arrayApproxEqualLargeListView(left, right *LargeListView, opt equalOption) func arrayApproxEqualFixedSizeList(left, right *FixedSizeList, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } o := func() bool { @@ -812,11 +874,15 @@ func arrayApproxEqualFixedSizeList(left, right *FixedSizeList, opt equalOption) } func arrayApproxEqualStruct(left, right *Struct, opt equalOption) bool { - return bitutils.VisitSetBitRuns( - left.NullBitmapBytes(), - int64(left.Offset()), int64(left.Len()), - approxEqualStructRun(left, right, opt), - ) == nil + visitFn := approxEqualStructRun(left, right, opt) + if opt.nullable { + return bitutils.VisitSetBitRuns( + left.NullBitmapBytes(), + int64(left.Offset()), int64(left.Len()), + visitFn, + ) == nil + } + return visitFn(0, int64(left.Len())) == nil } func approxEqualStructRun(left, right *Struct, opt equalOption) bitutils.VisitFn { @@ -833,7 +899,7 @@ func approxEqualStructRun(left, right *Struct, opt equalOption) bitutils.VisitFn // arrayApproxEqualMap doesn't care about the order of keys (in Go map traversal order is undefined) func arrayApproxEqualMap(left, right *Map, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if !arrayApproxEqualSingleMapEntry(left.newListValue(i).(*Struct), right.newListValue(i).(*Struct), opt) { @@ -854,7 +920,7 @@ func arrayApproxEqualSingleMapEntry(left, right *Struct, opt equalOption) bool { switch { case left.Len() != right.Len(): return false - case left.NullN() != right.NullN(): + case opt.nullable && left.NullN() != right.NullN(): return false case !arrow.TypeEqual(left.DataType(), right.DataType()): // We do not check for metadata as in the C++ implementation. return false @@ -864,7 +930,7 @@ func arrayApproxEqualSingleMapEntry(left, right *Struct, opt equalOption) bool { used := make(map[int]bool, right.Len()) for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } @@ -874,7 +940,7 @@ func arrayApproxEqualSingleMapEntry(left, right *Struct, opt equalOption) bool { if used[j] { continue } - if right.IsNull(j) { + if opt.nullable && right.IsNull(j) { used[j] = true continue } diff --git a/arrow/array/compare_test.go b/arrow/array/compare_test.go index 5c569f25d..ea1610221 100644 --- a/arrow/array/compare_test.go +++ b/arrow/array/compare_test.go @@ -391,6 +391,49 @@ func TestArrayApproxEqualFloats(t *testing.T) { } } +func TestArrayEqualNonNullable(t *testing.T) { + for name, recs := range arrdata.Records { + t.Run(name, func(t *testing.T) { + rec := recs[0] + + // Clone the schema and make everything non-nullable + fields := rec.Schema().Fields() + meta := rec.Schema().Metadata() + for i := range fields { + fields[i].Nullable = false + } + schema := arrow.NewSchema(fields, &meta) + + for i, rawCol := range rec.Columns() { + // make a clone of the column with NullN=0 + col := array.MakeFromData(array.NewData( + rawCol.DataType(), + rawCol.Len(), + rawCol.Data().Buffers(), + rawCol.Data().Children(), + 0, + 0, + )) + t.Run(schema.Field(i).Name, func(t *testing.T) { + arr := col + if !array.Equal(arr, arr, array.WithNullable(false)) { + t.Fatalf("identical arrays should compare equal:\narray=%v", arr) + } + sub1 := array.NewSlice(arr, 1, int64(arr.Len())) + defer sub1.Release() + + sub2 := array.NewSlice(arr, 0, int64(arr.Len()-1)) + defer sub2.Release() + + if array.Equal(sub1, sub2) && name != "nulls" { + t.Fatalf("non-identical arrays should not compare equal:\nsub1=%v\nsub2=%v\narrf=%v\n", sub1, sub2, arr) + } + }) + } + }) + } +} + func testStringMap(mem memory.Allocator, m map[string]string, keys []string) *array.Map { dt := arrow.MapOf(arrow.BinaryTypes.String, arrow.BinaryTypes.String) builder := array.NewMapBuilderWithType(mem, dt) diff --git a/arrow/array/decimal.go b/arrow/array/decimal.go index 993242b91..9cc1a9964 100644 --- a/arrow/array/decimal.go +++ b/arrow/array/decimal.go @@ -55,7 +55,7 @@ func (a *baseDecimal[T]) ValueStr(i int) string { if a.IsNull(i) { return NullValueStr } - return a.GetOneForMarshal(i).(string) + return a.GetOneForMarshal(i, true).(string) } func (a *baseDecimal[T]) Values() []T { return a.values } @@ -89,8 +89,8 @@ func (a *baseDecimal[T]) setData(data *Data) { } } -func (a *baseDecimal[T]) GetOneForMarshal(i int) any { - if a.IsNull(i) { +func (a *baseDecimal[T]) GetOneForMarshal(i int, nullable bool) any { + if nullable && a.IsNull(i) { return nil } @@ -102,7 +102,7 @@ func (a *baseDecimal[T]) GetOneForMarshal(i int) any { func (a *baseDecimal[T]) MarshalJSON() ([]byte, error) { vals := make([]any, a.Len()) for i := 0; i < a.Len(); i++ { - vals[i] = a.GetOneForMarshal(i) + vals[i] = a.GetOneForMarshal(i, true) } return json.Marshal(vals) } @@ -110,9 +110,9 @@ func (a *baseDecimal[T]) MarshalJSON() ([]byte, error) { func arrayEqualDecimal[T interface { decimal.DecimalTypes decimal.Num[T] -}](left, right *baseDecimal[T]) bool { +}](left, right *baseDecimal[T], opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } diff --git a/arrow/array/decimal128_test.go b/arrow/array/decimal128_test.go index 4d48a97a7..6f51e794f 100644 --- a/arrow/array/decimal128_test.go +++ b/arrow/array/decimal128_test.go @@ -278,6 +278,6 @@ func TestDecimal128GetOneForMarshal(t *testing.T) { } for i := range cases { - assert.Equalf(t, cases[i].want, arr.GetOneForMarshal(i), "unexpected value at index %d", i) + assert.Equalf(t, cases[i].want, arr.GetOneForMarshal(i, true), "unexpected value at index %d", i) } } diff --git a/arrow/array/decimal256_test.go b/arrow/array/decimal256_test.go index 025f3bd3a..0ba7f4f7d 100644 --- a/arrow/array/decimal256_test.go +++ b/arrow/array/decimal256_test.go @@ -288,6 +288,6 @@ func TestDecimal256GetOneForMarshal(t *testing.T) { } for i := range cases { - assert.Equalf(t, cases[i].want, arr.GetOneForMarshal(i), "unexpected value at index %d", i) + assert.Equalf(t, cases[i].want, arr.GetOneForMarshal(i, true), "unexpected value at index %d", i) } } diff --git a/arrow/array/dictionary.go b/arrow/array/dictionary.go index 8f9fe2451..4a796996f 100644 --- a/arrow/array/dictionary.go +++ b/arrow/array/dictionary.go @@ -287,24 +287,24 @@ func (d *Dictionary) GetValueIndex(i int) int { return -1 } -func (d *Dictionary) GetOneForMarshal(i int) interface{} { - if d.IsNull(i) { +func (d *Dictionary) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && d.IsNull(i) { return nil } vidx := d.GetValueIndex(i) - return d.Dictionary().GetOneForMarshal(vidx) + return d.Dictionary().GetOneForMarshal(vidx, nullable) } func (d *Dictionary) MarshalJSON() ([]byte, error) { vals := make([]any, d.Len()) for i := range d.Len() { - vals[i] = d.GetOneForMarshal(i) + vals[i] = d.GetOneForMarshal(i, true) } return json.Marshal(vals) } -func arrayEqualDict(l, r *Dictionary) bool { - return Equal(l.Dictionary(), r.Dictionary()) && Equal(l.indices, r.indices) +func arrayEqualDict(l, r *Dictionary, opt equalOption) bool { + return equal(l.Dictionary(), r.Dictionary(), opt) && equal(l.indices, r.indices, opt) } func arrayApproxEqualDict(l, r *Dictionary, opt equalOption) bool { diff --git a/arrow/array/encoded.go b/arrow/array/encoded.go index 85432a13c..ff2dab92f 100644 --- a/arrow/array/encoded.go +++ b/arrow/array/encoded.go @@ -219,13 +219,13 @@ func (r *RunEndEncoded) String() string { buf.WriteByte(',') } - value := r.values.GetOneForMarshal(i) + value := r.values.GetOneForMarshal(i, true) if byts, ok := value.(json.RawMessage); ok { value = string(byts) } var runEnd int - switch e := r.ends.GetOneForMarshal(i).(type) { + switch e := r.ends.GetOneForMarshal(i, true).(type) { case int16: runEnd = int(e) - r.data.offset case int32: @@ -240,8 +240,8 @@ func (r *RunEndEncoded) String() string { return buf.String() } -func (r *RunEndEncoded) GetOneForMarshal(i int) interface{} { - return r.values.GetOneForMarshal(r.GetPhysicalIndex(i)) +func (r *RunEndEncoded) GetOneForMarshal(i int, nullable bool) interface{} { + return r.values.GetOneForMarshal(r.GetPhysicalIndex(i), nullable) } func (r *RunEndEncoded) MarshalJSON() ([]byte, error) { @@ -252,7 +252,7 @@ func (r *RunEndEncoded) MarshalJSON() ([]byte, error) { if i != 0 { buf.WriteByte(',') } - if err := enc.Encode(r.GetOneForMarshal(i)); err != nil { + if err := enc.Encode(r.GetOneForMarshal(i, true)); err != nil { return nil, err } } @@ -260,14 +260,14 @@ func (r *RunEndEncoded) MarshalJSON() ([]byte, error) { return buf.Bytes(), nil } -func arrayRunEndEncodedEqual(l, r *RunEndEncoded) bool { +func arrayRunEndEncodedEqual(l, r *RunEndEncoded, opt equalOption) bool { // types were already checked before getting here, so we know // the encoded types are equal mr := encoded.NewMergedRuns([2]arrow.Array{l, r}) for mr.Next() { lIndex := mr.IndexIntoArray(0) rIndex := mr.IndexIntoArray(1) - if !SliceEqual(l.values, lIndex, lIndex+1, r.values, rIndex, rIndex+1) { + if !sliceEqual(l.values, lIndex, lIndex+1, r.values, rIndex, rIndex+1, opt) { return false } } diff --git a/arrow/array/extension.go b/arrow/array/extension.go index e509b5e0f..48c4f03ad 100644 --- a/arrow/array/extension.go +++ b/arrow/array/extension.go @@ -46,12 +46,12 @@ type ExtensionArray interface { // two extension arrays are equal if their data types are equal and // their underlying storage arrays are equal. -func arrayEqualExtension(l, r ExtensionArray) bool { +func arrayEqualExtension(l, r ExtensionArray, opt equalOption) bool { if !arrow.TypeEqual(l.DataType(), r.DataType()) { return false } - return Equal(l.Storage(), r.Storage()) + return equal(l.Storage(), r.Storage(), opt) } // two extension arrays are approximately equal if their data types are @@ -116,8 +116,8 @@ func (e *ExtensionArrayBase) String() string { return fmt.Sprintf("(%s)%s", e.data.dtype, e.storage) } -func (e *ExtensionArrayBase) GetOneForMarshal(i int) interface{} { - return e.storage.GetOneForMarshal(i) +func (e *ExtensionArrayBase) GetOneForMarshal(i int, nullable bool) interface{} { + return e.storage.GetOneForMarshal(i, nullable) } func (e *ExtensionArrayBase) MarshalJSON() ([]byte, error) { diff --git a/arrow/array/fixed_size_list.go b/arrow/array/fixed_size_list.go index 69cb67b22..56f7c46b3 100644 --- a/arrow/array/fixed_size_list.go +++ b/arrow/array/fixed_size_list.go @@ -51,7 +51,7 @@ func (a *FixedSizeList) ValueStr(i int) string { if a.IsNull(i) { return NullValueStr } - return string(a.GetOneForMarshal(i).(json.RawMessage)) + return string(a.GetOneForMarshal(i, true).(json.RawMessage)) } func (a *FixedSizeList) String() string { @@ -84,9 +84,9 @@ func (a *FixedSizeList) setData(data *Data) { a.values = MakeFromData(data.childData[0]) } -func arrayEqualFixedSizeList(left, right *FixedSizeList) bool { +func arrayEqualFixedSizeList(left, right *FixedSizeList, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } o := func() bool { @@ -94,7 +94,7 @@ func arrayEqualFixedSizeList(left, right *FixedSizeList) bool { defer l.Release() r := right.newListValue(i) defer r.Release() - return Equal(l, r) + return equal(l, r, opt) }() if !o { return false @@ -123,8 +123,8 @@ func (a *FixedSizeList) Release() { a.values.Release() } -func (a *FixedSizeList) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { +func (a *FixedSizeList) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && a.IsNull(i) { return nil } slice := a.newListValue(i) diff --git a/arrow/array/fixedsize_binary.go b/arrow/array/fixedsize_binary.go index 31d507c5b..b9cf2bfdc 100644 --- a/arrow/array/fixedsize_binary.go +++ b/arrow/array/fixedsize_binary.go @@ -86,8 +86,8 @@ func (a *FixedSizeBinary) setData(data *Data) { } } -func (a *FixedSizeBinary) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { +func (a *FixedSizeBinary) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && a.IsNull(i) { return nil } @@ -106,9 +106,9 @@ func (a *FixedSizeBinary) MarshalJSON() ([]byte, error) { return json.Marshal(vals) } -func arrayEqualFixedSizeBinary(left, right *FixedSizeBinary) bool { +func arrayEqualFixedSizeBinary(left, right *FixedSizeBinary, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if !bytes.Equal(left.Value(i), right.Value(i)) { diff --git a/arrow/array/float16.go b/arrow/array/float16.go index 41276803b..8536df402 100644 --- a/arrow/array/float16.go +++ b/arrow/array/float16.go @@ -77,8 +77,8 @@ func (a *Float16) setData(data *Data) { } } -func (a *Float16) GetOneForMarshal(i int) interface{} { - if a.IsValid(i) { +func (a *Float16) GetOneForMarshal(i int, nullable bool) interface{} { + if !nullable || a.IsValid(i) { return a.values[i].Float32() } return nil diff --git a/arrow/array/interval.go b/arrow/array/interval.go index 2900a592b..8d9ad5681 100644 --- a/arrow/array/interval.go +++ b/arrow/array/interval.go @@ -94,8 +94,8 @@ func (a *MonthInterval) setData(data *Data) { } } -func (a *MonthInterval) GetOneForMarshal(i int) interface{} { - if a.IsValid(i) { +func (a *MonthInterval) GetOneForMarshal(i int, nullable bool) interface{} { + if !nullable || a.IsValid(i) { return a.values[i] } return nil @@ -120,9 +120,9 @@ func (a *MonthInterval) MarshalJSON() ([]byte, error) { return json.Marshal(vals) } -func arrayEqualMonthInterval(left, right *MonthInterval) bool { +func arrayEqualMonthInterval(left, right *MonthInterval, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if left.Value(i) != right.Value(i) { @@ -360,7 +360,7 @@ func (a *DayTimeInterval) ValueStr(i int) string { if a.IsNull(i) { return NullValueStr } - data, err := json.Marshal(a.GetOneForMarshal(i)) + data, err := json.Marshal(a.GetOneForMarshal(i, true)) if err != nil { panic(err) } @@ -398,8 +398,8 @@ func (a *DayTimeInterval) setData(data *Data) { } } -func (a *DayTimeInterval) GetOneForMarshal(i int) interface{} { - if a.IsValid(i) { +func (a *DayTimeInterval) GetOneForMarshal(i int, nullable bool) interface{} { + if !nullable || a.IsValid(i) { return a.values[i] } return nil @@ -422,9 +422,9 @@ func (a *DayTimeInterval) MarshalJSON() ([]byte, error) { return json.Marshal(vals) } -func arrayEqualDayTimeInterval(left, right *DayTimeInterval) bool { +func arrayEqualDayTimeInterval(left, right *DayTimeInterval, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if left.Value(i) != right.Value(i) { @@ -661,7 +661,7 @@ func (a *MonthDayNanoInterval) ValueStr(i int) string { if a.IsNull(i) { return NullValueStr } - data, err := json.Marshal(a.GetOneForMarshal(i)) + data, err := json.Marshal(a.GetOneForMarshal(i, true)) if err != nil { panic(err) } @@ -701,8 +701,8 @@ func (a *MonthDayNanoInterval) setData(data *Data) { } } -func (a *MonthDayNanoInterval) GetOneForMarshal(i int) interface{} { - if a.IsValid(i) { +func (a *MonthDayNanoInterval) GetOneForMarshal(i int, nullable bool) interface{} { + if !nullable || a.IsValid(i) { return a.values[i] } return nil @@ -725,9 +725,9 @@ func (a *MonthDayNanoInterval) MarshalJSON() ([]byte, error) { return json.Marshal(vals) } -func arrayEqualMonthDayNanoInterval(left, right *MonthDayNanoInterval) bool { +func arrayEqualMonthDayNanoInterval(left, right *MonthDayNanoInterval, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if left.Value(i) != right.Value(i) { diff --git a/arrow/array/list.go b/arrow/array/list.go index d87cc2662..7ad3c1b21 100644 --- a/arrow/array/list.go +++ b/arrow/array/list.go @@ -61,7 +61,7 @@ func (a *List) ValueStr(i int) string { if !a.IsValid(i) { return NullValueStr } - return string(a.GetOneForMarshal(i).(json.RawMessage)) + return string(a.GetOneForMarshal(i, true).(json.RawMessage)) } func (a *List) String() string { @@ -98,8 +98,8 @@ func (a *List) setData(data *Data) { a.values = MakeFromData(data.childData[0]) } -func (a *List) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { +func (a *List) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && a.IsNull(i) { return nil } @@ -121,7 +121,7 @@ func (a *List) MarshalJSON() ([]byte, error) { if i != 0 { buf.WriteByte(',') } - if err := enc.Encode(a.GetOneForMarshal(i)); err != nil { + if err := enc.Encode(a.GetOneForMarshal(i, true)); err != nil { return nil, err } } @@ -129,9 +129,9 @@ func (a *List) MarshalJSON() ([]byte, error) { return buf.Bytes(), nil } -func arrayEqualList(left, right *List) bool { +func arrayEqualList(left, right *List, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } o := func() bool { @@ -139,7 +139,7 @@ func arrayEqualList(left, right *List) bool { defer l.Release() r := right.newListValue(i) defer r.Release() - return Equal(l, r) + return equal(l, r, opt) }() if !o { return false @@ -193,7 +193,7 @@ func (a *LargeList) ValueStr(i int) string { if !a.IsValid(i) { return NullValueStr } - return string(a.GetOneForMarshal(i).(json.RawMessage)) + return string(a.GetOneForMarshal(i, true).(json.RawMessage)) } func (a *LargeList) String() string { @@ -230,8 +230,8 @@ func (a *LargeList) setData(data *Data) { a.values = MakeFromData(data.childData[0]) } -func (a *LargeList) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { +func (a *LargeList) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && a.IsNull(i) { return nil } @@ -253,7 +253,7 @@ func (a *LargeList) MarshalJSON() ([]byte, error) { if i != 0 { buf.WriteByte(',') } - if err := enc.Encode(a.GetOneForMarshal(i)); err != nil { + if err := enc.Encode(a.GetOneForMarshal(i, true)); err != nil { return nil, err } } @@ -261,9 +261,9 @@ func (a *LargeList) MarshalJSON() ([]byte, error) { return buf.Bytes(), nil } -func arrayEqualLargeList(left, right *LargeList) bool { +func arrayEqualLargeList(left, right *LargeList, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } o := func() bool { @@ -271,7 +271,7 @@ func arrayEqualLargeList(left, right *LargeList) bool { defer l.Release() r := right.newListValue(i) defer r.Release() - return Equal(l, r) + return equal(l, r, opt) }() if !o { return false @@ -663,7 +663,7 @@ func (a *ListView) ValueStr(i int) string { if !a.IsValid(i) { return NullValueStr } - return string(a.GetOneForMarshal(i).(json.RawMessage)) + return string(a.GetOneForMarshal(i, true).(json.RawMessage)) } func (a *ListView) String() string { @@ -704,8 +704,8 @@ func (a *ListView) setData(data *Data) { a.values = MakeFromData(data.childData[0]) } -func (a *ListView) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { +func (a *ListView) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && a.IsNull(i) { return nil } @@ -727,7 +727,7 @@ func (a *ListView) MarshalJSON() ([]byte, error) { if i != 0 { buf.WriteByte(',') } - if err := enc.Encode(a.GetOneForMarshal(i)); err != nil { + if err := enc.Encode(a.GetOneForMarshal(i, true)); err != nil { return nil, err } } @@ -735,9 +735,9 @@ func (a *ListView) MarshalJSON() ([]byte, error) { return buf.Bytes(), nil } -func arrayEqualListView(left, right *ListView) bool { +func arrayEqualListView(left, right *ListView, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } o := func() bool { @@ -745,7 +745,7 @@ func arrayEqualListView(left, right *ListView) bool { defer l.Release() r := right.newListValue(i) defer r.Release() - return Equal(l, r) + return equal(l, r, opt) }() if !o { return false @@ -810,7 +810,7 @@ func (a *LargeListView) ValueStr(i int) string { if !a.IsValid(i) { return NullValueStr } - return string(a.GetOneForMarshal(i).(json.RawMessage)) + return string(a.GetOneForMarshal(i, true).(json.RawMessage)) } func (a *LargeListView) String() string { @@ -851,8 +851,8 @@ func (a *LargeListView) setData(data *Data) { a.values = MakeFromData(data.childData[0]) } -func (a *LargeListView) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { +func (a *LargeListView) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && a.IsNull(i) { return nil } @@ -874,7 +874,7 @@ func (a *LargeListView) MarshalJSON() ([]byte, error) { if i != 0 { buf.WriteByte(',') } - if err := enc.Encode(a.GetOneForMarshal(i)); err != nil { + if err := enc.Encode(a.GetOneForMarshal(i, true)); err != nil { return nil, err } } @@ -882,9 +882,9 @@ func (a *LargeListView) MarshalJSON() ([]byte, error) { return buf.Bytes(), nil } -func arrayEqualLargeListView(left, right *LargeListView) bool { +func arrayEqualLargeListView(left, right *LargeListView, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } o := func() bool { @@ -892,7 +892,7 @@ func arrayEqualLargeListView(left, right *LargeListView) bool { defer l.Release() r := right.newListValue(i) defer r.Release() - return Equal(l, r) + return equal(l, r, opt) }() if !o { return false diff --git a/arrow/array/map.go b/arrow/array/map.go index da9a150ba..35902bd5c 100644 --- a/arrow/array/map.go +++ b/arrow/array/map.go @@ -106,9 +106,9 @@ func (a *Map) Release() { a.items.Release() } -func arrayEqualMap(left, right *Map) bool { +func arrayEqualMap(left, right *Map, opt equalOption) bool { // since Map is implemented using a list, we can just use arrayEqualList - return arrayEqualList(left.List, right.List) + return arrayEqualList(left.List, right.List, opt) } type MapBuilder struct { diff --git a/arrow/array/null.go b/arrow/array/null.go index 02ea12eb7..7d036f3e9 100644 --- a/arrow/array/null.go +++ b/arrow/array/null.go @@ -80,7 +80,7 @@ func (a *Null) setData(data *Data) { a.data.nulls = a.data.length } -func (a *Null) GetOneForMarshal(i int) interface{} { +func (a *Null) GetOneForMarshal(i int, nullable bool) interface{} { return nil } diff --git a/arrow/array/numeric_generic.go b/arrow/array/numeric_generic.go index 874e86d62..e30d3a92d 100644 --- a/arrow/array/numeric_generic.go +++ b/arrow/array/numeric_generic.go @@ -82,8 +82,8 @@ func (a *numericArray[T]) ValueStr(i int) string { return fmt.Sprintf("%v", a.values[i]) } -func (a *numericArray[T]) GetOneForMarshal(i int) any { - if a.IsNull(i) { +func (a *numericArray[T]) GetOneForMarshal(i int, nullable bool) any { + if nullable && a.IsNull(i) { return nil } @@ -106,8 +106,8 @@ type oneByteArrs[T int8 | uint8] struct { numericArray[T] } -func (a *oneByteArrs[T]) GetOneForMarshal(i int) any { - if a.IsNull(i) { +func (a *oneByteArrs[T]) GetOneForMarshal(i int, nullable bool) any { + if nullable && a.IsNull(i) { return nil } @@ -140,8 +140,8 @@ func (a *floatArray[T]) ValueStr(i int) string { return strconv.FormatFloat(float64(a.Value(i)), 'g', -1, bitWidth) } -func (a *floatArray[T]) GetOneForMarshal(i int) any { - if a.IsNull(i) { +func (a *floatArray[T]) GetOneForMarshal(i int, nullable bool) any { + if nullable && a.IsNull(i) { return nil } @@ -159,7 +159,7 @@ func (a *floatArray[T]) GetOneForMarshal(i int) any { func (a *floatArray[T]) MarshalJSON() ([]byte, error) { vals := make([]any, a.Len()) for i := range a.values { - vals[i] = a.GetOneForMarshal(i) + vals[i] = a.GetOneForMarshal(i, true) } return json.Marshal(vals) } @@ -175,7 +175,7 @@ type dateArray[T interface { func (d *dateArray[T]) MarshalJSON() ([]byte, error) { vals := make([]any, d.Len()) for i := range d.values { - vals[i] = d.GetOneForMarshal(i) + vals[i] = d.GetOneForMarshal(i, true) } return json.Marshal(vals) } @@ -188,8 +188,8 @@ func (d *dateArray[T]) ValueStr(i int) string { return d.values[i].FormattedString() } -func (d *dateArray[T]) GetOneForMarshal(i int) interface{} { - if d.IsNull(i) { +func (d *dateArray[T]) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && d.IsNull(i) { return nil } @@ -211,7 +211,7 @@ type timeArray[T interface { func (a *timeArray[T]) MarshalJSON() ([]byte, error) { vals := make([]any, a.Len()) for i := range a.values { - vals[i] = a.GetOneForMarshal(i) + vals[i] = a.GetOneForMarshal(i, true) } return json.Marshal(vals) } @@ -224,8 +224,8 @@ func (a *timeArray[T]) ValueStr(i int) string { return a.values[i].FormattedString(a.DataType().(timeType).TimeUnit()) } -func (a *timeArray[T]) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { +func (a *timeArray[T]) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && a.IsNull(i) { return nil } @@ -245,7 +245,7 @@ func (a *Duration) DurationValues() []arrow.Duration { return a.Values() } func (a *Duration) MarshalJSON() ([]byte, error) { vals := make([]any, a.Len()) for i := range a.values { - vals[i] = a.GetOneForMarshal(i) + vals[i] = a.GetOneForMarshal(i, true) } return json.Marshal(vals) } @@ -258,8 +258,8 @@ func (a *Duration) ValueStr(i int) string { return fmt.Sprintf("%d%s", a.values[i], a.DataType().(timeType).TimeUnit()) } -func (a *Duration) GetOneForMarshal(i int) any { - if a.IsNull(i) { +func (a *Duration) GetOneForMarshal(i int, nullable bool) any { + if nullable && a.IsNull(i) { return nil } return fmt.Sprintf("%d%s", a.values[i], a.DataType().(timeType).TimeUnit()) @@ -405,9 +405,9 @@ func NewDate64Data(data arrow.ArrayData) *Date64 { func (a *Date64) Date64Values() []arrow.Date64 { return a.Values() } -func arrayEqualFixedWidth[T arrow.FixedWidthType](left, right arrow.TypedArray[T]) bool { +func arrayEqualFixedWidth[T arrow.FixedWidthType](left, right arrow.TypedArray[T], opt equalOption) bool { for i := range left.Len() { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if left.Value(i) != right.Value(i) { diff --git a/arrow/array/record.go b/arrow/array/record.go index 0aaf771b5..b71d15af3 100644 --- a/arrow/array/record.go +++ b/arrow/array/record.go @@ -421,10 +421,21 @@ func (b *RecordBuilder) UnmarshalJSON(data []byte) error { } continue } + idx := indices[0] - if err := b.fields[indices[0]].UnmarshalOne(dec); err != nil { + var next json.RawMessage + if err := dec.Decode(&next); err != nil { return err } + + if json.IsNullMessage(next) && !b.schema.Field(idx).Nullable { + b.fields[idx].AppendEmptyValue() + } else { + sub := json.NewDecoder(bytes.NewReader(next)) + if err := b.fields[idx].UnmarshalOne(sub); err != nil { + return err + } + } } for i := 0; i < b.schema.NumFields(); i++ { diff --git a/arrow/array/record_test.go b/arrow/array/record_test.go index 5900efe7f..003d01f53 100644 --- a/arrow/array/record_test.go +++ b/arrow/array/record_test.go @@ -17,6 +17,7 @@ package array_test import ( + "bytes" "fmt" "reflect" "testing" @@ -24,6 +25,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/apache/arrow-go/v18/internal/json" "github.com/stretchr/testify/assert" ) @@ -484,9 +486,9 @@ func TestRecordBuilder(t *testing.T) { mapDt.SetItemNullable(false) schema := arrow.NewSchema( []arrow.Field{ - {Name: "f1-i32", Type: arrow.PrimitiveTypes.Int32}, - {Name: "f2-f64", Type: arrow.PrimitiveTypes.Float64}, - {Name: "map", Type: mapDt}, + {Name: "f1-i32", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "f2-f64-notnull", Type: arrow.PrimitiveTypes.Float64, Nullable: false}, + {Name: "map", Type: mapDt, Nullable: true}, }, nil, ) @@ -497,11 +499,14 @@ func TestRecordBuilder(t *testing.T) { b.Retain() b.Release() - b.Field(0).(*array.Int32Builder).AppendValues([]int32{1, 2, 3}, nil) + b.Field(0).(*array.Int32Builder).AppendNull() + b.Field(0).(*array.Int32Builder).AppendValues([]int32{2, 3}, nil) b.Field(0).(*array.Int32Builder).AppendValues([]int32{4, 5}, nil) - b.Field(1).(*array.Float64Builder).AppendValues([]float64{1, 2, 3, 4, 5}, nil) + + b.Field(1).(*array.Float64Builder).AppendValues([]float64{1.1, 2.2, 3.3, 4.4, 5.5}, nil) + mb := b.Field(2).(*array.MapBuilder) - for i := 0; i < 5; i++ { + for i := range 5 { mb.Append(true) if i%3 == 0 { @@ -510,6 +515,12 @@ func TestRecordBuilder(t *testing.T) { } } + err := b.UnmarshalJSON([]byte(`{"f1-i32": 6, "f2-f64-notnull": 6.6, "map": [{"key": "4": "value": "d"}]}`)) + assert.NoError(t, err) + + err = b.UnmarshalJSON([]byte(`{"f1-i32": null, "f2-f64-notnull": null, "map": null}`)) + assert.NoError(t, err) + rec := b.NewRecordBatch() defer rec.Release() @@ -517,7 +528,7 @@ func TestRecordBuilder(t *testing.T) { t.Fatalf("invalid schema: got=%#v, want=%#v", got, want) } - if got, want := rec.NumRows(), int64(5); got != want { + if got, want := rec.NumRows(), int64(7); got != want { t.Fatalf("invalid number of rows: got=%d, want=%d", got, want) } if got, want := rec.NumCols(), int64(3); got != want { @@ -526,9 +537,27 @@ func TestRecordBuilder(t *testing.T) { if got, want := rec.ColumnName(0), schema.Field(0).Name; got != want { t.Fatalf("invalid column name: got=%q, want=%q", got, want) } - if got, want := rec.Column(2).String(), `[{["0" "2" "3"] ["a" "b" "c"]} {[] []} {[] []} {["3" "2" "3"] ["a" "b" "c"]} {[] []}]`; got != want { - t.Fatalf("invalid column name: got=%q, want=%q", got, want) + + if got, want := rec.Column(0).String(), `[(null) 2 3 4 5 6 (null)]`; got != want { + t.Fatalf("invalid column values: got=%q, want=%q", got, want) + } + if got, want := rec.Column(1).String(), `[1.1 2.2 3.3 4.4 5.5 6.6 0]`; got != want { + t.Fatalf("invalid column values: got=%q, want=%q", got, want) } + if got, want := rec.Column(2).String(), `[{["0" "2" "3"] ["a" "b" "c"]} {[] []} {[] []} {["3" "2" "3"] ["a" "b" "c"]} {[] []} {["4"] ["d"]} (null)]`; got != want { + t.Fatalf("invalid column values: got=%q, want=%q", got, want) + } + + // roundtripping from JSON with array.FromJSON should work + arr := array.RecordToStructArray(rec) + defer arr.Release() + jsonStr, err := json.Marshal(arr) + assert.NoError(t, err) + + roundtripped, _, err := array.FromJSON(mem, arr.DataType(), bytes.NewReader(jsonStr)) + defer roundtripped.Release() + assert.NoError(t, err) + assert.Truef(t, array.Equal(arr, roundtripped), "JSON round trip returns different array: got=%q, want=%d", arr, roundtripped) } type testMessage struct { diff --git a/arrow/array/string.go b/arrow/array/string.go index 60323e371..463c691d7 100644 --- a/arrow/array/string.go +++ b/arrow/array/string.go @@ -151,8 +151,8 @@ func (a *String) setData(data *Data) { } } -func (a *String) GetOneForMarshal(i int) interface{} { - if a.IsValid(i) { +func (a *String) GetOneForMarshal(i int, nullable bool) interface{} { + if !nullable || a.IsValid(i) { return a.Value(i) } return nil @@ -228,9 +228,9 @@ func (a *String) ValidateFull() error { return nil } -func arrayEqualString(left, right *String) bool { +func arrayEqualString(left, right *String, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if left.Value(i) != right.Value(i) { @@ -356,8 +356,8 @@ func (a *LargeString) setData(data *Data) { } } -func (a *LargeString) GetOneForMarshal(i int) interface{} { - if a.IsValid(i) { +func (a *LargeString) GetOneForMarshal(i int, nullable bool) interface{} { + if !nullable || a.IsValid(i) { return a.Value(i) } return nil @@ -366,7 +366,7 @@ func (a *LargeString) GetOneForMarshal(i int) interface{} { func (a *LargeString) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { - vals[i] = a.GetOneForMarshal(i) + vals[i] = a.GetOneForMarshal(i, true) } return json.Marshal(vals) } @@ -429,9 +429,9 @@ func (a *LargeString) ValidateFull() error { return nil } -func arrayEqualLargeString(left, right *LargeString) bool { +func arrayEqualLargeString(left, right *LargeString, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if left.Value(i) != right.Value(i) { @@ -520,8 +520,8 @@ func (a *StringView) ValueStr(i int) string { return a.Value(i) } -func (a *StringView) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { +func (a *StringView) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && a.IsNull(i) { return nil } return a.Value(i) @@ -530,15 +530,15 @@ func (a *StringView) GetOneForMarshal(i int) interface{} { func (a *StringView) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := 0; i < a.Len(); i++ { - vals[i] = a.GetOneForMarshal(i) + vals[i] = a.GetOneForMarshal(i, true) } return json.Marshal(vals) } -func arrayEqualStringView(left, right *StringView) bool { +func arrayEqualStringView(left, right *StringView, opt equalOption) bool { leftBufs, rightBufs := left.dataBuffers, right.dataBuffers for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if !left.ValueHeader(i).Equals(leftBufs, right.ValueHeader(i), rightBufs) { diff --git a/arrow/array/struct.go b/arrow/array/struct.go index 6883712c9..2c443b9eb 100644 --- a/arrow/array/struct.go +++ b/arrow/array/struct.go @@ -130,7 +130,7 @@ func (a *Struct) ValueStr(i int) string { return NullValueStr } - data, err := json.Marshal(a.GetOneForMarshal(i)) + data, err := json.Marshal(a.GetOneForMarshal(i, true)) if err != nil { panic(err) } @@ -201,15 +201,16 @@ func (a *Struct) setData(data *Data) { } } -func (a *Struct) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { +func (a *Struct) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && a.IsNull(i) { return nil } tmp := make(map[string]interface{}) - fieldList := a.data.dtype.(*arrow.StructType).Fields() + dtype := a.data.dtype.(*arrow.StructType) + fieldList := dtype.Fields() for j, d := range a.fields { - tmp[fieldList[j].Name] = d.GetOneForMarshal(i) + tmp[fieldList[j].Name] = d.GetOneForMarshal(i, dtype.Field(j).Nullable) } return tmp } @@ -223,7 +224,7 @@ func (a *Struct) MarshalJSON() ([]byte, error) { if i != 0 { buf.WriteByte(',') } - if err := enc.Encode(a.GetOneForMarshal(i)); err != nil { + if err := enc.Encode(a.GetOneForMarshal(i, true)); err != nil { return nil, err } } @@ -231,10 +232,11 @@ func (a *Struct) MarshalJSON() ([]byte, error) { return buf.Bytes(), nil } -func arrayEqualStruct(left, right *Struct) bool { +func arrayEqualStruct(left, right *Struct, opt equalOption) bool { for i, lf := range left.fields { rf := right.fields[i] - if !Equal(lf, rf) { + opt.nullable = left.data.dtype.(*arrow.StructType).Field(i).Nullable + if !equal(lf, rf, opt) { return false } } @@ -479,9 +481,19 @@ func (b *StructBuilder) UnmarshalOne(dec *json.Decoder) error { continue } - if err := b.fields[idx].UnmarshalOne(dec); err != nil { + var next json.RawMessage + if err := dec.Decode(&next); err != nil { return err } + + if json.IsNullMessage(next) && !b.dtype.(*arrow.StructType).Field(idx).Nullable { + b.fields[idx].AppendEmptyValue() + } else { + sub := json.NewDecoder(bytes.NewReader(next)) + if err := b.fields[idx].UnmarshalOne(sub); err != nil { + return err + } + } } // Append null values to all optional fields that were not presented in the json input diff --git a/arrow/array/struct_test.go b/arrow/array/struct_test.go index 24f522ed1..3e4a17d92 100644 --- a/arrow/array/struct_test.go +++ b/arrow/array/struct_test.go @@ -486,6 +486,12 @@ func TestStructArrayUnmarshalJSONMissingFields(t *testing.T) { panic: false, want: `{[(null)] [3] {[(null)] [(null)] ["test"]}}`, }, + { + name: "explicit null in required field", + jsonInput: `[{"f2": 3, "f3": {"f3_3": null}}]`, + panic: false, + want: `{[(null)] [3] {[(null)] [(null)] [""]}}`, + }, } for _, tc := range tests { diff --git a/arrow/array/timestamp.go b/arrow/array/timestamp.go index d0b5b0626..9e98905fd 100644 --- a/arrow/array/timestamp.go +++ b/arrow/array/timestamp.go @@ -108,8 +108,8 @@ func (a *Timestamp) ValueStr(i int) string { return toTime(a.values[i]).Format(layout) } -func (a *Timestamp) GetOneForMarshal(i int) interface{} { - if val := a.ValueStr(i); val != NullValueStr { +func (a *Timestamp) GetOneForMarshal(i int, nullable bool) interface{} { + if val := a.ValueStr(i); !nullable || val != NullValueStr { return val } return nil @@ -118,15 +118,15 @@ func (a *Timestamp) GetOneForMarshal(i int) interface{} { func (a *Timestamp) MarshalJSON() ([]byte, error) { vals := make([]interface{}, a.Len()) for i := range a.values { - vals[i] = a.GetOneForMarshal(i) + vals[i] = a.GetOneForMarshal(i, true) } return json.Marshal(vals) } -func arrayEqualTimestamp(left, right *Timestamp) bool { +func arrayEqualTimestamp(left, right *Timestamp, opt equalOption) bool { for i := 0; i < left.Len(); i++ { - if left.IsNull(i) { + if opt.nullable && left.IsNull(i) { continue } if left.Value(i) != right.Value(i) { diff --git a/arrow/array/union.go b/arrow/array/union.go index 9c13af05b..027318bfc 100644 --- a/arrow/array/union.go +++ b/arrow/array/union.go @@ -312,17 +312,17 @@ func (a *SparseUnion) setData(data *Data) { debug.Assert(a.data.buffers[0] == nil, "arrow/array: validity bitmap for sparse unions should be nil") } -func (a *SparseUnion) GetOneForMarshal(i int) interface{} { +func (a *SparseUnion) GetOneForMarshal(i int, nullable bool) interface{} { typeID := a.RawTypeCodes()[i] childID := a.ChildID(i) data := a.Field(childID) - if data.IsNull(i) { + if nullable && data.IsNull(i) { return nil } - return []interface{}{typeID, data.GetOneForMarshal(i)} + return []interface{}{typeID, data.GetOneForMarshal(i, nullable)} } func (a *SparseUnion) MarshalJSON() ([]byte, error) { @@ -334,7 +334,7 @@ func (a *SparseUnion) MarshalJSON() ([]byte, error) { if i != 0 { buf.WriteByte(',') } - if err := enc.Encode(a.GetOneForMarshal(i)); err != nil { + if err := enc.Encode(a.GetOneForMarshal(i, true)); err != nil { return nil, err } } @@ -347,7 +347,7 @@ func (a *SparseUnion) ValueStr(i int) string { return NullValueStr } - val := a.GetOneForMarshal(i) + val := a.GetOneForMarshal(i, true) if val == nil { // child is nil return NullValueStr @@ -372,7 +372,7 @@ func (a *SparseUnion) String() string { field := fieldList[a.ChildID(i)] f := a.Field(a.ChildID(i)) - fmt.Fprintf(&b, "{%s=%v}", field.Name, f.GetOneForMarshal(i)) + fmt.Fprintf(&b, "{%s=%v}", field.Name, f.GetOneForMarshal(i, true)) } b.WriteByte(']') return b.String() @@ -587,18 +587,18 @@ func (a *DenseUnion) setData(data *Data) { } } -func (a *DenseUnion) GetOneForMarshal(i int) interface{} { +func (a *DenseUnion) GetOneForMarshal(i int, nullable bool) interface{} { typeID := a.RawTypeCodes()[i] childID := a.ChildID(i) data := a.Field(childID) offset := int(a.RawValueOffsets()[i]) - if data.IsNull(offset) { + if nullable && data.IsNull(offset) { return nil } - return []interface{}{typeID, data.GetOneForMarshal(offset)} + return []interface{}{typeID, data.GetOneForMarshal(offset, nullable)} } func (a *DenseUnion) MarshalJSON() ([]byte, error) { @@ -610,7 +610,7 @@ func (a *DenseUnion) MarshalJSON() ([]byte, error) { if i != 0 { buf.WriteByte(',') } - if err := enc.Encode(a.GetOneForMarshal(i)); err != nil { + if err := enc.Encode(a.GetOneForMarshal(i, true)); err != nil { return nil, err } } @@ -623,7 +623,7 @@ func (a *DenseUnion) ValueStr(i int) string { return NullValueStr } - val := a.GetOneForMarshal(i) + val := a.GetOneForMarshal(i, true) if val == nil { // child in nil return NullValueStr @@ -650,7 +650,7 @@ func (a *DenseUnion) String() string { field := fieldList[a.ChildID(i)] f := a.Field(a.ChildID(i)) - fmt.Fprintf(&b, "{%s=%v}", field.Name, f.GetOneForMarshal(int(offsets[i]))) + fmt.Fprintf(&b, "{%s=%v}", field.Name, f.GetOneForMarshal(int(offsets[i]), true)) } b.WriteByte(']') return b.String() diff --git a/arrow/array/util.go b/arrow/array/util.go index 6a1d29cb0..2b8306082 100644 --- a/arrow/array/util.go +++ b/arrow/array/util.go @@ -287,7 +287,7 @@ func RecordToJSON(rec arrow.RecordBatch, w io.Writer) error { cols := make(map[string]interface{}) for i := 0; int64(i) < rec.NumRows(); i++ { for j, c := range rec.Columns() { - cols[fields[j].Name] = c.GetOneForMarshal(i) + cols[fields[j].Name] = c.GetOneForMarshal(i, rec.Schema().Field(j).Nullable) } if err := enc.Encode(cols); err != nil { return err diff --git a/arrow/array/util_test.go b/arrow/array/util_test.go index fb837871b..f587038fc 100644 --- a/arrow/array/util_test.go +++ b/arrow/array/util_test.go @@ -452,29 +452,50 @@ func TestArrRecordsJSONRoundTrip(t *testing.T) { continue } t.Run(k, func(t *testing.T) { - var buf bytes.Buffer - assert.NotPanics(t, func() { - enc := json.NewEncoder(&buf) - for _, r := range v { - if err := enc.Encode(r); err != nil { - panic(err) - } + for _, nullable := range []bool{true, false} { + var name string + if nullable { + name = "nullable" + } else { + name = "non-nullable" } - }) - - rdr := bytes.NewReader(buf.Bytes()) - var cur int64 - - mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) - defer mem.AssertSize(t, 0) - - for _, r := range v { - rec, off, err := array.RecordFromJSON(mem, r.Schema(), rdr, array.WithStartOffset(cur)) - assert.NoError(t, err) - defer rec.Release() - assert.Truef(t, array.RecordApproxEqual(r, rec), "expected: %s\ngot: %s\n", r, rec) - cur += off + t.Run(name, func(t *testing.T) { + fields := v[0].Schema().Fields() + for i := range fields { + fields[i].Nullable = nullable + } + meta := v[0].Schema().Metadata() + schema := arrow.NewSchema(fields, &meta) + + var buf bytes.Buffer + assert.NotPanics(t, func() { + enc := json.NewEncoder(&buf) + for _, rawBatch := range v { + batch := array.NewRecordBatch(schema, rawBatch.Columns(), rawBatch.NumRows()) + if err := enc.Encode(batch); err != nil { + panic(err) + } + } + }) + + rdr := bytes.NewReader(buf.Bytes()) + var cur int64 + + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + for _, rawBatch := range v { + batch := array.NewRecordBatch(schema, rawBatch.Columns(), rawBatch.NumRows()) + + rec, off, err := array.RecordFromJSON(mem, schema, rdr, array.WithStartOffset(cur)) + assert.NoError(t, err) + defer rec.Release() + + assert.Truef(t, array.RecordApproxEqual(batch, rec), "expected: %s\ngot: %s\n", batch, rec) + cur += off + } + }) } }) } diff --git a/arrow/compute/vector_sort_test.go b/arrow/compute/vector_sort_test.go index 39bf5e95f..5a15428e7 100644 --- a/arrow/compute/vector_sort_test.go +++ b/arrow/compute/vector_sort_test.go @@ -1349,8 +1349,8 @@ func TestVectorSortIndicesCppRecordBatchParity(t *testing.T) { t.Run("NoNull", func(t *testing.T) { schema := arrow.NewSchema([]arrow.Field{ - {Name: "a", Type: arrow.PrimitiveTypes.Uint8}, - {Name: "b", Type: arrow.PrimitiveTypes.Uint32}, + {Name: "a", Type: arrow.PrimitiveTypes.Uint8, Nullable: true}, + {Name: "b", Type: arrow.PrimitiveTypes.Uint32, Nullable: true}, }, nil) jsonRows := `[ {"a": 3, "b": 5}, @@ -1373,8 +1373,8 @@ func TestVectorSortIndicesCppRecordBatchParity(t *testing.T) { t.Run("Null", func(t *testing.T) { schema := arrow.NewSchema([]arrow.Field{ - {Name: "a", Type: arrow.PrimitiveTypes.Uint8}, - {Name: "b", Type: arrow.PrimitiveTypes.Uint32}, + {Name: "a", Type: arrow.PrimitiveTypes.Uint8, Nullable: true}, + {Name: "b", Type: arrow.PrimitiveTypes.Uint32, Nullable: true}, }, nil) jsonRows := `[ {"a": null, "b": 5}, @@ -1396,8 +1396,8 @@ func TestVectorSortIndicesCppRecordBatchParity(t *testing.T) { t.Run("NaN", func(t *testing.T) { schema := arrow.NewSchema([]arrow.Field{ - {Name: "a", Type: arrow.PrimitiveTypes.Float32}, - {Name: "b", Type: arrow.PrimitiveTypes.Float64}, + {Name: "a", Type: arrow.PrimitiveTypes.Float32, Nullable: true}, + {Name: "b", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, }, nil) ba := array.NewFloat32Builder(mem) defer ba.Release() @@ -1426,8 +1426,8 @@ func TestVectorSortIndicesCppRecordBatchParity(t *testing.T) { t.Run("NaNAndNull", func(t *testing.T) { schema := arrow.NewSchema([]arrow.Field{ - {Name: "a", Type: arrow.PrimitiveTypes.Float32}, - {Name: "b", Type: arrow.PrimitiveTypes.Float64}, + {Name: "a", Type: arrow.PrimitiveTypes.Float32, Nullable: true}, + {Name: "b", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, }, nil) ba := array.NewFloat32Builder(mem) defer ba.Release() @@ -1460,8 +1460,8 @@ func TestVectorSortIndicesCppRecordBatchParity(t *testing.T) { t.Run("Boolean", func(t *testing.T) { schema := arrow.NewSchema([]arrow.Field{ - {Name: "a", Type: arrow.FixedWidthTypes.Boolean}, - {Name: "b", Type: arrow.FixedWidthTypes.Boolean}, + {Name: "a", Type: arrow.FixedWidthTypes.Boolean, Nullable: true}, + {Name: "b", Type: arrow.FixedWidthTypes.Boolean, Nullable: true}, }, nil) jsonRows := `[ {"a": true, "b": null}, @@ -1486,9 +1486,9 @@ func TestVectorSortIndicesCppRecordBatchParity(t *testing.T) { ts := &arrow.TimestampType{Unit: arrow.Microsecond} fsb3 := &arrow.FixedSizeBinaryType{ByteWidth: 3} schema := arrow.NewSchema([]arrow.Field{ - {Name: "a", Type: ts}, - {Name: "b", Type: arrow.BinaryTypes.LargeString}, - {Name: "c", Type: fsb3}, + {Name: "a", Type: ts, Nullable: true}, + {Name: "b", Type: arrow.BinaryTypes.LargeString, Nullable: true}, + {Name: "c", Type: fsb3, Nullable: true}, }, nil) ba := array.NewTimestampBuilder(mem, ts) defer ba.Release() @@ -1535,8 +1535,8 @@ func TestVectorSortIndicesCppRecordBatchParity(t *testing.T) { d128 := &arrow.Decimal128Type{Precision: 3, Scale: 1} d256 := &arrow.Decimal256Type{Precision: 4, Scale: 2} schema := arrow.NewSchema([]arrow.Field{ - {Name: "a", Type: d128}, - {Name: "b", Type: d256}, + {Name: "a", Type: d128, Nullable: true}, + {Name: "b", Type: d256, Nullable: true}, }, nil) jsonRows := `[ {"a": "12.3", "b": "12.34"}, @@ -1561,8 +1561,8 @@ func TestVectorSortIndicesCppRecordBatchParity(t *testing.T) { t.Run("DuplicateSortKeys", func(t *testing.T) { schema := arrow.NewSchema([]arrow.Field{ - {Name: "a", Type: arrow.PrimitiveTypes.Float32}, - {Name: "b", Type: arrow.PrimitiveTypes.Float64}, + {Name: "a", Type: arrow.PrimitiveTypes.Float32, Nullable: true}, + {Name: "b", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, }, nil) ba := array.NewFloat32Builder(mem) defer ba.Release() @@ -1610,8 +1610,8 @@ func TestVectorSortIndicesCppTableParity(t *testing.T) { ctx := context.Background() schemaAB := arrow.NewSchema([]arrow.Field{ - {Name: "a", Type: arrow.PrimitiveTypes.Uint8}, - {Name: "b", Type: arrow.PrimitiveTypes.Uint32}, + {Name: "a", Type: arrow.PrimitiveTypes.Uint8, Nullable: true}, + {Name: "b", Type: arrow.PrimitiveTypes.Uint32, Nullable: true}, }, nil) t.Run("EmptyTable", func(t *testing.T) { @@ -1667,8 +1667,8 @@ func TestVectorSortIndicesCppTableParity(t *testing.T) { t.Run("BinaryLikeTwoChunks", func(t *testing.T) { fsb3 := &arrow.FixedSizeBinaryType{ByteWidth: 3} s := arrow.NewSchema([]arrow.Field{ - {Name: "a", Type: arrow.BinaryTypes.LargeString}, - {Name: "b", Type: fsb3}, + {Name: "a", Type: arrow.BinaryTypes.LargeString, Nullable: true}, + {Name: "b", Type: fsb3, Nullable: true}, }, nil) buildBatch := func(a []string, b [][]byte, bNulls []bool) arrow.RecordBatch { ab := array.NewLargeStringBuilder(mem) @@ -1719,8 +1719,8 @@ func TestVectorSortIndicesCppTableParity(t *testing.T) { t.Run("HeterogenousChunking", func(t *testing.T) { s := arrow.NewSchema([]arrow.Field{ - {Name: "a", Type: arrow.PrimitiveTypes.Float32}, - {Name: "b", Type: arrow.PrimitiveTypes.Float64}, + {Name: "a", Type: arrow.PrimitiveTypes.Float32, Nullable: true}, + {Name: "b", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, }, nil) a0, _, err := array.FromJSON(mem, arrow.PrimitiveTypes.Float32, strings.NewReader("[null, 1]")) require.NoError(t, err) diff --git a/arrow/extensions/bool8.go b/arrow/extensions/bool8.go index 97038a1bf..aaf51f0c5 100644 --- a/arrow/extensions/bool8.go +++ b/arrow/extensions/bool8.go @@ -114,8 +114,8 @@ func (a *Bool8Array) MarshalJSON() ([]byte, error) { return json.Marshal(values) } -func (a *Bool8Array) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { +func (a *Bool8Array) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && a.IsNull(i) { return nil } return a.Value(i) diff --git a/arrow/extensions/extensions.go b/arrow/extensions/extensions.go index 04566c750..6f13aa646 100644 --- a/arrow/extensions/extensions.go +++ b/arrow/extensions/extensions.go @@ -21,11 +21,12 @@ import ( ) var canonicalExtensionTypes = []arrow.ExtensionType{ - NewBool8Type(), - NewUUIDType(), - &OpaqueType{}, &JSONType{}, + &OpaqueType{}, + &TimestampWithOffsetType{}, &VariantType{}, + NewBool8Type(), + NewUUIDType(), } func init() { diff --git a/arrow/extensions/json.go b/arrow/extensions/json.go index 3f46b50e3..b9cc50a73 100644 --- a/arrow/extensions/json.go +++ b/arrow/extensions/json.go @@ -116,9 +116,7 @@ func (a *JSONArray) ValueBytes(i int) []byte { return b } -// ValueJSON wraps the underlying string value as a json.RawMessage, -// or returns nil if the array value is null. -func (a *JSONArray) ValueJSON(i int) json.RawMessage { +func (a *JSONArray) valueJSON(i int, nullable bool) json.RawMessage { var val json.RawMessage if a.IsValid(i) { val = json.RawMessage(a.Storage().(array.StringLike).Value(i)) @@ -126,6 +124,12 @@ func (a *JSONArray) ValueJSON(i int) json.RawMessage { return val } +// ValueJSON wraps the underlying string value as a json.RawMessage, +// or returns nil if the array value is null. +func (a *JSONArray) ValueJSON(i int) json.RawMessage { + return a.valueJSON(i, true) +} + // MarshalJSON implements json.Marshaler. // Marshaling json.RawMessage is a no-op, except that nil values will // be marshaled as a JSON null. @@ -138,8 +142,8 @@ func (a *JSONArray) MarshalJSON() ([]byte, error) { } // GetOneForMarshal implements arrow.Array. -func (a *JSONArray) GetOneForMarshal(i int) interface{} { - return a.ValueJSON(i) +func (a *JSONArray) GetOneForMarshal(i int, nullable bool) interface{} { + return a.valueJSON(i, nullable) } var ( diff --git a/arrow/extensions/timestamp_with_offset.go b/arrow/extensions/timestamp_with_offset.go new file mode 100644 index 000000000..471a3117d --- /dev/null +++ b/arrow/extensions/timestamp_with_offset.go @@ -0,0 +1,535 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions + +import ( + "fmt" + "iter" + "math" + "reflect" + "slices" + "strings" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/apache/arrow-go/v18/internal/json" +) + +// TimestampWithOffsetType represents a timestamp column that stores a timezone offset per row instead of +// applying the same timezone offset to the entire column. +type TimestampWithOffsetType struct { + arrow.ExtensionBase +} + +func isOffsetTypeOk(offsetType arrow.DataType) bool { + switch offsetType := offsetType.(type) { + case *arrow.Int16Type: + return true + case *arrow.DictionaryType: + return arrow.TypeEqual(offsetType.ValueType, arrow.PrimitiveTypes.Int16) + case *arrow.RunEndEncodedType: + return offsetType.ValidRunEndsType(offsetType.RunEnds()) && + arrow.TypeEqual(offsetType.Encoded(), arrow.PrimitiveTypes.Int16) + // FIXME: Technically this should be non-nullable, but a Arrow IPC does not deserialize + // ValueNullable properly, so enforcing this here would always fail when reading from an IPC + // stream + // !offsetType.ValueNullable + default: + return false + } +} + +// Whether the storageType is compatible with TimestampWithOffset. +// +// Returns (time_unit, offset_type, ok). If ok is false, time_unit and offset_type are garbage. +func isDataTypeCompatible(storageType arrow.DataType) (unit arrow.TimeUnit, offsetType arrow.DataType, ok bool) { + unit = arrow.Second + offsetType = arrow.PrimitiveTypes.Int16 + ok = false + + st, compat := storageType.(*arrow.StructType) + if !compat || st.NumFields() != 2 { + return + } + + if ts, compat := st.Field(0).Type.(*arrow.TimestampType); compat && ts.TimeZone == "UTC" { + unit = ts.TimeUnit() + } else { + return + } + + maybeOffset := st.Field(1) + offsetType = maybeOffset.Type + + ok = st.Field(0).Name == "timestamp" && + !st.Field(0).Nullable && + maybeOffset.Name == "offset_minutes" && + isOffsetTypeOk(offsetType) && + !maybeOffset.Nullable + return +} + +// NewTimestampWithOffsetType creates a new TimestampWithOffsetType with the underlying storage type set correctly to +// Struct(timestamp=Timestamp(T, "UTC"), offset_minutes=Int16), where T is any TimeUnit. +func NewTimestampWithOffsetType(unit arrow.TimeUnit) *TimestampWithOffsetType { + v, _ := NewTimestampWithOffsetTypeCustomOffset(unit, arrow.PrimitiveTypes.Int16) + // SAFETY: This should never error as Int16 is always a valid offset type + + return v +} + +// NewTimestampWithOffsetTypeCustomOffset creates a new TimestampWithOffsetType with the underlying storage type set correctly to +// Struct(timestamp=Timestamp(T, "UTC"), offset_minutes=O), where T is any TimeUnit and O is a valid offset type. +// +// The error will be populated if the data type is not a valid encoding of the offsets field. +func NewTimestampWithOffsetTypeCustomOffset(unit arrow.TimeUnit, offsetType arrow.DataType) (*TimestampWithOffsetType, error) { + if !isOffsetTypeOk(offsetType) { + return nil, fmt.Errorf("invalid offset type %s", offsetType) + } + + return &TimestampWithOffsetType{ + ExtensionBase: arrow.ExtensionBase{ + Storage: arrow.StructOf( + arrow.Field{ + Name: "timestamp", + Type: &arrow.TimestampType{ + Unit: unit, + TimeZone: "UTC", + }, + Nullable: false, + }, + arrow.Field{ + Name: "offset_minutes", + Type: offsetType, + Nullable: false, + }, + ), + }, + }, nil +} + +type DictIndexType interface { + *arrow.Int8Type | *arrow.Int16Type | *arrow.Int32Type | *arrow.Int64Type | + *arrow.Uint8Type | *arrow.Uint16Type | *arrow.Uint32Type | *arrow.Uint64Type +} + +type RunEndsType interface { + *arrow.Int16Type | *arrow.Int32Type | *arrow.Int64Type +} + +// NewTimestampWithOffsetTypeDictionaryEncoded creates a new TimestampWithOffsetType with the underlying storage type set correctly to +// Struct(timestamp=Timestamp(T, "UTC"), offset_minutes=Dictionary(I, Int16)), where T is any TimeUnit and I is a +// valid Dictionary index type. +func NewTimestampWithOffsetTypeDictionaryEncoded[I DictIndexType](unit arrow.TimeUnit, index I) *TimestampWithOffsetType { + offsetType := arrow.DictionaryType{ + IndexType: arrow.DataType(index), + ValueType: arrow.PrimitiveTypes.Int16, + Ordered: false, + } + v, _ := NewTimestampWithOffsetTypeCustomOffset(unit, &offsetType) + // SAFETY: This should never error as DictIndexType is always a valid index type + + return v +} + +// NewTimestampWithOffsetTypeRunEndEncoded creates a new TimestampWithOffsetType with the underlying storage type set correctly to +// Struct(timestamp=Timestamp(T, "UTC"), offset_minutes=RunEndEncoded(E, Int16)), where T is any TimeUnit and E is a +// valid run-ends type. +func NewTimestampWithOffsetTypeRunEndEncoded[E RunEndsType](unit arrow.TimeUnit, runEnds E) *TimestampWithOffsetType { + offsetType := arrow.RunEndEncodedOf(arrow.DataType(runEnds), arrow.PrimitiveTypes.Int16) + + v, _ := NewTimestampWithOffsetTypeCustomOffset(unit, offsetType) + // SAFETY: This should never error as RunEndsType always a valid run ends type + + return v + +} + +func (b *TimestampWithOffsetType) ArrayType() reflect.Type { + return reflect.TypeOf(TimestampWithOffsetArray{}) +} + +func (b *TimestampWithOffsetType) ExtensionName() string { return "arrow.timestamp_with_offset" } + +func (b *TimestampWithOffsetType) String() string { + return fmt.Sprintf("extension<%s>", b.ExtensionName()) +} + +func (e *TimestampWithOffsetType) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`{"name":"%s","metadata":%s}`, e.ExtensionName(), e.Serialize())), nil +} + +func (b *TimestampWithOffsetType) Serialize() string { return "" } + +func (b *TimestampWithOffsetType) Deserialize(storageType arrow.DataType, data string) (arrow.ExtensionType, error) { + timeUnit, offsetType, ok := isDataTypeCompatible(storageType) + if !ok { + return nil, fmt.Errorf("invalid storage type for TimestampWithOffsetType: %s", storageType.Name()) + } + + return NewTimestampWithOffsetTypeCustomOffset(timeUnit, offsetType) +} + +func (b *TimestampWithOffsetType) ExtensionEquals(other arrow.ExtensionType) bool { + return b.ExtensionName() == other.ExtensionName() && + arrow.TypeEqual(b.StorageType(), other.StorageType()) +} + +func (b *TimestampWithOffsetType) OffsetType() arrow.DataType { + return b.ExtensionBase.Storage.(*arrow.StructType).Field(1).Type +} + +func (b *TimestampWithOffsetType) TimeUnit() arrow.TimeUnit { + return b.ExtensionBase.Storage.(*arrow.StructType).Field(0).Type.(*arrow.TimestampType).TimeUnit() +} + +func (b *TimestampWithOffsetType) NewBuilder(mem memory.Allocator) array.Builder { + v, _ := NewTimestampWithOffsetBuilder(mem, b.TimeUnit(), b.OffsetType()) + return v +} + +// TimestampWithOffsetArray is a simple array of struct +type TimestampWithOffsetArray struct { + array.ExtensionArrayBase +} + +func (a *TimestampWithOffsetArray) String() string { + var o strings.Builder + o.WriteString("[") + for i := 0; i < a.Len(); i++ { + if i > 0 { + o.WriteString(" ") + } + switch { + case a.IsNull(i): + o.WriteString(array.NullValueStr) + default: + fmt.Fprintf(&o, "\"%s\"", a.Value(i)) + } + } + o.WriteString("]") + return o.String() +} + +func timeFromFieldValues(utcTimestamp arrow.Timestamp, offsetMinutes int16, unit arrow.TimeUnit) time.Time { + hours := offsetMinutes / 60 + minutes := offsetMinutes % 60 + if minutes < 0 { + minutes = -minutes + } + + loc := time.FixedZone(fmt.Sprintf("UTC%+03d:%02d", hours, minutes), int(offsetMinutes)*60) + return utcTimestamp.ToTime(unit).In(loc) +} + +func fieldValuesFromTime(t time.Time, unit arrow.TimeUnit) (arrow.Timestamp, int16) { + _, offsetSeconds := t.Zone() + offsetMinutes := int16(offsetSeconds / 60) + ts, _ := arrow.TimestampFromTime(t.UTC(), unit) + return ts, offsetMinutes +} + +// Get the raw arrow values at the given index +// +// SAFETY: the value at i must not be nil +func (a *TimestampWithOffsetArray) rawValueUnsafe(i int) (arrow.Timestamp, int16, arrow.TimeUnit) { + structs := a.Storage().(*array.Struct) + + timestampField := structs.Field(0) + timestamps := timestampField.(*array.Timestamp) + + timeUnit := timestampField.DataType().(*arrow.TimestampType).Unit + utcTimestamp := timestamps.Value(i) + + var offsetMinutes int16 + + switch offsets := structs.Field(1).(type) { + case *array.Int16: + offsetMinutes = offsets.Value(i) + case *array.Dictionary: + offsetMinutes = offsets.Dictionary().(*array.Int16).Value(offsets.GetValueIndex(i)) + case *array.RunEndEncoded: + offsetMinutes = offsets.Values().(*array.Int16).Value(offsets.GetPhysicalIndex(i)) + } + + return utcTimestamp, offsetMinutes, timeUnit +} + +func (a *TimestampWithOffsetArray) Value(i int) time.Time { + if a.IsNull(i) { + return time.Time{} + } + utcTimestamp, offsetMinutes, timeUnit := a.rawValueUnsafe(i) + return timeFromFieldValues(utcTimestamp, offsetMinutes, timeUnit) +} + +// Iterates over the array returning the timestamp at each position. +// +// The second parameter indicates whether the timestamp is valid or not. +// +// This will iterate using the fastest method given the underlying storage array +func (a *TimestampWithOffsetArray) iterValues() iter.Seq2[time.Time, bool] { + return func(yield func(time.Time, bool) bool) { + structs := a.Storage().(*array.Struct) + offsets := structs.Field(1) + if reeOffsets, isRee := offsets.(*array.RunEndEncoded); isRee { + timestampField := structs.Field(0) + timeUnit := timestampField.DataType().(*arrow.TimestampType).Unit + timestamps := timestampField.(*array.Timestamp) + + offsetValues := reeOffsets.Values().(*array.Int16) + offsetPhysicalIdx := 0 + + var getRunEnd (func(int) int) + switch arr := reeOffsets.RunEndsArr().(type) { + case *array.Int16: + getRunEnd = func(idx int) int { return int(arr.Value(idx)) } + case *array.Int32: + getRunEnd = func(idx int) int { return int(arr.Value(idx)) } + case *array.Int64: + getRunEnd = func(idx int) int { return int(arr.Value(idx)) } + } + + for i := 0; i < a.Len(); i++ { + if i >= getRunEnd(offsetPhysicalIdx) { + offsetPhysicalIdx += 1 + } + + var ts time.Time + valid := a.IsValid(i) + if valid { + utcTimestamp := timestamps.Value(i) + offsetMinutes := offsetValues.Value(offsetPhysicalIdx) + v := timeFromFieldValues(utcTimestamp, offsetMinutes, timeUnit) + ts = v + } + + if !yield(ts, valid) { + return + } + } + } else { + for i := 0; i < a.Len(); i++ { + var ts time.Time + valid := a.IsValid(i) + if valid { + utcTimestamp, offsetMinutes, timeUnit := a.rawValueUnsafe(i) + v := timeFromFieldValues(utcTimestamp, offsetMinutes, timeUnit) + ts = v + } + + if !yield(ts, valid) { + return + } + } + } + } +} + +func (a *TimestampWithOffsetArray) Values() []time.Time { + return slices.Collect(func(yield func(time.Time) bool) { + for time := range a.iterValues() { + if !yield(time) { + return + } + } + }) +} + +func (a *TimestampWithOffsetArray) ValueStr(i int) string { + switch { + case a.IsNull(i): + return array.NullValueStr + default: + return a.Value(i).String() + } +} + +func (a *TimestampWithOffsetArray) MarshalJSON() ([]byte, error) { + values := make([]interface{}, a.Len()) + i := 0 + for ts, valid := range a.iterValues() { + if !valid { + values[i] = nil + } else { + values[i] = &ts + } + i += 1 + } + return json.Marshal(values) +} + +func (a *TimestampWithOffsetArray) GetOneForMarshal(i int, nullable bool) interface{} { + if nullable && a.IsNull(i) { + return nil + } + return a.Value(i) +} + +// TimestampWithOffsetBuilder is a convenience builder for the TimestampWithOffset extension type, +// allowing arrays to be built with boolean values rather than the underlying storage type. +type TimestampWithOffsetBuilder struct { + *array.ExtensionBuilder + + // The layout used to parse any timestamps from strings. Defaults to time.RFC3339 + Layout string + unit arrow.TimeUnit + offsetType arrow.DataType + // lastOffset is only used to determine when to start new runs with run-end encoded offsets + lastOffset int16 +} + +// NewTimestampWithOffsetBuilder creates a new TimestampWithOffsetBuilder, exposing a convenient and efficient interface +// for writing time.Time values to the underlying storage array. +func NewTimestampWithOffsetBuilder(mem memory.Allocator, unit arrow.TimeUnit, offsetType arrow.DataType) (*TimestampWithOffsetBuilder, error) { + dataType, err := NewTimestampWithOffsetTypeCustomOffset(unit, offsetType) + if err != nil { + return nil, err + } + + return &TimestampWithOffsetBuilder{ + unit: unit, + offsetType: offsetType, + lastOffset: math.MaxInt16, + Layout: time.RFC3339, + ExtensionBuilder: array.NewExtensionBuilder(mem, dataType), + }, nil +} + +func (b *TimestampWithOffsetBuilder) Append(v time.Time) { + timestamp, offsetMinutes := fieldValuesFromTime(v, b.unit) + offsetMinutes16 := int16(offsetMinutes) + structBuilder := b.Builder.(*array.StructBuilder) + + structBuilder.Append(true) + structBuilder.FieldBuilder(0).(*array.TimestampBuilder).Append(timestamp) + + switch offsets := structBuilder.FieldBuilder(1).(type) { + case *array.Int16Builder: + offsets.Append(offsetMinutes16) + case *array.Int16DictionaryBuilder: + offsets.Append(offsetMinutes16) + case *array.RunEndEncodedBuilder: + if offsetMinutes != b.lastOffset { + offsets.Append(1) + offsets.ValueBuilder().(*array.Int16Builder).Append(offsetMinutes16) + } else { + offsets.ContinueRun(1) + } + + b.lastOffset = offsetMinutes16 + } + +} + +// By default, this will try to parse the string using the RFC3339 layout. +// +// You can change the default layout by using builder.SetLayout() +func (b *TimestampWithOffsetBuilder) AppendValueFromString(s string) error { + if s == array.NullValueStr { + b.AppendNull() + return nil + } + + parsed, err := time.Parse(b.Layout, s) + if err != nil { + return err + } + + b.Append(parsed) + return nil +} + +func (b *TimestampWithOffsetBuilder) AppendValues(values []time.Time, valids []bool) { + structBuilder := b.Builder.(*array.StructBuilder) + timestamps := structBuilder.FieldBuilder(0).(*array.TimestampBuilder) + + structBuilder.AppendValues(valids) + // SAFETY: by this point we know all buffers have available space given the earlier + // call to structBuilder.AppendValues which calls Reserve internally, so it's OK to + // call UnsafeAppend on inner builders + + switch offsets := structBuilder.FieldBuilder(1).(type) { + case *array.Int16Builder: + for _, v := range values { + timestamp, offsetMinutes := fieldValuesFromTime(v, b.unit) + timestamps.UnsafeAppend(timestamp) + offsets.UnsafeAppend(offsetMinutes) + } + case *array.Int16DictionaryBuilder: + for _, v := range values { + timestamp, offsetMinutes := fieldValuesFromTime(v, b.unit) + timestamps.UnsafeAppend(timestamp) + offsets.UnsafeAppend(offsetMinutes) + } + case *array.RunEndEncodedBuilder: + offsetValuesBuilder := offsets.ValueBuilder().(*array.Int16Builder) + for i, v := range values { + timestamp, offsetMinutes := fieldValuesFromTime(v, b.unit) + timestamps.UnsafeAppend(timestamp) + offsetMinutes16 := int16(offsetMinutes) + // If value at i is null, simply continue the run to maximize compression + if valids[i] && offsetMinutes != b.lastOffset { + offsets.Append(1) + offsetValuesBuilder.Append(offsetMinutes16) + } else { + offsets.ContinueRun(1) + } + b.lastOffset = offsetMinutes16 + } + } +} + +func (b *TimestampWithOffsetBuilder) UnmarshalOne(dec *json.Decoder) error { + tok, err := dec.Token() + if err != nil { + return fmt.Errorf("failed to decode json: %w", err) + } + + switch raw := tok.(type) { + case string: + t, err := time.Parse(b.Layout, raw) + if err != nil { + return fmt.Errorf("failed to parse string \"%s\" as time.Time using layout \"%s\"", raw, b.Layout) + } + b.Append(t) + case nil: + b.AppendNull() + default: + return fmt.Errorf("expected date string") + } + + return nil +} + +func (b *TimestampWithOffsetBuilder) Unmarshal(dec *json.Decoder) error { + for dec.More() { + if err := b.UnmarshalOne(dec); err != nil { + return err + } + } + return nil +} + +var ( + _ arrow.ExtensionType = (*TimestampWithOffsetType)(nil) + _ array.CustomExtensionBuilder = (*TimestampWithOffsetType)(nil) + _ array.ExtensionArray = (*TimestampWithOffsetArray)(nil) + _ array.Builder = (*TimestampWithOffsetBuilder)(nil) +) diff --git a/arrow/extensions/timestamp_with_offset_test.go b/arrow/extensions/timestamp_with_offset_test.go new file mode 100644 index 000000000..897a641e5 --- /dev/null +++ b/arrow/extensions/timestamp_with_offset_test.go @@ -0,0 +1,376 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions_test + +import ( + "bytes" + "fmt" + "testing" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/extensions" + "github.com/apache/arrow-go/v18/arrow/ipc" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/apache/arrow-go/v18/internal/json" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var testTimeUnit = arrow.Microsecond + +var epoch = time.Unix(0, 0).In(time.FixedZone("UTC+00:00", 0)) + +var testDate0 = time.Date(2025, 01, 01, 00, 00, 00, 00, time.FixedZone("UTC+00:00", 0)) + +var testZone1 = time.FixedZone("UTC-08:30", -8*60*60-30*60) +var testDate1 = testDate0.In(testZone1) + +var testZone2 = time.FixedZone("UTC+11:00", +11*60*60) +var testDate2 = testDate0.In(testZone2) + +func dict(index arrow.DataType) arrow.DataType { + return &arrow.DictionaryType{ + IndexType: index, + ValueType: arrow.PrimitiveTypes.Int16, + Ordered: false, + } +} + +func ree(runEnds arrow.DataType) arrow.DataType { + v := arrow.RunEndEncodedOf(runEnds, arrow.PrimitiveTypes.Int16) + v.ValueNullable = false + return v +} + +// All tests use this in a for loop to make sure everything works for every possible +// encoding of offsets (primitive, dictionary, run-end) +var allAllowedOffsetTypes = make(map[string]arrow.DataType) + +func init() { + // primitive offsetType + allAllowedOffsetTypes["primitive-int16"] = arrow.PrimitiveTypes.Int16 + + // dict-encoded offsetType + allAllowedOffsetTypes["dict-Uint8"] = dict(arrow.PrimitiveTypes.Uint8) + allAllowedOffsetTypes["dict-Uint16"] = dict(arrow.PrimitiveTypes.Uint16) + allAllowedOffsetTypes["dict-Uint32"] = dict(arrow.PrimitiveTypes.Uint32) + allAllowedOffsetTypes["dict-Uint64"] = dict(arrow.PrimitiveTypes.Uint64) + allAllowedOffsetTypes["dict-Int8"] = dict(arrow.PrimitiveTypes.Int8) + allAllowedOffsetTypes["dict-Int16"] = dict(arrow.PrimitiveTypes.Int16) + allAllowedOffsetTypes["dict-Int32"] = dict(arrow.PrimitiveTypes.Int32) + allAllowedOffsetTypes["dict-Int64"] = dict(arrow.PrimitiveTypes.Int64) + + // run-end encoded offsetType + allAllowedOffsetTypes["ree-Int16"] = ree(arrow.PrimitiveTypes.Int16) + allAllowedOffsetTypes["ree-Int32"] = ree(arrow.PrimitiveTypes.Int32) + allAllowedOffsetTypes["ree-Int64"] = ree(arrow.PrimitiveTypes.Int64) +} + +func TestTimestampWithOffsetTypePrimitiveBasics(t *testing.T) { + typ := extensions.NewTimestampWithOffsetType(testTimeUnit) + + assert.Equal(t, "arrow.timestamp_with_offset", typ.ExtensionName()) + assert.True(t, typ.ExtensionEquals(typ)) + + assert.True(t, arrow.TypeEqual(typ, typ)) + assert.True(t, arrow.TypeEqual( + arrow.StructOf( + arrow.Field{ + Name: "timestamp", + Type: &arrow.TimestampType{ + Unit: testTimeUnit, + TimeZone: "UTC", + }, + Nullable: false, + }, + arrow.Field{ + Name: "offset_minutes", + Type: arrow.PrimitiveTypes.Int16, + Nullable: false, + }, + ), + typ.StorageType())) + + assert.Equal(t, "extension", typ.String()) +} + +func assertDictBasics[I extensions.DictIndexType](t *testing.T, indexType I) { + typ := extensions.NewTimestampWithOffsetTypeDictionaryEncoded(testTimeUnit, indexType) + + assert.Equal(t, "arrow.timestamp_with_offset", typ.ExtensionName()) + assert.True(t, typ.ExtensionEquals(typ)) + + assert.True(t, arrow.TypeEqual(typ, typ)) + assert.True(t, arrow.TypeEqual( + arrow.StructOf( + arrow.Field{ + Name: "timestamp", + Type: &arrow.TimestampType{ + Unit: testTimeUnit, + TimeZone: "UTC", + }, + Nullable: false, + }, + arrow.Field{ + Name: "offset_minutes", + Type: dict(arrow.DataType(indexType)), + Nullable: false, + }, + ), + typ.StorageType())) + + assert.Equal(t, "extension", typ.String()) +} + +func TestTimestampWithOffsetTypeDictionaryEncodedBasics(t *testing.T) { + assertDictBasics(t, &arrow.Uint8Type{}) + assertDictBasics(t, &arrow.Uint16Type{}) + assertDictBasics(t, &arrow.Uint32Type{}) + assertDictBasics(t, &arrow.Uint64Type{}) + assertDictBasics(t, &arrow.Int8Type{}) + assertDictBasics(t, &arrow.Int16Type{}) + assertDictBasics(t, &arrow.Int32Type{}) + assertDictBasics(t, &arrow.Int64Type{}) +} + +func assertReeBasics[E extensions.RunEndsType](t *testing.T, runEndsType E) { + typ := extensions.NewTimestampWithOffsetTypeRunEndEncoded(testTimeUnit, runEndsType) + + assert.Equal(t, "arrow.timestamp_with_offset", typ.ExtensionName()) + assert.True(t, typ.ExtensionEquals(typ)) + + assert.True(t, arrow.TypeEqual(typ, typ)) + assert.True(t, arrow.TypeEqual( + arrow.StructOf( + arrow.Field{ + Name: "timestamp", + Type: &arrow.TimestampType{ + Unit: testTimeUnit, + TimeZone: "UTC", + }, + Nullable: false, + }, + arrow.Field{ + Name: "offset_minutes", + Type: ree(arrow.DataType(runEndsType)), + Nullable: false, + }, + ), + typ.StorageType())) + + assert.Equal(t, "extension", typ.String()) +} + +func TestTimestampWithOffsetTypeRunEndEncodedBasics(t *testing.T) { + assertReeBasics(t, &arrow.Int16Type{}) + assertReeBasics(t, &arrow.Int32Type{}) + assertReeBasics(t, &arrow.Int64Type{}) +} + +func TestTimestampWithOffsetEquals(t *testing.T) { + // Completely different types are not equal + assert.False(t, extensions.NewTimestampWithOffsetType(arrow.Nanosecond).ExtensionEquals(extensions.NewBool8Type())) + + // Different time units are not equal + assert.False(t, extensions.NewTimestampWithOffsetType(arrow.Nanosecond).ExtensionEquals(extensions.NewTimestampWithOffsetType(arrow.Microsecond))) + assert.False(t, extensions.NewTimestampWithOffsetType(arrow.Nanosecond).ExtensionEquals(extensions.NewTimestampWithOffsetType(arrow.Second))) + assert.False(t, extensions.NewTimestampWithOffsetType(arrow.Microsecond).ExtensionEquals(extensions.NewTimestampWithOffsetType(arrow.Second))) + + // Different underlying storage type is not equal + assert.False(t, extensions.NewTimestampWithOffsetType(arrow.Microsecond).ExtensionEquals(extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Int16Type{}))) + assert.False(t, extensions.NewTimestampWithOffsetType(arrow.Microsecond).ExtensionEquals(extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Int16Type{}))) + assert.False(t, extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Int16Type{}).ExtensionEquals(extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Int16Type{}))) + + // Dict-encoding key type is not equal + assert.False(t, extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Int16Type{}).ExtensionEquals(extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Uint16Type{}))) + + // REE index type is not equal + assert.False(t, extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Int16Type{}).ExtensionEquals(extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Int32Type{}))) + + // Equals OK + assert.True(t, extensions.NewTimestampWithOffsetType(arrow.Nanosecond).ExtensionEquals(extensions.NewTimestampWithOffsetType(arrow.Nanosecond))) + assert.True(t, extensions.NewTimestampWithOffsetType(arrow.Microsecond).ExtensionEquals(extensions.NewTimestampWithOffsetType(arrow.Microsecond))) + assert.True(t, extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Int16Type{}).ExtensionEquals(extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Int16Type{}))) + assert.True(t, extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Uint16Type{}).ExtensionEquals(extensions.NewTimestampWithOffsetTypeDictionaryEncoded(arrow.Microsecond, &arrow.Uint16Type{}))) + assert.True(t, extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Int16Type{}).ExtensionEquals(extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Int16Type{}))) + assert.True(t, extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Int32Type{}).ExtensionEquals(extensions.NewTimestampWithOffsetTypeRunEndEncoded(arrow.Microsecond, &arrow.Int32Type{}))) +} + +func TestTimestampWithOffsetExtensionBuilder(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + for name, offsetType := range allAllowedOffsetTypes { + t.Run(name, func(t *testing.T) { + builder, _ := extensions.NewTimestampWithOffsetBuilder(mem, testTimeUnit, offsetType) + + builder.Append(testDate0) + builder.AppendNull() + builder.Append(testDate1) + builder.Append(testDate2) + builder.Append(epoch) + + // it should build the array with the correct size + arr := builder.NewArray() + typedArr := arr.(*extensions.TimestampWithOffsetArray) + assert.Equal(t, 5, arr.Data().Len()) + defer arr.Release() + + // typedArr.Value(i) should return values adjusted for their original timezone + assert.Equal(t, testDate0, typedArr.Value(0)) + assert.Equal(t, testDate1, typedArr.Value(2)) + assert.Equal(t, testDate2, typedArr.Value(3)) + assert.Equal(t, epoch, typedArr.Value(4)) + + // storage TimeUnit should be the same as we pass in to the builder, and storage timezone should be UTC + timestampStructField := typedArr.Storage().(*array.Struct).Field(0) + timestampStructDataType := timestampStructField.DataType().(*arrow.TimestampType) + assert.Equal(t, timestampStructDataType.Unit, testTimeUnit) + assert.Equal(t, timestampStructDataType.TimeZone, "UTC") + + // stored values should be equivalent to the raw values in UTC + timestampsArr := timestampStructField.(*array.Timestamp) + assert.Equal(t, testDate0.In(time.UTC), timestampsArr.Value(0).ToTime(testTimeUnit)) + assert.Equal(t, testDate1.In(time.UTC), timestampsArr.Value(2).ToTime(testTimeUnit)) + assert.Equal(t, testDate2.In(time.UTC), timestampsArr.Value(3).ToTime(testTimeUnit)) + assert.Equal(t, epoch.In(time.UTC), timestampsArr.Value(4).ToTime(testTimeUnit)) + + // the array should encode itself as JSON and string + arrStr := arr.String() + assert.Equal(t, fmt.Sprintf(`["%[1]s" (null) "%[2]s" "%[3]s" "%[4]s"]`, testDate0, testDate1, testDate2, epoch), arrStr) + jsonStr, err := json.Marshal(arr) + assert.NoError(t, err) + + // roundtripping from JSON with array.FromJSON should work + expectedDataType, _ := extensions.NewTimestampWithOffsetTypeCustomOffset(testTimeUnit, offsetType) + roundtripped, _, err := array.FromJSON(mem, expectedDataType, bytes.NewReader(jsonStr)) + defer roundtripped.Release() + assert.NoError(t, err) + assert.Truef(t, array.Equal(arr, roundtripped), "expected %s\n\ngot %s", arr, roundtripped) + }) + } +} + +func TestTimestampWithOffsetExtensionRecordBuilder(t *testing.T) { + for name, offsetType := range allAllowedOffsetTypes { + t.Run(name, func(t *testing.T) { + dataType, _ := extensions.NewTimestampWithOffsetTypeCustomOffset(testTimeUnit, offsetType) + schema := arrow.NewSchema([]arrow.Field{ + { + Name: "timestamp_with_offset", + Nullable: true, + Type: dataType, + }, + }, nil) + builder := array.NewRecordBuilder(memory.DefaultAllocator, schema) + defer builder.Release() + + fieldBuilder := builder.Field(0).(*extensions.TimestampWithOffsetBuilder) + + // append a simple time.Time + fieldBuilder.Append(testDate0) + + // append the epoch + fieldBuilder.Append(epoch) + + // append a null and 2 time.Time all at once + values := []time.Time{ + time.Unix(0, 0).In(time.UTC), + testDate1, + testDate2, + } + valids := []bool{false, true, true} + fieldBuilder.AppendValues(values, valids) + + // append a value from RFC3339 string + fieldBuilder.AppendValueFromString(testDate0.Format(time.RFC3339)) + + // append value formatted in a different string layout + fieldBuilder.Layout = time.RFC3339Nano + fieldBuilder.AppendValueFromString(testDate1.Format(time.RFC3339Nano)) + + record := builder.NewRecordBatch() + + // Record batch should JSON-encode values containing per-row timezone info + json, err := record.MarshalJSON() + require.NoError(t, err) + expect := `[{"timestamp_with_offset":"2025-01-01T00:00:00Z"} +,{"timestamp_with_offset":"1970-01-01T00:00:00Z"} +,{"timestamp_with_offset":null} +,{"timestamp_with_offset":"2024-12-31T15:30:00-08:30"} +,{"timestamp_with_offset":"2025-01-01T11:00:00+11:00"} +,{"timestamp_with_offset":"2025-01-01T00:00:00Z"} +,{"timestamp_with_offset":"2024-12-31T15:30:00-08:30"} +]` + require.Equal(t, expect, string(json)) + + // Record batch roundtrip to JSON should work + roundtripped, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, bytes.NewReader(json)) + require.NoError(t, err) + defer roundtripped.Release() + require.Equal(t, schema, roundtripped.Schema()) + assert.Truef(t, array.RecordEqual(record, roundtripped), "expected %s\n\ngot %s", record, roundtripped) + }) + } +} + +func TestTimestampWithOffsetTypeBatchIPCRoundTrip(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + for name, offsetType := range allAllowedOffsetTypes { + t.Run(name, func(t *testing.T) { + builder, _ := extensions.NewTimestampWithOffsetBuilder(mem, testTimeUnit, offsetType) + builder.Append(testDate0) + builder.AppendNull() + builder.Append(testDate1) + builder.Append(testDate2) + builder.Append(epoch) + arr := builder.NewArray() + defer arr.Release() + + typ, _ := extensions.NewTimestampWithOffsetTypeCustomOffset(testTimeUnit, offsetType) + + batch := array.NewRecordBatch(arrow.NewSchema([]arrow.Field{{Name: "timestamp_with_offset", Type: typ, Nullable: true}}, nil), []arrow.Array{arr}, -1) + defer batch.Release() + + var written arrow.RecordBatch + { + var buf bytes.Buffer + wr := ipc.NewWriter(&buf, ipc.WithSchema(batch.Schema())) + require.NoError(t, wr.Write(batch)) + require.NoError(t, wr.Close()) + + rdr, err := ipc.NewReader(&buf) + require.NoError(t, err) + written, err = rdr.Read() + require.NoError(t, err) + written.Retain() + defer written.Release() + rdr.Release() + } + + assert.Truef(t, batch.Schema().Equal(written.Schema()), "expected: %s\n\ngot: %s", + batch.Schema(), written.Schema()) + + assert.Truef(t, array.RecordEqual(batch, written), "expected: %s\n\ngot: %s", + batch, written) + }) + } +} diff --git a/arrow/extensions/uuid.go b/arrow/extensions/uuid.go index 9aac02253..cbe56ecd6 100644 --- a/arrow/extensions/uuid.go +++ b/arrow/extensions/uuid.go @@ -190,13 +190,13 @@ func (a *UUIDArray) ValueStr(i int) string { func (a *UUIDArray) MarshalJSON() ([]byte, error) { vals := make([]any, a.Len()) for i := range vals { - vals[i] = a.GetOneForMarshal(i) + vals[i] = a.GetOneForMarshal(i, true) } return json.Marshal(vals) } -func (a *UUIDArray) GetOneForMarshal(i int) interface{} { - if a.IsValid(i) { +func (a *UUIDArray) GetOneForMarshal(i int, nullable bool) interface{} { + if !nullable || a.IsValid(i) { return a.Value(i) } return nil diff --git a/arrow/extensions/uuid_test.go b/arrow/extensions/uuid_test.go index a76b77a91..36bb812b9 100644 --- a/arrow/extensions/uuid_test.go +++ b/arrow/extensions/uuid_test.go @@ -62,7 +62,7 @@ func TestUUIDExtensionBuilder(t *testing.T) { func TestUUIDExtensionRecordBuilder(t *testing.T) { schema := arrow.NewSchema([]arrow.Field{ - {Name: "uuid", Type: extensions.NewUUIDType()}, + {Name: "uuid", Type: extensions.NewUUIDType(), Nullable: true}, }, nil) builder := array.NewRecordBuilder(memory.DefaultAllocator, schema) builder.Field(0).(*extensions.UUIDBuilder).Append(testUUID) diff --git a/arrow/extensions/variant.go b/arrow/extensions/variant.go index 659f571c5..35573dcfb 100644 --- a/arrow/extensions/variant.go +++ b/arrow/extensions/variant.go @@ -560,8 +560,8 @@ func (v *VariantArray) MarshalJSON() ([]byte, error) { return json.Marshal(values) } -func (v *VariantArray) GetOneForMarshal(i int) any { - if v.IsNull(i) { +func (v *VariantArray) GetOneForMarshal(i int, nullable bool) any { + if nullable && v.IsNull(i) { return nil } diff --git a/arrow/internal/arrdata/arrdata.go b/arrow/internal/arrdata/arrdata.go index 095571a8f..d95ee7cc7 100644 --- a/arrow/internal/arrdata/arrdata.go +++ b/arrow/internal/arrdata/arrdata.go @@ -192,59 +192,59 @@ func makeStructsRecords() []arrow.RecordBatch { mem := memory.NewGoAllocator() fields := []arrow.Field{ - {Name: "f1", Type: arrow.PrimitiveTypes.Int32}, - {Name: "f2", Type: arrow.BinaryTypes.String}, + {Name: "f1", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "f2", Type: arrow.BinaryTypes.String, Nullable: true}, } dtype := arrow.StructOf(fields...) schema := arrow.NewSchema([]arrow.Field{{Name: "struct_nullable", Type: dtype, Nullable: true}}, nil) - mask := []bool{true, false, false, true, true, true, false, true} + innerValids := []bool{true, false, false, true, true} chunks := [][]arrow.Array{ { structOf(mem, dtype, [][]arrow.Array{ { - arrayOf(mem, []int32{-1, -2, -3, -4, -5}, mask[:5]), - arrayOf(mem, []string{"111", "222", "333", "444", "555"}, mask[:5]), + arrayOf(mem, []int32{-1, -2, -3, -4, -5}, innerValids), + arrayOf(mem, []string{"111", "222", "333", "444", "555"}, innerValids), }, { - arrayOf(mem, []int32{-11, -12, -13, -14, -15}, mask[:5]), - arrayOf(mem, []string{"1111", "1222", "1333", "1444", "1555"}, mask[:5]), + arrayOf(mem, []int32{-11, -12, -13, -14, -15}, innerValids), + arrayOf(mem, []string{"1111", "1222", "1333", "1444", "1555"}, innerValids), }, { - arrayOf(mem, []int32{-21, -22, -23, -24, -25}, mask[:5]), - arrayOf(mem, []string{"2111", "2222", "2333", "2444", "2555"}, mask[:5]), + arrayOf(mem, []int32{-21, -22, -23, -24, -25}, innerValids), + arrayOf(mem, []string{"2111", "2222", "2333", "2444", "2555"}, innerValids), }, { - arrayOf(mem, []int32{-31, -32, -33, -34, -35}, mask[:5]), - arrayOf(mem, []string{"3111", "3222", "3333", "3444", "3555"}, mask[:5]), + arrayOf(mem, []int32{-31, -32, -33, -34, -35}, innerValids), + arrayOf(mem, []string{"3111", "3222", "3333", "3444", "3555"}, innerValids), }, { - arrayOf(mem, []int32{-41, -42, -43, -44, -45}, mask[:5]), - arrayOf(mem, []string{"4111", "4222", "4333", "4444", "4555"}, mask[:5]), + arrayOf(mem, []int32{-41, -42, -43, -44, -45}, innerValids), + arrayOf(mem, []string{"4111", "4222", "4333", "4444", "4555"}, innerValids), }, }, []bool{true, false, true, true, true}), }, { structOf(mem, dtype, [][]arrow.Array{ { - arrayOf(mem, []int32{1, 2, 3, 4, 5}, mask[:5]), - arrayOf(mem, []string{"-111", "-222", "-333", "-444", "-555"}, mask[:5]), + arrayOf(mem, []int32{1, 2, 3, 4, 5}, innerValids), + arrayOf(mem, []string{"-111", "-222", "-333", "-444", "-555"}, innerValids), }, { - arrayOf(mem, []int32{11, 12, 13, 14, 15}, mask[:5]), - arrayOf(mem, []string{"-1111", "-1222", "-1333", "-1444", "-1555"}, mask[:5]), + arrayOf(mem, []int32{11, 12, 13, 14, 15}, innerValids), + arrayOf(mem, []string{"-1111", "-1222", "-1333", "-1444", "-1555"}, innerValids), }, { - arrayOf(mem, []int32{21, 22, 23, 24, 25}, mask[:5]), - arrayOf(mem, []string{"-2111", "-2222", "-2333", "-2444", "-2555"}, mask[:5]), + arrayOf(mem, []int32{21, 22, 23, 24, 25}, innerValids), + arrayOf(mem, []string{"-2111", "-2222", "-2333", "-2444", "-2555"}, innerValids), }, { - arrayOf(mem, []int32{31, 32, 33, 34, 35}, mask[:5]), - arrayOf(mem, []string{"-3111", "-3222", "-3333", "-3444", "-3555"}, mask[:5]), + arrayOf(mem, []int32{31, 32, 33, 34, 35}, innerValids), + arrayOf(mem, []string{"-3111", "-3222", "-3333", "-3444", "-3555"}, innerValids), }, { - arrayOf(mem, []int32{41, 42, 43, 44, 45}, mask[:5]), - arrayOf(mem, []string{"-4111", "-4222", "-4333", "-4444", "-4555"}, mask[:5]), + arrayOf(mem, []int32{41, 42, 43, 44, 45}, innerValids), + arrayOf(mem, []string{"-4111", "-4222", "-4333", "-4444", "-4555"}, innerValids), }, }, []bool{true, false, false, true, true}), }, diff --git a/arrow/internal/arrjson/arrjson_test.go b/arrow/internal/arrjson/arrjson_test.go index 7e2f386fc..faeecfc01 100644 --- a/arrow/internal/arrjson/arrjson_test.go +++ b/arrow/internal/arrjson/arrjson_test.go @@ -948,7 +948,7 @@ func makeStructsWantJSONs() string { "isSigned": true, "bitWidth": 32 }, - "nullable": false, + "nullable": true, "children": [] }, { @@ -956,7 +956,7 @@ func makeStructsWantJSONs() string { "type": { "name": "utf8" }, - "nullable": false, + "nullable": true, "children": [] } ] diff --git a/arrow/ipc/cmd/arrow-ls/main_test.go b/arrow/ipc/cmd/arrow-ls/main_test.go index f90e4a800..0f3b5377c 100644 --- a/arrow/ipc/cmd/arrow-ls/main_test.go +++ b/arrow/ipc/cmd/arrow-ls/main_test.go @@ -59,7 +59,7 @@ records: 3 name: "structs", want: `schema: fields: 1 - - struct_nullable: type=struct, nullable + - struct_nullable: type=struct, nullable records: 2 `, }, @@ -221,7 +221,7 @@ records: 3 name: "structs", want: `schema: fields: 1 - - struct_nullable: type=struct, nullable + - struct_nullable: type=struct, nullable records: 2 `, }, @@ -230,7 +230,7 @@ records: 2 want: `version: V5 schema: fields: 1 - - struct_nullable: type=struct, nullable + - struct_nullable: type=struct, nullable records: 2 `, }, diff --git a/internal/json/json.go b/internal/json/json.go index b4c4c9f6e..7fdd0d863 100644 --- a/internal/json/json.go +++ b/internal/json/json.go @@ -20,6 +20,7 @@ package json import ( + "bytes" "io" "github.com/goccy/go-json" @@ -49,3 +50,7 @@ func NewDecoder(r io.Reader) *Decoder { func NewEncoder(w io.Writer) *Encoder { return json.NewEncoder(w) } + +func IsNullMessage(m RawMessage) bool { + return bytes.Equal(m, []byte("null")) +}