Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
name = "FerriteAssembly"
uuid = "fd21fc07-c509-4fe1-9468-19963fd5935d"
authors = ["Knut Andreas Meyer and contributors"]
version = "0.3.6"
authors = ["Knut Andreas Meyer and contributors"]

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Ferrite = "c061ca5d-56c9-439f-9c0e-210fe06d3992"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MaterialModelsBase = "af893363-701d-44dc-8b1e-d9a2c129bfc9"

[sources]
MaterialModelsBase = {rev = "main", url = "https://github.com/KnutAM/MaterialModelsBase.jl.git"}
MechanicalMaterialModels = {url = "https://github.com/KnutAM/MechanicalMaterialModels.jl.git"}
Newton = {url = "https://github.com/KnutAM/Newton.jl.git"}

[compat]
ConstructionBase = "1.5"
Ferrite = "1"
ForwardDiff = "0.10, 1"
MaterialModelsBase = "0.2"
MaterialModelsBase = "0.3"
julia = "1.11"

[extras]
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
MechanicalMaterialModels = "b3282f9b-607f-4337-ab95-e5488ab5652c"
Newton = "83aa5b51-0588-403c-85e4-434ec185aae7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Logging", "SparseArrays", "MechanicalMaterialModels", "Newton"]

[sources]
MaterialModelsBase = {url = "https://github.com/KnutAM/MaterialModelsBase.jl.git"}
MechanicalMaterialModels = {url = "https://github.com/KnutAM/MechanicalMaterialModels.jl.git"}
Newton = {url = "https://github.com/KnutAM/Newton.jl.git"}
12 changes: 7 additions & 5 deletions src/DomainBuffers.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
abstract type AbstractDomainBuffer end

get_num_tasks(::AbstractDomainBuffer) = Threads.nthreads() # Default for now

const DomainBuffers = Dict{String, <:AbstractDomainBuffer}

# Accessor functions
Expand Down Expand Up @@ -147,18 +145,22 @@ end
struct ThreadedDomainBuffer{I,B,S,SDH<:SubDofHandler} <: AbstractDomainBuffer
chunks::Vector{Vector{Vector{I}}} # I=Int (cell), I=FacetIndex (facet), or
set::Vector{I} # I=NTuple{2,FacetIndex} (interface)
num_tasks::Int
itembuffer::TaskLocals{B,B} # cell, facet, or interface buffer
states::StateVariables{S}
sdh::SDH
end
function ThreadedDomainBuffer(set, itembuffer::AbstractItemBuffer, states::StateVariables, sdh::SubDofHandler, colors_or_chunks=nothing)
function ThreadedDomainBuffer(set, itembuffer::AbstractItemBuffer, states::StateVariables, sdh::SubDofHandler, colors_or_chunks=nothing; num_tasks = Threads.nthreads())
grid = _getgrid(sdh)
set_vector = collect(set)
chunks = create_chunks(grid, set_vector, colors_or_chunks)
itembuffers = TaskLocals(itembuffer)
return ThreadedDomainBuffer(chunks, set_vector, itembuffers, states, sdh)
itembuffers = TaskLocals(itembuffer; num_tasks)
return ThreadedDomainBuffer(chunks, set_vector, num_tasks, itembuffers, states, sdh)
end

get_num_tasks(db::ThreadedDomainBuffer) = db.num_tasks
get_num_tasks(dbs::DomainBuffers) = maximum(get_num_tasks, values(dbs))

get_chunks(db::ThreadedDomainBuffer) = db.chunks

const StdDomainBuffer = Union{DomainBuffer, ThreadedDomainBuffer}
Expand Down
2 changes: 1 addition & 1 deletion src/Multithreading/TaskLocals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct TaskLocals{TB,TL}
base::TB
locals::Vector{TL}
end
function TaskLocals(base; num_tasks=Threads.nthreads())
function TaskLocals(base; num_tasks)
locals = [create_local(base) for _ in 1:num_tasks]
return TaskLocals(base, locals)
end
Expand Down
2 changes: 1 addition & 1 deletion src/Simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ update_states!(sim::Simulation) = update_states!(sim.db)
set_time_increment!(sim::Simulation, Δt) = set_time_increment!(sim.db, Δt)

