Skip to content

Commit 22a9821

Browse files
Merge OpenAI Triton commit 6116bfe (#4189)
This PR change the Triton base from c6ee626 to 6116bfe (May 12). Pass rate: 97.77%
2 parents b665d5a + 13827be commit 22a9821

File tree

109 files changed

+2141
-1056
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+2141
-1056
lines changed

.pre-commit-config.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ repos:
4040
hooks:
4141
- id: clang-format
4242

43+
- repo: https://github.com/pre-commit/mirrors-mypy
44+
rev: "v1.15.0"
45+
hooks:
46+
- id: mypy
47+
pass_filenames: false
48+
4349
# Expand YAML anchors in files used by github workflows, because github can't
4450
# do this itself. This lets us use anchors, which avoids code duplication.
4551
- repo: local

bench/triton_bench/matmul_ogs_details/_finalize_scatter.py

-150
This file was deleted.

bench/triton_bench/matmul_ogs_details/_finalize_split_k.py

-38
This file was deleted.

bin/RegisterTritonDialects.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ void registerTestTritonAMDGPURangeAnalysis();
5959

6060
inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6161
mlir::registerAllPasses();
62-
mlir::registerTritonPasses();
62+
mlir::triton::registerTritonPasses();
6363
mlir::triton::gpu::registerTritonGPUPasses();
6464
mlir::registerTritonNvidiaGPUPasses();
6565
mlir::test::intel::registerTestAxisInfoPass();

include/triton/Analysis/Allocation.h

-4
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77
#include "llvm/ADT/SetVector.h"
88
#include "llvm/Support/raw_ostream.h"
99

10-
#include "triton/Dialect/Triton/IR/Dialect.h"
11-
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
12-
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
13-
#include <atomic>
1410
#include <limits>
1511

1612
namespace mlir {

include/triton/Analysis/AxisInfo.h

-4
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,8 @@
66

77
#include "mlir/Support/LLVM.h"
88
#include "triton/Analysis/Utility.h"
9-
#include "triton/Dialect/Triton/IR/Dialect.h"
10-
#include "triton/Dialect/Triton/IR/Utility.h"
11-
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
129

1310
#include <optional>
14-
#include <type_traits>
1511

1612
namespace mlir::triton {
1713

include/triton/Analysis/Membar.h

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#define TRITON_ANALYSIS_MEMBAR_H
33

44
#include "Allocation.h"
5-
#include "llvm/ADT/SmallPtrSet.h"
65

76
#include <set>
87

include/triton/Conversion/TritonGPUToLLVM/Utility.h

+43
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,49 @@ SmallVector<Value> unpackLLVector(Location loc, Value llvmVec,
729729

730730
Value packLLVector(Location loc, ValueRange vals, RewriterBase &rewriter);
731731

732+
inline std::optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp) {
733+
switch (atomicOp) {
734+
case RMWOp::AND:
735+
return LLVM::AtomicBinOp::_and;
736+
case RMWOp::OR:
737+
return LLVM::AtomicBinOp::_or;
738+
case RMWOp::XOR:
739+
return LLVM::AtomicBinOp::_xor;
740+
case RMWOp::ADD:
741+
return LLVM::AtomicBinOp::add;
742+
case RMWOp::FADD:
743+
return LLVM::AtomicBinOp::fadd;
744+
case RMWOp::MAX:
745+
return LLVM::AtomicBinOp::max;
746+
case RMWOp::MIN:
747+
return LLVM::AtomicBinOp::min;
748+
case RMWOp::UMAX:
749+
return LLVM::AtomicBinOp::umax;
750+
case RMWOp::UMIN:
751+
return LLVM::AtomicBinOp::umin;
752+
case RMWOp::XCHG:
753+
return LLVM::AtomicBinOp::xchg;
754+
default:
755+
return {};
756+
}
757+
}
758+
759+
inline std::optional<LLVM::AtomicOrdering>
760+
getMemoryOrdering(MemSemantic memOrdering) {
761+
switch (memOrdering) {
762+
case MemSemantic::RELAXED:
763+
return LLVM::AtomicOrdering::monotonic;
764+
case MemSemantic::ACQUIRE:
765+
return LLVM::AtomicOrdering::acquire;
766+
case MemSemantic::RELEASE:
767+
return LLVM::AtomicOrdering::release;
768+
case MemSemantic::ACQUIRE_RELEASE:
769+
return LLVM::AtomicOrdering::acq_rel;
770+
default:
771+
return {};
772+
}
773+
}
774+
732775
inline bool
733776
isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
734777
ArrayRef<int64_t> allocShape,

include/triton/Dialect/Triton/IR/TritonDialect.td

+2-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def Triton_Dialect : Dialect {
4545

4646
let discardableAttrs = (ins
4747
"::mlir::IntegerAttr":$num_stages,
48-
"::mlir::IntegerAttr":$latency
48+
"::mlir::IntegerAttr":$latency,
49+
"::mlir::IntegerAttr":$self_latency
4950
);
5051

5152
let hasConstantMaterializer = 1;

include/triton/Dialect/Triton/Transforms/Passes.h

+4-8
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,14 @@
66
namespace mlir {
77
namespace triton {
88

9-
std::unique_ptr<Pass> createCombineOpsPass();
10-
11-
std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();
12-
std::unique_ptr<Pass> createReorderBroadcastPass();
13-
std::unique_ptr<Pass> createRewriteTensorPointerPass();
14-
std::unique_ptr<Pass> createLoopUnrollPass();
15-
16-
} // namespace triton
9+
// Generate the pass class declarations.
10+
#define GEN_PASS_DECL
11+
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
1712

1813
#define GEN_PASS_REGISTRATION
1914
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
2015

16+
} // namespace triton
2117
} // namespace mlir
2218

2319
#endif

include/triton/Dialect/Triton/Transforms/Passes.td

+3-7
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp">
1919
=> dot(x,y,splat(0))`
2020
}];
2121

22-
let constructor = "mlir::triton::createCombineOpsPass()";
23-
2422
let dependentDialects = ["mlir::arith::ArithDialect"];
2523
}
2624

@@ -33,7 +31,7 @@ def TritonReorderBroadcast : Pass</*cli-arg*/"triton-reorder-broadcast", /*Op*/"
3331
In the event of a match, the broadcast (or splat) operation is delayed
3432
and performed after the ElementWise operation.
3533
}];
36-
let constructor = "mlir::triton::createReorderBroadcastPass()";
34+
3735
let dependentDialects = ["mlir::triton::TritonDialect"];
3836
}
3937

@@ -45,8 +43,6 @@ def TritonRewriteTensorPointer : Pass</*cli-arg*/"triton-rewrite-tensor-pointer"
4543
the pointer/mask/other for each load/store.
4644
}];
4745

48-
let constructor = "mlir::triton::createRewriteTensorPointerPass()";
49-
5046
let dependentDialects = ["mlir::triton::TritonDialect"];
5147
}
5248

@@ -56,7 +52,7 @@ def TritonLoopUnroll : Pass</*cli-arg*/"triton-loop-unroll", /*Op*/"mlir::Module
5652
The pass unrolls a scf loop with tt.loop_unroll_factor attribute. The attribute specialises how many iterations
5753
the loop should be unrolled.
5854
}];
59-
let constructor = "mlir::triton::createLoopUnrollPass()";
55+
6056
let dependentDialects = ["mlir::triton::TritonDialect"];
6157
}
6258

@@ -68,7 +64,7 @@ def TritonLoopInvariantCodeMotion : Pass</*cli-arg*/"triton-licm", /*Op*/"mlir::
6864
generates a trip-count check. For scf.while loops, it clones the condition
6965
from the before body.
7066
}];
71-
let constructor = "mlir::triton::createLoopInvariantCodeMotionPass()";
67+
7268
let dependentDialects = ["mlir::triton::TritonDialect"];
7369
}
7470

0 commit comments

Comments
 (0)