Skip to content

Commit e7057bc

Browse files
mehrdadkhanitensorflower-gardener
authored andcommitted
[XLA:MSA] XLA flags for converting synchronous copies and slices into asynchronous ones are now enabled by default. This change enables asynchronous data movement optimizations, potentially improving performance. The criteria for converting synchronous operations to asynchronous ones has also been refined with stricter checks on operands, memory allocation, and usage patterns.
PiperOrigin-RevId: 688363223
1 parent 2095602 commit e7057bc

File tree

5 files changed

+248
-78
lines changed

5 files changed

+248
-78
lines changed

third_party/xla/xla/service/memory_space_assignment/algorithm.cc

Lines changed: 163 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,9 +1285,76 @@ void MsaAlgorithm::IdentifyAndOptimizeMemoryBoundLoops() {
12851285

12861286
bool MsaAlgorithm::IsAsyncConversionCandidate(
12871287
const HloInstruction* instruction) const {
1288-
return IsAsyncConversionCopyCandidate(instruction) ||
1289-
IsAsyncConversionSliceCandidate(instruction) ==
1290-
AsyncConversionResult::kSuccess;
1288+
bool meets_special_preconditions =
1289+
IsAsyncConversionCopyCandidate(instruction) ||
1290+
IsAsyncConversionSliceCandidate(instruction) ==
1291+
AsyncConversionResult::kSuccess;
1292+
if (!meets_special_preconditions) {
1293+
return false;
1294+
}
1295+
1296+
for (auto& operand : instruction->operands()) {
1297+
// TODO(b/374835319): relax the operand constraint to be able to cover
1298+
// nested sync data movement cases.
1299+
if (IsAsyncConversionCandidate(operand)) {
1300+
VLOG(4) << "The instruction is not considered to be replaced, because it "
1301+
"potentially has a replaceable operand.";
1302+
return false;
1303+
}
1304+
const HloValue& operand_value = alias_analysis_.dataflow_analysis()
1305+
.GetValueSet(operand)
1306+
.GetUniqueValue();
1307+
if (!buffer_intervals_.at(&operand_value).need_allocation) {
1308+
VLOG(4)
1309+
<< "The instruction is not considered to be replaced, because its "
1310+
"operand value doesn't need an allocation.";
1311+
return false;
1312+
}
1313+
}
1314+
1315+
const HloValue& value = alias_analysis_.dataflow_analysis()
1316+
.GetValueSet(instruction)
1317+
.GetUniqueValue();
1318+
if (!buffer_intervals_.at(&value).need_allocation) {
1319+
VLOG(4) << "The instruction is not considered to be replaced, because its "
1320+
"output doesn't need an allocation and it might be too late to "
1321+
"replace this instruction.";
1322+
return false;
1323+
}
1324+
if (value.IsRootOf(instruction->parent())) {
1325+
VLOG(4) << "The instruction is not considered to be replaced, because its "
1326+
"output value is in the root of the computation.";
1327+
return false;
1328+
}
1329+
if (finalized_values_.contains(&value)) {
1330+
VLOG(4) << "The instruction is not considered to be replaced, because its "
1331+
"output value is in the finalized values.";
1332+
return false;
1333+
}
1334+
if (buffer_intervals_.at(&value).size > available_heap_size()) {
1335+
VLOG(4) << "The instruction is not considered to be replaced, because its "
1336+
"output value is too large to fit in the heap.";
1337+
return false;
1338+
}
1339+
// This check is here only because we skip processing the values that are not
1340+
// allowed in alternate memory.
1341+
if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
1342+
buffer_intervals_.at(&value), options_.alternate_memory_space)) {
1343+
VLOG(4) << "The instruction is not considered to be replaced, because its "
1344+
"output value is not allowed in alternate memory.";
1345+
return false;
1346+
}
1347+
1348+
for (const HloInstruction* user : instruction->users()) {
1349+
if (HloDataflowAnalysis::IsAsynchronousOperationStart(user->opcode())) {
1350+
VLOG(4) << "The instruction is not considered to be replaced, because "
1351+
"its used by an async start operation that might require "
1352+
"contiguous allocation.";
1353+
return false;
1354+
}
1355+
}
1356+
1357+
return true;
12911358
}
12921359

