-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Create vectorized versions of ScalarQuantizer.quantize and recalculateCorrectiveOffset #14304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
ace9c1e
b0dd541
8edf1c2
9e1ed0f
50d6724
fb69472
d580aa9
c5c3a68
78ac8fa
cac99d3
0556aa8
054485e
b511ab4
7049128
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -907,4 +907,87 @@ public static long int4BitDotProduct128(byte[] q, byte[] d) { | |
} | ||
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); | ||
} | ||
|
||
@Override | ||
public float quantize( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's name this something better, we can call it "minMaxScalarQuantization" or something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done - and the recalculate method too |
||
float[] vector, byte[] dest, float scale, float alpha, float minQuantile, float maxQuantile) { | ||
float correction = 0; | ||
thecoop marked this conversation as resolved.
Show resolved
Hide resolved
|
||
int i = 0; | ||
// only vectorize if we have a viable BYTE_SPECIES we can use for output | ||
if (VECTOR_BITSIZE >= 256) { | ||
for (; i < FLOAT_SPECIES.loopBound(vector.length); i += FLOAT_SPECIES.length()) { | ||
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i); | ||
|
||
// Make sure the value is within the quantile range, cutting off the tails | ||
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile - | ||
// minQuantile) | ||
FloatVector dxc = v.min(maxQuantile).max(minQuantile).sub(minQuantile); | ||
// Scale the value to the range [0, 127], this is our quantized value | ||
// scale = 127/(maxQuantile - minQuantile) | ||
// Math.round rounds to positive infinity, so do the same by +0.5 then truncating to int | ||
Vector<Integer> roundedDxs = dxc.mul(scale).add(0.5f).convert(VectorOperators.F2I, 0); | ||
// output this to the array | ||
((ByteVector) roundedDxs.castShape(BYTE_SPECIES, 0)).intoArray(dest, i); | ||
// We multiply by `alpha` here to get the quantized value back into the original range | ||
// to aid in calculating the corrective offset | ||
Vector<Float> dxq = ((FloatVector) roundedDxs.castShape(FLOAT_SPECIES, 0)).mul(alpha); | ||
// Calculate the corrective offset that needs to be applied to the score | ||
// in addition to the `byte * minQuantile * alpha` term in the equation | ||
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value | ||
// will be rounded to the nearest whole number and lose some accuracy | ||
// Additionally, we account for the global correction of `minQuantile^2` in the equation | ||
correction += | ||
v.sub(minQuantile / 2f) | ||
.mul(minQuantile) | ||
.add(v.sub(minQuantile).sub(dxq).mul(dxq)) | ||
.reduceLanes(VectorOperators.ADD); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you collect the corrections in a float array? This way we keep all lanes parallized and then sum the floats later? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if you could keep the lanes separate for as long as possible, we get a bigger perf boost. Reducing lanes is a serious bottleneck. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed it is - this doubles the performance
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And even more with FMA operations There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yes ;). Thats the numbers I am expecting. |
||
} | ||
} | ||
|
||
// complete the tail normally | ||
correction += | ||
new DefaultVectorUtilSupport.ScalarQuantizer(alpha, scale, minQuantile, maxQuantile) | ||
.quantize(vector, dest, i); | ||
|
||
return correction; | ||
} | ||
|
||
@Override | ||
public float recalculateOffset( | ||
byte[] vector, | ||
float oldAlpha, | ||
float oldMinQuantile, | ||
float scale, | ||
float alpha, | ||
float minQuantile, | ||
float maxQuantile) { | ||
float correction = 0; | ||
int i = 0; | ||
// only vectorize if we have a viable BYTE_SPECIES that we can use | ||
if (VECTOR_BITSIZE >= 256) { | ||
for (; i < BYTE_SPECIES.loopBound(vector.length); i += BYTE_SPECIES.length()) { | ||
ByteVector bv = ByteVector.fromArray(BYTE_SPECIES, vector, i); | ||
// undo the old quantization | ||
FloatVector v = | ||
((FloatVector) bv.castShape(FLOAT_SPECIES, 0)).mul(oldAlpha).add(oldMinQuantile); | ||
|
||
// same operations as in quantize above | ||
FloatVector dxc = v.min(maxQuantile).max(minQuantile).sub(minQuantile); | ||
Vector<Integer> roundedDxs = dxc.mul(scale).add(0.5f).convert(VectorOperators.F2I, 0); | ||
Vector<Float> dxq = ((FloatVector) roundedDxs.castShape(FLOAT_SPECIES, 0)).mul(alpha); | ||
correction += | ||
v.sub(minQuantile / 2f) | ||
.mul(minQuantile) | ||
.add(v.sub(minQuantile).sub(dxq).mul(dxq)) | ||
.reduceLanes(VectorOperators.ADD); | ||
} | ||
} | ||
|
||
// complete the tail normally | ||
correction += | ||
new DefaultVectorUtilSupport.ScalarQuantizer(alpha, scale, minQuantile, maxQuantile) | ||
.recalculateOffset(vector, i, oldAlpha, oldMinQuantile); | ||
|
||
return correction; | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.