Skip to content

Commit b9305e7

Browse files
committed
CANN: fix multi-thread set_tensor race conditions
When ollama calls ggml_backend_tensor_set from multiple threads (each writing a different chunk of the same tensor), the CANN backend had three concurrency issues: 1. Quantized tensors (Q4_0/Q8_0) require a full-tensor format transform before uploading to device. Per-chunk transforms produced corrupt data. 2. ND-to-NZ weight conversion requires complete tensor data on device. Per-chunk conversion operated on incomplete data. 3. The global g_nz_workspaces array had unprotected concurrent access. Fix by introducing a TensorSetTracker that accumulates write progress per tensor. For quantized tensors, raw data is staged in a host buffer and the transform + upload is deferred until all chunks arrive. For NZ weights, chunks are uploaded directly but conversion is deferred. The tracker and its staging buffer are released immediately after post-processing completes. Add per-device mutex to g_nz_workspaces to prevent data races.
1 parent a0ed91a commit b9305e7

File tree

1 file changed

+105
-14
lines changed

1 file changed

+105
-14
lines changed

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 105 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,13 @@
3636
#include <cmath>
3737
#include <cstdio>
3838
#include <cstring>
39+
#include <memory>
3940
#include <mutex>
4041
#include <optional>
4142
#include <queue>
43+
#include <unordered_map>
4244
#include <unordered_set>
45+
#include <vector>
4346

4447
#define GGML_COMMON_DECL_C
4548

@@ -770,6 +773,28 @@ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(i
770773
}
771774

772775
// cann buffer
776+
777+
/**
778+
* @brief Tracks multi-threaded write progress for a single tensor.
779+
*
780+
* When multiple threads call set_tensor on different chunks of the same tensor,
781+
* this tracker accumulates progress and defers post-processing (quantized format
782+
* transform or ND-to-NZ conversion) until all data has been written.
783+
*/
784+
struct TensorSetTracker {
785+
std::mutex mtx; ///< Protects bytes_written and host_buffer access.
786+
size_t bytes_written; ///< Accumulated bytes written so far.
787+
size_t total_bytes; ///< Total bytes of the tensor (ggml_nbytes).
788+
std::vector<uint8_t> host_buffer; ///< Staging buffer for quantized tensors only.
789+
790+
TensorSetTracker(size_t total, bool need_staging)
791+
: bytes_written(0), total_bytes(total) {
792+
if (need_staging) {
793+
host_buffer.resize(total);
794+
}
795+
}
796+
};
797+
773798
/**
774799
* @brief Context for managing a CANN buffer associated with a specific device.
775800
*
@@ -780,6 +805,9 @@ struct ggml_backend_cann_buffer_context {
780805
int32_t device; ///< The device ID associated with this buffer context.
781806
void * dev_ptr = nullptr; ///< Pointer to the device memory allocated for the buffer.
782807

808+
std::mutex tracker_mutex;
809+
std::unordered_map<ggml_tensor *, std::shared_ptr<TensorSetTracker>> trackers;
810+
783811
/**
784812
* @brief Constructor to initialize the CANN buffer context.
785813
*
@@ -792,6 +820,28 @@ struct ggml_backend_cann_buffer_context {
792820
* @brief Destructor to free the device memory allocated for the buffer.
793821
*/
794822
~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
823+
824+
/**
825+
* @brief Get or create a tracker for the given tensor.
826+
*/
827+
std::shared_ptr<TensorSetTracker> get_or_create_tracker(ggml_tensor * tensor, bool need_staging) {
828+
std::lock_guard<std::mutex> lock(tracker_mutex);
829+
auto it = trackers.find(tensor);
830+
if (it == trackers.end()) {
831+
auto tracker = std::make_shared<TensorSetTracker>(ggml_nbytes(tensor), need_staging);
832+
trackers[tensor] = tracker;
833+
return tracker;
834+
}
835+
return it->second;
836+
}
837+
838+
/**
839+
* @brief Remove the tracker for the given tensor.
840+
*/
841+
void remove_tracker(ggml_tensor * tensor) {
842+
std::lock_guard<std::mutex> lock(tracker_mutex);
843+
trackers.erase(tensor);
844+
}
795845
};
796846

