Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
name = "GeometricBase"
uuid = "9a0b12b7-583b-4f04-aa1f-d8551b6addc9"
authors = ["Michael Kraus"]
version = "0.12.7"
authors = ["Michael Kraus"]

[deps]
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[compat]
julia = "1.6"
Unicode = "1.10"
julia = "1.10"

[extras]
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand Down
29 changes: 15 additions & 14 deletions src/GeometricBase.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
module GeometricBase

include("Config.jl")
include("Utils.jl")
include("Config.jl")
include("Utils.jl")

include("methods.jl")
include("types.jl")
include("methods.jl")
include("types.jl")

include("abstract_problem.jl")
include("abstract_method.jl")
include("abstract_integrator.jl")
include("abstract_solution.jl")
include("abstract_solver.jl")
include("abstract_solver_state.jl")
include("abstract_problem.jl")
include("abstract_method.jl")
include("abstract_integrator.jl")
include("abstract_solution.jl")
include("abstract_solver.jl")
include("abstract_solver_state.jl")

include("data/system_types.jl")
include("data/data_types.jl")
include("data/geometric_data.jl")
include("data/state_variables.jl")
include("data/system_types.jl")
include("data/data_types.jl")
include("data/geometric_data.jl")
include("data/state_variables.jl")
include("data/state.jl")

end
139 changes: 139 additions & 0 deletions src/data/state.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
using Unicode: normalize

export State
export solution, state, vectorfield


# The `_state` function returns an appropriate empty state for a given state variable.
# _state(x::Number) = ScalarVariable(x)
_state(x::TimeVariable) = zero(x)
_state(x::StateVariable) = StateWithError(zero(x))
_state(x::VectorfieldVariable) = zero(x)
_state(x::AlgebraicVariable) = zero(x)
_state(x::StateWithError) = zero(x)

# The `_vectorfield` function returns an appropriate empty vectorfield for a given state variable.
# _vectorfield(x::Number) = missing
_vectorfield(x::TimeVariable) = missing
_vectorfield(x::StateVariable) = VectorfieldVariable(x)
_vectorfield(x::VectorfieldVariable) = missing
_vectorfield(x::AlgebraicVariable) = missing
_vectorfield(x::StateWithError) = _vectorfield(x.state)

# The `_convert` function returns an appropriate type for any given `AbstractVariable`.
# In particular, it returns the scalar value of a `TimeVariable`.
_convert(x::Missing)::Missing = x
_convert(x::TimeVariable) = value(x)
_convert(x::AbstractVariable) = x

# Adds a dot or bar to a symbol, indicating a time derivative or vector field, or a previous solution.
_add_symbol(s::Symbol, c::Char) = Symbol(normalize("$(s)$(c)"))
_add_bar(s::Symbol) = _add_symbol(s, Char(0x0304))
_add_dot(s::Symbol) = _add_symbol(s, Char(0x0307))

# Removes a dot or bar from a symbol.
_strip_symbol(s::Symbol, c::Char) = Symbol(strip(normalize(String(s); decompose=true), c))
_strip_bar(s::Symbol) = _strip_symbol(s, Char(0x0304))
_strip_dot(s::Symbol) = _strip_symbol(s, Char(0x0307))

# q̄ = Symbol(join([Char('q'), Char(0x0304)]))
# q̇ = Symbol(join([Char('q'), Char(0x0307)]))
# q = Symbol(strip(String(:q̄), Char(0x0304)))
# q = Symbol(strip(String(:q̇), Char(0x0307)))


"""
Holds the solution of a geometric equation at a single time step.

It stores all the information that is required to uniquely determine the state of a systen,
in particular all state variables and their corresponding vector fields.
"""
struct State{
stateType<:NamedTuple,
solutionType<:NamedTuple,
vectorfieldType<:NamedTuple
}

state::stateType
solution::solutionType
vectorfield::vectorfieldType

function State(ics::NamedTuple)
# create solution tupke for all variables in ics
solution = NamedTuple{keys(ics)}(_state(x) for x in ics)

