@@ -1285,9 +1285,76 @@ void MsaAlgorithm::IdentifyAndOptimizeMemoryBoundLoops() {
1285
1285
1286
1286
bool MsaAlgorithm::IsAsyncConversionCandidate (
1287
1287
const HloInstruction* instruction) const {
1288
- return IsAsyncConversionCopyCandidate (instruction) ||
1289
- IsAsyncConversionSliceCandidate (instruction) ==
1290
- AsyncConversionResult::kSuccess ;
1288
+ bool meets_special_preconditions =
1289
+ IsAsyncConversionCopyCandidate (instruction) ||
1290
+ IsAsyncConversionSliceCandidate (instruction) ==
1291
+ AsyncConversionResult::kSuccess ;
1292
+ if (!meets_special_preconditions) {
1293
+ return false ;
1294
+ }
1295
+
1296
+ for (auto & operand : instruction->operands ()) {
1297
+ // TODO(b/374835319): relax the operand constraint to be able to cover
1298
+ // nested sync data movement cases.
1299
+ if (IsAsyncConversionCandidate (operand)) {
1300
+ VLOG (4 ) << " The instruction is not considered to be replaced, because it "
1301
+ " potentially has a replaceable operand." ;
1302
+ return false ;
1303
+ }
1304
+ const HloValue& operand_value = alias_analysis_.dataflow_analysis ()
1305
+ .GetValueSet (operand)
1306
+ .GetUniqueValue ();
1307
+ if (!buffer_intervals_.at (&operand_value).need_allocation ) {
1308
+ VLOG (4 )
1309
+ << " The instruction is not considered to be replaced, because its "
1310
+ " operand value doesn't need an allocation." ;
1311
+ return false ;
1312
+ }
1313
+ }
1314
+
1315
+ const HloValue& value = alias_analysis_.dataflow_analysis ()
1316
+ .GetValueSet (instruction)
1317
+ .GetUniqueValue ();
1318
+ if (!buffer_intervals_.at (&value).need_allocation ) {
1319
+ VLOG (4 ) << " The instruction is not considered to be replaced, because its "
1320
+ " output doesn't need an allocation and it might be too late to "
1321
+ " replace this instruction." ;
1322
+ return false ;
1323
+ }
1324
+ if (value.IsRootOf (instruction->parent ())) {
1325
+ VLOG (4 ) << " The instruction is not considered to be replaced, because its "
1326
+ " output value is in the root of the computation." ;
1327
+ return false ;
1328
+ }
1329
+ if (finalized_values_.contains (&value)) {
1330
+ VLOG (4 ) << " The instruction is not considered to be replaced, because its "
1331
+ " output value is in the finalized values." ;
1332
+ return false ;
1333
+ }
1334
+ if (buffer_intervals_.at (&value).size > available_heap_size ()) {
1335
+ VLOG (4 ) << " The instruction is not considered to be replaced, because its "
1336
+ " output value is too large to fit in the heap." ;
1337
+ return false ;
1338
+ }
1339
+ // This check is here only because we skip processing the values that are not
1340
+ // allowed in alternate memory.
1341
+ if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory (
1342
+ buffer_intervals_.at (&value), options_.alternate_memory_space )) {
1343
+ VLOG (4 ) << " The instruction is not considered to be replaced, because its "
1344
+ " output value is not allowed in alternate memory." ;
1345
+ return false ;
1346
+ }
1347
+
1348
+ for (const HloInstruction* user : instruction->users ()) {
1349
+ if (HloDataflowAnalysis::IsAsynchronousOperationStart (user->opcode ())) {
1350
+ VLOG (4 ) << " The instruction is not considered to be replaced, because "
1351
+ " its used by an async start operation that might require "
1352
+ " contiguous allocation." ;
1353
+ return false ;
1354
+ }
1355
+ }
1356
+
1357
+ return true ;
1291
1358
}
1292
1359
1293
1360
bool MsaAlgorithm::IsAsyncConversionCopyCandidate (
@@ -1363,18 +1430,6 @@ MsaAlgorithm::IsAsyncConversionSliceCandidate(
1363
1430
<< instruction->ToShortString ();
1364
1431
return AsyncConversionResult::kFailedPrecondition ;
1365
1432
}
1366
- bool has_slice_operand = false ;
1367
- for (auto & operand : instruction->operands ()) {
1368
- if (operand->opcode () == HloOpcode::kSlice ) {
1369
- has_slice_operand = true ;
1370
- break ;
1371
- }
1372
- }
1373
- if (has_slice_operand) {
1374
- VLOG (4 ) << " The sync slice is not considered to be replaced, because it "
1375
- " has a slice operand." ;
1376
- return AsyncConversionResult::kFailedPrecondition ;
1377
- }
1378
1433
1379
1434
if (instruction->shape ().layout ().memory_space () !=
1380
1435
static_cast <int64_t >(MemorySpace::kDefault ) ||
@@ -1444,22 +1499,6 @@ void MsaAlgorithm::UpdateSyncDataMovementCandidatesForJointProcessedValues(
1444
1499
const std::vector<const HloValue*>& joint_processed_values) {
1445
1500
absl::flat_hash_set<const HloInstruction*> replaceable_sync_instructions;
1446
1501
absl::flat_hash_set<const HloInstruction*> do_not_touch_instructions;
1447
- for (const HloValue* value : joint_processed_values) {
1448
- if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory (
1449
- buffer_intervals_.at (value), options_.alternate_memory_space )) {
1450
- HloInstruction* inst = value->instruction ();
1451
- if (IsAsyncConversionCandidate (inst)) {
1452
- do_not_touch_instructions.insert (inst);
1453
- failed_async_conversions_[inst] =
1454
- AsyncConversionResult::kFailedValueNotAllowedInAlternateMemory ;
1455
- VLOG (4 ) << " Not trying to replace sync instruction "
1456
- << inst->ToShortString ()
1457
- << " with an async version, because the sync instruction "
1458
- " defines a value that is not allowed in alternate memory." ;
1459
- }
1460
- }
1461
- }
1462
-
1463
1502
for (const HloValue* value : joint_processed_values) {
1464
1503
for (const auto & use : value->GetUses ()) {
1465
1504
bool is_use_replaceable_sync_candidate =
@@ -1577,12 +1616,23 @@ void MsaAlgorithm::CreateAllocationValuesForJointProcessedValues(
1577
1616
continue ;
1578
1617
}
1579
1618
1580
- if (!options_.enable_sync_slice_replacement &&
1581
- !options_.enable_window_prefetch &&
1619
+ if (!options_.enable_window_prefetch &&
1582
1620
interval.size > available_heap_size ()) {
1583
- VLOG (3 ) << " Skip " << interval.buffer ->ToShortString ()
1584
- << " because the buffer is larger than the heap size." ;
1585
- continue ;
1621
+ const HloInstruction* defining_instruction =
1622
+ interval.buffer ->instruction ();
1623
+ auto may_be_replaced_by_slice_fn = [this ](const HloInstruction* user) {
1624
+ return IsInstructionPendingReplacements (user) &&
1625
+ user->opcode () == HloOpcode::kSlice ;
1626
+ };
1627
+ bool may_be_replaced_by_slice = std::any_of (
1628
+ defining_instruction->users ().begin (),
1629
+ defining_instruction->users ().end (), may_be_replaced_by_slice_fn);
1630
+
1631
+ if (!may_be_replaced_by_slice) {
1632
+ VLOG (3 ) << " Skip " << interval.buffer ->ToShortString ()
1633
+ << " because the buffer is larger than the heap size." ;
1634
+ continue ;
1635
+ }
1586
1636
}
1587
1637
1588
1638
auto colocated_intervals = GetSortedColocatedIntervals (interval);
@@ -1642,6 +1692,24 @@ MsaAlgorithm::JointAllocationProposal MsaAlgorithm::GetJointProposal(
1642
1692
return proposal;
1643
1693
}
1644
1694
1695
+ bool MsaAlgorithm::RepackAllocationsIncludeConvertedSyncMemOp () {
1696
+ for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) {
1697
+ if (allocation_block.allocation ->is_copy_allocation ()) {
1698
+ if (dynamic_cast <CopyAllocation*>(allocation_block.allocation )
1699
+ ->sync_mem_op ()) {
1700
+ return true ;
1701
+ }
1702
+ }
1703
+ if (allocation_block.allocation ->is_sliced_copy_allocation ()) {
1704
+ if (dynamic_cast <SlicedCopyAllocation*>(allocation_block.allocation )
1705
+ ->sync_mem_op ()) {
1706
+ return true ;
1707
+ }
1708
+ }
1709
+ }
1710
+ return false ;
1711
+ }
1712
+
1645
1713
absl::StatusOr<HeapSimulator::Result<HloValue>> MsaAlgorithm::Finish () {
1646
1714
// Note: Memory Space Assignment creates a HeapSimulator and passes an
1647
1715
// MsaAlgorithm object to it. buffer_intervals_ is populated by calling the
@@ -1795,11 +1863,18 @@ absl::StatusOr<HeapSimulator::Result<HloValue>> MsaAlgorithm::Finish() {
1795
1863
} else if (result_requires_uncommit (result)) {
1796
1864
UncommitPendingChunks (absl::MakeSpan (proposal.allocation_values ));
1797
1865
VLOG (2 ) << " Couldn't allocate. Retry number " << retry_number;
1866
+ if (retry_number > 0 && !sorted_async_conversion_candidates_.empty ()) {
1867
+ failed_async_conversions_[sorted_async_conversion_candidates_.at (0 )] =
1868
+ AsyncConversionResult::kFailedGaveUp ;
1869
+ VLOG (2 ) << " Giving the allocation another chance by dropping one "
1870
+ " async conversion candidate." ;
1871
+ proposal = GetJointProposal (interval);
1872
+ --retry_number;
1873
+ }
1798
1874
} else if ((result_is (result, Result::kFailOutOfMemory ) ||
1799
1875
options_.repack_after_every_allocation ) &&
1800
1876
num_repacks_ < options_.max_repacks && !repacked &&
1801
- (sorted_async_conversion_candidates_.empty () ||
1802
- !options_.enable_sync_slice_replacement )) {
1877
+ !RepackAllocationsIncludeConvertedSyncMemOp ()) {
1803
1878
UncommitPendingChunks (absl::MakeSpan (proposal.allocation_values ));
1804
1879
++num_repacks_;
1805
1880
repacked = true ;
@@ -1861,6 +1936,9 @@ absl::StatusOr<HeapSimulator::Result<HloValue>> MsaAlgorithm::Finish() {
1861
1936
1862
1937
if (options_.repack_after_every_allocation ) {
1863
1938
CHECK_NE (options_.repacker , nullptr );
1939
+ CHECK (!RepackAllocationsIncludeConvertedSyncMemOp ())
1940
+ << " Repacking is not supported yet when there are converted sync mem "
1941
+ " ops." ;
1864
1942
std::vector<AllocationBlock*> repack_allocation_blocks;
1865
1943
ExportAllocationsForRepacking (repack_allocation_blocks);
1866
1944
VLOG (2 ) << " Final Repacking." ;
@@ -2238,7 +2316,7 @@ MsaAlgorithm::GenerateAllocationSegmentContexts(
2238
2316
uses_work_list.push_back ({&allocation_value.uses (), primary_use_idx,
2239
2317
allocation_value_idx, true });
2240
2318
for (auto sync_destination_idx :
2241
- value_indices_by_sync_inst[ primary_use.hlo_use .instruction ] ) {
2319
+ value_indices_by_sync_inst. at ( primary_use.hlo_use .instruction ) ) {
2242
2320
AllocationValue& sync_destination =
2243
2321
allocation_values.at (sync_destination_idx);
2244
2322
if (sync_destination.defining_instruction () ==
@@ -2257,13 +2335,6 @@ MsaAlgorithm::GenerateAllocationSegmentContexts(
2257
2335
uses_work_list.push_back ({&sync_destination.uses (),
2258
2336
secondary_use_id,
2259
2337
updates_allocation_value_idx, false });
2260
- if (sync_destination.requires_contiguous_allocation ()) {
2261
- VLOG (3 ) << " Setting requires_contiguous_allocation to true for "
2262
- << allocation_value.ToShortString ()
2263
- << " because its skip destination "
2264
- << sync_destination.ToShortString () << " requires it." ;
2265
- allocation_value.set_requires_contiguous_allocation (true );
2266
- }
2267
2338
}
2268
2339
} else {
2269
2340
VLOG (3 ) << " Skipping secondary uses related to allocation value "
@@ -2421,7 +2492,7 @@ absl::StatusOr<MsaAlgorithm::Result> MsaAlgorithm::AllocateAllocationValues(
2421
2492
return copy_allocation &&
2422
2493
(copy_allocation->copy_done_schedule_before () <=
2423
2494
request.required_copy_allocation_latest_time ) &&
2424
- (copy_allocation->sync_instruction () ==
2495
+ (copy_allocation->sync_mem_op () ==
2425
2496
request.required_copy_allocation_for ) &&
2426
2497
(!request.required_copy_for_slice ||
2427
2498
(request.required_copy_for_slice &&
@@ -2435,6 +2506,8 @@ absl::StatusOr<MsaAlgorithm::Result> MsaAlgorithm::AllocateAllocationValues(
2435
2506
return sliced_copy_allocation &&
2436
2507
(sliced_copy_allocation->earliest_available_time () <=
2437
2508
request.required_copy_allocation_latest_time ) &&
2509
+ (sliced_copy_allocation->sync_mem_op () ==
2510
+ request.required_copy_allocation_for ) &&
2438
2511
!request.required_copy_for_slice ;
2439
2512
}
2440
2513
return false ;
@@ -2446,10 +2519,6 @@ absl::StatusOr<MsaAlgorithm::Result> MsaAlgorithm::AllocateAllocationValues(
2446
2519
" segment allocation. "
2447
2520
" Sync copy replacement has failed. Fall back to the "
2448
2521
" normal mode." ;
2449
- VLOG (3 ) << " result_requires_uncommit(result)"
2450
- << result_requires_uncommit (result)
2451
- << " it == allocation_sequence->end()"
2452
- << (it == allocation_sequence->end ());
2453
2522
failed_async_conversions_[request.required_copy_allocation_for ] =
2454
2523
AsyncConversionResult::kFailedSatisfyingConstraints ;
2455
2524
result_mark (Result::kFailSyncDataMoveReplacement , result);
@@ -2472,6 +2541,8 @@ absl::StatusOr<MsaAlgorithm::Result> MsaAlgorithm::AllocateAllocationValues(
2472
2541
result_mark (Result::kFailSyncDataMoveReplacement , result);
2473
2542
result_mark (Result::kFailRequiresUncommit , result);
2474
2543
} else {
2544
+ not_finalized_async_conversions_.push_back (
2545
+ request.required_copy_allocation_for );
2475
2546
VLOG (3 ) << " Replacing "
2476
2547
<< request.required_copy_allocation_for ->ToShortString ()
2477
2548
<< " with " << (*it)->ToString ();
@@ -2513,9 +2584,33 @@ absl::StatusOr<MsaAlgorithm::Result> MsaAlgorithm::AllocateAllocationValues(
2513
2584
preferred_offset_for_computation);
2514
2585
}
2515
2586
}
2587
+
2588
+ if (!VerifyAllConversionsAreSuccessful ()) {
2589
+ result_mark (Result::kFailSyncDataMoveReplacement , result);
2590
+ result_mark (Result::kFailRequiresUncommit , result);
2591
+ }
2592
+
2516
2593
return result;
2517
2594
}
2518
2595
2596
+ bool MsaAlgorithm::VerifyAllConversionsAreSuccessful () {
2597
+ for (const HloInstruction* instruction :
2598
+ sorted_async_conversion_candidates_) {
2599
+ if (absl::c_find (not_finalized_async_conversions_, instruction) ==
2600
+ not_finalized_async_conversions_.end ()) {
2601
+ if (!failed_async_conversions_.contains (instruction)) {
2602
+ failed_async_conversions_[instruction] =
2603
+ AsyncConversionResult::kFailedNotProcessed ;
2604
+ VLOG (3 ) << " Async conversion failed for "
2605
+ << instruction->ToShortString ()
2606
+ << " because its operand or user was not processed." ;
2607
+ }
2608
+ return false ;
2609
+ }
2610
+ }
2611
+ return true ;
2612
+ }
2613
+
2519
2614
MsaAlgorithm::AliasedOffset* MsaAlgorithm::UpdatePreferredOffsetForUse (
2520
2615
const AllocationValue::Use& use,
2521
2616
MsaAlgorithm::AliasedOffset* preferred_offset) const {
@@ -2555,11 +2650,11 @@ MsaAlgorithm::AllocationRequest MsaAlgorithm::CreateAllocationRequest(
2555
2650
required_copy_for_slice =
2556
2651
(IsAsyncConversionSliceCandidate (use.sync_mem_op_operand ) ==
2557
2652
AsyncConversionResult::kSuccess );
2558
-
2559
2653
// The async copy allocation can be delayed until the earliest time at which
2560
2654
// the value is used in a position or the earliest use time of the updated
2561
2655
// allocation value. We find the minimum of these two times.
2562
- int64_t min_time = GetCorrectedUseTime (use.sync_mem_op_operand );
2656
+ int64_t min_time =
2657
+ GetCorrectedUseTime (allocation_value.defining_instruction ());
2563
2658
int64_t earliest_position_time = std::numeric_limits<int64_t >::max ();
2564
2659
for (auto & position : allocation_value.value ()->positions ()) {
2565
2660
auto position_time = GetCorrectedUseTime (position.instruction );
@@ -4079,6 +4174,7 @@ void MsaAlgorithm::FinalizeAllocations(
4079
4174
for (const HloInstruction* copy_inst : sorted_async_conversion_candidates_) {
4080
4175
successful_async_conversion_set_.insert (copy_inst);
4081
4176
}
4177
+ not_finalized_async_conversions_.clear ();
4082
4178
std::vector<std::pair<const AliasedOffset*, std::vector<Allocation*>>>
4083
4179
colocation_vector;
4084
4180
absl::flat_hash_map<const AliasedOffset*, size_t > offset_to_index;
@@ -4317,7 +4413,7 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) {
4317
4413
return allocation->memory_space () == required_memory_space_at_start;
4318
4414
});
4319
4415
if (prev_allocation_it != allocation_sequence->rend ()) {
4320
- (*prev_allocation_it)->set_end_time (request.inclusive_start_time );
4416
+ (*prev_allocation_it)->Extend (request.inclusive_start_time );
4321
4417
needs_required_allocation = false ;
4322
4418
}
4323
4419
}
@@ -4416,7 +4512,7 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) {
4416
4512
if (required_memory_space_at_end == MemorySpace::kDefault ) {
4417
4513
VLOG (3 )
4418
4514
<< " Not trying to prefetch because use requires buffer in default mem." ;
4419
- (*prev_allocation_in_default_mem_it)->set_end_time (request.end_time );
4515
+ (*prev_allocation_in_default_mem_it)->Extend (request.end_time );
4420
4516
(*prev_allocation_in_default_mem_it)->AddUse (request.use ->hlo_use );
4421
4517
4422
4518
// If the buffer is placed in default memory, we can also try window
@@ -4628,7 +4724,8 @@ void MsaAlgorithm::AddAsyncSlicesForPrefetch(
4628
4724
const Allocation& prev_allocation, AllocationSequence* allocations,
4629
4725
AliasedOffset* aliased_offset,
4630
4726
const std::vector<SliceDecision>& slice_decisions_sorted_by_start_time,
4631
- int64_t prefetch_end_time, int64_t allocation_end_time) {
4727
+ int64_t prefetch_end_time, int64_t allocation_end_time,
4728
+ HloInstruction* sync_mem_op) {
4632
4729
VLOG (3 ) << " Sliced copy to alternate memory. "
4633
4730
<< SliceTimesAndCopyResourcesToString (
4634
4731
slice_decisions_sorted_by_start_time, prefetch_end_time,
@@ -4642,7 +4739,7 @@ void MsaAlgorithm::AddAsyncSlicesForPrefetch(
4642
4739
prev_allocation, MemorySpace::kAlternate ,
4643
4740
slice_decisions_sorted_by_start_time, prefetch_end_time,
4644
4741
allocation_end_time, options_.sliced_prefetch_options ,
4645
- options_.get_equivalent_s8_shape_fn ));
4742
+ options_.get_equivalent_s8_shape_fn , sync_mem_op ));
4646
4743
4647
4744
// Register the additional async copy with the interval tree to keep track of
4648
4745
// the limit at any given time.
@@ -4856,7 +4953,12 @@ MsaAlgorithm::Result MsaAlgorithm::Evict(const AllocationRequest& request) {
4856
4953
4857
4954
MsaBufferInterval eviction_mem_interval;
4858
4955
eviction_mem_interval.buffer = request.allocation_value ->value ();
4859
- eviction_mem_interval.size = request.size ;
4956
+ // When replacing an sync slice, the size of the original allocation_value
4957
+ // matters instead of the queuing_allocation_value
4958
+ // TODO(mehrdadk): separate the request size for src and dst
4959
+ // AllocationSequence
4960
+ eviction_mem_interval.size =
4961
+ std::max (request.allocation_value ->size (), request.size );
4860
4962
// Try to reserve a buffer from the end of the previous allocation to the
4861
4963
// preferred eviction end time.
4862
4964
eviction_mem_interval.start = eviction_end_time + 1 ;
@@ -4912,7 +5014,7 @@ MsaAlgorithm::Result MsaAlgorithm::Evict(const AllocationRequest& request) {
4912
5014
// See if this interval would violate the asynchronous copy limit.
4913
5015
if (!eviction_interval_too_short && !eviction_violates_outstanding_copies &&
4914
5016
!eviction_violates_resource) {
4915
- prev_allocation->set_end_time (eviction_end_time);
5017
+ prev_allocation->Extend (eviction_end_time);
4916
5018
AddAsyncCopyOrOtherMemOp (
4917
5019
*prev_allocation, MemorySpace::kDefault ,
4918
5020
/* chunk=*/ std::nullopt, eviction_exclusive_start_time,
@@ -5144,7 +5246,8 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch(
5144
5246
->mutable_allocation_sequence (),
5145
5247
context.request ->preferred_offset ,
5146
5248
context.sliced_solution ->slice_decisions_sorted_by_start_time ,
5147
- context.prefetch_end_time , context.request ->end_time );
5249
+ context.prefetch_end_time , context.request ->end_time ,
5250
+ context.request ->required_copy_allocation_for );
5148
5251
context.request ->updates_allocation_value ->allocation_sequence ()
5149
5252
->back ()
5150
5253
->AddUse (context.request ->use ->hlo_use );
0 commit comments