diff --git a/pulsar-common/src/main/java/org/apache/pulsar/common/util/collections/Long2LongMap.java b/pulsar-common/src/main/java/org/apache/pulsar/common/util/collections/Long2LongMap.java new file mode 100644 index 0000000000000..6806fac7f305f --- /dev/null +++ b/pulsar-common/src/main/java/org/apache/pulsar/common/util/collections/Long2LongMap.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pulsar.common.util.collections; + +import java.util.function.LongUnaryOperator; + +/** + * A map with primitive {@code long} keys and primitive {@code long} values. + * + *

The default return value for missing keys is {@code 0}. Use {@link #getOrDefault(long, long)} + * or {@link #containsKey(long)} when {@code 0} is a valid mapped value. + */ +public interface Long2LongMap { + + @FunctionalInterface + interface EntryConsumer { + void accept(long key, long value); + } + + @FunctionalInterface + interface EntryPredicate { + boolean test(long key, long value); + } + + /** + * Returns the value for the given key, or {@code 0} if not present. + * + * @param key the key + * @return the mapped value, or {@code 0} + */ + long get(long key); + + /** + * Associates the given value with the given key. + * + * @param key the key + * @param value the value + * @return the previous value, or {@code 0} if there was no mapping + */ + long put(long key, long value); + + /** + * Removes the mapping for the given key. + * + * @param key the key + * @return the previous value, or {@code 0} if there was no mapping + */ + long remove(long key); + + /** + * Returns the value for the given key, or the specified default if not present. + * + * @param key the key + * @param defaultValue the default value to return if the key is absent + * @return the mapped value, or {@code defaultValue} + */ + long getOrDefault(long key, long defaultValue); + + /** + * If the key is not already present, computes its value using the given function and inserts it. + * + * @param key the key + * @param mappingFunction the function to compute a value + * @return the current (existing or computed) value + */ + long computeIfAbsent(long key, LongUnaryOperator mappingFunction); + + /** + * Returns {@code true} if this map contains the given key. + * + * @param key the key + * @return {@code true} if this map contains the key + */ + boolean containsKey(long key); + + /** + * Returns {@code true} if this map contains no entries. + */ + boolean isEmpty(); + + /** + * Returns the number of entries in this map. + */ + int size(); + + /** + * Removes all entries from this map. + */ + void clear(); + + /** + * Iterates over all entries, calling the consumer with primitive long keys and values. + * + * @param consumer the consumer to call for each entry + */ + void forEach(EntryConsumer consumer); + + /** + * Removes each entry that matches the predicate. + * + * @param predicate the predicate to test entries + * @return the number of removed entries + */ + int removeIf(EntryPredicate predicate); +} diff --git a/pulsar-common/src/main/java/org/apache/pulsar/common/util/collections/Long2LongOpenHashMap.java b/pulsar-common/src/main/java/org/apache/pulsar/common/util/collections/Long2LongOpenHashMap.java new file mode 100644 index 0000000000000..ffe8649ef5d25 --- /dev/null +++ b/pulsar-common/src/main/java/org/apache/pulsar/common/util/collections/Long2LongOpenHashMap.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pulsar.common.util.collections; + +import java.util.Arrays; +import java.util.function.LongUnaryOperator; + +/** + * Open-addressing hash map with primitive long keys and primitive long values. + * Uses linear probing and fibonacci hashing. + * Returns 0 for missing keys; use getOrDefault or containsKey when 0 is a valid mapped value. + * Not thread-safe. + */ +public class Long2LongOpenHashMap implements Long2LongMap { + + private static final float LOAD_FACTOR = 0.75f; + private static final int MIN_CAPACITY = 16; + + private long[] keys; + private long[] values; + private boolean[] used; + private int size; + private int capacity; + private int threshold; + + public Long2LongOpenHashMap() { + this(MIN_CAPACITY); + } + + public Long2LongOpenHashMap(int expectedItems) { + int cap = tableSizeFor(Math.max(MIN_CAPACITY, (int) (expectedItems / LOAD_FACTOR) + 1)); + keys = new long[cap]; + values = new long[cap]; + used = new boolean[cap]; + capacity = cap; + threshold = (int) (cap * LOAD_FACTOR); + } + + @Override + public long get(long key) { + int idx = indexOf(key); + return idx >= 0 ? values[idx] : 0; + } + + @Override + public long put(long key, long value) { + int idx = indexOf(key); + if (idx >= 0) { + long old = values[idx]; + values[idx] = value; + return old; + } + if (size >= threshold) { + rehash(capacity * 2); + } + insertNew(key, value); + return 0; + } + + @Override + public long remove(long key) { + int idx = indexOf(key); + if (idx < 0) { + return 0; + } + long old = values[idx]; + removeAt(idx); + return old; + } + + @Override + public long getOrDefault(long key, long defaultValue) { + int idx = indexOf(key); + return idx >= 0 ? values[idx] : defaultValue; + } + + @Override + public long computeIfAbsent(long key, LongUnaryOperator mappingFunction) { + int idx = indexOf(key); + if (idx >= 0) { + return values[idx]; + } + long value = mappingFunction.applyAsLong(key); + if (size >= threshold) { + rehash(capacity * 2); + } + insertNew(key, value); + return value; + } + + @Override + public boolean containsKey(long key) { + return indexOf(key) >= 0; + } + + @Override + public boolean isEmpty() { + return size == 0; + } + + @Override + public int size() { + return size; + } + + @Override + public void clear() { + if (size > 0) { + Arrays.fill(used, false); + size = 0; + } + } + + @Override + public void forEach(EntryConsumer consumer) { + for (int i = 0; i < capacity; i++) { + if (used[i]) { + consumer.accept(keys[i], values[i]); + } + } + } + + @Override + public int removeIf(EntryPredicate predicate) { + int removed = 0; + for (int i = 0; i < capacity;) { + if (!used[i]) { + i++; + continue; + } + if (predicate.test(keys[i], values[i])) { + removeAt(i); + removed++; + } else { + i++; + } + } + return removed; + } + + private int indexOf(long key) { + int mask = capacity - 1; + int idx = Long2ObjectOpenHashMap.hash(key) & mask; + while (true) { + if (!used[idx]) { + return -1; + } + if (keys[idx] == key) { + return idx; + } + idx = (idx + 1) & mask; + } + } + + private void insertNew(long key, long value) { + int mask = capacity - 1; + int idx = Long2ObjectOpenHashMap.hash(key) & mask; + while (used[idx]) { + idx = (idx + 1) & mask; + } + keys[idx] = key; + values[idx] = value; + used[idx] = true; + size++; + } + + private void removeAt(int idx) { + int mask = capacity - 1; + size--; + int next = (idx + 1) & mask; + while (used[next]) { + int naturalSlot = Long2ObjectOpenHashMap.hash(keys[next]) & mask; + if ((next > idx && (naturalSlot <= idx || naturalSlot > next)) + || (next < idx && (naturalSlot <= idx && naturalSlot > next))) { + keys[idx] = keys[next]; + values[idx] = values[next]; + idx = next; + } + next = (next + 1) & mask; + } + used[idx] = false; + } + + private void rehash(int newCapacity) { + long[] oldKeys = keys; + long[] oldValues = values; + boolean[] oldUsed = used; + int oldCapacity = capacity; + + capacity = newCapacity; + keys = new long[newCapacity]; + values = new long[newCapacity]; + used = new boolean[newCapacity]; + threshold = (int) (newCapacity * LOAD_FACTOR); + size = 0; + + for (int i = 0; i < oldCapacity; i++) { + if (oldUsed[i]) { + insertNew(oldKeys[i], oldValues[i]); + } + } + } + + private static int tableSizeFor(int cap) { + int n = cap - 1; + n |= n >>> 1; + n |= n >>> 2; + n |= n >>> 4; + n |= n >>> 8; + n |= n >>> 16; + return (n < MIN_CAPACITY) ? MIN_CAPACITY : n + 1; + } +} diff --git a/pulsar-common/src/test/java/org/apache/pulsar/common/util/collections/Long2LongOpenHashMapTest.java b/pulsar-common/src/test/java/org/apache/pulsar/common/util/collections/Long2LongOpenHashMapTest.java new file mode 100644 index 0000000000000..c2c6e428db5e6 --- /dev/null +++ b/pulsar-common/src/test/java/org/apache/pulsar/common/util/collections/Long2LongOpenHashMapTest.java @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pulsar.common.util.collections; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; +import org.testng.Reporter; +import org.testng.annotations.Test; + +public class Long2LongOpenHashMapTest { + + @Test + public void testEmpty() { + Long2LongOpenHashMap map = new Long2LongOpenHashMap(); + assertTrue(map.isEmpty()); + assertEquals(map.size(), 0); + assertEquals(map.get(0), 0L); + assertFalse(map.containsKey(0)); + } + + @Test + public void testPutGet() { + Long2LongOpenHashMap map = new Long2LongOpenHashMap(); + assertEquals(map.put(1, 10), 0L); + assertEquals(map.put(2, Long.MAX_VALUE), 0L); + assertFalse(map.isEmpty()); + assertEquals(map.size(), 2); + assertTrue(map.containsKey(1)); + assertEquals(map.get(1), 10L); + assertEquals(map.get(2), Long.MAX_VALUE); + assertEquals(map.get(3), 0L); + } + + @Test + public void testPutReplace() { + Long2LongOpenHashMap map = new Long2LongOpenHashMap(); + map.put(1, 10); + assertEquals(map.put(1, 100), 10L); + assertEquals(map.get(1), 100L); + assertEquals(map.size(), 1); + } + + @Test + public void testRemove() { + Long2LongOpenHashMap map = new Long2LongOpenHashMap(); + map.put(1, 10); + map.put(2, 20); + assertEquals(map.remove(1), 10L); + assertFalse(map.containsKey(1)); + assertEquals(map.get(1), 0L); + assertEquals(map.remove(99), 0L); + assertEquals(map.size(), 1); + } + + @Test + public void testGetOrDefault() { + Long2LongOpenHashMap map = new Long2LongOpenHashMap(); + map.put(1, 10); + assertEquals(map.getOrDefault(1, -1), 10L); + assertEquals(map.getOrDefault(2, -1), -1L); + } + + @Test + public void testZeroValueCanBeDistinguishedFromMissingKey() { + Long2LongOpenHashMap map = new Long2LongOpenHashMap(); + map.put(1, 0); + + assertTrue(map.containsKey(1)); + assertEquals(map.get(1), 0L); + assertEquals(map.getOrDefault(1, -1), 0L); + assertFalse(map.containsKey(2)); + assertEquals(map.get(2), 0L); + assertEquals(map.getOrDefault(2, -1), -1L); + } + + @Test + public void testEdgeKeysAndValuesRoundTrip() { + Long2LongOpenHashMap map = new Long2LongOpenHashMap(4); + Map expected = new HashMap<>(); + long[][] entries = { + {0L, 0L}, + {Long.MIN_VALUE, Long.MIN_VALUE}, + {Long.MAX_VALUE, Long.MAX_VALUE}, + {-1L, 1L}, + {1L, -1L}, + {Long.MIN_VALUE + 1, Long.MAX_VALUE - 1}, + {Long.MAX_VALUE - 1, Long.MIN_VALUE + 1} + }; + + for (long[] entry : entries) { + expected.put(entry[0], entry[1]); + assertEquals(map.put(entry[0], entry[1]), 0L); + } + + assertLong2LongMapMatches(expected, expected.keySet(), map, "edge values"); + } + + @Test + public void testComputeIfAbsent() { + Long2LongOpenHashMap map = new Long2LongOpenHashMap(); + assertEquals(map.computeIfAbsent(1, k -> 10), 10L); + assertEquals(map.computeIfAbsent(1, k -> 99), 10L); + } + + @Test + public void testClear() { + Long2LongOpenHashMap map = new Long2LongOpenHashMap(); + map.put(1, 10); + map.put(2, 20); + map.clear(); + assertTrue(map.isEmpty()); + assertEquals(map.size(), 0); + assertEquals(map.get(1), 0L); + } + + @Test + public void testForEach() { + Long2LongOpenHashMap map = new Long2LongOpenHashMap(); + map.put(1, 10); + map.put(2, 20); + Map values = new HashMap<>(); + + map.forEach(values::put); + + assertEquals(values, Map.of(1L, 10L, 2L, 20L)); + } + + @Test + public void testRemoveIf() { + Long2LongOpenHashMap map = new Long2LongOpenHashMap(); + for (int i = 0; i < 100; i++) { + map.put(i, i * 10L); + } + + int removed = map.removeIf((key, value) -> key % 2 == 0); + + assertEquals(removed, 50); + assertEquals(map.size(), 50); + for (int i = 0; i < 100; i++) { + assertEquals(map.containsKey(i), i % 2 != 0); + } + } + + @Test + public void testRehash() { + Long2LongOpenHashMap map = new Long2LongOpenHashMap(4); + for (int i = 0; i < 100; i++) { + map.put(i, i * 10L); + } + assertEquals(map.size(), 100); + for (int i = 0; i < 100; i++) { + assertEquals(map.get(i), i * 10L); + } + } + + @Test + public void testRemovePreservesProbeChainWithCollisions() { + Long2LongOpenHashMap map = new Long2LongOpenHashMap(4); + List keys = collidingLongKeys(16, 12); + + for (int i = 0; i < keys.size(); i++) { + assertEquals(map.put(keys.get(i), valueForIndex(i)), 0L); + } + + assertEquals(map.remove(keys.get(0)), valueForIndex(0)); + assertEquals(map.remove(keys.get(5)), valueForIndex(5)); + assertEquals(map.remove(keys.get(11)), valueForIndex(11)); + + for (int i = 1; i < keys.size() - 1; i++) { + long key = keys.get(i); + if (i != 5) { + assertEquals(map.get(key), valueForIndex(i)); + assertTrue(map.containsKey(key)); + } + } + assertFalse(map.containsKey(keys.get(0))); + assertFalse(map.containsKey(keys.get(5))); + assertFalse(map.containsKey(keys.get(11))); + + assertEquals(map.put(keys.get(5), Long.MIN_VALUE), 0L); + assertEquals(map.getOrDefault(keys.get(5), -1L), Long.MIN_VALUE); + } + + @Test + public void testRandomizedOperationsAgainstHashMap() { + Long2LongOpenHashMap map = new Long2LongOpenHashMap(4); + Map expected = new HashMap<>(); + Set seenKeys = new HashSet<>(); + long seed = randomSeed("testRandomizedOperationsAgainstHashMap"); + Random random = new Random(seed); + + for (int i = 0; i < 20_000; i++) { + long key = randomLongWithEdgeCases(random, 512); + seenKeys.add(key); + String context = "seed=" + seed + " iteration=" + i + " key=" + key; + + switch (random.nextInt(100)) { + case 0 -> { + long value = randomValue(random); + Long previous = expected.put(key, value); + assertEquals(map.put(key, value), previous == null ? 0L : previous.longValue(), context); + } + case 1 -> { + Long previous = expected.remove(key); + assertEquals(map.remove(key), previous == null ? 0L : previous.longValue(), context); + } + case 2 -> { + long value = randomValue(random); + Long previous = expected.get(key); + assertEquals(map.computeIfAbsent(key, ignored -> value), + previous == null ? value : previous.longValue(), context); + expected.putIfAbsent(key, value); + } + case 3 -> assertEquals(map.get(key), expected.getOrDefault(key, 0L).longValue(), context); + case 4 -> { + long defaultValue = randomValue(random); + assertEquals(map.getOrDefault(key, defaultValue), + expected.getOrDefault(key, defaultValue).longValue(), context); + } + case 5 -> assertEquals(map.containsKey(key), expected.containsKey(key), context); + case 6 -> runRemoveIfScenario(map, expected, random, context); + case 7 -> { + map.clear(); + expected.clear(); + } + default -> { + long otherKey = randomLongWithEdgeCases(random, 512); + seenKeys.add(otherKey); + long value = randomValue(random); + Long previous = expected.put(otherKey, value); + assertEquals(map.put(otherKey, value), previous == null ? 0L : previous.longValue(), + context + " otherKey=" + otherKey); + } + } + + if (i % 257 == 0) { + runRemoveIfScenario(map, expected, random, context + " periodicRemoveIf"); + } + + assertLong2LongMapMatches(expected, seenKeys, map, context); + } + } + + private static void runRemoveIfScenario(Long2LongOpenHashMap map, Map expected, Random random, + String context) { + int selector = random.nextInt(4); + int removed = map.removeIf((entryKey, value) -> removeIfMatches(selector, entryKey, value)); + + int expectedRemoved = 0; + Iterator> iterator = expected.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + if (removeIfMatches(selector, entry.getKey(), entry.getValue())) { + iterator.remove(); + expectedRemoved++; + } + } + assertEquals(removed, expectedRemoved, context + " removeIfSelector=" + selector); + } + + private static boolean removeIfMatches(int selector, long key, long value) { + return switch (selector) { + case 0 -> (key & 3L) == 0; + case 1 -> (value & 7L) == 0; + case 2 -> key < 0 && value <= 0; + case 3 -> key == Long.MIN_VALUE || value == Long.MAX_VALUE; + default -> throw new IllegalArgumentException("Unknown selector: " + selector); + }; + } + + private static void assertLong2LongMapMatches(Map expected, Iterable seenKeys, + Long2LongOpenHashMap actual, String context) { + long missingValue = 0x5A5A_5A5A_5A5A_5A5AL; + assertEquals(actual.isEmpty(), expected.isEmpty(), context); + assertEquals(actual.size(), expected.size(), context); + + for (long key : seenKeys) { + Long expectedValue = expected.get(key); + assertEquals(actual.containsKey(key), expectedValue != null, context + " checkedKey=" + key); + assertEquals(actual.get(key), expectedValue == null ? 0L : expectedValue.longValue(), + context + " checkedKey=" + key); + assertEquals(actual.getOrDefault(key, missingValue), + expectedValue == null ? missingValue : expectedValue.longValue(), context + " checkedKey=" + key); + } + + Map actualEntries = new HashMap<>(); + actual.forEach(actualEntries::put); + assertEquals(actualEntries, expected, context); + } + + private static long randomSeed(String testName) { + String configuredSeed = System.getProperty("pulsar.collections.randomSeed"); + long seed = configuredSeed != null ? Long.parseLong(configuredSeed) : ThreadLocalRandom.current().nextLong(); + String message = Long2LongOpenHashMapTest.class.getSimpleName() + "." + testName + " seed=" + seed; + Reporter.log(message, true); + System.err.println(message); + return seed; + } + + private static long randomValue(Random random) { + return switch (random.nextInt(32)) { + case 0 -> 0L; + case 1 -> Long.MIN_VALUE; + case 2 -> Long.MAX_VALUE; + default -> random.nextInt(1_024) - 512L; + }; + } + + private static long randomLongWithEdgeCases(Random random, int bound) { + return switch (random.nextInt(64)) { + case 0 -> 0L; + case 1 -> Long.MIN_VALUE; + case 2 -> Long.MAX_VALUE; + default -> random.nextInt(bound) - bound / 2L; + }; + } + + private static List collidingLongKeys(int capacity, int count) { + int mask = capacity - 1; + int bucket = Long2ObjectOpenHashMap.hash(0) & mask; + List keys = new ArrayList<>(); + for (long candidate = 0; keys.size() < count; candidate++) { + if ((Long2ObjectOpenHashMap.hash(candidate) & mask) == bucket) { + keys.add(candidate); + } + } + return keys; + } + + private static long valueForIndex(int index) { + return index % 3 == 0 ? 0L : index * 101L - 17L; + } +}