# create vectorfield vector for all state variables in ics
vectorfield = NamedTuple{keys(ics)}(_vectorfield(x) for x in ics)

# remove all fields that are missing, i.e., that correspond to a variable without vectorfield
vectorfield_filtered = NamedTuple{filter(k -> !all(ismissing.(vectorfield[k])), keys(vectorfield))}(vectorfield)

# create vector field symbols with dotted solution symbols
vectorfield_keys = Tuple(_add_dot(k) for k in keys(vectorfield_filtered))
vectorfield_dots = NamedTuple{vectorfield_keys}(values(vectorfield_filtered))

# create state by merging solution fields with filtered vector fields
state_fields = merge(solution, vectorfield_dots)

# create state
state = new{typeof(state_fields),typeof(solution),typeof(vectorfield_filtered)}(state_fields, solution, vectorfield_filtered)

# copy initial conditions to state
copy!(state, ics)

return state
end
end

@inline function Base.hasproperty(::State{ST}, s::Symbol) where {ST}
hasfield(ST, s) || hasfield(State, s)
end

@inline function Base.getproperty(st::State{ST}, s::Symbol) where {ST}
if hasfield(ST, s)
return _convert(getfield(st, :state)[s])
else
return getfield(st, s)
end
end

@inline function Base.setproperty!(st::State{ST}, s::Symbol, x) where {ST}
if hasfield(ST, s)
return copy!(getfield(st, :state)[s], x)
else
return setfield!(st, s, x)
end
end


state(st::State) = st.state
solution(st::State) = st.solution
vectorfield(st::State) = st.vectorfield

"""
keys(st::State)

Return the keys of all the state variables in the `State`.
"""
Base.keys(st::State) = keys(state(st))

"""
copy!(st::State, sol::NamedTuple)

Copy the values from a `NamedTuple` `sol` to the `State` `st`.

The keys of `sol` must be a subset of the keys of the state.

# Arguments
- `st`: the state to copy into
- `sol`: the named tuple containing the solution values to copy
"""
function Base.copy!(st::State, sol::NamedTuple)
@assert keys(sol) ⊆ keys(st)

for k in keys(sol)
copy!(solution(st)[k], sol[k])
end

return st
end
3 changes: 2 additions & 1 deletion src/data/state_variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export TimeVariable, StateVariable, VectorfieldVariable, AlgebraicVariable
export Increment, StateWithError, StateVariableWithError, StateVector

export isperiodic, parenttype, verifyrange
export add!, reset!, periodic, value, vectorfield, zerovector
export add!, reset!, periodic, value, zerovector


"""
Expand All @@ -29,6 +29,7 @@ Base.:(==)(x::AV, y::AV) where {AV<:AbstractVariable} = parent(x) == parent(y)
"""
abstract type AbstractScalarVariable{DT} <: AbstractVariable{DT,0} end

Base.:(==)(x::AbstractScalarVariable, y::AbstractScalarVariable) = value(x) == value(y)
Base.:(==)(x::AbstractScalarVariable, y::Number) = value(x) == y
Base.:(==)(x::Number, y::AbstractScalarVariable) = y == x

Expand Down
43 changes: 33 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,35 @@
using SafeTestsets

