Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
376 changes: 376 additions & 0 deletions docs/slic/constraints.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,376 @@
---
title: "User defined reusable constraints"
---

The full Julia code for this notebook can be accessed via the top right corner (`</> Code`).

The Julia packages needed to reproduce this document are [`StanBlocks.jl`](https://github.com/nsiccha/StanBlocks.jl) (for the model generation) and [`QuartoComponents.jl`](https://github.com/nsiccha/QuartoComponents.jl) (for the "pretty" printing).
Both packages have to be installed from the latest `main` branch (as of Oct 14th 2025).

## StanBlocks.jl implementation

The function and model definitions below make use of

* [Julia-style Named Tuples](https://docs.julialang.org/en/v1/base/base/#Core.NamedTuple) - allowing `.` syntax access to fields via names (instead of [via numbers as in Stan](https://mc-stan.org/docs/reference-manual/types.html#assigning-tuple-elements)).

# A uniform prior on a disk

## Full Julia + StanBlocks.jl code to define the models

### Current status

Currently, this is roughly how I'd add a custom constraining transform + prior on the constraint variables:

```julia
using StanBlocks

@deffun begin
"""
This function should of course actually return the jacobian adjustment of the `uniform_disk_constrain` transform
(wrt to `x` and `y`) - I'll work it out another time.
"""
uniform_disk_lpdf(xi::vector[n], n) = 0.
"""
This function returns the constrained parameters `x` and `y` together with the intermediate quantities `radius` and `angle` as a a Named Tuple,
which for Stan will look like a regular Tuple.
Within StanBlocks.jl, its possible to access the fields of the return value via `.` syntax,
e.g. `rv.radius` or `rv.x` after `rv = uniform_disk_constrain(xi)`.
"""
uniform_disk_constrain(xi::vector[n]) = begin
radius = inv_logit(xi[1])
angle = 2pi * inv_logit(xi[2])
(;radius, angle, x=cos(angle) * radius, y=sin(angle) * radius)
end
end

disk = @slic begin
"""
This initializes the unconstrained parameters and adds the jacobian adjustment (and any potential prior).

If you want to also put a prior on `x` and `y`, you'd currently (in StanBlocks.jl) have to write a new `_lpdf` function - e.g. `non_uniform_disk_lpdf`.
"""
xi ~ uniform_disk(2)
"""
This does the constraining.

It *would* be more convenient if there was an easier mechanism to combine (custom) "constraining" and "putting a prior on the constrained parameters"..
"""
return uniform_disk_constrain(xi)
end
```

With the `disk` model defined (**by the user in this notebook - not in the back end**) as above, usage would look like this,
where `obs` and `obs_scale` come from somewhere else in this example:

```julia
model_using_disk = @slic begin
disk_parameters ~ disk()
obs ~ normal(
some_binary_function_working_on_disk_coordinates(
disk_parameters.x, disk_parameters.y
),
obs_scale
)
end
```

### Future status

TO DO: figure out a syntax which would make custom constraints even easier.

## Generated Stan code

For the generated Stan code, the disk model has been spliced into a dummy model, resulting in a model definition equivalent to

```julia
dummy_model = @slic begin
"Needed to prevent StanBlocks.jl from moving `theta` to generated quantities"
dummy_obs = 1.
theta ~ disk()
"Needed to prevent StanBlocks.jl from moving `theta` to generated quantities"
dummy_obs ~ dummy(theta)
end
```

Without the dummy likelihood term, StanBlocks.jl would move `theta` to generated quantities, because it wouldn't affect the likelihood.

```{julia}
using StanBlocks, QuartoComponents

@deffun begin
"""
This function should of course actually return the jacobian adjustment of the `uniform_disk_constrain` transform
(wrt to `x` and `y`) - I'll work it out another time.
"""
uniform_disk_lpdf(xi::vector[n], n) = 0.
"""
This function returns the constrained parameters `x` and `y` together with the intermediate quantities `radius` and `angle` as a a Named Tuple,
which for Stan will look like a regular Tuple.
Within StanBlocks.jl, its possible to access the fields of the return value via `.` syntax,
e.g. `rv.radius` or `rv.x` after `rv = uniform_disk_constrain(xi)`.
"""
uniform_disk_constrain(xi::vector[n]) = begin
radius = inv_logit(xi[1])
angle = 2pi * inv_logit(xi[2])
(;radius, angle, x=cos(angle) * radius, y=sin(angle) * radius)
end
"Needed to prevent StanBlocks.jl from moving `theta` to generated quantities"
dummy_lpdf(args...) = 0.
"Needed to prevent StanBlocks.jl from moving `theta` to generated quantities"
dummy_lpdfs(args...) = 0.
"Needed to prevent StanBlocks.jl from moving `theta` to generated quantities"
dummy_rng(args...) = 0.
end


disk = @slic begin
"""
This initializes the unconstrained parameters and adds the jacobian adjustment (and any potential prior).

If you want to also put a prior on `x` and `y`, you'd currently (in StanBlocks.jl) have to write a new `_lpdf` function - e.g. `non_uniform_disk_lpdf`.
"""
xi ~ uniform_disk(2)
"""
This does the constraining.

It *would* be more convenient if there was an easier mechanism to combine (custom) "constraining" and "putting a prior on the constrained parameters"..
"""
return uniform_disk_constrain(xi)
end

dummy_model = @slic begin
"Needed to prevent StanBlocks.jl from moving `theta` to generated quantities"
dummy_obs = 1.
"Needed to prevent StanBlocks.jl from moving `theta` to generated quantities"
dummy_obs ~ dummy(theta)
end

disk_posteriors = (;
disk=dummy_model(quote
theta ~ disk()
end)
)

map(disk_posteriors) do posterior
QuartoComponents.Code("stan", stan_code(posterior))
end |> QuartoComponents.Tabset
```

# (Adaptive?) centering for hierarchical models

If [this discourse thread](https://discourse.mc-stan.org/t/offset-multiplier-initialization/20712/20) is any indication, it looks like generally, it might be preferable to be able compute the Jacobian adjustment and the prior density together for the sake of numerical stability.
This is something that Stan's lack of compound declare-distribute statements "prevents" - see also [When to add a feature to Stan? The recurring issue of the compound declare-distribute statement](https://statmodeling.stat.columbia.edu/2018/02/01/stan-feature-declare-distribute/).

Furthermore, it's probably a good idea to be able to parametrize the constraining transformations, allowing to vary e.g. the centeredness after defining the model via `data` arguments.

TO DO: elaborate


# Various simplex constraining transformations

The underlying code reproduces all transformations from [https://github.com/bob-carpenter/transforms](https://github.com/bob-carpenter/transforms).

## Example Julia + StanBlocks.jl code to define the models

### Current status

It does so by first defining (**by the user in this notebook - not in the back end**) a general model

```julia
any_simplex = @slic begin
xi ~ simplex_prior(constrain_f, prior_f, n)
return simplex_constrain(xi, constrain_f)
end
```

where `n` is the dimension of the simplex and `constrain_f` and `prior_f` can be used to change the used constraining transform or the imposed prior.

Defining a custom simplex constraining transform is then as easy as defining two functions,

* a function `f` accepting the vector of unconstrained parameters and returning a Named Tuple with fields `jac` (the jacobian adjustment) and `x` (the constrained parameters)
(to be passed as the `constrain_f` argument),
* and an overload of `unconstrained_dim(::typeof(f), n)` returning the dimension of the unconstrained parameters accepted by `f`.

See below for an example implementation for the ALR constraining transform:

```julia
constrain_alr(xi::vector[n]) = begin
r = log1p_exp(log_sum_exp(xi))
(;
jac=sum(xi)-(n+1)*r,
x=exp(append_row(xi - r, -r))
)
end
unconstrained_dim(::typeof(constrain_alr), n) = n-1
```

The two functions used in the `any_simplex` model, `simplex_prior_lpdf` and `simplex_constrain` are defined (**by the user in this notebook - not in the back end**) as

```julia
simplex_prior_lpdf(xi::vector[unconstrained_dim(constrain_f, n)], constrain_f, prior_f, n) = begin
tmp = constrain_f(xi)
tmp.jac + prior_f(tmp.x)
end
simplex_constrain(xi, constrain_f) = constrain_f(xi).x
```

### Current limitations

Currently, both the jacobian adjustment have to be computed twice - once in `simplex_prior_lpdf` and once in `simplex_constrain`.
Ideally, this could be avoided.

### Future status

TO DO: Think about how to best prevent the double work. The main work would be to think about appropriate syntax signalling to StanBlocks.jl that these functions work together to compute the constraining and the jacobian adjustment.

## Generated Stan code

For the generated Stan code, the `any_simplex` model has been spliced into a dummy model, resulting in a model definition equivalent to

```julia
dummy_simplex_model = @slic begin
"Needed to prevent StanBlocks.jl from moving `theta` to generated quantities"
dummy_obs = 1.
theta ~ any_simplex(;constrain_f, prior_f, n=10)
"Needed to prevent StanBlocks.jl from moving `theta` to generated quantities"
dummy_obs ~ dummy(theta)
end
```

where again without the dummy likelihood StanBlocks.jl would move `theta` to generated quantities, because it wouldn't affect the likelihood.

```{julia}
@deffun begin
Base.reverse(x::vector[n])::vector[n]
StanBlocks.stan.std_normal_lcdf(x)::real
Base.log2()::real
std_normal_lcdfs(x::real) = std_normal_lcdf(x)
std_normal_lcdfs(x::vector[n]) = jbroadcasted(std_normal_lcdfs, x)

unconstrained_dim(constrain_f, n) = reject(n)
simplex_prior_lpdf(xi::vector[unconstrained_dim(constrain_f, n)], constrain_f, prior_f, n) = begin
tmp = constrain_f(xi)
tmp.jac + prior_f(tmp.x)
end
simplex_constrain(xi, constrain_f) = constrain_f(xi).x
uniform_simplex_lpdf(xi) = 0.

constrain_alr(xi::vector[n]) = begin
r = log1p_exp(log_sum_exp(xi))
(;
jac=sum(xi)-(n+1)*r,
x=exp(append_row(xi - r, -r))
)
end
unconstrained_dim(::typeof(constrain_alr), n) = n-1
constrain_expanded_softmax(xi::vector[n]) = begin
r = log_sum_exp(xi)
(;
jac=std_normal_lpdf(r - log(n)) + sum(xi) - n * r, x=exp(xi - r))
end
unconstrained_dim(::typeof(constrain_expanded_softmax), n) = n
constrain_ilr(xi::vector[n]) = begin
ns = linspaced_vector(n, 1, n)
w = xi ./ sqrt(ns .* (ns + 1))
z = append_row(reverse(cumulative_sum(reverse(w))), 0) - append_row(0, ns .* w)
r = log_sum_exp(z)
(;
jac=0.5 * log(n+1)+sum(z) - (n+1) * r,
x=exp(z - r)
)
end
unconstrained_dim(::typeof(constrain_ilr), n) = n-1
constrain_ilr_reflector(xi::vector[n]) = begin
sqrtN = sqrt((n+1))
zN = sum(xi) / sqrtN
z = append_row(xi - zN ./ (sqrtN - 1), zN)
r = log_sum_exp(z)
(;
jac=0.5 * log(n+1)+sum(z) - (n+1) * r,
x=exp(z - r)
)
end
unconstrained_dim(::typeof(constrain_ilr_reflector), n) = n-1
exponential_log_qf(x) = -log1m_exp(x)
constrain_normalized_exponential(xi::vector[n]) = begin
z = log(exponential_log_qf(std_normal_lcdfs(xi)))
r = log_sum_exp(z)
(;
jac=std_normal_lpdf(xi) - lgamma(n),
x=exp(z - r)
)
end
unconstrained_dim(::typeof(constrain_normalized_exponential), n) = n
constrain_stickbreaking_angular(xi::vector[n]) = begin
log_u = log_inv_logit(xi)
log_phi = log_u + (log(pi) - log2())
phi = exp(log_phi)
log_s = log(sin(phi))
log_c = log(cos(phi))
log_s2_prod = append_row(0, 2 * cumulative_sum(log_s))
(;
jac=n * log2() + sum(log1m_exp(log_u)) + sum(log_s) + sum(log_phi)+sum(log_s2_prod[2:n]) + sum(log_c),
x=exp(log_s2_prod + append_row(2 * log_c, 0))
)
end
unconstrained_dim(::typeof(constrain_stickbreaking_angular), n) = n-1
constrain_stickbreaking_logistic(xi::vector[n]) = begin
log_z = log_inv_logit(xi - log(reverse(linspaced_vector(n, 1, n))))
log_cum_prod = append_row(0, cumulative_sum(log1m_exp(log_z)))
(;
jac=sum(log_cum_prod) + sum(log_z),
x=exp(append_row(log_z, 0) + log_cum_prod)
)
end
unconstrained_dim(::typeof(constrain_stickbreaking_logistic), n) = n-1
constrain_stickbreaking_normal(xi::vector[n]) = begin
w = xi - log(reverse(linspaced_vector(n, 1, n))) / 2
log_z = std_normal_lcdfs(w)
log_cum_prod = append_row(0, cumulative_sum(log1m_exp(log_z)))
(;
jac=std_normal_lpdf(w) + sum(log_cum_prod[2:n]),
x=exp(append_row(log_z, 0) + log_cum_prod)
)
end
unconstrained_dim(::typeof(constrain_stickbreaking_normal), n) = n-1
constrain_stickbreaking_power_logistic(xi::vector[n]) = begin
log_u = log_inv_logit(xi)
log_w = log_u ./ reverse(linspaced_vector(n, 1, n))
(;
jac=2 * sum(log_u) - sum(xi) - lgamma(n+1),
x=exp(append_row(log1m_exp(log_w), 0) + append_row(0, cumulative_sum(log_w)))
)
end
unconstrained_dim(::typeof(constrain_stickbreaking_power_logistic), n) = n-1
constrain_stickbreaking_power_normal(xi::vector[n]) = begin
log_u = std_normal_lcdfs(xi)
log_w = log_u ./ reverse(linspaced_vector(n, 1, n))
(;
jac=std_normal_lpdf(xi) - lgamma(n+1),
x=exp(append_row(log1m_exp(log_w), 0) + append_row(0, cumulative_sum(log_w)))
)
end
unconstrained_dim(::typeof(constrain_stickbreaking_power_normal), n) = n-1
end

any_simplex = @slic begin
xi ~ simplex_prior(constrain_f, prior_f, n)
return simplex_constrain(xi, constrain_f)
end
dummy_simplex_model = dummy_model(quote
theta ~ any_simplex(;constrain_f, prior_f, n=10)
end)

simplex_posteriors = map((;
constrain_alr, constrain_expanded_softmax, constrain_ilr, constrain_ilr_reflector,
constrain_normalized_exponential,
constrain_stickbreaking_angular, constrain_stickbreaking_logistic, constrain_stickbreaking_normal,
constrain_stickbreaking_power_logistic, constrain_stickbreaking_power_normal
)) do constrain_f
dummy_simplex_model(;constrain_f, prior_f=uniform_simplex_lpdf)
end

map(simplex_posteriors) do posterior
QuartoComponents.Code("stan", stan_code(posterior))
end |> QuartoComponents.Tabset
```
Loading