Skip to content

[WIP] AITER integration #185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 33 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
17e9a42
update
slippedJim May 12, 2025
bcbc614
update top CMakeLists & update fused_attn_aiter.cpp
slippedJim May 12, 2025
4d682fd
update: replace ck api
slippedJim May 13, 2025
2da3d45
update: fix header & CMakeLists
slippedJim May 13, 2025
af3bf66
update
slippedJim May 13, 2025
6f4a914
update
slippedJim May 13, 2025
a303a28
update
slippedJim May 13, 2025
fc2f92c
update
slippedJim May 13, 2025
0d40d76
typo
slippedJim May 13, 2025
9f07aea
typo
slippedJim May 13, 2025
bf37cbf
update
slippedJim May 13, 2025
df5219f
update
slippedJim May 13, 2025
0b0cfeb
update
slippedJim May 14, 2025
e4e053b
update 3rdparty/aiter
slippedJim May 14, 2025
c02df6d
fix: mha_bwd bias
slippedJim May 14, 2025
21eb2e1
update
slippedJim May 15, 2025
4367bdd
update
slippedJim May 15, 2025
77aa9c1
update
slippedJim May 15, 2025
b79547b
update aiter
slippedJim May 15, 2025
ea1097f
update 3rdparty
slippedJim May 15, 2025
5acc481
update 3rdparty
slippedJim May 15, 2025
fe6d2dd
update fa build flags
slippedJim May 15, 2025
c10b86c
[ROCm] take ck bf16_cvt from upstream in building
wangye805 May 16, 2025
067f07e
update for fa fwd asm
slippedJim May 19, 2025
8f712ec
update CMakeLists
slippedJim May 19, 2025
43fbb14
update: fix CMakeLists && update aiter
slippedJim May 19, 2025
8f3c6c5
[ROCm] limit fwd v3 asm on BSHD_BSHD_BSHD only
wangye805 May 19, 2025
5a6a9ef
update aiter: enable native bwd hd128 asm kernels
slippedJim May 29, 2025
9f91a4d
update aiter: unused variable
slippedJim May 29, 2025
2329940
update test & benchmark script
slippedJim May 29, 2025
280092c
[ROCm] update aiter commit and use RTNE by default in gfx950
wangye805 Jun 2, 2025
5a4db15
[ROCm] restore pytorch/fused_attn tests
wangye805 Jun 2, 2025
6b079c9
[ROCm] remove Werror for ck_fused_attn compiling due to amdgpu-waves-…
wangye805 Jun 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
[submodule "3rdparty/aotriton"]
path = 3rdparty/aotriton
url = https://github.com/ROCm/aotriton.git
[submodule "3rdparty/composable_kernel"]
path = 3rdparty/composable_kernel
url = https://github.com/ROCm/composable_kernel.git
[submodule "3rdparty/aiter"]
path = 3rdparty/aiter
url = https://github.com/ROCm/aiter.git
[submodule "examples/pytorch/nanogpt"]
path = examples/pytorch/nanogpt
url = https://github.com/floraamd/nanoGPTwTE.git
1 change: 1 addition & 0 deletions 3rdparty/aiter
Submodule aiter added at 07ddac
1 change: 0 additions & 1 deletion 3rdparty/composable_kernel
Submodule composable_kernel deleted from 4c0781
19 changes: 4 additions & 15 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -242,25 +242,14 @@ NVTE_FUSED_ATTN=0 will use the TE unfused attention even if NVTE_FUSED_ATTN_CK o
Fused attention backends are chosen according to the match results between the actual problem config and the support matrix of the specific backend.
For the scenario that both backends are enabled and match the problem configuration, the CK backend will be chosen with higher priority.

FA v3 Backward Kernels in CK Backend
FA v3 Kernels in CK Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ROCm TE provides experimental support for flash-attention v3 bwd kernels using the ck backend for limited fused attention configs (currently only for hdim=128).
ROCm TE provides experimental support for flash-attention v3 fwd/bwd kernels using the ck backend for limited fused attention configs.
To enable FA v3 kernels, the following environment variables can be used:

