Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 0 additions & 156 deletions stan/math/prim/fun/log_gamma_q_dgamma.hpp

This file was deleted.

148 changes: 39 additions & 109 deletions stan/math/prim/prob/gamma_lccdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,15 @@
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/fma.hpp>
#include <stan/math/prim/fun/gamma_p.hpp>
#include <stan/math/prim/fun/gamma_q.hpp>
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
#include <stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1m.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/fun/scalar_seq_view.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/tgamma.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>
#include <stan/math/prim/fun/log_gamma_q_dgamma.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/functor/partials_propagator.hpp>
#include <cmath>

Expand All @@ -29,9 +24,10 @@ namespace math {
template <typename T_y, typename T_shape, typename T_inv_scale>
inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
const T_y& y, const T_shape& alpha, const T_inv_scale& beta) {
using T_partials_return = partials_return_t<T_y, T_shape, T_inv_scale>;
using std::exp;
using std::log;
using T_partials_return = partials_return_t<T_y, T_shape, T_inv_scale>;
using std::pow;
using T_y_ref = ref_type_t<T_y>;
using T_alpha_ref = ref_type_t<T_shape>;
using T_beta_ref = ref_type_t<T_inv_scale>;
Expand All @@ -55,127 +51,61 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
scalar_seq_view<T_y_ref> y_vec(y_ref);
scalar_seq_view<T_alpha_ref> alpha_vec(alpha_ref);
scalar_seq_view<T_beta_ref> beta_vec(beta_ref);
const size_t N = max_size(y, alpha, beta);

constexpr bool any_fvar = is_fvar<scalar_type_t<T_y>>::value
|| is_fvar<scalar_type_t<T_shape>>::value
|| is_fvar<scalar_type_t<T_inv_scale>>::value;
constexpr bool partials_fvar = is_fvar<T_partials_return>::value;
size_t N = max_size(y, alpha, beta);

// Explicit return for extreme values
// The gradients are technically ill-defined, but treated as zero
for (size_t i = 0; i < stan::math::size(y); i++) {
if (y_vec.val(i) == 0) {
// LCCDF(0) = log(P(Y > 0)) = log(1) = 0
return ops_partials.build(0.0);
}
}

for (size_t n = 0; n < N; n++) {
// Explicit results for extreme values
// The gradients are technically ill-defined, but treated as zero
const T_partials_return y_dbl = y_vec.val(n);
if (y_dbl == 0.0) {
continue;
}
if (y_dbl == INFTY) {
if (y_vec.val(n) == INFTY) {
// LCCDF(∞) = log(P(Y > ∞)) = log(0) = -∞
return ops_partials.build(negative_infinity());
}

const T_partials_return y_dbl = y_vec.val(n);
const T_partials_return alpha_dbl = alpha_vec.val(n);
const T_partials_return beta_dbl = beta_vec.val(n);
const T_partials_return beta_y_dbl = beta_dbl * y_dbl;

const T_partials_return beta_y = beta_dbl * y_dbl;
if (beta_y == INFTY) {
return ops_partials.build(negative_infinity());
}
// Qn = 1 - Pn
const T_partials_return Qn = gamma_q(alpha_dbl, beta_y_dbl);
const T_partials_return log_Qn = log(Qn);

bool use_cf = beta_y > alpha_dbl + 1.0;
T_partials_return log_Qn;
[[maybe_unused]] T_partials_return dlogQ_dalpha = 0.0;

// Branch by autodiff type first, then handle use_cf logic inside each path
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
// var-only path: use log_gamma_q_dgamma which computes both log_q
// and its gradient analytically with double inputs
const double beta_y_dbl = value_of_rec(beta_y);
const double alpha_dbl_val = value_of_rec(alpha_dbl);

if (use_cf) {
auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
log_Qn = log_q_result.log_q;
dlogQ_dalpha = log_q_result.dlog_q_da;
} else {
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
log_Qn = log1m(Pn);
const T_partials_return Qn = exp(log_Qn);

// Check if we need to fallback to continued fraction
bool need_cf_fallback
= !std::isfinite(value_of_rec(log_Qn)) || Qn <= 0.0;
if (need_cf_fallback && beta_y > 0.0) {
auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
log_Qn = log_q_result.log_q;
dlogQ_dalpha = log_q_result.dlog_q_da;
} else {
dlogQ_dalpha = -grad_reg_lower_inc_gamma(alpha_dbl, beta_y) / Qn;
}
}
} else if constexpr (partials_fvar && is_autodiff_v<T_shape>) {
// fvar path: use unit derivative trick to compute gradients
T_partials_return alpha_unit = alpha_dbl;
alpha_unit.d_ = 1;
T_partials_return beta_unit = beta_y;
beta_unit.d_ = 0;

if (use_cf) {
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
T_partials_return log_Qn_fvar
= internal::log_q_gamma_cf(alpha_unit, beta_unit);
dlogQ_dalpha = log_Qn_fvar.d_;
} else {
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
log_Qn = log1m(Pn);

if (!std::isfinite(value_of_rec(log_Qn)) && beta_y > 0.0) {
// Fallback to continued fraction
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
T_partials_return log_Qn_fvar
= internal::log_q_gamma_cf(alpha_unit, beta_unit);
dlogQ_dalpha = log_Qn_fvar.d_;
} else {
T_partials_return log_Qn_fvar = log1m(gamma_p(alpha_unit, beta_unit));
dlogQ_dalpha = log_Qn_fvar.d_;
}
}
} else {
// No alpha derivative needed (alpha is constant or double-only)
if (use_cf) {
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
} else {
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
log_Qn = log1m(Pn);

if (!std::isfinite(value_of_rec(log_Qn)) && beta_y > 0.0) {
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
}
}
}
if (!std::isfinite(value_of_rec(log_Qn))) {
return ops_partials.build(negative_infinity());
}
P += log_Qn;

if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
const T_partials_return log_y = log(y_dbl);
const T_partials_return alpha_minus_one = fma(alpha_dbl, log_y, -log_y);

const T_partials_return log_pdf = alpha_dbl * log(beta_dbl)
- lgamma(alpha_dbl) + alpha_minus_one
- beta_y;

const T_partials_return hazard = exp(log_pdf - log_Qn); // f/Q
if constexpr (is_any_autodiff_v<T_y, T_inv_scale>) {
const T_partials_return log_y_dbl = log(y_dbl);
const T_partials_return log_beta_dbl = log(beta_dbl);
const T_partials_return log_pdf
= alpha_dbl * log_beta_dbl - lgamma(alpha_dbl)
+ (alpha_dbl - 1.0) * log_y_dbl - beta_y_dbl;
const T_partials_return common_term = exp(log_pdf - log_Qn);

if constexpr (is_autodiff_v<T_y>) {
partials<0>(ops_partials)[n] -= hazard;
// d/dy log(1-F(y)) = -f(y)/(1-F(y))
partials<0>(ops_partials)[n] -= common_term;
}
if constexpr (is_autodiff_v<T_inv_scale>) {
partials<2>(ops_partials)[n] -= (y_dbl / beta_dbl) * hazard;
// d/dbeta log(1-F(y)) = -y*f(y)/(beta*(1-F(y)))
partials<2>(ops_partials)[n] -= y_dbl / beta_dbl * common_term;
}
}

if constexpr (is_autodiff_v<T_shape>) {
partials<1>(ops_partials)[n] += dlogQ_dalpha;
const T_partials_return digamma_val = digamma(alpha_dbl);
const T_partials_return gamma_val = tgamma(alpha_dbl);
// d/dalpha log(1-F(y)) = grad_upper_inc_gamma / (1-F(y))
partials<1>(ops_partials)[n]
+= grad_reg_inc_gamma(alpha_dbl, beta_y_dbl, gamma_val, digamma_val)
/ Qn;
}
}
return ops_partials.build(P);
Expand Down
Loading