Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions arrow-arith/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,36 @@ impl<T: ArrowNativeTypeOp> NumericAccumulator<T> for SumAccumulator<T> {
}
}

#[derive(Clone, Copy)]
struct ProductAccumulator<T: ArrowNativeTypeOp> {
product: T,
}

impl<T: ArrowNativeTypeOp> Default for ProductAccumulator<T> {
fn default() -> Self {
Self { product: T::ONE }
}
}

impl<T: ArrowNativeTypeOp> NumericAccumulator<T> for ProductAccumulator<T> {
fn accumulate(&mut self, value: T) {
self.product = self.product.mul_wrapping(value);
}

fn accumulate_nullable(&mut self, value: T, valid: bool) {
let product = self.product;
self.product = select(valid, product.mul_wrapping(value), product)
}

fn merge(&mut self, other: Self) {
self.product = self.product.mul_wrapping(other.product);
}

fn finish(&mut self) -> T {
self.product
}
}

#[derive(Clone, Copy)]
struct MinAccumulator<T: ArrowNativeTypeOp> {
min: T,
Expand Down Expand Up @@ -914,6 +944,60 @@ pub fn sum<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native>
aggregate::<T::Native, T, SumAccumulator<T::Native>>(array)
}

/// Returns the product of values in the primitive array.
///
/// Returns `None` if the array is empty or only contains null values.
///
/// This doesn't detect overflow in release mode by default. Once overflowing, the result will
/// wrap around. For an overflow-checking variant, use [`product_checked`] instead.
pub fn product<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native> {
aggregate::<T::Native, T, ProductAccumulator<T::Native>>(array)
}

/// Returns the product of values in the primitive array.
///
/// Returns `Ok(None)` if the array is empty or only contains null values.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use [`product`] instead.
pub fn product_checked<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
) -> Result<Option<T::Native>, ArrowError> {
let null_count = array.null_count();

if null_count == array.len() {
return Ok(None);
}

let data: &[T::Native] = array.values();

match array.nulls() {
None => {
let product = data.iter().try_fold(T::Native::ONE, |accumulator, value| {
accumulator.mul_checked(*value)
})?;

Ok(Some(product))
}
Some(nulls) => {
let mut product = T::Native::ONE;

try_for_each_valid_idx(
nulls.len(),
nulls.offset(),
nulls.null_count(),
Some(nulls.validity()),
|idx| {
unsafe { product = product.mul_checked(array.value_unchecked(idx))? };
Ok::<_, ArrowError>(())
},
)?;

Ok(Some(product))
}
}
}

/// Returns the minimum value in the array, according to the natural order.
/// For floating point arrays any NaN values are considered to be greater than any other non-null value
///
Expand Down Expand Up @@ -963,6 +1047,67 @@ mod tests {
assert_eq!(16.5, sum(&a).unwrap());
}

#[test]
fn test_primitive_array_product() {
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
assert_eq!(120, product(&a).unwrap());
}

#[test]
fn test_primitive_array_float_product() {
let a = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
assert_eq!(120.0, product(&a).unwrap());
}

#[test]
fn test_primitive_array_product_with_nulls() {
let a = Int32Array::from(vec![None, Some(2), Some(3), None, Some(5)]);
assert_eq!(30, product(&a).unwrap());
}

#[test]
fn test_primitive_array_product_all_nulls() {
let a = Int32Array::from(vec![None, None, None]);
assert_eq!(None, product(&a));
}

#[test]
fn test_primitive_array_product_empty() {
let a = Int32Array::from(Vec::<i32>::new());
assert_eq!(None, product(&a));
}

#[test]
fn test_primitive_array_product_checked() {
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
assert_eq!(120, product_checked(&a).unwrap().unwrap());
}

#[test]
fn test_primitive_array_product_checked_with_nulls() {
let a = Int32Array::from(vec![None, Some(2), Some(3), None, Some(5)]);
assert_eq!(30, product_checked(&a).unwrap().unwrap());
}

#[test]
fn test_primitive_array_product_checked_all_nulls() {
let a = Int32Array::from(vec![None, None, None]);
assert_eq!(None, product_checked(&a).unwrap());
}

#[test]
fn test_product_overflow() {
let a = Int32Array::from(vec![i32::MAX, 2]);
// wrapping variant silently overflows
assert_eq!(product(&a).unwrap(), -2);
}

#[test]
fn test_product_checked_overflow() {
let a = Int32Array::from(vec![i32::MAX, 2]);
product_checked(&a).expect_err("overflow should be detected");
}

#[test]
fn test_primitive_array_sum_with_nulls() {
let a = Int32Array::from(vec![None, Some(2), Some(3), None, Some(5)]);
Expand Down
Loading