Skip to content

Commit 049f8e9

Browse files
authored
Fix the type of Normalization output (#2833)
* parallel for Signed-off-by: Chen Tong <[email protected]> * run through Signed-off-by: Tong Chen <[email protected]> * flag Signed-off-by: Tong Chen <[email protected]> * format Signed-off-by: chentong319 <[email protected]> * Revert "format" This reverts commit ebd27e2. * Revert "flag" This reverts commit 432e4be. * fix merge error Signed-off-by: chentong319 <[email protected]> * pass test Signed-off-by: chentong319 <[email protected]> --------- Signed-off-by: Chen Tong <[email protected]> Signed-off-by: Tong Chen <[email protected]> Signed-off-by: Tong Chen <[email protected]> Signed-off-by: chentong319 <[email protected]>
1 parent 35f66a0 commit 049f8e9

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

src/Conversion/ONNXToKrnl/NN/Normalization.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,8 @@ static inline void replaceGenericLayerNormOp(
339339
// TODO: conversions of types are not handled.
340340
template <typename OP_TYPE>
341341
LogicalResult generateGenericLayerNormOpONNXCode(
342-
ConversionPatternRewriter &rewriter, Location loc, OP_TYPE lnOp) {
342+
ConversionPatternRewriter &rewriter, Location loc, OP_TYPE lnOp,
343+
const TypeConverter *const typeConverter) {
343344
MDBuilder create(rewriter, loc);
344345
Value X = lnOp.getX(); // Original value, not translated.
345346
TensorType XType = X.getType().cast<TensorType>();
@@ -368,6 +369,10 @@ LogicalResult generateGenericLayerNormOpONNXCode(
368369
if constexpr (std::is_same<OP_TYPE, ONNXLayerNormalizationOp>::value) {
369370
// Reduction of input
370371
meanOfX = create.onnx.reduceMean(reductionType, X, axes);
372+
Type originType = lnOp.getMean().getType();
373+
if (hasStaticShape(originType)) {
374+
meanOfX.setType(typeConverter->convertType(originType));
375+
}
371376
Value pow2OfMeanOfX = create.onnx.mul(meanOfX, meanOfX);
372377
Value XPow2 = create.onnx.mul(X, X);
373378
Value meanOfXPow2 = create.onnx.reduceMean(reductionType, XPow2, axes);
@@ -383,19 +388,30 @@ LogicalResult generateGenericLayerNormOpONNXCode(
383388
Value varWithEpsilon = create.onnx.add(var, epsilon);
384389
Value stdDev = create.onnx.sqrt(varWithEpsilon);
385390
Value invStdDev = create.onnx.reciprocal(stdDev);
391+
Type originType = lnOp.getInvStdDev().getType();
392+
if (hasStaticShape(originType)) {
393+
invStdDev.setType(typeConverter->convertType(originType));
394+
}
386395
Value normalized = create.onnx.mul(d, invStdDev);
387396
Value Y = create.onnx.mul(normalized, lnOp.getScale());
388-
if (!isNoneValue(lnOp.getB()))
397+
if (!isNoneValue(lnOp.getB())) {
389398
Y = create.onnx.add(Y, lnOp.getB());
399+
Type originYType = lnOp.getY().getType();
400+
if (hasStaticShape(originYType)) {
401+
Y.setType(typeConverter->convertType(originYType));
402+
}
403+
}
390404
replaceGenericLayerNormOp<OP_TYPE>(rewriter, lnOp, Y, meanOfX, invStdDev);
391405
return success();
392406
}
393407

408+
/*
394409
LogicalResult generateONNXLayerNormalizationOpONNXCode(
395410
ConversionPatternRewriter &rewriter, Location loc,
396411
ONNXLayerNormalizationOp lnOp) {
397412
return generateGenericLayerNormOpONNXCode(rewriter, loc, lnOp);
398413
}
414+
*/
399415

400416
template <typename OP_TYPE, typename SHAPE_HELPER_TYPE>
401417
struct GenericLayerNormaOpLowering : public OpConversionPattern<OP_TYPE> {
@@ -667,7 +683,8 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern<OP_TYPE> {
667683
return generateSIMDCode(rewriter, loc, lnOp, adaptor, shapeHelper, 4, VL,
668684
scaleBroadcastKind, biasBroadcastKind, scaleModFactor, biasModFactor);
669685
}
670-
return generateGenericLayerNormOpONNXCode(rewriter, loc, lnOp);
686+
return generateGenericLayerNormOpONNXCode(
687+
rewriter, loc, lnOp, this->typeConverter);
671688
}
672689

673690
using F1 = std::function<void(int64_t offsetInt, Value offsetVal)>;

0 commit comments

Comments
 (0)