Skip to content

Commit 20c348e

Browse files
Move buffer position in ByteBufferIndexWriter#writeFloats (#607)
* Move buffer position in ByteBufferIndexWriter#writeFloats * Add simpler way to write VectorFloat to an IndexWriter
1 parent d9ddce5 commit 20c348e

5 files changed

Lines changed: 268 additions & 0 deletions

File tree

jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferIndexWriter.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ public void writeFloat(float v) {
236236
@Override
237237
public void writeFloats(float[] floats, int offset, int count) throws IOException {
238238
buffer.asFloatBuffer().put(floats, offset, count);
239+
buffer.position(buffer.position() + count * Float.BYTES);
239240
}
240241

241242
@Override

jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayVectorFloat.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616

1717
package io.github.jbellis.jvector.vector;
1818

19+
import io.github.jbellis.jvector.disk.IndexWriter;
1920
import io.github.jbellis.jvector.util.RamUsageEstimator;
2021
import io.github.jbellis.jvector.vector.types.VectorFloat;
2122

23+
import java.io.IOException;
2224
import java.util.Arrays;
2325

2426
/**
@@ -65,6 +67,11 @@ public int length()
6567
return data.length;
6668
}
6769

70+
@Override
71+
public void writeTo(IndexWriter writer) throws IOException {
72+
writer.writeFloats(this.get(), 0, data.length);
73+
}
74+
6875
@Override
6976
public VectorFloat<float[]> copy()
7077
{

jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorFloat.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616

1717
package io.github.jbellis.jvector.vector.types;
1818

19+
import io.github.jbellis.jvector.disk.IndexWriter;
1920
import io.github.jbellis.jvector.util.Accountable;
2021

22+
import java.io.IOException;
23+
2124
public interface VectorFloat<T> extends Accountable
2225
{
2326
/**
@@ -31,6 +34,8 @@ default int offset(int i) {
3134
return i;
3235
}
3336

37+
void writeTo(IndexWriter indexWriter) throws IOException;
38+
3439
VectorFloat<T> copy();
3540

3641
void copyFrom(VectorFloat<?> src, int srcOffset, int destOffset, int length);

jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorFloat.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616

1717
package io.github.jbellis.jvector.vector;
1818

19+
import io.github.jbellis.jvector.disk.IndexWriter;
1920
import io.github.jbellis.jvector.util.RamUsageEstimator;
2021
import io.github.jbellis.jvector.vector.types.VectorFloat;
2122

23+
import java.io.IOException;
2224
import java.lang.foreign.MemorySegment;
2325
import java.nio.Buffer;
2426

@@ -86,6 +88,13 @@ public int offset(int i)
8688
return i * Float.BYTES;
8789
}
8890

91+
@Override
92+
public void writeTo(IndexWriter writer) throws IOException {
93+
for (int i = 0; i < length(); i++) {
94+
writer.writeFloat(get(i));
95+
}
96+
}
97+
8998
@Override
9099
public VectorFloat<MemorySegment> copy()
91100
{
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.github.jbellis.jvector.disk;
18+
19+
import io.github.jbellis.jvector.LuceneTestCase;
20+
import io.github.jbellis.jvector.vector.VectorUtil;
21+
import io.github.jbellis.jvector.vector.VectorizationProvider;
22+
import io.github.jbellis.jvector.vector.types.VectorFloat;
23+
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
24+
import org.junit.Test;
25+
26+
import java.io.IOException;
27+
import java.nio.ByteBuffer;
28+
29+
import static org.junit.Assert.*;
30+
31+
public class TestByteBufferIndexWriter extends LuceneTestCase {
32+
33+
@Test
34+
public void testWriteFloatsAdvancesPosition() throws IOException {
35+
ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false);
36+
37+
float[] floats = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
38+
39+
// Write floats
40+
writer.writeFloats(floats, 0, floats.length);
41+
42+
// Verify position advanced correctly
43+
assertEquals(floats.length * Float.BYTES, writer.position());
44+
45+
// Write more data to ensure position is correct
46+
writer.writeInt(42);
47+
assertEquals(floats.length * Float.BYTES + Integer.BYTES, writer.position());
48+
}
49+
50+
@Test
51+
public void testWriteFloatsWithOffsetAndCount() throws IOException {
52+
ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false);
53+
54+
float[] floats = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
55+
56+
// Write only middle 3 elements
57+
writer.writeFloats(floats, 1, 3);
58+
59+
// Verify position
60+
assertEquals(3 * Float.BYTES, writer.position());
61+
62+
// Verify content
63+
ByteBuffer buffer = writer.getWrittenData();
64+
assertEquals(2.0f, buffer.getFloat(), 0.001f);
65+
assertEquals(3.0f, buffer.getFloat(), 0.001f);
66+
assertEquals(4.0f, buffer.getFloat(), 0.001f);
67+
}
68+
69+
@Test
70+
public void testWriteFloatVectorIntegration() throws IOException {
71+
VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();
72+
ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false);
73+
74+
// Create a vector
75+
float[] data = {1.5f, 2.5f, 3.5f, 4.5f};
76+
VectorFloat<?> vector = vts.createFloatVector(data);
77+
78+
// Write vector using VectorTypeSupport
79+
vts.writeFloatVector(writer, vector);
80+
81+
// Verify position advanced
82+
assertEquals(data.length * Float.BYTES, writer.position());
83+
84+
// Verify content
85+
ByteBuffer buffer = writer.getWrittenData();
86+
for (float expected : data) {
87+
assertEquals(expected, buffer.getFloat(), 0.001f);
88+
}
89+
}
90+
91+
@Test
92+
public void testMultipleWriteFloatsCalls() throws IOException {
93+
ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false);
94+
95+
float[] floats1 = {1.0f, 2.0f};
96+
float[] floats2 = {3.0f, 4.0f, 5.0f};
97+
98+
writer.writeFloats(floats1, 0, floats1.length);
99+
long pos1 = writer.position();
100+
assertEquals(floats1.length * Float.BYTES, pos1);
101+
102+
writer.writeFloats(floats2, 0, floats2.length);
103+
long pos2 = writer.position();
104+
assertEquals((floats1.length + floats2.length) * Float.BYTES, pos2);
105+
106+
// Verify all data was written correctly
107+
ByteBuffer buffer = writer.getWrittenData();
108+
assertEquals(1.0f, buffer.getFloat(), 0.001f);
109+
assertEquals(2.0f, buffer.getFloat(), 0.001f);
110+
assertEquals(3.0f, buffer.getFloat(), 0.001f);
111+
assertEquals(4.0f, buffer.getFloat(), 0.001f);
112+
assertEquals(5.0f, buffer.getFloat(), 0.001f);
113+
}
114+
115+
@Test
116+
public void testWriteFloatsDoesNotOverwrite() throws IOException {
117+
ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false);
118+
119+
// Write an int first
120+
writer.writeInt(999);
121+
long posAfterInt = writer.position();
122+
123+
// Write floats
124+
float[] floats = {1.0f, 2.0f, 3.0f};
125+
writer.writeFloats(floats, 0, floats.length);
126+
127+
// Verify position
128+
assertEquals(posAfterInt + floats.length * Float.BYTES, writer.position());
129+
130+
// Verify the int wasn't overwritten
131+
ByteBuffer buffer = writer.getWrittenData();
132+
assertEquals(999, buffer.getInt());
133+
assertEquals(1.0f, buffer.getFloat(), 0.001f);
134+
assertEquals(2.0f, buffer.getFloat(), 0.001f);
135+
assertEquals(3.0f, buffer.getFloat(), 0.001f);
136+
}
137+
138+
@Test
139+
public void testWriteFloatsWithDirectBuffer() throws IOException {
140+
ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, true);
141+
142+
float[] floats = {1.0f, 2.0f, 3.0f};
143+
writer.writeFloats(floats, 0, floats.length);
144+
145+
assertEquals(floats.length * Float.BYTES, writer.position());
146+
147+
ByteBuffer buffer = writer.getWrittenData();
148+
for (float expected : floats) {
149+
assertEquals(expected, buffer.getFloat(), 0.001f);
150+
}
151+
}
152+
153+
@Test
154+
public void testBytesWritten() throws IOException {
155+
ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false);
156+
157+
assertEquals(0, writer.bytesWritten());
158+
159+
float[] floats = {1.0f, 2.0f, 3.0f};
160+
writer.writeFloats(floats, 0, floats.length);
161+
162+
assertEquals(floats.length * Float.BYTES, writer.bytesWritten());
163+
164+
writer.writeInt(42);
165+
assertEquals(floats.length * Float.BYTES + Integer.BYTES, writer.bytesWritten());
166+
}
167+
168+
@Test
169+
public void testArrayVectorFloatWriteTo() throws IOException {
170+
VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();
171+
ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false);
172+
173+
// Create an ArrayVectorFloat
174+
float[] data = {1.5f, 2.5f, 3.5f, 4.5f, 5.5f};
175+
VectorFloat<?> vector = vts.createFloatVector(data);
176+
177+
// Call writeTo which internally calls writeFloats
178+
vector.writeTo(writer);
179+
180+
// Verify position advanced correctly
181+
assertEquals(data.length * Float.BYTES, writer.position());
182+
assertEquals(data.length * Float.BYTES, writer.bytesWritten());
183+
184+
// Verify content was written correctly
185+
ByteBuffer buffer = writer.getWrittenData();
186+
for (float expected : data) {
187+
assertEquals(expected, buffer.getFloat(), 0.001f);
188+
}
189+
}
190+
191+
@Test
192+
public void testArrayVectorFloatWriteToMultipleTimes() throws IOException {
193+
VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();
194+
ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, false);
195+
196+
// Create two vectors
197+
float[] data1 = {1.0f, 2.0f, 3.0f};
198+
float[] data2 = {4.0f, 5.0f, 6.0f, 7.0f};
199+
VectorFloat<?> vector1 = vts.createFloatVector(data1);
200+
VectorFloat<?> vector2 = vts.createFloatVector(data2);
201+
202+
// Write first vector
203+
vector1.writeTo(writer);
204+
long pos1 = writer.position();
205+
assertEquals(data1.length * Float.BYTES, pos1);
206+
207+
// Write second vector
208+
vector2.writeTo(writer);
209+
long pos2 = writer.position();
210+
assertEquals((data1.length + data2.length) * Float.BYTES, pos2);
211+
212+
// Verify all data was written correctly
213+
ByteBuffer buffer = writer.getWrittenData();
214+
assertEquals(1.0f, buffer.getFloat(), 0.001f);
215+
assertEquals(2.0f, buffer.getFloat(), 0.001f);
216+
assertEquals(3.0f, buffer.getFloat(), 0.001f);
217+
assertEquals(4.0f, buffer.getFloat(), 0.001f);
218+
assertEquals(5.0f, buffer.getFloat(), 0.001f);
219+
assertEquals(6.0f, buffer.getFloat(), 0.001f);
220+
assertEquals(7.0f, buffer.getFloat(), 0.001f);
221+
}
222+
223+
@Test
224+
public void testArrayVectorFloatWriteToWithDirectBuffer() throws IOException {
225+
VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();
226+
ByteBufferIndexWriter writer = ByteBufferIndexWriter.create(1024, true);
227+
228+
// Create a vector
229+
float[] data = {10.5f, 20.5f, 30.5f};
230+
VectorFloat<?> vector = vts.createFloatVector(data);
231+
232+
// Write to direct buffer
233+
vector.writeTo(writer);
234+
235+
// Verify position
236+
assertEquals(data.length * Float.BYTES, writer.position());
237+
238+
// Verify content
239+
ByteBuffer buffer = writer.getWrittenData();
240+
for (float expected : data) {
241+
assertEquals(expected, buffer.getFloat(), 0.001f);
242+
}
243+
}
244+
}
245+
246+
// Made with Bob

0 commit comments

Comments
 (0)