diff --git a/Project.toml b/Project.toml index c2f3ec2..0704bf4 100644 --- a/Project.toml +++ b/Project.toml @@ -4,24 +4,30 @@ repo = "https://github.com/NVE/JulES.jl.git" version = "0.4.0" [deps] -CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" -Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" -JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TuLiPa = "970f5c25-cd7d-4f04-b50d-7a4fe2af6639" YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" +[weakdeps] +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" + [compat] julia = "1.9.2" +OrdinaryDiffEq = "6.66.0" +ComponentArrays = "0.15.17" +Interpolations = "0.16.1" +JLD2 = "0.5.15" + +[extensions] +IfmExt = ["ComponentArrays", "Interpolations", "JLD2", "OrdinaryDiffEq"] \ No newline at end of file diff --git a/ext/IfmExt.jl b/ext/IfmExt.jl new file mode 100644 index 0000000..4c8d0f6 --- /dev/null +++ b/ext/IfmExt.jl @@ -0,0 +1,200 @@ +""" +This code is copied from https://github.com/marv-in/HydroNODE +and has BSD 3-Clause License +""" +# TODO: Update License file +# TODO: Write license info in each source file + +module IfmExt + +using OrdinaryDiffEq +using ComponentArrays +using Interpolations +using JLD2 +using Dates +using Statistics +using JulES + + +step_fct(x) = (tanh(5.0*x) + 1.0)*0.5 +Ps(P, T, Tmin) = step_fct(Tmin-T)*P +Pr(P, T, Tmin) = step_fct(T-Tmin)*P +M(S0, T, Df, Tmax) = step_fct(T-Tmax)*step_fct(S0)*minimum([S0, Df*(T-Tmax)]) +PET(T, Lday) = 29.8 * Lday * 0.611 * exp((17.3*T)/(T+237.3)) / (T + 273.2) +ET(S1, T, Lday, Smax) = step_fct(S1)*step_fct(S1-Smax)*PET(T,Lday) + step_fct(S1)*step_fct(Smax-S1)*PET(T,Lday)*(S1/Smax) +Qb(S1,f,Smax,Qmax) = step_fct(S1)*step_fct(S1-Smax)*Qmax + step_fct(S1)*step_fct(Smax-S1)*Qmax*exp(-f*(Smax-S1)) +Qs(S1, Smax) = step_fct(S1)*step_fct(S1-Smax)*(S1-Smax) + +function JulES.predict(m::JulES.TwoStateIfmHandler, u0::Vector{Float64}, t::JulES.TuLiPa.ProbTime) + JulES.update_prediction_data(m, m.updater, t) + + # create interpolation input functions + itp_method = SteffenMonotonicInterpolation() + itp_P = interpolate(m.data_pred.timepoints, m.data_pred.P, itp_method) + itp_T = interpolate(m.data_pred.timepoints, m.data_pred.T, itp_method) + itp_Lday = interpolate(m.data_pred.timepoints, m.data_pred.Lday, itp_method) + + (S0, G0) = u0 + (Q, __) = JulES.predict(m.predictor, S0, G0, itp_Lday, itp_P, itp_T, m.data_pred.timepoints) + + Q = Float64.(Q) + + Q .= Q .* m.m3s_per_mm + + return Q +end + +function JulES.estimate_u0(m::JulES.TwoStateIfmHandler, t::JulES.TuLiPa.ProbTime) + if isnothing(m.prev_t) + JulES._initial_data_obs_update(m, t) + else + JulES._data_obs_update(m, t) + end + m.prev_t = t + + # do prediction from start of obs up until today + (S0, G0) = (Float32(0), Float32(0)) + + # create interpolation input functions + itp_method = SteffenMonotonicInterpolation() + itp_P = interpolate(m.data_obs.timepoints, m.data_obs.P, itp_method) + itp_T = interpolate(m.data_obs.timepoints, m.data_obs.T, itp_method) + itp_Lday = interpolate(m.data_obs.timepoints, m.data_obs.Lday, itp_method) + + (__, OED_sol) = JulES.predict(m.predictor, S0, G0, itp_Lday, itp_P, itp_T, m.data_obs.timepoints) + + # extract states + est_S0 = Float64(last(OED_sol[1, :])) + est_G0 = Float64(last(OED_sol[2, :])) + + return [est_S0, est_G0] +end + +function JulES.calculate_normalize_factor(ifm_model) + start_with_buffer = ifm_model.handler.scen_start - Day(ifm_model.handler.ndays_obs) # Add ndays_obs days buffer + days_with_buffer = Day(ifm_model.handler.scen_stop - start_with_buffer) |> Dates.value + timepoints_with_buffer = (1:days_with_buffer) + days = Dates.value(Day(ifm_model.handler.scen_stop - ifm_model.handler.scen_start)) + timepoints_start = days_with_buffer - days + + P = zeros(length(timepoints_with_buffer)) + T = zeros(length(timepoints_with_buffer)) + Lday = zeros(length(timepoints_with_buffer)) + for i in timepoints_with_buffer + start = ifm_model.handler.scen_start + Day(i - 1) + P[i] = JulES.TuLiPa.getweightedaverage(ifm_model.handler.hist_P, start, JulES.ONEDAY_MS_TIMEDELTA) + T[i] = JulES.TuLiPa.getweightedaverage(ifm_model.handler.hist_T, start, JulES.ONEDAY_MS_TIMEDELTA) + Lday[i] = JulES.TuLiPa.getweightedaverage(ifm_model.handler.hist_Lday, start, JulES.ONEDAY_MS_TIMEDELTA) + end + + itp_method = SteffenMonotonicInterpolation() + itp_P = interpolate(timepoints_with_buffer, P, itp_method) + itp_T = interpolate(timepoints_with_buffer, T, itp_method) + itp_Lday = interpolate(timepoints_with_buffer, Lday, itp_method) + Q, _ = JulES.predict(ifm_model.handler.predictor, 0, 0, itp_Lday, itp_P, itp_T, timepoints_with_buffer) + Q = Float64.(Q)[timepoints_start:end] + Q .= Q .* ifm_model.handler.m3s_per_mm + return 1 / mean(Q) +end + +function JulES.common_includeTwoStateIfm!(Constructor, toplevel::Dict, lowlevel::Dict, elkey::JulES.TuLiPa.ElementKey, value::Dict) + JulES.TuLiPa.checkkey(toplevel, elkey) + + model_params = JulES.TuLiPa.getdictvalue(value, "ModelParams", String, elkey) + + moments = nothing + if haskey(value, "Moments") + moments = JulES.TuLiPa.getdictvalue(value, "Moments", String, elkey) + end + + hist_P = JulES.TuLiPa.getdictvalue(value, "HistoricalPercipitation", JulES.TuLiPa.TIMEVECTORPARSETYPES, elkey) + hist_T = JulES.TuLiPa.getdictvalue(value, "HistoricalTemperature", JulES.TuLiPa.TIMEVECTORPARSETYPES, elkey) + hist_Lday = JulES.TuLiPa.getdictvalue(value, "HistoricalDaylight", JulES.TuLiPa.TIMEVECTORPARSETYPES, elkey) + + ndays_pred = JulES.TuLiPa.getdictvalue(value, "NDaysPred", Real, elkey) + try + ndays_pred = Int(ndays_pred) + @assert ndays_pred >= 0 + catch e + error("Value for key NDaysPred must be positive integer for $elkey") + end + + basin_area = JulES.TuLiPa.getdictvalue(value, "BasinArea", Float64, elkey) + + deps = JulES.TuLiPa.Id[] + + all_ok = true + + (id, hist_P, ok) = JulES.TuLiPa.getdicttimevectorvalue(lowlevel, hist_P) + all_ok = all_ok && ok + JulES.TuLiPa._update_deps(deps, id, ok) + + (id, hist_T, ok) = JulES.TuLiPa.getdicttimevectorvalue(lowlevel, hist_T) + all_ok = all_ok && ok + JulES.TuLiPa._update_deps(deps, id, ok) + + (id, hist_Lday, ok) = JulES.TuLiPa.getdicttimevectorvalue(lowlevel, hist_Lday) + all_ok = all_ok && ok + JulES.TuLiPa._update_deps(deps, id, ok) + + if all_ok == false + return (false, deps) + end + + # TODO: Maybe make this user input in future? + updater = JulES.SimpleIfmDataUpdater() + + model_params = JLD2.load_object(model_params) + + is_nn = !isnothing(moments) + if is_nn + # convert model_params, stored with simpler data structure for stability between versions, + # into ComponentArray, which the NN-model needs + # (the simpler data structure is Vector{Tuple{Vector{Float32}, Vector{Float32}}}) + _subarray(i) = ComponentArray(weight = model_params[i][1], bias = model_params[i][2]) + _tuple(i) = (Symbol("layer_", i), _subarray(i)) + model_params = ComponentArray(NamedTuple(_tuple(i) for i in eachindex(model_params))) + # add moments, which is needed to normalize state inputs to the NN-model + moments = JLD2.load_object(moments) + model_params = (model_params, moments) + end + + data_forecast = nothing + data_obs = nothing + ndays_obs = 365 + data_forecast = nothing + ndays_forecast = 0 + + periodkey = JulES.TuLiPa.Id(JulES.TuLiPa.TIMEPERIOD_CONCEPT, "ScenarioTimePeriod") + period = lowlevel[periodkey] + scen_start = period["Start"] + scen_stop = period["Stop"] + + id = JulES.TuLiPa.getobjkey(elkey) + toplevel[id] = Constructor(id, model_params, updater, basin_area, hist_P, hist_T, hist_Lday, + ndays_pred, ndays_obs, ndays_forecast, data_obs, data_forecast, scen_start, scen_stop) + + return (true, deps) +end + +function JulES.basic_bucket_incl_states(p_, itp_Lday, itp_P, itp_T, t_out) + function exp_hydro_optim_states!(dS,S,ps,t) + f, Smax, Qmax, Df, Tmax, Tmin = ps + Lday = itp_Lday(t) + P = itp_P(t) + T = itp_T(t) + Q_out = Qb(S[2],f,Smax,Qmax) + Qs(S[2], Smax) + dS[1] = Ps(P, T, Tmin) - M(S[1], T, Df, Tmax) + dS[2] = Pr(P, T, Tmin) + M(S[1], T, Df, Tmax) - ET(S[2], T, Lday, Smax) - Q_out + end + + prob = ODEProblem(exp_hydro_optim_states!, p_[1:2], Float64.((t_out[1], maximum(t_out)))) + # sol = solve(prob, BS3(), u0=p_[1:2], p=p_[3:end], saveat=t_out, reltol=1e-3, abstol=1e-3, sensealg=ForwardDiffSensitivity()) + sol = solve(prob, BS3(), u0=p_[1:2], p=p_[3:end], saveat=t_out, reltol=1e-3, abstol=1e-3) + Qb_ = Qb.(sol[2,:], p_[3], p_[4], p_[5]) + Qs_ = Qs.(sol[2,:], p_[4]) + Qout_ = Qb_.+Qs_ + return Qout_, sol +end + +end \ No newline at end of file diff --git a/ext/IfmNeuralExt.jl b/ext/IfmNeuralExt.jl new file mode 100644 index 0000000..33f7e33 --- /dev/null +++ b/ext/IfmNeuralExt.jl @@ -0,0 +1,52 @@ +""" +This code is copied from https://github.com/marv-in/HydroNODE +and has BSD 3-Clause License +""" +# TODO: Update License file +# TODO: Write license info in each source file + +module IfmNeuralExt + +using DiffEqFlux +using SciMLSensitivity +using Optimization +using OptimizationOptimisers +using OptimizationBBO +using Zygote +using Lux +using Random +using CSV +using JulES + +function JulES.initialize_NN_model() + rng = Random.default_rng() + NNmodel = Lux.Chain(Lux.Dense(4, 32, tanh), Lux.Dense(32,32, leakyrelu), Lux.Dense(32,32, leakyrelu), Lux.Dense(32,32, leakyrelu), + Lux.Dense(32,32, leakyrelu), Lux.Dense(32,5)) + p_NN_init, st_NN_init = Lux.setup(rng, NNmodel) + NN_in_fct(x, p) = NNmodel(x,p,st_NN_init)[1] + p_NN_init = ComponentArray(p_NN_init) + return NN_in_fct, p_NN_init +end + +function JulES.NeuralODE_M100(p, norm_S0, norm_S1, norm_P, norm_T, itp_Lday, itp_P, itp_T, t_out, ann; S_init = [0.0, 0.0]) + function NeuralODE_M100_core!(dS,S,p,t) + Lday = itp_Lday(t) + P = itp_P(t) + T = itp_T(t) + g = ann([norm_S0(S[1]), norm_S1(S[2]), norm_P(P), norm_T(T)],p) + melting = relu(step_fct(S[1])*sinh(g[3])) + dS[1] = relu(sinh(g[4])*step_fct(-T)) - melting + dS[2] = relu(sinh(g[5])) + melting - step_fct(S[2])*Lday*exp(g[1])- step_fct(S[2])*exp(g[2]) + end + prob = ODEProblem(NeuralODE_M100_core!, S_init, Float64.((t_out[1], maximum(t_out))), p) + # sol = solve(prob, BS3(), dt=1.0, saveat=t_out, reltol=1e-3, abstol=1e-3, sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())) + sol = solve(prob, BS3(), dt=1.0, saveat=t_out, reltol=1e-3, abstol=1e-3) + P_interp = norm_P.(itp_P.(t_out)) + T_interp = norm_T.(itp_T.(t_out)) + S0_ = norm_S0.(sol[1,:]) + S1_ = norm_S1.(sol[2,:]) + Qout_ = exp.(ann(permutedims([S0_ S1_ P_interp T_interp]),p)[2,:]) + return Qout_, sol +end + +end \ No newline at end of file diff --git a/src/JulES.jl b/src/JulES.jl index 7b41518..b485f9b 100644 --- a/src/JulES.jl +++ b/src/JulES.jl @@ -13,10 +13,6 @@ using YAML using HDF5 # Used by ifm -#using CSV -#using Random -#using OrdinaryDiffEq -#using Lux #using ComponentArrays #using Interpolations #using JLD2 @@ -31,7 +27,6 @@ using HDF5 include("abstract_types.jl") include("dimension_types.jl") -include("ifm_bsd.jl") include("ifm.jl") include("generic_io.jl") include("io.jl") diff --git a/src/ifm.jl b/src/ifm.jl index 27800f5..4cbde2f 100644 --- a/src/ifm.jl +++ b/src/ifm.jl @@ -11,6 +11,7 @@ and integration with JulES # TODO: Replace all calls to Day(n) with Millisecond(86400000*n) for better performance const ONEDAY_MS_TIMEDELTA = TuLiPa.MsTimeDelta(Day(1)) +const extension_error_msg = "Missing optional dependency, cannot load extension." struct TwoStateIfmData P::Vector{Float32} @@ -176,48 +177,11 @@ function _data_obs_update(m::TwoStateIfmHandler, t::TuLiPa.ProbTime) end function estimate_u0(m::TwoStateIfmHandler, t::TuLiPa.ProbTime) - if isnothing(m.prev_t) - _initial_data_obs_update(m, t) - else - _data_obs_update(m, t) - end - m.prev_t = t - - # do prediction from start of obs up until today - (S0, G0) = (Float32(0), Float32(0)) - - # create interpolation input functions - itp_method = SteffenMonotonicInterpolation() - itp_P = interpolate(m.data_obs.timepoints, m.data_obs.P, itp_method) - itp_T = interpolate(m.data_obs.timepoints, m.data_obs.T, itp_method) - itp_Lday = interpolate(m.data_obs.timepoints, m.data_obs.Lday, itp_method) - - (__, OED_sol) = predict(m.predictor, S0, G0, itp_Lday, itp_P, itp_T, m.data_obs.timepoints) - - # extract states - est_S0 = Float64(last(OED_sol[1, :])) - est_G0 = Float64(last(OED_sol[2, :])) - - return [est_S0, est_G0] + error(extension_error_msg) end function predict(m::TwoStateIfmHandler, u0::Vector{Float64}, t::TuLiPa.ProbTime) - update_prediction_data(m, m.updater, t) - - # create interpolation input functions - itp_method = SteffenMonotonicInterpolation() - itp_P = interpolate(m.data_pred.timepoints, m.data_pred.P, itp_method) - itp_T = interpolate(m.data_pred.timepoints, m.data_pred.T, itp_method) - itp_Lday = interpolate(m.data_pred.timepoints, m.data_pred.Lday, itp_method) - - (S0, G0) = u0 - (Q, __) = predict(m.predictor, S0, G0, itp_Lday, itp_P, itp_T, m.data_pred.timepoints) - - Q = Float64.(Q) - - Q .= Q .* m.m3s_per_mm - - return Q + error(extension_error_msg) end # TODO: Add AutoCorrIfmDataUpdater that use value w(t)*x(t0) + (1-w(t-t0))*x(t), where w(0) = 1 and w -> 0 for larger inputs @@ -337,86 +301,7 @@ function includeTwoStateNeuralODEIfm!(toplevel::Dict, lowlevel::Dict, elkey::TuL end function common_includeTwoStateIfm!(Constructor, toplevel::Dict, lowlevel::Dict, elkey::TuLiPa.ElementKey, value::Dict) - TuLiPa.checkkey(toplevel, elkey) - - model_params = TuLiPa.getdictvalue(value, "ModelParams", String, elkey) - - moments = nothing - if haskey(value, "Moments") - moments = TuLiPa.getdictvalue(value, "Moments", String, elkey) - end - - hist_P = TuLiPa.getdictvalue(value, "HistoricalPercipitation", TuLiPa.TIMEVECTORPARSETYPES, elkey) - hist_T = TuLiPa.getdictvalue(value, "HistoricalTemperature", TuLiPa.TIMEVECTORPARSETYPES, elkey) - hist_Lday = TuLiPa.getdictvalue(value, "HistoricalDaylight", TuLiPa.TIMEVECTORPARSETYPES, elkey) - - ndays_pred = TuLiPa.getdictvalue(value, "NDaysPred", Real, elkey) - try - ndays_pred = Int(ndays_pred) - @assert ndays_pred >= 0 - catch e - error("Value for key NDaysPred must be positive integer for $elkey") - end - - basin_area = TuLiPa.getdictvalue(value, "BasinArea", Float64, elkey) - - deps = TuLiPa.Id[] - - all_ok = true - - (id, hist_P, ok) = TuLiPa.getdicttimevectorvalue(lowlevel, hist_P) - all_ok = all_ok && ok - TuLiPa._update_deps(deps, id, ok) - - (id, hist_T, ok) = TuLiPa.getdicttimevectorvalue(lowlevel, hist_T) - all_ok = all_ok && ok - TuLiPa._update_deps(deps, id, ok) - - (id, hist_Lday, ok) = TuLiPa.getdicttimevectorvalue(lowlevel, hist_Lday) - all_ok = all_ok && ok - TuLiPa._update_deps(deps, id, ok) - - if all_ok == false - return (false, deps) - end - - # TODO: Maybe make this user input in future? - updater = SimpleIfmDataUpdater() - - model_params = JLD2.load_object(model_params) - - is_nn = !isnothing(moments) - if is_nn - # convert model_params, stored with simpler data structure for stability between versions, - # into ComponentArray, which the NN-model needs - # (the simpler data structure is Vector{Tuple{Vector{Float32}, Vector{Float32}}}) - _subarray(i) = ComponentArray(weight = model_params[i][1], bias = model_params[i][2]) - _tuple(i) = (Symbol("layer_", i), _subarray(i)) - model_params = ComponentArray(NamedTuple(_tuple(i) for i in eachindex(model_params))) - # add moments, which is needed to normalize state inputs to the NN-model - moments = JLD2.load_object(moments) - model_params = (model_params, moments) - end - - data_forecast = nothing - data_obs = nothing - ndays_obs = 365 - data_forecast = nothing - ndays_forecast = 0 - - - - periodkey = TuLiPa.Id(TuLiPa.TIMEPERIOD_CONCEPT, "ScenarioTimePeriod") - period = lowlevel[periodkey] - scen_start = period["Start"] - scen_stop = period["Stop"] - - - id = TuLiPa.getobjkey(elkey) - toplevel[id] = Constructor(id, model_params, updater, basin_area, hist_P, hist_T, hist_Lday, - ndays_pred, ndays_obs, ndays_forecast, data_obs, data_forecast, scen_start, scen_stop) - - return (true, deps) + error(extension_error_msg) end # --- Functions used in run_serial in connection with inflow models --- @@ -472,30 +357,7 @@ function save_ifm_Q(div_db, inflow_name, stepnr, Q) end function calculate_normalize_factor(ifm_model) - start_with_buffer = ifm_model.handler.scen_start - Day(ifm_model.handler.ndays_obs) # Add ndays_obs days buffer - days_with_buffer = Day(ifm_model.handler.scen_stop - start_with_buffer) |> Dates.value - timepoints_with_buffer = (1:days_with_buffer) - days = Dates.value(Day(ifm_model.handler.scen_stop - ifm_model.handler.scen_start)) - timepoints_start = days_with_buffer - days - - P = zeros(length(timepoints_with_buffer)) - T = zeros(length(timepoints_with_buffer)) - Lday = zeros(length(timepoints_with_buffer)) - for i in timepoints_with_buffer - start = ifm_model.handler.scen_start + Day(i - 1) - P[i] = TuLiPa.getweightedaverage(ifm_model.handler.hist_P, start, JulES.ONEDAY_MS_TIMEDELTA) - T[i] = TuLiPa.getweightedaverage(ifm_model.handler.hist_T, start, JulES.ONEDAY_MS_TIMEDELTA) - Lday[i] = TuLiPa.getweightedaverage(ifm_model.handler.hist_Lday, start, JulES.ONEDAY_MS_TIMEDELTA) - end - - itp_method = JulES.SteffenMonotonicInterpolation() - itp_P = JulES.interpolate(timepoints_with_buffer, P, itp_method) - itp_T = JulES.interpolate(timepoints_with_buffer, T, itp_method) - itp_Lday = JulES.interpolate(timepoints_with_buffer, Lday, itp_method) - Q, _ = JulES.predict(ifm_model.handler.predictor, 0, 0, itp_Lday, itp_P, itp_T, timepoints_with_buffer) - Q = Float64.(Q)[timepoints_start:end] - Q .= Q .* ifm_model.handler.m3s_per_mm - return 1 / mean(Q) + error(extension_error_msg) end """ @@ -821,3 +683,16 @@ function copy_elements_iprogtype(elements, iprogtype, ifm_names, ifm_derivedname @assert length(elements) == length(elements1) return elements1 end + +function initialize_NN_model() + error(extension_error_msg) +end + +function basic_bucket_incl_states(p_, itp_Lday, itp_P, itp_T, t_out) + error(extension_error_msg) +end + +function NeuralODE_M100(p, norm_S0, norm_S1, norm_P, norm_T, + itp_Lday, itp_P, itp_T, t_out, ann; S_init = [0.0, 0.0]) + error(extension_error_msg) +end \ No newline at end of file diff --git a/src/ifm_bsd.jl b/src/ifm_bsd.jl deleted file mode 100644 index cfc0513..0000000 --- a/src/ifm_bsd.jl +++ /dev/null @@ -1,66 +0,0 @@ -""" -This code is copied from https://github.com/marv-in/HydroNODE -and has BSD 3-Clause License -""" -# TODO: Update License file -# TODO: Write license info in each source file - -function initialize_NN_model() - rng = Random.default_rng() - NNmodel = Lux.Chain(Lux.Dense(4, 32, tanh), Lux.Dense(32,32, leakyrelu), Lux.Dense(32,32, leakyrelu), Lux.Dense(32,32, leakyrelu), - Lux.Dense(32,32, leakyrelu), Lux.Dense(32,5)) - p_NN_init, st_NN_init = Lux.setup(rng, NNmodel) - NN_in_fct(x, p) = NNmodel(x,p,st_NN_init)[1] - p_NN_init = ComponentArray(p_NN_init) - return NN_in_fct, p_NN_init -end - -step_fct(x) = (tanh(5.0*x) + 1.0)*0.5 -Ps(P, T, Tmin) = step_fct(Tmin-T)*P -Pr(P, T, Tmin) = step_fct(T-Tmin)*P -M(S0, T, Df, Tmax) = step_fct(T-Tmax)*step_fct(S0)*minimum([S0, Df*(T-Tmax)]) -PET(T, Lday) = 29.8 * Lday * 0.611 * exp((17.3*T)/(T+237.3)) / (T + 273.2) -ET(S1, T, Lday, Smax) = step_fct(S1)*step_fct(S1-Smax)*PET(T,Lday) + step_fct(S1)*step_fct(Smax-S1)*PET(T,Lday)*(S1/Smax) -Qb(S1,f,Smax,Qmax) = step_fct(S1)*step_fct(S1-Smax)*Qmax + step_fct(S1)*step_fct(Smax-S1)*Qmax*exp(-f*(Smax-S1)) -Qs(S1, Smax) = step_fct(S1)*step_fct(S1-Smax)*(S1-Smax) - -function basic_bucket_incl_states(p_, itp_Lday, itp_P, itp_T, t_out) - function exp_hydro_optim_states!(dS,S,ps,t) - f, Smax, Qmax, Df, Tmax, Tmin = ps - Lday = itp_Lday(t) - P = itp_P(t) - T = itp_T(t) - Q_out = Qb(S[2],f,Smax,Qmax) + Qs(S[2], Smax) - dS[1] = Ps(P, T, Tmin) - M(S[1], T, Df, Tmax) - dS[2] = Pr(P, T, Tmin) + M(S[1], T, Df, Tmax) - ET(S[2], T, Lday, Smax) - Q_out - end - - prob = ODEProblem(exp_hydro_optim_states!, p_[1:2], Float64.((t_out[1], maximum(t_out)))) - # sol = solve(prob, BS3(), u0=p_[1:2], p=p_[3:end], saveat=t_out, reltol=1e-3, abstol=1e-3, sensealg=ForwardDiffSensitivity()) - sol = solve(prob, BS3(), u0=p_[1:2], p=p_[3:end], saveat=t_out, reltol=1e-3, abstol=1e-3) - Qb_ = Qb.(sol[2,:], p_[3], p_[4], p_[5]) - Qs_ = Qs.(sol[2,:], p_[4]) - Qout_ = Qb_.+Qs_ - return Qout_, sol -end - -function NeuralODE_M100(p, norm_S0, norm_S1, norm_P, norm_T, itp_Lday, itp_P, itp_T, t_out, ann; S_init = [0.0, 0.0]) - function NeuralODE_M100_core!(dS,S,p,t) - Lday = itp_Lday(t) - P = itp_P(t) - T = itp_T(t) - g = ann([norm_S0(S[1]), norm_S1(S[2]), norm_P(P), norm_T(T)],p) - melting = relu(step_fct(S[1])*sinh(g[3])) - dS[1] = relu(sinh(g[4])*step_fct(-T)) - melting - dS[2] = relu(sinh(g[5])) + melting - step_fct(S[2])*Lday*exp(g[1])- step_fct(S[2])*exp(g[2]) - end - prob = ODEProblem(NeuralODE_M100_core!, S_init, Float64.((t_out[1], maximum(t_out))), p) - # sol = solve(prob, BS3(), dt=1.0, saveat=t_out, reltol=1e-3, abstol=1e-3, sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())) - sol = solve(prob, BS3(), dt=1.0, saveat=t_out, reltol=1e-3, abstol=1e-3) - P_interp = norm_P.(itp_P.(t_out)) - T_interp = norm_T.(itp_T.(t_out)) - S0_ = norm_S0.(sol[1,:]) - S1_ = norm_S1.(sol[2,:]) - Qout_ = exp.(ann(permutedims([S0_ S1_ P_interp T_interp]),p)[2,:]) - return Qout_, sol -end \ No newline at end of file diff --git a/src/run_jules_wrapper.jl b/src/run_jules_wrapper.jl index 256b8d1..b5f2b08 100644 --- a/src/run_jules_wrapper.jl +++ b/src/run_jules_wrapper.jl @@ -40,15 +40,26 @@ function getdataset(config, names, filename_clearing, filename_aggregated) end function load_ifm_dep() - mod = @__MODULE__ + if myid() == 1 + function ensure_packages(pkgs::Vector{String}) + deps = values(Pkg.dependencies()) + not_installed = filter(pkg -> !any(d -> d.name == pkg, deps), pkgs) + if !isempty(not_installed) + println("Installing missing packages: ", join(not_installed, ", ")) + Pkg.add(not_installed) + else + println("All packages already installed.") + end + end + ensure_packages(["OrdinaryDiffEq", "ComponentArrays", "Interpolations", "JLD2"]) + end + @everywhere begin - @eval $mod using CSV - @eval $mod using Random - @eval $mod using OrdinaryDiffEq - @eval $mod using Lux - @eval $mod using ComponentArrays - @eval $mod using Interpolations - @eval $mod using JLD2 + Pkg.instantiate() + Base.eval(Main, :(using OrdinaryDiffEq)) + Base.eval(Main, :(using ComponentArrays)) + Base.eval(Main, :(using Interpolations)) + Base.eval(Main, :(using JLD2)) end end