Skip to content

Commit 5c3785b

Browse files
PR tensorflow#14862: Add SPMD config option to specify zero cost method for gather/scatter.
Imported from GitHub PR openxla/xla#14862 Issue tensorflow#13304 In SPMD handling of gather/scatter the partition strategy is hardcoded to IndexParallel strategy. This is not optimal for all topology. This PR makes this option an SPMD config, but defaults to IndexParallel to maintain existing behavior. Clang-format also fixed some formatting. Tests were added and all tests pass. Copybara import of the project: -- 7f83c21573f24cd4e314b13ce2e349dd6194b451 by ptoulme-aws <[email protected]>: Add SPMD config option to specify zero cost method for gather/scatter. Merging this change closes tensorflow#14862 PiperOrigin-RevId: 652736743
1 parent 15c227b commit 5c3785b

File tree

3 files changed

+284
-79
lines changed

3 files changed

+284
-79
lines changed

third_party/xla/xla/service/spmd/gather_scatter_handler.cc

Lines changed: 117 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,11 @@ limitations under the License.
4646

4747
namespace xla {
4848
namespace spmd {
49-
5049
namespace {
51-
5250
using hlo_sharding_util::GroupedSharding;
51+
PartitioningMethod gather_partition_method = PartitioningMethod::kIndexParallel;
52+
PartitioningMethod scatter_partition_method =
53+
PartitioningMethod::kIndexParallel;
5354

5455
// Generates per-group partitioned hlo based on given grouped sharding.
5556
PartitionedHlo PerGroupPartitionedHlo(
@@ -723,6 +724,22 @@ GatherPartitionMethods() {
723724
"PartitionGatherIndexPassthroughDimensions"}};
724725
}
725726

727+
// Helper function to get the gather partitioning method.
728+
decltype(PartitionGather)* GetGatherPartitionMethod(PartitioningMethod method) {
729+
switch (method) {
730+
case PartitioningMethod::kIndexParallel:
731+
return PartitionGatherIndexParallelDimensions;
732+
case PartitioningMethod::kOperandPassthrough:
733+
return PartitionGatherOperandPassthroughDimensions;
734+
case PartitioningMethod::kTrivialSlicedOperand:
735+
return PartitionGatherTrivialSlicedOperandDimensions;
736+
case PartitioningMethod::kIndexPassthrough:
737+
return PartitionGatherIndexPassthroughDimensions;
738+
default:
739+
return PartitionGatherIndexParallelDimensions;
740+
}
741+
}
742+
726743
// Estimates the memory and communication cost for each partitioning methods for
727744
// gather.
728745
std::pair<int64_t, int64_t> GatherPartitionMethodCostModel(
@@ -731,9 +748,12 @@ std::pair<int64_t, int64_t> GatherPartitionMethodCostModel(
731748
const PartitionedHlo& indices, const Shape& output_shape,
732749
const HloSharding& output_sharding, absl::Span<const int64_t> batch_dims,
733750
absl::Span<const int64_t> slice_sizes, SpmdPartitioningVisitor* visitor) {
734-
if (partition_method == PartitionGatherIndexParallelDimensions) {
735-
// Always prioritize index parallel partitioning, and assume it has zero
751+
decltype(PartitionGather)* zero_cost_method =
752+
GetGatherPartitionMethod(gather_partition_method);
753+
if (partition_method == zero_cost_method) {
754+
// Always prioritize the user's chosen partitioning, and assume it has zero
736755
// cost.
756+
// This defaults to IndexParallel.
737757
return {0, 0};
738758
}
739759
return EvaluatePartitionCost(gather, partition_method, gather, operand,
@@ -838,6 +858,7 @@ absl::Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
838858
batch_dims.push_back(i);
839859
}
840860
}
861+
gather_partition_method = options().gather_partition_method;
841862
TF_ASSIGN_OR_RETURN(
842863
HloInstruction * pgather,
843864
PartitionGather(gather, operand, indices, gather->shape(),
@@ -1292,82 +1313,80 @@ absl::StatusOr<HloInstruction*> PartitionScatterIndexPassthroughDimensions(
12921313
// results.
12931314
return nullptr;
12941315
}
1295-
HloInstruction* identity;
1296-
switch (*reduction_opcode) {
1297-
case HloOpcode::kAdd:
1298-
case HloOpcode::kOr:
1316+
HloInstruction* identity;
1317+
switch (*reduction_opcode) {
1318+
case HloOpcode::kAdd:
1319+
case HloOpcode::kOr:
12991320
identity = CreateZero(per_group_operand.hlo()->shape(), b);
13001321
break;
1301-
case HloOpcode::kMultiply:
1302-
case HloOpcode::kAnd:
1322+
case HloOpcode::kMultiply:
1323+
case HloOpcode::kAnd:
13031324
identity = CreateOne(per_group_operand.hlo()->shape(), b);
13041325
break;
1305-
case HloOpcode::kMinimum:
1326+
case HloOpcode::kMinimum:
13061327
identity = CreateConstant(
13071328
per_group_operand.hlo()->shape(),
13081329
LiteralUtil::MaxValue(scatter->shape().element_type()), b);
13091330
break;
1310-
case HloOpcode::kMaximum:
1331+
case HloOpcode::kMaximum:
13111332
identity = CreateConstant(
13121333
per_group_operand.hlo()->shape(),
13131334
LiteralUtil::MinValue(scatter->shape().element_type()), b);
13141335
break;
1315-
default:
1316-
return nullptr;
1317-
}
1318-
// Update partition_id for partial replicate.
1319-
auto partition_id = indices.state().partition_id;
1320-
if (indices.sharding().ReplicateOnLastTileDim()) {
1321-
auto sharding_grouped = hlo_sharding_util::GroupShardingOnDims(
1322-
indices.sharding(),
1323-
{indices.sharding().tile_assignment().num_dimensions() - 1});
1324-
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1325-
indices.state(), sharding_grouped.device_groups, b);
1326-
partition_id = per_group_partitioner_state.partition_id;
1327-
}
1328-
// To avoid accumulating the initial operand multiple times during
1329-
// all-reduce, we use identity operands for all non-zero partitions.
1330-
auto not_partition_zero = b->AddInstruction(HloInstruction::CreateConvert(
1331-
ShapeUtil::MakeScalarShape(PRED), partition_id));
1332-
not_partition_zero = b->AddInstruction(HloInstruction::CreateBroadcast(
1333-
ShapeUtil::ChangeElementType(identity->shape(), PRED),
1334-
not_partition_zero, {}));
1335-
auto select_operand =
1336-
b->AddInstruction(HloInstruction::HloInstruction::CreateTernary(
1337-
identity->shape(), HloOpcode::kSelect, not_partition_zero, identity,
1338-
per_group_operand.hlo()));
1339-
PartitionedHlo new_operand =
1340-
per_group_operand.CloneWithNewHlo(select_operand);
1341-
std::vector<PartitionedHlo> per_group_new_operands = {new_operand};
1342-
std::vector<PartitionedHlo> per_group_updates = {
1343-
PerGroupPartitionedHlo(updates[0], update_grouped, b, clean_ups)};
1344-
PartitionedHlo per_group_indices =
1345-
PerGroupPartitionedHlo(indices, indices_grouped, b, clean_ups);
1346-
auto pshape = MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape);
1347-
TF_ASSIGN_OR_RETURN(
1348-
HloInstruction * pscatter,
1349-
PartitionScatter(
1350-
scatter, per_group_new_operands, per_group_indices,
1351-
per_group_updates, pshape,
1352-
HloSharding::Single(scatter->shape(), output_grouped.sharding),
1353-
slice_sizes, visitor, allow_recursive));
1354-
// All-reduce along all dims in operand sharding -- this is OK because the
1355-
// operand is not sharded on index_vector_dim.
1356-
std::vector<int64_t> all_dims(indices.rank());
1357-
absl::c_iota(all_dims, 0);
1358-
auto all_reduce =
1359-
operands[0].state().partitioner->AllReduceAlongShardingDims(
1360-
b, pscatter, original_indices_sharding,
1361-
indices.state().next_channel_id, all_dims,
1362-
operands[0].state().collective_ops_creator, scatter->to_apply());
1363-
all_reduce->set_sharding(
1364-
hlo_sharding_util::UngroupSharding(output_grouped));
1365-
if (allow_recursive) {
1366-
VLOG(5) << "[Scatter partitioning]: Partitioned as index passthrough";
1367-
}
1368-
return PartitionedHlo(all_reduce, output_shape, operands[0].state())
1369-
.Reshard(output_sharding)
1370-
.hlo();
1336+
default:
1337+
return nullptr;
1338+
}
1339+
// Update partition_id for partial replicate.
1340+
auto partition_id = indices.state().partition_id;
1341+
if (indices.sharding().ReplicateOnLastTileDim()) {
1342+
auto sharding_grouped = hlo_sharding_util::GroupShardingOnDims(
1343+
indices.sharding(),
1344+
{indices.sharding().tile_assignment().num_dimensions() - 1});
1345+
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1346+
indices.state(), sharding_grouped.device_groups, b);
1347+
partition_id = per_group_partitioner_state.partition_id;
1348+
}
1349+
// To avoid accumulating the initial operand multiple times during
1350+
// all-reduce, we use identity operands for all non-zero partitions.
1351+
auto not_partition_zero = b->AddInstruction(HloInstruction::CreateConvert(
1352+
ShapeUtil::MakeScalarShape(PRED), partition_id));
1353+
not_partition_zero = b->AddInstruction(HloInstruction::CreateBroadcast(
1354+
ShapeUtil::ChangeElementType(identity->shape(), PRED), not_partition_zero,
1355+
{}));
1356+
auto select_operand =
1357+
b->AddInstruction(HloInstruction::HloInstruction::CreateTernary(
1358+
identity->shape(), HloOpcode::kSelect, not_partition_zero, identity,
1359+
per_group_operand.hlo()));
1360+
PartitionedHlo new_operand =
1361+
per_group_operand.CloneWithNewHlo(select_operand);
1362+
std::vector<PartitionedHlo> per_group_new_operands = {new_operand};
1363+
std::vector<PartitionedHlo> per_group_updates = {
1364+
PerGroupPartitionedHlo(updates[0], update_grouped, b, clean_ups)};
1365+
PartitionedHlo per_group_indices =
1366+
PerGroupPartitionedHlo(indices, indices_grouped, b, clean_ups);
1367+
auto pshape = MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape);
1368+
TF_ASSIGN_OR_RETURN(
1369+
HloInstruction * pscatter,
1370+
PartitionScatter(
1371+
scatter, per_group_new_operands, per_group_indices, per_group_updates,
1372+
pshape,
1373+
HloSharding::Single(scatter->shape(), output_grouped.sharding),
1374+
slice_sizes, visitor, allow_recursive));
1375+
// All-reduce along all dims in operand sharding -- this is OK because the
1376+
// operand is not sharded on index_vector_dim.
1377+
std::vector<int64_t> all_dims(indices.rank());
1378+
absl::c_iota(all_dims, 0);
1379+
auto all_reduce = operands[0].state().partitioner->AllReduceAlongShardingDims(
1380+
b, pscatter, original_indices_sharding, indices.state().next_channel_id,
1381+
all_dims, operands[0].state().collective_ops_creator,
1382+
scatter->to_apply());
1383+
all_reduce->set_sharding(hlo_sharding_util::UngroupSharding(output_grouped));
1384+
if (allow_recursive) {
1385+
VLOG(5) << "[Scatter partitioning]: Partitioned as index passthrough";
1386+
}
1387+
return PartitionedHlo(all_reduce, output_shape, operands[0].state())
1388+
.Reshard(output_sharding)
1389+
.hlo();
13711390
}
13721391

13731392
// Partition a Scatter when its sliced in a dimension in the operand that is
@@ -1487,14 +1506,31 @@ absl::StatusOr<HloInstruction*> PartitionScatterTrivialSlicedOperandDimensions(
14871506
// Returns a full list of partitioning methods used for scatter.
14881507
std::vector<std::pair<decltype(PartitionScatter)*, absl::string_view>>
14891508
ScatterPartitionMethods() {
1490-
return {{PartitionScatterIndexParallelDimensions,
1491-
"PartitionScatterIndexParallelDimensions"},
1492-
{PartitionScatterOperandPassthroughDimensions,
1493-
"PartitionScatterOperandPassthroughDimensions"},
1494-
{PartitionScatterTrivialSlicedOperandDimensions,
1495-
"PartitionScatterTrivialSlicedOperandDimensions"},
1496-
{PartitionScatterIndexPassthroughDimensions,
1497-
"PartitionScatterIndexPassthroughDimensions"}};
1509+
return {{PartitionScatterIndexParallelDimensions,
1510+
"PartitionScatterIndexParallelDimensions"},
1511+
{PartitionScatterOperandPassthroughDimensions,
1512+
"PartitionScatterOperandPassthroughDimensions"},
1513+
{PartitionScatterTrivialSlicedOperandDimensions,
1514+
"PartitionScatterTrivialSlicedOperandDimensions"},
1515+
{PartitionScatterIndexPassthroughDimensions,
1516+
"PartitionScatterIndexPassthroughDimensions"}};
1517+
}
1518+
1519+
// Helper function to get the actual scatter partitioning method
1520+
decltype(PartitionScatter)* GetScatterPartitionMethod(
1521+
PartitioningMethod method) {
1522+
switch (method) {
1523+
case PartitioningMethod::kIndexParallel:
1524+
return PartitionScatterIndexParallelDimensions;
1525+
case PartitioningMethod::kOperandPassthrough:
1526+
return PartitionScatterOperandPassthroughDimensions;
1527+
case PartitioningMethod::kTrivialSlicedOperand:
1528+
return PartitionScatterTrivialSlicedOperandDimensions;
1529+
case PartitioningMethod::kIndexPassthrough:
1530+
return PartitionScatterIndexPassthroughDimensions;
1531+
default:
1532+
return PartitionScatterIndexParallelDimensions;
1533+
}
14981534
}
14991535

15001536
// Estimates the memory and communication for each partitioning methods for
@@ -1506,7 +1542,10 @@ std::pair<int64_t, int64_t> ScatterPartitionMethodCostModel(
15061542
const std::vector<PartitionedHlo>& updates, const Shape& output_shape,
15071543
const HloSharding& output_sharding, absl::Span<const int64_t> slice_sizes,
15081544
SpmdPartitioningVisitor* visitor) {
1509-
if (partition_method == PartitionScatterIndexParallelDimensions) {
1545+
decltype(PartitionScatter)* zero_cost_method =
1546+
GetScatterPartitionMethod(scatter_partition_method);
1547+
1548+
if (partition_method == zero_cost_method) {
15101549
// Always prioritize index parallel partitioning, and assume it has zero
15111550
// cost.
15121551
return {0, 0};
@@ -1679,6 +1718,7 @@ absl::Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
16791718
break;
16801719
}
16811720
}
1721+
scatter_partition_method = options().scatter_partition_method;
16821722
std::vector<int64_t> slice_sizes = hlo_sharding_util::GetScatterSliceSize(
16831723
operands[0].base_shape(), updates[0].base_shape(), dnums);
16841724

third_party/xla/xla/service/spmd/spmd_partitioner.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ limitations under the License.
5252
namespace xla {
5353
namespace spmd {
5454

55+
// Enum representing the partitioning methods for gather and scatter.
56+
enum class PartitioningMethod {
57+
kIndexParallel,
58+
kOperandPassthrough,
59+
kTrivialSlicedOperand,
60+
kIndexPassthrough,
61+
};
62+
5563
struct SpmdPartitionerOptions {
5664
// Always exchange halo on LHS for all convolutions. If false, backprop filter
5765
// convolution exchanges halo on RHS.
@@ -100,6 +108,14 @@ struct SpmdPartitionerOptions {
100108
// Whether disable rewrite for dots that share the same
101109
// operand as an already rewritten windowed einsum loop.
102110
bool disable_ag_rewrite_for_multiple_consumers = false;
111+
112+
// Partitioning method to prioritize for gather operations.
113+
PartitioningMethod gather_partition_method =
114+
PartitioningMethod::kIndexParallel;
115+
116+
// Partitioning method to prioritize for scatter operations.
117+
PartitioningMethod scatter_partition_method =
118+
PartitioningMethod::kIndexParallel;
103119
};
104120

105121
// Class to wrap the computation builder to capture information during SPMD

0 commit comments

Comments
 (0)