@@ -246,135 +246,179 @@ HloInstruction* SpmdBuilder::AddInstruction(
246
246
HloInstruction* hlo =
247
247
HloComputation::Builder::AddInstruction (std::move (instruction));
248
248
if (visiting_hlo_) {
249
- hlo->set_metadata (visiting_hlo_->metadata ());
249
+ std::shared_ptr<const HloSharding> prev_sharding = hlo->sharding_ptr ();
250
+ visiting_hlo_->SetupDerivedInstruction (hlo);
251
+ if (prev_sharding != nullptr ) {
252
+ hlo->set_sharding (*prev_sharding);
253
+ } else {
254
+ hlo->clear_sharding ();
255
+ }
250
256
instructions_[visiting_hlo_].push_back (hlo);
251
257
}
252
- if (hlo->opcode () == HloOpcode::kBroadcast ) {
253
- for (int64_t i = 0 ; i < hlo->shape ().rank (); ++i) {
254
- if (!absl::c_linear_search (hlo->dimensions (), i)) {
255
- broadcast_dims_[hlo].insert (i);
258
+ SetBroadcastDimsForAddedHlo (*hlo);
259
+ return hlo;
260
+ }
261
+
262
+ void SpmdBuilder::SetBroadcastDimsForAddedHlo (const HloInstruction& hlo) {
263
+ if (hlo.opcode () == HloOpcode::kBroadcast ) {
264
+ for (int64_t i = 0 ; i < hlo.shape ().rank (); ++i) {
265
+ if (!absl::c_linear_search (hlo.dimensions (), i)) {
266
+ broadcast_dims_[&hlo].insert (i);
256
267
}
257
268
}
258
269
}
259
- if (hlo-> IsElementwise () && hlo-> operand_count () > 0 &&
270
+ if (hlo. IsElementwise () && hlo. operand_count () > 0 &&
260
271
// Copy can have a tuple result.
261
- hlo->shape ().IsArray ()) {
262
- absl::flat_hash_set<int64_t > broadcast_dims;
263
- for (int64_t i = 0 ; i < hlo->shape ().rank (); ++i) {
264
- broadcast_dims.insert (i);
272
+ hlo.shape ().IsArray ()) {
273
+ SetBroadcastDimsForElementwise (hlo);
274
+ }
275
+ if (hlo.opcode () == HloOpcode::kTranspose ) {
276
+ SetBroadcastDimsForTranspose (hlo);
277
+ }
278
+ if (hlo.opcode () == HloOpcode::kReshape &&
279
+ Product (hlo.shape ().dimensions ()) > 0 ) {
280
+ SetBroadcastDimsForReshape (hlo);
281
+ }
282
+ if (hlo.opcode () == HloOpcode::kSlice ||
283
+ hlo.opcode () == HloOpcode::kDynamicSlice ) {
284
+ SetBroadcastDimsForSlice (hlo);
285
+ }
286
+ if (hlo.opcode () == HloOpcode::kPad ) {
287
+ SetBroadcastDimsForPad (hlo);
288
+ }
289
+ }
290
+
291
+ void SpmdBuilder::SetBroadcastDimsForReshape (const HloInstruction& hlo) {
292
+ CHECK (hlo.opcode () == HloOpcode::kReshape );
293
+
294
+ auto it = broadcast_dims_.find (hlo.operand (0 ));
295
+ if (it == broadcast_dims_.end ()) {
296
+ return ;
297
+ }
298
+ std::vector<int64_t > iota_dims (hlo.shape ().rank ());
299
+ absl::c_iota (iota_dims, 0 );
300
+ absl::flat_hash_set<int64_t > reshape_broadcast_dims (iota_dims.begin (),
301
+ iota_dims.end ());
302
+
303
+ absl::Span<const int64_t > operand_dims = hlo.operand (0 )->shape ().dimensions ();
304
+ absl::Span<const int64_t > hlo_dims = hlo.shape ().dimensions ();
305
+ std::vector<int64_t > before_dim_size_stack (operand_dims.rbegin (),
306
+ operand_dims.rend ());
307
+ std::vector<int64_t > after_dim_size_stack (hlo_dims.rbegin (), hlo_dims.rend ());
308
+
309
+ auto erase_reshape_broadcast_dims = [&reshape_broadcast_dims](int64_t from,
310
+ int64_t to) {
311
+ for (int64_t i = from; i < to; ++i) {
312
+ reshape_broadcast_dims.erase (i);
265
313
}
266
- for (int64_t i = 0 ; i < hlo->operand_count (); ++i) {
267
- auto it = broadcast_dims_.find (hlo->operand (i));
268
- if (it == broadcast_dims_.end ()) {
269
- broadcast_dims.clear ();
270
- break ;
271
- }
272
- for (int64_t i = 0 ; i < hlo->shape ().rank (); ++i) {
273
- if (!it->second .contains (i)) {
274
- broadcast_dims.erase (i);
275
- }
276
- }
314
+ };
315
+
316
+ while (!before_dim_size_stack.empty () && !after_dim_size_stack.empty ()) {
317
+ int64_t before_size = before_dim_size_stack.back ();
318
+ int64_t after_size = after_dim_size_stack.back ();
319
+ int64_t current_before_dim =
320
+ hlo.operand (0 )->shape ().rank () - before_dim_size_stack.size ();
321
+ int64_t current_after_dim =
322
+ hlo.shape ().rank () - after_dim_size_stack.size ();
323
+ before_dim_size_stack.pop_back ();
324
+ after_dim_size_stack.pop_back ();
325
+ if (!it->second .contains (current_before_dim)) {
326
+ reshape_broadcast_dims.erase (current_after_dim);
327
+ }
328
+ if (before_size == after_size) {
329
+ continue ;
277
330
}
278
- if (!broadcast_dims.empty ()) {
279
- broadcast_dims_[hlo] = std::move (broadcast_dims);
331
+ if (before_size % after_size == 0 ) {
332
+ // Split dim.
333
+ before_dim_size_stack.push_back (before_size / after_size);
334
+ } else if (after_size % before_size == 0 ) {
335
+ // Merge dim.
336
+ after_dim_size_stack.push_back (after_size / before_size);
337
+ } else {
338
+ // Other cases, mark all remaining dims as non-broadcast.
339
+ erase_reshape_broadcast_dims (current_after_dim, hlo.shape ().rank ());
340
+ break ;
280
341
}
281
342
}
282
- if (hlo->opcode () == HloOpcode::kTranspose ) {
283
- auto it = broadcast_dims_.find (hlo->operand (0 ));
284
- if (it != broadcast_dims_.end ()) {
285
- absl::flat_hash_set<int64_t > xpose_broadcast_dims;
286
- std::vector<int64_t > reverse_map (hlo->shape ().rank ());
287
- for (int64_t i = 0 ; i < reverse_map.size (); ++i) {
288
- reverse_map[hlo->dimensions (i)] = i;
289
- }
290
- for (int64_t dim : it->second ) {
291
- xpose_broadcast_dims.insert (reverse_map[dim]);
292
- }
293
- broadcast_dims_[hlo] = std::move (xpose_broadcast_dims);
294
- }
343
+
344
+ bool has_broadcast_dims = !reshape_broadcast_dims.empty () &&
345
+ before_dim_size_stack.empty () &&
346
+ after_dim_size_stack.empty ();
347
+ if (has_broadcast_dims) {
348
+ broadcast_dims_[&hlo] = std::move (reshape_broadcast_dims);
295
349
}
296
- if (hlo->opcode () == HloOpcode::kReshape &&
297
- Product (hlo->shape ().dimensions ()) > 0 ) {
298
- auto it = broadcast_dims_.find (hlo->operand (0 ));
299
- if (it != broadcast_dims_.end ()) {
300
- absl::flat_hash_set<int64_t > reshape_broadcast_dims;
301
- for (int64_t i = 0 ; i < hlo->shape ().rank (); ++i) {
302
- reshape_broadcast_dims.insert (i);
303
- }
304
- std::vector<int64_t > before_dim_size_stack;
305
- std::vector<int64_t > after_dim_size_stack;
306
- const int64_t operand0_rank = hlo->operand (0 )->shape ().rank ();
307
- const int64_t hlo_shape_rank = hlo->shape ().rank ();
308
- before_dim_size_stack.reserve (operand0_rank);
309
- after_dim_size_stack.reserve (hlo_shape_rank);
310
- for (int64_t i = operand0_rank - 1 ; i >= 0 ; --i) {
311
- before_dim_size_stack.push_back (hlo->operand (0 )->shape ().dimensions (i));
312
- }
313
- for (int64_t i = hlo_shape_rank - 1 ; i >= 0 ; --i) {
314
- after_dim_size_stack.push_back (hlo->shape ().dimensions (i));
315
- }
316
- while (!before_dim_size_stack.empty () && !after_dim_size_stack.empty ()) {
317
- int64_t before_size = before_dim_size_stack.back ();
318
- int64_t after_size = after_dim_size_stack.back ();
319
- int64_t current_before_dim =
320
- hlo->operand (0 )->shape ().rank () - before_dim_size_stack.size ();
321
- int64_t current_after_dim =
322
- hlo->shape ().rank () - after_dim_size_stack.size ();
323
- before_dim_size_stack.pop_back ();
324
- after_dim_size_stack.pop_back ();
325
- if (!it->second .contains (current_before_dim)) {
326
- reshape_broadcast_dims.erase (current_after_dim);
327
- }
328
- if (before_size == after_size) {
329
- continue ;
330
- }
331
- if (before_size % after_size == 0 ) {
332
- // Split dim.
333
- before_dim_size_stack.push_back (before_size / after_size);
334
- } else if (after_size % before_size == 0 ) {
335
- // Merge dim.
336
- after_dim_size_stack.push_back (after_size / before_size);
337
- } else {
338
- // Other cases, mark all remaining dims as non-broadcast.
339
- for (int64_t i = current_after_dim; i < hlo->shape ().rank (); ++i) {
340
- reshape_broadcast_dims.erase (i);
341
- }
342
- break ;
343
- }
344
- }
345
- if (!before_dim_size_stack.empty () || !after_dim_size_stack.empty ()) {
346
- reshape_broadcast_dims.clear ();
347
- }
348
- if (!reshape_broadcast_dims.empty ()) {
349
- broadcast_dims_[hlo] = std::move (reshape_broadcast_dims);
350
- }
351
- }
350
+ }
351
+
352
+ void SpmdBuilder::SetBroadcastDimsForTranspose (const HloInstruction& hlo) {
353
+ CHECK (hlo.opcode () == HloOpcode::kTranspose );
354
+ auto it = broadcast_dims_.find (hlo.operand (0 ));
355
+ if (it == broadcast_dims_.end ()) {
356
+ return ;
357
+ }
358
+ absl::flat_hash_set<int64_t > xpose_broadcast_dims;
359
+ std::vector<int64_t > reverse_map (hlo.shape ().rank ());
360
+ for (int64_t i = 0 ; i < reverse_map.size (); ++i) {
361
+ reverse_map[hlo.dimensions (i)] = i;
362
+ }
363
+ for (int64_t dim : it->second ) {
364
+ xpose_broadcast_dims.insert (reverse_map[dim]);
365
+ }
366
+ broadcast_dims_[&hlo] = std::move (xpose_broadcast_dims);
367
+ }
368
+
369
+ void SpmdBuilder::SetBroadcastDimsForPad (const HloInstruction& hlo) {
370
+ CHECK (hlo.opcode () == HloOpcode::kPad );
371
+ auto it = broadcast_dims_.find (hlo.operand (0 ));
372
+ if (it == broadcast_dims_.end ()) {
373
+ return ;
352
374
}
353
- if (hlo-> opcode () == HloOpcode:: kSlice ||
354
- hlo-> opcode () == HloOpcode:: kDynamicSlice ) {
355
- auto it = broadcast_dims_. find (hlo-> operand ( 0 ) );
356
- if (it != broadcast_dims_. end ()) {
357
- auto dims = it->second ;
358
- broadcast_dims_[hlo] = std::move (dims );
375
+ absl::flat_hash_set< int64_t > pad_broadcast_dims;
376
+ for ( int64_t i = 0 ; i < hlo. shape (). rank (); ++i ) {
377
+ const auto & dim = hlo. padding_config (). dimensions (i );
378
+ if (dim. edge_padding_low () == 0 && dim. edge_padding_high () == 0 &&
379
+ dim. interior_padding () == 0 && it->second . contains (i)) {
380
+ pad_broadcast_dims. insert (i );
359
381
}
360
382
}
361
- if (hlo->opcode () == HloOpcode::kPad ) {
362
- auto it = broadcast_dims_.find (hlo->operand (0 ));
363
- if (it != broadcast_dims_.end ()) {
364
- absl::flat_hash_set<int64_t > pad_broadcast_dims;
365
- for (int64_t i = 0 ; i < hlo->shape ().rank (); ++i) {
366
- const auto & dim = hlo->padding_config ().dimensions (i);
367
- if (dim.edge_padding_low () == 0 && dim.edge_padding_high () == 0 &&
368
- dim.interior_padding () == 0 && it->second .contains (i)) {
369
- pad_broadcast_dims.insert (i);
370
- }
371
- }
372
- if (!pad_broadcast_dims.empty ()) {
373
- broadcast_dims_[hlo] = std::move (pad_broadcast_dims);
383
+ if (!pad_broadcast_dims.empty ()) {
384
+ broadcast_dims_[&hlo] = std::move (pad_broadcast_dims);
385
+ }
386
+ }
387
+
388
+ void SpmdBuilder::SetBroadcastDimsForSlice (const HloInstruction& hlo) {
389
+ CHECK (hlo.opcode () == HloOpcode::kSlice ||
390
+ hlo.opcode () == HloOpcode::kDynamicSlice );
391
+ auto it = broadcast_dims_.find (hlo.operand (0 ));
392
+ if (it != broadcast_dims_.end ()) {
393
+ auto dims = it->second ;
394
+ broadcast_dims_[&hlo] = std::move (dims);
395
+ }
396
+ }
397
+
398
+ void SpmdBuilder::SetBroadcastDimsForElementwise (const HloInstruction& hlo) {
399
+ CHECK (hlo.IsElementwise ());
400
+ if (hlo.operand_count () == 0 || hlo.shape ().IsTuple ()) {
401
+ return ;
402
+ }
403
+ absl::flat_hash_set<int64_t > broadcast_dims;
404
+ for (int64_t i = 0 ; i < hlo.shape ().rank (); ++i) {
405
+ broadcast_dims.insert (i);
406
+ }
407
+ for (int64_t i = 0 ; i < hlo.operand_count (); ++i) {
408
+ auto it = broadcast_dims_.find (hlo.operand (i));
409
+ if (it == broadcast_dims_.end ()) {
410
+ broadcast_dims.clear ();
411
+ break ;
412
+ }
413
+ for (int64_t i = 0 ; i < hlo.shape ().rank (); ++i) {
414
+ if (!it->second .contains (i)) {
415
+ broadcast_dims.erase (i);
374
416
}
375
417
}
376
418
}
377
- return hlo;
419
+ if (!broadcast_dims.empty ()) {
420
+ broadcast_dims_[&hlo] = std::move (broadcast_dims);
421
+ }
378
422
}
379
423
380
424
PartitionedHlo PartitionedHlo::Reshard (const HloSharding& target,
0 commit comments