Skip to content

Commit 85f473a

Browse files
Merge OpenAI Triton commit 99b5e29 (#4219)
This PR change the Triton base from 86e7117 to 99b5e29 (May 13). Pass rate: 97.77%->97.25% (#4221, #4222)
2 parents fff1773 + 7fa8493 commit 85f473a

File tree

24 files changed

+1153
-1433
lines changed

24 files changed

+1153
-1433
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_ARITH_TYPE_CONVERSION_H_
2+
#define TRITON_DIALECT_TRITON_TRANSFORMS_ARITH_TYPE_CONVERSION_H_
3+
#include "mlir/Transforms/DialectConversion.h"
4+
5+
namespace mlir::triton {
6+
7+
/**
8+
* @brief Provides helper patterns for converting arith operations using a type
9+
* converter.
10+
*
11+
* Note at of the time of writing this isn't provided in upstream mlir.
12+
*/
13+
void populateArithTypeConversions(const TypeConverter &converter,
14+
RewritePatternSet &patterns);
15+
16+
} // namespace mlir::triton
17+
18+
#endif // TRITON_DIALECT_TRITON_TRANSFORMS_ARITH_TYPE_CONVERSION_H_
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_FUNCTION_TYPE_CONVERSION_H_
2+
#define TRITON_DIALECT_TRITON_TRANSFORMS_FUNCTION_TYPE_CONVERSION_H_
3+
#include "mlir/Transforms/DialectConversion.h"
4+
5+
namespace mlir::triton {
6+
7+
/**
8+
* @brief Provides helper patterns for converting triton function operations
9+
* using a type converter.
10+
*
11+
* Note we cannot use upstream passes for this because they are unaware of
12+
* tt.call and tt.return.
13+
*/
14+
void populateFunctionTypeConversions(const TypeConverter &converter,
15+
RewritePatternSet &patterns);
16+
17+
} // namespace mlir::triton
18+
19+
#endif // TRITON_DIALECT_TRITON_TRANSFORMS_FUNCTION_TYPE_CONVERSION_H_

include/triton/Dialect/Triton/Transforms/Passes.td

+11
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@ def TritonRewriteTensorPointer : Pass</*cli-arg*/"triton-rewrite-tensor-pointer"
4646
let dependentDialects = ["mlir::triton::TritonDialect"];
4747
}
4848

49+
def TritonRewriteTensorDescriptorToPointer : Pass</*cli-arg*/"triton-rewrite-tensor-descriptor-to-pointer", /*Op*/"mlir::ModuleOp"> {
50+
let summary = "Rewrite load/stores of tensor descriptors into pointer load/stores";
51+
let description = [{
52+
This pass rewrites all load/store semantics initiated by a `tt.make_tensor_descriptor` into pointer semantics. After
53+
this pass, `tt.make_tensor_descriptor` will disappear, and it generates logics to compute the pointer/mask/other
54+
for each load/store.
55+
}];
56+
57+
let dependentDialects = ["mlir::triton::TritonDialect"];
58+
}
59+
4960
def TritonLoopUnroll : Pass</*cli-arg*/"triton-loop-unroll", /*Op*/"mlir::ModuleOp"> {
5061
let summary = "Loop unroller";
5162
let description = [{

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

+34-26
Original file line numberDiff line numberDiff line change
@@ -72,32 +72,40 @@ def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> {
7272
//
7373
// WarpGroupDot Op
7474
//
75-
def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
76-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
77-
DeclareOpInterfaceMethods<DotOpInterface>,
78-
TypesMatchWith<"result's type matches accumulator's type",
79-
"d", "c", "$_self">]> {
80-
let summary = "warp group dot";
81-
82-
let description = [{
83-
$d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp
84-
}];
85-
86-
let arguments = (ins TTG_TensorOrMemDesc:$a,
87-
TTG_TensorOrMemDesc:$b,
88-
TT_FpIntTensor:$c,
89-
Optional<I1>:$useC,
90-
DefaultValuedAttr<TT_InputPrecisionAttr, "::mlir::triton::InputPrecision::IEEE">:$inputPrecision,
91-
DefaultValuedAttr<I32Attr, "0">:$maxNumImpreciseAcc,
92-
DefaultValuedAttr<BoolAttr, "false">:$isAsync);
93-
94-
let results = (outs TT_FpIntTensor:$d);
95-
96-
let assemblyFormat = "$a`,` $b`,` $c (`,` $useC^)? attr-dict `:` type($a) `*` type($b) `->` type($d)";
97-
98-
let extraClassDeclaration = [{
99-
bool needsPartialAccumulator();
100-
}];
75+
def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [
76+
DeclareOpInterfaceMethods<InferTypeOpInterface>,
77+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
78+
DeclareOpInterfaceMethods<DotOpInterface>,
79+
TypesMatchWith<"result's type matches accumulator's type", "d", "c", "$_self">
80+
]> {
81+
let summary = "warp group dot";
82+
83+
let description = [{
84+
$d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp
85+
}];
86+
87+
let arguments = (ins
88+
TTG_TensorOrMemDesc:$a,
89+
TTG_TensorOrMemDesc:$b,
90+
TT_FpIntTensor:$c,
91+
Optional<I1>:$useC,
92+
DefaultValuedAttr<TT_InputPrecisionAttr, "::mlir::triton::InputPrecision::IEEE">:$inputPrecision,
93+
DefaultValuedAttr<I32Attr, "0">:$maxNumImpreciseAcc,
94+
DefaultValuedAttr<BoolAttr, "false">:$isAsync
95+
);
96+
97+
let results = (outs TT_FpIntTensor:$d);
98+
99+
let assemblyFormat = [{
100+
$a`,` $b`,` $c (`,` $useC^)? attr-dict
101+
`:` type($a) `*` type($b) `->` type($d)
102+
}];
103+
104+
let extraClassDeclaration = [{
105+
bool needsPartialAccumulator();
106+
}];
107+
108+
let hasVerifier = 1;
101109
}
102110

103111
def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#include "triton/Dialect/Triton/Transforms/ArithTypeConversion.h"
2+
3+
#include "mlir/Dialect/Arith/IR/Arith.h"
4+
#include "mlir/Dialect/SCF/IR/SCF.h"
5+
#include "mlir/IR/PatternMatch.h"
6+
#include "mlir/Support/LLVM.h"
7+
#include "mlir/Transforms/DialectConversion.h"
8+
9+
namespace {
10+
11+
struct RewriteArithSelectOp : mlir::OpConversionPattern<mlir::arith::SelectOp> {
12+
using mlir::OpConversionPattern<mlir::arith::SelectOp>::OpConversionPattern;
13+
14+
mlir::LogicalResult
15+
matchAndRewrite(mlir::arith::SelectOp op, OneToNOpAdaptor adaptor,
16+
mlir::ConversionPatternRewriter &rewriter) const {
17+
// Note we're replacing the select op with an if op because we are
18+
// converting one value into many values.
19+
auto newIf = rewriter.create<mlir::scf::IfOp>(
20+
op.getLoc(), mlir::TypeRange(adaptor.getTrueValue()), op.getCondition(),
21+
true);
22+
// We set the attributes from the op in case the op has any additional
23+
// attributes
24+
newIf->setAttrs(op->getAttrs());
25+
26+
{
27+
mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
28+
rewriter.setInsertionPointToStart(newIf.thenBlock());
29+
rewriter.create<mlir::scf::YieldOp>(op->getLoc(), adaptor.getTrueValue());
30+
rewriter.setInsertionPointToStart(newIf.elseBlock());
31+
rewriter.create<mlir::scf::YieldOp>(op->getLoc(),
32+
adaptor.getFalseValue());
33+
}
34+
35+
// Replace the old operation results
36+
rewriter.replaceOpWithMultiple(op, {newIf->getResults()});
37+
38+
return mlir::success();
39+
}
40+
};
41+
42+
} // namespace
43+
namespace mlir::triton {
44+
45+
void populateArithTypeConversions(const TypeConverter &converter,
46+
RewritePatternSet &patterns) {
47+
patterns.add<RewriteArithSelectOp>(converter, patterns.getContext());
48+
}
49+
50+
} // namespace mlir::triton

lib/Dialect/Triton/Transforms/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ add_triton_library(TritonTransforms
88
LoopUnroll.cpp
99
ReorderBroadcast.cpp
1010
RewriteTensorPointer.cpp
11+
RewriteTensorDescriptorToPointer.cpp
12+
ArithTypeConversion.cpp
13+
FunctionTypeConversion.cpp
1114

1215
DEPENDS
1316
TritonTransformsIncGen
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include "triton/Dialect/Triton/Transforms/FunctionTypeConversion.h"
2+
3+
#include "mlir/IR/Value.h"
4+
#include "mlir/Support/LLVM.h"
5+
#include "mlir/Transforms/DialectConversion.h"
6+
#include "triton/Dialect/Triton/IR/Dialect.h"
7+
#include "llvm/ADT/STLExtras.h"
8+
#include "llvm/ADT/SmallVector.h"
9+
10+
#include <cstdlib>
11+
12+
namespace mlir::triton {
13+
14+
namespace {
15+
16+
SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
17+
SmallVector<Value> ret;
18+
for (const auto &vs : values) {
19+
llvm::append_range(ret, vs);
20+
}
21+
return ret;
22+
}
23+
24+
struct CallOpConversion : public OpConversionPattern<CallOp> {
25+
using OpConversionPattern<CallOp>::OpConversionPattern;
26+
27+
LogicalResult
28+
matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
29+
ConversionPatternRewriter &rewriter) const override {
30+
llvm::SmallVector<std::size_t> resultReplacementGrouping;
31+
llvm::SmallVector<Type> convertedResults;
32+
33+
for (auto type : callOp->getResultTypes()) {
34+
const auto oldNumFlattenedResults = convertedResults.size();
35+
if (failed(getTypeConverter()->convertTypes(type, convertedResults))) {
36+
return failure();
37+
}
38+
resultReplacementGrouping.push_back(convertedResults.size() -
39+
oldNumFlattenedResults);
40+
}
41+
42+
auto newCallOp = rewriter.create<CallOp>(
43+
callOp->getLoc(), callOp.getCallee(), convertedResults,
44+
flattenValues(adaptor.getOperands()));
45+
// Preserve any additional attributes that may have been set on the op
46+
newCallOp->setAttrs(callOp->getAttrs());
47+
48+
SmallVector<ValueRange> replacements;
49+
std::size_t offset = 0;
50+
for (auto groupSize : resultReplacementGrouping) {
51+
replacements.push_back(newCallOp->getResults().slice(offset, groupSize));
52+
offset += groupSize;
53+
}
54+
55+
rewriter.replaceOpWithMultiple(callOp, replacements);
56+
return success();
57+
}
58+
};
59+
60+
struct ReturnOpConversion : public OpConversionPattern<ReturnOp> {
61+
using OpConversionPattern<ReturnOp>::OpConversionPattern;
62+
63+
LogicalResult
64+
matchAndRewrite(ReturnOp returnOp, OneToNOpAdaptor adaptor,
65+
ConversionPatternRewriter &rewriter) const override {
66+
auto newReturnOp = rewriter.create<ReturnOp>(
67+
returnOp->getLoc(), flattenValues(adaptor.getOperands()));
68+
// Preserve any additional attributes that may have been set on the op
69+
newReturnOp->setAttrs(returnOp->getAttrs());
70+
71+
rewriter.replaceOp(returnOp, newReturnOp);
72+
return success();
73+
}
74+
};
75+
76+
} // namespace
77+
78+
void populateFunctionTypeConversions(const TypeConverter &converter,
79+
RewritePatternSet &patterns) {
80+
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::triton::FuncOp>(
81+
patterns, converter);
82+
patterns.add<CallOpConversion, ReturnOpConversion>(converter,
83+
patterns.getContext());
84+
}
85+
86+
} // namespace mlir::triton

0 commit comments

Comments
 (0)