From e94fca4f4270581b178165051f95feefff1f9c1a Mon Sep 17 00:00:00 2001 From: nsiccha Date: Tue, 14 Oct 2025 14:22:22 +0200 Subject: [PATCH 1/3] add logit(...) tracetypes --- src/slic_stan/builtin.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/slic_stan/builtin.jl b/src/slic_stan/builtin.jl index 0eedd90..b19b5a0 100644 --- a/src/slic_stan/builtin.jl +++ b/src/slic_stan/builtin.jl @@ -283,7 +283,7 @@ autokwargs(::CanonicalExpr{<:Union{typeof.((lognormal,chi_square,inv_chi_square, # end end @defsig begin - Union{typeof.((sqrt, exp, log, log10, sin, cos, asin, acos, log1m, inv_logit, log_inv_logit, log1m_exp, expm1, Phi, lgamma, abs, log1p_exp, log1m_exp, Base.inv, Base.log1p))...} => begin + Union{typeof.((sqrt, exp, log, log10, sin, cos, asin, acos, log1m, inv_logit, logit, log_inv_logit, log1m_exp, expm1, Phi, lgamma, abs, log1p_exp, log1m_exp, Base.inv, Base.log1p))...} => begin (real,)=>real (vector[n],)=>vector[n] (row_vector[n],)=>row_vector[n] From 7bea53d30cc0a1e1355366d3ba239faecdc57cfb Mon Sep 17 00:00:00 2001 From: nsiccha Date: Wed, 15 Oct 2025 13:10:14 +0200 Subject: [PATCH 2/3] Start defining built-in binary vectorized function tracetypes --- src/slic_stan/builtin.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/slic_stan/builtin.jl b/src/slic_stan/builtin.jl index b19b5a0..ad8a021 100644 --- a/src/slic_stan/builtin.jl +++ b/src/slic_stan/builtin.jl @@ -290,6 +290,10 @@ end (real[n],)=>real[n] (matrix[m,n],)=>matrix[m,n] end + Union{typeof.((log_sum_exp, ))...} => begin + (real, real) => real + (vector[n], vector[n]) => vector[n] + end typeof(÷) => begin (int, int) => int end From d804397a368490998b09164075de5476abaadd13 Mon Sep 17 00:00:00 2001 From: nsiccha Date: Wed, 15 Oct 2025 13:13:14 +0200 Subject: [PATCH 3/3] Fix #45 --- src/slic_stan/slic.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/slic_stan/slic.jl b/src/slic_stan/slic.jl index 6a9b4e3..b746c70 100644 --- a/src/slic_stan/slic.jl +++ b/src/slic_stan/slic.jl @@ -340,9 +340,17 @@ forward!(x::AssignmentExpr{Symbol}; info) = begin forward!(remake(x, name, rhs); info) end maybe_lazy_size(key::Symbol, i, sizei; info) = sizei -maybe_lazy_size(key::Symbol, i, ::StanExpr{<:CanonicalExpr}; info) = StanExpr( - forward!(CanonicalExpr(:getindex, CanonicalExpr(:dims, key), i); info), StanType(types.int) -) +is_simple_size(x::StanExpr) = is_simple_size(expr(x)) +is_simple_size(x::CanonicalExpr{<:Union{typeof.((+,-,*,÷))...}}) = all(is_simple_size, x.args) +is_simple_size(x::CanonicalExpr) = false +is_simple_size(x::Symbol) = true +is_simple_size(x::Number) = true +is_simple_size(x) = false#error(typeof(x)) +maybe_lazy_size(key::Symbol, i, sizei::StanExpr{<:CanonicalExpr}; info) = if is_simple_size(sizei) + sizei +else + forward!(canonical(:(dims($key)[$i])); info) +end forward!(x::AssignmentExpr{Symbol,<:StanExpr}; info) = begin name, rhs = x.args @assert name ∉ keys(info)