Skip to content

[AMD] added slicing ttg.async_copy_global_to_local #797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 47 commits into
base: shared/triton-gfx950-launch
Choose a base branch
from

Conversation

ravil-mobile
Copy link

> TRITON_HIP_USE_BLOCK_PINGPONG=1 TRITON_HIP_USE_ASYNC_COPY=1 pytest -s -v op_tests/triton_tests/test_gemm_afp4wfp4.py

...

================================================================================================= warnings summary =================================================================================================
op_tests/triton_tests/test_gemm_afp4wfp4.py::test_gemm_afp4_wfp4[dtype0-1024-1024-1024]
  /home/rdorozhi/work/aiter/op_tests/triton_tests/test_gemm_afp4wfp4.py:92: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/Context.cpp:328.)
    return torch.mm(x_f32, w_f32).to(dtype)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
===================================================================================== 74 passed, 1 warning in 65.59s (0:01:05) =====================================================================================

AlexAUT and others added 30 commits May 13, 2025 17:19
4-stage FA experiment

Cluster assignment
Initial support over already arranged ops.
…ction based on the loop. This is not meant as a permanent solution just to make this branch useable for other workloads
Computation part interleaves mfma and ds_read
Placed extra conditional barrier to overlap computation part
and buffer_load part. Dot slicing by plognjen at https://github.com/plognjen/triton/tree/slice_dot_scaled
requires vmcnt fix to achieve full performance.
Fix incorrect condition to choose enable transforms.
Fix missing tokens to the local_load
Only enable for 256x256x256 tilesize
… BufferLoadToLocal to avoid implicit barrier from Membar"

This reverts commit 012793a.
@ravil-mobile ravil-mobile force-pushed the shared/triton-gfx950-launch-update branch from 9276703 to 24ae652 Compare May 16, 2025 12:23
@ravil-mobile
Copy link
Author

@jungpark-mlir, @raikonenfnu. Thanks for your comments! I addressed all of them.

