Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arrow-flight/benches/flight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ mod common;
use common::{TYPES, build_batch, start_server};

const ROWS: [usize; 2] = [8 * 1024, 64 * 1024];
const COLS: [usize; 2] = [1, 8];
const COLS: [usize; 4] = [1, 4, 8, 16];

fn bench_encode(c: &mut Criterion) {
let rt = tokio::runtime::Runtime::new().unwrap();
Expand Down
105 changes: 55 additions & 50 deletions arrow-flight/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
use crate::{FlightData, FlightDescriptor, SchemaAsIpc, error::Result};

use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, UnionArray};
use arrow_ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteContext, IpcWriteOptions};

use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, UnionMode};
use bytes::Bytes;
Expand Down Expand Up @@ -159,16 +159,12 @@ pub struct FlightDataEncoderBuilder {
dictionary_handling: DictionaryHandling,
}

/// Default target size for encoded [`FlightData`].
///
/// Note this value would normally be 4MB, but the size calculation is
/// somewhat inexact, so we set it to 2MB.
pub const GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES: usize = 2097152;

impl Default for FlightDataEncoderBuilder {
fn default() -> Self {
Self {
max_flight_data_size: GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES,
// Default target size for encoded FlightData. gRPC has a default max message
// size of 3MB; we use 3MB so batches are split to stay within that limit.
max_flight_data_size: 3 * 1024 * 1024,
options: IpcWriteOptions::default(),
app_metadata: Bytes::new(),
schema: None,
Expand All @@ -185,7 +181,7 @@ impl FlightDataEncoderBuilder {
}

/// Set the (approximate) maximum size, in bytes, of the
/// [`FlightData`] produced by this encoder. Defaults to 2MB.
/// [`FlightData`] produced by this encoder. Defaults to 3MB.
///
/// Since there is often a maximum message size for gRPC messages
/// (typically around 4MB), this encoder splits up [`RecordBatch`]s
Expand Down Expand Up @@ -329,20 +325,14 @@ impl FlightDataEncoder {
}

/// Place the `FlightData` in the queue to send
#[inline]

@Rich-T-kid Rich-T-kid Jun 12, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiler very likely could have inlined this, but I think its work adding this explicitly.

fn queue_message(&mut self, mut data: FlightData) {
if let Some(descriptor) = self.descriptor.take() {
data.flight_descriptor = Some(descriptor);
}
self.queue.push_back(data);
}

/// Place the `FlightData` in the queue to send
fn queue_messages(&mut self, datas: impl IntoIterator<Item = FlightData>) {
for data in datas {
self.queue_message(data)
}
}

/// Encodes schema as a [`FlightData`] in self.queue.
/// Updates `self.schema` and returns the new schema
fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef {
Expand Down Expand Up @@ -378,14 +368,15 @@ impl FlightDataEncoder {
DictionaryHandling::Resend => batch,
DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?,
};

for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) {
let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?;

self.queue_messages(flight_dictionaries);
let (flight_dictionaries, flight_batch) = self
.encoder
.encode_batch(&batch, self.max_flight_data_size)?;
for dict in flight_dictionaries {
self.queue_message(dict);
}
self.queue_message(flight_batch);
}

Ok(())
}
}
Expand Down Expand Up @@ -671,27 +662,31 @@ fn prepare_schema_for_flight(
fn split_batch_for_grpc_response(
batch: RecordBatch,
max_flight_data_size: usize,
) -> Vec<RecordBatch> {
) -> impl Iterator<Item = RecordBatch> {
let size = batch
.columns()
.iter()
.map(|col| col.get_buffer_memory_size())
.sum::<usize>();

let n_batches =
(size / max_flight_data_size + usize::from(size % max_flight_data_size != 0)).max(1);
let rows_per_batch = (batch.num_rows() / n_batches).max(1);
let mut out = Vec::with_capacity(n_batches + 1);

let num_rows = batch.num_rows();
let rows_per_batch = if size == 0 {
num_rows
} else {
(max_flight_data_size * num_rows / size).max(1)
};
let mut offset = 0;
while offset < batch.num_rows() {
let length = (rows_per_batch).min(batch.num_rows() - offset);
out.push(batch.slice(offset, length));

offset += length;
}

out
std::iter::from_fn(move || {
if offset < num_rows {
let length = rows_per_batch.min(num_rows - offset);
let slice = batch.slice(offset, length);
offset += length;
Some(slice)
} else {
None
}
})
}

/// The data needed to encode a stream of flight data, holding on to
Expand All @@ -704,7 +699,7 @@ struct FlightIpcEncoder {
options: IpcWriteOptions,
data_gen: IpcDataGenerator,
dictionary_tracker: DictionaryTracker,
compression_context: CompressionContext,
compression_context: IpcWriteContext,
}

impl FlightIpcEncoder {
Expand All @@ -713,7 +708,7 @@ impl FlightIpcEncoder {
options,
data_gen: IpcDataGenerator::default(),
dictionary_tracker: DictionaryTracker::new(error_on_replacement),
compression_context: CompressionContext::default(),
compression_context: IpcWriteContext::default(),
}
}

Expand All @@ -724,15 +719,22 @@ impl FlightIpcEncoder {

/// Convert a `RecordBatch` to a Vec of `FlightData` representing
/// dictionaries and a `FlightData` representing the batch
fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec<FlightData>, FlightData)> {
fn encode_batch(
&mut self,
batch: &RecordBatch,
max_flight_data_size: usize,
) -> Result<(impl Iterator<Item = FlightData> + 'static, FlightData)> {
self.compression_context
.scratch
.reserve(max_flight_data_size);
let (encoded_dictionaries, encoded_batch) = self.data_gen.encode(
batch,
&mut self.dictionary_tracker,
&self.options,
&mut self.compression_context,
)?;

let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
let flight_dictionaries = encoded_dictionaries.into_iter().map(|e| e.into());
let flight_batch = encoded_batch.into();

Ok((flight_dictionaries, flight_batch))
Expand Down Expand Up @@ -1813,7 +1815,7 @@ mod tests {
) -> (Vec<FlightData>, FlightData) {
let data_gen = IpcDataGenerator::default();
let mut dictionary_tracker = DictionaryTracker::new(false);
let mut compression_context = CompressionContext::default();
let mut compression_context = IpcWriteContext::default();

let (encoded_dictionaries, encoded_batch) = data_gen
.encode(
Expand All @@ -1838,7 +1840,8 @@ mod tests {
let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
.expect("cannot create record batch");
let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
let split: Vec<_> =
split_batch_for_grpc_response(batch.clone(), max_flight_data_size).collect();
assert_eq!(split.len(), 1);
assert_eq!(batch, split[0]);

Expand All @@ -1848,8 +1851,9 @@ mod tests {
let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::<Vec<_>>());
let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
.expect("cannot create record batch");
let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
assert_eq!(split.len(), 3);
let split: Vec<_> =
split_batch_for_grpc_response(batch.clone(), max_flight_data_size).collect();
assert_eq!(split.len(), 2);
assert_eq!(
split.iter().map(|batch| batch.num_rows()).sum::<usize>(),
n_rows
Expand All @@ -1861,14 +1865,14 @@ mod tests {

#[test]
fn test_split_batch_for_grpc_response_sizes() {
// 2000 8 byte entries into 2k pieces: 8 chunks of 250 rows
verify_split(2000, 2 * 1024, vec![250, 250, 250, 250, 250, 250, 250, 250]);
// 2000 8 byte entries into 2k pieces: fill to limit, last chunk is remainder
verify_split(2000, 2 * 1024, vec![256, 256, 256, 256, 256, 256, 256, 208]);

// 2000 8 byte entries into 4k pieces: 4 chunks of 500 rows
verify_split(2000, 4 * 1024, vec![500, 500, 500, 500]);
// 2000 8 byte entries into 4k pieces: fill to limit, last chunk is remainder
verify_split(2000, 4 * 1024, vec![512, 512, 512, 464]);

// 2023 8 byte entries into 3k pieces does not divide evenly
verify_split(2023, 3 * 1024, vec![337, 337, 337, 337, 337, 337, 1]);
verify_split(2023, 3 * 1024, vec![384, 384, 384, 384, 384, 103]);

// 10 8 byte entries into 1 byte pieces means each rows gets its own
verify_split(10, 1, vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1]);
Expand All @@ -1892,7 +1896,8 @@ mod tests {

let input_rows = batch.num_rows();

let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes);
let split: Vec<_> =
split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes).collect();
let sizes: Vec<_> = split.iter().map(RecordBatch::num_rows).collect();
let output_rows: usize = sizes.iter().sum();

Expand All @@ -1918,7 +1923,7 @@ mod tests {
])
.unwrap();

verify_encoded_split(batch, 120).await;
verify_encoded_split(batch, 152).await;
}

#[tokio::test]
Expand Down Expand Up @@ -1981,7 +1986,7 @@ mod tests {

let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap();

verify_encoded_split(batch, 56).await;
verify_encoded_split(batch, 163).await;
}

#[tokio::test]
Expand Down
4 changes: 2 additions & 2 deletions arrow-flight/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::sync::Arc;
use arrow_array::{ArrayRef, RecordBatch};
use arrow_buffer::Buffer;
use arrow_ipc::convert::fb_to_schema;
use arrow_ipc::writer::CompressionContext;
use arrow_ipc::writer::IpcWriteContext;
use arrow_ipc::{reader, root_as_message, writer, writer::IpcWriteOptions};
use arrow_schema::{ArrowError, Schema, SchemaRef};

Expand Down Expand Up @@ -92,7 +92,7 @@ pub fn batches_to_flight_data(

let data_gen = writer::IpcDataGenerator::default();
let mut dictionary_tracker = writer::DictionaryTracker::new(false);
let mut compression_context = CompressionContext::default();
let mut compression_context = IpcWriteContext::default();

for batch in batches.iter() {
let (encoded_dictionaries, encoded_batch) = data_gen.encode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use arrow::{
datatypes::SchemaRef,
ipc::{
self, reader,
writer::{self, CompressionContext},
writer::{self, IpcWriteContext},
},
record_batch::RecordBatch,
};
Expand Down Expand Up @@ -95,7 +95,7 @@ async fn upload_data(

let mut original_data_iter = original_data.iter().enumerate();

let mut compression_context = CompressionContext::default();
let mut compression_context = IpcWriteContext::default();

if let Some((counter, first_batch)) = original_data_iter.next() {
let metadata = counter.to_string().into_bytes();
Expand Down Expand Up @@ -159,7 +159,7 @@ async fn send_batch(
batch: &RecordBatch,
options: &writer::IpcWriteOptions,
dictionary_tracker: &mut writer::DictionaryTracker,
compression_context: &mut CompressionContext,
compression_context: &mut IpcWriteContext,
) -> Result {
let data_gen = writer::IpcDataGenerator::default();

Expand Down
28 changes: 15 additions & 13 deletions arrow-ipc/src/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,24 @@
use crate::CompressionType;
use arrow_buffer::Buffer;
use arrow_schema::ArrowError;
use flatbuffers::FlatBufferBuilder;

const LENGTH_NO_COMPRESSED_DATA: i64 = -1;
const LENGTH_OF_PREFIX_DATA: i64 = 8;

/// Additional context that may be needed for compression.
///
/// In the case of zstd, this will contain the zstd context, which can be reused between subsequent
/// compression calls to avoid the performance overhead of initialising a new context for every
/// compression.
/// - The flatbuffer builder (`fbb`) is reset and reused across calls.
/// - The zstd compressor (when enabled) is kept alive to avoid re-initialisation overhead.
#[derive(Default)]
pub struct CompressionContext {
pub struct IpcWriteContext {
#[cfg(feature = "zstd")]
compressor: Option<zstd::bulk::Compressor<'static>>,
pub(crate) fbb: FlatBufferBuilder<'static>,
/// Scratch buffer for the IPC arrow data body. When set by the caller before
/// encode(), the existing allocation is reused instead of creating a fresh Vec.
pub scratch: Vec<u8>,
}

impl CompressionContext {
impl IpcWriteContext {
#[cfg(feature = "zstd")]
fn zstd_compressor(&mut self) -> &mut zstd::bulk::Compressor<'static> {
self.compressor.get_or_insert_with(|| {
Expand All @@ -43,9 +45,9 @@ impl CompressionContext {
}
}

impl std::fmt::Debug for CompressionContext {
impl std::fmt::Debug for IpcWriteContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut ds = f.debug_struct("CompressionContext");
let mut ds = f.debug_struct("IpcWriteContext");

#[cfg(feature = "zstd")]
ds.field(
Expand Down Expand Up @@ -143,7 +145,7 @@ impl CompressionCodec {
&self,
input: &[u8],
output: &mut Vec<u8>,
context: &mut CompressionContext,
context: &mut IpcWriteContext,
) -> Result<usize, ArrowError> {
let uncompressed_data_len = input.len();
let original_output_len = output.len();
Expand Down Expand Up @@ -209,7 +211,7 @@ impl CompressionCodec {
&self,
input: &[u8],
output: &mut Vec<u8>,
context: &mut CompressionContext,
context: &mut IpcWriteContext,
) -> Result<(), ArrowError> {
match self {
CompressionCodec::Lz4Frame => compress_lz4(input, output),
Expand Down Expand Up @@ -278,7 +280,7 @@ fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, A
fn compress_zstd(
input: &[u8],
output: &mut Vec<u8>,
context: &mut CompressionContext,
context: &mut IpcWriteContext,
) -> Result<(), ArrowError> {
let result = context.zstd_compressor().compress(input)?;
output.extend_from_slice(&result);
Expand All @@ -290,7 +292,7 @@ fn compress_zstd(
fn compress_zstd(
_input: &[u8],
_output: &mut Vec<u8>,
_context: &mut CompressionContext,
_context: &mut IpcWriteContext,
) -> Result<(), ArrowError> {
Err(ArrowError::InvalidArgumentError(
"zstd IPC compression requires the zstd feature".to_string(),
Expand Down
Loading
Loading