diff --git a/scripts/inference/sir.jl b/scripts/inference/sir.jl new file mode 100644 index 0000000..9ad3f74 --- /dev/null +++ b/scripts/inference/sir.jl @@ -0,0 +1,123 @@ +SRC_DIR = joinpath(Base.source_dir(), "..", "..", "src") +OUT_DIR = joinpath(Base.source_dir(), "..", "..", "output") +mkpath(OUT_DIR) +using Makie: lines, lines! + +include(joinpath(SRC_DIR, "BridgeSDEInference_for_tests.jl")) + +using StaticArrays +using Distributions +using Random +# Let's generate the data +# ----------------------- +using Bridge +#import BridgeSDEInference: CIR, CIRaux +DIR = "auxiliary" +include(joinpath(SRC_DIR, DIR, "data_simulation_fns.jl")) +include(joinpath(SRC_DIR, DIR, "utility_functions.jl")) +Random.seed!(2) +pop = 50_000_000 +#θˣ = [0.37, 0.05, 0.05, 0.01] +#α, β, σ1, σ2 +θˣ = [0.37, 0.05, 0.3/sqrt(pop), 1.0/sqrt(pop)] + +Pˣ = SIR(θˣ...) + +x0, dt, T = ℝ{2}(1/pop, 0.), 1/10000, 10.0 +tt = 0.0:dt:T + +Random.seed!(1) +XX, _ = simulate_segment(ℝ{2}(1.0, 0.0), x0, Pˣ, tt) +last(XX)[2]*pop + +#lines(XX.tt, K .- sum.(XX.yy)) +lines(XX.tt, first.(XX.yy), color = :red) #infected +lines!(XX.tt,last.(XX.yy), color = :blue) #recovered + dead people + +θ_init = [0.5, 0.1, 0.3/sqrt(pop), 1.0/sqrt(pop)] +Pˣ = SIR(θ_init...) + +length(XX.tt) +skip = 2000 + +#Σdiagel = +Σ = @SMatrix[1.0 0.0; 0.0 0.5]/pop +L = @SMatrix[1.0 0.0; 0.0 1.0] + +obs_time, obs_vals = XX.tt[1:skip:end], [rand(Gaussian(L*x, Σ)) for x in XX.yy[1:skip:end]] + +days = [0.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0] +cases = [1.0 0.0; 62.0 0.0; 121.0 0.0; 198.0 3.0; 291.0 6.0; 440.0 9.0; 571.0 17.0; 830.0 25.0; 1287.0 41.0; 1975.0 56.0; 2744.0 80.0; 4515.0 106.0; 5974.0 132.0; 7771.0 170.0] + +if true + obs_time = days + obs_vals = 1/pop*reinterpret(SVector{2,Float64}, [1 -1; 0 1]*cases') # first coordinate is I + R +end +obs_vals[end] +P̃ = [SIRAux(θ_init..., t₀, u, T, v) for (t₀, T, u, v) + in zip(obs_time[1:end-1], obs_time[2:end], obs_vals[1:end-1], obs_vals[2:end])] + +model_setup = DiffusionSetup(Pˣ, P̃, PartObs()) +set_observations!(model_setup, [L for _ in P̃], [Σ for _ in P̃], obs_vals, obs_time) # uses default fpt +set_imputation_grid!(model_setup, 1/1000) +set_x0_prior!(model_setup, + GsnStartingPt(x0, @SMatrix [0.01 0.0; + 0.0 0.01;]), + x0) +set_auxiliary!(model_setup; skip_for_save=10^0, + adaptive_prop=NoAdaptation()) +initialise!(eltype(x0), model_setup, Vern7(), false, NoChangePt(100)) +#:step, :scale, :min, :max, :trgt, :offset +readj = (100, 0.001, 0.001, 999.9, 0.4, 50) +readj2 = (100, 0.01, 0.05, 0.2, 0.7, 50) + + +mcmc_setup = MCMCSetup( + Imputation(NoBlocking(), 0.99, Vern7()), + ParamUpdate(ConjugateUpdt(), [1,2], θ_init, nothing, + MvNormal(fill(0.05, 2), diagm(0=>fill(1000.0, 2))), + UpdtAuxiliary(Vern7(), check_if_recompute_ODEs(P̃, [1,2]))), + + ) + +schedule = MCMCSchedule(10^4, [[1],[2]], #[[1],[2], [5]], + (save=10^2, verbose=10^2, warm_up=100, + readjust=(x->x%100==0), fuse=(x->false))) + +Random.seed!(4) +out = mcmc(mcmc_setup, schedule, model_setup) +error("STOP HERE") + + +using Makie +ws = out[2] + + +θs = ws.θ_chain +±(a, b) = a - b, a + b +beta, gamma, s1 = [median(getindex.(θs, i)) for i in [1,2,3]] .± [std(getindex.(θs, i)) for i in [1,2, 3]] +R0 = mean(getindex.(θs, 1)./getindex.(θs, 2)) ± std(getindex.(θs, 1)./getindex.(θs, 2)) + +θs = ws.θ_chain + + + +@show beta +@show gamma +@show s1 +@show R0 + +lines(getindex.(θs, 1)) +lines(getindex.(θs, 2)) + + + +lines(getindex.(θs, 3)) + +include(joinpath(SRC_DIR, DIR, "plotting_fns.jl")) +plot_chains(ws; truth=θˣ) +#= +plot_paths(out[1], out[2], schedule; obs=(times=obs_time[2:end], + vals=[[v[1] for v in obs_vals[2:end]], + [v[2] for v in obs_vals[2:end]]], indices=[2,3])) +=# diff --git a/scripts/nextjournal/domserv.jl b/scripts/nextjournal/domserv.jl new file mode 100644 index 0000000..3af921b --- /dev/null +++ b/scripts/nextjournal/domserv.jl @@ -0,0 +1,165 @@ +using JSServe, WGLMakie, AbstractPlotting +using JSServe: JSServe.DOM, @js_str, onjs +global three, scene + +using StaticArrays +using Colors +using Random +rebirth(α, R) = x -> (rand() > α ? x : (2rand(typeof(x)) .- 1).*R) +const 𝕏 = SVector + +using FileIO +using Makie: band +using GLMakie + +using Hyperscript, Markdown +using JSServe, Observables +using JSServe: Application, Session, evaljs, linkjs, div, active_sessions, Asset +using JSServe: @js_str, onjs, Button, TextField, Slider, JSString, Dependency, with_session +using JSServe.DOM + +function dom_handler(session, request) + global three, scene + + if @isdefined dom + dom != nothing && return dom + end + + WGLMakie.activate!() + + + # slider and field for sigma + sliders = JSServe.Slider(0.01:0.01:1) + nrs = JSServe.NumberInput(0.0) + linkjs(session, sliders.value, nrs.value) + + # time wheel ;-) + button = JSServe.Slider(1:109) + + # init + R = 𝕏(1.5,6.0) + R1, R2 = R + limits = FRect(-R[1], -R[2], 2R[1], 2R[2]) + n = 800 + K = 80 + dt = 0.001 + sqrtdt = sqrt(dt) + + particlecss = Asset(joinpath(@__DIR__,"particle.css")) + global sliderbg = Asset(joinpath(@__DIR__,"slider1.png")) + + ms = 0.03 + global scene = scatter(repeat(2randn(n), outer=K), repeat(2randn(n),outer=K), color = fill(:white, n*K), + backgroundcolor = RGB{Float32}(0.04, 0.11, 0.22), markersize = ms, + glowwidth = 0.005, glowcolor = :white, + resolution=(600,600), limits = limits, + ) + axis = scene[Axis] + axis[:grid, :linewidth] = (0.3, 0.3) + axis[:grid, :linecolor] = (RGBA{Float32}(0.5, 0.7, 1.0, 0.3),RGBA{Float32}(0.5, 0.7, 1.0, 0.3)) + axis[:names][:textsize] = (0.0,0.0) + axis[:ticks, :textcolor] = (RGBA{Float32}(0.5, 0.7, 1.0, 0.5),RGBA{Float32}(0.5, 0.7, 1.0, 0.5)) + + + splot = scene[end] + scatter!(scene, -R1:0.01:R1, sin.(-R1:0.01:R1), color = RGBA{Float32}(255, 0.0, 4.0, 1.0), markersize=ms) + kplot = scene[end] + + three, canvas = WGLMakie.three_display(session, scene) + js_scene = WGLMakie.to_jsscene(three, scene) + mesh = js_scene.getObjectByName(string(objectid(splot))) + mesh2 = js_scene.getObjectByName(string(objectid(kplot))) + + # init javascript + evaljs(session, js""" + console.log("Hello"); + iter = 1; + si = 0.0; + R1 = $(R1); + R2 = $(R2); + updatekline = function (value){ + si = value; + var mesh = $(mesh2); + var positions = mesh.geometry.attributes.offset.array; + for ( var i = 0, l = positions.length; i < l; i += 2 ) { + positions[i+1] = si*Math.sin(positions[i]); + } + mesh.geometry.attributes.offset.needsUpdate = true; + } + setInterval( + function (){ + function randn_bm() { + var u = 0, v = 0; + while(u === 0) u = Math.random(); //Converting [0,1) to (0,1) + while(v === 0) v = Math.random(); + return Math.sqrt( -2.0 * Math.log( u ) ) * Math.cos( 2.0 * Math.PI * v ); + } + var mu = 0.2; + var mesh = $(mesh); + var K = $(K); + var n = $(n); + var dt = $(dt); + console.log(iter++); + var sqrtdt = $(sqrtdt); + + k = iter%K; + var positions = mesh.geometry.attributes.offset.array; + var color = mesh.geometry.attributes.color.array; + console.log(color.length); + for ( var i = 0; i < n; i++ ) { + inew = k*2*n + 2*i; + iold = ((K + k - 1)%K)*2*n + 2*i; + positions[inew] = (1 - mu*dt)*positions[iold] - 3*dt*positions[iold+1] + si*sqrtdt*randn_bm(); // x + positions[inew+1] = (1 - mu*dt)*positions[iold+1] + 3*dt*positions[iold] + si*sqrtdt*randn_bm(); + color[k*4*n + 4*i] = 1.0; + color[k*4*n + 4*i + 1] = 1.0; + color[k*4*n + 4*i + 2] = 1.0; + color[k*4*n + 4*i + 3] = 1.0; + if (Math.random() < 0.01) + { + positions[inew] = (2*Math.random()-1)*R1; + positions[inew+1] = (2*Math.random()-1)*R2; + } + + } + for ( var k = 0; k < K; k++ ) { + for ( var i = 0; i < n; i++ ) { + color[k*4*n + 4*i + 3] = 0.98*color[k*4*n + 4*i + 3]; + } + } + mesh.geometry.attributes.color.needsUpdate = true; + mesh.geometry.attributes.offset.needsUpdate = true; + + } + , 50); + """) + onjs(session, sliders.value, js"""function (value){ + updatekline(value); + }""") + sliderbgurl = JSServe.url(sliderbg) + style_obs = Observable(" background-repeat: no-repeat; background-image: url($sliderbgurl);") + + global dom = DOM.div(particlecss, DOM.p(canvas), DOM.p("Parameters"), DOM.div(sliders, id="slider", style=style_obs), + DOM.p(nrs)) + println("running...") + dom +end + + +app = JSServe.Application( + dom_handler, + get(ENV, "WEBIO_SERVER_HOST_URL", "127.0.0.1"), + parse(Int, get(ENV, "WEBIO_HTTP_PORT", "8081")), + verbose = false +) +function cl() + close(app) + global dom = nothing +end +println("Done.") +# +if false + +cl() + +end diff --git a/scripts/nextjournal/domserv_fhn.jl b/scripts/nextjournal/domserv_fhn.jl new file mode 100644 index 0000000..cd2e203 --- /dev/null +++ b/scripts/nextjournal/domserv_fhn.jl @@ -0,0 +1,259 @@ +using JSServe, WGLMakie, AbstractPlotting +using JSServe: JSServe.DOM, @js_str, onjs +global three, scene + +using Colors +using Random +using WGLMakie: scatter, scatter! + +# fallback values if statistical analysis is not run +if !@isdefined thetas + thetas = [0.1, -0.8, 1.5, 0.0, 0.3] +end +if !@isdefined thetalims + thetalims = [(0.0,1.0), (-5.0,5.0), (0.0,10.0), (0.0,10.0), (0.0,1.0)] +end +WGLMakie.activate!() + +using Hyperscript, Markdown +using JSServe, Observables +using JSServe: Application, Session, evaljs, linkjs, div, active_sessions, Asset +using JSServe: @js_str, onjs, Button, TextField, Slider, JSString, Dependency, with_session +using JSServe.DOM + +function dom_handler(session, request) + global three, scene + + # fetch initial parameter (initial slider settings) + eps = thetas[1]; + s = thetas[2]; + gamma = thetas[3]; + beta = thetas[4]; + si = thetas[5]; + rebirth = 0.001; # how often a particle is "reborn" at random position + sl = 101 # slider sub-divisions + + # load histogram images to use as slider background + sliderbg = [JSServe.Asset(), JSServe.Asset(), JSServe.Asset(), JSServe.Asset(), JSServe.Asset()] + + # slider and field for sigma + slider5 = JSServe.Slider(range(thetalims[5]..., length=sl), si) + nrs5 = JSServe.NumberInput(si) + linkjs(session, slider5.value, nrs5.value) + + # slider and field for beta + slider4 = JSServe.Slider(range(thetalims[4]..., length=sl), beta) + nrs4 = JSServe.NumberInput(beta) + linkjs(session, slider4.value, nrs4.value) + + # slider and field for gamma + slider3 = JSServe.Slider(range(thetalims[3]..., length=sl), gamma) + nrs3 = JSServe.NumberInput(gamma) + linkjs(session, slider3.value, nrs3.value) + + # slider and field for s + slider2 = JSServe.Slider(range(thetalims[2]..., length=sl), s) + nrs2 = JSServe.NumberInput(s) + linkjs(session, slider2.value, nrs2.value) + + # slider and field for eps + slider1 = JSServe.Slider(range(thetalims[1]..., length=sl), eps) + nrs1 = JSServe.NumberInput(eps) + linkjs(session, slider1.value, nrs1.value) + + # slider and field for rebirth + slider6 = JSServe.Slider(0.0:0.0001:0.005, rebirth) + nrs6 = JSServe.NumberInput(rebirth) + linkjs(session, slider6.value, nrs6.value) + + # init + R = (1.5, 3.0) # plot area + R1, R2 = R + limits = FRect(-R[1], -R[2], 2R[1], 2R[2]) + n = 400 # no of particles + K = 150 # display K past positions of particle fading out + dt = 0.0005 # time step + sqrtdt = sqrt(dt) + + particlecss = JSServe.Asset() # style sheet + ms1 = 0.02 # markersize particles + ms2 = 0.02 # markersize isokline + + # plot particles, initially at random positions + global scene = WGLMakie.scatter(repeat(R1*(2rand(n) .- 1), outer=K), repeat(R2*(2rand(n) .- 1),outer=K), color = fill((:white,0f0), n*K), + backgroundcolor = RGB{Float32}(0.04, 0.11, 0.22), markersize = ms1, + glowwidth = 0.005, glowcolor = :white, + resolution=(500,500), limits = limits, + ) + + # style plot + axis = scene[Axis] + axis[:grid, :linewidth] = (0.3, 0.3) + axis[:grid, :linecolor] = (RGBA{Float32}(0.5, 0.7, 1.0, 0.3),RGBA{Float32}(0.5, 0.7, 1.0, 0.3)) + axis[:names][:textsize] = (0.0,0.0) + axis[:ticks, :textcolor] = (RGBA{Float32}(0.5, 0.7, 1.0, 0.5),RGBA{Float32}(0.5, 0.7, 1.0, 0.5)) + splot = scene[end] + + # plot isoklines + WGLMakie.scatter!(scene, -R1:0.01:R1, (-R1:0.01:R1) .- (-R1:0.01:R1).^3 .+ s, color = RGBA{Float32}(0.5, 0.7, 1.0, 0.8), markersize=ms2) + kplot1 = scene[end] + WGLMakie.scatter!(scene, -R1:0.01:R1, gamma*(-R1:0.01:R1) .+ beta , color = RGBA{Float32}(0.5, 0.7, 1.0, 0.8), markersize=ms2) + kplot2 = scene[end] + + # set up threejs scene + three, canvas = WGLMakie.three_display(session, scene) + js_scene = WGLMakie.to_jsscene(three, scene) + mesh = js_scene.getObjectByName(string(objectid(splot))) + mesh1 = js_scene.getObjectByName(string(objectid(kplot1))) + mesh2 = js_scene.getObjectByName(string(objectid(kplot2))) + + # init javascript + evaljs(session, js""" + iter = 1; // iteration number + + // fetch parameters + eps = $(eps); + s = $(s); + gamma = $(gamma); + beta = $(beta); + si = $(si); + R1 = $(R1); + R2 = $(R2); + rebirth = $(rebirth); + + // update functions for isoklines + updateklinebeta = function (value){ + beta = value; + var mesh = $(mesh2); + var positions = mesh.geometry.attributes.offset.array; + for ( var i = 0, l = positions.length; i < l; i += 2 ) { + positions[i+1] = beta + positions[i]*gamma; + } + mesh.geometry.attributes.offset.needsUpdate = true; + //mesh.geometry.attributes.color.needsUpdate = true; + } + updateklinegamma = function (value){ + gamma = value; + var mesh = $(mesh2); + var positions = mesh.geometry.attributes.offset.array; + for ( var i = 0, l = positions.length; i < l; i += 2 ) { + positions[i+1] = beta + positions[i]*gamma; + } + mesh.geometry.attributes.offset.needsUpdate = true; + //mesh.geometry.attributes.color.needsUpdate = true; + } + updateklines = function (value){ + s = value; + var mesh = $(mesh1); + var positions = mesh.geometry.attributes.offset.array; + for ( var i = 0, l = positions.length; i < l; i += 2 ) { + positions[i+1] = positions[i] - positions[i]*positions[i]*positions[i] + s; + } + mesh.geometry.attributes.offset.needsUpdate = true; + //mesh.geometry.attributes.color.needsUpdate = true; + } + + // move particles every x milliseconds + setInterval( + function (){ + function randn_bm() { + var u = 0, v = 0; + while(u === 0) u = Math.random(); //Converting [0,1) to (0,1) + while(v === 0) v = Math.random(); + return Math.sqrt( -2.0 * Math.log( u ) ) * Math.cos( 2.0 * Math.PI * v ); + } + var mu = 0.2; + var mesh = $(mesh); + var K = $(K); + var n = $(n); + var dt = $(dt); + console.log(iter++); + var sqrtdt = $(sqrtdt); + k = iter%K; + var positions = mesh.geometry.attributes.offset.array; + var color = mesh.geometry.attributes.color.array; + console.log(color.length); + for ( var i = 0; i < n; i++ ) { + inew = k*2*n + 2*i; + iold = ((K + k - 1)%K)*2*n + 2*i; + positions[inew] = positions[iold] + dt/eps*((1 - positions[iold]*positions[iold])*positions[iold] - positions[iold+1] + s); // x + positions[inew+1] = positions[iold+1] + dt*(-positions[iold+1] + gamma*positions[iold] + beta) + si*sqrtdt*randn_bm(); + color[k*4*n + 4*i] = 1.0; + color[k*4*n + 4*i + 1] = 1.0; + color[k*4*n + 4*i + 2] = 1.0; + color[k*4*n + 4*i + 3] = 1.0; + if (Math.random() < rebirth) + { + positions[inew] = (2*Math.random()-1)*R1; + positions[inew+1] = (2*Math.random()-1)*R2; + } + } + for ( var k = 0; k < K; k++ ) { + for ( var i = 0; i < n; i++ ) { + color[k*4*n + 4*i + 3] = 0.98*color[k*4*n + 4*i + 3]; + } + } + mesh.geometry.attributes.color.needsUpdate = true; + mesh.geometry.attributes.offset.needsUpdate = true; + } + , 15); // milliseconds + """) + + # react on slider movements + + onjs(session, slider1.value, js"""function (value){ + eps = value; + }""") + onjs(session, slider2.value, js"""function (value){ + updateklines(value); + }""") + onjs(session, slider3.value, js"""function (value){ + updateklinegamma(value); + }""") + onjs(session, slider4.value, js"""function (value){ + updateklinebeta(value); + }""") + onjs(session, slider5.value, js"""function (value){ + si = value; + }""") + onjs(session, slider6.value, js"""function (value){ + rebirth = value; + }""") + + # set background for sliders + styles = [Observable("padding-left:0px; padding-top: 10px; height: 50px; width=150px; background-size: 115px 50px; background-repeat: no-repeat; + background-position: center center; background-image: url($(JSServe.url(sliderbg[j])));") for j in 1:5] + + # arrange canvas and sliders as html elements + dom = DOM.div(particlecss, DOM.p(canvas), DOM.p("Parameters"), DOM.table( + DOM.tr(DOM.td("eps"), DOM.td("s"), DOM.td("gamma"), DOM.td("beta"), DOM.td("sigma")), + DOM.tr( + DOM.td(DOM.div(slider1, id="slider1", style=styles[1]), DOM.div(nrs1)), + DOM.td(DOM.div(slider2, id="slider2", style=styles[2]), DOM.div(nrs2)), + DOM.td(DOM.div(slider3, id="slider3", style=styles[3]), DOM.div(nrs3)), + DOM.td(DOM.div(slider4, id="slider4", style=styles[4]), DOM.div(nrs4)), + DOM.td(DOM.div(slider5, id="slider5", style=styles[5] ), DOM.div(nrs5)) + ), + DOM.tr( + DOM.td("rebirth", DOM.div(slider6, id="slider6"), DOM.div(nrs6)), + ))) + println("running...") + dom +end + +# attach handler to current session +#JSServe.with_session() do session, request +# dom_handler(session, request) +#end + +app = JSServe.Application( + dom_handler, + get(ENV, "WEBIO_SERVER_HOST_URL", "127.0.0.1"), + parse(Int, get(ENV, "WEBIO_HTTP_PORT", "8081")), + verbose = false +) + +cl() = (close(app), "stopped") +#println("Done.") +# +cl() diff --git a/scripts/nextjournal/nextjournal_script.jl b/scripts/nextjournal/nextjournal_script.jl new file mode 100644 index 0000000..fc7cfe5 --- /dev/null +++ b/scripts/nextjournal/nextjournal_script.jl @@ -0,0 +1,101 @@ +## Intro +using Profile, Revise +using Bridge +using StaticArrays +using BridgeSDEInference +using Random, LinearAlgebra, Distributions +const State = SArray{Tuple{2},T,1,2} where {T}; +using BridgeSDEInference: EulerMaruyamaBounded +param = :regular # regular parametrization of the trajectories (no transformation of the coordinate processes) +ε = 0.1 ; s =-0.8 ; γ =1.5 ; β = 0.0 ; σ =0.3 ; +P = FitzhughDiffusion(param, ε, s, γ, β, σ); + +include(joinpath(pathof(BridgeSDEInference), "..", "auxiliary", "data_simulation_fns.jl")) +# starting point under :regular parametrisation +x0 = State(-0.5, -0.6) + +## Simulation +# time grid +dt = 1/50000 +T = 20.0 +tt = 0.0:dt:T + +Random.seed!(4) +X, _ = simulate_segment(0.0, x0, P, tt); + +# subsampling +num_obs = 100 +skip = div(length(tt), num_obs) +Σ = [10^(-4)] +L = [1.0 0.0] +obs = (time = X.tt[1:skip:end], values = [Float64(rand(MvNormal(L*x, Σ))[1]) for x in X.yy[1:skip:end]]); + +## Plotting +#=using Plots +gr() +Plots.plot(X.tt, first.(X.yy), label = "X") +Plots.plot!(X.tt, last.(X.yy), label = "Y") +Plots.scatter!(obs.time, obs.values, markersize=1.5, label = "observations") +=# +## Inference +θ_init = [ε, s, γ, β, σ].*(1 .+ (2*rand(5) .- 1).*[0.2, 0.2, 0.2, 0.0, 0.2]) + + +# Take the real β, as it is fixed. + +P_trgt = FitzhughDiffusion(param, θ_init...) +P_aux = [FitzhughDiffusionAux(param, θ_init..., t₀, u, T, v) for (t₀,T,u,v) + in zip(obs.time[1:end-1], obs.time[2:end], obs.values[1:end-1], obs.values[2:end])] + +# Container +model_setup = DiffusionSetup(P_trgt, P_aux, PartObs()) + +# Observation scheme +L = @SMatrix [1. 0.] +Σ = @SMatrix [10^(-4)] +set_observations!(model_setup, [L for _ in P_aux], [Σ for _ in P_aux], + obs.values, obs.time) + +# Imputation grid +dt = 1/200 +set_imputation_grid!(model_setup, dt) +# Prior distribution on (X_0, Y_0) +set_x0_prior!(model_setup, KnownStartingPt(x0)); +#set_x0_prior!(model_setup, GsnStartingPt(x0, @SMatrix [.1 0; 0 .1]), x0); + +initialise!(eltype(x0), model_setup, Vern7(), false, NoChangePt(100)) +# Further setting +set_auxiliary!(model_setup; skip_for_save=1, adaptive_prop=NoAdaptation()); + +mcmc_setup = MCMCSetup( + Imputation(NoBlocking(), 0.975, Vern7()), + ParamUpdate(MetropolisHastingsUpdt(), 1, θ_init, + UniformRandomWalk(0.5, true), ImproperPosPrior(), + UpdtAuxiliary(Vern7(), check_if_recompute_ODEs(P_aux, 1)) + ), + ParamUpdate(MetropolisHastingsUpdt(), 2, θ_init, + UniformRandomWalk(0.5, false), ImproperPrior(), + UpdtAuxiliary(Vern7(), check_if_recompute_ODEs(P_aux, 2)) + ), + ParamUpdate(MetropolisHastingsUpdt(), 3, θ_init, + UniformRandomWalk(0.5, true), ImproperPosPrior(), + UpdtAuxiliary(Vern7(), check_if_recompute_ODEs(P_aux, 3)) + ), + ParamUpdate(MetropolisHastingsUpdt(), 5, θ_init, + UniformRandomWalk(0.5, true), ImproperPosPrior(), + UpdtAuxiliary(Vern7(), check_if_recompute_ODEs(P_aux, 5)) + )) + +schedule = MCMCSchedule(1*10^3, [[1,2,3,4,5]], + (save=3*10^2, verbose=10^2, warm_up=100, + readjust=(x->x%100==0), fuse=(x->false))); + +Random.seed!(4) +Profile.init() +setup_mcmc = mcmc_setup +setup = model_setup +ws, ll, θ = BridgeSDEInference.create_workspace(setup) +ws_mcmc = BridgeSDEInference.create_workspace(setup_mcmc, schedule, θ) +adpt = BridgeSDEInference.adaptation_object(setup, ws) +Profile.clear() +@profile out = BridgeSDEInference.mcmc_(setup_mcmc, schedule, setup, ws, ll, θ, ws_mcmc, adpt) diff --git a/scripts/nextjournal/particle.css b/scripts/nextjournal/particle.css new file mode 100644 index 0000000..6eafc55 --- /dev/null +++ b/scripts/nextjournal/particle.css @@ -0,0 +1,30 @@ +#slider1 { + background-color: #778899; + width: 4cm; + background-image: linear-gradient(to right, #8888FF, #FFFF88, #8888FF); +} + +#slider2 { + background-color: #778899; + width: 4cm; + background-image: linear-gradient(to right, #8888FF, #FFFF88, #8888FF); +} + + +#slider3 { + background-color: #778899; + width: 4cm; + background-image: linear-gradient(to right, #8888FF, #FFFF88, #8888FF); +} + +#slider4 { + background-color: #778899; + width: 4cm; + background-image: linear-gradient(to right, #8888FF, #FFFF88, #8888FF); +} + +#slider5 { + background-color: #778899; + width: 4cm; + background-image: linear-gradient(to right, #8888FF, #FFFF88, #8888FF); +} diff --git a/scripts/nextjournal/slider1.png b/scripts/nextjournal/slider1.png new file mode 100644 index 0000000..cf625f3 Binary files /dev/null and b/scripts/nextjournal/slider1.png differ diff --git a/src/BridgeSDEInference.jl b/src/BridgeSDEInference.jl index f0bf8d0..c171080 100644 --- a/src/BridgeSDEInference.jl +++ b/src/BridgeSDEInference.jl @@ -5,6 +5,10 @@ using Statistics, Random, LinearAlgebra using ForwardDiff using ForwardDiff: value +#sir.jl +export SIR, SIRAux + + # fitzHughNagumo.jl export FitzhughDiffusion, FitzhughDiffusionAux, ℝ export regularToAlter, alterToRegular, regularToConjug, conjugToRegular, display @@ -103,6 +107,7 @@ include(joinpath(_DIR, "lorenz_system_const_vola.jl")) include(joinpath(_DIR, "prokaryotic_autoregulatory_gene_network.jl")) include(joinpath(_DIR, "Jansen_and_Rit_simple.jl")) include(joinpath(_DIR, "lotka_volterra.jl")) +include(joinpath(_DIR, "sir.jl")) _DIR = "mcmc" include(joinpath(_DIR, "priors.jl")) diff --git a/src/BridgeSDEInference_for_tests.jl b/src/BridgeSDEInference_for_tests.jl index 91e9f74..2e31662 100644 --- a/src/BridgeSDEInference_for_tests.jl +++ b/src/BridgeSDEInference_for_tests.jl @@ -36,6 +36,8 @@ include(joinpath(_DIR, "lorenz_system.jl")) include(joinpath(_DIR, "lorenz_system_const_vola.jl")) include(joinpath(_DIR, "prokaryotic_autoregulatory_gene_network.jl")) include(joinpath(_DIR, "lotka_volterra.jl")) +include(joinpath(_DIR, "sir.jl")) + _DIR = "mcmc" include(joinpath(_DIR, "priors.jl")) diff --git a/src/examples/sir.jl b/src/examples/sir.jl new file mode 100644 index 0000000..926483c --- /dev/null +++ b/src/examples/sir.jl @@ -0,0 +1,104 @@ +using Bridge +using StaticArrays +import Bridge: b, σ, B, β, a, constdiff +const ℝ = SVector{N,T} where {N,T} +import Base.display +sq(x) = sqrt(max(x, 2e-10)) +struct SIR{T} <: ContinuousTimeProcess{SVector{2,T}} + α::T + β::T + σ1::T + σ2::T +end + +b(t, u, P::SIR) = @SVector [P.α*(1 - u[1] - u[2])*u[1] - P.β*u[1], P.β*u[1]] +σ(t, u, P::SIR) = @SMatrix Float64[ + -P.σ1*sq((1 - u[1] - u[2])*u[1]) -P.σ2*sq(u[1]) + 0.0 P.σ2*sq.(u[1]) + ] +a(t, u, P::SIR) = σ(t, u, P)*σ(t, u, P)' +constdiff(::SIR) = false +clone(P::SIR, θ) = SIR(θ...) +params(P::SIR) = [P.α, P.β, P.σ1, P.σ2] +#domain(P::SIR) = LowerBoundedDomain((0.0, 0.0), (1,2)) +domain(P::SIR) = LowerBoundedDomain((0.0, 0.0), (1, 2)) + + +# <--------------------------------------------- +# this is optional, needed for conjugate updates +phi(::Val{0}, t, u, P::SIR) = (zero(u[1]), zero(u[2])) +#[P.α*(1 - u[1] - u[2])*u[1] - P.β*u[1], P.β*u[1]] +phi(::Val{1}, t, u, P::SIR) = ((1 - u[1] - u[2])*u[1], zero(u[2])) +phi(::Val{2}, t, u, P::SIR) = (-u[1], u[1]) +phi(::Val{3}, t, x, P::SIR) = (zero(x[1]), zero(x[2])) +phi(::Val{4}, t, x, P::SIR) = (zero(x[1]), zero(x[2])) +phi(::Val{5}, t, x, P::SIR) = (zero(x[1]), zero(x[2])) + + + +nonhypo(P::SIR, x) = x +@inline hypo_a_inv(P::SIR, t, x) = inv(a(t, x, P)) +num_non_hypo(P::Type{<:SIR}) = 2 + + +struct SIRAux{T,S1,S2} <: ContinuousTimeProcess{SVector{2,T}} + α::T + β::T + σ1::T + σ2::T + t::Float64 + u::S1 + T::Float64 + v::S2 +end + + +# function B(t, P::SIRAux) +# # b(t, u, P::SIR) = @SVector [P.α*(P.k - u[1] - u[2])*u[1] - P.β*u[1], P.β*u[1]] +# @SMatrix [(P.α) -(P.β); +# (P.β) -0.0] +# end +function B(t, P::SIRAux) +# b(t, u, P::SIR) = @SVector [P.α*(1 - u[1] - u[2])*u[1] - P.β*u[1], P.β*u[1]] + @SMatrix [(P.α*(1 - P.v[1] - P.v[2]) - (P.β)) 0.0; + (P.β) 0.0] +end + + + + + +# mean = ℝ{2}(P.γ/P.δ, P.α/P.β) +function β(t, P::SIRAux) + ℝ{2}(0.0, 0.0) +end + +# function σ(t, P::SIRAux) +# sq(0.0001)*@SMatrix Float64[ +# -P.σ1 -P.σ2 +# 0.0 P.σ2 +# ] +# end + +function σ(t, P::SIRAux) + @SMatrix Float64[ + (-P.σ1*(1 - P.v[2] - P.v[1])*P.v[1]) -P.σ2*P.v[1]; + 0.0 P.σ2*P.v[1]] +end + + + + +σ(t, x, P::SIRAux) = σ(t, P) + +depends_on_params(::SIRAux) = (3, 4, 5) + +constdiff(::SIRAux) = true +b(t, x, P::SIRAux) = B(t,P) * x + β(t,P) +a(t, P::SIRAux) = σ(t,P) * σ(t, P)' + + +clone(P::SIRAux, θ) = SIRAux(θ..., P.t, P.u, P.T, P.v) + +clone(P::SIRAux, θ, v) = SIRAux(θ..., P.t, v, P.T, v) +params(P::SIRAux) = [P.α, P.β, P.σ1, P.σ2] diff --git a/src/general/readjustments.jl b/src/general/readjustments.jl index e53bf42..5848535 100644 --- a/src/general/readjustments.jl +++ b/src/general/readjustments.jl @@ -6,8 +6,14 @@ function named_readjust(p) (step=p[1], scale=p[2], min=p[3], max=p[4], trgt=p[5], offset=p[6]) end +""" +δ decreases roughly proportional to scale/sqrt(iteration) +""" compute_δ(p, mcmc_iter) = p.scale/sqrt(max(1.0, mcmc_iter/p.step-p.offset)) +""" +ϵ is moved by δ to adapt to target acceptance rate +""" function compute_ϵ(ϵ_old, p, a_r, δ, flip=1.0, f=identity, finv=identity) ϵ = finv(f(ϵ_old) + flip*(2*(a_r > p.trgt)-1)*δ) ϵ = max(min(ϵ, p.max), p.min) # trim excessive updates diff --git a/src/mcmc/mcmc.jl b/src/mcmc/mcmc.jl index 6775cc7..0507a2f 100644 --- a/src/mcmc/mcmc.jl +++ b/src/mcmc/mcmc.jl @@ -13,7 +13,9 @@ function mcmc(setup_mcmc::MCMCSetup, schedule::MCMCSchedule, setup::T) where T < ws, ll, θ = create_workspace(setup) ws_mcmc = create_workspace(setup_mcmc, schedule, θ) adpt = adaptation_object(setup, ws) - + mcmc_(setup_mcmc, schedule, setup, ws, ll, θ, ws_mcmc, adpt) +end +function mcmc_(setup_mcmc, schedule, setup, ws, ll, θ, ws_mcmc, adpt) aux = nothing for step in schedule step.save && save_imputed!(ws) diff --git a/src/mcmc/updates.jl b/src/mcmc/updates.jl index 6cb1079..cd69813 100644 --- a/src/mcmc/updates.jl +++ b/src/mcmc/updates.jl @@ -138,7 +138,13 @@ end Preconditioned Crank-Nicolson update with memory parameter `ρ`, previous vector `y` and new vector `yᵒ` """ -crank_nicolson!(yᵒ, y, ρ) = (yᵒ .= √(1-ρ)*yᵒ + √(ρ)*y) +function crank_nicolson!(yᵒ, y, ρ) + r1 = √(1-ρ) + r2 = √(ρ) + for i in eachindex(y) + yᵒ[i] = r1*yᵒ[i] + r2*y[i] + end +end """ path_log_likhd(::OS, XX, P, iRange, fpt; skipFPT=false diff --git a/src/mcmc/workspace.jl b/src/mcmc/workspace.jl index 85391ff..920679a 100644 --- a/src/mcmc/workspace.jl +++ b/src/mcmc/workspace.jl @@ -151,7 +151,7 @@ set!(x::SingleElem{T}, y::T) where T = (x.val = y) The main container of the `mcmc` function from `mcmc.jl` in which most data pertinent to sampling is stored """ -struct Workspace{ObsScheme,S,TX,TW,R,TP,TZ}# ,Q, where Q = eltype(result) +struct Workspace{ObsScheme,S,TX,TW,R,TP,TL,TZ}# ,Q, where Q = eltype(result) # Related to imputed path Wnr::Wiener{S} # Wiener, driving law XXᵒ::Vector{TX} # Diffusion proposal paths @@ -167,6 +167,7 @@ struct Workspace{ObsScheme,S,TX,TW,R,TP,TZ}# ,Q, where Q = eltype(result) time::Vector{Float64} # Storage with time axis # Related to the starting point x0_prior::TP + ll::TL z::SingleElem{TZ} #recompute_ODEs::Vector{Bool} # Info on whether to recompute H,Hν,c after resp. param updt @@ -197,10 +198,12 @@ struct Workspace{ObsScheme,S,TX,TW,R,TP,TZ}# ,Q, where Q = eltype(result) end y = XX[i].yy[end] end + y = x0_guess ll = ( logpdf(x0_prior, y) + path_log_likhd(ObsScheme(), XX, P, 1:m, fpt, skipFPT=true) + lobslikelihood(P[1], y) ) + TL = typeof(ll) XXᵒ, WWᵒ, Pᵒ = deepcopy(XX), deepcopy(WW), deepcopy(P) @@ -213,17 +216,16 @@ struct Workspace{ObsScheme,S,TX,TW,R,TP,TZ}# ,Q, where Q = eltype(result) _time = collect(Iterators.flatten(p.tt[1:skip:end-1] for p in P)) #check_if_recompute_ODEs(setup) - (workspace = new{ObsScheme,S,TX,TW,R,TP,TZ}(Wnr, XXᵒ, XX, WWᵒ, WW, Pᵒ, + new{ObsScheme,S,TX,TW,R,TP,TL,TZ}(Wnr, XXᵒ, XX, WWᵒ, WW, Pᵒ, P, fpt, skip, [], _time, - x0_prior, z), - ll = ll, θ = params(P[1].Target)) + x0_prior, ll, z) end function Workspace(ws::Workspace{ObsScheme,S,TX,TW,R̃,TP,TZ}, P::Vector{R}, Pᵒ::Vector{R}) where {ObsScheme,S,TX,TW,R̃,R,TP,TZ} - new{ObsScheme,S,TX,TW,R,TP,TZ}(ws.Wnr, ws.XXᵒ, ws.XX, ws.WWᵒ, ws.WW, Pᵒ, + new{ObsScheme,S,TX,TW,R,TP,TL,TZ}(ws.Wnr, ws.XXᵒ, ws.XX, ws.WWᵒ, ws.WW, Pᵒ, P, ws.fpt, ws.skip_for_save, ws.paths, - ws.time, ws.x0_prior, ws.z) + ws.time, ws.x0_prior, ws.ll, ws.z) end end @@ -274,5 +276,7 @@ function create_workspace(setup::MCMCSetup, schedule::MCMCSchedule, θ) end function create_workspace(setup::T) where {T <: ModelSetup} - Workspace(setup) + ws = Workspace(setup) + θ = params(ws.P[1].Target) + ws, ws.ll, θ end