auto sizePerThread = encoding.getSizePerThread();
SmallVector<unsigned> threadPerWarp(warpsPerCTA.size(), 0);
for (size_t dim = 0; dim < numDims; ++dim) {
threadPerWarp[dim] =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a silly Q for educational purposes, looks like we are only updating the threadsPerWarp here, does it not make more sense to updated the sizePerThread in the newEncoding?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we want to preserve the number of elements per thread (i.e., each holds 16 consecutive elements of a tensor). We just want to change the layout of threads; thus, to change which part of a tensor is held by CTA.

Each CTA holds 64x128 tile of a tensor in the following example

#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>

But if we change threadsPerWarp to [32,2] then a CTA holds 256x32 tile of a tensor.

#blocked3 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [8, 1], order = [1, 0]}>

In both cases, a thread holds 16 consecutive elements which determines the width of load instructions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the answer, that's really interesting! A follow up question, if we change the which part of tensor the CTA will hold, wouldn't we need to have an extra global read (or potentially fused but wider global read) to get that part of tensor?

Copy link
Author

@ravil-mobile ravil-mobile May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConverLayout results in some machine ops if it cannot be fully optimized (propagated till the top of a function). If I remember correctly, the layout change is going to happen in LDS. In out case, it is going to be optimized

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! didn't realize this was pre-layout propagation phase, but this makes a lot of sense, thanks :)

Comment on lines +927 to +931
builder
.create<ttg::ConvertLayoutOp>(tensor.getLoc(), newType, tensor)
.getResult();
slicedTensorType =
RankedTensorType::get(slicedShape, elemType, newEncoding);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another silly Q, I know amdgpu.extract_slice src and dst layout need to match, is this the main purpose of ttg.convert_layout to set up layout for stuff like extract_slice where the layout is for sure to change? On top of that, does that mean ttg.convert_layout for slices will most likely not have a matching layout to it's shape?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dst layout is determined by the source layout. The problem comes from the following. Let's assume we want to slice 256x128 tensor into 4 pieces of 256x32 tiles. Let's also assume that the original layout is

# orig-layout
#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>

which holds 64x128 tiles per CTA. ExtractSliceOp is the CTA level op - i.e., it can slice a tensor if a new size is proportional to the CTA tile - i.e.,64x128 in our case. Therefore, we cannot apply ExtractSliceOp to our tensor with orig-layout. Thus, we change the source layout to a new one - i.e.,

# new-layout
#blocked3 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [8, 1], order = [1, 0]}>

which has its CTA-level slice equaled to 256x32. Now we can slice 256x128 tensor into 4 pieces. (note: it was determined by dot-slicing)

Comment on lines +919 to +921
RankedTensorType newType = nullptr;
Value newTensor = nullptr;
RankedTensorType slicedTensorType = nullptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: IIRC, RankedTensorType/Type/Value will default to null even if you don't explicitly set nullptr

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I just like to be explicit in the code about initialization of all local variables (old habit).

@ravil-mobile ravil-mobile force-pushed the shared/triton-gfx950-launch-update branch from 24ae652 to 1b2a86b Compare May 19, 2025 09:36
@ravil-mobile ravil-mobile requested a review from raikonenfnu May 19, 2025 09:47
antiagainst and others added 7 commits May 19, 2025 09:39
…ton-lang#6844)

This commit improves how we create the mfma-like layout for
optimizing global store by using linear layout composition.
Along the way fixes a few implemenation issues.

---------

Co-authored-by: Yi Qian <[email protected]>
avoid wrongly enabled.
Requirement to enable the transform
: mxfp4, 128x128x512 tile size, async_copy, num_stages=2, num_warps=8
@ravil-mobile ravil-mobile force-pushed the shared/triton-gfx950-launch-update branch from 67b292c to 3167930 Compare May 21, 2025 10:48
@ravil-mobile ravil-mobile force-pushed the shared/triton-gfx950-launch-update branch from 3167930 to a89b3b4 Compare May 21, 2025 14:33
plognjen and others added 3 commits May 21, 2025 14:51
The "concat" operation combines a list of source n-dimensional tensors
into a single larger destination tensor.

All source tensors must have the same shape, element type, and encoding.
The concatenation dimension is inferred from the source and destination
shapes provided by the user.
For example, two tensors of shape 64x128 can produce a destination shape
of 128x128,
indicating concatenation along dimension 0; or 64x256, indicating
concatenation along dimension 1.

Generally, source tensors passed as op arguments can be arranged into
the resulting shape in multiple ways.
For example, given four tensors of shape 64x64:
  concat s0<64x64>, s1<64x64>, s2<64x64>, s3<64x64> -> <128x128>

They can be laid out in different configurations within the result
tensor:
   1) s0 s1 
       s2 s3  

   2) s0 s2
        s1 s3

From a logical tensor perspective, the source tensors are treated as
elements of a tensor of tensors.
In other words, the 1-D array of input tensors is conceptually reshaped
into an n-D grid.
The semantics of this op assume a row-major order (or its n-D
generalization),
meaning the fastest-varying dimension is filled first, and the
slowest-varying dimension is filled last.
In the example above, this corresponds to layout 1).

The source and destination tensors must have identical linear layouts at
the CTA tile level.
That is, all base vectors for input dimensions must match, except for
the register input dimension.
The register basis must align on the subset that defines the logical
tensor shape of a single CTA tile.

This ensures that the concatenation is a no-op, meaning no data
rearrangement among threads is required
to assemble the destination tensor with the given shape and layout.
However, the order of CTA tiles within the layout does not need to match
between source and destination layouts.
It is the responsibility of the op's lowering logic to handle this
correctly.

This op is designed to work on logical tensors directly, avoiding the
need for complex layout reinterpretation or reshaping.
For example, the `tt.join` operation only supports concatenation along
the innermost dimension,
and requires that the resulting innermost dimension provide 2 elements
per thread, distributed across registers.
In contrast, this `concat` op imposes no constraints on the
concatenation dimension or the size of dimensions.

---------

Co-authored-by: Ognjen Plavsic <[email protected]>
Co-authored-by: Lei Zhang <[email protected]>
@ravil-mobile ravil-mobile force-pushed the shared/triton-gfx950-launch-update branch from 20f8e72 to 34538bc Compare May 26, 2025 13:37
@antiagainst antiagainst force-pushed the shared/triton-gfx950-launch branch from 77c00fa to a259f0a Compare May 26, 2025 17:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants