Skip to content

Commit e0e9dca

Browse files
Preserve the frontend attributes associated with an HLO when partitioning it into a partitioned HLO through the spmd_partitioner pass.
As part of this change, we broke down the `SpmdBuilder::AddInstruction` into multiple smaller functions. PiperOrigin-RevId: 684223463
1 parent ffda817 commit e0e9dca

File tree

2 files changed

+168
-112
lines changed

2 files changed

+168
-112
lines changed

third_party/xla/xla/service/spmd/spmd_partitioner.cc

Lines changed: 156 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -246,135 +246,179 @@ HloInstruction* SpmdBuilder::AddInstruction(
246246
HloInstruction* hlo =
247247
HloComputation::Builder::AddInstruction(std::move(instruction));
248248
if (visiting_hlo_) {
249-
hlo->set_metadata(visiting_hlo_->metadata());
249+
std::shared_ptr<const HloSharding> prev_sharding = hlo->sharding_ptr();
250+
visiting_hlo_->SetupDerivedInstruction(hlo);
251+
if (prev_sharding != nullptr) {
252+
hlo->set_sharding(*prev_sharding);
253+
} else {
254+
hlo->clear_sharding();
255+
}
250256
instructions_[visiting_hlo_].push_back(hlo);
251257
}
252-
if (hlo->opcode() == HloOpcode::kBroadcast) {
253-
for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
254-
if (!absl::c_linear_search(hlo->dimensions(), i)) {
255-
broadcast_dims_[hlo].insert(i);
258+
SetBroadcastDimsForAddedHlo(*hlo);
259+
return hlo;
260+
}
261+
262+
void SpmdBuilder::SetBroadcastDimsForAddedHlo(const HloInstruction& hlo) {
263+
if (hlo.opcode() == HloOpcode::kBroadcast) {
264+
for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
265+
if (!absl::c_linear_search(hlo.dimensions(), i)) {
266+
broadcast_dims_[&hlo].insert(i);
256267
}
257268
}
258269
}
259-
if (hlo->IsElementwise() && hlo->operand_count() > 0 &&
270+
if (hlo.IsElementwise() && hlo.operand_count() > 0 &&
260271
// Copy can have a tuple result.
261-
hlo->shape().IsArray()) {
262-
absl::flat_hash_set<int64_t> broadcast_dims;
263-
for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
264-
broadcast_dims.insert(i);
272+
hlo.shape().IsArray()) {
273+
SetBroadcastDimsForElementwise(hlo);
274+
}
275+
if (hlo.opcode() == HloOpcode::kTranspose) {
276+
SetBroadcastDimsForTranspose(hlo);
277+
}
278+
if (hlo.opcode() == HloOpcode::kReshape &&
279+
Product(hlo.shape().dimensions()) > 0) {
280+
SetBroadcastDimsForReshape(hlo);
281+
}
282+
if (hlo.opcode() == HloOpcode::kSlice ||
283+
hlo.opcode() == HloOpcode::kDynamicSlice) {
284+
SetBroadcastDimsForSlice(hlo);
285+
}
286+
if (hlo.opcode() == HloOpcode::kPad) {
287+
SetBroadcastDimsForPad(hlo);
288+
}
289+
}
290+
291+
void SpmdBuilder::SetBroadcastDimsForReshape(const HloInstruction& hlo) {
292+
CHECK(hlo.opcode() == HloOpcode::kReshape);
293+
294+
auto it = broadcast_dims_.find(hlo.operand(0));
295+
if (it == broadcast_dims_.end()) {
296+
return;
297+
}
298+
std::vector<int64_t> iota_dims(hlo.shape().rank());
299+
absl::c_iota(iota_dims, 0);
300+
absl::flat_hash_set<int64_t> reshape_broadcast_dims(iota_dims.begin(),
301+
iota_dims.end());
302+
303+
absl::Span<const int64_t> operand_dims = hlo.operand(0)->shape().dimensions();
304+
absl::Span<const int64_t> hlo_dims = hlo.shape().dimensions();
305+
std::vector<int64_t> before_dim_size_stack(operand_dims.rbegin(),
306+
operand_dims.rend());
307+
std::vector<int64_t> after_dim_size_stack(hlo_dims.rbegin(), hlo_dims.rend());
308+
309+
auto erase_reshape_broadcast_dims = [&reshape_broadcast_dims](int64_t from,
310+
int64_t to) {
311+
for (int64_t i = from; i < to; ++i) {
312+
reshape_broadcast_dims.erase(i);
265313
}
266-
for (int64_t i = 0; i < hlo->operand_count(); ++i) {
267-
auto it = broadcast_dims_.find(hlo->operand(i));
268-
if (it == broadcast_dims_.end()) {
269-
broadcast_dims.clear();
270-
break;
271-
}
272-
for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
273-
if (!it->second.contains(i)) {
274-
broadcast_dims.erase(i);
275-
}
276-
}
314+
};
315+
316+
while (!before_dim_size_stack.empty() && !after_dim_size_stack.empty()) {
317+
int64_t before_size = before_dim_size_stack.back();
318+
int64_t after_size = after_dim_size_stack.back();
319+
int64_t current_before_dim =
320+
hlo.operand(0)->shape().rank() - before_dim_size_stack.size();
321+
int64_t current_after_dim =
322+
hlo.shape().rank() - after_dim_size_stack.size();
323+
before_dim_size_stack.pop_back();
324+
after_dim_size_stack.pop_back();
325+
if (!it->second.contains(current_before_dim)) {
326+
reshape_broadcast_dims.erase(current_after_dim);
327+
}
328+
if (before_size == after_size) {
329+
continue;
277330
}
278-
if (!broadcast_dims.empty()) {
279-
broadcast_dims_[hlo] = std::move(broadcast_dims);
331+
if (before_size % after_size == 0) {
332+
// Split dim.
333+
before_dim_size_stack.push_back(before_size / after_size);
334+
} else if (after_size % before_size == 0) {
335+
// Merge dim.
336+
after_dim_size_stack.push_back(after_size / before_size);
337+
} else {
338+
// Other cases, mark all remaining dims as non-broadcast.
339+
erase_reshape_broadcast_dims(current_after_dim, hlo.shape().rank());
340+
break;
280341
}
281342
}
282-
if (hlo->opcode() == HloOpcode::kTranspose) {
283-
auto it = broadcast_dims_.find(hlo->operand(0));
284-
if (it != broadcast_dims_.end()) {
285-
absl::flat_hash_set<int64_t> xpose_broadcast_dims;
286-
std::vector<int64_t> reverse_map(hlo->shape().rank());
287-
for (int64_t i = 0; i < reverse_map.size(); ++i) {
288-
reverse_map[hlo->dimensions(i)] = i;
289-
}
290-
for (int64_t dim : it->second) {
291-
xpose_broadcast_dims.insert(reverse_map[dim]);
292-
}
293-
broadcast_dims_[hlo] = std::move(xpose_broadcast_dims);
294-
}
343+
344+
bool has_broadcast_dims = !reshape_broadcast_dims.empty() &&
345+
before_dim_size_stack.empty() &&
346+
after_dim_size_stack.empty();
347+
if (has_broadcast_dims) {
348+
broadcast_dims_[&hlo] = std::move(reshape_broadcast_dims);
295349
}
296-
if (hlo->opcode() == HloOpcode::kReshape &&
297-
Product(hlo->shape().dimensions()) > 0) {
298-
auto it = broadcast_dims_.find(hlo->operand(0));
299-
if (it != broadcast_dims_.end()) {
300-
absl::flat_hash_set<int64_t> reshape_broadcast_dims;
301-
for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
302-
reshape_broadcast_dims.insert(i);
303-
}
304-
std::vector<int64_t> before_dim_size_stack;
305-
std::vector<int64_t> after_dim_size_stack;
306-
const int64_t operand0_rank = hlo->operand(0)->shape().rank();
307-
const int64_t hlo_shape_rank = hlo->shape().rank();
308-
before_dim_size_stack.reserve(operand0_rank);
309-
after_dim_size_stack.reserve(hlo_shape_rank);
310-
for (int64_t i = operand0_rank - 1; i >= 0; --i) {
311-
before_dim_size_stack.push_back(hlo->operand(0)->shape().dimensions(i));
312-
}
313-
for (int64_t i = hlo_shape_rank - 1; i >= 0; --i) {
314-
after_dim_size_stack.push_back(hlo->shape().dimensions(i));
315-
}
316-
while (!before_dim_size_stack.empty() && !after_dim_size_stack.empty()) {
317-
int64_t before_size = before_dim_size_stack.back();
318-
int64_t after_size = after_dim_size_stack.back();
319-
int64_t current_before_dim =
320-
hlo->operand(0)->shape().rank() - before_dim_size_stack.size();
321-
int64_t current_after_dim =
322-
hlo->shape().rank() - after_dim_size_stack.size();
323-
before_dim_size_stack.pop_back();
324-
after_dim_size_stack.pop_back();
325-
if (!it->second.contains(current_before_dim)) {
326-
reshape_broadcast_dims.erase(current_after_dim);
327-
}
328-
if (before_size == after_size) {
329-
continue;
330-
}
331-
if (before_size % after_size == 0) {
332-
// Split dim.
333-
before_dim_size_stack.push_back(before_size / after_size);
334-
} else if (after_size % before_size == 0) {
335-
// Merge dim.
336-
after_dim_size_stack.push_back(after_size / before_size);
337-
} else {
338-
// Other cases, mark all remaining dims as non-broadcast.
339-
for (int64_t i = current_after_dim; i < hlo->shape().rank(); ++i) {
340-
reshape_broadcast_dims.erase(i);
341-
}
342-
break;
343-
}
344-
}
345-
if (!before_dim_size_stack.empty() || !after_dim_size_stack.empty()) {
346-
reshape_broadcast_dims.clear();
347-
}
348-
if (!reshape_broadcast_dims.empty()) {
349-
broadcast_dims_[hlo] = std::move(reshape_broadcast_dims);
350-
}
351-
}
350+
}
351+
352+
void SpmdBuilder::SetBroadcastDimsForTranspose(const HloInstruction& hlo) {
353+
CHECK(hlo.opcode() == HloOpcode::kTranspose);
354+
auto it = broadcast_dims_.find(hlo.operand(0));
355+
if (it == broadcast_dims_.end()) {
356+
return;
357+
}
358+
absl::flat_hash_set<int64_t> xpose_broadcast_dims;
359+
std::vector<int64_t> reverse_map(hlo.shape().rank());
360+
for (int64_t i = 0; i < reverse_map.size(); ++i) {
361+
reverse_map[hlo.dimensions(i)] = i;
362+
}
363+
for (int64_t dim : it->second) {
364+
xpose_broadcast_dims.insert(reverse_map[dim]);
365+
}
366+
broadcast_dims_[&hlo] = std::move(xpose_broadcast_dims);
367+
}
368+
369+
void SpmdBuilder::SetBroadcastDimsForPad(const HloInstruction& hlo) {
370+
CHECK(hlo.opcode() == HloOpcode::kPad);
371+
auto it = broadcast_dims_.find(hlo.operand(0));
372+
if (it == broadcast_dims_.end()) {
373+
return;
352374
}
353-
if (hlo->opcode() == HloOpcode::kSlice ||
354-
hlo->opcode() == HloOpcode::kDynamicSlice) {
355-
auto it = broadcast_dims_.find(hlo->operand(0));
356-
if (it != broadcast_dims_.end()) {
357-
auto dims = it->second;
358-
broadcast_dims_[hlo] = std::move(dims);
375+
absl::flat_hash_set<int64_t> pad_broadcast_dims;
376+
for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
377+
const auto& dim = hlo.padding_config().dimensions(i);
378+
if (dim.edge_padding_low() == 0 && dim.edge_padding_high() == 0 &&
379+
dim.interior_padding() == 0 && it->second.contains(i)) {
380+
pad_broadcast_dims.insert(i);
359381
}
360382
}
361-
if (hlo->opcode() == HloOpcode::kPad) {
362-
auto it = broadcast_dims_.find(hlo->operand(0));
363-
if (it != broadcast_dims_.end()) {
364-
absl::flat_hash_set<int64_t> pad_broadcast_dims;
365-
for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
366-
const auto& dim = hlo->padding_config().dimensions(i);
367-
if (dim.edge_padding_low() == 0 && dim.edge_padding_high() == 0 &&
368-
dim.interior_padding() == 0 && it->second.contains(i)) {
369-
pad_broadcast_dims.insert(i);
370-
}
371-
}
372-
if (!pad_broadcast_dims.empty()) {
373-
broadcast_dims_[hlo] = std::move(pad_broadcast_dims);
383+
if (!pad_broadcast_dims.empty()) {
384+
broadcast_dims_[&hlo] = std::move(pad_broadcast_dims);
385+
}
386+
}
387+
388+
void SpmdBuilder::SetBroadcastDimsForSlice(const HloInstruction& hlo) {
389+
CHECK(hlo.opcode() == HloOpcode::kSlice ||
390+
hlo.opcode() == HloOpcode::kDynamicSlice);
391+
auto it = broadcast_dims_.find(hlo.operand(0));
392+
if (it != broadcast_dims_.end()) {
393+
auto dims = it->second;
394+
broadcast_dims_[&hlo] = std::move(dims);
395+
}
396+
}
397+
398+
void SpmdBuilder::SetBroadcastDimsForElementwise(const HloInstruction& hlo) {
399+
CHECK(hlo.IsElementwise());
400+
if (hlo.operand_count() == 0 || hlo.shape().IsTuple()) {
401+
return;
402+
}
403+
absl::flat_hash_set<int64_t> broadcast_dims;
404+
for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
405+
broadcast_dims.insert(i);
406+
}
407+
for (int64_t i = 0; i < hlo.operand_count(); ++i) {
408+
auto it = broadcast_dims_.find(hlo.operand(i));
409+
if (it == broadcast_dims_.end()) {
410+
broadcast_dims.clear();
411+
break;
412+
}
413+
for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
414+
if (!it->second.contains(i)) {
415+
broadcast_dims.erase(i);
374416
}
375417
}
376418
}
377-
return hlo;
419+
if (!broadcast_dims.empty()) {
420+
broadcast_dims_[&hlo] = std::move(broadcast_dims);
421+
}
378422
}
379423

380424
PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target,

third_party/xla/xla/service/spmd/spmd_partitioner.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,18 @@ class SpmdBuilder : public HloComputation::Builder {
154154
}
155155

156156
private:
157+
// Sets the broadcast dims for the newly added/created hlo.
158+
void SetBroadcastDimsForAddedHlo(const HloInstruction& hlo);
159+
160+
void SetBroadcastDimsForReshape(const HloInstruction& hlo);
161+
162+
void SetBroadcastDimsForTranspose(const HloInstruction& hlo);
163+
164+
void SetBroadcastDimsForPad(const HloInstruction& hlo);
165+
166+
void SetBroadcastDimsForSlice(const HloInstruction& hlo);
167+
168+
void SetBroadcastDimsForElementwise(const HloInstruction& hlo);
157169
// Currently visiting instruction.
158170
HloInstruction* visiting_hlo_;
159171

0 commit comments

Comments
 (0)