|
29 | 29 | import io.github.jbellis.jvector.vector.types.ByteSequence; |
30 | 30 | import io.github.jbellis.jvector.vector.types.VectorFloat; |
31 | 31 | import io.github.jbellis.jvector.vector.types.VectorTypeSupport; |
| 32 | +import org.agrona.collections.IntHashSet; |
32 | 33 |
|
33 | 34 | import java.io.IOException; |
34 | 35 | import java.util.Arrays; |
35 | 36 | import java.util.List; |
36 | 37 | import java.util.Objects; |
| 38 | +import java.util.SplittableRandom; |
37 | 39 | import java.util.concurrent.Callable; |
38 | 40 | import java.util.concurrent.ForkJoinPool; |
39 | | -import java.util.concurrent.ThreadLocalRandom; |
40 | 41 | import java.util.concurrent.atomic.AtomicReference; |
41 | | -import java.util.function.Supplier; |
42 | 42 | import java.util.logging.Logger; |
43 | 43 | import java.util.stream.Collectors; |
44 | | -import java.util.stream.DoubleStream; |
45 | 44 | import java.util.stream.IntStream; |
46 | | -import java.util.stream.Stream; |
47 | 45 |
|
48 | 46 | import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED; |
49 | 47 | import static io.github.jbellis.jvector.util.MathUtil.square; |
50 | 48 | import static io.github.jbellis.jvector.vector.VectorUtil.dotProduct; |
51 | 49 | import static io.github.jbellis.jvector.vector.VectorUtil.sub; |
52 | | -import static java.lang.Math.min; |
53 | 50 | import static java.lang.Math.sqrt; |
54 | 51 |
|
55 | 52 | /** |
@@ -139,11 +136,36 @@ public static ProductQuantization compute(RandomAccessVectorValues ravv, |
139 | 136 | } |
140 | 137 |
|
141 | 138 | static List<VectorFloat<?>> extractTrainingVectors(RandomAccessVectorValues ravv, ForkJoinPool parallelExecutor) { |
142 | | - // limit the number of vectors we train on |
143 | | - var P = min(1.0f, MAX_PQ_TRAINING_SET_SIZE / (float) ravv.size()); |
| 139 | + final IntStream ordinalStream; |
| 140 | + |
| 141 | + if (ravv.size() <= MAX_PQ_TRAINING_SET_SIZE) { |
| 142 | + ordinalStream = IntStream.range(0, ravv.size()); |
| 143 | + } else { |
| 144 | + // Uses Floyd’s sampling algorithm to select MAX_PQ_TRAINING_SET_SIZE random ordinals from 0 to ravv.size() |
| 145 | + // while only iterating MAX_PQ_TRAINING_SET_SIZE times. |
| 146 | + SplittableRandom rng = new SplittableRandom(1); |
| 147 | + IntHashSet ordinals = new IntHashSet(MAX_PQ_TRAINING_SET_SIZE); |
| 148 | + // j runs from (ravv.size() - MAX_PQ_TRAINING_SET_SIZE) to ravv.size() (exclusive) |
| 149 | + for (int j = ravv.size() - MAX_PQ_TRAINING_SET_SIZE; j < ravv.size(); j++) { |
| 150 | + int t = rng.nextInt(j + 1); // uniform in [0, j] |
| 151 | + if (ordinals.contains(t)) { |
| 152 | + ordinals.add(j); |
| 153 | + } else { |
| 154 | + ordinals.add(t); |
| 155 | + } |
| 156 | + } |
| 157 | + int[] ordinalArray = new int[ordinals.size()]; |
| 158 | + IntHashSet.IntIterator it = ordinals.iterator(); |
| 159 | + for (int i = 0; i < ordinals.size(); i++) { |
| 160 | + assert it.hasNext(); |
| 161 | + ordinalArray[i] = it.next(); |
| 162 | + } |
| 163 | + assert !it.hasNext(); |
| 164 | + ordinalStream = IntStream.of(ordinalArray); |
| 165 | + } |
| 166 | + |
144 | 167 | var ravvCopy = ravv.threadLocalSupplier(); |
145 | | - return parallelExecutor.submit(() -> IntStream.range(0, ravv.size()).parallel() |
146 | | - .filter(i -> ThreadLocalRandom.current().nextFloat() < P) |
| 168 | + return parallelExecutor.submit(() -> ordinalStream.parallel() |
147 | 169 | .mapToObj(targetOrd -> { |
148 | 170 | var localRavv = ravvCopy.get(); |
149 | 171 | VectorFloat<?> v = localRavv.getVector(targetOrd); |
|
0 commit comments