From 7edaad893d6f3876672877129a45f4e219eeeed6 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 19 Feb 2026 17:21:56 -0600 Subject: [PATCH 01/10] Update ck_fused_attn logging to direct to thread-specific files --- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 362 ++++++++++-------- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 180 +++++---- .../ck_fused_attn/src/ck_fused_attn_utils.hpp | 26 +- 3 files changed, 314 insertions(+), 254 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 3f51b96b6..3d04ead5e 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -16,6 +16,22 @@ namespace ck_fused_attn{ +// We want to cache and reuse the log stream so we use thread_local here. +namespace { +std::ofstream* get_bwd_log_stream() { + thread_local std::ofstream log_file; + thread_local bool attempted = false; + if (!attempted) { + attempted = true; + open_ck_fused_attn_log_file(log_file, "ck_fused_attn_bwd"); + } + if (!log_file.is_open()) { + return nullptr; + } + return &log_file; +} +} // namespace + // TODO: unify with binary search in TE/common/fused_attn(rocm)/util // no device std::upper_bound // in an increasing array with given size len, search for the index that: @@ -346,110 +362,104 @@ void log_bwd_config(const char* func_name, const bool is_v3_atomic_fp32, const int how_v3_bf16_cvt, const fmha_bwd_args& fmha_args){ - - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } - if (ck_fused_attn_log_config) { - std::cout<::type>(mask_type)<::type>(bias_type)<::type>(mask_type) << "\n"; + *log_file << "bias_type: " << static_cast::type>(bias_type) << "\n"; + *log_file << "has_dbias: " << has_dbias << "\n"; + *log_file << "has_dropout: " << has_dropout << "\n"; + *log_file << "is_store_randval: " << is_store_randval << "\n"; + *log_file << "is_deterministic: " << is_deterministic << "\n"; + *log_file << "uses_bwd_v3: " << uses_bwd_v3 << "\n"; + *log_file << "is_v3_atomic_fp32: " << is_v3_atomic_fp32 << "\n"; + *log_file << "how_v3_bf16_cvt: " << how_v3_bf16_cvt << "\n"; // fmha_args debug - std::cout<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset)) << "\n"; + *log_file << "dropout_offset_ptr: " << std::get<1>(std::get>(fmha_args.drop_seed_offset)) << "\n"; } } @@ -531,7 +541,7 @@ hipError_t ck_attn_bwd( mask_enum mask_type = static_cast(attn_mask_type); bool ck_fused_attn_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") + if (env_p != nullptr && std::string(env_p) != "") ck_fused_attn_log_config = true; } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); @@ -708,17 +718,19 @@ hipError_t ck_attn_bwd( if (d_qk == d_v) { dim3 block(d_qk); if (ck_fused_attn_log_config){ - std::cout<(dbias_ptr));); }else if(bias_shape==BiasShape::k1HSS){ if (ck_fused_attn_log_config){ - std::cout<(dbias_ptr));); }else if(bias_shape==BiasShape::kB1SS){ if (ck_fused_attn_log_config){ - std::cout<::type>(mask_type)<::type>(bias_type)<::type>(mask_type) << "\n"; + *log_file << "bias_type: " << static_cast::type>(bias_type) << "\n"; + *log_file << "has_lse: " << has_lse << "\n"; + *log_file << "has_dropout: " << has_dropout << "\n"; + *log_file << "do_fp8_static_quant: " << do_fp8_static_quant << "\n"; + *log_file << "skip_min_seqlen_q: " << (fmha_args.min_seqlen_q != 0) << "\n"; + *log_file << "uses_fwd_v3: " << uses_fwd_v3 << "\n"; + *log_file << "how_v3_bf16_cvt: " << how_v3_bf16_cvt << "\n"; // debug fmha_args - std::cout<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset)) << "\n"; + *log_file << "dropout_offset_ptr: " << std::get<1>(std::get>(fmha_args.drop_seed_offset)) << "\n"; } } @@ -181,7 +191,7 @@ hipError_t ck_attn_fwd( bool ck_fused_attn_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") + if (env_p != nullptr && std::string(env_p) != "") ck_fused_attn_log_config = true; } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); @@ -356,7 +366,7 @@ hipError_t ck_attn_varlen_fwd( bool ck_fused_attn_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") + if (env_p != nullptr && std::string(env_p) != "") ck_fused_attn_log_config = true; } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp index a75915ee2..b0ba19b08 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp @@ -7,8 +7,12 @@ #ifndef CK_FUSED_ATTN_UTILS_H #define CK_FUSED_ATTN_UTILS_H -#include -#include +#include +#include +#include +#include +#include +#include #include //forward declaration for ck_tile enum @@ -56,5 +60,23 @@ std::pair get_ck_bias_type_shape(BiasType attn_bias_type, uint64_t get_runtime_max_seqlen(uint64_t b, const void* cu_seqlen_ptr, const void* cu_seqlen_padded_ptr, void* workspace, hipStream_t stream); +inline bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefix) { + const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG"); + if (env_p == nullptr) { + return false; + } + const std::string log_dir_str(env_p); + if (log_dir_str.empty() || log_dir_str == "0") { + return false; + } + std::filesystem::path log_dir(log_dir_str); + std::error_code ec; + std::filesystem::create_directories(log_dir, ec); + std::ostringstream filename; + filename << file_prefix << "_" << getpid() << "_" << std::this_thread::get_id() << ".log"; + log_file.open(log_dir / filename.str(), std::ios_base::app); + return log_file.is_open(); +} + }//namespace ck_fused_attn #endif // CK_FUSED_ATTN_UTILS_H From 13920f8012f11240a7bb017653035480ce2a9d41 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 19 Feb 2026 17:23:42 -0600 Subject: [PATCH 02/10] Added error logging --- .../common/ck_fused_attn/src/ck_fused_attn_utils.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp index b0ba19b08..10cd28bfd 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp @@ -72,6 +72,10 @@ inline bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* fil std::filesystem::path log_dir(log_dir_str); std::error_code ec; std::filesystem::create_directories(log_dir, ec); + if(ec){ + std::cerr << "Failed to create log directory: " << log_dir_str << ", error: " << ec.message() << std::endl; + return false; + } std::ostringstream filename; filename << file_prefix << "_" << getpid() << "_" << std::this_thread::get_id() << ".log"; log_file.open(log_dir / filename.str(), std::ios_base::app); From d33b4994759482a714316b35710556239a823cf4 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 19 Feb 2026 17:26:09 -0600 Subject: [PATCH 03/10] Moved function body out of header --- .../ck_fused_attn/src/ck_fused_attn_utils.cpp | 22 +++++++++++++++++++ .../ck_fused_attn/src/ck_fused_attn_utils.hpp | 22 +------------------ 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index 26c92ca2b..86abf4ffe 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -13,6 +13,28 @@ namespace ck_fused_attn{ +bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefix) { + const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG"); + if (env_p == nullptr) { + return false; + } + const std::string log_dir_str(env_p); + if (log_dir_str.empty() || log_dir_str == "0") { + return false; + } + std::filesystem::path log_dir(log_dir_str); + std::error_code ec; + std::filesystem::create_directories(log_dir, ec); + if(ec){ + std::cerr << "Failed to create log directory: " << log_dir_str << ", error: " << ec.message() << std::endl; + return false; + } + std::ostringstream filename; + filename << file_prefix << "_" << getpid() << "_" << std::this_thread::get_id() << ".log"; + log_file.open(log_dir / filename.str(), std::ios_base::app); + return log_file.is_open(); +} + std::string get_data_type_str(DType dtype){ std::string data_type_str; if(dtype==DType::kFloat16){ diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp index 10cd28bfd..6b1dbe711 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp @@ -60,27 +60,7 @@ std::pair get_ck_bias_type_shape(BiasType attn_bias_type, uint64_t get_runtime_max_seqlen(uint64_t b, const void* cu_seqlen_ptr, const void* cu_seqlen_padded_ptr, void* workspace, hipStream_t stream); -inline bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefix) { - const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG"); - if (env_p == nullptr) { - return false; - } - const std::string log_dir_str(env_p); - if (log_dir_str.empty() || log_dir_str == "0") { - return false; - } - std::filesystem::path log_dir(log_dir_str); - std::error_code ec; - std::filesystem::create_directories(log_dir, ec); - if(ec){ - std::cerr << "Failed to create log directory: " << log_dir_str << ", error: " << ec.message() << std::endl; - return false; - } - std::ostringstream filename; - filename << file_prefix << "_" << getpid() << "_" << std::this_thread::get_id() << ".log"; - log_file.open(log_dir / filename.str(), std::ios_base::app); - return log_file.is_open(); -} +bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefix); }//namespace ck_fused_attn #endif // CK_FUSED_ATTN_UTILS_H From 85a52f7e642317e931579090ba1150848bb3affa Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 20 Feb 2026 14:43:18 -0600 Subject: [PATCH 04/10] Removed risky dir-create and streamlined header --- .../common/ck_fused_attn/src/ck_fused_attn_utils.cpp | 10 ++++------ .../common/ck_fused_attn/src/ck_fused_attn_utils.hpp | 4 ---- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index 86abf4ffe..f601bf060 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -5,6 +5,10 @@ ************************************************************************/ #include +#include +#include +#include +#include #include "ck_fused_attn_utils.hpp" #include "ck_fused_attn/ck_fused_attn.hpp" #include "mask.hpp" @@ -23,12 +27,6 @@ bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefi return false; } std::filesystem::path log_dir(log_dir_str); - std::error_code ec; - std::filesystem::create_directories(log_dir, ec); - if(ec){ - std::cerr << "Failed to create log directory: " << log_dir_str << ", error: " << ec.message() << std::endl; - return false; - } std::ostringstream filename; filename << file_prefix << "_" << getpid() << "_" << std::this_thread::get_id() << ".log"; log_file.open(log_dir / filename.str(), std::ios_base::app); diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp index 6b1dbe711..cac1a0b9d 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp @@ -9,10 +9,6 @@ #include #include -#include -#include -#include -#include #include //forward declaration for ck_tile enum From 11184b810d2bdd1ab18a9aa6b9d5473567f7f2b1 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 23 Feb 2026 10:27:28 -0600 Subject: [PATCH 05/10] Minor refactor --- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 207 ++++++++---------- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 37 ++-- .../ck_fused_attn/src/ck_fused_attn_utils.cpp | 16 +- .../ck_fused_attn/src/ck_fused_attn_utils.hpp | 3 +- 4 files changed, 126 insertions(+), 137 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 3d04ead5e..12a9a9c0f 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -18,16 +18,28 @@ namespace ck_fused_attn{ // We want to cache and reuse the log stream so we use thread_local here. namespace { -std::ofstream* get_bwd_log_stream() { +std::ostream* get_bwd_log_stream() { thread_local std::ofstream log_file; thread_local bool attempted = false; + thread_local bool opened = false; + thread_local bool requested = false; + thread_local std::string log_dir_str; if (!attempted) { attempted = true; - open_ck_fused_attn_log_file(log_file, "ck_fused_attn_bwd"); + if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG")) { + log_dir_str = std::string(env_p); + requested = !log_dir_str.empty() && log_dir_str != "0"; + } + if (requested) { + opened = open_ck_fused_attn_log_file(log_file, "ck_fused_attn_bwd", log_dir_str); + } } - if (!log_file.is_open()) { + if (!requested) { return nullptr; } + if (!opened) { + return &std::cout; + } return &log_file; } } // namespace @@ -539,15 +551,10 @@ hipError_t ck_attn_bwd( right = window_size_right; mask_enum mask_type = static_cast(attn_mask_type); - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) != "") - ck_fused_attn_log_config = true; - } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_bwd_log_stream() != nullptr}; ck_tile::index_t shape_seqlen_q = seqlen_q; ck_tile::index_t shape_seqlen_k = seqlen_k; @@ -717,20 +724,18 @@ hipError_t ck_attn_bwd( dim3 grid(b, s_kv, hg); if (d_qk == d_v) { dim3 block(d_qk); - if (ck_fused_attn_log_config){ - if (auto* log_file = get_bwd_log_stream()) { - *log_file << "\n" << "run dk_dv_reduce: " << "\n"; - *log_file << "dk_expanded_ptr: " << dk_expanded_ptr << "\n"; - *log_file << "dv_expanded_ptr: " << dv_expanded_ptr << "\n"; - *log_file << "stride_b_dkv_expanded: " << stride_b_dk_expanded << "\n"; - *log_file << "stride_h_dkv_expanded: " << stride_h_dk_expanded << "\n"; - *log_file << "stride_s_dkv_expanded: " << stride_s_dk_expanded << "\n"; - *log_file << "dk_ptr: " << dk_ptr << "\n"; - *log_file << "dv_ptr: " << dv_ptr << "\n"; - *log_file << "stride_b_dk: " << stride_b_dk << "\n"; - *log_file << "stride_h_dk: " << stride_h_dk << "\n"; - *log_file << "stride_s_dk: " << stride_s_dk << "\n"; - } + if (auto* log_file = get_bwd_log_stream()) { + *log_file << "\n" << "run dk_dv_reduce: " << "\n"; + *log_file << "dk_expanded_ptr: " << dk_expanded_ptr << "\n"; + *log_file << "dv_expanded_ptr: " << dv_expanded_ptr << "\n"; + *log_file << "stride_b_dkv_expanded: " << stride_b_dk_expanded << "\n"; + *log_file << "stride_h_dkv_expanded: " << stride_h_dk_expanded << "\n"; + *log_file << "stride_s_dkv_expanded: " << stride_s_dk_expanded << "\n"; + *log_file << "dk_ptr: " << dk_ptr << "\n"; + *log_file << "dv_ptr: " << dv_ptr << "\n"; + *log_file << "stride_b_dk: " << stride_b_dk << "\n"; + *log_file << "stride_h_dk: " << stride_h_dk << "\n"; + *log_file << "stride_s_dk: " << stride_s_dk << "\n"; } CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE, hipLaunchKernelGGL( @@ -744,18 +749,16 @@ hipError_t ck_attn_bwd( stride_b_dk, stride_h_dk, stride_s_dk);); } else { dim3 block_dk(d_qk); - if (ck_fused_attn_log_config){ - if (auto* log_file = get_bwd_log_stream()) { - *log_file << "\n" << "run dk_or_dv_reduce on dk: " << "\n"; - *log_file << "dk_expanded_ptr: " << dk_expanded_ptr << "\n"; - *log_file << "stride_b_dk_expanded: " << stride_b_dk_expanded << "\n"; - *log_file << "stride_h_dk_expanded: " << stride_h_dk_expanded << "\n"; - *log_file << "stride_s_dk_expanded: " << stride_s_dk_expanded << "\n"; - *log_file << "dk_ptr: " << dk_ptr << "\n"; - *log_file << "stride_b_dk: " << stride_b_dk << "\n"; - *log_file << "stride_h_dk: " << stride_h_dk << "\n"; - *log_file << "stride_s_dk: " << stride_s_dk << "\n"; - } + if (auto* log_file = get_bwd_log_stream()) { + *log_file << "\n" << "run dk_or_dv_reduce on dk: " << "\n"; + *log_file << "dk_expanded_ptr: " << dk_expanded_ptr << "\n"; + *log_file << "stride_b_dk_expanded: " << stride_b_dk_expanded << "\n"; + *log_file << "stride_h_dk_expanded: " << stride_h_dk_expanded << "\n"; + *log_file << "stride_s_dk_expanded: " << stride_s_dk_expanded << "\n"; + *log_file << "dk_ptr: " << dk_ptr << "\n"; + *log_file << "stride_b_dk: " << stride_b_dk << "\n"; + *log_file << "stride_h_dk: " << stride_h_dk << "\n"; + *log_file << "stride_s_dk: " << stride_s_dk << "\n"; } CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE, hipLaunchKernelGGL( @@ -767,18 +770,16 @@ hipError_t ck_attn_bwd( stride_b_dk, stride_h_dk, stride_s_dk);); dim3 block_dv(d_v); - if (ck_fused_attn_log_config){ - if (auto* log_file = get_bwd_log_stream()) { - *log_file << "\n" << "run dk_or_dv_reduce on dv: " << "\n"; - *log_file << "dv_expanded_ptr: " << dv_expanded_ptr << "\n"; - *log_file << "stride_b_dv_expanded: " << stride_b_dv_expanded << "\n"; - *log_file << "stride_h_dv_expanded: " << stride_h_dv_expanded << "\n"; - *log_file << "stride_s_dv_expanded: " << stride_s_dv_expanded << "\n"; - *log_file << "dv_ptr: " << dv_ptr << "\n"; - *log_file << "stride_b_dv: " << stride_b_dv << "\n"; - *log_file << "stride_h_dv: " << stride_h_dv << "\n"; - *log_file << "stride_s_dv: " << stride_s_dv << "\n"; - } + if (auto* log_file = get_bwd_log_stream()) { + *log_file << "\n" << "run dk_or_dv_reduce on dv: " << "\n"; + *log_file << "dv_expanded_ptr: " << dv_expanded_ptr << "\n"; + *log_file << "stride_b_dv_expanded: " << stride_b_dv_expanded << "\n"; + *log_file << "stride_h_dv_expanded: " << stride_h_dv_expanded << "\n"; + *log_file << "stride_s_dv_expanded: " << stride_s_dv_expanded << "\n"; + *log_file << "dv_ptr: " << dv_ptr << "\n"; + *log_file << "stride_b_dv: " << stride_b_dv << "\n"; + *log_file << "stride_h_dv: " << stride_h_dv << "\n"; + *log_file << "stride_s_dv: " << stride_s_dv << "\n"; } CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE, hipLaunchKernelGGL( @@ -797,12 +798,10 @@ hipError_t ck_attn_bwd( dim3 block(THREADS_PER_BLOCK); dim3 grid(ceil(1.0 * s_q * s_kv/THREADS_PER_BLOCK)); if(bias_shape==BiasShape::k11SS){ - if (ck_fused_attn_log_config){ - if (auto* log_file = get_bwd_log_stream()) { - *log_file << "\n" << "run dbias_reduce_11SS: " << "\n"; - *log_file << "dbias_ptr: " << dbias_ptr << "\n"; - *log_file << "dbias_expanded_ptr: " << dbias_expanded_ptr << "\n"; - } + if (auto* log_file = get_bwd_log_stream()) { + *log_file << "\n" << "run dbias_reduce_11SS: " << "\n"; + *log_file << "dbias_ptr: " << dbias_ptr << "\n"; + *log_file << "dbias_expanded_ptr: " << dbias_expanded_ptr << "\n"; } CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE, hipLaunchKernelGGL( @@ -811,12 +810,10 @@ hipError_t ck_attn_bwd( static_cast(dbias_expanded_ptr), static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::k1HSS){ - if (ck_fused_attn_log_config){ - if (auto* log_file = get_bwd_log_stream()) { - *log_file << "\n" << "run dbias_reduce_1HSS: " << "\n"; - *log_file << "dbias_ptr: " << dbias_ptr << "\n"; - *log_file << "dbias_expanded_ptr: " << dbias_expanded_ptr << "\n"; - } + if (auto* log_file = get_bwd_log_stream()) { + *log_file << "\n" << "run dbias_reduce_1HSS: " << "\n"; + *log_file << "dbias_ptr: " << dbias_ptr << "\n"; + *log_file << "dbias_expanded_ptr: " << dbias_expanded_ptr << "\n"; } CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE, hipLaunchKernelGGL( @@ -825,12 +822,10 @@ hipError_t ck_attn_bwd( static_cast(dbias_expanded_ptr), static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::kB1SS){ - if (ck_fused_attn_log_config){ - if (auto* log_file = get_bwd_log_stream()) { - *log_file << "\n" << "run dbias_reduce_B1SS: " << "\n"; - *log_file << "dbias_ptr: " << dbias_ptr << "\n"; - *log_file << "dbias_expanded_ptr: " << dbias_expanded_ptr << "\n"; - } + if (auto* log_file = get_bwd_log_stream()) { + *log_file << "\n" << "run dbias_reduce_B1SS: " << "\n"; + *log_file << "dbias_ptr: " << dbias_ptr << "\n"; + *log_file << "dbias_expanded_ptr: " << dbias_expanded_ptr << "\n"; } CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE, hipLaunchKernelGGL( @@ -907,14 +902,9 @@ hipError_t ck_attn_varlen_bwd( right = window_size_right; mask_enum mask_type = static_cast(attn_mask_type); - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) != "") - ck_fused_attn_log_config = true; - } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_bwd_log_stream() != nullptr}; std::string data_type_str = get_data_type_str(dtype); @@ -1056,8 +1046,9 @@ hipError_t ck_attn_varlen_bwd( // lse_thd_ptr used as buffer if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { if(std::string(env_p) == "1"){ - if(ck_fused_attn_log_config){ - std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + if (auto* log_file = get_bwd_log_stream()) { + *log_file + << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization.\n"; } fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream); fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream); @@ -1090,20 +1081,18 @@ hipError_t ck_attn_varlen_bwd( dim3 grid(max_tokens_kv, hg); if (d_qk == d_v) { dim3 block(d_qk); - if (ck_fused_attn_log_config){ - if (auto* log_file = get_bwd_log_stream()) { - *log_file << "\n" << "run dk_dv_reduce_thd: " << "\n"; - *log_file << "cu_seqlen_kv_ptr: " << cu_seqlen_kv_ptr << "\n"; - *log_file << "cu_seqlen_kv_padded_ptr: " << cu_seqlen_kv_padded_ptr << "\n"; - *log_file << "dk_expanded_ptr: " << dk_expanded_ptr << "\n"; - *log_file << "dv_expanded_ptr: " << dv_expanded_ptr << "\n"; - *log_file << "stride_h_dkv_expanded: " << stride_h_dk_expanded << "\n"; - *log_file << "stride_s_dkv_expanded: " << stride_s_dk_expanded << "\n"; - *log_file << "dk_ptr: " << dk_ptr << "\n"; - *log_file << "dv_ptr: " << dv_ptr << "\n"; - *log_file << "stride_h_dk: " << stride_h_dk << "\n"; - *log_file << "stride_s_dk: " << stride_s_dk << "\n"; - } + if (auto* log_file = get_bwd_log_stream()) { + *log_file << "\n" << "run dk_dv_reduce_thd: " << "\n"; + *log_file << "cu_seqlen_kv_ptr: " << cu_seqlen_kv_ptr << "\n"; + *log_file << "cu_seqlen_kv_padded_ptr: " << cu_seqlen_kv_padded_ptr << "\n"; + *log_file << "dk_expanded_ptr: " << dk_expanded_ptr << "\n"; + *log_file << "dv_expanded_ptr: " << dv_expanded_ptr << "\n"; + *log_file << "stride_h_dkv_expanded: " << stride_h_dk_expanded << "\n"; + *log_file << "stride_s_dkv_expanded: " << stride_s_dk_expanded << "\n"; + *log_file << "dk_ptr: " << dk_ptr << "\n"; + *log_file << "dv_ptr: " << dv_ptr << "\n"; + *log_file << "stride_h_dk: " << stride_h_dk << "\n"; + *log_file << "stride_s_dk: " << stride_s_dk << "\n"; } CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE, hipLaunchKernelGGL( @@ -1119,18 +1108,16 @@ hipError_t ck_attn_varlen_bwd( stride_h_dk, stride_s_dk);); } else { dim3 block_dk(d_qk); - if (ck_fused_attn_log_config){ - if (auto* log_file = get_bwd_log_stream()) { - *log_file << "\n" << "run dk_or_dv_reduce_thd on dk: " << "\n"; - *log_file << "cu_seqlen_kv_ptr: " << cu_seqlen_kv_ptr << "\n"; - *log_file << "cu_seqlen_kv_padded_ptr: " << cu_seqlen_kv_padded_ptr << "\n"; - *log_file << "dk_expanded_ptr: " << dk_expanded_ptr << "\n"; - *log_file << "stride_h_dk_expanded: " << stride_h_dk_expanded << "\n"; - *log_file << "stride_s_dk_expanded: " << stride_s_dk_expanded << "\n"; - *log_file << "dk_ptr: " << dk_ptr << "\n"; - *log_file << "stride_h_dk: " << stride_h_dk << "\n"; - *log_file << "stride_s_dk: " << stride_s_dk << "\n"; - } + if (auto* log_file = get_bwd_log_stream()) { + *log_file << "\n" << "run dk_or_dv_reduce_thd on dk: " << "\n"; + *log_file << "cu_seqlen_kv_ptr: " << cu_seqlen_kv_ptr << "\n"; + *log_file << "cu_seqlen_kv_padded_ptr: " << cu_seqlen_kv_padded_ptr << "\n"; + *log_file << "dk_expanded_ptr: " << dk_expanded_ptr << "\n"; + *log_file << "stride_h_dk_expanded: " << stride_h_dk_expanded << "\n"; + *log_file << "stride_s_dk_expanded: " << stride_s_dk_expanded << "\n"; + *log_file << "dk_ptr: " << dk_ptr << "\n"; + *log_file << "stride_h_dk: " << stride_h_dk << "\n"; + *log_file << "stride_s_dk: " << stride_s_dk << "\n"; } CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE, hipLaunchKernelGGL( @@ -1144,18 +1131,16 @@ hipError_t ck_attn_varlen_bwd( stride_h_dk, stride_s_dk);); dim3 block_dv(d_v); - if (ck_fused_attn_log_config){ - if (auto* log_file = get_bwd_log_stream()) { - *log_file << "\n" << "run dk_or_dv_reduce_thd on dv: " << "\n"; - *log_file << "cu_seqlen_kv_ptr: " << cu_seqlen_kv_ptr << "\n"; - *log_file << "cu_seqlen_kv_padded_ptr: " << cu_seqlen_kv_padded_ptr << "\n"; - *log_file << "dv_expanded_ptr: " << dv_expanded_ptr << "\n"; - *log_file << "stride_h_dv_expanded: " << stride_h_dv_expanded << "\n"; - *log_file << "stride_s_dv_expanded: " << stride_s_dv_expanded << "\n"; - *log_file << "dv_ptr: " << dv_ptr << "\n"; - *log_file << "stride_h_dv: " << stride_h_dv << "\n"; - *log_file << "stride_s_dv: " << stride_s_dv << "\n"; - } + if (auto* log_file = get_bwd_log_stream()) { + *log_file << "\n" << "run dk_or_dv_reduce_thd on dv: " << "\n"; + *log_file << "cu_seqlen_kv_ptr: " << cu_seqlen_kv_ptr << "\n"; + *log_file << "cu_seqlen_kv_padded_ptr: " << cu_seqlen_kv_padded_ptr << "\n"; + *log_file << "dv_expanded_ptr: " << dv_expanded_ptr << "\n"; + *log_file << "stride_h_dv_expanded: " << stride_h_dv_expanded << "\n"; + *log_file << "stride_s_dv_expanded: " << stride_s_dv_expanded << "\n"; + *log_file << "dv_ptr: " << dv_ptr << "\n"; + *log_file << "stride_h_dv: " << stride_h_dv << "\n"; + *log_file << "stride_s_dv: " << stride_s_dv << "\n"; } CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE, hipLaunchKernelGGL( diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index abb49f3a5..5f231ee2d 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -17,16 +17,28 @@ namespace ck_fused_attn{ namespace { -std::ofstream* get_fwd_log_stream() { +std::ostream* get_fwd_log_stream() { thread_local std::ofstream log_file; thread_local bool attempted = false; + thread_local bool opened = false; + thread_local bool requested = false; + thread_local std::string log_dir_str; if (!attempted) { attempted = true; - open_ck_fused_attn_log_file(log_file, "ck_fused_attn_fwd"); + if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG")) { + log_dir_str = std::string(env_p); + requested = !log_dir_str.empty() && log_dir_str != "0"; + } + if (requested) { + opened = open_ck_fused_attn_log_file(log_file, "ck_fused_attn_fwd", log_dir_str); + } } - if (!log_file.is_open()) { + if (!requested) { return nullptr; } + if (!opened) { + return &std::cout; + } return &log_file; } } // namespace @@ -189,14 +201,9 @@ hipError_t ck_attn_fwd( right = window_size_right; mask_enum mask_type = static_cast(attn_mask_type); - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) != "") - ck_fused_attn_log_config = true; - } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_fwd_log_stream() != nullptr}; std::string data_type_str = get_data_type_str(dtype); @@ -364,14 +371,9 @@ hipError_t ck_attn_varlen_fwd( bias_enum bias_type = bias_enum::no_bias; - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) != "") - ck_fused_attn_log_config = true; - } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_fwd_log_stream() != nullptr}; std::string data_type_str = get_data_type_str(dtype); @@ -467,8 +469,9 @@ hipError_t ck_attn_varlen_fwd( // lse_thd_ptr used as buffer if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")){ if(std::string(env_p) == "1"){ - if(ck_fused_attn_log_config){ - std::cout << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + if (auto* log_file = get_fwd_log_stream()) { + *log_file + << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization.\n"; } fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, cu_seqlen_q_padded_ptr, lse_thd_ptr, stream); } diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index f601bf060..c1361a6f1 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -17,20 +17,20 @@ namespace ck_fused_attn{ -bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefix) { - const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG"); - if (env_p == nullptr) { - return false; - } - const std::string log_dir_str(env_p); - if (log_dir_str.empty() || log_dir_str == "0") { +bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefix, const std::string& log_dir_str) { + // Explicitly use std::cout as a fallback + if (log_dir_str == "1") { return false; } std::filesystem::path log_dir(log_dir_str); std::ostringstream filename; filename << file_prefix << "_" << getpid() << "_" << std::this_thread::get_id() << ".log"; log_file.open(log_dir / filename.str(), std::ios_base::app); - return log_file.is_open(); + if (!log_file.is_open()) { + std::cerr << "Failed to open log file: " << (log_dir / filename.str()) << "\n"; + return false; + } + return true; } std::string get_data_type_str(DType dtype){ diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp index cac1a0b9d..13e3d3c0a 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp @@ -9,6 +9,7 @@ #include #include +#include #include //forward declaration for ck_tile enum @@ -56,7 +57,7 @@ std::pair get_ck_bias_type_shape(BiasType attn_bias_type, uint64_t get_runtime_max_seqlen(uint64_t b, const void* cu_seqlen_ptr, const void* cu_seqlen_padded_ptr, void* workspace, hipStream_t stream); -bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefix); +bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefix, const std::string& log_dir_str); }//namespace ck_fused_attn #endif // CK_FUSED_ATTN_UTILS_H From ce12f95ad2a72b81f3222fcc4810eb33eb47039f Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 23 Feb 2026 10:42:10 -0600 Subject: [PATCH 06/10] Copyright --- .../common/ck_fused_attn/src/ck_fused_attn_bwd.cpp | 2 +- .../common/ck_fused_attn/src/ck_fused_attn_fwd.cpp | 2 +- .../common/ck_fused_attn/src/ck_fused_attn_utils.cpp | 2 +- .../common/ck_fused_attn/src/ck_fused_attn_utils.hpp | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 12a9a9c0f..0bd062f6b 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 5f231ee2d..935caa9eb 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index c1361a6f1..6bbfbda4f 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp index 13e3d3c0a..a0ea13d81 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ From a570bf0f8d6f48479e5cc5c1f1d118a07358ad77 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 25 Feb 2026 10:39:57 -0600 Subject: [PATCH 07/10] Streamlined implementation and unified fwd/bwd logging --- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 279 ++++++++---------- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 234 +++++++-------- .../ck_fused_attn/src/ck_fused_attn_utils.cpp | 25 +- .../ck_fused_attn/src/ck_fused_attn_utils.hpp | 2 +- 4 files changed, 256 insertions(+), 284 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 0bd062f6b..e0b5fd875 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -16,34 +16,6 @@ namespace ck_fused_attn{ -// We want to cache and reuse the log stream so we use thread_local here. -namespace { -std::ostream* get_bwd_log_stream() { - thread_local std::ofstream log_file; - thread_local bool attempted = false; - thread_local bool opened = false; - thread_local bool requested = false; - thread_local std::string log_dir_str; - if (!attempted) { - attempted = true; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG")) { - log_dir_str = std::string(env_p); - requested = !log_dir_str.empty() && log_dir_str != "0"; - } - if (requested) { - opened = open_ck_fused_attn_log_file(log_file, "ck_fused_attn_bwd", log_dir_str); - } - } - if (!requested) { - return nullptr; - } - if (!opened) { - return &std::cout; - } - return &log_file; -} -} // namespace - // TODO: unify with binary search in TE/common/fused_attn(rocm)/util // no device std::upper_bound // in an increasing array with given size len, search for the index that: @@ -361,119 +333,119 @@ __global__ void dbias_reduce_b1ss( } // print the fmha_traits and args passed into ck apis -void log_bwd_config(const char* func_name, - const std::string data_type_str, - const bool is_group_mode, - const mask_enum mask_type, - const bias_enum bias_type, - const bool has_dbias, - const bool has_dropout, - const bool is_store_randval, - const bool is_deterministic, - const bool uses_bwd_v3, - const bool is_v3_atomic_fp32, - const int how_v3_bf16_cvt, - const fmha_bwd_args& fmha_args){ - if (auto* log_file = get_bwd_log_stream()) { - *log_file << "\n" << func_name << "\n"; - - // fmha_traits debug - *log_file << "\n" << "fmha_traits: " << "\n"; - *log_file << "hdim_q: " << fmha_args.hdim_q << "\n"; - *log_file << "hdim_v: " << fmha_args.hdim_v << "\n"; - *log_file << "data_type: " << data_type_str << "\n"; - *log_file << "is_group_mode: " << is_group_mode << "\n"; - *log_file << "mask_type: " << static_cast::type>(mask_type) << "\n"; - *log_file << "bias_type: " << static_cast::type>(bias_type) << "\n"; - *log_file << "has_dbias: " << has_dbias << "\n"; - *log_file << "has_dropout: " << has_dropout << "\n"; - *log_file << "is_store_randval: " << is_store_randval << "\n"; - *log_file << "is_deterministic: " << is_deterministic << "\n"; - *log_file << "uses_bwd_v3: " << uses_bwd_v3 << "\n"; - *log_file << "is_v3_atomic_fp32: " << is_v3_atomic_fp32 << "\n"; - *log_file << "how_v3_bf16_cvt: " << how_v3_bf16_cvt << "\n"; +void log_bwd_config( + std::ostream* log_file, + const char* func_name, + const std::string data_type_str, + const bool is_group_mode, + const mask_enum mask_type, + const bias_enum bias_type, + const bool has_dbias, + const bool has_dropout, + const bool is_store_randval, + const bool is_deterministic, + const bool uses_bwd_v3, + const bool is_v3_atomic_fp32, + const int how_v3_bf16_cvt, + const fmha_bwd_args& fmha_args +){ + *log_file << "\n" << func_name << "\n"; - // fmha_args debug - *log_file << "\n" << "fmha_args: " << "\n"; - *log_file << "q_ptr: " << fmha_args.q_ptr << "\n"; - *log_file << "k_ptr: " << fmha_args.k_ptr << "\n"; - *log_file << "v_ptr: " << fmha_args.v_ptr << "\n"; - *log_file << "bias_ptr: " << fmha_args.bias_ptr << "\n"; - *log_file << "o_ptr: " << fmha_args.o_ptr << "\n"; - *log_file << "lse_ptr: " << fmha_args.lse_ptr << "\n"; - *log_file << "do_ptr: " << fmha_args.do_ptr << "\n"; - *log_file << "d_ptr: " << fmha_args.d_ptr << "\n"; - *log_file << "rand_val_ptr: " << fmha_args.rand_val_ptr << "\n"; - *log_file << "dq_ptr: " << fmha_args.dq_ptr << "\n"; - *log_file << "dk_ptr: " << fmha_args.dk_ptr << "\n"; - *log_file << "dv_ptr: " << fmha_args.dv_ptr << "\n"; - *log_file << "dbias_ptr: " << fmha_args.dbias_ptr << "\n"; - *log_file << "dq_acc_ptr: " << fmha_args.dq_acc_ptr << "\n"; + // fmha_traits debug + *log_file << "\n" << "fmha_traits: " << "\n"; + *log_file << "hdim_q: " << fmha_args.hdim_q << "\n"; + *log_file << "hdim_v: " << fmha_args.hdim_v << "\n"; + *log_file << "data_type: " << data_type_str << "\n"; + *log_file << "is_group_mode: " << is_group_mode << "\n"; + *log_file << "mask_type: " << static_cast::type>(mask_type) << "\n"; + *log_file << "bias_type: " << static_cast::type>(bias_type) << "\n"; + *log_file << "has_dbias: " << has_dbias << "\n"; + *log_file << "has_dropout: " << has_dropout << "\n"; + *log_file << "is_store_randval: " << is_store_randval << "\n"; + *log_file << "is_deterministic: " << is_deterministic << "\n"; + *log_file << "uses_bwd_v3: " << uses_bwd_v3 << "\n"; + *log_file << "is_v3_atomic_fp32: " << is_v3_atomic_fp32 << "\n"; + *log_file << "how_v3_bf16_cvt: " << how_v3_bf16_cvt << "\n"; - *log_file << "seqstart_q_ptr: " << fmha_args.seqstart_q_ptr << "\n"; - *log_file << "seqstart_k_ptr: " << fmha_args.seqstart_k_ptr << "\n"; - *log_file << "seqlen_q_ptr: " << fmha_args.seqlen_q_ptr << "\n"; - *log_file << "seqlen_k_ptr: " << fmha_args.seqlen_k_ptr << "\n"; - *log_file << "cu_seqlen_q_ptr: " << fmha_args.cu_seqlen_q_ptr << "\n"; - *log_file << "cu_seqlen_k_ptr: " << fmha_args.cu_seqlen_k_ptr << "\n"; + // fmha_args debug + *log_file << "\n" << "fmha_args: " << "\n"; + *log_file << "q_ptr: " << fmha_args.q_ptr << "\n"; + *log_file << "k_ptr: " << fmha_args.k_ptr << "\n"; + *log_file << "v_ptr: " << fmha_args.v_ptr << "\n"; + *log_file << "bias_ptr: " << fmha_args.bias_ptr << "\n"; + *log_file << "o_ptr: " << fmha_args.o_ptr << "\n"; + *log_file << "lse_ptr: " << fmha_args.lse_ptr << "\n"; + *log_file << "do_ptr: " << fmha_args.do_ptr << "\n"; + *log_file << "d_ptr: " << fmha_args.d_ptr << "\n"; + *log_file << "rand_val_ptr: " << fmha_args.rand_val_ptr << "\n"; + *log_file << "dq_ptr: " << fmha_args.dq_ptr << "\n"; + *log_file << "dk_ptr: " << fmha_args.dk_ptr << "\n"; + *log_file << "dv_ptr: " << fmha_args.dv_ptr << "\n"; + *log_file << "dbias_ptr: " << fmha_args.dbias_ptr << "\n"; + *log_file << "dq_acc_ptr: " << fmha_args.dq_acc_ptr << "\n"; - *log_file << "seqlen_q: " << fmha_args.seqlen_q << "\n"; - *log_file << "seqlen_k: " << fmha_args.seqlen_k << "\n"; - *log_file << "batch: " << fmha_args.batch << "\n"; - *log_file << "max_seqlen_q: " << fmha_args.max_seqlen_q << "\n"; - *log_file << "max_seqlen_k: " << fmha_args.max_seqlen_k << "\n"; - *log_file << "hdim_q: " << fmha_args.hdim_q << "\n"; - *log_file << "hdim_v: " << fmha_args.hdim_v << "\n"; - *log_file << "nhead_q: " << fmha_args.nhead_q << "\n"; - *log_file << "nhead_k: " << fmha_args.nhead_k << "\n"; - *log_file << "scale: " << fmha_args.scale << "\n"; - *log_file << "stride_q: " << fmha_args.stride_q << "\n"; - *log_file << "stride_k: " << fmha_args.stride_k << "\n"; - *log_file << "stride_v: " << fmha_args.stride_v << "\n"; - *log_file << "stride_bias: " << fmha_args.stride_bias << "\n"; - *log_file << "stride_o: " << fmha_args.stride_o << "\n"; - *log_file << "stride_randval: " << fmha_args.stride_randval << "\n"; - *log_file << "stride_do: " << fmha_args.stride_do << "\n"; - *log_file << "stride_dq_acc: " << fmha_args.stride_dq_acc << "\n"; - *log_file << "stride_dq: " << fmha_args.stride_dq << "\n"; - *log_file << "stride_dk: " << fmha_args.stride_dk << "\n"; - *log_file << "stride_dv: " << fmha_args.stride_dv << "\n"; - *log_file << "stride_dbias: " << fmha_args.stride_dbias << "\n"; - *log_file << "nhead_stride_q: " << fmha_args.nhead_stride_q << "\n"; - *log_file << "nhead_stride_k: " << fmha_args.nhead_stride_k << "\n"; - *log_file << "nhead_stride_v: " << fmha_args.nhead_stride_v << "\n"; - *log_file << "nhead_stride_bias: " << fmha_args.nhead_stride_bias << "\n"; - *log_file << "nhead_stride_o: " << fmha_args.nhead_stride_o << "\n"; - *log_file << "nhead_stride_randval: " << fmha_args.nhead_stride_randval << "\n"; - *log_file << "nhead_stride_do: " << fmha_args.nhead_stride_do << "\n"; - *log_file << "nhead_stride_lsed: " << fmha_args.nhead_stride_lsed << "\n"; - *log_file << "nhead_stride_dq_acc: " << fmha_args.nhead_stride_dq_acc << "\n"; - *log_file << "nhead_stride_dq: " << fmha_args.nhead_stride_dq << "\n"; - *log_file << "nhead_stride_dk: " << fmha_args.nhead_stride_dk << "\n"; - *log_file << "nhead_stride_dv: " << fmha_args.nhead_stride_dv << "\n"; - *log_file << "nhead_stride_dbias: " << fmha_args.nhead_stride_dbias << "\n"; - *log_file << "batch_stride_q: " << fmha_args.batch_stride_q << "\n"; - *log_file << "batch_stride_k: " << fmha_args.batch_stride_k << "\n"; - *log_file << "batch_stride_v: " << fmha_args.batch_stride_v << "\n"; - *log_file << "batch_stride_bias: " << fmha_args.batch_stride_bias << "\n"; - *log_file << "batch_stride_o: " << fmha_args.batch_stride_o << "\n"; - *log_file << "batch_stride_randval: " << fmha_args.batch_stride_randval << "\n"; - *log_file << "batch_stride_do: " << fmha_args.batch_stride_do << "\n"; - *log_file << "batch_stride_lsed: " << fmha_args.batch_stride_lsed << "\n"; - *log_file << "batch_stride_dq_acc: " << fmha_args.batch_stride_dq_acc << "\n"; - *log_file << "batch_stride_dq: " << fmha_args.batch_stride_dq << "\n"; - *log_file << "batch_stride_dk: " << fmha_args.batch_stride_dk << "\n"; - *log_file << "batch_stride_dv: " << fmha_args.batch_stride_dv << "\n"; - *log_file << "batch_stride_dbias: " << fmha_args.batch_stride_dbias << "\n"; - *log_file << "window_size_left: " << fmha_args.window_size_left << "\n"; - *log_file << "window_size_right: " << fmha_args.window_size_right << "\n"; - *log_file << "mask_type: " << fmha_args.mask_type << "\n"; - *log_file << "p_drop: " << fmha_args.p_drop << "\n"; - *log_file << "p_undrop: " << fmha_args.p_undrop << "\n"; - *log_file << "dropout_seed_ptr: " << std::get<0>(std::get>(fmha_args.drop_seed_offset)) << "\n"; - *log_file << "dropout_offset_ptr: " << std::get<1>(std::get>(fmha_args.drop_seed_offset)) << "\n"; - } + *log_file << "seqstart_q_ptr: " << fmha_args.seqstart_q_ptr << "\n"; + *log_file << "seqstart_k_ptr: " << fmha_args.seqstart_k_ptr << "\n"; + *log_file << "seqlen_q_ptr: " << fmha_args.seqlen_q_ptr << "\n"; + *log_file << "seqlen_k_ptr: " << fmha_args.seqlen_k_ptr << "\n"; + *log_file << "cu_seqlen_q_ptr: " << fmha_args.cu_seqlen_q_ptr << "\n"; + *log_file << "cu_seqlen_k_ptr: " << fmha_args.cu_seqlen_k_ptr << "\n"; + *log_file << "seqlen_q: " << fmha_args.seqlen_q << "\n"; + *log_file << "seqlen_k: " << fmha_args.seqlen_k << "\n"; + *log_file << "batch: " << fmha_args.batch << "\n"; + *log_file << "max_seqlen_q: " << fmha_args.max_seqlen_q << "\n"; + *log_file << "max_seqlen_k: " << fmha_args.max_seqlen_k << "\n"; + *log_file << "hdim_q: " << fmha_args.hdim_q << "\n"; + *log_file << "hdim_v: " << fmha_args.hdim_v << "\n"; + *log_file << "nhead_q: " << fmha_args.nhead_q << "\n"; + *log_file << "nhead_k: " << fmha_args.nhead_k << "\n"; + *log_file << "scale: " << fmha_args.scale << "\n"; + *log_file << "stride_q: " << fmha_args.stride_q << "\n"; + *log_file << "stride_k: " << fmha_args.stride_k << "\n"; + *log_file << "stride_v: " << fmha_args.stride_v << "\n"; + *log_file << "stride_bias: " << fmha_args.stride_bias << "\n"; + *log_file << "stride_o: " << fmha_args.stride_o << "\n"; + *log_file << "stride_randval: " << fmha_args.stride_randval << "\n"; + *log_file << "stride_do: " << fmha_args.stride_do << "\n"; + *log_file << "stride_dq_acc: " << fmha_args.stride_dq_acc << "\n"; + *log_file << "stride_dq: " << fmha_args.stride_dq << "\n"; + *log_file << "stride_dk: " << fmha_args.stride_dk << "\n"; + *log_file << "stride_dv: " << fmha_args.stride_dv << "\n"; + *log_file << "stride_dbias: " << fmha_args.stride_dbias << "\n"; + *log_file << "nhead_stride_q: " << fmha_args.nhead_stride_q << "\n"; + *log_file << "nhead_stride_k: " << fmha_args.nhead_stride_k << "\n"; + *log_file << "nhead_stride_v: " << fmha_args.nhead_stride_v << "\n"; + *log_file << "nhead_stride_bias: " << fmha_args.nhead_stride_bias << "\n"; + *log_file << "nhead_stride_o: " << fmha_args.nhead_stride_o << "\n"; + *log_file << "nhead_stride_randval: " << fmha_args.nhead_stride_randval << "\n"; + *log_file << "nhead_stride_do: " << fmha_args.nhead_stride_do << "\n"; + *log_file << "nhead_stride_lsed: " << fmha_args.nhead_stride_lsed << "\n"; + *log_file << "nhead_stride_dq_acc: " << fmha_args.nhead_stride_dq_acc << "\n"; + *log_file << "nhead_stride_dq: " << fmha_args.nhead_stride_dq << "\n"; + *log_file << "nhead_stride_dk: " << fmha_args.nhead_stride_dk << "\n"; + *log_file << "nhead_stride_dv: " << fmha_args.nhead_stride_dv << "\n"; + *log_file << "nhead_stride_dbias: " << fmha_args.nhead_stride_dbias << "\n"; + *log_file << "batch_stride_q: " << fmha_args.batch_stride_q << "\n"; + *log_file << "batch_stride_k: " << fmha_args.batch_stride_k << "\n"; + *log_file << "batch_stride_v: " << fmha_args.batch_stride_v << "\n"; + *log_file << "batch_stride_bias: " << fmha_args.batch_stride_bias << "\n"; + *log_file << "batch_stride_o: " << fmha_args.batch_stride_o << "\n"; + *log_file << "batch_stride_randval: " << fmha_args.batch_stride_randval << "\n"; + *log_file << "batch_stride_do: " << fmha_args.batch_stride_do << "\n"; + *log_file << "batch_stride_lsed: " << fmha_args.batch_stride_lsed << "\n"; + *log_file << "batch_stride_dq_acc: " << fmha_args.batch_stride_dq_acc << "\n"; + *log_file << "batch_stride_dq: " << fmha_args.batch_stride_dq << "\n"; + *log_file << "batch_stride_dk: " << fmha_args.batch_stride_dk << "\n"; + *log_file << "batch_stride_dv: " << fmha_args.batch_stride_dv << "\n"; + *log_file << "batch_stride_dbias: " << fmha_args.batch_stride_dbias << "\n"; + *log_file << "window_size_left: " << fmha_args.window_size_left << "\n"; + *log_file << "window_size_right: " << fmha_args.window_size_right << "\n"; + *log_file << "mask_type: " << fmha_args.mask_type << "\n"; + *log_file << "p_drop: " << fmha_args.p_drop << "\n"; + *log_file << "p_undrop: " << fmha_args.p_undrop << "\n"; + *log_file << "dropout_seed_ptr: " << std::get<0>(std::get>(fmha_args.drop_seed_offset)) << "\n"; + *log_file << "dropout_offset_ptr: " << std::get<1>(std::get>(fmha_args.drop_seed_offset)) << "\n"; } void dump_bwd_timings(const char* dump_path, float average_runtime){ @@ -554,7 +526,7 @@ hipError_t ck_attn_bwd( const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_bwd_log_stream() != nullptr}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; ck_tile::index_t shape_seqlen_q = seqlen_q; ck_tile::index_t shape_seqlen_k = seqlen_k; @@ -699,8 +671,9 @@ hipError_t ck_attn_bwd( }(); // print ck traits and args when needed - log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_type, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); - + if (auto* log_file = get_ck_log_stream()) { + log_bwd_config(log_file, __FUNCTION__, data_type_str, is_group_mode, mask_type, bias_type, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); + } float average_runtime = aiter::mha_bwd(fmha_args, stream_config, data_type_str, @@ -724,7 +697,7 @@ hipError_t ck_attn_bwd( dim3 grid(b, s_kv, hg); if (d_qk == d_v) { dim3 block(d_qk); - if (auto* log_file = get_bwd_log_stream()) { + if (auto* log_file = get_ck_log_stream()) { *log_file << "\n" << "run dk_dv_reduce: " << "\n"; *log_file << "dk_expanded_ptr: " << dk_expanded_ptr << "\n"; *log_file << "dv_expanded_ptr: " << dv_expanded_ptr << "\n"; @@ -749,7 +722,7 @@ hipError_t ck_attn_bwd( stride_b_dk, stride_h_dk, stride_s_dk);); } else { dim3 block_dk(d_qk); - if (auto* log_file = get_bwd_log_stream()) { + if (auto* log_file = get_ck_log_stream()) { *log_file << "\n" << "run dk_or_dv_reduce on dk: " << "\n"; *log_file << "dk_expanded_ptr: " << dk_expanded_ptr << "\n"; *log_file << "stride_b_dk_expanded: " << stride_b_dk_expanded << "\n"; @@ -770,7 +743,7 @@ hipError_t ck_attn_bwd( stride_b_dk, stride_h_dk, stride_s_dk);); dim3 block_dv(d_v); - if (auto* log_file = get_bwd_log_stream()) { + if (auto* log_file = get_ck_log_stream()) { *log_file << "\n" << "run dk_or_dv_reduce on dv: " << "\n"; *log_file << "dv_expanded_ptr: " << dv_expanded_ptr << "\n"; *log_file << "stride_b_dv_expanded: " << stride_b_dv_expanded << "\n"; @@ -798,7 +771,7 @@ hipError_t ck_attn_bwd( dim3 block(THREADS_PER_BLOCK); dim3 grid(ceil(1.0 * s_q * s_kv/THREADS_PER_BLOCK)); if(bias_shape==BiasShape::k11SS){ - if (auto* log_file = get_bwd_log_stream()) { + if (auto* log_file = get_ck_log_stream()) { *log_file << "\n" << "run dbias_reduce_11SS: " << "\n"; *log_file << "dbias_ptr: " << dbias_ptr << "\n"; *log_file << "dbias_expanded_ptr: " << dbias_expanded_ptr << "\n"; @@ -810,7 +783,7 @@ hipError_t ck_attn_bwd( static_cast(dbias_expanded_ptr), static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::k1HSS){ - if (auto* log_file = get_bwd_log_stream()) { + if (auto* log_file = get_ck_log_stream()) { *log_file << "\n" << "run dbias_reduce_1HSS: " << "\n"; *log_file << "dbias_ptr: " << dbias_ptr << "\n"; *log_file << "dbias_expanded_ptr: " << dbias_expanded_ptr << "\n"; @@ -822,7 +795,7 @@ hipError_t ck_attn_bwd( static_cast(dbias_expanded_ptr), static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::kB1SS){ - if (auto* log_file = get_bwd_log_stream()) { + if (auto* log_file = get_ck_log_stream()) { *log_file << "\n" << "run dbias_reduce_B1SS: " << "\n"; *log_file << "dbias_ptr: " << dbias_ptr << "\n"; *log_file << "dbias_expanded_ptr: " << dbias_expanded_ptr << "\n"; @@ -904,7 +877,7 @@ hipError_t ck_attn_varlen_bwd( const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_bwd_log_stream() != nullptr}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; std::string data_type_str = get_data_type_str(dtype); @@ -1046,7 +1019,7 @@ hipError_t ck_attn_varlen_bwd( // lse_thd_ptr used as buffer if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { if(std::string(env_p) == "1"){ - if (auto* log_file = get_bwd_log_stream()) { + if (auto* log_file = get_ck_log_stream()) { *log_file << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization.\n"; } @@ -1056,7 +1029,9 @@ hipError_t ck_attn_varlen_bwd( } // print ck traits and args when needed - log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); + if (auto* log_file = get_ck_log_stream()) { + log_bwd_config(log_file, __FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); + } float average_runtime = aiter::mha_bwd(fmha_args, stream_config, @@ -1081,7 +1056,7 @@ hipError_t ck_attn_varlen_bwd( dim3 grid(max_tokens_kv, hg); if (d_qk == d_v) { dim3 block(d_qk); - if (auto* log_file = get_bwd_log_stream()) { + if (auto* log_file = get_ck_log_stream()) { *log_file << "\n" << "run dk_dv_reduce_thd: " << "\n"; *log_file << "cu_seqlen_kv_ptr: " << cu_seqlen_kv_ptr << "\n"; *log_file << "cu_seqlen_kv_padded_ptr: " << cu_seqlen_kv_padded_ptr << "\n"; @@ -1108,7 +1083,7 @@ hipError_t ck_attn_varlen_bwd( stride_h_dk, stride_s_dk);); } else { dim3 block_dk(d_qk); - if (auto* log_file = get_bwd_log_stream()) { + if (auto* log_file = get_ck_log_stream()) { *log_file << "\n" << "run dk_or_dv_reduce_thd on dk: " << "\n"; *log_file << "cu_seqlen_kv_ptr: " << cu_seqlen_kv_ptr << "\n"; *log_file << "cu_seqlen_kv_padded_ptr: " << cu_seqlen_kv_padded_ptr << "\n"; @@ -1131,7 +1106,7 @@ hipError_t ck_attn_varlen_bwd( stride_h_dk, stride_s_dk);); dim3 block_dv(d_v); - if (auto* log_file = get_bwd_log_stream()) { + if (auto* log_file = get_ck_log_stream()) { *log_file << "\n" << "run dk_or_dv_reduce_thd on dv: " << "\n"; *log_file << "cu_seqlen_kv_ptr: " << cu_seqlen_kv_ptr << "\n"; *log_file << "cu_seqlen_kv_padded_ptr: " << cu_seqlen_kv_padded_ptr << "\n"; diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 935caa9eb..e739560c4 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -16,130 +16,104 @@ namespace ck_fused_attn{ -namespace { -std::ostream* get_fwd_log_stream() { - thread_local std::ofstream log_file; - thread_local bool attempted = false; - thread_local bool opened = false; - thread_local bool requested = false; - thread_local std::string log_dir_str; - if (!attempted) { - attempted = true; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG")) { - log_dir_str = std::string(env_p); - requested = !log_dir_str.empty() && log_dir_str != "0"; - } - if (requested) { - opened = open_ck_fused_attn_log_file(log_file, "ck_fused_attn_fwd", log_dir_str); - } - } - if (!requested) { - return nullptr; - } - if (!opened) { - return &std::cout; - } - return &log_file; -} -} // namespace - // print the fmha traits and args when calling ck apis -void log_fwd_config(const char* func_name, - const std::string data_type_str, - const bool is_group_mode, - const bool has_logits_soft_cap, - const mask_enum mask_type, - const bias_enum bias_type, - const bool has_lse, - const bool has_dropout, - const bool is_v_rowmajor, - const bool do_fp8_static_quant, - const bool uses_fwd_v3, - const bool how_v3_bf16_cvt, - const fmha_fwd_args& fmha_args){ - if (auto* log_file = get_fwd_log_stream()) { - *log_file << "\n" << func_name << "\n"; - - // debug fmha_traits - *log_file << "\n" << "fmha_traits: " << "\n"; - *log_file << "hdim_q: " << fmha_args.hdim_q << "\n"; - *log_file << "hdim_v: " << fmha_args.hdim_v << "\n"; - *log_file << "data_type: " << data_type_str << "\n"; - *log_file << "is_group_mode: " << is_group_mode << "\n"; - *log_file << "is_v_rowmajor: " << is_v_rowmajor << "\n"; - *log_file << "has_logits_soft_cap: " << has_logits_soft_cap << "\n"; - *log_file << "mask_type: " << static_cast::type>(mask_type) << "\n"; - *log_file << "bias_type: " << static_cast::type>(bias_type) << "\n"; - *log_file << "has_lse: " << has_lse << "\n"; - *log_file << "has_dropout: " << has_dropout << "\n"; - *log_file << "do_fp8_static_quant: " << do_fp8_static_quant << "\n"; - *log_file << "skip_min_seqlen_q: " << (fmha_args.min_seqlen_q != 0) << "\n"; - *log_file << "uses_fwd_v3: " << uses_fwd_v3 << "\n"; - *log_file << "how_v3_bf16_cvt: " << how_v3_bf16_cvt << "\n"; - - // debug fmha_args - *log_file << "\n" << "fmha_args: " << "\n"; - - *log_file << "q_ptr: " << fmha_args.q_ptr << "\n"; - *log_file << "k_ptr: " << fmha_args.k_ptr << "\n"; - *log_file << "v_ptr: " << fmha_args.v_ptr << "\n"; - *log_file << "bias_ptr: " << fmha_args.bias_ptr << "\n"; - *log_file << "rand_val_ptr: " << fmha_args.rand_val_ptr << "\n"; - *log_file << "lse_ptr: " << fmha_args.lse_ptr << "\n"; - *log_file << "o_ptr: " << fmha_args.o_ptr << "\n"; - - *log_file << "seqstart_q_ptr: " << fmha_args.seqstart_q_ptr << "\n"; - *log_file << "seqstart_k_ptr: " << fmha_args.seqstart_k_ptr << "\n"; - *log_file << "seqlen_q_ptr: " << fmha_args.seqlen_q_ptr << "\n"; - *log_file << "seqlen_k_ptr: " << fmha_args.seqlen_k_ptr << "\n"; - *log_file << "cu_seqlen_q_ptr: " << fmha_args.cu_seqlen_q_ptr << "\n"; - *log_file << "cu_seqlen_k_ptr: " << fmha_args.cu_seqlen_k_ptr << "\n"; - - *log_file << "seqlen_q: " << fmha_args.seqlen_q << "\n"; - *log_file << "seqlen_k: " << fmha_args.seqlen_k << "\n"; - *log_file << "batch: " << fmha_args.batch << "\n"; - *log_file << "max_seqlen_q: " << fmha_args.max_seqlen_q << "\n"; - *log_file << "hdim_q: " << fmha_args.hdim_q << "\n"; - *log_file << "hdim_v: " << fmha_args.hdim_v << "\n"; - *log_file << "nhead_q: " << fmha_args.nhead_q << "\n"; - *log_file << "nhead_k: " << fmha_args.nhead_k << "\n"; - - *log_file << "scale_s: " << fmha_args.scale_s << "\n"; - - *log_file << "logits_soft_cap: " << fmha_args.logits_soft_cap << "\n"; - - *log_file << "stride_q: " << fmha_args.stride_q << "\n"; - *log_file << "stride_k: " << fmha_args.stride_k << "\n"; - *log_file << "stride_v: " << fmha_args.stride_v << "\n"; - *log_file << "stride_bias: " << fmha_args.stride_bias << "\n"; - *log_file << "stride_randval: " << fmha_args.stride_randval << "\n"; - *log_file << "stride_o: " << fmha_args.stride_o << "\n"; - *log_file << "nhead_stride_q: " << fmha_args.nhead_stride_q << "\n"; - *log_file << "nhead_stride_k: " << fmha_args.nhead_stride_k << "\n"; - *log_file << "nhead_stride_v: " << fmha_args.nhead_stride_v << "\n"; - *log_file << "nhead_stride_bias: " << fmha_args.nhead_stride_bias << "\n"; - *log_file << "nhead_stride_randval: " << fmha_args.nhead_stride_randval << "\n"; - *log_file << "nhead_stride_lse: " << fmha_args.nhead_stride_lse << "\n"; - *log_file << "nhead_stride_o: " << fmha_args.nhead_stride_o << "\n"; - *log_file << "batch_stride_q: " << fmha_args.batch_stride_q << "\n"; - *log_file << "batch_stride_k: " << fmha_args.batch_stride_k << "\n"; - *log_file << "batch_stride_v: " << fmha_args.batch_stride_v << "\n"; - *log_file << "batch_stride_bias: " << fmha_args.batch_stride_bias << "\n"; - *log_file << "batch_stride_randval: " << fmha_args.batch_stride_randval << "\n"; - *log_file << "batch_stride_lse: " << fmha_args.batch_stride_lse << "\n"; - *log_file << "batch_stride_o: " << fmha_args.batch_stride_o << "\n"; - - *log_file << "window_size_left: " << fmha_args.window_size_left << "\n"; - *log_file << "window_size_right: " << fmha_args.window_size_right << "\n"; - *log_file << "mask_type: " << fmha_args.mask_type << "\n"; - *log_file << "min_seqlen_q: " << fmha_args.min_seqlen_q << "\n"; - - *log_file << "p_drop: " << fmha_args.p_drop << "\n"; - *log_file << "s_randval: " << fmha_args.s_randval << "\n"; - - *log_file << "dropout_seed_ptr: " << std::get<0>(std::get>(fmha_args.drop_seed_offset)) << "\n"; - *log_file << "dropout_offset_ptr: " << std::get<1>(std::get>(fmha_args.drop_seed_offset)) << "\n"; - } +void log_fwd_config( + std::ostream* log_file, + const char* func_name, + const std::string data_type_str, + const bool is_group_mode, + const bool has_logits_soft_cap, + const mask_enum mask_type, + const bias_enum bias_type, + const bool has_lse, + const bool has_dropout, + const bool is_v_rowmajor, + const bool do_fp8_static_quant, + const bool uses_fwd_v3, + const bool how_v3_bf16_cvt, + const fmha_fwd_args& fmha_args +){ + *log_file << "\n" << func_name << "\n"; + + // debug fmha_traits + *log_file << "\n" << "fmha_traits: " << "\n"; + *log_file << "hdim_q: " << fmha_args.hdim_q << "\n"; + *log_file << "hdim_v: " << fmha_args.hdim_v << "\n"; + *log_file << "data_type: " << data_type_str << "\n"; + *log_file << "is_group_mode: " << is_group_mode << "\n"; + *log_file << "is_v_rowmajor: " << is_v_rowmajor << "\n"; + *log_file << "has_logits_soft_cap: " << has_logits_soft_cap << "\n"; + *log_file << "mask_type: " << static_cast::type>(mask_type) << "\n"; + *log_file << "bias_type: " << static_cast::type>(bias_type) << "\n"; + *log_file << "has_lse: " << has_lse << "\n"; + *log_file << "has_dropout: " << has_dropout << "\n"; + *log_file << "do_fp8_static_quant: " << do_fp8_static_quant << "\n"; + *log_file << "skip_min_seqlen_q: " << (fmha_args.min_seqlen_q != 0) << "\n"; + *log_file << "uses_fwd_v3: " << uses_fwd_v3 << "\n"; + *log_file << "how_v3_bf16_cvt: " << how_v3_bf16_cvt << "\n"; + + // debug fmha_args + *log_file << "\n" << "fmha_args: " << "\n"; + + *log_file << "q_ptr: " << fmha_args.q_ptr << "\n"; + *log_file << "k_ptr: " << fmha_args.k_ptr << "\n"; + *log_file << "v_ptr: " << fmha_args.v_ptr << "\n"; + *log_file << "bias_ptr: " << fmha_args.bias_ptr << "\n"; + *log_file << "rand_val_ptr: " << fmha_args.rand_val_ptr << "\n"; + *log_file << "lse_ptr: " << fmha_args.lse_ptr << "\n"; + *log_file << "o_ptr: " << fmha_args.o_ptr << "\n"; + + *log_file << "seqstart_q_ptr: " << fmha_args.seqstart_q_ptr << "\n"; + *log_file << "seqstart_k_ptr: " << fmha_args.seqstart_k_ptr << "\n"; + *log_file << "seqlen_q_ptr: " << fmha_args.seqlen_q_ptr << "\n"; + *log_file << "seqlen_k_ptr: " << fmha_args.seqlen_k_ptr << "\n"; + *log_file << "cu_seqlen_q_ptr: " << fmha_args.cu_seqlen_q_ptr << "\n"; + *log_file << "cu_seqlen_k_ptr: " << fmha_args.cu_seqlen_k_ptr << "\n"; + + *log_file << "seqlen_q: " << fmha_args.seqlen_q << "\n"; + *log_file << "seqlen_k: " << fmha_args.seqlen_k << "\n"; + *log_file << "batch: " << fmha_args.batch << "\n"; + *log_file << "max_seqlen_q: " << fmha_args.max_seqlen_q << "\n"; + *log_file << "hdim_q: " << fmha_args.hdim_q << "\n"; + *log_file << "hdim_v: " << fmha_args.hdim_v << "\n"; + *log_file << "nhead_q: " << fmha_args.nhead_q << "\n"; + *log_file << "nhead_k: " << fmha_args.nhead_k << "\n"; + + *log_file << "scale_s: " << fmha_args.scale_s << "\n"; + + *log_file << "logits_soft_cap: " << fmha_args.logits_soft_cap << "\n"; + + *log_file << "stride_q: " << fmha_args.stride_q << "\n"; + *log_file << "stride_k: " << fmha_args.stride_k << "\n"; + *log_file << "stride_v: " << fmha_args.stride_v << "\n"; + *log_file << "stride_bias: " << fmha_args.stride_bias << "\n"; + *log_file << "stride_randval: " << fmha_args.stride_randval << "\n"; + *log_file << "stride_o: " << fmha_args.stride_o << "\n"; + *log_file << "nhead_stride_q: " << fmha_args.nhead_stride_q << "\n"; + *log_file << "nhead_stride_k: " << fmha_args.nhead_stride_k << "\n"; + *log_file << "nhead_stride_v: " << fmha_args.nhead_stride_v << "\n"; + *log_file << "nhead_stride_bias: " << fmha_args.nhead_stride_bias << "\n"; + *log_file << "nhead_stride_randval: " << fmha_args.nhead_stride_randval << "\n"; + *log_file << "nhead_stride_lse: " << fmha_args.nhead_stride_lse << "\n"; + *log_file << "nhead_stride_o: " << fmha_args.nhead_stride_o << "\n"; + *log_file << "batch_stride_q: " << fmha_args.batch_stride_q << "\n"; + *log_file << "batch_stride_k: " << fmha_args.batch_stride_k << "\n"; + *log_file << "batch_stride_v: " << fmha_args.batch_stride_v << "\n"; + *log_file << "batch_stride_bias: " << fmha_args.batch_stride_bias << "\n"; + *log_file << "batch_stride_randval: " << fmha_args.batch_stride_randval << "\n"; + *log_file << "batch_stride_lse: " << fmha_args.batch_stride_lse << "\n"; + *log_file << "batch_stride_o: " << fmha_args.batch_stride_o << "\n"; + + *log_file << "window_size_left: " << fmha_args.window_size_left << "\n"; + *log_file << "window_size_right: " << fmha_args.window_size_right << "\n"; + *log_file << "mask_type: " << fmha_args.mask_type << "\n"; + *log_file << "min_seqlen_q: " << fmha_args.min_seqlen_q << "\n"; + + *log_file << "p_drop: " << fmha_args.p_drop << "\n"; + *log_file << "s_randval: " << fmha_args.s_randval << "\n"; + + *log_file << "dropout_seed_ptr: " << std::get<0>(std::get>(fmha_args.drop_seed_offset)) << "\n"; + *log_file << "dropout_offset_ptr: " << std::get<1>(std::get>(fmha_args.drop_seed_offset)) << "\n"; } void dump_fwd_timings(const char* dump_path, float average_runtime){ @@ -203,7 +177,7 @@ hipError_t ck_attn_fwd( const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_fwd_log_stream() != nullptr}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; std::string data_type_str = get_data_type_str(dtype); @@ -294,7 +268,9 @@ hipError_t ck_attn_fwd( }(); // print ck traits and args when needed - log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); + if (auto* log_file = get_ck_log_stream()) { + log_fwd_config(log_file, __FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); + } float average_runtime = aiter::mha_fwd(fmha_args, stream_config, @@ -373,7 +349,7 @@ hipError_t ck_attn_varlen_fwd( const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_fwd_log_stream() != nullptr}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; std::string data_type_str = get_data_type_str(dtype); @@ -469,7 +445,7 @@ hipError_t ck_attn_varlen_fwd( // lse_thd_ptr used as buffer if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")){ if(std::string(env_p) == "1"){ - if (auto* log_file = get_fwd_log_stream()) { + if (auto* log_file = get_ck_log_stream()) { *log_file << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization.\n"; } @@ -477,7 +453,9 @@ hipError_t ck_attn_varlen_fwd( } } // print ck traits and args when needed - log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); + if (auto* log_file = get_ck_log_stream()) { + log_fwd_config(log_file, __FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); + } float average_runtime = aiter::mha_fwd( fmha_args, diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index 6bbfbda4f..abacb8f0a 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -19,9 +19,6 @@ namespace ck_fused_attn{ bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefix, const std::string& log_dir_str) { // Explicitly use std::cout as a fallback - if (log_dir_str == "1") { - return false; - } std::filesystem::path log_dir(log_dir_str); std::ostringstream filename; filename << file_prefix << "_" << getpid() << "_" << std::this_thread::get_id() << ".log"; @@ -33,6 +30,28 @@ bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefi return true; } +std::ostream* get_ck_log_stream() { + thread_local std::ofstream log_file; + thread_local std::ostream* log_stream = nullptr; + thread_local bool initialized = false; + if (!initialized) { + initialized = true; + if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG")) { + std::string log_dir_str(env_p); + if (!log_dir_str.empty() && log_dir_str != "0") { + if (log_dir_str == "1") { + log_stream = static_cast(&std::cout); + } + if (open_ck_fused_attn_log_file(log_file, "ck_fused_attn", log_dir_str)) { + log_stream = static_cast(&log_file); + } + } + } + } + + return log_stream; +} + std::string get_data_type_str(DType dtype){ std::string data_type_str; if(dtype==DType::kFloat16){ diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp index a0ea13d81..71dce0c0e 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp @@ -57,7 +57,7 @@ std::pair get_ck_bias_type_shape(BiasType attn_bias_type, uint64_t get_runtime_max_seqlen(uint64_t b, const void* cu_seqlen_ptr, const void* cu_seqlen_padded_ptr, void* workspace, hipStream_t stream); -bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefix, const std::string& log_dir_str); +std::ostream* get_ck_log_stream(); }//namespace ck_fused_attn #endif // CK_FUSED_ATTN_UTILS_H From ce5e5864b5b3a3ddc17004c68ca2f0a0eed88ba5 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 26 Feb 2026 11:23:00 -0600 Subject: [PATCH 08/10] Trimmed headers --- .../common/ck_fused_attn/src/ck_fused_attn_utils.hpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp index 71dce0c0e..3270ec32f 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp @@ -7,9 +7,7 @@ #ifndef CK_FUSED_ATTN_UTILS_H #define CK_FUSED_ATTN_UTILS_H -#include #include -#include #include //forward declaration for ck_tile enum From 7262f58942b71e6c45d502d9ee6000913259b55e Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 26 Feb 2026 11:25:54 -0600 Subject: [PATCH 09/10] Removed static casts --- .../common/ck_fused_attn/src/ck_fused_attn_utils.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index abacb8f0a..72a0122d9 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -40,10 +40,10 @@ std::ostream* get_ck_log_stream() { std::string log_dir_str(env_p); if (!log_dir_str.empty() && log_dir_str != "0") { if (log_dir_str == "1") { - log_stream = static_cast(&std::cout); + log_stream = &std::cout; } - if (open_ck_fused_attn_log_file(log_file, "ck_fused_attn", log_dir_str)) { - log_stream = static_cast(&log_file); + else if (open_ck_fused_attn_log_file(log_file, "ck_fused_attn", log_dir_str)) { + log_stream = &log_file; } } } From f279d525dc6fc99d9a4d9b107baa35479d677aed Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 27 Feb 2026 14:15:05 -0600 Subject: [PATCH 10/10] Removed redundant includes --- .../common/ck_fused_attn/src/ck_fused_attn_bwd.cpp | 1 - .../common/ck_fused_attn/src/ck_fused_attn_fwd.cpp | 1 - 2 files changed, 2 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index e0b5fd875..c349b9681 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -4,7 +4,6 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ -#include #include #include #include diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index e739560c4..f9e93515d 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -4,7 +4,6 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ -#include #include #include #include