@@ -122,10 +122,10 @@ class StreamPipeliner {
122
122
public:
123
123
StreamPipeliner (scf::ForOp _forOp, int _numStages, int _globalPrefetch,
124
124
int _localPrefetch, bool _useAsyncCopy,
125
- bool _useF16BlockPingpong)
125
+ bool _useF16BlockPingpong, bool _useAsyncCopyOverlap )
126
126
: forOp(_forOp), numStages(_numStages), numBuffers(1 ),
127
127
useAsyncCopy (_useAsyncCopy), useF16BlockPingpong(_useF16BlockPingpong),
128
- schedule(numStages),
128
+ useAsyncCopyOverlap(_useAsyncCopyOverlap), schedule(numStages),
129
129
axisInfoAnalysis(forOp->getParentOfType<ModuleOp>()) {
130
130
int lastStage = numStages - 1 ;
131
131
stages[SCHED_GLOBAL_LOAD] = 0 ;
@@ -181,6 +181,9 @@ class StreamPipeliner {
181
181
// Whether or not we are intend to ping-pong.
182
182
bool useF16BlockPingpong;
183
183
184
+ // Move AsyncCopy before AsyncWait.
185
+ bool useAsyncCopyOverlap;
186
+
184
187
// Stage for each SchedType Op
185
188
int stages[SCHED_SIZE];
186
189
// Cluster for each SchedType Op
@@ -297,6 +300,14 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) {
297
300
computeCluster = localLoadCluster;
298
301
}
299
302
303
+ if (useAsyncCopyOverlap) {
304
+ globalLoadCluster = 0 ;
305
+ localStoreCluster = 1 ;
306
+ asyncWaitCluster = 2 ;
307
+ localLoadCluster = 3 ;
308
+ computeCluster = 3 ;
309
+ }
310
+
300
311
// Make assignments
301
312
std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> clusterVec;
302
313
std::generate (clusterVec.begin (), clusterVec.end (),
@@ -1072,6 +1083,9 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
1072
1083
// between MXFP4 and FP16.
1073
1084
bool useF16BlockPingpong =
1074
1085
triton::tools::getBoolEnv (" TRITON_HIP_ENABLE_F16_ASYNC_PINGPONG" );
1086
+ bool useAsyncCopyOverlap =
1087
+ triton::tools::getBoolEnv (" TRITON_HIP_ASYNC_COPY_OVERLAP" ) &
1088
+ useAsyncCopy;
1075
1089
SmallVector<scf::ForOp> loops;
1076
1090
getOperation ()->walk ([&](scf::ForOp forOp) {
1077
1091
labelLoadOpsForTritonDot (forOp);
@@ -1092,7 +1106,7 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
1092
1106
} else {
1093
1107
StreamPipeliner sp (forOp, tt::getNumStagesOrDefault (forOp, numStages),
1094
1108
globalPrefetch, localPrefetch, useAsyncCopy,
1095
- useF16BlockPingpong);
1109
+ useF16BlockPingpong, useAsyncCopyOverlap );
1096
1110
(void )sp.pipelineLoop ();
1097
1111
}
1098
1112
}
0 commit comments