Skip to content

@observe x ~ dist #1270

@penelopeysm

Description

@penelopeysm

This sort of model doesn't work "as intended" (one of the oldest remaining issues, #519):

@model function f()
    x ~ Normal()
    y = x + 1
    y ~ Normal()
end

y will be treated as a random variable here.

I think it's fine to say that this is how DynamicPPL works, but the question is what workarounds for it do we have right now. The main solution is to use @addlogprob!

@model function f()
    x ~ Normal()
    y = x + 1
    @addlogprob! logpdf(Normal(), y)
end

While this works correctly for log-density tracking purposes, it's not semantically equivalent to a true observation y ~ Normal(). In particular, @addlogprob! only triggers accumulation for the log-likelihood accumulator, and any other accumulators are ignored. That means that any custom behaviour in accumulate_observe!! will not work.

For most accumulators in DynamicPPL, we actually define accumulate_observe!!(acc, ...) = acc i.e., the accumulators don't track observed variables. However, there are a couple of exceptions. One is pointwise_loglikelihoods, which does in fact hook into accumulate_observe!!. That means that for the model above, we can't extract any loglikelihoods:

pointwise_loglikelihoods(f(), VarInfo(f()))
# OrderedDict{VarName, Float64}()

whereas with a 'true' observe statement, we can:

@model function g(y)
    x ~ Normal()
    y ~ Normal()
end
pointwise_loglikelihoods(g(1.0), VarInfo(g(1.0)))
# OrderedDict{VarName, Float64} with 1 entry:
#   y => -1.41894

I think that the right solution to this is that we need something to be able to force a tilde-statement to be treated as an observation. In essence, @addlogprob! expands to just a call to accumulate_assume!! on one single accumulator, but we should have something that expands to DynamicPPL.tilde_observe!!(...) which will in turn behave like a true observation.

I threw this together (it took me way longer than I would like to admit):

macro observe(expr)
    return _observe(expr)
end
function _observe(expr)
    @gensym dist left_val val
    if Meta.isexpr(expr, :call) && length(expr.args) == 3 && expr.args[1] == :~
        left_arg, right_arg = expr.args[2], expr.args[3]
        vn = AbstractPPL.varname(left_arg, false)
        return esc(quote
            $left_val = $left_arg
            $dist = $right_arg
            $val, __varinfo__ = $(DynamicPPL.tilde_observe!!)(__model__.context, $dist, $left_val, $vn, __varinfo__)
            $val
        end)
    else
        error("@observe expects an expression of the form `var ~ dist`.")
    end
end

@model function fnew()
    x ~ Normal()
    y = x + 1
    @observe y ~ Normal()
end

and now you get

julia> pointwise_loglikelihoods(fnew(), VarInfo(fnew()))
OrderedDict{VarName, Float64} with 1 entry:
  y => -1.52277

Metadata

Metadata

Assignees

No one assigned

    Labels

    modelling-syntax`@model`, `~`, and associated syntax

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions