diff --git a/cpp/src/cluster/detail/agglomerative.cuh b/cpp/src/cluster/detail/agglomerative.cuh index 16c2822c52..bc920f7701 100644 --- a/cpp/src/cluster/detail/agglomerative.cuh +++ b/cpp/src/cluster/detail/agglomerative.cuh @@ -5,9 +5,13 @@ #pragma once +#include +#include +#include #include #include #include +#include #include #include @@ -108,9 +112,15 @@ void build_dendrogram_host(raft::resources const& handle, std::vector mst_dst_h(n_edges); std::vector mst_weights_h(n_edges); - raft::update_host(mst_src_h.data(), rows, n_edges, stream); - raft::update_host(mst_dst_h.data(), cols, n_edges, stream); - raft::update_host(mst_weights_h.data(), data, n_edges, stream); + raft::copy(handle, + raft::make_host_vector_view(mst_src_h.data(), n_edges), + raft::make_device_vector_view(rows, n_edges)); + raft::copy(handle, + raft::make_host_vector_view(mst_dst_h.data(), n_edges), + raft::make_device_vector_view(cols, n_edges)); + raft::copy(handle, + raft::make_host_vector_view(mst_weights_h.data(), n_edges), + raft::make_device_vector_view(data, n_edges)); raft::resource::sync_stream(handle, stream); @@ -138,9 +148,15 @@ void build_dendrogram_host(raft::resources const& handle, U.perform_union(aa, bb); } - raft::update_device(children, children_h.data(), n_edges * 2, stream); - raft::update_device(out_size, out_size_h.data(), n_edges, stream); - raft::update_device(out_delta, out_delta_h.data(), n_edges, stream); + raft::copy(handle, + raft::make_device_vector_view(children, n_edges * 2), + raft::make_host_vector_view(children_h.data(), n_edges * 2)); + raft::copy(handle, + raft::make_device_vector_view(out_size, n_edges), + raft::make_host_vector_view(out_size_h.data(), n_edges)); + raft::copy(handle, + raft::make_device_vector_view(out_delta, n_edges), + raft::make_host_vector_view(out_delta_h.data(), n_edges)); } template @@ -236,7 +252,8 @@ void extract_flattened_clusters(raft::resources const& handle, // Handle special case where n_clusters == 1 if (n_clusters == 1) { - thrust::fill(thrust_policy, labels, labels + n_leaves, 0); + raft::matrix::fill( + handle, raft::make_device_vector_view(labels, n_leaves), value_idx(0)); } else { /** * Compute levels for each node diff --git a/cpp/src/cluster/detail/connectivities.cuh b/cpp/src/cluster/detail/connectivities.cuh index 1f0adad334..1737eead12 100644 --- a/cpp/src/cluster/detail/connectivities.cuh +++ b/cpp/src/cluster/detail/connectivities.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -9,11 +9,13 @@ #include "./kmeans_common.cuh" #include #include +#include +#include +#include #include #include #include #include -#include #include #include #include @@ -85,9 +87,9 @@ struct distance_graph_impl { bool self_loop = row == col; return (self_loop * std::numeric_limits::max()) + (!self_loop * val); }, - rows_view, - cols_view, - vals_in_view); + raft::make_const_mdspan(rows_view), + raft::make_const_mdspan(cols_view), + raft::make_const_mdspan(vals_in_view)); raft::sparse::convert::sorted_coo_to_csr( knn_graph_coo.rows(), knn_graph_coo.nnz, indptr.data(), m + 1, stream); @@ -144,7 +146,7 @@ void pairwise_distances(const raft::resources& handle, raft::make_device_vector_view(indptr, m), [=] __device__(value_idx idx) { return idx * m; }); - raft::update_device(indptr + m, &nnz, 1, stream); + raft::copy(handle, raft::make_device_scalar_view(indptr + m), raft::make_host_scalar_view(&nnz)); // TODO: It would ultimately be nice if the MST could accept // dense inputs directly so we don't need to double the memory @@ -157,11 +159,14 @@ void pairwise_distances(const raft::resources& handle, // self-loops get max distance auto data_view = raft::make_device_vector_view(data, nnz); - raft::linalg::map_offset(handle, data_view, [=] __device__(value_idx idx) { - value_t val = data[idx]; - bool self_loop = idx % m == idx / m; - return (self_loop * std::numeric_limits::max()) + (!self_loop * val); - }); + raft::linalg::map_offset( + handle, + data_view, + [=] __device__(value_idx idx, value_t val) { + bool self_loop = idx % m == idx / m; + return (self_loop * std::numeric_limits::max()) + (!self_loop * val); + }, + raft::make_const_mdspan(data_view)); } /** diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 0a27d3b351..6e7bff8450 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -10,8 +10,11 @@ #include #include +#include #include +#include #include +#include #include #include #include @@ -20,12 +23,15 @@ #include #include #include +#include #include #include #include +#include #include #include #include +#include #include #include #include @@ -34,9 +40,6 @@ #include #include -#include -#include -#include #include #include @@ -133,8 +136,7 @@ void kmeansPlusPlus(raft::resources const& handle, if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream); + raft::linalg::norm(handle, X, L2NormX.view()); } raft::random::RngState rng(params.rng_state.seed, params.rng_state.type); @@ -147,8 +149,9 @@ void kmeansPlusPlus(raft::resources const& handle, int n_clusters_picked = 1; // store the chosen centroid in the buffer - raft::copy( - centroidsRawData.data_handle(), initialCentroid.data_handle(), initialCentroid.size(), stream); + raft::copy(handle, + raft::make_device_vector_view(centroidsRawData.data_handle(), initialCentroid.size()), + raft::make_device_vector_view(initialCentroid.data_handle(), initialCentroid.size())); // C = initial set of centroids auto centroids = raft::make_device_matrix_view( @@ -198,12 +201,13 @@ void kmeansPlusPlus(raft::resources const& handle, // Calculate costPerCandidate[n_trials] where costPerCandidate[i] is the cluster cost when using // centroid candidate-i - raft::linalg::reduce(costPerCandidate.data_handle(), - minDistBuf.data_handle(), - minDistBuf.extent(1), - minDistBuf.extent(0), - static_cast(0), - stream); + raft::linalg::reduce( + handle, + raft::make_device_matrix_view( + minDistBuf.data_handle(), minDistBuf.extent(0), minDistBuf.extent(1)), + raft::make_device_vector_view(costPerCandidate.data_handle(), + minDistBuf.extent(0)), + static_cast(0)); // Greedy Choice - Choose the candidate that has minimum cluster cost // ArgMin operation below identifies the index of minimum cost in costPerCandidate @@ -229,21 +233,24 @@ void kmeansPlusPlus(raft::resources const& handle, stream); int bestCandidateIdx = -1; - raft::copy(&bestCandidateIdx, &minClusterIndexAndDistance.data()->key, 1, stream); + raft::copy(handle, + raft::make_host_scalar_view(&bestCandidateIdx), + raft::make_device_scalar_view(&minClusterIndexAndDistance.data()->key)); raft::resource::sync_stream(handle); /// <<< End of Step-3 >>> /// <<< Step-4 >>>: C = C U {x} // Update minimum cluster distance corresponding to the chosen centroid candidate - raft::copy(minClusterDistance.data_handle(), - minDistBuf.data_handle() + bestCandidateIdx * n_samples, - n_samples, - stream); + raft::copy(handle, + raft::make_device_vector_view(minClusterDistance.data_handle(), n_samples), + raft::make_device_vector_view( + minDistBuf.data_handle() + bestCandidateIdx * n_samples, n_samples)); - raft::copy(centroidsRawData.data_handle() + n_clusters_picked * n_features, - centroidCandidates.data_handle() + bestCandidateIdx * n_features, - n_features, - stream); + raft::copy(handle, + raft::make_device_vector_view( + centroidsRawData.data_handle() + n_clusters_picked * n_features, n_features), + raft::make_device_vector_view( + centroidCandidates.data_handle() + bestCandidateIdx * n_features, n_features)); ++n_clusters_picked; /// <<< End of Step-4 >>> @@ -383,8 +390,7 @@ void kmeans_fit_main(raft::resources const& handle, if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream); + raft::linalg::norm(handle, X, L2NormX.view()); } RAFT_LOG_DEBUG( @@ -419,23 +425,17 @@ void kmeans_fit_main(raft::resources const& handle, params.batch_centroids, workspace); - // Using TransformInputIteratorT to dereference an array of - // raft::KeyValuePair and converting them to just return the Key to be used - // in reduce_rows_by_key prims - cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; - thrust::transform_iterator, - raft::KeyValuePair*> - itr(minClusterAndDistance.data_handle(), conversion_op); - - update_centroids(handle, - X, - weight, - raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features), - itr, - wtInCluster.view(), - newCentroids.view(), - workspace); + update_centroids( + handle, + X, + weight, + raft::make_device_matrix_view( + centroidsRawData.data_handle(), n_clusters, n_features), + cuda::transform_iterator(minClusterAndDistance.data_handle(), + cuvs::cluster::kmeans::detail::KeyValueIndexOp{}), + wtInCluster.view(), + newCentroids.view(), + workspace); // compute the squared norm between the newCentroids and the original // centroids, destructor releases the resource @@ -448,10 +448,11 @@ void kmeans_fit_main(raft::resources const& handle, newCentroids.data_handle()); DataT sqrdNormError = 0; - raft::copy(&sqrdNormError, sqrdNorm.data_handle(), sqrdNorm.size(), stream); + raft::copy(handle, raft::make_host_scalar_view(&sqrdNormError), sqrdNorm.view()); - raft::copy( - centroidsRawData.data_handle(), newCentroids.data_handle(), newCentroids.size(), stream); + raft::copy(handle, + raft::make_device_vector_view(centroidsRawData.data_handle(), newCentroids.size()), + raft::make_device_vector_view(newCentroids.data_handle(), newCentroids.size())); bool done = false; if (params.inertia_check) { @@ -501,18 +502,17 @@ void kmeans_fit_main(raft::resources const& handle, params.batch_centroids, workspace); - // TODO: add different templates for InType of binaryOp to avoid thrust transform - thrust::transform(raft::resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - weight.data_handle(), - minClusterAndDistance.data_handle(), - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }); + raft::linalg::map( + handle, + minClusterAndDistance.view(), + [=] __device__(const raft::KeyValuePair kvp, DataT wt) { + raft::KeyValuePair res; + res.value = kvp.value * wt; + res.key = kvp.key; + return res; + }, + raft::make_const_mdspan(minClusterAndDistance.view()), + raft::make_const_mdspan(weight)); // calculate cluster cost phi_x(C) cuvs::cluster::kmeans::detail::computeClusterCost( @@ -586,13 +586,16 @@ void initScalableKMeansPlusPlus(raft::resources const& handle, // device buffer to flag the sample that is chosen as initial centroid auto isSampleCentroid = raft::make_device_vector(handle, n_samples); - raft::copy( - isSampleCentroid.data_handle(), h_isSampleCentroid.data(), isSampleCentroid.size(), stream); + raft::copy(handle, + raft::make_device_vector_view(isSampleCentroid.data_handle(), isSampleCentroid.size()), + raft::make_host_vector_view(h_isSampleCentroid.data(), isSampleCentroid.size())); rmm::device_uvector centroidsBuf(initialCentroid.size(), stream); // reset buffer to store the chosen centroid - raft::copy(centroidsBuf.data(), initialCentroid.data_handle(), initialCentroid.size(), stream); + raft::copy(handle, + raft::make_device_vector_view(centroidsBuf.data(), initialCentroid.size()), + raft::make_device_vector_view(initialCentroid.data_handle(), initialCentroid.size())); auto potentialCentroids = raft::make_device_matrix_view( centroidsBuf.data(), initialCentroid.extent(0), initialCentroid.extent(1)); @@ -606,8 +609,7 @@ void initScalableKMeansPlusPlus(raft::resources const& handle, auto L2NormX = raft::make_device_vector(handle, n_samples); if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream); + raft::linalg::norm(handle, X, L2NormX.view()); } auto minClusterDistanceVec = raft::make_device_vector(handle, n_samples); @@ -700,8 +702,10 @@ void initScalableKMeansPlusPlus(raft::resources const& handle, /// <<<< Step-5 >>> : C = C U C' // append the data in Cp to the buffer holding the potentialCentroids centroidsBuf.resize(centroidsBuf.size() + Cp.size(), stream); - raft::copy( - centroidsBuf.data() + centroidsBuf.size() - Cp.size(), Cp.data_handle(), Cp.size(), stream); + raft::copy(handle, + raft::make_device_vector_view(centroidsBuf.data() + centroidsBuf.size() - Cp.size(), + Cp.size()), + raft::make_device_vector_view(Cp.data_handle(), Cp.size())); IndexT tot_centroids = potentialCentroids.extent(0) + Cp.extent(0); potentialCentroids = @@ -760,16 +764,17 @@ void initScalableKMeansPlusPlus(raft::resources const& handle, initRandom(handle, rand_params, X, centroidsRawData); // copy centroids generated during kmeans|| iteration to the buffer - raft::copy(centroidsRawData.data_handle() + n_random_clusters * n_features, - potentialCentroids.data_handle(), - potentialCentroids.size(), - stream); + raft::copy( + handle, + raft::make_device_vector_view(centroidsRawData.data_handle() + n_random_clusters * n_features, + potentialCentroids.size()), + raft::make_device_vector_view(potentialCentroids.data_handle(), potentialCentroids.size())); } else { // found the required n_clusters - raft::copy(centroidsRawData.data_handle(), - potentialCentroids.data_handle(), - potentialCentroids.size(), - stream); + raft::copy( + handle, + raft::make_device_vector_view(centroidsRawData.data_handle(), potentialCentroids.size()), + raft::make_device_vector_view(potentialCentroids.data_handle(), potentialCentroids.size())); } } @@ -850,12 +855,9 @@ void kmeans_fit(raft::resources const& handle, rmm::device_uvector workspace(0, stream); auto weight = raft::make_device_vector(handle, n_samples); if (sample_weight.has_value()) - raft::copy(weight.data_handle(), sample_weight.value().data_handle(), n_samples, stream); + raft::copy(handle, weight.view(), sample_weight.value()); else - thrust::fill(raft::resource::get_thrust_policy(handle), - weight.data_handle(), - weight.data_handle() + weight.size(), - 1); + raft::matrix::fill(handle, weight.view(), DataT(1)); // check if weights sum up to n_samples checkWeight(handle, weight.view(), workspace); @@ -910,7 +912,9 @@ void kmeans_fit(raft::resources const& handle, seed_iter + 1, n_init); raft::copy( - centroidsRawData.data_handle(), centroids.data_handle(), n_clusters * n_features, stream); + handle, + raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features), + raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features)); } else { THROW("unknown initialization method to select initial centers"); } @@ -928,7 +932,9 @@ void kmeans_fit(raft::resources const& handle, inertia[0] = iter_inertia; n_iter[0] = n_current_iter; raft::copy( - centroids.data_handle(), centroidsRawData.data_handle(), n_clusters * n_features, stream); + handle, + raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features), + raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features)); } RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter[0] - %d", seed_iter + 1, @@ -998,12 +1004,9 @@ void kmeans_predict(raft::resources const& handle, rmm::device_uvector workspace(0, stream); auto weight = raft::make_device_vector(handle, n_samples); if (sample_weight.has_value()) - raft::copy(weight.data_handle(), sample_weight.value().data_handle(), n_samples, stream); + raft::copy(handle, weight.view(), sample_weight.value()); else - thrust::fill(raft::resource::get_thrust_policy(handle), - weight.data_handle(), - weight.data_handle() + weight.size(), - 1); + raft::matrix::fill(handle, weight.view(), DataT(1)); // check if weights sum up to n_samples if (normalize_weight) checkWeight(handle, weight.view(), workspace); @@ -1016,8 +1019,7 @@ void kmeans_predict(raft::resources const& handle, auto L2NormX = raft::make_device_vector(handle, n_samples); if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream); + raft::linalg::norm(handle, X, L2NormX.view()); } // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] @@ -1041,18 +1043,17 @@ void kmeans_predict(raft::resources const& handle, // calculate cluster cost phi_x(C) rmm::device_scalar clusterCostD(stream); - // TODO: add different templates for InType of binaryOp to avoid thrust transform - thrust::transform(raft::resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - weight.data_handle(), - minClusterAndDistance.data_handle(), - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }); + raft::linalg::map( + handle, + minClusterAndDistance.view(), + [=] __device__(const raft::KeyValuePair kvp, DataT wt) { + raft::KeyValuePair res; + res.value = kvp.value * wt; + res.key = kvp.key; + return res; + }, + raft::make_const_mdspan(minClusterAndDistance.view()), + raft::make_const_mdspan(weight.view())); cuvs::cluster::kmeans::detail::computeClusterCost( handle, @@ -1062,11 +1063,8 @@ void kmeans_predict(raft::resources const& handle, raft::value_op{}, raft::add_op{}); - thrust::transform(raft::resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - labels.data_handle(), - raft::key_op{}); + raft::linalg::map( + handle, labels, raft::key_op{}, raft::make_const_mdspan(minClusterAndDistance.view())); inertia[0] = clusterCostD.value(stream); } diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 51ede618c4..7582ec900e 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -25,9 +25,9 @@ #include #include #include -#include #include #include +#include #include #include #include @@ -95,15 +95,14 @@ inline std::enable_if_t> predict_core( auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( handle, mr, raft::make_extents(n_rows)); raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - thrust::fill(raft::resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - initial_value); + raft::matrix::fill(handle, minClusterAndDistance.view(), initial_value); auto centroidsNorm = raft::make_device_mdarray(handle, mr, raft::make_extents(n_clusters)); - raft::linalg::rowNorm( - centroidsNorm.data_handle(), centers, dim, n_clusters, stream); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(centers, n_clusters, dim), + centroidsNorm.view()); cuvs::distance::fusedDistanceNNMinReduce, IdxT>( minClusterAndDistance.data_handle(), @@ -124,10 +123,9 @@ inline std::enable_if_t> predict_core( // todo(lsugy): use KVP + iterator in caller. // Copy keys to output labels - thrust::transform(raft::resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + n_rows, - labels, + raft::linalg::map(handle, + raft::make_const_mdspan(minClusterAndDistance.view()), + raft::make_device_vector_view(labels, n_rows), raft::compose_op, raft::key_op>()); break; } @@ -138,15 +136,15 @@ inline std::enable_if_t> predict_core( auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( handle, mr, raft::make_extents(n_rows)); raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - thrust::fill(raft::resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - initial_value); + raft::matrix::fill(handle, minClusterAndDistance.view(), initial_value); auto centroidsNorm = raft::make_device_mdarray(handle, mr, raft::make_extents(n_clusters)); - raft::linalg::rowNorm( - centroidsNorm.data_handle(), centers, dim, n_clusters, stream, raft::sqrt_op{}); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(centers, n_clusters, dim), + centroidsNorm.view(), + raft::sqrt_op{}); cuvs::distance::fusedDistanceNNMinReduce, IdxT>( minClusterAndDistance.data_handle(), @@ -165,10 +163,9 @@ inline std::enable_if_t> predict_core( 0.0f, stream); // Copy keys to output labels - thrust::transform(raft::resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + n_rows, - labels, + raft::linalg::map(handle, + raft::make_const_mdspan(minClusterAndDistance.view()), + raft::make_device_vector_view(labels, n_rows), raft::compose_op, raft::key_op>()); break; } @@ -339,7 +336,11 @@ void calc_centers_and_sizes(const raft::resources& handle, // Add previous sizes if necessary if (!reset_counters) { - raft::linalg::add(cluster_sizes, cluster_sizes, temp_sizes, n_clusters, stream); + raft::linalg::add( + handle, + raft::make_device_vector_view(cluster_sizes, n_clusters), + raft::make_device_vector_view(temp_sizes, n_clusters), + raft::make_device_vector_view(cluster_sizes, n_clusters)); } raft::linalg::matrix_vector_op(handle, @@ -372,13 +373,20 @@ void compute_norm(const raft::resources& handle, } else { mapped_dataset.resize(n_rows * dim, stream); - raft::linalg::unaryOp(mapped_dataset.data(), dataset, n_rows * dim, mapping_op, stream); + raft::linalg::map( + handle, + raft::make_device_vector_view(dataset, n_rows * dim), + raft::make_device_vector_view(mapped_dataset.data(), n_rows * dim), + mapping_op); dataset_ptr = static_cast(mapped_dataset.data()); } - raft::linalg::rowNorm( - dataset_norm, dataset_ptr, dim, n_rows, stream, norm_fin_op); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(dataset_ptr, n_rows, dim), + raft::make_device_vector_view(dataset_norm, n_rows), + norm_fin_op); } /** @@ -437,8 +445,11 @@ void predict(const raft::resources& handle, if constexpr (std::is_same_v) { cur_dataset_ptr = const_cast(dataset + offset * dim); } else { - raft::linalg::unaryOp( - cur_dataset_ptr, dataset + offset * dim, minibatch_size * dim, mapping_op, stream); + raft::linalg::map( + handle, + raft::make_device_vector_view(dataset + offset * dim, minibatch_size * dim), + raft::make_device_vector_view(cur_dataset_ptr, minibatch_size * dim), + mapping_op); } // Compute the norm now if it hasn't been pre-computed. @@ -974,10 +985,11 @@ auto build_fine_clusters(const raft::resources& handle, device_memory, mc_trainset_norm); - raft::copy(cluster_centers + (dim * fine_clusters_csum[i]), - mc_trainset_ccenters.data(), - fine_clusters_nums[i] * dim, - stream); + raft::copy(handle, + raft::make_device_vector_view(cluster_centers + (dim * fine_clusters_csum[i]), + fine_clusters_nums[i] * dim), + raft::make_device_vector_view(mc_trainset_ccenters.data(), + fine_clusters_nums[i] * dim)); raft::resource::sync_stream(handle, stream); n_clusters_done += fine_clusters_nums[i]; } diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index c5db4a4cfa..4e2a41b26a 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -10,8 +10,10 @@ #include #include +#include #include #include +#include #include #include #include @@ -19,9 +21,9 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -36,9 +38,8 @@ #include #include #include -#include +#include #include -#include #include #include @@ -147,7 +148,9 @@ void checkWeight(raft::resources const& handle, n_samples, stream)); DataT wt_sum = 0; - raft::copy(&wt_sum, wt_aggr.data_handle(), 1, stream); + raft::copy(handle, + raft::make_host_scalar_view(&wt_sum), + raft::make_device_scalar_view(wt_aggr.data_handle())); raft::resource::sync_stream(handle, stream); if (wt_sum != n_samples) { @@ -157,11 +160,8 @@ void checkWeight(raft::resources const& handle, n_samples); auto scale = static_cast(n_samples) / wt_sum; - raft::linalg::unaryOp(weight.data_handle(), - weight.data_handle(), - n_samples, - raft::mul_const_op{scale}, - stream); + raft::linalg::map( + handle, weight, raft::mul_const_op{scale}, raft::make_const_mdspan(weight)); } } @@ -193,7 +193,7 @@ void computeClusterCost(raft::resources const& handle, { cudaStream_t stream = raft::resource::get_cuda_stream(handle); - thrust::transform_iterator itr(minClusterDistance.data_handle(), main_op); + cuda::transform_iterator itr(minClusterDistance.data_handle(), main_op); size_t temp_storage_bytes = 0; RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(nullptr, @@ -256,7 +256,9 @@ void sampleCentroids(raft::resources const& handle, stream)); IndexT nPtsSampledInRank = 0; - raft::copy(&nPtsSampledInRank, nSelected.data_handle(), 1, stream); + raft::copy(handle, + raft::make_host_scalar_view(&nPtsSampledInRank), + raft::make_device_scalar_view(nSelected.data_handle())); raft::resource::sync_stream(handle, stream); uint8_t* rawPtr_isSampleCentroid = isSampleCentroid.data_handle(); @@ -456,13 +458,8 @@ void countSamplesInCluster(raft::resources const& handle, params.batch_centroids, workspace); - // Using TransformInputIteratorT to dereference an array of raft::KeyValuePair - // and converting them to just return the Key to be used in reduce_rows_by_key - // prims - cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; - thrust::transform_iterator, - raft::KeyValuePair*> - itr(minClusterAndDistance.data_handle(), conversion_op); + cuda::transform_iterator itr(minClusterAndDistance.data_handle(), + cuvs::cluster::kmeans::detail::KeyValueIndexOp{}); // count # of samples in each cluster countLabels(handle, diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh index 9be5a6d674..4c8d7f8b2a 100644 --- a/cpp/src/cluster/detail/kmeans_mg.cuh +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,13 +8,16 @@ #include "../kmeans.cuh" #include +#include #include +#include #include #include +#include #include #include #include -#include +#include #include #include #include @@ -28,14 +31,10 @@ #include #include -#include -#include -#include -#include -#include -#include +#include #include +#include #include namespace cuvs::cluster::kmeans::mg::detail { @@ -51,15 +50,6 @@ namespace cuvs::cluster::kmeans::mg::detail { if (isRoot) { RAFT_LOG_DEBUG(fmt, ##__VA_ARGS__); } \ } while (0) -template -struct KeyValueIndexOp { - __host__ __device__ __forceinline__ IndexT - operator()(const raft::KeyValuePair& a) const - { - return a.key; - } -}; - #define KMEANS_COMM_ROOT 0 static cuvs::cluster::kmeans::params default_params; @@ -110,10 +100,10 @@ void initRandom(const raft::resources& handle, handle, X, centroidsSampledInRank.view(), nCentroidsSampledInRank, params.rng_state.seed); std::vector displs(n_ranks); - thrust::exclusive_scan(thrust::host, - nCentroidsElementsToReceiveFromRank.begin(), - nCentroidsElementsToReceiveFromRank.end(), - displs.begin()); + std::exclusive_scan(nCentroidsElementsToReceiveFromRank.begin(), + nCentroidsElementsToReceiveFromRank.end(), + displs.begin(), + size_t(0)); // gather centroids from all ranks comm.allgatherv(centroidsSampledInRank.data_handle(), // sendbuff @@ -173,9 +163,11 @@ void initKMeansPlusPlus(const raft::resources& handle, } { rmm::device_scalar rp_d(stream); - raft::copy(rp_d.data(), &rp, 1, stream); + raft::copy( + handle, raft::make_device_scalar_view(rp_d.data()), raft::make_host_scalar_view(&rp)); comm.bcast(rp_d.data(), 1, /*root=*/KMEANS_COMM_ROOT, stream); - raft::copy(&rp, rp_d.data(), 1, stream); + raft::copy( + handle, raft::make_host_scalar_view(&rp), raft::make_device_scalar_view(rp_d.data())); raft::resource::sync_stream(handle); } @@ -197,8 +189,9 @@ void initKMeansPlusPlus(const raft::resources& handle, auto centroidsView = raft::make_device_matrix_view( X.data_handle() + cIdx * n_features, 1, n_features); - raft::copy( - initialCentroid.data_handle(), centroidsView.data_handle(), centroidsView.size(), stream); + raft::copy(handle, + raft::make_device_vector_view(initialCentroid.data_handle(), centroidsView.size()), + raft::make_device_vector_view(centroidsView.data_handle(), centroidsView.size())); h_isSampleCentroid[cIdx] = 1; } @@ -209,14 +202,17 @@ void initKMeansPlusPlus(const raft::resources& handle, // device buffer to flag the sample that is chosen as initial centroid auto isSampleCentroid = raft::make_device_vector(handle, n_samples); - raft::copy( - isSampleCentroid.data_handle(), h_isSampleCentroid.data(), isSampleCentroid.size(), stream); + raft::copy(handle, + raft::make_device_vector_view(isSampleCentroid.data_handle(), isSampleCentroid.size()), + raft::make_host_vector_view(h_isSampleCentroid.data(), isSampleCentroid.size())); rmm::device_uvector centroidsBuf(0, stream); // reset buffer to store the chosen centroid centroidsBuf.resize(initialCentroid.size(), stream); - raft::copy(centroidsBuf.begin(), initialCentroid.data_handle(), initialCentroid.size(), stream); + raft::copy(handle, + raft::make_device_vector_view(centroidsBuf.begin(), initialCentroid.size()), + raft::make_device_vector_view(initialCentroid.data_handle(), initialCentroid.size())); auto potentialCentroids = raft::make_device_matrix_view( centroidsBuf.data(), initialCentroid.extent(0), initialCentroid.extent(1)); @@ -228,8 +224,11 @@ void initKMeansPlusPlus(const raft::resources& handle, auto L2NormX = raft::make_device_vector(handle, n_samples); if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream); + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + X.data_handle(), n_samples, n_features), + L2NormX.view()); } auto minClusterDistance = raft::make_device_vector(handle, n_samples); @@ -264,7 +263,7 @@ void initKMeansPlusPlus(const raft::resources& handle, clusterCost.data_handle(), clusterCost.data_handle(), 1, raft::comms::op_t::SUM, stream); DataT psi = 0; - raft::copy(&psi, clusterCost.data_handle(), 1, stream); + raft::copy(handle, raft::make_host_scalar_view(&psi), clusterCost.view()); // <<< End of Step-2 >>> @@ -310,7 +309,7 @@ void initKMeansPlusPlus(const raft::resources& handle, [] __device__(const DataT& a, const DataT& b) { return a + b; })); comm.allreduce( clusterCost.data_handle(), clusterCost.data_handle(), 1, raft::comms::op_t::SUM, stream); - raft::copy(&psi, clusterCost.data_handle(), 1, stream); + raft::copy(handle, raft::make_host_scalar_view(&psi), clusterCost.view()); ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, "An error occurred in the distributed operation. This can result " "from a failed rank"); @@ -349,20 +348,19 @@ void initKMeansPlusPlus(const raft::resources& handle, "An error occurred in the distributed operation. This can result " "from a failed rank"); - auto nPtsSampled = - thrust::reduce(thrust::host, nPtsSampledByRank, nPtsSampledByRank + n_rank, 0); + auto nPtsSampled = std::reduce(nPtsSampledByRank, nPtsSampledByRank + n_rank, 0); // gather centroids from all ranks std::vector sizes(n_rank); - thrust::transform( - thrust::host, nPtsSampledByRank, nPtsSampledByRank + n_rank, sizes.begin(), [&](int val) { - return val * n_features; + std::transform( + nPtsSampledByRank, nPtsSampledByRank + n_rank, sizes.begin(), [n_features](int val) { + return static_cast(val) * n_features; }); RAFT_CUDA_TRY_NO_THROW(cudaFreeHost(nPtsSampledByRank)); std::vector displs(n_rank); - thrust::exclusive_scan(thrust::host, sizes.begin(), sizes.end(), displs.begin()); + std::exclusive_scan(sizes.begin(), sizes.end(), displs.begin(), size_t(0)); centroidsBuf.resize(centroidsBuf.size() + nPtsSampled * n_features, stream); comm.allgatherv(inRankCp.data(), @@ -444,17 +442,18 @@ void initKMeansPlusPlus(const raft::resources& handle, initRandom(handle, rand_params, X, centroidsRawData); // copy centroids generated during kmeans|| iteration to the buffer - raft::copy(centroidsRawData.data_handle() + n_random_clusters * n_features, - potentialCentroids.data_handle(), - potentialCentroids.size(), - stream); + raft::copy( + handle, + raft::make_device_vector_view(centroidsRawData.data_handle() + n_random_clusters * n_features, + potentialCentroids.size()), + raft::make_device_vector_view(potentialCentroids.data_handle(), potentialCentroids.size())); } else { // found the required n_clusters - raft::copy(centroidsRawData.data_handle(), - potentialCentroids.data_handle(), - potentialCentroids.size(), - stream); + raft::copy( + handle, + raft::make_device_vector_view(centroidsRawData.data_handle(), potentialCentroids.size()), + raft::make_device_vector_view(potentialCentroids.data_handle(), potentialCentroids.size())); } } @@ -493,12 +492,8 @@ void checkWeights(const raft::resources& handle, n_samples); DataT scale = n_samples / wt_sum; - raft::linalg::unaryOp( - weight.data_handle(), - weight.data_handle(), - weight.size(), - cuda::proclaim_return_type([=] __device__(const DataT& wt) { return wt * scale; }), - stream); + raft::linalg::map( + handle, weight, raft::mul_const_op(scale), raft::make_const_mdspan(weight)); } } @@ -521,12 +516,9 @@ void fit(const raft::resources& handle, auto weight = raft::make_device_vector(handle, n_samples); if (sample_weight) { - raft::copy(weight.data_handle(), sample_weight->data_handle(), n_samples, stream); + raft::copy(handle, weight.view(), sample_weight.value()); } else { - thrust::fill(raft::resource::get_thrust_policy(handle), - weight.data_handle(), - weight.data_handle() + weight.size(), - 1); + raft::matrix::fill(handle, weight.view(), DataT(1)); } // check if weights sum up to n_samples @@ -573,8 +565,11 @@ void fit(const raft::resources& handle, auto L2NormX = raft::make_device_vector(handle, n_samples); if (metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm( - L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream); + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + X.data_handle(), n_samples, n_features), + L2NormX.view()); } DataT priorClusteringCost = 0; @@ -602,20 +597,14 @@ void fit(const raft::resources& handle, params.batch_centroids, workspace); - // Using TransformInputIteratorT to dereference an array of - // cub::KeyValuePair and converting them to just return the Key to be used - // in reduce_rows_by_key prims - KeyValueIndexOp conversion_op; - thrust::transform_iterator, raft::KeyValuePair*> - itr(minClusterAndDistance.data_handle(), conversion_op); - workspace.resize(n_samples, stream); - // Calculates weighted sum of all the samples assigned to cluster-i and - // store the result in newCentroids[i] + cuda::transform_iterator keys_itr( + minClusterAndDistance.data_handle(), + cuvs::cluster::kmeans::detail::KeyValueIndexOp{}); raft::linalg::reduce_rows_by_key((DataT*)X.data_handle(), X.extent(1), - itr, + keys_itr, weight.data_handle(), workspace.data(), X.extent(0), @@ -626,7 +615,7 @@ void fit(const raft::resources& handle, // Reduce weights by key to compute weight in each cluster raft::linalg::reduce_cols_by_key(weight.data_handle(), - itr, + keys_itr, wtInCluster.data_handle(), (IndexT)1, (IndexT)weight.extent(0), @@ -705,9 +694,11 @@ void fit(const raft::resources& handle, newCentroids.data_handle()); DataT sqrdNormError = 0; - raft::copy(&sqrdNormError, sqrdNorm.data_handle(), sqrdNorm.size(), stream); + raft::copy(handle, raft::make_host_scalar_view(&sqrdNormError), sqrdNorm.view()); - raft::copy(centroids.data_handle(), newCentroids.data_handle(), newCentroids.size(), stream); + raft::copy(handle, + raft::make_device_vector_view(centroids.data_handle(), newCentroids.size()), + raft::make_device_vector_view(newCentroids.data_handle(), newCentroids.size())); bool done = false; if (params.inertia_check) { @@ -736,7 +727,9 @@ void fit(const raft::resources& handle, stream); DataT curClusteringCost = 0; - raft::copy(&curClusteringCost, &(clusterCostD.data()->value), 1, stream); + raft::copy(handle, + raft::make_host_scalar_view(&curClusteringCost), + raft::make_device_scalar_view(&(clusterCostD.data()->value))); ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, "An error occurred in the distributed operation. This can result " diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 690c6b2ed0..8370ff922f 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -1,10 +1,12 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #include "kmeans_common.cuh" +#include + namespace cuvs::cluster::kmeans::detail { // Calculates a pair for every sample in input 'X' where key is an @@ -35,11 +37,10 @@ void minClusterAndDistanceCompute( if (is_fused) { L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::rowNorm(L2NormBuf_OR_DistBuf.data(), - centroids.data_handle(), - centroids.extent(1), - centroids.extent(0), - stream); + raft::linalg::norm( + handle, + centroids, + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); } else { // TODO: Unless pool allocator is used, passing in a workspace for this // isn't really increasing performance because this needs to do a re-allocation @@ -57,10 +58,7 @@ void minClusterAndDistanceCompute( raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - thrust::fill(raft::resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - initial_value); + raft::matrix::fill(handle, minClusterAndDistance, initial_value); // tile over the input dataset for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { @@ -189,11 +187,11 @@ void minClusterDistanceCompute(raft::resources const& handle, if (is_fused) { L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::rowNorm(L2NormBuf_OR_DistBuf.data(), - centroids.data_handle(), - centroids.extent(1), - centroids.extent(0), - stream); + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)), + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); } else { L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); } @@ -206,10 +204,7 @@ void minClusterDistanceCompute(raft::resources const& handle, auto pairwiseDistance = raft::make_device_matrix_view( L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); - thrust::fill(raft::resource::get_thrust_policy(handle), - minClusterDistance.data_handle(), - minClusterDistance.data_handle() + minClusterDistance.size(), - std::numeric_limits::max()); + raft::matrix::fill(handle, minClusterDistance, std::numeric_limits::max()); // tile over the input data and calculate distance matrix [n_samples x // n_clusters] diff --git a/cpp/src/cluster/detail/mst.cuh b/cpp/src/cluster/detail/mst.cuh index 07a786caee..bbd74022ba 100644 --- a/cpp/src/cluster/detail/mst.cuh +++ b/cpp/src/cluster/detail/mst.cuh @@ -24,7 +24,6 @@ #include #include -#include #include namespace cuvs::cluster::agglomerative::detail { @@ -154,7 +153,9 @@ void connect_knn_graph( raft::label::make_monotonic(d_color_remapped.data(), color, m, stream, true); std::vector h_color(m); - raft::copy(h_color.data(), d_color_remapped.data(), m, stream); + raft::copy(handle, + raft::make_host_vector_view(h_color.data(), m), + raft::make_device_vector_view(d_color_remapped.data(), m)); raft::resource::sync_stream(handle, stream); // make key (color) : value (vector of ids that have that color) @@ -196,8 +197,14 @@ void connect_knn_graph( auto device_u_indices = raft::make_device_vector(handle, new_nnz); auto device_v_indices = raft::make_device_vector(handle, new_nnz); - raft::copy(device_u_indices.data_handle(), host_u_indices.data(), new_nnz, stream); - raft::copy(device_v_indices.data_handle(), host_v_indices.data(), new_nnz, stream); + raft::copy( + handle, + device_u_indices.view(), + raft::make_host_vector_view(host_u_indices.data(), value_idx(new_nnz))); + raft::copy( + handle, + device_v_indices.view(), + raft::make_host_vector_view(host_v_indices.data(), value_idx(new_nnz))); auto data_u = raft::make_device_matrix(handle, new_nnz, n); auto data_v = raft::make_device_matrix(handle, new_nnz, n); diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index ebd88a5c7b..7d37b9cf80 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -7,11 +7,15 @@ #include "detail/kmeans.cuh" #include "kmeans_mg.hpp" #include +#include +#include +#include #include #include #include #include #include +#include #include @@ -416,8 +420,7 @@ void cluster_cost(raft::resources const& handle, auto x_norms = raft::make_device_vector(handle, n_samples); - raft::linalg::rowNorm( - x_norms.data_handle(), X.data_handle(), n_features, n_samples, stream); + raft::linalg::norm(handle, X, x_norms.view()); auto min_cluster_distance = raft::make_device_vector(handle, n_samples); rmm::device_uvector l2_norm_or_distance_buffer(0, stream); @@ -437,14 +440,11 @@ void cluster_cost(raft::resources const& handle, n_clusters, workspace); - rmm::device_scalar device_cost(0, stream); + auto device_cost = raft::make_device_scalar(handle, DataT(0)); - cuvs::cluster::kmeans::cluster_cost(handle, - min_cluster_distance.view(), - workspace, - raft::make_device_scalar_view(device_cost.data()), - raft::add_op{}); - raft::update_host(cost.data_handle(), device_cost.data(), 1, stream); + cuvs::cluster::kmeans::cluster_cost( + handle, min_cluster_distance.view(), workspace, device_cost.view(), raft::add_op{}); + raft::copy(handle, cost, raft::make_const_mdspan(device_cost.view())); raft::resource::sync_stream(handle); } diff --git a/cpp/src/distance/detail/distance.cuh b/cpp/src/distance/detail/distance.cuh index d212af0808..44b6ae5a63 100644 --- a/cpp/src/distance/detail/distance.cuh +++ b/cpp/src/distance/detail/distance.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2018-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -13,9 +13,9 @@ #include #include #include +#include #include #include -#include #include // to_float #include @@ -130,33 +130,78 @@ void distance_impl(raft::resources const& handle, // perhaps the use of stridedSummationKernel could be causing this, // need to investigate and fix. if (x == y && is_row_major) { - raft::linalg::reduce( - x_norm, x, k, std::max(m, n), (AccT)0, stream, false, raft::identity_op(), raft::add_op()); + raft::linalg::reduce( + handle, + raft::make_device_matrix_view(x, std::max(m, n), k), + raft::make_device_vector_view(x_norm, std::max(m, n)), + (AccT)0, + false, + raft::identity_op(), + raft::add_op()); sq_x_norm += std::max(m, n); sq_y_norm = sq_x_norm; - raft::linalg::rowNorm(sq_x_norm, x, k, std::max(m, n), stream); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(x, std::max(m, n), k), + raft::make_device_vector_view(sq_x_norm, std::max(m, n))); } else { y_norm += m; if (is_row_major) { - raft::linalg::reduce( - x_norm, x, k, m, (AccT)0, stream, false, raft::identity_op(), raft::add_op()); - raft::linalg::reduce( - y_norm, y, k, n, (AccT)0, stream, false, raft::identity_op(), raft::add_op()); + raft::linalg::reduce( + handle, + raft::make_device_matrix_view(x, m, k), + raft::make_device_vector_view(x_norm, m), + (AccT)0, + false, + raft::identity_op(), + raft::add_op()); + raft::linalg::reduce( + handle, + raft::make_device_matrix_view(y, n, k), + raft::make_device_vector_view(y_norm, n), + (AccT)0, + false, + raft::identity_op(), + raft::add_op()); } else { - raft::linalg::reduce( - x_norm, x, k, m, (AccT)0, stream, false, raft::identity_op(), raft::add_op()); - raft::linalg::reduce( - y_norm, y, k, n, (AccT)0, stream, false, raft::identity_op(), raft::add_op()); + raft::linalg::reduce( + handle, + raft::make_device_matrix_view(x, m, k), + raft::make_device_vector_view(x_norm, m), + (AccT)0, + false, + raft::identity_op(), + raft::add_op()); + raft::linalg::reduce( + handle, + raft::make_device_matrix_view(y, n, k), + raft::make_device_vector_view(y_norm, n), + (AccT)0, + false, + raft::identity_op(), + raft::add_op()); } sq_x_norm += (m + n); sq_y_norm = sq_x_norm + m; if (is_row_major) { - raft::linalg::rowNorm(sq_x_norm, x, k, m, stream); - raft::linalg::rowNorm(sq_y_norm, y, k, n, stream); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(x, m, k), + raft::make_device_vector_view(sq_x_norm, m)); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(y, n, k), + raft::make_device_vector_view(sq_y_norm, n)); } else { - raft::linalg::rowNorm(sq_x_norm, x, k, m, stream); - raft::linalg::rowNorm(sq_y_norm, y, k, n, stream); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(x, m, k), + raft::make_device_vector_view(sq_x_norm, m)); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(y, n, k), + raft::make_device_vector_view(sq_y_norm, n)); } } @@ -197,16 +242,35 @@ void distance_impl(raft::resources const& handle, // perhaps the use of stridedSummationKernel could be causing this, // need to investigate and fix. if (x == y && is_row_major) { - raft::linalg::rowNorm( - x_norm, x, k, std::max(m, n), stream, raft::sqrt_op{}); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(x, std::max(m, n), k), + raft::make_device_vector_view(x_norm, std::max(m, n)), + raft::sqrt_op{}); } else { y_norm += m; if (is_row_major) { - raft::linalg::rowNorm(x_norm, x, k, m, stream, raft::sqrt_op{}); - raft::linalg::rowNorm(y_norm, y, k, n, stream, raft::sqrt_op{}); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(x, m, k), + raft::make_device_vector_view(x_norm, m), + raft::sqrt_op{}); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(y, n, k), + raft::make_device_vector_view(y_norm, n), + raft::sqrt_op{}); } else { - raft::linalg::rowNorm(x_norm, x, k, m, stream, raft::sqrt_op{}); - raft::linalg::rowNorm(y_norm, y, k, n, stream, raft::sqrt_op{}); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(x, m, k), + raft::make_device_vector_view(x_norm, m), + raft::sqrt_op{}); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(y, n, k), + raft::make_device_vector_view(y_norm, n), + raft::sqrt_op{}); } } @@ -285,29 +349,37 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT) // metric_arg unused { - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - // Check if arrays overlap const DataT* x_end = x + m * k; const DataT* y_end = y + n * k; bool arrays_overlap = (x < y_end) && (y < x_end); - const auto raft_sqrt = raft::linalg::unaryOp; - const auto raft_sq = raft::linalg::unaryOp; - if (!arrays_overlap) { // Arrays don't overlap: sqrt each array independently - raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream); - raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); + raft::linalg::map( + handle, + raft::make_device_vector_view((DataT*)x, m * k), + raft::sqrt_op{}, + raft::make_const_mdspan(raft::make_device_vector_view(x, m * k))); + raft::linalg::map( + handle, + raft::make_device_vector_view((DataT*)y, n * k), + raft::sqrt_op{}, + raft::make_const_mdspan(raft::make_device_vector_view(y, n * k))); } else { // Arrays overlap: sqrt the union of both arrays exactly once const DataT* start = (x < y) ? x : y; const DataT* end = (x_end > y_end) ? x_end : y_end; IdxT union_size = end - start; - raft_sqrt((DataT*)start, start, union_size, raft::sqrt_op{}, stream); + raft::linalg::map( + handle, + raft::make_device_vector_view((DataT*)start, union_size), + raft::sqrt_op{}, + raft::make_const_mdspan(raft::make_device_vector_view(start, union_size))); } + cudaStream_t stream = raft::resource::get_cuda_stream(handle); // Calculate Hellinger distance ops::hellinger_distance_op distance_op{}; @@ -320,15 +392,27 @@ void distance_impl(raft::resources const& handle, // Restore arrays by squaring back if (!arrays_overlap) { // Arrays don't overlap: square each array independently - raft_sq((DataT*)x, x, m * k, raft::sq_op{}, stream); - raft_sq((DataT*)y, y, n * k, raft::sq_op{}, stream); + raft::linalg::map( + handle, + raft::make_device_vector_view((DataT*)x, m * k), + raft::sq_op{}, + raft::make_const_mdspan(raft::make_device_vector_view(x, m * k))); + raft::linalg::map( + handle, + raft::make_device_vector_view((DataT*)y, n * k), + raft::sq_op{}, + raft::make_const_mdspan(raft::make_device_vector_view(y, n * k))); } else { // Arrays overlap: square the union back const DataT* start = (x < y) ? x : y; const DataT* end = (x_end > y_end) ? x_end : y_end; IdxT union_size = end - start; - raft_sq((DataT*)start, start, union_size, raft::sq_op{}, stream); + raft::linalg::map( + handle, + raft::make_device_vector_view((DataT*)start, union_size), + raft::sq_op{}, + raft::make_const_mdspan(raft::make_device_vector_view(start, union_size))); } RAFT_CUDA_TRY(cudaGetLastError()); @@ -399,8 +483,11 @@ void distance_impl(raft::resources const& handle, }; if (x != y) { - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda, stream); + raft::linalg::map( + handle, + raft::make_device_vector_view((DataT*)y, n * k), + unaryOp_lambda, + raft::make_const_mdspan(raft::make_device_vector_view(y, n * k))); } const OutT* x_norm = nullptr; @@ -415,8 +502,11 @@ void distance_impl(raft::resources const& handle, if (x != y) { // Now reverse previous log (x) back to x using (e ^ log(x)) - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda_reverse, stream); + raft::linalg::map( + handle, + raft::make_device_vector_view((DataT*)y, n * k), + unaryOp_lambda_reverse, + raft::make_const_mdspan(raft::make_device_vector_view(y, n * k))); } } diff --git a/cpp/src/distance/detail/kernels/kernel_matrices.cu b/cpp/src/distance/detail/kernels/kernel_matrices.cu index 039db7cd43..9ed25f959c 100644 --- a/cpp/src/distance/detail/kernels/kernel_matrices.cu +++ b/cpp/src/distance/detail/kernels/kernel_matrices.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -466,18 +466,20 @@ void RBFKernel::matrixRowNormL2(raft::resources const& handle, int minor = is_row_major ? matrix.extent(1) : matrix.extent(0); int ld = is_row_major ? matrix.stride(0) : matrix.stride(1); ASSERT(ld == minor, "RBF Kernel lazy rowNorm compute does not support ld parameter"); + auto n_rows = matrix.extent(0); + auto n_cols = matrix.extent(1); if (is_row_major) { - raft::linalg::rowNorm(target, - matrix.data_handle(), - matrix.extent(1), - matrix.extent(0), - raft::resource::get_cuda_stream(handle)); + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + matrix.data_handle(), n_rows, n_cols), + raft::make_device_vector_view(target, n_rows)); } else { - raft::linalg::rowNorm(target, - matrix.data_handle(), - matrix.extent(1), - matrix.extent(0), - raft::resource::get_cuda_stream(handle)); + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + matrix.data_handle(), n_rows, n_cols), + raft::make_device_vector_view(target, n_rows)); } } diff --git a/cpp/src/distance/detail/masked_nn.cuh b/cpp/src/distance/detail/masked_nn.cuh index 3e6da23063..315ecf9d7e 100644 --- a/cpp/src/distance/detail/masked_nn.cuh +++ b/cpp/src/distance/detail/masked_nn.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -255,8 +256,10 @@ void masked_l2_nn_impl(raft::resources const& handle, size_t m_div_64 = raft::ceildiv(m, IdxT(64)); rmm::device_uvector ws_adj64{m_div_64 * num_groups, stream, ws_mr}; rmm::device_uvector ws_fused_nn{size_t(m), stream, ws_mr}; - RAFT_CUDA_TRY(cudaMemsetAsync(ws_adj64.data(), 0, ws_adj64.size() * sizeof(uint64_t), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(ws_fused_nn.data(), 0, ws_fused_nn.size() * sizeof(int), stream)); + raft::matrix::fill( + handle, raft::make_device_vector_view(ws_adj64.data(), ws_adj64.size()), uint64_t(0)); + raft::matrix::fill( + handle, raft::make_device_vector_view(ws_fused_nn.data(), ws_fused_nn.size()), int(0)); // Compress boolean adjacency matrix to bitfield. auto adj_view = raft::make_device_matrix_view(adj, m, num_groups); diff --git a/cpp/src/distance/detail/sparse/coo_spmv.cuh b/cpp/src/distance/detail/sparse/coo_spmv.cuh index cbfb299082..125e21f72e 100644 --- a/cpp/src/distance/detail/sparse/coo_spmv.cuh +++ b/cpp/src/distance/detail/sparse/coo_spmv.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -10,6 +10,7 @@ #include "coo_spmv_strategies/hash_strategy.cuh" #include +#include #include #include #include @@ -43,8 +44,10 @@ inline void balanced_coo_pairwise_generalized_spmv( strategy_t strategy, int chunk_size = 500000) { - uint64_t n = (uint64_t)sizeof(value_t) * (uint64_t)config_.a_nrows * (uint64_t)config_.b_nrows; - RAFT_CUDA_TRY(cudaMemsetAsync(out_dists, 0, n, raft::resource::get_cuda_stream(config_.handle))); + raft::matrix::fill( + config_.handle, + raft::make_device_vector_view(out_dists, (int64_t)config_.a_nrows * config_.b_nrows), + value_t(0)); strategy.dispatch(out_dists, coo_rows_b, product_func, accum_func, write_func, chunk_size); }; @@ -97,8 +100,10 @@ inline void balanced_coo_pairwise_generalized_spmv( write_f write_func, int chunk_size = 500000) { - uint64_t n = (uint64_t)sizeof(value_t) * (uint64_t)config_.a_nrows * (uint64_t)config_.b_nrows; - RAFT_CUDA_TRY(cudaMemsetAsync(out_dists, 0, n, raft::resource::get_cuda_stream(config_.handle))); + raft::matrix::fill( + config_.handle, + raft::make_device_vector_view(out_dists, (int64_t)config_.a_nrows * config_.b_nrows), + value_t(0)); int max_cols = max_cols_per_block(); diff --git a/cpp/src/distance/detail/sparse/l2_distance.cuh b/cpp/src/distance/detail/sparse/l2_distance.cuh index 627535472f..89d7c978a4 100644 --- a/cpp/src/distance/detail/sparse/l2_distance.cuh +++ b/cpp/src/distance/detail/sparse/l2_distance.cuh @@ -10,7 +10,7 @@ #include #include -#include +#include #include #include #include @@ -276,15 +276,15 @@ class l2_sqrt_expanded_distances_t : public l2_expanded_distances_t::compute(out_dists); // Sqrt Post-processing - raft::linalg::unaryOp( - out_dists, - out_dists, - this->config_->a_nrows * this->config_->b_nrows, + uint64_t n = (uint64_t)this->config_->a_nrows * this->config_->b_nrows; + raft::linalg::map( + this->config_->handle, + raft::make_device_vector_view(out_dists, n), [] __device__(value_t input) { int neg = input < 0 ? -1 : 1; return raft::sqrt(abs(input) * neg); }, - raft::resource::get_cuda_stream(this->config_->handle)); + raft::make_const_mdspan(raft::make_device_vector_view(out_dists, n))); } ~l2_sqrt_expanded_distances_t() = default; @@ -427,16 +427,16 @@ class hellinger_expanded_distances_t : public distances_t { raft::add_op(), raft::atomic_add_op()); - raft::linalg::unaryOp( - out_dists, - out_dists, - config_->a_nrows * config_->b_nrows, + uint64_t n = (uint64_t)config_->a_nrows * config_->b_nrows; + raft::linalg::map( + config_->handle, + raft::make_device_vector_view(out_dists, n), [=] __device__(value_t input) { // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative bool rectifier = (1 - input) > 0; return raft::sqrt(rectifier * (1 - input)); }, - raft::resource::get_cuda_stream(config_->handle)); + raft::make_const_mdspan(raft::make_device_vector_view(out_dists, n))); } ~hellinger_expanded_distances_t() = default; @@ -462,12 +462,12 @@ class russelrao_expanded_distances_t : public distances_t { value_t n_cols = config_->a_ncols; value_t n_cols_inv = 1.0 / n_cols; - raft::linalg::unaryOp( - out_dists, - out_dists, - config_->a_nrows * config_->b_nrows, + uint64_t n = (uint64_t)config_->a_nrows * config_->b_nrows; + raft::linalg::map( + config_->handle, + raft::make_device_vector_view(out_dists, n), [=] __device__(value_t input) { return (n_cols - input) * n_cols_inv; }, - raft::resource::get_cuda_stream(config_->handle)); + raft::make_const_mdspan(raft::make_device_vector_view(out_dists, n))); auto exec_policy = rmm::exec_policy(raft::resource::get_cuda_stream(config_->handle)); auto diags = cuda::make_counting_iterator(0); diff --git a/cpp/src/distance/detail/sparse/lp_distance.cuh b/cpp/src/distance/detail/sparse/lp_distance.cuh index 7e6a3cc7ae..38025329b9 100644 --- a/cpp/src/distance/detail/sparse/lp_distance.cuh +++ b/cpp/src/distance/detail/sparse/lp_distance.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -116,15 +117,14 @@ class l2_sqrt_unexpanded_distances_t : public l2_unexpanded_distances_tconfig_->a_nrows * (uint64_t)this->config_->b_nrows; // Sqrt Post-processing - raft::linalg::unaryOp( - out_dists, - out_dists, - n, + raft::linalg::map( + this->config_->handle, + raft::make_device_vector_view(out_dists, n), [] __device__(value_t input) { int neg = input < 0 ? -1 : 1; return raft::sqrt(abs(input) * neg); }, - raft::resource::get_cuda_stream(this->config_->handle)); + raft::make_const_mdspan(raft::make_device_vector_view(out_dists, n))); } }; @@ -194,11 +194,11 @@ class lp_unexpanded_distances_t : public distances_t { uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows; value_t one_over_p = value_t{1} / p; - raft::linalg::unaryOp(out_dists, - out_dists, - n, - raft::pow_const_op(one_over_p), - raft::resource::get_cuda_stream(config_->handle)); + raft::linalg::map( + config_->handle, + raft::make_device_vector_view(out_dists, n), + raft::pow_const_op(one_over_p), + raft::make_const_mdspan(raft::make_device_vector_view(out_dists, n))); } private: @@ -221,11 +221,11 @@ class hamming_unexpanded_distances_t : public distances_t { uint64_t n = (uint64_t)config_->a_nrows * (uint64_t)config_->b_nrows; value_t n_cols = 1.0 / config_->a_ncols; - raft::linalg::unaryOp(out_dists, - out_dists, - n, - raft::mul_const_op(n_cols), - raft::resource::get_cuda_stream(config_->handle)); + raft::linalg::map( + config_->handle, + raft::make_device_vector_view(out_dists, n), + raft::mul_const_op(n_cols), + raft::make_const_mdspan(raft::make_device_vector_view(out_dists, n))); } private: @@ -263,12 +263,11 @@ class jensen_shannon_unexpanded_distances_t : public distances_t { raft::atomic_add_op()); uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows; - raft::linalg::unaryOp( - out_dists, - out_dists, - n, + raft::linalg::map( + config_->handle, + raft::make_device_vector_view(out_dists, n), [=] __device__(value_t input) { return raft::sqrt(0.5 * input); }, - raft::resource::get_cuda_stream(config_->handle)); + raft::make_const_mdspan(raft::make_device_vector_view(out_dists, n))); } private: @@ -304,11 +303,11 @@ class kl_divergence_unexpanded_distances_t : public distances_t { raft::atomic_add_op()); uint64_t n = (uint64_t)this->config_->a_nrows * (uint64_t)this->config_->b_nrows; - raft::linalg::unaryOp(out_dists, - out_dists, - n, - raft::mul_const_op(0.5), - raft::resource::get_cuda_stream(config_->handle)); + raft::linalg::map( + config_->handle, + raft::make_device_vector_view(out_dists, n), + raft::mul_const_op(0.5), + raft::make_const_mdspan(raft::make_device_vector_view(out_dists, n))); } private: diff --git a/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh b/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh index 1ec968f945..7e4ff748a4 100644 --- a/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh +++ b/cpp/src/neighbors/all_neighbors/all_neighbors_batched.cuh @@ -15,8 +15,11 @@ #include #include #include +#include #include +#include #include +#include #include #include #include @@ -121,10 +124,11 @@ void single_gpu_assign_clusters( for (size_t i = 0; i < num_batches; i++) { size_t row_offset = n_rows_per_batch * i + base_row_offset; size_t n_rows_of_current_batch = std::min(n_rows_per_batch, num_rows - row_offset); - raft::copy(dataset_batch_d.data_handle(), - dataset.data_handle() + row_offset * num_cols, - n_rows_of_current_batch * num_cols, - stream); + raft::copy(res, + raft::make_device_vector_view(dataset_batch_d.data_handle(), + n_rows_of_current_batch * num_cols), + raft::make_host_vector_view(dataset.data_handle() + row_offset * num_cols, + n_rows_of_current_batch * num_cols)); // n_clusters is usually not large, so okay to do this brute-force cuvs::neighbors::brute_force::search(res, @@ -132,10 +136,12 @@ void single_gpu_assign_clusters( raft::make_const_mdspan(dataset_batch_d.view()), nearest_clusters_idx_d.view(), nearest_clusters_dist_d.view()); - raft::copy(global_nearest_cluster.data_handle() + row_offset * overlap_factor, - nearest_clusters_idx_d.data_handle(), - n_rows_of_current_batch * overlap_factor, - stream); + raft::copy(res, + raft::make_host_vector_view( + global_nearest_cluster.data_handle() + row_offset * overlap_factor, + n_rows_of_current_batch * overlap_factor), + raft::make_device_vector_view(nearest_clusters_idx_d.data_handle(), + n_rows_of_current_batch * overlap_factor)); } } @@ -160,10 +166,10 @@ void assign_clusters(raft::resources const& res, size_t num_cols = static_cast(dataset.extent(1)); auto centroids_h = raft::make_host_matrix(params.n_clusters, num_cols); - raft::copy(centroids_h.data_handle(), - centroids.data_handle(), - params.n_clusters * num_cols, - raft::resource::get_cuda_stream(res)); + raft::copy( + res, + raft::make_host_vector_view(centroids_h.data_handle(), params.n_clusters * num_cols), + raft::make_device_vector_view(centroids.data_handle(), params.n_clusters * num_cols)); size_t n_rows_per_cluster = (num_rows + params.n_clusters - 1) / params.n_clusters; @@ -179,10 +185,11 @@ void assign_clusters(raft::resources const& res, auto centroids_matrix = raft::make_device_matrix(dev_res, params.n_clusters, num_cols); - raft::copy(centroids_matrix.data_handle(), - centroids_h.data_handle(), - params.n_clusters * num_cols, - raft::resource::get_cuda_stream(dev_res)); + raft::copy( + dev_res, + raft::make_device_vector_view(centroids_matrix.data_handle(), params.n_clusters * num_cols), + raft::make_host_vector_view(centroids_h.data_handle(), + params.n_clusters * num_cols)); size_t base_cluster_idx = rank * clusters_per_rank + std::min((size_t)rank, rem); @@ -349,20 +356,16 @@ void multi_gpu_batch_build(const raft::resources& handle, size_t rem = params.n_clusters - clusters_per_rank * num_ranks; auto cluster_offsets = raft::make_host_vector(cluster_offsets_c.size()); - raft::copy(cluster_offsets.data_handle(), - cluster_offsets_c.data_handle(), - cluster_offsets_c.size(), - raft::resource::get_cuda_stream(handle)); + raft::copy(handle, cluster_offsets.view(), cluster_offsets_c); using ReachabilityPP = cuvs::neighbors::detail::reachability::ReachabilityPostProcess; const bool mutual_reach_dist = std::is_same_v; std::optional> core_distances_h; if constexpr (mutual_reach_dist) { core_distances_h.emplace(raft::make_host_vector(num_rows)); - raft::copy(core_distances_h.value().data_handle(), - dist_epilogue.core_dists, - num_rows, - raft::resource::get_cuda_stream(handle)); + raft::copy(handle, + core_distances_h.value().view(), + raft::make_device_vector_view(dist_epilogue.core_dists, num_rows)); } // Ensure all async copies complete before starting parallel region @@ -415,10 +418,9 @@ void multi_gpu_batch_build(const raft::resources& handle, auto dist_epilgogue_for_rank = [&]() { if constexpr (mutual_reach_dist) { core_distances_d_for_rank.emplace(raft::make_device_vector(dev_res, num_rows)); - raft::copy(core_distances_d_for_rank.value().data_handle(), - core_distances_h.value().data_handle(), - num_rows, - raft::resource::get_cuda_stream(dev_res)); + raft::copy(dev_res, + core_distances_d_for_rank.value().view(), + raft::make_const_mdspan(core_distances_h.value().view())); return ReachabilityPP{ core_distances_d_for_rank.value().data_handle(), dist_epilogue.alpha, num_rows}; } else { @@ -604,15 +606,15 @@ void batch_build( inverted_indices_view); } - raft::copy(indices.data_handle(), - global_neighbors.data_handle(), - num_rows * k, - raft::resource::get_cuda_stream(handle)); + raft::copy( + handle, + raft::make_device_vector_view(indices.data_handle(), num_rows * k), + raft::make_device_vector_view(global_neighbors.data_handle(), num_rows * k)); if (distances.has_value()) { - raft::copy(distances.value().data_handle(), - global_distances.data_handle(), - num_rows * k, - raft::resource::get_cuda_stream(handle)); + raft::copy( + handle, + raft::make_device_vector_view(distances.value().data_handle(), num_rows * k), + raft::make_device_vector_view(global_distances.data_handle(), num_rows * k)); } } diff --git a/cpp/src/neighbors/all_neighbors/all_neighbors_builder.cuh b/cpp/src/neighbors/all_neighbors/all_neighbors_builder.cuh index 0c757d51eb..8f54ea0b24 100644 --- a/cpp/src/neighbors/all_neighbors/all_neighbors_builder.cuh +++ b/cpp/src/neighbors/all_neighbors/all_neighbors_builder.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -15,9 +15,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -190,10 +192,11 @@ struct all_neighbors_builder_ivfpq : public all_neighbors_builder { candidate_distances_view); // copy candidate neighbors to host - raft::copy(candidate_neighbors_h.value().data_handle(), - candidate_neighbors_view.data_handle(), - num_data_in_cluster * candidate_k, - raft::resource::get_cuda_stream(this->res)); + raft::copy(this->res, + raft::make_host_vector_view(candidate_neighbors_h.value().data_handle(), + num_data_in_cluster * candidate_k), + raft::make_device_vector_view(candidate_neighbors_view.data_handle(), + num_data_in_cluster * candidate_k)); auto candidate_neighbors_h_view = raft::make_host_matrix_view( candidate_neighbors_h.value().data_handle(), num_data_in_cluster, candidate_k); auto refined_distances_h_view = raft::make_host_matrix_view( @@ -210,10 +213,11 @@ struct all_neighbors_builder_ivfpq : public all_neighbors_builder { all_ivf_pq_params.build_params.metric); if (this->n_clusters > 1) { // do batching - raft::copy(this->batch_distances_d.value().data_handle(), - refined_distances_h_view.data_handle(), - num_data_in_cluster * this->k, - raft::resource::get_cuda_stream(this->res)); + raft::copy(this->res, + raft::make_device_vector_view(this->batch_distances_d.value().data_handle(), + num_data_in_cluster * this->k), + raft::make_host_vector_view(refined_distances_h_view.data_handle(), + num_data_in_cluster * this->k)); remap_and_merge_subgraphs( this->res, @@ -231,15 +235,17 @@ struct all_neighbors_builder_ivfpq : public all_neighbors_builder { } else { size_t num_rows = num_data_in_cluster; // copy resulting indices and distances to device output - raft::copy(this->indices_.value().data_handle(), - refined_neighbors_h_view.data_handle(), - num_rows * this->k, - raft::resource::get_cuda_stream(this->res)); + raft::copy( + this->res, + raft::make_device_vector_view(this->indices_.value().data_handle(), num_rows * this->k), + raft::make_host_vector_view(refined_neighbors_h_view.data_handle(), + num_rows * this->k)); if (this->distances_.has_value()) { - raft::copy(this->distances_.value().data_handle(), - refined_distances_h_view.data_handle(), - num_rows * this->k, - raft::resource::get_cuda_stream(this->res)); + raft::copy( + this->res, + raft::make_device_vector_view(this->distances_.value().data_handle(), num_rows * this->k), + raft::make_host_vector_view(refined_distances_h_view.data_handle(), + num_rows * this->k)); } } } @@ -251,10 +257,9 @@ struct all_neighbors_builder_ivfpq : public all_neighbors_builder { std::optional> global_distances = std::nullopt) override { // we need data on device for ivfpq build and search. - raft::copy(data_d.value().data_handle(), - dataset.data_handle(), - dataset.size(), - raft::resource::get_cuda_stream(this->res)); + raft::copy(this->res, + raft::make_device_vector_view(data_d.value().data_handle(), dataset.size()), + raft::make_host_vector_view(dataset.data_handle(), dataset.size())); build_knn_common(raft::make_device_matrix_view( data_d.value().data_handle(), dataset.extent(0), dataset.extent(1)), @@ -275,10 +280,9 @@ struct all_neighbors_builder_ivfpq : public all_neighbors_builder { auto dataset_h = raft::make_host_matrix(dataset.extent(0), dataset.extent(1)); // we need data on host for refining - raft::copy(dataset_h.data_handle(), - dataset.data_handle(), - dataset.size(), - raft::resource::get_cuda_stream(this->res)); + raft::copy(this->res, + raft::make_host_vector_view(dataset_h.data_handle(), dataset.size()), + raft::make_device_vector_view(dataset.data_handle(), dataset.size())); build_knn_common(dataset, raft::make_host_matrix_view( @@ -381,10 +385,11 @@ struct all_neighbors_builder_nn_descent : public all_neighbors_builder size_t num_data_in_cluster = dataset.extent(0); if constexpr (std::is_same_v) { // gather core dists - raft::copy(this->inverted_indices_d.value().data_handle(), - inverted_indices.value().data_handle(), - num_data_in_cluster, - raft::resource::get_cuda_stream(this->res)); + raft::copy(this->res, + raft::make_device_vector_view(this->inverted_indices_d.value().data_handle(), + num_data_in_cluster), + raft::make_host_vector_view(inverted_indices.value().data_handle(), + num_data_in_cluster)); raft::matrix::gather(this->res, raft::make_device_matrix_view( @@ -458,10 +463,11 @@ struct all_neighbors_builder_nn_descent : public all_neighbors_builder } // copy to final device output - raft::copy(this->indices_.value().data_handle(), - tmp_indices.data_handle(), - tmp_indices.extent(0) * this->k, - raft::resource::get_cuda_stream(this->res)); + raft::copy(this->res, + raft::make_device_vector_view(this->indices_.value().data_handle(), + tmp_indices.extent(0) * this->k), + raft::make_host_vector_view(tmp_indices.data_handle(), + tmp_indices.extent(0) * this->k)); } } @@ -546,10 +552,11 @@ struct all_neighbors_builder_brute_force : public all_neighbors_builder if constexpr (std::is_same_v) { // gather core dists - raft::copy(this->inverted_indices_d.value().data_handle(), - inverted_indices.value().data_handle(), - num_data_in_cluster, - raft::resource::get_cuda_stream(this->res)); + raft::copy(this->res, + raft::make_device_vector_view(this->inverted_indices_d.value().data_handle(), + num_data_in_cluster), + raft::make_host_vector_view(inverted_indices.value().data_handle(), + num_data_in_cluster)); raft::matrix::gather(this->res, raft::make_device_matrix_view( @@ -591,10 +598,11 @@ struct all_neighbors_builder_brute_force : public all_neighbors_builder raft::make_device_matrix_view( this->batch_distances_d.value().data_handle(), num_data_in_cluster, this->k)); } - raft::copy(this->batch_neighbors_h.value().data_handle(), - this->batch_neighbors_d.value().data_handle(), - num_data_in_cluster * this->k, - raft::resource::get_cuda_stream(this->res)); + raft::copy(this->res, + raft::make_host_vector_view(this->batch_neighbors_h.value().data_handle(), + num_data_in_cluster * this->k), + raft::make_device_vector_view( + this->batch_neighbors_d.value().data_handle(), num_data_in_cluster * this->k)); remap_and_merge_subgraphs>( this->res, @@ -654,10 +662,9 @@ struct all_neighbors_builder_brute_force : public all_neighbors_builder std::optional> global_neighbors = std::nullopt, std::optional> global_distances = std::nullopt) override { - raft::copy(data_d.value().data_handle(), - dataset.data_handle(), - dataset.size(), - raft::resource::get_cuda_stream(this->res)); + raft::copy(this->res, + raft::make_device_vector_view(data_d.value().data_handle(), dataset.size()), + raft::make_host_vector_view(dataset.data_handle(), dataset.size())); build_knn_common(raft::make_device_matrix_view( data_d.value().data_handle(), dataset.extent(0), dataset.extent(1)), diff --git a/cpp/src/neighbors/all_neighbors/all_neighbors_merge.cuh b/cpp/src/neighbors/all_neighbors/all_neighbors_merge.cuh index 1c55eb1b95..f9088a2730 100644 --- a/cpp/src/neighbors/all_neighbors/all_neighbors_merge.cuh +++ b/cpp/src/neighbors/all_neighbors/all_neighbors_merge.cuh @@ -1,10 +1,11 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once #include +#include #include #include #include @@ -268,15 +269,16 @@ void remap_and_merge_subgraphs(raft::resources const& res, } } - raft::copy(inverted_indices_d.data_handle(), - inverted_indices.data_handle(), - num_data_in_cluster, - raft::resource::get_cuda_stream(res)); + raft::copy( + res, + raft::make_device_vector_view(inverted_indices_d.data_handle(), num_data_in_cluster), + raft::make_host_vector_view(inverted_indices.data_handle(), num_data_in_cluster)); - raft::copy(batch_neighbors_d.data_handle(), - batch_neighbors_h.data_handle(), - num_data_in_cluster * k, - raft::resource::get_cuda_stream(res)); + raft::copy( + res, + raft::make_device_vector_view(batch_neighbors_d.data_handle(), num_data_in_cluster * k), + raft::make_host_vector_view(batch_neighbors_h.data_handle(), + num_data_in_cluster * k)); merge_subgraphs(res, k, diff --git a/cpp/src/neighbors/ball_cover/ball_cover.cuh b/cpp/src/neighbors/ball_cover/ball_cover.cuh index c6647c7ca5..c39756f7d5 100644 --- a/cpp/src/neighbors/ball_cover/ball_cover.cuh +++ b/cpp/src/neighbors/ball_cover/ball_cover.cuh @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -24,9 +25,7 @@ #include #include -#include #include -#include #include #include @@ -56,15 +55,11 @@ void sample_landmarks(raft::resources const& handle, raft::linalg::map_offset(handle, index.get_R_1nn_cols(), raft::identity_op{}); - thrust::fill(raft::resource::get_thrust_policy(handle), - R_1nn_ones.data(), - R_1nn_ones.data() + R_1nn_ones.size(), - 1.0); + raft::matrix::fill( + handle, raft::make_device_vector_view(R_1nn_ones.data(), R_1nn_ones.size()), value_t(1.0)); - thrust::fill(raft::resource::get_thrust_policy(handle), - R_indices.data(), - R_indices.data() + R_indices.size(), - 0.0); + raft::matrix::fill( + handle, raft::make_device_vector_view(R_indices.data(), R_indices.size()), value_idx(0)); /** * 1. Randomly sample sqrt(n) points from X @@ -109,10 +104,9 @@ void construct_landmark_1nn(raft::resources const& handle, { auto R_1nn_inds = raft::make_device_vector(handle, index.m); - thrust::fill(raft::resource::get_thrust_policy(handle), - R_1nn_inds.data_handle(), - R_1nn_inds.data_handle() + index.m, - std::numeric_limits::max()); + raft::matrix::fill(handle, + raft::make_device_vector_view(R_1nn_inds.data_handle(), index.m), + std::numeric_limits::max()); raft::linalg::map_offset(handle, R_1nn_inds.view(), [R_knn_inds_ptr, k] __device__(value_idx i) { return R_knn_inds_ptr[i * k]; @@ -226,14 +220,12 @@ void perform_rbc_query(raft::resources const& handle, bool perform_post_filtering = true) { // initialize output inds and dists - thrust::fill(raft::resource::get_thrust_policy(handle), - inds, - inds + (k * n_query_pts), - std::numeric_limits::max()); - thrust::fill(raft::resource::get_thrust_policy(handle), - dists, - dists + (k * n_query_pts), - std::numeric_limits::max()); + raft::matrix::fill(handle, + raft::make_device_vector_view(inds, k * n_query_pts), + std::numeric_limits::max()); + raft::matrix::fill(handle, + raft::make_device_vector_view(dists, k * n_query_pts), + std::numeric_limits::max()); // Compute nearest k for each neighborhood in each closest R rbc_low_dim_pass_one( @@ -261,8 +253,8 @@ void perform_rbc_eps_nn_query(raft::resources const& handle, value_idx* vd) { // initialize output - RAFT_CUDA_TRY(cudaMemsetAsync( - adj, 0, index.m * n_query_pts * sizeof(bool), raft::resource::get_cuda_stream(handle))); + raft::matrix::fill( + handle, raft::make_device_vector_view(adj, index.m * n_query_pts), bool(false)); raft::resource::sync_stream(handle); @@ -311,14 +303,13 @@ void rbc_build_index(raft::resources const& handle, rmm::device_uvector R_knn_inds(index.m, raft::resource::get_cuda_stream(handle)); // Initialize the uvectors - thrust::fill(raft::resource::get_thrust_policy(handle), - R_knn_inds.begin(), - R_knn_inds.end(), - std::numeric_limits::max()); - thrust::fill(raft::resource::get_thrust_policy(handle), - index.get_R_closest_landmark_dists().data_handle(), - index.get_R_closest_landmark_dists().data_handle() + index.m, - std::numeric_limits::max()); + raft::matrix::fill(handle, + raft::make_device_vector_view(R_knn_inds.data(), R_knn_inds.size()), + std::numeric_limits::max()); + raft::matrix::fill(handle, + raft::make_device_vector_view( + index.get_R_closest_landmark_dists().data_handle(), index.m), + std::numeric_limits::max()); /** * 1. Randomly sample sqrt(n) points from X @@ -374,23 +365,19 @@ void rbc_all_knn_query(raft::resources const& handle, rmm::device_uvector R_knn_dists(k * index.m, raft::resource::get_cuda_stream(handle)); // Initialize the uvectors - thrust::fill(raft::resource::get_thrust_policy(handle), - R_knn_inds.begin(), - R_knn_inds.end(), - std::numeric_limits::max()); - thrust::fill(raft::resource::get_thrust_policy(handle), - R_knn_dists.begin(), - R_knn_dists.end(), - std::numeric_limits::max()); - - thrust::fill(raft::resource::get_thrust_policy(handle), - inds, - inds + (k * index.m), - std::numeric_limits::max()); - thrust::fill(raft::resource::get_thrust_policy(handle), - dists, - dists + (k * index.m), - std::numeric_limits::max()); + raft::matrix::fill(handle, + raft::make_device_vector_view(R_knn_inds.data(), R_knn_inds.size()), + std::numeric_limits::max()); + raft::matrix::fill(handle, + raft::make_device_vector_view(R_knn_dists.data(), R_knn_dists.size()), + std::numeric_limits::max()); + + raft::matrix::fill(handle, + raft::make_device_vector_view(inds, k * index.m), + std::numeric_limits::max()); + raft::matrix::fill(handle, + raft::make_device_vector_view(dists, k * index.m), + std::numeric_limits::max()); sample_landmarks(handle, index); @@ -440,23 +427,19 @@ void rbc_knn_query(raft::resources const& handle, raft::resource::get_cuda_stream(handle)); // Initialize the uvectors - thrust::fill(raft::resource::get_thrust_policy(handle), - R_knn_inds.begin(), - R_knn_inds.end(), - std::numeric_limits::max()); - thrust::fill(raft::resource::get_thrust_policy(handle), - R_knn_dists.begin(), - R_knn_dists.end(), - std::numeric_limits::max()); - - thrust::fill(raft::resource::get_thrust_policy(handle), - inds, - inds + (k * n_query_pts), - std::numeric_limits::max()); - thrust::fill(raft::resource::get_thrust_policy(handle), - dists, - dists + (k * n_query_pts), - std::numeric_limits::max()); + raft::matrix::fill(handle, + raft::make_device_vector_view(R_knn_inds.data(), R_knn_inds.size()), + std::numeric_limits::max()); + raft::matrix::fill(handle, + raft::make_device_vector_view(R_knn_dists.data(), R_knn_dists.size()), + std::numeric_limits::max()); + + raft::matrix::fill(handle, + raft::make_device_vector_view(inds, k * n_query_pts), + std::numeric_limits::max()); + raft::matrix::fill(handle, + raft::make_device_vector_view(dists, k * n_query_pts), + std::numeric_limits::max()); k_closest_landmarks(handle, index, query, n_query_pts, k, R_knn_inds.data(), R_knn_dists.data()); diff --git a/cpp/src/neighbors/ball_cover/registers.cuh b/cpp/src/neighbors/ball_cover/registers.cuh index af623957dc..4d381d8ca2 100644 --- a/cpp/src/neighbors/ball_cover/registers.cuh +++ b/cpp/src/neighbors/ball_cover/registers.cuh @@ -14,12 +14,12 @@ #include "../detail/faiss_select/key_value_block_select.cuh" #include #include -#include +#include +#include #include #include -#include -#include +#include #include #include @@ -1134,8 +1134,9 @@ void rbc_low_dim_pass_two(raft::resources const& handle, rmm::device_uvector bitset(bitset_size * n_query_rows, raft::resource::get_cuda_stream(handle)); - thrust::fill( - raft::resource::get_thrust_policy(handle), bitset.data(), bitset.data() + bitset.size(), 0); + raft::matrix::fill(handle, + raft::make_device_vector_view(bitset.data(), bitset.size()), + std::uint32_t(0)); perform_post_filter_registers << max_k_in) { // ceil vd to max_k - raft::linalg::unaryOp( - vd_ptr, - vd_ptr, - n_query_rows, - [max_k_in] __device__(value_idx vd_count) { - return vd_count > max_k_in ? max_k_in : vd_count; - }, - raft::resource::get_cuda_stream(handle)); + raft::linalg::map(handle, + raft::make_device_vector_view(vd_ptr, n_query_rows), + raft::make_device_vector_view(vd_ptr, n_query_rows), + [max_k_in] __device__(value_idx vd_count) { + return vd_count > max_k_in ? max_k_in : vd_count; + }); } thrust::exclusive_scan(raft::resource::get_thrust_policy(handle), diff --git a/cpp/src/neighbors/cagra.cuh b/cpp/src/neighbors/cagra.cuh index 30c5729f6b..f11e69ac7f 100644 --- a/cpp/src/neighbors/cagra.cuh +++ b/cpp/src/neighbors/cagra.cuh @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -46,16 +47,16 @@ void index::compute_dataset_norms_(raft::resources const& res) // first scale the dataset and then compute norms auto scaled_sq_op = raft::compose_op( raft::sq_op{}, raft::div_const_op{float(kScale)}, raft::cast_op()); - raft::linalg::reduce(dataset_norms_->data_handle(), - dataset_view.data_handle(), - dataset_view.stride(0), - dataset_view.extent(0), - (float)0, - raft::resource::get_cuda_stream(res), - false, - scaled_sq_op, - raft::add_op(), - raft::sqrt_op{}); + raft::linalg::reduce( + res, + raft::make_device_matrix_view( + dataset_view.data_handle(), dataset_view.extent(0), dataset_view.stride(0)), + dataset_norms_->view(), + (float)0, + false, + scaled_sq_op, + raft::add_op(), + raft::sqrt_op{}); } /** diff --git a/cpp/src/neighbors/detail/cagra/add_nodes.cuh b/cpp/src/neighbors/detail/cagra/add_nodes.cuh index 5eee2bc564..8d6ac67d83 100644 --- a/cpp/src/neighbors/detail/cagra/add_nodes.cuh +++ b/cpp/src/neighbors/detail/cagra/add_nodes.cuh @@ -1,14 +1,16 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #include "../../../core/omp_wrapper.hpp" #include "../ann_utils.cuh" #include +#include #include #include #include #include +#include #include #include @@ -37,10 +39,7 @@ void add_node_core( // Step 0: Calculate the number of incoming edges for each node auto dev_num_incoming_edges = raft::make_device_vector(handle, new_size); - RAFT_CUDA_TRY(cudaMemsetAsync(dev_num_incoming_edges.data_handle(), - 0, - sizeof(int) * new_size, - raft::resource::get_cuda_stream(handle))); + raft::matrix::fill(handle, dev_num_incoming_edges.view(), int(0)); raft::stats::histogram(raft::stats::HistTypeAuto, dev_num_incoming_edges.data_handle(), old_size, @@ -50,10 +49,7 @@ void add_node_core( raft::resource::get_cuda_stream(handle)); auto host_num_incoming_edges = raft::make_host_vector(new_size); - raft::copy(host_num_incoming_edges.data_handle(), - dev_num_incoming_edges.data_handle(), - new_size, - raft::resource::get_cuda_stream(handle)); + raft::copy(handle, host_num_incoming_edges.view(), dev_num_incoming_edges.view()); std::size_t data_size_per_vector = sizeof(IdxT) * base_degree + sizeof(DistanceT) * base_degree + sizeof(T) * dim; @@ -117,10 +113,10 @@ void add_node_core( neighbors::cagra::search( handle, params, idx, queries_view, neighbor_indices_view, neighbor_distances_view); - raft::copy(host_neighbor_indices.data_handle(), - neighbor_indices.data_handle(), - batch.size() * base_degree, - raft::resource::get_cuda_stream(handle)); + raft::copy( + handle, + raft::make_host_vector_view(host_neighbor_indices.data_handle(), batch.size() * base_degree), + raft::make_device_vector_view(neighbor_indices.data_handle(), batch.size() * base_degree)); raft::resource::sync_stream(handle); // Check search results @@ -300,10 +296,9 @@ void add_graph_nodes( const std::size_t max_chunk_size_ = params.max_chunk_size == 0 ? new_dataset_size : params.max_chunk_size; - raft::copy(updated_graph_view.data_handle(), - index.graph().data_handle(), - index.graph().size(), - raft::resource::get_cuda_stream(handle)); + raft::copy(handle, + raft::make_device_vector_view(updated_graph_view.data_handle(), index.graph().size()), + raft::make_device_vector_view(index.graph().data_handle(), index.graph().size())); neighbors::cagra::index internal_index( handle, @@ -444,10 +439,10 @@ void extend_core( } // Copy updated dataset on host memory to device memory - raft::copy(updated_dataset_view.data_handle(), - host_updated_dataset.data_handle(), - new_dataset_size * stride, - raft::resource::get_cuda_stream(handle)); + raft::copy( + handle, + raft::make_device_vector_view(updated_dataset_view.data_handle(), new_dataset_size * stride), + raft::make_host_vector_view(host_updated_dataset.data_handle(), new_dataset_size * stride)); // Add graph nodes cuvs::neighbors::cagra::add_graph_nodes( @@ -471,10 +466,10 @@ void extend_core( // Update index graph if (new_graph_buffer_view.has_value()) { auto device_graph_view = new_graph_buffer_view.value(); - raft::copy(device_graph_view.data_handle(), - updated_graph.data_handle(), - updated_graph.size(), - raft::resource::get_cuda_stream(handle)); + raft::copy( + handle, + raft::make_device_vector_view(device_graph_view.data_handle(), updated_graph.size()), + raft::make_host_vector_view(updated_graph.data_handle(), updated_graph.size())); index.update_graph(handle, raft::make_const_mdspan(device_graph_view)); } else { index.update_graph(handle, raft::make_const_mdspan(updated_graph.view())); diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 7a4c70be89..637b9276c0 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -8,6 +8,7 @@ #include "../../vpq_dataset.cuh" #include "graph_core.cuh" +#include #include #include #include @@ -109,8 +110,7 @@ void ace_get_partition_labels( } } auto sample_db_dev = raft::make_device_matrix(res, n_samples, dataset_dim); - raft::update_device( - sample_db_dev.data_handle(), sample_db.data_handle(), sample_db.size(), stream); + raft::copy(res, sample_db_dev.view(), sample_db.view()); cuvs::cluster::kmeans::balanced_params kmeans_params; auto centroids_dev = raft::make_device_matrix(res, n_partitions, dataset_dim); @@ -145,10 +145,11 @@ void ace_get_partition_labels( sub_dataset(i_sub, k) = static_cast(dataset(i, k)); } } + auto sub_dataset_dev_view = raft::make_device_matrix_view( + _sub_dataset_dev.data_handle(), sub_dataset_size, dataset_dim); + raft::copy(res, sub_dataset_dev_view, sub_dataset); auto sub_dataset_dev = raft::make_device_matrix_view( _sub_dataset_dev.data_handle(), sub_dataset_size, dataset_dim); - raft::update_device( - _sub_dataset_dev.data_handle(), sub_dataset.data_handle(), sub_dataset.size(), stream); auto sub_distances = raft::make_host_matrix_view( _sub_distances.data_handle(), sub_dataset_size, n_partitions); @@ -161,8 +162,7 @@ void ace_get_partition_labels( sub_distances_dev, cuvs::distance::DistanceType::L2Expanded); - raft::update_host( - sub_distances.data_handle(), sub_distances_dev.data_handle(), sub_distances.size(), stream); + raft::copy(res, sub_distances, sub_distances_dev); raft::resource::sync_stream(res, stream); // Find two closest partitions to each dataset vector @@ -1382,10 +1382,10 @@ index build_ace(raft::resources const& res, auto sub_search_graph = raft::make_host_matrix(core_sub_dataset_size, graph_degree); cudaStream_t stream = raft::resource::get_cuda_stream(res); - raft::update_host(sub_search_graph.data_handle(), - sub_index.graph().data_handle(), - sub_search_graph.size(), - stream); + raft::copy( + res, + raft::make_host_vector_view(sub_search_graph.data_handle(), sub_search_graph.size()), + raft::make_device_vector_view(sub_index.graph().data_handle(), sub_search_graph.size())); raft::resource::sync_stream(res, stream); if (use_disk_mode) { @@ -1771,16 +1771,14 @@ void build_knn_graph( } // copy next batch to host - raft::copy(neighbors_host.data_handle(), - neighbors.data_handle(), - neighbors_view.size(), - raft::resource::get_cuda_stream(res)); + raft::copy(res, + raft::make_host_vector_view(neighbors_host.data_handle(), neighbors_view.size()), + raft::make_device_vector_view(neighbors.data_handle(), neighbors_view.size())); if (top_k != gpu_top_k) { // can be skipped for disabled refinement - raft::copy(queries_host.data_handle(), - batch.data(), - queries_view.size(), - raft::resource::get_cuda_stream(res)); + raft::copy(res, + raft::make_host_vector_view(queries_host.data_handle(), queries_view.size()), + raft::make_device_vector_view(batch.data(), queries_view.size())); } previous_batch_size = batch.size(); @@ -1822,10 +1820,11 @@ void build_knn_graph( refined_neighbors_view, refined_distances_view, pq.build_params.metric); - raft::copy(refined_neighbors_host.data_handle(), - refined_neighbors_view.data_handle(), - refined_neighbors_view.size(), - raft::resource::get_cuda_stream(res)); + raft::copy(res, + raft::make_host_vector_view(refined_neighbors_host.data_handle(), + refined_neighbors_view.size()), + raft::make_device_vector_view(refined_neighbors_view.data_handle(), + refined_neighbors_view.size())); raft::resource::sync_stream(res); auto refined_neighbors_host_view = raft::make_host_matrix_view( @@ -2135,10 +2134,7 @@ auto iterative_build_graph( auto batch_neighbors_view = raft::make_host_matrix_view( neighbors_view.data_handle() + batch.offset() * curr_topk, batch.size(), curr_topk); - raft::copy(batch_neighbors_view.data_handle(), - batch_dev_neighbors_view.data_handle(), - batch_neighbors_view.size(), - raft::resource::get_cuda_stream(res)); + raft::copy(res, batch_neighbors_view, batch_dev_neighbors_view); } // Optimize graph diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 2d383a2429..f1650980e0 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -30,7 +30,7 @@ #include #include -#include +#include namespace cuvs::neighbors::cagra::detail { @@ -217,16 +217,16 @@ void search_main(raft::resources const& res, // first scale the queries and then compute norms auto scaled_sq_op = raft::compose_op( raft::sq_op{}, raft::div_const_op{DistanceT(kScale)}, raft::cast_op()); - raft::linalg::reduce(query_norms.data_handle(), - queries.data_handle(), - queries.extent(1), - queries.extent(0), - (DistanceT)0, - stream, - false, - scaled_sq_op, - raft::add_op(), - raft::sqrt_op{}); + raft::linalg::reduce( + res, + raft::make_device_matrix_view( + queries.data_handle(), queries.extent(0), queries.extent(1)), + query_norms.view(), + (DistanceT)0, + false, + scaled_sq_op, + raft::add_op(), + raft::sqrt_op{}); const auto n_queries = distances.extent(0); const auto k = distances.extent(1); @@ -239,14 +239,14 @@ void search_main(raft::resources const& res, distances, raft::compose_op(raft::add_const_op{DistanceT(1)}, raft::div_checkzero_op{})); } else { - cuvs::neighbors::ivf::detail::postprocess_distances(dist_out, + cuvs::neighbors::ivf::detail::postprocess_distances(res, + dist_out, dist_in, index.metric(), distances.extent(0), distances.extent(1), kScale, - true, - raft::resource::get_cuda_stream(res)); + true); } } /** @} */ // end group cagra diff --git a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh index 866415b1e4..4e6de1d2b6 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh @@ -1,11 +1,12 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once #include +#include #include #include #include @@ -182,10 +183,7 @@ void serialize_to_hnswlib( auto graph = index_.graph(); auto host_graph = raft::make_host_matrix(graph.extent(0), graph.extent(1)); - raft::copy(host_graph.data_handle(), - graph.data_handle(), - graph.size(), - raft::resource::get_cuda_stream(res)); + raft::copy(res, host_graph.view(), graph); raft::resource::sync_stream(res); size_t d_report_offset = index_.size() / 10; // Report progress in 10% steps. diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 41efa1686f..69175f152f 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -6,11 +6,14 @@ #include "utils.hpp" +#include #include #include +#include #include #include #include +#include // TODO: This shouldn't be invoking anything from spatial/knn #include "../../../core/nvtx.hpp" @@ -531,15 +534,9 @@ void sort_knn_graph( auto d_dataset = raft::make_device_mdarray( res, large_tmp_mr, raft::make_extents(dataset_size, dataset_dim)); - raft::copy(d_dataset.data_handle(), - dataset_ptr, - dataset_size * dataset_dim, - raft::resource::get_cuda_stream(res)); + raft::copy(res, d_dataset.view(), dataset); - raft::copy(d_input_graph.data_handle(), - input_graph_ptr, - graph_size * input_graph_degree, - raft::resource::get_cuda_stream(res)); + raft::copy(res, d_input_graph.view(), knn_graph); void (*kernel_sort)(const DataT* const, const IdxT, @@ -588,10 +585,7 @@ void sort_knn_graph( metric); raft::resource::sync_stream(res); RAFT_LOG_DEBUG("."); - raft::copy(input_graph_ptr, - d_input_graph.data_handle(), - graph_size * input_graph_degree, - raft::resource::get_cuda_stream(res)); + raft::copy(res, knn_graph, raft::make_const_mdspan(d_input_graph.view())); RAFT_LOG_DEBUG("\n"); const double time_sort_end = cur_time(); @@ -795,31 +789,17 @@ void mst_optimization(raft::resources const& res, auto d_stats_ptr = d_stats.data_handle(); if (use_gpu) { - raft::copy(d_mst_graph_ptr, - mst_graph_ptr, - (size_t)graph_size * mst_graph_degree, - raft::resource::get_cuda_stream(res)); - raft::copy(d_outgoing_num_edges_ptr, - outgoing_num_edges_ptr, - (size_t)graph_size, - raft::resource::get_cuda_stream(res)); - raft::copy(d_incoming_num_edges_ptr, - incoming_num_edges_ptr, - (size_t)graph_size, - raft::resource::get_cuda_stream(res)); - raft::copy(d_outgoing_max_edges_ptr, - outgoing_max_edges_ptr, - (size_t)graph_size, - raft::resource::get_cuda_stream(res)); - raft::copy(d_incoming_max_edges_ptr, - incoming_max_edges_ptr, - (size_t)graph_size, - raft::resource::get_cuda_stream(res)); - raft::copy(d_label_ptr, label_ptr, (size_t)graph_size, raft::resource::get_cuda_stream(res)); - raft::copy(d_cluster_size_ptr, - cluster_size_ptr, - (size_t)graph_size, - raft::resource::get_cuda_stream(res)); + raft::copy(res, d_mst_graph.view(), raft::make_const_mdspan(mst_graph.view())); + raft::copy( + res, d_outgoing_num_edges.view(), raft::make_const_mdspan(outgoing_num_edges.view())); + raft::copy( + res, d_incoming_num_edges.view(), raft::make_const_mdspan(incoming_num_edges.view())); + raft::copy( + res, d_outgoing_max_edges.view(), raft::make_const_mdspan(outgoing_max_edges.view())); + raft::copy( + res, d_incoming_max_edges.view(), raft::make_const_mdspan(incoming_max_edges.view())); + raft::copy(res, d_label.view(), raft::make_const_mdspan(label.view())); + raft::copy(res, d_cluster_size.view(), raft::make_const_mdspan(cluster_size.view())); } IdxT num_clusters = 0; @@ -836,11 +816,8 @@ void mst_optimization(raft::resources const& res, // If the number of clusters does not converge to 1, then edges are // made from all nodes not belonging to the main cluster to any node // in the main cluster. - raft::copy(cluster_size_ptr, - d_cluster_size_ptr, - (size_t)graph_size, - raft::resource::get_cuda_stream(res)); - raft::copy(label_ptr, d_label_ptr, (size_t)graph_size, raft::resource::get_cuda_stream(res)); + raft::copy(res, cluster_size.view(), raft::make_const_mdspan(d_cluster_size.view())); + raft::copy(res, label.view(), raft::make_const_mdspan(d_label.view())); raft::resource::sync_stream(res); uint32_t main_cluster_label = graph_size; #pragma omp parallel for reduction(min : main_cluster_label) @@ -871,15 +848,12 @@ void mst_optimization(raft::resources const& res, // 2. Update MST graph // * Try to add candidate edges to MST graph if (use_gpu) { - raft::copy(d_candidate_edges_ptr, - candidate_edges_ptr, - graph_size, - raft::resource::get_cuda_stream(res)); + raft::copy(res, d_candidate_edges.view(), raft::make_const_mdspan(candidate_edges.view())); stats_ptr[0] = 0; stats_ptr[1] = num_direct; stats_ptr[2] = num_alternate; stats_ptr[3] = num_failure; - raft::copy(d_stats_ptr, stats_ptr, 4, raft::resource::get_cuda_stream(res)); + raft::copy(res, d_stats.view(), raft::make_const_mdspan(stats.view())); constexpr uint64_t n_threads = 256; const dim3 threads(n_threads, 1, 1); @@ -896,7 +870,7 @@ void mst_optimization(raft::resources const& res, mst_graph_degree, d_stats_ptr); - raft::copy(stats_ptr, d_stats_ptr, 4, raft::resource::get_cuda_stream(res)); + raft::copy(res, stats.view(), raft::make_const_mdspan(d_stats.view())); raft::resource::sync_stream(res); num_direct = stats_ptr[1]; num_alternate = stats_ptr[2]; @@ -923,7 +897,9 @@ void mst_optimization(raft::resources const& res, flag_update = 0; if (use_gpu) { stats_ptr[0] = flag_update; - raft::copy(d_stats_ptr, stats_ptr, 1, raft::resource::get_cuda_stream(res)); + raft::copy(res, + raft::make_device_vector_view(d_stats_ptr, int64_t(1)), + raft::make_host_vector_view(stats_ptr, int64_t(1))); constexpr uint64_t n_threads = 256; const dim3 threads(n_threads, 1, 1); @@ -931,7 +907,9 @@ void mst_optimization(raft::resources const& res, kern_mst_opt_labeling<<>>( d_label_ptr, d_mst_graph_ptr, graph_size, mst_graph_degree, d_stats_ptr); - raft::copy(stats_ptr, d_stats_ptr, 1, raft::resource::get_cuda_stream(res)); + raft::copy(res, + raft::make_host_vector_view(stats_ptr, int64_t(1)), + raft::make_device_vector_view(d_stats_ptr, int64_t(1))); raft::resource::sync_stream(res); flag_update = stats_ptr[0]; } else { @@ -953,7 +931,9 @@ void mst_optimization(raft::resources const& res, num_clusters = 0; if (use_gpu) { stats_ptr[0] = num_clusters; - raft::copy(d_stats_ptr, stats_ptr, 1, raft::resource::get_cuda_stream(res)); + raft::copy(res, + raft::make_device_vector_view(d_stats_ptr, int64_t(1)), + raft::make_host_vector_view(stats_ptr, int64_t(1))); constexpr uint64_t n_threads = 256; const dim3 threads(n_threads, 1, 1); @@ -961,7 +941,9 @@ void mst_optimization(raft::resources const& res, kern_mst_opt_cluster_size<<>>( d_cluster_size_ptr, d_label_ptr, graph_size, d_stats_ptr); - raft::copy(stats_ptr, d_stats_ptr, 1, raft::resource::get_cuda_stream(res)); + raft::copy(res, + raft::make_host_vector_view(stats_ptr, int64_t(1)), + raft::make_device_vector_view(d_stats_ptr, int64_t(1))); raft::resource::sync_stream(res); num_clusters = stats_ptr[0]; } else { @@ -992,7 +974,7 @@ void mst_optimization(raft::resources const& res, stats_ptr[1] = cluster_size_max; stats_ptr[2] = total_outgoing_edges; stats_ptr[3] = total_incoming_edges; - raft::copy(d_stats_ptr, stats_ptr, 4, raft::resource::get_cuda_stream(res)); + raft::copy(res, d_stats.view(), raft::make_const_mdspan(stats.view())); constexpr uint64_t n_threads = 256; const dim3 threads(n_threads, 1, 1); @@ -1007,7 +989,7 @@ void mst_optimization(raft::resources const& res, mst_graph_degree, d_stats_ptr); - raft::copy(stats_ptr, d_stats_ptr, 4, raft::resource::get_cuda_stream(res)); + raft::copy(res, stats.view(), raft::make_const_mdspan(d_stats.view())); raft::resource::sync_stream(res); cluster_size_min = stats_ptr[0]; cluster_size_max = stats_ptr[1]; @@ -1071,10 +1053,7 @@ void mst_optimization(raft::resources const& res, // The edges that make up the MST are stored as edges in the output graph. if (use_gpu) { - raft::copy(mst_graph_ptr, - d_mst_graph_ptr, - (size_t)graph_size * mst_graph_degree, - raft::resource::get_cuda_stream(res)); + raft::copy(res, mst_graph.view(), raft::make_const_mdspan(d_mst_graph.view())); raft::resource::sync_stream(res); } #pragma omp parallel for @@ -1258,17 +1237,11 @@ void optimize( auto d_detour_count = raft::make_device_mdarray( res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), - 0xff, - graph_size * knn_graph_degree * sizeof(uint8_t), - raft::resource::get_cuda_stream(res))); + raft::matrix::fill(res, d_detour_count.view(), uint8_t(0xff)); auto d_num_no_detour_edges = raft::make_device_mdarray( res, large_tmp_mr, raft::make_extents(graph_size)); - RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), - 0x00, - graph_size * sizeof(uint32_t), - raft::resource::get_cuda_stream(res))); + raft::matrix::fill(res, d_num_no_detour_edges.view(), uint32_t(0)); auto dev_stats = raft::make_device_vector(res, 2); auto host_stats = raft::make_host_vector(2); @@ -1295,8 +1268,7 @@ void optimize( const dim3 threads_prune(32, 1, 1); const dim3 blocks_prune(batch_size, 1, 1); - RAFT_CUDA_TRY(cudaMemsetAsync( - dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, raft::resource::get_cuda_stream(res))); + raft::matrix::fill(res, dev_stats.view(), uint64_t(0)); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { kern_prune @@ -1318,13 +1290,9 @@ void optimize( raft::resource::sync_stream(res); RAFT_LOG_DEBUG("\n"); - raft::copy(detour_count.data_handle(), - d_detour_count.data_handle(), - detour_count.size(), - raft::resource::get_cuda_stream(res)); + raft::copy(res, detour_count.view(), raft::make_const_mdspan(d_detour_count.view())); - raft::copy( - host_stats.data_handle(), dev_stats.data_handle(), 2, raft::resource::get_cuda_stream(res)); + raft::copy(res, host_stats.view(), raft::make_const_mdspan(dev_stats.view())); num_keep = host_stats.data_handle()[0]; num_full = host_stats.data_handle()[1]; @@ -1425,17 +1393,14 @@ void optimize( const double time_make_start = cur_time(); device_matrix_view_from_host d_rev_graph(res, rev_graph.view()); - RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph.data_handle(), - 0xff, - graph_size * output_graph_degree * sizeof(IdxT), - raft::resource::get_cuda_stream(res))); + raft::matrix::fill(res, + raft::make_device_vector_view( + d_rev_graph.data_handle(), graph_size * output_graph_degree), + IdxT(-1)); auto d_rev_graph_count = raft::make_device_mdarray( res, large_tmp_mr, raft::make_extents(graph_size)); - RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(), - 0x00, - graph_size * sizeof(uint32_t), - raft::resource::get_cuda_stream(res))); + raft::matrix::fill(res, d_rev_graph_count.view(), uint32_t(0)); auto dest_nodes = raft::make_host_vector(graph_size); auto d_dest_nodes = @@ -1449,10 +1414,7 @@ void optimize( } raft::resource::sync_stream(res); - raft::copy(d_dest_nodes.data_handle(), - dest_nodes.data_handle(), - graph_size, - raft::resource::get_cuda_stream(res)); + raft::copy(res, d_dest_nodes.view(), raft::make_const_mdspan(dest_nodes.view())); dim3 threads(256, 1, 1); dim3 blocks(1024, 1, 1); @@ -1469,15 +1431,9 @@ void optimize( RAFT_LOG_DEBUG("\n"); if (d_rev_graph.allocated_memory()) { - raft::copy(rev_graph.data_handle(), - d_rev_graph.data_handle(), - graph_size * output_graph_degree, - raft::resource::get_cuda_stream(res)); + raft::copy(res, rev_graph.view(), raft::make_const_mdspan(d_rev_graph.view())); } - raft::copy(rev_graph_count.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - raft::resource::get_cuda_stream(res)); + raft::copy(res, rev_graph_count.view(), raft::make_const_mdspan(d_rev_graph_count.view())); const double time_make_end = cur_time(); RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 30c7287430..b1dadb1e36 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -1,13 +1,15 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once // TODO: This shouldn't be invoking anything from detail outside of neighbors namespace +#include #include #include #include +#include #include #include @@ -180,10 +182,7 @@ class device_matrix_view_from_host { // The user may opt to set this resource to managed memory to allow large allocations. device_mem_.emplace(raft::make_device_mdarray( res, raft::resource::get_large_workspace_resource(res), host_view.extents())); - raft::copy(device_mem_->data_handle(), - host_view.data_handle(), - host_view.extent(0) * host_view.extent(1), - raft::resource::get_cuda_stream(res)); + raft::copy(res, device_mem_->view(), host_view); device_ptr = device_mem_->data_handle(); } } @@ -241,10 +240,7 @@ class host_matrix_view_from_device { // allocate memory and copy over host_mem_.emplace( raft::make_host_matrix(device_view.extent(0), device_view.extent(1))); - raft::copy(host_mem_->data_handle(), - device_view.data_handle(), - device_view.extent(0) * device_view.extent(1), - raft::resource::get_cuda_stream(res)); + raft::copy(res, host_mem_->view(), device_view); host_ptr = host_mem_->data_handle(); } } @@ -282,12 +278,10 @@ void copy_with_padding( raft::make_device_mdarray(res, mr, raft::make_extents(src.extent(0), padded_dim)); } if (dst.extent(1) == src.extent(1)) { - raft::copy( - dst.data_handle(), src.data_handle(), src.size(), raft::resource::get_cuda_stream(res)); + raft::copy(res, dst.view(), src); } else { // copy with padding - RAFT_CUDA_TRY(cudaMemsetAsync( - dst.data_handle(), 0, dst.size() * sizeof(T), raft::resource::get_cuda_stream(res))); + raft::matrix::fill(res, dst.view(), T(0)); RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(), sizeof(T) * dst.extent(1), src.data_handle(), diff --git a/cpp/src/neighbors/detail/hnsw.hpp b/cpp/src/neighbors/detail/hnsw.hpp index e224513f2b..9ef1d0470f 100644 --- a/cpp/src/neighbors/detail/hnsw.hpp +++ b/cpp/src/neighbors/detail/hnsw.hpp @@ -13,7 +13,9 @@ #include #include +#include #include +#include #include #include @@ -267,10 +269,7 @@ std::enable_if_t>> fro // copy cagra graph to host host_graph = raft::make_host_matrix(host_graph_view.extent(0), host_graph_view.extent(1)); - raft::copy(host_graph.data_handle(), - host_graph_view.data_handle(), - host_graph_view.size(), - raft::resource::get_cuda_stream(res)); + raft::copy(res, host_graph.view(), host_graph_view); raft::resource::sync_stream(res); host_graph_view = host_graph.view(); } @@ -914,10 +913,11 @@ std::enable_if_t>> fro if (next_batch_i < n_batches) { auto offset = next_batch_i * max_batch_size; auto batch_size = std::min(max_batch_size, n_rows - offset); - raft::copy(bufs[next_batch_i % 2], - source_dataset + offset * source_stride, - batch_size * source_stride, - stream); + raft::copy( + res, + raft::make_host_vector_view(bufs[next_batch_i % 2], batch_size * source_stride), + raft::make_device_vector_view(source_dataset + offset * source_stride, + batch_size * source_stride)); } } if (batch_i < 0) { continue; } diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index ada7fab978..34ef8dd937 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -111,20 +111,33 @@ void tiled_brute_force_knn(const raft::resources& handle, // cosine needs the l2norm, where as l2 distances needs the squared norm if (metric == cuvs::distance::DistanceType::CosineExpanded) { if (!precomputed_search_norms) { - raft::linalg::rowNorm( - search_norms.data(), search, d, m, stream, raft::sqrt_op{}); + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + search, m, d), + raft::make_device_vector_view(search_norms.data(), m), + raft::sqrt_op{}); } if (!precomputed_index_norms) { - raft::linalg::rowNorm( - index_norms.data(), index, d, n, stream, raft::sqrt_op{}); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(index, n, d), + raft::make_device_vector_view(index_norms.data(), n), + raft::sqrt_op{}); } } else { if (!precomputed_search_norms) { - raft::linalg::rowNorm( - search_norms.data(), search, d, m, stream); + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + search, m, d), + raft::make_device_vector_view(search_norms.data(), m)); } if (!precomputed_index_norms) { - raft::linalg::rowNorm(index_norms.data(), index, d, n, stream); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(index, n, d), + raft::make_device_vector_view(index_norms.data(), n)); } } pairwise_metric = cuvs::distance::DistanceType::InnerProduct; @@ -377,7 +390,9 @@ void brute_force_knn_impl( rmm::device_uvector trans(0, userStream); if (id_ranges.size() > 0) { trans.resize(id_ranges.size(), userStream); - raft::update_device(trans.data(), id_ranges.data(), id_ranges.size(), userStream); + raft::copy(handle, + raft::make_device_vector_view(trans.data(), id_ranges.size()), + raft::make_host_vector_view(id_ranges.data(), id_ranges.size())); } rmm::device_uvector all_D(0, userStream); @@ -455,12 +470,12 @@ void brute_force_knn_impl( metric == cuvs::distance::DistanceType::LpUnexpanded) { DistType p = 0.5; // standard l2 if (metric == cuvs::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg; - raft::linalg::unaryOp( - res_D, - res_D, - n * k, + raft::linalg::map( + handle, + raft::make_device_vector_view(res_D, n * k), [p] __device__(DistType input) { return powf(fabsf(input), p); }, - stream); + raft::make_const_mdspan( + raft::make_device_vector_view(res_D, n * k))); } } else { switch (metric) { @@ -682,24 +697,21 @@ void brute_force_search_filtered( if (metric == cuvs::distance::DistanceType::CosineExpanded) { if (!query_norms) { query_norms_ = raft::make_device_vector(res, n_queries); - raft::linalg::rowNorm( - (DistanceT*)(query_norms_->data_handle()), - queries.data_handle(), - dim, - n_queries, - stream, + raft::linalg::norm( + res, + raft::make_device_matrix_view( + queries.data_handle(), n_queries, dim), + query_norms_->view(), raft::sqrt_op{}); } } else { if (!query_norms) { query_norms_ = raft::make_device_vector(res, n_queries); - raft::linalg::rowNorm( - (DistanceT*)(query_norms_->data_handle()), - queries.data_handle(), - dim, - n_queries, - stream, - raft::identity_op{}); + raft::linalg::norm( + res, + raft::make_device_matrix_view( + queries.data_handle(), n_queries, dim), + query_norms_->view()); } } cuvs::neighbors::detail::epilogue_on_csr( diff --git a/cpp/src/neighbors/detail/knn_graph.cuh b/cpp/src/neighbors/detail/knn_graph.cuh index aaacb65a5d..2d7637083b 100644 --- a/cpp/src/neighbors/detail/knn_graph.cuh +++ b/cpp/src/neighbors/detail/knn_graph.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -10,7 +10,6 @@ #include #include #include -#include #include #include #include @@ -88,10 +87,10 @@ void knn_graph(raft::resources const& res, indices_64_view, distances_view); - raft::linalg::unary_op(res, - raft::make_const_mdspan(indices_64_view), - raft::make_device_vector_view(indices.data(), nnz), - raft::cast_op{}); + raft::linalg::map(res, + raft::make_device_vector_view(indices.data(), nnz), + raft::cast_op{}, + raft::make_const_mdspan(indices_64_view)); raft::sparse::linalg::symmetrize(res, rows.data(), diff --git a/cpp/src/neighbors/detail/nn_descent.cuh b/cpp/src/neighbors/detail/nn_descent.cuh index bd8b81ea03..a1eb829569 100644 --- a/cpp/src/neighbors/detail/nn_descent.cuh +++ b/cpp/src/neighbors/detail/nn_descent.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -13,6 +13,7 @@ #include #include +#include #include #include #include @@ -1368,7 +1369,9 @@ void GNND::add_reverse_edges(Index_t* graph_ptr, std::numeric_limits::max()); add_rev_edges_kernel<<>>( graph_ptr, d_rev_graph_ptr, NUM_SAMPLES, list_sizes); - raft::copy(h_rev_graph_ptr, d_rev_graph_ptr, nrow_ * NUM_SAMPLES, stream); + raft::copy(res, + raft::make_host_vector_view(h_rev_graph_ptr, nrow_ * NUM_SAMPLES), + raft::make_device_vector_view(d_rev_graph_ptr, nrow_ * NUM_SAMPLES)); } template @@ -1499,18 +1502,9 @@ void GNND::build(Data_t* data, }; for (size_t it = 0; it < build_config_.max_iterations; it++) { - raft::copy(d_list_sizes_new_.data_handle(), - graph_.h_list_sizes_new.data_handle(), - nrow_, - raft::resource::get_cuda_stream(res)); - raft::copy(h_graph_old_.data_handle(), - graph_.h_graph_old.data_handle(), - nrow_ * NUM_SAMPLES, - raft::resource::get_cuda_stream(res)); - raft::copy(d_list_sizes_old_.data_handle(), - graph_.h_list_sizes_old.data_handle(), - nrow_, - raft::resource::get_cuda_stream(res)); + raft::copy(res, d_list_sizes_new_.view(), graph_.h_list_sizes_new.view()); + raft::copy(res, h_graph_old_.view(), graph_.h_graph_old.view()); + raft::copy(res, d_list_sizes_old_.view(), graph_.h_list_sizes_old.view()); raft::resource::sync_stream(res); std::thread update_and_sample_thread(update_and_sample, it); @@ -1551,14 +1545,8 @@ void GNND::build(Data_t* data, update_and_sample_thread.join(); if (update_counter_ == -1) { break; } - raft::copy(graph_host_buffer_.data_handle(), - graph_buffer_.data_handle(), - nrow_ * DEGREE_ON_DEVICE, - raft::resource::get_cuda_stream(res)); - raft::copy(dists_host_buffer_.data_handle(), - dists_buffer_.data_handle(), - nrow_ * DEGREE_ON_DEVICE, - raft::resource::get_cuda_stream(res)); + raft::copy(res, graph_host_buffer_.view(), graph_buffer_.view()); + raft::copy(res, dists_host_buffer_.view(), dists_buffer_.view()); raft::resource::sync_stream(res); graph_.sample_graph_new(graph_host_buffer_.data_handle(), DEGREE_ON_DEVICE); @@ -1585,10 +1573,11 @@ void GNND::build(Data_t* data, graph_h_dists(i, j) = graph_.h_dists(i, j); } } - raft::copy(output_distances, - graph_h_dists.data_handle(), - nrow_ * build_config_.output_graph_degree, - raft::resource::get_cuda_stream(res)); + raft::copy( + res, + raft::make_device_vector_view(output_distances, nrow_ * build_config_.output_graph_degree), + raft::make_host_vector_view(graph_h_dists.data_handle(), + nrow_ * build_config_.output_graph_degree)); auto output_dist_view = raft::make_device_matrix_view( output_distances, nrow_, build_config_.output_graph_degree); diff --git a/cpp/src/neighbors/detail/reachability.cuh b/cpp/src/neighbors/detail/reachability.cuh index 5879e878f4..338dbd88a2 100644 --- a/cpp/src/neighbors/detail/reachability.cuh +++ b/cpp/src/neighbors/detail/reachability.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -9,7 +9,6 @@ #include #include #include -#include #include #include #include diff --git a/cpp/src/neighbors/detail/sparse_knn.cuh b/cpp/src/neighbors/detail/sparse_knn.cuh index f24b1d586c..2f20c30b97 100644 --- a/cpp/src/neighbors/detail/sparse_knn.cuh +++ b/cpp/src/neighbors/detail/sparse_knn.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -7,8 +7,10 @@ #include "../../distance/sparse_distance.cuh" #include +#include +#include +#include #include -#include #include #include @@ -286,25 +288,26 @@ class sparse_knn_t { } // copy merged output back into merge buffer partition for next iteration - raft::copy_async(merge_buffer_indices.data(), - indices_merge_buffer_tmp_ptr, - batch_rows * k, - raft::resource::get_cuda_stream(handle)); - raft::copy_async(merge_buffer_dists.data(), - dists_merge_buffer_tmp_ptr, - batch_rows * k, - raft::resource::get_cuda_stream(handle)); + raft::copy(handle, + raft::make_device_vector_view(merge_buffer_indices.data(), batch_rows * k), + raft::make_const_mdspan( + raft::make_device_vector_view(indices_merge_buffer_tmp_ptr, batch_rows * k))); + raft::copy(handle, + raft::make_device_vector_view(merge_buffer_dists.data(), batch_rows * k), + raft::make_const_mdspan( + raft::make_device_vector_view(dists_merge_buffer_tmp_ptr, batch_rows * k))); } // Copy final merged batch to output array - raft::copy_async(output_indices + (rows_processed * k), - merge_buffer_indices.data(), - query_batcher.batch_rows() * k, - raft::resource::get_cuda_stream(handle)); - raft::copy_async(output_dists + (rows_processed * k), - merge_buffer_dists.data(), - query_batcher.batch_rows() * k, - raft::resource::get_cuda_stream(handle)); + auto batch_len = query_batcher.batch_rows() * k; + raft::copy(handle, + raft::make_device_vector_view(output_indices + (rows_processed * k), batch_len), + raft::make_const_mdspan( + raft::make_device_vector_view(merge_buffer_indices.data(), batch_len))); + raft::copy(handle, + raft::make_device_vector_view(output_dists + (rows_processed * k), batch_len), + raft::make_const_mdspan( + raft::make_device_vector_view(merge_buffer_dists.data(), batch_len))); rows_processed += query_batcher.batch_rows(); } @@ -324,8 +327,9 @@ class sparse_knn_t { id_ranges.push_back(idx_batcher.batch_start()); rmm::device_uvector trans(id_ranges.size(), raft::resource::get_cuda_stream(handle)); - raft::update_device( - trans.data(), id_ranges.data(), id_ranges.size(), raft::resource::get_cuda_stream(handle)); + raft::copy(handle, + raft::make_device_vector_view(trans.data(), id_ranges.size()), + raft::make_host_vector_view(id_ranges.data(), id_ranges.size())); // combine merge buffers only if there's more than 1 partition to combine auto rows = query_batcher.batch_rows(); diff --git a/cpp/src/neighbors/detail/tiered_index.cuh b/cpp/src/neighbors/detail/tiered_index.cuh index 1338b4c78f..9cad64549b 100644 --- a/cpp/src/neighbors/detail/tiered_index.cuh +++ b/cpp/src/neighbors/detail/tiered_index.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -237,7 +237,10 @@ struct index_state { auto stream = raft::resource::get_cuda_stream(res); int64_t host_translations[2] = {0, static_cast(ann_rows())}; auto device_translations = raft::make_device_vector(res, 2); - raft::copy(device_translations.data_handle(), host_translations, 2, stream); + raft::copy( + res, + device_translations.view(), + raft::make_host_vector_view(host_translations, device_translations.extent(0))); knn_merge_parts(res, temp_distances.view(), diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 9de51e33c5..0798141c9c 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -237,7 +237,10 @@ void batched_insert_vamana( int num_blocks = min(maxBlocks, step_size); // Copy ids to be inserted for this batch - raft::copy(query_ids.data_handle(), &insert_order.data()[start], step_size, stream); + raft::copy( + res, + raft::make_device_vector_view(query_ids.data_handle(), int64_t(step_size)), + raft::make_host_vector_view(insert_order.data() + start, int64_t(step_size))); set_query_ids<<>>( query_list_ptr.data_handle(), query_ids.data_handle(), step_size); @@ -542,7 +545,7 @@ void batched_insert_vamana( batch_prune); #endif - raft::copy(graph.data_handle(), d_graph.data_handle(), d_graph.size(), stream); + raft::copy(res, graph, d_graph.view()); RAFT_CHECK_CUDA(stream); } @@ -604,10 +607,10 @@ index build( res, codebook_params.pq_encoding_table.size()); // logically a 2D matrix with dimensions // pq_codebook_size x dim_per_subspace * pq_dim - raft::copy(pq_encoding_table_device_vec.data_handle(), - codebook_params.pq_encoding_table.data(), - codebook_params.pq_encoding_table.size(), - raft::resource::get_cuda_stream(res)); + raft::copy(res, + pq_encoding_table_device_vec.view(), + raft::make_host_vector_view(codebook_params.pq_encoding_table.data(), + pq_encoding_table_device_vec.extent(0))); int dim_per_subspace = dim / pq_dim; auto pq_codebook = raft::make_device_matrix(res, pq_codebook_size * pq_dim, dim_per_subspace); @@ -631,10 +634,12 @@ index build( // prepare rotation matrix auto rotation_matrix_device = raft::make_device_matrix(res, dim, dim); - raft::copy(rotation_matrix_device.data_handle(), - codebook_params.rotation_matrix.data(), - codebook_params.rotation_matrix.size(), - raft::resource::get_cuda_stream(res)); + raft::copy( + res, + raft::make_device_vector_view(rotation_matrix_device.data_handle(), + int64_t(codebook_params.rotation_matrix.size())), + raft::make_host_vector_view(codebook_params.rotation_matrix.data(), + int64_t(codebook_params.rotation_matrix.size()))); // process in batches const uint32_t n_rows = dataset.extent(0); diff --git a/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh b/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh index b4960c726f..887c9eb448 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,6 +8,7 @@ #include "vamana_structs.cuh" #include +#include #include #include #include @@ -66,10 +67,11 @@ void serialize_dataset(raft::resources const& res, if (strided_dataset) { auto h_dataset = raft::make_host_matrix(strided_dataset->n_rows(), strided_dataset->dim()); - raft::copy(h_dataset.data_handle(), - strided_dataset->view().data_handle(), - strided_dataset->n_rows() * strided_dataset->dim(), - raft::resource::get_cuda_stream(res)); + raft::copy(res, + raft::make_host_vector_view(h_dataset.data_handle(), + strided_dataset->n_rows() * strided_dataset->dim()), + raft::make_device_vector_view(strided_dataset->view().data_handle(), + strided_dataset->n_rows() * strided_dataset->dim())); to_file(dataset_base_file, h_dataset); } else { RAFT_LOG_DEBUG("dynamic_cast to strided_dataset failed"); @@ -88,10 +90,7 @@ void serialize_dataset(raft::resources const& res, // try allocating a buffer for the dataset on host try { auto h_dataset = raft::make_host_matrix(dataset.extent(0), dataset.extent(1)); - raft::copy(h_dataset.data_handle(), - dataset.data_handle(), - dataset.extent(0) * dataset.extent(1), - raft::resource::get_cuda_stream(res)); + raft::copy(res, h_dataset.view(), dataset); to_file(dataset_base_file, h_dataset); } catch (std::bad_alloc& e) { RAFT_LOG_INFO("Failed to serialize dataset"); @@ -313,10 +312,7 @@ void serialize(raft::resources const& res, { auto d_graph = index_.graph(); auto h_graph = raft::make_host_matrix(d_graph.extent(0), d_graph.extent(1)); - raft::copy(h_graph.data_handle(), - d_graph.data_handle(), - d_graph.size(), - raft::resource::get_cuda_stream(res)); + raft::copy(res, h_graph.view(), d_graph); raft::resource::sync_stream(res); // if requested, write sector-aligned file and return diff --git a/cpp/src/neighbors/iface/iface.hpp b/cpp/src/neighbors/iface/iface.hpp index 59b1d55905..e76a3673af 100644 --- a/cpp/src/neighbors/iface/iface.hpp +++ b/cpp/src/neighbors/iface/iface.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -116,10 +117,7 @@ void search(const raft::resources& handle, int64_t n_rows = h_queries.extent(0); int64_t n_dims = h_queries.extent(1); auto d_queries = raft::make_device_matrix(handle, n_rows, n_dims); - raft::copy(d_queries.data_handle(), - h_queries.data_handle(), - n_rows * n_dims, - resource::get_cuda_stream(handle)); + raft::copy(handle, d_queries.view(), h_queries); auto d_query_view = raft::make_const_mdspan(d_queries.view()); search(handle, interface, search_params, d_query_view, d_neighbors, d_distances); diff --git a/cpp/src/neighbors/ivf_common.cuh b/cpp/src/neighbors/ivf_common.cuh index ad3dc86d0d..4624ce9a7a 100644 --- a/cpp/src/neighbors/ivf_common.cuh +++ b/cpp/src/neighbors/ivf_common.cuh @@ -6,7 +6,11 @@ #pragma once #include -#include +#include +#include +#include +#include +#include #include // matrix::detail::select::warpsort::warp_sort_distributed namespace cuvs::neighbors::ivf::detail { @@ -169,62 +173,65 @@ void postprocess_neighbors(IdxT* neighbors_out, // [n_queries, to * translate the element type if necessary. */ template -void postprocess_distances(ScoreOutT* out, // [n_queries, topk] +void postprocess_distances(const raft::resources& res, + ScoreOutT* out, // [n_queries, topk] const ScoreInT* in, // [n_queries, topk] distance::DistanceType metric, uint32_t n_queries, uint32_t topk, float scaling_factor, - bool account_for_max_close, - rmm::cuda_stream_view stream) + bool account_for_max_close) { constexpr bool needs_cast = !std::is_same::value; const bool needs_copy = ((void*)in) != ((void*)out); size_t len = size_t(n_queries) * size_t(topk); + auto out_view = raft::make_device_vector_view(out, len); + auto in_view = raft::make_device_vector_view(in, len); switch (metric) { case distance::DistanceType::L2Unexpanded: case distance::DistanceType::L2Expanded: { if (scaling_factor != 1.0) { - raft::linalg::unaryOp( - out, - in, - len, + raft::linalg::map( + res, + out_view, raft::compose_op(raft::mul_const_op{scaling_factor * scaling_factor}, raft::cast_op{}), - stream); + raft::make_const_mdspan(in_view)); } else if (needs_cast || needs_copy) { - raft::linalg::unaryOp(out, in, len, raft::cast_op{}, stream); + raft::linalg::map( + res, out_view, raft::cast_op{}, raft::make_const_mdspan(in_view)); } } break; case distance::DistanceType::L2SqrtUnexpanded: case distance::DistanceType::L2SqrtExpanded: { if (scaling_factor != 1.0) { - raft::linalg::unaryOp(out, - in, - len, - raft::compose_op{raft::mul_const_op{scaling_factor}, - raft::sqrt_op{}, - raft::cast_op{}}, - stream); + raft::linalg::map(res, + out_view, + raft::compose_op{raft::mul_const_op{scaling_factor}, + raft::sqrt_op{}, + raft::cast_op{}}, + raft::make_const_mdspan(in_view)); } else if (needs_cast) { - raft::linalg::unaryOp( - out, in, len, raft::compose_op{raft::sqrt_op{}, raft::cast_op{}}, stream); + raft::linalg::map(res, + out_view, + raft::compose_op{raft::sqrt_op{}, raft::cast_op{}}, + raft::make_const_mdspan(in_view)); } else { - raft::linalg::unaryOp(out, in, len, raft::sqrt_op{}, stream); + raft::linalg::map(res, out_view, raft::sqrt_op{}, raft::make_const_mdspan(in_view)); } } break; case distance::DistanceType::CosineExpanded: case distance::DistanceType::InnerProduct: { float factor = (account_for_max_close ? -1.0 : 1.0) * scaling_factor * scaling_factor; if (factor != 1.0) { - raft::linalg::unaryOp( - out, - in, - len, + raft::linalg::map( + res, + out_view, raft::compose_op(raft::mul_const_op{factor}, raft::cast_op{}), - stream); + raft::make_const_mdspan(in_view)); } else if (needs_cast || needs_copy) { - raft::linalg::unaryOp(out, in, len, raft::cast_op{}, stream); + raft::linalg::map( + res, out_view, raft::cast_op{}, raft::make_const_mdspan(in_view)); } } break; case distance::DistanceType::BitwiseHamming: break; @@ -256,15 +263,17 @@ void recompute_internal_state(const raft::resources& res, Index& index) sort_cluster_sizes_descending( index.list_sizes().data_handle(), sorted_sizes.data(), index.n_lists(), stream, tmp_res); // copy the results to CPU - std::vector sorted_sizes_host(index.n_lists()); - raft::copy(sorted_sizes_host.data(), sorted_sizes.data(), index.n_lists(), stream); + auto sorted_sizes_host = raft::make_host_vector(index.n_lists()); + raft::copy(res, + sorted_sizes_host.view(), + raft::make_device_vector_view(sorted_sizes.data(), index.n_lists())); raft::resource::sync_stream(res); // accumulate the sorted cluster sizes auto accum_sorted_sizes = index.accum_sorted_sizes(); accum_sorted_sizes(0) = 0; - for (uint32_t label = 0; label < sorted_sizes_host.size(); label++) { - accum_sorted_sizes(label + 1) = accum_sorted_sizes(label) + sorted_sizes_host[label]; + for (uint32_t label = 0; label < sorted_sizes_host.extent(0); label++) { + accum_sorted_sizes(label + 1) = accum_sorted_sizes(label) + sorted_sizes_host(label); } } diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh index c191773b58..06862c083d 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -17,6 +17,9 @@ #include "../detail/ann_utils.cuh" #include #include +#include +#include +#include #include #include #include @@ -27,6 +30,7 @@ #include #include #include +#include #include #include @@ -53,20 +57,11 @@ auto clone(const raft::resources& res, const index& source) -> indexdata_handle(), - source.center_norms()->data_handle(), - source.center_norms()->size(), - stream); + raft::copy(res, target.center_norms().value(), source.center_norms().value()); } // Copy shared pointers target.lists() = source.lists(); @@ -230,7 +225,9 @@ void extend(raft::resources const& handle, auto* list_sizes_ptr = index->list_sizes().data_handle(); auto old_list_sizes_dev = raft::make_device_mdarray( handle, raft::resource::get_workspace_resource(handle), raft::make_extents(n_lists)); - raft::copy(old_list_sizes_dev.data_handle(), list_sizes_ptr, n_lists, stream); + raft::copy(handle, + old_list_sizes_dev.view(), + raft::make_device_vector_view(list_sizes_ptr, n_lists)); // Calculate the centers and sizes on the new data, starting from the original values if (index->adaptive_centers()) { @@ -260,16 +257,23 @@ void extend(raft::resources const& handle, n_rows, 1, stream); - raft::linalg::add( - list_sizes_ptr, list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); + raft::linalg::add(handle, + raft::make_device_vector_view(list_sizes_ptr, n_lists), + raft::make_device_vector_view( + old_list_sizes_dev.data_handle(), n_lists), + raft::make_device_vector_view(list_sizes_ptr, n_lists)); } // Calculate and allocate new list data std::vector new_list_sizes(n_lists); std::vector old_list_sizes(n_lists); { - raft::copy(old_list_sizes.data(), old_list_sizes_dev.data_handle(), n_lists, stream); - raft::copy(new_list_sizes.data(), list_sizes_ptr, n_lists, stream); + raft::copy(handle, + raft::make_host_vector_view(old_list_sizes.data(), n_lists), + raft::make_device_vector_view(old_list_sizes_dev.data_handle(), n_lists)); + raft::copy(handle, + raft::make_host_vector_view(new_list_sizes.data(), n_lists), + raft::make_device_vector_view(list_sizes_ptr, n_lists)); raft::resource::sync_stream(handle); auto& lists = index->lists(); for (uint32_t label = 0; label < n_lists; label++) { @@ -284,7 +288,10 @@ void extend(raft::resources const& handle, ivf::detail::recompute_internal_state(handle, *index); // Copy the old sizes, so we can start from the current state of the index; // we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter. - raft::copy(list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); + raft::copy( + handle, + raft::make_device_vector_view(list_sizes_ptr, n_lists), + raft::make_device_vector_view(old_list_sizes_dev.data_handle(), n_lists)); utils::batch_load_iterator vec_indices( new_indices, n_rows, 1, max_batch_size, stream, raft::resource::get_workspace_resource(handle)); @@ -329,33 +336,30 @@ void extend(raft::resources const& handle, if (!index->center_norms().has_value()) { index->allocate_center_norms(handle); if (index->center_norms().has_value()) { + auto centers_view = raft::make_device_matrix_view( + index->centers().data_handle(), n_lists, dim); + auto norms_view = raft::make_device_vector_view( + index->center_norms()->data_handle(), n_lists); if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::rowNorm(index->center_norms()->data_handle(), - index->centers().data_handle(), - dim, - n_lists, - stream, - raft::sqrt_op{}); + raft::linalg::norm( + handle, centers_view, norms_view, raft::sqrt_op{}); } else { - raft::linalg::rowNorm(index->center_norms()->data_handle(), - index->centers().data_handle(), - dim, - n_lists, - stream); + raft::linalg::norm( + handle, centers_view, norms_view); } RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); } } else if (index->center_norms().has_value() && index->adaptive_centers()) { + auto centers_view = raft::make_device_matrix_view( + index->centers().data_handle(), n_lists, dim); + auto norms_view = + raft::make_device_vector_view(index->center_norms()->data_handle(), n_lists); if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::rowNorm(index->center_norms()->data_handle(), - index->centers().data_handle(), - dim, - n_lists, - stream, - raft::sqrt_op{}); + raft::linalg::norm( + handle, centers_view, norms_view, raft::sqrt_op{}); } else { - raft::linalg::rowNorm( - index->center_norms()->data_handle(), index->centers().data_handle(), dim, n_lists, stream); + raft::linalg::norm( + handle, centers_view, norms_view); } RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); } @@ -485,7 +489,7 @@ inline void fill_refinement_index(raft::resources const& handle, // Update the pointers and the sizes ivf::detail::recompute_internal_state(handle, *refinement_index); - RAFT_CUDA_TRY(cudaMemsetAsync(list_sizes_ptr, 0, n_lists * sizeof(uint32_t), stream)); + raft::matrix::fill(handle, raft::make_device_vector_view(list_sizes_ptr, n_lists), uint32_t(0)); const dim3 block_dim(256); const dim3 grid_dim(raft::ceildiv(n_queries * n_candidates, block_dim.x)); diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh index 03de3eb791..3379e7b8dc 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -18,10 +18,10 @@ #include #include #include -#include // raft::resources -#include // raft::linalg::gemm -#include // raft::linalg::norm -#include // raft::linalg::unary_op +#include // raft::resources +#include // raft::linalg::gemm +#include // raft::linalg::map +#include // raft::linalg::norm #include #include @@ -90,8 +90,12 @@ void search_impl(raft::resources const& handle, if constexpr (std::is_same_v) { converted_queries_ptr = const_cast(queries); } else { - raft::linalg::unaryOp( - converted_queries_ptr, queries, n_queries * index.dim(), utils::mapping{}, stream); + raft::linalg::map( + handle, + raft::make_device_vector_view(converted_queries_ptr, n_queries * index.dim()), + utils::mapping{}, + raft::make_const_mdspan( + raft::make_device_vector_view(queries, n_queries * index.dim()))); } float alpha = 1.0f; @@ -103,11 +107,12 @@ void search_impl(raft::resources const& handle, case cuvs::distance::DistanceType::L2SqrtExpanded: { alpha = -2.0f; beta = 1.0f; - raft::linalg::rowNorm(query_norm_dev.data(), - converted_queries_ptr, - static_cast(index.dim()), - static_cast(n_queries), - stream); + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + converted_queries_ptr, static_cast(n_queries), static_cast(index.dim())), + raft::make_device_vector_view(query_norm_dev.data(), + static_cast(n_queries))); utils::outer_add(query_norm_dev.data(), (IdxT)n_queries, index.center_norms()->data_handle(), @@ -119,12 +124,13 @@ void search_impl(raft::resources const& handle, break; } case cuvs::distance::DistanceType::CosineExpanded: { - raft::linalg::rowNorm(query_norm_dev.data(), - converted_queries_ptr, - static_cast(index.dim()), - static_cast(n_queries), - stream, - raft::sqrt_op{}); + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + converted_queries_ptr, static_cast(n_queries), static_cast(index.dim())), + raft::make_device_vector_view(query_norm_dev.data(), + static_cast(n_queries)), + raft::sqrt_op{}); alpha = -1.0f; beta = 0.0f; break; @@ -283,7 +289,7 @@ void search_impl(raft::resources const& handle, if (!manage_local_topk) { // post process distances && neighbor IDs ivf::detail::postprocess_distances( - distances, distances, index.metric(), n_queries, k, 1.0, false, stream); + handle, distances, distances, index.metric(), n_queries, k, 1.0, false); } ivf::detail::postprocess_neighbors(neighbors, neighbors_uint32, diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_serialize.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize.cuh index 9ce3a85d8c..e29d1d9589 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_serialize.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -64,10 +65,7 @@ void serialize(raft::resources const& handle, std::ostream& os, const index(index_.list_sizes().extent(0)); - raft::copy(sizes_host.data_handle(), - index_.list_sizes().data_handle(), - sizes_host.size(), - raft::resource::get_cuda_stream(handle)); + raft::copy(handle, sizes_host.view(), index_.list_sizes()); raft::resource::sync_stream(handle); serialize_mdspan(handle, os, sizes_host.view()); diff --git a/cpp/src/neighbors/ivf_list.cuh b/cpp/src/neighbors/ivf_list.cuh index a2dd54623c..d8704c7b89 100644 --- a/cpp/src/neighbors/ivf_list.cuh +++ b/cpp/src/neighbors/ivf_list.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,20 +8,22 @@ #include #include +#include #include +#include #include #include +#include #include #include #include #include #include +#include #include #include "ivf_common.cuh" -#include - #include #include #include @@ -56,10 +58,7 @@ list::list(raft::resources const& res, e.what()); } // Fill the index buffer with a pre-defined marker for easier debugging - thrust::fill_n(raft::resource::get_thrust_policy(res), - indices.data_handle(), - indices.size(), - ivf::kInvalidRecord); + raft::matrix::fill(res, indices.view(), ivf::kInvalidRecord); } template @@ -94,14 +93,12 @@ void resize_list(raft::resources const& res, raft::row_major, false, true>(new_list->data.data_handle(), copied_data_extents); - raft::copy(copied_view.data_handle(), - orig_list->data.data_handle(), - copied_view.size(), - raft::resource::get_cuda_stream(res)); - raft::copy(new_list->indices.data_handle(), - orig_list->indices.data_handle(), - old_used_size, - raft::resource::get_cuda_stream(res)); + raft::copy(res, + raft::make_device_vector_view(copied_view.data_handle(), copied_view.size()), + raft::make_device_vector_view(orig_list->data.data_handle(), copied_view.size())); + raft::copy(res, + raft::make_device_vector_view(new_list->indices.data_handle(), old_used_size), + raft::make_device_vector_view(orig_list->indices.data_handle(), old_used_size)); } // swap the shared pointer content with the new list new_list.swap(orig_list); @@ -124,14 +121,12 @@ enable_if_valid_list_t serialize_list(const raft::resources& handle, raft::make_host_mdarray(data_extents); auto inds_array = raft::make_host_mdarray( raft::make_extents(size)); - raft::copy(data_array.data_handle(), - ld.data.data_handle(), - data_array.size(), - raft::resource::get_cuda_stream(handle)); - raft::copy(inds_array.data_handle(), - ld.indices.data_handle(), - inds_array.size(), - raft::resource::get_cuda_stream(handle)); + raft::copy(handle, + raft::make_host_vector_view(data_array.data_handle(), data_array.size()), + raft::make_device_vector_view(ld.data.data_handle(), data_array.size())); + raft::copy(handle, + raft::make_host_vector_view(inds_array.data_handle(), inds_array.size()), + raft::make_device_vector_view(ld.indices.data_handle(), inds_array.size())); raft::resource::sync_stream(handle); raft::serialize_mdspan(handle, os, data_array.view()); raft::serialize_mdspan(handle, os, inds_array.view()); @@ -169,15 +164,13 @@ enable_if_valid_list_t deserialize_list(const raft::resources& handle, raft::make_extents(size)); raft::deserialize_mdspan(handle, is, data_array.view()); raft::deserialize_mdspan(handle, is, inds_array.view()); - raft::copy(ld->data.data_handle(), - data_array.data_handle(), - data_array.size(), - raft::resource::get_cuda_stream(handle)); + raft::copy(handle, + raft::make_device_vector_view(ld->data.data_handle(), data_array.size()), + raft::make_host_vector_view(data_array.data_handle(), data_array.size())); // NB: copying exactly 'size' indices to leave the rest 'kInvalidRecord' intact. - raft::copy(ld->indices.data_handle(), - inds_array.data_handle(), - size, - raft::resource::get_cuda_stream(handle)); + raft::copy(handle, + raft::make_device_vector_view(ld->indices.data_handle(), size), + raft::make_host_vector_view(inds_array.data_handle(), size)); // Make sure the data is copied from host to device before the host arrays get out of the scope. raft::resource::sync_stream(handle); } diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index 51bf9f2423..b2da2bb821 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -38,7 +38,6 @@ #include #include #include -#include #include #include #include @@ -149,28 +148,34 @@ void flat_compute_residuals( auto tmp_view = raft::make_device_vector_view(tmp.data(), tmp.size()); if (metric == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::map(handle, - tmp_view, - raft::cast_op{}, - raft::make_device_vector_view(dataset, n_rows * dim)); + raft::linalg::map( + handle, + tmp_view, + raft::cast_op{}, + raft::make_const_mdspan(raft::make_device_vector_view(dataset, n_rows * dim))); auto tmp_matrix_view = raft::make_device_matrix_view(tmp.data(), n_rows, dim); raft::linalg::row_normalize( handle, raft::make_const_mdspan(tmp_matrix_view), tmp_matrix_view); } else { - raft::linalg::map_offset(handle, tmp_view, [dataset, dim] __device__(size_t i) { - return utils::mapping{}(dataset[i]); - }); + raft::linalg::map_offset( + handle, + tmp_view, + [dim] __device__(size_t i, T val) { return utils::mapping{}(val); }, + raft::make_const_mdspan(raft::make_device_vector_view(dataset, n_rows * dim))); } raft::linalg::map_offset( - handle, tmp_view, [centers, tmp = tmp.data(), labels, dim] __device__(size_t i) { + handle, + tmp_view, + [centers, labels, dim] __device__(size_t i, float val) { auto row_ix = i / dim; auto el_ix = i % dim; auto label = std::holds_alternative(labels) ? std::get(labels) : std::get(labels)[row_ix]; - return tmp[i] - centers(label, el_ix); - }); + return val - centers(label, el_ix); + }, + raft::make_const_mdspan(tmp_view)); float alpha = 1.0f; float beta = 0.0f; @@ -268,8 +273,10 @@ inline void pad_centers_with_norms(raft::resources const& res, stream)); rmm::device_uvector center_norms(n_lists, stream); - raft::linalg::rowNorm( - center_norms.data(), centers, dim, n_lists, stream); + raft::linalg::norm( + res, + raft::make_device_matrix_view(centers, n_lists, dim), + raft::make_device_vector_view(center_norms.data(), n_lists)); RAFT_CUDA_TRY(cudaMemcpy2DAsync(padded_centers + dim, sizeof(float) * dim_ext, center_norms.data(), @@ -837,12 +844,14 @@ auto extend_list_prepare( uint32_t n_rows = new_indices.extent(0); uint32_t offset; // Allocate the lists to fit the new data - raft::copy( - &offset, index->list_sizes().data_handle() + label, 1, raft::resource::get_cuda_stream(res)); + raft::copy(res, + raft::make_host_scalar_view(&offset), + raft::make_device_scalar_view(index->list_sizes().data_handle() + label)); raft::resource::sync_stream(res); uint32_t new_size = offset + n_rows; - raft::copy( - index->list_sizes().data_handle() + label, &new_size, 1, raft::resource::get_cuda_stream(res)); + raft::copy(res, + raft::make_device_scalar_view(index->list_sizes().data_handle() + label), + raft::make_host_scalar_view(&new_size)); auto& list_data_base_ptr = index->lists()[label]; if (index->codes_layout() == list_layout::FLAT) { auto spec = list_spec_flat{ @@ -853,10 +862,10 @@ auto extend_list_prepare( index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()}; cuvs::neighbors::ivf_pq::helpers::resize_list(res, list_data_base_ptr, spec, new_size, offset); } - raft::copy(list_data_base_ptr->indices_ptr() + offset, - new_indices.data_handle(), - n_rows, - raft::resource::get_cuda_stream(res)); + raft::copy(res, + raft::make_device_vector_view( + list_data_base_ptr->indices_ptr() + offset, n_rows), + new_indices); return offset; } @@ -927,8 +936,9 @@ template void erase_list(raft::resources const& res, index* index, uint32_t label) { uint32_t zero = 0; - raft::copy( - index->list_sizes().data_handle() + label, &zero, 1, raft::resource::get_cuda_stream(res)); + raft::copy(res, + raft::make_device_scalar_view(index->list_sizes().data_handle() + label), + raft::make_host_scalar_view(&zero)); index->lists()[label].reset(); ivf::detail::recompute_internal_state(res, *index); } @@ -950,25 +960,12 @@ auto clone(const raft::resources& res, const index& source) -> index source.conservative_memory_allocation(), source.codes_layout()); - // raft::copy the independent parts using mutable accessors - raft::copy(impl->list_sizes().data_handle(), - source.list_sizes().data_handle(), - source.list_sizes().size(), - stream); - raft::copy(impl->rotation_matrix().data_handle(), - source.rotation_matrix().data_handle(), - source.rotation_matrix().size(), - stream); - raft::copy(impl->pq_centers().data_handle(), - source.pq_centers().data_handle(), - source.pq_centers().size(), - stream); - raft::copy( - impl->centers().data_handle(), source.centers().data_handle(), source.centers().size(), stream); - raft::copy(impl->centers_rot().data_handle(), - source.centers_rot().data_handle(), - source.centers_rot().size(), - stream); + // Copy the independent parts using mutable accessors + raft::copy(res, impl->list_sizes(), source.list_sizes()); + raft::copy(res, impl->rotation_matrix(), source.rotation_matrix()); + raft::copy(res, impl->pq_centers(), source.pq_centers()); + raft::copy(res, impl->centers(), source.centers()); + raft::copy(res, impl->centers_rot(), source.centers_rot()); // raft::copy shared pointers impl->lists() = source.lists(); @@ -1135,7 +1132,9 @@ void extend(raft::resources const& handle, auto list_sizes = index->list_sizes().data_handle(); // store the current cluster sizes, because we'll need them later rmm::device_uvector orig_list_sizes(n_clusters, stream, device_memory); - raft::copy(orig_list_sizes.data(), list_sizes, n_clusters, stream); + raft::copy(handle, + raft::make_device_vector_view(orig_list_sizes.data(), n_clusters), + raft::make_device_vector_view(list_sizes, n_clusters)); // Get the combined cluster sizes raft::stats::histogram(raft::stats::HistTypeAuto, @@ -1145,14 +1144,22 @@ void extend(raft::resources const& handle, n_rows, 1, stream); - raft::linalg::add(list_sizes, list_sizes, orig_list_sizes.data(), n_clusters, stream); + raft::linalg::add( + handle, + raft::make_device_vector_view(list_sizes, n_clusters), + raft::make_device_vector_view(orig_list_sizes.data(), n_clusters), + raft::make_device_vector_view(list_sizes, n_clusters)); // Allocate the lists to fit the new data { std::vector new_cluster_sizes(n_clusters); std::vector old_cluster_sizes(n_clusters); - raft::copy(new_cluster_sizes.data(), list_sizes, n_clusters, stream); - raft::copy(old_cluster_sizes.data(), orig_list_sizes.data(), n_clusters, stream); + raft::copy(handle, + raft::make_host_vector_view(new_cluster_sizes.data(), n_clusters), + raft::make_device_vector_view(list_sizes, n_clusters)); + raft::copy(handle, + raft::make_host_vector_view(old_cluster_sizes.data(), n_clusters), + raft::make_device_vector_view(orig_list_sizes.data(), n_clusters)); raft::resource::sync_stream(handle); if (index->codes_layout() == list_layout::FLAT) { auto spec = list_spec_flat{ @@ -1175,7 +1182,9 @@ void extend(raft::resources const& handle, ivf::detail::recompute_internal_state(handle, *index); // Recover old cluster sizes: they are used as counters in the fill-codes kernel - raft::copy(list_sizes, orig_list_sizes.data(), n_clusters, stream); + raft::copy(handle, + raft::make_device_vector_view(list_sizes, n_clusters), + raft::make_device_vector_view(orig_list_sizes.data(), n_clusters)); // By this point, the index state is updated and valid except it doesn't contain the new data // Fill the extended index with the new data (possibly, in batches) @@ -1300,11 +1309,12 @@ auto build(raft::resources const& handle, raft::matrix::sample_rows(handle, random_state, dataset, trainset_tmp.view()); - raft::linalg::unaryOp(trainset.data_handle(), - trainset_tmp.data_handle(), - trainset.size(), - utils::mapping{}, - raft::resource::get_cuda_stream(handle)); + raft::linalg::map(handle, + raft::make_device_vector_view(trainset.data_handle(), + (int64_t)trainset.size()), + utils::mapping{}, + raft::make_const_mdspan(raft::make_device_vector_view( + trainset_tmp.data_handle(), (int64_t)trainset.size()))); } // NB: here cluster_centers is used as if it is [n_clusters, data_dim] not [n_clusters, @@ -1571,10 +1581,7 @@ auto build( impl->n_lists()); if (centers.extent(1) == impl->dim_ext()) { - raft::copy(impl->centers().data_handle(), - centers.data_handle(), - impl->centers().extent(0) * impl->centers().extent(1), - stream); + raft::copy(handle, impl->centers(), centers); } else { cuvs::neighbors::ivf_pq::helpers::pad_centers_with_norms(handle, centers, impl->centers()); } @@ -1587,10 +1594,7 @@ auto build( dim, rotation_matrix.value().extent(0), rotation_matrix.value().extent(1)); - raft::copy(impl->rotation_matrix().data_handle(), - rotation_matrix.value().data_handle(), - rotation_matrix.value().size(), - stream); + raft::copy(handle, impl->rotation_matrix(), rotation_matrix.value()); } else { helpers::make_rotation_matrix( handle, impl->rotation_matrix(), index_params.force_random_rotation); @@ -1604,10 +1608,7 @@ auto build( impl->rot_dim(), centers_rot.value().extent(0), centers_rot.value().extent(1)); - raft::copy(impl->centers_rot().data_handle(), - centers_rot.value().data_handle(), - centers_rot.value().size(), - stream); + raft::copy(handle, impl->centers_rot(), centers_rot.value()); } else { cuvs::neighbors::ivf_pq::helpers::rotate_padded_centers( handle, impl->centers(), impl->rotation_matrix(), impl->centers_rot()); @@ -1623,7 +1624,7 @@ auto build( pq_centers.extent(0), pq_centers.extent(1), pq_centers.extent(2)); - raft::copy(impl->pq_centers().data_handle(), pq_centers.data_handle(), pq_centers.size(), stream); + raft::copy(handle, impl->pq_centers(), pq_centers); // Wrap the impl in an index and return return index(std::move(impl)); diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh index 095d706569..cfbdf2c321 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh @@ -28,7 +28,6 @@ #include #include #include -#include #include #include #include @@ -628,14 +627,14 @@ void ivfpq_search_worker(raft::resources const& handle, num_samples_vector); // Postprocessing - ivf::detail::postprocess_distances(distances, + ivf::detail::postprocess_distances(handle, + distances, topk_dists.data(), index.metric(), n_queries, topK, scaling_factor, - index.metric() != distance::DistanceType::CosineExpanded, - stream); + index.metric() != distance::DistanceType::CosineExpanded); ivf::detail::postprocess_neighbors(neighbors, neighbors_uint32, index.inds_ptrs().data_handle(), diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh index dbb4095dce..7a159e9797 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh @@ -10,6 +10,7 @@ #include "../ivf_pq_impl.hpp" #include #include +#include #include #include #include @@ -63,10 +64,7 @@ void serialize(raft::resources const& handle_, std::ostream& os, const index(index.list_sizes().extents()); - raft::copy(sizes_host.data_handle(), - index.list_sizes().data_handle(), - sizes_host.size(), - raft::resource::get_cuda_stream(handle_)); + raft::copy(handle_, sizes_host.view(), index.list_sizes()); raft::resource::sync_stream(handle_); raft::serialize_mdspan(handle_, os, sizes_host.view()); // NOTE: We use static_cast here because serialize_list requires the concrete list type diff --git a/cpp/src/neighbors/mg/snmg.cuh b/cpp/src/neighbors/mg/snmg.cuh index c39f3ff3ba..d2e98f1c1a 100644 --- a/cpp/src/neighbors/mg/snmg.cuh +++ b/cpp/src/neighbors/mg/snmg.cuh @@ -5,10 +5,14 @@ #pragma once +#include +#include +#include #include #include #include #include +#include #include #include "../../core/omp_wrapper.hpp" @@ -346,10 +350,9 @@ void sharded_search_with_direct_merge( } auto d_trans = raft::make_device_vector(root_handle_, index.num_ranks_); - raft::copy(d_trans.data_handle(), - h_trans.data(), - index.num_ranks_, - raft::resource::get_cuda_stream(root_handle_)); + raft::copy(root_handle_, + d_trans.view(), + raft::make_host_vector_view(h_trans.data(), index.num_ranks_)); knn_merge_parts(root_handle_, in_distances.view(), @@ -358,14 +361,13 @@ void sharded_search_with_direct_merge( out_neighbors.view(), d_trans.view()); - raft::copy(neighbors.data_handle() + output_offset, - out_neighbors.data_handle(), - part_size, - raft::resource::get_cuda_stream(root_handle_)); - raft::copy(distances.data_handle() + output_offset, - out_distances.data_handle(), - part_size, - raft::resource::get_cuda_stream(root_handle_)); + raft::copy( + root_handle_, + raft::make_host_vector_view(neighbors.data_handle() + output_offset, part_size), + raft::make_device_vector_view(out_neighbors.data_handle(), part_size)); + raft::copy(root_handle_, + raft::make_host_vector_view(distances.data_handle() + output_offset, part_size), + raft::make_device_vector_view(out_distances.data_handle(), part_size)); resource::sync_stream(root_handle_); } @@ -425,8 +427,7 @@ void sharded_search_with_tree_merge( raft::resource::get_cuda_stream(dev_res)); auto d_trans = raft::make_device_vector(dev_res, 2); - cudaMemsetAsync( - d_trans.data_handle(), 0, 2 * sizeof(searchIdxT), raft::resource::get_cuda_stream(dev_res)); + raft::matrix::fill(dev_res, d_trans.view(), searchIdxT(0)); int64_t remaining = index.num_ranks_; int64_t radix = 2; @@ -486,25 +487,26 @@ void sharded_search_with_tree_merge( distances_merge_res.view(), neighbors_merge_res.view(), d_trans.view()); - raft::copy(tmp_neighbors.data_handle(), - neighbors_merge_res.data_handle(), - part_size, - raft::resource::get_cuda_stream(dev_res)); - raft::copy(tmp_distances.data_handle(), - distances_merge_res.data_handle(), - part_size, - raft::resource::get_cuda_stream(dev_res)); + raft::copy(dev_res, + raft::make_device_vector_view(tmp_neighbors.data_handle(), part_size), + raft::make_device_vector_view( + neighbors_merge_res.data_handle(), part_size)); + raft::copy(dev_res, + raft::make_device_vector_view(tmp_distances.data_handle(), part_size), + raft::make_device_vector_view(distances_merge_res.data_handle(), + part_size)); // If done, copy the final result if (remaining <= 1) { - raft::copy(neighbors.data_handle() + output_offset, - tmp_neighbors.data_handle(), - part_size, - raft::resource::get_cuda_stream(dev_res)); - raft::copy(distances.data_handle() + output_offset, - tmp_distances.data_handle(), - part_size, - raft::resource::get_cuda_stream(dev_res)); + raft::copy( + dev_res, + raft::make_host_vector_view(neighbors.data_handle() + output_offset, part_size), + raft::make_device_vector_view(tmp_neighbors.data_handle(), + part_size)); + raft::copy( + dev_res, + raft::make_host_vector_view(distances.data_handle() + output_offset, part_size), + raft::make_device_vector_view(tmp_distances.data_handle(), part_size)); resource::sync_stream(dev_res); } } @@ -540,14 +542,16 @@ void run_search_batch(const raft::resources& clique, cuvs::neighbors::search( dev_res, ann_if, search_params, query_partition, d_neighbors.view(), d_distances.view()); - raft::copy(neighbors.data_handle() + output_offset, - d_neighbors.data_handle(), - n_rows_of_current_batch * n_neighbors, - raft::resource::get_cuda_stream(dev_res)); - raft::copy(distances.data_handle() + output_offset, - d_distances.data_handle(), - n_rows_of_current_batch * n_neighbors, - raft::resource::get_cuda_stream(dev_res)); + raft::copy(dev_res, + raft::make_host_vector_view(neighbors.data_handle() + output_offset, + n_rows_of_current_batch * n_neighbors), + raft::make_device_vector_view( + d_neighbors.data_handle(), n_rows_of_current_batch * n_neighbors)); + raft::copy(dev_res, + raft::make_host_vector_view(distances.data_handle() + output_offset, + n_rows_of_current_batch * n_neighbors), + raft::make_device_vector_view(d_distances.data_handle(), + n_rows_of_current_batch * n_neighbors)); resource::sync_stream(dev_res); } diff --git a/cpp/src/neighbors/refine/refine_device.cuh b/cpp/src/neighbors/refine/refine_device.cuh index b81c0b2b2a..10873891a6 100644 --- a/cpp/src/neighbors/refine/refine_device.cuh +++ b/cpp/src/neighbors/refine/refine_device.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -15,11 +15,10 @@ #include #include #include -#include #include +#include #include - -#include +#include namespace cuvs::neighbors { @@ -66,9 +65,9 @@ void refine_device( // - We run IVF flat search with n_probes=1 to select the best k elements of the candidates. rmm::device_uvector fake_coarse_idx(n_queries, raft::resource::get_cuda_stream(handle)); - thrust::sequence(raft::resource::get_thrust_policy(handle), - fake_coarse_idx.data(), - fake_coarse_idx.data() + n_queries); + raft::linalg::map_offset(handle, + raft::make_device_vector_view(fake_coarse_idx.data(), n_queries), + raft::cast_op{}); cuvs::neighbors::ivf_flat::index refinement_index( handle, metric, n_queries, false, true, dim); @@ -88,10 +87,8 @@ void refine_device( rmm::device_uvector chunk_index(n_queries, raft::resource::get_cuda_stream(handle)); // we know that each cluster has exactly n_candidates entries - thrust::fill(raft::resource::get_thrust_policy(handle), - chunk_index.data(), - chunk_index.data() + n_queries, - uint32_t(n_candidates)); + raft::matrix::fill( + handle, raft::make_device_vector_view(chunk_index.data(), n_queries), uint32_t(n_candidates)); uint32_t* neighbors_uint32 = nullptr; if constexpr (sizeof(idx_t) == sizeof(uint32_t)) { diff --git a/cpp/src/neighbors/scann/detail/scann_avq.cuh b/cpp/src/neighbors/scann/detail/scann_avq.cuh index 6c4797918e..e7c1663f3e 100644 --- a/cpp/src/neighbors/scann/detail/scann_avq.cuh +++ b/cpp/src/neighbors/scann/detail/scann_avq.cuh @@ -251,16 +251,11 @@ void compute_avq_centroid(raft::resources const& dev_resources, x_eta_1.view(), raft::mul_op()); - raft::linalg::reduce(avq_centroid.data_handle(), - raft::make_const_mdspan(x_eta_1.view()).data_handle(), - x_eta_1.extent(1), - x_eta_1.extent(0), - 0.0f, - raft::resource::get_cuda_stream(dev_resources), - false, - raft::identity_op(), - raft::add_op(), - raft::identity_op()); + raft::linalg::reduce( + dev_resources, + raft::make_const_mdspan(x_eta_1.view()), + raft::make_device_vector_view(avq_centroid.data_handle(), x_eta_1.extent(1)), + 0.0f); // scale x // skipping zero elements in the vector should be ok, since they are norms @@ -313,16 +308,8 @@ void compute_avq_centroid(raft::resources const& dev_resources, auto dots = raft::make_device_vector(dev_resources, x.extent(1)); - raft::linalg::reduce(dots.data_handle(), - raft::make_const_mdspan(x).data_handle(), - x.extent(1), - x.extent(0), - 0.0f, - raft::resource::get_cuda_stream(dev_resources), - false, - raft::identity_op(), - raft::add_op(), - raft::identity_op()); + raft::linalg::reduce( + dev_resources, raft::make_const_mdspan(x), dots.view(), 0.0f); raft::linalg::dot(dev_resources, raft::make_const_mdspan(dots.view()), @@ -348,31 +335,33 @@ void rescale_avq_centroids(raft::resources const& dev_resources, sum_reduce_vector(dev_resources, rescale_num_v, rescale_num.view()); - raft::linalg::map_offset(dev_resources, - raft::make_const_mdspan(rescale_denom_v), - rescale_denom_v, - [cluster_sizes, dataset_size] __device__(size_t i, float x) { - uint32_t cluster_size = i + 1 < cluster_sizes.extent(0) - ? cluster_sizes[i + 1] - cluster_sizes[i] - : dataset_size - cluster_sizes[i]; + raft::linalg::map_offset( + dev_resources, + rescale_denom_v, + [cluster_sizes, dataset_size] __device__(size_t i, float x) { + uint32_t cluster_size = i + 1 < cluster_sizes.extent(0) + ? cluster_sizes[i + 1] - cluster_sizes[i] + : dataset_size - cluster_sizes[i]; - return x * cluster_size; - }); + return x * cluster_size; + }, + raft::make_const_mdspan(rescale_denom_v)); sum_reduce_vector(dev_resources, rescale_denom_v, rescale_denom.view()); auto rescale_num_ptr = rescale_num.data_handle(); auto rescale_denom_ptr = rescale_denom.data_handle(); - raft::linalg::map_offset(dev_resources, - raft::make_const_mdspan(centroids), - centroids, - [rescale_num_ptr, rescale_denom_ptr] __device__(size_t i, float x) { - // should probably check the denominator is nonzero - float rescale = (*rescale_num_ptr) / (*rescale_denom_ptr); - - return x * rescale; - }); + raft::linalg::map_offset( + dev_resources, + centroids, + [rescale_num_ptr, rescale_denom_ptr] __device__(size_t i, float x) { + // should probably check the denominator is nonzero + float rescale = (*rescale_num_ptr) / (*rescale_denom_ptr); + + return x * rescale; + }, + raft::make_const_mdspan(centroids)); } /** @@ -526,8 +515,7 @@ class cluster_loader { auto h_cluster_ids = raft::make_pinned_vector_view(cluster_ids_buf_.data_handle(), size); - raft::copy( - h_cluster_ids.data_handle(), cluster_ids.data_handle(), cluster_ids.size(), stream_); + raft::copy(res, h_cluster_ids, cluster_ids); raft::resource::sync_stream(res, stream_); auto pinned_cluster = raft::make_pinned_matrix_view( @@ -541,10 +529,7 @@ class cluster_loader { sizeof(T) * dim_); } - raft::copy(cluster_vectors.data_handle(), - pinned_cluster.data_handle(), - pinned_cluster.size(), - stream_); + raft::copy(res, cluster_vectors, raft::make_const_mdspan(pinned_cluster)); raft::resource::sync_stream(res, stream_); } else { @@ -600,8 +585,7 @@ void apply_avq(raft::resources const& res, compute_cluster_offsets(res, labels_view, cluster_offsets.view(), max_cluster_size); auto h_cluster_offsets = raft::make_host_vector(cluster_offsets.extent(0)); - raft::copy( - h_cluster_offsets.data_handle(), cluster_offsets.data_handle(), cluster_offsets.size(), stream); + raft::copy(res, h_cluster_offsets.view(), raft::make_const_mdspan(cluster_offsets.view())); dim3 block(32, 1, 1); dim3 grid((dataset.extent(0) + block.x - 1) / block.x, 1, 1); diff --git a/cpp/src/neighbors/scann/detail/scann_build.cuh b/cpp/src/neighbors/scann/detail/scann_build.cuh index 8902fc2051..7805f622d3 100644 --- a/cpp/src/neighbors/scann/detail/scann_build.cuh +++ b/cpp/src/neighbors/scann/detail/scann_build.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -10,10 +10,12 @@ #include #include +#include #include #include #include #include +#include #include #include #include @@ -185,10 +187,12 @@ index build( auto sub_pq_codebook = create_pq_codebook(res, raft::make_const_mdspan(sub_trainset.view()), pq_params); - raft::copy(full_codebook.data_handle() + (subspace * sub_pq_codebook.size()), - sub_pq_codebook.data_handle(), - sub_pq_codebook.size(), - stream); + raft::copy( + res, + raft::make_device_vector_view( + full_codebook.data_handle() + (subspace * sub_pq_codebook.size()), sub_pq_codebook.size()), + raft::make_device_vector_view(sub_pq_codebook.data_handle(), + sub_pq_codebook.size())); } raft::resource::sync_stream(res); @@ -272,21 +276,26 @@ index build( // Copy unpacked codes to host // TODO (rmaschal): these copies are blocking and not overlapped - raft::copy(idx.quantized_residuals().data_handle() + batch.offset() * num_subspaces, - quantized_residuals.data_handle(), - quantized_residuals.size(), - stream); - - raft::copy(idx.quantized_soar_residuals().data_handle() + batch.offset() * num_subspaces, - quantized_soar_residuals.data_handle(), - quantized_soar_residuals.size(), - stream); + raft::copy(res, + raft::make_host_vector_view( + idx.quantized_residuals().data_handle() + batch.offset() * num_subspaces, + quantized_residuals.size()), + raft::make_device_vector_view(quantized_residuals.data_handle(), + quantized_residuals.size())); + + raft::copy(res, + raft::make_host_vector_view( + idx.quantized_soar_residuals().data_handle() + batch.offset() * num_subspaces, + quantized_soar_residuals.size()), + raft::make_device_vector_view(quantized_soar_residuals.data_handle(), + quantized_soar_residuals.size())); if (params.reordering_bf16) { - raft::copy(idx.bf16_dataset().data_handle() + batch.offset() * dim, - bf16_dataset.data_handle(), - bf16_dataset.size(), - stream); + raft::copy(res, + raft::make_host_vector_view( + idx.bf16_dataset().data_handle() + batch.offset() * dim, bf16_dataset.size()), + raft::make_device_vector_view(bf16_dataset.data_handle(), + bf16_dataset.size())); } // Make sure work on device is finished before swapping buffers diff --git a/cpp/src/neighbors/scann/detail/scann_common.cuh b/cpp/src/neighbors/scann/detail/scann_common.cuh index ec9a11e404..0fea279bc3 100644 --- a/cpp/src/neighbors/scann/detail/scann_common.cuh +++ b/cpp/src/neighbors/scann/detail/scann_common.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -7,8 +7,10 @@ #include "../../../core/omp_wrapper.hpp" +#include #include #include +#include #include namespace cuvs::neighbors::experimental::scann::detail { @@ -33,7 +35,7 @@ struct gather_functor { { auto h_cluster_ids = raft::make_host_vector(cluster_ids.extent(0)); - raft::copy(h_cluster_ids.data_handle(), cluster_ids.data_handle(), cluster_ids.size(), stream); + raft::copy(res, h_cluster_ids.view(), raft::make_const_mdspan(cluster_ids)); auto pinned_cluster = raft::make_host_matrix(cluster_vecs.extent(0), cluster_vecs.extent(1)); @@ -51,7 +53,9 @@ struct gather_functor { } raft::copy( - cluster_vecs.data_handle(), pinned_cluster.data_handle(), pinned_cluster.size(), stream); + res, + raft::make_device_vector_view(cluster_vecs.data_handle(), pinned_cluster.size()), + raft::make_host_vector_view(pinned_cluster.data_handle(), pinned_cluster.size())); raft::resource::sync_stream(res, stream); } }; diff --git a/cpp/src/neighbors/scann/detail/scann_quantize.cuh b/cpp/src/neighbors/scann/detail/scann_quantize.cuh index d00abad6d9..16ef1f4295 100644 --- a/cpp/src/neighbors/scann/detail/scann_quantize.cuh +++ b/cpp/src/neighbors/scann/detail/scann_quantize.cuh @@ -8,8 +8,10 @@ #include #include #include +#include #include #include +#include #include "scann_common.cuh" using namespace cuvs::neighbors; @@ -202,10 +204,7 @@ auto quantize_residuals(raft::resources const& res, // vq centers and computed residuals w.r.t those centers auto vq_codebook = raft::make_device_matrix(res, 1, dim); - RAFT_CUDA_TRY(cudaMemsetAsync(vq_codebook.data_handle(), - 0, - vq_codebook.size() * sizeof(T), - raft::resource::get_cuda_stream(res))); + raft::matrix::fill(res, vq_codebook.view(), T(0)); auto codes = process_and_fill_codes_subspaces( res, ps, residuals, raft::make_const_mdspan(vq_codebook.view()), pq_codebook); @@ -495,12 +494,13 @@ void quantize_bfloat16(raft::resources const& res, if (!std::isnan(noise_shaping_threshold)) { quantize_bfloat16_noise_shaped(res, dataset, bf16_dataset, noise_shaping_threshold); } else { - raft::linalg::unaryOp( - bf16_dataset.data_handle(), - dataset.data_handle(), - dataset.size(), + raft::linalg::map( + res, + raft::make_device_vector_view(bf16_dataset.data_handle(), + (int64_t)bf16_dataset.size()), [] __device__(float x) { return float_to_bfloat16(x); }, - resource::get_cuda_stream(res)); + raft::make_const_mdspan(raft::make_device_vector_view( + dataset.data_handle(), (int64_t)dataset.size()))); } } diff --git a/cpp/src/neighbors/scann/detail/scann_soar.cuh b/cpp/src/neighbors/scann/detail/scann_soar.cuh index e3cc662ebe..38bb7b0858 100644 --- a/cpp/src/neighbors/scann/detail/scann_soar.cuh +++ b/cpp/src/neighbors/scann/detail/scann_soar.cuh @@ -90,16 +90,13 @@ void compute_soar_labels(raft::resources const& dev_resources, auto centers_transpose = raft::make_device_matrix(dev_resources, centers.extent(1), centers.extent(0)); - raft::linalg::reduce(centers_norm.data_handle(), - centers.data_handle(), - centers.extent(1), - centers.extent(0), - 0.0f, - raft::resource::get_cuda_stream(dev_resources), - false, - raft::sq_op(), - raft::add_op(), - raft::identity_op()); + raft::linalg::reduce(dev_resources, + raft::make_const_mdspan(centers), + centers_norm.view(), + 0.0f, + false, + raft::sq_op(), + raft::add_op()); raft::linalg::transpose(dev_resources, centers, centers_transpose.view()); @@ -114,12 +111,12 @@ void compute_soar_labels(raft::resources const& dev_resources, raft::sub_op()); raft::linalg::map( - dev_resources, raft::make_const_mdspan(soar_scores.view()), soar_scores.view(), raft::sq_op()); + dev_resources, soar_scores.view(), raft::sq_op{}, raft::make_const_mdspan(soar_scores.view())); raft::linalg::map(dev_resources, - raft::make_const_mdspan(soar_scores.view()), soar_scores.view(), - raft::mul_const_op(lambda)); + raft::mul_const_op(lambda), + raft::make_const_mdspan(soar_scores.view())); auto nc_dataset = raft::make_device_matrix_view( const_cast(dataset.data_handle()), dataset.extent(0), dataset.extent(1)); diff --git a/cpp/src/preprocessing/quantize/detail/binary.cuh b/cpp/src/preprocessing/quantize/detail/binary.cuh index 1ad0183ce7..3c61aba562 100644 --- a/cpp/src/preprocessing/quantize/detail/binary.cuh +++ b/cpp/src/preprocessing/quantize/detail/binary.cuh @@ -1,21 +1,24 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once #include +#include +#include +#include #include +#include #include -#include +#include #include #include #include #include #include #include -#include namespace cuvs::preprocessing::quantize::detail { @@ -145,8 +148,7 @@ void mean_f16_in_f32(raft::resources const& res, auto mr = raft::resource::get_workspace_resource(res); auto f32_result_vec = raft::make_device_mdarray(res, mr, raft::make_extents(dataset_dim)); - RAFT_CUDA_TRY( - cudaMemsetAsync(f32_result_vec.data_handle(), 0, sizeof(float) * dataset_dim, cuda_stream)); + raft::matrix::fill(res, f32_result_vec.view(), float(0)); constexpr uint32_t dataset_size_per_cta = 2048; constexpr uint32_t block_size = 256; @@ -323,18 +325,19 @@ auto train(raft::resources const& res, } if constexpr (std::is_same_v) { - raft::copy(quantizer.threshold.data_handle(), - host_threshold_vec.data(), - dataset_dim, - raft::resource::get_cuda_stream(res)); + raft::copy( + res, + raft::make_device_vector_view(quantizer.threshold.data_handle(), (int64_t)dataset_dim), + raft::make_host_vector_view(host_threshold_vec.data(), + (int64_t)dataset_dim)); } else { auto mr = raft::resource::get_workspace_resource(res); auto casted_vec = raft::make_device_mdarray( res, mr, raft::make_extents(dataset_dim)); - raft::copy(casted_vec.data_handle(), - host_threshold_vec.data(), - dataset_dim, - raft::resource::get_cuda_stream(res)); + raft::copy(res, + casted_vec.view(), + raft::make_host_vector_view(host_threshold_vec.data(), + (int64_t)dataset_dim)); raft::linalg::map(res, quantizer.threshold.view(), raft::cast_op{}, @@ -417,10 +420,10 @@ void transform(raft::resources const& res, threshold_ptr = threshold_vec.data_handle(); if constexpr (std::is_same_v) { - raft::copy(threshold_ptr, - quantizer.threshold.data_handle(), - dataset_dim, - raft::resource::get_cuda_stream(res)); + raft::copy(res, + raft::make_host_vector_view(threshold_ptr, (int64_t)dataset_dim), + raft::make_device_vector_view(quantizer.threshold.data_handle(), + (int64_t)dataset_dim)); } else { auto mr = raft::resource::get_workspace_resource(res); auto casted_vec = raft::make_device_mdarray( @@ -430,7 +433,9 @@ void transform(raft::resources const& res, raft::cast_op{}, raft::make_const_mdspan(quantizer.threshold.view())); raft::copy( - threshold_ptr, casted_vec.data_handle(), dataset_dim, raft::resource::get_cuda_stream(res)); + res, + raft::make_host_vector_view(threshold_ptr, (int64_t)dataset_dim), + raft::make_device_vector_view(casted_vec.data_handle(), (int64_t)dataset_dim)); } // Populate the threshold_ptr on the host side before the host parallel loop. raft::resource::sync_stream(res); diff --git a/cpp/src/preprocessing/quantize/detail/scalar.cuh b/cpp/src/preprocessing/quantize/detail/scalar.cuh index b781f9797c..63a55aaf54 100644 --- a/cpp/src/preprocessing/quantize/detail/scalar.cuh +++ b/cpp/src/preprocessing/quantize/detail/scalar.cuh @@ -1,18 +1,20 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once #include +#include +#include +#include #include -#include +#include #include #include #include #include -#include namespace cuvs::preprocessing::quantize::detail { @@ -93,8 +95,12 @@ std::tuple quantile_min_max( int pos_min = subset_size - pos_max - 1; T minmax_h[2]; - raft::update_host(&(minmax_h[0]), subset.data_handle() + pos_min, 1, stream); - raft::update_host(&(minmax_h[1]), subset.data_handle() + pos_max, 1, stream); + raft::copy(res, + raft::make_host_scalar_view(&minmax_h[0]), + raft::make_device_scalar_view(subset.data_handle() + pos_min)); + raft::copy(res, + raft::make_host_scalar_view(&minmax_h[1]), + raft::make_device_scalar_view(subset.data_handle() + pos_max)); raft::resource::sync_stream(res); return {minmax_h[0], minmax_h[1]}; } @@ -141,7 +147,10 @@ void transform(raft::resources const& res, { cudaStream_t stream = raft::resource::get_cuda_stream(res); - raft::linalg::map(res, out, quantize_op(quantizer.min_, quantizer.max_), dataset); + raft::linalg::map(res, + out, + quantize_op(quantizer.min_, quantizer.max_), + raft::make_const_mdspan(dataset)); } template @@ -167,7 +176,10 @@ void inverse_transform(raft::resources const& res, { cudaStream_t stream = raft::resource::get_cuda_stream(res); - raft::linalg::map(res, out, quantize_op(quantizer.min_, quantizer.max_), dataset); + raft::linalg::map(res, + out, + quantize_op(quantizer.min_, quantizer.max_), + raft::make_const_mdspan(dataset)); } template diff --git a/cpp/src/preprocessing/spectral/detail/spectral_embedding.cuh b/cpp/src/preprocessing/spectral/detail/spectral_embedding.cuh index 169d331366..04c3e83b0a 100644 --- a/cpp/src/preprocessing/spectral/detail/spectral_embedding.cuh +++ b/cpp/src/preprocessing/spectral/detail/spectral_embedding.cuh @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -25,9 +26,6 @@ #include #include -#include -#include - namespace cuvs::preprocessing::spectral_embedding::detail { template @@ -44,10 +42,11 @@ OutSparseMatrixType create_laplacian(raft::resources const& handle, auto laplacian_elements_view = raft::make_device_vector_view( laplacian.get_elements().data(), laplacian.structure_view().get_nnz()); - raft::linalg::unary_op(handle, - raft::make_const_mdspan(laplacian_elements_view), - laplacian_elements_view, - [] __device__(DataT x) { return -x; }); + raft::linalg::map( + handle, + laplacian_elements_view, + [] __device__(DataT x) { return -x; }, + raft::make_const_mdspan(laplacian_elements_view)); return laplacian; } @@ -92,13 +91,9 @@ void compute_eigenpairs(raft::resources const& handle, spectral_embedding_config.drop_first ? config.n_components - 1 : config.n_components; auto col_indices = raft::make_device_vector(handle, config.n_components); - // TODO: https://github.com/rapidsai/raft/issues/2661 - thrust::sequence(thrust::device, - col_indices.data_handle(), - col_indices.data_handle() + config.n_components, - config.n_components - 1, // Start from the last column index - -1 // Decrement (move backward) - ); + raft::linalg::map_offset(handle, + col_indices.view(), + [n = config.n_components] __device__(int idx) { return n - 1 - idx; }); // Create row-major views of the column-major matrices // This is just a view re-interpretation, no data movement @@ -166,15 +161,15 @@ void create_connectivity_graph( auto knn_rows = raft::make_device_vector(handle, nnz); auto knn_cols = raft::make_device_vector(handle, nnz); - raft::linalg::unary_op( - handle, make_const_mdspan(d_indices.view()), knn_cols.view(), [] __device__(int64_t x) { - return static_cast(x); - }); + raft::linalg::map( + handle, + knn_cols.view(), + [] __device__(int64_t x) { return static_cast(x); }, + raft::make_const_mdspan(d_indices.view())); - thrust::tabulate(raft::resource::get_thrust_policy(handle), - knn_rows.data_handle(), - knn_rows.data_handle() + nnz, - [k_search] __device__(NNZType idx) { return idx / k_search; }); + raft::linalg::map_offset(handle, knn_rows.view(), [k_search] __device__(NNZType idx) -> int { + return static_cast(idx / k_search); + }); // set all distances to 1.0f (connectivity KNN graph) raft::matrix::fill( diff --git a/cpp/src/sparse/neighbors/detail/cross_component_nn.cuh b/cpp/src/sparse/neighbors/detail/cross_component_nn.cuh index f90915fd1a..9bfc593559 100644 --- a/cpp/src/sparse/neighbors/detail/cross_component_nn.cuh +++ b/cpp/src/sparse/neighbors/detail/cross_component_nn.cuh @@ -7,8 +7,10 @@ #include "../../../distance/masked_nn.cuh" #include +#include #include #include +#include #include #include #include @@ -324,8 +326,10 @@ void perform_1nn(raft::resources const& handle, colors_group_idxs.data_handle() + 1, n_components); auto x_norm = raft::make_device_vector(handle, (value_idx)n_rows); - raft::linalg::rowNorm( - x_norm.data_handle(), X, n_cols, n_rows, stream); + raft::linalg::norm( + handle, + raft::make_device_matrix_view(X, n_rows, n_cols), + x_norm.view()); auto adj = raft::make_device_matrix(handle, row_batch_size, n_components); using OutT = raft::KeyValuePair; @@ -381,16 +385,17 @@ void perform_1nn(raft::resources const& handle, } // Transform the keys so that they correctly point to the unpermuted indices. - thrust::transform(exec_policy, - kvp, - kvp + n_rows, - kvp, - [sort_plan = sort_plan.data_handle()] __device__(OutT KVP) { - OutT res; - res.value = KVP.value; - res.key = sort_plan[KVP.key]; - return res; - }); + raft::linalg::map( + handle, + raft::make_device_vector_view(kvp, (value_idx)n_rows), + [sort_plan = sort_plan.data_handle()] __device__(OutT KVP) { + OutT res; + res.value = KVP.value; + res.key = sort_plan[KVP.key]; + return res; + }, + raft::make_const_mdspan( + raft::make_device_vector_view(kvp, (value_idx)n_rows))); // Undo permutation of the rows of X by scattering in place. raft::matrix::scatter(handle, X_mutable_view, sort_plan_const_view, (value_idx)col_batch_size); @@ -409,7 +414,12 @@ void perform_1nn(raft::resources const& handle, raft::copy_async(kvp, tmp_kvp.data_handle(), n_rows, stream); LookupColorOp extract_colors_op(colors); - thrust::transform(exec_policy, kvp, kvp + n_rows, nn_colors, extract_colors_op); + raft::linalg::map( + handle, + raft::make_device_vector_view(nn_colors, (value_idx)n_rows), + extract_colors_op, + raft::make_const_mdspan( + raft::make_device_vector_view(kvp, (value_idx)n_rows))); } /** @@ -596,7 +606,9 @@ void cross_component_nn( // compute final size value_idx size_int = 0; - raft::update_host(&size_int, out_index.data() + (out_index.size() - 1), 1, stream); + raft::copy(handle, + raft::make_host_scalar_view(&size_int), + raft::make_device_scalar_view(out_index.data() + (out_index.size() - 1))); raft::resource::sync_stream(handle, stream); nnz_t size = static_cast(size_int); diff --git a/cpp/src/stats/detail/batched/silhouette_score.cuh b/cpp/src/stats/detail/batched/silhouette_score.cuh index 35d943277c..d438baeb66 100644 --- a/cpp/src/stats/detail/batched/silhouette_score.cuh +++ b/cpp/src/stats/detail/batched/silhouette_score.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -7,18 +7,17 @@ #include "../silhouette_score.cuh" +#include #include #include -#include +#include +#include +#include #include #include +#include #include -#include - -#include -#include -#include namespace cuvs { namespace stats { @@ -184,7 +183,6 @@ value_t silhouette_score( rmm::device_uvector cluster_counts = get_cluster_counts(handle, y, n_rows, n_labels); auto stream = raft::resource::get_cuda_stream(handle); - auto policy = raft::resource::get_thrust_policy(handle); auto b_size = n_rows * n_labels; @@ -202,7 +200,8 @@ value_t silhouette_score( a_ptr = scores; } - thrust::fill(policy, a_ptr, a_ptr + n_rows, 0); + raft::matrix::fill( + handle, raft::make_device_vector_view(a_ptr, n_rows), value_t(0)); dim3 block_size(std::min(n_rows, 32), std::min(n_labels, 32)); dim3 grid_size(raft::ceildiv(n_rows, (value_idx)block_size.x), @@ -247,23 +246,33 @@ value_t silhouette_score( raft::resource::sync_stream_pool(handle); // calculating row-wise minimum in b - // this prim only supports int indices for now - raft::linalg::reduce( - b_ptr, - b_ptr, - n_labels, - n_rows, + raft::linalg::reduce( + handle, + raft::make_device_matrix_view( + b_ptr, n_rows, n_labels), + raft::make_device_vector_view(b_ptr, n_rows), std::numeric_limits::max(), - stream, false, raft::identity_op(), raft::min_op()); // calculating the silhouette score per sample - raft::linalg::binaryOp, value_t, value_idx>( - a_ptr, a_ptr, b_ptr, n_rows, cuvs::stats::detail::SilOp(), stream); - - return thrust::reduce(policy, a_ptr, a_ptr + n_rows, value_t(0)) / n_rows; + raft::linalg::map( + handle, + raft::make_device_vector_view(a_ptr, n_rows), + cuvs::stats::detail::SilOp(), + raft::make_const_mdspan(raft::make_device_vector_view(a_ptr, n_rows)), + raft::make_const_mdspan( + raft::make_device_vector_view(b_ptr, n_rows))); + + auto sum = raft::make_device_vector(handle, 1); + raft::linalg::reduce( + handle, + raft::make_device_matrix_view(a_ptr, n_rows, 1), + sum.view(), + value_t(0)); + raft::resource::sync_stream(handle); + return sum(0) / n_rows; } } // namespace detail diff --git a/cpp/src/stats/detail/silhouette_score.cuh b/cpp/src/stats/detail/silhouette_score.cuh index e6a426b25f..5ba31bd911 100644 --- a/cpp/src/stats/detail/silhouette_score.cuh +++ b/cpp/src/stats/detail/silhouette_score.cuh @@ -9,10 +9,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include @@ -218,30 +220,36 @@ DataT silhouette_score( } else { perSampleSilScore = silhouette_scorePerSample; } - RAFT_CUDA_TRY(cudaMemsetAsync(perSampleSilScore, 0, nRows * sizeof(DataT), stream)); + raft::matrix::fill( + handle, raft::make_device_vector_view(perSampleSilScore, nRows), DataT(0)); // getting the sample count per cluster rmm::device_uvector binCountArray(nLabels, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(binCountArray.data(), 0, nLabels * sizeof(DataT), stream)); + raft::matrix::fill( + handle, raft::make_device_vector_view(binCountArray.data(), nLabels), DataT(0)); countLabels(labels, binCountArray.data(), nRows, nLabels, workspace, stream); // calculating the sample-cluster-distance-sum-array rmm::device_uvector sampleToClusterSumOfDistances(nRows * nLabels, stream); - RAFT_CUDA_TRY(cudaMemsetAsync( - sampleToClusterSumOfDistances.data(), 0, nRows * nLabels * sizeof(DataT), stream)); - raft::linalg::reduce_cols_by_key(distanceMatrix.data(), - labels, - sampleToClusterSumOfDistances.data(), - nRows, - nRows, - nLabels, - stream); + raft::matrix::fill(handle, + raft::make_device_vector_view(sampleToClusterSumOfDistances.data(), + nRows * nLabels), + DataT(0)); + raft::linalg::reduce_cols_by_key(handle, + raft::make_device_matrix_view( + distanceMatrix.data(), nRows, nRows), + raft::make_device_vector_view(labels, nRows), + raft::make_device_matrix_view( + sampleToClusterSumOfDistances.data(), nRows, nLabels), + nLabels); // creating the a array and b array rmm::device_uvector d_aArray(nRows, stream); rmm::device_uvector d_bArray(nRows, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(d_aArray.data(), 0, nRows * sizeof(DataT), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(d_bArray.data(), 0, nRows * sizeof(DataT), stream)); + raft::matrix::fill( + handle, raft::make_device_vector_view(d_aArray.data(), nRows), DataT(0)); + raft::matrix::fill( + handle, raft::make_device_vector_view(d_bArray.data(), nRows), DataT(0)); // kernel that populates the d_aArray // kernel configuration @@ -260,8 +268,10 @@ DataT silhouette_score( // elementwise dividing by bincounts rmm::device_uvector averageDistanceBetweenSampleAndCluster(nRows * nLabels, stream); - RAFT_CUDA_TRY(cudaMemsetAsync( - averageDistanceBetweenSampleAndCluster.data(), 0, nRows * nLabels * sizeof(DataT), stream)); + raft::matrix::fill(handle, + raft::make_device_vector_view( + averageDistanceBetweenSampleAndCluster.data(), nRows * nLabels), + DataT(0)); auto averageDistanceBetweenSampleAndClusterView = raft::make_device_matrix_view( averageDistanceBetweenSampleAndCluster.data(), nRows, nLabels); @@ -283,24 +293,28 @@ DataT silhouette_score( }); // calculating row-wise minimum - raft::linalg::reduce( - d_bArray.data(), - averageDistanceBetweenSampleAndCluster.data(), - nLabels, - nRows, + raft::linalg::reduce( + handle, + raft::make_device_matrix_view( + averageDistanceBetweenSampleAndCluster.data(), nRows, nLabels), + raft::make_device_vector_view(d_bArray.data(), nRows), std::numeric_limits::max(), - stream, false, raft::identity_op{}, raft::min_op{}); // calculating the silhouette score per sample using the d_aArray and d_bArray - raft::linalg::binaryOp>( - perSampleSilScore, d_aArray.data(), d_bArray.data(), nRows, SilOp(), stream); + raft::linalg::map( + handle, + raft::make_device_vector_view(perSampleSilScore, nRows), + SilOp(), + raft::make_const_mdspan(raft::make_device_vector_view(d_aArray.data(), nRows)), + raft::make_const_mdspan(raft::make_device_vector_view(d_bArray.data(), nRows))); // calculating the sum of all the silhouette score rmm::device_scalar d_avgSilhouetteScore(stream); - RAFT_CUDA_TRY(cudaMemsetAsync(d_avgSilhouetteScore.data(), 0, sizeof(DataT), stream)); + raft::matrix::fill( + handle, raft::make_device_vector_view(d_avgSilhouetteScore.data(), 1), DataT(0)); raft::linalg::mapThenSumReduce(d_avgSilhouetteScore.data(), nRows, diff --git a/cpp/src/stats/detail/trustworthiness_score.cuh b/cpp/src/stats/detail/trustworthiness_score.cuh index f6a5754843..f44404fefd 100644 --- a/cpp/src/stats/detail/trustworthiness_score.cuh +++ b/cpp/src/stats/detail/trustworthiness_score.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -181,7 +182,7 @@ double trustworthiness_score(const raft::resources& h, build_lookup_table<<>>( lookup_table.data(), X_ind.data(), n, work); - RAFT_CUDA_TRY(cudaMemsetAsync(t_dbuf.data(), 0, sizeof(double), stream)); + raft::matrix::fill(h, raft::make_device_scalar_view(t_dbuf.data()), double(0)); work = curBatchSize * (n_neighbors + 1); n_blocks = raft::ceildiv(work, N_THREADS);