Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions ggml/src/ggml-cann/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ struct ggml_backend_cann_context {
cann_task_queue task_queue;
bool async_mode;
bool support_set_rows;
bool multi_stream_enabled; /**< Whether multi-stream execution is enabled. */
int num_streams; /**< Number of streams to use for parallel execution. */
aclrtEvent stream_events[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Events for stream synchronization. */
// Rope Cache
void* rope_init_ptr = nullptr;
void* rope_sin_ptr = nullptr;
Expand All @@ -387,6 +390,7 @@ struct ggml_backend_cann_context {
int64_t f32_one_cache_element = 0;

aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
int current_stream_idx = 0; /**< Index of the current active stream for multi-stream execution. */

/**
* @brief Constructor for initializing the context with a given device.
Expand All @@ -408,6 +412,14 @@ struct ggml_backend_cann_context {
GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. "
"Falling back to eager mode.\n", __func__);
}

multi_stream_enabled = parse_bool(get_env("GGML_CANN_MULTI_STREAM").value_or(""));
auto num_streams_env = get_env("GGML_CANN_NUM_STREAMS");
num_streams = num_streams_env.has_value() ? std::min(std::stoi(num_streams_env.value()), GGML_CANN_MAX_STREAMS) : 4;
if (multi_stream_enabled) {
GGML_LOG_INFO("%s: device %d multi-stream execution is ON with %d streams\n", __func__,
device, num_streams);
}
}

/**
Expand All @@ -423,6 +435,9 @@ struct ggml_backend_cann_context {
if (streams[i] != nullptr) {
ACL_CHECK(aclrtDestroyStream(streams[i]));
}
if (stream_events[i] != nullptr) {
ACL_CHECK(aclrtDestroyEvent(stream_events[i]));
}
}
if(rope_init_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_init_ptr));
Expand All @@ -443,22 +458,43 @@ struct ggml_backend_cann_context {

/**
* @brief Get or create a stream for a given index.
* @param stream Index of the stream.
* @param stream_idx Index of the stream.
* @return The stream corresponding to the given index.
*/
aclrtStream stream(int stream) {
if (streams[stream] == nullptr) {
aclrtStream stream(int stream_idx) {
if (streams[stream_idx] == nullptr) {
ggml_cann_set_device(device);
ACL_CHECK(aclrtCreateStream(&streams[stream]));
ACL_CHECK(aclrtCreateStream(&streams[stream_idx]));
}
return streams[stream];
return streams[stream_idx];
}

/**
* @brief Get or create the default stream (index 0).
* @return The default stream.
* @brief Get or create the current active stream (based on current_stream_idx).
* @return The current active stream.
*/
aclrtStream stream() { return stream(0); }
aclrtStream stream() { return stream(current_stream_idx); }

/**
* @brief Set the current active stream index for multi-stream execution.
* @param idx The stream index to set as active.
*/
void set_current_stream(int idx) {
current_stream_idx = idx % num_streams;
}

/**
* @brief Get or create an event for stream synchronization.
* @param stream_idx Index of the stream for which to get the event.
* @return The event for the specified stream.
*/
aclrtEvent get_stream_event(int stream_idx) {
if (stream_events[stream_idx] == nullptr) {
ggml_cann_set_device(device);
ACL_CHECK(aclrtCreateEvent(&stream_events[stream_idx]));
}
return stream_events[stream_idx];
}

// TODO: each stream should have a memory pool.
std::unique_ptr<ggml_cann_pool>
Expand Down
250 changes: 240 additions & 10 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@
#include <stdarg.h>
#include <aclnnop/aclnn_trans_matmul_weight.h>

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <mutex>
#include <queue>
#include <chrono>
#include <unordered_map>
#include <unordered_set>
#include <optional>
#include <vector>

#include "ggml-impl.h"
#include "ggml-backend-impl.h"
Expand Down Expand Up @@ -2063,7 +2066,8 @@ static bool ggml_backend_cann_cpy_tensor_async(
* @brief Synchronizes a CANN backend.
*
* This function synchronizes the specified CANN backend by waiting for all
* operations in its associated stream to complete.
* operations in its associated streams to complete. When multi-stream execution
* is enabled, it synchronizes all active streams.
*
* @param backend Pointer to the CANN backend structure to synchronize.
*/
Expand All @@ -2072,7 +2076,17 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
(ggml_backend_cann_context*)backend->context;
cann_ctx->task_queue.wait();
ggml_cann_set_device(cann_ctx->device);
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));

// Synchronize all active streams when multi-stream is enabled
if (cann_ctx->multi_stream_enabled) {
for (int i = 0; i < cann_ctx->num_streams; ++i) {
if (cann_ctx->streams[i] != nullptr) {
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->streams[i]));
}
}
} else {
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
}
}

#ifdef USE_ACL_GRAPH
Expand Down Expand Up @@ -2173,6 +2187,148 @@ static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx,
}
#endif // USE_ACL_GRAPH

