Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ fn concat_dictionaries<K: ArrowDictionaryKeyType>(
.inspect(|d| output_len += d.len())
.collect();

if !should_merge_dictionary_values::<K>(&dictionaries, output_len).0 {
if !should_merge_dictionary_values::<K>(&dictionaries).0 {
return concat_fallback(arrays, Capacities::Array(output_len));
}

Expand Down Expand Up @@ -1304,11 +1304,8 @@ mod tests {
let actual = collect_string_dictionary(dictionary);
assert_eq!(actual, expected);

// Should have concatenated inputs together
assert_eq!(
dictionary.values().len(),
input_1.values().len() + input_2.values().len(),
)
// Should have merged inputs together (deduplicated)
assert_eq!(dictionary.values().len(), 6)
}

#[test]
Expand Down Expand Up @@ -1483,6 +1480,42 @@ mod tests {
assert!(!new.values().to_data().ptr_eq(&com.values().to_data()));
}

#[test]
fn concat_dictionary_batches_deduplicates_values() {
// Reproducer for https://github.com/apache/arrow-rs/issues/10160
// Concatenating batches with overlapping dictionary values must
// deduplicate the dictionary entries, not naively concatenate them.
let schema = Arc::new(Schema::new(vec![Field::new(
"symbol",
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
false,
)]));

let batch_0 = {
let dict = DictionaryArray::<Int32Type>::try_new(
Int32Array::from(vec![0, 1, 2, 0]),
Arc::new(StringArray::from(vec!["alpha", "beta", "gamma"])),
)
.unwrap();
RecordBatch::try_new(schema.clone(), vec![Arc::new(dict)]).unwrap()
};

let batch_1 = {
let dict = DictionaryArray::<Int32Type>::try_new(
Int32Array::from(vec![2, 1, 0, 2]),
Arc::new(StringArray::from(vec!["gamma", "alpha", "beta"])),
)
.unwrap();
RecordBatch::try_new(schema.clone(), vec![Arc::new(dict)]).unwrap()
};

let merged = concat_batches(&schema, &[batch_0, batch_1]).unwrap();
let column = merged.column(0).as_dictionary::<Int32Type>();

// All 3 unique dictionary values should be preserved (no duplicates)
assert_eq!(column.values().len(), 3);
}

#[test]
fn concat_record_batches() {
let schema = Arc::new(Schema::new(vec![
Expand Down
24 changes: 7 additions & 17 deletions arrow-select/src/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,12 @@ fn bytes_ptr_eq<T: ByteArrayType>(a: &dyn Array, b: &dyn Array) -> bool {
/// A type-erased function that compares two array for pointer equality
type PtrEq = fn(&dyn Array, &dyn Array) -> bool;

/// A weak heuristic of whether to merge dictionary values that aims to only
/// perform the expensive merge computation when it is likely to yield at least
/// some return over the naive approach used by MutableArrayData
///
/// `len` is the total length of the merged output
///
/// Returns `(should_merge, has_overflow)` where:
/// - `should_merge`: whether dictionary values should be merged
/// - `should_merge`: whether dictionary values should be merged (`true` when dictionaries
/// have different backing arrays, to avoid duplicate values)
/// - `has_overflow`: whether the combined dictionary values would overflow the key type
pub(crate) fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
dictionaries: &[&DictionaryArray<K>],
len: usize,
) -> (bool, bool) {
use DataType::*;
let first_values = dictionaries[0].values().as_ref();
Expand All @@ -202,22 +196,18 @@ pub(crate) fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
};

let mut single_dictionary = true;
let mut total_values = first_values.len();
for dict in dictionaries.iter().skip(1) {
let values = dict.values().as_ref();
total_values += values.len();
if single_dictionary {
let values = dict.values().as_ref();
single_dictionary = ptr_eq(first_values, values)
}
}

let overflow = K::Native::from_usize(total_values).is_none();
let values_exceed_length = total_values >= len;
let overflow =
K::Native::from_usize(dictionaries.iter().map(|d| d.values().len()).sum::<usize>())
.is_none();

(
!single_dictionary && (overflow || values_exceed_length),
overflow,
)
(!single_dictionary, overflow)
}

/// Given an array of dictionaries and an optional key mask compute a values array
Expand Down
7 changes: 3 additions & 4 deletions arrow-select/src/interleave.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,7 @@ fn interleave_dictionaries<K: ArrowDictionaryKeyType>(
indices: &[(usize, usize)],
) -> Result<ArrayRef, ArrowError> {
let dictionaries: Vec<_> = arrays.iter().map(|x| x.as_dictionary::<K>()).collect();
let (should_merge, has_overflow) =
should_merge_dictionary_values::<K>(&dictionaries, indices.len());
let (should_merge, has_overflow) = should_merge_dictionary_values::<K>(&dictionaries);
if !should_merge {
return if has_overflow {
interleave_fallback(arrays, indices)
Expand Down Expand Up @@ -913,11 +912,11 @@ mod tests {
let a = DictionaryArray::<Int32Type>::from_iter(["a", "b", "c", "a", "b"]);
let b = DictionaryArray::<Int32Type>::from_iter(["a", "c", "a", "c", "a"]);

// Should not recompute dictionary
// Should merge dictionaries (deduplicate values)
let values =
interleave(&[&a, &b], &[(0, 2), (0, 2), (0, 2), (1, 0), (1, 1), (0, 1)]).unwrap();
let v = values.as_dictionary::<Int32Type>();
assert_eq!(v.values().len(), 5);
assert_eq!(v.values().len(), 3);

let vc = v.downcast_dict::<StringArray>().unwrap();
let collected: Vec<_> = vc.into_iter().map(Option::unwrap).collect();
Expand Down