@@ -21,10 +21,12 @@ limitations under the License.
21
21
#include < string>
22
22
#include < vector>
23
23
24
+ #include " absl/algorithm/container.h"
24
25
#include " absl/log/check.h"
25
26
#include " absl/log/log.h"
26
27
#include " absl/status/status.h"
27
28
#include " absl/strings/str_cat.h"
29
+ #include " absl/strings/str_join.h"
28
30
#include " llvm/Support/raw_ostream.h"
29
31
#include " mlir/Dialect/Arith/IR/Arith.h"
30
32
#include " mlir/Dialect/Math/IR/Math.h"
@@ -42,6 +44,7 @@ limitations under the License.
42
44
#include " xla/primitive_util.h"
43
45
#include " xla/service/algorithm_util.h"
44
46
#include " xla/tsl/platform/errors.h"
47
+ #include " xla/tsl/platform/statusor.h"
45
48
#include " xla/xla_data.pb.h"
46
49
#include " tsl/platform/tensor_float_32_utils.h"
47
50
#include " triton/Dialect/Triton/IR/Dialect.h"
@@ -277,11 +280,14 @@ ttir::InputPrecision InferDotPrecision(const HloDotInstruction& dot) {
277
280
return use_tf32 ? ttir::InputPrecision::TF32 : ttir::InputPrecision::IEEE;
278
281
}
279
282
280
- Type GetAlgUnsetAccumulatorType (EmitterLocOpBuilder& b,
281
- const DotOperands& dot_operands) {
282
- Type lhs_type = ElementType (dot_operands.lhs );
283
- Type rhs_type = ElementType (dot_operands.rhs );
284
- Type accumulator_type = ElementType (dot_operands.accumulator );
283
+ absl::StatusOr<Type> GetAlgUnsetAccumulatorType (EmitterLocOpBuilder& b,
284
+ const HloDotInstruction& dot) {
285
+ TF_ASSIGN_OR_RETURN (Type lhs_type,
286
+ TritonType (b, dot.operand (0 )->shape ().element_type ()));
287
+ TF_ASSIGN_OR_RETURN (Type rhs_type,
288
+ TritonType (b, dot.operand (1 )->shape ().element_type ()));
289
+ TF_ASSIGN_OR_RETURN (Type accumulator_type,
290
+ TritonType (b, dot.shape ().element_type ()));
285
291
286
292
// The code below assumes that lhs and rhs have the same type. However
287
293
// this may not always be the case with f8 matmuls, e.g. e4m3×e5m2 is
@@ -313,12 +319,6 @@ absl::StatusOr<Value> EmitDotAlgUnset(EmitterLocOpBuilder& b,
313
319
Value rhs = dot_operands.rhs ;
314
320
Value acc = dot_operands.accumulator ;
315
321
316
- Type expected_acc_type = GetAlgUnsetAccumulatorType (b, dot_operands);
317
- if (ElementType (acc) != expected_acc_type) {
318
- return absl::FailedPreconditionError (
319
- " Given accumulator type for unset dot does not match expected type." );
320
- }
321
-
322
322
int max_num_imprecise_acc = 0 ;
323
323
if (ElementType (lhs).isFloat (8 ) || ElementType (rhs).isFloat (8 )) {
324
324
// For fp8 dots, disable accumulator promotion to mimick cuBLAS. It may make
@@ -365,114 +365,130 @@ absl::StatusOr<Value> EmitRegularDot(EmitterLocOpBuilder& b,
365
365
/* maxNumImpreciseAcc=*/ max_num_imprecise_acc);
366
366
}
367
367
368
- } // namespace
369
-
370
- absl::StatusOr<Value> EmitSingleTileDot (EmitterLocOpBuilder& b,
371
- const HloDotInstruction& dot,
372
- DotOperands dot_operands) {
373
- AlgorithmEmitter algorithm_emitter = nullptr ;
374
- PrecisionSpec precision_spec{dot.precision_config ().algorithm (),
375
- dot.precision_config ().operand_precision (0 ),
376
- dot.precision_config ().operand_precision (1 ),
377
- InferDotPrecision (dot)};
378
-
379
- // Algorithms mostly expect that their input and output types correspond to
380
- // what the algorithm describes. This is not always the case though, e.g.
381
- // for BF16_BF16_F32_X9, working from inputs casted to BF16 makes no sense;
382
- // this algorithm instead expects F32 inputs, and performs splits into BF16
383
- // sub-values under the hood.
384
- std::optional<Type> force_operands_type;
385
- std::optional<Type> force_accumulator_type;
386
-
387
- PrecisionConfig::Algorithm algorithm = precision_spec.algorithm ;
388
-
389
- Type bf16 = b.getBF16Type ();
390
- Type f16 = b.getF16Type ();
391
- Type f32 = b.getF32Type ();
392
- Type f64 = b.getF64Type ();
393
-
368
+ // Returns an emitter for the given dot algorithm. Raises an
369
+ // `UnimplementedError` if the algorithm is not supported.
370
+ absl::StatusOr<AlgorithmEmitter> GetAlgorithmEmitter (
371
+ const PrecisionConfig::Algorithm algorithm) {
394
372
switch (algorithm) {
395
373
case PrecisionConfig::ALG_UNSET:
396
- algorithm_emitter = EmitDotAlgUnset;
397
- break ;
374
+ return EmitDotAlgUnset;
398
375
case PrecisionConfig::ALG_DOT_F16_F16_F16:
399
- force_operands_type = f16 ;
400
- force_accumulator_type = f16 ;
401
- algorithm_emitter = EmitRegularDot;
402
- break ;
403
376
case PrecisionConfig::ALG_DOT_F32_F32_F32:
404
- force_operands_type = f32 ;
405
- force_accumulator_type = f32 ;
406
- algorithm_emitter = EmitRegularDot;
407
- break ;
408
377
case PrecisionConfig::ALG_DOT_F64_F64_F64:
409
- force_operands_type = f64 ;
410
- force_accumulator_type = f64 ;
411
- algorithm_emitter = EmitRegularDot;
412
- break ;
413
378
case PrecisionConfig::ALG_DOT_F16_F16_F32:
414
- force_operands_type = f16 ;
415
- force_accumulator_type = f32 ;
416
- algorithm_emitter = EmitRegularDot;
417
- break ;
418
379
case PrecisionConfig::ALG_DOT_BF16_BF16_BF16:
419
- force_operands_type = bf16 ;
420
- force_accumulator_type = bf16 ;
421
- algorithm_emitter = EmitRegularDot;
422
- break ;
423
380
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
424
- force_operands_type = bf16 ;
425
- force_accumulator_type = f32 ;
426
- algorithm_emitter = EmitRegularDot;
427
- break ;
381
+ return EmitRegularDot;
428
382
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3:
429
- force_operands_type = f32 ; // This is not a typo.
430
- force_accumulator_type = f32 ;
431
- algorithm_emitter = EmitBF16x3Matmul;
432
- break ;
383
+ return EmitBF16x3Matmul;
433
384
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
434
- force_operands_type = f32 ; // This is not a typo.
435
- force_accumulator_type = f32 ;
436
- algorithm_emitter = EmitBF16x6Matmul;
437
- break ;
385
+ return EmitBF16x6Matmul;
438
386
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
439
- // TODO(bchetioui): pass around tf32 matmul config.
440
- force_operands_type = f32 ;
441
- force_accumulator_type = f32 ;
442
387
// TODO(bchetioui): this should be factored out of EmitRegularDot.
443
- algorithm_emitter = EmitRegularDot;
444
- break ;
388
+ return EmitRegularDot;
445
389
case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3:
446
- // TODO(bchetioui): pass around tf32 matmul config.
447
- force_operands_type = f32 ;
448
- force_accumulator_type = f32 ;
449
390
// TODO(bchetioui): this should be factored out of EmitRegularDot.
450
- algorithm_emitter = EmitRegularDot;
451
- break ;
391
+ return EmitRegularDot;
452
392
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9:
453
- force_operands_type = f32 ; // This is not a typo.
454
- force_accumulator_type = f32 ;
455
- algorithm_emitter = EmitBF16x9Matmul;
456
- break ;
393
+ return EmitBF16x9Matmul;
457
394
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32:
458
- // TODO(bchetioui): How to enforce "any f8"?
459
- force_accumulator_type = f32 ;
460
- break ;
461
395
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM:
462
- // TODO(bchetioui): How to enforce "any f8"?
463
- force_accumulator_type = f32 ;
464
- break ;
465
396
default :
466
397
break ;
467
398
}
468
399
469
400
// Couldn't find an algorithm emitter for this algorithm. Raise an error.
470
- if (algorithm_emitter == nullptr ) {
471
- return absl::UnimplementedError (
472
- absl::StrCat (" This algorithm is not supported yet: " ,
473
- PrecisionConfig::Algorithm_Name (algorithm)));
401
+ return absl::UnimplementedError (
402
+ absl::StrCat (" This algorithm is not supported yet: " ,
403
+ PrecisionConfig::Algorithm_Name (algorithm)));
404
+ }
405
+
406
+ // Returns the `Type` that the dot operands should be casted to if there is a
407
+ // clear candidate. Raises an error if there are multiple allowed choices but
408
+ // the operands do not already conform to any of them. Returns `std::nullopt` if
409
+ // no casting is a priori needed.
410
+ absl::StatusOr<std::optional<Type>> GetForceOperandsType (
411
+ EmitterLocOpBuilder& b, const HloDotInstruction& dot,
412
+ const DotOperands& dot_operands) {
413
+ PrecisionConfig::Algorithm algorithm = dot.precision_config ().algorithm ();
414
+ if (algorithm == PrecisionConfig::ALG_UNSET) {
415
+ return std::nullopt;
416
+ }
417
+
418
+ TF_ASSIGN_OR_RETURN (
419
+ std::vector<PrimitiveType> allowed_operands_primitive_types,
420
+ algorithm_util::GetAllowedOperandsTypeForAlgorithm (algorithm));
421
+ CHECK (!allowed_operands_primitive_types.empty ());
422
+
423
+ std::vector<Type> allowed_operands_types;
424
+ allowed_operands_types.reserve (allowed_operands_primitive_types.size ());
425
+ for (PrimitiveType primitive_type : allowed_operands_primitive_types) {
426
+ TF_ASSIGN_OR_RETURN (Type type, TritonType (b, primitive_type));
427
+ allowed_operands_types.push_back (type);
428
+ }
429
+
430
+ Type lhs_type = ElementType (dot_operands.lhs );
431
+ Type rhs_type = ElementType (dot_operands.rhs );
432
+ if (allowed_operands_types.size () == 1 ) {
433
+ // If there is a single allowed operand type, we force the operands to use
434
+ // this type.
435
+ return allowed_operands_types.front ();
436
+
437
+ } else {
438
+ // If there are several allowed operand types, we just check that the
439
+ // operands have the same type, and that this type is one of the allowed
440
+ // ones. Raise an error otherwise.
441
+ if (lhs_type != rhs_type ||
442
+ !absl::c_linear_search (allowed_operands_types, lhs_type)) {
443
+ std::string allowed_operands_types_str = absl::StrJoin (
444
+ allowed_operands_types, " , " , [&](std::string* out, Type type) {
445
+ absl::StrAppend (out, MlirToString (type));
446
+ });
447
+ return absl::FailedPreconditionError (absl::StrCat (
448
+ " Expected dot operands to both have the same type, and for this type "
449
+ " to be one of the following types: " ,
450
+ allowed_operands_types_str, " but got " , MlirToString (lhs_type),
451
+ " and " , MlirToString (rhs_type)));
452
+ }
453
+ }
454
+
455
+ return std::nullopt;
456
+ }
457
+
458
+ } // namespace
459
+
460
+ // TODO(b/266862493): Add support for more types as needed.
461
+ absl::StatusOr<Type> GetDotAccumulatorType (EmitterLocOpBuilder& b,
462
+ const HloDotInstruction& dot) {
463
+ const PrecisionConfig::Algorithm algorithm =
464
+ dot.precision_config ().algorithm ();
465
+
466
+ if (algorithm == PrecisionConfig::ALG_UNSET) {
467
+ return GetAlgUnsetAccumulatorType (b, dot);
474
468
}
475
469
470
+ TF_ASSIGN_OR_RETURN (PrimitiveType accumulator_type,
471
+ algorithm_util::GetDotAccumulatorType (algorithm));
472
+ return TritonType (b, accumulator_type);
473
+ }
474
+
475
+ absl::StatusOr<Value> EmitSingleTileDot (EmitterLocOpBuilder& b,
476
+ const HloDotInstruction& dot,
477
+ DotOperands dot_operands) {
478
+ PrecisionConfig::Algorithm algorithm = dot.precision_config ().algorithm ();
479
+ PrecisionSpec precision_spec{
480
+ algorithm, dot.precision_config ().operand_precision (0 ),
481
+ dot.precision_config ().operand_precision (1 ), InferDotPrecision (dot)};
482
+
483
+ TF_ASSIGN_OR_RETURN (AlgorithmEmitter algorithm_emitter,
484
+ GetAlgorithmEmitter (algorithm));
485
+
486
+ TF_ASSIGN_OR_RETURN (std::optional<Type> force_operands_type,
487
+ GetForceOperandsType (b, dot, dot_operands));
488
+
489
+ TF_ASSIGN_OR_RETURN (Type force_accumulator_type,
490
+ GetDotAccumulatorType (b, dot));
491
+
476
492
if (force_operands_type.has_value ()) {
477
493
if (ElementType (dot_operands.lhs ) != *force_operands_type) {
478
494
dot_operands.lhs = Cast (b, dot_operands.lhs , *force_operands_type);
@@ -483,11 +499,9 @@ absl::StatusOr<Value> EmitSingleTileDot(EmitterLocOpBuilder& b,
483
499
}
484
500
}
485
501
486
- if (force_accumulator_type.has_value ()) {
487
- if (ElementType (dot_operands.accumulator ) != *force_accumulator_type) {
488
- dot_operands.accumulator =
489
- Cast (b, dot_operands.accumulator , *force_accumulator_type);
490
- }
502
+ if (ElementType (dot_operands.accumulator ) != force_accumulator_type) {
503
+ dot_operands.accumulator =
504
+ Cast (b, dot_operands.accumulator , force_accumulator_type);
491
505
}
492
506
493
507
TF_ASSIGN_OR_RETURN (Value result,
0 commit comments