From 2dd503bcc5ab93e4e04045e52193998b10bf80c7 Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 16 Feb 2026 01:59:46 +0000 Subject: [PATCH 01/13] pca-preprocessor --- cpp/CMakeLists.txt | 1 + cpp/include/cuvs/preprocessing/pca.hpp | 238 ++++++++++++++++ cpp/src/preprocessing/pca/detail/pca.cuh | 113 ++++++++ cpp/src/preprocessing/pca/pca.cu | 102 +++++++ cpp/tests/CMakeLists.txt | 7 +- cpp/tests/preprocessing/pca.cu | 335 +++++++++++++++++++++++ 6 files changed, 794 insertions(+), 2 deletions(-) create mode 100644 cpp/include/cuvs/preprocessing/pca.hpp create mode 100644 cpp/src/preprocessing/pca/detail/pca.cuh create mode 100644 cpp/src/preprocessing/pca/pca.cu create mode 100644 cpp/tests/preprocessing/pca.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6313db71ca..7cc86eb2f8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -549,6 +549,7 @@ if(NOT BUILD_CPU_ONLY) src/preprocessing/quantize/binary.cu src/preprocessing/quantize/pq.cu src/preprocessing/spectral/spectral_embedding.cu + src/preprocessing/pca/pca.cu src/selection/select_k_float_int64_t.cu src/selection/select_k_float_int32_t.cu src/selection/select_k_float_uint32_t.cu diff --git a/cpp/include/cuvs/preprocessing/pca.hpp b/cpp/include/cuvs/preprocessing/pca.hpp new file mode 100644 index 0000000000..b26613d50c --- /dev/null +++ b/cpp/include/cuvs/preprocessing/pca.hpp @@ -0,0 +1,238 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +namespace cuvs::preprocessing::pca { + +/** + * @brief Solver algorithm for PCA decomposition. + * + * @param COV_EIG_DQ: covariance of input will be used along with eigen decomposition using divide + * and conquer method for symmetric matrices + * @param COV_EIG_JACOBI: covariance of input will be used along with eigen decomposition using + * jacobi method for symmetric matrices + */ +enum class solver : int { + COV_EIG_DQ, + COV_EIG_JACOBI, +}; + +/** + * @brief Parameters for PCA decomposition. Ref: + * http://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html + */ +struct params { + /** @brief Number of components to keep. */ + int n_components = 1; + + /** + * @brief If false, data passed to fit are overwritten and running fit(X).transform(X) will + * not yield the expected results, use fit_transform(X) instead. + */ + bool copy = true; + + /** + * @brief When true (false by default) the components vectors are multiplied by the square + * root of n_samples and then divided by the singular values to ensure uncorrelated outputs with + * unit component-wise variances. + */ + bool whiten = false; + + /** @brief The solver algorithm to use. */ + solver algorithm = solver::COV_EIG_DQ; + + /** + * @brief Tolerance for singular values computed by svd_solver == 'arpack' or + * svd_solver == 'COV_EIG_JACOBI'. + */ + float tol = 0.0f; + + /** + * @brief Number of iterations for the power method computed by jacobi method + * (svd_solver == 'COV_EIG_JACOBI'). + */ + int n_iterations = 15; + + /** @brief 0: no error message printing, 1: print error messages. */ + int verbose = 0; +}; + +/** + * @defgroup pca PCA (Principal Component Analysis) + * @{ + */ + +/** + * @brief Perform PCA fit operation. + * + * Computes the principal components, explained variances, singular values, and column means + * from the input data. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * cuvs::preprocessing::pca::params params; + * params.n_components = 2; + * + * auto input = raft::make_device_matrix(handle, n_rows, n_cols); + * // ... fill input ... + * + * auto components = raft::make_device_matrix( + * handle, params.n_components, n_cols); + * auto explained_var = raft::make_device_vector(handle, params.n_components); + * auto explained_var_ratio = raft::make_device_vector(handle, params.n_components); + * auto singular_vals = raft::make_device_vector(handle, params.n_components); + * auto mu = raft::make_device_vector(handle, n_cols); + * auto noise_vars = raft::make_device_scalar(handle); + * + * cuvs::preprocessing::pca::fit(handle, params, + * input.view(), components.view(), explained_var.view(), + * explained_var_ratio.view(), singular_vals.view(), mu.view(), noise_vars.view()); + * @endcode + * + * @param[in] handle raft resource handle + * @param[in] config PCA parameters + * @param[inout] input input data [n_rows x n_cols] (col-major). Modified temporarily. + * @param[out] components principal components [n_components x n_cols] (col-major) + * @param[out] explained_var explained variances [n_components] + * @param[out] explained_var_ratio explained variance ratios [n_components] + * @param[out] singular_vals singular values [n_components] + * @param[out] mu column means [n_cols] + * @param[out] noise_vars noise variance (scalar) + * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) + */ +void fit(raft::resources const& handle, + params config, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_scalar_view noise_vars, + bool flip_signs_based_on_U = false); + +void fit(raft::resources const& handle, + params config, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_scalar_view noise_vars, + bool flip_signs_based_on_U = false); + +/** + * @brief Perform PCA fit and transform operations. + * + * Computes the principal components and transforms the input data into the eigenspace + * in a single operation. + * + * @param[in] handle raft resource handle + * @param[in] config PCA parameters + * @param[inout] input input data [n_rows x n_cols] (col-major). Modified temporarily. + * @param[out] trans_input transformed data [n_rows x n_components] (col-major) + * @param[out] components principal components [n_components x n_cols] (col-major) + * @param[out] explained_var explained variances [n_components] + * @param[out] explained_var_ratio explained variance ratios [n_components] + * @param[out] singular_vals singular values [n_components] + * @param[out] mu column means [n_cols] + * @param[out] noise_vars noise variance (scalar) + * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) + */ +void fit_transform(raft::resources const& handle, + params config, + raft::device_matrix_view input, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_scalar_view noise_vars, + bool flip_signs_based_on_U = false); + +void fit_transform(raft::resources const& handle, + params config, + raft::device_matrix_view input, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_scalar_view noise_vars, + bool flip_signs_based_on_U = false); + +/** + * @brief Perform PCA transform operation. + * + * Transforms the input data into the eigenspace using previously computed principal components. + * + * @param[in] handle raft resource handle + * @param[in] config PCA parameters + * @param[inout] input data to transform [n_rows x n_cols] (col-major). Modified temporarily + * (mean-centered then restored). + * @param[in] components principal components [n_components x n_cols] (col-major) + * @param[in] singular_vals singular values [n_components] + * @param[in] mu column means [n_cols] + * @param[out] trans_input transformed data [n_rows x n_components] (col-major) + */ +void transform(raft::resources const& handle, + params config, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_matrix_view trans_input); + +void transform(raft::resources const& handle, + params config, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_matrix_view trans_input); + +/** + * @brief Perform PCA inverse transform operation. + * + * Transforms data from the eigenspace back to the original space. + * + * @param[in] handle raft resource handle + * @param[in] config PCA parameters + * @param[in] trans_input transformed data [n_rows x n_components] (col-major) + * @param[in] components principal components [n_components x n_cols] (col-major) + * @param[in] singular_vals singular values [n_components] + * @param[in] mu column means [n_cols] + * @param[out] output reconstructed data [n_rows x n_cols] (col-major) + */ +void inverse_transform(raft::resources const& handle, + params config, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_matrix_view output); + +void inverse_transform(raft::resources const& handle, + params config, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_matrix_view output); + +/** @} */ // end group pca + +} // namespace cuvs::preprocessing::pca diff --git a/cpp/src/preprocessing/pca/detail/pca.cuh b/cpp/src/preprocessing/pca/detail/pca.cuh new file mode 100644 index 0000000000..42d1e7bdd7 --- /dev/null +++ b/cpp/src/preprocessing/pca/detail/pca.cuh @@ -0,0 +1,113 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include +#include +#include + +namespace cuvs::preprocessing::pca::detail { + +/** + * @brief Convert cuvs::preprocessing::pca::params to raft::linalg::paramsPCA. + */ +inline raft::linalg::paramsPCA to_raft_params(params config, std::size_t n_rows, std::size_t n_cols) +{ + raft::linalg::paramsPCA prms; + prms.n_rows = n_rows; + prms.n_cols = n_cols; + prms.n_components = static_cast(config.n_components); + prms.algorithm = static_cast(static_cast(config.algorithm)); + prms.tol = config.tol; + prms.n_iterations = static_cast(config.n_iterations); + prms.verbose = config.verbose; + prms.copy = config.copy; + prms.whiten = config.whiten; + return prms; +} + +template +void fit(raft::resources const& handle, + params config, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_scalar_view noise_vars, + bool flip_signs_based_on_U) +{ + auto raft_prms = to_raft_params(config, input.extent(0), input.extent(1)); + raft::linalg::pca_fit(handle, + raft_prms, + input, + components, + explained_var, + explained_var_ratio, + singular_vals, + mu, + noise_vars, + flip_signs_based_on_U); +} + +template +void fit_transform(raft::resources const& handle, + params config, + raft::device_matrix_view input, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_scalar_view noise_vars, + bool flip_signs_based_on_U) +{ + auto raft_prms = to_raft_params(config, input.extent(0), input.extent(1)); + raft::linalg::pca_fit_transform(handle, + raft_prms, + input, + trans_input, + components, + explained_var, + explained_var_ratio, + singular_vals, + mu, + noise_vars, + flip_signs_based_on_U); +} + +template +void transform(raft::resources const& handle, + params config, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_matrix_view trans_input) +{ + auto raft_prms = to_raft_params(config, input.extent(0), input.extent(1)); + raft::linalg::pca_transform(handle, raft_prms, input, components, singular_vals, mu, trans_input); +} + +template +void inverse_transform(raft::resources const& handle, + params config, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_matrix_view output) +{ + auto raft_prms = to_raft_params(config, output.extent(0), output.extent(1)); + raft::linalg::pca_inverse_transform( + handle, raft_prms, trans_input, components, singular_vals, mu, output); +} + +} // namespace cuvs::preprocessing::pca::detail diff --git a/cpp/src/preprocessing/pca/pca.cu b/cpp/src/preprocessing/pca/pca.cu new file mode 100644 index 0000000000..35f67206a7 --- /dev/null +++ b/cpp/src/preprocessing/pca/pca.cu @@ -0,0 +1,102 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "./detail/pca.cuh" + +#include + +namespace cuvs::preprocessing::pca { + +#define CUVS_INST_PCA_FIT(DataT) \ + void fit(raft::resources const& handle, \ + params config, \ + raft::device_matrix_view input, \ + raft::device_matrix_view components, \ + raft::device_vector_view explained_var, \ + raft::device_vector_view explained_var_ratio, \ + raft::device_vector_view singular_vals, \ + raft::device_vector_view mu, \ + raft::device_scalar_view noise_vars, \ + bool flip_signs_based_on_U) \ + { \ + detail::fit(handle, \ + config, \ + input, \ + components, \ + explained_var, \ + explained_var_ratio, \ + singular_vals, \ + mu, \ + noise_vars, \ + flip_signs_based_on_U); \ + } + +CUVS_INST_PCA_FIT(float); +CUVS_INST_PCA_FIT(double); +#undef CUVS_INST_PCA_FIT + +#define CUVS_INST_PCA_FIT_TRANSFORM(DataT) \ + void fit_transform(raft::resources const& handle, \ + params config, \ + raft::device_matrix_view input, \ + raft::device_matrix_view trans_input, \ + raft::device_matrix_view components, \ + raft::device_vector_view explained_var, \ + raft::device_vector_view explained_var_ratio, \ + raft::device_vector_view singular_vals, \ + raft::device_vector_view mu, \ + raft::device_scalar_view noise_vars, \ + bool flip_signs_based_on_U) \ + { \ + detail::fit_transform(handle, \ + config, \ + input, \ + trans_input, \ + components, \ + explained_var, \ + explained_var_ratio, \ + singular_vals, \ + mu, \ + noise_vars, \ + flip_signs_based_on_U); \ + } + +CUVS_INST_PCA_FIT_TRANSFORM(float); +CUVS_INST_PCA_FIT_TRANSFORM(double); +#undef CUVS_INST_PCA_FIT_TRANSFORM + +#define CUVS_INST_PCA_TRANSFORM(DataT) \ + void transform(raft::resources const& handle, \ + params config, \ + raft::device_matrix_view input, \ + raft::device_matrix_view components, \ + raft::device_vector_view singular_vals, \ + raft::device_vector_view mu, \ + raft::device_matrix_view trans_input) \ + { \ + detail::transform(handle, config, input, components, singular_vals, mu, trans_input); \ + } + +CUVS_INST_PCA_TRANSFORM(float); +CUVS_INST_PCA_TRANSFORM(double); +#undef CUVS_INST_PCA_TRANSFORM + +#define CUVS_INST_PCA_INVERSE_TRANSFORM(DataT) \ + void inverse_transform(raft::resources const& handle, \ + params config, \ + raft::device_matrix_view trans_input, \ + raft::device_matrix_view components, \ + raft::device_vector_view singular_vals, \ + raft::device_vector_view mu, \ + raft::device_matrix_view output) \ + { \ + detail::inverse_transform(handle, config, trans_input, components, singular_vals, mu, output); \ + } + +CUVS_INST_PCA_INVERSE_TRANSFORM(float); +CUVS_INST_PCA_INVERSE_TRANSFORM(double); +#undef CUVS_INST_PCA_INVERSE_TRANSFORM + +} // namespace cuvs::preprocessing::pca diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 9fc620b4cb..44c4f77fc0 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -334,8 +334,11 @@ ConfigureTest( ConfigureTest( NAME PREPROCESSING_TEST - PATH preprocessing/scalar_quantization.cu preprocessing/binary_quantization.cu - preprocessing/spectral_embedding.cu preprocessing/product_quantization.cu + PATH preprocessing/scalar_quantization.cu + preprocessing/binary_quantization.cu + preprocessing/spectral_embedding.cu + preprocessing/product_quantization.cu + preprocessing/pca.cu GPUS 1 PERCENT 100 ) diff --git a/cpp/tests/preprocessing/pca.cu b/cpp/tests/preprocessing/pca.cu new file mode 100644 index 0000000000..dfe0808289 --- /dev/null +++ b/cpp/tests/preprocessing/pca.cu @@ -0,0 +1,335 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "../test_utils.cuh" + +#include +#include +// #include +#include +#include +#include +#include + +#include +#include + +#include + +namespace cuvs::preprocessing::pca { + +template +struct PcaInputs { + T tolerance; + int len; + int n_row; + int n_col; + int len2; + int n_row2; + int n_col2; + unsigned long long int seed; + int algo; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const PcaInputs& dims) +{ + return os; +} + +template +class PcaTest : public ::testing::TestWithParam> { + public: + PcaTest() + : params(::testing::TestWithParam>::GetParam()), + stream(raft::resource::get_cuda_stream(handle)), + explained_vars(params.n_col, stream), + explained_vars_ref(params.n_col, stream), + components(params.n_col * params.n_col, stream), + components_ref(params.n_col * params.n_col, stream), + trans_data(params.len, stream), + trans_data_ref(params.len, stream), + data(params.len, stream), + data_back(params.len, stream), + data2(params.len2, stream), + data2_back(params.len2, stream) + { + basicTest(); + advancedTest(); + } + + protected: + void basicTest() + { + raft::random::Rng r(params.seed, raft::random::GenPC); + int len = params.len; + + std::vector data_h = {1.0, 2.0, 5.0, 4.0, 2.0, 1.0}; + data_h.resize(len); + raft::update_device(data.data(), data_h.data(), len, stream); + + std::vector trans_data_ref_h = {-2.3231, -0.3517, 2.6748, 0.3979, -0.6571, 0.2592}; + trans_data_ref_h.resize(len); + raft::update_device(trans_data_ref.data(), trans_data_ref_h.data(), len, stream); + + int len_comp = params.n_col * params.n_col; + rmm::device_uvector explained_var_ratio(params.n_col, stream); + rmm::device_uvector singular_vals(params.n_col, stream); + rmm::device_uvector mean(params.n_col, stream); + rmm::device_uvector noise_vars(1, stream); + + std::vector components_ref_h = {0.8163, 0.5776, -0.5776, 0.8163}; + components_ref_h.resize(len_comp); + std::vector explained_vars_ref_h = {6.338, 0.3287}; + explained_vars_ref_h.resize(params.n_col); + + raft::update_device(components_ref.data(), components_ref_h.data(), len_comp, stream); + raft::update_device( + explained_vars_ref.data(), explained_vars_ref_h.data(), params.n_col, stream); + + cuvs::preprocessing::pca::params prms; + // prms.n_cols = params.n_col; + // prms.n_rows = params.n_row; + prms.n_components = params.n_col; + prms.whiten = false; + if (params.algo == 0) + prms.algorithm = solver::COV_EIG_DQ; + else + prms.algorithm = solver::COV_EIG_JACOBI; + + auto input_view = raft::make_device_matrix_view( + data.data(), params.n_row, params.n_col); + auto components_view = raft::make_device_matrix_view( + components.data(), prms.n_components, params.n_col); + auto explained_var_view = + raft::make_device_vector_view(explained_vars.data(), prms.n_components); + auto explained_var_ratio_view = + raft::make_device_vector_view(explained_var_ratio.data(), prms.n_components); + auto singular_vals_view = + raft::make_device_vector_view(singular_vals.data(), prms.n_components); + auto mu_view = raft::make_device_vector_view(mean.data(), params.n_col); + auto noise_vars_view = raft::make_device_scalar_view(noise_vars.data()); + + cuvs::preprocessing::pca::fit(handle, + prms, + input_view, + components_view, + explained_var_view, + explained_var_ratio_view, + singular_vals_view, + mu_view, + noise_vars_view); + + auto trans_data_view = raft::make_device_matrix_view( + trans_data.data(), params.n_row, prms.n_components); + + cuvs::preprocessing::pca::transform( + handle, prms, input_view, components_view, singular_vals_view, mu_view, trans_data_view); + + auto data_back_view = raft::make_device_matrix_view( + data_back.data(), params.n_row, params.n_col); + + cuvs::preprocessing::pca::inverse_transform( + handle, prms, trans_data_view, components_view, singular_vals_view, mu_view, data_back_view); + } + + void advancedTest() + { + raft::random::Rng r(params.seed, raft::random::GenPC); + int len = params.len2; + + cuvs::preprocessing::pca::params prms; + // prms.n_cols = params.n_col2; + // prms.n_rows = params.n_row2; + prms.n_components = params.n_col2; + prms.whiten = false; + if (params.algo == 0) + prms.algorithm = solver::COV_EIG_DQ; + else if (params.algo == 1) + prms.algorithm = solver::COV_EIG_JACOBI; + + r.uniform(data2.data(), len, T(-1.0), T(1.0), stream); + rmm::device_uvector data2_trans(params.n_row2 * prms.n_components, stream); + + int len_comp = params.n_col2 * prms.n_components; + rmm::device_uvector components2(len_comp, stream); + rmm::device_uvector explained_vars2(prms.n_components, stream); + rmm::device_uvector explained_var_ratio2(prms.n_components, stream); + rmm::device_uvector singular_vals2(prms.n_components, stream); + rmm::device_uvector mean2(params.n_col2, stream); + rmm::device_uvector noise_vars2(1, stream); + + auto input_view = raft::make_device_matrix_view( + data2.data(), params.n_row2, params.n_col2); + auto trans_view = raft::make_device_matrix_view( + data2_trans.data(), params.n_row2, prms.n_components); + auto comp_view = raft::make_device_matrix_view( + components2.data(), prms.n_components, params.n_col2); + auto ev_view = + raft::make_device_vector_view(explained_vars2.data(), prms.n_components); + auto evr_view = + raft::make_device_vector_view(explained_var_ratio2.data(), prms.n_components); + auto sv_view = + raft::make_device_vector_view(singular_vals2.data(), prms.n_components); + auto mu_view = raft::make_device_vector_view(mean2.data(), params.n_col2); + auto noise_view = raft::make_device_scalar_view(noise_vars2.data()); + + cuvs::preprocessing::pca::fit_transform(handle, + prms, + input_view, + trans_view, + comp_view, + ev_view, + evr_view, + sv_view, + mu_view, + noise_view); + + auto data2_back_view = raft::make_device_matrix_view( + data2_back.data(), params.n_row2, params.n_col2); + + cuvs::preprocessing::pca::inverse_transform( + handle, prms, trans_view, comp_view, sv_view, mu_view, data2_back_view); + } + + protected: + raft::device_resources handle; + cudaStream_t stream; + + PcaInputs params; + + rmm::device_uvector explained_vars, explained_vars_ref, components, components_ref, trans_data, + trans_data_ref, data, data_back, data2, data2_back; +}; + +const std::vector> inputsf2 = { + {0.01f, 3 * 2, 3, 2, 1024 * 128, 1024, 128, 1234ULL, 0}, + {0.01f, 3 * 2, 3, 2, 256 * 32, 256, 32, 1234ULL, 1}}; + +const std::vector> inputsd2 = { + {0.01, 3 * 2, 3, 2, 1024 * 128, 1024, 128, 1234ULL, 0}, + {0.01, 3 * 2, 3, 2, 256 * 32, 256, 32, 1234ULL, 1}}; + +typedef PcaTest PcaTestValF; +TEST_P(PcaTestValF, Result) +{ + ASSERT_TRUE(devArrMatch(explained_vars.data(), + explained_vars_ref.data(), + params.n_col, + cuvs::CompareApprox(params.tolerance), + raft::resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestValD; +TEST_P(PcaTestValD, Result) +{ + ASSERT_TRUE(devArrMatch(explained_vars.data(), + explained_vars_ref.data(), + params.n_col, + cuvs::CompareApprox(params.tolerance), + raft::resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestLeftVecF; +TEST_P(PcaTestLeftVecF, Result) +{ + ASSERT_TRUE(devArrMatch(components.data(), + components_ref.data(), + (params.n_col * params.n_col), + cuvs::CompareApprox(params.tolerance), + raft::resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestLeftVecD; +TEST_P(PcaTestLeftVecD, Result) +{ + ASSERT_TRUE(devArrMatch(components.data(), + components_ref.data(), + (params.n_col * params.n_col), + cuvs::CompareApprox(params.tolerance), + raft::resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestTransDataF; +TEST_P(PcaTestTransDataF, Result) +{ + ASSERT_TRUE(devArrMatch(trans_data.data(), + trans_data_ref.data(), + (params.n_row * params.n_col), + cuvs::CompareApprox(params.tolerance), + raft::resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestTransDataD; +TEST_P(PcaTestTransDataD, Result) +{ + ASSERT_TRUE(devArrMatch(trans_data.data(), + trans_data_ref.data(), + (params.n_row * params.n_col), + cuvs::CompareApprox(params.tolerance), + raft::resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestDataVecSmallF; +TEST_P(PcaTestDataVecSmallF, Result) +{ + ASSERT_TRUE(devArrMatch(data.data(), + data_back.data(), + (params.n_col * params.n_col), + cuvs::CompareApprox(params.tolerance), + raft::resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestDataVecSmallD; +TEST_P(PcaTestDataVecSmallD, Result) +{ + ASSERT_TRUE(devArrMatch(data.data(), + data_back.data(), + (params.n_col * params.n_col), + cuvs::CompareApprox(params.tolerance), + raft::resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestDataVecF; +TEST_P(PcaTestDataVecF, Result) +{ + ASSERT_TRUE(devArrMatch(data2.data(), + data2_back.data(), + (params.n_col2 * params.n_col2), + cuvs::CompareApprox(params.tolerance), + raft::resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestDataVecD; +TEST_P(PcaTestDataVecD, Result) +{ + ASSERT_TRUE(devArrMatch(data2.data(), + data2_back.data(), + (params.n_col2 * params.n_col2), + cuvs::CompareApprox(params.tolerance), + raft::resource::get_cuda_stream(handle))); +} + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestValF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestValD, ::testing::ValuesIn(inputsd2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestLeftVecF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestLeftVecD, ::testing::ValuesIn(inputsd2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecSmallF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecSmallD, ::testing::ValuesIn(inputsd2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestTransDataF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestTransDataD, ::testing::ValuesIn(inputsd2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecD, ::testing::ValuesIn(inputsd2)); + +} // end namespace cuvs::preprocessing::pca From 14b08e7b027453af3e5397644da42dccb87405b6 Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 16 Feb 2026 02:19:15 +0000 Subject: [PATCH 02/13] IndexT tparam --- cpp/src/preprocessing/pca/detail/pca.cuh | 58 +++++----- cpp/src/preprocessing/pca/pca.cu | 132 +++++++++++------------ 2 files changed, 95 insertions(+), 95 deletions(-) diff --git a/cpp/src/preprocessing/pca/detail/pca.cuh b/cpp/src/preprocessing/pca/detail/pca.cuh index 42d1e7bdd7..e0629da684 100644 --- a/cpp/src/preprocessing/pca/detail/pca.cuh +++ b/cpp/src/preprocessing/pca/detail/pca.cuh @@ -31,16 +31,16 @@ inline raft::linalg::paramsPCA to_raft_params(params config, std::size_t n_rows, return prms; } -template +template void fit(raft::resources const& handle, params config, - raft::device_matrix_view input, - raft::device_matrix_view components, - raft::device_vector_view explained_var, - raft::device_vector_view explained_var_ratio, - raft::device_vector_view singular_vals, - raft::device_vector_view mu, - raft::device_scalar_view noise_vars, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_scalar_view noise_vars, bool flip_signs_based_on_U) { auto raft_prms = to_raft_params(config, input.extent(0), input.extent(1)); @@ -56,17 +56,17 @@ void fit(raft::resources const& handle, flip_signs_based_on_U); } -template +template void fit_transform(raft::resources const& handle, params config, - raft::device_matrix_view input, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, - raft::device_vector_view explained_var, - raft::device_vector_view explained_var_ratio, - raft::device_vector_view singular_vals, - raft::device_vector_view mu, - raft::device_scalar_view noise_vars, + raft::device_matrix_view input, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_scalar_view noise_vars, bool flip_signs_based_on_U) { auto raft_prms = to_raft_params(config, input.extent(0), input.extent(1)); @@ -83,27 +83,27 @@ void fit_transform(raft::resources const& handle, flip_signs_based_on_U); } -template +template void transform(raft::resources const& handle, params config, - raft::device_matrix_view input, - raft::device_matrix_view components, - raft::device_vector_view singular_vals, - raft::device_vector_view mu, - raft::device_matrix_view trans_input) + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_matrix_view trans_input) { auto raft_prms = to_raft_params(config, input.extent(0), input.extent(1)); raft::linalg::pca_transform(handle, raft_prms, input, components, singular_vals, mu, trans_input); } -template +template void inverse_transform(raft::resources const& handle, params config, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, - raft::device_vector_view singular_vals, - raft::device_vector_view mu, - raft::device_matrix_view output) + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_matrix_view output) { auto raft_prms = to_raft_params(config, output.extent(0), output.extent(1)); raft::linalg::pca_inverse_transform( diff --git a/cpp/src/preprocessing/pca/pca.cu b/cpp/src/preprocessing/pca/pca.cu index 35f67206a7..1091090815 100644 --- a/cpp/src/preprocessing/pca/pca.cu +++ b/cpp/src/preprocessing/pca/pca.cu @@ -9,94 +9,94 @@ namespace cuvs::preprocessing::pca { -#define CUVS_INST_PCA_FIT(DataT) \ - void fit(raft::resources const& handle, \ - params config, \ - raft::device_matrix_view input, \ - raft::device_matrix_view components, \ - raft::device_vector_view explained_var, \ - raft::device_vector_view explained_var_ratio, \ - raft::device_vector_view singular_vals, \ - raft::device_vector_view mu, \ - raft::device_scalar_view noise_vars, \ - bool flip_signs_based_on_U) \ - { \ - detail::fit(handle, \ - config, \ - input, \ - components, \ - explained_var, \ - explained_var_ratio, \ - singular_vals, \ - mu, \ - noise_vars, \ - flip_signs_based_on_U); \ +#define CUVS_INST_PCA_FIT(DataT, IndexT) \ + void fit(raft::resources const& handle, \ + params config, \ + raft::device_matrix_view input, \ + raft::device_matrix_view components, \ + raft::device_vector_view explained_var, \ + raft::device_vector_view explained_var_ratio, \ + raft::device_vector_view singular_vals, \ + raft::device_vector_view mu, \ + raft::device_scalar_view noise_vars, \ + bool flip_signs_based_on_U) \ + { \ + detail::fit(handle, \ + config, \ + input, \ + components, \ + explained_var, \ + explained_var_ratio, \ + singular_vals, \ + mu, \ + noise_vars, \ + flip_signs_based_on_U); \ } -CUVS_INST_PCA_FIT(float); -CUVS_INST_PCA_FIT(double); +CUVS_INST_PCA_FIT(float, int64_t); +CUVS_INST_PCA_FIT(double, int64_t); #undef CUVS_INST_PCA_FIT -#define CUVS_INST_PCA_FIT_TRANSFORM(DataT) \ - void fit_transform(raft::resources const& handle, \ - params config, \ - raft::device_matrix_view input, \ - raft::device_matrix_view trans_input, \ - raft::device_matrix_view components, \ - raft::device_vector_view explained_var, \ - raft::device_vector_view explained_var_ratio, \ - raft::device_vector_view singular_vals, \ - raft::device_vector_view mu, \ - raft::device_scalar_view noise_vars, \ - bool flip_signs_based_on_U) \ - { \ - detail::fit_transform(handle, \ - config, \ - input, \ - trans_input, \ - components, \ - explained_var, \ - explained_var_ratio, \ - singular_vals, \ - mu, \ - noise_vars, \ - flip_signs_based_on_U); \ +#define CUVS_INST_PCA_FIT_TRANSFORM(DataT, IndexT) \ + void fit_transform(raft::resources const& handle, \ + params config, \ + raft::device_matrix_view input, \ + raft::device_matrix_view trans_input, \ + raft::device_matrix_view components, \ + raft::device_vector_view explained_var, \ + raft::device_vector_view explained_var_ratio, \ + raft::device_vector_view singular_vals, \ + raft::device_vector_view mu, \ + raft::device_scalar_view noise_vars, \ + bool flip_signs_based_on_U) \ + { \ + detail::fit_transform(handle, \ + config, \ + input, \ + trans_input, \ + components, \ + explained_var, \ + explained_var_ratio, \ + singular_vals, \ + mu, \ + noise_vars, \ + flip_signs_based_on_U); \ } -CUVS_INST_PCA_FIT_TRANSFORM(float); -CUVS_INST_PCA_FIT_TRANSFORM(double); +CUVS_INST_PCA_FIT_TRANSFORM(float, int64_t); +CUVS_INST_PCA_FIT_TRANSFORM(double, int64_t); #undef CUVS_INST_PCA_FIT_TRANSFORM -#define CUVS_INST_PCA_TRANSFORM(DataT) \ +#define CUVS_INST_PCA_TRANSFORM(DataT, IndexT) \ void transform(raft::resources const& handle, \ params config, \ - raft::device_matrix_view input, \ - raft::device_matrix_view components, \ - raft::device_vector_view singular_vals, \ - raft::device_vector_view mu, \ - raft::device_matrix_view trans_input) \ + raft::device_matrix_view input, \ + raft::device_matrix_view components, \ + raft::device_vector_view singular_vals, \ + raft::device_vector_view mu, \ + raft::device_matrix_view trans_input) \ { \ detail::transform(handle, config, input, components, singular_vals, mu, trans_input); \ } -CUVS_INST_PCA_TRANSFORM(float); -CUVS_INST_PCA_TRANSFORM(double); +CUVS_INST_PCA_TRANSFORM(float, int64_t); +CUVS_INST_PCA_TRANSFORM(double, int64_t); #undef CUVS_INST_PCA_TRANSFORM -#define CUVS_INST_PCA_INVERSE_TRANSFORM(DataT) \ +#define CUVS_INST_PCA_INVERSE_TRANSFORM(DataT, IndexT) \ void inverse_transform(raft::resources const& handle, \ params config, \ - raft::device_matrix_view trans_input, \ - raft::device_matrix_view components, \ - raft::device_vector_view singular_vals, \ - raft::device_vector_view mu, \ - raft::device_matrix_view output) \ + raft::device_matrix_view trans_input, \ + raft::device_matrix_view components, \ + raft::device_vector_view singular_vals, \ + raft::device_vector_view mu, \ + raft::device_matrix_view output) \ { \ detail::inverse_transform(handle, config, trans_input, components, singular_vals, mu, output); \ } -CUVS_INST_PCA_INVERSE_TRANSFORM(float); -CUVS_INST_PCA_INVERSE_TRANSFORM(double); +CUVS_INST_PCA_INVERSE_TRANSFORM(float, int64_t); +CUVS_INST_PCA_INVERSE_TRANSFORM(double, int64_t); #undef CUVS_INST_PCA_INVERSE_TRANSFORM } // namespace cuvs::preprocessing::pca From c3f74efcb924082daac922c56fec3e8295d69c9b Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 16 Feb 2026 02:20:47 +0000 Subject: [PATCH 03/13] remove unused --- cpp/tests/preprocessing/pca.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/tests/preprocessing/pca.cu b/cpp/tests/preprocessing/pca.cu index dfe0808289..ebbb54a717 100644 --- a/cpp/tests/preprocessing/pca.cu +++ b/cpp/tests/preprocessing/pca.cu @@ -5,10 +5,10 @@ #include "../test_utils.cuh" +#include + #include #include -// #include -#include #include #include #include From a7f89a950eb309456ac1b92da4d95256cc5b9a1a Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 16 Feb 2026 02:29:12 +0000 Subject: [PATCH 04/13] trailing return --- cpp/src/preprocessing/pca/detail/pca.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/preprocessing/pca/detail/pca.cuh b/cpp/src/preprocessing/pca/detail/pca.cuh index e0629da684..ee3353700a 100644 --- a/cpp/src/preprocessing/pca/detail/pca.cuh +++ b/cpp/src/preprocessing/pca/detail/pca.cuh @@ -16,7 +16,8 @@ namespace cuvs::preprocessing::pca::detail { /** * @brief Convert cuvs::preprocessing::pca::params to raft::linalg::paramsPCA. */ -inline raft::linalg::paramsPCA to_raft_params(params config, std::size_t n_rows, std::size_t n_cols) +inline auto to_raft_params(params config, std::size_t n_rows, std::size_t n_cols) + -> raft::linalg::paramsPCA { raft::linalg::paramsPCA prms; prms.n_rows = n_rows; From e153fe16ec45c76e88d55d827b102f7e4ae9d677 Mon Sep 17 00:00:00 2001 From: aamijar Date: Wed, 18 Feb 2026 08:44:27 +0000 Subject: [PATCH 05/13] Simplify paramsPCA --- cpp/include/cuvs/preprocessing/pca.hpp | 19 ++++--------------- cpp/src/preprocessing/pca/detail/pca.cuh | 7 +++---- cpp/tests/preprocessing/pca.cu | 10 ++++++---- 3 files changed, 13 insertions(+), 23 deletions(-) diff --git a/cpp/include/cuvs/preprocessing/pca.hpp b/cpp/include/cuvs/preprocessing/pca.hpp index b26613d50c..bbc06fb199 100644 --- a/cpp/include/cuvs/preprocessing/pca.hpp +++ b/cpp/include/cuvs/preprocessing/pca.hpp @@ -7,21 +7,11 @@ #include #include +#include namespace cuvs::preprocessing::pca { -/** - * @brief Solver algorithm for PCA decomposition. - * - * @param COV_EIG_DQ: covariance of input will be used along with eigen decomposition using divide - * and conquer method for symmetric matrices - * @param COV_EIG_JACOBI: covariance of input will be used along with eigen decomposition using - * jacobi method for symmetric matrices - */ -enum class solver : int { - COV_EIG_DQ, - COV_EIG_JACOBI, -}; +using solver = raft::linalg::solver; /** * @brief Parameters for PCA decomposition. Ref: @@ -49,13 +39,12 @@ struct params { /** * @brief Tolerance for singular values computed by svd_solver == 'arpack' or - * svd_solver == 'COV_EIG_JACOBI'. + * the Jacobi solver. */ float tol = 0.0f; /** - * @brief Number of iterations for the power method computed by jacobi method - * (svd_solver == 'COV_EIG_JACOBI'). + * @brief Number of iterations for the power method computed by the Jacobi solver. */ int n_iterations = 15; diff --git a/cpp/src/preprocessing/pca/detail/pca.cuh b/cpp/src/preprocessing/pca/detail/pca.cuh index ee3353700a..c5643ec24d 100644 --- a/cpp/src/preprocessing/pca/detail/pca.cuh +++ b/cpp/src/preprocessing/pca/detail/pca.cuh @@ -22,11 +22,10 @@ inline auto to_raft_params(params config, std::size_t n_rows, std::size_t n_cols raft::linalg::paramsPCA prms; prms.n_rows = n_rows; prms.n_cols = n_cols; - prms.n_components = static_cast(config.n_components); - prms.algorithm = static_cast(static_cast(config.algorithm)); + prms.n_components = config.n_components; + prms.algorithm = config.algorithm; prms.tol = config.tol; - prms.n_iterations = static_cast(config.n_iterations); - prms.verbose = config.verbose; + prms.n_iterations = config.n_iterations; prms.copy = config.copy; prms.whiten = config.whiten; return prms; diff --git a/cpp/tests/preprocessing/pca.cu b/cpp/tests/preprocessing/pca.cu index ebbb54a717..a6e234d355 100644 --- a/cpp/tests/preprocessing/pca.cu +++ b/cpp/tests/preprocessing/pca.cu @@ -94,10 +94,11 @@ class PcaTest : public ::testing::TestWithParam> { // prms.n_rows = params.n_row; prms.n_components = params.n_col; prms.whiten = false; - if (params.algo == 0) + if (params.algo == 0) { prms.algorithm = solver::COV_EIG_DQ; - else + } else { prms.algorithm = solver::COV_EIG_JACOBI; + } auto input_view = raft::make_device_matrix_view( data.data(), params.n_row, params.n_col); @@ -145,10 +146,11 @@ class PcaTest : public ::testing::TestWithParam> { // prms.n_rows = params.n_row2; prms.n_components = params.n_col2; prms.whiten = false; - if (params.algo == 0) + if (params.algo == 0) { prms.algorithm = solver::COV_EIG_DQ; - else if (params.algo == 1) + } else if (params.algo == 1) { prms.algorithm = solver::COV_EIG_JACOBI; + } r.uniform(data2.data(), len, T(-1.0), T(1.0), stream); rmm::device_uvector data2_trans(params.n_row2 * prms.n_components, stream); From f3a4007e916102a5fee34b2df7e716afc22b6370 Mon Sep 17 00:00:00 2001 From: Anupam <54245698+aamijar@users.noreply.github.com> Date: Wed, 4 Mar 2026 08:31:56 -0800 Subject: [PATCH 06/13] Update cpp/tests/preprocessing/pca.cu Co-authored-by: Dante Gama Dessavre --- cpp/tests/preprocessing/pca.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/tests/preprocessing/pca.cu b/cpp/tests/preprocessing/pca.cu index a6e234d355..a96ed6ed07 100644 --- a/cpp/tests/preprocessing/pca.cu +++ b/cpp/tests/preprocessing/pca.cu @@ -289,7 +289,7 @@ TEST_P(PcaTestDataVecSmallD, Result) { ASSERT_TRUE(devArrMatch(data.data(), data_back.data(), - (params.n_col * params.n_col), + (params.n_row * params.n_col), cuvs::CompareApprox(params.tolerance), raft::resource::get_cuda_stream(handle))); } From ff4e72345cd21b4264ee4a6a69938a5986446dd6 Mon Sep 17 00:00:00 2001 From: Anupam <54245698+aamijar@users.noreply.github.com> Date: Wed, 4 Mar 2026 08:33:27 -0800 Subject: [PATCH 07/13] Apply suggestions from code review Co-authored-by: Dante Gama Dessavre --- cpp/tests/preprocessing/pca.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/tests/preprocessing/pca.cu b/cpp/tests/preprocessing/pca.cu index a96ed6ed07..5499463e56 100644 --- a/cpp/tests/preprocessing/pca.cu +++ b/cpp/tests/preprocessing/pca.cu @@ -279,7 +279,7 @@ TEST_P(PcaTestDataVecSmallF, Result) { ASSERT_TRUE(devArrMatch(data.data(), data_back.data(), - (params.n_col * params.n_col), + (params.n_row * params.n_col), cuvs::CompareApprox(params.tolerance), raft::resource::get_cuda_stream(handle))); } @@ -299,7 +299,7 @@ TEST_P(PcaTestDataVecF, Result) { ASSERT_TRUE(devArrMatch(data2.data(), data2_back.data(), - (params.n_col2 * params.n_col2), + (params.n_row2 * params.n_col2), cuvs::CompareApprox(params.tolerance), raft::resource::get_cuda_stream(handle))); } @@ -309,7 +309,7 @@ TEST_P(PcaTestDataVecD, Result) { ASSERT_TRUE(devArrMatch(data2.data(), data2_back.data(), - (params.n_col2 * params.n_col2), + (params.n_row2 * params.n_col2), cuvs::CompareApprox(params.tolerance), raft::resource::get_cuda_stream(handle))); } From 99f32fcdcaa809f6ca1f71c4e4147c36121ddc6e Mon Sep 17 00:00:00 2001 From: aamijar Date: Wed, 4 Mar 2026 08:47:44 -0800 Subject: [PATCH 08/13] remove verbose param --- cpp/include/cuvs/preprocessing/pca.hpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/cpp/include/cuvs/preprocessing/pca.hpp b/cpp/include/cuvs/preprocessing/pca.hpp index bbc06fb199..42a52866a5 100644 --- a/cpp/include/cuvs/preprocessing/pca.hpp +++ b/cpp/include/cuvs/preprocessing/pca.hpp @@ -47,9 +47,6 @@ struct params { * @brief Number of iterations for the power method computed by the Jacobi solver. */ int n_iterations = 15; - - /** @brief 0: no error message printing, 1: print error messages. */ - int verbose = 0; }; /** From 074fd96e6290997ac6325600b53016d485a3dbb9 Mon Sep 17 00:00:00 2001 From: aamijar Date: Wed, 4 Mar 2026 08:50:50 -0800 Subject: [PATCH 09/13] remove commented out --- cpp/tests/preprocessing/pca.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/tests/preprocessing/pca.cu b/cpp/tests/preprocessing/pca.cu index 5499463e56..a99df21c2a 100644 --- a/cpp/tests/preprocessing/pca.cu +++ b/cpp/tests/preprocessing/pca.cu @@ -142,8 +142,6 @@ class PcaTest : public ::testing::TestWithParam> { int len = params.len2; cuvs::preprocessing::pca::params prms; - // prms.n_cols = params.n_col2; - // prms.n_rows = params.n_row2; prms.n_components = params.n_col2; prms.whiten = false; if (params.algo == 0) { From c7c52a70ce4985bd7c473b554879193cf87c3f7d Mon Sep 17 00:00:00 2001 From: aamijar Date: Wed, 4 Mar 2026 09:08:26 -0800 Subject: [PATCH 10/13] const params& config --- cpp/include/cuvs/preprocessing/pca.hpp | 16 ++++++++-------- cpp/src/preprocessing/pca/detail/pca.cuh | 10 +++++----- cpp/src/preprocessing/pca/pca.cu | 8 ++++---- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/cpp/include/cuvs/preprocessing/pca.hpp b/cpp/include/cuvs/preprocessing/pca.hpp index 42a52866a5..846264e57e 100644 --- a/cpp/include/cuvs/preprocessing/pca.hpp +++ b/cpp/include/cuvs/preprocessing/pca.hpp @@ -97,7 +97,7 @@ struct params { * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) */ void fit(raft::resources const& handle, - params config, + const params& config, raft::device_matrix_view input, raft::device_matrix_view components, raft::device_vector_view explained_var, @@ -108,7 +108,7 @@ void fit(raft::resources const& handle, bool flip_signs_based_on_U = false); void fit(raft::resources const& handle, - params config, + const params& config, raft::device_matrix_view input, raft::device_matrix_view components, raft::device_vector_view explained_var, @@ -137,7 +137,7 @@ void fit(raft::resources const& handle, * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) */ void fit_transform(raft::resources const& handle, - params config, + const params& config, raft::device_matrix_view input, raft::device_matrix_view trans_input, raft::device_matrix_view components, @@ -149,7 +149,7 @@ void fit_transform(raft::resources const& handle, bool flip_signs_based_on_U = false); void fit_transform(raft::resources const& handle, - params config, + const params& config, raft::device_matrix_view input, raft::device_matrix_view trans_input, raft::device_matrix_view components, @@ -175,7 +175,7 @@ void fit_transform(raft::resources const& handle, * @param[out] trans_input transformed data [n_rows x n_components] (col-major) */ void transform(raft::resources const& handle, - params config, + const params& config, raft::device_matrix_view input, raft::device_matrix_view components, raft::device_vector_view singular_vals, @@ -183,7 +183,7 @@ void transform(raft::resources const& handle, raft::device_matrix_view trans_input); void transform(raft::resources const& handle, - params config, + const params& config, raft::device_matrix_view input, raft::device_matrix_view components, raft::device_vector_view singular_vals, @@ -204,7 +204,7 @@ void transform(raft::resources const& handle, * @param[out] output reconstructed data [n_rows x n_cols] (col-major) */ void inverse_transform(raft::resources const& handle, - params config, + const params& config, raft::device_matrix_view trans_input, raft::device_matrix_view components, raft::device_vector_view singular_vals, @@ -212,7 +212,7 @@ void inverse_transform(raft::resources const& handle, raft::device_matrix_view output); void inverse_transform(raft::resources const& handle, - params config, + const params& config, raft::device_matrix_view trans_input, raft::device_matrix_view components, raft::device_vector_view singular_vals, diff --git a/cpp/src/preprocessing/pca/detail/pca.cuh b/cpp/src/preprocessing/pca/detail/pca.cuh index c5643ec24d..0aa1f555e2 100644 --- a/cpp/src/preprocessing/pca/detail/pca.cuh +++ b/cpp/src/preprocessing/pca/detail/pca.cuh @@ -16,7 +16,7 @@ namespace cuvs::preprocessing::pca::detail { /** * @brief Convert cuvs::preprocessing::pca::params to raft::linalg::paramsPCA. */ -inline auto to_raft_params(params config, std::size_t n_rows, std::size_t n_cols) +inline auto to_raft_params(const params& config, std::size_t n_rows, std::size_t n_cols) -> raft::linalg::paramsPCA { raft::linalg::paramsPCA prms; @@ -33,7 +33,7 @@ inline auto to_raft_params(params config, std::size_t n_rows, std::size_t n_cols template void fit(raft::resources const& handle, - params config, + const params& config, raft::device_matrix_view input, raft::device_matrix_view components, raft::device_vector_view explained_var, @@ -58,7 +58,7 @@ void fit(raft::resources const& handle, template void fit_transform(raft::resources const& handle, - params config, + const params& config, raft::device_matrix_view input, raft::device_matrix_view trans_input, raft::device_matrix_view components, @@ -85,7 +85,7 @@ void fit_transform(raft::resources const& handle, template void transform(raft::resources const& handle, - params config, + const params& config, raft::device_matrix_view input, raft::device_matrix_view components, raft::device_vector_view singular_vals, @@ -98,7 +98,7 @@ void transform(raft::resources const& handle, template void inverse_transform(raft::resources const& handle, - params config, + const params& config, raft::device_matrix_view trans_input, raft::device_matrix_view components, raft::device_vector_view singular_vals, diff --git a/cpp/src/preprocessing/pca/pca.cu b/cpp/src/preprocessing/pca/pca.cu index 1091090815..faad547cfd 100644 --- a/cpp/src/preprocessing/pca/pca.cu +++ b/cpp/src/preprocessing/pca/pca.cu @@ -11,7 +11,7 @@ namespace cuvs::preprocessing::pca { #define CUVS_INST_PCA_FIT(DataT, IndexT) \ void fit(raft::resources const& handle, \ - params config, \ + const params& config, \ raft::device_matrix_view input, \ raft::device_matrix_view components, \ raft::device_vector_view explained_var, \ @@ -39,7 +39,7 @@ CUVS_INST_PCA_FIT(double, int64_t); #define CUVS_INST_PCA_FIT_TRANSFORM(DataT, IndexT) \ void fit_transform(raft::resources const& handle, \ - params config, \ + const params& config, \ raft::device_matrix_view input, \ raft::device_matrix_view trans_input, \ raft::device_matrix_view components, \ @@ -69,7 +69,7 @@ CUVS_INST_PCA_FIT_TRANSFORM(double, int64_t); #define CUVS_INST_PCA_TRANSFORM(DataT, IndexT) \ void transform(raft::resources const& handle, \ - params config, \ + const params& config, \ raft::device_matrix_view input, \ raft::device_matrix_view components, \ raft::device_vector_view singular_vals, \ @@ -85,7 +85,7 @@ CUVS_INST_PCA_TRANSFORM(double, int64_t); #define CUVS_INST_PCA_INVERSE_TRANSFORM(DataT, IndexT) \ void inverse_transform(raft::resources const& handle, \ - params config, \ + const params& config, \ raft::device_matrix_view trans_input, \ raft::device_matrix_view components, \ raft::device_vector_view singular_vals, \ From 0bfd500fd55f8294843fa70bac059eaf890a361d Mon Sep 17 00:00:00 2001 From: aamijar Date: Wed, 4 Mar 2026 09:59:05 -0800 Subject: [PATCH 11/13] remove comments --- cpp/tests/preprocessing/pca.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/tests/preprocessing/pca.cu b/cpp/tests/preprocessing/pca.cu index a99df21c2a..29f17a2813 100644 --- a/cpp/tests/preprocessing/pca.cu +++ b/cpp/tests/preprocessing/pca.cu @@ -90,8 +90,6 @@ class PcaTest : public ::testing::TestWithParam> { explained_vars_ref.data(), explained_vars_ref_h.data(), params.n_col, stream); cuvs::preprocessing::pca::params prms; - // prms.n_cols = params.n_col; - // prms.n_rows = params.n_row; prms.n_components = params.n_col; prms.whiten = false; if (params.algo == 0) { From 5144ab27c89f8c8e1198b2d4acd13b24e085e880 Mon Sep 17 00:00:00 2001 From: aamijar Date: Wed, 4 Mar 2026 10:13:13 -0800 Subject: [PATCH 12/13] remove double apis --- cpp/include/cuvs/preprocessing/pca.hpp | 39 ---------------- cpp/src/preprocessing/pca/pca.cu | 4 -- cpp/tests/preprocessing/pca.cu | 64 -------------------------- 3 files changed, 107 deletions(-) diff --git a/cpp/include/cuvs/preprocessing/pca.hpp b/cpp/include/cuvs/preprocessing/pca.hpp index 846264e57e..f30b245d34 100644 --- a/cpp/include/cuvs/preprocessing/pca.hpp +++ b/cpp/include/cuvs/preprocessing/pca.hpp @@ -107,17 +107,6 @@ void fit(raft::resources const& handle, raft::device_scalar_view noise_vars, bool flip_signs_based_on_U = false); -void fit(raft::resources const& handle, - const params& config, - raft::device_matrix_view input, - raft::device_matrix_view components, - raft::device_vector_view explained_var, - raft::device_vector_view explained_var_ratio, - raft::device_vector_view singular_vals, - raft::device_vector_view mu, - raft::device_scalar_view noise_vars, - bool flip_signs_based_on_U = false); - /** * @brief Perform PCA fit and transform operations. * @@ -148,18 +137,6 @@ void fit_transform(raft::resources const& handle, raft::device_scalar_view noise_vars, bool flip_signs_based_on_U = false); -void fit_transform(raft::resources const& handle, - const params& config, - raft::device_matrix_view input, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, - raft::device_vector_view explained_var, - raft::device_vector_view explained_var_ratio, - raft::device_vector_view singular_vals, - raft::device_vector_view mu, - raft::device_scalar_view noise_vars, - bool flip_signs_based_on_U = false); - /** * @brief Perform PCA transform operation. * @@ -182,14 +159,6 @@ void transform(raft::resources const& handle, raft::device_vector_view mu, raft::device_matrix_view trans_input); -void transform(raft::resources const& handle, - const params& config, - raft::device_matrix_view input, - raft::device_matrix_view components, - raft::device_vector_view singular_vals, - raft::device_vector_view mu, - raft::device_matrix_view trans_input); - /** * @brief Perform PCA inverse transform operation. * @@ -211,14 +180,6 @@ void inverse_transform(raft::resources const& handle, raft::device_vector_view mu, raft::device_matrix_view output); -void inverse_transform(raft::resources const& handle, - const params& config, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, - raft::device_vector_view singular_vals, - raft::device_vector_view mu, - raft::device_matrix_view output); - /** @} */ // end group pca } // namespace cuvs::preprocessing::pca diff --git a/cpp/src/preprocessing/pca/pca.cu b/cpp/src/preprocessing/pca/pca.cu index faad547cfd..ac944fddd7 100644 --- a/cpp/src/preprocessing/pca/pca.cu +++ b/cpp/src/preprocessing/pca/pca.cu @@ -34,7 +34,6 @@ namespace cuvs::preprocessing::pca { } CUVS_INST_PCA_FIT(float, int64_t); -CUVS_INST_PCA_FIT(double, int64_t); #undef CUVS_INST_PCA_FIT #define CUVS_INST_PCA_FIT_TRANSFORM(DataT, IndexT) \ @@ -64,7 +63,6 @@ CUVS_INST_PCA_FIT(double, int64_t); } CUVS_INST_PCA_FIT_TRANSFORM(float, int64_t); -CUVS_INST_PCA_FIT_TRANSFORM(double, int64_t); #undef CUVS_INST_PCA_FIT_TRANSFORM #define CUVS_INST_PCA_TRANSFORM(DataT, IndexT) \ @@ -80,7 +78,6 @@ CUVS_INST_PCA_FIT_TRANSFORM(double, int64_t); } CUVS_INST_PCA_TRANSFORM(float, int64_t); -CUVS_INST_PCA_TRANSFORM(double, int64_t); #undef CUVS_INST_PCA_TRANSFORM #define CUVS_INST_PCA_INVERSE_TRANSFORM(DataT, IndexT) \ @@ -96,7 +93,6 @@ CUVS_INST_PCA_TRANSFORM(double, int64_t); } CUVS_INST_PCA_INVERSE_TRANSFORM(float, int64_t); -CUVS_INST_PCA_INVERSE_TRANSFORM(double, int64_t); #undef CUVS_INST_PCA_INVERSE_TRANSFORM } // namespace cuvs::preprocessing::pca diff --git a/cpp/tests/preprocessing/pca.cu b/cpp/tests/preprocessing/pca.cu index 29f17a2813..81e2861140 100644 --- a/cpp/tests/preprocessing/pca.cu +++ b/cpp/tests/preprocessing/pca.cu @@ -206,10 +206,6 @@ const std::vector> inputsf2 = { {0.01f, 3 * 2, 3, 2, 1024 * 128, 1024, 128, 1234ULL, 0}, {0.01f, 3 * 2, 3, 2, 256 * 32, 256, 32, 1234ULL, 1}}; -const std::vector> inputsd2 = { - {0.01, 3 * 2, 3, 2, 1024 * 128, 1024, 128, 1234ULL, 0}, - {0.01, 3 * 2, 3, 2, 256 * 32, 256, 32, 1234ULL, 1}}; - typedef PcaTest PcaTestValF; TEST_P(PcaTestValF, Result) { @@ -220,16 +216,6 @@ TEST_P(PcaTestValF, Result) raft::resource::get_cuda_stream(handle))); } -typedef PcaTest PcaTestValD; -TEST_P(PcaTestValD, Result) -{ - ASSERT_TRUE(devArrMatch(explained_vars.data(), - explained_vars_ref.data(), - params.n_col, - cuvs::CompareApprox(params.tolerance), - raft::resource::get_cuda_stream(handle))); -} - typedef PcaTest PcaTestLeftVecF; TEST_P(PcaTestLeftVecF, Result) { @@ -240,16 +226,6 @@ TEST_P(PcaTestLeftVecF, Result) raft::resource::get_cuda_stream(handle))); } -typedef PcaTest PcaTestLeftVecD; -TEST_P(PcaTestLeftVecD, Result) -{ - ASSERT_TRUE(devArrMatch(components.data(), - components_ref.data(), - (params.n_col * params.n_col), - cuvs::CompareApprox(params.tolerance), - raft::resource::get_cuda_stream(handle))); -} - typedef PcaTest PcaTestTransDataF; TEST_P(PcaTestTransDataF, Result) { @@ -260,16 +236,6 @@ TEST_P(PcaTestTransDataF, Result) raft::resource::get_cuda_stream(handle))); } -typedef PcaTest PcaTestTransDataD; -TEST_P(PcaTestTransDataD, Result) -{ - ASSERT_TRUE(devArrMatch(trans_data.data(), - trans_data_ref.data(), - (params.n_row * params.n_col), - cuvs::CompareApprox(params.tolerance), - raft::resource::get_cuda_stream(handle))); -} - typedef PcaTest PcaTestDataVecSmallF; TEST_P(PcaTestDataVecSmallF, Result) { @@ -280,16 +246,6 @@ TEST_P(PcaTestDataVecSmallF, Result) raft::resource::get_cuda_stream(handle))); } -typedef PcaTest PcaTestDataVecSmallD; -TEST_P(PcaTestDataVecSmallD, Result) -{ - ASSERT_TRUE(devArrMatch(data.data(), - data_back.data(), - (params.n_row * params.n_col), - cuvs::CompareApprox(params.tolerance), - raft::resource::get_cuda_stream(handle))); -} - typedef PcaTest PcaTestDataVecF; TEST_P(PcaTestDataVecF, Result) { @@ -300,34 +256,14 @@ TEST_P(PcaTestDataVecF, Result) raft::resource::get_cuda_stream(handle))); } -typedef PcaTest PcaTestDataVecD; -TEST_P(PcaTestDataVecD, Result) -{ - ASSERT_TRUE(devArrMatch(data2.data(), - data2_back.data(), - (params.n_row2 * params.n_col2), - cuvs::CompareApprox(params.tolerance), - raft::resource::get_cuda_stream(handle))); -} - INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestValF, ::testing::ValuesIn(inputsf2)); -INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestValD, ::testing::ValuesIn(inputsd2)); - INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestLeftVecF, ::testing::ValuesIn(inputsf2)); -INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestLeftVecD, ::testing::ValuesIn(inputsd2)); - INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecSmallF, ::testing::ValuesIn(inputsf2)); -INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecSmallD, ::testing::ValuesIn(inputsd2)); - INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestTransDataF, ::testing::ValuesIn(inputsf2)); -INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestTransDataD, ::testing::ValuesIn(inputsd2)); - INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecF, ::testing::ValuesIn(inputsf2)); -INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecD, ::testing::ValuesIn(inputsd2)); - } // end namespace cuvs::preprocessing::pca From 948882ce636ec4bd9555e6c36c8cb34d84e8bbe3 Mon Sep 17 00:00:00 2001 From: aamijar Date: Wed, 4 Mar 2026 17:30:40 -0800 Subject: [PATCH 13/13] add new gtest and refactor --- cpp/tests/preprocessing/pca.cu | 399 ++++++++++++++++----------------- 1 file changed, 198 insertions(+), 201 deletions(-) diff --git a/cpp/tests/preprocessing/pca.cu b/cpp/tests/preprocessing/pca.cu index 81e2861140..4430972911 100644 --- a/cpp/tests/preprocessing/pca.cu +++ b/cpp/tests/preprocessing/pca.cu @@ -16,6 +16,8 @@ #include #include +#include +#include #include namespace cuvs::preprocessing::pca { @@ -23,10 +25,8 @@ namespace cuvs::preprocessing::pca { template struct PcaInputs { T tolerance; - int len; int n_row; int n_col; - int len2; int n_row2; int n_col2; unsigned long long int seed; @@ -39,231 +39,228 @@ template return os; } +/** + * @brief Run fit_transform followed by inverse_transform. + * + * Intermediate buffers are managed internally unless the caller provides + * pre-allocated pointers via the optional parameters, in which case the + * results are written there directly. + */ +template +void pca_roundtrip(raft::resources const& handle, + T* input, + int n_rows, + int n_cols, + T* output, + int n_components, + int algo, + cudaStream_t stream, + T* components_out = nullptr, + T* explained_var_out = nullptr, + T* trans_out = nullptr) +{ + params prms; + prms.n_components = n_components; + if (algo == 0) + prms.algorithm = solver::COV_EIG_DQ; + else + prms.algorithm = solver::COV_EIG_JACOBI; + + rmm::device_uvector comp_buf(components_out ? 0 : n_components * n_cols, stream); + rmm::device_uvector ev_buf(explained_var_out ? 0 : n_components, stream); + rmm::device_uvector trans_buf(trans_out ? 0 : n_rows * n_components, stream); + + T* comp_ptr = components_out ? components_out : comp_buf.data(); + T* ev_ptr = explained_var_out ? explained_var_out : ev_buf.data(); + T* trans_ptr = trans_out ? trans_out : trans_buf.data(); + + rmm::device_uvector evr(n_components, stream); + rmm::device_uvector sv(n_components, stream); + rmm::device_uvector mu(n_cols, stream); + rmm::device_uvector nv(1, stream); + + auto input_view = + raft::make_device_matrix_view(input, n_rows, n_cols); + auto trans_view = + raft::make_device_matrix_view(trans_ptr, n_rows, n_components); + auto comp_view = + raft::make_device_matrix_view(comp_ptr, n_components, n_cols); + auto ev_view = raft::make_device_vector_view(ev_ptr, n_components); + auto evr_view = raft::make_device_vector_view(evr.data(), n_components); + auto sv_view = raft::make_device_vector_view(sv.data(), n_components); + auto mu_view = raft::make_device_vector_view(mu.data(), n_cols); + auto nv_view = raft::make_device_scalar_view(nv.data()); + auto output_view = + raft::make_device_matrix_view(output, n_rows, n_cols); + + fit_transform( + handle, prms, input_view, trans_view, comp_view, ev_view, evr_view, sv_view, mu_view, nv_view); + inverse_transform(handle, prms, trans_view, comp_view, sv_view, mu_view, output_view); +} + template class PcaTest : public ::testing::TestWithParam> { public: PcaTest() - : params(::testing::TestWithParam>::GetParam()), + : params_(::testing::TestWithParam>::GetParam()), stream(raft::resource::get_cuda_stream(handle)), - explained_vars(params.n_col, stream), - explained_vars_ref(params.n_col, stream), - components(params.n_col * params.n_col, stream), - components_ref(params.n_col * params.n_col, stream), - trans_data(params.len, stream), - trans_data_ref(params.len, stream), - data(params.len, stream), - data_back(params.len, stream), - data2(params.len2, stream), - data2_back(params.len2, stream) + explained_vars(params_.n_col, stream), + explained_vars_ref(params_.n_col, stream), + components(params_.n_col * params_.n_col, stream), + components_ref(params_.n_col * params_.n_col, stream), + trans_data(params_.n_row * params_.n_col, stream), + trans_data_ref(params_.n_row * params_.n_col, stream), + data(params_.n_row * params_.n_col, stream), + data_back(params_.n_row * params_.n_col, stream), + data2(params_.n_row2 * params_.n_col2, stream), + data2_back(params_.n_row2 * params_.n_col2, stream) { - basicTest(); - advancedTest(); } protected: - void basicTest() + void SetUp() override { - raft::random::Rng r(params.seed, raft::random::GenPC); - int len = params.len; - - std::vector data_h = {1.0, 2.0, 5.0, 4.0, 2.0, 1.0}; - data_h.resize(len); - raft::update_device(data.data(), data_h.data(), len, stream); - - std::vector trans_data_ref_h = {-2.3231, -0.3517, 2.6748, 0.3979, -0.6571, 0.2592}; - trans_data_ref_h.resize(len); - raft::update_device(trans_data_ref.data(), trans_data_ref_h.data(), len, stream); - - int len_comp = params.n_col * params.n_col; - rmm::device_uvector explained_var_ratio(params.n_col, stream); - rmm::device_uvector singular_vals(params.n_col, stream); - rmm::device_uvector mean(params.n_col, stream); - rmm::device_uvector noise_vars(1, stream); - - std::vector components_ref_h = {0.8163, 0.5776, -0.5776, 0.8163}; - components_ref_h.resize(len_comp); - std::vector explained_vars_ref_h = {6.338, 0.3287}; - explained_vars_ref_h.resize(params.n_col); - - raft::update_device(components_ref.data(), components_ref_h.data(), len_comp, stream); - raft::update_device( - explained_vars_ref.data(), explained_vars_ref_h.data(), params.n_col, stream); - - cuvs::preprocessing::pca::params prms; - prms.n_components = params.n_col; - prms.whiten = false; - if (params.algo == 0) { - prms.algorithm = solver::COV_EIG_DQ; - } else { - prms.algorithm = solver::COV_EIG_JACOBI; + int len = params_.n_row * params_.n_col; + int len2 = params_.n_row2 * params_.n_col2; + + // --- basic test: all components, known reference data --- + { + std::vector data_h = {1.0, 2.0, 5.0, 4.0, 2.0, 1.0}; + data_h.resize(len); + raft::update_device(data.data(), data_h.data(), len, stream); + + std::vector trans_data_ref_h = {-2.3231, -0.3517, 2.6748, 0.3979, -0.6571, 0.2592}; + trans_data_ref_h.resize(len); + raft::update_device(trans_data_ref.data(), trans_data_ref_h.data(), len, stream); + + int len_comp = params_.n_col * params_.n_col; + + std::vector components_ref_h = {0.8163, 0.5776, -0.5776, 0.8163}; + components_ref_h.resize(len_comp); + std::vector explained_vars_ref_h = {6.338, 0.3287}; + explained_vars_ref_h.resize(params_.n_col); + + raft::update_device(components_ref.data(), components_ref_h.data(), len_comp, stream); + raft::update_device( + explained_vars_ref.data(), explained_vars_ref_h.data(), params_.n_col, stream); + + pca_roundtrip(handle, + data.data(), + params_.n_row, + params_.n_col, + data_back.data(), + params_.n_col, + params_.algo, + stream, + components.data(), + explained_vars.data(), + trans_data.data()); } - auto input_view = raft::make_device_matrix_view( - data.data(), params.n_row, params.n_col); - auto components_view = raft::make_device_matrix_view( - components.data(), prms.n_components, params.n_col); - auto explained_var_view = - raft::make_device_vector_view(explained_vars.data(), prms.n_components); - auto explained_var_ratio_view = - raft::make_device_vector_view(explained_var_ratio.data(), prms.n_components); - auto singular_vals_view = - raft::make_device_vector_view(singular_vals.data(), prms.n_components); - auto mu_view = raft::make_device_vector_view(mean.data(), params.n_col); - auto noise_vars_view = raft::make_device_scalar_view(noise_vars.data()); - - cuvs::preprocessing::pca::fit(handle, - prms, - input_view, - components_view, - explained_var_view, - explained_var_ratio_view, - singular_vals_view, - mu_view, - noise_vars_view); - - auto trans_data_view = raft::make_device_matrix_view( - trans_data.data(), params.n_row, prms.n_components); - - cuvs::preprocessing::pca::transform( - handle, prms, input_view, components_view, singular_vals_view, mu_view, trans_data_view); - - auto data_back_view = raft::make_device_matrix_view( - data_back.data(), params.n_row, params.n_col); - - cuvs::preprocessing::pca::inverse_transform( - handle, prms, trans_data_view, components_view, singular_vals_view, mu_view, data_back_view); - } + // --- advanced test: all components, random data --- + { + raft::random::Rng r(params_.seed, raft::random::GenPC); + r.uniform(data2.data(), len2, T(-1.0), T(1.0), stream); + + pca_roundtrip(handle, + data2.data(), + params_.n_row2, + params_.n_col2, + data2_back.data(), + params_.n_col2, + params_.algo, + stream); + } - void advancedTest() - { - raft::random::Rng r(params.seed, raft::random::GenPC); - int len = params.len2; - - cuvs::preprocessing::pca::params prms; - prms.n_components = params.n_col2; - prms.whiten = false; - if (params.algo == 0) { - prms.algorithm = solver::COV_EIG_DQ; - } else if (params.algo == 1) { - prms.algorithm = solver::COV_EIG_JACOBI; + // --- dim reduction test: n_components < n_cols, random data --- + { + int n_components = std::max(1, params_.n_col2 / 4); + + rmm::device_uvector input(len2, stream); + rmm::device_uvector input_copy(len2, stream); + rmm::device_uvector recon(len2, stream); + + raft::random::Rng rng(params_.seed + 1, raft::random::GenPC); + rng.uniform(input.data(), len2, T(-1.0), T(1.0), stream); + raft::copy(input_copy.data(), input.data(), len2, stream); + + pca_roundtrip(handle, + input.data(), + params_.n_row2, + params_.n_col2, + recon.data(), + n_components, + params_.algo, + stream); + + std::vector orig_h(len2); + std::vector recon_h(len2); + raft::update_host(orig_h.data(), input_copy.data(), len2, stream); + raft::update_host(recon_h.data(), recon.data(), len2, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + + max_recon_err = T(0); + for (int i = 0; i < len2; ++i) { + max_recon_err = std::max(max_recon_err, std::abs(orig_h[i] - recon_h[i])); + } } + } - r.uniform(data2.data(), len, T(-1.0), T(1.0), stream); - rmm::device_uvector data2_trans(params.n_row2 * prms.n_components, stream); - - int len_comp = params.n_col2 * prms.n_components; - rmm::device_uvector components2(len_comp, stream); - rmm::device_uvector explained_vars2(prms.n_components, stream); - rmm::device_uvector explained_var_ratio2(prms.n_components, stream); - rmm::device_uvector singular_vals2(prms.n_components, stream); - rmm::device_uvector mean2(params.n_col2, stream); - rmm::device_uvector noise_vars2(1, stream); - - auto input_view = raft::make_device_matrix_view( - data2.data(), params.n_row2, params.n_col2); - auto trans_view = raft::make_device_matrix_view( - data2_trans.data(), params.n_row2, prms.n_components); - auto comp_view = raft::make_device_matrix_view( - components2.data(), prms.n_components, params.n_col2); - auto ev_view = - raft::make_device_vector_view(explained_vars2.data(), prms.n_components); - auto evr_view = - raft::make_device_vector_view(explained_var_ratio2.data(), prms.n_components); - auto sv_view = - raft::make_device_vector_view(singular_vals2.data(), prms.n_components); - auto mu_view = raft::make_device_vector_view(mean2.data(), params.n_col2); - auto noise_view = raft::make_device_scalar_view(noise_vars2.data()); - - cuvs::preprocessing::pca::fit_transform(handle, - prms, - input_view, - trans_view, - comp_view, - ev_view, - evr_view, - sv_view, - mu_view, - noise_view); - - auto data2_back_view = raft::make_device_matrix_view( - data2_back.data(), params.n_row2, params.n_col2); - - cuvs::preprocessing::pca::inverse_transform( - handle, prms, trans_view, comp_view, sv_view, mu_view, data2_back_view); + void testPca() + { + auto s = raft::resource::get_cuda_stream(handle); + + ASSERT_TRUE(devArrMatch(explained_vars.data(), + explained_vars_ref.data(), + params_.n_col, + cuvs::CompareApprox(params_.tolerance), + s)); + + ASSERT_TRUE(devArrMatch(components.data(), + components_ref.data(), + (params_.n_col * params_.n_col), + cuvs::CompareApprox(params_.tolerance), + s)); + + ASSERT_TRUE(devArrMatch(trans_data.data(), + trans_data_ref.data(), + (params_.n_row * params_.n_col), + cuvs::CompareApprox(params_.tolerance), + s)); + + ASSERT_TRUE(devArrMatch(data.data(), + data_back.data(), + (params_.n_row * params_.n_col), + cuvs::CompareApprox(params_.tolerance), + s)); + + ASSERT_TRUE(devArrMatch(data2.data(), + data2_back.data(), + (params_.n_row2 * params_.n_col2), + cuvs::CompareApprox(params_.tolerance), + s)); + + EXPECT_GT(max_recon_err, T(1e-5)) << "Error should be non-zero when n_components < n_cols"; + EXPECT_LT(max_recon_err, T(2.0)) << "Reconstruction error should be bounded"; } - protected: + private: raft::device_resources handle; cudaStream_t stream; - PcaInputs params; + PcaInputs params_; + T max_recon_err = T(0); rmm::device_uvector explained_vars, explained_vars_ref, components, components_ref, trans_data, trans_data_ref, data, data_back, data2, data2_back; }; -const std::vector> inputsf2 = { - {0.01f, 3 * 2, 3, 2, 1024 * 128, 1024, 128, 1234ULL, 0}, - {0.01f, 3 * 2, 3, 2, 256 * 32, 256, 32, 1234ULL, 1}}; - -typedef PcaTest PcaTestValF; -TEST_P(PcaTestValF, Result) -{ - ASSERT_TRUE(devArrMatch(explained_vars.data(), - explained_vars_ref.data(), - params.n_col, - cuvs::CompareApprox(params.tolerance), - raft::resource::get_cuda_stream(handle))); -} - -typedef PcaTest PcaTestLeftVecF; -TEST_P(PcaTestLeftVecF, Result) -{ - ASSERT_TRUE(devArrMatch(components.data(), - components_ref.data(), - (params.n_col * params.n_col), - cuvs::CompareApprox(params.tolerance), - raft::resource::get_cuda_stream(handle))); -} - -typedef PcaTest PcaTestTransDataF; -TEST_P(PcaTestTransDataF, Result) -{ - ASSERT_TRUE(devArrMatch(trans_data.data(), - trans_data_ref.data(), - (params.n_row * params.n_col), - cuvs::CompareApprox(params.tolerance), - raft::resource::get_cuda_stream(handle))); -} - -typedef PcaTest PcaTestDataVecSmallF; -TEST_P(PcaTestDataVecSmallF, Result) -{ - ASSERT_TRUE(devArrMatch(data.data(), - data_back.data(), - (params.n_row * params.n_col), - cuvs::CompareApprox(params.tolerance), - raft::resource::get_cuda_stream(handle))); -} - -typedef PcaTest PcaTestDataVecF; -TEST_P(PcaTestDataVecF, Result) -{ - ASSERT_TRUE(devArrMatch(data2.data(), - data2_back.data(), - (params.n_row2 * params.n_col2), - cuvs::CompareApprox(params.tolerance), - raft::resource::get_cuda_stream(handle))); -} - -INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestValF, ::testing::ValuesIn(inputsf2)); - -INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestLeftVecF, ::testing::ValuesIn(inputsf2)); - -INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecSmallF, ::testing::ValuesIn(inputsf2)); +const std::vector> inputsf2 = {{0.01f, 3, 2, 1024, 128, 1234ULL, 0}, + {0.01f, 3, 2, 256, 32, 1234ULL, 1}}; -INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestTransDataF, ::testing::ValuesIn(inputsf2)); +typedef PcaTest PcaTestF; +TEST_P(PcaTestF, Result) { this->testPca(); } -INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecF, ::testing::ValuesIn(inputsf2)); +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestF, ::testing::ValuesIn(inputsf2)); } // end namespace cuvs::preprocessing::pca