Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
389531a
SIR model
mschauer Jan 30, 2020
75af874
Tuning
mschauer Jan 30, 2020
676a21a
Fixes SIR
mschauer Jan 31, 2020
c1d01a5
Experiment
mschauer Feb 1, 2020
93fa09b
js particle animation
mschauer Mar 11, 2020
aef9ae9
Restart particles
mschauer Mar 11, 2020
4f1fcfa
interface with FitzHugh-Nagumo dynamics
SebaGraz Mar 11, 2020
5c91968
Fix sir diffusion and auxiliary diffusion
SebaGraz Mar 11, 2020
15c36a0
smoothing
SebaGraz Mar 11, 2020
d9cd9c3
Run automatically
mschauer Mar 11, 2020
36fac3a
try multiple sliders
SebaGraz Mar 11, 2020
04d0bff
Merge branch 'sir' of https://github.com/mmider/BridgeSDEInference.jl…
SebaGraz Mar 11, 2020
5452f6a
update
SebaGraz Mar 11, 2020
b901196
Experiment with kline
mschauer Mar 12, 2020
d540cfe
Fixes. Parameter estimation not working yet
SebaGraz Mar 17, 2020
97b160f
Fixes. Parameter estimation not working yet
SebaGraz Mar 17, 2020
25f9820
Merge branch 'sir' of https://github.com/mmider/BridgeSDEInference.jl…
SebaGraz Mar 17, 2020
edb7743
add sliders for each parameter and automatic update of the particles
SebaGraz Mar 20, 2020
c25d3d4
add visible nullclines and small fixes
SebaGraz Mar 21, 2020
090737a
updatekline not working
SebaGraz Mar 21, 2020
8cbef36
samll fix, still nullkline does not move
SebaGraz Mar 21, 2020
8e30990
interface fully functioning
SebaGraz Mar 21, 2020
d29254b
Small fixes
mschauer Mar 21, 2020
bc20498
Bisschen huebscher
mschauer Mar 22, 2020
ab6f0fa
small changes
SebaGraz Mar 22, 2020
91e064d
Serve a background picture for the sliders
mschauer Mar 23, 2020
4aef50d
Current nextjournal code
mschauer Mar 25, 2020
6da3c8f
Script to profile
SebaGraz Mar 25, 2020
5e01837
Merge branch 'sir' of https://github.com/mmider/BridgeSDEInference.jl…
SebaGraz Mar 25, 2020
ea77f7a
Fixes
mschauer Mar 26, 2020
dfab57f
Fix some type inference issues
mschauer Mar 26, 2020
1c9f31c
Profile inner loop
mschauer Mar 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions scripts/inference/sir.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
SRC_DIR = joinpath(Base.source_dir(), "..", "..", "src")
OUT_DIR = joinpath(Base.source_dir(), "..", "..", "output")
mkpath(OUT_DIR)
using Makie: lines, lines!

include(joinpath(SRC_DIR, "BridgeSDEInference_for_tests.jl"))

using StaticArrays
using Distributions
using Random
# Let's generate the data
# -----------------------
using Bridge
#import BridgeSDEInference: CIR, CIRaux
DIR = "auxiliary"
include(joinpath(SRC_DIR, DIR, "data_simulation_fns.jl"))
include(joinpath(SRC_DIR, DIR, "utility_functions.jl"))
Random.seed!(2)
pop = 50_000_000
#θˣ = [0.37, 0.05, 0.05, 0.01]
#α, β, σ1, σ2
θˣ = [0.37, 0.05, 0.3/sqrt(pop), 1.0/sqrt(pop)]

Pˣ = SIR(θˣ...)

x0, dt, T = ℝ{2}(1/pop, 0.), 1/10000, 10.0
tt = 0.0:dt:T

Random.seed!(1)
XX, _ = simulate_segment(ℝ{2}(1.0, 0.0), x0, Pˣ, tt)
last(XX)[2]*pop

#lines(XX.tt, K .- sum.(XX.yy))
lines(XX.tt, first.(XX.yy), color = :red) #infected
lines!(XX.tt,last.(XX.yy), color = :blue) #recovered + dead people

θ_init = [0.5, 0.1, 0.3/sqrt(pop), 1.0/sqrt(pop)]
Pˣ = SIR(θ_init...)

length(XX.tt)
skip = 2000

#Σdiagel =
Σ = @SMatrix[1.0 0.0; 0.0 0.5]/pop
L = @SMatrix[1.0 0.0; 0.0 1.0]

obs_time, obs_vals = XX.tt[1:skip:end], [rand(Gaussian(L*x, Σ)) for x in XX.yy[1:skip:end]]

