Skip to content

Implement factorised pointwise probabilities#1331

Draft
penelopeysm wants to merge 1 commit intomainfrom
py/pointwise
Draft

Implement factorised pointwise probabilities#1331
penelopeysm wants to merge 1 commit intomainfrom
py/pointwise

Conversation

@penelopeysm
Copy link
Copy Markdown
Member

@penelopeysm penelopeysm commented Mar 22, 2026

This is a proof-of-concept for #1038; it's not mergeable yet. I would want a few things to be changed for this, but the basic concept is there:

  • A unified API in PosteriorStats
  • Moving it to a PosteriorStats extension (PS brings in a lot of deps)
  • Implementing it for the chains methods
  • Tests
  • Docs

cc @sethaxen: Especially regarding PosteriorStats being a relatively heavy dep, is there any reasonable chance that this functionality could be extracted into a separate package that depends only on Distributions + basic stats stuff?

I wonder in general if turning PosteriorStats into a monorepo-like structure would quite useful for encouraging extensions? You could have a number of small repos that define units of functionality, and PosteriorStats could just be an aggregator that re-exports all of them. That would make it easy for other packages to hook into bits of it as needed, without having to pick up the deps from the other bits that they don't need.

The thing is that this is quite neat and I'd really like to have it in the main DynamicPPL package rather than an extension, since extensions have much poorer discoverability!

julia> using DynamicPPL, Distributions, LinearAlgebra

julia> @model function f(y)
           x ~ MvNormal(zeros(2), I)
           y ~ MvNormal(zeros(2), I)
       end
f (generic function with 2 methods)

julia> model = f([1.0, 1.0])
Model{typeof(f), (:y,), (), (), Tuple{Vector{Float64}}, Tuple{}, DefaultContext, false}(f, (y = [1.0, 1.0],), NamedTuple(), DefaultContext())

julia> vi = VarInfo(model)
VarInfo
 ├─ transform_strategy: UnlinkAll()
 ├─ values
 │  VarNamedTuple
 │  └─ x => VectorValue{Vector{Float64}, Bijectors.VectorBijectors.TypedIdentity}([0.3793074864362537, -1.0132794015923987], Bijectors.VectorBijectors.TypedIdentity())
 └─ accs
    AccumulatorTuple with 3 accumulators
    ├─ LogPrior => LogPriorAccumulator(-2.423181723888365)
    ├─ LogJacobian => LogJacobianAccumulator(0.0)
    └─ LogLikelihood => LogLikelihoodAccumulator(-2.8378770664093453)

julia> pointwise_logdensities(model, vi)
VarNamedTuple
├─ x => -2.423181723888365
└─ y => -2.8378770664093453

julia> pointwise_logdensities(model, vi; factorize=true)
VarNamedTuple
├─ x => [-0.9908756178379672, -1.4323061060503974]
└─ y => [-1.4189385332046727, -1.4189385332046727]

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 22, 2026

Benchmark Report

  • this PR's head: aa2e6e430e324c81f4e90405ca54496bfc45cfb8
  • base branch: 34b8230ac30bb948798e58ec95adb65a9fbad4b4

Computer Information

