diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index bb85d7035a4c..aeb6f56f7a5e 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -107,7 +107,7 @@ fn concat_dictionaries( .inspect(|d| output_len += d.len()) .collect(); - if !should_merge_dictionary_values::(&dictionaries, output_len).0 { + if !should_merge_dictionary_values::(&dictionaries).0 { return concat_fallback(arrays, Capacities::Array(output_len)); } @@ -1304,11 +1304,8 @@ mod tests { let actual = collect_string_dictionary(dictionary); assert_eq!(actual, expected); - // Should have concatenated inputs together - assert_eq!( - dictionary.values().len(), - input_1.values().len() + input_2.values().len(), - ) + // Should have merged inputs together (deduplicated) + assert_eq!(dictionary.values().len(), 6) } #[test] @@ -1483,6 +1480,42 @@ mod tests { assert!(!new.values().to_data().ptr_eq(&com.values().to_data())); } + #[test] + fn concat_dictionary_batches_deduplicates_values() { + // Reproducer for https://github.com/apache/arrow-rs/issues/10160 + // Concatenating batches with overlapping dictionary values must + // deduplicate the dictionary entries, not naively concatenate them. + let schema = Arc::new(Schema::new(vec![Field::new( + "symbol", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + )])); + + let batch_0 = { + let dict = DictionaryArray::::try_new( + Int32Array::from(vec![0, 1, 2, 0]), + Arc::new(StringArray::from(vec!["alpha", "beta", "gamma"])), + ) + .unwrap(); + RecordBatch::try_new(schema.clone(), vec![Arc::new(dict)]).unwrap() + }; + + let batch_1 = { + let dict = DictionaryArray::::try_new( + Int32Array::from(vec![2, 1, 0, 2]), + Arc::new(StringArray::from(vec!["gamma", "alpha", "beta"])), + ) + .unwrap(); + RecordBatch::try_new(schema.clone(), vec![Arc::new(dict)]).unwrap() + }; + + let merged = concat_batches(&schema, &[batch_0, batch_1]).unwrap(); + let column = merged.column(0).as_dictionary::(); + + // All 3 unique dictionary values should be preserved (no duplicates) + assert_eq!(column.values().len(), 3); + } + #[test] fn concat_record_batches() { let schema = Arc::new(Schema::new(vec![ diff --git a/arrow-select/src/dictionary.rs b/arrow-select/src/dictionary.rs index 5b32f4e761f8..371f21318b8e 100644 --- a/arrow-select/src/dictionary.rs +++ b/arrow-select/src/dictionary.rs @@ -169,18 +169,12 @@ fn bytes_ptr_eq(a: &dyn Array, b: &dyn Array) -> bool { /// A type-erased function that compares two array for pointer equality type PtrEq = fn(&dyn Array, &dyn Array) -> bool; -/// A weak heuristic of whether to merge dictionary values that aims to only -/// perform the expensive merge computation when it is likely to yield at least -/// some return over the naive approach used by MutableArrayData -/// -/// `len` is the total length of the merged output -/// /// Returns `(should_merge, has_overflow)` where: -/// - `should_merge`: whether dictionary values should be merged +/// - `should_merge`: whether dictionary values should be merged (`true` when dictionaries +/// have different backing arrays, to avoid duplicate values) /// - `has_overflow`: whether the combined dictionary values would overflow the key type pub(crate) fn should_merge_dictionary_values( dictionaries: &[&DictionaryArray], - len: usize, ) -> (bool, bool) { use DataType::*; let first_values = dictionaries[0].values().as_ref(); @@ -202,22 +196,18 @@ pub(crate) fn should_merge_dictionary_values( }; let mut single_dictionary = true; - let mut total_values = first_values.len(); for dict in dictionaries.iter().skip(1) { - let values = dict.values().as_ref(); - total_values += values.len(); if single_dictionary { + let values = dict.values().as_ref(); single_dictionary = ptr_eq(first_values, values) } } - let overflow = K::Native::from_usize(total_values).is_none(); - let values_exceed_length = total_values >= len; + let overflow = + K::Native::from_usize(dictionaries.iter().map(|d| d.values().len()).sum::()) + .is_none(); - ( - !single_dictionary && (overflow || values_exceed_length), - overflow, - ) + (!single_dictionary, overflow) } /// Given an array of dictionaries and an optional key mask compute a values array diff --git a/arrow-select/src/interleave.rs b/arrow-select/src/interleave.rs index c3d47980e3c3..a578c2f535b4 100644 --- a/arrow-select/src/interleave.rs +++ b/arrow-select/src/interleave.rs @@ -252,8 +252,7 @@ fn interleave_dictionaries( indices: &[(usize, usize)], ) -> Result { let dictionaries: Vec<_> = arrays.iter().map(|x| x.as_dictionary::()).collect(); - let (should_merge, has_overflow) = - should_merge_dictionary_values::(&dictionaries, indices.len()); + let (should_merge, has_overflow) = should_merge_dictionary_values::(&dictionaries); if !should_merge { return if has_overflow { interleave_fallback(arrays, indices) @@ -913,11 +912,11 @@ mod tests { let a = DictionaryArray::::from_iter(["a", "b", "c", "a", "b"]); let b = DictionaryArray::::from_iter(["a", "c", "a", "c", "a"]); - // Should not recompute dictionary + // Should merge dictionaries (deduplicate values) let values = interleave(&[&a, &b], &[(0, 2), (0, 2), (0, 2), (1, 0), (1, 1), (0, 1)]).unwrap(); let v = values.as_dictionary::(); - assert_eq!(v.values().len(), 5); + assert_eq!(v.values().len(), 3); let vc = v.downcast_dict::().unwrap(); let collected: Vec<_> = vc.into_iter().map(Option::unwrap).collect();