diff --git a/Project.toml b/Project.toml index 1f01229..7bfc860 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,14 @@ name = "GeometricBase" uuid = "9a0b12b7-583b-4f04-aa1f-d8551b6addc9" -authors = ["Michael Kraus"] version = "0.12.7" +authors = ["Michael Kraus"] + +[deps] +Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [compat] -julia = "1.6" +Unicode = "1.10" +julia = "1.10" [extras] SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" diff --git a/src/GeometricBase.jl b/src/GeometricBase.jl index c981cd4..1240943 100644 --- a/src/GeometricBase.jl +++ b/src/GeometricBase.jl @@ -1,21 +1,22 @@ module GeometricBase - include("Config.jl") - include("Utils.jl") +include("Config.jl") +include("Utils.jl") - include("methods.jl") - include("types.jl") +include("methods.jl") +include("types.jl") - include("abstract_problem.jl") - include("abstract_method.jl") - include("abstract_integrator.jl") - include("abstract_solution.jl") - include("abstract_solver.jl") - include("abstract_solver_state.jl") +include("abstract_problem.jl") +include("abstract_method.jl") +include("abstract_integrator.jl") +include("abstract_solution.jl") +include("abstract_solver.jl") +include("abstract_solver_state.jl") - include("data/system_types.jl") - include("data/data_types.jl") - include("data/geometric_data.jl") - include("data/state_variables.jl") +include("data/system_types.jl") +include("data/data_types.jl") +include("data/geometric_data.jl") +include("data/state_variables.jl") +include("data/state.jl") end diff --git a/src/data/state.jl b/src/data/state.jl new file mode 100644 index 0000000..3f3ec91 --- /dev/null +++ b/src/data/state.jl @@ -0,0 +1,139 @@ +using Unicode: normalize + +export State +export solution, state, vectorfield + + +# The `_state` function returns an appropriate empty state for a given state variable. +# _state(x::Number) = ScalarVariable(x) +_state(x::TimeVariable) = zero(x) +_state(x::StateVariable) = StateWithError(zero(x)) +_state(x::VectorfieldVariable) = zero(x) +_state(x::AlgebraicVariable) = zero(x) +_state(x::StateWithError) = zero(x) + +# The `_vectorfield` function returns an appropriate empty vectorfield for a given state variable. +# _vectorfield(x::Number) = missing +_vectorfield(x::TimeVariable) = missing +_vectorfield(x::StateVariable) = VectorfieldVariable(x) +_vectorfield(x::VectorfieldVariable) = missing +_vectorfield(x::AlgebraicVariable) = missing +_vectorfield(x::StateWithError) = _vectorfield(x.state) + +# The `_convert` function returns an appropriate type for any given `AbstractVariable`. +# In particular, it returns the scalar value of a `TimeVariable`. +_convert(x::Missing)::Missing = x +_convert(x::TimeVariable) = value(x) +_convert(x::AbstractVariable) = x + +# Adds a dot or bar to a symbol, indicating a time derivative or vector field, or a previous solution. +_add_symbol(s::Symbol, c::Char) = Symbol(normalize("$(s)$(c)")) +_add_bar(s::Symbol) = _add_symbol(s, Char(0x0304)) +_add_dot(s::Symbol) = _add_symbol(s, Char(0x0307)) + +# Removes a dot or bar from a symbol. +_strip_symbol(s::Symbol, c::Char) = Symbol(strip(normalize(String(s); decompose=true), c)) +_strip_bar(s::Symbol) = _strip_symbol(s, Char(0x0304)) +_strip_dot(s::Symbol) = _strip_symbol(s, Char(0x0307)) + +# q̄ = Symbol(join([Char('q'), Char(0x0304)])) +# q̇ = Symbol(join([Char('q'), Char(0x0307)])) +# q = Symbol(strip(String(:q̄), Char(0x0304))) +# q = Symbol(strip(String(:q̇), Char(0x0307))) + + +""" +Holds the solution of a geometric equation at a single time step. + +It stores all the information that is required to uniquely determine the state of a systen, +in particular all state variables and their corresponding vector fields. +""" +struct State{ + stateType<:NamedTuple, + solutionType<:NamedTuple, + vectorfieldType<:NamedTuple +} + + state::stateType + solution::solutionType + vectorfield::vectorfieldType + + function State(ics::NamedTuple) + # create solution tupke for all variables in ics + solution = NamedTuple{keys(ics)}(_state(x) for x in ics) + + # create vectorfield vector for all state variables in ics + vectorfield = NamedTuple{keys(ics)}(_vectorfield(x) for x in ics) + + # remove all fields that are missing, i.e., that correspond to a variable without vectorfield + vectorfield_filtered = NamedTuple{filter(k -> !all(ismissing.(vectorfield[k])), keys(vectorfield))}(vectorfield) + + # create vector field symbols with dotted solution symbols + vectorfield_keys = Tuple(_add_dot(k) for k in keys(vectorfield_filtered)) + vectorfield_dots = NamedTuple{vectorfield_keys}(values(vectorfield_filtered)) + + # create state by merging solution fields with filtered vector fields + state_fields = merge(solution, vectorfield_dots) + + # create state + state = new{typeof(state_fields),typeof(solution),typeof(vectorfield_filtered)}(state_fields, solution, vectorfield_filtered) + + # copy initial conditions to state + copy!(state, ics) + + return state + end +end + +@inline function Base.hasproperty(::State{ST}, s::Symbol) where {ST} + hasfield(ST, s) || hasfield(State, s) +end + +@inline function Base.getproperty(st::State{ST}, s::Symbol) where {ST} + if hasfield(ST, s) + return _convert(getfield(st, :state)[s]) + else + return getfield(st, s) + end +end + +@inline function Base.setproperty!(st::State{ST}, s::Symbol, x) where {ST} + if hasfield(ST, s) + return copy!(getfield(st, :state)[s], x) + else + return setfield!(st, s, x) + end +end + + +state(st::State) = st.state +solution(st::State) = st.solution +vectorfield(st::State) = st.vectorfield + +""" + keys(st::State) + +Return the keys of all the state variables in the `State`. +""" +Base.keys(st::State) = keys(state(st)) + +""" + copy!(st::State, sol::NamedTuple) + +Copy the values from a `NamedTuple` `sol` to the `State` `st`. + +The keys of `sol` must be a subset of the keys of the state. + +# Arguments +- `st`: the state to copy into +- `sol`: the named tuple containing the solution values to copy +""" +function Base.copy!(st::State, sol::NamedTuple) + @assert keys(sol) ⊆ keys(st) + + for k in keys(sol) + copy!(solution(st)[k], sol[k]) + end + + return st +end diff --git a/src/data/state_variables.jl b/src/data/state_variables.jl index 880159a..1a87eed 100644 --- a/src/data/state_variables.jl +++ b/src/data/state_variables.jl @@ -5,7 +5,7 @@ export TimeVariable, StateVariable, VectorfieldVariable, AlgebraicVariable export Increment, StateWithError, StateVariableWithError, StateVector export isperiodic, parenttype, verifyrange -export add!, reset!, periodic, value, vectorfield, zerovector +export add!, reset!, periodic, value, zerovector """ @@ -29,6 +29,7 @@ Base.:(==)(x::AV, y::AV) where {AV<:AbstractVariable} = parent(x) == parent(y) """ abstract type AbstractScalarVariable{DT} <: AbstractVariable{DT,0} end +Base.:(==)(x::AbstractScalarVariable, y::AbstractScalarVariable) = value(x) == value(y) Base.:(==)(x::AbstractScalarVariable, y::Number) = value(x) == y Base.:(==)(x::Number, y::AbstractScalarVariable) = y == x diff --git a/test/runtests.jl b/test/runtests.jl index 768040f..4d57429 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,12 +1,35 @@ using SafeTestsets -@safetestset "Abstract Problem " begin include("abstract_problem_tests.jl") end -@safetestset "Abstract Solution " begin include("abstract_solution_tests.jl") end -@safetestset "Abstract Integrator " begin include("abstract_integrator_tests.jl") end -@safetestset "Abstract Method " begin include("abstract_method_tests.jl") end -@safetestset "Abstract Solver " begin include("abstract_solver_tests.jl") end -@safetestset "Geometric Data " begin include("geometric_data_tests.jl") end -@safetestset "State Variables " begin include("state_variables_tests.jl") end -@safetestset "Methods " begin include("methods_tests.jl") end -@safetestset "Types " begin include("types_tests.jl") end -@safetestset "Utils " begin include("utils_tests.jl") end +@safetestset "Abstract Problem " begin + include("abstract_problem_tests.jl") +end +@safetestset "Abstract Solution " begin + include("abstract_solution_tests.jl") +end +@safetestset "Abstract Integrator " begin + include("abstract_integrator_tests.jl") +end +@safetestset "Abstract Method " begin + include("abstract_method_tests.jl") +end +@safetestset "Abstract Solver " begin + include("abstract_solver_tests.jl") +end +@safetestset "Geometric Data " begin + include("geometric_data_tests.jl") +end +@safetestset "State Variables " begin + include("state_variables_tests.jl") +end +@safetestset "States " begin + include("state_tests.jl") +end +@safetestset "Methods " begin + include("methods_tests.jl") +end +@safetestset "Types " begin + include("types_tests.jl") +end +@safetestset "Utils " begin + include("utils_tests.jl") +end diff --git a/test/state_tests.jl b/test/state_tests.jl new file mode 100644 index 0000000..89af6a2 --- /dev/null +++ b/test/state_tests.jl @@ -0,0 +1,92 @@ +using GeometricBase +using Test + +using GeometricBase: StateWithError, TimeVariable, VectorfieldVariable +using GeometricBase: periodic, value +using GeometricBase: _strip_symbol, _strip_bar, _strip_dot, _add_symbol, _add_bar, _add_dot, _state, _vectorfield + + +@testset "$(rpad("State Helper Functions",80))" begin + + x = rand(3) + + @test _strip_symbol(:q̄, Char(0x0304)) == :q + @test _strip_symbol(:q̇, Char(0x0307)) == :q + + @test _strip_bar(:q̄) == :q + @test _strip_dot(:q̇) == :q + + @test _add_symbol(:q, Char(0x0304)) == :q̄ + @test _add_symbol(:q, Char(0x0307)) == :q̇ + + @test _add_bar(:q) == :q̄ + @test _add_dot(:q) == :q̇ + + @test _state(TimeVariable(1.0)) == TimeVariable(0.0) + @test _state(StateVariable(x)) == StateWithError(StateVariable(zero(x))) + + # @test ismissing(_vectorfield(1)) + @test ismissing(_vectorfield(TimeVariable(1.0))) + @test ismissing(_vectorfield(AlgebraicVariable(x))) + @test ismissing(_vectorfield(VectorfieldVariable(StateVariable(x)))) + + @test _vectorfield(StateVariable(x)) == VectorfieldVariable(zero(x)) + @test _vectorfield(StateWithError(StateVariable(x))) == VectorfieldVariable(zero(x)) + +end + + +@testset "$(rpad("State Constructor and Access Functions",80))" begin + + data = ( + t=TimeVariable(0.0), + q=StateVariable(rand(3)), + p=StateVariable(rand(3)), + λ=AlgebraicVariable(rand(2)), + ) + + st = State(data) + + @test state(st) == st.state + @test solution(st) == st.solution + @test vectorfield(st) == st.vectorfield + + @test keys(st) == keys(state(st)) + + @test st.t == data.t + @test st.q == data.q + @test st.p == data.p + @test st.λ == data.λ + + @test hasproperty(st, :t) + @test hasproperty(st, :q) + @test hasproperty(st, :p) + @test hasproperty(st, :λ) + @test hasproperty(st, :q̇) + @test hasproperty(st, :ṗ) + + @test :t ∈ keys(st) + @test :q ∈ keys(st) + @test :p ∈ keys(st) + @test :λ ∈ keys(st) + @test :q̇ ∈ keys(st) + @test :ṗ ∈ keys(st) + + @test :t ∈ keys(state(st)) + @test :q ∈ keys(state(st)) + @test :p ∈ keys(state(st)) + @test :λ ∈ keys(state(st)) + @test :q̇ ∈ keys(state(st)) + @test :ṗ ∈ keys(state(st)) + + @test :t ∈ keys(solution(st)) + @test :q ∈ keys(solution(st)) + @test :p ∈ keys(solution(st)) + @test :λ ∈ keys(solution(st)) + + @test :t ∉ keys(vectorfield(st)) + @test :q ∈ keys(vectorfield(st)) + @test :p ∈ keys(vectorfield(st)) + @test :λ ∉ keys(vectorfield(st)) + +end