Skip to content

Commit 77c00fa

Browse files
authored
[AMD] Add an option to force async copy overlapping
Use `TRITON_HIP_ASYNC_COPY_OVERLAP=1` env to enable async copy overlap
2 parents 5c4b1fb + 18ae32b commit 77c00fa

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3535
"TRITON_HIP_LOCAL_PREFETCH",
3636
"TRITON_HIP_USE_ASYNC_COPY",
3737
"TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE",
38+
"TRITON_HIP_ASYNC_COPY_OVERLAP",
3839
"TRITON_HIP_ENABLE_F16_ASYNC_PINGPONG",
3940
"TRITON_HIP_USE_BLOCK_PINGPONG",
4041
"TRITON_HIP_USE_IN_THREAD_TRANSPOSE",

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ class StreamPipeliner {
122122
public:
123123
StreamPipeliner(scf::ForOp _forOp, int _numStages, int _globalPrefetch,
124124
int _localPrefetch, bool _useAsyncCopy,
125-
bool _useF16BlockPingpong)
125+
bool _useF16BlockPingpong, bool _useAsyncCopyOverlap)
126126
: forOp(_forOp), numStages(_numStages), numBuffers(1),
127127
useAsyncCopy(_useAsyncCopy), useF16BlockPingpong(_useF16BlockPingpong),
128-
schedule(numStages),
128+
useAsyncCopyOverlap(_useAsyncCopyOverlap), schedule(numStages),
129129
axisInfoAnalysis(forOp->getParentOfType<ModuleOp>()) {
130130
int lastStage = numStages - 1;
131131
stages[SCHED_GLOBAL_LOAD] = 0;
@@ -181,6 +181,9 @@ class StreamPipeliner {
181181
// Whether or not we are intend to ping-pong.
182182
bool useF16BlockPingpong;
183183

184+
// Move AsyncCopy before AsyncWait.
185+
bool useAsyncCopyOverlap;
186+
184187
// Stage for each SchedType Op
185188
int stages[SCHED_SIZE];
186189
// Cluster for each SchedType Op
@@ -297,6 +300,14 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) {
297300
computeCluster = localLoadCluster;
298301
}
299302

303+
if (useAsyncCopyOverlap) {
304+
globalLoadCluster = 0;
305+
localStoreCluster = 1;
306+
asyncWaitCluster = 2;
307+
localLoadCluster = 3;
308+
computeCluster = 3;
309+
}
310+
300311
// Make assignments
301312
std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> clusterVec;
302313
std::generate(clusterVec.begin(), clusterVec.end(),
@@ -1072,6 +1083,9 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
10721083
// between MXFP4 and FP16.
10731084
bool useF16BlockPingpong =
10741085
triton::tools::getBoolEnv("TRITON_HIP_ENABLE_F16_ASYNC_PINGPONG");
1086+
bool useAsyncCopyOverlap =
1087+
triton::tools::getBoolEnv("TRITON_HIP_ASYNC_COPY_OVERLAP") &
1088+
useAsyncCopy;
10751089
SmallVector<scf::ForOp> loops;
10761090
getOperation()->walk([&](scf::ForOp forOp) {
10771091
labelLoadOpsForTritonDot(forOp);
@@ -1092,7 +1106,7 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
10921106
} else {
10931107
StreamPipeliner sp(forOp, tt::getNumStagesOrDefault(forOp, numStages),
10941108
globalPrefetch, localPrefetch, useAsyncCopy,
1095-
useF16BlockPingpong);
1109+
useF16BlockPingpong, useAsyncCopyOverlap);
10961110
(void)sp.pipelineLoop();
10971111
}
10981112
}

0 commit comments

Comments
 (0)