Julia Version 1.11.9
Commit 53a02c0720c (2026-02-06 00:27 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬────────┬───────────────────────────────┬────────────────────────────┬─────────────────────────────────┐
│                       │       │             │        │       t(eval) / t(ref)        │     t(grad) / t(eval)      │        t(grad) / t(ref)         │
│                       │       │             │        │ ─────────┬──────────┬──────── │ ───────┬─────────┬──────── │ ──────────┬───────────┬──────── │
│                 Model │   Dim │  AD Backend │ Linked │     base │  this PR │ speedup │   base │ this PR │ speedup │      base │   this PR │ speedup │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│               Dynamic │    10 │    mooncake │   true │   281.26 │   278.66 │    1.01 │   7.55 │    8.17 │    0.92 │   2122.76 │   2276.12 │    0.93 │
│                   LDA │    12 │ reversediff │   true │  2523.71 │  2572.47 │    0.98 │   2.01 │    2.29 │    0.88 │   5068.95 │   5885.75 │    0.86 │
│   Loop univariate 10k │ 10000 │    mooncake │   true │ 29813.56 │ 31157.28 │    0.96 │   7.62 │    6.45 │    1.18 │ 227152.41 │ 200993.51 │    1.13 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│    Loop univariate 1k │  1000 │    mooncake │   true │  3027.21 │  3195.05 │    0.95 │   6.22 │   10.12 │    0.61 │  18815.99 │  32337.98 │    0.58 │
│      Multivariate 10k │ 10000 │    mooncake │   true │ 29758.38 │ 30836.28 │    0.97 │  10.17 │   10.13 │    1.00 │ 302671.57 │ 312358.30 │    0.97 │
│       Multivariate 1k │  1000 │    mooncake │   true │  3326.77 │  3367.13 │    0.99 │   9.28 │    9.34 │    0.99 │  30865.48 │  31446.08 │    0.98 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│ Simple assume observe │     1 │ forwarddiff │  false │     0.85 │     0.88 │    0.97 │  10.39 │   10.40 │    1.00 │      8.80 │      9.11 │    0.97 │
│           Smorgasbord │   201 │ forwarddiff │  false │   934.42 │   940.93 │    0.99 │  68.87 │   68.23 │    1.01 │  64350.14 │  64203.05 │    1.00 │
│           Smorgasbord │   201 │      enzyme │   true │  1245.90 │  1292.46 │    0.96 │   4.93 │    4.90 │    1.01 │   6136.90 │   6330.35 │    0.97 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │   true │  1256.54 │  1286.98 │    0.98 │  65.73 │   68.53 │    0.96 │  82593.15 │  88193.42 │    0.94 │
│           Smorgasbord │   201 │    mooncake │   true │  1250.76 │  1303.46 │    0.96 │   4.64 │    4.53 │    1.02 │   5805.57 │   5909.57 │    0.98 │
│           Smorgasbord │   201 │ reversediff │   true │  1236.97 │  1315.41 │    0.94 │ 126.68 │  123.27 │    1.03 │ 156702.85 │ 162148.55 │    0.97 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│              Submodel │     1 │    mooncake │   true │     0.85 │     0.97 │    0.88 │  27.01 │   25.41 │    1.06 │     22.89 │     24.60 │    0.93 │
└───────────────────────┴───────┴─────────────┴────────┴──────────┴──────────┴─────────┴────────┴─────────┴─────────┴───────────┴───────────┴─────────┘

@codecov
Copy link
Copy Markdown

codecov bot commented Mar 22, 2026

Codecov Report

❌ Patch coverage is 0% with 15 lines in your changes missing coverage. Please review.
✅ Project coverage is 42.38%. Comparing base (34b8230) to head (aa2e6e4).

Files with missing lines Patch % Lines
src/accumulators/pointwise_logdensities.jl 0.00% 15 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (34b8230) and HEAD (aa2e6e4). Click for more details.

HEAD has 6 uploads less than BASE
Flag BASE (34b8230) HEAD (aa2e6e4)
12 6
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1331       +/-   ##
===========================================
- Coverage   78.26%   42.38%   -35.89%     
===========================================
  Files          50       50               
  Lines        3566     3546       -20     
===========================================
- Hits         2791     1503     -1288     
- Misses        775     2043     +1268     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@sethaxen
Copy link
Copy Markdown
Member

cc @sethaxen: Especially regarding PosteriorStats being a relatively heavy dep, is there any reasonable chance that this functionality could be extracted into a separate package that depends only on Distributions + basic stats stuff?

Yeah, I think we can do that. Even in testing the functionality, I needed to implement much of the machinery needed to compute arbitrary conditional/marginal distributions. There have been longstanding issues open for each of these features in Distributions, and they probably ultimately belong there, but since that's not likely to happen soon, I think it could make sense to have a PartitionedDistributions.jl package that implements at least marginal, conditional, and pointwise_conditional_logpdfs (AKA pointwise_conditional_loglikelihoods). I could start by just putting the latter in such a package and register it and then add the rest of the features later.

@penelopeysm
Copy link
Copy Markdown
Member Author

That sounds great! I'm very happy to help out in whatever way I can :)

@sethaxen
Copy link
Copy Markdown
Member

sethaxen commented Mar 24, 2026

Once sethaxen/PartitionedDistributions.jl#5 is merged and there are some basic docs, I'll make a release and register it. The easier marginal/conditional implementations are already merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants