Skip to content

Commit 1e8a146

Browse files
authored
Create vectorized versions of ScalarQuantizer.quantize and recalculateCorrectiveOffset (#14304)
This resolves #13922. It takes the existing methods in `ScalarQuantizer`, and creates vectorized versions of that same algorithm. JMH shows a ~13x speedup: ``` Benchmark Mode Cnt Score Error Units Quant.quantize thrpt 5 235.029 ± 3.204 ops/ms Quant.quantizeVector thrpt 5 3153.388 ± 192.635 ops/ms ```
1 parent 70abd1f commit 1e8a146

File tree

7 files changed

+307
-40
lines changed

7 files changed

+307
-40
lines changed

lucene/CHANGES.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Optimizations
3535
---------------------
3636
* GITHUB#14011: Reduce allocation rate in HNSW concurrent merge. (Viliam Durina)
3737
* GITHUB#14022: Optimize DFS marking of connected components in HNSW by reducing stack depth, improving performance and reducing allocations. (Viswanath Kuchibhotla)
38+
* GITHUB#14304: Add SIMD optimizations for scalar quantized queries and indexing. (Simon Cooper)
3839

3940
Bug Fixes
4041
---------------------

lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,4 +234,79 @@ public static long int4BitDotProductImpl(byte[] q, byte[] d) {
234234
}
235235
return ret;
236236
}
237+
238+
@Override
239+
public float minMaxScalarQuantize(
240+
float[] vector, byte[] dest, float scale, float alpha, float minQuantile, float maxQuantile) {
241+
return new ScalarQuantizer(alpha, scale, minQuantile, maxQuantile).quantize(vector, dest, 0);
242+
}
243+
244+
@Override
245+
public float recalculateScalarQuantizationOffset(
246+
byte[] vector,
247+
float oldAlpha,
248+
float oldMinQuantile,
249+
float scale,
250+
float alpha,
251+
float minQuantile,
252+
float maxQuantile) {
253+
return new ScalarQuantizer(alpha, scale, minQuantile, maxQuantile)
254+
.recalculateOffset(vector, 0, oldAlpha, oldMinQuantile);
255+
}
256+
257+
static class ScalarQuantizer {
258+
private final float alpha;
259+
private final float scale;
260+
private final float minQuantile, maxQuantile;
261+
262+
ScalarQuantizer(float alpha, float scale, float minQuantile, float maxQuantile) {
263+
this.alpha = alpha;
264+
this.scale = scale;
265+
this.minQuantile = minQuantile;
266+
this.maxQuantile = maxQuantile;
267+
}
268+
269+
float quantize(float[] vector, byte[] dest, int start) {
270+
assert vector.length == dest.length;
271+
float correction = 0;
272+
for (int i = start; i < vector.length; i++) {
273+
correction += quantizeFloat(vector[i], dest, i);
274+
}
275+
return correction;
276+
}
277+
278+
float recalculateOffset(byte[] vector, int start, float oldAlpha, float oldMinQuantile) {
279+
float correction = 0;
280+
for (int i = start; i < vector.length; i++) {
281+
// undo the old quantization
282+
float v = (oldAlpha * vector[i]) + oldMinQuantile;
283+
correction += quantizeFloat(v, null, 0);
284+
}
285+
return correction;
286+
}
287+
288+
private float quantizeFloat(float v, byte[] dest, int destIndex) {
289+
assert dest == null || destIndex < dest.length;
290+
// Make sure the value is within the quantile range, cutting off the tails
291+
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
292+
// minQuantile)
293+
float dx = v - minQuantile;
294+
float dxc = Math.max(minQuantile, Math.min(maxQuantile, v)) - minQuantile;
295+
// Scale the value to the range [0, 127], this is our quantized value
296+
// scale = 127/(maxQuantile - minQuantile)
297+
int roundedDxs = Math.round(scale * dxc);
298+
// We multiply by `alpha` here to get the quantized value back into the original range
299+
// to aid in calculating the corrective offset
300+
float dxq = roundedDxs * alpha;
301+
if (dest != null) {
302+
dest[destIndex] = (byte) roundedDxs;
303+
}
304+
// Calculate the corrective offset that needs to be applied to the score
305+
// in addition to the `byte * minQuantile * alpha` term in the equation
306+
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value
307+
// will be rounded to the nearest whole number and lose some accuracy
308+
// Additionally, we account for the global correction of `minQuantile^2` in the equation
309+
return minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
310+
}
311+
}
237312
}

lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,39 @@ public interface VectorUtilSupport {
6565
* @return the dot product
6666
*/
6767
long int4BitDotProduct(byte[] int4Quantized, byte[] binaryQuantized);
68+
69+
/**
70+
* Quantizes {@code vector}, putting the result into {@code dest}.
71+
*
72+
* @param vector the vector to quantize
73+
* @param dest the destination vector
74+
* @param scale the scaling factor
75+
* @param alpha the alpha value
76+
* @param minQuantile the lower quantile of the distribution
77+
* @param maxQuantile the upper quantile of the distribution
78+
* @return the corrective offset that needs to be applied to the score
79+
*/
80+
float minMaxScalarQuantize(
81+
float[] vector, byte[] dest, float scale, float alpha, float minQuantile, float maxQuantile);
82+
83+
/**
84+
* Recalculates the offset for {@code vector}.
85+
*
86+
* @param vector the vector to quantize
87+
* @param oldAlpha the previous alpha value
88+
* @param oldMinQuantile the previous lower quantile
89+
* @param scale the scaling factor
90+
* @param alpha the alpha value
91+
* @param minQuantile the lower quantile of the distribution
92+
* @param maxQuantile the upper quantile of the distribution
93+
* @return the new corrective offset
94+
*/
95+
float recalculateScalarQuantizationOffset(
96+
byte[] vector,
97+
float oldAlpha,
98+
float oldMinQuantile,
99+
float scale,
100+
float alpha,
101+
float minQuantile,
102+
float maxQuantile);
68103
}

lucene/core/src/java/org/apache/lucene/util/VectorUtil.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,4 +334,46 @@ public static int findNextGEQ(int[] buffer, int target, int from, int to) {
334334
assert IntStream.range(0, to - 1).noneMatch(i -> buffer[i] > buffer[i + 1]);
335335
return IMPL.findNextGEQ(buffer, target, from, to);
336336
}
337+
338+
/**
339+
* Scalar quantizes {@code vector}, putting the result into {@code dest}.
340+
*
341+
* @param vector the vector to quantize
342+
* @param dest the destination vector
343+
* @param scale the scaling factor
344+
* @param alpha the alpha value
345+
* @param minQuantile the lower quantile of the distribution
346+
* @param maxQuantile the upper quantile of the distribution
347+
* @return the corrective offset that needs to be applied to the score
348+
*/
349+
public static float minMaxScalarQuantize(
350+
float[] vector, byte[] dest, float scale, float alpha, float minQuantile, float maxQuantile) {
351+
if (vector.length != dest.length)
352+
throw new IllegalArgumentException("source and destination arrays should be the same size");
353+
return IMPL.minMaxScalarQuantize(vector, dest, scale, alpha, minQuantile, maxQuantile);
354+
}
355+
356+
/**
357+
* Recalculates the offset for {@code vector}.
358+
*
359+
* @param vector the vector to quantize
360+
* @param oldAlpha the previous alpha value
361+
* @param oldMinQuantile the previous lower quantile
362+
* @param scale the scaling factor
363+
* @param alpha the alpha value
364+
* @param minQuantile the lower quantile of the distribution
365+
* @param maxQuantile the upper quantile of the distribution
366+
* @return the new corrective offset
367+
*/
368+
public static float recalculateOffset(
369+
byte[] vector,
370+
float oldAlpha,
371+
float oldMinQuantile,
372+
float scale,
373+
float alpha,
374+
float minQuantile,
375+
float maxQuantile) {
376+
return IMPL.recalculateScalarQuantizationOffset(
377+
vector, oldAlpha, oldMinQuantile, scale, alpha, minQuantile, maxQuantile);
378+
}
337379
}

lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -122,40 +122,15 @@ public ScalarQuantizer(float minQuantile, float maxQuantile, byte bits) {
122122
public float quantize(float[] src, byte[] dest, VectorSimilarityFunction similarityFunction) {
123123
assert src.length == dest.length;
124124
assert similarityFunction != VectorSimilarityFunction.COSINE || VectorUtil.isUnitVector(src);
125-
float correction = 0;
126-
for (int i = 0; i < src.length; i++) {
127-
correction += quantizeFloat(src[i], dest, i);
128-
}
125+
126+
float correction =
127+
VectorUtil.minMaxScalarQuantize(src, dest, scale, alpha, minQuantile, maxQuantile);
129128
if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN)) {
130129
return 0;
131130
}
132131
return correction;
133132
}
134133

135-
private float quantizeFloat(float v, byte[] dest, int destIndex) {
136-
assert dest == null || destIndex < dest.length;
137-
// Make sure the value is within the quantile range, cutting off the tails
138-
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
139-
// minQuantile)
140-
float dx = v - minQuantile;
141-
float dxc = Math.max(minQuantile, Math.min(maxQuantile, v)) - minQuantile;
142-
// Scale the value to the range [0, 127], this is our quantized value
143-
// scale = 127/(maxQuantile - minQuantile)
144-
float dxs = scale * dxc;
145-
// We multiply by `alpha` here to get the quantized value back into the original range
146-
// to aid in calculating the corrective offset
147-
float dxq = Math.round(dxs) * alpha;
148-
if (dest != null) {
149-
dest[destIndex] = (byte) Math.round(dxs);
150-
}
151-
// Calculate the corrective offset that needs to be applied to the score
152-
// in addition to the `byte * minQuantile * alpha` term in the equation
153-
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value
154-
// will be rounded to the nearest whole number and lose some accuracy
155-
// Additionally, we account for the global correction of `minQuantile^2` in the equation
156-
return minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
157-
}
158-
159134
/**
160135
* Recalculate the old score corrective value given new current quantiles
161136
*
@@ -171,13 +146,14 @@ public float recalculateCorrectiveOffset(
171146
if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN)) {
172147
return 0f;
173148
}
174-
float correctiveOffset = 0f;
175-
for (int i = 0; i < quantizedVector.length; i++) {
176-
// dequantize the old value in order to recalculate the corrective offset
177-
float v = (oldQuantizer.alpha * quantizedVector[i]) + oldQuantizer.minQuantile;
178-
correctiveOffset += quantizeFloat(v, null, 0);
179-
}
180-
return correctiveOffset;
149+
return VectorUtil.recalculateOffset(
150+
quantizedVector,
151+
oldQuantizer.alpha,
152+
oldQuantizer.minQuantile,
153+
scale,
154+
alpha,
155+
minQuantile,
156+
maxQuantile);
181157
}
182158

183159
/**

lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,4 +907,98 @@ public static long int4BitDotProduct128(byte[] q, byte[] d) {
907907
}
908908
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
909909
}
910+
911+
@Override
912+
public float minMaxScalarQuantize(
913+
float[] vector, byte[] dest, float scale, float alpha, float minQuantile, float maxQuantile) {
914+
assert vector.length == dest.length;
915+
float correction = 0;
916+
int i = 0;
917+
// only vectorize if we have a viable BYTE_SPECIES we can use for output
918+
if (VECTOR_BITSIZE >= 256) {
919+
FloatVector sum = FloatVector.zero(FLOAT_SPECIES);
920+
921+
for (; i < FLOAT_SPECIES.loopBound(vector.length); i += FLOAT_SPECIES.length()) {
922+
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
923+
924+
// Make sure the value is within the quantile range, cutting off the tails
925+
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
926+
// minQuantile)
927+
FloatVector dxc = v.min(maxQuantile).max(minQuantile).sub(minQuantile);
928+
// Scale the value to the range [0, 127], this is our quantized value
929+
// scale = 127/(maxQuantile - minQuantile)
930+
// Math.round rounds to positive infinity, so do the same by +0.5 then truncating to int
931+
Vector<Integer> roundedDxs =
932+
fma(dxc, dxc.broadcast(scale), dxc.broadcast(0.5f)).convert(VectorOperators.F2I, 0);
933+
// output this to the array
934+
((ByteVector) roundedDxs.castShape(BYTE_SPECIES, 0)).intoArray(dest, i);
935+
// We multiply by `alpha` here to get the quantized value back into the original range
936+
// to aid in calculating the corrective offset
937+
FloatVector dxq = ((FloatVector) roundedDxs.castShape(FLOAT_SPECIES, 0)).mul(alpha);
938+
// Calculate the corrective offset that needs to be applied to the score
939+
// in addition to the `byte * minQuantile * alpha` term in the equation
940+
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value
941+
// will be rounded to the nearest whole number and lose some accuracy
942+
// Additionally, we account for the global correction of `minQuantile^2` in the equation
943+
sum =
944+
fma(
945+
v.sub(minQuantile / 2f),
946+
v.broadcast(minQuantile),
947+
fma(v.sub(minQuantile).sub(dxq), dxq, sum));
948+
}
949+
950+
correction = sum.reduceLanes(VectorOperators.ADD);
951+
}
952+
953+
// complete the tail normally
954+
correction +=
955+
new DefaultVectorUtilSupport.ScalarQuantizer(alpha, scale, minQuantile, maxQuantile)
956+
.quantize(vector, dest, i);
957+
958+
return correction;
959+
}
960+
961+
@Override
962+
public float recalculateScalarQuantizationOffset(
963+
byte[] vector,
964+
float oldAlpha,
965+
float oldMinQuantile,
966+
float scale,
967+
float alpha,
968+
float minQuantile,
969+
float maxQuantile) {
970+
float correction = 0;
971+
int i = 0;
972+
// only vectorize if we have a viable BYTE_SPECIES that we can use
973+
if (VECTOR_BITSIZE >= 256) {
974+
FloatVector sum = FloatVector.zero(FLOAT_SPECIES);
975+
976+
for (; i < BYTE_SPECIES.loopBound(vector.length); i += BYTE_SPECIES.length()) {
977+
FloatVector fv =
978+
(FloatVector) ByteVector.fromArray(BYTE_SPECIES, vector, i).castShape(FLOAT_SPECIES, 0);
979+
// undo the old quantization
980+
FloatVector v = fma(fv, fv.broadcast(oldAlpha), fv.broadcast(oldMinQuantile));
981+
982+
// same operations as in quantize above
983+
FloatVector dxc = v.min(maxQuantile).max(minQuantile).sub(minQuantile);
984+
Vector<Integer> roundedDxs =
985+
fma(dxc, dxc.broadcast(scale), dxc.broadcast(0.5f)).convert(VectorOperators.F2I, 0);
986+
FloatVector dxq = ((FloatVector) roundedDxs.castShape(FLOAT_SPECIES, 0)).mul(alpha);
987+
sum =
988+
fma(
989+
v.sub(minQuantile / 2f),
990+
v.broadcast(minQuantile),
991+
fma(v.sub(minQuantile).sub(dxq), dxq, sum));
992+
}
993+
994+
correction = sum.reduceLanes(VectorOperators.ADD);
995+
}
996+
997+
// complete the tail normally
998+
correction +=
999+
new DefaultVectorUtilSupport.ScalarQuantizer(alpha, scale, minQuantile, maxQuantile)
1000+
.recalculateOffset(vector, i, oldAlpha, oldMinQuantile);
1001+
1002+
return correction;
1003+
}
9101004
}

0 commit comments

Comments
 (0)