Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 45 additions & 16 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ git-tree-sha1 = "59e36cd1d343605aec81d7ef1d59ed98b9a22949"
repo-rev = "master"
repo-url = "https://github.com/mschauer/Bridge.jl.git"
uuid = "2d3116d5-4b8f-5680-861c-71f149790274"
version = "0.9.0+"
version = "0.10.0+"

[[CategoricalArrays]]
deps = ["Compat", "DataAPI", "Future", "JSON", "Missings", "Printf", "Reexport"]
git-tree-sha1 = "13240cfcc884837fc1aa89b60d500a652bcc3f10"
deps = ["Compat", "DataAPI", "Future", "JSON", "Missings", "Printf", "Reexport", "Unicode"]
git-tree-sha1 = "5f4400b24adb1fbed17a9dcc1e8ab8aaf5b03d1f"
uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597"
version = "0.5.5"
version = "0.6.0"

[[ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
Expand All @@ -42,10 +42,16 @@ uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.8.0"

[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport"]
git-tree-sha1 = "c9c1845d6bf22e34738bee65c357a69f416ed5d1"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.9.5"
version = "0.9.6"

[[CommonSubexpressions]]
deps = ["Test"]
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.2.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
Expand Down Expand Up @@ -83,6 +89,18 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[DiffResults]]
deps = ["Compat", "StaticArrays"]
git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "0.0.4"

[[DiffRules]]
deps = ["Random", "Test"]
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.10"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -98,6 +116,12 @@ git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.6.1"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.3"

[[Future]]
deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
Expand Down Expand Up @@ -147,14 +171,19 @@ deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[Missings]]
deps = ["SparseArrays", "Test"]
git-tree-sha1 = "f0719736664b4358aa9ec173077d4285775f8007"
git-tree-sha1 = "29858ce6c8ae629cf2d733bffa329619a1c843d0"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.1"
version = "0.4.2"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[NaNMath]]
deps = ["Compat"]
git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.2"

