Skip to content

Create vectorized versions of ScalarQuantizer.quantize and recalculateCorrectiveOffset #14304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Optimizations
---------------------
* GITHUB#14011: Reduce allocation rate in HNSW concurrent merge. (Viliam Durina)
* GITHUB#14022: Optimize DFS marking of connected components in HNSW by reducing stack depth, improving performance and reducing allocations. (Viswanath Kuchibhotla)
* GITHUB#14304: Add SIMD optimizations for scalar quantized queries and indexing. (Simon Cooper)

Bug Fixes
---------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,79 @@ public static long int4BitDotProductImpl(byte[] q, byte[] d) {
}
return ret;
}

@Override
public float minMaxScalarQuantize(
float[] vector, byte[] dest, float scale, float alpha, float minQuantile, float maxQuantile) {
return new ScalarQuantizer(alpha, scale, minQuantile, maxQuantile).quantize(vector, dest, 0);
}

@Override
public float recalculateScalarQuantizationOffset(
byte[] vector,
float oldAlpha,
float oldMinQuantile,
float scale,
float alpha,
float minQuantile,
float maxQuantile) {
return new ScalarQuantizer(alpha, scale, minQuantile, maxQuantile)
.recalculateOffset(vector, 0, oldAlpha, oldMinQuantile);
}