12931360
bool MsaAlgorithm::IsAsyncConversionCopyCandidate(
@@ -1363,18 +1430,6 @@ MsaAlgorithm::IsAsyncConversionSliceCandidate(
13631430
<< instruction->ToShortString();
13641431
return AsyncConversionResult::kFailedPrecondition;
13651432
}
1366-
bool has_slice_operand = false;
1367-
for (auto& operand : instruction->operands()) {
1368-
if (operand->opcode() == HloOpcode::kSlice) {
1369-
has_slice_operand = true;
1370-
break;
1371-
}
1372-
}
1373-
if (has_slice_operand) {
1374-
VLOG(4) << "The sync slice is not considered to be replaced, because it "
1375-
"has a slice operand.";
1376-
return AsyncConversionResult::kFailedPrecondition;
1377-
}
13781433

13791434
if (instruction->shape().layout().memory_space() !=
13801435
static_cast<int64_t>(MemorySpace::kDefault) ||
@@ -1444,22 +1499,6 @@ void MsaAlgorithm::UpdateSyncDataMovementCandidatesForJointProcessedValues(
14441499
const std::vector<const HloValue*>& joint_processed_values) {
14451500
absl::flat_hash_set<const HloInstruction*> replaceable_sync_instructions;
14461501
absl::flat_hash_set<const HloInstruction*> do_not_touch_instructions;
1447-
for (const HloValue* value : joint_processed_values) {
1448-
if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
1449-
buffer_intervals_.at(value), options_.alternate_memory_space)) {
1450-
HloInstruction* inst = value->instruction();
1451-
if (IsAsyncConversionCandidate(inst)) {
1452-
do_not_touch_instructions.insert(inst);
1453-
failed_async_conversions_[inst] =
1454-
AsyncConversionResult::kFailedValueNotAllowedInAlternateMemory;
1455-
VLOG(4) << "Not trying to replace sync instruction "
1456-
<< inst->ToShortString()
1457-
<< " with an async version, because the sync instruction "
1458-
"defines a value that is not allowed in alternate memory.";
1459-
}
1460-
}
1461-
}
1462-
14631502
for (const HloValue* value : joint_processed_values) {
14641503
for (const auto& use : value->GetUses()) {
14651504
bool is_use_replaceable_sync_candidate =
@@ -1577,12 +1616,23 @@ void MsaAlgorithm::CreateAllocationValuesForJointProcessedValues(
15771616
continue;
15781617
}
15791618

1580-
if (!options_.enable_sync_slice_replacement &&
1581-
!options_.enable_window_prefetch &&
1619+
if (!options_.enable_window_prefetch &&
15821620
interval.size > available_heap_size()) {
1583-
VLOG(3) << "Skip " << interval.buffer->ToShortString()
1584-
<< " because the buffer is larger than the heap size.";
1585-
continue;
1621+
const HloInstruction* defining_instruction =
1622+
interval.buffer->instruction();
1623+
auto may_be_replaced_by_slice_fn = [this](const HloInstruction* user) {
1624+
return IsInstructionPendingReplacements(user) &&
1625+
user->opcode() == HloOpcode::kSlice;
1626+
};
1627+
bool may_be_replaced_by_slice = std::any_of(
1628+
defining_instruction->users().begin(),
1629+
defining_instruction->users().end(), may_be_replaced_by_slice_fn);
1630+
1631+
if (!may_be_replaced_by_slice) {
1632+
VLOG(3) << "Skip " << interval.buffer->ToShortString()
1633+
<< " because the buffer is larger than the heap size.";
1634+
continue;
1635+
}
15861636
}
15871637

15881638
auto colocated_intervals = GetSortedColocatedIntervals(interval);
@@ -1642,6 +1692,24 @@ MsaAlgorithm::JointAllocationProposal MsaAlgorithm::GetJointProposal(
16421692
return proposal;
16431693
}
16441694

1695+
bool MsaAlgorithm::RepackAllocationsIncludeConvertedSyncMemOp() {
1696+
for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) {
1697+
if (allocation_block.allocation->is_copy_allocation()) {
1698+
if (dynamic_cast<CopyAllocation*>(allocation_block.allocation)
1699+
->sync_mem_op()) {
1700+
return true;
1701+
}
1702+
}
1703+
if (allocation_block.allocation->is_sliced_copy_allocation()) {
1704+
if (dynamic_cast<SlicedCopyAllocation*>(allocation_block.allocation)
1705+
->sync_mem_op()) {
1706+
return true;
1707+
}
1708+
}
1709+
}
1710+
return false;
1711+
}
1712+
16451713
absl::StatusOr<HeapSimulator::Result<HloValue>> MsaAlgorithm::Finish() {
16461714
// Note: Memory Space Assignment creates a HeapSimulator and passes an
16471715
// MsaAlgorithm object to it. buffer_intervals_ is populated by calling the
@@ -1795,11 +1863,18 @@ absl::StatusOr<HeapSimulator::Result<HloValue>> MsaAlgorithm::Finish() {
17951863
} else if (result_requires_uncommit(result)) {
17961864
UncommitPendingChunks(absl::MakeSpan(proposal.allocation_values));
17971865
VLOG(2) << "Couldn't allocate. Retry number " << retry_number;
1866+
if (retry_number > 0 && !sorted_async_conversion_candidates_.empty()) {
1867+
failed_async_conversions_[sorted_async_conversion_candidates_.at(0)] =
1868+
AsyncConversionResult::kFailedGaveUp;
1869+
VLOG(2) << "Giving the allocation another chance by dropping one "
1870+
"async conversion candidate.";
1871+
proposal = GetJointProposal(interval);
1872+
--retry_number;
1873+
}
17981874
} else if ((result_is(result, Result::kFailOutOfMemory) ||
17991875
options_.repack_after_every_allocation) &&
18001876
num_repacks_ < options_.max_repacks && !repacked &&
1801-
(sorted_async_conversion_candidates_.empty() ||
1802-
!options_.enable_sync_slice_replacement)) {
1877+
!RepackAllocationsIncludeConvertedSyncMemOp()) {
18031878
UncommitPendingChunks(absl::MakeSpan(proposal.allocation_values));
18041879
++num_repacks_;
18051880
repacked = true;
@@ -1861,6 +1936,9 @@ absl::StatusOr<HeapSimulator::Result<HloValue>> MsaAlgorithm::Finish() {
18611936

18621937
if (options_.repack_after_every_allocation) {
18631938
CHECK_NE(options_.repacker, nullptr);
1939+
CHECK(!RepackAllocationsIncludeConvertedSyncMemOp())
1940+
<< "Repacking is not supported yet when there are converted sync mem "
1941+
"ops.";
18641942
std::vector<AllocationBlock*> repack_allocation_blocks;
18651943
ExportAllocationsForRepacking(repack_allocation_blocks);
18661944
VLOG(2) << "Final Repacking.";
@@ -2238,7 +2316,7 @@ MsaAlgorithm::GenerateAllocationSegmentContexts(
22382316
uses_work_list.push_back({&allocation_value.uses(), primary_use_idx,
22392317
allocation_value_idx, true});
22402318
for (auto sync_destination_idx :
2241-
value_indices_by_sync_inst[primary_use.hlo_use.instruction]) {
2319+
value_indices_by_sync_inst.at(primary_use.hlo_use.instruction)) {
22422320
AllocationValue& sync_destination =
22432321
allocation_values.at(sync_destination_idx);
22442322
if (sync_destination.defining_instruction() ==
@@ -2257,13 +2335,6 @@ MsaAlgorithm::GenerateAllocationSegmentContexts(
22572335
uses_work_list.push_back({&sync_destination.uses(),
22582336
secondary_use_id,
22592337
updates_allocation_value_idx, false});
2260-
if (sync_destination.requires_contiguous_allocation()) {
2261-
VLOG(3) << "Setting requires_contiguous_allocation to true for "
2262-
<< allocation_value.ToShortString()
2263-
<< " because its skip destination "
2264-
<< sync_destination.ToShortString() << " requires it.";
2265-
allocation_value.set_requires_contiguous_allocation(true);
2266-
}
22672338
}
22682339
} else {
22692340
VLOG(3) << "Skipping secondary uses related to allocation value "
@@ -2421,7 +2492,7 @@ absl::StatusOr<MsaAlgorithm::Result> MsaAlgorithm::AllocateAllocationValues(
24212492
return copy_allocation &&
24222493
(copy_allocation->copy_done_schedule_before() <=
24232494
request.required_copy_allocation_latest_time) &&
2424-
(copy_allocation->sync_instruction() ==
2495+
(copy_allocation->sync_mem_op() ==
24252496
request.required_copy_allocation_for) &&
24262497
(!request.required_copy_for_slice ||
24272498
(request.required_copy_for_slice &&
@@ -2435,6 +2506,8 @@ absl::StatusOr<MsaAlgorithm::Result> MsaAlgorithm::AllocateAllocationValues(
24352506
return sliced_copy_allocation &&
24362507
(sliced_copy_allocation->earliest_available_time() <=
24372508
request.required_copy_allocation_latest_time) &&
2509+
(sliced_copy_allocation->sync_mem_op() ==
2510+
request.required_copy_allocation_for) &&
24382511
!request.required_copy_for_slice;
24392512
}
24402513
return false;
@@ -2446,10 +2519,6 @@ absl::StatusOr<MsaAlgorithm::Result> MsaAlgorithm::AllocateAllocationValues(
24462519
"segment allocation. "
24472520
"Sync copy replacement has failed. Fall back to the "
24482521
"normal mode.";
2449-
VLOG(3) << "result_requires_uncommit(result)"
2450-
<< result_requires_uncommit(result)
2451-
<< " it == allocation_sequence->end()"
2452-
<< (it == allocation_sequence->end());
24532522
failed_async_conversions_[request.required_copy_allocation_for] =
24542523
AsyncConversionResult::kFailedSatisfyingConstraints;
24552524
result_mark(Result::kFailSyncDataMoveReplacement, result);
@@ -2472,6 +2541,8 @@ absl::StatusOr<MsaAlgorithm::Result> MsaAlgorithm::AllocateAllocationValues(
24722541
result_mark(Result::kFailSyncDataMoveReplacement, result);
24732542
result_mark(Result::kFailRequiresUncommit, result);
24742543
} else {
2544+
not_finalized_async_conversions_.push_back(
2545+
request.required_copy_allocation_for);
24752546
VLOG(3) << "Replacing "
24762547
<< request.required_copy_allocation_for->ToShortString()
24772548
<< " with " << (*it)->ToString();
@@ -2513,9 +2584,33 @@ absl::StatusOr<MsaAlgorithm::Result> MsaAlgorithm::AllocateAllocationValues(
25132584
preferred_offset_for_computation);
25142585
}
25152586
}
2587+
2588+
if (!VerifyAllConversionsAreSuccessful()) {
2589+
result_mark(Result::kFailSyncDataMoveReplacement, result);
2590+
result_mark(Result::kFailRequiresUncommit, result);
2591+
}
2592+
25162593
return result;
25172594
}
25182595

2596+
bool MsaAlgorithm::VerifyAllConversionsAreSuccessful() {
2597+
for (const HloInstruction* instruction :
2598+
sorted_async_conversion_candidates_) {
2599+
if (absl::c_find(not_finalized_async_conversions_, instruction) ==
2600+
not_finalized_async_conversions_.end()) {
2601+
if (!failed_async_conversions_.contains(instruction)) {
2602+
failed_async_conversions_[instruction] =
2603+
AsyncConversionResult::kFailedNotProcessed;
2604+
VLOG(3) << "Async conversion failed for "
2605+
<< instruction->ToShortString()
2606+
<< " because its operand or user was not processed.";
2607+
}
2608+
return false;
2609+
}
2610+
}
2611+
return true;
2612+
}
2613+
25192614
MsaAlgorithm::AliasedOffset* MsaAlgorithm::UpdatePreferredOffsetForUse(
25202615
const AllocationValue::Use& use,
25212616
MsaAlgorithm::AliasedOffset* preferred_offset) const {
@@ -2555,11 +2650,11 @@ MsaAlgorithm::AllocationRequest MsaAlgorithm::CreateAllocationRequest(
25552650
required_copy_for_slice =
25562651
(IsAsyncConversionSliceCandidate(use.sync_mem_op_operand) ==
25572652
AsyncConversionResult::kSuccess);
2558-
25592653
// The async copy allocation can be delayed until the earliest time at which
25602654
// the value is used in a position or the earliest use time of the updated
25612655
// allocation value. We find the minimum of these two times.
2562-
int64_t min_time = GetCorrectedUseTime(use.sync_mem_op_operand);
2656+
int64_t min_time =
2657+
GetCorrectedUseTime(allocation_value.defining_instruction());
25632658
int64_t earliest_position_time = std::numeric_limits<int64_t>::max();
25642659
for (auto& position : allocation_value.value()->positions()) {
25652660
auto position_time = GetCorrectedUseTime(position.instruction);
@@ -4079,6 +4174,7 @@ void MsaAlgorithm::FinalizeAllocations(
40794174
for (const HloInstruction* copy_inst : sorted_async_conversion_candidates_) {
40804175
successful_async_conversion_set_.insert(copy_inst);
40814176
}
4177+
not_finalized_async_conversions_.clear();
40824178
std::vector<std::pair<const AliasedOffset*, std::vector<Allocation*>>>
40834179
colocation_vector;
40844180
absl::flat_hash_map<const AliasedOffset*, size_t> offset_to_index;
@@ -4317,7 +4413,7 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) {
43174413
return allocation->memory_space() == required_memory_space_at_start;
43184414
});
43194415
if (prev_allocation_it != allocation_sequence->rend()) {
4320-
(*prev_allocation_it)->set_end_time(request.inclusive_start_time);
4416+
(*prev_allocation_it)->Extend(request.inclusive_start_time);
43214417
needs_required_allocation = false;
43224418
}
43234419
}
@@ -4416,7 +4512,7 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) {
44164512
if (required_memory_space_at_end == MemorySpace::kDefault) {
44174513
VLOG(3)
44184514
<< "Not trying to prefetch because use requires buffer in default mem.";
4419-
(*prev_allocation_in_default_mem_it)->set_end_time(request.end_time);
4515+
(*prev_allocation_in_default_mem_it)->Extend(request.end_time);
44204516
(*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
44214517

44224518
// If the buffer is placed in default memory, we can also try window
@@ -4628,7 +4724,8 @@ void MsaAlgorithm::AddAsyncSlicesForPrefetch(
46284724
const Allocation& prev_allocation, AllocationSequence* allocations,
46294725
AliasedOffset* aliased_offset,
46304726
const std::vector<SliceDecision>& slice_decisions_sorted_by_start_time,
4631-
int64_t prefetch_end_time, int64_t allocation_end_time) {
4727+
int64_t prefetch_end_time, int64_t allocation_end_time,
4728+
HloInstruction* sync_mem_op) {
46324729
VLOG(3) << "Sliced copy to alternate memory. "
46334730
<< SliceTimesAndCopyResourcesToString(
46344731
slice_decisions_sorted_by_start_time, prefetch_end_time,
@@ -4642,7 +4739,7 @@ void MsaAlgorithm::AddAsyncSlicesForPrefetch(
46424739
prev_allocation, MemorySpace::kAlternate,
46434740
slice_decisions_sorted_by_start_time, prefetch_end_time,
46444741
allocation_end_time, options_.sliced_prefetch_options,
4645-
options_.get_equivalent_s8_shape_fn));
4742+
options_.get_equivalent_s8_shape_fn, sync_mem_op));
46464743

46474744
// Register the additional async copy with the interval tree to keep track of
46484745
// the limit at any given time.
@@ -4856,7 +4953,12 @@ MsaAlgorithm::Result MsaAlgorithm::Evict(const AllocationRequest& request) {
48564953

48574954
MsaBufferInterval eviction_mem_interval;
48584955
eviction_mem_interval.buffer = request.allocation_value->value();
4859-
eviction_mem_interval.size = request.size;
4956+
// When replacing an sync slice, the size of the original allocation_value
4957+
// matters instead of the queuing_allocation_value
4958+
// TODO(mehrdadk): separate the request size for src and dst
4959+
// AllocationSequence
4960+
eviction_mem_interval.size =
4961+
std::max(request.allocation_value->size(), request.size);
48604962
// Try to reserve a buffer from the end of the previous allocation to the
48614963
// preferred eviction end time.
48624964
eviction_mem_interval.start = eviction_end_time + 1;
@@ -4912,7 +5014,7 @@ MsaAlgorithm::Result MsaAlgorithm::Evict(const AllocationRequest& request) {
49125014
// See if this interval would violate the asynchronous copy limit.
49135015
if (!eviction_interval_too_short && !eviction_violates_outstanding_copies &&
49145016
!eviction_violates_resource) {
4915-
prev_allocation->set_end_time(eviction_end_time);
5017+
prev_allocation->Extend(eviction_end_time);
49165018
AddAsyncCopyOrOtherMemOp(
49175019
*prev_allocation, MemorySpace::kDefault,
49185020
/*chunk=*/std::nullopt, eviction_exclusive_start_time,
@@ -5144,7 +5246,8 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch(
51445246
->mutable_allocation_sequence(),
51455247
context.request->preferred_offset,
51465248
context.sliced_solution->slice_decisions_sorted_by_start_time,
5147-
context.prefetch_end_time, context.request->end_time);
5249+
context.prefetch_end_time, context.request->end_time,
5250+
context.request->required_copy_allocation_for);
51485251
context.request->updates_allocation_value->allocation_sequence()
51495252
->back()
51505253
->AddUse(context.request->use->hlo_use);

0 commit comments

Comments
 (0)