From 3b9e5550e4506a5038d02fc44dd937dcd92fd4c6 Mon Sep 17 00:00:00 2001 From: Julian Straus Date: Tue, 16 Dec 2025 15:37:34 +0100 Subject: [PATCH] Created new type ScenTreeNodes for StrategicScenario * The new type allows for the functionality `last` * No other impact of the new type --- docs/src/reference/internal.md | 1 + src/strat_scenarios/strat_scenarios.jl | 32 ++++++++++++++++++++++---- test/runtests.jl | 8 ++++++- 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/docs/src/reference/internal.md b/docs/src/reference/internal.md index 5972578..733903b 100644 --- a/docs/src/reference/internal.md +++ b/docs/src/reference/internal.md @@ -27,6 +27,7 @@ TimeStruct.AbstractTreeStructure TimeStruct.AbstractStratScens TimeStruct.StratTreeNodes TimeStruct.StratScens +TimeStruct.ScenTreeNodes TimeStruct.SingleStrategicScenarioWrapper ``` diff --git a/src/strat_scenarios/strat_scenarios.jl b/src/strat_scenarios/strat_scenarios.jl index 4d36b2f..c3994b0 100644 --- a/src/strat_scenarios/strat_scenarios.jl +++ b/src/strat_scenarios/strat_scenarios.jl @@ -109,16 +109,40 @@ function Base.iterate(scs::StrategicScenario, state = (nothing, 1)) return next[1], (next[2], sp) end +""" + struct ScenTreeNodes{S,T,N,OP<:AbstractTreeNode{S,T}} <: AbstractStratPers{T} + +Type for iterating through the individual strategic nodes of a [`StrategicScenario`](@ref). +It is automatically created through the function [`strat_periods`](@ref), and hence, +[`strategic_periods`](@ref). +""" +struct ScenTreeNodes{S,T,N,OP<:AbstractTreeNode{S,T}} <: AbstractStratPers{T} + ts::StrategicScenario{S,T,N,OP} +end + +# Adding methods to existing Julia functions +Base.length(sps::ScenTreeNodes) = length(sps.ts.nodes) +Base.eltype(_::Type{ScenTreeNodes{S,T,N,OP}}) where {S,T,N,OP} = OP +function Base.iterate(sps::ScenTreeNodes, state = nothing) + next = isnothing(state) ? 1 : state + 1 + next == length(sps) + 1 && return nothing + return sps.ts.nodes[next], next +end +function Base.getindex(sps::ScenTreeNodes, index::Int) + return sps.ts.nodes[index] +end +function Base.last(sps::ScenTreeNodes) + return last(sps.ts.nodes) +end + """ When the `TimeStructure` is a [`StrategicScenario`](@ref), `strat_periods` returns a -[`StratTreeNodes`](@ref) type, which, through iteration, provides [`StratNode`](@ref) types. +[`ScenTreeNodes`](@ref) type, which, through iteration, provides [`StratNode`](@ref) types. These are equivalent to a [`StrategicPeriod`](@ref) of a [`TwoLevel`](@ref) time structure. """ function strat_periods(ts::StrategicScenario) - return StratTreeNodes( - TwoLevelTree(length(ts), first(ts), [n for n in ts.nodes], ts.op_per_strat), - ) + return ScenTreeNodes(ts) end """ diff --git a/test/runtests.jl b/test/runtests.jl index 5418619..190f844 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1228,6 +1228,8 @@ end sps_scen_1 = strat_periods(first(scens)) @test sps == sps_scens @test all(sps1 === sps2 for (sps1, sps2) in zip(sps_scen_1, collect(sps)[1:3])) + @test last(sps_scen_1) == collect(sps)[3] + @test isa(sps_scen_1, TS.ScenTreeNodes) # Test that the representative periods are correct rps = repr_periods(regtree) @@ -1464,11 +1466,15 @@ end end two_level_tree = TwoLevelTree(5, [3, 2], uniform_day) - for (prev, n) in withprev(strat_periods(two_level_tree)) @test n.parent == prev end + strat_scen_1 = first(strategic_scenarios(two_level_tree)) + for (prev, n) in withprev(strat_periods(strat_scen_1)) + @test n.parent == prev + end + @test_throws ErrorException withnext(strat_periods(two_level_tree)) @test_throws ErrorException chunk(strat_periods(two_level_tree), 2) @test_throws ErrorException chunk_duration(strat_periods(two_level_tree), 2)