Skip to content

Commit 4e1199b

Browse files
Merge OpenAI Triton commit 86e7117 (#4217)
This PR change the Triton base from 6116bfe to 86e7117 (May 12). Pass rate: 97.77%
2 parents 1f4e845 + 05fdcfd commit 4e1199b

33 files changed

+641
-241
lines changed

README.md

-40
Original file line numberDiff line numberDiff line change
@@ -24,46 +24,6 @@ pip install triton
2424

2525
Binary wheels are available for CPython 3.9-3.13.
2626

27-
# Enabling Blackwell Support
28-
29-
The main branch now features support for NVIDIA Blackwell GPUs using 5th
30-
generation tensor cores. To enable this, you will need two additional steps:
31-
32-
1. Build a pre-release PyTorch from source with CUDA 12.8
33-
2. Build triton from the latest source
34-
35-
36-
First, to build pytorch you need to have CUDA 12.8 installed locally. If not,
37-
follow the [instructions for your platform](https://developer.nvidia.com/cuda-downloads)
38-
```bash
39-
# Clone and checkout pytorch 2.6 release candidate
40-
git clone https://github.com/pytorch/pytorch
41-
cd pytorch
42-
git checkout v2.6.0-rc9
43-
git submodule sync
44-
git submodule update --init --recursive -j 8
45-
46-
# Install build dependencies (assumes you already have a system compiler)
47-
pip install -r requirements.txt
48-
pip install mkl-static mkl-include wheel
49-
50-
# Build PyTorch (will take a long time)
51-
export CUDA_HOME=/usr/local/cuda-12.8
52-
export CUDA_PATH=$CUDA_HOME
53-
export TORCH_CUDA_ARCH_LIST=Blackwell
54-
python setup.py develop
55-
56-
# Optional, package build into a wheel to install on other machines.
57-
python setup.py bdist_wheel
58-
ls dist # Wheel should be output in this directory
59-
```
60-
61-
Note that if you use the domain libraries (`torchvision`, `torchtext`,
62-
`torchaudio`, etc.) these will need to be built from source as well, otherwise
63-
their custom PyTorch extensions will not work.
64-
65-
Finally, follow the instructions below to install triton from source.
66-
6727
# Install from source
6828

6929
```shell

cmake/llvm-hash.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
092b6e73e651469527662443b592f98f442ece72
1+
3c709802d31b5bc5ed3af8284b40593ff39b9eec

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,7 @@ LinearLayout chooseScaledMfmaScaleLayout(
287287
// 8 elements. This layout is useful for emitting the widest 128-bit global
288288
// store instructions. Since it closely resembles mfmaLayout, conversion between
289289
// the two can be done using transferWithinWarp, without involving LDS
290-
LinearLayout chooseMfmaLikeStoreLayout(AMDMfmaEncodingAttr mfmaLayout,
291-
ArrayRef<int64_t> shape);
290+
std::optional<LinearLayout> chooseMfmaLikeStoreLayout(RankedTensorType valType);
292291

293292
} // namespace mlir::triton::gpu
294293
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

-5
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,6 @@ def TTG_MemDescReshapeOp : TTG_Op<"memdesc_reshape", [Pure,
265265
}];
266266

267267
let arguments = (ins TTG_MemDescType:$src);
268-
269-
let arguments = (
270-
ins TTG_MemDescType:$src
271-
);
272-
273268
let results = (outs TTG_MemDescType:$result);
274269

275270
let assemblyFormat = "$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))";

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ static const char *kLoopStageAttrName = "loop.stage";
2020
static const char *kLoopClusterAttrName = "loop.cluster";
2121
static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage";
2222
static const char *kAssignedStageAttrName = "ttg.assigned_stage";
23+
static const char *kAssignedClusterAttrName = "ttg.assigned_cluster";
2324

2425
//===----------------------------------------------------------------------===//
2526
// Hoisting Utilities
@@ -106,7 +107,7 @@ Value createScalarAlloc(ImplicitLocOpBuilder &rewriter, Type type,
106107
Value createBarrierAlloc(scf::ForOp forOp, int numBarriers,
107108
int arriveCount = 1);
108109
// Create an allocation that can hold distance number of tensor shapes.
109-
Value createAlloc(scf::ForOp forOp, RankedTensorType ty, Location loc,
110+
Value createAlloc(Operation *insertBefore, RankedTensorType ty, Location loc,
110111
gpu::SharedEncodingTrait sharedEnc, unsigned distance);
111112

112113
// Determine if the operation is a TMA load.

lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -157,22 +157,22 @@ struct AllocateWarpGroups
157157
}
158158

159159
// Compute the register deficit over the partition warp groups.
160-
int registerDeficit = 0;
160+
int registerBudget = maxnreg * baseNumWarps * threadsPerWarp;
161161
for (const WarpGroupInfo &wg : warpGroups) {
162162
assert(wg.numWarps % 4 == 0);
163-
registerDeficit +=
163+
registerBudget +=
164164
(maxnreg - wg.maxRequestedRegs) * wg.numWarps * threadsPerWarp;
165165
}
166-
if (registerDeficit <= 0)
166+
if (registerBudget <= 0)
167167
return;
168168

169169
// Determine the number of extra registers that we can distribute to the
170170
// default warp group.
171-
int leftover =
172-
((baseNumWarps * threadsPerWarp * maxnreg) + registerDeficit) /
173-
baseNumWarps / threadsPerWarp;
171+
int leftover = registerBudget / (baseNumWarps * threadsPerWarp);
174172
// Round down to the nearest multiple of 8.
175173
leftover = leftover / 8 * 8;
174+
if (leftover < 24)
175+
return; // too few registers
176176

177177
// Generate setmaxnreg in each partition according to its warp group.
178178
SmallVector<int32_t> maxnregsPerPartition(1 + arr.size());

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
298298
b.shl(b.lshr(offset, b.i32_val(rshiftVal)), b.i32_val(lshiftVal)),
299299
offset);
300300
}
301-
auto vecAddr = b.gep(sharedPtrTy, elemTy, smemBase, offset);
302-
vecAddr.setInbounds(true);
301+
auto vecAddr = b.gep(sharedPtrTy, elemTy, smemBase, offset,
302+
LLVM::GEPNoWrapFlags::inbounds);
303303
return vecAddr;
304304
};
305305

lib/Conversion/TritonGPUToLLVM/Utility.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,8 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
398398
smemOffset = b.sub(smemOffset, baseToAllocBaseDist);
399399
}
400400
auto ptrTy = smemBase.getType();
401-
auto vecAddr = b.gep(ptrTy, elemLlvmTy, smemBase, smemOffset);
402-
vecAddr.setInbounds(true);
401+
auto vecAddr = b.gep(ptrTy, elemLlvmTy, smemBase, smemOffset,
402+
LLVM::GEPNoWrapFlags::inbounds);
403403
return vecAddr;
404404
}
405405

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

+13-6
Original file line numberDiff line numberDiff line change
@@ -1537,10 +1537,17 @@ LinearLayout chooseScaledMfmaScaleLayout(
15371537
return newLL;
15381538
}
15391539

1540-
LinearLayout chooseMfmaLikeStoreLayout(AMDMfmaEncodingAttr mfmaLayout,
1541-
ArrayRef<int64_t> shape) {
1542-
assert(shape.size() == 2 && mfmaLayout.getMDim() == 32 &&
1543-
mfmaLayout.getNDim() == 32 && mfmaLayout.getIsTransposed());
1540+
std::optional<LinearLayout>
1541+
chooseMfmaLikeStoreLayout(RankedTensorType valType) {
1542+
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(valType.getEncoding());
1543+
1544+
// Currently support transposed [B]F16 MFMA32x32 on CDNA4
1545+
bool isMfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32;
1546+
Type elemType = valType.getElementType();
1547+
if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) &&
1548+
mfmaLayout.getVersionMajor() == 4 && mfmaLayout.getIsTransposed() &&
1549+
isMfma32))
1550+
return {};
15441551

15451552
MLIRContext *ctx = mfmaLayout.getContext();
15461553
StringAttr kRegister = S("register");
@@ -1565,8 +1572,8 @@ LinearLayout chooseMfmaLikeStoreLayout(AMDMfmaEncodingAttr mfmaLayout,
15651572
identityStandardND(kWarp, mfmaLayout.getWarpsPerCTA(), order);
15661573
LinearLayout ctaLayout = mfma8Layout.transposeOuts(standardOutDims) *
15671574
warpLayout.transposeOuts(standardOutDims);
1568-
mfma8Layout =
1569-
combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape);
1575+
mfma8Layout = combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(),
1576+
valType.getShape());
15701577
return mfma8Layout;
15711578
}
15721579

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -380,22 +380,22 @@ Value mlir::triton::createBarrierAlloc(scf::ForOp forOp, int numBarriers,
380380
return barrierAlloc;
381381
}
382382

383-
Value mlir::triton::createAlloc(scf::ForOp forOp, RankedTensorType ty,
383+
Value mlir::triton::createAlloc(Operation *insertBefore, RankedTensorType ty,
384384
Location loc,
385385
gpu::SharedEncodingTrait sharedEnc,
386386
unsigned distance) {
387-
OpBuilder builder(forOp);
387+
OpBuilder builder(insertBefore);
388388
Attribute sharedMemorySpace =
389-
ttg::SharedMemorySpaceAttr::get(forOp.getContext());
389+
ttg::SharedMemorySpaceAttr::get(insertBefore->getContext());
390390
SmallVector<int64_t> bufferShape(ty.getShape().begin(), ty.getShape().end());
391391
bufferShape.insert(bufferShape.begin(), distance);
392392
Type memdescType = ttg::MemDescType::get(bufferShape, ty.getElementType(),
393393
sharedEnc, sharedMemorySpace,
394394
/*mutableMemory=*/true);
395395
Value alloc = builder.create<ttg::LocalAllocOp>(loc, memdescType);
396396

397-
builder.setInsertionPointAfter(forOp);
398-
builder.create<ttg::LocalDeallocOp>(forOp.getLoc(), alloc);
397+
builder.setInsertionPointAfter(insertBefore);
398+
builder.create<ttg::LocalDeallocOp>(insertBefore->getLoc(), alloc);
399399
return alloc;
400400
}
401401

lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,11 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
276276
}
277277

278278
// If it's a shmem operand, it must either be defined outside the loop, or
279-
// come from an MemDescSubview op. Only ConvertLayout and Trans ops are
279+
// come from an MemDescSubview op. Only ConvertLayout and view ops are
280280
// allowed in between.
281281
Value transitiveOperand = operand;
282-
while (isa_and_nonnull<ttg::ConvertLayoutOp, ttg::MemDescTransOp>(
282+
while (isa_and_nonnull<ttg::ConvertLayoutOp, ttg::MemDescTransOp,
283+
ttg::MemDescReshapeOp>(
283284
transitiveOperand.getDefiningOp()) ||
284285
isa<BlockArgument>(transitiveOperand)) {
285286
auto blockArg = dyn_cast<BlockArgument>(transitiveOperand);

lib/Dialect/TritonGPU/Transforms/Utility.cpp

+10-8
Original file line numberDiff line numberDiff line change
@@ -1241,10 +1241,10 @@ ttg::LocalAllocOp findShmemAlloc(Value operand) {
12411241
// come from an MemDescSubview op. Only ConvertLayout and Trans ops are
12421242
// allowed in between.
12431243
Value transitiveOperand = operand;
1244-
while (
1245-
isa_and_nonnull<ttg::ConvertLayoutOp, tt::TransOp, ttg::MemDescTransOp>(
1246-
transitiveOperand.getDefiningOp()) ||
1247-
isa<BlockArgument>(transitiveOperand)) {
1244+
while (isa_and_nonnull<ttg::ConvertLayoutOp, tt::TransOp, ttg::MemDescTransOp,
1245+
ttg::MemDescReshapeOp>(
1246+
transitiveOperand.getDefiningOp()) ||
1247+
isa<BlockArgument>(transitiveOperand)) {
12481248
if (auto blockArg = dyn_cast<BlockArgument>(transitiveOperand)) {
12491249
assert(isa<scf::ForOp>(blockArg.getOwner()->getParentOp()) &&
12501250
"Block argument must come from a for loop");
@@ -1409,7 +1409,7 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
14091409
}
14101410

14111411
// Non-subview/trans ops will be replaced by `val`.
1412-
if (!isa<ttg::MemDescTransOp, ttg::MemDescSubviewOp>(use.getOwner())) {
1412+
if (!use.getOwner()->hasTrait<OpTrait::MemDescViewTrait>()) {
14131413
operandsToReplace.push_back(&use);
14141414
continue;
14151415
}
@@ -1427,13 +1427,15 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
14271427
oldType.getMemorySpace(), isMutable);
14281428
newVal = builder.create<ttg::MemDescSubviewOp>(
14291429
subview.getLoc(), newDstType, val, subview.getOffsets());
1430-
newVal.getDefiningOp()->setAttrs(user->getAttrs());
14311430
} else if (auto trans = dyn_cast<ttg::MemDescTransOp>(user)) {
14321431
newVal = builder.create<ttg::MemDescTransOp>(trans.getLoc(), val,
14331432
trans.getOrder());
1434-
newVal.getDefiningOp()->setAttrs(user->getAttrs());
1433+
} else if (auto reshape = dyn_cast<ttg::MemDescReshapeOp>(user)) {
1434+
newVal = builder.create<ttg::MemDescReshapeOp>(reshape.getLoc(),
1435+
reshape.getType(), val);
14351436
}
1436-
assert(newVal);
1437+
assert(newVal && "unhandled memdesc view");
1438+
newVal.getDefiningOp()->setAttrs(user->getAttrs());
14371439
replaceUsesAndPropagateType(builder, user, newVal);
14381440
opsToDelete.push_back(use.getOwner());
14391441
}

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ static PartitionScheme getPartitionScheme(scf::ForOp loop) {
130130
}
131131
while (!operandViews.empty()) {
132132
Operation *op = operandViews.pop_back_val();
133-
if (!op->hasOneUse() || !isa<MemDescSubviewOp, MemDescTransOp>(op))
133+
if (!op->hasOneUse() || !op->hasTrait<OpTrait::MemDescViewTrait>())
134134
continue;
135135
mma.operandViews.push_back(op);
136136
if (Operation *defOp = op->getOperand(0).getDefiningOp())
@@ -669,7 +669,7 @@ findSharedMemorySinkOps(Value value, SmallVectorImpl<Operation *> &sinkOps) {
669669
for (Operation *user : value.getUsers()) {
670670
if (isa<ttng::MMAv5OpInterface, LocalLoadOp>(user)) {
671671
sinkOps.push_back(user);
672-
} else if (isa<MemDescTransOp, MemDescSubviewOp>(user)) {
672+
} else if (user->hasTrait<OpTrait::MemDescViewTrait>()) {
673673
if (failed(findSharedMemorySinkOps(user->getResult(0), sinkOps)))
674674
return failure();
675675
} else {

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp

+16-4
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,23 @@ LogicalResult triton::gpu::partitionLoop(scf::ForOp loop) {
208208
continue;
209209
}
210210

211-
if (isa<RankedTensorType>(capture.getType())) {
212-
return mlir::emitWarning(capture.getLoc(),
213-
"FIXME: capturing tensor values into warp "
214-
"partitions is not supported");
211+
// Explicitly pass tensor captures through shared memory.
212+
auto tensorTy = dyn_cast<RankedTensorType>(capture.getType());
213+
if (tensorTy) {
214+
SharedEncodingTrait sharedEnc = getSharedEncoding(tensorTy);
215+
ImplicitLocOpBuilder b(capture.getLoc(), wsOp);
216+
auto memdescTy = MemDescType::get(
217+
tensorTy.getShape(), tensorTy.getElementType(), sharedEnc,
218+
SharedMemorySpaceAttr::get(tensorTy.getContext()));
219+
auto alloc = b.create<LocalAllocOp>(memdescTy, capture);
220+
for (Region *region : wsOp.getPartitionRegions()) {
221+
b.setInsertionPointToStart(&region->front());
222+
Value value = b.create<LocalLoadOp>(tensorTy, alloc);
223+
replaceAllUsesInRegionWith(capture, value, *region);
224+
}
225+
capture = alloc;
215226
}
227+
216228
explicitCaptures.push_back(capture);
217229
}
218230

test/Conversion/allocate_warp_groups.mlir

+19
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,22 @@ tt.func @setmaxnreg() {
9292
}
9393

9494
}
95+
96+
// -----
97+
98+
// CHECK: module attributes {ttg.maxnreg = 128 : i32
99+
module attributes {"ttg.num-warps" = 8 : i32} {
100+
101+
tt.func @steal_from_default() {
102+
// CHECK: actualRegisters = array<i32: 64, 192>
103+
ttg.warp_specialize() attributes {requestedRegisters = array<i32: 192>}
104+
default {
105+
ttg.warp_yield
106+
}
107+
partition0() num_warps(8) {
108+
ttg.warp_return
109+
} : () -> ()
110+
tt.return
111+
}
112+
113+
}

0 commit comments

Comments
 (0)