Skip to content

New pass Reduce variable liveness #3965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

mfrancepillois
Copy link
Contributor

@mfrancepillois mfrancepillois commented Apr 18, 2025

Add a new pass to reduce the variable liveness by prefetching data then moving load op closer to use-op.
Add test.

@mfrancepillois mfrancepillois requested review from whitneywhtsang, etiotto and a team April 18, 2025 10:52
@mfrancepillois
Copy link
Contributor Author

Performance improvement for FA on PVC1550:
image

@mfrancepillois mfrancepillois changed the title Add pass: Reduce the register pressure New pass Reduce register pressure Apr 18, 2025
@mfrancepillois mfrancepillois linked an issue Apr 18, 2025 that may be closed by this pull request
@mfrancepillois mfrancepillois changed the title New pass Reduce register pressure [Draft] New pass Reduce register pressure Apr 18, 2025
@mfrancepillois mfrancepillois marked this pull request as draft April 18, 2025 11:40
@mfrancepillois mfrancepillois marked this pull request as ready for review April 18, 2025 16:31
@mfrancepillois mfrancepillois changed the title [Draft] New pass Reduce register pressure New pass Reduce register pressure Apr 18, 2025
namespace {

/// Return true if the lifespan of the V value is considered long.
static bool isLongLifeSpanVariable(Value v, Block *useBlock) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The heuristic is pretty crude. At some point I wrote an analysis to estimate the live range of a variable. The analysis is https://github.com/intel/intel-xpu-backend-for-triton/blob/main/third_party/intel/include/Analysis/Liveness.h. I am wondering whether we should attempt to use that analysis to collect the live ranges of variables in a loop. Then sink variables that have a live range that is "too big".

Can you give that a try ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is tricky to get this kind of instruction "scheduling" correct at the Triton level. I kind of feel the low level compiler (e.g. IGC) would have all the information to schedule instructions based on register usage.. Hard to do that well at the abstraction level Triton operates at.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recognize that at this level, it is difficult to say that we are reducing the register pressure. Instead, we reduce the variable liveness hoping that liveness reduction will allow us to save registers.
I therefore renamed the pass in this way.
I've also based the heuristic on MLIR Liveness analysis.

@mfrancepillois mfrancepillois changed the title New pass Reduce register pressure New pass Reduce variable liveness Apr 24, 2025
// CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
// CHECK-NOT: tt.load {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
%1 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #dot1>>
%2 = tt.load %0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<128x64xf16, #dot0>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not sink the load inside the loop in Triton.. Sinking the load in the loop means that the value is loaded at every loop iteration. Triton doesn't have enough information about register pressure to determine whether this is profitable.

@mfrancepillois mfrancepillois marked this pull request as draft April 30, 2025 16:53
Signed-off-by: Maxime France-Pillois <[email protected]>
@mfrancepillois mfrancepillois marked this pull request as ready for review April 30, 2025 17:46
@etiotto etiotto requested review from alexbaden and chengjunlu May 1, 2025 19:37
Copy link
Contributor

@whitneywhtsang whitneywhtsang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a loop sink pass in IGC. Can you please create an issue for IGC team to investigate why it doesn't catch the case of FA with the shape that gives the most gain?


/// Create a prefetch operation for the given load operation.
static void createPrefetchOp(tt::LoadOp loadOp) {
Operation *op = loadOp.getPtr().getDefiningOp();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when did we check that loadOp.getPtr() is an operation? do we need to add that to isLoadCandidate?
Or should we add the support of when pointer is a region argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noticing. A check has been added to isLoadCandidate.
As the pass adds a prefetch right after the defining op, I'm concerned that adding this prefetch in another region (in the case the load ptr has been defined in another region) could have side effects on the cache (as an early data fetch could mean evincing data that are still needed).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we care about the case that the pointer directly come from function argument?

@chengjunlu
Copy link
Contributor

chengjunlu commented May 6, 2025

It is good to have the reduce variable liveness as the beginning for liveness optimization in the Triton middle end.
This PR looks good to me as the beginning.

The optimization relies on the cache to hold the values that we may reuse in the loop. But the cache system is not fully controllable by the program. The better we can enhance it with the usage of shared local memory and make it some how like RegisterToMem pass for general case.

@etiotto
Copy link
Contributor

etiotto commented May 6, 2025

@mfrancepillois can you do a Triton Benchmark run with this PR to identify improvement (or degradations - hopefully none) in all the microbmks we have ?

@mfrancepillois mfrancepillois marked this pull request as draft May 12, 2025 13:11
Operation *forOp) {
// Only pointer to tensor are considered to be moved
if (!mlir::triton::isTensorPointerType(loadOp.getPtr().getType()))
if (!mlir::triton::isTensorOrTensorPointerType(loadOp.getPtr().getType()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional]

Suggested change
if (!mlir::triton::isTensorOrTensorPointerType(loadOp.getPtr().getType()))
if (!mlir::triton::isTensorPointerType(loadOp.getResult().getType()))

// Multiple users
if (any_of(loadOp->getUsers(), [&](Operation *user) {
return ((user->getBlock() == forOp->getBlock()) &&
user->isBeforeInBlock(forOp));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does user->isBeforeInBlock(forOp) mean?
user->getBlock() == forOp->getBlock() means user is part of the loop?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FA performance] Improve the Q matrix load stategy
4 participants