Skip to content

Merge OpenAI Triton commit 99b5e29 #4219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/triton/Dialect/Triton/Transforms/ArithTypeConversion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_ARITH_TYPE_CONVERSION_H_
#define TRITON_DIALECT_TRITON_TRANSFORMS_ARITH_TYPE_CONVERSION_H_
#include "mlir/Transforms/DialectConversion.h"

namespace mlir::triton {

/**
* @brief Provides helper patterns for converting arith operations using a type
* converter.
*
* Note at of the time of writing this isn't provided in upstream mlir.
*/
void populateArithTypeConversions(const TypeConverter &converter,
RewritePatternSet &patterns);

} // namespace mlir::triton

#endif // TRITON_DIALECT_TRITON_TRANSFORMS_ARITH_TYPE_CONVERSION_H_
19 changes: 19 additions & 0 deletions include/triton/Dialect/Triton/Transforms/FunctionTypeConversion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_FUNCTION_TYPE_CONVERSION_H_
#define TRITON_DIALECT_TRITON_TRANSFORMS_FUNCTION_TYPE_CONVERSION_H_
#include "mlir/Transforms/DialectConversion.h"

namespace mlir::triton {

/**
* @brief Provides helper patterns for converting triton function operations
* using a type converter.
*
* Note we cannot use upstream passes for this because they are unaware of
* tt.call and tt.return.
*/
void populateFunctionTypeConversions(const TypeConverter &converter,
RewritePatternSet &patterns);

} // namespace mlir::triton

#endif // TRITON_DIALECT_TRITON_TRANSFORMS_FUNCTION_TYPE_CONVERSION_H_
11 changes: 11 additions & 0 deletions include/triton/Dialect/Triton/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ def TritonRewriteTensorPointer : Pass</*cli-arg*/"triton-rewrite-tensor-pointer"
let dependentDialects = ["mlir::triton::TritonDialect"];
}

def TritonRewriteTensorDescriptorToPointer : Pass</*cli-arg*/"triton-rewrite-tensor-descriptor-to-pointer", /*Op*/"mlir::ModuleOp"> {
let summary = "Rewrite load/stores of tensor descriptors into pointer load/stores";
let description = [{
This pass rewrites all load/store semantics initiated by a `tt.make_tensor_descriptor` into pointer semantics. After
this pass, `tt.make_tensor_descriptor` will disappear, and it generates logics to compute the pointer/mask/other
for each load/store.
}];

let dependentDialects = ["mlir::triton::TritonDialect"];
}

