@@ -339,7 +339,8 @@ static inline void replaceGenericLayerNormOp(
339
339
// TODO: conversions of types are not handled.
340
340
template <typename OP_TYPE>
341
341
LogicalResult generateGenericLayerNormOpONNXCode (
342
- ConversionPatternRewriter &rewriter, Location loc, OP_TYPE lnOp) {
342
+ ConversionPatternRewriter &rewriter, Location loc, OP_TYPE lnOp,
343
+ const TypeConverter *const typeConverter) {
343
344
MDBuilder create (rewriter, loc);
344
345
Value X = lnOp.getX (); // Original value, not translated.
345
346
TensorType XType = X.getType ().cast <TensorType>();
@@ -368,6 +369,10 @@ LogicalResult generateGenericLayerNormOpONNXCode(
368
369
if constexpr (std::is_same<OP_TYPE, ONNXLayerNormalizationOp>::value) {
369
370
// Reduction of input
370
371
meanOfX = create.onnx .reduceMean (reductionType, X, axes);
372
+ Type originType = lnOp.getMean ().getType ();
373
+ if (hasStaticShape (originType)) {
374
+ meanOfX.setType (typeConverter->convertType (originType));
375
+ }
371
376
Value pow2OfMeanOfX = create.onnx .mul (meanOfX, meanOfX);
372
377
Value XPow2 = create.onnx .mul (X, X);
373
378
Value meanOfXPow2 = create.onnx .reduceMean (reductionType, XPow2, axes);
@@ -383,19 +388,30 @@ LogicalResult generateGenericLayerNormOpONNXCode(
383
388
Value varWithEpsilon = create.onnx .add (var, epsilon);
384
389
Value stdDev = create.onnx .sqrt (varWithEpsilon);
385
390
Value invStdDev = create.onnx .reciprocal (stdDev);
391
+ Type originType = lnOp.getInvStdDev ().getType ();
392
+ if (hasStaticShape (originType)) {
393
+ invStdDev.setType (typeConverter->convertType (originType));
394
+ }
386
395
Value normalized = create.onnx .mul (d, invStdDev);
387
396
Value Y = create.onnx .mul (normalized, lnOp.getScale ());
388
- if (!isNoneValue (lnOp.getB ()))
397
+ if (!isNoneValue (lnOp.getB ())) {
389
398
Y = create.onnx .add (Y, lnOp.getB ());
399
+ Type originYType = lnOp.getY ().getType ();
400
+ if (hasStaticShape (originYType)) {
401
+ Y.setType (typeConverter->convertType (originYType));
402
+ }
403
+ }
390
404
replaceGenericLayerNormOp<OP_TYPE>(rewriter, lnOp, Y, meanOfX, invStdDev);
391
405
return success ();
392
406
}
393
407
408
+ /*
394
409
LogicalResult generateONNXLayerNormalizationOpONNXCode(
395
410
ConversionPatternRewriter &rewriter, Location loc,
396
411
ONNXLayerNormalizationOp lnOp) {
397
412
return generateGenericLayerNormOpONNXCode(rewriter, loc, lnOp);
398
413
}
414
+ */
399
415
400
416
template <typename OP_TYPE, typename SHAPE_HELPER_TYPE>
401
417
struct GenericLayerNormaOpLowering : public OpConversionPattern <OP_TYPE> {
@@ -667,7 +683,8 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern<OP_TYPE> {
667
683
return generateSIMDCode (rewriter, loc, lnOp, adaptor, shapeHelper, 4 , VL,
668
684
scaleBroadcastKind, biasBroadcastKind, scaleModFactor, biasModFactor);
669
685
}
670
- return generateGenericLayerNormOpONNXCode (rewriter, loc, lnOp);
686
+ return generateGenericLayerNormOpONNXCode (
687
+ rewriter, loc, lnOp, this ->typeConverter );
671
688
}
672
689
673
690
using F1 = std::function<void (int64_t offsetInt, Value offsetVal)>;
0 commit comments