@safetestset "Abstract Problem " begin include("abstract_problem_tests.jl") end
@safetestset "Abstract Solution " begin include("abstract_solution_tests.jl") end
@safetestset "Abstract Integrator " begin include("abstract_integrator_tests.jl") end
@safetestset "Abstract Method " begin include("abstract_method_tests.jl") end
@safetestset "Abstract Solver " begin include("abstract_solver_tests.jl") end
@safetestset "Geometric Data " begin include("geometric_data_tests.jl") end
@safetestset "State Variables " begin include("state_variables_tests.jl") end
@safetestset "Methods " begin include("methods_tests.jl") end
@safetestset "Types " begin include("types_tests.jl") end
@safetestset "Utils " begin include("utils_tests.jl") end
@safetestset "Abstract Problem " begin
include("abstract_problem_tests.jl")
end
@safetestset "Abstract Solution " begin
include("abstract_solution_tests.jl")
end
@safetestset "Abstract Integrator " begin
include("abstract_integrator_tests.jl")
end
@safetestset "Abstract Method " begin
include("abstract_method_tests.jl")
end
@safetestset "Abstract Solver " begin
include("abstract_solver_tests.jl")
end
@safetestset "Geometric Data " begin
include("geometric_data_tests.jl")
end
@safetestset "State Variables " begin
include("state_variables_tests.jl")
end
@safetestset "States " begin
include("state_tests.jl")
end
@safetestset "Methods " begin
include("methods_tests.jl")
end
@safetestset "Types " begin
include("types_tests.jl")
end
@safetestset "Utils " begin
include("utils_tests.jl")
end
92 changes: 92 additions & 0 deletions test/state_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
using GeometricBase
using Test

using GeometricBase: StateWithError, TimeVariable, VectorfieldVariable
using GeometricBase: periodic, value
using GeometricBase: _strip_symbol, _strip_bar, _strip_dot, _add_symbol, _add_bar, _add_dot, _state, _vectorfield


@testset "$(rpad("State Helper Functions",80))" begin

x = rand(3)

@test _strip_symbol(:q̄, Char(0x0304)) == :q
@test _strip_symbol(:q̇, Char(0x0307)) == :q

@test _strip_bar(:q̄) == :q
@test _strip_dot(:q̇) == :q

@test _add_symbol(:q, Char(0x0304)) == :q̄
@test _add_symbol(:q, Char(0x0307)) == :q̇

@test _add_bar(:q) == :q̄
@test _add_dot(:q) == :q̇

@test _state(TimeVariable(1.0)) == TimeVariable(0.0)
@test _state(StateVariable(x)) == StateWithError(StateVariable(zero(x)))

# @test ismissing(_vectorfield(1))
@test ismissing(_vectorfield(TimeVariable(1.0)))
@test ismissing(_vectorfield(AlgebraicVariable(x)))
@test ismissing(_vectorfield(VectorfieldVariable(StateVariable(x))))

@test _vectorfield(StateVariable(x)) == VectorfieldVariable(zero(x))
@test _vectorfield(StateWithError(StateVariable(x))) == VectorfieldVariable(zero(x))

end


@testset "$(rpad("State Constructor and Access Functions",80))" begin

data = (
t=TimeVariable(0.0),
q=StateVariable(rand(3)),
p=StateVariable(rand(3)),
λ=AlgebraicVariable(rand(2)),
)

st = State(data)

@test state(st) == st.state
@test solution(st) == st.solution
@test vectorfield(st) == st.vectorfield

@test keys(st) == keys(state(st))

@test st.t == data.t
@test st.q == data.q
@test st.p == data.p
@test st.λ == data.λ

@test hasproperty(st, :t)
@test hasproperty(st, :q)
@test hasproperty(st, :p)
@test hasproperty(st, :λ)
@test hasproperty(st, :q̇)
@test hasproperty(st, :ṗ)

@test :t ∈ keys(st)
@test :q ∈ keys(st)
@test :p ∈ keys(st)
@test :λ ∈ keys(st)
@test :q̇ ∈ keys(st)
@test :ṗ ∈ keys(st)

@test :t ∈ keys(state(st))
@test :q ∈ keys(state(st))
@test :p ∈ keys(state(st))
@test :λ ∈ keys(state(st))
@test :q̇ ∈ keys(state(st))
@test :ṗ ∈ keys(state(st))

@test :t ∈ keys(solution(st))
@test :q ∈ keys(solution(st))
@test :p ∈ keys(solution(st))
@test :λ ∈ keys(solution(st))

@test :t ∉ keys(vectorfield(st))
@test :q ∈ keys(vectorfield(st))
@test :p ∈ keys(vectorfield(st))
@test :λ ∉ keys(vectorfield(st))

end
Loading