static class ScalarQuantizer {
private final float alpha;
private final float scale;
private final float minQuantile, maxQuantile;

ScalarQuantizer(float alpha, float scale, float minQuantile, float maxQuantile) {
this.alpha = alpha;
this.scale = scale;
this.minQuantile = minQuantile;
this.maxQuantile = maxQuantile;
}

float quantize(float[] vector, byte[] dest, int start) {
assert vector.length == dest.length;
float correction = 0;
for (int i = start; i < vector.length; i++) {
correction += quantizeFloat(vector[i], dest, i);
}
return correction;
}

float recalculateOffset(byte[] vector, int start, float oldAlpha, float oldMinQuantile) {
float correction = 0;
for (int i = start; i < vector.length; i++) {
// undo the old quantization
float v = (oldAlpha * vector[i]) + oldMinQuantile;
correction += quantizeFloat(v, null, 0);
}
return correction;
}

private float quantizeFloat(float v, byte[] dest, int destIndex) {
assert dest == null || destIndex < dest.length;
// Make sure the value is within the quantile range, cutting off the tails
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
// minQuantile)
float dx = v - minQuantile;
float dxc = Math.max(minQuantile, Math.min(maxQuantile, v)) - minQuantile;
// Scale the value to the range [0, 127], this is our quantized value
// scale = 127/(maxQuantile - minQuantile)
int roundedDxs = Math.round(scale * dxc);
// We multiply by `alpha` here to get the quantized value back into the original range
// to aid in calculating the corrective offset
float dxq = roundedDxs * alpha;
if (dest != null) {
dest[destIndex] = (byte) roundedDxs;
}
// Calculate the corrective offset that needs to be applied to the score
// in addition to the `byte * minQuantile * alpha` term in the equation
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value
// will be rounded to the nearest whole number and lose some accuracy
// Additionally, we account for the global correction of `minQuantile^2` in the equation
return minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,39 @@ public interface VectorUtilSupport {
* @return the dot product
*/
long int4BitDotProduct(byte[] int4Quantized, byte[] binaryQuantized);

/**
* Quantizes {@code vector}, putting the result into {@code dest}.
*
* @param vector the vector to quantize
* @param dest the destination vector
* @param scale the scaling factor
* @param alpha the alpha value
* @param minQuantile the lower quantile of the distribution
* @param maxQuantile the upper quantile of the distribution
* @return the corrective offset that needs to be applied to the score
*/
float minMaxScalarQuantize(
float[] vector, byte[] dest, float scale, float alpha, float minQuantile, float maxQuantile);

/**
* Recalculates the offset for {@code vector}.
*
* @param vector the vector to quantize
* @param oldAlpha the previous alpha value
* @param oldMinQuantile the previous lower quantile
* @param scale the scaling factor
* @param alpha the alpha value
* @param minQuantile the lower quantile of the distribution
* @param maxQuantile the upper quantile of the distribution
* @return the new corrective offset
*/
float recalculateScalarQuantizationOffset(
byte[] vector,
float oldAlpha,
float oldMinQuantile,
float scale,
float alpha,
float minQuantile,
float maxQuantile);
}
42 changes: 42 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -334,4 +334,46 @@ public static int findNextGEQ(int[] buffer, int target, int from, int to) {
assert IntStream.range(0, to - 1).noneMatch(i -> buffer[i] > buffer[i + 1]);
return IMPL.findNextGEQ(buffer, target, from, to);
}

/**
* Scalar quantizes {@code vector}, putting the result into {@code dest}.
*
* @param vector the vector to quantize
* @param dest the destination vector
* @param scale the scaling factor
* @param alpha the alpha value
* @param minQuantile the lower quantile of the distribution
* @param maxQuantile the upper quantile of the distribution
* @return the corrective offset that needs to be applied to the score
*/
public static float minMaxScalarQuantize(
float[] vector, byte[] dest, float scale, float alpha, float minQuantile, float maxQuantile) {
if (vector.length != dest.length)
throw new IllegalArgumentException("source and destination arrays should be the same size");
return IMPL.minMaxScalarQuantize(vector, dest, scale, alpha, minQuantile, maxQuantile);
}

/**
* Recalculates the offset for {@code vector}.
*
* @param vector the vector to quantize
* @param oldAlpha the previous alpha value
* @param oldMinQuantile the previous lower quantile
* @param scale the scaling factor
* @param alpha the alpha value
* @param minQuantile the lower quantile of the distribution
* @param maxQuantile the upper quantile of the distribution
* @return the new corrective offset
*/
public static float recalculateOffset(
byte[] vector,
float oldAlpha,
float oldMinQuantile,
float scale,
float alpha,
float minQuantile,
float maxQuantile) {
return IMPL.recalculateScalarQuantizationOffset(
vector, oldAlpha, oldMinQuantile, scale, alpha, minQuantile, maxQuantile);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -122,40 +122,15 @@ public ScalarQuantizer(float minQuantile, float maxQuantile, byte bits) {
public float quantize(float[] src, byte[] dest, VectorSimilarityFunction similarityFunction) {
assert src.length == dest.length;
assert similarityFunction != VectorSimilarityFunction.COSINE || VectorUtil.isUnitVector(src);
float correction = 0;
for (int i = 0; i < src.length; i++) {
correction += quantizeFloat(src[i], dest, i);
}

float correction =
VectorUtil.minMaxScalarQuantize(src, dest, scale, alpha, minQuantile, maxQuantile);
if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN)) {
return 0;
}
return correction;
}

private float quantizeFloat(float v, byte[] dest, int destIndex) {
assert dest == null || destIndex < dest.length;
// Make sure the value is within the quantile range, cutting off the tails
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
// minQuantile)
float dx = v - minQuantile;
float dxc = Math.max(minQuantile, Math.min(maxQuantile, v)) - minQuantile;
// Scale the value to the range [0, 127], this is our quantized value
// scale = 127/(maxQuantile - minQuantile)
float dxs = scale * dxc;
// We multiply by `alpha` here to get the quantized value back into the original range
// to aid in calculating the corrective offset
float dxq = Math.round(dxs) * alpha;
if (dest != null) {
dest[destIndex] = (byte) Math.round(dxs);
}
// Calculate the corrective offset that needs to be applied to the score
// in addition to the `byte * minQuantile * alpha` term in the equation
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value
// will be rounded to the nearest whole number and lose some accuracy
// Additionally, we account for the global correction of `minQuantile^2` in the equation
return minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
}

/**
* Recalculate the old score corrective value given new current quantiles
*
Expand All @@ -171,13 +146,14 @@ public float recalculateCorrectiveOffset(
if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN)) {
return 0f;
}
float correctiveOffset = 0f;
for (int i = 0; i < quantizedVector.length; i++) {
// dequantize the old value in order to recalculate the corrective offset
float v = (oldQuantizer.alpha * quantizedVector[i]) + oldQuantizer.minQuantile;
correctiveOffset += quantizeFloat(v, null, 0);
}
return correctiveOffset;
return VectorUtil.recalculateOffset(
quantizedVector,
oldQuantizer.alpha,
oldQuantizer.minQuantile,
scale,
alpha,
minQuantile,
maxQuantile);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -907,4 +907,98 @@ public static long int4BitDotProduct128(byte[] q, byte[] d) {
}
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
}

@Override
public float minMaxScalarQuantize(
float[] vector, byte[] dest, float scale, float alpha, float minQuantile, float maxQuantile) {
assert vector.length == dest.length;
float correction = 0;
int i = 0;
// only vectorize if we have a viable BYTE_SPECIES we can use for output
if (VECTOR_BITSIZE >= 256) {
FloatVector sum = FloatVector.zero(FLOAT_SPECIES);

for (; i < FLOAT_SPECIES.loopBound(vector.length); i += FLOAT_SPECIES.length()) {
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);

// Make sure the value is within the quantile range, cutting off the tails
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
// minQuantile)
FloatVector dxc = v.min(maxQuantile).max(minQuantile).sub(minQuantile);
// Scale the value to the range [0, 127], this is our quantized value
// scale = 127/(maxQuantile - minQuantile)
// Math.round rounds to positive infinity, so do the same by +0.5 then truncating to int
Vector<Integer> roundedDxs =
fma(dxc, dxc.broadcast(scale), dxc.broadcast(0.5f)).convert(VectorOperators.F2I, 0);
// output this to the array
((ByteVector) roundedDxs.castShape(BYTE_SPECIES, 0)).intoArray(dest, i);
// We multiply by `alpha` here to get the quantized value back into the original range
// to aid in calculating the corrective offset
FloatVector dxq = ((FloatVector) roundedDxs.castShape(FLOAT_SPECIES, 0)).mul(alpha);
// Calculate the corrective offset that needs to be applied to the score
// in addition to the `byte * minQuantile * alpha` term in the equation
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value
// will be rounded to the nearest whole number and lose some accuracy
// Additionally, we account for the global correction of `minQuantile^2` in the equation
sum =
fma(
v.sub(minQuantile / 2f),
v.broadcast(minQuantile),
fma(v.sub(minQuantile).sub(dxq), dxq, sum));
}

correction = sum.reduceLanes(VectorOperators.ADD);
}

// complete the tail normally
correction +=
new DefaultVectorUtilSupport.ScalarQuantizer(alpha, scale, minQuantile, maxQuantile)
.quantize(vector, dest, i);

return correction;
}

@Override
public float recalculateScalarQuantizationOffset(
byte[] vector,
float oldAlpha,
float oldMinQuantile,
float scale,
float alpha,
float minQuantile,
float maxQuantile) {
float correction = 0;
int i = 0;
// only vectorize if we have a viable BYTE_SPECIES that we can use
if (VECTOR_BITSIZE >= 256) {
FloatVector sum = FloatVector.zero(FLOAT_SPECIES);

for (; i < BYTE_SPECIES.loopBound(vector.length); i += BYTE_SPECIES.length()) {
FloatVector fv =
(FloatVector) ByteVector.fromArray(BYTE_SPECIES, vector, i).castShape(FLOAT_SPECIES, 0);
// undo the old quantization
FloatVector v = fma(fv, fv.broadcast(oldAlpha), fv.broadcast(oldMinQuantile));

// same operations as in quantize above
FloatVector dxc = v.min(maxQuantile).max(minQuantile).sub(minQuantile);
Vector<Integer> roundedDxs =
fma(dxc, dxc.broadcast(scale), dxc.broadcast(0.5f)).convert(VectorOperators.F2I, 0);
FloatVector dxq = ((FloatVector) roundedDxs.castShape(FLOAT_SPECIES, 0)).mul(alpha);
sum =
fma(
v.sub(minQuantile / 2f),
v.broadcast(minQuantile),
fma(v.sub(minQuantile).sub(dxq), dxq, sum));
}

correction = sum.reduceLanes(VectorOperators.ADD);
}

// complete the tail normally
correction +=
new DefaultVectorUtilSupport.ScalarQuantizer(alpha, scale, minQuantile, maxQuantile)
.recalculateOffset(vector, i, oldAlpha, oldMinQuantile);

return correction;
}
}
Loading