diff --git a/src/pcms/interpolator/mls_interpolation.cpp b/src/pcms/interpolator/mls_interpolation.cpp index 113986a1..3bbf9767 100644 --- a/src/pcms/interpolator/mls_interpolation.cpp +++ b/src/pcms/interpolator/mls_interpolation.cpp @@ -105,7 +105,8 @@ Write mls_interpolation(const Reals source_values, const Reals source_coordinates, const Reals target_coordinates, const SupportResults& support, const LO& dim, - const LO& degree, RadialBasisFunction bf) + const LO& degree, RadialBasisFunction bf, + double lambda_factor) { const auto nvertices_target = target_coordinates.size() / dim; @@ -116,19 +117,19 @@ Write mls_interpolation(const Reals source_values, case RadialBasisFunction::RBF_GAUSSIAN: interpolated_values = detail::mls_interpolation( source_values, source_coordinates, target_coordinates, support, dim, - degree, RBF_GAUSSIAN{}); + degree, RBF_GAUSSIAN{}, lambda_factor); break; case RadialBasisFunction::RBF_C4: interpolated_values = detail::mls_interpolation( source_values, source_coordinates, target_coordinates, support, dim, - degree, RBF_C4{}); + degree, RBF_C4{}, lambda_factor); break; case RadialBasisFunction::RBF_CONST: interpolated_values = detail::mls_interpolation( source_values, source_coordinates, target_coordinates, support, dim, - degree, RBF_CONST{}); + degree, RBF_CONST{}, lambda_factor); break; } diff --git a/src/pcms/interpolator/mls_interpolation.hpp b/src/pcms/interpolator/mls_interpolation.hpp index 17b085aa..4a3a5503 100644 --- a/src/pcms/interpolator/mls_interpolation.hpp +++ b/src/pcms/interpolator/mls_interpolation.hpp @@ -19,6 +19,7 @@ Write mls_interpolation(const Reals source_values, const Reals source_coordinates, const Reals target_coordinates, const SupportResults& support, const LO& dim, - const LO& degree, RadialBasisFunction bf); + const LO& degree, RadialBasisFunction bf, + double lambda_factor = 0); } // namespace pcms #endif diff --git a/src/pcms/interpolator/mls_interpolation_impl.hpp b/src/pcms/interpolator/mls_interpolation_impl.hpp index 5a79f64a..b8bd2553 100644 --- a/src/pcms/interpolator/mls_interpolation_impl.hpp +++ b/src/pcms/interpolator/mls_interpolation_impl.hpp @@ -231,6 +231,14 @@ void scale_column_trans_matrix(const ScratchMatView& matrix, } } +KOKKOS_INLINE_FUNCTION +void add_regularization(const member_type& team, ScratchMatView& square_matrix, + Real lambda_factor) +{ + Kokkos::parallel_for(Kokkos::TeamThreadRange(team, square_matrix.extent(0)), + [=](int i) { square_matrix(i, i) += lambda_factor; }); +} + /** * @struct ResultConvertNormal * @brief Stores the results of matrix and vector transformations. @@ -283,7 +291,8 @@ KOKKOS_INLINE_FUNCTION ResultConvertNormal convert_normal_equation(const ScratchMatView& matrix, const ScratchVecView& weight_vector, const ScratchVecView& rhs, - member_type team) + member_type team, + double lambda_factor) { int m = matrix.extent(0); @@ -319,6 +328,10 @@ ResultConvertNormal convert_normal_equation(const ScratchMatView& matrix, team.team_barrier(); + add_regularization(team, square_matrix, lambda_factor); + + team.team_barrier(); + KokkosBlas::Experimental:: Gemv::invoke( team, 'N', 1.0, scaled_matrix, rhs, 0.0, transformed_rhs); @@ -419,7 +432,8 @@ void mls_interpolation(RealConstDefaultScalarArrayView source_values, RealConstDefaultScalarArrayView target_coordinates, const SupportResults& support, const LO& dim, const LO& degree, Func rbf_func, - RealDefaultScalarArrayView approx_target_values) + RealDefaultScalarArrayView approx_target_values, + double lambda_factor) { PCMS_FUNCTION_TIMER; static_assert(std::is_invocable_r_v, @@ -550,8 +564,8 @@ void mls_interpolation(RealConstDefaultScalarArrayView source_values, team.team_barrier(); - auto result = convert_normal_equation(vandermonde_matrix, phi_vector, - support_values, team); + auto result = convert_normal_equation( + vandermonde_matrix, phi_vector, support_values, team, lambda_factor); team.team_barrier(); @@ -593,7 +607,8 @@ Write mls_interpolation(const Reals source_values, const Reals source_coordinates, const Reals target_coordinates, const SupportResults& support, const LO& dim, - const LO& degree, Func rbf_func) + const LO& degree, Func rbf_func, + double lambda_factor) { const auto nvertices_source = source_coordinates.size() / dim; @@ -619,7 +634,7 @@ Write mls_interpolation(const Reals source_values, mls_interpolation(source_values_array_view, source_coordinates_array_view, target_coordinates_array_view, support, dim, degree, - rbf_func, interpolated_values_array_view); + rbf_func, interpolated_values_array_view, lambda_factor); return interpolated_values; } diff --git a/test/test_linear_solver.cpp b/test/test_linear_solver.cpp index 2f8664ec..e11e8fbb 100644 --- a/test/test_linear_solver.cpp +++ b/test/test_linear_solver.cpp @@ -103,8 +103,9 @@ TEST_CASE("solver test") team.team_barrier(); + double lambda = 0.5; auto result = convert_normal_equation(vandermonde_matrix, phi, - support_values, team); + support_values, team, lambda); team.team_barrier(); @@ -159,10 +160,10 @@ TEST_CASE("solver test") Kokkos::View expected_solution( "expected solution", nvertices_target, size); - expected_lhs_matrix(0, 0, 0) = 35.0; + expected_lhs_matrix(0, 0, 0) = 35.5; expected_lhs_matrix(0, 0, 1) = 44.0; expected_lhs_matrix(0, 1, 0) = 44.0; - expected_lhs_matrix(0, 1, 1) = 56.0; + expected_lhs_matrix(0, 1, 1) = 56.5; expected_rhs_vector(0, 0) = 76.0; expected_rhs_vector(0, 1) = 100.0; @@ -174,13 +175,13 @@ TEST_CASE("solver test") expected_scaled_matrix(0, 1, 1) = 4.0; expected_scaled_matrix(0, 1, 2) = 6.0; - expected_solution(0, 0) = -6.0; - expected_solution(0, 1) = 6.5; + expected_solution(0, 0) = -1.519713; + expected_solution(0, 1) = 2.953405; - expected_lhs_matrix(1, 0, 0) = 30.0; + expected_lhs_matrix(1, 0, 0) = 30.5; expected_lhs_matrix(1, 0, 1) = 20.0; expected_lhs_matrix(1, 1, 0) = 20.0; - expected_lhs_matrix(1, 1, 1) = 30.0; + expected_lhs_matrix(1, 1, 1) = 30.5; expected_rhs_vector(1, 0) = 70.0; expected_rhs_vector(1, 1) = 40.0; @@ -192,8 +193,8 @@ TEST_CASE("solver test") expected_scaled_matrix(1, 1, 1) = -2.0; expected_scaled_matrix(1, 1, 2) = 1.0; - expected_solution(1, 0) = 2.6; - expected_solution(1, 1) = -0.4; + expected_solution(1, 0) = 2.517680; + expected_solution(1, 1) = -0.339463; for (int i = 0; i < nvertices_target; ++i) { for (int j = 0; j < size; ++j) { @@ -203,7 +204,7 @@ TEST_CASE("solver test") i, j, l); REQUIRE_THAT( expected_scaled_matrix(i, j, l), - Catch::Matchers::WithinAbs(host_result_scaled(i, j, l), 1E-10)); + Catch::Matchers::WithinAbs(host_result_scaled(i, j, l), 1E-6)); } } for (int j = 0; j < size; ++j) { @@ -213,14 +214,14 @@ TEST_CASE("solver test") k); REQUIRE_THAT( expected_lhs_matrix(i, j, k), - Catch::Matchers::WithinAbs(host_result_lhs(i, j, k), 1E-10)); + Catch::Matchers::WithinAbs(host_result_lhs(i, j, k), 1E-6)); } } for (int j = 0; j < size; ++j) { printf("A^T Q b: (%f,%f) at (%d,%d)\n", expected_rhs_vector(i, j), host_result_rhs(i, j), i, j); REQUIRE_THAT(expected_rhs_vector(i, j), - Catch::Matchers::WithinAbs(host_result_rhs(i, j), 1E-10)); + Catch::Matchers::WithinAbs(host_result_rhs(i, j), 1E-6)); } for (int j = 0; j < size; ++j) { @@ -228,7 +229,7 @@ TEST_CASE("solver test") host_result_solution(i, j), i, j); REQUIRE_THAT( expected_solution(i, j), - Catch::Matchers::WithinAbs(host_result_solution(i, j), 1E-10)); + Catch::Matchers::WithinAbs(host_result_solution(i, j), 1E-6)); } } }