diff --git a/src/slic_stan/slic.jl b/src/slic_stan/slic.jl index 6a9b4e3..b746c70 100644 --- a/src/slic_stan/slic.jl +++ b/src/slic_stan/slic.jl @@ -340,9 +340,17 @@ forward!(x::AssignmentExpr{Symbol}; info) = begin forward!(remake(x, name, rhs); info) end maybe_lazy_size(key::Symbol, i, sizei; info) = sizei -maybe_lazy_size(key::Symbol, i, ::StanExpr{<:CanonicalExpr}; info) = StanExpr( - forward!(CanonicalExpr(:getindex, CanonicalExpr(:dims, key), i); info), StanType(types.int) -) +is_simple_size(x::StanExpr) = is_simple_size(expr(x)) +is_simple_size(x::CanonicalExpr{<:Union{typeof.((+,-,*,÷))...}}) = all(is_simple_size, x.args) +is_simple_size(x::CanonicalExpr) = false +is_simple_size(x::Symbol) = true +is_simple_size(x::Number) = true +is_simple_size(x) = false#error(typeof(x)) +maybe_lazy_size(key::Symbol, i, sizei::StanExpr{<:CanonicalExpr}; info) = if is_simple_size(sizei) + sizei +else + forward!(canonical(:(dims($key)[$i])); info) +end forward!(x::AssignmentExpr{Symbol,<:StanExpr}; info) = begin name, rhs = x.args @assert name ∉ keys(info)