Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,30 @@ repo = "https://github.com/NVE/JulES.jl.git"
version = "0.4.0"

[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TuLiPa = "970f5c25-cd7d-4f04-b50d-7a4fe2af6639"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[weakdeps]
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"

[compat]
julia = "1.9.2"
OrdinaryDiffEq = "6.66.0"
ComponentArrays = "0.15.17"
Interpolations = "0.16.1"
JLD2 = "0.5.15"

[extensions]
IfmExt = ["ComponentArrays", "Interpolations", "JLD2", "OrdinaryDiffEq"]
200 changes: 200 additions & 0 deletions ext/IfmExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""
This code is copied from https://github.com/marv-in/HydroNODE
and has BSD 3-Clause License
"""
# TODO: Update License file
# TODO: Write license info in each source file

module IfmExt

using OrdinaryDiffEq
using ComponentArrays
using Interpolations
using JLD2
using Dates
using Statistics
using JulES


step_fct(x) = (tanh(5.0*x) + 1.0)*0.5
Ps(P, T, Tmin) = step_fct(Tmin-T)*P
Pr(P, T, Tmin) = step_fct(T-Tmin)*P
M(S0, T, Df, Tmax) = step_fct(T-Tmax)*step_fct(S0)*minimum([S0, Df*(T-Tmax)])
PET(T, Lday) = 29.8 * Lday * 0.611 * exp((17.3*T)/(T+237.3)) / (T + 273.2)
ET(S1, T, Lday, Smax) = step_fct(S1)*step_fct(S1-Smax)*PET(T,Lday) + step_fct(S1)*step_fct(Smax-S1)*PET(T,Lday)*(S1/Smax)
Qb(S1,f,Smax,Qmax) = step_fct(S1)*step_fct(S1-Smax)*Qmax + step_fct(S1)*step_fct(Smax-S1)*Qmax*exp(-f*(Smax-S1))
Qs(S1, Smax) = step_fct(S1)*step_fct(S1-Smax)*(S1-Smax)

function JulES.predict(m::JulES.TwoStateIfmHandler, u0::Vector{Float64}, t::JulES.TuLiPa.ProbTime)
JulES.update_prediction_data(m, m.updater, t)

# create interpolation input functions
itp_method = SteffenMonotonicInterpolation()
itp_P = interpolate(m.data_pred.timepoints, m.data_pred.P, itp_method)
itp_T = interpolate(m.data_pred.timepoints, m.data_pred.T, itp_method)
itp_Lday = interpolate(m.data_pred.timepoints, m.data_pred.Lday, itp_method)

(S0, G0) = u0
(Q, __) = JulES.predict(m.predictor, S0, G0, itp_Lday, itp_P, itp_T, m.data_pred.timepoints)

Q = Float64.(Q)

Q .= Q .* m.m3s_per_mm

return Q
end

function JulES.estimate_u0(m::JulES.TwoStateIfmHandler, t::JulES.TuLiPa.ProbTime)
if isnothing(m.prev_t)
JulES._initial_data_obs_update(m, t)
else
JulES._data_obs_update(m, t)
end
m.prev_t = t

# do prediction from start of obs up until today
(S0, G0) = (Float32(0), Float32(0))

# create interpolation input functions
itp_method = SteffenMonotonicInterpolation()
itp_P = interpolate(m.data_obs.timepoints, m.data_obs.P, itp_method)
itp_T = interpolate(m.data_obs.timepoints, m.data_obs.T, itp_method)
itp_Lday = interpolate(m.data_obs.timepoints, m.data_obs.Lday, itp_method)

(__, OED_sol) = JulES.predict(m.predictor, S0, G0, itp_Lday, itp_P, itp_T, m.data_obs.timepoints)

# extract states
est_S0 = Float64(last(OED_sol[1, :]))
est_G0 = Float64(last(OED_sol[2, :]))

return [est_S0, est_G0]
end

function JulES.calculate_normalize_factor(ifm_model)
start_with_buffer = ifm_model.handler.scen_start - Day(ifm_model.handler.ndays_obs) # Add ndays_obs days buffer
days_with_buffer = Day(ifm_model.handler.scen_stop - start_with_buffer) |> Dates.value
timepoints_with_buffer = (1:days_with_buffer)
days = Dates.value(Day(ifm_model.handler.scen_stop - ifm_model.handler.scen_start))
timepoints_start = days_with_buffer - days

P = zeros(length(timepoints_with_buffer))
T = zeros(length(timepoints_with_buffer))
Lday = zeros(length(timepoints_with_buffer))
for i in timepoints_with_buffer
start = ifm_model.handler.scen_start + Day(i - 1)
P[i] = JulES.TuLiPa.getweightedaverage(ifm_model.handler.hist_P, start, JulES.ONEDAY_MS_TIMEDELTA)
T[i] = JulES.TuLiPa.getweightedaverage(ifm_model.handler.hist_T, start, JulES.ONEDAY_MS_TIMEDELTA)
Lday[i] = JulES.TuLiPa.getweightedaverage(ifm_model.handler.hist_Lday, start, JulES.ONEDAY_MS_TIMEDELTA)
end

itp_method = SteffenMonotonicInterpolation()
itp_P = interpolate(timepoints_with_buffer, P, itp_method)
itp_T = interpolate(timepoints_with_buffer, T, itp_method)
itp_Lday = interpolate(timepoints_with_buffer, Lday, itp_method)
Q, _ = JulES.predict(ifm_model.handler.predictor, 0, 0, itp_Lday, itp_P, itp_T, timepoints_with_buffer)
Q = Float64.(Q)[timepoints_start:end]
Q .= Q .* ifm_model.handler.m3s_per_mm
return 1 / mean(Q)
end

function JulES.common_includeTwoStateIfm!(Constructor, toplevel::Dict, lowlevel::Dict, elkey::JulES.TuLiPa.ElementKey, value::Dict)
JulES.TuLiPa.checkkey(toplevel, elkey)

model_params = JulES.TuLiPa.getdictvalue(value, "ModelParams", String, elkey)

moments = nothing
if haskey(value, "Moments")
moments = JulES.TuLiPa.getdictvalue(value, "Moments", String, elkey)
end

hist_P = JulES.TuLiPa.getdictvalue(value, "HistoricalPercipitation", JulES.TuLiPa.TIMEVECTORPARSETYPES, elkey)
hist_T = JulES.TuLiPa.getdictvalue(value, "HistoricalTemperature", JulES.TuLiPa.TIMEVECTORPARSETYPES, elkey)
hist_Lday = JulES.TuLiPa.getdictvalue(value, "HistoricalDaylight", JulES.TuLiPa.TIMEVECTORPARSETYPES, elkey)

ndays_pred = JulES.TuLiPa.getdictvalue(value, "NDaysPred", Real, elkey)
try
ndays_pred = Int(ndays_pred)
@assert ndays_pred >= 0
catch e
error("Value for key NDaysPred must be positive integer for $elkey")
end

basin_area = JulES.TuLiPa.getdictvalue(value, "BasinArea", Float64, elkey)

deps = JulES.TuLiPa.Id[]

all_ok = true

(id, hist_P, ok) = JulES.TuLiPa.getdicttimevectorvalue(lowlevel, hist_P)
all_ok = all_ok && ok
JulES.TuLiPa._update_deps(deps, id, ok)

(id, hist_T, ok) = JulES.TuLiPa.getdicttimevectorvalue(lowlevel, hist_T)
all_ok = all_ok && ok
JulES.TuLiPa._update_deps(deps, id, ok)

(id, hist_Lday, ok) = JulES.TuLiPa.getdicttimevectorvalue(lowlevel, hist_Lday)
all_ok = all_ok && ok
JulES.TuLiPa._update_deps(deps, id, ok)

if all_ok == false
return (false, deps)
end

# TODO: Maybe make this user input in future?
updater = JulES.SimpleIfmDataUpdater()

model_params = JLD2.load_object(model_params)

is_nn = !isnothing(moments)
if is_nn
# convert model_params, stored with simpler data structure for stability between versions,
# into ComponentArray, which the NN-model needs
# (the simpler data structure is Vector{Tuple{Vector{Float32}, Vector{Float32}}})
_subarray(i) = ComponentArray(weight = model_params[i][1], bias = model_params[i][2])
_tuple(i) = (Symbol("layer_", i), _subarray(i))
model_params = ComponentArray(NamedTuple(_tuple(i) for i in eachindex(model_params)))
# add moments, which is needed to normalize state inputs to the NN-model
moments = JLD2.load_object(moments)
model_params = (model_params, moments)
end

data_forecast = nothing
data_obs = nothing
ndays_obs = 365
data_forecast = nothing
ndays_forecast = 0

periodkey = JulES.TuLiPa.Id(JulES.TuLiPa.TIMEPERIOD_CONCEPT, "ScenarioTimePeriod")
period = lowlevel[periodkey]
scen_start = period["Start"]
scen_stop = period["Stop"]

id = JulES.TuLiPa.getobjkey(elkey)
toplevel[id] = Constructor(id, model_params, updater, basin_area, hist_P, hist_T, hist_Lday,
ndays_pred, ndays_obs, ndays_forecast, data_obs, data_forecast, scen_start, scen_stop)

return (true, deps)
end

function JulES.basic_bucket_incl_states(p_, itp_Lday, itp_P, itp_T, t_out)
function exp_hydro_optim_states!(dS,S,ps,t)
f, Smax, Qmax, Df, Tmax, Tmin = ps
Lday = itp_Lday(t)
P = itp_P(t)
T = itp_T(t)
Q_out = Qb(S[2],f,Smax,Qmax) + Qs(S[2], Smax)
dS[1] = Ps(P, T, Tmin) - M(S[1], T, Df, Tmax)
dS[2] = Pr(P, T, Tmin) + M(S[1], T, Df, Tmax) - ET(S[2], T, Lday, Smax) - Q_out
end

prob = ODEProblem(exp_hydro_optim_states!, p_[1:2], Float64.((t_out[1], maximum(t_out))))
# sol = solve(prob, BS3(), u0=p_[1:2], p=p_[3:end], saveat=t_out, reltol=1e-3, abstol=1e-3, sensealg=ForwardDiffSensitivity())
sol = solve(prob, BS3(), u0=p_[1:2], p=p_[3:end], saveat=t_out, reltol=1e-3, abstol=1e-3)
Qb_ = Qb.(sol[2,:], p_[3], p_[4], p_[5])
Qs_ = Qs.(sol[2,:], p_[4])
Qout_ = Qb_.+Qs_
return Qout_, sol
end

end
52 changes: 52 additions & 0 deletions ext/IfmNeuralExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
This code is copied from https://github.com/marv-in/HydroNODE
and has BSD 3-Clause License
"""
# TODO: Update License file
# TODO: Write license info in each source file

module IfmNeuralExt

using DiffEqFlux
using SciMLSensitivity
using Optimization
using OptimizationOptimisers
using OptimizationBBO
using Zygote
using Lux
using Random
using CSV
using JulES

function JulES.initialize_NN_model()
rng = Random.default_rng()
NNmodel = Lux.Chain(Lux.Dense(4, 32, tanh), Lux.Dense(32,32, leakyrelu), Lux.Dense(32,32, leakyrelu), Lux.Dense(32,32, leakyrelu),
Lux.Dense(32,32, leakyrelu), Lux.Dense(32,5))
p_NN_init, st_NN_init = Lux.setup(rng, NNmodel)
NN_in_fct(x, p) = NNmodel(x,p,st_NN_init)[1]
p_NN_init = ComponentArray(p_NN_init)
return NN_in_fct, p_NN_init
end

function JulES.NeuralODE_M100(p, norm_S0, norm_S1, norm_P, norm_T, itp_Lday, itp_P, itp_T, t_out, ann; S_init = [0.0, 0.0])
function NeuralODE_M100_core!(dS,S,p,t)
Lday = itp_Lday(t)
P = itp_P(t)
T = itp_T(t)
g = ann([norm_S0(S[1]), norm_S1(S[2]), norm_P(P), norm_T(T)],p)
melting = relu(step_fct(S[1])*sinh(g[3]))
dS[1] = relu(sinh(g[4])*step_fct(-T)) - melting
dS[2] = relu(sinh(g[5])) + melting - step_fct(S[2])*Lday*exp(g[1])- step_fct(S[2])*exp(g[2])
end
prob = ODEProblem(NeuralODE_M100_core!, S_init, Float64.((t_out[1], maximum(t_out))), p)
# sol = solve(prob, BS3(), dt=1.0, saveat=t_out, reltol=1e-3, abstol=1e-3, sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP()))
sol = solve(prob, BS3(), dt=1.0, saveat=t_out, reltol=1e-3, abstol=1e-3)
P_interp = norm_P.(itp_P.(t_out))
T_interp = norm_T.(itp_T.(t_out))
S0_ = norm_S0.(sol[1,:])
S1_ = norm_S1.(sol[2,:])
Qout_ = exp.(ann(permutedims([S0_ S1_ P_interp T_interp]),p)[2,:])
return Qout_, sol
end

end
5 changes: 0 additions & 5 deletions src/JulES.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ using YAML
using HDF5

# Used by ifm
#using CSV
#using Random
#using OrdinaryDiffEq
#using Lux
#using ComponentArrays
#using Interpolations
#using JLD2
Expand All @@ -31,7 +27,6 @@ using HDF5

include("abstract_types.jl")
include("dimension_types.jl")
include("ifm_bsd.jl")
include("ifm.jl")
include("generic_io.jl")
include("io.jl")
Expand Down
Loading
Loading