diff --git a/src/slic_stan/builtin.jl b/src/slic_stan/builtin.jl index 0eedd90..ad8a021 100644 --- a/src/slic_stan/builtin.jl +++ b/src/slic_stan/builtin.jl @@ -283,13 +283,17 @@ autokwargs(::CanonicalExpr{<:Union{typeof.((lognormal,chi_square,inv_chi_square, # end end @defsig begin - Union{typeof.((sqrt, exp, log, log10, sin, cos, asin, acos, log1m, inv_logit, log_inv_logit, log1m_exp, expm1, Phi, lgamma, abs, log1p_exp, log1m_exp, Base.inv, Base.log1p))...} => begin + Union{typeof.((sqrt, exp, log, log10, sin, cos, asin, acos, log1m, inv_logit, logit, log_inv_logit, log1m_exp, expm1, Phi, lgamma, abs, log1p_exp, log1m_exp, Base.inv, Base.log1p))...} => begin (real,)=>real (vector[n],)=>vector[n] (row_vector[n],)=>row_vector[n] (real[n],)=>real[n] (matrix[m,n],)=>matrix[m,n] end + Union{typeof.((log_sum_exp, ))...} => begin + (real, real) => real + (vector[n], vector[n]) => vector[n] + end typeof(÷) => begin (int, int) => int end diff --git a/src/slic_stan/slic.jl b/src/slic_stan/slic.jl index 6a9b4e3..b746c70 100644 --- a/src/slic_stan/slic.jl +++ b/src/slic_stan/slic.jl @@ -340,9 +340,17 @@ forward!(x::AssignmentExpr{Symbol}; info) = begin forward!(remake(x, name, rhs); info) end maybe_lazy_size(key::Symbol, i, sizei; info) = sizei -maybe_lazy_size(key::Symbol, i, ::StanExpr{<:CanonicalExpr}; info) = StanExpr( - forward!(CanonicalExpr(:getindex, CanonicalExpr(:dims, key), i); info), StanType(types.int) -) +is_simple_size(x::StanExpr) = is_simple_size(expr(x)) +is_simple_size(x::CanonicalExpr{<:Union{typeof.((+,-,*,÷))...}}) = all(is_simple_size, x.args) +is_simple_size(x::CanonicalExpr) = false +is_simple_size(x::Symbol) = true +is_simple_size(x::Number) = true +is_simple_size(x) = false#error(typeof(x)) +maybe_lazy_size(key::Symbol, i, sizei::StanExpr{<:CanonicalExpr}; info) = if is_simple_size(sizei) + sizei +else + forward!(canonical(:(dims($key)[$i])); info) +end forward!(x::AssignmentExpr{Symbol,<:StanExpr}; info) = begin name, rhs = x.args @assert name ∉ keys(info)