From 976dab313f8fe1af9aed1bcf0c0eed0d14f541c8 Mon Sep 17 00:00:00 2001 From: devanbenz Date: Tue, 16 Jun 2026 17:57:21 -0500 Subject: [PATCH] feat: Adds product aggregate compute kernel The C++ Arrow implementation has a product op for aggregate compute kernel. This commit adds parity to the C++ impl --- arrow-arith/src/aggregate.rs | 145 +++++++++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs index 8745c779ce0a..2e5713dd7d06 100644 --- a/arrow-arith/src/aggregate.rs +++ b/arrow-arith/src/aggregate.rs @@ -78,6 +78,36 @@ impl NumericAccumulator for SumAccumulator { } } +#[derive(Clone, Copy)] +struct ProductAccumulator { + product: T, +} + +impl Default for ProductAccumulator { + fn default() -> Self { + Self { product: T::ONE } + } +} + +impl NumericAccumulator for ProductAccumulator { + fn accumulate(&mut self, value: T) { + self.product = self.product.mul_wrapping(value); + } + + fn accumulate_nullable(&mut self, value: T, valid: bool) { + let product = self.product; + self.product = select(valid, product.mul_wrapping(value), product) + } + + fn merge(&mut self, other: Self) { + self.product = self.product.mul_wrapping(other.product); + } + + fn finish(&mut self) -> T { + self.product + } +} + #[derive(Clone, Copy)] struct MinAccumulator { min: T, @@ -914,6 +944,60 @@ pub fn sum(array: &PrimitiveArray) -> Option aggregate::>(array) } +/// Returns the product of values in the primitive array. +/// +/// Returns `None` if the array is empty or only contains null values. +/// +/// This doesn't detect overflow in release mode by default. Once overflowing, the result will +/// wrap around. For an overflow-checking variant, use [`product_checked`] instead. +pub fn product(array: &PrimitiveArray) -> Option { + aggregate::>(array) +} + +/// Returns the product of values in the primitive array. +/// +/// Returns `Ok(None)` if the array is empty or only contains null values. +/// +/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, +/// use [`product`] instead. +pub fn product_checked( + array: &PrimitiveArray, +) -> Result, ArrowError> { + let null_count = array.null_count(); + + if null_count == array.len() { + return Ok(None); + } + + let data: &[T::Native] = array.values(); + + match array.nulls() { + None => { + let product = data.iter().try_fold(T::Native::ONE, |accumulator, value| { + accumulator.mul_checked(*value) + })?; + + Ok(Some(product)) + } + Some(nulls) => { + let mut product = T::Native::ONE; + + try_for_each_valid_idx( + nulls.len(), + nulls.offset(), + nulls.null_count(), + Some(nulls.validity()), + |idx| { + unsafe { product = product.mul_checked(array.value_unchecked(idx))? }; + Ok::<_, ArrowError>(()) + }, + )?; + + Ok(Some(product)) + } + } +} + /// Returns the minimum value in the array, according to the natural order. /// For floating point arrays any NaN values are considered to be greater than any other non-null value /// @@ -963,6 +1047,67 @@ mod tests { assert_eq!(16.5, sum(&a).unwrap()); } + #[test] + fn test_primitive_array_product() { + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(120, product(&a).unwrap()); + } + + #[test] + fn test_primitive_array_float_product() { + let a = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); + assert_eq!(120.0, product(&a).unwrap()); + } + + #[test] + fn test_primitive_array_product_with_nulls() { + let a = Int32Array::from(vec![None, Some(2), Some(3), None, Some(5)]); + assert_eq!(30, product(&a).unwrap()); + } + + #[test] + fn test_primitive_array_product_all_nulls() { + let a = Int32Array::from(vec![None, None, None]); + assert_eq!(None, product(&a)); + } + + #[test] + fn test_primitive_array_product_empty() { + let a = Int32Array::from(Vec::::new()); + assert_eq!(None, product(&a)); + } + + #[test] + fn test_primitive_array_product_checked() { + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(120, product_checked(&a).unwrap().unwrap()); + } + + #[test] + fn test_primitive_array_product_checked_with_nulls() { + let a = Int32Array::from(vec![None, Some(2), Some(3), None, Some(5)]); + assert_eq!(30, product_checked(&a).unwrap().unwrap()); + } + + #[test] + fn test_primitive_array_product_checked_all_nulls() { + let a = Int32Array::from(vec![None, None, None]); + assert_eq!(None, product_checked(&a).unwrap()); + } + + #[test] + fn test_product_overflow() { + let a = Int32Array::from(vec![i32::MAX, 2]); + // wrapping variant silently overflows + assert_eq!(product(&a).unwrap(), -2); + } + + #[test] + fn test_product_checked_overflow() { + let a = Int32Array::from(vec![i32::MAX, 2]); + product_checked(&a).expect_err("overflow should be detected"); + } + #[test] fn test_primitive_array_sum_with_nulls() { let a = Int32Array::from(vec![None, Some(2), Some(3), None, Some(5)]);