Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions include/kde1d/kde1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ class Kde1d
VarType type,
double multiplier = 1.0,
double bandwidth = NAN,
size_t degree = 2);
size_t degree = 2,
size_t grid_size = 400);

Kde1d(double xmin = NAN,
double xmax = NAN,
std::string type = "continuous",
double multiplier = 1.0,
double bandwidth = NAN,
size_t degree = 2);
size_t degree = 2,
size_t grid_size = 400);

Kde1d(const interp::InterpolationGrid& grid,
double xmin,
Expand Down Expand Up @@ -72,6 +74,8 @@ class Kde1d
double get_multiplier() const { return multiplier_; }
double get_bandwidth() const { return bandwidth_; }
size_t get_degree() const { return degree_; }
size_t get_grid_size() const { return grid_size_; }
size_t get_actual_grid_size() const { return grid_.get_grid_points().size(); }
double get_edf() const { return edf_; }
double get_loglik() const { return loglik_; }
void set_xmin_xmax(double xmin = NAN, double xmax = NAN);
Expand Down Expand Up @@ -99,6 +103,7 @@ class Kde1d
double multiplier_;
double bandwidth_;
size_t degree_;
size_t grid_size_;
double prob0_{ 0.0 };
double loglik_{ NAN };
double edf_{ NAN };
Expand Down Expand Up @@ -161,18 +166,21 @@ class Kde1d
//! @param bandwidth positive bandwidth parameter (`NaN` means automatic
//! selection).
//! @param degree degree of the local polynomial.
//! @param grid_size number of grid points for the interpolation grid.
inline Kde1d::Kde1d(double xmin,
double xmax,
VarType type,
double multiplier,
double bandwidth,
size_t degree)
size_t degree,
size_t grid_size)
: xmin_(xmin)
, xmax_(xmax)
, type_(type)
, multiplier_(multiplier)
, bandwidth_(bandwidth)
, degree_(degree)
, grid_size_(grid_size)
{
this->check_xmin_xmax(xmin, xmax);
if (multiplier <= 0.0) {
Expand All @@ -184,6 +192,9 @@ inline Kde1d::Kde1d(double xmin,
if (degree_ > 2) {
throw std::invalid_argument("degree must be 0, 1 or 2");
}
if (grid_size_ < 4) {
throw std::invalid_argument("grid_size must be at least 4");
}
}

//! construct model from an already fit interpolation grid.
Expand All @@ -206,6 +217,7 @@ inline Kde1d::Kde1d(const interp::InterpolationGrid& grid,
, xmin_(xmin)
, xmax_(xmax)
, type_(type)
, grid_size_(grid.get_grid_points().size())
, prob0_(prob0)
{
this->check_xmin_xmax(xmin, xmax);
Expand All @@ -227,13 +239,21 @@ inline Kde1d::Kde1d(const interp::InterpolationGrid& grid,
//! @param bandwidth positive bandwidth parameter (`NaN` means automatic
//! selection).
//! @param degree degree of the local polynomial.
//! @param grid_size number of grid points for the interpolation grid.
inline Kde1d::Kde1d(double xmin,
double xmax,
std::string type,
double multiplier,
double bandwidth,
size_t degree)
: Kde1d(xmin, xmax, this->as_enum(type), multiplier, bandwidth, degree)
size_t degree,
size_t grid_size)
: Kde1d(xmin,
xmax,
this->as_enum(type),
multiplier,
bandwidth,
degree,
grid_size)
{
}

Expand Down Expand Up @@ -582,7 +602,7 @@ Kde1d::fit_lp(const Eigen::VectorXd& x,
{
size_t m = grid_points.size();
fft::KdeFFT kde_fft(
x, bandwidth_, grid_points(0), grid_points(m - 1), weights);
x, bandwidth_, grid_points(0), grid_points(m - 1), weights, m - 1);
Eigen::VectorXd f0 = kde_fft.kde_drv(0);
Eigen::VectorXd f1(f0.size()), f2(f0.size());

Expand Down Expand Up @@ -761,7 +781,7 @@ Kde1d::boundary_correct(const Eigen::VectorXd& x, const Eigen::VectorXd& fhat)

//! constructs a grid later used for interpolation
//! @param x vector of observations.
//! @return a grid of size 50.
//! @return a grid of size 400.
inline Eigen::VectorXd
Kde1d::construct_grid_points(const Eigen::VectorXd& x)
{
Expand All @@ -771,7 +791,7 @@ Kde1d::construct_grid_points(const Eigen::VectorXd& x)
rng(0) -= 4 * bandwidth_;
rng(1) += 4 * bandwidth_;
}
auto zgrid = Eigen::VectorXd::LinSpaced(401, rng(0), rng(1));
auto zgrid = Eigen::VectorXd::LinSpaced(grid_size_ + 1, rng(0), rng(1));
return boundary_transform(zgrid, true);
}

Expand Down
12 changes: 8 additions & 4 deletions include/kde1d/kdefft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class KdeFFT
double bandwidth,
double lower,
double upper,
const Eigen::VectorXd& weights = Eigen::VectorXd());
const Eigen::VectorXd& weights = Eigen::VectorXd(),
size_t num_bins = 400);

Eigen::VectorXd kde_drv(unsigned drv) const;
Eigen::VectorXd get_bin_counts() const { return bin_counts_; };
Expand All @@ -29,7 +30,7 @@ class KdeFFT
double bandwidth_;
double lower_;
double upper_;
static constexpr size_t num_bins_{ 400 };
size_t num_bins_;
Eigen::VectorXd bin_counts_;
};

Expand All @@ -38,14 +39,17 @@ class KdeFFT
//! @param lower lower bound of the grid.
//! @param upper bound of the grid.
//! @param weigths optional vector of weights for each observation.
//! @param num_bins number of bins for the FFT grid.
inline KdeFFT::KdeFFT(const Eigen::VectorXd& x,
double bandwidth,
double lower,
double upper,
const Eigen::VectorXd& weights)
const Eigen::VectorXd& weights,
size_t num_bins)
: bandwidth_(bandwidth)
, lower_(lower)
, upper_(upper)
, num_bins_(num_bins)
{
if (weights.size() > 0 && (weights.size() != x.size()))
throw std::invalid_argument("x and weights must have the same size");
Expand All @@ -65,7 +69,7 @@ inline KdeFFT::KdeFFT(const Eigen::VectorXd& x,
inline Eigen::VectorXd
KdeFFT::kde_drv(unsigned drv) const
{
double delta = (upper_ - lower_) / num_bins_;
double delta = (upper_ - lower_) / static_cast<double>(num_bins_);
double tau = 4.0 + drv;
size_t L = static_cast<size_t>(std::floor(tau * bandwidth_ / delta));
L = std::min(L, num_bins_ + 1);
Expand Down
138 changes: 138 additions & 0 deletions test/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,144 @@ size_t nlevels = 50;
Eigen::VectorXd x_d =
(x_cb.array() * (static_cast<double>(nlevels) - 1)).round();

TEST_CASE("grid_size parameter", "[grid-size]")
{

SECTION("constructor accepts grid_size parameter")
{
// Test VarType constructor
kde1d::Kde1d fit1(NAN, NAN, kde1d::VarType::continuous, 1.0, NAN, 2, 100);
CHECK(fit1.get_grid_size() == 100);

// Test string constructor
kde1d::Kde1d fit2(NAN, NAN, "continuous", 1.0, NAN, 2, 200);
CHECK(fit2.get_grid_size() == 200);
}

SECTION("grid_size validation")
{
// Should throw for grid_size < 4
CHECK_THROWS(kde1d::Kde1d(NAN, NAN, "continuous", 1.0, NAN, 2, 3));
CHECK_THROWS(kde1d::Kde1d(NAN, NAN, "continuous", 1.0, NAN, 2, 0));

// Should work for grid_size >= 4
CHECK_NOTHROW(kde1d::Kde1d(NAN, NAN, "continuous", 1.0, NAN, 2, 4));
CHECK_NOTHROW(kde1d::Kde1d(NAN, NAN, "continuous", 1.0, NAN, 2, 50));
}

SECTION("default grid_size is 400")
{
kde1d::Kde1d fit; // Use default constructor
CHECK(fit.get_grid_size() == 400);
}

SECTION("grid_size affects interpolation grid size")
{
// Just test that the grid_size parameter is stored correctly
std::vector<size_t> grid_sizes = {50, 100, 200};

for (size_t requested_size : grid_sizes) {
kde1d::Kde1d fit(NAN, NAN, "continuous", 1.0, NAN, 2, requested_size);

// Check that requested grid size is stored correctly
CHECK(fit.get_grid_size() == requested_size);
}
}

SECTION("grid_size works after fitting")
{
// Test various grid sizes to ensure they work properly now
std::vector<size_t> test_sizes = {50, 100, 200, 400, 600};

for (size_t grid_size : test_sizes) {
kde1d::Kde1d fit(NAN, NAN, "continuous", 1.0, NAN, 2, grid_size);
CHECK_NOTHROW(fit.fit(x_ub));

// Check that requested grid size is stored correctly
CHECK(fit.get_grid_size() == grid_size);

// Check that we can call methods that depend on the fitted model
CHECK_NOTHROW(fit.pdf(x_ub));
CHECK_NOTHROW(fit.cdf(x_ub));
CHECK_NOTHROW(fit.quantile(ugrid));

// Check that actual grid size matches requested size
size_t actual_size = fit.get_actual_grid_size();
CHECK(actual_size == grid_size + 1); // Grid points = grid_size + 1
}
}

SECTION("grid_size affects estimation with different data types")
{
size_t test_grid_size = 150;

// Continuous data
kde1d::Kde1d fit_cont(NAN, NAN, "continuous", 1.0, NAN, 2, test_grid_size);
CHECK_NOTHROW(fit_cont.fit(x_ub));
CHECK(fit_cont.get_grid_size() == test_grid_size);

// Discrete data
kde1d::Kde1d fit_disc(NAN, NAN, "discrete", 1.0, NAN, 2, test_grid_size);
CHECK_NOTHROW(fit_disc.fit(x_d));
CHECK(fit_disc.get_grid_size() == test_grid_size);

// Zero-inflated data
Eigen::VectorXd x_zi = x_lb;
x_zi.head(n_sample / 4).setZero();
kde1d::Kde1d fit_zi(0, NAN, "zero-inflated", 1.0, NAN, 2, test_grid_size);
CHECK_NOTHROW(fit_zi.fit(x_zi));
CHECK(fit_zi.get_grid_size() == test_grid_size);
}

SECTION("grid_size works with boundaries")
{
size_t test_grid_size = 120;

// Left boundary
kde1d::Kde1d fit_lb(0, NAN, "continuous", 1.0, NAN, 2, test_grid_size);
CHECK_NOTHROW(fit_lb.fit(x_lb));
CHECK(fit_lb.get_grid_size() == test_grid_size);

// Right boundary
kde1d::Kde1d fit_rb(NAN, 0, "continuous", 1.0, NAN, 2, test_grid_size);
CHECK_NOTHROW(fit_rb.fit(x_rb));
CHECK(fit_rb.get_grid_size() == test_grid_size);

// Both boundaries
kde1d::Kde1d fit_bb(0, 1, "continuous", 1.0, NAN, 2, test_grid_size);
CHECK_NOTHROW(fit_bb.fit(x_cb));
CHECK(fit_bb.get_grid_size() == test_grid_size);
}

SECTION("grid_size affects estimation accuracy")
{
auto points = stats::qnorm(upoints);
auto target = stats::dnorm(points);

// Test with small grid size
kde1d::Kde1d fit_small(NAN, NAN, "continuous", 1.0, NAN, 2, 50);
fit_small.fit(x_ub);
auto pdf_small = fit_small.pdf(points);

// Test with large grid size
kde1d::Kde1d fit_large(NAN, NAN, "continuous", 1.0, NAN, 2, 800);
fit_large.fit(x_ub);
auto pdf_large = fit_large.pdf(points);

// Both should be reasonable approximations
double error_small = (pdf_small - target).array().abs().mean();
double error_large = (pdf_large - target).array().abs().mean();

// Both errors should be reasonable (less than the tolerance used elsewhere)
CHECK(error_small <= pdf_tol);
CHECK(error_large <= pdf_tol);

// Generally, larger grid should perform at least as well or better
// (though for very large grids, numerical issues might make this not always true)
CHECK(error_large <= error_small * 2.0); // Allow reasonable tolerance
}
}

TEST_CASE("misc checks", "[input-checks][argument-checks]")
{

Expand Down
Loading