Skip to content

Add compile-time options to change default LogProbType#1328

Merged
penelopeysm merged 11 commits intomainfrom
py/f32
Mar 24, 2026
Merged

Add compile-time options to change default LogProbType#1328
penelopeysm merged 11 commits intomainfrom
py/f32

Conversation

@penelopeysm
Copy link
Copy Markdown
Member

@penelopeysm penelopeysm commented Mar 21, 2026

With this PR:

julia> using DynamicPPL

julia> set_logprob_type!(Float32) # Then restart session
[ Info: DynamicPPL's log probability type has been set to Float32; please note you will need to restart your Julia session for this change to take effect.

julia> using DynamicPPL, LogDensityProblems, ForwardDiff, Distributions, ADTypes

julia> @model function f()
           x ~ Normal(0.0f0, 1.0f0)
       end
f (generic function with 2 methods)

julia> typeof(rand(f())[@varname(x)])
Float32

julia> typeof(logjoint(f(), (; x=0.0f0)))
Float32

julia> ldf = LogDensityFunction(f(), getlogjoint_internal, LinkAll(); adtype=AutoForwardDiff());

julia> typeof(rand(ldf))
Vector{Float32} (alias for Array{Float32, 1})

julia> typeof(LogDensityProblems.logdensity(ldf, [0.0f0]))
Float32

julia> typeof(LogDensityProblems.logdensity_and_gradient(ldf, [0.0f0]))
Tuple{Float32, Vector{Float32}}

Prior to this PR, logjoint and LogDensityProblems.logdensity would return Float64 instead. See TuringLang/Turing.jl#2212.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 21, 2026

Benchmark Report

  • this PR's head: 9f2975d2c43d2f7281af8625b605c57e137cb689
  • base branch: 34b8230ac30bb948798e58ec95adb65a9fbad4b4

Computer Information

Julia Version 1.11.9
Commit 53a02c0720c (2026-02-06 00:27 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, icelake-server)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬────────┬───────────────────────────────┬────────────────────────────┬─────────────────────────────────┐
│                       │       │             │        │       t(eval) / t(ref)        │     t(grad) / t(eval)      │        t(grad) / t(ref)         │
│                       │       │             │        │ ─────────┬──────────┬──────── │ ───────┬─────────┬──────── │ ──────────┬───────────┬──────── │
│                 Model │   Dim │  AD Backend │ Linked │     base │  this PR │ speedup │   base │ this PR │ speedup │      base │   this PR │ speedup │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│               Dynamic │    10 │    mooncake │   true │   282.98 │   280.26 │    1.01 │   7.72 │    8.13 │    0.95 │   2185.29 │   2277.45 │    0.96 │
│                   LDA │    12 │ reversediff │   true │  3485.55 │  3306.00 │    1.05 │   3.29 │    1.95 │    1.69 │  11483.28 │   6439.73 │    1.78 │
│   Loop univariate 10k │ 10000 │    mooncake │   true │ 30908.25 │ 30837.72 │    1.00 │   6.36 │    6.31 │    1.01 │ 196476.77 │ 194681.42 │    1.01 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│    Loop univariate 1k │  1000 │    mooncake │   true │  3144.22 │  3144.12 │    1.00 │   6.99 │    6.20 │    1.13 │  21964.97 │  19501.99 │    1.13 │
│      Multivariate 10k │ 10000 │    mooncake │   true │ 31896.32 │ 31989.73 │    1.00 │   9.86 │    9.73 │    1.01 │ 314375.20 │ 311353.11 │    1.01 │
│       Multivariate 1k │  1000 │    mooncake │   true │  3428.76 │  3417.58 │    1.00 │   9.31 │    9.30 │    1.00 │  31927.86 │  31792.76 │    1.00 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│ Simple assume observe │     1 │ forwarddiff │  false │     0.88 │     1.01 │    0.87 │  10.50 │    9.08 │    1.16 │      9.22 │      9.14 │    1.01 │
│           Smorgasbord │   201 │ forwarddiff │  false │   960.16 │   933.92 │    1.03 │  71.50 │   71.84 │    1.00 │  68650.05 │  67091.60 │    1.02 │
│           Smorgasbord │   201 │      enzyme │   true │  1297.65 │  1292.76 │    1.00 │   6.47 │    4.84 │    1.34 │   8395.18 │   6258.34 │    1.34 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │   true │  1296.15 │  1291.82 │    1.00 │  67.36 │   67.71 │    0.99 │  87306.86 │  87475.17 │    1.00 │
│           Smorgasbord │   201 │    mooncake │   true │  1301.63 │  1298.80 │    1.00 │   4.68 │    4.75 │    0.99 │   6090.86 │   6167.60 │    0.99 │
│           Smorgasbord │   201 │ reversediff │   true │  1296.65 │  1282.38 │    1.01 │ 127.53 │  130.76 │    0.98 │ 165365.71 │ 167686.37 │    0.99 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│              Submodel │     1 │    mooncake │   true │     0.88 │     0.87 │    1.01 │  27.04 │   26.88 │    1.01 │     23.71 │     23.40 │    1.01 │
└───────────────────────┴───────┴─────────────┴────────┴──────────┴──────────┴─────────┴────────┴─────────┴─────────┴───────────┴───────────┴─────────┘

@codecov
Copy link
Copy Markdown

codecov bot commented Mar 21, 2026

Codecov Report

❌ Patch coverage is 34.48276% with 19 lines in your changes missing coverage. Please review.
✅ Project coverage is 77.93%. Comparing base (34b8230) to head (613dc1d).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
src/utils.jl 25.00% 18 Missing ⚠️
src/transformed_values.jl 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1328      +/-   ##
==========================================
- Coverage   78.26%   77.93%   -0.34%     
==========================================
  Files          50       50              
  Lines        3566     3585      +19     
==========================================
+ Hits         2791     2794       +3     
- Misses        775      791      +16     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@penelopeysm

This comment was marked as outdated.

@github-actions
Copy link
Copy Markdown
Contributor

DynamicPPL.jl documentation for PR #1328 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1328/

@penelopeysm penelopeysm changed the title Fix propagation of non-Float64 floats Add compile-time options to change default LogProbType Mar 23, 2026
@penelopeysm penelopeysm changed the base branch from breaking to main March 23, 2026 21:29
@penelopeysm penelopeysm merged commit a1e8f06 into main Mar 24, 2026
20 of 22 checks passed
@penelopeysm penelopeysm deleted the py/f32 branch March 24, 2026 10:29
penelopeysm added a commit to TuringLang/Turing.jl that referenced this pull request Mar 26, 2026
…-export `DynamicPPL.set_logprob_type!` (#2794)

Closes #2739.

As a nice by-product of using `rand(ldf)` rather than `vi[:]`, we also
avoid accidentally promoting Float32 to Float64. This means that
(together with TuringLang/DynamicPPL.jl#1328 and
tpapp/DynamicHMC.jl#199) one can do

```julia
julia> using DynamicPPL; DynamicPPL.set_logprob_type!(Float32)
┌ Info: DynamicPPL's log probability type has been set to Float32.
└ Please note you will need to restart your Julia session for this change to take effect.
```

and then after restarting

```julia
julia> using Turing, FlexiChains, DynamicHMC

julia> @model function f()
           x ~ Normal(0.0f0, 1.0f0)
       end
f (generic function with 2 methods)

julia> chn = sample(f(), externalsampler(DynamicHMC.NUTS()), 100; chain_type=VNChain)
Sampling 100%|████████████████████████████████████████████| Time: 0:00:02
FlexiChain (100 iterations, 1 chain)
↓ iter=1:100 | → chain=1:1

Parameter type   VarName
Parameters       x
Extra keys       :logprior, :loglikelihood, :logjoint


julia> eltype(chn[@varname(x)])
Float32

julia> eltype(chn[:logjoint])
Float32
```

(Previously, the values of `x` would be Float32, but logjoint would be
Float64. And if you used MCMCChains, everything would be Float64.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant