@@ -19,13 +19,15 @@ limitations under the License.
19
19
#include < array>
20
20
#include < atomic>
21
21
#include < cstdint>
22
+ #include < iterator>
22
23
#include < memory>
23
24
#include < optional>
24
25
#include < string>
25
26
#include < utility>
26
27
#include < variant>
27
28
#include < vector>
28
29
30
+ #include " absl/algorithm/container.h"
29
31
#include " absl/container/flat_hash_map.h"
30
32
#include " absl/container/flat_hash_set.h"
31
33
#include " absl/log/check.h"
@@ -56,12 +58,14 @@ limitations under the License.
56
58
#include " xla/service/algorithm_util.h"
57
59
#include " xla/service/call_inliner.h"
58
60
#include " xla/service/dump.h"
61
+ #include " xla/service/executable.h"
59
62
#include " xla/service/float_normalization.h"
60
63
#include " xla/service/gpu/autotuning/autotuner_compile_util.h"
61
64
#include " xla/service/gpu/autotuning/autotuner_util.h"
62
65
#include " xla/service/gpu/backend_configs.pb.h"
63
66
#include " xla/service/gpu/buffer_comparator.h"
64
67
#include " xla/service/gpu/gpu_float_support.h"
68
+ #include " xla/service/gpu/hlo_traversal.h"
65
69
#include " xla/service/gpu/ir_emission_utils.h"
66
70
#include " xla/service/gpu/kernels/custom_kernel.h"
67
71
#include " xla/service/gpu/kernels/custom_kernel_fusion.h"
@@ -84,6 +88,7 @@ limitations under the License.
84
88
#include " xla/stream_executor/device_description.h"
85
89
#include " xla/stream_executor/device_memory.h"
86
90
#include " xla/stream_executor/device_memory_allocator.h"
91
+ #include " xla/stream_executor/gpu/redzone_allocator.h"
87
92
#include " xla/stream_executor/semantic_version.h"
88
93
#include " xla/stream_executor/stream.h"
89
94
#include " xla/stream_executor/stream_executor_memory_allocator.h"
@@ -717,7 +722,7 @@ std::vector<BackendConfig> GenerateCustomKernelFusionConfigs(
717
722
std::vector<CustomKernelFusionPattern::Match> match =
718
723
patterns->Match (device_description, dot_instruction);
719
724
720
- // For Cutlass we expect only one match for a gemm fusion.
725
+ // For Cutlass we expect only one match for a GEMM fusion.
721
726
if (match.size () == 1 ) {
722
727
CustomKernelFusionRegistry* registry =
723
728
CustomKernelFusionRegistry::Default ();
@@ -1195,10 +1200,6 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const {
1195
1200
debug_options_.xla_gpu_exhaustive_tiling_search () && cc.IsAtLeastHopper ();
1196
1201
1197
1202
for (int num_stages : kNumStages ) {
1198
- // Volta doesn't support num_stages > 2.
1199
- if (!cc.IsAtLeastAmpere () && num_stages > 2 ) {
1200
- break ;
1201
- }
1202
1203
for (int tile_m : kBlockSizes ) {
1203
1204
for (int tile_n : kBlockSizes ) {
1204
1205
for (int tile_k : kBlockSizes ) {
@@ -1242,28 +1243,22 @@ std::vector<TritonGemmConfig> GemmFusionAutotunerImpl::GetDefaultTritonConfigs()
1242
1243
const {
1243
1244
using Config = TritonGemmConfig;
1244
1245
std::vector<Config> configs = {
1245
- Config (32 , 32 , 256 , 1 , 1 , 4 ), Config (64 , 32 , 32 , 16 , 1 , 4 ),
1246
- Config (32 , 64 , 64 , 4 , 1 , 4 ), Config (128 , 128 , 64 , 4 , 1 , 4 ),
1247
- Config (16 , 16 , 256 , 1 , 1 , 4 ), Config (16 , 128 , 32 , 16 , 1 , 4 ),
1248
- Config (16 , 64 , 128 , 1 , 1 , 4 ), Config (16 , 128 , 32 , 8 , 1 , 4 ),
1249
- Config (16 , 16 , 512 , 1 , 1 , 4 ), Config (32 , 16 , 512 , 1 , 1 , 4 ),
1250
- Config (64 , 32 , 64 , 1 , 2 , 8 )};
1251
- if (GetComputeCapability ().IsAtLeastAmpere ()) {
1252
- absl::c_copy (
1253
- std::vector<Config>{
1254
- Config (128 , 256 , 32 , 1 , 3 , 8 ), Config (256 , 128 , 32 , 1 , 3 , 8 ),
1255
- Config (256 , 64 , 32 , 1 , 4 , 4 ), Config (64 , 256 , 32 , 1 , 4 , 4 ),
1256
- Config (128 , 64 , 32 , 1 , 4 , 4 ), Config (64 , 128 , 32 , 1 , 4 , 4 ),
1257
- Config (256 , 128 , 128 , 1 , 3 , 8 ), Config (256 , 64 , 128 , 1 , 4 , 4 ),
1258
- Config (64 , 256 , 128 , 1 , 4 , 4 ), Config (128 , 128 , 128 , 1 , 4 , 4 ),
1259
- Config (128 , 64 , 64 , 1 , 4 , 4 ), Config (64 , 128 , 64 , 1 , 4 , 4 ),
1260
- Config (128 , 32 , 64 , 1 , 4 , 4 ), Config (64 , 32 , 64 , 1 , 4 , 4 ),
1261
- Config (32 , 128 , 32 , 1 , 4 , 4 ), Config (128 , 128 , 32 , 1 , 4 , 4 ),
1262
- Config (16 , 16 , 256 , 1 , 3 , 4 ), Config (128 , 128 , 64 , 2 , 1 , 8 ),
1263
- Config (64 , 64 , 64 , 1 , 2 , 4 ), Config (16 , 64 , 256 , 8 , 1 , 4 ),
1264
- Config (256 , 256 , 128 , 1 , 3 , 8 )},
1265
- std::back_inserter (configs));
1266
- }
1246
+ Config (32 , 32 , 256 , 1 , 1 , 4 ), Config (64 , 32 , 32 , 16 , 1 , 4 ),
1247
+ Config (32 , 64 , 64 , 4 , 1 , 4 ), Config (128 , 128 , 64 , 4 , 1 , 4 ),
1248
+ Config (16 , 16 , 256 , 1 , 1 , 4 ), Config (16 , 128 , 32 , 16 , 1 , 4 ),
1249
+ Config (16 , 64 , 128 , 1 , 1 , 4 ), Config (16 , 128 , 32 , 8 , 1 , 4 ),
1250
+ Config (16 , 16 , 512 , 1 , 1 , 4 ), Config (32 , 16 , 512 , 1 , 1 , 4 ),
1251
+ Config (64 , 32 , 64 , 1 , 2 , 8 ), Config (128 , 256 , 32 , 1 , 3 , 8 ),
1252
+ Config (256 , 128 , 32 , 1 , 3 , 8 ), Config (256 , 64 , 32 , 1 , 4 , 4 ),
1253
+ Config (64 , 256 , 32 , 1 , 4 , 4 ), Config (128 , 64 , 32 , 1 , 4 , 4 ),
1254
+ Config (64 , 128 , 32 , 1 , 4 , 4 ), Config (256 , 128 , 128 , 1 , 3 , 8 ),
1255
+ Config (256 , 64 , 128 , 1 , 4 , 4 ), Config (64 , 256 , 128 , 1 , 4 , 4 ),
1256
+ Config (128 , 128 , 128 , 1 , 4 , 4 ), Config (128 , 64 , 64 , 1 , 4 , 4 ),
1257
+ Config (64 , 128 , 64 , 1 , 4 , 4 ), Config (128 , 32 , 64 , 1 , 4 , 4 ),
1258
+ Config (64 , 32 , 64 , 1 , 4 , 4 ), Config (32 , 128 , 32 , 1 , 4 , 4 ),
1259
+ Config (128 , 128 , 32 , 1 , 4 , 4 ), Config (16 , 16 , 256 , 1 , 3 , 4 ),
1260
+ Config (128 , 128 , 64 , 2 , 1 , 8 ), Config (64 , 64 , 64 , 1 , 2 , 4 ),
1261
+ Config (16 , 64 , 256 , 8 , 1 , 4 ), Config (256 , 256 , 128 , 1 , 3 , 8 )};
1267
1262
if (GetComputeCapability ().IsAtLeastHopper ()) {
1268
1263
absl::c_copy (
1269
1264
std::vector<Config>{
0 commit comments