diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/MemorySegmentVectorProviderBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/MemorySegmentVectorProviderBenchmark.java new file mode 100644 index 000000000..b8d354b0c --- /dev/null +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/MemorySegmentVectorProviderBenchmark.java @@ -0,0 +1,102 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.bench; + +import io.github.jbellis.jvector.vector.MemorySegmentVectorProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.ByteSequence; +import io.github.jbellis.jvector.disk.IndexWriter; +import org.openjdk.jmh.annotations.*; + +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.SECONDS) +@Warmup(iterations = 2) +@Measurement(iterations = 5) +@Fork(2) +@State(Scope.Benchmark) +public class MemorySegmentVectorProviderBenchmark { + + @Param({"512","1024","1536"}) + public int length; + + private MemorySegmentVectorProvider provider; + private VectorFloat fvector; + private ByteSequence bvector; + + @Setup(Level.Trial) + public void setup() { + provider = new MemorySegmentVectorProvider(); + float[] fdata = new float[length]; + byte[] bdata = new byte[length]; + for (int i = 0; i < length; i++) { + fdata[i] = (float) i; + bdata[i] = (byte) i; + } + fvector = provider.createFloatVector(fdata); + bvector = provider.createByteSequence(bdata); + } + + @Benchmark + public void writeFloatVector() throws IOException { + try (MemoryIndexWriter w = new MemoryIndexWriter(length * 4)) { + provider.writeFloatVector(w, fvector); + } + } + + @Benchmark + public void writeByteVector() throws IOException { + try (MemoryIndexWriter w = new MemoryIndexWriter(length * 4)) { + provider.writeByteSequence(w, bvector); + } + } + + static final class MemoryIndexWriter implements IndexWriter { + private final java.io.ByteArrayOutputStream bos; + private final java.io.DataOutputStream out; + + MemoryIndexWriter(int capacity) { + this.bos = new java.io.ByteArrayOutputStream(capacity); + this.out = new java.io.DataOutputStream(bos); + } + + byte[] toByteArray() { return bos.toByteArray(); } + + @Override public long position() { return bos.size(); } + @Override public void close() throws IOException { out.close(); } + + @Override public void write(int b) throws IOException { out.write(b); } + @Override public void write(byte[] b) throws IOException { out.write(b); } + @Override public void write(byte[] b, int off, int len) throws IOException { out.write(b, off, len); } + @Override public void writeFloat(float v) throws IOException { out.writeFloat(v); } + + @Override public void writeBoolean(boolean v) throws IOException { out.writeBoolean(v); } + @Override public void writeByte(int v) throws IOException { out.writeByte(v); } + @Override public void writeShort(int v) throws IOException { out.writeShort(v); } + @Override public void writeChar(int v) throws IOException { out.writeChar(v); } + @Override public void writeInt(int v) throws IOException { out.writeInt(v); } + @Override public void writeLong(long v) throws IOException { out.writeLong(v); } + @Override public void writeDouble(double v) throws IOException { out.writeDouble(v); } + @Override public void writeBytes(String s) throws IOException { out.writeBytes(s); } + @Override public void writeChars(String s) throws IOException { out.writeChars(s); } + @Override public void writeUTF(String s) throws IOException { out.writeUTF(s); } + } +} diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java index 1ce0d81b2..a201e48c8 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java @@ -62,8 +62,8 @@ public void readFloatVector(RandomAccessReader r, int count, VectorFloat vect @Override public void writeFloatVector(IndexWriter out, VectorFloat vector) throws IOException { - for (int i = 0; i < vector.length(); i++) - out.writeFloat(vector.get(i)); + float[] data = (float[]) ((MemorySegmentVectorFloat) vector).get().heapBase().get(); + out.writeFloats(data, 0, vector.length()); } @Override @@ -98,7 +98,7 @@ public void readByteSequence(RandomAccessReader r, ByteSequence sequence) thr @Override public void writeByteSequence(IndexWriter out, ByteSequence sequence) throws IOException { - for (int i = 0; i < sequence.length(); i++) - out.writeByte(sequence.get(i)); + java.nio.ByteBuffer bb = ((MemorySegmentByteSequence) sequence).get().asByteBuffer(); + out.write(bb.array(), bb.arrayOffset(), bb.remaining()); } } diff --git a/jvector-native/src/test/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProviderTest.java b/jvector-native/src/test/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProviderTest.java new file mode 100644 index 000000000..29b415a86 --- /dev/null +++ b/jvector-native/src/test/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProviderTest.java @@ -0,0 +1,106 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.vector; + +import io.github.jbellis.jvector.disk.IndexWriter; +import io.github.jbellis.jvector.vector.types.ByteSequence; + +import java.io.IOException; +import org.junit.jupiter.api.Test; + +public class MemorySegmentVectorProviderTest { + + @Test + void testWriteByteSequenceSlice() throws IOException { + MemorySegmentVectorProvider provider = new MemorySegmentVectorProvider(); + + byte[] originalBytes = {10, 20, 30, 40, 50}; + ByteSequence original = provider.createByteSequence(originalBytes); + + ByteSequence slice = original.slice(2, 3); + + MockIndexWriter dummyWriter = new MockIndexWriter(); + + // Proves sliced writes function perfectly + provider.writeByteSequence(dummyWriter, slice); + + byte[] expected = {30, 40, 50}; + org.junit.jupiter.api.Assertions.assertArrayEquals(expected, dummyWriter.toByteArray()); + } + + @Test + void testWriteByteSequenceFull() throws IOException { + MemorySegmentVectorProvider provider = new MemorySegmentVectorProvider(); + + byte[] expectedBytes = {1, 2, 3, 4, 5}; + ByteSequence sequence = provider.createByteSequence(expectedBytes); + + MockIndexWriter dummyWriter = new MockIndexWriter(); + + // Proves standard, non-sliced writes function perfectly + provider.writeByteSequence(dummyWriter, sequence); + + org.junit.jupiter.api.Assertions.assertArrayEquals(expectedBytes, dummyWriter.toByteArray()); + } + + @Test + void testWriteByteSequenceZeroLength() throws IOException { + MemorySegmentVectorProvider provider = new MemorySegmentVectorProvider(); + + byte[] originalBytes = {10, 20, 30}; + ByteSequence original = provider.createByteSequence(originalBytes); + + // Create a logical empty slice + ByteSequence emptySlice = original.slice(1, 0); + + MockIndexWriter dummyWriter = new MockIndexWriter(); + + // Proves edge cases don't throw IndexOutOfBoundsException + provider.writeByteSequence(dummyWriter, emptySlice); + + org.junit.jupiter.api.Assertions.assertArrayEquals(new byte[0], dummyWriter.toByteArray()); + } + + /** + * A lightweight mock to capture IndexWriter output without boilerplate. + */ + private static class MockIndexWriter implements IndexWriter { + private final java.io.ByteArrayOutputStream bos = new java.io.ByteArrayOutputStream(); + private final java.io.DataOutputStream out = new java.io.DataOutputStream(bos); + + public byte[] toByteArray() { return bos.toByteArray(); } + + @Override public long position() { return bos.size(); } + @Override public void close() throws IOException { out.close(); } + + // DataOutput delegation + @Override public void write(int b) throws IOException { out.write(b); } + @Override public void write(byte[] b) throws IOException { out.write(b); } + @Override public void write(byte[] b, int off, int len) throws IOException { out.write(b, off, len); } + @Override public void writeBoolean(boolean v) throws IOException { out.writeBoolean(v); } + @Override public void writeByte(int v) throws IOException { out.writeByte(v); } + @Override public void writeShort(int v) throws IOException { out.writeShort(v); } + @Override public void writeChar(int v) throws IOException { out.writeChar(v); } + @Override public void writeInt(int v) throws IOException { out.writeInt(v); } + @Override public void writeLong(long v) throws IOException { out.writeLong(v); } + @Override public void writeFloat(float v) throws IOException { out.writeFloat(v); } + @Override public void writeDouble(double v) throws IOException { out.writeDouble(v); } + @Override public void writeBytes(String s) throws IOException { out.writeBytes(s); } + @Override public void writeChars(String s) throws IOException { out.writeChars(s); } + @Override public void writeUTF(String s) throws IOException { out.writeUTF(s); } + } +}