@@ -46,10 +46,11 @@ limitations under the License.
46
46
47
47
namespace xla {
48
48
namespace spmd {
49
-
50
49
namespace {
51
-
52
50
using hlo_sharding_util::GroupedSharding;
51
+ PartitioningMethod gather_partition_method = PartitioningMethod::kIndexParallel ;
52
+ PartitioningMethod scatter_partition_method =
53
+ PartitioningMethod::kIndexParallel ;
53
54
54
55
// Generates per-group partitioned hlo based on given grouped sharding.
55
56
PartitionedHlo PerGroupPartitionedHlo (
@@ -723,6 +724,22 @@ GatherPartitionMethods() {
723
724
" PartitionGatherIndexPassthroughDimensions" }};
724
725
}
725
726
727
+ // Helper function to get the gather partitioning method.
728
+ decltype (PartitionGather)* GetGatherPartitionMethod (PartitioningMethod method) {
729
+ switch (method) {
730
+ case PartitioningMethod::kIndexParallel :
731
+ return PartitionGatherIndexParallelDimensions;
732
+ case PartitioningMethod::kOperandPassthrough :
733
+ return PartitionGatherOperandPassthroughDimensions;
734
+ case PartitioningMethod::kTrivialSlicedOperand :
735
+ return PartitionGatherTrivialSlicedOperandDimensions;
736
+ case PartitioningMethod::kIndexPassthrough :
737
+ return PartitionGatherIndexPassthroughDimensions;
738
+ default :
739
+ return PartitionGatherIndexParallelDimensions;
740
+ }
741
+ }
742
+
726
743
// Estimates the memory and communication cost for each partitioning methods for
727
744
// gather.
728
745
std::pair<int64_t , int64_t > GatherPartitionMethodCostModel (
@@ -731,9 +748,12 @@ std::pair<int64_t, int64_t> GatherPartitionMethodCostModel(
731
748
const PartitionedHlo& indices, const Shape& output_shape,
732
749
const HloSharding& output_sharding, absl::Span<const int64_t> batch_dims,
733
750
absl::Span<const int64_t> slice_sizes, SpmdPartitioningVisitor* visitor) {
734
- if (partition_method == PartitionGatherIndexParallelDimensions) {
735
- // Always prioritize index parallel partitioning, and assume it has zero
751
+ decltype (PartitionGather)* zero_cost_method =
752
+ GetGatherPartitionMethod (gather_partition_method);
753
+ if (partition_method == zero_cost_method) {
754
+ // Always prioritize the user's chosen partitioning, and assume it has zero
736
755
// cost.
756
+ // This defaults to IndexParallel.
737
757
return {0 , 0 };
738
758
}
739
759
return EvaluatePartitionCost (gather, partition_method, gather, operand,
@@ -838,6 +858,7 @@ absl::Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
838
858
batch_dims.push_back (i);
839
859
}
840
860
}
861
+ gather_partition_method = options ().gather_partition_method ;
841
862
TF_ASSIGN_OR_RETURN (
842
863
HloInstruction * pgather,
843
864
PartitionGather (gather, operand, indices, gather->shape (),
@@ -1292,82 +1313,80 @@ absl::StatusOr<HloInstruction*> PartitionScatterIndexPassthroughDimensions(
1292
1313
// results.
1293
1314
return nullptr ;
1294
1315
}
1295
- HloInstruction* identity;
1296
- switch (*reduction_opcode) {
1297
- case HloOpcode::kAdd :
1298
- case HloOpcode::kOr :
1316
+ HloInstruction* identity;
1317
+ switch (*reduction_opcode) {
1318
+ case HloOpcode::kAdd :
1319
+ case HloOpcode::kOr :
1299
1320
identity = CreateZero (per_group_operand.hlo ()->shape (), b);
1300
1321
break ;
1301
- case HloOpcode::kMultiply :
1302
- case HloOpcode::kAnd :
1322
+ case HloOpcode::kMultiply :
1323
+ case HloOpcode::kAnd :
1303
1324
identity = CreateOne (per_group_operand.hlo ()->shape (), b);
1304
1325
break ;
1305
- case HloOpcode::kMinimum :
1326
+ case HloOpcode::kMinimum :
1306
1327
identity = CreateConstant (
1307
1328
per_group_operand.hlo ()->shape (),
1308
1329
LiteralUtil::MaxValue (scatter->shape ().element_type ()), b);
1309
1330
break ;
1310
- case HloOpcode::kMaximum :
1331
+ case HloOpcode::kMaximum :
1311
1332
identity = CreateConstant (
1312
1333
per_group_operand.hlo ()->shape (),
1313
1334
LiteralUtil::MinValue (scatter->shape ().element_type ()), b);
1314
1335
break ;
1315
- default :
1316
- return nullptr ;
1317
- }
1318
- // Update partition_id for partial replicate.
1319
- auto partition_id = indices.state ().partition_id ;
1320
- if (indices.sharding ().ReplicateOnLastTileDim ()) {
1321
- auto sharding_grouped = hlo_sharding_util::GroupShardingOnDims (
1322
- indices.sharding (),
1323
- {indices.sharding ().tile_assignment ().num_dimensions () - 1 });
1324
- auto per_group_partitioner_state = CreatePerGroupPartitioningState (
1325
- indices.state (), sharding_grouped.device_groups , b);
1326
- partition_id = per_group_partitioner_state.partition_id ;
1327
- }
1328
- // To avoid accumulating the initial operand multiple times during
1329
- // all-reduce, we use identity operands for all non-zero partitions.
1330
- auto not_partition_zero = b->AddInstruction (HloInstruction::CreateConvert (
1331
- ShapeUtil::MakeScalarShape (PRED), partition_id));
1332
- not_partition_zero = b->AddInstruction (HloInstruction::CreateBroadcast (
1333
- ShapeUtil::ChangeElementType (identity->shape (), PRED),
1334
- not_partition_zero, {}));
1335
- auto select_operand =
1336
- b->AddInstruction (HloInstruction::HloInstruction::CreateTernary (
1337
- identity->shape (), HloOpcode::kSelect , not_partition_zero, identity,
1338
- per_group_operand.hlo ()));
1339
- PartitionedHlo new_operand =
1340
- per_group_operand.CloneWithNewHlo (select_operand);
1341
- std::vector<PartitionedHlo> per_group_new_operands = {new_operand};
1342
- std::vector<PartitionedHlo> per_group_updates = {
1343
- PerGroupPartitionedHlo (updates[0 ], update_grouped, b, clean_ups)};
1344
- PartitionedHlo per_group_indices =
1345
- PerGroupPartitionedHlo (indices, indices_grouped, b, clean_ups);
1346
- auto pshape = MaybeGetTuplePerGroupBaseShape (output_grouped, output_shape);
1347
- TF_ASSIGN_OR_RETURN (
1348
- HloInstruction * pscatter,
1349
- PartitionScatter (
1350
- scatter, per_group_new_operands, per_group_indices,
1351
- per_group_updates, pshape,
1352
- HloSharding::Single (scatter->shape (), output_grouped.sharding ),
1353
- slice_sizes, visitor, allow_recursive));
1354
- // All-reduce along all dims in operand sharding -- this is OK because the
1355
- // operand is not sharded on index_vector_dim.
1356
- std::vector<int64_t > all_dims (indices.rank ());
1357
- absl::c_iota (all_dims, 0 );
1358
- auto all_reduce =
1359
- operands[0 ].state ().partitioner ->AllReduceAlongShardingDims (
1360
- b, pscatter, original_indices_sharding,
1361
- indices.state ().next_channel_id , all_dims,
1362
- operands[0 ].state ().collective_ops_creator , scatter->to_apply ());
1363
- all_reduce->set_sharding (
1364
- hlo_sharding_util::UngroupSharding (output_grouped));
1365
- if (allow_recursive) {
1366
- VLOG (5 ) << " [Scatter partitioning]: Partitioned as index passthrough" ;
1367
- }
1368
- return PartitionedHlo (all_reduce, output_shape, operands[0 ].state ())
1369
- .Reshard (output_sharding)
1370
- .hlo ();
1336
+ default :
1337
+ return nullptr ;
1338
+ }
1339
+ // Update partition_id for partial replicate.
1340
+ auto partition_id = indices.state ().partition_id ;
1341
+ if (indices.sharding ().ReplicateOnLastTileDim ()) {
1342
+ auto sharding_grouped = hlo_sharding_util::GroupShardingOnDims (
1343
+ indices.sharding (),
1344
+ {indices.sharding ().tile_assignment ().num_dimensions () - 1 });
1345
+ auto per_group_partitioner_state = CreatePerGroupPartitioningState (
1346
+ indices.state (), sharding_grouped.device_groups , b);
1347
+ partition_id = per_group_partitioner_state.partition_id ;
1348
+ }
1349
+ // To avoid accumulating the initial operand multiple times during
1350
+ // all-reduce, we use identity operands for all non-zero partitions.
1351
+ auto not_partition_zero = b->AddInstruction (HloInstruction::CreateConvert (
1352
+ ShapeUtil::MakeScalarShape (PRED), partition_id));
1353
+ not_partition_zero = b->AddInstruction (HloInstruction::CreateBroadcast (
1354
+ ShapeUtil::ChangeElementType (identity->shape (), PRED), not_partition_zero,
1355
+ {}));
1356
+ auto select_operand =
1357
+ b->AddInstruction (HloInstruction::HloInstruction::CreateTernary (
1358
+ identity->shape (), HloOpcode::kSelect , not_partition_zero, identity,
1359
+ per_group_operand.hlo ()));
1360
+ PartitionedHlo new_operand =
1361
+ per_group_operand.CloneWithNewHlo (select_operand);
1362
+ std::vector<PartitionedHlo> per_group_new_operands = {new_operand};
1363
+ std::vector<PartitionedHlo> per_group_updates = {
1364
+ PerGroupPartitionedHlo (updates[0 ], update_grouped, b, clean_ups)};
1365
+ PartitionedHlo per_group_indices =
1366
+ PerGroupPartitionedHlo (indices, indices_grouped, b, clean_ups);
1367
+ auto pshape = MaybeGetTuplePerGroupBaseShape (output_grouped, output_shape);
1368
+ TF_ASSIGN_OR_RETURN (
1369
+ HloInstruction * pscatter,
1370
+ PartitionScatter (
1371
+ scatter, per_group_new_operands, per_group_indices, per_group_updates,
1372
+ pshape,
1373
+ HloSharding::Single (scatter->shape (), output_grouped.sharding ),
1374
+ slice_sizes, visitor, allow_recursive));
1375
+ // All-reduce along all dims in operand sharding -- this is OK because the
1376
+ // operand is not sharded on index_vector_dim.
1377
+ std::vector<int64_t > all_dims (indices.rank ());
1378
+ absl::c_iota (all_dims, 0 );
1379
+ auto all_reduce = operands[0 ].state ().partitioner ->AllReduceAlongShardingDims (
1380
+ b, pscatter, original_indices_sharding, indices.state ().next_channel_id ,
1381
+ all_dims, operands[0 ].state ().collective_ops_creator ,
1382
+ scatter->to_apply ());
1383
+ all_reduce->set_sharding (hlo_sharding_util::UngroupSharding (output_grouped));
1384
+ if (allow_recursive) {
1385
+ VLOG (5 ) << " [Scatter partitioning]: Partitioned as index passthrough" ;
1386
+ }
1387
+ return PartitionedHlo (all_reduce, output_shape, operands[0 ].state ())
1388
+ .Reshard (output_sharding)
1389
+ .hlo ();
1371
1390
}
1372
1391
1373
1392
// Partition a Scatter when its sliced in a dimension in the operand that is
@@ -1487,14 +1506,31 @@ absl::StatusOr<HloInstruction*> PartitionScatterTrivialSlicedOperandDimensions(
1487
1506
// Returns a full list of partitioning methods used for scatter.
1488
1507
std::vector<std::pair<decltype (PartitionScatter)*, absl::string_view>>
1489
1508
ScatterPartitionMethods () {
1490
- return {{PartitionScatterIndexParallelDimensions,
1491
- " PartitionScatterIndexParallelDimensions" },
1492
- {PartitionScatterOperandPassthroughDimensions,
1493
- " PartitionScatterOperandPassthroughDimensions" },
1494
- {PartitionScatterTrivialSlicedOperandDimensions,
1495
- " PartitionScatterTrivialSlicedOperandDimensions" },
1496
- {PartitionScatterIndexPassthroughDimensions,
1497
- " PartitionScatterIndexPassthroughDimensions" }};
1509
+ return {{PartitionScatterIndexParallelDimensions,
1510
+ " PartitionScatterIndexParallelDimensions" },
1511
+ {PartitionScatterOperandPassthroughDimensions,
1512
+ " PartitionScatterOperandPassthroughDimensions" },
1513
+ {PartitionScatterTrivialSlicedOperandDimensions,
1514
+ " PartitionScatterTrivialSlicedOperandDimensions" },
1515
+ {PartitionScatterIndexPassthroughDimensions,
1516
+ " PartitionScatterIndexPassthroughDimensions" }};
1517
+ }
1518
+
1519
+ // Helper function to get the actual scatter partitioning method
1520
+ decltype (PartitionScatter)* GetScatterPartitionMethod (
1521
+ PartitioningMethod method) {
1522
+ switch (method) {
1523
+ case PartitioningMethod::kIndexParallel :
1524
+ return PartitionScatterIndexParallelDimensions;
1525
+ case PartitioningMethod::kOperandPassthrough :
1526
+ return PartitionScatterOperandPassthroughDimensions;
1527
+ case PartitioningMethod::kTrivialSlicedOperand :
1528
+ return PartitionScatterTrivialSlicedOperandDimensions;
1529
+ case PartitioningMethod::kIndexPassthrough :
1530
+ return PartitionScatterIndexPassthroughDimensions;
1531
+ default :
1532
+ return PartitionScatterIndexParallelDimensions;
1533
+ }
1498
1534
}
1499
1535
1500
1536
// Estimates the memory and communication for each partitioning methods for
@@ -1506,7 +1542,10 @@ std::pair<int64_t, int64_t> ScatterPartitionMethodCostModel(
1506
1542
const std::vector<PartitionedHlo>& updates, const Shape& output_shape,
1507
1543
const HloSharding& output_sharding, absl::Span<const int64_t> slice_sizes,
1508
1544
SpmdPartitioningVisitor* visitor) {
1509
- if (partition_method == PartitionScatterIndexParallelDimensions) {
1545
+ decltype (PartitionScatter)* zero_cost_method =
1546
+ GetScatterPartitionMethod (scatter_partition_method);
1547
+
1548
+ if (partition_method == zero_cost_method) {
1510
1549
// Always prioritize index parallel partitioning, and assume it has zero
1511
1550
// cost.
1512
1551
return {0 , 0 };
@@ -1679,6 +1718,7 @@ absl::Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
1679
1718
break ;
1680
1719
}
1681
1720
}
1721
+ scatter_partition_method = options ().scatter_partition_method ;
1682
1722
std::vector<int64_t > slice_sizes = hlo_sharding_util::GetScatterSliceSize (
1683
1723
operands[0 ].base_shape (), updates[0 ].base_shape (), dnums);
1684
1724
0 commit comments