Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions cpp/src/cluster/detail/agglomerative.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@

#pragma once

#include <raft/core/copy.cuh>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/matrix/init.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

Expand Down Expand Up @@ -108,9 +112,15 @@ void build_dendrogram_host(raft::resources const& handle,
std::vector<value_idx> mst_dst_h(n_edges);
std::vector<value_t> mst_weights_h(n_edges);

raft::update_host(mst_src_h.data(), rows, n_edges, stream);
raft::update_host(mst_dst_h.data(), cols, n_edges, stream);
raft::update_host(mst_weights_h.data(), data, n_edges, stream);
raft::copy(handle,
raft::make_host_vector_view(mst_src_h.data(), n_edges),
raft::make_device_vector_view(rows, n_edges));
raft::copy(handle,
raft::make_host_vector_view(mst_dst_h.data(), n_edges),
raft::make_device_vector_view(cols, n_edges));
raft::copy(handle,
raft::make_host_vector_view(mst_weights_h.data(), n_edges),
raft::make_device_vector_view(data, n_edges));

raft::resource::sync_stream(handle, stream);

Expand Down Expand Up @@ -138,9 +148,15 @@ void build_dendrogram_host(raft::resources const& handle,
U.perform_union(aa, bb);
}

raft::update_device(children, children_h.data(), n_edges * 2, stream);
raft::update_device(out_size, out_size_h.data(), n_edges, stream);
raft::update_device(out_delta, out_delta_h.data(), n_edges, stream);
raft::copy(handle,
raft::make_device_vector_view(children, n_edges * 2),
raft::make_host_vector_view(children_h.data(), n_edges * 2));
raft::copy(handle,
raft::make_device_vector_view(out_size, n_edges),
raft::make_host_vector_view(out_size_h.data(), n_edges));
raft::copy(handle,
raft::make_device_vector_view(out_delta, n_edges),
raft::make_host_vector_view(out_delta_h.data(), n_edges));
}

template <typename value_idx>
Expand Down Expand Up @@ -236,7 +252,8 @@ void extract_flattened_clusters(raft::resources const& handle,

// Handle special case where n_clusters == 1
if (n_clusters == 1) {
thrust::fill(thrust_policy, labels, labels + n_leaves, 0);
raft::matrix::fill(
handle, raft::make_device_vector_view<value_idx>(labels, n_leaves), value_idx(0));
} else {
/**
* Compute levels for each node
Expand Down
27 changes: 16 additions & 11 deletions cpp/src/cluster/detail/connectivities.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -9,11 +9,13 @@
#include "./kmeans_common.cuh"
#include <cuvs/cluster/agglomerative.hpp>
#include <cuvs/distance/distance.hpp>
#include <raft/core/copy.cuh>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/map.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/sparse/convert/csr.cuh>
#include <raft/sparse/coo.hpp>
#include <raft/util/cuda_utils.cuh>
Expand Down Expand Up @@ -85,9 +87,9 @@ struct distance_graph_impl<Linkage::KNN_GRAPH, value_idx, value_t> {
bool self_loop = row == col;
return (self_loop * std::numeric_limits<value_t>::max()) + (!self_loop * val);
},
rows_view,
cols_view,
vals_in_view);
raft::make_const_mdspan(rows_view),
raft::make_const_mdspan(cols_view),
raft::make_const_mdspan(vals_in_view));

raft::sparse::convert::sorted_coo_to_csr(
knn_graph_coo.rows(), knn_graph_coo.nnz, indptr.data(), m + 1, stream);
Expand Down Expand Up @@ -144,7 +146,7 @@ void pairwise_distances(const raft::resources& handle,
raft::make_device_vector_view<value_idx, value_idx>(indptr, m),
[=] __device__(value_idx idx) { return idx * m; });

raft::update_device(indptr + m, &nnz, 1, stream);
raft::copy(handle, raft::make_device_scalar_view(indptr + m), raft::make_host_scalar_view(&nnz));

// TODO: It would ultimately be nice if the MST could accept
// dense inputs directly so we don't need to double the memory
Expand All @@ -157,11 +159,14 @@ void pairwise_distances(const raft::resources& handle,
// self-loops get max distance
auto data_view = raft::make_device_vector_view<value_t, value_idx>(data, nnz);

raft::linalg::map_offset(handle, data_view, [=] __device__(value_idx idx) {
value_t val = data[idx];
bool self_loop = idx % m == idx / m;
return (self_loop * std::numeric_limits<value_t>::max()) + (!self_loop * val);
});
raft::linalg::map_offset(
handle,
data_view,
[=] __device__(value_idx idx, value_t val) {
bool self_loop = idx % m == idx / m;
return (self_loop * std::numeric_limits<value_t>::max()) + (!self_loop * val);
},
raft::make_const_mdspan(data_view));
}

/**
Expand Down
Loading