Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/reference/internal.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ TimeStruct.AbstractTreeStructure
TimeStruct.AbstractStratScens
TimeStruct.StratTreeNodes
TimeStruct.StratScens
TimeStruct.ScenTreeNodes
TimeStruct.SingleStrategicScenarioWrapper
```

Expand Down
32 changes: 28 additions & 4 deletions src/strat_scenarios/strat_scenarios.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,40 @@ function Base.iterate(scs::StrategicScenario, state = (nothing, 1))
return next[1], (next[2], sp)
end

"""
struct ScenTreeNodes{S,T,N,OP<:AbstractTreeNode{S,T}} <: AbstractStratPers{T}

Type for iterating through the individual strategic nodes of a [`StrategicScenario`](@ref).
It is automatically created through the function [`strat_periods`](@ref), and hence,
[`strategic_periods`](@ref).
"""
struct ScenTreeNodes{S,T,N,OP<:AbstractTreeNode{S,T}} <: AbstractStratPers{T}
ts::StrategicScenario{S,T,N,OP}
end

# Adding methods to existing Julia functions
Base.length(sps::ScenTreeNodes) = length(sps.ts.nodes)
Base.eltype(_::Type{ScenTreeNodes{S,T,N,OP}}) where {S,T,N,OP} = OP
function Base.iterate(sps::ScenTreeNodes, state = nothing)
next = isnothing(state) ? 1 : state + 1
next == length(sps) + 1 && return nothing
return sps.ts.nodes[next], next
end
function Base.getindex(sps::ScenTreeNodes, index::Int)
return sps.ts.nodes[index]
end
function Base.last(sps::ScenTreeNodes)
return last(sps.ts.nodes)
end

"""
When the `TimeStructure` is a [`StrategicScenario`](@ref), `strat_periods` returns a
[`StratTreeNodes`](@ref) type, which, through iteration, provides [`StratNode`](@ref) types.
[`ScenTreeNodes`](@ref) type, which, through iteration, provides [`StratNode`](@ref) types.

These are equivalent to a [`StrategicPeriod`](@ref) of a [`TwoLevel`](@ref) time structure.
"""
function strat_periods(ts::StrategicScenario)
return StratTreeNodes(
TwoLevelTree(length(ts), first(ts), [n for n in ts.nodes], ts.op_per_strat),
)
return ScenTreeNodes(ts)
end

"""
Expand Down
8 changes: 7 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,8 @@ end
sps_scen_1 = strat_periods(first(scens))
@test sps == sps_scens
@test all(sps1 === sps2 for (sps1, sps2) in zip(sps_scen_1, collect(sps)[1:3]))
@test last(sps_scen_1) == collect(sps)[3]
@test isa(sps_scen_1, TS.ScenTreeNodes)

# Test that the representative periods are correct
rps = repr_periods(regtree)
Expand Down Expand Up @@ -1464,11 +1466,15 @@ end
end

two_level_tree = TwoLevelTree(5, [3, 2], uniform_day)

for (prev, n) in withprev(strat_periods(two_level_tree))
@test n.parent == prev
end

strat_scen_1 = first(strategic_scenarios(two_level_tree))
for (prev, n) in withprev(strat_periods(strat_scen_1))
@test n.parent == prev
end

@test_throws ErrorException withnext(strat_periods(two_level_tree))
@test_throws ErrorException chunk(strat_periods(two_level_tree), 2)
@test_throws ErrorException chunk_duration(strat_periods(two_level_tree), 2)
Expand Down
Loading