* NVTE_CK_USES_FWD_V3 - by default 0, if set to 1, some cases will call the fwd v3 kernel;
* NVTE_CK_USES_BWD_V3 - by default 0, if set to 1, some cases will call the bwd v3 dqdkdv kernel;
* NVTE_CK_IS_V3_ATOMIC_FP32 - by default 1, if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) when NVTE_CK_USES_BWD_V3 is set to 1;
* NVTE_CK_HOW_V3_BF16_CVT - by default 1, float to bf16 convert type when bwd_v3 is set to 1, 0:RTNE; 1:RTNA; 2:RTZ.

Float to BFloat16 Conversion in CK Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
How fp32 converts to bf16 affects both the performance and accuracy in ck fused attn.
ROCm TE provides the compile-time env NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT with the following values available to choose from:

* 0 - standard;
* 1 - truncate with nan;
* 2 - truncate;
* 3 - standard asm, default;
* 4 - rta_asm.
* NVTE_CK_IS_V3_ATOMIC_FP32 - by default 1, if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) in bwd pass when NVTE_CK_USES_BWD_V3 is set to 1;

Experimental Triton Kernels on ROCm
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
49 changes: 27 additions & 22 deletions benchmarks/attention/benchmark_attention_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
import numpy as np
import torch
import nvtx
# import nvtx
import transformer_engine
from transformer_engine_torch import NVTE_Fused_Attn_Backend

Expand Down Expand Up @@ -45,32 +45,32 @@

model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
# "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
# "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
# "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(8, 64, 8, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
}

# Define DataFrame indices and columns
indices = [model for model in model_configs.keys()]
columns = [
"FusedAttention Module",
"FusedAttention Kernels (fwd)",
"FusedAttention Kernels (bwd)",
"FusedAttention Kernels (fwd+bwd)",
"FlashAttention Module",
"FlashAttention Kernels (fwd)",
"FlashAttention Kernels (bwd)",
"FlashAttention Kernels (fwd+bwd)",
"Fused vs Flash Kernels Speedup (fwd+bwd)",
# "FusedAttention Module",
# "FusedAttention Kernels (fwd)",
# "FusedAttention Kernels (bwd)",
# "FusedAttention Kernels (fwd+bwd)",
# "FlashAttention Module",
# "FlashAttention Kernels (fwd)",
# "FlashAttention Kernels (bwd)",
# "FlashAttention Kernels (fwd+bwd)",
# "Fused vs Flash Kernels Speedup (fwd+bwd)",
"FusedAttention CK Module",
"FusedAttention CK Kernels (fwd)",
"FusedAttention CK Kernels (bwd)",
"FusedAttention CK Kernels (fwd+bwd)",
"FusedAttention AOTriton Module",
"FusedAttention AOTriton Kernels (fwd)",
"FusedAttention AOTriton Kernels (bwd)",
"FusedAttention AOTriton Kernels (fwd+bwd)",
# "FusedAttention AOTriton Module",
# "FusedAttention AOTriton Kernels (fwd)",
# "FusedAttention AOTriton Kernels (bwd)",
# "FusedAttention AOTriton Kernels (fwd+bwd)",
]

output_csv="times.csv"
Expand Down Expand Up @@ -103,8 +103,8 @@ def benchmark_dot_product_attention(model, attention, column_name, filename):
f"""'{model}', '{attention}', '{column_name}')" """,
]
prof_cmd = " ".join(prof_cmd)
print(prof_cmd)
subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)

if os.path.exists("results.stats.csv"):
shutil.move("results.stats.csv", filename)
else:
Expand All @@ -128,6 +128,7 @@ def benchmark_dot_product_attention_profiler(model, attention, column_name):
pad_between_seqs,
is_training,
)
print("++++++++++++++RUN+++++++++++++++++++++")
torch.cuda.synchronize()
attn_time = time.time() - attn_start

Expand Down Expand Up @@ -219,19 +220,23 @@ def main():
)