797847
// cann buffer type
@@ -1124,6 +1174,7 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(ggml_backend_buffer
11241174
* designed to be used with a global array, one per device.
11251175
*/
11261176
struct ggml_cann_nz_workspace {
1177+
std::mutex mtx; // Protects ptr/allocated from concurrent access
11271178
void * ptr; // Pointer to allocated device buffer
11281179
size_t allocated; // Size of currently allocated buffer in bytes
11291180

@@ -1190,13 +1241,15 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
11901241
* @note The workspace buffer used in this function is managed globally and reused
11911242
* across calls. This reduces overhead from repeated memory allocation and deallocation.
11921243
*/
1193-
static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) {
1194-
acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
1244+
static void weight_format_to_nz(ggml_tensor * tensor, int device) {
1245+
acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, 0);
11951246
uint64_t workspaceSize = 0;
11961247
aclOpExecutor * executor;
11971248

11981249
// TransMatmulWeight
11991250
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor));
1251+
1252+
std::lock_guard<std::mutex> lock(g_nz_workspaces[device].mtx);
12001253
// Avoid frequent malloc/free of the workspace.
12011254
g_nz_workspaces[device].realloc(workspaceSize);
12021255

@@ -1210,7 +1263,13 @@ static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device)
12101263
* @brief Set tensor data in a CANN buffer.
12111264
*
12121265
* This function sets tensor data in a CANN buffer, handling transformations
1213-
* if needed based on the tensor's type.
1266+
* if needed based on the tensor's type. It supports multi-threaded calls
1267+
* where different threads write different chunks of the same tensor.
1268+
*
1269+
* For quantized tensors (Q4_0/Q8_0), data is staged in a host buffer and
1270+
* the format transform is deferred until all chunks are written.
1271+
* For NZ weight tensors, chunks are uploaded directly but the ND-to-NZ
1272+
* conversion is deferred until all chunks are written.
12141273
*
12151274
* @param buffer The CANN buffer where the tensor data will be set.
12161275
* @param tensor Pointer to the tensor whose data will be set.
@@ -1226,25 +1285,57 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
12261285
ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
12271286

12281287
ggml_cann_set_device(ctx->device);
1229-
// TODO: refer to cann(#6017), it use thread's default stream.
1230-
// For acl, synchronous functions use this default stream.
1231-
// Why aclrtSynchronizeDevice?
12321288

12331289
// Only check env once.
12341290
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
1235-
if (!need_transform(tensor->type)) {
1291+
1292+
const bool needs_transform = need_transform(tensor->type);
1293+
const bool needs_nz = !needs_transform && weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor);
1294+
1295+
if (!needs_transform && !needs_nz) {
1296+
// Plain tensor: direct memcpy is safe per-chunk, no tracker needed.
12361297
ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
1237-
if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
1298+
return;
1299+
}
1300+
1301+
// Needs post-processing: use tracker to defer until all chunks are written.
1302+
auto tracker = ctx->get_or_create_tracker(tensor, needs_transform);
1303+
1304+
bool all_done = false;
1305+
{
1306+
std::lock_guard<std::mutex> lock(tracker->mtx);
1307+
1308+
if (needs_transform) {
1309+
// Stage raw data in host buffer; transform requires the full tensor.
1310+
memcpy(tracker->host_buffer.data() + offset, data, size);
1311+
} else {
1312+
// NZ case: upload chunk to device immediately (different offsets, safe).
1313+
ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
1314+
}
1315+
1316+
tracker->bytes_written += size;
1317+
all_done = (tracker->bytes_written >= tracker->total_bytes);
1318+
}
1319+
1320+
if (all_done) {
1321+
if (needs_transform) {
1322+
// All data staged, now transform entire tensor and upload at once.
1323+
size_t total = tracker->total_bytes;
1324+
void * transform_buf = malloc(total);
1325+
ggml_backend_cann_transform(tensor, tracker->host_buffer.data(), transform_buf);
1326+
ACL_CHECK(aclrtMemcpy(tensor->data, total, transform_buf, total, ACL_MEMCPY_HOST_TO_DEVICE));
1327+
free(transform_buf);
1328+
}
1329+
1330+
if (needs_nz) {
1331+
// All data on device, now convert entire tensor to NZ format.
12381332
GGML_ASSERT(tensor->ne[2] == 1);
12391333
GGML_ASSERT(tensor->ne[3] == 1);
1240-
weight_format_to_nz(tensor, offset, ctx->device);
1334+
weight_format_to_nz(tensor, ctx->device);
12411335
}
1242-
} else {
1243-
void * transform_buffer = malloc(size);
1244-
ggml_backend_cann_transform(tensor, data, transform_buffer);
12451336

1246-
ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE));
1247-
free(transform_buffer);
1337+
// Cleanup: release the tracker and its host_buffer memory.
1338+
ctx->remove_tracker(tensor);
12481339
}
12491340
}
12501341

0 commit comments

Comments
 (0)