Skip to content

Commit 3f1d29b

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[XLA:GPU][NFC] Extract the logic to deduce the type of a dot's accumulator type into dot_algorithms.h.
Take the opportunity to split up the logic deriving the required operands type, the required accumulator type, and the algorithm emitter and hoist it out of `EmitSingleTileDot`. The purpose of this change is to make this logic available to the generic emitter as well. PiperOrigin-RevId: 742628334
1 parent 923fbc1 commit 3f1d29b

File tree

8 files changed

+198
-152
lines changed

8 files changed

+198
-152
lines changed

xla/backends/gpu/codegen/triton/BUILD

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,13 @@ cc_library(
101101
"//xla:xla_proto_cc",
102102
"//xla/codegen:emitter_loc_op_builder",
103103
"//xla/hlo/ir:hlo",
104-
"//xla/hlo/utils:hlo_query",
105104
"//xla/mlir_hlo",
106105
"//xla/mlir_hlo:map_mhlo_to_scalar_op",
107106
"//xla/mlir_hlo:transformation_helpers",
108107
"//xla/service/gpu:target_util",
109108
"//xla/service/llvm_ir:llvm_util",
110109
"//xla/stream_executor:device_description",
110+
"//xla/tsl/platform:status",
111111
"@com_google_absl//absl/log",
112112
"@com_google_absl//absl/log:check",
113113
"@com_google_absl//absl/status",
@@ -120,7 +120,6 @@ cc_library(
120120
"@llvm-project//mlir:MathDialect",
121121
"@llvm-project//mlir:Support",
122122
"@triton//:TritonDialects",
123-
"@tsl//tsl/platform:status",
124123
"@tsl//tsl/platform:statusor",
125124
],
126125
)

xla/backends/gpu/codegen/triton/dot_algorithms.cc

Lines changed: 114 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ limitations under the License.
2121
#include <string>
2222
#include <vector>
2323

24+
#include "absl/algorithm/container.h"
2425
#include "absl/log/check.h"
2526
#include "absl/log/log.h"
2627
#include "absl/status/status.h"
2728
#include "absl/strings/str_cat.h"
29+
#include "absl/strings/str_join.h"
2830
#include "llvm/Support/raw_ostream.h"
2931
#include "mlir/Dialect/Arith/IR/Arith.h"
3032
#include "mlir/Dialect/Math/IR/Math.h"
@@ -42,6 +44,7 @@ limitations under the License.
4244
#include "xla/primitive_util.h"
4345
#include "xla/service/algorithm_util.h"
4446
#include "xla/tsl/platform/errors.h"
47+
#include "xla/tsl/platform/statusor.h"
4548
#include "xla/xla_data.pb.h"
4649
#include "tsl/platform/tensor_float_32_utils.h"
4750
#include "triton/Dialect/Triton/IR/Dialect.h"
@@ -277,11 +280,14 @@ ttir::InputPrecision InferDotPrecision(const HloDotInstruction& dot) {
277280
return use_tf32 ? ttir::InputPrecision::TF32 : ttir::InputPrecision::IEEE;
278281
}
279282

