Skip to content

Commit d9ddce5

Browse files
Ensure extractTrainingVectors return a list of at most MAX_PQ_TRAINING_SET_SIZE (#610)
* For PQ codebook training * Ensures extractTrainingVectors return a list of at most MAX_PQ_TRAINING_SET_SIZE * Uses Floyd's algorithm
1 parent d663b4f commit d9ddce5

1 file changed

Lines changed: 31 additions & 9 deletions

File tree

jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,24 @@
2929
import io.github.jbellis.jvector.vector.types.ByteSequence;
3030
import io.github.jbellis.jvector.vector.types.VectorFloat;
3131
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
32+
import org.agrona.collections.IntHashSet;
3233

3334
import java.io.IOException;
3435
import java.util.Arrays;
3536
import java.util.List;
3637
import java.util.Objects;
38+
import java.util.SplittableRandom;
3739
import java.util.concurrent.Callable;
3840
import java.util.concurrent.ForkJoinPool;
39-
import java.util.concurrent.ThreadLocalRandom;
4041
import java.util.concurrent.atomic.AtomicReference;
41-
import java.util.function.Supplier;
4242
import java.util.logging.Logger;
4343
import java.util.stream.Collectors;
44-
import java.util.stream.DoubleStream;
4544
import java.util.stream.IntStream;
46-
import java.util.stream.Stream;
4745

4846
import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED;
4947
import static io.github.jbellis.jvector.util.MathUtil.square;
5048
import static io.github.jbellis.jvector.vector.VectorUtil.dotProduct;
5149
import static io.github.jbellis.jvector.vector.VectorUtil.sub;
52-
import static java.lang.Math.min;
5350
import static java.lang.Math.sqrt;
5451

5552
/**
@@ -139,11 +136,36 @@ public static ProductQuantization compute(RandomAccessVectorValues ravv,
139136
}
140137

141138
static List<VectorFloat<?>> extractTrainingVectors(RandomAccessVectorValues ravv, ForkJoinPool parallelExecutor) {
142-
// limit the number of vectors we train on
143-
var P = min(1.0f, MAX_PQ_TRAINING_SET_SIZE / (float) ravv.size());
139+
final IntStream ordinalStream;
140+
141+
if (ravv.size() <= MAX_PQ_TRAINING_SET_SIZE) {
142+
ordinalStream = IntStream.range(0, ravv.size());
143+
} else {
144+
// Uses Floyd’s sampling algorithm to select MAX_PQ_TRAINING_SET_SIZE random ordinals from 0 to ravv.size()
145+
// while only iterating MAX_PQ_TRAINING_SET_SIZE times.
146+
SplittableRandom rng = new SplittableRandom(1);
147+
IntHashSet ordinals = new IntHashSet(MAX_PQ_TRAINING_SET_SIZE);
148+
// j runs from (ravv.size() - MAX_PQ_TRAINING_SET_SIZE) to ravv.size() (exclusive)
149+
for (int j = ravv.size() - MAX_PQ_TRAINING_SET_SIZE; j < ravv.size(); j++) {
150+
int t = rng.nextInt(j + 1); // uniform in [0, j]
151+
if (ordinals.contains(t)) {
152+
ordinals.add(j);
153+
} else {
154+
ordinals.add(t);
155+
}
156+
}
157+
int[] ordinalArray = new int[ordinals.size()];
158+
IntHashSet.IntIterator it = ordinals.iterator();
159+
for (int i = 0; i < ordinals.size(); i++) {
160+
assert it.hasNext();
161+
ordinalArray[i] = it.next();
162+
}
163+
assert !it.hasNext();
164+
ordinalStream = IntStream.of(ordinalArray);
165+
}
166+
144167
var ravvCopy = ravv.threadLocalSupplier();
145-
return parallelExecutor.submit(() -> IntStream.range(0, ravv.size()).parallel()
146-
.filter(i -> ThreadLocalRandom.current().nextFloat() < P)
168+
return parallelExecutor.submit(() -> ordinalStream.parallel()
147169
.mapToObj(targetOrd -> {
148170
var localRavv = ravvCopy.get();
149171
VectorFloat<?> v = localRavv.getVector(targetOrd);

0 commit comments

Comments
 (0)