[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
Expand All @@ -163,9 +192,9 @@ version = "1.1.0"

[[PDMats]]
deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"]
git-tree-sha1 = "f99548922adf8dd5df2f02ab0063944201a12ed8"
git-tree-sha1 = "035f8d60ba2a22cb1d2580b1e0e5ce0cb05e4563"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.9.8"
version = "0.9.10"

[[Parameters]]
deps = ["OrderedCollections"]
Expand All @@ -175,9 +204,9 @@ version = "0.11.0"

[[Parsers]]
deps = ["Dates", "Test"]
git-tree-sha1 = "db2b35dedab3c0e46dc15996d170af07a5ab91c9"
git-tree-sha1 = "ef0af6c8601db18c282d092ccbd2f01f3f0cd70b"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "0.3.6"
version = "0.3.7"

[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand Down Expand Up @@ -292,9 +321,9 @@ version = "1.0.0"

[[Tables]]
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"]
git-tree-sha1 = "951b5be359e92703f886881b175ecfe924d8bd91"
git-tree-sha1 = "aaed7b3b00248ff6a794375ad6adf30f30ca5591"
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
version = "0.2.10"
version = "0.2.11"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
Expand Down
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
julia = "1"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -20,6 +26,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Statistics", "Random", "LinearAlgebra"]

[compat]
julia = "1"
140 changes: 140 additions & 0 deletions scripts/mixedeffects.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
OUT_DIR = joinpath(Base.source_dir(), "..", "output")
mkpath(OUT_DIR)

#include(joinpath(SRC_DIR, "BridgeSDEInference.jl"))
using BridgeSDEInference

const BSI = BridgeSDEInference
using DataFrames, DelimitedFiles, CSV
using Test
using Makie
using Bridge, BridgeSDEInference, StaticArrays, Distributions
using Statistics, Random, LinearAlgebra

#include(joinpath(AUX_DIR, "read_and_write_data.jl"))
include(joinpath("..","src","auxiliary","read_and_write_data.jl"))
include(joinpath("..","src","auxiliary","transforms.jl"))
const 𝕏 = SVector
# decide if first passage time observations or partially observed diffusion
fptObsFlag = false

# pick dataset
using DelimitedFiles
using Makie
#data = readdlm("../LinneasData190920.csv", ';')
#
#data[isnan.(data)] .= circshift(data, (-1,0))[isnan.(data)]
#data[isnan.(data)] .= circshift(data, (1,0))[isnan.(data)]
#data[isnan.(data)] .= circshift(data, (-2,0))[isnan.(data)]
#data[isnan.(data)] .= circshift(data, (2,0))[isnan.(data)]
#data[isnan.(data)] .= circshift(data, (3,0))[isnan.(data)]
#any(isnan.(data))
# t = 0:30:N
#data = cumsum(0.1rand(200,100), dims=1) # Mock data

#N, K = size(data)

#x0 = [𝕏(x, 0.0) for x in data[:, 1]]
#obs = map(𝕏, data)
#obsTime = hcat([range(0, 1, length=N) for k in 1:K]...)




include("simulate_mix_part_obs.jl")
K = length(XX)
N = getunique(length.(XX))
obs = [map(x->x[1:1], XX[k].yy) for k in 1:K]
obsTime = [XX[k].tt for k in 1:K]

fpt = fill(NaN, size(data)) # really needed?
fptOrPartObs = PartObs()


param = :complexConjug
# Initial parameter guess.
θ₀ = (10.0, -8.0, 15.0, 0.0, 3.0)
randomEffects = (false, false, false, false, true)

# Target law
P˟ = [FitzhughDiffusion(param, θ₀...) for i in 1:K]

P̃ = map(1:K) do i
map(1:N-1) do
i, k = I[1], I[2]
t₀, T, u, v = obsTime[i], obsTime[i+1], obs[k][i], obs[k][i+1]
FitzhughDiffusionAux(param, θ₀..., t₀, u[1], T, v[1])
end
display(P̃[1,1])
𝕂 = Float64
L = @SMatrix [1. 0.]
Σdiagel = 1e-1
Σ = @SMatrix [Σdiagel]

Ls = [L for _ in P̃]
Σs = [Σ for _ in P̃]
τ(t₀,T) = (x) -> t₀ + (x-t₀) * (2-(x-t₀)/(T-t₀))
numSteps=1*10^5
saveIter=3*10^2
tKernel = RandomWalk([3.0, 5.0, 5.0, 0.01, 0.5],
[false, false, false, false, true])
priors = Priors((MvNormal([0.0,0.0,0.0], diagm(0=>[1000.0, 1000.0, 1000.0])),
#ImproperPrior(),
ImproperPrior()))
𝔅 = NoBlocking()
blockingParams = ([], 0.1, NoChangePt())
changePt = NoChangePt()
#x0Pr = KnownStartingPt(x0)
x0Pr = [GsnStartingPt(x, x, @SMatrix [20. 0; 0 20.]) for x in x0]
warmUp = 100

Random.seed!(4)
start = time()
(chain, accRateImp, accRateUpdt,
paths, time_) = BSI.mixedmcmc(𝕂, fptOrPartObs, obs, obsTime, x0Pr, 0.0, P˟,
P̃, Ls, Σs, numSteps, tKernel, priors, τ;
fpt=fpt,
ρ=0.975,
dt=1/1000,
saveIter=saveIter,
verbIter=10^2,
updtCoord=(Val((true, true, true, false, false)),
#Val((true, false, false, false, false)),
Val((false, false, false, false, true)),
),
randomEffects=randomEffects,
paramUpdt=true,
updtType=(ConjugateUpdt(),
#MetropolisHastingsUpdt(),
MixedEffectsMHUpdt(),
),
skipForSave=10^0,
blocking=𝔅,
blockingParams=blockingParams,
solver=Vern7(),
changePt=changePt,
warmUp=warmUp)
elapsed = time() - start
print("time elapsed: ", elapsed, "\n")

print("imputation acceptance rate: ", accRateImp,
", parameter update acceptance rate: ", accRateUpdt)

x0⁺, pathsToSave = transformMCMCOutput(x0, paths, saveIter; chain=chain,
numGibbsSteps=2,
parametrisation=param,
warmUp=warmUp)


df2 = savePathsToFile(pathsToSave, time_, joinpath(OUT_DIR, "sampled_paths.csv"))
df3 = saveChainToFile(chain, joinpath(OUT_DIR, "chain.csv"))

include(joinpath(AUX_DIR, "plotting_fns.jl"))
set_default_plot_size(30cm, 20cm)
plotPaths(df2, obs=[Float64.(df.x1), [x0⁺[2]]],
obsTime=[Float64.(df.time), [0.0]], obsCoords=[1,2])

plotChain(df3, coords=[1])
plotChain(df3, coords=[2])
plotChain(df3, coords=[3])
plotChain(df3, coords=[5])
3 changes: 2 additions & 1 deletion src/BridgeSDEInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export ImproperPrior, NoChangePt, SimpleChangePt

# mcmc.jl
export mcmc, PartObs, FPT, FPTInfo, ConjugateUpdt, MetropolisHastingsUpdt
export mixedmcmc, MixedEffectsMHUpdt

# ODE solvers:
export Ralston3, RK4, Tsit5, Vern7
Expand Down Expand Up @@ -48,5 +49,5 @@ include("blocking_schedule.jl")
include("starting_pt.jl")
include("mcmc.jl")
include("path_to_wiener.jl")

include("mixedeffects.jl")
end
3 changes: 2 additions & 1 deletion src/fitzHughNagumo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ end
constdiff(::FitzhughDiffusion) = true
clone(P::FitzhughDiffusion, θ) = FitzhughDiffusion(P.param, θ...)
params(P::FitzhughDiffusion) = [P.ϵ, P.s, P.γ, P.β, P.σ]

paramnames(::FitzhughDiffusion) = [:ϵ, :s, :γ, :β, :σ]

"""
struct FitzhughDiffusionAux <: ContinuousTimeProcess{ℝ{2}}
Expand Down Expand Up @@ -305,6 +305,7 @@ clone(P::FitzhughDiffusionAux, θ) = FitzhughDiffusionAux(P.param, θ..., P.t,
clone(P::FitzhughDiffusionAux, θ, v) = FitzhughDiffusionAux(P.param, θ..., P.t,
zero(v), P.T, v)
params(P::FitzhughDiffusionAux) = [P.ϵ, P.s, P.γ, P.β, P.σ]
paramnames(::FitzhughDiffusionAux) = [:ϵ, :s, :γ, :β, :σ]


"""
Expand Down
20 changes: 14 additions & 6 deletions src/mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ Flag for performing update according to Metropolis Hastings step
"""
struct MetropolisHastingsUpdt <: ParamUpdateType end

"""
MixedEffectsMHUpdt <: ParamUpdateType

Flag for performing update according to Metropolis Hastings step
for a mixed effect parameter.
"""
struct MixedEffectsMHUpdt <: ParamUpdateType end


"""
setBlocking(𝔅::NoBlocking, ::Any, ::Any, ::Any, ::Any)
Expand Down Expand Up @@ -204,7 +212,7 @@ end

Initialise the workspace for MCMC algorithm. Initialises containers for driving
Wiener processes `WWᵒ` & `WW`, for diffusion processes `XXᵒ` & `XX`, for
diffusion Law `Pᵒ` (parametetrised by proposal parameters) and defines the type
diffusion Law `Pᵒ` (parametrised by proposal parameters) and defines the type
of Wiener process `Wnr`.
"""
function initialise(::ObsScheme, P, m, yPr::StartingPtPrior{T}, ::S,
Expand Down Expand Up @@ -481,7 +489,7 @@ Imputation step of the MCMC scheme (without blocking).
- `XXᵒ`: containers for proposal diffusion paths
- `XX`: containers with old diffusion paths
- `P`: laws of the diffusion path (proposal and target)
- `11`: log-likelihood of the old (previously accepted) diffusion path
- `ll`: log-likelihood of the old (previously accepted) diffusion path
- `fpt`: info about first-passage time conditioning
- `ρ`: memory parameter for the Crank-Nicolson scheme
- `verbose`: whether to print updates info while sampling
Expand Down Expand Up @@ -593,7 +601,7 @@ Imputation step of the MCMC scheme (without blocking).
- `XXᵒ`: containers for proposal diffusion paths
- `XX`: containers with old diffusion paths
- `P`: laws of the diffusion path (proposal and target)
- `11`: log-likelihood of the old (previously accepted) diffusion path
- `ll`: log-likelihood of the old (previously accepted) diffusion path
- `fpt`: info about first-passage time conditioning
- `ρ`: memory parameter for the Crank-Nicolson scheme
- `verbose`: whether to print updates info while sampling
Expand Down Expand Up @@ -676,7 +684,7 @@ Imputation step of the MCMC scheme (without blocking).
- `XXᵒ`: containers for proposal diffusion paths
- `XX`: containers with old diffusion paths
- `P`: laws of the diffusion path (proposal and target)
- `11`: log-likelihood of the old (previously accepted) diffusion path
- `ll`: log-likelihood of the old (previously accepted) diffusion path
- `fpt`: info about first-passage time conditioning
- `ρ`: memory parameter for the Crank-Nicolson scheme
- `verbose`: whether to print updates info while sampling
Expand Down Expand Up @@ -840,7 +848,7 @@ Update parameters
- `P`: laws of the diffusion path with old parametrisation
- `XXᵒ`: containers for proposal diffusion paths
- `XX`: containers with old diffusion paths
- `11`: likelihood of the old (previously accepted) parametrisation
- `ll`: likelihood of the old (previously accepted) parametrisation
- `priors`: list of priors
- `fpt`: info about first-passage time conditioning
- `recomputeODEs`: whether auxiliary law depends on the updated params
Expand Down Expand Up @@ -901,7 +909,7 @@ Update parameters
- `P`: laws of the diffusion path with old parametrisation
- `XXᵒ`: containers for proposal diffusion paths
- `XX`: containers with old diffusion paths
- `11`: likelihood of the old (previously accepted) parametrisation
- `ll`: likelihood of the old (previously accepted) parametrisation
- `priors`: list of priors
- `fpt`: info about first-passage time conditioning
- `recomputeODEs`: whether auxiliary law depends on the updated params
Expand Down
Loading