Skip to content

Create vectorized versions of ScalarQuantizer.quantize and recalculateCorrectiveOffset #14304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,79 @@ public static long int4BitDotProductImpl(byte[] q, byte[] d) {
}
return ret;
}

@Override
public float quantize(
float[] vector, byte[] dest, float scale, float alpha, float minQuantile, float maxQuantile) {
return new ScalarQuantizer(alpha, scale, minQuantile, maxQuantile).quantize(vector, dest, 0);
}

@Override
public float recalculateOffset(
byte[] vector,
float oldAlpha,
float oldMinQuantile,
float scale,
float alpha,
float minQuantile,
float maxQuantile) {
return new ScalarQuantizer(alpha, scale, minQuantile, maxQuantile)
.recalculateOffset(vector, 0, oldAlpha, oldMinQuantile);
}

static class ScalarQuantizer {
private final float alpha;
private final float scale;
private final float minQuantile, maxQuantile;

ScalarQuantizer(float alpha, float scale, float minQuantile, float maxQuantile) {
this.alpha = alpha;
this.scale = scale;
this.minQuantile = minQuantile;
this.maxQuantile = maxQuantile;
}

float quantize(float[] vector, byte[] dest, int start) {
assert vector.length == dest.length;
float correction = 0;
for (int i = start; i < vector.length; i++) {
correction += quantizeFloat(vector[i], dest, i);
}
return correction;
}

float recalculateOffset(byte[] vector, int start, float oldAlpha, float oldMinQuantile) {
float correction = 0;
for (int i = start; i < vector.length; i++) {
// undo the old quantization
float v = (oldAlpha * vector[i]) + oldMinQuantile;
correction += quantizeFloat(v, null, 0);
}
return correction;
}

private float quantizeFloat(float v, byte[] dest, int destIndex) {
assert dest == null || destIndex < dest.length;
// Make sure the value is within the quantile range, cutting off the tails
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
// minQuantile)
float dx = v - minQuantile;
float dxc = Math.max(minQuantile, Math.min(maxQuantile, v)) - minQuantile;
// Scale the value to the range [0, 127], this is our quantized value
// scale = 127/(maxQuantile - minQuantile)
int roundedDxs = Math.round(scale * dxc);
// We multiply by `alpha` here to get the quantized value back into the original range
// to aid in calculating the corrective offset
float dxq = roundedDxs * alpha;
if (dest != null) {
dest[destIndex] = (byte) roundedDxs;
}
// Calculate the corrective offset that needs to be applied to the score
// in addition to the `byte * minQuantile * alpha` term in the equation
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value
// will be rounded to the nearest whole number and lose some accuracy
// Additionally, we account for the global correction of `minQuantile^2` in the equation
return minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,39 @@ public interface VectorUtilSupport {
* @return the dot product
*/
long int4BitDotProduct(byte[] int4Quantized, byte[] binaryQuantized);

/**
* Quantizes {@code vector}, putting the result into {@code dest}.
*
* @param vector the vector to quantize
* @param dest the destination vector
* @param scale the scaling factor
* @param alpha the alpha value
* @param minQuantile the lower quantile of the distribution
* @param maxQuantile the upper quantile of the distribution
* @return the corrective offset that needs to be applied to the score
*/
float quantize(
float[] vector, byte[] dest, float scale, float alpha, float minQuantile, float maxQuantile);

/**
* Recalculates the offset for {@code vector}.
*
* @param vector the vector to quantize
* @param oldAlpha the previous alpha value
* @param oldMinQuantile the previous lower quantile
* @param scale the scaling factor
* @param alpha the alpha value
* @param minQuantile the lower quantile of the distribution
* @param maxQuantile the upper quantile of the distribution
* @return the new corrective offset
*/
float recalculateOffset(
byte[] vector,
float oldAlpha,
float oldMinQuantile,
float scale,
float alpha,
float minQuantile,
float maxQuantile);
}
41 changes: 41 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -334,4 +334,45 @@ public static int findNextGEQ(int[] buffer, int target, int from, int to) {
assert IntStream.range(0, to - 1).noneMatch(i -> buffer[i] > buffer[i + 1]);
return IMPL.findNextGEQ(buffer, target, from, to);
}

/**
* Quantizes {@code vector}, putting the result into {@code dest}.
*
* @param vector the vector to quantize
* @param dest the destination vector, can be null
* @param scale the scaling factor
* @param alpha the alpha value
* @param minQuantile the lower quantile of the distribution
* @param maxQuantile the upper quantile of the distribution
* @return the corrective offset that needs to be applied to the score
*/
public static float quantize(
float[] vector, byte[] dest, float scale, float alpha, float minQuantile, float maxQuantile) {
assert vector.length == dest.length;
return IMPL.quantize(vector, dest, scale, alpha, minQuantile, maxQuantile);
}

/**
* Recalculates the offset for {@code vector}.
*
* @param vector the vector to quantize
* @param oldAlpha the previous alpha value
* @param oldMinQuantile the previous lower quantile
* @param scale the scaling factor
* @param alpha the alpha value
* @param minQuantile the lower quantile of the distribution
* @param maxQuantile the upper quantile of the distribution
* @return the new corrective offset
*/
public static float recalculateOffset(
byte[] vector,
float oldAlpha,
float oldMinQuantile,
float scale,
float alpha,
float minQuantile,
float maxQuantile) {
return IMPL.recalculateOffset(
vector, oldAlpha, oldMinQuantile, scale, alpha, minQuantile, maxQuantile);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -122,40 +122,14 @@ public ScalarQuantizer(float minQuantile, float maxQuantile, byte bits) {
public float quantize(float[] src, byte[] dest, VectorSimilarityFunction similarityFunction) {
assert src.length == dest.length;
assert similarityFunction != VectorSimilarityFunction.COSINE || VectorUtil.isUnitVector(src);
float correction = 0;
for (int i = 0; i < src.length; i++) {
correction += quantizeFloat(src[i], dest, i);
}

float correction = VectorUtil.quantize(src, dest, scale, alpha, minQuantile, maxQuantile);
if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN)) {
return 0;
}
return correction;
}

private float quantizeFloat(float v, byte[] dest, int destIndex) {
assert dest == null || destIndex < dest.length;
// Make sure the value is within the quantile range, cutting off the tails
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
// minQuantile)
float dx = v - minQuantile;
float dxc = Math.max(minQuantile, Math.min(maxQuantile, v)) - minQuantile;
// Scale the value to the range [0, 127], this is our quantized value
// scale = 127/(maxQuantile - minQuantile)
float dxs = scale * dxc;
// We multiply by `alpha` here to get the quantized value back into the original range
// to aid in calculating the corrective offset
float dxq = Math.round(dxs) * alpha;
if (dest != null) {
dest[destIndex] = (byte) Math.round(dxs);
}
// Calculate the corrective offset that needs to be applied to the score
// in addition to the `byte * minQuantile * alpha` term in the equation
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value
// will be rounded to the nearest whole number and lose some accuracy
// Additionally, we account for the global correction of `minQuantile^2` in the equation
return minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
}

/**
* Recalculate the old score corrective value given new current quantiles
*
Expand All @@ -171,13 +145,14 @@ public float recalculateCorrectiveOffset(
if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN)) {
return 0f;
}
float correctiveOffset = 0f;
for (int i = 0; i < quantizedVector.length; i++) {
// dequantize the old value in order to recalculate the corrective offset
float v = (oldQuantizer.alpha * quantizedVector[i]) + oldQuantizer.minQuantile;
correctiveOffset += quantizeFloat(v, null, 0);
}
return correctiveOffset;
return VectorUtil.recalculateOffset(
quantizedVector,
oldQuantizer.alpha,
oldQuantizer.minQuantile,
scale,
alpha,
minQuantile,
maxQuantile);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -907,4 +907,87 @@ public static long int4BitDotProduct128(byte[] q, byte[] d) {
}
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
}

@Override
public float quantize(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's name this something better, we can call it "minMaxScalarQuantization" or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - and the recalculate method too

float[] vector, byte[] dest, float scale, float alpha, float minQuantile, float maxQuantile) {
float correction = 0;
int i = 0;
// only vectorize if we have a viable BYTE_SPECIES we can use for output
if (VECTOR_BITSIZE >= 256) {
for (; i < FLOAT_SPECIES.loopBound(vector.length); i += FLOAT_SPECIES.length()) {
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);

// Make sure the value is within the quantile range, cutting off the tails
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
// minQuantile)
FloatVector dxc = v.min(maxQuantile).max(minQuantile).sub(minQuantile);
// Scale the value to the range [0, 127], this is our quantized value
// scale = 127/(maxQuantile - minQuantile)
// Math.round rounds to positive infinity, so do the same by +0.5 then truncating to int
Vector<Integer> roundedDxs = dxc.mul(scale).add(0.5f).convert(VectorOperators.F2I, 0);
// output this to the array
((ByteVector) roundedDxs.castShape(BYTE_SPECIES, 0)).intoArray(dest, i);
// We multiply by `alpha` here to get the quantized value back into the original range
// to aid in calculating the corrective offset
Vector<Float> dxq = ((FloatVector) roundedDxs.castShape(FLOAT_SPECIES, 0)).mul(alpha);
// Calculate the corrective offset that needs to be applied to the score
// in addition to the `byte * minQuantile * alpha` term in the equation
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value
// will be rounded to the nearest whole number and lose some accuracy
// Additionally, we account for the global correction of `minQuantile^2` in the equation
correction +=
v.sub(minQuantile / 2f)
.mul(minQuantile)
.add(v.sub(minQuantile).sub(dxq).mul(dxq))
.reduceLanes(VectorOperators.ADD);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you collect the corrections in a float array? This way we keep all lanes parallized and then sum the floats later?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you could keep the lanes separate for as long as possible, we get a bigger perf boost. Reducing lanes is a serious bottleneck.

Copy link
Contributor Author

@thecoop thecoop Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed it is - this doubles the performance

Benchmark              Mode  Cnt     Score    Error   Units
Quant.quantize        thrpt    5   235.029 ±  3.204  ops/ms
Quant.quantizeVector  thrpt    5  2831.313 ± 46.475  ops/ms

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And even more with FMA operations

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes ;). Thats the numbers I am expecting.

}
}

// complete the tail normally
correction +=
new DefaultVectorUtilSupport.ScalarQuantizer(alpha, scale, minQuantile, maxQuantile)
.quantize(vector, dest, i);

return correction;
}

@Override
public float recalculateOffset(
byte[] vector,
float oldAlpha,
float oldMinQuantile,
float scale,
float alpha,
float minQuantile,
float maxQuantile) {
float correction = 0;
int i = 0;
// only vectorize if we have a viable BYTE_SPECIES that we can use
if (VECTOR_BITSIZE >= 256) {
for (; i < BYTE_SPECIES.loopBound(vector.length); i += BYTE_SPECIES.length()) {
ByteVector bv = ByteVector.fromArray(BYTE_SPECIES, vector, i);
// undo the old quantization
FloatVector v =
((FloatVector) bv.castShape(FLOAT_SPECIES, 0)).mul(oldAlpha).add(oldMinQuantile);

// same operations as in quantize above
FloatVector dxc = v.min(maxQuantile).max(minQuantile).sub(minQuantile);
Vector<Integer> roundedDxs = dxc.mul(scale).add(0.5f).convert(VectorOperators.F2I, 0);
Vector<Float> dxq = ((FloatVector) roundedDxs.castShape(FLOAT_SPECIES, 0)).mul(alpha);
correction +=
v.sub(minQuantile / 2f)
.mul(minQuantile)
.add(v.sub(minQuantile).sub(dxq).mul(dxq))
.reduceLanes(VectorOperators.ADD);
}
}

// complete the tail normally
correction +=
new DefaultVectorUtilSupport.ScalarQuantizer(alpha, scale, minQuantile, maxQuantile)
.recalculateOffset(vector, i, oldAlpha, oldMinQuantile);

return correction;
}
}
Loading
Loading