# Forwarding for internal API
get_num_tasks(sim::Simulation{<:AbstractDomainBuffer}) = get_num_tasks(sim.db)
get_num_tasks(sim::Simulation) = get_num_tasks(sim.db)
get_chunks(sim::Simulation{<:AbstractDomainBuffer}) = get_chunks(sim.db)
get_itembuffer(sim::Simulation, args::Vararg{Any, N}) where {N} = get_itembuffer(sim.db, args...)

Expand Down
19 changes: 10 additions & 9 deletions src/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ function check_input(dbs::DomainBuffers, ::Type) # Unspecified (typically facet
end

"""
setup_domainbuffer(domain::DomainSpec; a=nothing, threading=false, autodiffbuffer=false)
setup_domainbuffer(domain::DomainSpec; a=nothing, threading=false, autodiffbuffer=false, num_tasks=Threads.nthreads())

Setup a domain buffer for a single grid domain, `domain`.
* `a::Vector`: The global degree of freedom values are used to pass the
Expand All @@ -99,6 +99,7 @@ Setup a domain buffer for a single grid domain, `domain`.
conditions for the field variables.
* `threading`: Should a `ThreadedDomainBuffer` be created to work the grid multithreaded
if supported by the used `worker`?
* `num_tasks`: The number of tasks to spawn during threaded assembly. Only applicable for `threading = true`.
* `autodiffbuffer`: Should a custom itembuffer be used to speed up the automatic
differentiation (if supported by the itembuffer)
"""
Expand All @@ -118,22 +119,22 @@ function setup_itembuffer(adb, domain::DomainSpec{Int}, states)
return setup_cellbuffer(adb, domain.sdh, domain.fe_values, domain.material, first(values(states)), dofrange, domain.user_data)
end

function _setup_domainbuffer(threaded, domain; a=nothing, autodiffbuffer=Val(false))
function _setup_domainbuffer(threaded, domain; a=nothing, autodiffbuffer=Val(false), kwargs...)
new_states = create_states(domain, a)
old_states = create_states(domain, a)
itembuffer = setup_itembuffer(autodiffbuffer, domain, new_states)
return _setup_domainbuffer(threaded, domain.set, itembuffer, StateVariables(old_states, new_states), domain.sdh, domain.colors_or_chunks)
return _setup_domainbuffer(threaded, domain.set, itembuffer, StateVariables(old_states, new_states), domain.sdh, domain.colors_or_chunks; kwargs...)
end

