Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
bf6022e
Implement PGAS with NUTS
THargreaves Feb 12, 2026
436ac64
Implement CSMC loop
THargreaves Feb 16, 2026
fe9c83c
Add ESS benchmarks to RB-PGAS test
THargreaves Feb 16, 2026
d51a0b0
Refactor backward simulation for new CSMC interface
THargreaves Feb 16, 2026
e78a30e
Correct Kalman gradients for H and c
THargreaves Feb 17, 2026
82682b2
Implement ChainRules and LogDensityProblems interface for SSM
THargreaves Feb 17, 2026
d82d212
Implement particle Gibbs
THargreaves Feb 17, 2026
9a6fb7d
Refactor augemented LGSSM code
THargreaves Feb 17, 2026
04d52b0
Implement Turing integration
THargreaves Feb 17, 2026
518ff99
Remove prototype/legacy code
THargreaves Feb 17, 2026
e01b0d5
Refactor PG to use arbitrary param sampler, update AHMC/AMH compat bo…
THargreaves Feb 19, 2026
1f64c8d
Unified duplicate elements of code-base, generalised for other analyt…
THargreaves Feb 19, 2026
f4bddfc
Correct CSMC-AS loglik output and add unit test
THargreaves Feb 19, 2026
a74471e
Refactor duplicate code
THargreaves Feb 19, 2026
98afb47
Formatting
THargreaves Feb 23, 2026
aa98556
Add example script
THargreaves Feb 23, 2026
f245460
Improve type stability
THargreaves Feb 23, 2026
100ac33
Fix formatting
THargreaves Feb 24, 2026
5a86410
Merge branch 'main' into th/pgas-nuts
THargreaves Feb 25, 2026
fd14793
Merge branch 'main' into th/pgas-nuts
THargreaves Mar 3, 2026
2de0c92
Add return statements to model definitions
THargreaves Mar 3, 2026
77a093e
Add simple PGAS example
THargreaves Mar 3, 2026
d6932af
Add alternative AD frameworks and compare to MH
THargreaves Mar 3, 2026
db7d20a
Reduce allocations in resampling
THargreaves Mar 3, 2026
918021e
Implement rrule for mooncake
THargreaves Mar 9, 2026
a465b65
Add Mooncake benchmarks
THargreaves Mar 9, 2026
3575468
Cleaned Mooncake integration
THargreaves Mar 9, 2026
ebfb8fe
Generalise Mooncake AD support to all AbstractPDMat
THargreaves Mar 9, 2026
fa2f02a
Add unit tests for Mooncake integration
THargreaves Mar 10, 2026
f3e480e
Format
THargreaves Mar 10, 2026
ac95f71
Add missing jitter to KF gradient
THargreaves Mar 10, 2026
4973637
Move Mooncake integration to extension
THargreaves Mar 10, 2026
0ec1da0
Create package extension for CUDA
THargreaves Mar 10, 2026
09da0b5
Make ForwardDiff a test dep
THargreaves Mar 10, 2026
25f1632
Format
THargreaves Mar 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions GeneralisedFilters/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,96 @@ version = "0.4.2"
authors = ["THargreaves <tim.hargreaves@icloud.com>", "Charles Knipp <charlesknipp98@gmail.com>", "FredericWantiez <frederic.wantiez@gmail.com>", "Hong Ge <hg344@cam.ac.uk>"]

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
CUDAExt = "CUDA"
MooncakeExt = "Mooncake"

[compat]
AbstractMCMC = "5"
ADTypes = "1.21.0"
AbstractMCMC = "5.9.0"
AcceleratedKernels = "0.3, 0.4"
AdvancedHMC = "0.8.3"
AdvancedMH = "0.8.9"
Aqua = "0.8"
Bijectors = "0.15.16"
CUDA = "5"
ChainRulesCore = "1.26.0"
DataStructures = "0.18.20, 0.19"
Distributions = "0.25"
DynamicPPL = "0.39.13"
ForwardDiff = "1.3.2"
LogDensityProblems = "2.2.0"
LogDensityProblemsAD = "1.13.1"
LogExpFunctions = "0.3"
MCMCChains = "7.7.0"
Mooncake = "0.5.13"
OffsetArrays = "1.14.1"
PDMats = "0.11.35"
SSMProblems = "0.6"
StaticArrays = "1.9.17"
Statistics = "1"
StatsBase = "0.34.3"
Test = "1"
Turing = "0.42.6"
Zygote = "0.7.10"
julia = "1.10"

[extras]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "FiniteDifferences", "PDMats", "StableRNGs", "Test", "TestItemRunner", "TestItems", "JET"]
test = [
"AdvancedHMC",
"AdvancedMH",
"Aqua",
"CUDA",
"FiniteDifferences",
"ForwardDiff",
"JET",
"Mooncake",
"PDMats",
"StableRNGs",
"Test",
"TestItemRunner",
"TestItems",
"Zygote",
]
15 changes: 15 additions & 0 deletions GeneralisedFilters/examples/PGAS Example/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GeneralisedFilters = "3ef92589-7ab8-43f9-b5b9-a3a0c86ecbb7"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
97 changes: 97 additions & 0 deletions GeneralisedFilters/examples/PGAS Example/rb_ssm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
using GeneralisedFilters
using AbstractMCMC: AbstractMCMC
using AdvancedHMC
using ADTypes: ADTypes
using MCMCChains: MCMCChains
using Turing: @model
using Distributions
using PDMats
using LinearAlgebra
using Random
using Statistics
using SSMProblems
using StaticArrays
using Zygote
using Mooncake

rng = MersenneTwister(1234)