/**
* @brief Check if a tensor operation is an empty/view operation that doesn't do computation.
*
* @param node The tensor node to check.
* @return true if the node is empty or a view operation, false otherwise.
*/
static bool ggml_cann_is_empty_op(const ggml_tensor * node) {
return node->op == GGML_OP_NONE ||
node->op == GGML_OP_RESHAPE ||
node->op == GGML_OP_TRANSPOSE ||
node->op == GGML_OP_VIEW ||
node->op == GGML_OP_PERMUTE;
}

/**
* @brief Check if dst tensor depends on src tensor.
*
* @param dst The destination tensor to check.
* @param src The source tensor to check.
* @return true if dst depends on src, false otherwise.
*/
static bool ggml_cann_is_src_of(const ggml_tensor * dst, const ggml_tensor * src) {
// Check direct source dependency
for (int s = 0; s < GGML_MAX_SRC; ++s) {
if (dst->src[s] == src) {
return true;
}
}
// Check implicit dependency if they view the same tensor
const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst;
const ggml_tensor * src2 = src->view_src ? src->view_src : src;
if (dst2 == src2) {
return true;
}
return false;
}

/**
* @brief Optimize the graph to allow more parallel execution.
*
* @param backend The CANN backend.
* @param graph The computation graph to optimize.
*/
static void ggml_cann_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * graph) {
ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *)backend->context;

static bool disable_optimize_graph = parse_bool(get_env("GGML_CANN_DISABLE_OPTIMIZE_GRAPH").value_or(""));
if (disable_optimize_graph) {
return;
}

int num_small_nodes = 0;
int num_counted_nodes = 0;
for (int i = 0; i < graph->n_nodes; ++i) {
if (!ggml_cann_is_empty_op(graph->nodes[i]) &&
graph->nodes[i]->op != GGML_OP_SET_ROWS) {
if (ggml_nrows(graph->nodes[i]) <= 8) {
num_small_nodes++;
}
num_counted_nodes++;
}
}
if (num_small_nodes < num_counted_nodes / 2) {
return;
}

std::vector<ggml_tensor *> new_order;
std::vector<bool> used(graph->n_nodes, false);
int first_unused = 0;

while (first_unused < graph->n_nodes) {
std::vector<int> current_set;

current_set.push_back(first_unused);

const int NUM_TO_CHECK = 20;
for (int j = first_unused + 1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
if (used[j]) {
continue;
}
if (ggml_cann_is_empty_op(graph->nodes[j])) {
continue;
}
bool ok = true;
for (int c = first_unused; c < j; ++c) {
if (!used[c] &&
ggml_cann_is_src_of(graph->nodes[j], graph->nodes[c]) &&
!(j == c + 1 && c == (int)current_set.back() &&
graph->nodes[c]->op == GGML_OP_RMS_NORM &&
graph->nodes[j]->op == GGML_OP_MUL)) {
ok = false;
break;
}
}
if (ok) {
current_set.push_back(j);
}
}

if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) {
for (int j = first_unused + 1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
if (used[j]) {
continue;
}
if (!ggml_cann_is_empty_op(graph->nodes[j])) {
continue;
}
bool ok = true;
for (int c = first_unused; c < j; ++c) {
bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end();
// skip views whose srcs haven't been processed
if (!used[c] &&
ggml_cann_is_src_of(graph->nodes[j], graph->nodes[c]) &&
!c_in_current_set) {
ok = false;
break;
}
}
if (ok) {
current_set.push_back(j);
}
}
}

// Push the current set into new_order
for (auto c : current_set) {
new_order.push_back(graph->nodes[c]);
used[c] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
first_unused++;
}
}

// Replace the graph with the new order
for (int i = 0; i < graph->n_nodes; ++i) {
graph->nodes[i] = new_order[i];
}

GGML_UNUSED(cann_ctx);
}

/**
* @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
*
Expand Down Expand Up @@ -2201,18 +2357,89 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
// With the use of CANN graphs, the execution will be performed by the graph launch.
if (!use_cann_graph || cann_graph_update_required) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (cann_ctx->multi_stream_enabled && cann_ctx->num_streams > 1) {
std::unordered_map<const ggml_tensor*, int> tensor_stream_map;
int current_stream = 0;
int nodes_in_current_batch = 0;
const int batch_size = 4;

if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];

if (ggml_is_empty(node) || ggml_cann_is_empty_op(node)) {
continue;
}

int max_src_stream = -1;
for (int s = 0; s < GGML_MAX_SRC; ++s) {
if (node->src[s] != nullptr) {
const ggml_tensor* src = node->src[s]->view_src ? node->src[s]->view_src : node->src[s];
auto it = tensor_stream_map.find(src);
if (it != tensor_stream_map.end()) {
max_src_stream = std::max(max_src_stream, it->second);
}
}
}

int target_stream;
if (max_src_stream >= 0) {
target_stream = max_src_stream;
} else {
target_stream = current_stream;
nodes_in_current_batch++;
if (nodes_in_current_batch >= batch_size) {
current_stream = (current_stream + 1) % cann_ctx->num_streams;
nodes_in_current_batch = 0;
}
}

for (int s = 0; s < GGML_MAX_SRC; ++s) {
if (node->src[s] != nullptr) {
const ggml_tensor* src = node->src[s]->view_src ? node->src[s]->view_src : node->src[s];
auto it = tensor_stream_map.find(src);
if (it != tensor_stream_map.end() && it->second != target_stream) {
aclrtEvent event = cann_ctx->get_stream_event(it->second);
ACL_CHECK(aclrtRecordEvent(event, cann_ctx->stream(it->second)));
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(target_stream), event));
}
}
}

cann_ctx->set_current_stream(target_stream);
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);

const ggml_tensor* out_tensor = node->view_src ? node->view_src : node;
tensor_stream_map[out_tensor] = target_stream;
}

bool ok = ggml_cann_compute_forward(*cann_ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
for (int s = 0; s < cann_ctx->num_streams; ++s) {
if (cann_ctx->streams[s] != nullptr) {
aclrtEvent event = cann_ctx->get_stream_event(s);
ACL_CHECK(aclrtRecordEvent(event, cann_ctx->stream(s)));
if (s != 0) {
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(0), event));
}
}
}
cann_ctx->set_current_stream(0);
} else {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];

if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}

bool ok = ggml_cann_compute_forward(*cann_ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);
}
GGML_ASSERT(ok);
}
}

Expand Down Expand Up @@ -2247,6 +2474,9 @@ static enum ggml_status ggml_backend_cann_graph_compute(
(ggml_backend_cann_context*)backend->context;
ggml_cann_set_device(cann_ctx->device);
release_nz_workspace();

ggml_cann_optimize_graph(backend, cgraph);

#ifdef USE_ACL_GRAPH
bool use_cann_graph = true;
bool cann_graph_update_required = false;
Expand Down