diff --git a/docs/slic/constraints.qmd b/docs/slic/constraints.qmd new file mode 100644 index 0000000..9efde11 --- /dev/null +++ b/docs/slic/constraints.qmd @@ -0,0 +1,376 @@ +--- +title: "User defined reusable constraints" +--- + +The full Julia code for this notebook can be accessed via the top right corner (` Code`). + +The Julia packages needed to reproduce this document are [`StanBlocks.jl`](https://github.com/nsiccha/StanBlocks.jl) (for the model generation) and [`QuartoComponents.jl`](https://github.com/nsiccha/QuartoComponents.jl) (for the "pretty" printing). +Both packages have to be installed from the latest `main` branch (as of Oct 14th 2025). + +## StanBlocks.jl implementation + +The function and model definitions below make use of + +* [Julia-style Named Tuples](https://docs.julialang.org/en/v1/base/base/#Core.NamedTuple) - allowing `.` syntax access to fields via names (instead of [via numbers as in Stan](https://mc-stan.org/docs/reference-manual/types.html#assigning-tuple-elements)). + +# A uniform prior on a disk + +## Full Julia + StanBlocks.jl code to define the models + +### Current status + +Currently, this is roughly how I'd add a custom constraining transform + prior on the constraint variables: + +```julia +using StanBlocks + +@deffun begin + """ + This function should of course actually return the jacobian adjustment of the `uniform_disk_constrain` transform + (wrt to `x` and `y`) - I'll work it out another time. + """ + uniform_disk_lpdf(xi::vector[n], n) = 0. + """ + This function returns the constrained parameters `x` and `y` together with the intermediate quantities `radius` and `angle` as a a Named Tuple, + which for Stan will look like a regular Tuple. + Within StanBlocks.jl, its possible to access the fields of the return value via `.` syntax, + e.g. `rv.radius` or `rv.x` after `rv = uniform_disk_constrain(xi)`. + """ + uniform_disk_constrain(xi::vector[n]) = begin + radius = inv_logit(xi[1]) + angle = 2pi * inv_logit(xi[2]) + (;radius, angle, x=cos(angle) * radius, y=sin(angle) * radius) + end +end + +disk = @slic begin + """ + This initializes the unconstrained parameters and adds the jacobian adjustment (and any potential prior). + + If you want to also put a prior on `x` and `y`, you'd currently (in StanBlocks.jl) have to write a new `_lpdf` function - e.g. `non_uniform_disk_lpdf`. + """ + xi ~ uniform_disk(2) + """ + This does the constraining. + + It *would* be more convenient if there was an easier mechanism to combine (custom) "constraining" and "putting a prior on the constrained parameters".. + """ + return uniform_disk_constrain(xi) +end +``` + +With the `disk` model defined (**by the user in this notebook - not in the back end**) as above, usage would look like this, +where `obs` and `obs_scale` come from somewhere else in this example: + +```julia +model_using_disk = @slic begin + disk_parameters ~ disk() + obs ~ normal( + some_binary_function_working_on_disk_coordinates( + disk_parameters.x, disk_parameters.y + ), + obs_scale + ) +end +``` + +### Future status + +TO DO: figure out a syntax which would make custom constraints even easier. + +## Generated Stan code + +For the generated Stan code, the disk model has been spliced into a dummy model, resulting in a model definition equivalent to + +```julia +dummy_model = @slic begin + "Needed to prevent StanBlocks.jl from moving `theta` to generated quantities" + dummy_obs = 1. + theta ~ disk() + "Needed to prevent StanBlocks.jl from moving `theta` to generated quantities" + dummy_obs ~ dummy(theta) +end +``` + +Without the dummy likelihood term, StanBlocks.jl would move `theta` to generated quantities, because it wouldn't affect the likelihood. + +```{julia} +using StanBlocks, QuartoComponents + +@deffun begin + """ + This function should of course actually return the jacobian adjustment of the `uniform_disk_constrain` transform + (wrt to `x` and `y`) - I'll work it out another time. + """ + uniform_disk_lpdf(xi::vector[n], n) = 0. + """ + This function returns the constrained parameters `x` and `y` together with the intermediate quantities `radius` and `angle` as a a Named Tuple, + which for Stan will look like a regular Tuple. + Within StanBlocks.jl, its possible to access the fields of the return value via `.` syntax, + e.g. `rv.radius` or `rv.x` after `rv = uniform_disk_constrain(xi)`. + """ + uniform_disk_constrain(xi::vector[n]) = begin + radius = inv_logit(xi[1]) + angle = 2pi * inv_logit(xi[2]) + (;radius, angle, x=cos(angle) * radius, y=sin(angle) * radius) + end + "Needed to prevent StanBlocks.jl from moving `theta` to generated quantities" + dummy_lpdf(args...) = 0. + "Needed to prevent StanBlocks.jl from moving `theta` to generated quantities" + dummy_lpdfs(args...) = 0. + "Needed to prevent StanBlocks.jl from moving `theta` to generated quantities" + dummy_rng(args...) = 0. +end + + +disk = @slic begin + """ + This initializes the unconstrained parameters and adds the jacobian adjustment (and any potential prior). + + If you want to also put a prior on `x` and `y`, you'd currently (in StanBlocks.jl) have to write a new `_lpdf` function - e.g. `non_uniform_disk_lpdf`. + """ + xi ~ uniform_disk(2) + """ + This does the constraining. + + It *would* be more convenient if there was an easier mechanism to combine (custom) "constraining" and "putting a prior on the constrained parameters".. + """ + return uniform_disk_constrain(xi) +end + +dummy_model = @slic begin + "Needed to prevent StanBlocks.jl from moving `theta` to generated quantities" + dummy_obs = 1. + "Needed to prevent StanBlocks.jl from moving `theta` to generated quantities" + dummy_obs ~ dummy(theta) +end + +disk_posteriors = (; + disk=dummy_model(quote + theta ~ disk() + end) +) + +map(disk_posteriors) do posterior + QuartoComponents.Code("stan", stan_code(posterior)) +end |> QuartoComponents.Tabset +``` + +# (Adaptive?) centering for hierarchical models + +If [this discourse thread](https://discourse.mc-stan.org/t/offset-multiplier-initialization/20712/20) is any indication, it looks like generally, it might be preferable to be able compute the Jacobian adjustment and the prior density together for the sake of numerical stability. +This is something that Stan's lack of compound declare-distribute statements "prevents" - see also [When to add a feature to Stan? The recurring issue of the compound declare-distribute statement](https://statmodeling.stat.columbia.edu/2018/02/01/stan-feature-declare-distribute/). + +Furthermore, it's probably a good idea to be able to parametrize the constraining transformations, allowing to vary e.g. the centeredness after defining the model via `data` arguments. + +TO DO: elaborate + + +# Various simplex constraining transformations + +The underlying code reproduces all transformations from [https://github.com/bob-carpenter/transforms](https://github.com/bob-carpenter/transforms). + +## Example Julia + StanBlocks.jl code to define the models + +### Current status + +It does so by first defining (**by the user in this notebook - not in the back end**) a general model + +```julia +any_simplex = @slic begin + xi ~ simplex_prior(constrain_f, prior_f, n) + return simplex_constrain(xi, constrain_f) +end +``` + +where `n` is the dimension of the simplex and `constrain_f` and `prior_f` can be used to change the used constraining transform or the imposed prior. + +Defining a custom simplex constraining transform is then as easy as defining two functions, + +* a function `f` accepting the vector of unconstrained parameters and returning a Named Tuple with fields `jac` (the jacobian adjustment) and `x` (the constrained parameters) +(to be passed as the `constrain_f` argument), +* and an overload of `unconstrained_dim(::typeof(f), n)` returning the dimension of the unconstrained parameters accepted by `f`. + +See below for an example implementation for the ALR constraining transform: + +```julia +constrain_alr(xi::vector[n]) = begin + r = log1p_exp(log_sum_exp(xi)) + (; + jac=sum(xi)-(n+1)*r, + x=exp(append_row(xi - r, -r)) + ) +end +unconstrained_dim(::typeof(constrain_alr), n) = n-1 +``` + +The two functions used in the `any_simplex` model, `simplex_prior_lpdf` and `simplex_constrain` are defined (**by the user in this notebook - not in the back end**) as + +```julia +simplex_prior_lpdf(xi::vector[unconstrained_dim(constrain_f, n)], constrain_f, prior_f, n) = begin + tmp = constrain_f(xi) + tmp.jac + prior_f(tmp.x) +end +simplex_constrain(xi, constrain_f) = constrain_f(xi).x +``` + +### Current limitations + +Currently, both the jacobian adjustment have to be computed twice - once in `simplex_prior_lpdf` and once in `simplex_constrain`. +Ideally, this could be avoided. + +### Future status + +TO DO: Think about how to best prevent the double work. The main work would be to think about appropriate syntax signalling to StanBlocks.jl that these functions work together to compute the constraining and the jacobian adjustment. + +## Generated Stan code + +For the generated Stan code, the `any_simplex` model has been spliced into a dummy model, resulting in a model definition equivalent to + +```julia +dummy_simplex_model = @slic begin + "Needed to prevent StanBlocks.jl from moving `theta` to generated quantities" + dummy_obs = 1. + theta ~ any_simplex(;constrain_f, prior_f, n=10) + "Needed to prevent StanBlocks.jl from moving `theta` to generated quantities" + dummy_obs ~ dummy(theta) +end +``` + +where again without the dummy likelihood StanBlocks.jl would move `theta` to generated quantities, because it wouldn't affect the likelihood. + +```{julia} +@deffun begin + Base.reverse(x::vector[n])::vector[n] + StanBlocks.stan.std_normal_lcdf(x)::real + Base.log2()::real + std_normal_lcdfs(x::real) = std_normal_lcdf(x) + std_normal_lcdfs(x::vector[n]) = jbroadcasted(std_normal_lcdfs, x) + + unconstrained_dim(constrain_f, n) = reject(n) + simplex_prior_lpdf(xi::vector[unconstrained_dim(constrain_f, n)], constrain_f, prior_f, n) = begin + tmp = constrain_f(xi) + tmp.jac + prior_f(tmp.x) + end + simplex_constrain(xi, constrain_f) = constrain_f(xi).x + uniform_simplex_lpdf(xi) = 0. + + constrain_alr(xi::vector[n]) = begin + r = log1p_exp(log_sum_exp(xi)) + (; + jac=sum(xi)-(n+1)*r, + x=exp(append_row(xi - r, -r)) + ) + end + unconstrained_dim(::typeof(constrain_alr), n) = n-1 + constrain_expanded_softmax(xi::vector[n]) = begin + r = log_sum_exp(xi) + (; + jac=std_normal_lpdf(r - log(n)) + sum(xi) - n * r, x=exp(xi - r)) + end + unconstrained_dim(::typeof(constrain_expanded_softmax), n) = n + constrain_ilr(xi::vector[n]) = begin + ns = linspaced_vector(n, 1, n) + w = xi ./ sqrt(ns .* (ns + 1)) + z = append_row(reverse(cumulative_sum(reverse(w))), 0) - append_row(0, ns .* w) + r = log_sum_exp(z) + (; + jac=0.5 * log(n+1)+sum(z) - (n+1) * r, + x=exp(z - r) + ) + end + unconstrained_dim(::typeof(constrain_ilr), n) = n-1 + constrain_ilr_reflector(xi::vector[n]) = begin + sqrtN = sqrt((n+1)) + zN = sum(xi) / sqrtN + z = append_row(xi - zN ./ (sqrtN - 1), zN) + r = log_sum_exp(z) + (; + jac=0.5 * log(n+1)+sum(z) - (n+1) * r, + x=exp(z - r) + ) + end + unconstrained_dim(::typeof(constrain_ilr_reflector), n) = n-1 + exponential_log_qf(x) = -log1m_exp(x) + constrain_normalized_exponential(xi::vector[n]) = begin + z = log(exponential_log_qf(std_normal_lcdfs(xi))) + r = log_sum_exp(z) + (; + jac=std_normal_lpdf(xi) - lgamma(n), + x=exp(z - r) + ) + end + unconstrained_dim(::typeof(constrain_normalized_exponential), n) = n + constrain_stickbreaking_angular(xi::vector[n]) = begin + log_u = log_inv_logit(xi) + log_phi = log_u + (log(pi) - log2()) + phi = exp(log_phi) + log_s = log(sin(phi)) + log_c = log(cos(phi)) + log_s2_prod = append_row(0, 2 * cumulative_sum(log_s)) + (; + jac=n * log2() + sum(log1m_exp(log_u)) + sum(log_s) + sum(log_phi)+sum(log_s2_prod[2:n]) + sum(log_c), + x=exp(log_s2_prod + append_row(2 * log_c, 0)) + ) + end + unconstrained_dim(::typeof(constrain_stickbreaking_angular), n) = n-1 + constrain_stickbreaking_logistic(xi::vector[n]) = begin + log_z = log_inv_logit(xi - log(reverse(linspaced_vector(n, 1, n)))) + log_cum_prod = append_row(0, cumulative_sum(log1m_exp(log_z))) + (; + jac=sum(log_cum_prod) + sum(log_z), + x=exp(append_row(log_z, 0) + log_cum_prod) + ) + end + unconstrained_dim(::typeof(constrain_stickbreaking_logistic), n) = n-1 + constrain_stickbreaking_normal(xi::vector[n]) = begin + w = xi - log(reverse(linspaced_vector(n, 1, n))) / 2 + log_z = std_normal_lcdfs(w) + log_cum_prod = append_row(0, cumulative_sum(log1m_exp(log_z))) + (; + jac=std_normal_lpdf(w) + sum(log_cum_prod[2:n]), + x=exp(append_row(log_z, 0) + log_cum_prod) + ) + end + unconstrained_dim(::typeof(constrain_stickbreaking_normal), n) = n-1 + constrain_stickbreaking_power_logistic(xi::vector[n]) = begin + log_u = log_inv_logit(xi) + log_w = log_u ./ reverse(linspaced_vector(n, 1, n)) + (; + jac=2 * sum(log_u) - sum(xi) - lgamma(n+1), + x=exp(append_row(log1m_exp(log_w), 0) + append_row(0, cumulative_sum(log_w))) + ) + end + unconstrained_dim(::typeof(constrain_stickbreaking_power_logistic), n) = n-1 + constrain_stickbreaking_power_normal(xi::vector[n]) = begin + log_u = std_normal_lcdfs(xi) + log_w = log_u ./ reverse(linspaced_vector(n, 1, n)) + (; + jac=std_normal_lpdf(xi) - lgamma(n+1), + x=exp(append_row(log1m_exp(log_w), 0) + append_row(0, cumulative_sum(log_w))) + ) + end + unconstrained_dim(::typeof(constrain_stickbreaking_power_normal), n) = n-1 +end + +any_simplex = @slic begin + xi ~ simplex_prior(constrain_f, prior_f, n) + return simplex_constrain(xi, constrain_f) +end +dummy_simplex_model = dummy_model(quote + theta ~ any_simplex(;constrain_f, prior_f, n=10) +end) + +simplex_posteriors = map((; + constrain_alr, constrain_expanded_softmax, constrain_ilr, constrain_ilr_reflector, + constrain_normalized_exponential, + constrain_stickbreaking_angular, constrain_stickbreaking_logistic, constrain_stickbreaking_normal, + constrain_stickbreaking_power_logistic, constrain_stickbreaking_power_normal + )) do constrain_f + dummy_simplex_model(;constrain_f, prior_f=uniform_simplex_lpdf) +end + +map(simplex_posteriors) do posterior + QuartoComponents.Code("stan", stan_code(posterior)) +end |> QuartoComponents.Tabset +``` \ No newline at end of file