diff --git a/arrow/cdata/cdata.go b/arrow/cdata/cdata.go index 352dfd9a..68c6f7e3 100644 --- a/arrow/cdata/cdata.go +++ b/arrow/cdata/cdata.go @@ -47,12 +47,14 @@ import ( "runtime" "strconv" "strings" + "sync/atomic" "syscall" "unsafe" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/bitutil" + "github.com/apache/arrow-go/v18/arrow/internal/debug" "github.com/apache/arrow-go/v18/arrow/memory" ) @@ -903,18 +905,19 @@ func importCArrayAsType(arr *CArrowArray, dt arrow.DataType) (imp *cimporter, er } func initReader(rdr *nativeCRecordBatchReader, stream *CArrowArrayStream) error { + rdr.refCount.Store(1) rdr.stream = C.get_stream() C.ArrowArrayStreamMove(stream, rdr.stream) rdr.arr = C.get_arr() - runtime.SetFinalizer(rdr, func(r *nativeCRecordBatchReader) { - if r.cur != nil { - r.cur.Release() - } - C.ArrowArrayStreamRelease(r.stream) - C.ArrowArrayRelease(r.arr) - C.free(unsafe.Pointer(r.stream)) - C.free(unsafe.Pointer(r.arr)) - }) + + rdr.cleanUps[0] = runtime.AddCleanup(rdr, func(s *CArrowArrayStream) { + C.ArrowArrayStreamRelease(s) + C.free(unsafe.Pointer(s)) + }, rdr.stream) + rdr.cleanUps[1] = runtime.AddCleanup(rdr, func(a *CArrowArray) { + C.ArrowArrayRelease(a) + C.free(unsafe.Pointer(a)) + }, rdr.arr) var sc CArrowSchema errno := C.stream_get_schema(rdr.stream, &sc) @@ -940,12 +943,32 @@ type nativeCRecordBatchReader struct { cur arrow.RecordBatch err error + + refCount atomic.Int64 + cleanUps [2]runtime.Cleanup +} + +func (n *nativeCRecordBatchReader) Retain() { + n.refCount.Add(1) } -// No need to implement retain and release here as we used runtime.SetFinalizer when constructing -// the reader to free up the ArrowArrayStream memory when the garbage collector cleans it up. -func (n *nativeCRecordBatchReader) Retain() {} -func (n *nativeCRecordBatchReader) Release() {} +func (n *nativeCRecordBatchReader) Release() { + rc := n.refCount.Add(-1) + debug.Assert(rc >= 0, "too many releases") + + if rc == 0 { + n.cleanUps[0].Stop() + n.cleanUps[1].Stop() + if n.cur != nil { + n.cur.Release() + } + + C.ArrowArrayStreamRelease(n.stream) + C.ArrowArrayRelease(n.arr) + C.free(unsafe.Pointer(n.stream)) + C.free(unsafe.Pointer(n.arr)) + } +} func (n *nativeCRecordBatchReader) Err() error { return n.err } func (n *nativeCRecordBatchReader) RecordBatch() arrow.RecordBatch { return n.cur } diff --git a/arrow/cdata/interface.go b/arrow/cdata/interface.go index f776d7f7..a3690662 100644 --- a/arrow/cdata/interface.go +++ b/arrow/cdata/interface.go @@ -189,8 +189,10 @@ func ImportCArrayStream(stream *CArrowArrayStream, schema *arrow.Schema) arrio.R func ImportCRecordReader(stream *CArrowArrayStream, schema *arrow.Schema) (arrio.Reader, error) { out := &nativeCRecordBatchReader{schema: schema} if err := initReader(out, stream); err != nil { + out.Release() return nil, err } + return out, nil }