a = 0.8
c_val = 0.5 # coupling: inner offset = b + c_val * outer_state
q² = 0.1
r² = 0.5
σ₀² = 1.0
σ_b² = 4.0
T_len = 100
N_particles = 50
N_iter = 50
N_adapts = 10

# HierarchicalSSM with inner drift b:
# outer state — AR(1), sampled via particles
# inner state — AR(1) with offset b + c_val * outer_state, marginalised via KF
# observations — from inner state only
#
# All constant parts are pre-built so that PDMat constructors never appear in the
# Zygote trace — only getfield accesses, which Zygote handles natively.
const _fixed_ssm = let
outer_prior = HomogeneousGaussianPrior(SVector{1}(0.0), PDMat(SMatrix{1,1}(σ₀²)))
outer_dyn = HomogeneousLinearGaussianLatentDynamics(
SMatrix{1,1}(a), SVector{1}(0.0), PDMat(SMatrix{1,1}(q²))
)
inner_prior = HomogeneousGaussianPrior(SVector{1}(0.0), PDMat(SMatrix{1,1}(σ₀²)))
inner_dyn = GeneralisedFilters.GFTest.InnerDynamics(
SMatrix{1,1}(a),
SVector{1,Float64}([0.0]),
SMatrix{1,1}(c_val),
PDMat(SMatrix{1,1}(q²)),
)
inner_obs = HomogeneousLinearGaussianObservationProcess(
SMatrix{1,1}(1.0), SVector{1}(0.0), PDMat(SMatrix{1,1}(r²))
)
HierarchicalSSM(outer_prior, outer_dyn, inner_prior, inner_dyn, inner_obs)
end

function build_ssm_rb(b)
dyn = _fixed_ssm.inner_model.dyn
new_inner_dyn = GeneralisedFilters.GFTest.InnerDynamics(
dyn.A, SVector{1,Float64}(b), dyn.C, dyn.Q
)
return HierarchicalSSM(
_fixed_ssm.outer_prior,
_fixed_ssm.outer_dyn,
_fixed_ssm.inner_model.prior,
new_inner_dyn,
_fixed_ssm.inner_model.obs,
)
end

true_b = 1.5
_, _, _, _, ys = AbstractMCMC.sample(rng, build_ssm_rb([true_b]), T_len)

@model function drift_model_rb(ys)
b ~ MvNormal([0.0], σ_b² * I)
ssm = build_ssm_rb(b)
x ~ SSMTrajectory(ssm, KF(), ys)
return nothing
end

m = drift_model_rb(ys)
param_sampler = HMC(0.01, 10)
adtype = ADTypes.AutoZygote()
# adtype = ADTypes.AutoMooncake()
pg = ParticleGibbs(CSMCAS(RBPF(BF(N_particles), KF())), param_sampler; adtype=adtype)

chain = AbstractMCMC.sample(
rng, m, pg, N_iter; n_adapts=N_adapts, progress=true, chain_type=MCMCChains.Chains
)

@profview begin
chain = AbstractMCMC.sample(
rng, m, pg, N_iter; n_adapts=N_adapts, progress=true, chain_type=MCMCChains.Chains
)
end

@benchmark AbstractMCMC.sample(
rng, m, pg, N_iter; n_adapts=N_adapts, progress=true, chain_type=MCMCChains.Chains
)
68 changes: 68 additions & 0 deletions GeneralisedFilters/examples/PGAS Example/regular_ssm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using GeneralisedFilters
using AbstractMCMC: AbstractMCMC
using AdvancedHMC: NUTS
using MCMCChains: MCMCChains
using Turing: @model
using Distributions
using PDMats
using LinearAlgebra
using Random
using Statistics
using SSMProblems
using StaticArrays
using Zygote

rng = MersenneTwister(1234)

a = 0.8
q² = 0.1
r² = 0.5
σ₀² = 1.0
σ_b² = 4.0
T_len = 100
N_particles = 50
N_iter = 5000
N_adapts = 1000

function build_ssm_reg(drift)
return create_homogeneous_linear_gaussian_model(
SVector{1}(0.0),
PDMat(SMatrix{1,1}(σ₀²)),
SMatrix{1,1}(a),
SVector{1}(only(drift)),
PDMat(SMatrix{1,1}(q²)),
SMatrix{1,1}(1.0),
SVector{1}(0.0),
PDMat(SMatrix{1,1}(r²)),
)
end

true_b = 1.5
true_ssm = build_ssm_reg([true_b])
_, _, ys = SSMProblems.sample(rng, true_ssm, T_len)

@model function drift_model_reg(ys)
b ~ MvNormal([0.0], σ_b² * I)
ssm = build_ssm_reg(b)
x ~ SSMTrajectory(ssm, ys)
return nothing
end

m = drift_model_reg(ys)
# pg = ParticleGibbs(CSMC(BF(N_particles)), NUTS(0.8))
pg = ParticleGibbs(CSMCAS(BF(N_particles)), NUTS(0.8))
# pg = ParticleGibbs(CSMCBS(BF(N_particles)), NUTS(0.8))

chain = AbstractMCMC.sample(
rng, m, pg, N_iter; n_adapts=N_adapts, progress=false, chain_type=MCMCChains.Chains
)

display(
@benchmark AbstractMCMC.sample(
rng, m, pg, N_iter; n_adapts=N_adapts, progress=false, chain_type=MCMCChains.Chains
)
)

@profview AbstractMCMC.sample(
rng, m, pg, N_iter; n_adapts=N_adapts, progress=false, chain_type=MCMCChains.Chains
)
Loading
Loading