diff --git a/.gitignore b/.gitignore index 9b1a5785..8e456c27 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,5 @@ _git2_* docs/build *.info .vscode/spellright.dict +deps/build.log +.DS_Store diff --git a/Project.toml b/Project.toml index a1fb9c8f..4e471255 100644 --- a/Project.toml +++ b/Project.toml @@ -33,6 +33,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826" NumericalIntegration = "e7bfaba1-d571-5449-8927-abc22e82249b" Optim = "429524aa-4258-5aef-a3af-852621145aeb" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Peaks = "18e31ff7-3703-566c-8e60-38913d67486b" PhysicalConstants = "5ad8b20f-a522-5ce9-bfc9-ddf1d5bda6ab" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -82,6 +83,7 @@ Logging = "1.9" Memoize = "0.4.4" NumericalIntegration = "0.3.3" Optim = "1.4" +OrdinaryDiffEq = "6.102.0" Peaks = "0.3.2, 0.4, 0.5" PhysicalConstants = "0.2" Pkg = "1.9" diff --git a/scripts/solver_work_precision.jl b/scripts/solver_work_precision.jl new file mode 100644 index 00000000..6b318bcc --- /dev/null +++ b/scripts/solver_work_precision.jl @@ -0,0 +1,341 @@ +# Calculate work-precision plots for various NLSE solvers + +using DifferentialEquations, SciMLOperators +import FFTW +import LinearAlgebra: inv, mul!, ldiv!, norm, Diagonal +using PyPlot +import Luna +import Printf: @sprintf + +# NLSE grid and temporary storage +mutable struct NLSE{TFT} + n::Int + dt::Float64 + dΩ::Float64 + T::Vector{Float64} + Ω::Vector{Float64} + ut::Vector{ComplexF64} + utmp::Vector{ComplexF64} + utmp2::Vector{ComplexF64} + dutmp::Vector{ComplexF64} + L::Vector{ComplexF64} + FT::TFT + cz::Float64 + nfunc::Int +end + +function NLSE(dt, trange) + n = nextpow(2, ceil(Int, trange/dt)) + dt = trange/n + dΩ = 2π/trange + T = collect((-n//2:n//2-1)*dt) + Ω = FFTW.fftshift((-n//2:n//2-1)*dΩ) + ut = @. complex(5*sech(T)) + utmp = similar(ut) + utmp2 = similar(ut) + dutmp = similar(ut) + L = @. 1im*Ω^2/2 + FT = FFTW.plan_fft(ut) + inv(FT) + cz = 0.0 + NLSE(n, dt, dΩ, T, Ω, ut, utmp, utmp2, dutmp, L, FT, cz, 0) +end + +function reset!(nlse::NLSE) + nlse.cz = 0.0 + nlse.nfunc = 0 +end + +# explicit linear operator +function f1!(du,u,p,z) + @. du = p.L*u +end + +# nonlinear operator (Kerr effect) +function f2!(du,u,p,z) + p.nfunc += 1 + ldiv!(p.ut, p.FT, u) + @. p.utmp = -1im*abs2(p.ut)*p.ut + mul!(du, p.FT, p.utmp) +end + +# full interaction picture, constant L, analytically integrated +function fpre!(du,u,p,z) + @. p.utmp = u*exp(p.L*z) + f2!(du,p.utmp,p,z) + @. du *= exp(-p.L*z) +end + +# piecewise interaction picture, constant L, analytically integrated +function fpre2!(du,u,p,z) + @. p.utmp = u*exp(p.L*(z - p.cz)) + f2!(du,p.utmp,p,z) + @. du *= exp(-p.L*(z - p.cz)) +end + +# full interaction picture, L numerically integrated +function fdbl!(du,u,p,z) + uu = @view u[1:length(u)÷2] + ll = @view u[length(u)÷2+1:end] + duu = @view du[1:length(u)÷2] + dll = @view du[length(u)÷2+1:end] + @. p.utmp = uu*exp(ll) + f2!(p.utmp2,p.utmp,p,z) + @. duu = p.utmp2*exp(-ll) + @. dll = p.L +end + +# explicitly call both linear and nonlinear terms, this is stiff +function fall!(du,u,p,z) + f1!(p.dutmp,u,p,z) + f2!(du,u,p,z) + du .+= p.dutmp +end + +# 5th order soliton initial condition +function getinit(nlse::NLSE) + ut0 = @. complex.(5*sech(nlse.T)) + FFTW.fft(ut0) +end + +# reset u and cz at each step for piecewise interaction picture solver +function resetaffect!(integrator) + integrator.u .= integrator.u .* exp.(integrator.p.L .* (integrator.t - integrator.p.cz)) + integrator.p.cz = integrator.t +end + +function noaffect!(integrator) + # do nothing +end + +function geterror(nlse, u) + ana = @. 5*sech(nlse.T)*exp(-1im*π/4) + norm(FFTW.ifft(u) .- ana)/norm(ana) +end + +function run(prob, solver, adaptive, dt, reltol, abstol; cb=nothing) + zs = range(0.0, π/2, length=201) + println("building") + @time integrator = init(prob, solver; dt, adaptive, reltol, abstol, saveat=zs, callback=cb) + println("starting") + @time u = solve!(integrator) + zs, u, integrator +end + +# run full interaction picture, constant L, analytically integrated +function run_fullip(nlse::NLSE; solver=Tsit5(), adaptive=true, dt=0.0002, reltol=1e-2, abstol=1e-6, fullret=false) + reset!(nlse) + prob = ODEProblem(fpre!, getinit(nlse), (0.0, π/2), nlse) + zs, u, integrator = run(prob, solver, adaptive, dt, reltol, abstol) + res = Array{Complex{Float64}}(undef, nlse.n, length(zs)) + for (i,z) in enumerate(zs) + @. res[:,i] = u[:,i] * exp(nlse.L * z) + end + err = geterror(nlse, res[:,end]) + println("nfunc: $(nlse.nfunc)") + println("error: $err") + if fullret + return zs, res, nlse.nfunc, err, u, integrator + end + zs, res, nlse.nfunc, err +end + +# run full interaction picture, L numerically integrated +function run_numfullip(nlse::NLSE; solver=Tsit5(), adaptive=true, dt=0.0002, reltol=1e-2, abstol=1e-6) + reset!(nlse) + u0 = vcat(getinit(nlse), zero(nlse.L)) + prob = ODEProblem(fdbl!, u0, (0.0, π/2), nlse) + cb = DiscreteCallback((u,t,integrator) -> true, noaffect!, save_positions=(true,true)) + zs, u, integrator = run(prob, solver, adaptive, dt, reltol, abstol; cb=cb) + zs = Array(u.t) + res = Array{Complex{Float64}}(undef, nlse.n, length(zs)) + for i in 1:length(zs) + @. res[:,i] = u[1:nlse.n,i] * exp(u[nlse.n + 1:end,i]) + end + err = geterror(nlse, res[:,end]) + println("nfunc: $(nlse.nfunc)") + println("error: $err") + zs, res, nlse.nfunc, err +end + +# piecewise interaction picture, constant L, analytically integrated +function run_pieceip(nlse::NLSE; solver=Tsit5(), adaptive=true, dt=0.0002, reltol=1e-2, abstol=1e-6) + reset!(nlse) + prob = ODEProblem(fpre2!, getinit(nlse), (0.0, π/2), nlse) + cb = DiscreteCallback((u,t,integrator) -> true, resetaffect!, save_positions=(false,true)) + _, u, integrator = run(prob, solver, adaptive, dt, reltol, abstol; cb) + res = Array(u) + zs = u.t + err = geterror(nlse, res[:,end]) + println("nfunc: $(nlse.nfunc)") + println("error: $err") + zs, res, nlse.nfunc, err +end + +function run_stiff(nlse::NLSE; solver=Tsit5(), adaptive=true, dt=0.0002, reltol=1e-2, abstol=1e-6) + reset!(nlse) + prob = ODEProblem(fall!, getinit(nlse), (0.0, π/2), nlse) + zs, u, integrator = run(prob, solver, adaptive, dt, reltol, abstol) + res = Array(u) + err = geterror(nlse, res[:,end]) + println("nfunc: $(nlse.nfunc)") + println("error: $err") + zs, res, nlse.nfunc, err +end + +# Exponential RK integrator +function run_splitlin(nlse::NLSE; solver=ETDRK4(), adaptive=false, dt=0.0002, reltol=1e-2, abstol=1e-6) + reset!(nlse) + op = DiagonalOperator(nlse.L) + f = SplitFunction(op, f2!) + prob = SplitODEProblem(f, getinit(nlse), (0.0, π/2), nlse) + zs, u, integrator = run(prob, solver, adaptive, dt, reltol, abstol) + res = Array(u) + err = geterror(nlse, res[:,end]) + println("nfunc: $(nlse.nfunc)") + println("error: $err") + zs, res, nlse.nfunc, err +end + +# Luna original RK45 solver +function run_Luna_weak(nlse::NLSE; solver=nothing, adaptive=true, dt=0.0002, reltol=1e-2, abstol=1e-6) + reset!(nlse) + z, u, steps = Luna.RK45.solve_precon((du, u, z) -> f2!(du, u, nlse, z), nlse.L, getinit(nlse), 0.0, dt, π/2; + rtol=reltol, atol=abstol, output=true, locextrap=true, norm=Luna.RK45.weaknorm) + err = geterror(nlse, u[:,end]) + println("nfunc: $(nlse.nfunc)") + println("error: $err") + z, u, nlse.nfunc, err +end + +# Luna original RK45 solver with better norm +function run_Luna_norm(nlse::NLSE; solver=nothing, adaptive=true, dt=0.0002, reltol=1e-2, abstol=1e-6) + reset!(nlse) + z, u, steps = Luna.RK45.solve_precon((du, u, z) -> f2!(du, u, nlse, z), nlse.L, getinit(nlse), 0.0, dt, π/2; + rtol=reltol, atol=abstol, output=true, locextrap=true, norm=Luna.RK45.normnorm) + err = geterror(nlse, u[:,end]) + println("nfunc: $(nlse.nfunc)") + println("error: $err") + z, u, nlse.nfunc, err +end + +# Luna new solver +function run_newLuna(nlse::NLSE; solver=:Tsit5, adaptive=true, dt=0.0002, reltol=1e-2, abstol=1e-6) + reset!(nlse) + zs = range(0.0, π/2, length=201) + iz = 2 + res = Array{Complex{Float64}}(undef, nlse.n, length(zs)) + res[:,1] = getinit(nlse) + function stepfun(u, z, dz, interpolant) + while iz <= length(zs) && z >= zs[iz] + res[:,iz] = interpolant(zs[iz]) + iz += 1 + end + end + sol = Luna.Propagator.propagate((du, u, z) -> f2!(du, u, nlse, z), nlse.L, res[:,1], 0, π/2, stepfun; + rtol=reltol, atol=abstol, init_dz=dt, max_dz=Inf, min_dz=0, + status_period=10, solver) + err = geterror(nlse, res[:,end]) + println("nfunc: $(nlse.nfunc)") + println("error: $err") + zs, res, nlse.nfunc, err +end + +function workprecision(nlse::NLSE, solvers) + errs = [] + nfs = [] + for (i,solverset) in enumerate(solvers) + solver, dts, reltols, abstols, label = solverset + errsi = zeros(length(dts)) + nfsi = zeros(length(dts)) + for j in 1:length(dts) + z, u, nfuncs, err = solver(nlse; reltol=reltols[j], abstol=abstols[j], dt=dts[j]) + errsi[j] = err + nfsi[j] = nfuncs + end + if isnothing(label) + label = string(solver) + end + loglog(errsi, nfsi, label=label) + push!(errs, errsi) + push!(nfs, nfsi) + end + legend() + PyPlot.grid() + xlabel("Error") + ylabel("Function Evaluations") + ylim(2e3,4e4) + xlim(1e-6,1e-1) + errs, nfs +end + +function plot_nlse(nlse::NLSE, z, u; axs=nothing) + IT = 10log10.(abs2.(FFTW.ifft(u,1))) + IW = 10log10.(abs2.(FFTW.fftshift(u,1))) + IT .-= maximum(IT) + IW .-= maximum(IW) + if isnothing(axs) + fig = PyPlot.plt.figure(constrained_layout=true, figsize=(10, 6)) + axd = fig.subplot_mosaic( + """ + ab + """) + axs = (axd["a"], axd["b"]) + end + axs[1].pcolormesh(z, nlse.T, IT, clim=(-200,0), rasterized=true) + axs[1].set_xlabel("Position") + axs[1].set_ylabel("Time") + img = axs[2].pcolormesh(z, FFTW.fftshift(nlse.Ω), IW, clim=(-200,0), rasterized=true) + axs[2].set_xlabel("Position") + axs[2].set_ylabel("Frequency") + colorbar(img, ax=axs, fraction=0.05, pad=0.1, label="dB") +end + +function plot_nlse_cmp(nlse::NLSE, data) + fig = PyPlot.plt.figure(constrained_layout=true, figsize=(10, 3*length(data))) + ax_array = fig.subplots(length(data), 2) + for (i, (z, u, nfs, err)) in enumerate(data) + axs = (ax_array[i,1], ax_array[i,2]) + plot_nlse(nlse, z, u; axs=axs) + errs = @sprintf("%.2e", err) + axs[1].set_title("nfs=$(nfs), err=$(errs)") + end +end + +nlse = NLSE(0.016, 48.0); + +# run work-precision plots for various solvers +# errs, nfs = workprecision(nlse, ( +# (run_fullip, 0.0002 .* ones(30), collect(logrange(1e-5, 1e-1, 30)), 1e-6 .* ones(30), nothing), +# (run_pieceip, 0.0002 .* ones(30), collect(logrange(1e-5, 1e-1, 30)), 1e-6 .* ones(30), nothing), +# (run_numfullip, 0.0002 .* ones(30), collect(logrange(1e-5, 1e-1, 30)), 1e-6 .* ones(30), nothing), +# (run_splitlin, collect(logrange(1e-4, 1e-2, 30)), 1e-6 .* ones(30), 1e-6 .* ones(30), nothing), +# (run_Luna_weak, 0.0002 .* ones(30), collect(logrange(1e-10, 1e-3, 30)), 1e-10 .* ones(30), nothing), +# (run_Luna_norm, 0.0002 .* ones(30), collect(logrange(1e-7, 1e-1, 30)), 1e-6 .* ones(30), nothing), +# (run_newLuna, 0.0002 .* ones(40), collect(logrange(5e-5, 1.2e-1, 40)), 1e-6 .* ones(40), nothing) +# )) +# savefig("scripts/solver_work_precision_nofsal.svg")) + +# errs, nfs = workprecision(nlse, ( +# (run_fullip, 0.0002 .* ones(30), collect(logrange(1e-5, 1e-1, 30)), 1e-6 .* ones(30), nothing), +# (run_pieceip, 0.0002 .* ones(30), collect(logrange(1e-5, 1e-1, 30)), 1e-6 .* ones(30), nothing), +# (run_Luna_weak, 0.0002 .* ones(30), collect(logrange(1e-10, 1e-3, 30)), 1e-10 .* ones(30), nothing), +# (run_Luna_norm, 0.0002 .* ones(30), collect(logrange(1e-7, 1e-1, 30)), 1e-6 .* ones(30), nothing), +# (run_newLuna, 0.0002 .* ones(40), collect(logrange(5e-5, 1.2e-1, 40)), 1e-6 .* ones(40), nothing) +# )) +# savefig(solver_work_precision_nabsbound.svg")) + +# # work-precision curves for multiple atol values for new Luna solver +# errs, nfs = workprecision(nlse, ( +# (run_newLuna, 0.0002 .* ones(40), collect(logrange(1e-10, 1.2e-1, 40)), 1e-4 .* ones(40), "1e-4"), +# (run_newLuna, 0.0002 .* ones(40), collect(logrange(1e-10, 1.2e-1, 40)), 1e-5 .* ones(40), "1e-5"), +# (run_newLuna, 0.0002 .* ones(40), collect(logrange(1e-10, 1.2e-1, 40)), 1e-6 .* ones(40), "1e-6"), +# (run_newLuna, 0.0002 .* ones(40), collect(logrange(1e-10, 1.2e-1, 40)), 1e-8 .* ones(40), "1e-8"), +# (run_newLuna, 0.0002 .* ones(40), collect(logrange(1e-10, 1.2e-1, 40)), 1e-10 .* ones(40), "1e-10"), +# )) +# savefig("solver_work_precision_atolscan.svg") + +# # run a comparison to visualise the error +# data = [run_newLuna(nlse; reltol=rtol, abstol=1e-6) for rtol in (1e-1, 6.9e-2, 1.2e-2, 5e-4)] +# plot_nlse_cmp(nlse, data) +# savefig("solver_work_precision_cmp.svg"), dpi=600) diff --git a/src/Interface.jl b/src/Interface.jl index 7495088a..42cfff0b 100644 --- a/src/Interface.jl +++ b/src/Interface.jl @@ -339,11 +339,19 @@ If `raman` is `true`, then the following options apply: - `scanidx`: Current scan index within a scan being run. Only used when `scan` is passed. - `filename`: Can be used to to overwrite the scan name when running a parameter scan. The running `scanidx` will be appended to this filename. Ignored if no `scan` is given. + +# Solver options +- `rtol::Number`: Relative tolerance for the solver. Defaults to an optimal solver dependent value. +- `atol::Number`: Absolute tolerance for the solver. Defaults to an optimal solver dependent value. +- `solver::Symbol`: The solver to use. Defaults to `:Tsit5` from DifferentialEquations.jl. + Other possible choices can be found in the non-stiff ODE documentation of DifferentialEquations.jl. + For the original Luna RK45 solver, use `:OrigRK45`. - `status_period::Number`: Interval (in seconds) between printed status updates. """ -function prop_capillary(args...; status_period=5, kwargs...) +function prop_capillary(args...; rtol=nothing, atol=nothing, solver=:Tsit5, + status_period=5, kwargs...) Eω, grid, linop, transform, FT, output = prop_capillary_args(args...; kwargs...) - Luna.run(Eω, grid, linop, transform, FT, output; status_period) + Luna.run(Eω, grid, linop, transform, FT, output; rtol, atol, solver, status_period) output end @@ -897,11 +905,19 @@ Note that the current GNLSE model is single mode only. - `scanidx`: Current scan index within a scan being run. Only used when `scan` is passed. - `filename`: Can be used to to overwrite the scan name when running a parameter scan. The running `scanidx` will be appended to this filename. Ignored if no `scan` is given. + +# Solver options +- `rtol::Number`: Relative tolerance for the solver. Defaults to an optimal solver dependent value. +- `atol::Number`: Absolute tolerance for the solver. Defaults to an optimal solver dependent value. +- `solver::Symbol`: The solver to use. Defaults to `:Tsit5` from DifferentialEquations.jl. + Other possible choices can be found in the non-stiff ODE documentation of DifferentialEquations.jl. + For the original Luna RK45 solver, use `:OrigRK45`. - `status_period::Number`: Interval (in seconds) between printed status updates. """ -function prop_gnlse(args...; status_period=5, kwargs...) +function prop_gnlse(args...; rtol=nothing, atol=nothing, solver=:Tsit5, + status_period=5, kwargs...) Eω, grid, linop, transform, FT, output = prop_gnlse_args(args...; kwargs...) - Luna.run(Eω, grid, linop, transform, FT, output; status_period) + Luna.run(Eω, grid, linop, transform, FT, output; rtol, atol, solver, status_period) output end diff --git a/src/Luna.jl b/src/Luna.jl index 24de3188..a6203255 100644 --- a/src/Luna.jl +++ b/src/Luna.jl @@ -57,6 +57,7 @@ include("Grid.jl") include("Modes.jl") include("Fields.jl") include("RK45.jl") +include("Propagator.jl") include("LinearOps.jl") include("Capillary.jl") include("Antiresonant.jl") @@ -88,7 +89,7 @@ export Utils, Scans, Output, Maths, PhysData, Grid, RK45, Modes, Capillary, Rect Nonlinear, Ionisation, NonlinearRHS, LinearOps, Stats, Polarisation, Tools, Plotting, Raman, Antiresonant, Fields, Processing, Interface, SFA, prop_capillary, prop_gnlse, Pulses, Scan, runscan, makefilename, addvariable!, - StepIndexFibre, SimpleFibre + StepIndexFibre, SimpleFibre, Propagator # for a tuple of TimeFields we assume all inputs are for mode 1 function doinput_sm(grid, inputs::Tuple{Vararg{T} where T <: Fields.TimeField}, FT) @@ -358,21 +359,21 @@ save_modeinfo_maybe(output, t) = nothing function run(Eω, grid, linop, transform, FT, output; min_dz=0, max_dz=grid.zmax/2, init_dz=1e-4, z0=0.0, - rtol=1e-6, atol=1e-10, safety=0.9, norm=RK45.weaknorm, - status_period=1) + rtol=nothing, atol=nothing, safety=0.9, norm=RK45.weaknorm, + status_period=1, solver=:Tsit5) Et = FT \ Eω - function stepfun(Eω, z, dz, interpolant) - Eω .*= grid.ωwin + function stepfun(Eω, z, dz, interpolant; stepcache=nothing) + #Eω .*= grid.ωwin ldiv!(Et, FT, Eω) Et .*= grid.twin mul!(Eω, FT, Et) - output(Eω, z, dz, interpolant) + output(Eω, z, dz, interpolant; stepcache) end # check_cache does nothing except for HDF5Outputs - Eωc, zc, dzc = Output.check_cache(output, Eω, z0, init_dz) + Eωc, zc, dzc, stepcache = Output.check_cache(output, Eω, z0, init_dz) if zc > z0 Logging.@info("Found cached propagation. Resuming...") Eω, z0, init_dz = Eωc, zc, dzc @@ -383,20 +384,36 @@ function run(Eω, grid, save_modeinfo_maybe(output, transform) flush(stderr) # flush std error once before starting to show setup steps - RK45.solve_precon( - transform, linop, Eω, z0, init_dz, grid.zmax, stepfun=stepfun, - max_dt=max_dz, min_dt=min_dz, - rtol=rtol, atol=atol, safety=safety, norm=norm, - status_period=status_period) + + if solver == :OrigRK45 + Logging.@info("Using original Luna RK45 solver.") + rtol = isnothing(rtol) ? 1e-6 : rtol + atol = isnothing(atol) ? 1e-10 : atol + Logging.@info("Using rtol = $rtol, atol = $atol") + RK45.solve_precon( + transform, linop, Eω, z0, init_dz, grid.zmax, stepfun=stepfun, + max_dt=max_dz, min_dt=min_dz, + rtol=rtol, atol=atol, safety=safety, norm=norm, + status_period=status_period) + return + else + Logging.@info("Using $solver solver") + rtol = isnothing(rtol) ? 7e-2 : rtol + atol = isnothing(atol) ? 1e-6 : atol + Logging.@info("Using rtol = $rtol, atol = $atol") + Propagator.propagate(transform, linop, Eω, z0, grid.zmax, stepfun; + rtol, atol, init_dz, max_dz, min_dz, status_period, + solver, stepcache) + end end # run some code for precompilation Logging.with_logger(Logging.NullLogger()) do - prop_capillary(125e-6, 0.3, :He, 1.0; λ0=800e-9, energy=1e-9, + prop_capillary(125e-6, 0.03, :He, 1.0; λ0=800e-9, energy=1e-9, τfwhm=10e-15, λlims=(150e-9, 4e-6), trange=1e-12, saveN=11) - prop_capillary(125e-6, 0.3, :He, (1.0, 0); λ0=800e-9, energy=1e-9, + prop_capillary(125e-6, 0.03, :He, (1.0, 0); λ0=800e-9, energy=1e-9, τfwhm=10e-15, λlims=(150e-9, 4e-6), trange=1e-12, saveN=11) - prop_capillary(125e-6, 0.3, :He, 1.0; λ0=800e-9, energy=1e-9, + prop_capillary(125e-6, 0.03, :He, 1.0; λ0=800e-9, energy=1e-9, τfwhm=10e-15, λlims=(150e-9, 4e-6), trange=1e-12, saveN=11, modes=4) p = Tools.capillary_params(120e-6, 10e-15, 800e-9, 125e-6, :He, P=1.0) diff --git a/src/Output.jl b/src/Output.jl index af2bf0fd..3231ea22 100644 --- a/src/Output.jl +++ b/src/Output.jl @@ -39,10 +39,14 @@ function MemoryOutput(save_cond, yname, tname, statsfun=nostats, script=nothing) MemoryOutput(save_cond, yname, tname, 0, data, statsfun) end -function initialise(o::MemoryOutput, y) +function initialise(o::MemoryOutput, t, dt, y) dims = init_dims(size(y), o.save_cond) o.data[o.yname] = Array{ComplexF64}(undef, dims) o.data[o.tname] = Array{Float64}(undef, (dims[end],)) + o.data["cache"] = Dict{String, Any}() + o.data["cache"]["y"] = copy(y) + o.data["cache"]["t"] = t + o.data["cache"]["dt"] = dt end "getindex works interchangeably so when switching from one Output to @@ -62,10 +66,10 @@ haskey(o::MemoryOutput, key) = haskey(o.data, key) yfun: callable which returns interpolated function value at different t Note that from RK45.jl, this will be called with yn and tn as arguments. """ -function (o::MemoryOutput)(y, t, dt, yfun) +function (o::MemoryOutput)(y, t, dt, yfun; stepcache=nothing) save, ts = o.save_cond(y, t, dt, o.saved) append_stats!(o, o.statsfun(y, t, dt)) - !haskey(o.data, o.yname) && initialise(o, y) + !haskey(o.data, o.yname) && initialise(o, t, dt, y) while save s = size(o.data[o.yname]) if s[end] < o.saved+1 @@ -79,6 +83,16 @@ function (o::MemoryOutput)(y, t, dt, yfun) o.saved += 1 save, ts = o.save_cond(y, t, dt, o.saved) end + o.data["cache"]["y"] .= y + o.data["cache"]["t"] = t + o.data["cache"]["dt"] = dt + if !isnothing(stepcache) + o.data["cache"]["dtpropose"] = stepcache.dtpropose + o.data["cache"]["dtcache"] = stepcache.dtcache + o.data["cache"]["qold"] = stepcache.qold + o.data["cache"]["erracc"] = stepcache.erracc + o.data["cache"]["dtacc"] = stepcache.dtacc + end end function append_stats!(o::MemoryOutput, d) @@ -212,7 +226,7 @@ function HDF5Output(fpath::AbstractString) HDF5Output(fpath, 0, 0, 1; readonly=true) end -function initialise(o::HDF5Output, y) +function initialise(o::HDF5Output, t, dt, y) ydims = size(y) idims = init_dims(ydims, o.save_cond) cdims = collect(idims) @@ -235,10 +249,15 @@ function initialise(o::HDF5Output, y) o.cachehash = hash((statsnames, size(y))) file["meta"]["cachehash"] = o.cachehash if o.cache - file["meta"]["cache"]["t"] = typemin(0.0) - file["meta"]["cache"]["dt"] = typemin(0.0) + file["meta"]["cache"]["t"] = t + file["meta"]["cache"]["dt"] = dt file["meta"]["cache"]["y"] = y file["meta"]["cache"]["saved"] = 0 + file["meta"]["cache"]["dtpropose"] = 0.0 + file["meta"]["cache"]["dtcache"] = 0.0 + file["meta"]["cache"]["qold"] = 0.0 + file["meta"]["cache"]["erracc"] = 0.0 + file["meta"]["cache"]["dtacc"] = 0.0 end end end @@ -310,13 +329,13 @@ end yfun: callable which returns interpolated function value at different t Note that from RK45.jl, this will be called with yn and tn as arguments. """ -function (o::HDF5Output)(y, t, dt, yfun) +function (o::HDF5Output)(y, t, dt, yfun; stepcache=nothing) o.readonly && error("Cannot add data to read-only output!") save, ts = o.save_cond(y, t, dt, o.saved) push!(o.stats_tmp, o.statsfun(y, t, dt)) if save HDF5.h5open(o.fpath, "r+") do file - !HDF5.haskey(file, o.yname) && initialise(o, y) + !HDF5.haskey(file, o.yname) && initialise(o, t, dt, y) statsnames = sort(collect(keys(o.stats_tmp[end]))) cachehash = hash((statsnames, size(y))) cachehash == o.cachehash || error( @@ -345,6 +364,13 @@ function (o::HDF5Output)(y, t, dt, yfun) write(file["meta"]["cache"]["dt"], dt) write(file["meta"]["cache"]["y"], y) write(file["meta"]["cache"]["saved"], o.saved) + if !isnothing(stepcache) + write(file["meta"]["cache"]["dtpropose"], stepcache.dtpropose) + write(file["meta"]["cache"]["dtcache"], stepcache.dtcache) + write(file["meta"]["cache"]["qold"], stepcache.qold) + write(file["meta"]["cache"]["erracc"], stepcache.erracc) + write(file["meta"]["cache"]["dtacc"], stepcache.dtacc) + end end end end @@ -464,19 +490,65 @@ Check for an existing cached propagation in the output `o` and return this cache """ function check_cache(o::HDF5Output, y, t, dt) if !o.cache || !haskey(o["meta"]["cache"], "t") - return y, t, dt + return y, t, dt, nothing end tc = o["meta"]["cache"]["t"] if tc < t - return y, t, dt + return y, t, dt, nothing end yc = o["meta"]["cache"]["y"] dtc = o["meta"]["cache"]["dt"] - return yc, tc, dtc + if haskey(o["meta"]["cache"], "dtacc") + dtacc = o["meta"]["cache"]["dtacc"] + dtpropose = o["meta"]["cache"]["dtpropose"] + dtcache = o["meta"]["cache"]["dtcache"] + qold = o["meta"]["cache"]["qold"] + erracc = o["meta"]["cache"]["erracc"] + dtacc = o["meta"]["cache"]["dtacc"] + stepcache = (; dtacc, dtpropose, dtcache, qold, erracc) + if all(iszero, stepcache) + stepcache = nothing + end + else + stepcache = nothing + end + return yc, tc, dtc, stepcache +end + +""" + check_cache(o::MemoryOutput, y, t, dt) + +Check for an existing cached propagation in the output `o` and return this cache if present. +""" +function check_cache(o::MemoryOutput, y, t, dt) + if !haskey(o, "cache") + return y, t, dt, nothing + end + tc = o.data["cache"]["t"] + if tc < t + return y, t, dt, nothing + end + yc = o.data["cache"]["y"] + dtc = o.data["cache"]["dt"] + if haskey(o["cache"], "dtacc") + dtacc = o["cache"]["dtacc"] + dtpropose = o["cache"]["dtpropose"] + dtcache = o["cache"]["dtcache"] + qold = o["cache"]["qold"] + erracc = o["cache"]["erracc"] + dtacc = o["cache"]["dtacc"] + stepcache = (; dtacc, dtpropose, dtcache, qold, erracc) + if all(iszero, stepcache) + stepcache = nothing + end + else + stepcache = nothing + end + return yc, tc, dtc, stepcache end -# For other outputs (e.g. MemoryOutput or another function), checking the cache does nothing. -check_cache(o, y, t, dt) = y, t, dt +# For other outputs, checking the cache does nothing. +check_cache(o, y, t, dt) = y, t, dt, nothing "Condition callable that distributes save points evenly on a grid" struct GridCondition diff --git a/src/Propagator.jl b/src/Propagator.jl new file mode 100644 index 00000000..08c8e14f --- /dev/null +++ b/src/Propagator.jl @@ -0,0 +1,220 @@ +module Propagator +import Dates +import Logging +import Printf: @sprintf +import Luna.Utils: format_elapsed +import OrdinaryDiffEq as ODE + +mutable struct Printer{DT} + status_period::Int + start::DT + tic::DT + zmax::Float64 +end + +function Printer(status_period, zmax) + Printer(status_period, Dates.now(), Dates.now(), zmax) +end + +function printstart(p::Printer) + p.start = Dates.now() + p.tic = Dates.now() + Logging.@info "Starting propagation" +end + +function printstep(p::Printer, z, dz) + if Dates.value(Dates.now() - p.tic) > 1000*p.status_period + speed = z/(Dates.value(Dates.now() - p.start)/1000) + eta_in_s = (p.zmax - z)/speed + if eta_in_s > 356400 + Logging.@info @sprintf("Progress: %.2f %%, ETA: XX:XX:XX, stepsize %.2e", + z/p.zmax*100, dz) + else + eta_in_ms = Dates.Millisecond(ceil(eta_in_s*1000)) + etad = Dates.DateTime(Dates.UTInstant(eta_in_ms)) + Logging.@info @sprintf("Progress: %.2f %%, ETA: %s, stepsize %.2e", + z/p.zmax*100, Dates.format(etad, "HH:MM:SS"), dz) + end + flush(stderr) + p.tic = Dates.now() + end +end + +function printstop(p::Printer, integrator) + totaltime = Dates.now() - p.start + dtstring = format_elapsed(totaltime) + Logging.@info @sprintf("Propagation finished in %s", dtstring) + Logging.@info @sprintf("Steps accepted: %d; rejected: %d", + integrator.stats.naccept, integrator.stats.nreject) + Logging.@info @sprintf("Nonlinear function calls: %d", integrator.stats.nf) +end + +abstract type AbstractPropagator end + +# For the cases where we can integrate L(z) analytically +struct AnalyticalPropagator{LT, NLT, SFT, PT, AT<:AbstractArray} <: AbstractPropagator + Li!::LT # function to get integrated linear operator at z + nonlinop!::NLT + stepfun::SFT + Litmp::AT + Eωtmp::AT + Pωtmp::AT + p::PT +end + +function fcl!(du,u,p,z) + p.Li!(p.Litmp, z) # Get integrated linear operator at z + @. p.Eωtmp = u * exp(p.Litmp) # Transform back from interaction picture + p.nonlinop!(p.Pωtmp, p.Eωtmp, z) # Apply nonlinear operator + @. du = p.Pωtmp * exp(-p.Litmp) # Transform to interaction picture +end + +function callbackcl(integrator) + # The output we want must be transformed back from the interaction picture + integrator.p.Li!(integrator.p.Litmp, integrator.t) + @. integrator.p.Eωtmp = integrator.u * exp(integrator.p.Litmp) + + # define interp function to pass to output + interp = let integrator=integrator + function interp(z) + u = integrator(z) + integrator.p.Li!(integrator.p.Litmp, z) + @. u * exp(integrator.p.Litmp) + end + end + + stepcache = (dtpropose = integrator.dtpropose, + dtcache = integrator.dtcache, + qold = integrator.qold, + erracc = integrator.erracc, + dtacc = integrator.dtacc) + + integrator.p.stepfun(integrator.p.Eωtmp, integrator.t, + integrator.dt, interp; stepcache) + + printstep(integrator.p.p, integrator.t, ODE.get_proposed_dt(integrator)) + + # copy back as we (might) modify u in stepfun (absorbing boundaries) + integrator.p.Li!(integrator.p.Litmp, integrator.t) + @. integrator.u = integrator.p.Eωtmp * exp(-integrator.p.Litmp) +end + +# For a non-constant linear operator, we need to integrate L(z) numerically along with +# the solution. We do this by simply including the integral of linear operator in the state vector. +struct NonConstPropagator{LT, NLT, SFT, PT, AT<:AbstractArray} <: AbstractPropagator + linop!::LT # function to get linear operator at z + nonlinop!::NLT + stepfun::SFT + n::Int + Eωtmp::AT + Pωtmp::AT + p::PT +end + +function fncl!(du,u,p,z) + Eω = @views u[1:p.n] # Actual Eω + Li = @views u[p.n+1:end] # Cumulatively integrated linear operator + dEω = @views du[1:p.n] + dLi = @views du[p.n+1:end] + @. p.Eωtmp = Eω * exp(Li) # Transform back from interaction picture + p.nonlinop!(p.Pωtmp, p.Eωtmp, z) # Apply nonlinear operator + @. dEω = p.Pωtmp * exp(-Li) # Transform to interaction picture + p.linop!(dLi, z) # Integrate linear operator +end + +function callbackncl(integrator) + n = integrator.p.n + Eω = @views integrator.u[1:n] # Actual Eω + Li = @views integrator.u[n+1:end] # Cumulatively integrated linear operator + # The output we want must be transformed back from the interaction picture + @. integrator.p.Eωtmp = Eω * exp(Li) + + # define interp function to pass to output + interp = let integrator=integrator, n=n + function interp(z) + u = integrator(z) + Eω = @views u[1:n] + Li = @views u[n+1:end] + @. Eω * exp(Li) + end + end + + stepcache = (dtpropose = integrator.dtpropose, + dtcache = integrator.dtcache, + qold = integrator.qold, + erracc = integrator.erracc, + dtacc = integrator.dtacc) + + integrator.p.stepfun(integrator.p.Eωtmp, integrator.t, + integrator.dt, interp; stepcache) + + printstep(integrator.p.p, integrator.t, ODE.get_proposed_dt(integrator)) + + # copy back as we (might) modify u in stepfun (absorbing boundaries) + @. Eω = integrator.p.Eωtmp * exp(-Li) +end + +# Constant linear operator case--linop is an array +function makeprop(f!, linop::Array{ComplexF64}, Eω0, z, zmax, stepfun, printer, rtol, atol) + # For a constant linear operator L, the integral is just L*z + Li! = let linop=linop, z0=z + function Li!(out, z) + @. out = linop * (z - z0) # difference from z0 to handle continuing + end + end + prop = AnalyticalPropagator(Li!, f!, stepfun, similar(Eω0), similar(Eω0), similar(Eω0), printer) + prob = ODE.ODEProblem(fcl!, Eω0, (z, zmax), prop) + prob, callbackcl, rtol, atol +end + +# For a linop tuple we expect a pair of functions (linop, ilinop) where the second function provides the +# cumulatively integrated linear operator. This is mostly for testing with analytically integrable operators. +function makeprop(f!, linop::Tuple, Eω0, z, zmax, stepfun, printer, rtol, atol) + Li0 = zero(Eω0) + ilinop! = linop[2] + ilinop!(Li0, z) + Li! = let ilinop! = ilinop!, Li0=Li0 + function Li!(out, z) + ilinop!(out, z) + out .-= Li0 # handle continuing + end + end + prop = AnalyticalPropagator(Li!, f!, stepfun, similar(Eω0), similar(Eω0), similar(Eω0), printer) + prob = ODE.ODEProblem(fcl!, Eω0, (z, zmax), prop) + prob, callbackcl, rtol, atol +end + +# General linop, we integrate numerically along with the solution +function makeprop(f!, linop, Eω0, z, zmax, stepfun, printer, rtol, atol) + prop = NonConstPropagator(linop, f!, stepfun, + length(Eω0), similar(Eω0), similar(Eω0), printer) + # continuing is handled implicilty + u0 = vcat(Eω0, zero(Eω0)) # state vector includes integrated linear operator + prob = ODE.ODEProblem(fncl!, u0, (z, zmax), prop) + prob, callbackncl, rtol, atol +end + +function propagate(f!, linop, Eω0, z, zmax, stepfun; + rtol=1e-3, atol=1e-6, init_dz=1e-4, max_dz=Inf, min_dz=0, + status_period=1, solver=:Tsit5, zstops=nothing, stepcache=nothing) + printer = Printer(status_period, zmax) + prob, cbfunc, rtol, atol = makeprop(f!, linop, Eω0, z, zmax, stepfun, printer, rtol, atol) + # We do all saving and stats in a callback called at every step + cb = ODE.DiscreteCallback((u,t,integrator) -> true, cbfunc, save_positions=(true,true)) + solveri = getproperty(ODE, solver)() + integrator = ODE.init(prob, solveri; adaptive=true, reltol=rtol, abstol=atol, + dt=init_dz, dtmin=min_dz, dtmax=max_dz, callback=cb) + if !isnothing(stepcache) + integrator.qold = stepcache.qold + integrator.erracc = stepcache.erracc + integrator.dtacc = stepcache.dtacc + integrator.dtpropose = stepcache.dtpropose + integrator.dtcache = stepcache.dtcache + end + printstart(printer) + sol = ODE.solve!(integrator) + printstop(printer, integrator) + sol +end + +end # module diff --git a/test/test_gradient.jl b/test/test_gradient.jl index 54502d49..299c2413 100644 --- a/test/test_gradient.jl +++ b/test/test_gradient.jl @@ -75,8 +75,13 @@ statsfun = Stats.collect_stats(grid, Eω, output_grad_array = Output.MemoryOutput(0, grid.zmax, 201, statsfun) Luna.run(Eω, grid, linop, transform, FT, output_grad_array, status_period=10) -@test all(output_grad.data["Eω"][grid.sidx, :] .≈ output_const.data["Eω"][grid.sidx, :]) -@test all(output_grad_array.data["Eω"][grid.sidx, :] .≈ output_const.data["Eω"][grid.sidx, :]) +# TODO: tolerances here are quite high, because the const propagator and non-const propagator handle the +# integration of the linear part differently. The absolute error is small (e.g. here 6 significant digits) +# and is easily reduced by tightening tolerances. This does not arise with the piecewise integrator originally +# used, because the const and piecewise operators are identical for a constant gradient. But see the +# analytical gradient test below for an example of where the piecewise approximation breaks down +@test isapprox(output_grad.data["Eω"][grid.sidx, :], output_const.data["Eω"][grid.sidx, :], rtol=3e-6) +@test isapprox(output_grad_array.data["Eω"][grid.sidx, :], output_const.data["Eω"][grid.sidx, :], rtol=3e-6) end @testset "envelope" begin @@ -134,6 +139,64 @@ statsfun = Stats.collect_stats(grid, Eω, output_grad_array = Output.MemoryOutput(0, grid.zmax, 201, statsfun) Luna.run(Eω, grid, linop, transform, FT, output_grad_array, status_period=10) -@test all(output_grad.data["Eω"][grid.sidx, :] .≈ output_const.data["Eω"][grid.sidx, :]) -@test all(output_grad_array.data["Eω"][grid.sidx, :] .≈ output_const.data["Eω"][grid.sidx, :]) +# See comment on line 78 +@test isapprox(output_grad.data["Eω"][grid.sidx, :], output_const.data["Eω"][grid.sidx, :], rtol=3e-6) +@test isapprox(output_grad_array.data["Eω"][grid.sidx, :], output_const.data["Eω"][grid.sidx, :], rtol=3e-6) +end + +@testset "analytical gradient" begin +a = 13e-6 +gas = :Ar +pres = 5 +τ = 30e-15 +λ0 = 800e-9 +L = 5e-2 + +# Common setup +grid = Grid.RealGrid(L, λ0, (160e-9, 3000e-9), 0.5e-12) +inputs = Fields.GaussField(λ0=λ0, τfwhm=τ, energy=1e-6) +responses = (Nonlinear.Kerr_field(PhysData.γ3_gas(gas)),) + +# Constant +dens0 = PhysData.density(gas, pres) +dens(z) = dens0 +m = Capillary.MarcatiliMode(a, gas, pres, loss=false) +aeff(z) = Modes.Aeff(m, z=z) +energyfun, energyfunω = Fields.energyfuncs(grid) +linop, βfun!, frame_vel, αfun = LinearOps.make_const_linop(grid, m, λ0) + +# Test a linearly increasing linop for analytical verification of propagator +linopfun = let linop=linop, L=L + function linopfun(out, z) + @. out = linop * z / L # linear increase from 0 to linop over distance L + end +end +linopifun = let linop=linop, L=L + function linopifun(out, z) + @. out = linop * z^2 / (2 * L) # integral of linopfun + end +end + +Eω, transform, FT = Luna.setup( + grid, dens, responses, inputs, βfun!, aeff) +statsfun = Stats.collect_stats(grid, Eω, + Stats.ω0(grid), + Stats.energy(grid, energyfunω)) + + +# Test analytical propagator with linopfun and linopifun +output_ana = Output.MemoryOutput(0, grid.zmax, 201, statsfun) +Luna.run(Eω, grid, (linopfun, linopifun), transform, FT, output_ana, status_period=10) + +# Test numerically integrated propagator with linopfun +output_num = Output.MemoryOutput(0, grid.zmax, 201, statsfun) +Luna.run(Eω, grid, linopfun, transform, FT, output_num, status_period=10) + +# Test piecewise version +output_piece = Output.MemoryOutput(0, grid.zmax, 201, statsfun) +Luna.run(Eω, grid, linopfun, transform, FT, output_piece, status_period=10, solver=:OrigRK45) + +@test isapprox(output_num.data["Eω"][grid.sidx, :], output_ana.data["Eω"][grid.sidx, :], rtol=2e-3) +# Piecewise solver struggles with the high frequency part: +@test isapprox(output_piece.data["Eω"][grid.sidx, :], output_ana.data["Eω"][grid.sidx, :], rtol=0.4) end diff --git a/test/test_output.jl b/test/test_output.jl index d8db385c..be639302 100644 --- a/test/test_output.jl +++ b/test/test_output.jl @@ -225,8 +225,8 @@ fpath = joinpath(homedir(), ".luna", "output_test", "test.h5") Stats.ω0(grid), Stats.energy(grid, energyfunω)) output = Output.HDF5Output(fpath, 0, grid.zmax, 51, statsfun) - function stepfun(Eω, z, dz, interpolant) - output(Eω, z, dz, interpolant) + function stepfun(Eω, z, dz, interpolant; stepcache) + output(Eω, z, dz, interpolant; stepcache) if z > 3e-2 error("Oh no!") end @@ -263,9 +263,9 @@ fpath = joinpath(homedir(), ".luna", "output_test", "test.h5") Eω = output["Eω"][idx1:idx2, :] Iω = abs2.(Eω) Iωm = abs2.(Eωm) - @test norm(Iω - Iωm)/norm(Iω) < 1e-7 - @test all(isapprox.(Eωm, Eω, atol=1e-4*maximum(abs.(Eωm)))) - @test all(isapprox.(output["stats"]["ω0"], mem.data["stats"]["ω0"], rtol=1e-6)) + @test norm(Iω - Iωm)/norm(Iω) < 6e-9 + @test all(isapprox.(Eωm, Eω, atol=1e-5*maximum(abs.(Eωm)))) + @test all(isapprox.(output["stats"]["ω0"], mem.data["stats"]["ω0"], rtol=3e-6)) @test all(output["stats"]["energy"] .≈ mem.data["stats"]["energy"]) @test output["z"] == mem.data["z"] @test output["grid"] == Grid.to_dict(grid)