days = [0.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0]
cases = [1.0 0.0; 62.0 0.0; 121.0 0.0; 198.0 3.0; 291.0 6.0; 440.0 9.0; 571.0 17.0; 830.0 25.0; 1287.0 41.0; 1975.0 56.0; 2744.0 80.0; 4515.0 106.0; 5974.0 132.0; 7771.0 170.0]

if true
obs_time = days
obs_vals = 1/pop*reinterpret(SVector{2,Float64}, [1 -1; 0 1]*cases') # first coordinate is I + R
end
obs_vals[end]
P̃ = [SIRAux(θ_init..., t₀, u, T, v) for (t₀, T, u, v)
in zip(obs_time[1:end-1], obs_time[2:end], obs_vals[1:end-1], obs_vals[2:end])]

model_setup = DiffusionSetup(Pˣ, P̃, PartObs())
set_observations!(model_setup, [L for _ in P̃], [Σ for _ in P̃], obs_vals, obs_time) # uses default fpt
set_imputation_grid!(model_setup, 1/1000)
set_x0_prior!(model_setup,
GsnStartingPt(x0, @SMatrix [0.01 0.0;
0.0 0.01;]),
x0)
set_auxiliary!(model_setup; skip_for_save=10^0,
adaptive_prop=NoAdaptation())
initialise!(eltype(x0), model_setup, Vern7(), false, NoChangePt(100))
#:step, :scale, :min, :max, :trgt, :offset
readj = (100, 0.001, 0.001, 999.9, 0.4, 50)
readj2 = (100, 0.01, 0.05, 0.2, 0.7, 50)


mcmc_setup = MCMCSetup(
Imputation(NoBlocking(), 0.99, Vern7()),
ParamUpdate(ConjugateUpdt(), [1,2], θ_init, nothing,
MvNormal(fill(0.05, 2), diagm(0=>fill(1000.0, 2))),
UpdtAuxiliary(Vern7(), check_if_recompute_ODEs(P̃, [1,2]))),

)

schedule = MCMCSchedule(10^4, [[1],[2]], #[[1],[2], [5]],
(save=10^2, verbose=10^2, warm_up=100,
readjust=(x->x%100==0), fuse=(x->false)))

Random.seed!(4)
out = mcmc(mcmc_setup, schedule, model_setup)
error("STOP HERE")


using Makie
ws = out[2]


θs = ws.θ_chain
±(a, b) = a - b, a + b
beta, gamma, s1 = [median(getindex.(θs, i)) for i in [1,2,3]] .± [std(getindex.(θs, i)) for i in [1,2, 3]]
R0 = mean(getindex.(θs, 1)./getindex.(θs, 2)) ± std(getindex.(θs, 1)./getindex.(θs, 2))

θs = ws.θ_chain



@show beta
@show gamma
@show s1
@show R0

lines(getindex.(θs, 1))
lines(getindex.(θs, 2))



lines(getindex.(θs, 3))

include(joinpath(SRC_DIR, DIR, "plotting_fns.jl"))
plot_chains(ws; truth=θˣ)
#=
plot_paths(out[1], out[2], schedule; obs=(times=obs_time[2:end],
vals=[[v[1] for v in obs_vals[2:end]],
[v[2] for v in obs_vals[2:end]]], indices=[2,3]))
=#
165 changes: 165 additions & 0 deletions scripts/nextjournal/domserv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
using JSServe, WGLMakie, AbstractPlotting
using JSServe: JSServe.DOM, @js_str, onjs
global three, scene

using StaticArrays
using Colors
using Random
rebirth(α, R) = x -> (rand() > α ? x : (2rand(typeof(x)) .- 1).*R)
const 𝕏 = SVector

using FileIO
using Makie: band
using GLMakie

using Hyperscript, Markdown
using JSServe, Observables
using JSServe: Application, Session, evaljs, linkjs, div, active_sessions, Asset
using JSServe: @js_str, onjs, Button, TextField, Slider, JSString, Dependency, with_session
using JSServe.DOM

function dom_handler(session, request)
global three, scene

if @isdefined dom
dom != nothing && return dom
end

WGLMakie.activate!()


# slider and field for sigma
sliders = JSServe.Slider(0.01:0.01:1)
nrs = JSServe.NumberInput(0.0)
linkjs(session, sliders.value, nrs.value)

# time wheel ;-)
button = JSServe.Slider(1:109)

# init
R = 𝕏(1.5,6.0)
R1, R2 = R
limits = FRect(-R[1], -R[2], 2R[1], 2R[2])
n = 800
K = 80
dt = 0.001
sqrtdt = sqrt(dt)

particlecss = Asset(joinpath(@__DIR__,"particle.css"))
global sliderbg = Asset(joinpath(@__DIR__,"slider1.png"))

ms = 0.03
global scene = scatter(repeat(2randn(n), outer=K), repeat(2randn(n),outer=K), color = fill(:white, n*K),
backgroundcolor = RGB{Float32}(0.04, 0.11, 0.22), markersize = ms,
glowwidth = 0.005, glowcolor = :white,
resolution=(600,600), limits = limits,
)
axis = scene[Axis]
axis[:grid, :linewidth] = (0.3, 0.3)
axis[:grid, :linecolor] = (RGBA{Float32}(0.5, 0.7, 1.0, 0.3),RGBA{Float32}(0.5, 0.7, 1.0, 0.3))
axis[:names][:textsize] = (0.0,0.0)
axis[:ticks, :textcolor] = (RGBA{Float32}(0.5, 0.7, 1.0, 0.5),RGBA{Float32}(0.5, 0.7, 1.0, 0.5))


splot = scene[end]
scatter!(scene, -R1:0.01:R1, sin.(-R1:0.01:R1), color = RGBA{Float32}(255, 0.0, 4.0, 1.0), markersize=ms)
kplot = scene[end]

three, canvas = WGLMakie.three_display(session, scene)
js_scene = WGLMakie.to_jsscene(three, scene)
mesh = js_scene.getObjectByName(string(objectid(splot)))
mesh2 = js_scene.getObjectByName(string(objectid(kplot)))

# init javascript
evaljs(session, js"""
console.log("Hello");
iter = 1;
si = 0.0;
R1 = $(R1);
R2 = $(R2);
updatekline = function (value){
si = value;
var mesh = $(mesh2);
var positions = mesh.geometry.attributes.offset.array;
for ( var i = 0, l = positions.length; i < l; i += 2 ) {
positions[i+1] = si*Math.sin(positions[i]);
}
mesh.geometry.attributes.offset.needsUpdate = true;
}
setInterval(
function (){
function randn_bm() {
var u = 0, v = 0;
while(u === 0) u = Math.random(); //Converting [0,1) to (0,1)
while(v === 0) v = Math.random();
return Math.sqrt( -2.0 * Math.log( u ) ) * Math.cos( 2.0 * Math.PI * v );
}
var mu = 0.2;
var mesh = $(mesh);
var K = $(K);
var n = $(n);
var dt = $(dt);
console.log(iter++);
var sqrtdt = $(sqrtdt);

k = iter%K;
var positions = mesh.geometry.attributes.offset.array;
var color = mesh.geometry.attributes.color.array;
console.log(color.length);
for ( var i = 0; i < n; i++ ) {
inew = k*2*n + 2*i;
iold = ((K + k - 1)%K)*2*n + 2*i;
positions[inew] = (1 - mu*dt)*positions[iold] - 3*dt*positions[iold+1] + si*sqrtdt*randn_bm(); // x
positions[inew+1] = (1 - mu*dt)*positions[iold+1] + 3*dt*positions[iold] + si*sqrtdt*randn_bm();
color[k*4*n + 4*i] = 1.0;
color[k*4*n + 4*i + 1] = 1.0;
color[k*4*n + 4*i + 2] = 1.0;
color[k*4*n + 4*i + 3] = 1.0;
if (Math.random() < 0.01)
{
positions[inew] = (2*Math.random()-1)*R1;
positions[inew+1] = (2*Math.random()-1)*R2;
}

}
for ( var k = 0; k < K; k++ ) {
for ( var i = 0; i < n; i++ ) {
color[k*4*n + 4*i + 3] = 0.98*color[k*4*n + 4*i + 3];
}
}
mesh.geometry.attributes.color.needsUpdate = true;
mesh.geometry.attributes.offset.needsUpdate = true;

}
, 50);
""")
onjs(session, sliders.value, js"""function (value){
updatekline(value);
}""")
sliderbgurl = JSServe.url(sliderbg)
style_obs = Observable(" background-repeat: no-repeat; background-image: url($sliderbgurl);")

global dom = DOM.div(particlecss, DOM.p(canvas), DOM.p("Parameters"), DOM.div(sliders, id="slider", style=style_obs),
DOM.p(nrs))
println("running...")
dom
end


app = JSServe.Application(
dom_handler,
get(ENV, "WEBIO_SERVER_HOST_URL", "127.0.0.1"),
parse(Int, get(ENV, "WEBIO_HTTP_PORT", "8081")),
verbose = false
)
function cl()
close(app)
global dom = nothing
end
println("Done.")
#
if false

cl()

end
Loading