diff --git a/Project.toml b/Project.toml index 24c3c37..64c69ab 100644 --- a/Project.toml +++ b/Project.toml @@ -61,6 +61,7 @@ julia = "1.10" [extras] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" ParameterizedFunctions = "65888b18-ceab-5e60-b2b9-181511a3b968" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -70,4 +71,4 @@ SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ADTypes", "Test", "Pkg", "OrdinaryDiffEq", "ParameterizedFunctions", "SafeTestsets", "StatsBase", "SteadyStateDiffEq"] +test = ["ADTypes", "JET", "Test", "Pkg", "OrdinaryDiffEq", "ParameterizedFunctions", "SafeTestsets", "StatsBase", "SteadyStateDiffEq"] diff --git a/src/dynamichmc_inference.jl b/src/dynamichmc_inference.jl index 6e72cf8..6794edb 100644 --- a/src/dynamichmc_inference.jl +++ b/src/dynamichmc_inference.jl @@ -12,7 +12,7 @@ $(FIELDS) Base.@kwdef struct DynamicHMCPosterior{TA, TP, TD, TT, TR, TS, TK, TI, TRe} "Algorithm for the ODE solver." algorithm::TA - "An ODE problem definition (`DiffEqBase.DEProblem`)." + "A problem definition (`DiffEqBase.DEProblem` or `DiffEqBase.AbstractNonlinearProblem`)." problem::TP "Time values at which the simulated path is compared to `data`." t::TT @@ -102,7 +102,7 @@ posterior values (transformed from `ℝⁿ`). - `mcmc_kwargs` are passed on as keyword arguments to `DynamicHMC.mcmc_with_warmup` """ function dynamichmc_inference( - problem::DiffEqBase.DEProblem, algorithm, t, data, + problem::Union{DiffEqBase.DEProblem, DiffEqBase.AbstractNonlinearProblem}, algorithm, t, data, parameter_priors, parameter_transformations = as( Vector, asℝ₊, diff --git a/src/stan_inference.jl b/src/stan_inference.jl index fbd3f0c..93093ac 100644 --- a/src/stan_inference.jl +++ b/src/stan_inference.jl @@ -10,7 +10,7 @@ end struct StanODEData end -function generate_priors(n, priors) +function generate_priors(n::Integer, priors) priors_string = "" if priors === nothing for i in 1:n @@ -27,7 +27,7 @@ function generate_priors(n, priors) return priors_string end -function generate_theta(n, priors) +function generate_theta(n::Integer, priors) theta = "" for i in 1:n upper_bound = "" @@ -55,7 +55,7 @@ function generate_theta(n, priors) end function stan_inference( - prob::DiffEqBase.DEProblem, + prob::Union{DiffEqBase.DEProblem, DiffEqBase.AbstractNonlinearProblem}, alg, # Positional arguments t, diff --git a/src/turing_inference.jl b/src/turing_inference.jl index cbcd121..9cea743 100644 --- a/src/turing_inference.jl +++ b/src/turing_inference.jl @@ -1,5 +1,5 @@ function turing_inference( - prob::DiffEqBase.DEProblem, + prob::Union{DiffEqBase.DEProblem, DiffEqBase.AbstractNonlinearProblem}, alg, t, data, diff --git a/test/jet.jl b/test/jet.jl new file mode 100644 index 0000000..297d053 --- /dev/null +++ b/test/jet.jl @@ -0,0 +1,10 @@ +using DiffEqBayes +using JET +using Test + +@testset "JET static analysis" begin + @testset "Package-level analysis" begin + result = JET.report_package("DiffEqBayes"; target_modules = (DiffEqBayes,)) + @test length(JET.get_reports(result)) == 0 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index e09e7a7..e726e57 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,3 +21,9 @@ if GROUP == "Stan" || GROUP == "All" include("stan.jl") end end + +if GROUP == "JET" || GROUP == "All" + @time @safetestset "JET" begin + include("jet.jl") + end +end