diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 3f51b96b6..c349b9681 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -1,10 +1,9 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ -#include #include #include #include @@ -333,125 +332,119 @@ __global__ void dbias_reduce_b1ss( } // print the fmha_traits and args passed into ck apis -void log_bwd_config(const char* func_name, - const std::string data_type_str, - const bool is_group_mode, - const mask_enum mask_type, - const bias_enum bias_type, - const bool has_dbias, - const bool has_dropout, - const bool is_store_randval, - const bool is_deterministic, - const bool uses_bwd_v3, - const bool is_v3_atomic_fp32, - const int how_v3_bf16_cvt, - const fmha_bwd_args& fmha_args){ +void log_bwd_config( + std::ostream* log_file, + const char* func_name, + const std::string data_type_str, + const bool is_group_mode, + const mask_enum mask_type, + const bias_enum bias_type, + const bool has_dbias, + const bool has_dropout, + const bool is_store_randval, + const bool is_deterministic, + const bool uses_bwd_v3, + const bool is_v3_atomic_fp32, + const int how_v3_bf16_cvt, + const fmha_bwd_args& fmha_args +){ + *log_file << "\n" << func_name << "\n"; - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } - if (ck_fused_attn_log_config) { - std::cout<::type>(mask_type)<::type>(bias_type)<::type>(mask_type) << "\n"; + *log_file << "bias_type: " << static_cast::type>(bias_type) << "\n"; + *log_file << "has_dbias: " << has_dbias << "\n"; + *log_file << "has_dropout: " << has_dropout << "\n"; + *log_file << "is_store_randval: " << is_store_randval << "\n"; + *log_file << "is_deterministic: " << is_deterministic << "\n"; + *log_file << "uses_bwd_v3: " << uses_bwd_v3 << "\n"; + *log_file << "is_v3_atomic_fp32: " << is_v3_atomic_fp32 << "\n"; + *log_file << "how_v3_bf16_cvt: " << how_v3_bf16_cvt << "\n"; - // fmha_args debug - std::cout<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset)) << "\n"; + *log_file << "dropout_offset_ptr: " << std::get<1>(std::get>(fmha_args.drop_seed_offset)) << "\n"; } void dump_bwd_timings(const char* dump_path, float average_runtime){ @@ -529,15 +522,10 @@ hipError_t ck_attn_bwd( right = window_size_right; mask_enum mask_type = static_cast(attn_mask_type); - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; ck_tile::index_t shape_seqlen_q = seqlen_q; ck_tile::index_t shape_seqlen_k = seqlen_k; @@ -682,8 +670,9 @@ hipError_t ck_attn_bwd( }(); // print ck traits and args when needed - log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_type, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); - + if (auto* log_file = get_ck_log_stream()) { + log_bwd_config(log_file, __FUNCTION__, data_type_str, is_group_mode, mask_type, bias_type, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); + } float average_runtime = aiter::mha_bwd(fmha_args, stream_config, data_type_str, @@ -707,18 +696,18 @@ hipError_t ck_attn_bwd( dim3 grid(b, s_kv, hg); if (d_qk == d_v) { dim3 block(d_qk); - if (ck_fused_attn_log_config){ - std::cout<(dbias_expanded_ptr), static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::k1HSS){ - if (ck_fused_attn_log_config){ - std::cout<(dbias_expanded_ptr), static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::kB1SS){ - if (ck_fused_attn_log_config){ - std::cout<(attn_mask_type); - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; std::string data_type_str = get_data_type_str(dtype); @@ -1034,8 +1018,9 @@ hipError_t ck_attn_varlen_bwd( // lse_thd_ptr used as buffer if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { if(std::string(env_p) == "1"){ - if(ck_fused_attn_log_config){ - std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + if (auto* log_file = get_ck_log_stream()) { + *log_file + << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization.\n"; } fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream); fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream); @@ -1043,7 +1028,9 @@ hipError_t ck_attn_varlen_bwd( } // print ck traits and args when needed - log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); + if (auto* log_file = get_ck_log_stream()) { + log_bwd_config(log_file, __FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); + } float average_runtime = aiter::mha_bwd(fmha_args, stream_config, @@ -1068,18 +1055,18 @@ hipError_t ck_attn_varlen_bwd( dim3 grid(max_tokens_kv, hg); if (d_qk == d_v) { dim3 block(d_qk); - if (ck_fused_attn_log_config){ - std::cout< #include #include #include @@ -17,107 +16,103 @@ namespace ck_fused_attn{ // print the fmha traits and args when calling ck apis -void log_fwd_config(const char* func_name, - const std::string data_type_str, - const bool is_group_mode, - const bool has_logits_soft_cap, - const mask_enum mask_type, - const bias_enum bias_type, - const bool has_lse, - const bool has_dropout, - const bool is_v_rowmajor, - const bool do_fp8_static_quant, - const bool uses_fwd_v3, - const bool how_v3_bf16_cvt, - const fmha_fwd_args& fmha_args){ - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } - if (ck_fused_attn_log_config) { - std::cout<::type>(mask_type)<::type>(bias_type)<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset))<::type>(mask_type) << "\n"; + *log_file << "bias_type: " << static_cast::type>(bias_type) << "\n"; + *log_file << "has_lse: " << has_lse << "\n"; + *log_file << "has_dropout: " << has_dropout << "\n"; + *log_file << "do_fp8_static_quant: " << do_fp8_static_quant << "\n"; + *log_file << "skip_min_seqlen_q: " << (fmha_args.min_seqlen_q != 0) << "\n"; + *log_file << "uses_fwd_v3: " << uses_fwd_v3 << "\n"; + *log_file << "how_v3_bf16_cvt: " << how_v3_bf16_cvt << "\n"; + + // debug fmha_args + *log_file << "\n" << "fmha_args: " << "\n"; + + *log_file << "q_ptr: " << fmha_args.q_ptr << "\n"; + *log_file << "k_ptr: " << fmha_args.k_ptr << "\n"; + *log_file << "v_ptr: " << fmha_args.v_ptr << "\n"; + *log_file << "bias_ptr: " << fmha_args.bias_ptr << "\n"; + *log_file << "rand_val_ptr: " << fmha_args.rand_val_ptr << "\n"; + *log_file << "lse_ptr: " << fmha_args.lse_ptr << "\n"; + *log_file << "o_ptr: " << fmha_args.o_ptr << "\n"; + + *log_file << "seqstart_q_ptr: " << fmha_args.seqstart_q_ptr << "\n"; + *log_file << "seqstart_k_ptr: " << fmha_args.seqstart_k_ptr << "\n"; + *log_file << "seqlen_q_ptr: " << fmha_args.seqlen_q_ptr << "\n"; + *log_file << "seqlen_k_ptr: " << fmha_args.seqlen_k_ptr << "\n"; + *log_file << "cu_seqlen_q_ptr: " << fmha_args.cu_seqlen_q_ptr << "\n"; + *log_file << "cu_seqlen_k_ptr: " << fmha_args.cu_seqlen_k_ptr << "\n"; + + *log_file << "seqlen_q: " << fmha_args.seqlen_q << "\n"; + *log_file << "seqlen_k: " << fmha_args.seqlen_k << "\n"; + *log_file << "batch: " << fmha_args.batch << "\n"; + *log_file << "max_seqlen_q: " << fmha_args.max_seqlen_q << "\n"; + *log_file << "hdim_q: " << fmha_args.hdim_q << "\n"; + *log_file << "hdim_v: " << fmha_args.hdim_v << "\n"; + *log_file << "nhead_q: " << fmha_args.nhead_q << "\n"; + *log_file << "nhead_k: " << fmha_args.nhead_k << "\n"; + + *log_file << "scale_s: " << fmha_args.scale_s << "\n"; + + *log_file << "logits_soft_cap: " << fmha_args.logits_soft_cap << "\n"; + + *log_file << "stride_q: " << fmha_args.stride_q << "\n"; + *log_file << "stride_k: " << fmha_args.stride_k << "\n"; + *log_file << "stride_v: " << fmha_args.stride_v << "\n"; + *log_file << "stride_bias: " << fmha_args.stride_bias << "\n"; + *log_file << "stride_randval: " << fmha_args.stride_randval << "\n"; + *log_file << "stride_o: " << fmha_args.stride_o << "\n"; + *log_file << "nhead_stride_q: " << fmha_args.nhead_stride_q << "\n"; + *log_file << "nhead_stride_k: " << fmha_args.nhead_stride_k << "\n"; + *log_file << "nhead_stride_v: " << fmha_args.nhead_stride_v << "\n"; + *log_file << "nhead_stride_bias: " << fmha_args.nhead_stride_bias << "\n"; + *log_file << "nhead_stride_randval: " << fmha_args.nhead_stride_randval << "\n"; + *log_file << "nhead_stride_lse: " << fmha_args.nhead_stride_lse << "\n"; + *log_file << "nhead_stride_o: " << fmha_args.nhead_stride_o << "\n"; + *log_file << "batch_stride_q: " << fmha_args.batch_stride_q << "\n"; + *log_file << "batch_stride_k: " << fmha_args.batch_stride_k << "\n"; + *log_file << "batch_stride_v: " << fmha_args.batch_stride_v << "\n"; + *log_file << "batch_stride_bias: " << fmha_args.batch_stride_bias << "\n"; + *log_file << "batch_stride_randval: " << fmha_args.batch_stride_randval << "\n"; + *log_file << "batch_stride_lse: " << fmha_args.batch_stride_lse << "\n"; + *log_file << "batch_stride_o: " << fmha_args.batch_stride_o << "\n"; + + *log_file << "window_size_left: " << fmha_args.window_size_left << "\n"; + *log_file << "window_size_right: " << fmha_args.window_size_right << "\n"; + *log_file << "mask_type: " << fmha_args.mask_type << "\n"; + *log_file << "min_seqlen_q: " << fmha_args.min_seqlen_q << "\n"; + + *log_file << "p_drop: " << fmha_args.p_drop << "\n"; + *log_file << "s_randval: " << fmha_args.s_randval << "\n"; + + *log_file << "dropout_seed_ptr: " << std::get<0>(std::get>(fmha_args.drop_seed_offset)) << "\n"; + *log_file << "dropout_offset_ptr: " << std::get<1>(std::get>(fmha_args.drop_seed_offset)) << "\n"; } void dump_fwd_timings(const char* dump_path, float average_runtime){ @@ -179,14 +174,9 @@ hipError_t ck_attn_fwd( right = window_size_right; mask_enum mask_type = static_cast(attn_mask_type); - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; std::string data_type_str = get_data_type_str(dtype); @@ -277,7 +267,9 @@ hipError_t ck_attn_fwd( }(); // print ck traits and args when needed - log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); + if (auto* log_file = get_ck_log_stream()) { + log_fwd_config(log_file, __FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); + } float average_runtime = aiter::mha_fwd(fmha_args, stream_config, @@ -354,14 +346,9 @@ hipError_t ck_attn_varlen_fwd( bias_enum bias_type = bias_enum::no_bias; - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; std::string data_type_str = get_data_type_str(dtype); @@ -457,14 +444,17 @@ hipError_t ck_attn_varlen_fwd( // lse_thd_ptr used as buffer if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")){ if(std::string(env_p) == "1"){ - if(ck_fused_attn_log_config){ - std::cout << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + if (auto* log_file = get_ck_log_stream()) { + *log_file + << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization.\n"; } fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, cu_seqlen_q_padded_ptr, lse_thd_ptr, stream); } } // print ck traits and args when needed - log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); + if (auto* log_file = get_ck_log_stream()) { + log_fwd_config(log_file, __FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); + } float average_runtime = aiter::mha_fwd( fmha_args, diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index 26c92ca2b..72a0122d9 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -1,10 +1,14 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ #include +#include +#include +#include +#include #include "ck_fused_attn_utils.hpp" #include "ck_fused_attn/ck_fused_attn.hpp" #include "mask.hpp" @@ -13,6 +17,41 @@ namespace ck_fused_attn{ +bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefix, const std::string& log_dir_str) { + // Explicitly use std::cout as a fallback + std::filesystem::path log_dir(log_dir_str); + std::ostringstream filename; + filename << file_prefix << "_" << getpid() << "_" << std::this_thread::get_id() << ".log"; + log_file.open(log_dir / filename.str(), std::ios_base::app); + if (!log_file.is_open()) { + std::cerr << "Failed to open log file: " << (log_dir / filename.str()) << "\n"; + return false; + } + return true; +} + +std::ostream* get_ck_log_stream() { + thread_local std::ofstream log_file; + thread_local std::ostream* log_stream = nullptr; + thread_local bool initialized = false; + if (!initialized) { + initialized = true; + if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG")) { + std::string log_dir_str(env_p); + if (!log_dir_str.empty() && log_dir_str != "0") { + if (log_dir_str == "1") { + log_stream = &std::cout; + } + else if (open_ck_fused_attn_log_file(log_file, "ck_fused_attn", log_dir_str)) { + log_stream = &log_file; + } + } + } + } + + return log_stream; +} + std::string get_data_type_str(DType dtype){ std::string data_type_str; if(dtype==DType::kFloat16){ diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp index a75915ee2..3270ec32f 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ @@ -7,8 +7,7 @@ #ifndef CK_FUSED_ATTN_UTILS_H #define CK_FUSED_ATTN_UTILS_H -#include -#include +#include #include //forward declaration for ck_tile enum @@ -56,5 +55,7 @@ std::pair get_ck_bias_type_shape(BiasType attn_bias_type, uint64_t get_runtime_max_seqlen(uint64_t b, const void* cu_seqlen_ptr, const void* cu_seqlen_padded_ptr, void* workspace, hipStream_t stream); +std::ostream* get_ck_log_stream(); + }//namespace ck_fused_attn #endif // CK_FUSED_ATTN_UTILS_H