diff --git a/src/Conversion/ONNXToKrnl/Tensor/Unique.cpp b/src/Conversion/ONNXToKrnl/Tensor/Unique.cpp index 095f56def4..91c6ab594b 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Unique.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Unique.cpp @@ -160,7 +160,13 @@ struct ONNXUniqueOpLowering : public ConversionPattern { // Insert an allocation and deallocation for the outputs. // Value outputY; - if (axis < 0) { + if (hasStaticShape(uniqueOp.getY().getType())) { + // This is a patch related to https://github.com/onnx/onnx/issues/6133 + MemRefType memrefType = + typeConverter->convertType(uniqueOp.getY().getType()) + .cast(); + outputY = create.mem.alignedAlloc(memrefType); + } else if (axis < 0) { MemRefType memrefType = MemRefType::get({ShapedType::kDynamic}, elementType); outputY = create.mem.alignedAlloc(memrefType, outputYDims); @@ -175,16 +181,37 @@ struct ONNXUniqueOpLowering : public ConversionPattern { Type i64Type = rewriter.getI64Type(); MemRefType memrefType = MemRefType::get({ShapedType::kDynamic}, i64Type); Value emptyMemref = create.mem.alignedAlloc(MemRefType::get({0}, i64Type)); - Value indices = isNoneValue(uniqueOp.getIndices()) - ? emptyMemref - : create.mem.alignedAlloc(memrefType, outputIndexDims); + + Type indicesType = uniqueOp.getIndices().getType(); + Value indices = + isNoneValue(uniqueOp.getIndices()) + ? emptyMemref + : (hasStaticShape(indicesType) + ? create.mem.alignedAlloc( + typeConverter->convertType(indicesType) + .cast()) + : create.mem.alignedAlloc(memrefType, outputIndexDims)); + + Type inverseIndicesType = uniqueOp.getInverseIndices().getType(); Value inverseIndices = isNoneValue(uniqueOp.getInverseIndices()) ? emptyMemref - : create.mem.alignedAlloc(memrefType, outputInverseIndexDims); - Value counts = isNoneValue(uniqueOp.getCounts()) - ? emptyMemref - : create.mem.alignedAlloc(memrefType, outputIndexDims); + : (hasStaticShape(inverseIndicesType) + ? create.mem.alignedAlloc( + typeConverter->convertType(inverseIndicesType) + .cast()) + : create.mem.alignedAlloc( + memrefType, outputInverseIndexDims)); + + Type countsType = uniqueOp.getCounts().getType(); + Value counts = + isNoneValue(uniqueOp.getCounts()) + ? emptyMemref + : (hasStaticShape(countsType) + ? create.mem.alignedAlloc( + typeConverter->convertType(countsType) + .cast()) + : create.mem.alignedAlloc(memrefType, outputIndexDims)); // // Emit a Unique call to get the outputs //