def TritonLoopUnroll : Pass</*cli-arg*/"triton-loop-unroll", /*Op*/"mlir::ModuleOp"> {
let summary = "Loop unroller";
let description = [{
Expand Down
60 changes: 34 additions & 26 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -72,32 +72,40 @@ def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> {
//
// WarpGroupDot Op
//
def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<DotOpInterface>,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "warp group dot";

let description = [{
$d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp
}];

let arguments = (ins TTG_TensorOrMemDesc:$a,
TTG_TensorOrMemDesc:$b,
TT_FpIntTensor:$c,
Optional<I1>:$useC,
DefaultValuedAttr<TT_InputPrecisionAttr, "::mlir::triton::InputPrecision::IEEE">:$inputPrecision,
DefaultValuedAttr<I32Attr, "0">:$maxNumImpreciseAcc,
DefaultValuedAttr<BoolAttr, "false">:$isAsync);

let results = (outs TT_FpIntTensor:$d);

let assemblyFormat = "$a`,` $b`,` $c (`,` $useC^)? attr-dict `:` type($a) `*` type($b) `->` type($d)";

let extraClassDeclaration = [{
bool needsPartialAccumulator();
}];
def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [
DeclareOpInterfaceMethods<InferTypeOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<DotOpInterface>,
TypesMatchWith<"result's type matches accumulator's type", "d", "c", "$_self">
]> {
let summary = "warp group dot";

let description = [{
$d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp
}];

let arguments = (ins
TTG_TensorOrMemDesc:$a,
TTG_TensorOrMemDesc:$b,
TT_FpIntTensor:$c,
Optional<I1>:$useC,
DefaultValuedAttr<TT_InputPrecisionAttr, "::mlir::triton::InputPrecision::IEEE">:$inputPrecision,
DefaultValuedAttr<I32Attr, "0">:$maxNumImpreciseAcc,
DefaultValuedAttr<BoolAttr, "false">:$isAsync
);

let results = (outs TT_FpIntTensor:$d);

let assemblyFormat = [{
$a`,` $b`,` $c (`,` $useC^)? attr-dict
`:` type($a) `*` type($b) `->` type($d)
}];

let extraClassDeclaration = [{
bool needsPartialAccumulator();
}];

let hasVerifier = 1;
}

def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
Expand Down
50 changes: 50 additions & 0 deletions lib/Dialect/Triton/Transforms/ArithTypeConversion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "triton/Dialect/Triton/Transforms/ArithTypeConversion.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"

namespace {

struct RewriteArithSelectOp : mlir::OpConversionPattern<mlir::arith::SelectOp> {
using mlir::OpConversionPattern<mlir::arith::SelectOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::arith::SelectOp op, OneToNOpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
// Note we're replacing the select op with an if op because we are
// converting one value into many values.
auto newIf = rewriter.create<mlir::scf::IfOp>(
op.getLoc(), mlir::TypeRange(adaptor.getTrueValue()), op.getCondition(),
true);
// We set the attributes from the op in case the op has any additional
// attributes
newIf->setAttrs(op->getAttrs());

{
mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(newIf.thenBlock());
rewriter.create<mlir::scf::YieldOp>(op->getLoc(), adaptor.getTrueValue());
rewriter.setInsertionPointToStart(newIf.elseBlock());
rewriter.create<mlir::scf::YieldOp>(op->getLoc(),
adaptor.getFalseValue());
}

// Replace the old operation results
rewriter.replaceOpWithMultiple(op, {newIf->getResults()});

return mlir::success();
}
};

} // namespace
namespace mlir::triton {

void populateArithTypeConversions(const TypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<RewriteArithSelectOp>(converter, patterns.getContext());
}

} // namespace mlir::triton
3 changes: 3 additions & 0 deletions lib/Dialect/Triton/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ add_triton_library(TritonTransforms
LoopUnroll.cpp
ReorderBroadcast.cpp
RewriteTensorPointer.cpp
RewriteTensorDescriptorToPointer.cpp
ArithTypeConversion.cpp
FunctionTypeConversion.cpp

DEPENDS
TritonTransformsIncGen
Expand Down
86 changes: 86 additions & 0 deletions lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include "triton/Dialect/Triton/Transforms/FunctionTypeConversion.h"

#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"

#include <cstdlib>

namespace mlir::triton {

namespace {

SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
SmallVector<Value> ret;
for (const auto &vs : values) {
llvm::append_range(ret, vs);
}
return ret;
}

struct CallOpConversion : public OpConversionPattern<CallOp> {
using OpConversionPattern<CallOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<std::size_t> resultReplacementGrouping;
llvm::SmallVector<Type> convertedResults;

for (auto type : callOp->getResultTypes()) {
const auto oldNumFlattenedResults = convertedResults.size();
if (failed(getTypeConverter()->convertTypes(type, convertedResults))) {
return failure();
}
resultReplacementGrouping.push_back(convertedResults.size() -
oldNumFlattenedResults);
}

auto newCallOp = rewriter.create<CallOp>(
callOp->getLoc(), callOp.getCallee(), convertedResults,
flattenValues(adaptor.getOperands()));
// Preserve any additional attributes that may have been set on the op
newCallOp->setAttrs(callOp->getAttrs());

SmallVector<ValueRange> replacements;
std::size_t offset = 0;
for (auto groupSize : resultReplacementGrouping) {
replacements.push_back(newCallOp->getResults().slice(offset, groupSize));
offset += groupSize;
}

rewriter.replaceOpWithMultiple(callOp, replacements);
return success();
}
};

struct ReturnOpConversion : public OpConversionPattern<ReturnOp> {
using OpConversionPattern<ReturnOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ReturnOp returnOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newReturnOp = rewriter.create<ReturnOp>(
returnOp->getLoc(), flattenValues(adaptor.getOperands()));
// Preserve any additional attributes that may have been set on the op
newReturnOp->setAttrs(returnOp->getAttrs());

rewriter.replaceOp(returnOp, newReturnOp);
return success();
}
};

} // namespace

void populateFunctionTypeConversions(const TypeConverter &converter,
RewritePatternSet &patterns) {
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::triton::FuncOp>(
patterns, converter);
patterns.add<CallOpConversion, ReturnOpConversion>(converter,
patterns.getContext());
}

} // namespace mlir::triton
Loading