280-
Type GetAlgUnsetAccumulatorType(EmitterLocOpBuilder& b,
281-
const DotOperands& dot_operands) {
282-
Type lhs_type = ElementType(dot_operands.lhs);
283-
Type rhs_type = ElementType(dot_operands.rhs);
284-
Type accumulator_type = ElementType(dot_operands.accumulator);
283+
absl::StatusOr<Type> GetAlgUnsetAccumulatorType(EmitterLocOpBuilder& b,
284+
const HloDotInstruction& dot) {
285+
TF_ASSIGN_OR_RETURN(Type lhs_type,
286+
TritonType(b, dot.operand(0)->shape().element_type()));
287+
TF_ASSIGN_OR_RETURN(Type rhs_type,
288+
TritonType(b, dot.operand(1)->shape().element_type()));
289+
TF_ASSIGN_OR_RETURN(Type accumulator_type,
290+
TritonType(b, dot.shape().element_type()));
285291

286292
// The code below assumes that lhs and rhs have the same type. However
287293
// this may not always be the case with f8 matmuls, e.g. e4m3×e5m2 is
@@ -313,12 +319,6 @@ absl::StatusOr<Value> EmitDotAlgUnset(EmitterLocOpBuilder& b,
313319
Value rhs = dot_operands.rhs;
314320
Value acc = dot_operands.accumulator;
315321

316-
Type expected_acc_type = GetAlgUnsetAccumulatorType(b, dot_operands);
317-
if (ElementType(acc) != expected_acc_type) {
318-
return absl::FailedPreconditionError(
319-
"Given accumulator type for unset dot does not match expected type.");
320-
}
321-
322322
int max_num_imprecise_acc = 0;
323323
if (ElementType(lhs).isFloat(8) || ElementType(rhs).isFloat(8)) {
324324
// For fp8 dots, disable accumulator promotion to mimick cuBLAS. It may make
@@ -365,114 +365,130 @@ absl::StatusOr<Value> EmitRegularDot(EmitterLocOpBuilder& b,
365365
/*maxNumImpreciseAcc=*/max_num_imprecise_acc);
366366
}
367367

368-
} // namespace
369-
370-
absl::StatusOr<Value> EmitSingleTileDot(EmitterLocOpBuilder& b,
371-
const HloDotInstruction& dot,
372-
DotOperands dot_operands) {
373-
AlgorithmEmitter algorithm_emitter = nullptr;
374-
PrecisionSpec precision_spec{dot.precision_config().algorithm(),
375-
dot.precision_config().operand_precision(0),
376-
dot.precision_config().operand_precision(1),
377-
InferDotPrecision(dot)};
378-
379-
// Algorithms mostly expect that their input and output types correspond to
380-
// what the algorithm describes. This is not always the case though, e.g.
381-
// for BF16_BF16_F32_X9, working from inputs casted to BF16 makes no sense;
382-
// this algorithm instead expects F32 inputs, and performs splits into BF16
383-
// sub-values under the hood.
384-
std::optional<Type> force_operands_type;
385-
std::optional<Type> force_accumulator_type;
386-
387-
PrecisionConfig::Algorithm algorithm = precision_spec.algorithm;
388-
389-
Type bf16 = b.getBF16Type();
390-
Type f16 = b.getF16Type();
391-
Type f32 = b.getF32Type();
392-
Type f64 = b.getF64Type();
393-
368+
// Returns an emitter for the given dot algorithm. Raises an
369+
// `UnimplementedError` if the algorithm is not supported.
370+
absl::StatusOr<AlgorithmEmitter> GetAlgorithmEmitter(
371+
const PrecisionConfig::Algorithm algorithm) {
394372
switch (algorithm) {
395373
case PrecisionConfig::ALG_UNSET:
396-
algorithm_emitter = EmitDotAlgUnset;
397-
break;
374+
return EmitDotAlgUnset;
398375
case PrecisionConfig::ALG_DOT_F16_F16_F16:
399-
force_operands_type = f16;
400-
force_accumulator_type = f16;
401-
algorithm_emitter = EmitRegularDot;
402-
break;
403376
case PrecisionConfig::ALG_DOT_F32_F32_F32:
404-
force_operands_type = f32;
405-
force_accumulator_type = f32;
406-
algorithm_emitter = EmitRegularDot;
407-
break;
408377
case PrecisionConfig::ALG_DOT_F64_F64_F64:
409-
force_operands_type = f64;
410-
force_accumulator_type = f64;
411-
algorithm_emitter = EmitRegularDot;
412-
break;
413378
case PrecisionConfig::ALG_DOT_F16_F16_F32:
414-
force_operands_type = f16;
415-
force_accumulator_type = f32;
416-
algorithm_emitter = EmitRegularDot;
417-
break;
418379
case PrecisionConfig::ALG_DOT_BF16_BF16_BF16:
419-
force_operands_type = bf16;
420-
force_accumulator_type = bf16;
421-
algorithm_emitter = EmitRegularDot;
422-
break;
423380
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
424-
force_operands_type = bf16;
425-
force_accumulator_type = f32;
426-
algorithm_emitter = EmitRegularDot;
427-
break;
381+
return EmitRegularDot;
428382
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3:
429-
force_operands_type = f32; // This is not a typo.
430-
force_accumulator_type = f32;
431-
algorithm_emitter = EmitBF16x3Matmul;
432-
break;
383+
return EmitBF16x3Matmul;
433384
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
434-
force_operands_type = f32; // This is not a typo.
435-
force_accumulator_type = f32;
436-
algorithm_emitter = EmitBF16x6Matmul;
437-
break;
385+
return EmitBF16x6Matmul;
438386
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
439-
// TODO(bchetioui): pass around tf32 matmul config.
440-
force_operands_type = f32;
441-
force_accumulator_type = f32;
442387
// TODO(bchetioui): this should be factored out of EmitRegularDot.
443-
algorithm_emitter = EmitRegularDot;
444-
break;
388+
return EmitRegularDot;
445389
case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3:
446-
// TODO(bchetioui): pass around tf32 matmul config.
447-
force_operands_type = f32;
448-
force_accumulator_type = f32;
449390
// TODO(bchetioui): this should be factored out of EmitRegularDot.
450-
algorithm_emitter = EmitRegularDot;
451-
break;
391+
return EmitRegularDot;
452392
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9:
453-
force_operands_type = f32; // This is not a typo.
454-
force_accumulator_type = f32;
455-
algorithm_emitter = EmitBF16x9Matmul;
456-
break;
393+
return EmitBF16x9Matmul;
457394
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32:
458-
// TODO(bchetioui): How to enforce "any f8"?
459-
force_accumulator_type = f32;
460-
break;
461395
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM:
462-
// TODO(bchetioui): How to enforce "any f8"?
463-
force_accumulator_type = f32;
464-
break;
465396
default:
466397
break;
467398
}
468399

469400
// Couldn't find an algorithm emitter for this algorithm. Raise an error.
470-
if (algorithm_emitter == nullptr) {
471-
return absl::UnimplementedError(
472-
absl::StrCat("This algorithm is not supported yet: ",
473-
PrecisionConfig::Algorithm_Name(algorithm)));
401+
return absl::UnimplementedError(
402+
absl::StrCat("This algorithm is not supported yet: ",
403+
PrecisionConfig::Algorithm_Name(algorithm)));
404+
}
405+
406+
// Returns the `Type` that the dot operands should be casted to if there is a
407+
// clear candidate. Raises an error if there are multiple allowed choices but
408+
// the operands do not already conform to any of them. Returns `std::nullopt` if
409+
// no casting is a priori needed.
410+
absl::StatusOr<std::optional<Type>> GetForceOperandsType(
411+
EmitterLocOpBuilder& b, const HloDotInstruction& dot,
412+
const DotOperands& dot_operands) {
413+
PrecisionConfig::Algorithm algorithm = dot.precision_config().algorithm();
414+
if (algorithm == PrecisionConfig::ALG_UNSET) {
415+
return std::nullopt;
416+
}
417+
418+
TF_ASSIGN_OR_RETURN(
419+
std::vector<PrimitiveType> allowed_operands_primitive_types,
420+
algorithm_util::GetAllowedOperandsTypeForAlgorithm(algorithm));
421+
CHECK(!allowed_operands_primitive_types.empty());
422+
423+
std::vector<Type> allowed_operands_types;
424+
allowed_operands_types.reserve(allowed_operands_primitive_types.size());
425+
for (PrimitiveType primitive_type : allowed_operands_primitive_types) {
426+
TF_ASSIGN_OR_RETURN(Type type, TritonType(b, primitive_type));
427+
allowed_operands_types.push_back(type);
428+
}
429+
430+
Type lhs_type = ElementType(dot_operands.lhs);
431+
Type rhs_type = ElementType(dot_operands.rhs);
432+
if (allowed_operands_types.size() == 1) {
433+
// If there is a single allowed operand type, we force the operands to use
434+
// this type.
435+
return allowed_operands_types.front();
436+
437+
} else {
438+
// If there are several allowed operand types, we just check that the
439+
// operands have the same type, and that this type is one of the allowed
440+
// ones. Raise an error otherwise.
441+
if (lhs_type != rhs_type ||
442+
!absl::c_linear_search(allowed_operands_types, lhs_type)) {
443+
std::string allowed_operands_types_str = absl::StrJoin(
444+
allowed_operands_types, ", ", [&](std::string* out, Type type) {
445+
absl::StrAppend(out, MlirToString(type));
446+
});
447+
return absl::FailedPreconditionError(absl::StrCat(
448+
"Expected dot operands to both have the same type, and for this type "
449+
"to be one of the following types: ",
450+
allowed_operands_types_str, " but got ", MlirToString(lhs_type),
451+
" and ", MlirToString(rhs_type)));
452+
}
453+
}
454+
455+
return std::nullopt;
456+
}
457+
458+
} // namespace
459+
460+
// TODO(b/266862493): Add support for more types as needed.
461+
absl::StatusOr<Type> GetDotAccumulatorType(EmitterLocOpBuilder& b,
462+
const HloDotInstruction& dot) {
463+
const PrecisionConfig::Algorithm algorithm =
464+
dot.precision_config().algorithm();
465+
466+
if (algorithm == PrecisionConfig::ALG_UNSET) {
467+
return GetAlgUnsetAccumulatorType(b, dot);
474468
}
475469

470+
TF_ASSIGN_OR_RETURN(PrimitiveType accumulator_type,
471+
algorithm_util::GetDotAccumulatorType(algorithm));
472+
return TritonType(b, accumulator_type);
473+
}
474+
475+
absl::StatusOr<Value> EmitSingleTileDot(EmitterLocOpBuilder& b,
476+
const HloDotInstruction& dot,
477+
DotOperands dot_operands) {
478+
PrecisionConfig::Algorithm algorithm = dot.precision_config().algorithm();
479+
PrecisionSpec precision_spec{
480+
algorithm, dot.precision_config().operand_precision(0),
481+
dot.precision_config().operand_precision(1), InferDotPrecision(dot)};
482+
483+
TF_ASSIGN_OR_RETURN(AlgorithmEmitter algorithm_emitter,
484+
GetAlgorithmEmitter(algorithm));
485+
486+
TF_ASSIGN_OR_RETURN(std::optional<Type> force_operands_type,
487+
GetForceOperandsType(b, dot, dot_operands));
488+
489+
TF_ASSIGN_OR_RETURN(Type force_accumulator_type,
490+
GetDotAccumulatorType(b, dot));
491+
476492
if (force_operands_type.has_value()) {
477493
if (ElementType(dot_operands.lhs) != *force_operands_type) {
478494
dot_operands.lhs = Cast(b, dot_operands.lhs, *force_operands_type);
@@ -483,11 +499,9 @@ absl::StatusOr<Value> EmitSingleTileDot(EmitterLocOpBuilder& b,
483499
}
484500
}
485501

486-
if (force_accumulator_type.has_value()) {
487-
if (ElementType(dot_operands.accumulator) != *force_accumulator_type) {
488-
dot_operands.accumulator =
489-
Cast(b, dot_operands.accumulator, *force_accumulator_type);
490-
}
502+
if (ElementType(dot_operands.accumulator) != force_accumulator_type) {
503+
dot_operands.accumulator =
504+
Cast(b, dot_operands.accumulator, force_accumulator_type);
491505
}
492506

493507
TF_ASSIGN_OR_RETURN(Value result,

xla/backends/gpu/codegen/triton/dot_algorithms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
#define XLA_BACKENDS_GPU_CODEGEN_TRITON_DOT_ALGORITHMS_H_
1818

1919
#include "absl/status/statusor.h"
20+
#include "mlir/IR/Types.h"
2021
#include "mlir/IR/Value.h"
2122
#include "xla/codegen/emitter_loc_op_builder.h"
2223
#include "xla/hlo/ir/hlo_instructions.h"
@@ -33,6 +34,11 @@ struct DotOperands {
3334
::mlir::Value accumulator;
3435
};
3536

37+
// Returns the type to use for accumulation for the given `dot` instruction.
38+
// This also handles the case where the algorithm is `ALG_UNSET`.
39+
absl::StatusOr<::mlir::Type> GetDotAccumulatorType(
40+
EmitterLocOpBuilder& b, const HloDotInstruction& dot);
41+
3642
// Emits a single-tile dot, considering the given `dot` instruction's algorithm
3743
// and operand precisions. Raises an `UnimplementedError` if the algorithm is
3844
// not supported.

xla/backends/gpu/codegen/triton/emitter_helpers.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ limitations under the License.
1717
#define XLA_BACKENDS_GPU_CODEGEN_TRITON_EMITTER_HELPERS_H_
1818

1919
#include <cstdint>
20+
#include <string>
2021

2122
#include "absl/log/check.h"
2223
#include "absl/log/log.h"
2324
#include "absl/status/statusor.h"
2425
#include "absl/strings/string_view.h"
2526
#include "llvm/ADT/SmallVector.h"
27+
#include "llvm/Support/raw_ostream.h"
2628
#include "mlir/Dialect/Arith/IR/Arith.h"
2729
#include "mlir/IR/Builders.h"
2830
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -37,10 +39,20 @@ limitations under the License.
3739
#include "xla/service/llvm_ir/llvm_util.h"
3840
#include "xla/shape_util.h"
3941
#include "xla/stream_executor/device_description.h"
42+
#include "xla/tsl/platform/status.h"
4043
#include "xla/xla.pb.h"
4144

4245
namespace xla::gpu::triton {
4346

47+
// Returns a string representation of the given MLIR entity.
48+
template <typename T>
49+
std::string MlirToString(T&& value) {
50+
std::string result;
51+
llvm::raw_string_ostream os(result);
52+
value.print(os);
53+
return result;
54+
}
55+
4456
// This is a wrapper around mlir::Value that can hold either a scalar or a
4557
// non-0D tensor. An attempt to use this class with 0D tensors will CHECK-fail
4658
// because 0D tensors are not supported by Triton.

0 commit comments

Comments
 (0)