Skip to content

Commit 8ff94d3

Browse files
bchetiouitensorflower-gardener
authored andcommitted
[XLA:GPU][Cleanup] Remove pre-Ampere paths in GEMM fusion autotuner.
These paths are dead, given that GEMM fusions are gated on the compute capability being at least Ampere. Fix includes as a side cleanup. PiperOrigin-RevId: 681909633
1 parent a22c6d4 commit 8ff94d3

File tree

2 files changed

+25
-27
lines changed

2 files changed

+25
-27
lines changed

third_party/xla/xla/service/gpu/autotuning/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ cc_library(
6060
"//xla/service/gpu:backend_configs_cc",
6161
"//xla/service/gpu:buffer_comparator",
6262
"//xla/service/gpu:gpu_float_support",
63+
"//xla/service/gpu:hlo_traversal",
6364
"//xla/service/gpu:ir_emission_utils",
6465
"//xla/service/gpu:matmul_utils",
6566
"//xla/service/gpu:split_k_gemm_rewriter",
@@ -77,9 +78,11 @@ cc_library(
7778
"//xla/stream_executor:device_memory",
7879
"//xla/stream_executor:semantic_version",
7980
"//xla/stream_executor:stream_executor_memory_allocator",
81+
"//xla/stream_executor/gpu:redzone_allocator",
8082
"//xla/tools:hlo_decomposer_lib",
8183
"//xla/tsl/lib/core:bits",
8284
"//xla/tsl/util/proto:proto_utils",
85+
"@com_google_absl//absl/algorithm:container",
8386
"@com_google_absl//absl/container:flat_hash_map",
8487
"@com_google_absl//absl/container:flat_hash_set",
8588
"@com_google_absl//absl/log",

third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@ limitations under the License.
1919
#include <array>
2020
#include <atomic>
2121
#include <cstdint>
22+
#include <iterator>
2223
#include <memory>
2324
#include <optional>
2425
#include <string>
2526
#include <utility>
2627
#include <variant>
2728
#include <vector>
2829

30+
#include "absl/algorithm/container.h"
2931
#include "absl/container/flat_hash_map.h"
3032
#include "absl/container/flat_hash_set.h"
3133
#include "absl/log/check.h"
@@ -56,12 +58,14 @@ limitations under the License.
5658
#include "xla/service/algorithm_util.h"
5759
#include "xla/service/call_inliner.h"
5860
#include "xla/service/dump.h"
61+
#include "xla/service/executable.h"
5962
#include "xla/service/float_normalization.h"
6063
#include "xla/service/gpu/autotuning/autotuner_compile_util.h"
6164
#include "xla/service/gpu/autotuning/autotuner_util.h"
6265
#include "xla/service/gpu/backend_configs.pb.h"
6366
#include "xla/service/gpu/buffer_comparator.h"
6467
#include "xla/service/gpu/gpu_float_support.h"
68+
#include "xla/service/gpu/hlo_traversal.h"
6569
#include "xla/service/gpu/ir_emission_utils.h"
6670
#include "xla/service/gpu/kernels/custom_kernel.h"
6771
#include "xla/service/gpu/kernels/custom_kernel_fusion.h"
@@ -84,6 +88,7 @@ limitations under the License.
8488
#include "xla/stream_executor/device_description.h"
8589
#include "xla/stream_executor/device_memory.h"
8690
#include "xla/stream_executor/device_memory_allocator.h"
91+
#include "xla/stream_executor/gpu/redzone_allocator.h"
8792
#include "xla/stream_executor/semantic_version.h"
8893
#include "xla/stream_executor/stream.h"
8994
#include "xla/stream_executor/stream_executor_memory_allocator.h"
@@ -717,7 +722,7 @@ std::vector<BackendConfig> GenerateCustomKernelFusionConfigs(
717722
std::vector<CustomKernelFusionPattern::Match> match =
718723
patterns->Match(device_description, dot_instruction);
719724

720-
// For Cutlass we expect only one match for a gemm fusion.
725+
// For Cutlass we expect only one match for a GEMM fusion.
721726
if (match.size() == 1) {
722727
CustomKernelFusionRegistry* registry =
723728
CustomKernelFusionRegistry::Default();
@@ -1195,10 +1200,6 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const {
11951200
debug_options_.xla_gpu_exhaustive_tiling_search() && cc.IsAtLeastHopper();
11961201

11971202
for (int num_stages : kNumStages) {
1198-
// Volta doesn't support num_stages > 2.
1199-
if (!cc.IsAtLeastAmpere() && num_stages > 2) {
1200-
break;
1201-
}
12021203
for (int tile_m : kBlockSizes) {
12031204
for (int tile_n : kBlockSizes) {
12041205
for (int tile_k : kBlockSizes) {
@@ -1242,28 +1243,22 @@ std::vector<TritonGemmConfig> GemmFusionAutotunerImpl::GetDefaultTritonConfigs()
12421243
const {
12431244
using Config = TritonGemmConfig;
12441245
std::vector<Config> configs = {
1245-
Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4),
1246-
Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4),
1247-
Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4),
1248-
Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4),
1249-
Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4),
1250-
Config(64, 32, 64, 1, 2, 8)};
1251-
if (GetComputeCapability().IsAtLeastAmpere()) {
1252-
absl::c_copy(
1253-
std::vector<Config>{
1254-
Config(128, 256, 32, 1, 3, 8), Config(256, 128, 32, 1, 3, 8),
1255-
Config(256, 64, 32, 1, 4, 4), Config(64, 256, 32, 1, 4, 4),
1256-
Config(128, 64, 32, 1, 4, 4), Config(64, 128, 32, 1, 4, 4),
1257-
Config(256, 128, 128, 1, 3, 8), Config(256, 64, 128, 1, 4, 4),
1258-
Config(64, 256, 128, 1, 4, 4), Config(128, 128, 128, 1, 4, 4),
1259-
Config(128, 64, 64, 1, 4, 4), Config(64, 128, 64, 1, 4, 4),
1260-
Config(128, 32, 64, 1, 4, 4), Config(64, 32, 64, 1, 4, 4),
1261-
Config(32, 128, 32, 1, 4, 4), Config(128, 128, 32, 1, 4, 4),
1262-
Config(16, 16, 256, 1, 3, 4), Config(128, 128, 64, 2, 1, 8),
1263-
Config(64, 64, 64, 1, 2, 4), Config(16, 64, 256, 8, 1, 4),
1264-
Config(256, 256, 128, 1, 3, 8)},
1265-
std::back_inserter(configs));
1266-
}
1246+
Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4),
1247+
Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4),
1248+
Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4),
1249+
Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4),
1250+
Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4),
1251+
Config(64, 32, 64, 1, 2, 8), Config(128, 256, 32, 1, 3, 8),
1252+
Config(256, 128, 32, 1, 3, 8), Config(256, 64, 32, 1, 4, 4),
1253+
Config(64, 256, 32, 1, 4, 4), Config(128, 64, 32, 1, 4, 4),
1254+
Config(64, 128, 32, 1, 4, 4), Config(256, 128, 128, 1, 3, 8),
1255+
Config(256, 64, 128, 1, 4, 4), Config(64, 256, 128, 1, 4, 4),
1256+
Config(128, 128, 128, 1, 4, 4), Config(128, 64, 64, 1, 4, 4),
1257+
Config(64, 128, 64, 1, 4, 4), Config(128, 32, 64, 1, 4, 4),
1258+
Config(64, 32, 64, 1, 4, 4), Config(32, 128, 32, 1, 4, 4),
1259+
Config(128, 128, 32, 1, 4, 4), Config(16, 16, 256, 1, 3, 4),
1260+
Config(128, 128, 64, 2, 1, 8), Config(64, 64, 64, 1, 2, 4),
1261+
Config(16, 64, 256, 8, 1, 4), Config(256, 256, 128, 1, 3, 8)};
12671262
if (GetComputeCapability().IsAtLeastHopper()) {
12681263
absl::c_copy(
12691264
std::vector<Config>{

0 commit comments

Comments
 (0)