filename_flash_attn, filename_fused_attn, filename_fused_ck, filename_fused_aotriton = None, None, None, None
print(fused_attn_backends)
# Benchmark for each attention backend
if flash_attn_supported:
print("===============================================")
filename_flash_attn = os.path.join("profiler_outputs/", f"prof_flash_{model}.csv")
benchmark_dot_product_attention(model, "FlashAttention", "FlashAttention Module", filename_flash_attn)

if fused_attn_supported:
print("============================")
filename_fused_attn = os.path.join("profiler_outputs/", f"prof_fused_{model}.csv")
benchmark_dot_product_attention(model, "FusedAttention", "FusedAttention Module", filename_fused_attn)

if NVTE_Fused_Attn_Backend.NVTE_CK in fused_attn_backends:
#CK Backend
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0"
os.environ["NVTE_CK_USES_BWD_V3"] = "1"
os.environ["NVTE_CK_USES_FWD_V3"] = "1"
os.environ["NVTE_FUSED_ATTN_CK"] = "1"
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
Expand Down Expand Up @@ -260,11 +265,11 @@ def main():
a = df_times[
[
"FusedAttention Kernels (fwd+bwd)",
"FlashAttention Kernels (fwd+bwd)",
"Fused vs Flash Kernels Speedup (fwd+bwd)",
# "FlashAttention Kernels (fwd+bwd)",
# "Fused vs Flash Kernels Speedup (fwd+bwd)",
]
]
a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)", "cuDNN vs flash speedup"]
# a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)", "cuDNN vs flash speedup"]
print()
print(a)

Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def setup_common_extension() -> CMakeExtension:
if os.getenv("NVTE_AOTRITON_PATH"):
aotriton_path = Path(os.getenv("NVTE_AOTRITON_PATH"))
cmake_flags.append(f"-DAOTRITON_PATH={aotriton_path}")
cmake_flags.append(f"-DCK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT={os.getenv('NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT', 3)}")
if os.getenv("NVTE_CK_FUSED_ATTN_PATH"):
ck_path = Path(os.getenv("NVTE_CK_FUSED_ATTN_PATH"))
cmake_flags.append(f"-DCK_FUSED_ATTN_PATH={ck_path}")
Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ option(USE_ROCBLAS "Use ROCBLAS" ON)
option(USE_FUSED_ATTN_AOTRITON "Use aotriton backend" ON)
option(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS "Build AOTriton GPU kernels" OFF)
option(USE_FUSED_ATTN_CK "Use ck backend" ON)
# TODO: fix aiter build path
option(USE_FUSED_ATTN_AITER "Use aiter backend" ON)
set(USE_CUDA OFF)

if (USE_ROCM)
Expand Down Expand Up @@ -103,6 +105,7 @@ if(USE_ROCM)
message(STATUS "CMAKE_HIP_ARCHITECTURES: ${CMAKE_HIP_ARCHITECTURES}")
message(STATUS "USE_HIPBLASLT ${USE_HIPBLASLT} USE_ROCBLAS ${USE_ROCBLAS}")
message(STATUS "USE_FUSED_ATTN_CK ${USE_FUSED_ATTN_CK}")
message(STATUS "USE_FUSED_ATTN_AITER ${USE_FUSED_ATTN_AITER}")
message(STATUS "USE_FUSED_ATTN_AOTRITON ${USE_FUSED_ATTN_AOTRITON}")
endif()

Expand Down Expand Up @@ -338,7 +341,6 @@ else()

if(USE_FUSED_ATTN_CK)
if(NOT DEFINED CK_FUSED_ATTN_PATH)
set(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT ${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} CACHE STRING "ck float to bf16 conversion rounding")
set(CK_FUSED_ATTN_TARGET_GPUS ${CMAKE_HIP_ARCHITECTURES} CACHE STRING "Target arch to compile ck fused attn backend")
add_subdirectory(ck_fused_attn ${CMAKE_CURRENT_BINARY_DIR}/ck_fused_attn)
else()
Expand Down
111 changes: 96 additions & 15 deletions transformer_engine/common/ck_fused_attn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,45 +1,87 @@
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT

# This file is for temporary use
cmake_minimum_required(VERSION 3.21)
set(CMAKE_CXX_STANDARD 17)
project(ck_fused_attn LANGUAGES HIP CXX)
# generate ck fused attn kernels, both fwd/bwd
# generate fused attn kernels, both fwd/bwd

set(CK_FUSED_ATTN_TARGET_GPUS "gfx942,gfx950" CACHE STRING "Target Architecture to build ck_fused_attn backend")
set(CK_FUSED_ATTN_TARGET_GPUS "gfx950" CACHE STRING "Target Architecture to build ck_fused_attn backend")

# remove files that should be regenerated
file(REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp ${CMAKE_CURRENT_BINARY_DIR}/gen_src/blob_list.txt)

# create gen_src and gen_src/tmp directories if needed
file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp)

set(__CK_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/composable_kernel")
set(__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/aiter")
set(__CK_SOURCE_DIR "${__AITER_SOURCE_DIR}/3rdparty/composable_kernel")

#fwd kernels list
execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_blob_list.txt
--api fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_blob_list.txt --receipt 600
)
execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_splitkv_blob_list.txt --receipt 600
)
execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api batch_prefill --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_batch_prefill_blob_list.txt --receipt 600
)

#bwd kernels list
execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/bwd_blob_list.txt --receipt 5
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/gen_src/bwd_blob_list.txt --receipt 600
)

file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_splitkv_blob_list.txt FMHA_FWD_SPLITKV_GEN_BLOBS)
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/fwd_batch_prefill_blob_list.txt FMHA_FWD_BATCH_PREFILL_GEN_BLOBS)
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gen_src/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)

# generate the actual fwd kernel cpp files
execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp
--api fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600
)

execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600
)

execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api batch_prefill --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600
)

execute_process(
COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/cpp_itfs/mha_fwd_generate.py
--output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 5
)

execute_process(
COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/py_itfs_cu/fmha_v3_fwd_kernel_generate.py
--output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp
)

# generate the actual bwd kernel cpp files
execute_process(
COMMAND python3 ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 5
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp --receipt 600
)

execute_process(
COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/py_itfs_cu/fmha_v3_bwd_kernel_generate.py
--receipt 1 --filter *@*_ndeterministic@*_nbias*_dropout*_ndeterministic* --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp
)

execute_process(
COMMAND python3 ${__AITER_SOURCE_DIR}/csrc/cpp_itfs/mha_bwd_generate.py
--receipt 3 --output_dir ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp
)

set(ck_fused_attn_SOURCES)
Expand All @@ -48,33 +90,61 @@ list(APPEND ck_fused_attn_SOURCES
src/ck_fused_attn_bwd.cpp
src/ck_fused_attn_utils.cpp)

# Update all new and modified kernels and add to ck_fused_attn
foreach(blob ${FMHA_FWD_GEN_BLOBS})
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
endforeach()
list(APPEND ck_fused_attn_SOURCES ${FMHA_FWD_GEN_BLOBS})

foreach(blob ${FMHA_FWD_SPLITKV_GEN_BLOBS})
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
endforeach()
list(APPEND ck_fused_attn_SOURCES ${FMHA_FWD_SPLITKV_GEN_BLOBS})

foreach(blob ${FMHA_FWD_BATCH_PREFILL_GEN_BLOBS})
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
endforeach()
list(APPEND ck_fused_attn_SOURCES ${FMHA_FWD_BATCH_PREFILL_GEN_BLOBS})

foreach(blob ${FMHA_BWD_GEN_BLOBS})
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${blob})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${blob} ONLY_IF_DIFFERENT)
endforeach()
list(APPEND ck_fused_attn_SOURCES ${FMHA_BWD_GEN_BLOBS})