# Type-unstable switch
function _setup_domainbuffer(threaded::Bool, args...)
return _setup_domainbuffer(Val(threaded), args...)
function _setup_domainbuffer(threaded::Bool, args...; kwargs...)
return _setup_domainbuffer(Val(threaded), args...; kwargs...)
end
# Sequential
function _setup_domainbuffer(::Val{false}, set, itembuffer, states, sdh, args...)
return DomainBuffer(set, itembuffer, states, sdh)
function _setup_domainbuffer(::Val{false}, set, itembuffer, states, sdh, args...; kwargs...)
return DomainBuffer(set, itembuffer, states, sdh; kwargs...)
end
# Threaded
function _setup_domainbuffer(::Val{true}, set, itembuffer, states, sdh, args...)
return ThreadedDomainBuffer(set, itembuffer, states, sdh, args...)
function _setup_domainbuffer(::Val{true}, set, itembuffer, states, sdh, args...; kwargs...)
return ThreadedDomainBuffer(set, itembuffer, states, sdh, args...; kwargs...)
end
4 changes: 2 additions & 2 deletions src/work.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function work!(worker, sim::SingleDomainSim, coupled_simulations = CoupledSimula
end
function work!(worker, multisim::MultiDomainThreadedSim, coupled_simulations = CoupledSimulations())
if can_thread(worker)
workers = TaskLocals(worker)
workers = TaskLocals(worker, num_tasks = get_num_tasks(multisim))
for (name, sim) in multisim
skip_this_domain(worker, name) && continue
coupled = get_domain_simulation(coupled_simulations, name)
Expand All @@ -46,7 +46,7 @@ function work!(worker, multisim::MultiDomainThreadedSim, coupled_simulations = C
end
function work!(worker, sim::SingleDomainThreadedSim, coupled_simulations = CoupledSimulations())
if can_thread(worker)
workers = TaskLocals(worker)
workers = TaskLocals(worker; num_tasks = get_num_tasks(sim))
work_domain_threaded!(workers, sim, coupled_simulations)
else
work_domain_sequential!(worker, sim, coupled_simulations)
Expand Down
61 changes: 37 additions & 24 deletions test/heatequation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@
weak = EE.WeakForm((δu, ∇δu, u, ∇u, u_dot, ∇u_dot) -> 1.0*(∇δu ⋅ ∇u) - δu*1.0)
materials = (same=ThermalMaterial(), ad=ThermalMaterialAD(), weak=weak, mixed=Dict("A"=>ThermalMaterial(), "B"=>ThermalMaterialAD()))

function setup_assembly_test(dh, material, cv; autodiff_cb=false, threaded=false)
function setup_assembly_test(dh, material, cv; autodiff_cb=false, threaded=false, kwargs...)
BufferType = threaded ? FerriteAssembly.ThreadedDomainBuffer : FerriteAssembly.DomainBuffer
if isa(material, Dict) && length(dh.subdofhandlers) == 1
setA, setB = (getcellset(dh.grid, name) for name in ("A", "B"))
ad1 = DomainSpec(dh, material["A"], cv; set=setA)
ad2 = DomainSpec(dh, material["B"], cv; set=setB)
buffer = setup_domainbuffers(Dict("A"=>ad1, "B"=>ad2); autodiffbuffer=autodiff_cb, threading=threaded)
buffer = setup_domainbuffers(Dict("A"=>ad1, "B"=>ad2); autodiffbuffer=autodiff_cb, threading=threaded, kwargs...)
@test isa(buffer, Dict{String,<:BufferType})
@test isa(FerriteAssembly.get_old_state(buffer, "A"), FerriteAssembly.StateVector)
@test isa(FerriteAssembly.get_old_state(buffer, "B"), FerriteAssembly.StateVector)
Expand All @@ -149,7 +149,7 @@
# For ad3 and ad4; add the full set to check correct intersection with sdh2's cellset internally.
ad3 = DomainSpec(sdh2, material["A"], cv; set=setA) # sdh2A
ad4 = DomainSpec(sdh2, material["B"], cv; set=setB) # sdh2B
buffer = setup_domainbuffers(Dict("sdh1A"=>ad1, "sdh1B"=>ad2, "sdh2A"=>ad3, "sdh2B"=>ad4); autodiffbuffer=autodiff_cb, threading=threaded)
buffer = setup_domainbuffers(Dict("sdh1A"=>ad1, "sdh1B"=>ad2, "sdh2A"=>ad3, "sdh2B"=>ad4); autodiffbuffer=autodiff_cb, threading=threaded, kwargs...)
@test isa(buffer, Dict{String,<:BufferType})
@test isa(FerriteAssembly.get_old_state(buffer, "sdh1A"), FerriteAssembly.StateVector)
return buffer
Expand All @@ -159,12 +159,12 @@
set1 = sdh1.cellset; set2 = sdh2.cellset
ad1 = DomainSpec(sdh1, material, cv; set=set1)
ad2 = DomainSpec(sdh2, material, cv; set=set2)
buffer = setup_domainbuffers(Dict("sdh1"=>ad1, "sdh2"=>ad2); autodiffbuffer=autodiff_cb, threading=threaded)
buffer = setup_domainbuffers(Dict("sdh1"=>ad1, "sdh2"=>ad2); autodiffbuffer=autodiff_cb, threading=threaded, kwargs...)
@test isa(buffer, Dict{String,<:BufferType})
@test isa(FerriteAssembly.get_old_state(buffer, "sdh1"), FerriteAssembly.StateVector)
return buffer
else
buffer = setup_domainbuffer(DomainSpec(dh, material, cv); autodiffbuffer=autodiff_cb, threading=threaded)
buffer = setup_domainbuffer(DomainSpec(dh, material, cv); autodiffbuffer=autodiff_cb, threading=threaded, kwargs...)
@test isa(buffer, BufferType)
@test isa(FerriteAssembly.get_old_state(buffer), FerriteAssembly.StateVector)
return buffer
Expand Down Expand Up @@ -211,32 +211,45 @@
@testset "$DH, $mattype, threaded" begin
autdiff_cbs = isa(material,ThermalMaterial) ? (false,) : (false, true)
for autodiff_cb in autdiff_cbs
fill!(K, 0);
r .= rand(length(r)) # To ensure that it is actually changed
reset_scaling!(scaling)
ferrite_assembler = start_assemble(K, r)
assembler = isa(scaling, FerriteAssembly.NoScaling) ? ferrite_assembler : FerriteAssembly.KeReAssembler(ferrite_assembler; scaling=scaling)
buffer = setup_assembly_test(dh, material, cv; autodiff_cb=autodiff_cb, threaded=true)
# Quick check that test script works and that it is actually colored
TDB = FerriteAssembly.ThreadedDomainBuffer
@test isa(buffer, Union{Dict{String,<:TDB}, TDB})

work!(assembler, buffer; a=a)
if isa(scaling, ElementResidualScaling)
@test scaling.factors[:u] ≈ sum(abs, r)
end
@test K_ref ≈ K
@test r_ref ≈ r

if mattype == :ad
for num_tasks in (2, :default, Threads.maxthreadid() + 10)
fill!(K, 0);
r .= rand(length(r)) # To ensure that it is actually changed
reset_scaling!(scaling)
assembler = FerriteAssembly.ReAssembler(r; scaling=scaling)
ferrite_assembler = start_assemble(K, r)
assembler = isa(scaling, FerriteAssembly.NoScaling) ? ferrite_assembler : FerriteAssembly.KeReAssembler(ferrite_assembler; scaling=scaling)
buffer = if num_tasks == :default
setup_assembly_test(dh, material, cv; autodiff_cb=autodiff_cb, threaded=true)
else
_buffer = setup_assembly_test(dh, material, cv; autodiff_cb=autodiff_cb, threaded=true, num_tasks)
@test FerriteAssembly.get_num_tasks(_buffer) == num_tasks
if isa(_buffer, FerriteAssembly.AbstractDomainBuffer)
@test length(FerriteAssembly.get_locals(_buffer.itembuffer)) == num_tasks
elseif isa(_buffer, FerriteAssembly.DomainBuffers)
@test all(b -> length(FerriteAssembly.get_locals(b.itembuffer)) == num_tasks, values(_buffer))
end
_buffer
end
# Quick check that test script works and that it is actually colored
TDB = FerriteAssembly.ThreadedDomainBuffer
@test isa(buffer, Union{Dict{String,<:TDB}, TDB})

work!(assembler, buffer; a=a)
if isa(scaling, ElementResidualScaling)
@test scaling.factors[:u] ≈ sum(abs, r)
end
@test K_ref ≈ K
@test r_ref ≈ r

if mattype == :ad
r .= rand(length(r)) # To ensure that it is actually changed
reset_scaling!(scaling)
assembler = FerriteAssembly.ReAssembler(r; scaling=scaling)
work!(assembler, buffer; a=a)
if isa(scaling, ElementResidualScaling)
@test scaling.factors[:u] ≈ sum(abs, r)
end
@test r_ref ≈ r
end
end
end
end
Expand Down
Loading