diff --git a/include/kde1d/kde1d.hpp b/include/kde1d/kde1d.hpp index 71da074..f6440a2 100644 --- a/include/kde1d/kde1d.hpp +++ b/include/kde1d/kde1d.hpp @@ -26,14 +26,16 @@ class Kde1d VarType type, double multiplier = 1.0, double bandwidth = NAN, - size_t degree = 2); + size_t degree = 2, + size_t grid_size = 400); Kde1d(double xmin = NAN, double xmax = NAN, std::string type = "continuous", double multiplier = 1.0, double bandwidth = NAN, - size_t degree = 2); + size_t degree = 2, + size_t grid_size = 400); Kde1d(const interp::InterpolationGrid& grid, double xmin, @@ -72,6 +74,8 @@ class Kde1d double get_multiplier() const { return multiplier_; } double get_bandwidth() const { return bandwidth_; } size_t get_degree() const { return degree_; } + size_t get_grid_size() const { return grid_size_; } + size_t get_actual_grid_size() const { return grid_.get_grid_points().size(); } double get_edf() const { return edf_; } double get_loglik() const { return loglik_; } void set_xmin_xmax(double xmin = NAN, double xmax = NAN); @@ -99,6 +103,7 @@ class Kde1d double multiplier_; double bandwidth_; size_t degree_; + size_t grid_size_; double prob0_{ 0.0 }; double loglik_{ NAN }; double edf_{ NAN }; @@ -161,18 +166,21 @@ class Kde1d //! @param bandwidth positive bandwidth parameter (`NaN` means automatic //! selection). //! @param degree degree of the local polynomial. +//! @param grid_size number of grid points for the interpolation grid. inline Kde1d::Kde1d(double xmin, double xmax, VarType type, double multiplier, double bandwidth, - size_t degree) + size_t degree, + size_t grid_size) : xmin_(xmin) , xmax_(xmax) , type_(type) , multiplier_(multiplier) , bandwidth_(bandwidth) , degree_(degree) + , grid_size_(grid_size) { this->check_xmin_xmax(xmin, xmax); if (multiplier <= 0.0) { @@ -184,6 +192,9 @@ inline Kde1d::Kde1d(double xmin, if (degree_ > 2) { throw std::invalid_argument("degree must be 0, 1 or 2"); } + if (grid_size_ < 4) { + throw std::invalid_argument("grid_size must be at least 4"); + } } //! construct model from an already fit interpolation grid. @@ -206,6 +217,7 @@ inline Kde1d::Kde1d(const interp::InterpolationGrid& grid, , xmin_(xmin) , xmax_(xmax) , type_(type) + , grid_size_(grid.get_grid_points().size()) , prob0_(prob0) { this->check_xmin_xmax(xmin, xmax); @@ -227,13 +239,21 @@ inline Kde1d::Kde1d(const interp::InterpolationGrid& grid, //! @param bandwidth positive bandwidth parameter (`NaN` means automatic //! selection). //! @param degree degree of the local polynomial. +//! @param grid_size number of grid points for the interpolation grid. inline Kde1d::Kde1d(double xmin, double xmax, std::string type, double multiplier, double bandwidth, - size_t degree) - : Kde1d(xmin, xmax, this->as_enum(type), multiplier, bandwidth, degree) + size_t degree, + size_t grid_size) + : Kde1d(xmin, + xmax, + this->as_enum(type), + multiplier, + bandwidth, + degree, + grid_size) { } @@ -582,7 +602,7 @@ Kde1d::fit_lp(const Eigen::VectorXd& x, { size_t m = grid_points.size(); fft::KdeFFT kde_fft( - x, bandwidth_, grid_points(0), grid_points(m - 1), weights); + x, bandwidth_, grid_points(0), grid_points(m - 1), weights, m - 1); Eigen::VectorXd f0 = kde_fft.kde_drv(0); Eigen::VectorXd f1(f0.size()), f2(f0.size()); @@ -761,7 +781,7 @@ Kde1d::boundary_correct(const Eigen::VectorXd& x, const Eigen::VectorXd& fhat) //! constructs a grid later used for interpolation //! @param x vector of observations. -//! @return a grid of size 50. +//! @return a grid of size 400. inline Eigen::VectorXd Kde1d::construct_grid_points(const Eigen::VectorXd& x) { @@ -771,7 +791,7 @@ Kde1d::construct_grid_points(const Eigen::VectorXd& x) rng(0) -= 4 * bandwidth_; rng(1) += 4 * bandwidth_; } - auto zgrid = Eigen::VectorXd::LinSpaced(401, rng(0), rng(1)); + auto zgrid = Eigen::VectorXd::LinSpaced(grid_size_ + 1, rng(0), rng(1)); return boundary_transform(zgrid, true); } diff --git a/include/kde1d/kdefft.hpp b/include/kde1d/kdefft.hpp index 5a9bd1a..1434486 100644 --- a/include/kde1d/kdefft.hpp +++ b/include/kde1d/kdefft.hpp @@ -19,7 +19,8 @@ class KdeFFT double bandwidth, double lower, double upper, - const Eigen::VectorXd& weights = Eigen::VectorXd()); + const Eigen::VectorXd& weights = Eigen::VectorXd(), + size_t num_bins = 400); Eigen::VectorXd kde_drv(unsigned drv) const; Eigen::VectorXd get_bin_counts() const { return bin_counts_; }; @@ -29,7 +30,7 @@ class KdeFFT double bandwidth_; double lower_; double upper_; - static constexpr size_t num_bins_{ 400 }; + size_t num_bins_; Eigen::VectorXd bin_counts_; }; @@ -38,14 +39,17 @@ class KdeFFT //! @param lower lower bound of the grid. //! @param upper bound of the grid. //! @param weigths optional vector of weights for each observation. +//! @param num_bins number of bins for the FFT grid. inline KdeFFT::KdeFFT(const Eigen::VectorXd& x, double bandwidth, double lower, double upper, - const Eigen::VectorXd& weights) + const Eigen::VectorXd& weights, + size_t num_bins) : bandwidth_(bandwidth) , lower_(lower) , upper_(upper) + , num_bins_(num_bins) { if (weights.size() > 0 && (weights.size() != x.size())) throw std::invalid_argument("x and weights must have the same size"); @@ -65,7 +69,7 @@ inline KdeFFT::KdeFFT(const Eigen::VectorXd& x, inline Eigen::VectorXd KdeFFT::kde_drv(unsigned drv) const { - double delta = (upper_ - lower_) / num_bins_; + double delta = (upper_ - lower_) / static_cast(num_bins_); double tau = 4.0 + drv; size_t L = static_cast(std::floor(tau * bandwidth_ / delta)); L = std::min(L, num_bins_ + 1); diff --git a/test/test.cpp b/test/test.cpp index 6327ea5..d6dc704 100644 --- a/test/test.cpp +++ b/test/test.cpp @@ -22,6 +22,144 @@ size_t nlevels = 50; Eigen::VectorXd x_d = (x_cb.array() * (static_cast(nlevels) - 1)).round(); +TEST_CASE("grid_size parameter", "[grid-size]") +{ + + SECTION("constructor accepts grid_size parameter") + { + // Test VarType constructor + kde1d::Kde1d fit1(NAN, NAN, kde1d::VarType::continuous, 1.0, NAN, 2, 100); + CHECK(fit1.get_grid_size() == 100); + + // Test string constructor + kde1d::Kde1d fit2(NAN, NAN, "continuous", 1.0, NAN, 2, 200); + CHECK(fit2.get_grid_size() == 200); + } + + SECTION("grid_size validation") + { + // Should throw for grid_size < 4 + CHECK_THROWS(kde1d::Kde1d(NAN, NAN, "continuous", 1.0, NAN, 2, 3)); + CHECK_THROWS(kde1d::Kde1d(NAN, NAN, "continuous", 1.0, NAN, 2, 0)); + + // Should work for grid_size >= 4 + CHECK_NOTHROW(kde1d::Kde1d(NAN, NAN, "continuous", 1.0, NAN, 2, 4)); + CHECK_NOTHROW(kde1d::Kde1d(NAN, NAN, "continuous", 1.0, NAN, 2, 50)); + } + + SECTION("default grid_size is 400") + { + kde1d::Kde1d fit; // Use default constructor + CHECK(fit.get_grid_size() == 400); + } + + SECTION("grid_size affects interpolation grid size") + { + // Just test that the grid_size parameter is stored correctly + std::vector grid_sizes = {50, 100, 200}; + + for (size_t requested_size : grid_sizes) { + kde1d::Kde1d fit(NAN, NAN, "continuous", 1.0, NAN, 2, requested_size); + + // Check that requested grid size is stored correctly + CHECK(fit.get_grid_size() == requested_size); + } + } + + SECTION("grid_size works after fitting") + { + // Test various grid sizes to ensure they work properly now + std::vector test_sizes = {50, 100, 200, 400, 600}; + + for (size_t grid_size : test_sizes) { + kde1d::Kde1d fit(NAN, NAN, "continuous", 1.0, NAN, 2, grid_size); + CHECK_NOTHROW(fit.fit(x_ub)); + + // Check that requested grid size is stored correctly + CHECK(fit.get_grid_size() == grid_size); + + // Check that we can call methods that depend on the fitted model + CHECK_NOTHROW(fit.pdf(x_ub)); + CHECK_NOTHROW(fit.cdf(x_ub)); + CHECK_NOTHROW(fit.quantile(ugrid)); + + // Check that actual grid size matches requested size + size_t actual_size = fit.get_actual_grid_size(); + CHECK(actual_size == grid_size + 1); // Grid points = grid_size + 1 + } + } + + SECTION("grid_size affects estimation with different data types") + { + size_t test_grid_size = 150; + + // Continuous data + kde1d::Kde1d fit_cont(NAN, NAN, "continuous", 1.0, NAN, 2, test_grid_size); + CHECK_NOTHROW(fit_cont.fit(x_ub)); + CHECK(fit_cont.get_grid_size() == test_grid_size); + + // Discrete data + kde1d::Kde1d fit_disc(NAN, NAN, "discrete", 1.0, NAN, 2, test_grid_size); + CHECK_NOTHROW(fit_disc.fit(x_d)); + CHECK(fit_disc.get_grid_size() == test_grid_size); + + // Zero-inflated data + Eigen::VectorXd x_zi = x_lb; + x_zi.head(n_sample / 4).setZero(); + kde1d::Kde1d fit_zi(0, NAN, "zero-inflated", 1.0, NAN, 2, test_grid_size); + CHECK_NOTHROW(fit_zi.fit(x_zi)); + CHECK(fit_zi.get_grid_size() == test_grid_size); + } + + SECTION("grid_size works with boundaries") + { + size_t test_grid_size = 120; + + // Left boundary + kde1d::Kde1d fit_lb(0, NAN, "continuous", 1.0, NAN, 2, test_grid_size); + CHECK_NOTHROW(fit_lb.fit(x_lb)); + CHECK(fit_lb.get_grid_size() == test_grid_size); + + // Right boundary + kde1d::Kde1d fit_rb(NAN, 0, "continuous", 1.0, NAN, 2, test_grid_size); + CHECK_NOTHROW(fit_rb.fit(x_rb)); + CHECK(fit_rb.get_grid_size() == test_grid_size); + + // Both boundaries + kde1d::Kde1d fit_bb(0, 1, "continuous", 1.0, NAN, 2, test_grid_size); + CHECK_NOTHROW(fit_bb.fit(x_cb)); + CHECK(fit_bb.get_grid_size() == test_grid_size); + } + + SECTION("grid_size affects estimation accuracy") + { + auto points = stats::qnorm(upoints); + auto target = stats::dnorm(points); + + // Test with small grid size + kde1d::Kde1d fit_small(NAN, NAN, "continuous", 1.0, NAN, 2, 50); + fit_small.fit(x_ub); + auto pdf_small = fit_small.pdf(points); + + // Test with large grid size + kde1d::Kde1d fit_large(NAN, NAN, "continuous", 1.0, NAN, 2, 800); + fit_large.fit(x_ub); + auto pdf_large = fit_large.pdf(points); + + // Both should be reasonable approximations + double error_small = (pdf_small - target).array().abs().mean(); + double error_large = (pdf_large - target).array().abs().mean(); + + // Both errors should be reasonable (less than the tolerance used elsewhere) + CHECK(error_small <= pdf_tol); + CHECK(error_large <= pdf_tol); + + // Generally, larger grid should perform at least as well or better + // (though for very large grids, numerical issues might make this not always true) + CHECK(error_large <= error_small * 2.0); // Allow reasonable tolerance + } +} + TEST_CASE("misc checks", "[input-checks][argument-checks]") {