set(MHA_BWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/mha_bwd.cpp")
set(MHA_FWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/mha_fwd.cpp")
set(ASM_MHA_BWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/asm_fmha_bwd_v3.cpp")
set(ASM_MHA_FWD_SRC "${CMAKE_CURRENT_BINARY_DIR}/gen_src/asm_fmha_fwd_v3.cpp")

# TODO: other generated files need to be added to `ck_fused_attn_SOURCES`
file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${MHA_BWD_SRC})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${MHA_BWD_SRC} ONLY_IF_DIFFERENT)

file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${MHA_FWD_SRC})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${MHA_FWD_SRC} ONLY_IF_DIFFERENT)

file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${ASM_MHA_BWD_SRC})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${ASM_MHA_BWD_SRC} ONLY_IF_DIFFERENT)

file(RELATIVE_PATH blob_path ${CMAKE_CURRENT_BINARY_DIR}/gen_src ${ASM_MHA_FWD_SRC})
file(COPY_FILE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp/${blob_path} ${ASM_MHA_FWD_SRC} ONLY_IF_DIFFERENT)

list(APPEND ck_fused_attn_SOURCES ${MHA_BWD_SRC} ${MHA_FWD_SRC} ${ASM_MHA_BWD_SRC} ${ASM_MHA_FWD_SRC})

# remove all previously generated temporary files
file(REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR}/gen_src/tmp)

# Glob all hsaco .cpp files and append to ck_fused_attn
file(GLOB_RECURSE CK_HSACO_FILES "${__CK_SOURCE_DIR}/example/ck_tile/01_fmha/hsaco/*.cpp")
list(APPEND ck_fused_attn_SOURCES ${CK_HSACO_FILES})

message(STATUS "Found the following CK fused attention files:")
message(STATUS "Found the following fused attention files:")
foreach(file ${ck_fused_attn_SOURCES})
message(STATUS " ${file}")
endforeach()

add_library(ck_fused_attn STATIC ${ck_fused_attn_SOURCES})
set(CK_FUSED_ATTN_COMPILE_OPTIONS)
list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -DCK_TILE_FMHA_FWD_SPLITKV_API=0 -DCK_TILE_FMHA_FWD_APPENDKV_API=0 -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} -fgpu-flush-denormals-to-zero -Wno-float-equal -ftemplate-backtrace-limit=0 -fPIC -Wno-gnu-line-marker -Wunused-variable -Wuninitialized -Werror)
list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -DCK_TILE_FMHA_FWD_SPLITKV_API=1 -DCK_TILE_FMHA_FWD_APPENDKV_API=0 -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=0 -fgpu-flush-denormals-to-zero -Wno-float-equal -ftemplate-backtrace-limit=0 -fPIC -Wno-gnu-line-marker -Wunused-variable -Wuninitialized "SHELL:-mllvm -enable-post-misched=0" "SHELL:-mllvm -amdgpu-early-inline-all=true" "SHELL:-mllvm -amdgpu-function-calls=false" "SHELL:-mllvm -amdgpu-coerce-illegal-types=1" "SHELL:-mllvm --amdgpu-kernarg-preload-count=16")
foreach(rocm_arch ${CK_FUSED_ATTN_TARGET_GPUS})
list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS --offload-arch=${rocm_arch})
endforeach()
Expand All @@ -89,8 +159,19 @@ if(NOT EXISTS "${CK_INCLUDE_DIR}")
"within the Transformer Engine source.")
endif()

set(AITER_INCLUDE_DIR "${__AITER_SOURCE_DIR}/csrc/include")
message(STATUS "aiter_include_dir: ${AITER_INCLUDE_DIR}")

if(NOT EXISTS "${AITER_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find AITER API. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source.")
endif()

target_include_directories(ck_fused_attn PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include")
target_include_directories(ck_fused_attn PRIVATE ${CK_INCLUDE_DIR} ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha)
target_include_directories(ck_fused_attn PRIVATE ${AITER_INCLUDE_DIR})

find_package(hip)
list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64)
Expand Down
Loading