3636#include < cmath>
3737#include < cstdio>
3838#include < cstring>
39+ #include < memory>
3940#include < mutex>
4041#include < optional>
4142#include < queue>
43+ #include < unordered_map>
4244#include < unordered_set>
45+ #include < vector>
4346
4447#define GGML_COMMON_DECL_C
4548
@@ -770,6 +773,28 @@ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(i
770773}
771774
772775// cann buffer
776+
777+ /* *
778+ * @brief Tracks multi-threaded write progress for a single tensor.
779+ *
780+ * When multiple threads call set_tensor on different chunks of the same tensor,
781+ * this tracker accumulates progress and defers post-processing (quantized format
782+ * transform or ND-to-NZ conversion) until all data has been written.
783+ */
784+ struct TensorSetTracker {
785+ std::mutex mtx; // /< Protects bytes_written and host_buffer access.
786+ size_t bytes_written; // /< Accumulated bytes written so far.
787+ size_t total_bytes; // /< Total bytes of the tensor (ggml_nbytes).
788+ std::vector<uint8_t > host_buffer; // /< Staging buffer for quantized tensors only.
789+
790+ TensorSetTracker (size_t total, bool need_staging)
791+ : bytes_written(0 ), total_bytes(total) {
792+ if (need_staging) {
793+ host_buffer.resize (total);
794+ }
795+ }
796+ };
797+
773798/* *
774799 * @brief Context for managing a CANN buffer associated with a specific device.
775800 *
@@ -780,6 +805,9 @@ struct ggml_backend_cann_buffer_context {
780805 int32_t device; // /< The device ID associated with this buffer context.
781806 void * dev_ptr = nullptr ; // /< Pointer to the device memory allocated for the buffer.
782807
808+ std::mutex tracker_mutex;
809+ std::unordered_map<ggml_tensor *, std::shared_ptr<TensorSetTracker>> trackers;
810+
783811 /* *
784812 * @brief Constructor to initialize the CANN buffer context.
785813 *
@@ -792,6 +820,28 @@ struct ggml_backend_cann_buffer_context {
792820 * @brief Destructor to free the device memory allocated for the buffer.
793821 */
794822 ~ggml_backend_cann_buffer_context () { ACL_CHECK (aclrtFree (dev_ptr)); }
823+
824+ /* *
825+ * @brief Get or create a tracker for the given tensor.
826+ */
827+ std::shared_ptr<TensorSetTracker> get_or_create_tracker (ggml_tensor * tensor, bool need_staging) {
828+ std::lock_guard<std::mutex> lock (tracker_mutex);
829+ auto it = trackers.find (tensor);
830+ if (it == trackers.end ()) {
831+ auto tracker = std::make_shared<TensorSetTracker>(ggml_nbytes (tensor), need_staging);
832+ trackers[tensor] = tracker;
833+ return tracker;
834+ }
835+ return it->second ;
836+ }
837+
838+ /* *
839+ * @brief Remove the tracker for the given tensor.
840+ */
841+ void remove_tracker (ggml_tensor * tensor) {
842+ std::lock_guard<std::mutex> lock (tracker_mutex);
843+ trackers.erase (tensor);
844+ }
795845};
796846
797847// cann buffer type
@@ -1124,6 +1174,7 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(ggml_backend_buffer
11241174 * designed to be used with a global array, one per device.
11251175 */
11261176struct ggml_cann_nz_workspace {
1177+ std::mutex mtx; // Protects ptr/allocated from concurrent access
11271178 void * ptr; // Pointer to allocated device buffer
11281179 size_t allocated; // Size of currently allocated buffer in bytes
11291180
@@ -1190,13 +1241,15 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
11901241 * @note The workspace buffer used in this function is managed globally and reused
11911242 * across calls. This reduces overhead from repeated memory allocation and deallocation.
11921243 */
1193- static void weight_format_to_nz (ggml_tensor * tensor, size_t offset, int device) {
1194- acl_tensor_ptr weightTransposed = ggml_cann_create_tensor (tensor, tensor->ne , tensor->nb , 2 , ACL_FORMAT_ND, offset );
1244+ static void weight_format_to_nz (ggml_tensor * tensor, int device) {
1245+ acl_tensor_ptr weightTransposed = ggml_cann_create_tensor (tensor, tensor->ne , tensor->nb , 2 , ACL_FORMAT_ND, 0 );
11951246 uint64_t workspaceSize = 0 ;
11961247 aclOpExecutor * executor;
11971248
11981249 // TransMatmulWeight
11991250 ACL_CHECK (aclnnTransMatmulWeightGetWorkspaceSize (weightTransposed.get (), &workspaceSize, &executor));
1251+
1252+ std::lock_guard<std::mutex> lock (g_nz_workspaces[device].mtx );
12001253 // Avoid frequent malloc/free of the workspace.
12011254 g_nz_workspaces[device].realloc (workspaceSize);
12021255
@@ -1210,7 +1263,13 @@ static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device)
12101263 * @brief Set tensor data in a CANN buffer.
12111264 *
12121265 * This function sets tensor data in a CANN buffer, handling transformations
1213- * if needed based on the tensor's type.
1266+ * if needed based on the tensor's type. It supports multi-threaded calls
1267+ * where different threads write different chunks of the same tensor.
1268+ *
1269+ * For quantized tensors (Q4_0/Q8_0), data is staged in a host buffer and
1270+ * the format transform is deferred until all chunks are written.
1271+ * For NZ weight tensors, chunks are uploaded directly but the ND-to-NZ
1272+ * conversion is deferred until all chunks are written.
12141273 *
12151274 * @param buffer The CANN buffer where the tensor data will be set.
12161275 * @param tensor Pointer to the tensor whose data will be set.
@@ -1226,25 +1285,57 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
12261285 ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context ;
12271286
12281287 ggml_cann_set_device (ctx->device );
1229- // TODO: refer to cann(#6017), it use thread's default stream.
1230- // For acl, synchronous functions use this default stream.
1231- // Why aclrtSynchronizeDevice?
12321288
12331289 // Only check env once.
12341290 static bool weight_to_nz = parse_bool (get_env_as_lowercase (" GGML_CANN_WEIGHT_NZ" ).value_or (" on" ));
1235- if (!need_transform (tensor->type )) {
1291+
1292+ const bool needs_transform = need_transform (tensor->type );
1293+ const bool needs_nz = !needs_transform && weight_to_nz && is_matmul_weight ((const ggml_tensor *) tensor);
1294+
1295+ if (!needs_transform && !needs_nz) {
1296+ // Plain tensor: direct memcpy is safe per-chunk, no tracker needed.
12361297 ACL_CHECK (aclrtMemcpy ((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
1237- if (weight_to_nz && is_matmul_weight ((const ggml_tensor *) tensor)) {
1298+ return ;
1299+ }
1300+
1301+ // Needs post-processing: use tracker to defer until all chunks are written.
1302+ auto tracker = ctx->get_or_create_tracker (tensor, needs_transform);
1303+
1304+ bool all_done = false ;
1305+ {
1306+ std::lock_guard<std::mutex> lock (tracker->mtx );
1307+
1308+ if (needs_transform) {
1309+ // Stage raw data in host buffer; transform requires the full tensor.
1310+ memcpy (tracker->host_buffer .data () + offset, data, size);
1311+ } else {
1312+ // NZ case: upload chunk to device immediately (different offsets, safe).
1313+ ACL_CHECK (aclrtMemcpy ((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
1314+ }
1315+
1316+ tracker->bytes_written += size;
1317+ all_done = (tracker->bytes_written >= tracker->total_bytes );
1318+ }
1319+
1320+ if (all_done) {
1321+ if (needs_transform) {
1322+ // All data staged, now transform entire tensor and upload at once.
1323+ size_t total = tracker->total_bytes ;
1324+ void * transform_buf = malloc (total);
1325+ ggml_backend_cann_transform (tensor, tracker->host_buffer .data (), transform_buf);
1326+ ACL_CHECK (aclrtMemcpy (tensor->data , total, transform_buf, total, ACL_MEMCPY_HOST_TO_DEVICE));
1327+ free (transform_buf);
1328+ }
1329+
1330+ if (needs_nz) {
1331+ // All data on device, now convert entire tensor to NZ format.
12381332 GGML_ASSERT (tensor->ne [2 ] == 1 );
12391333 GGML_ASSERT (tensor->ne [3 ] == 1 );
1240- weight_format_to_nz (tensor, offset, ctx->device );
1334+ weight_format_to_nz (tensor, ctx->device );
12411335 }
1242- } else {
1243- void * transform_buffer = malloc (size);
1244- ggml_backend_cann_transform (tensor, data, transform_buffer);
12451336
1246- ACL_CHECK ( aclrtMemcpy (( char *) tensor-> data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE));
1247- free (transform_buffer );
1337+ // Cleanup: release the tracker and its host_buffer memory.
1338+ ctx-> remove_tracker (tensor );
12481339 }
12491340}
12501341
0 commit comments