From 7ac470b20a0589bf5efd1ddc7a1c51ba8bf7b8c5 Mon Sep 17 00:00:00 2001 From: solarisch Date: Sun, 28 Dec 2025 23:17:44 +0800 Subject: [PATCH] feat: multi stream support --- ggml/src/ggml-cann/common.h | 52 ++++++- ggml/src/ggml-cann/ggml-cann.cpp | 250 +++++++++++++++++++++++++++++-- 2 files changed, 284 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 33794062f56..ee5b993752a 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -375,6 +375,9 @@ struct ggml_backend_cann_context { cann_task_queue task_queue; bool async_mode; bool support_set_rows; + bool multi_stream_enabled; /**< Whether multi-stream execution is enabled. */ + int num_streams; /**< Number of streams to use for parallel execution. */ + aclrtEvent stream_events[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Events for stream synchronization. */ // Rope Cache void* rope_init_ptr = nullptr; void* rope_sin_ptr = nullptr; @@ -387,6 +390,7 @@ struct ggml_backend_cann_context { int64_t f32_one_cache_element = 0; aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */ + int current_stream_idx = 0; /**< Index of the current active stream for multi-stream execution. */ /** * @brief Constructor for initializing the context with a given device. @@ -408,6 +412,14 @@ struct ggml_backend_cann_context { GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. " "Falling back to eager mode.\n", __func__); } + + multi_stream_enabled = parse_bool(get_env("GGML_CANN_MULTI_STREAM").value_or("")); + auto num_streams_env = get_env("GGML_CANN_NUM_STREAMS"); + num_streams = num_streams_env.has_value() ? std::min(std::stoi(num_streams_env.value()), GGML_CANN_MAX_STREAMS) : 4; + if (multi_stream_enabled) { + GGML_LOG_INFO("%s: device %d multi-stream execution is ON with %d streams\n", __func__, + device, num_streams); + } } /** @@ -423,6 +435,9 @@ struct ggml_backend_cann_context { if (streams[i] != nullptr) { ACL_CHECK(aclrtDestroyStream(streams[i])); } + if (stream_events[i] != nullptr) { + ACL_CHECK(aclrtDestroyEvent(stream_events[i])); + } } if(rope_init_ptr != nullptr) { ACL_CHECK(aclrtFree(rope_init_ptr)); @@ -443,22 +458,43 @@ struct ggml_backend_cann_context { /** * @brief Get or create a stream for a given index. - * @param stream Index of the stream. + * @param stream_idx Index of the stream. * @return The stream corresponding to the given index. */ - aclrtStream stream(int stream) { - if (streams[stream] == nullptr) { + aclrtStream stream(int stream_idx) { + if (streams[stream_idx] == nullptr) { ggml_cann_set_device(device); - ACL_CHECK(aclrtCreateStream(&streams[stream])); + ACL_CHECK(aclrtCreateStream(&streams[stream_idx])); } - return streams[stream]; + return streams[stream_idx]; } /** - * @brief Get or create the default stream (index 0). - * @return The default stream. + * @brief Get or create the current active stream (based on current_stream_idx). + * @return The current active stream. */ - aclrtStream stream() { return stream(0); } + aclrtStream stream() { return stream(current_stream_idx); } + + /** + * @brief Set the current active stream index for multi-stream execution. + * @param idx The stream index to set as active. + */ + void set_current_stream(int idx) { + current_stream_idx = idx % num_streams; + } + + /** + * @brief Get or create an event for stream synchronization. + * @param stream_idx Index of the stream for which to get the event. + * @return The event for the specified stream. + */ + aclrtEvent get_stream_event(int stream_idx) { + if (stream_events[stream_idx] == nullptr) { + ggml_cann_set_device(device); + ACL_CHECK(aclrtCreateEvent(&stream_events[stream_idx])); + } + return stream_events[stream_idx]; + } // TODO: each stream should have a memory pool. std::unique_ptr diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index cb8af42ebf9..bad3334dc59 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -26,14 +26,17 @@ #include #include +#include #include #include #include #include #include #include +#include #include #include +#include #include "ggml-impl.h" #include "ggml-backend-impl.h" @@ -2063,7 +2066,8 @@ static bool ggml_backend_cann_cpy_tensor_async( * @brief Synchronizes a CANN backend. * * This function synchronizes the specified CANN backend by waiting for all - * operations in its associated stream to complete. + * operations in its associated streams to complete. When multi-stream execution + * is enabled, it synchronizes all active streams. * * @param backend Pointer to the CANN backend structure to synchronize. */ @@ -2072,7 +2076,17 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) { (ggml_backend_cann_context*)backend->context; cann_ctx->task_queue.wait(); ggml_cann_set_device(cann_ctx->device); - ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); + + // Synchronize all active streams when multi-stream is enabled + if (cann_ctx->multi_stream_enabled) { + for (int i = 0; i < cann_ctx->num_streams; ++i) { + if (cann_ctx->streams[i] != nullptr) { + ACL_CHECK(aclrtSynchronizeStream(cann_ctx->streams[i])); + } + } + } else { + ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); + } } #ifdef USE_ACL_GRAPH @@ -2173,6 +2187,148 @@ static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, } #endif // USE_ACL_GRAPH +/** + * @brief Check if a tensor operation is an empty/view operation that doesn't do computation. + * + * @param node The tensor node to check. + * @return true if the node is empty or a view operation, false otherwise. + */ +static bool ggml_cann_is_empty_op(const ggml_tensor * node) { + return node->op == GGML_OP_NONE || + node->op == GGML_OP_RESHAPE || + node->op == GGML_OP_TRANSPOSE || + node->op == GGML_OP_VIEW || + node->op == GGML_OP_PERMUTE; +} + +/** + * @brief Check if dst tensor depends on src tensor. + * + * @param dst The destination tensor to check. + * @param src The source tensor to check. + * @return true if dst depends on src, false otherwise. + */ +static bool ggml_cann_is_src_of(const ggml_tensor * dst, const ggml_tensor * src) { + // Check direct source dependency + for (int s = 0; s < GGML_MAX_SRC; ++s) { + if (dst->src[s] == src) { + return true; + } + } + // Check implicit dependency if they view the same tensor + const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst; + const ggml_tensor * src2 = src->view_src ? src->view_src : src; + if (dst2 == src2) { + return true; + } + return false; +} + +/** + * @brief Optimize the graph to allow more parallel execution. + * + * @param backend The CANN backend. + * @param graph The computation graph to optimize. + */ +static void ggml_cann_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * graph) { + ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *)backend->context; + + static bool disable_optimize_graph = parse_bool(get_env("GGML_CANN_DISABLE_OPTIMIZE_GRAPH").value_or("")); + if (disable_optimize_graph) { + return; + } + + int num_small_nodes = 0; + int num_counted_nodes = 0; + for (int i = 0; i < graph->n_nodes; ++i) { + if (!ggml_cann_is_empty_op(graph->nodes[i]) && + graph->nodes[i]->op != GGML_OP_SET_ROWS) { + if (ggml_nrows(graph->nodes[i]) <= 8) { + num_small_nodes++; + } + num_counted_nodes++; + } + } + if (num_small_nodes < num_counted_nodes / 2) { + return; + } + + std::vector new_order; + std::vector used(graph->n_nodes, false); + int first_unused = 0; + + while (first_unused < graph->n_nodes) { + std::vector current_set; + + current_set.push_back(first_unused); + + const int NUM_TO_CHECK = 20; + for (int j = first_unused + 1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (ggml_cann_is_empty_op(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + if (!used[c] && + ggml_cann_is_src_of(graph->nodes[j], graph->nodes[c]) && + !(j == c + 1 && c == (int)current_set.back() && + graph->nodes[c]->op == GGML_OP_RMS_NORM && + graph->nodes[j]->op == GGML_OP_MUL)) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + + if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) { + for (int j = first_unused + 1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (!ggml_cann_is_empty_op(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end(); + // skip views whose srcs haven't been processed + if (!used[c] && + ggml_cann_is_src_of(graph->nodes[j], graph->nodes[c]) && + !c_in_current_set) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + } + + // Push the current set into new_order + for (auto c : current_set) { + new_order.push_back(graph->nodes[c]); + used[c] = true; + } + while (first_unused < graph->n_nodes && used[first_unused]) { + first_unused++; + } + } + + // Replace the graph with the new order + for (int i = 0; i < graph->n_nodes; ++i) { + graph->nodes[i] = new_order[i]; + } + + GGML_UNUSED(cann_ctx); +} + /** * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API. * @@ -2201,18 +2357,89 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph. // With the use of CANN graphs, the execution will be performed by the graph launch. if (!use_cann_graph || cann_graph_update_required) { - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; + if (cann_ctx->multi_stream_enabled && cann_ctx->num_streams > 1) { + std::unordered_map tensor_stream_map; + int current_stream = 0; + int nodes_in_current_batch = 0; + const int batch_size = 4; - if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { - continue; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (ggml_is_empty(node) || ggml_cann_is_empty_op(node)) { + continue; + } + + int max_src_stream = -1; + for (int s = 0; s < GGML_MAX_SRC; ++s) { + if (node->src[s] != nullptr) { + const ggml_tensor* src = node->src[s]->view_src ? node->src[s]->view_src : node->src[s]; + auto it = tensor_stream_map.find(src); + if (it != tensor_stream_map.end()) { + max_src_stream = std::max(max_src_stream, it->second); + } + } + } + + int target_stream; + if (max_src_stream >= 0) { + target_stream = max_src_stream; + } else { + target_stream = current_stream; + nodes_in_current_batch++; + if (nodes_in_current_batch >= batch_size) { + current_stream = (current_stream + 1) % cann_ctx->num_streams; + nodes_in_current_batch = 0; + } + } + + for (int s = 0; s < GGML_MAX_SRC; ++s) { + if (node->src[s] != nullptr) { + const ggml_tensor* src = node->src[s]->view_src ? node->src[s]->view_src : node->src[s]; + auto it = tensor_stream_map.find(src); + if (it != tensor_stream_map.end() && it->second != target_stream) { + aclrtEvent event = cann_ctx->get_stream_event(it->second); + ACL_CHECK(aclrtRecordEvent(event, cann_ctx->stream(it->second))); + ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(target_stream), event)); + } + } + } + + cann_ctx->set_current_stream(target_stream); + bool ok = ggml_cann_compute_forward(*cann_ctx, node); + if (!ok) { + GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + + const ggml_tensor* out_tensor = node->view_src ? node->view_src : node; + tensor_stream_map[out_tensor] = target_stream; } - bool ok = ggml_cann_compute_forward(*cann_ctx, node); - if (!ok) { - GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + for (int s = 0; s < cann_ctx->num_streams; ++s) { + if (cann_ctx->streams[s] != nullptr) { + aclrtEvent event = cann_ctx->get_stream_event(s); + ACL_CHECK(aclrtRecordEvent(event, cann_ctx->stream(s))); + if (s != 0) { + ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(0), event)); + } + } + } + cann_ctx->set_current_stream(0); + } else { + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + + bool ok = ggml_cann_compute_forward(*cann_ctx, node); + if (!ok) { + GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); } - GGML_ASSERT(ok); } } @@ -2247,6 +2474,9 @@ static enum ggml_status ggml_backend_cann_graph_compute( (ggml_backend_cann_context*)backend->context; ggml_cann_set_device(cann_ctx->device); release_nz_workspace(); + + ggml_cann_optimize_graph(backend, cgraph); + #ifdef USE_ACL_GRAPH bool use_cann_graph = true; bool cann_graph_update_required = false;