From 3e9cae6d840da186c2e65da2e1a1e24cff585471 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 6 Jan 2026 15:58:27 +0000 Subject: [PATCH 01/29] Initial batching demo --- GeneralisedFilters/Project.toml | 16 +- GeneralisedFilters/src/GFTest/GFTest.jl | 1 - GeneralisedFilters/src/GeneralisedFilters.jl | 7 +- .../src/batching/broadcasting.jl | 202 +++++ GeneralisedFilters/src/batching/operations.jl | 214 ++++++ GeneralisedFilters/src/batching/types.jl | 222 ++++++ GeneralisedFilters/src/batching/wrappers.jl | 720 ++++++++++++++++++ research/batching/batching_demo.jl | 113 +++ 8 files changed, 1484 insertions(+), 11 deletions(-) create mode 100644 GeneralisedFilters/src/batching/broadcasting.jl create mode 100644 GeneralisedFilters/src/batching/operations.jl create mode 100644 GeneralisedFilters/src/batching/types.jl create mode 100644 GeneralisedFilters/src/batching/wrappers.jl create mode 100644 research/batching/batching_demo.jl diff --git a/GeneralisedFilters/Project.toml b/GeneralisedFilters/Project.toml index 06e080fe..fdb57c87 100644 --- a/GeneralisedFilters/Project.toml +++ b/GeneralisedFilters/Project.toml @@ -1,12 +1,7 @@ name = "GeneralisedFilters" uuid = "3ef92589-7ab8-43f9-b5b9-a3a0c86ecbb7" version = "0.4.2" -authors = [ - "THargreaves ", - "Charles Knipp ", - "FredericWantiez ", - "Hong Ge " -] +authors = ["THargreaves ", "Charles Knipp ", "FredericWantiez ", "Hong Ge "] [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -14,9 +9,10 @@ AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Magma = "a4173727-5e3e-4567-b12d-2e3cf2fa2f28" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -24,6 +20,7 @@ SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [compat] AbstractMCMC = "5" @@ -32,25 +29,26 @@ Aqua = "0.8" CUDA = "5" DataStructures = "0.18.20, 0.19" Distributions = "0.25" +IRTools = "0.4.15" LogExpFunctions = "0.3" -NNlib = "0.9" OffsetArrays = "1.14.1" PDMats = "0.11.35" SSMProblems = "0.6" StaticArrays = "1.9.14" Statistics = "1" StatsBase = "0.34.3" +StructArrays = "0.7.2" Test = "1" julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" [targets] test = ["Aqua", "PDMats", "StableRNGs", "Test", "TestItemRunner", "TestItems", "JET"] diff --git a/GeneralisedFilters/src/GFTest/GFTest.jl b/GeneralisedFilters/src/GFTest/GFTest.jl index 3d673c44..c5a0f6a8 100644 --- a/GeneralisedFilters/src/GFTest/GFTest.jl +++ b/GeneralisedFilters/src/GFTest/GFTest.jl @@ -2,7 +2,6 @@ module GFTest using CUDA using LinearAlgebra -using NNlib using Random using StaticArrays diff --git a/GeneralisedFilters/src/GeneralisedFilters.jl b/GeneralisedFilters/src/GeneralisedFilters.jl index deb349c7..37042d89 100644 --- a/GeneralisedFilters/src/GeneralisedFilters.jl +++ b/GeneralisedFilters/src/GeneralisedFilters.jl @@ -10,7 +10,6 @@ using StatsBase # TODO: heavy modules—move to extension using CUDA -using NNlib export initialise, step, predict, update, filter @@ -19,6 +18,12 @@ include("callbacks.jl") include("containers.jl") include("resamplers.jl") +# Batching utilities +include("batching/types.jl") +include("batching/broadcasting.jl") +include("batching/wrappers.jl") +include("batching/operations.jl") + ## FILTERING BASE ########################################################################## abstract type AbstractFilter <: AbstractSampler end diff --git a/GeneralisedFilters/src/batching/broadcasting.jl b/GeneralisedFilters/src/batching/broadcasting.jl new file mode 100644 index 00000000..92d50811 --- /dev/null +++ b/GeneralisedFilters/src/batching/broadcasting.jl @@ -0,0 +1,202 @@ +using IRTools +using IRTools: @code_ir, IR, Statement, Variable, func +using StructArrays +using LinearAlgebra: I, UniformScaling + +using Base.Broadcast: Broadcasted, BroadcastStyle, DefaultArrayStyle +import Base.Broadcast: broadcasted + +import PDMats: PDMat + +# ============================================================================= +# Broadcast Style +# ============================================================================= + +struct BatchedStyle <: Broadcast.BroadcastStyle end + +Base.BroadcastStyle(::Type{<:BatchedCuMatrix}) = BatchedStyle() +Base.BroadcastStyle(::Type{<:BatchedCuVector}) = BatchedStyle() +Base.BroadcastStyle(::Type{<:SharedCuMatrix}) = BatchedStyle() +Base.BroadcastStyle(::Type{<:SharedCuVector}) = BatchedStyle() +Base.BroadcastStyle(::Type{<:BatchedCholesky}) = BatchedStyle() +Base.BroadcastStyle(::Type{<:BatchedPDMat}) = BatchedStyle() +# HACK: Currently hard-coded but can be replaced with a custom StructArray type +Base.BroadcastStyle(::Type{<:StructArray}) = BatchedStyle() +Base.BroadcastStyle(::BatchedStyle, ::BatchedStyle) = BatchedStyle() +Base.BroadcastStyle(::BatchedStyle, ::DefaultArrayStyle{0}) = BatchedStyle() + +# ============================================================================= +# Ref Conversion (for Shared arrays) +# ============================================================================= + +maybe_convert_ref(x) = x +function maybe_convert_ref(r::Base.RefValue{<:CuVector{T}}) where {T} + return SharedCuVector{T,CuVector{T}}(r[]) +end +function maybe_convert_ref(r::Base.RefValue{<:CuMatrix{T}}) where {T} + return SharedCuMatrix{T,CuMatrix{T}}(r[]) +end + +# ============================================================================= +# Structural Operations (Pass-through) +# ============================================================================= + +broadcasted(::typeof(tuple), args...) = tuple(args...) +broadcasted(::typeof(getproperty), x, s::Symbol) = getproperty(x, s) +broadcasted(::typeof(getfield), x, s::Symbol) = getfield(x, s) +broadcasted(::typeof(getfield), x, i::Int) = getfield(x, i) + +# Special handling for RefValue - unwrap before indexing +broadcasted(::typeof(getfield), r::Base.RefValue, i::Int) = getfield(r[], i) +broadcasted(::typeof(getfield), r::Base.RefValue, s::Symbol) = getfield(r[], s) + +# ============================================================================= +# StructArray Wrapping +# ============================================================================= + +inner_eltype(arg::BatchedCuVector{T}) where {T} = CuVector{T} +inner_eltype(arg::BatchedCuMatrix{T}) where {T} = CuMatrix{T} +inner_eltype(arg::SharedCuVector{T}) where {T} = CuVector{T} +inner_eltype(arg::SharedCuMatrix{T}) where {T} = CuMatrix{T} +inner_eltype(arg::BatchedPDMat{T}) where {T} = PDMat{T,CuMatrix{T}} +inner_eltype(arg) = typeof(arg) + +function wrap_if_batched(::Type{T}, args...) where {T} + if any(arg -> arg isa Union{BatchedArray,SharedArray,BatchedPDMat}, args) + field_names = fieldnames(T) + element_types = Tuple{map(inner_eltype, args)...} + ElType = Core.Compiler.return_type(T, element_types) + nt = NamedTuple{field_names}(args) + return StructArray{ElType}(nt) + else + return T(args...) + end +end + +# ============================================================================= +# IR Transformation +# ============================================================================= + +const SKIP_BROADCAST = Set{Any}([ + tuple, + Core.tuple, + getfield, + getproperty, + adjoint, + transpose, + LowerTriangular, + UpperTriangular, +]) + +const BROADCAST_TYPES = Set{Any}([PDMat]) + +maybe_wrap_scalar(x) = x +maybe_wrap_scalar(x::UniformScaling) = Ref(x) + +@inline function broadcast_and_materialize(f, args...) + wrapped_args = map(maybe_wrap_scalar, args) + result = broadcasted(f, wrapped_args...) + if result isa Broadcasted + return Broadcast.materialize(result) + end + return result +end + +function resolve_to_type(ir::IR, val) + val isa Type && return val + if val isa Variable + stmt = ir[val] + if stmt !== nothing + return resolve_to_type(ir, stmt.expr) + end + end + if val isa GlobalRef + try + resolved = getfield(val.mod, val.name) + return resolved isa Type ? resolved : nothing + catch + return nothing + end + end + return nothing +end + +function is_type_ref(ir::IR, val) + return resolve_to_type(ir, val) !== nothing +end + +function is_broadcast_type(ir::IR, val) + resolved = resolve_to_type(ir, val) + return resolved !== nothing && resolved in BROADCAST_TYPES +end + +function transform_to_batched(ir::IR) + ir = copy(ir) + + for (v, stmt) in ir + if stmt.expr isa Expr && stmt.expr.head == :call + fn = stmt.expr.args[1] + if fn in SKIP_BROADCAST + continue + end + if is_broadcast_type(ir, fn) + new_args = [broadcast_and_materialize, stmt.expr.args...] + ir[v] = Statement(stmt; expr=Expr(:call, new_args...)) + continue + end + if is_type_ref(ir, fn) + new_args = [wrap_if_batched, stmt.expr.args...] + ir[v] = Statement(stmt; expr=Expr(:call, new_args...)) + continue + end + new_args = [broadcast_and_materialize, stmt.expr.args...] + ir[v] = Statement(stmt; expr=Expr(:call, new_args...)) + end + end + + return ir +end + +ir_element_type(::Type{T}) where {T} = T +ir_element_type(::Type{<:StructArray{T}}) where {T} = T + +function generate_batched_function(f, argtypes::Type{<:Tuple}) + element_types = Tuple{map(ir_element_type, argtypes.parameters)...} + + ir = IRTools.Inner.code_ir(f, element_types) + if ir === nothing + error( + "Could not get IR for function $f with types $element_types (original: $argtypes)", + ) + end + batched_ir = transform_to_batched(ir) + return IRTools.func(batched_ir) +end + +# ============================================================================= +# Broadcast Materialization +# ============================================================================= + +const BATCHED_FUNC_CACHE = Dict{Tuple,Any}() + +function Broadcast.materialize(bc::Broadcasted{BatchedStyle}) + f = bc.f + args = map(maybe_convert_ref, bc.args) + + result = broadcasted(f, args...) + if !(result isa Broadcasted) + return result + end + + argtypes = Tuple{map(typeof, args)...} + key = (f, argtypes) + + if !haskey(BATCHED_FUNC_CACHE, key) + println(" [Generating batched version of $f]") + batched_f = generate_batched_function(f, argtypes) + BATCHED_FUNC_CACHE[key] = batched_f + end + + batched_f = BATCHED_FUNC_CACHE[key] + return Base.invokelatest(batched_f, nothing, args...) +end diff --git a/GeneralisedFilters/src/batching/operations.jl b/GeneralisedFilters/src/batching/operations.jl new file mode 100644 index 00000000..8464fc61 --- /dev/null +++ b/GeneralisedFilters/src/batching/operations.jl @@ -0,0 +1,214 @@ +import PDMats: X_A_Xt + +# ============================================================================= +# Adjoint/Transpose Broadcasting +# ============================================================================= + +function broadcasted(::typeof(adjoint), A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} + return BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}(A.data) +end + +function broadcasted(::typeof(transpose), A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} + return BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}(A.data) +end + +function broadcasted(::typeof(adjoint), A::SharedCuMatrix{T,CuMatrix{T}}) where {T} + return SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}}(A.data) +end + +function broadcasted(::typeof(transpose), A::SharedCuMatrix{T,CuMatrix{T}}) where {T} + return SharedCuMatrix{T,Transpose{T,CuMatrix{T}}}(A.data) +end + +function broadcasted( + ::typeof(adjoint), A::BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}} +) where {T} + return BatchedCuMatrix{T,CuMatrix{T}}(A.data) +end + +function broadcasted( + ::typeof(adjoint), A::SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}} +) where {T} + return SharedCuMatrix{T,CuMatrix{T}}(A.data) +end + +function broadcasted(::Type{LowerTriangular}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} + return BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}(A.data) +end + +function broadcasted(::Type{UpperTriangular}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} + return BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}(A.data) +end + +# ============================================================================= +# Matrix Multiply Broadcasting +# ============================================================================= + +function broadcasted( + ::typeof(*), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + transA = trans_flag(A) + transB = trans_flag(B) + + A_inner = inner_size_for_blas(A) + B_inner = inner_size_for_blas(B) + + m = transA == 'N' ? A_inner[1] : A_inner[2] + n = transB == 'N' ? B_inner[2] : B_inner[1] + N = get_batch_size(A, B) + + C_data = CuArray{T}(undef, m, n, N) + C = BatchedCuMatrix(C_data) + + gemm_batched!(transA, transB, one(T), A, B, zero(T), C) + return C +end + +# Multi-argument multiply +function broadcasted( + ::typeof(*), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + C::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + rest::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}..., +) where {T} + result = broadcasted(*, A, B) + result = broadcasted(*, result, C) + for R in rest + result = broadcasted(*, result, R) + end + return result +end + +# ============================================================================= +# Matrix-Vector Multiply Broadcasting +# ============================================================================= + +function broadcasted( + ::typeof(*), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + x::Union{BatchedCuVector{T},SharedCuVector{T}}, +) where {T} + transA = trans_flag(A) + A_inner = inner_size_for_blas(A) + m = transA == 'N' ? A_inner[1] : A_inner[2] + N = get_batch_size(A, x) + + y_data = CuArray{T}(undef, m, N) + y = BatchedCuVector(y_data) + + gemv_batched!(transA, one(T), A, x, zero(T), y) + + return y +end + +# ============================================================================= +# Identity Minus Matrix (I - A) Custom Kernel +# ============================================================================= + +function identity_minus_kernel!(C, A, m, n) + batch_idx = blockIdx().z + i = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + j = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y + + if i <= m && j <= n + if i == j + @inbounds C[i, j, batch_idx] = one(eltype(C)) - A[i, j, batch_idx] + else + @inbounds C[i, j, batch_idx] = -A[i, j, batch_idx] + end + end + return nothing +end + +function identity_minus_batched!(C::CuArray{T,3}, A::CuArray{T,3}) where {T} + m, n, N = size(A) + threads = (16, 16) + blocks = (cld(m, 16), cld(n, 16), N) + @cuda threads = threads blocks = blocks identity_minus_kernel!(C, A, m, n) + return C +end + +function broadcasted( + ::typeof(-), ::Base.RefValue{UniformScaling{Bool}}, A::BatchedCuMatrix{T} +) where {T} + C = CuArray{T}(undef, size(A.data)) + identity_minus_batched!(C, A.data) + return BatchedCuMatrix(C) +end + +# ============================================================================= +# PDMat Broadcasting +# ============================================================================= + +function broadcasted(::Type{PDMat}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} + chol = cholesky_batched(A) + return BatchedPDMat{T}(chol) +end + +function broadcasted(::typeof(\), S::BatchedPDMat{T}, A::BatchedCuMatrix{T}) where {T} + return pdmat_solve(S, A) +end + +function broadcasted(::typeof(/), A::BatchedCuMatrix{T}, S::BatchedPDMat{T}) where {T} + # Need to actually transpose the data, not just wrap it + At_data = permutedims(A.data, (2, 1, 3)) + At = BatchedCuMatrix(At_data) + result_t = pdmat_solve(S, At) + # Transpose back + result_data = permutedims(result_t.data, (2, 1, 3)) + return BatchedCuMatrix(result_data) +end + +# ============================================================================= +# Quadratic Form Broadcasting +# ============================================================================= + +# HACK: treat this as two GEMMs for now +function broadcasted( + ::typeof(X_A_Xt), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + temp = broadcasted(*, X, A) + Xt = broadcasted(adjoint, X) + return broadcasted(*, temp, Xt) +end + +# X_A_Xt for BatchedPDMat: X * P * X' where P = L * L' +# Computed as (X * L) * (X * L)' using TRMM and SYRK +function broadcasted( + ::typeof(X_A_Xt), + P::BatchedPDMat{T}, + X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + L = P.chol.factors + N = get_batch_size(P, X) + + X_inner = inner_size_for_blas(X) + m = X_inner[1] + + # Copy X to XL (TRMM overwrites in-place) + XL_data = if X isa SharedCuMatrix + repeat(reshape(X.data, size(X.data, 1), size(X.data, 2), 1), 1, 1, N) + else + copy(X.data) + end + XL = BatchedCuMatrix(XL_data) + + # XL = X * L using TRMM (side='R' for right multiply, uplo='L' for lower triangular) + L_data = BatchedCuMatrix(L.data) + trmm_batched!('R', 'L', 'N', 'N', one(T), L_data, XL) + + # Result = XL * XL' using SYRK (fills lower triangle) + Result_data = CuArray{T}(undef, m, m, N) + Result = BatchedCuMatrix(Result_data) + syrk_batched!('L', 'N', one(T), XL, zero(T), Result) + + # Symmetrize: copy lower triangle to upper + symmetrize_lower!(Result) + + return Result +end diff --git a/GeneralisedFilters/src/batching/types.jl b/GeneralisedFilters/src/batching/types.jl new file mode 100644 index 00000000..7ae2b399 --- /dev/null +++ b/GeneralisedFilters/src/batching/types.jl @@ -0,0 +1,222 @@ +using CUDA +using LinearAlgebra: + Adjoint, Transpose, LowerTriangular, UpperTriangular, UniformScaling, Cholesky +using PDMats: PDMat + +export BatchedCuMatrix, BatchedCuVector +export SharedCuMatrix, SharedCuVector +export BatchedPDMat, BatchedCholesky + +# ============================================================================= +# Core Batched Types +# ============================================================================= + +struct BatchedCuMatrix{T,Inner<:AbstractMatrix{T}} <: AbstractVector{Inner} + data::CuArray{T,3} +end + +struct BatchedCuVector{T,Inner<:AbstractVector{T}} <: AbstractVector{Inner} + data::CuMatrix{T} +end + +const BatchedArray = Union{BatchedCuVector,BatchedCuMatrix} + +BatchedCuMatrix(data::CuArray{T,3}) where {T} = BatchedCuMatrix{T,CuMatrix{T}}(data) +BatchedCuVector(data::CuMatrix{T}) where {T} = BatchedCuVector{T,CuVector{T}}(data) + +batch_size(x::BatchedCuVector) = size(x.data, 2) +batch_size(x::BatchedCuMatrix) = size(x.data, 3) + +Base.size(x::BatchedCuVector) = (batch_size(x),) +Base.size(x::BatchedCuMatrix) = (batch_size(x),) +Base.length(x::BatchedArray) = batch_size(x) + +inner_size(x::BatchedCuVector) = (size(x.data, 1),) +inner_size(x::BatchedCuMatrix) = (size(x.data, 1), size(x.data, 2)) + +function Base.getindex(x::BatchedCuVector{T,CuVector{T}}, i::Int) where {T} + return view(x.data, :, i) +end + +function Base.getindex(x::BatchedCuMatrix{T,CuMatrix{T}}, i::Int) where {T} + return view(x.data, :, :, i) +end + +function Base.getindex( + x::BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, i::Int +) where {T} + return LowerTriangular(view(x.data, :, :, i)) +end + +function Base.getindex( + x::BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}, i::Int +) where {T} + return UpperTriangular(view(x.data, :, :, i)) +end + +function Base.getindex(x::BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}, i::Int) where {T} + return adjoint(view(x.data, :, :, i)) +end + +function Base.getindex(x::BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}, i::Int) where {T} + return transpose(view(x.data, :, :, i)) +end + +# ============================================================================= +# Shared Types (same data reused across all batch elements) +# ============================================================================= + +struct SharedCuMatrix{T,Inner<:AbstractMatrix{T}} <: AbstractVector{Inner} + data::CuMatrix{T} +end + +struct SharedCuVector{T,Inner<:AbstractVector{T}} <: AbstractVector{Inner} + data::CuVector{T} +end + +const SharedArray = Union{SharedCuVector,SharedCuMatrix} + +SharedCuMatrix(data::CuMatrix{T}) where {T} = SharedCuMatrix{T,CuMatrix{T}}(data) +SharedCuVector(data::CuVector{T}) where {T} = SharedCuVector{T,CuVector{T}}(data) + +Shared(x::CuMatrix{T}) where {T} = SharedCuMatrix(x) +Shared(x::CuVector{T}) where {T} = SharedCuVector(x) + +batch_size(::SharedCuVector) = nothing +batch_size(::SharedCuMatrix) = nothing + +inner_size(x::SharedCuVector) = size(x.data) +inner_size(x::SharedCuMatrix) = size(x.data) + +Base.size(x::SharedCuVector) = (1,) +Base.size(x::SharedCuMatrix) = (1,) +Base.length(::SharedArray) = 1 + +Base.getindex(x::SharedCuVector, ::Int) = x.data +Base.getindex(x::SharedCuMatrix{T,CuMatrix{T}}, ::Int) where {T} = x.data +function Base.getindex(x::SharedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, ::Int) where {T} + return LowerTriangular(x.data) +end + +# ============================================================================= +# Type Aliases and Union Types for Dispatch +# ============================================================================= + +const AnyBatchedMatrix{T} = Union{ + BatchedCuMatrix{T,CuMatrix{T}}, + BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}, + BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}, + BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, + BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}, +} + +const AnySharedMatrix{T} = Union{ + SharedCuMatrix{T,CuMatrix{T}}, + SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}}, + SharedCuMatrix{T,Transpose{T,CuMatrix{T}}}, + SharedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, + SharedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}, +} + +const AnyMatrix{T} = Union{AnyBatchedMatrix{T},AnySharedMatrix{T}} +const AnyVector{T} = Union{BatchedCuVector{T},SharedCuVector{T}} + +# ============================================================================= +# Helper Functions +# ============================================================================= + +is_shared(::BatchedCuMatrix) = false +is_shared(::BatchedCuVector) = false +is_shared(::SharedCuMatrix) = true +is_shared(::SharedCuVector) = true + +unwrap_data(A::BatchedCuMatrix) = A.data +unwrap_data(A::SharedCuMatrix) = A.data +unwrap_data(x::BatchedCuVector) = x.data +unwrap_data(x::SharedCuVector) = x.data + +trans_flag(::BatchedCuMatrix{T,CuMatrix{T}}) where {T} = 'N' +trans_flag(::BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}) where {T} = T <: Real ? 'T' : 'C' +trans_flag(::BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}) where {T} = 'T' +trans_flag(::BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}) where {T} = 'N' +trans_flag(::BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}) where {T} = 'N' + +trans_flag(::SharedCuMatrix{T,CuMatrix{T}}) where {T} = 'N' +trans_flag(::SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}}) where {T} = T <: Real ? 'T' : 'C' +trans_flag(::SharedCuMatrix{T,Transpose{T,CuMatrix{T}}}) where {T} = 'T' +trans_flag(::SharedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}) where {T} = 'N' +trans_flag(::SharedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}) where {T} = 'N' + +function inner_size_for_blas(A::BatchedCuMatrix) + m, n = size(A.data, 1), size(A.data, 2) + return (m, n) +end + +function inner_size_for_blas(A::SharedCuMatrix) + m, n = size(A.data) + return (m, n) +end + +function get_batch_size(args...) + for arg in args + bs = batch_size(arg) + if bs !== nothing + return bs + end + end + return error("At least one argument must be batched") +end + +# ============================================================================= +# Stateful Wrapper Types (Cholesky, PDMat) +# ============================================================================= + +struct BatchedCholesky{T} <: AbstractVector{Cholesky{T,CuMatrix{T}}} + factors::BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}} + info::CuVector{Int32} + uplo::Char +end + +struct BatchedPDMat{T} <: AbstractVector{PDMat{T,CuMatrix{T}}} + chol::BatchedCholesky{T} +end + +batch_size(c::BatchedCholesky) = batch_size(c.factors) +batch_size(p::BatchedPDMat) = batch_size(p.chol) + +inner_size(c::BatchedCholesky) = inner_size(c.factors) +inner_size(p::BatchedPDMat) = inner_size(p.chol) + +Base.size(c::BatchedCholesky) = (batch_size(c),) +Base.size(p::BatchedPDMat) = (batch_size(p),) + +Base.getindex(p::BatchedPDMat{T}, i::Int) where {T} = p.chol.factors[i] * p.chol.factors[i]' + +# ============================================================================= +# Pointer Array Creation +# ============================================================================= + +function create_pointer_array(A::BatchedCuMatrix{T}) where {T} + return CUDA.CUBLAS.unsafe_strided_batch(A.data) +end + +function create_pointer_array(A::SharedCuMatrix{T}, N::Int) where {T} + base_ptr = pointer(A.data) + ptrs_cpu = fill(base_ptr, N) + return CuArray(ptrs_cpu) +end + +function create_pointer_array_vector(x::BatchedCuVector{T}) where {T} + n = size(x.data, 1) + N = size(x.data, 2) + base_ptr = pointer(x.data) + stride = n * sizeof(T) + ptrs = CuArray([base_ptr + (i - 1) * stride for i in 1:N]) + return ptrs +end + +function create_pointer_array_vector(x::SharedCuVector{T}, N::Int) where {T} + base_ptr = pointer(x.data) + ptrs_cpu = fill(base_ptr, N) + return CuArray(ptrs_cpu) +end diff --git a/GeneralisedFilters/src/batching/wrappers.jl b/GeneralisedFilters/src/batching/wrappers.jl new file mode 100644 index 00000000..6f7e741b --- /dev/null +++ b/GeneralisedFilters/src/batching/wrappers.jl @@ -0,0 +1,720 @@ +using Magma +using Magma.LibMagma + +# ============================================================================= +# Trivial Wrappers (reductions and elementwise operations) +# ============================================================================= + +function broadcasted( + ::typeof(+), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + if is_shared(A) && is_shared(B) + return SharedCuMatrix(A.data .+ B.data) + else + return BatchedCuMatrix(A.data .+ B.data) + end +end + +function broadcasted( + ::typeof(+), + a::Union{BatchedCuVector{T},SharedCuVector{T}}, + b::Union{BatchedCuVector{T},SharedCuVector{T}}, +) where {T} + if is_shared(a) && is_shared(b) + return SharedCuVector(a.data .+ b.data) + else + return BatchedCuVector(a.data .+ b.data) + end +end + +function broadcasted( + ::typeof(-), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + if is_shared(A) && is_shared(B) + return SharedCuMatrix(A.data .- B.data) + else + return BatchedCuMatrix(A.data .- B.data) + end +end + +function broadcasted( + ::typeof(-), + a::Union{BatchedCuVector{T},SharedCuVector{T}}, + b::Union{BatchedCuVector{T},SharedCuVector{T}}, +) where {T} + if is_shared(a) && is_shared(b) + return SharedCuVector(a.data .- b.data) + else + return BatchedCuVector(a.data .- b.data) + end +end + +# ============================================================================= +# MAGMA Constants Conversion +# ============================================================================= + +function magma_trans(c::Char) + if c == 'N' + return LibMagma.MagmaNoTrans + elseif c == 'T' + return LibMagma.MagmaTrans + elseif c == 'C' + return LibMagma.MagmaConjTrans + else + error("Unknown transpose char: $c") + end +end + +function magma_uplo(c::Char) + if c == 'L' + return LibMagma.MagmaLower + elseif c == 'U' + return LibMagma.MagmaUpper + else + error("Unknown uplo char: $c") + end +end + +function magma_side(c::Char) + if c == 'L' + return LibMagma.MagmaLeft + elseif c == 'R' + return LibMagma.MagmaRight + else + error("Unknown side char: $c") + end +end + +function magma_diag(c::Char) + if c == 'N' + return LibMagma.MagmaNonUnit + elseif c == 'U' + return LibMagma.MagmaUnit + else + error("Unknown diag char: $c") + end +end + +# ============================================================================= +# MAGMA Operations +# ============================================================================= + +function gemm_batched!( + transA::Char, + transB::Char, + alpha::Float32, + A::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + B::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + beta::Float32, + C::BatchedCuMatrix{Float32}, +) + N = batch_size(C) + m, n = size(C.data, 1), size(C.data, 2) + k = transA == 'N' ? size(unwrap_data(A), 2) : size(unwrap_data(A), 1) + + dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) + dB = B isa BatchedCuMatrix ? create_pointer_array(B) : create_pointer_array(B, N) + dC = create_pointer_array(C) + + ldda = size(unwrap_data(A), 1) + lddb = size(unwrap_data(B), 1) + lddc = m + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magma_sgemm_batched( + magma_trans(transA), + magma_trans(transB), + m, + n, + k, + alpha, + dA, + ldda, + dB, + lddb, + beta, + dC, + lddc, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + CUDA.unsafe_free!(dC) + + return C +end + +# ============================================================================= +# Part 3b: Batched GEMM Small Square Wrapper (for D < 32) +# ============================================================================= + +function gemm_batched_smallsq!( + transA::Char, + transB::Char, + alpha::Float32, + A::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + B::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + beta::Float32, + C::BatchedCuMatrix{Float32}, +) + N = batch_size(C) + m, n = size(C.data, 1), size(C.data, 2) + k = transA == 'N' ? size(unwrap_data(A), 2) : size(unwrap_data(A), 1) + + dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) + dB = B isa BatchedCuMatrix ? create_pointer_array(B) : create_pointer_array(B, N) + dC = create_pointer_array(C) + + ldda = size(unwrap_data(A), 1) + lddb = size(unwrap_data(B), 1) + lddc = m + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_sgemm_batched_smallsq( + magma_trans(transA), + magma_trans(transB), + m, + n, + k, + alpha, + dA, + 0, # ai + 0, # aj + ldda, + dB, + 0, # bi + 0, # bj + lddb, + beta, + dC, + 0, # ci + 0, # cj + lddc, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + CUDA.unsafe_free!(dC) + + return C +end + +function gemm_batched_smallsq!( + transA::Char, + transB::Char, + alpha::Float64, + A::Union{BatchedCuMatrix{Float64},SharedCuMatrix{Float64}}, + B::Union{BatchedCuMatrix{Float64},SharedCuMatrix{Float64}}, + beta::Float64, + C::BatchedCuMatrix{Float64}, +) + N = batch_size(C) + m, n = size(C.data, 1), size(C.data, 2) + k = transA == 'N' ? size(unwrap_data(A), 2) : size(unwrap_data(A), 1) + + dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) + dB = B isa BatchedCuMatrix ? create_pointer_array(B) : create_pointer_array(B, N) + dC = create_pointer_array(C) + + ldda = size(unwrap_data(A), 1) + lddb = size(unwrap_data(B), 1) + lddc = m + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_dgemm_batched_smallsq( + magma_trans(transA), + magma_trans(transB), + m, + n, + k, + alpha, + dA, + 0, # ai + 0, # aj + ldda, + dB, + 0, # bi + 0, # bj + lddb, + beta, + dC, + 0, # ci + 0, # cj + lddc, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + CUDA.unsafe_free!(dC) + + return C +end + +function gemv_batched!( + transA::Char, + alpha::Float32, + A::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + x::Union{BatchedCuVector{Float32},SharedCuVector{Float32}}, + beta::Float32, + y::BatchedCuVector{Float32}, +) + N = batch_size(y) + m, n = size(unwrap_data(A), 1), size(unwrap_data(A), 2) + + dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) + dx = if x isa BatchedCuVector + create_pointer_array_vector(x) + else + create_pointer_array_vector(x, N) + end + dy = create_pointer_array_vector(y) + + ldda = m + incx = 1 + incy = 1 + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_sgemv_batched( + magma_trans(transA), m, n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue_ptr[] + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dx) + CUDA.unsafe_free!(dy) + + return y +end + +function gemv_batched!( + transA::Char, + alpha::Float64, + A::Union{BatchedCuMatrix{Float64},SharedCuMatrix{Float64}}, + x::Union{BatchedCuVector{Float64},SharedCuVector{Float64}}, + beta::Float64, + y::BatchedCuVector{Float64}, +) + N = batch_size(y) + m, n = size(unwrap_data(A), 1), size(unwrap_data(A), 2) + + dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) + dx = if x isa BatchedCuVector + create_pointer_array_vector(x) + else + create_pointer_array_vector(x, N) + end + dy = create_pointer_array_vector(y) + + ldda = m + incx = 1 + incy = 1 + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_dgemv_batched( + magma_trans(transA), m, n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue_ptr[] + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dx) + CUDA.unsafe_free!(dy) + + return y +end + +function gemv_batched_smallsq!( + transA::Char, + alpha::Float32, + A::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + x::Union{BatchedCuVector{Float32},SharedCuVector{Float32}}, + beta::Float32, + y::BatchedCuVector{Float32}, +) + N = batch_size(y) + n = size(unwrap_data(A), 1) + + dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) + dx = if x isa BatchedCuVector + create_pointer_array_vector(x) + else + create_pointer_array_vector(x, N) + end + dy = create_pointer_array_vector(y) + + ldda = n + incx = 1 + incy = 1 + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_sgemv_batched_smallsq( + magma_trans(transA), n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue_ptr[] + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dx) + CUDA.unsafe_free!(dy) + + return y +end + +function potrf_batched!(uplo::Char, A::BatchedCuMatrix{Float32}) + N = batch_size(A) + n = size(A.data, 1) + lda = n + + dA = create_pointer_array(A) + info_gpu = CUDA.zeros(Int64, N) + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magma_spotrf_batched( + magma_uplo(uplo), n, dA, lda, pointer(info_gpu), N, queue_ptr[] + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + + factors = BatchedCuMatrix{Float32,LowerTriangular{Float32,CuMatrix{Float32}}}(A.data) + return BatchedCholesky{Float32}(factors, info_gpu, uplo) +end + +function potrs_batched!( + uplo::Char, A::BatchedCuMatrix{Float32}, B::BatchedCuMatrix{Float32} +) + N = batch_size(B) + n = size(A.data, 1) + nrhs = size(B.data, 2) + + dA = create_pointer_array(A) + dB = create_pointer_array(B) + + ldda = n + lddb = n + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magma_spotrs_batched( + magma_uplo(uplo), n, nrhs, dA, ldda, dB, lddb, N, queue_ptr[] + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + + return B +end + +function trsm_batched!( + side::Char, + uplo::Char, + transA::Char, + diag::Char, + alpha::Float32, + A::BatchedCuMatrix{Float32}, + B::BatchedCuMatrix{Float32}, +) + N = batch_size(B) + m, n = size(B.data, 1), size(B.data, 2) + + dA = create_pointer_array(A) + dB = create_pointer_array(B) + + ldda = size(A.data, 1) + lddb = m + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_strsm_batched( + magma_side(side), + magma_uplo(uplo), + magma_trans(transA), + magma_diag(diag), + m, + n, + alpha, + dA, + ldda, + dB, + lddb, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + + return B +end + +# ============================================================================= +# Higher-level Cholesky Operations +# ============================================================================= + +function cholesky_batched(A::BatchedCuMatrix{T}) where {T} + A_copy = BatchedCuMatrix(copy(A.data)) + return potrf_batched!('L', A_copy) +end + +function pdmat_solve(S::BatchedPDMat{T}, B::BatchedCuMatrix{T}) where {T} + L = S.chol.factors + L_data = BatchedCuMatrix(L.data) + + B_copy = BatchedCuMatrix(copy(B.data)) + + # Solve L*L'*X = B via two triangular solves: + # 1. Solve L*Y = B (Y stored in B_copy) + trsm_batched!('L', 'L', 'N', 'N', one(T), L_data, B_copy) + # 2. Solve L'*X = Y (X stored in B_copy) + trsm_batched!('L', 'L', 'T', 'N', one(T), L_data, B_copy) + + return B_copy +end + +# ============================================================================= +# Batched TRMM (Triangular Matrix Multiply) +# ============================================================================= + +function trmm_batched!( + side::Char, + uplo::Char, + transA::Char, + diag::Char, + alpha::Float32, + A::BatchedCuMatrix{Float32}, + B::BatchedCuMatrix{Float32}, +) + N = batch_size(B) + m, n = size(B.data, 1), size(B.data, 2) + + dA = create_pointer_array(A) + dB = create_pointer_array(B) + + ldda = size(A.data, 1) + lddb = m + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_strmm_batched( + magma_side(side), + magma_uplo(uplo), + magma_trans(transA), + magma_diag(diag), + m, + n, + alpha, + dA, + ldda, + dB, + lddb, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + + return B +end + +function trmm_batched!( + side::Char, + uplo::Char, + transA::Char, + diag::Char, + alpha::Float64, + A::BatchedCuMatrix{Float64}, + B::BatchedCuMatrix{Float64}, +) + N = batch_size(B) + m, n = size(B.data, 1), size(B.data, 2) + + dA = create_pointer_array(A) + dB = create_pointer_array(B) + + ldda = size(A.data, 1) + lddb = m + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_dtrmm_batched( + magma_side(side), + magma_uplo(uplo), + magma_trans(transA), + magma_diag(diag), + m, + n, + alpha, + dA, + ldda, + dB, + lddb, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + + return B +end + +# ============================================================================= +# Batched SYRK (Symmetric Rank-K Update) +# ============================================================================= + +function syrk_batched!( + uplo::Char, + trans::Char, + alpha::Float32, + A::BatchedCuMatrix{Float32}, + beta::Float32, + C::BatchedCuMatrix{Float32}, +) + N = batch_size(C) + n = size(C.data, 1) + k = trans == 'N' ? size(A.data, 2) : size(A.data, 1) + + dA = create_pointer_array(A) + dC = create_pointer_array(C) + + ldda = size(A.data, 1) + lddc = n + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_ssyrk_batched( + magma_uplo(uplo), + magma_trans(trans), + n, + k, + alpha, + dA, + ldda, + beta, + dC, + lddc, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dC) + + return C +end + +function syrk_batched!( + uplo::Char, + trans::Char, + alpha::Float64, + A::BatchedCuMatrix{Float64}, + beta::Float64, + C::BatchedCuMatrix{Float64}, +) + N = batch_size(C) + n = size(C.data, 1) + k = trans == 'N' ? size(A.data, 2) : size(A.data, 1) + + dA = create_pointer_array(A) + dC = create_pointer_array(C) + + ldda = size(A.data, 1) + lddc = n + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_dsyrk_batched( + magma_uplo(uplo), + magma_trans(trans), + n, + k, + alpha, + dA, + ldda, + beta, + dC, + lddc, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dC) + + return C +end + +# ============================================================================= +# Symmetrize Lower Triangular Matrix +# ============================================================================= + +function symmetrize_lower_kernel!(A, n) + batch_idx = blockIdx().z + i = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + j = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y + + if i <= n && j <= n && j > i + @inbounds A[j, i, batch_idx] = A[i, j, batch_idx] + end + return nothing +end + +function symmetrize_lower!(A::BatchedCuMatrix{T}) where {T} + n = size(A.data, 1) + N = size(A.data, 3) + threads = (16, 16) + blocks = (cld(n, 16), cld(n, 16), N) + @cuda threads = threads blocks = blocks symmetrize_lower_kernel!(A.data, n) + return A +end diff --git a/research/batching/batching_demo.jl b/research/batching/batching_demo.jl new file mode 100644 index 00000000..a97effce --- /dev/null +++ b/research/batching/batching_demo.jl @@ -0,0 +1,113 @@ +using GeneralisedFilters + +using Distributions +using LinearAlgebra +using Base.Broadcast: broadcasted +using PDMats +using StructArrays +using BenchmarkTools + +using CUDA +using Magma +using Magma.LibMagma + +Magma.magma_init() + +# ============================================================================= +# Configuration +# ============================================================================= + +D_state = 64 +D_obs = 64 +N = 1000 + +function kalman_predict(state, dyn_params) + A = dyn_params[1] + b = dyn_params[2] + Q = dyn_params[3] + + μ̂ = A * state.μ + b + Σ̂ = X_A_Xt(state.Σ, A) + Q + return MvNormal(μ̂, Σ̂) +end + +I_mat = CuArray{Float32}(I, D_state, D_state) +Is = SharedCuMatrix(I_mat) + +μs = BatchedCuVector(CUDA.randn(Float32, D_state, N)) +Σs_root = BatchedCuMatrix(CUDA.randn(Float32, D_state, D_state, N)) +Σs = Σs_root .* adjoint.(Σs_root) .+ Is + +As = SharedCuMatrix(CUDA.randn(Float32, D_state, D_state)) +bs = BatchedCuVector(CUDA.randn(Float32, D_state, N)) +Q_root = CUDA.randn(Float32, D_state, D_state) +Q = Q_root * Q_root' + I +Qs = SharedCuMatrix(Q) + +Σ_PDs = broadcasted(PDMat, Σs); +Gs = StructArray{MvNormal}((μ=μs, Σ=Σ_PDs)); + +function kalman_predict(state, dyn_params) + A = dyn_params[1] + b = dyn_params[2] + Q = dyn_params[3] + + μ̂ = A * state.μ + b + Σ̂ = X_A_Xt(state.Σ, A) + Q + return MvNormal(μ̂, Σ̂) +end + +dyn_params = (As, bs, Qs) +pred_Gs = kalman_predict.(Gs, Ref(dyn_params)); + +# Compare to CPU +μ_test = Array(μs[end]) +Σ_test = Array(Σs[end]) +A_test = Array(As.data) +b_test = Array(bs[end]) +Q_test = Array(Qs.data) +pred_G_test = kalman_predict(MvNormal(μ_test, PDMat(Σ_test)), (A_test, b_test, Q_test)) + +println("CPU Mean: ", pred_G_test.μ[1:5]) +println("GPU Mean: ", Array(pred_Gs[end].μ[1:5])) + +println("CPU Covariance Diagonal: ", diag(pred_G_test.Σ)[1:5]) +println("GPU Covariance Diagonal: ", Array(diag(pred_Gs[end].Σ))[1:5]) + +# Increase batch size and benchmark +D_large = 32 +N_large = 10000 +μs_large = BatchedCuVector(CUDA.randn(Float32, D_large, N_large)) +Σs_root_large = BatchedCuMatrix(CUDA.randn(Float32, D_large, D_large, N_large)) +Σs_large = Σs_root_large .* adjoint.(Σs_root_large) .+ SharedCuMatrix(CuArray{Float32}(I, D_large, D_large)) +Σ_PDs_large = broadcasted(PDMat, Σs_large); +Gs_large = StructArray{MvNormal}((μ=μs_large, Σ=Σ_PDs_large)); +dyn_params_large = ( + SharedCuMatrix(CUDA.randn(Float32, D_large, D_large)), + BatchedCuVector(CUDA.randn(Float32, D_large, N_large)), + SharedCuMatrix((CUDA.randn(Float32, D_large, D_large) * CUDA.randn(Float32, D_large, D_large)') .+ CuArray{Float32}(I, D_large, D_large)), +) +display(@benchmark kalman_predict.($Gs_large, Ref($dyn_params_large))) + +# Compare to multithreading StaticArrays +using StaticArrays +μs_static = [SVector{D_large, Float32}(randn(Float32, D_large)) for _ in 1:N_large]; +Σs_root_static = [SMatrix{D_large,D_large,Float32}(randn(Float32, D_large, D_large)) for _ in 1:N_large]; +Σs_static = [Σs_root_static[i] * adjoint(Σs_root_static[i]) + I for i in 1:N_large]; +Gs_static = [MvNormal(μs_static[i], Σs_static[i]) for i in 1:N_large]; +A_static = SMatrix{D_large,D_large,Float32}(randn(Float32, D_large, D_large)); +b_static = [SVector{D_large, Float32}(randn(Float32, D_large)) for _ in 1:N_large]; +Q_root_static = SMatrix{D_large,D_large,Float32}(randn(Float32, D_large, D_large)); +Q_static = Q_root_static * adjoint(Q_root_static) + I; + +function test_static(Gs, A, b, Q) + out = Vector{MvNormal{Float32, PDMat{Float32, SMatrix{32, 32, Float32, 1024}}, SVector{32, Float32}}}(undef, length(Gs)) + for i in 1:length(Gs) + @inbounds out[i] = kalman_predict(Gs[i], (A, b[i], Q)) + end + return out +end + +display(@benchmark test_static($Gs_static, $A_static, $b_static, $Q_static)) + +@profview test_static(Gs_static, A_static, b_static, Q_static) From 70b0dc5ed001a16430d752b25c71bb8d49c71a7b Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 6 Jan 2026 16:08:47 +0000 Subject: [PATCH 02/29] Extend demo to full Kalman filter --- research/batching/batching_demo.jl | 122 ++++++++++++++++++++--------- 1 file changed, 85 insertions(+), 37 deletions(-) diff --git a/research/batching/batching_demo.jl b/research/batching/batching_demo.jl index a97effce..d2841061 100644 --- a/research/batching/batching_demo.jl +++ b/research/batching/batching_demo.jl @@ -68,46 +68,94 @@ b_test = Array(bs[end]) Q_test = Array(Qs.data) pred_G_test = kalman_predict(MvNormal(μ_test, PDMat(Σ_test)), (A_test, b_test, Q_test)) +println("=== Predict Comparison ===") println("CPU Mean: ", pred_G_test.μ[1:5]) println("GPU Mean: ", Array(pred_Gs[end].μ[1:5])) -println("CPU Covariance Diagonal: ", diag(pred_G_test.Σ)[1:5]) -println("GPU Covariance Diagonal: ", Array(diag(pred_Gs[end].Σ))[1:5]) - -# Increase batch size and benchmark -D_large = 32 -N_large = 10000 -μs_large = BatchedCuVector(CUDA.randn(Float32, D_large, N_large)) -Σs_root_large = BatchedCuMatrix(CUDA.randn(Float32, D_large, D_large, N_large)) -Σs_large = Σs_root_large .* adjoint.(Σs_root_large) .+ SharedCuMatrix(CuArray{Float32}(I, D_large, D_large)) -Σ_PDs_large = broadcasted(PDMat, Σs_large); -Gs_large = StructArray{MvNormal}((μ=μs_large, Σ=Σ_PDs_large)); -dyn_params_large = ( - SharedCuMatrix(CUDA.randn(Float32, D_large, D_large)), - BatchedCuVector(CUDA.randn(Float32, D_large, N_large)), - SharedCuMatrix((CUDA.randn(Float32, D_large, D_large) * CUDA.randn(Float32, D_large, D_large)') .+ CuArray{Float32}(I, D_large, D_large)), -) -display(@benchmark kalman_predict.($Gs_large, Ref($dyn_params_large))) - -# Compare to multithreading StaticArrays -using StaticArrays -μs_static = [SVector{D_large, Float32}(randn(Float32, D_large)) for _ in 1:N_large]; -Σs_root_static = [SMatrix{D_large,D_large,Float32}(randn(Float32, D_large, D_large)) for _ in 1:N_large]; -Σs_static = [Σs_root_static[i] * adjoint(Σs_root_static[i]) + I for i in 1:N_large]; -Gs_static = [MvNormal(μs_static[i], Σs_static[i]) for i in 1:N_large]; -A_static = SMatrix{D_large,D_large,Float32}(randn(Float32, D_large, D_large)); -b_static = [SVector{D_large, Float32}(randn(Float32, D_large)) for _ in 1:N_large]; -Q_root_static = SMatrix{D_large,D_large,Float32}(randn(Float32, D_large, D_large)); -Q_static = Q_root_static * adjoint(Q_root_static) + I; - -function test_static(Gs, A, b, Q) - out = Vector{MvNormal{Float32, PDMat{Float32, SMatrix{32, 32, Float32, 1024}}, SVector{32, Float32}}}(undef, length(Gs)) - for i in 1:length(Gs) - @inbounds out[i] = kalman_predict(Gs[i], (A, b[i], Q)) - end - return out +println("CPU Covariance [1:3, 1:3]: ", Matrix(pred_G_test.Σ)[1:3, 1:3]) +println("GPU Covariance [1:3, 1:3]: ", Array(pred_Gs[end].Σ)[1:3, 1:3]) + +# ============================================================================= +# Kalman Update +# ============================================================================= + +function kalman_update(state, obs_params, observation) + μ = state.μ + Σ = state.Σ + H = obs_params[1] + c = obs_params[2] + R = obs_params[3] + + # Compute innovation distribution + m = H * μ + c + S = PDMat(X_A_Xt(Σ, H) + R) + ȳ = observation - m + + # Kalman gain + K = Σ * H' / S + + # Update parameters using Joseph form for numerical stability + μ̂ = μ + K * ȳ + Σ̂ = X_A_Xt(Σ, I - K * H) + X_A_Xt(R, K) + + return MvNormal(μ̂, Σ̂) end -display(@benchmark test_static($Gs_static, $A_static, $b_static, $Q_static)) +function kalman_step(state, dyn_params, obs_params, observation) + state = kalman_predict(state, dyn_params) + state = kalman_update(state, obs_params, observation) + return state +end + +# Observation parameters (H and c shared, R batched) +Hs = SharedCuMatrix(CUDA.randn(Float32, D_obs, D_state)) +cs = SharedCuVector(CUDA.randn(Float32, D_obs)) +I_obs = CuArray{Float32}(I, D_obs, D_obs) +I_obs_shared = SharedCuMatrix(I_obs) +Rs_root = BatchedCuMatrix(CUDA.randn(Float32, D_obs, D_obs, N)) +Rs = Rs_root .* adjoint.(Rs_root) .+ I_obs_shared + +obs_params = (Hs, cs, Rs) + +# Observations +observations = BatchedCuVector(CUDA.randn(Float32, D_obs, N)) + +# Run update on GPU +update_Gs = kalman_update.(pred_Gs, Ref(obs_params), observations); + +# Compare update to CPU +H_test = Array(Hs.data) +c_test = Array(cs.data) +R_test = PDMat(Array(Rs[end])) +obs_test = Array(observations[end]) + +update_G_test = kalman_update(pred_G_test, (H_test, c_test, R_test), obs_test) + +println("\n=== Update Comparison ===") +println("CPU Mean: ", update_G_test.μ[1:5]) +println("GPU Mean: ", Array(update_Gs.μ[end][1:5])) + +println("CPU Covariance [1:3, 1:3]: ", Matrix(update_G_test.Σ)[1:3, 1:3]) +println("GPU Covariance [1:3, 1:3]: ", Array(update_Gs.Σ[end])[1:3, 1:3]) + +# ============================================================================= +# Full Kalman Step +# ============================================================================= + +# Run full step on GPU (from original state) +step_Gs = kalman_step.(Gs, Ref(dyn_params), Ref(obs_params), observations); + +# Compare full step to CPU +step_G_test = kalman_step( + MvNormal(μ_test, PDMat(Σ_test)), + (A_test, b_test, Q_test), + (H_test, c_test, R_test), + obs_test, +) + +println("\n=== Full Step Comparison ===") +println("CPU Mean: ", step_G_test.μ[1:5]) +println("GPU Mean: ", Array(step_Gs.μ[end][1:5])) -@profview test_static(Gs_static, A_static, b_static, Q_static) +println("CPU Covariance [1:3, 1:3]: ", Matrix(step_G_test.Σ)[1:3, 1:3]) +println("GPU Covariance [1:3, 1:3]: ", Array(step_Gs.Σ[end])[1:3, 1:3]) From 7732369f3e5b337c1cbe0e82310b04672bbf4520 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 08:29:57 +0000 Subject: [PATCH 03/29] Fix formatting --- GeneralisedFilters/src/batching/operations.jl | 4 +--- GeneralisedFilters/src/batching/types.jl | 10 +++++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/GeneralisedFilters/src/batching/operations.jl b/GeneralisedFilters/src/batching/operations.jl index 8464fc61..5d138155 100644 --- a/GeneralisedFilters/src/batching/operations.jl +++ b/GeneralisedFilters/src/batching/operations.jl @@ -180,9 +180,7 @@ end # X_A_Xt for BatchedPDMat: X * P * X' where P = L * L' # Computed as (X * L) * (X * L)' using TRMM and SYRK function broadcasted( - ::typeof(X_A_Xt), - P::BatchedPDMat{T}, - X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + ::typeof(X_A_Xt), P::BatchedPDMat{T}, X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}} ) where {T} L = P.chol.factors N = get_batch_size(P, X) diff --git a/GeneralisedFilters/src/batching/types.jl b/GeneralisedFilters/src/batching/types.jl index 7ae2b399..02be3464 100644 --- a/GeneralisedFilters/src/batching/types.jl +++ b/GeneralisedFilters/src/batching/types.jl @@ -39,27 +39,27 @@ function Base.getindex(x::BatchedCuVector{T,CuVector{T}}, i::Int) where {T} end function Base.getindex(x::BatchedCuMatrix{T,CuMatrix{T}}, i::Int) where {T} - return view(x.data, :, :, i) + return view(x.data,:,:,i) end function Base.getindex( x::BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, i::Int ) where {T} - return LowerTriangular(view(x.data, :, :, i)) + return LowerTriangular(view(x.data,:,:,i)) end function Base.getindex( x::BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}, i::Int ) where {T} - return UpperTriangular(view(x.data, :, :, i)) + return UpperTriangular(view(x.data,:,:,i)) end function Base.getindex(x::BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}, i::Int) where {T} - return adjoint(view(x.data, :, :, i)) + return adjoint(view(x.data,:,:,i)) end function Base.getindex(x::BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}, i::Int) where {T} - return transpose(view(x.data, :, :, i)) + return transpose(view(x.data,:,:,i)) end # ============================================================================= From b00cfefca91f5e1fde6ad0454889a2101b2a0594 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 11:17:43 +0000 Subject: [PATCH 04/29] Implement generic wrapping system Removes parametric batched types and custom Cholesky/PDMat batched types and replaces them with a generic system based on StructArrays. --- .../src/batching/broadcasting.jl | 57 ++--- GeneralisedFilters/src/batching/operations.jl | 204 ++++++++++-------- GeneralisedFilters/src/batching/types.jl | 122 ++--------- GeneralisedFilters/src/batching/wrappers.jl | 47 ++-- research/batching/wrappers_demo.jl | 146 +++++++++++++ 5 files changed, 338 insertions(+), 238 deletions(-) create mode 100644 research/batching/wrappers_demo.jl diff --git a/GeneralisedFilters/src/batching/broadcasting.jl b/GeneralisedFilters/src/batching/broadcasting.jl index 92d50811..8ffda2a0 100644 --- a/GeneralisedFilters/src/batching/broadcasting.jl +++ b/GeneralisedFilters/src/batching/broadcasting.jl @@ -18,9 +18,6 @@ Base.BroadcastStyle(::Type{<:BatchedCuMatrix}) = BatchedStyle() Base.BroadcastStyle(::Type{<:BatchedCuVector}) = BatchedStyle() Base.BroadcastStyle(::Type{<:SharedCuMatrix}) = BatchedStyle() Base.BroadcastStyle(::Type{<:SharedCuVector}) = BatchedStyle() -Base.BroadcastStyle(::Type{<:BatchedCholesky}) = BatchedStyle() -Base.BroadcastStyle(::Type{<:BatchedPDMat}) = BatchedStyle() -# HACK: Currently hard-coded but can be replaced with a custom StructArray type Base.BroadcastStyle(::Type{<:StructArray}) = BatchedStyle() Base.BroadcastStyle(::BatchedStyle, ::BatchedStyle) = BatchedStyle() Base.BroadcastStyle(::BatchedStyle, ::DefaultArrayStyle{0}) = BatchedStyle() @@ -30,12 +27,8 @@ Base.BroadcastStyle(::BatchedStyle, ::DefaultArrayStyle{0}) = BatchedStyle() # ============================================================================= maybe_convert_ref(x) = x -function maybe_convert_ref(r::Base.RefValue{<:CuVector{T}}) where {T} - return SharedCuVector{T,CuVector{T}}(r[]) -end -function maybe_convert_ref(r::Base.RefValue{<:CuMatrix{T}}) where {T} - return SharedCuMatrix{T,CuMatrix{T}}(r[]) -end +maybe_convert_ref(r::Base.RefValue{<:CuVector}) = SharedCuVector(r[]) +maybe_convert_ref(r::Base.RefValue{<:CuMatrix}) = SharedCuMatrix(r[]) # ============================================================================= # Structural Operations (Pass-through) @@ -54,15 +47,12 @@ broadcasted(::typeof(getfield), r::Base.RefValue, s::Symbol) = getfield(r[], s) # StructArray Wrapping # ============================================================================= -inner_eltype(arg::BatchedCuVector{T}) where {T} = CuVector{T} -inner_eltype(arg::BatchedCuMatrix{T}) where {T} = CuMatrix{T} -inner_eltype(arg::SharedCuVector{T}) where {T} = CuVector{T} -inner_eltype(arg::SharedCuMatrix{T}) where {T} = CuMatrix{T} -inner_eltype(arg::BatchedPDMat{T}) where {T} = PDMat{T,CuMatrix{T}} +inner_eltype(arg::BatchedOrShared) = eltype(arg) +inner_eltype(arg::StructArray) = eltype(arg) inner_eltype(arg) = typeof(arg) function wrap_if_batched(::Type{T}, args...) where {T} - if any(arg -> arg isa Union{BatchedArray,SharedArray,BatchedPDMat}, args) + if any(arg -> arg isa Union{BatchedOrShared,StructArray}, args) field_names = fieldnames(T) element_types = Tuple{map(inner_eltype, args)...} ElType = Core.Compiler.return_type(T, element_types) @@ -73,20 +63,37 @@ function wrap_if_batched(::Type{T}, args...) where {T} end end +# ============================================================================= +# Generic Wrapper Broadcasting +# ============================================================================= + +""" + broadcasted(::Type{W}, args::BatchedOrShared...) + +Generic wrapper for any type constructor applied to batched arguments. +Works for single-field wrappers (Adjoint, Transpose, LowerTriangular, etc.) +as well as multi-field types (PDMat, Cholesky, etc.). + +Returns a StructArray where each element is the type applied to the +corresponding elements of the input arrays. +""" +function broadcasted(::Type{W}, args::Vararg{BatchedOrShared}) where {W} + element_types = Tuple{map(eltype, args)...} + ElType = Core.Compiler.return_type(W, element_types) + field_names = fieldnames(ElType) + nt = NamedTuple{field_names}(args) + return StructArray{ElType}(nt) +end + +# Redirect function forms to type constructors +broadcasted(::typeof(adjoint), A::BatchedOrShared) = broadcasted(Adjoint, A) +broadcasted(::typeof(transpose), A::BatchedOrShared) = broadcasted(Transpose, A) + # ============================================================================= # IR Transformation # ============================================================================= -const SKIP_BROADCAST = Set{Any}([ - tuple, - Core.tuple, - getfield, - getproperty, - adjoint, - transpose, - LowerTriangular, - UpperTriangular, -]) +const SKIP_BROADCAST = Set{Any}([tuple, Core.tuple, getfield, getproperty]) const BROADCAST_TYPES = Set{Any}([PDMat]) diff --git a/GeneralisedFilters/src/batching/operations.jl b/GeneralisedFilters/src/batching/operations.jl index 5d138155..b1fc55d1 100644 --- a/GeneralisedFilters/src/batching/operations.jl +++ b/GeneralisedFilters/src/batching/operations.jl @@ -1,53 +1,71 @@ import PDMats: X_A_Xt # ============================================================================= -# Adjoint/Transpose Broadcasting +# GEMM-Compatible Types # ============================================================================= -function broadcasted(::typeof(adjoint), A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} - return BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}(A.data) -end - -function broadcasted(::typeof(transpose), A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} - return BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}(A.data) -end - -function broadcasted(::typeof(adjoint), A::SharedCuMatrix{T,CuMatrix{T}}) where {T} - return SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}}(A.data) -end - -function broadcasted(::typeof(transpose), A::SharedCuMatrix{T,CuMatrix{T}}) where {T} - return SharedCuMatrix{T,Transpose{T,CuMatrix{T}}}(A.data) -end - -function broadcasted( - ::typeof(adjoint), A::BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}} -) where {T} - return BatchedCuMatrix{T,CuMatrix{T}}(A.data) -end - -function broadcasted( - ::typeof(adjoint), A::SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}} -) where {T} - return SharedCuMatrix{T,CuMatrix{T}}(A.data) -end - -function broadcasted(::Type{LowerTriangular}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} - return BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}(A.data) -end - -function broadcasted(::Type{UpperTriangular}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} - return BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}(A.data) -end +# Type aliases for StructArray-wrapped matrices +const BatchedAdjoint{T,M} = StructArray{ + Adjoint{T,CuArray{T,2,M}},1,@NamedTuple{parent::BatchedCuMatrix{T,M}} +} +const BatchedTranspose{T,M} = StructArray{ + Transpose{T,CuArray{T,2,M}},1,@NamedTuple{parent::BatchedCuMatrix{T,M}} +} +const SharedAdjoint{T,M} = StructArray{ + Adjoint{T,CuArray{T,2,M}},1,@NamedTuple{parent::SharedCuMatrix{T,M}} +} +const SharedTranspose{T,M} = StructArray{ + Transpose{T,CuArray{T,2,M}},1,@NamedTuple{parent::SharedCuMatrix{T,M}} +} + +# Union of all GEMM-compatible matrix types +const GEMMCompatibleMatrix{T} = Union{ + BatchedCuMatrix{T}, + SharedCuMatrix{T}, + BatchedAdjoint{T}, + BatchedTranspose{T}, + SharedAdjoint{T}, + SharedTranspose{T}, +} + +# trans_flag: returns BLAS transpose flag for each type +trans_flag(::BatchedCuMatrix{T}) where {T} = 'N' +trans_flag(::SharedCuMatrix{T}) where {T} = 'N' +trans_flag(::BatchedAdjoint{T}) where {T} = T <: Real ? 'T' : 'C' +trans_flag(::BatchedTranspose{T}) where {T} = 'T' +trans_flag(::SharedAdjoint{T}) where {T} = T <: Real ? 'T' : 'C' +trans_flag(::SharedTranspose{T}) where {T} = 'T' + +# gemm_data: extracts the underlying BatchedCuMatrix/SharedCuMatrix for GEMM +gemm_data(A::BatchedCuMatrix) = A +gemm_data(A::SharedCuMatrix) = A +gemm_data(A::BatchedAdjoint) = A.parent +gemm_data(A::BatchedTranspose) = A.parent +gemm_data(A::SharedAdjoint) = A.parent +gemm_data(A::SharedTranspose) = A.parent + +# inner_size_for_blas for wrapped types (delegates to underlying data) +inner_size_for_blas(A::BatchedAdjoint) = inner_size_for_blas(A.parent) +inner_size_for_blas(A::BatchedTranspose) = inner_size_for_blas(A.parent) +inner_size_for_blas(A::SharedAdjoint) = inner_size_for_blas(A.parent) +inner_size_for_blas(A::SharedTranspose) = inner_size_for_blas(A.parent) + +# batch_size for wrapped types +batch_size(A::BatchedAdjoint) = batch_size(A.parent) +batch_size(A::BatchedTranspose) = batch_size(A.parent) +batch_size(A::SharedAdjoint) = batch_size(A.parent) +batch_size(A::SharedTranspose) = batch_size(A.parent) + +# TODO: For nested wrappers (e.g., Adjoint{LowerTriangular{...}}), we should +# materialize the inner wrapper first before extracting. For now, we only +# support single-level Adjoint/Transpose wrappers for efficient GEMM dispatch. # ============================================================================= # Matrix Multiply Broadcasting # ============================================================================= function broadcasted( - ::typeof(*), - A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, - B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + ::typeof(*), A::GEMMCompatibleMatrix{T}, B::GEMMCompatibleMatrix{T} ) where {T} transA = trans_flag(A) transB = trans_flag(B) @@ -62,17 +80,17 @@ function broadcasted( C_data = CuArray{T}(undef, m, n, N) C = BatchedCuMatrix(C_data) - gemm_batched!(transA, transB, one(T), A, B, zero(T), C) + gemm_batched!(transA, transB, one(T), gemm_data(A), gemm_data(B), zero(T), C) return C end # Multi-argument multiply function broadcasted( ::typeof(*), - A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, - B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, - C::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, - rest::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}..., + A::GEMMCompatibleMatrix{T}, + B::GEMMCompatibleMatrix{T}, + C::GEMMCompatibleMatrix{T}, + rest::GEMMCompatibleMatrix{T}..., ) where {T} result = broadcasted(*, A, B) result = broadcasted(*, result, C) @@ -143,24 +161,24 @@ end # PDMat Broadcasting # ============================================================================= -function broadcasted(::Type{PDMat}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} - chol = cholesky_batched(A) - return BatchedPDMat{T}(chol) -end - -function broadcasted(::typeof(\), S::BatchedPDMat{T}, A::BatchedCuMatrix{T}) where {T} - return pdmat_solve(S, A) -end - -function broadcasted(::typeof(/), A::BatchedCuMatrix{T}, S::BatchedPDMat{T}) where {T} - # Need to actually transpose the data, not just wrap it - At_data = permutedims(A.data, (2, 1, 3)) - At = BatchedCuMatrix(At_data) - result_t = pdmat_solve(S, At) - # Transpose back - result_data = permutedims(result_t.data, (2, 1, 3)) - return BatchedCuMatrix(result_data) -end +# function broadcasted(::Type{PDMat}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} +# chol = cholesky_batched(A) +# return BatchedPDMat{T}(chol) +# end + +# function broadcasted(::typeof(\), S::BatchedPDMat{T}, A::BatchedCuMatrix{T}) where {T} +# return pdmat_solve(S, A) +# end + +# function broadcasted(::typeof(/), A::BatchedCuMatrix{T}, S::BatchedPDMat{T}) where {T} +# # Need to actually transpose the data, not just wrap it +# At_data = permutedims(A.data, (2, 1, 3)) +# At = BatchedCuMatrix(At_data) +# result_t = pdmat_solve(S, At) +# # Transpose back +# result_data = permutedims(result_t.data, (2, 1, 3)) +# return BatchedCuMatrix(result_data) +# end # ============================================================================= # Quadratic Form Broadcasting @@ -179,34 +197,34 @@ end # X_A_Xt for BatchedPDMat: X * P * X' where P = L * L' # Computed as (X * L) * (X * L)' using TRMM and SYRK -function broadcasted( - ::typeof(X_A_Xt), P::BatchedPDMat{T}, X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}} -) where {T} - L = P.chol.factors - N = get_batch_size(P, X) - - X_inner = inner_size_for_blas(X) - m = X_inner[1] - - # Copy X to XL (TRMM overwrites in-place) - XL_data = if X isa SharedCuMatrix - repeat(reshape(X.data, size(X.data, 1), size(X.data, 2), 1), 1, 1, N) - else - copy(X.data) - end - XL = BatchedCuMatrix(XL_data) - - # XL = X * L using TRMM (side='R' for right multiply, uplo='L' for lower triangular) - L_data = BatchedCuMatrix(L.data) - trmm_batched!('R', 'L', 'N', 'N', one(T), L_data, XL) - - # Result = XL * XL' using SYRK (fills lower triangle) - Result_data = CuArray{T}(undef, m, m, N) - Result = BatchedCuMatrix(Result_data) - syrk_batched!('L', 'N', one(T), XL, zero(T), Result) - - # Symmetrize: copy lower triangle to upper - symmetrize_lower!(Result) - - return Result -end +# function broadcasted( +# ::typeof(X_A_Xt), P::BatchedPDMat{T}, X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}} +# ) where {T} +# L = P.chol.factors +# N = get_batch_size(P, X) + +# X_inner = inner_size_for_blas(X) +# m = X_inner[1] + +# # Copy X to XL (TRMM overwrites in-place) +# XL_data = if X isa SharedCuMatrix +# repeat(reshape(X.data, size(X.data, 1), size(X.data, 2), 1), 1, 1, N) +# else +# copy(X.data) +# end +# XL = BatchedCuMatrix(XL_data) + +# # XL = X * L using TRMM (side='R' for right multiply, uplo='L' for lower triangular) +# L_data = BatchedCuMatrix(L.data) +# trmm_batched!('R', 'L', 'N', 'N', one(T), L_data, XL) + +# # Result = XL * XL' using SYRK (fills lower triangle) +# Result_data = CuArray{T}(undef, m, m, N) +# Result = BatchedCuMatrix(Result_data) +# syrk_batched!('L', 'N', one(T), XL, zero(T), Result) + +# # Symmetrize: copy lower triangle to upper +# symmetrize_lower!(Result) + +# return Result +# end diff --git a/GeneralisedFilters/src/batching/types.jl b/GeneralisedFilters/src/batching/types.jl index 02be3464..8efc0bd1 100644 --- a/GeneralisedFilters/src/batching/types.jl +++ b/GeneralisedFilters/src/batching/types.jl @@ -2,28 +2,25 @@ using CUDA using LinearAlgebra: Adjoint, Transpose, LowerTriangular, UpperTriangular, UniformScaling, Cholesky using PDMats: PDMat +using StructArrays: StructArray export BatchedCuMatrix, BatchedCuVector export SharedCuMatrix, SharedCuVector -export BatchedPDMat, BatchedCholesky # ============================================================================= # Core Batched Types # ============================================================================= -struct BatchedCuMatrix{T,Inner<:AbstractMatrix{T}} <: AbstractVector{Inner} - data::CuArray{T,3} +struct BatchedCuMatrix{T,M} <: AbstractVector{CuArray{T,2,M}} + data::CuArray{T,3,M} end -struct BatchedCuVector{T,Inner<:AbstractVector{T}} <: AbstractVector{Inner} - data::CuMatrix{T} +struct BatchedCuVector{T,M} <: AbstractVector{CuArray{T,1,M}} + data::CuArray{T,2,M} end const BatchedArray = Union{BatchedCuVector,BatchedCuMatrix} -BatchedCuMatrix(data::CuArray{T,3}) where {T} = BatchedCuMatrix{T,CuMatrix{T}}(data) -BatchedCuVector(data::CuMatrix{T}) where {T} = BatchedCuVector{T,CuVector{T}}(data) - batch_size(x::BatchedCuVector) = size(x.data, 2) batch_size(x::BatchedCuMatrix) = size(x.data, 3) @@ -34,53 +31,29 @@ Base.length(x::BatchedArray) = batch_size(x) inner_size(x::BatchedCuVector) = (size(x.data, 1),) inner_size(x::BatchedCuMatrix) = (size(x.data, 1), size(x.data, 2)) -function Base.getindex(x::BatchedCuVector{T,CuVector{T}}, i::Int) where {T} - return view(x.data, :, i) -end - -function Base.getindex(x::BatchedCuMatrix{T,CuMatrix{T}}, i::Int) where {T} - return view(x.data,:,:,i) -end - -function Base.getindex( - x::BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, i::Int -) where {T} - return LowerTriangular(view(x.data,:,:,i)) -end - -function Base.getindex( - x::BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}, i::Int -) where {T} - return UpperTriangular(view(x.data,:,:,i)) -end - -function Base.getindex(x::BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}, i::Int) where {T} - return adjoint(view(x.data,:,:,i)) -end - -function Base.getindex(x::BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}, i::Int) where {T} - return transpose(view(x.data,:,:,i)) -end +Base.getindex(x::BatchedCuVector, i::Int) = view(x.data, :, i) +Base.getindex(x::BatchedCuMatrix, i::Int) = view(x.data,:,:,i) # ============================================================================= # Shared Types (same data reused across all batch elements) # ============================================================================= -struct SharedCuMatrix{T,Inner<:AbstractMatrix{T}} <: AbstractVector{Inner} - data::CuMatrix{T} +struct SharedCuMatrix{T,M} <: AbstractVector{CuArray{T,2,M}} + data::CuArray{T,2,M} end -struct SharedCuVector{T,Inner<:AbstractVector{T}} <: AbstractVector{Inner} - data::CuVector{T} +struct SharedCuVector{T,M} <: AbstractVector{CuArray{T,1,M}} + data::CuArray{T,1,M} end const SharedArray = Union{SharedCuVector,SharedCuMatrix} -SharedCuMatrix(data::CuMatrix{T}) where {T} = SharedCuMatrix{T,CuMatrix{T}}(data) -SharedCuVector(data::CuVector{T}) where {T} = SharedCuVector{T,CuVector{T}}(data) +# Convenience constructors that infer memory type +SharedCuMatrix(data::CuArray{T,2,M}) where {T,M} = SharedCuMatrix{T,M}(data) +SharedCuVector(data::CuArray{T,1,M}) where {T,M} = SharedCuVector{T,M}(data) -Shared(x::CuMatrix{T}) where {T} = SharedCuMatrix(x) -Shared(x::CuVector{T}) where {T} = SharedCuVector(x) +Shared(x::CuArray{T,2,M}) where {T,M} = SharedCuMatrix{T,M}(x) +Shared(x::CuArray{T,1,M}) where {T,M} = SharedCuVector{T,M}(x) batch_size(::SharedCuVector) = nothing batch_size(::SharedCuMatrix) = nothing @@ -93,34 +66,16 @@ Base.size(x::SharedCuMatrix) = (1,) Base.length(::SharedArray) = 1 Base.getindex(x::SharedCuVector, ::Int) = x.data -Base.getindex(x::SharedCuMatrix{T,CuMatrix{T}}, ::Int) where {T} = x.data -function Base.getindex(x::SharedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, ::Int) where {T} - return LowerTriangular(x.data) -end +Base.getindex(x::SharedCuMatrix, ::Int) = x.data # ============================================================================= -# Type Aliases and Union Types for Dispatch +# Union Types for Dispatch # ============================================================================= -const AnyBatchedMatrix{T} = Union{ - BatchedCuMatrix{T,CuMatrix{T}}, - BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}, - BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}, - BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, - BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}, -} - -const AnySharedMatrix{T} = Union{ - SharedCuMatrix{T,CuMatrix{T}}, - SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}}, - SharedCuMatrix{T,Transpose{T,CuMatrix{T}}}, - SharedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, - SharedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}, +const BatchedOrShared = Union{ + BatchedCuMatrix,BatchedCuVector,SharedCuMatrix,SharedCuVector,StructArray } -const AnyMatrix{T} = Union{AnyBatchedMatrix{T},AnySharedMatrix{T}} -const AnyVector{T} = Union{BatchedCuVector{T},SharedCuVector{T}} - # ============================================================================= # Helper Functions # ============================================================================= @@ -135,18 +90,6 @@ unwrap_data(A::SharedCuMatrix) = A.data unwrap_data(x::BatchedCuVector) = x.data unwrap_data(x::SharedCuVector) = x.data -trans_flag(::BatchedCuMatrix{T,CuMatrix{T}}) where {T} = 'N' -trans_flag(::BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}) where {T} = T <: Real ? 'T' : 'C' -trans_flag(::BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}) where {T} = 'T' -trans_flag(::BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}) where {T} = 'N' -trans_flag(::BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}) where {T} = 'N' - -trans_flag(::SharedCuMatrix{T,CuMatrix{T}}) where {T} = 'N' -trans_flag(::SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}}) where {T} = T <: Real ? 'T' : 'C' -trans_flag(::SharedCuMatrix{T,Transpose{T,CuMatrix{T}}}) where {T} = 'T' -trans_flag(::SharedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}) where {T} = 'N' -trans_flag(::SharedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}) where {T} = 'N' - function inner_size_for_blas(A::BatchedCuMatrix) m, n = size(A.data, 1), size(A.data, 2) return (m, n) @@ -167,31 +110,6 @@ function get_batch_size(args...) return error("At least one argument must be batched") end -# ============================================================================= -# Stateful Wrapper Types (Cholesky, PDMat) -# ============================================================================= - -struct BatchedCholesky{T} <: AbstractVector{Cholesky{T,CuMatrix{T}}} - factors::BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}} - info::CuVector{Int32} - uplo::Char -end - -struct BatchedPDMat{T} <: AbstractVector{PDMat{T,CuMatrix{T}}} - chol::BatchedCholesky{T} -end - -batch_size(c::BatchedCholesky) = batch_size(c.factors) -batch_size(p::BatchedPDMat) = batch_size(p.chol) - -inner_size(c::BatchedCholesky) = inner_size(c.factors) -inner_size(p::BatchedPDMat) = inner_size(p.chol) - -Base.size(c::BatchedCholesky) = (batch_size(c),) -Base.size(p::BatchedPDMat) = (batch_size(p),) - -Base.getindex(p::BatchedPDMat{T}, i::Int) where {T} = p.chol.factors[i] * p.chol.factors[i]' - # ============================================================================= # Pointer Array Creation # ============================================================================= diff --git a/GeneralisedFilters/src/batching/wrappers.jl b/GeneralisedFilters/src/batching/wrappers.jl index 6f7e741b..969a6bb8 100644 --- a/GeneralisedFilters/src/batching/wrappers.jl +++ b/GeneralisedFilters/src/batching/wrappers.jl @@ -1,5 +1,7 @@ using Magma using Magma.LibMagma +using LinearAlgebra: cholesky, Cholesky, LowerTriangular +using StructArrays: StructArray # ============================================================================= # Trivial Wrappers (reductions and elementwise operations) @@ -389,27 +391,25 @@ function gemv_batched_smallsq!( return y end -function potrf_batched!(uplo::Char, A::BatchedCuMatrix{Float32}) +function potrf_batched!(uplo::Char, A::BatchedCuMatrix{Float32}, info::CuVector{Int64}) N = batch_size(A) n = size(A.data, 1) lda = n dA = create_pointer_array(A) - info_gpu = CUDA.zeros(Int64, N) CUDA.synchronize() queue_ptr = Ref{LibMagma.magma_queue_t}() LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) LibMagma.magma_spotrf_batched( - magma_uplo(uplo), n, dA, lda, pointer(info_gpu), N, queue_ptr[] + magma_uplo(uplo), n, dA, lda, pointer(info), N, queue_ptr[] ) LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) - factors = BatchedCuMatrix{Float32,LowerTriangular{Float32,CuMatrix{Float32}}}(A.data) - return BatchedCholesky{Float32}(factors, info_gpu, uplo) + return A end function potrs_batched!( @@ -489,26 +489,37 @@ end # Higher-level Cholesky Operations # ============================================================================= -function cholesky_batched(A::BatchedCuMatrix{T}) where {T} +function broadcasted(::typeof(cholesky), A::BatchedCuMatrix{T,M}) where {T,M} + N = batch_size(A) A_copy = BatchedCuMatrix(copy(A.data)) - return potrf_batched!('L', A_copy) -end + info = CUDA.zeros(Int64, N) -function pdmat_solve(S::BatchedPDMat{T}, B::BatchedCuMatrix{T}) where {T} - L = S.chol.factors - L_data = BatchedCuMatrix(L.data) + potrf_batched!('L', A_copy, info) - B_copy = BatchedCuMatrix(copy(B.data)) + factors_wrapped = broadcasted(LowerTriangular, A_copy) - # Solve L*L'*X = B via two triangular solves: - # 1. Solve L*Y = B (Y stored in B_copy) - trsm_batched!('L', 'L', 'N', 'N', one(T), L_data, B_copy) - # 2. Solve L'*X = Y (X stored in B_copy) - trsm_batched!('L', 'L', 'T', 'N', one(T), L_data, B_copy) + # TODO: Use a lazy constant vector for uplo instead of dense fill + uplo = fill('L', N) - return B_copy + ElType = Cholesky{T,eltype(A)} + return StructArray{ElType}((; factors=factors_wrapped, uplo=uplo, info=info)) end +# function pdmat_solve(S::BatchedPDMat{T}, B::BatchedCuMatrix{T}) where {T} +# L = S.chol.factors +# L_data = BatchedCuMatrix(L.data) + +# B_copy = BatchedCuMatrix(copy(B.data)) + +# # Solve L*L'*X = B via two triangular solves: +# # 1. Solve L*Y = B (Y stored in B_copy) +# trsm_batched!('L', 'L', 'N', 'N', one(T), L_data, B_copy) +# # 2. Solve L'*X = Y (X stored in B_copy) +# trsm_batched!('L', 'L', 'T', 'N', one(T), L_data, B_copy) + +# return B_copy +# end + # ============================================================================= # Batched TRMM (Triangular Matrix Multiply) # ============================================================================= diff --git a/research/batching/wrappers_demo.jl b/research/batching/wrappers_demo.jl new file mode 100644 index 00000000..a135cef1 --- /dev/null +++ b/research/batching/wrappers_demo.jl @@ -0,0 +1,146 @@ +using GeneralisedFilters +using CUDA +using LinearAlgebra +using StructArrays +using Base.Broadcast: broadcasted +using Distributions +using PDMats + +# ============================================================================= +# Configuration +# ============================================================================= + +N = 10 # batch size +D = 4 # matrix dimension + +println("Creating batched matrices...") +A = BatchedCuMatrix(CUDA.randn(Float32, D, D, N)); + +# ============================================================================= +# Test 1: Basic wrapper creation +# ============================================================================= + +println("\n=== Test 1: Basic wrapper creation ===\n") + +# Adjoint +A_adj = broadcasted(Adjoint, A); +println("Adjoint type: ", typeof(A_adj)) +println("Adjoint eltype: ", eltype(A_adj)) +println("First element type: ", typeof(A_adj[1])) + +# Transpose +A_trans = broadcasted(Transpose, A); +println("\nTranspose type: ", typeof(A_trans)) +println("Transpose eltype: ", eltype(A_trans)) + +# LowerTriangular +A_lower = broadcasted(LowerTriangular, A); +println("\nLowerTriangular type: ", typeof(A_lower)) +println("LowerTriangular eltype: ", eltype(A_lower)) + +# UpperTriangular +A_upper = broadcasted(UpperTriangular, A); +println("\nUpperTriangular type: ", typeof(A_upper)) +println("UpperTriangular eltype: ", eltype(A_upper)) + +# ============================================================================= +# Test 2: Function form redirects +# ============================================================================= + +println("\n=== Test 2: Function form redirects ===\n") + +A_adj2 = broadcasted(adjoint, A); +println("adjoint redirect type: ", typeof(A_adj2)) +println("Types match: ", typeof(A_adj) == typeof(A_adj2)) + +A_trans2 = broadcasted(transpose, A); +println("\ntranspose redirect type: ", typeof(A_trans2)) +println("Types match: ", typeof(A_trans) == typeof(A_trans2)) + +# ============================================================================= +# Test 3: Nested wrappers +# ============================================================================= + +println("\n=== Test 3: Nested wrappers ===\n") + +# Adjoint of LowerTriangular +A_lower_adj = broadcasted(Adjoint, A_lower); +println("Adjoint(LowerTriangular) type: ", typeof(A_lower_adj)) +println("Adjoint(LowerTriangular) eltype: ", eltype(A_lower_adj)) + +# ============================================================================= +# Test 4: Verify element access +# ============================================================================= + +println("\n=== Test 4: Element access verification ===\n") + +# Get first element of each wrapped array +println("A[1] type: ", typeof(A[1])) +println("A_adj[1] type: ", typeof(A_adj[1])) +println("A_lower[1] type: ", typeof(A_lower[1])) + +# Check values match +A_cpu = Array(A[1]); +A_adj_cpu = Array(parent(A_adj[1])); +println("\nValues match (A vs parent of A_adj): ", A_cpu ≈ A_adj_cpu) + +# ============================================================================= +# Test 5: SharedCuMatrix wrappers +# ============================================================================= + +println("\n=== Test 5: SharedCuMatrix wrappers ===\n") + +S = SharedCuMatrix(CUDA.randn(Float32, D, D)); +S_adj = broadcasted(Adjoint, S); +println("SharedCuMatrix adjoint type: ", typeof(S_adj)) +println("SharedCuMatrix adjoint eltype: ", eltype(S_adj)) + +# ============================================================================= +# Test 6: Cholesky +# ============================================================================= + +println("\n=== Test 6: Batched Cholesky ===\n") + +# Create positive definite matrices: B = A * A' +B = A .* broadcasted(adjoint, A); +println("B type: ", typeof(B)) + +chol_result = cholesky.(B); +println("Cholesky result type: ", typeof(chol_result)) +println("Cholesky eltype: ", eltype(chol_result)) + +# Check fields +println("\nCholesky fields:") +println(" factors type: ", typeof(chol_result.factors)) +println(" uplo type: ", typeof(chol_result.uplo)) +println(" info type: ", typeof(chol_result.info)) + +# Access individual element +# SKIP: requires scalar indexing into info +# println("\nFirst element:") +# println(" chol_result[1] type: ", typeof(chol_result[1])) + +# ============================================================================= +# Test 7: PDMat wrapper +# ============================================================================= + +println("\n=== Test 7: PDMat wrapper ===\n") + +P = broadcasted(PDMat, B, chol_result); +println("PDMat result type: ", typeof(P)) +println("PDMat eltype: ", eltype(P)) + +println("\nPDMat fields:") +# println(" dim type: ", typeof(P.dim)) # SKIP: not a real field +println(" mat type: ", typeof(P.mat)) +println(" chol type: ", typeof(P.chol)) + +# ============================================================================= +# Test 8: MvNormal wrapper +# ============================================================================= + +println("\n=== Test 8: MvNormal wrapper ===\n") +μ = BatchedCuVector(CUDA.randn(Float32, D, N)); +G = broadcasted(MvNormal, μ, P); +println("MvNormal type: ", typeof(G)) +println("MvNormal eltype: ", eltype(G)) From f96cb6dfb00f9ba685e4e1be5451ffcad35213eb Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 12:31:23 +0000 Subject: [PATCH 05/29] Temporarily disable caching due to world age bugs --- .../src/batching/broadcasting.jl | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/GeneralisedFilters/src/batching/broadcasting.jl b/GeneralisedFilters/src/batching/broadcasting.jl index 8ffda2a0..636217df 100644 --- a/GeneralisedFilters/src/batching/broadcasting.jl +++ b/GeneralisedFilters/src/batching/broadcasting.jl @@ -198,12 +198,17 @@ function Broadcast.materialize(bc::Broadcasted{BatchedStyle}) argtypes = Tuple{map(typeof, args)...} key = (f, argtypes) - if !haskey(BATCHED_FUNC_CACHE, key) - println(" [Generating batched version of $f]") - batched_f = generate_batched_function(f, argtypes) - BATCHED_FUNC_CACHE[key] = batched_f - end - - batched_f = BATCHED_FUNC_CACHE[key] + # HACK: caching was issues when functions were modified + # if !haskey(BATCHED_FUNC_CACHE, key) + # println(" [Generating batched version of $f]") + # batched_f = generate_batched_function(f, argtypes) + # BATCHED_FUNC_CACHE[key] = batched_f + # end + + # batched_f = BATCHED_FUNC_CACHE[key] + # return Base.invokelatest(batched_f, nothing, args...) + + println(" [Generating batched version of $f]") + batched_f = generate_batched_function(f, argtypes) return Base.invokelatest(batched_f, nothing, args...) end From 1fe3b2faa581cd2d55f289222396c81cdd5eabf3 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 12:31:53 +0000 Subject: [PATCH 06/29] Update PD/Cholesky ops to new wrapping interface --- GeneralisedFilters/src/batching/operations.jl | 123 +++++++++++------- GeneralisedFilters/src/batching/types.jl | 3 - 2 files changed, 73 insertions(+), 53 deletions(-) diff --git a/GeneralisedFilters/src/batching/operations.jl b/GeneralisedFilters/src/batching/operations.jl index b1fc55d1..7568f1f0 100644 --- a/GeneralisedFilters/src/batching/operations.jl +++ b/GeneralisedFilters/src/batching/operations.jl @@ -161,24 +161,47 @@ end # PDMat Broadcasting # ============================================================================= -# function broadcasted(::Type{PDMat}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} -# chol = cholesky_batched(A) -# return BatchedPDMat{T}(chol) -# end - -# function broadcasted(::typeof(\), S::BatchedPDMat{T}, A::BatchedCuMatrix{T}) where {T} -# return pdmat_solve(S, A) -# end - -# function broadcasted(::typeof(/), A::BatchedCuMatrix{T}, S::BatchedPDMat{T}) where {T} -# # Need to actually transpose the data, not just wrap it -# At_data = permutedims(A.data, (2, 1, 3)) -# At = BatchedCuMatrix(At_data) -# result_t = pdmat_solve(S, At) -# # Transpose back -# result_data = permutedims(result_t.data, (2, 1, 3)) -# return BatchedCuMatrix(result_data) -# end +# HACK: PDMat is a constructor so will use +# `broadcasted(::Type{W}, args::Union{BatchedCuMatrix, BatchedCuVector, SharedCuMatrix, SharedCuVector, StructArray}...) where W` +# rather than the desired recursive broadcast to +# `PDMat(mat::AbstractMatrix) = PDMat(mat, cholesky(mat))` +# This method hardcodes a manual override for BatchedCuMatrix inputs. This should be replaced +# by a more general solution in the future. +function broadcasted(::Type{PDMat}, A::BatchedCuMatrix{T}) where {T} + chol = cholesky.(A) + return PDMat.(A, chol) +end + +# HACK: Addition with PDMat extracts .mat field. Should be replaced by automatic +# materialization of PDMat to BatchedCuMatrix in the future. +function broadcasted( + ::typeof(+), A::BatchedCuMatrix{T}, P::StructArray{<:PDMat{T}} +) where {T} + return broadcasted(+, A, P.mat) +end + +function broadcasted( + ::typeof(+), P::StructArray{<:PDMat{T}}, A::BatchedCuMatrix{T} +) where {T} + return broadcasted(+, P.mat, A) +end + +# A / S where S is PDMat: computes A * inv(S) +# potrs solves S * X = B, so we solve S * X = A' and transpose back +function broadcasted( + ::typeof(/), A::BatchedCuMatrix{T}, S::StructArray{<:PDMat{T}} +) where {T} + L = S.chol.factors.data + + # Transpose A: potrs solves S*X = B, we want A*inv(S) = (inv(S)*A')' + At = BatchedCuMatrix(permutedims(A.data, (2, 1, 3))) + + # Solve S * X = A' in-place (result stored in At) + potrs_batched!('L', L, At) + + # Transpose back + return BatchedCuMatrix(permutedims(At.data, (2, 1, 3))) +end # ============================================================================= # Quadratic Form Broadcasting @@ -195,36 +218,36 @@ function broadcasted( return broadcasted(*, temp, Xt) end -# X_A_Xt for BatchedPDMat: X * P * X' where P = L * L' +# X_A_Xt for StructArray{PDMat}: X * P * X' where P = L * L' # Computed as (X * L) * (X * L)' using TRMM and SYRK -# function broadcasted( -# ::typeof(X_A_Xt), P::BatchedPDMat{T}, X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}} -# ) where {T} -# L = P.chol.factors -# N = get_batch_size(P, X) - -# X_inner = inner_size_for_blas(X) -# m = X_inner[1] - -# # Copy X to XL (TRMM overwrites in-place) -# XL_data = if X isa SharedCuMatrix -# repeat(reshape(X.data, size(X.data, 1), size(X.data, 2), 1), 1, 1, N) -# else -# copy(X.data) -# end -# XL = BatchedCuMatrix(XL_data) - -# # XL = X * L using TRMM (side='R' for right multiply, uplo='L' for lower triangular) -# L_data = BatchedCuMatrix(L.data) -# trmm_batched!('R', 'L', 'N', 'N', one(T), L_data, XL) - -# # Result = XL * XL' using SYRK (fills lower triangle) -# Result_data = CuArray{T}(undef, m, m, N) -# Result = BatchedCuMatrix(Result_data) -# syrk_batched!('L', 'N', one(T), XL, zero(T), Result) - -# # Symmetrize: copy lower triangle to upper -# symmetrize_lower!(Result) - -# return Result -# end +# HACK: this function should dispatch to specialised `*` for triangular types but this is +# not yet implemented +function broadcasted( + ::typeof(X_A_Xt), + P::StructArray{<:PDMat{T}}, + X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + # P.chol.factors is StructArray{LowerTriangular}, .data is the BatchedCuMatrix + L = P.chol.factors.data + N = get_batch_size(L, X) + out_dim = inner_size_for_blas(X)[1] + + # Copy X for in-place TRMM + XL = if X isa SharedCuMatrix + BatchedCuMatrix(repeat(reshape(X.data, size(X.data)..., 1), 1, 1, N)) + else + BatchedCuMatrix(copy(X.data)) + end + + # XL = X * L using TRMM (side='R', uplo='L', no transpose, non-unit diagonal) + trmm_batched!('R', 'L', 'N', 'N', one(T), L, XL) + + # result = XL * XL' using SYRK (fills lower triangle only) + result = BatchedCuMatrix(CuArray{T}(undef, out_dim, out_dim, N)) + syrk_batched!('L', 'N', one(T), XL, zero(T), result) + + # Copy lower triangle to upper for full symmetric matrix + symmetrize_lower!(result) + + return result +end diff --git a/GeneralisedFilters/src/batching/types.jl b/GeneralisedFilters/src/batching/types.jl index 8efc0bd1..88843b24 100644 --- a/GeneralisedFilters/src/batching/types.jl +++ b/GeneralisedFilters/src/batching/types.jl @@ -48,9 +48,6 @@ end const SharedArray = Union{SharedCuVector,SharedCuMatrix} -# Convenience constructors that infer memory type -SharedCuMatrix(data::CuArray{T,2,M}) where {T,M} = SharedCuMatrix{T,M}(data) -SharedCuVector(data::CuArray{T,1,M}) where {T,M} = SharedCuVector{T,M}(data) Shared(x::CuArray{T,2,M}) where {T,M} = SharedCuMatrix{T,M}(x) Shared(x::CuArray{T,1,M}) where {T,M} = SharedCuVector{T,M}(x) From 424d33fd3c67c28e25996b1f4aa68fc42c88960e Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 12:32:11 +0000 Subject: [PATCH 07/29] Update batching demo script to new wrapper interface --- research/batching/batching_demo.jl | 40 ++++++++++++++++-------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/research/batching/batching_demo.jl b/research/batching/batching_demo.jl index d2841061..99ea2f77 100644 --- a/research/batching/batching_demo.jl +++ b/research/batching/batching_demo.jl @@ -17,9 +17,9 @@ Magma.magma_init() # Configuration # ============================================================================= -D_state = 64 -D_obs = 64 -N = 1000 +D_state = 2 +D_obs = 2 +N = 3 function kalman_predict(state, dyn_params) A = dyn_params[1] @@ -45,7 +45,7 @@ Q = Q_root * Q_root' + I Qs = SharedCuMatrix(Q) Σ_PDs = broadcasted(PDMat, Σs); -Gs = StructArray{MvNormal}((μ=μs, Σ=Σ_PDs)); +Gs = MvNormal.(μs, Σ_PDs); function kalman_predict(state, dyn_params) A = dyn_params[1] @@ -53,7 +53,8 @@ function kalman_predict(state, dyn_params) Q = dyn_params[3] μ̂ = A * state.μ + b - Σ̂ = X_A_Xt(state.Σ, A) + Q + Σ̂ = PDMat(X_A_Xt(state.Σ, A) + Q) + return MvNormal(μ̂, Σ̂) end @@ -69,11 +70,11 @@ Q_test = Array(Qs.data) pred_G_test = kalman_predict(MvNormal(μ_test, PDMat(Σ_test)), (A_test, b_test, Q_test)) println("=== Predict Comparison ===") -println("CPU Mean: ", pred_G_test.μ[1:5]) -println("GPU Mean: ", Array(pred_Gs[end].μ[1:5])) +println("CPU Mean: ", pred_G_test.μ) +println("GPU Mean: ", Array(pred_Gs.μ[end])) -println("CPU Covariance [1:3, 1:3]: ", Matrix(pred_G_test.Σ)[1:3, 1:3]) -println("GPU Covariance [1:3, 1:3]: ", Array(pred_Gs[end].Σ)[1:3, 1:3]) +println("CPU Covariance: ", Matrix(pred_G_test.Σ)) +println("GPU Covariance: ", Array(pred_Gs.Σ.mat[end])) # ============================================================================= # Kalman Update @@ -96,7 +97,7 @@ function kalman_update(state, obs_params, observation) # Update parameters using Joseph form for numerical stability μ̂ = μ + K * ȳ - Σ̂ = X_A_Xt(Σ, I - K * H) + X_A_Xt(R, K) + Σ̂ = PDMat(X_A_Xt(Σ, I - K * H) + X_A_Xt(R, K)) return MvNormal(μ̂, Σ̂) end @@ -114,6 +115,7 @@ I_obs = CuArray{Float32}(I, D_obs, D_obs) I_obs_shared = SharedCuMatrix(I_obs) Rs_root = BatchedCuMatrix(CUDA.randn(Float32, D_obs, D_obs, N)) Rs = Rs_root .* adjoint.(Rs_root) .+ I_obs_shared +Rs = PDMat.(Rs); obs_params = (Hs, cs, Rs) @@ -126,17 +128,17 @@ update_Gs = kalman_update.(pred_Gs, Ref(obs_params), observations); # Compare update to CPU H_test = Array(Hs.data) c_test = Array(cs.data) -R_test = PDMat(Array(Rs[end])) +R_test = PDMat(Array(Rs.mat[end])) obs_test = Array(observations[end]) update_G_test = kalman_update(pred_G_test, (H_test, c_test, R_test), obs_test) println("\n=== Update Comparison ===") -println("CPU Mean: ", update_G_test.μ[1:5]) -println("GPU Mean: ", Array(update_Gs.μ[end][1:5])) +println("CPU Mean: ", update_G_test.μ) +println("GPU Mean: ", Array(update_Gs.μ[end])) -println("CPU Covariance [1:3, 1:3]: ", Matrix(update_G_test.Σ)[1:3, 1:3]) -println("GPU Covariance [1:3, 1:3]: ", Array(update_Gs.Σ[end])[1:3, 1:3]) +println("CPU Covariance: ", Matrix(update_G_test.Σ)) +println("GPU Covariance: ", Array(update_Gs.Σ.mat[end])) # ============================================================================= # Full Kalman Step @@ -154,8 +156,8 @@ step_G_test = kalman_step( ) println("\n=== Full Step Comparison ===") -println("CPU Mean: ", step_G_test.μ[1:5]) -println("GPU Mean: ", Array(step_Gs.μ[end][1:5])) +println("CPU Mean: ", step_G_test.μ) +println("GPU Mean: ", Array(step_Gs.μ[end])) -println("CPU Covariance [1:3, 1:3]: ", Matrix(step_G_test.Σ)[1:3, 1:3]) -println("GPU Covariance [1:3, 1:3]: ", Array(step_Gs.Σ[end])[1:3, 1:3]) +println("CPU Covariance: ", Matrix(step_G_test.Σ)) +println("GPU Covariance: ", Array(step_Gs.Σ.mat[end])) From d21b09b5f8befa9a5e83b376f56b230507d23203 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 12:34:29 +0000 Subject: [PATCH 08/29] Fix formatting --- GeneralisedFilters/src/batching/types.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/GeneralisedFilters/src/batching/types.jl b/GeneralisedFilters/src/batching/types.jl index 88843b24..ece85e77 100644 --- a/GeneralisedFilters/src/batching/types.jl +++ b/GeneralisedFilters/src/batching/types.jl @@ -48,7 +48,6 @@ end const SharedArray = Union{SharedCuVector,SharedCuMatrix} - Shared(x::CuArray{T,2,M}) where {T,M} = SharedCuMatrix{T,M}(x) Shared(x::CuArray{T,1,M}) where {T,M} = SharedCuVector{T,M}(x) From 876d58cdba2f86dab150dd821511139274e05894 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 13:23:57 +0000 Subject: [PATCH 09/29] Reintroduce broadcast caching with proper world age handling --- .../src/batching/broadcasting.jl | 52 +++++++++++++++---- research/batching/batching_demo.jl | 2 + 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/GeneralisedFilters/src/batching/broadcasting.jl b/GeneralisedFilters/src/batching/broadcasting.jl index 636217df..8b9797ba 100644 --- a/GeneralisedFilters/src/batching/broadcasting.jl +++ b/GeneralisedFilters/src/batching/broadcasting.jl @@ -8,6 +8,8 @@ import Base.Broadcast: broadcasted import PDMats: PDMat +export BATCHED_CACHE_VERBOSITY, clear_batched_cache! + # ============================================================================= # Broadcast Style # ============================================================================= @@ -184,7 +186,24 @@ end # Broadcast Materialization # ============================================================================= -const BATCHED_FUNC_CACHE = Dict{Tuple,Any}() +# Verbosity levels: :silent, :verbose, :debug +# :silent - no output +# :verbose - print when generating or regenerating (i.e. cache misses) +# :debug - print all cache activity including hits +const BATCHED_CACHE_VERBOSITY = Ref{Symbol}(:silent) + +# Cache stores (batched_function, world_age_when_cached) +const BATCHED_FUNC_CACHE = Dict{Tuple{Any,Type},Tuple{Any,UInt}}() + +""" + clear_batched_cache!() + +Clear the batched function cache. Useful for debugging or forcing regeneration. +""" +function clear_batched_cache!() + empty!(BATCHED_FUNC_CACHE) + return nothing +end function Broadcast.materialize(bc::Broadcasted{BatchedStyle}) f = bc.f @@ -198,17 +217,30 @@ function Broadcast.materialize(bc::Broadcasted{BatchedStyle}) argtypes = Tuple{map(typeof, args)...} key = (f, argtypes) - # HACK: caching was issues when functions were modified - # if !haskey(BATCHED_FUNC_CACHE, key) - # println(" [Generating batched version of $f]") - # batched_f = generate_batched_function(f, argtypes) - # BATCHED_FUNC_CACHE[key] = batched_f - # end + # Get element types for method lookup + element_types = Tuple{map(ir_element_type, argtypes.parameters)...} - # batched_f = BATCHED_FUNC_CACHE[key] - # return Base.invokelatest(batched_f, nothing, args...) + if haskey(BATCHED_FUNC_CACHE, key) + batched_f, cached_world = BATCHED_FUNC_CACHE[key] + # Check if the method has been redefined since caching + m = which(f, element_types) + if m.primary_world <= cached_world + if BATCHED_CACHE_VERBOSITY[] == :debug + println(" [Using cached batched version of $f]") + end + return Base.invokelatest(batched_f, nothing, args...) + end + if BATCHED_CACHE_VERBOSITY[] in (:verbose, :debug) + println(" [Regenerating batched version of $f (method redefined)]") + end + else + if BATCHED_CACHE_VERBOSITY[] in (:verbose, :debug) + println(" [Generating batched version of $f]") + end + end - println(" [Generating batched version of $f]") batched_f = generate_batched_function(f, argtypes) + current_world = Base.get_world_counter() + BATCHED_FUNC_CACHE[key] = (batched_f, current_world) return Base.invokelatest(batched_f, nothing, args...) end diff --git a/research/batching/batching_demo.jl b/research/batching/batching_demo.jl index 99ea2f77..303db60d 100644 --- a/research/batching/batching_demo.jl +++ b/research/batching/batching_demo.jl @@ -21,6 +21,8 @@ D_state = 2 D_obs = 2 N = 3 +BATCHED_CACHE_VERBOSITY[] = :debug + function kalman_predict(state, dyn_params) A = dyn_params[1] b = dyn_params[2] From 5e0996d10e2e333e0dc0c138351b9efa68eb1fed Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 15:26:12 +0000 Subject: [PATCH 10/29] Implement reusual queue (3x speedup) --- GeneralisedFilters/src/batching/wrappers.jl | 151 +++++++++++--------- 1 file changed, 82 insertions(+), 69 deletions(-) diff --git a/GeneralisedFilters/src/batching/wrappers.jl b/GeneralisedFilters/src/batching/wrappers.jl index 969a6bb8..56de9a57 100644 --- a/GeneralisedFilters/src/batching/wrappers.jl +++ b/GeneralisedFilters/src/batching/wrappers.jl @@ -3,6 +3,49 @@ using Magma.LibMagma using LinearAlgebra: cholesky, Cholesky, LowerTriangular using StructArrays: StructArray +export get_magma_queue, reset_magma_queue! + +# ============================================================================= +# MAGMA Queue Management +# ============================================================================= + +# Store the queue pointer directly (not in a Ref) to avoid lifetime issues +mutable struct MagmaQueueCache + ptr::LibMagma.magma_queue_t + initialized::Bool +end + +const MAGMA_QUEUE_CACHE = MagmaQueueCache(LibMagma.magma_queue_t(), false) + +""" + get_magma_queue() + +Get the cached MAGMA queue, creating it on first use. +Returns the raw queue pointer for use with MAGMA functions. +""" +function get_magma_queue() + if !MAGMA_QUEUE_CACHE.initialized + queue_ref = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ref, C_NULL, C_NULL, 0) + MAGMA_QUEUE_CACHE.ptr = queue_ref[] + MAGMA_QUEUE_CACHE.initialized = true + end + return MAGMA_QUEUE_CACHE.ptr +end + +""" + reset_magma_queue!() + +Destroy and recreate the MAGMA queue. Useful if you suspect queue corruption. +""" +function reset_magma_queue!() + if MAGMA_QUEUE_CACHE.initialized + LibMagma.magma_queue_destroy_internal(MAGMA_QUEUE_CACHE.ptr, C_NULL, C_NULL, 0) + MAGMA_QUEUE_CACHE.initialized = false + end + return get_magma_queue() +end + # ============================================================================= # Trivial Wrappers (reductions and elementwise operations) # ============================================================================= @@ -127,8 +170,7 @@ function gemm_batched!( lddc = m CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magma_sgemm_batched( magma_trans(transA), magma_trans(transB), @@ -144,10 +186,9 @@ function gemm_batched!( dC, lddc, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dB) @@ -182,8 +223,7 @@ function gemm_batched_smallsq!( lddc = m CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_sgemm_batched_smallsq( magma_trans(transA), magma_trans(transB), @@ -205,10 +245,9 @@ function gemm_batched_smallsq!( 0, # cj lddc, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dB) @@ -239,8 +278,7 @@ function gemm_batched_smallsq!( lddc = m CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_dgemm_batched_smallsq( magma_trans(transA), magma_trans(transB), @@ -262,10 +300,9 @@ function gemm_batched_smallsq!( 0, # cj lddc, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dB) @@ -298,13 +335,11 @@ function gemv_batched!( incy = 1 CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_sgemv_batched( - magma_trans(transA), m, n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue_ptr[] + magma_trans(transA), m, n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dx) @@ -337,13 +372,11 @@ function gemv_batched!( incy = 1 CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_dgemv_batched( - magma_trans(transA), m, n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue_ptr[] + magma_trans(transA), m, n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dx) @@ -376,13 +409,11 @@ function gemv_batched_smallsq!( incy = 1 CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_sgemv_batched_smallsq( - magma_trans(transA), n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue_ptr[] + magma_trans(transA), n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dx) @@ -399,13 +430,9 @@ function potrf_batched!(uplo::Char, A::BatchedCuMatrix{Float32}, info::CuVector{ dA = create_pointer_array(A) CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) - LibMagma.magma_spotrf_batched( - magma_uplo(uplo), n, dA, lda, pointer(info), N, queue_ptr[] - ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + queue = get_magma_queue() + LibMagma.magma_spotrf_batched(magma_uplo(uplo), n, dA, lda, pointer(info), N, queue) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) @@ -426,13 +453,9 @@ function potrs_batched!( lddb = n CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) - LibMagma.magma_spotrs_batched( - magma_uplo(uplo), n, nrhs, dA, ldda, dB, lddb, N, queue_ptr[] - ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + queue = get_magma_queue() + LibMagma.magma_spotrs_batched(magma_uplo(uplo), n, nrhs, dA, ldda, dB, lddb, N, queue) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dB) @@ -459,8 +482,7 @@ function trsm_batched!( lddb = m CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_strsm_batched( magma_side(side), magma_uplo(uplo), @@ -474,10 +496,9 @@ function trsm_batched!( dB, lddb, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dB) @@ -543,8 +564,7 @@ function trmm_batched!( lddb = m CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_strmm_batched( magma_side(side), magma_uplo(uplo), @@ -558,10 +578,9 @@ function trmm_batched!( dB, lddb, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dB) @@ -588,8 +607,7 @@ function trmm_batched!( lddb = m CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_dtrmm_batched( magma_side(side), magma_uplo(uplo), @@ -603,10 +621,9 @@ function trmm_batched!( dB, lddb, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dB) @@ -637,8 +654,7 @@ function syrk_batched!( lddc = n CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_ssyrk_batched( magma_uplo(uplo), magma_trans(trans), @@ -651,10 +667,9 @@ function syrk_batched!( dC, lddc, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dC) @@ -681,8 +696,7 @@ function syrk_batched!( lddc = n CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_dsyrk_batched( magma_uplo(uplo), magma_trans(trans), @@ -695,10 +709,9 @@ function syrk_batched!( dC, lddc, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dC) From 7d1bfb37f864d6fd180b6a7630fa27510473e38c Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 16:03:06 +0000 Subject: [PATCH 11/29] Add benchmarking to demo script --- research/batching/batching_demo.jl | 118 ++++++++++++++++++++++++++++- 1 file changed, 115 insertions(+), 3 deletions(-) diff --git a/research/batching/batching_demo.jl b/research/batching/batching_demo.jl index 303db60d..8830fc1d 100644 --- a/research/batching/batching_demo.jl +++ b/research/batching/batching_demo.jl @@ -71,7 +71,7 @@ b_test = Array(bs[end]) Q_test = Array(Qs.data) pred_G_test = kalman_predict(MvNormal(μ_test, PDMat(Σ_test)), (A_test, b_test, Q_test)) -println("=== Predict Comparison ===") +println("=== Predict Comparison ===\n") println("CPU Mean: ", pred_G_test.μ) println("GPU Mean: ", Array(pred_Gs.μ[end])) @@ -135,7 +135,7 @@ obs_test = Array(observations[end]) update_G_test = kalman_update(pred_G_test, (H_test, c_test, R_test), obs_test) -println("\n=== Update Comparison ===") +println("\n=== Update Comparison ===\n") println("CPU Mean: ", update_G_test.μ) println("GPU Mean: ", Array(update_Gs.μ[end])) @@ -157,9 +157,121 @@ step_G_test = kalman_step( obs_test, ) -println("\n=== Full Step Comparison ===") +println("\n=== Full Step Comparison ===\n") println("CPU Mean: ", step_G_test.μ) println("GPU Mean: ", Array(step_Gs.μ[end])) println("CPU Covariance: ", Matrix(step_G_test.Σ)) println("GPU Covariance: ", Array(step_Gs.Σ.mat[end])) + +# ============================================================================= +# Benchmarking +# ============================================================================= + +using BenchmarkTools +using StaticArrays + +D_bench = 10 +N_bench = 60000 + +println("\n=== Benchmarking batched Kalman step ===\n") + +println("D = $D_bench, N = $N_bench") +println( + "Size of batched covariance matrices: ", + round(Int, D_bench * D_bench * N_bench * sizeof(Float32) / 1024^2), + " MB", +) + +Is_bench = SharedCuMatrix(CuArray{Float32}(I, D_bench, D_bench)) +μs_bench = BatchedCuVector(CUDA.randn(Float32, D_bench, N_bench)) +Σs_root_bench = BatchedCuMatrix(CUDA.randn(Float32, D_bench, D_bench, N_bench)) +Σs_bench = Σs_root_bench .* adjoint.(Σs_root_bench) .+ Is_bench +Σ_PDs_bench = broadcasted(PDMat, Σs_bench); +Gs_bench = MvNormal.(μs_bench, Σ_PDs_bench); + +As_bench = SharedCuMatrix(CUDA.randn(Float32, D_bench, D_bench)) +bs_bench = BatchedCuVector(CUDA.randn(Float32, D_bench, N_bench)) +Qs_root_bench = CUDA.randn(Float32, D_bench, D_bench) +Qs_bench_mat = Qs_root_bench * adjoint(Qs_root_bench) +Qs_bench_mat += I +Qs_bench = SharedCuMatrix(Qs_bench_mat) +dyn_params_bench = (As_bench, bs_bench, Qs_bench) + +Hs_bench = SharedCuMatrix(CUDA.randn(Float32, D_bench, D_bench)) +cs_bench = SharedCuVector(CUDA.randn(Float32, D_bench)) +Rs_root_bench = BatchedCuMatrix(CUDA.randn(Float32, D_bench, D_bench, N_bench)) +Rs_bench = Rs_root_bench .* adjoint.(Rs_root_bench) .+ Is_bench +Rs_bench = PDMat.(Rs_bench) +obs_params_bench = (Hs_bench, cs_bench, Rs_bench) + +ys_bench = BatchedCuVector(CUDA.randn(Float32, D_bench, N_bench)) + +println("\nBenchmarking batched Kalman step...\n") + +BATCHED_CACHE_VERBOSITY[] = :silent +display( + @benchmark kalman_step.( + $Gs_bench, Ref($dyn_params_bench), Ref($obs_params_bench), $ys_bench + ) +) + +# Compare to static arrays +μs_static = [@SVector randn(Float32, D_bench) for _ in 1:N_bench] +Σs_static = [ + begin + A = @SMatrix randn(Float32, D_bench, D_bench) + A * A' + I + end for _ in 1:N_bench +] +Gs_static = [MvNormal(μs_static[n], PDMat(Σs_static[n])) for n in 1:N_bench] + +As_static = @SMatrix randn(Float32, D_bench, D_bench) +bs_static = [@SVector randn(Float32, D_bench) for _ in 1:N_bench] +Qs_root_static = @SMatrix randn(Float32, D_bench, D_bench) +Qs_static_mat = Qs_root_static * adjoint(Qs_root_static) + I +Qs_static = Qs_static_mat +dyn_params_static = (As_static, bs_static, Qs_static) + +Hs_static = @SMatrix randn(Float32, D_bench, D_bench) +cs_static = @SVector randn(Float32, D_bench) +Rs_root_static = [@SMatrix randn(Float32, D_bench, D_bench) for _ in 1:N_bench] +Rs_static = [ + Symmetric(R_root * R_root') + SMatrix{D_bench,D_bench,Float32}(I) for + R_root in Rs_root_static +] +obs_static = [@SVector randn(Float32, D_bench) for _ in 1:N_bench] +obs_params_static = (Hs_static, cs_static, PDMat.(Rs_static)) + +function test_static(Gs, dyn_params, obs_params, observations, out) + N = length(out) + Threads.@threads for n in 1:N + @inbounds out[n] = kalman_step( + Gs[n], + (dyn_params[1], dyn_params[2][n], dyn_params[3]), + (obs_params[1], obs_params[2], obs_params[3][n]), + observations[n], + ) + end + return out +end + +out = Vector{eltype(Gs_static)}(undef, N_bench) + +println("\nBenchmarking static Kalman step...\n") + +test_static(Gs_static, dyn_params_static, obs_params_static, obs_static, out); # warm-up + +display( + @benchmark test_static( + $Gs_static, $dyn_params_static, $obs_params_static, $obs_static, $out + ) +) + +println("\nBenchmarking single batched matrix multiplication...\n") + +display(@benchmark $Σs_root_bench .* $Σs_root_bench) + +tot_mem = 3 * D_bench * D_bench * N_bench * sizeof(Float32) +throughput = 1008 * 10^9 +println("\nTheoretical optimum time: ", tot_mem / throughput * 10^6, " μs") From 2c1093b065d633e89df19014990115efd67074d4 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 6 Jan 2026 15:58:27 +0000 Subject: [PATCH 12/29] Initial batching demo --- GeneralisedFilters/Project.toml | 5 + GeneralisedFilters/src/GeneralisedFilters.jl | 8 + .../src/batching/broadcasting.jl | 202 +++++ GeneralisedFilters/src/batching/operations.jl | 214 ++++++ GeneralisedFilters/src/batching/types.jl | 222 ++++++ GeneralisedFilters/src/batching/wrappers.jl | 720 ++++++++++++++++++ research/batching/batching_demo.jl | 113 +++ 7 files changed, 1484 insertions(+) create mode 100644 GeneralisedFilters/src/batching/broadcasting.jl create mode 100644 GeneralisedFilters/src/batching/operations.jl create mode 100644 GeneralisedFilters/src/batching/types.jl create mode 100644 GeneralisedFilters/src/batching/wrappers.jl create mode 100644 research/batching/batching_demo.jl diff --git a/GeneralisedFilters/Project.toml b/GeneralisedFilters/Project.toml index 198b7ef3..682265e9 100644 --- a/GeneralisedFilters/Project.toml +++ b/GeneralisedFilters/Project.toml @@ -9,8 +9,10 @@ AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +Magma = "a4173727-5e3e-4567-b12d-2e3cf2fa2f28" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -18,6 +20,7 @@ SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [compat] AbstractMCMC = "5" @@ -26,6 +29,7 @@ Aqua = "0.8" CUDA = "5" DataStructures = "0.18.20, 0.19" Distributions = "0.25" +IRTools = "0.4.15" LogExpFunctions = "0.3" OffsetArrays = "1.14.1" PDMats = "0.11.35" @@ -33,6 +37,7 @@ SSMProblems = "0.6" StaticArrays = "1.9.16" Statistics = "1" StatsBase = "0.34.3" +StructArrays = "0.7.2" Test = "1" julia = "1.10" diff --git a/GeneralisedFilters/src/GeneralisedFilters.jl b/GeneralisedFilters/src/GeneralisedFilters.jl index 1fdf3024..fe8f38d9 100644 --- a/GeneralisedFilters/src/GeneralisedFilters.jl +++ b/GeneralisedFilters/src/GeneralisedFilters.jl @@ -11,11 +11,19 @@ using StatsBase # TODO: heavy modules—move to extension using CUDA +export initialise, step, predict, update, filter + # Filtering utilities include("callbacks.jl") include("containers.jl") include("resamplers.jl") +# Batching utilities +include("batching/types.jl") +include("batching/broadcasting.jl") +include("batching/wrappers.jl") +include("batching/operations.jl") + ## FILTERING BASE ########################################################################## abstract type AbstractFilter <: AbstractSampler end diff --git a/GeneralisedFilters/src/batching/broadcasting.jl b/GeneralisedFilters/src/batching/broadcasting.jl new file mode 100644 index 00000000..92d50811 --- /dev/null +++ b/GeneralisedFilters/src/batching/broadcasting.jl @@ -0,0 +1,202 @@ +using IRTools +using IRTools: @code_ir, IR, Statement, Variable, func +using StructArrays +using LinearAlgebra: I, UniformScaling + +using Base.Broadcast: Broadcasted, BroadcastStyle, DefaultArrayStyle +import Base.Broadcast: broadcasted + +import PDMats: PDMat + +# ============================================================================= +# Broadcast Style +# ============================================================================= + +struct BatchedStyle <: Broadcast.BroadcastStyle end + +Base.BroadcastStyle(::Type{<:BatchedCuMatrix}) = BatchedStyle() +Base.BroadcastStyle(::Type{<:BatchedCuVector}) = BatchedStyle() +Base.BroadcastStyle(::Type{<:SharedCuMatrix}) = BatchedStyle() +Base.BroadcastStyle(::Type{<:SharedCuVector}) = BatchedStyle() +Base.BroadcastStyle(::Type{<:BatchedCholesky}) = BatchedStyle() +Base.BroadcastStyle(::Type{<:BatchedPDMat}) = BatchedStyle() +# HACK: Currently hard-coded but can be replaced with a custom StructArray type +Base.BroadcastStyle(::Type{<:StructArray}) = BatchedStyle() +Base.BroadcastStyle(::BatchedStyle, ::BatchedStyle) = BatchedStyle() +Base.BroadcastStyle(::BatchedStyle, ::DefaultArrayStyle{0}) = BatchedStyle() + +# ============================================================================= +# Ref Conversion (for Shared arrays) +# ============================================================================= + +maybe_convert_ref(x) = x +function maybe_convert_ref(r::Base.RefValue{<:CuVector{T}}) where {T} + return SharedCuVector{T,CuVector{T}}(r[]) +end +function maybe_convert_ref(r::Base.RefValue{<:CuMatrix{T}}) where {T} + return SharedCuMatrix{T,CuMatrix{T}}(r[]) +end + +# ============================================================================= +# Structural Operations (Pass-through) +# ============================================================================= + +broadcasted(::typeof(tuple), args...) = tuple(args...) +broadcasted(::typeof(getproperty), x, s::Symbol) = getproperty(x, s) +broadcasted(::typeof(getfield), x, s::Symbol) = getfield(x, s) +broadcasted(::typeof(getfield), x, i::Int) = getfield(x, i) + +# Special handling for RefValue - unwrap before indexing +broadcasted(::typeof(getfield), r::Base.RefValue, i::Int) = getfield(r[], i) +broadcasted(::typeof(getfield), r::Base.RefValue, s::Symbol) = getfield(r[], s) + +# ============================================================================= +# StructArray Wrapping +# ============================================================================= + +inner_eltype(arg::BatchedCuVector{T}) where {T} = CuVector{T} +inner_eltype(arg::BatchedCuMatrix{T}) where {T} = CuMatrix{T} +inner_eltype(arg::SharedCuVector{T}) where {T} = CuVector{T} +inner_eltype(arg::SharedCuMatrix{T}) where {T} = CuMatrix{T} +inner_eltype(arg::BatchedPDMat{T}) where {T} = PDMat{T,CuMatrix{T}} +inner_eltype(arg) = typeof(arg) + +function wrap_if_batched(::Type{T}, args...) where {T} + if any(arg -> arg isa Union{BatchedArray,SharedArray,BatchedPDMat}, args) + field_names = fieldnames(T) + element_types = Tuple{map(inner_eltype, args)...} + ElType = Core.Compiler.return_type(T, element_types) + nt = NamedTuple{field_names}(args) + return StructArray{ElType}(nt) + else + return T(args...) + end +end + +# ============================================================================= +# IR Transformation +# ============================================================================= + +const SKIP_BROADCAST = Set{Any}([ + tuple, + Core.tuple, + getfield, + getproperty, + adjoint, + transpose, + LowerTriangular, + UpperTriangular, +]) + +const BROADCAST_TYPES = Set{Any}([PDMat]) + +maybe_wrap_scalar(x) = x +maybe_wrap_scalar(x::UniformScaling) = Ref(x) + +@inline function broadcast_and_materialize(f, args...) + wrapped_args = map(maybe_wrap_scalar, args) + result = broadcasted(f, wrapped_args...) + if result isa Broadcasted + return Broadcast.materialize(result) + end + return result +end + +function resolve_to_type(ir::IR, val) + val isa Type && return val + if val isa Variable + stmt = ir[val] + if stmt !== nothing + return resolve_to_type(ir, stmt.expr) + end + end + if val isa GlobalRef + try + resolved = getfield(val.mod, val.name) + return resolved isa Type ? resolved : nothing + catch + return nothing + end + end + return nothing +end + +function is_type_ref(ir::IR, val) + return resolve_to_type(ir, val) !== nothing +end + +function is_broadcast_type(ir::IR, val) + resolved = resolve_to_type(ir, val) + return resolved !== nothing && resolved in BROADCAST_TYPES +end + +function transform_to_batched(ir::IR) + ir = copy(ir) + + for (v, stmt) in ir + if stmt.expr isa Expr && stmt.expr.head == :call + fn = stmt.expr.args[1] + if fn in SKIP_BROADCAST + continue + end + if is_broadcast_type(ir, fn) + new_args = [broadcast_and_materialize, stmt.expr.args...] + ir[v] = Statement(stmt; expr=Expr(:call, new_args...)) + continue + end + if is_type_ref(ir, fn) + new_args = [wrap_if_batched, stmt.expr.args...] + ir[v] = Statement(stmt; expr=Expr(:call, new_args...)) + continue + end + new_args = [broadcast_and_materialize, stmt.expr.args...] + ir[v] = Statement(stmt; expr=Expr(:call, new_args...)) + end + end + + return ir +end + +ir_element_type(::Type{T}) where {T} = T +ir_element_type(::Type{<:StructArray{T}}) where {T} = T + +function generate_batched_function(f, argtypes::Type{<:Tuple}) + element_types = Tuple{map(ir_element_type, argtypes.parameters)...} + + ir = IRTools.Inner.code_ir(f, element_types) + if ir === nothing + error( + "Could not get IR for function $f with types $element_types (original: $argtypes)", + ) + end + batched_ir = transform_to_batched(ir) + return IRTools.func(batched_ir) +end + +# ============================================================================= +# Broadcast Materialization +# ============================================================================= + +const BATCHED_FUNC_CACHE = Dict{Tuple,Any}() + +function Broadcast.materialize(bc::Broadcasted{BatchedStyle}) + f = bc.f + args = map(maybe_convert_ref, bc.args) + + result = broadcasted(f, args...) + if !(result isa Broadcasted) + return result + end + + argtypes = Tuple{map(typeof, args)...} + key = (f, argtypes) + + if !haskey(BATCHED_FUNC_CACHE, key) + println(" [Generating batched version of $f]") + batched_f = generate_batched_function(f, argtypes) + BATCHED_FUNC_CACHE[key] = batched_f + end + + batched_f = BATCHED_FUNC_CACHE[key] + return Base.invokelatest(batched_f, nothing, args...) +end diff --git a/GeneralisedFilters/src/batching/operations.jl b/GeneralisedFilters/src/batching/operations.jl new file mode 100644 index 00000000..8464fc61 --- /dev/null +++ b/GeneralisedFilters/src/batching/operations.jl @@ -0,0 +1,214 @@ +import PDMats: X_A_Xt + +# ============================================================================= +# Adjoint/Transpose Broadcasting +# ============================================================================= + +function broadcasted(::typeof(adjoint), A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} + return BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}(A.data) +end + +function broadcasted(::typeof(transpose), A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} + return BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}(A.data) +end + +function broadcasted(::typeof(adjoint), A::SharedCuMatrix{T,CuMatrix{T}}) where {T} + return SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}}(A.data) +end + +function broadcasted(::typeof(transpose), A::SharedCuMatrix{T,CuMatrix{T}}) where {T} + return SharedCuMatrix{T,Transpose{T,CuMatrix{T}}}(A.data) +end + +function broadcasted( + ::typeof(adjoint), A::BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}} +) where {T} + return BatchedCuMatrix{T,CuMatrix{T}}(A.data) +end + +function broadcasted( + ::typeof(adjoint), A::SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}} +) where {T} + return SharedCuMatrix{T,CuMatrix{T}}(A.data) +end + +function broadcasted(::Type{LowerTriangular}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} + return BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}(A.data) +end + +function broadcasted(::Type{UpperTriangular}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} + return BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}(A.data) +end + +# ============================================================================= +# Matrix Multiply Broadcasting +# ============================================================================= + +function broadcasted( + ::typeof(*), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + transA = trans_flag(A) + transB = trans_flag(B) + + A_inner = inner_size_for_blas(A) + B_inner = inner_size_for_blas(B) + + m = transA == 'N' ? A_inner[1] : A_inner[2] + n = transB == 'N' ? B_inner[2] : B_inner[1] + N = get_batch_size(A, B) + + C_data = CuArray{T}(undef, m, n, N) + C = BatchedCuMatrix(C_data) + + gemm_batched!(transA, transB, one(T), A, B, zero(T), C) + return C +end + +# Multi-argument multiply +function broadcasted( + ::typeof(*), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + C::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + rest::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}..., +) where {T} + result = broadcasted(*, A, B) + result = broadcasted(*, result, C) + for R in rest + result = broadcasted(*, result, R) + end + return result +end + +# ============================================================================= +# Matrix-Vector Multiply Broadcasting +# ============================================================================= + +function broadcasted( + ::typeof(*), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + x::Union{BatchedCuVector{T},SharedCuVector{T}}, +) where {T} + transA = trans_flag(A) + A_inner = inner_size_for_blas(A) + m = transA == 'N' ? A_inner[1] : A_inner[2] + N = get_batch_size(A, x) + + y_data = CuArray{T}(undef, m, N) + y = BatchedCuVector(y_data) + + gemv_batched!(transA, one(T), A, x, zero(T), y) + + return y +end + +# ============================================================================= +# Identity Minus Matrix (I - A) Custom Kernel +# ============================================================================= + +function identity_minus_kernel!(C, A, m, n) + batch_idx = blockIdx().z + i = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + j = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y + + if i <= m && j <= n + if i == j + @inbounds C[i, j, batch_idx] = one(eltype(C)) - A[i, j, batch_idx] + else + @inbounds C[i, j, batch_idx] = -A[i, j, batch_idx] + end + end + return nothing +end + +function identity_minus_batched!(C::CuArray{T,3}, A::CuArray{T,3}) where {T} + m, n, N = size(A) + threads = (16, 16) + blocks = (cld(m, 16), cld(n, 16), N) + @cuda threads = threads blocks = blocks identity_minus_kernel!(C, A, m, n) + return C +end + +function broadcasted( + ::typeof(-), ::Base.RefValue{UniformScaling{Bool}}, A::BatchedCuMatrix{T} +) where {T} + C = CuArray{T}(undef, size(A.data)) + identity_minus_batched!(C, A.data) + return BatchedCuMatrix(C) +end + +# ============================================================================= +# PDMat Broadcasting +# ============================================================================= + +function broadcasted(::Type{PDMat}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} + chol = cholesky_batched(A) + return BatchedPDMat{T}(chol) +end + +function broadcasted(::typeof(\), S::BatchedPDMat{T}, A::BatchedCuMatrix{T}) where {T} + return pdmat_solve(S, A) +end + +function broadcasted(::typeof(/), A::BatchedCuMatrix{T}, S::BatchedPDMat{T}) where {T} + # Need to actually transpose the data, not just wrap it + At_data = permutedims(A.data, (2, 1, 3)) + At = BatchedCuMatrix(At_data) + result_t = pdmat_solve(S, At) + # Transpose back + result_data = permutedims(result_t.data, (2, 1, 3)) + return BatchedCuMatrix(result_data) +end + +# ============================================================================= +# Quadratic Form Broadcasting +# ============================================================================= + +# HACK: treat this as two GEMMs for now +function broadcasted( + ::typeof(X_A_Xt), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + temp = broadcasted(*, X, A) + Xt = broadcasted(adjoint, X) + return broadcasted(*, temp, Xt) +end + +# X_A_Xt for BatchedPDMat: X * P * X' where P = L * L' +# Computed as (X * L) * (X * L)' using TRMM and SYRK +function broadcasted( + ::typeof(X_A_Xt), + P::BatchedPDMat{T}, + X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + L = P.chol.factors + N = get_batch_size(P, X) + + X_inner = inner_size_for_blas(X) + m = X_inner[1] + + # Copy X to XL (TRMM overwrites in-place) + XL_data = if X isa SharedCuMatrix + repeat(reshape(X.data, size(X.data, 1), size(X.data, 2), 1), 1, 1, N) + else + copy(X.data) + end + XL = BatchedCuMatrix(XL_data) + + # XL = X * L using TRMM (side='R' for right multiply, uplo='L' for lower triangular) + L_data = BatchedCuMatrix(L.data) + trmm_batched!('R', 'L', 'N', 'N', one(T), L_data, XL) + + # Result = XL * XL' using SYRK (fills lower triangle) + Result_data = CuArray{T}(undef, m, m, N) + Result = BatchedCuMatrix(Result_data) + syrk_batched!('L', 'N', one(T), XL, zero(T), Result) + + # Symmetrize: copy lower triangle to upper + symmetrize_lower!(Result) + + return Result +end diff --git a/GeneralisedFilters/src/batching/types.jl b/GeneralisedFilters/src/batching/types.jl new file mode 100644 index 00000000..7ae2b399 --- /dev/null +++ b/GeneralisedFilters/src/batching/types.jl @@ -0,0 +1,222 @@ +using CUDA +using LinearAlgebra: + Adjoint, Transpose, LowerTriangular, UpperTriangular, UniformScaling, Cholesky +using PDMats: PDMat + +export BatchedCuMatrix, BatchedCuVector +export SharedCuMatrix, SharedCuVector +export BatchedPDMat, BatchedCholesky + +# ============================================================================= +# Core Batched Types +# ============================================================================= + +struct BatchedCuMatrix{T,Inner<:AbstractMatrix{T}} <: AbstractVector{Inner} + data::CuArray{T,3} +end + +struct BatchedCuVector{T,Inner<:AbstractVector{T}} <: AbstractVector{Inner} + data::CuMatrix{T} +end + +const BatchedArray = Union{BatchedCuVector,BatchedCuMatrix} + +BatchedCuMatrix(data::CuArray{T,3}) where {T} = BatchedCuMatrix{T,CuMatrix{T}}(data) +BatchedCuVector(data::CuMatrix{T}) where {T} = BatchedCuVector{T,CuVector{T}}(data) + +batch_size(x::BatchedCuVector) = size(x.data, 2) +batch_size(x::BatchedCuMatrix) = size(x.data, 3) + +Base.size(x::BatchedCuVector) = (batch_size(x),) +Base.size(x::BatchedCuMatrix) = (batch_size(x),) +Base.length(x::BatchedArray) = batch_size(x) + +inner_size(x::BatchedCuVector) = (size(x.data, 1),) +inner_size(x::BatchedCuMatrix) = (size(x.data, 1), size(x.data, 2)) + +function Base.getindex(x::BatchedCuVector{T,CuVector{T}}, i::Int) where {T} + return view(x.data, :, i) +end + +function Base.getindex(x::BatchedCuMatrix{T,CuMatrix{T}}, i::Int) where {T} + return view(x.data, :, :, i) +end + +function Base.getindex( + x::BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, i::Int +) where {T} + return LowerTriangular(view(x.data, :, :, i)) +end + +function Base.getindex( + x::BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}, i::Int +) where {T} + return UpperTriangular(view(x.data, :, :, i)) +end + +function Base.getindex(x::BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}, i::Int) where {T} + return adjoint(view(x.data, :, :, i)) +end + +function Base.getindex(x::BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}, i::Int) where {T} + return transpose(view(x.data, :, :, i)) +end + +# ============================================================================= +# Shared Types (same data reused across all batch elements) +# ============================================================================= + +struct SharedCuMatrix{T,Inner<:AbstractMatrix{T}} <: AbstractVector{Inner} + data::CuMatrix{T} +end + +struct SharedCuVector{T,Inner<:AbstractVector{T}} <: AbstractVector{Inner} + data::CuVector{T} +end + +const SharedArray = Union{SharedCuVector,SharedCuMatrix} + +SharedCuMatrix(data::CuMatrix{T}) where {T} = SharedCuMatrix{T,CuMatrix{T}}(data) +SharedCuVector(data::CuVector{T}) where {T} = SharedCuVector{T,CuVector{T}}(data) + +Shared(x::CuMatrix{T}) where {T} = SharedCuMatrix(x) +Shared(x::CuVector{T}) where {T} = SharedCuVector(x) + +batch_size(::SharedCuVector) = nothing +batch_size(::SharedCuMatrix) = nothing + +inner_size(x::SharedCuVector) = size(x.data) +inner_size(x::SharedCuMatrix) = size(x.data) + +Base.size(x::SharedCuVector) = (1,) +Base.size(x::SharedCuMatrix) = (1,) +Base.length(::SharedArray) = 1 + +Base.getindex(x::SharedCuVector, ::Int) = x.data +Base.getindex(x::SharedCuMatrix{T,CuMatrix{T}}, ::Int) where {T} = x.data +function Base.getindex(x::SharedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, ::Int) where {T} + return LowerTriangular(x.data) +end + +# ============================================================================= +# Type Aliases and Union Types for Dispatch +# ============================================================================= + +const AnyBatchedMatrix{T} = Union{ + BatchedCuMatrix{T,CuMatrix{T}}, + BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}, + BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}, + BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, + BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}, +} + +const AnySharedMatrix{T} = Union{ + SharedCuMatrix{T,CuMatrix{T}}, + SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}}, + SharedCuMatrix{T,Transpose{T,CuMatrix{T}}}, + SharedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, + SharedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}, +} + +const AnyMatrix{T} = Union{AnyBatchedMatrix{T},AnySharedMatrix{T}} +const AnyVector{T} = Union{BatchedCuVector{T},SharedCuVector{T}} + +# ============================================================================= +# Helper Functions +# ============================================================================= + +is_shared(::BatchedCuMatrix) = false +is_shared(::BatchedCuVector) = false +is_shared(::SharedCuMatrix) = true +is_shared(::SharedCuVector) = true + +unwrap_data(A::BatchedCuMatrix) = A.data +unwrap_data(A::SharedCuMatrix) = A.data +unwrap_data(x::BatchedCuVector) = x.data +unwrap_data(x::SharedCuVector) = x.data + +trans_flag(::BatchedCuMatrix{T,CuMatrix{T}}) where {T} = 'N' +trans_flag(::BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}) where {T} = T <: Real ? 'T' : 'C' +trans_flag(::BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}) where {T} = 'T' +trans_flag(::BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}) where {T} = 'N' +trans_flag(::BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}) where {T} = 'N' + +trans_flag(::SharedCuMatrix{T,CuMatrix{T}}) where {T} = 'N' +trans_flag(::SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}}) where {T} = T <: Real ? 'T' : 'C' +trans_flag(::SharedCuMatrix{T,Transpose{T,CuMatrix{T}}}) where {T} = 'T' +trans_flag(::SharedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}) where {T} = 'N' +trans_flag(::SharedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}) where {T} = 'N' + +function inner_size_for_blas(A::BatchedCuMatrix) + m, n = size(A.data, 1), size(A.data, 2) + return (m, n) +end + +function inner_size_for_blas(A::SharedCuMatrix) + m, n = size(A.data) + return (m, n) +end + +function get_batch_size(args...) + for arg in args + bs = batch_size(arg) + if bs !== nothing + return bs + end + end + return error("At least one argument must be batched") +end + +# ============================================================================= +# Stateful Wrapper Types (Cholesky, PDMat) +# ============================================================================= + +struct BatchedCholesky{T} <: AbstractVector{Cholesky{T,CuMatrix{T}}} + factors::BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}} + info::CuVector{Int32} + uplo::Char +end + +struct BatchedPDMat{T} <: AbstractVector{PDMat{T,CuMatrix{T}}} + chol::BatchedCholesky{T} +end + +batch_size(c::BatchedCholesky) = batch_size(c.factors) +batch_size(p::BatchedPDMat) = batch_size(p.chol) + +inner_size(c::BatchedCholesky) = inner_size(c.factors) +inner_size(p::BatchedPDMat) = inner_size(p.chol) + +Base.size(c::BatchedCholesky) = (batch_size(c),) +Base.size(p::BatchedPDMat) = (batch_size(p),) + +Base.getindex(p::BatchedPDMat{T}, i::Int) where {T} = p.chol.factors[i] * p.chol.factors[i]' + +# ============================================================================= +# Pointer Array Creation +# ============================================================================= + +function create_pointer_array(A::BatchedCuMatrix{T}) where {T} + return CUDA.CUBLAS.unsafe_strided_batch(A.data) +end + +function create_pointer_array(A::SharedCuMatrix{T}, N::Int) where {T} + base_ptr = pointer(A.data) + ptrs_cpu = fill(base_ptr, N) + return CuArray(ptrs_cpu) +end + +function create_pointer_array_vector(x::BatchedCuVector{T}) where {T} + n = size(x.data, 1) + N = size(x.data, 2) + base_ptr = pointer(x.data) + stride = n * sizeof(T) + ptrs = CuArray([base_ptr + (i - 1) * stride for i in 1:N]) + return ptrs +end + +function create_pointer_array_vector(x::SharedCuVector{T}, N::Int) where {T} + base_ptr = pointer(x.data) + ptrs_cpu = fill(base_ptr, N) + return CuArray(ptrs_cpu) +end diff --git a/GeneralisedFilters/src/batching/wrappers.jl b/GeneralisedFilters/src/batching/wrappers.jl new file mode 100644 index 00000000..6f7e741b --- /dev/null +++ b/GeneralisedFilters/src/batching/wrappers.jl @@ -0,0 +1,720 @@ +using Magma +using Magma.LibMagma + +# ============================================================================= +# Trivial Wrappers (reductions and elementwise operations) +# ============================================================================= + +function broadcasted( + ::typeof(+), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + if is_shared(A) && is_shared(B) + return SharedCuMatrix(A.data .+ B.data) + else + return BatchedCuMatrix(A.data .+ B.data) + end +end + +function broadcasted( + ::typeof(+), + a::Union{BatchedCuVector{T},SharedCuVector{T}}, + b::Union{BatchedCuVector{T},SharedCuVector{T}}, +) where {T} + if is_shared(a) && is_shared(b) + return SharedCuVector(a.data .+ b.data) + else + return BatchedCuVector(a.data .+ b.data) + end +end + +function broadcasted( + ::typeof(-), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + if is_shared(A) && is_shared(B) + return SharedCuMatrix(A.data .- B.data) + else + return BatchedCuMatrix(A.data .- B.data) + end +end + +function broadcasted( + ::typeof(-), + a::Union{BatchedCuVector{T},SharedCuVector{T}}, + b::Union{BatchedCuVector{T},SharedCuVector{T}}, +) where {T} + if is_shared(a) && is_shared(b) + return SharedCuVector(a.data .- b.data) + else + return BatchedCuVector(a.data .- b.data) + end +end + +# ============================================================================= +# MAGMA Constants Conversion +# ============================================================================= + +function magma_trans(c::Char) + if c == 'N' + return LibMagma.MagmaNoTrans + elseif c == 'T' + return LibMagma.MagmaTrans + elseif c == 'C' + return LibMagma.MagmaConjTrans + else + error("Unknown transpose char: $c") + end +end + +function magma_uplo(c::Char) + if c == 'L' + return LibMagma.MagmaLower + elseif c == 'U' + return LibMagma.MagmaUpper + else + error("Unknown uplo char: $c") + end +end + +function magma_side(c::Char) + if c == 'L' + return LibMagma.MagmaLeft + elseif c == 'R' + return LibMagma.MagmaRight + else + error("Unknown side char: $c") + end +end + +function magma_diag(c::Char) + if c == 'N' + return LibMagma.MagmaNonUnit + elseif c == 'U' + return LibMagma.MagmaUnit + else + error("Unknown diag char: $c") + end +end + +# ============================================================================= +# MAGMA Operations +# ============================================================================= + +function gemm_batched!( + transA::Char, + transB::Char, + alpha::Float32, + A::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + B::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + beta::Float32, + C::BatchedCuMatrix{Float32}, +) + N = batch_size(C) + m, n = size(C.data, 1), size(C.data, 2) + k = transA == 'N' ? size(unwrap_data(A), 2) : size(unwrap_data(A), 1) + + dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) + dB = B isa BatchedCuMatrix ? create_pointer_array(B) : create_pointer_array(B, N) + dC = create_pointer_array(C) + + ldda = size(unwrap_data(A), 1) + lddb = size(unwrap_data(B), 1) + lddc = m + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magma_sgemm_batched( + magma_trans(transA), + magma_trans(transB), + m, + n, + k, + alpha, + dA, + ldda, + dB, + lddb, + beta, + dC, + lddc, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + CUDA.unsafe_free!(dC) + + return C +end + +# ============================================================================= +# Part 3b: Batched GEMM Small Square Wrapper (for D < 32) +# ============================================================================= + +function gemm_batched_smallsq!( + transA::Char, + transB::Char, + alpha::Float32, + A::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + B::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + beta::Float32, + C::BatchedCuMatrix{Float32}, +) + N = batch_size(C) + m, n = size(C.data, 1), size(C.data, 2) + k = transA == 'N' ? size(unwrap_data(A), 2) : size(unwrap_data(A), 1) + + dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) + dB = B isa BatchedCuMatrix ? create_pointer_array(B) : create_pointer_array(B, N) + dC = create_pointer_array(C) + + ldda = size(unwrap_data(A), 1) + lddb = size(unwrap_data(B), 1) + lddc = m + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_sgemm_batched_smallsq( + magma_trans(transA), + magma_trans(transB), + m, + n, + k, + alpha, + dA, + 0, # ai + 0, # aj + ldda, + dB, + 0, # bi + 0, # bj + lddb, + beta, + dC, + 0, # ci + 0, # cj + lddc, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + CUDA.unsafe_free!(dC) + + return C +end + +function gemm_batched_smallsq!( + transA::Char, + transB::Char, + alpha::Float64, + A::Union{BatchedCuMatrix{Float64},SharedCuMatrix{Float64}}, + B::Union{BatchedCuMatrix{Float64},SharedCuMatrix{Float64}}, + beta::Float64, + C::BatchedCuMatrix{Float64}, +) + N = batch_size(C) + m, n = size(C.data, 1), size(C.data, 2) + k = transA == 'N' ? size(unwrap_data(A), 2) : size(unwrap_data(A), 1) + + dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) + dB = B isa BatchedCuMatrix ? create_pointer_array(B) : create_pointer_array(B, N) + dC = create_pointer_array(C) + + ldda = size(unwrap_data(A), 1) + lddb = size(unwrap_data(B), 1) + lddc = m + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_dgemm_batched_smallsq( + magma_trans(transA), + magma_trans(transB), + m, + n, + k, + alpha, + dA, + 0, # ai + 0, # aj + ldda, + dB, + 0, # bi + 0, # bj + lddb, + beta, + dC, + 0, # ci + 0, # cj + lddc, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + CUDA.unsafe_free!(dC) + + return C +end + +function gemv_batched!( + transA::Char, + alpha::Float32, + A::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + x::Union{BatchedCuVector{Float32},SharedCuVector{Float32}}, + beta::Float32, + y::BatchedCuVector{Float32}, +) + N = batch_size(y) + m, n = size(unwrap_data(A), 1), size(unwrap_data(A), 2) + + dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) + dx = if x isa BatchedCuVector + create_pointer_array_vector(x) + else + create_pointer_array_vector(x, N) + end + dy = create_pointer_array_vector(y) + + ldda = m + incx = 1 + incy = 1 + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_sgemv_batched( + magma_trans(transA), m, n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue_ptr[] + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dx) + CUDA.unsafe_free!(dy) + + return y +end + +function gemv_batched!( + transA::Char, + alpha::Float64, + A::Union{BatchedCuMatrix{Float64},SharedCuMatrix{Float64}}, + x::Union{BatchedCuVector{Float64},SharedCuVector{Float64}}, + beta::Float64, + y::BatchedCuVector{Float64}, +) + N = batch_size(y) + m, n = size(unwrap_data(A), 1), size(unwrap_data(A), 2) + + dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) + dx = if x isa BatchedCuVector + create_pointer_array_vector(x) + else + create_pointer_array_vector(x, N) + end + dy = create_pointer_array_vector(y) + + ldda = m + incx = 1 + incy = 1 + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_dgemv_batched( + magma_trans(transA), m, n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue_ptr[] + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dx) + CUDA.unsafe_free!(dy) + + return y +end + +function gemv_batched_smallsq!( + transA::Char, + alpha::Float32, + A::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + x::Union{BatchedCuVector{Float32},SharedCuVector{Float32}}, + beta::Float32, + y::BatchedCuVector{Float32}, +) + N = batch_size(y) + n = size(unwrap_data(A), 1) + + dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) + dx = if x isa BatchedCuVector + create_pointer_array_vector(x) + else + create_pointer_array_vector(x, N) + end + dy = create_pointer_array_vector(y) + + ldda = n + incx = 1 + incy = 1 + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_sgemv_batched_smallsq( + magma_trans(transA), n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue_ptr[] + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dx) + CUDA.unsafe_free!(dy) + + return y +end + +function potrf_batched!(uplo::Char, A::BatchedCuMatrix{Float32}) + N = batch_size(A) + n = size(A.data, 1) + lda = n + + dA = create_pointer_array(A) + info_gpu = CUDA.zeros(Int64, N) + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magma_spotrf_batched( + magma_uplo(uplo), n, dA, lda, pointer(info_gpu), N, queue_ptr[] + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + + factors = BatchedCuMatrix{Float32,LowerTriangular{Float32,CuMatrix{Float32}}}(A.data) + return BatchedCholesky{Float32}(factors, info_gpu, uplo) +end + +function potrs_batched!( + uplo::Char, A::BatchedCuMatrix{Float32}, B::BatchedCuMatrix{Float32} +) + N = batch_size(B) + n = size(A.data, 1) + nrhs = size(B.data, 2) + + dA = create_pointer_array(A) + dB = create_pointer_array(B) + + ldda = n + lddb = n + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magma_spotrs_batched( + magma_uplo(uplo), n, nrhs, dA, ldda, dB, lddb, N, queue_ptr[] + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + + return B +end + +function trsm_batched!( + side::Char, + uplo::Char, + transA::Char, + diag::Char, + alpha::Float32, + A::BatchedCuMatrix{Float32}, + B::BatchedCuMatrix{Float32}, +) + N = batch_size(B) + m, n = size(B.data, 1), size(B.data, 2) + + dA = create_pointer_array(A) + dB = create_pointer_array(B) + + ldda = size(A.data, 1) + lddb = m + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_strsm_batched( + magma_side(side), + magma_uplo(uplo), + magma_trans(transA), + magma_diag(diag), + m, + n, + alpha, + dA, + ldda, + dB, + lddb, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + + return B +end + +# ============================================================================= +# Higher-level Cholesky Operations +# ============================================================================= + +function cholesky_batched(A::BatchedCuMatrix{T}) where {T} + A_copy = BatchedCuMatrix(copy(A.data)) + return potrf_batched!('L', A_copy) +end + +function pdmat_solve(S::BatchedPDMat{T}, B::BatchedCuMatrix{T}) where {T} + L = S.chol.factors + L_data = BatchedCuMatrix(L.data) + + B_copy = BatchedCuMatrix(copy(B.data)) + + # Solve L*L'*X = B via two triangular solves: + # 1. Solve L*Y = B (Y stored in B_copy) + trsm_batched!('L', 'L', 'N', 'N', one(T), L_data, B_copy) + # 2. Solve L'*X = Y (X stored in B_copy) + trsm_batched!('L', 'L', 'T', 'N', one(T), L_data, B_copy) + + return B_copy +end + +# ============================================================================= +# Batched TRMM (Triangular Matrix Multiply) +# ============================================================================= + +function trmm_batched!( + side::Char, + uplo::Char, + transA::Char, + diag::Char, + alpha::Float32, + A::BatchedCuMatrix{Float32}, + B::BatchedCuMatrix{Float32}, +) + N = batch_size(B) + m, n = size(B.data, 1), size(B.data, 2) + + dA = create_pointer_array(A) + dB = create_pointer_array(B) + + ldda = size(A.data, 1) + lddb = m + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_strmm_batched( + magma_side(side), + magma_uplo(uplo), + magma_trans(transA), + magma_diag(diag), + m, + n, + alpha, + dA, + ldda, + dB, + lddb, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + + return B +end + +function trmm_batched!( + side::Char, + uplo::Char, + transA::Char, + diag::Char, + alpha::Float64, + A::BatchedCuMatrix{Float64}, + B::BatchedCuMatrix{Float64}, +) + N = batch_size(B) + m, n = size(B.data, 1), size(B.data, 2) + + dA = create_pointer_array(A) + dB = create_pointer_array(B) + + ldda = size(A.data, 1) + lddb = m + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_dtrmm_batched( + magma_side(side), + magma_uplo(uplo), + magma_trans(transA), + magma_diag(diag), + m, + n, + alpha, + dA, + ldda, + dB, + lddb, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + + return B +end + +# ============================================================================= +# Batched SYRK (Symmetric Rank-K Update) +# ============================================================================= + +function syrk_batched!( + uplo::Char, + trans::Char, + alpha::Float32, + A::BatchedCuMatrix{Float32}, + beta::Float32, + C::BatchedCuMatrix{Float32}, +) + N = batch_size(C) + n = size(C.data, 1) + k = trans == 'N' ? size(A.data, 2) : size(A.data, 1) + + dA = create_pointer_array(A) + dC = create_pointer_array(C) + + ldda = size(A.data, 1) + lddc = n + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_ssyrk_batched( + magma_uplo(uplo), + magma_trans(trans), + n, + k, + alpha, + dA, + ldda, + beta, + dC, + lddc, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dC) + + return C +end + +function syrk_batched!( + uplo::Char, + trans::Char, + alpha::Float64, + A::BatchedCuMatrix{Float64}, + beta::Float64, + C::BatchedCuMatrix{Float64}, +) + N = batch_size(C) + n = size(C.data, 1) + k = trans == 'N' ? size(A.data, 2) : size(A.data, 1) + + dA = create_pointer_array(A) + dC = create_pointer_array(C) + + ldda = size(A.data, 1) + lddc = n + + CUDA.synchronize() + queue_ptr = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + LibMagma.magmablas_dsyrk_batched( + magma_uplo(uplo), + magma_trans(trans), + n, + k, + alpha, + dA, + ldda, + beta, + dC, + lddc, + N, + queue_ptr[], + ) + LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dC) + + return C +end + +# ============================================================================= +# Symmetrize Lower Triangular Matrix +# ============================================================================= + +function symmetrize_lower_kernel!(A, n) + batch_idx = blockIdx().z + i = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + j = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y + + if i <= n && j <= n && j > i + @inbounds A[j, i, batch_idx] = A[i, j, batch_idx] + end + return nothing +end + +function symmetrize_lower!(A::BatchedCuMatrix{T}) where {T} + n = size(A.data, 1) + N = size(A.data, 3) + threads = (16, 16) + blocks = (cld(n, 16), cld(n, 16), N) + @cuda threads = threads blocks = blocks symmetrize_lower_kernel!(A.data, n) + return A +end diff --git a/research/batching/batching_demo.jl b/research/batching/batching_demo.jl new file mode 100644 index 00000000..a97effce --- /dev/null +++ b/research/batching/batching_demo.jl @@ -0,0 +1,113 @@ +using GeneralisedFilters + +using Distributions +using LinearAlgebra +using Base.Broadcast: broadcasted +using PDMats +using StructArrays +using BenchmarkTools + +using CUDA +using Magma +using Magma.LibMagma + +Magma.magma_init() + +# ============================================================================= +# Configuration +# ============================================================================= + +D_state = 64 +D_obs = 64 +N = 1000 + +function kalman_predict(state, dyn_params) + A = dyn_params[1] + b = dyn_params[2] + Q = dyn_params[3] + + μ̂ = A * state.μ + b + Σ̂ = X_A_Xt(state.Σ, A) + Q + return MvNormal(μ̂, Σ̂) +end + +I_mat = CuArray{Float32}(I, D_state, D_state) +Is = SharedCuMatrix(I_mat) + +μs = BatchedCuVector(CUDA.randn(Float32, D_state, N)) +Σs_root = BatchedCuMatrix(CUDA.randn(Float32, D_state, D_state, N)) +Σs = Σs_root .* adjoint.(Σs_root) .+ Is + +As = SharedCuMatrix(CUDA.randn(Float32, D_state, D_state)) +bs = BatchedCuVector(CUDA.randn(Float32, D_state, N)) +Q_root = CUDA.randn(Float32, D_state, D_state) +Q = Q_root * Q_root' + I +Qs = SharedCuMatrix(Q) + +Σ_PDs = broadcasted(PDMat, Σs); +Gs = StructArray{MvNormal}((μ=μs, Σ=Σ_PDs)); + +function kalman_predict(state, dyn_params) + A = dyn_params[1] + b = dyn_params[2] + Q = dyn_params[3] + + μ̂ = A * state.μ + b + Σ̂ = X_A_Xt(state.Σ, A) + Q + return MvNormal(μ̂, Σ̂) +end + +dyn_params = (As, bs, Qs) +pred_Gs = kalman_predict.(Gs, Ref(dyn_params)); + +# Compare to CPU +μ_test = Array(μs[end]) +Σ_test = Array(Σs[end]) +A_test = Array(As.data) +b_test = Array(bs[end]) +Q_test = Array(Qs.data) +pred_G_test = kalman_predict(MvNormal(μ_test, PDMat(Σ_test)), (A_test, b_test, Q_test)) + +println("CPU Mean: ", pred_G_test.μ[1:5]) +println("GPU Mean: ", Array(pred_Gs[end].μ[1:5])) + +println("CPU Covariance Diagonal: ", diag(pred_G_test.Σ)[1:5]) +println("GPU Covariance Diagonal: ", Array(diag(pred_Gs[end].Σ))[1:5]) + +# Increase batch size and benchmark +D_large = 32 +N_large = 10000 +μs_large = BatchedCuVector(CUDA.randn(Float32, D_large, N_large)) +Σs_root_large = BatchedCuMatrix(CUDA.randn(Float32, D_large, D_large, N_large)) +Σs_large = Σs_root_large .* adjoint.(Σs_root_large) .+ SharedCuMatrix(CuArray{Float32}(I, D_large, D_large)) +Σ_PDs_large = broadcasted(PDMat, Σs_large); +Gs_large = StructArray{MvNormal}((μ=μs_large, Σ=Σ_PDs_large)); +dyn_params_large = ( + SharedCuMatrix(CUDA.randn(Float32, D_large, D_large)), + BatchedCuVector(CUDA.randn(Float32, D_large, N_large)), + SharedCuMatrix((CUDA.randn(Float32, D_large, D_large) * CUDA.randn(Float32, D_large, D_large)') .+ CuArray{Float32}(I, D_large, D_large)), +) +display(@benchmark kalman_predict.($Gs_large, Ref($dyn_params_large))) + +# Compare to multithreading StaticArrays +using StaticArrays +μs_static = [SVector{D_large, Float32}(randn(Float32, D_large)) for _ in 1:N_large]; +Σs_root_static = [SMatrix{D_large,D_large,Float32}(randn(Float32, D_large, D_large)) for _ in 1:N_large]; +Σs_static = [Σs_root_static[i] * adjoint(Σs_root_static[i]) + I for i in 1:N_large]; +Gs_static = [MvNormal(μs_static[i], Σs_static[i]) for i in 1:N_large]; +A_static = SMatrix{D_large,D_large,Float32}(randn(Float32, D_large, D_large)); +b_static = [SVector{D_large, Float32}(randn(Float32, D_large)) for _ in 1:N_large]; +Q_root_static = SMatrix{D_large,D_large,Float32}(randn(Float32, D_large, D_large)); +Q_static = Q_root_static * adjoint(Q_root_static) + I; + +function test_static(Gs, A, b, Q) + out = Vector{MvNormal{Float32, PDMat{Float32, SMatrix{32, 32, Float32, 1024}}, SVector{32, Float32}}}(undef, length(Gs)) + for i in 1:length(Gs) + @inbounds out[i] = kalman_predict(Gs[i], (A, b[i], Q)) + end + return out +end + +display(@benchmark test_static($Gs_static, $A_static, $b_static, $Q_static)) + +@profview test_static(Gs_static, A_static, b_static, Q_static) From 36d4d0b0e9b4d8b8256e02889ef695539033bf15 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 6 Jan 2026 16:08:47 +0000 Subject: [PATCH 13/29] Extend demo to full Kalman filter --- research/batching/batching_demo.jl | 122 ++++++++++++++++++++--------- 1 file changed, 85 insertions(+), 37 deletions(-) diff --git a/research/batching/batching_demo.jl b/research/batching/batching_demo.jl index a97effce..d2841061 100644 --- a/research/batching/batching_demo.jl +++ b/research/batching/batching_demo.jl @@ -68,46 +68,94 @@ b_test = Array(bs[end]) Q_test = Array(Qs.data) pred_G_test = kalman_predict(MvNormal(μ_test, PDMat(Σ_test)), (A_test, b_test, Q_test)) +println("=== Predict Comparison ===") println("CPU Mean: ", pred_G_test.μ[1:5]) println("GPU Mean: ", Array(pred_Gs[end].μ[1:5])) -println("CPU Covariance Diagonal: ", diag(pred_G_test.Σ)[1:5]) -println("GPU Covariance Diagonal: ", Array(diag(pred_Gs[end].Σ))[1:5]) - -# Increase batch size and benchmark -D_large = 32 -N_large = 10000 -μs_large = BatchedCuVector(CUDA.randn(Float32, D_large, N_large)) -Σs_root_large = BatchedCuMatrix(CUDA.randn(Float32, D_large, D_large, N_large)) -Σs_large = Σs_root_large .* adjoint.(Σs_root_large) .+ SharedCuMatrix(CuArray{Float32}(I, D_large, D_large)) -Σ_PDs_large = broadcasted(PDMat, Σs_large); -Gs_large = StructArray{MvNormal}((μ=μs_large, Σ=Σ_PDs_large)); -dyn_params_large = ( - SharedCuMatrix(CUDA.randn(Float32, D_large, D_large)), - BatchedCuVector(CUDA.randn(Float32, D_large, N_large)), - SharedCuMatrix((CUDA.randn(Float32, D_large, D_large) * CUDA.randn(Float32, D_large, D_large)') .+ CuArray{Float32}(I, D_large, D_large)), -) -display(@benchmark kalman_predict.($Gs_large, Ref($dyn_params_large))) - -# Compare to multithreading StaticArrays -using StaticArrays -μs_static = [SVector{D_large, Float32}(randn(Float32, D_large)) for _ in 1:N_large]; -Σs_root_static = [SMatrix{D_large,D_large,Float32}(randn(Float32, D_large, D_large)) for _ in 1:N_large]; -Σs_static = [Σs_root_static[i] * adjoint(Σs_root_static[i]) + I for i in 1:N_large]; -Gs_static = [MvNormal(μs_static[i], Σs_static[i]) for i in 1:N_large]; -A_static = SMatrix{D_large,D_large,Float32}(randn(Float32, D_large, D_large)); -b_static = [SVector{D_large, Float32}(randn(Float32, D_large)) for _ in 1:N_large]; -Q_root_static = SMatrix{D_large,D_large,Float32}(randn(Float32, D_large, D_large)); -Q_static = Q_root_static * adjoint(Q_root_static) + I; - -function test_static(Gs, A, b, Q) - out = Vector{MvNormal{Float32, PDMat{Float32, SMatrix{32, 32, Float32, 1024}}, SVector{32, Float32}}}(undef, length(Gs)) - for i in 1:length(Gs) - @inbounds out[i] = kalman_predict(Gs[i], (A, b[i], Q)) - end - return out +println("CPU Covariance [1:3, 1:3]: ", Matrix(pred_G_test.Σ)[1:3, 1:3]) +println("GPU Covariance [1:3, 1:3]: ", Array(pred_Gs[end].Σ)[1:3, 1:3]) + +# ============================================================================= +# Kalman Update +# ============================================================================= + +function kalman_update(state, obs_params, observation) + μ = state.μ + Σ = state.Σ + H = obs_params[1] + c = obs_params[2] + R = obs_params[3] + + # Compute innovation distribution + m = H * μ + c + S = PDMat(X_A_Xt(Σ, H) + R) + ȳ = observation - m + + # Kalman gain + K = Σ * H' / S + + # Update parameters using Joseph form for numerical stability + μ̂ = μ + K * ȳ + Σ̂ = X_A_Xt(Σ, I - K * H) + X_A_Xt(R, K) + + return MvNormal(μ̂, Σ̂) end -display(@benchmark test_static($Gs_static, $A_static, $b_static, $Q_static)) +function kalman_step(state, dyn_params, obs_params, observation) + state = kalman_predict(state, dyn_params) + state = kalman_update(state, obs_params, observation) + return state +end + +# Observation parameters (H and c shared, R batched) +Hs = SharedCuMatrix(CUDA.randn(Float32, D_obs, D_state)) +cs = SharedCuVector(CUDA.randn(Float32, D_obs)) +I_obs = CuArray{Float32}(I, D_obs, D_obs) +I_obs_shared = SharedCuMatrix(I_obs) +Rs_root = BatchedCuMatrix(CUDA.randn(Float32, D_obs, D_obs, N)) +Rs = Rs_root .* adjoint.(Rs_root) .+ I_obs_shared + +obs_params = (Hs, cs, Rs) + +# Observations +observations = BatchedCuVector(CUDA.randn(Float32, D_obs, N)) + +# Run update on GPU +update_Gs = kalman_update.(pred_Gs, Ref(obs_params), observations); + +# Compare update to CPU +H_test = Array(Hs.data) +c_test = Array(cs.data) +R_test = PDMat(Array(Rs[end])) +obs_test = Array(observations[end]) + +update_G_test = kalman_update(pred_G_test, (H_test, c_test, R_test), obs_test) + +println("\n=== Update Comparison ===") +println("CPU Mean: ", update_G_test.μ[1:5]) +println("GPU Mean: ", Array(update_Gs.μ[end][1:5])) + +println("CPU Covariance [1:3, 1:3]: ", Matrix(update_G_test.Σ)[1:3, 1:3]) +println("GPU Covariance [1:3, 1:3]: ", Array(update_Gs.Σ[end])[1:3, 1:3]) + +# ============================================================================= +# Full Kalman Step +# ============================================================================= + +# Run full step on GPU (from original state) +step_Gs = kalman_step.(Gs, Ref(dyn_params), Ref(obs_params), observations); + +# Compare full step to CPU +step_G_test = kalman_step( + MvNormal(μ_test, PDMat(Σ_test)), + (A_test, b_test, Q_test), + (H_test, c_test, R_test), + obs_test, +) + +println("\n=== Full Step Comparison ===") +println("CPU Mean: ", step_G_test.μ[1:5]) +println("GPU Mean: ", Array(step_Gs.μ[end][1:5])) -@profview test_static(Gs_static, A_static, b_static, Q_static) +println("CPU Covariance [1:3, 1:3]: ", Matrix(step_G_test.Σ)[1:3, 1:3]) +println("GPU Covariance [1:3, 1:3]: ", Array(step_Gs.Σ[end])[1:3, 1:3]) From 4d10f18bee989ba4a2e8d69051f000af86a1d1c5 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 08:29:57 +0000 Subject: [PATCH 14/29] Fix formatting --- GeneralisedFilters/src/batching/operations.jl | 4 +--- GeneralisedFilters/src/batching/types.jl | 10 +++++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/GeneralisedFilters/src/batching/operations.jl b/GeneralisedFilters/src/batching/operations.jl index 8464fc61..5d138155 100644 --- a/GeneralisedFilters/src/batching/operations.jl +++ b/GeneralisedFilters/src/batching/operations.jl @@ -180,9 +180,7 @@ end # X_A_Xt for BatchedPDMat: X * P * X' where P = L * L' # Computed as (X * L) * (X * L)' using TRMM and SYRK function broadcasted( - ::typeof(X_A_Xt), - P::BatchedPDMat{T}, - X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + ::typeof(X_A_Xt), P::BatchedPDMat{T}, X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}} ) where {T} L = P.chol.factors N = get_batch_size(P, X) diff --git a/GeneralisedFilters/src/batching/types.jl b/GeneralisedFilters/src/batching/types.jl index 7ae2b399..02be3464 100644 --- a/GeneralisedFilters/src/batching/types.jl +++ b/GeneralisedFilters/src/batching/types.jl @@ -39,27 +39,27 @@ function Base.getindex(x::BatchedCuVector{T,CuVector{T}}, i::Int) where {T} end function Base.getindex(x::BatchedCuMatrix{T,CuMatrix{T}}, i::Int) where {T} - return view(x.data, :, :, i) + return view(x.data,:,:,i) end function Base.getindex( x::BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, i::Int ) where {T} - return LowerTriangular(view(x.data, :, :, i)) + return LowerTriangular(view(x.data,:,:,i)) end function Base.getindex( x::BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}, i::Int ) where {T} - return UpperTriangular(view(x.data, :, :, i)) + return UpperTriangular(view(x.data,:,:,i)) end function Base.getindex(x::BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}, i::Int) where {T} - return adjoint(view(x.data, :, :, i)) + return adjoint(view(x.data,:,:,i)) end function Base.getindex(x::BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}, i::Int) where {T} - return transpose(view(x.data, :, :, i)) + return transpose(view(x.data,:,:,i)) end # ============================================================================= From 89e3aa9b965421ded482985b5349f28cc09ef9e0 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 11:17:43 +0000 Subject: [PATCH 15/29] Implement generic wrapping system Removes parametric batched types and custom Cholesky/PDMat batched types and replaces them with a generic system based on StructArrays. --- .../src/batching/broadcasting.jl | 57 ++--- GeneralisedFilters/src/batching/operations.jl | 204 ++++++++++-------- GeneralisedFilters/src/batching/types.jl | 122 ++--------- GeneralisedFilters/src/batching/wrappers.jl | 47 ++-- research/batching/wrappers_demo.jl | 146 +++++++++++++ 5 files changed, 338 insertions(+), 238 deletions(-) create mode 100644 research/batching/wrappers_demo.jl diff --git a/GeneralisedFilters/src/batching/broadcasting.jl b/GeneralisedFilters/src/batching/broadcasting.jl index 92d50811..8ffda2a0 100644 --- a/GeneralisedFilters/src/batching/broadcasting.jl +++ b/GeneralisedFilters/src/batching/broadcasting.jl @@ -18,9 +18,6 @@ Base.BroadcastStyle(::Type{<:BatchedCuMatrix}) = BatchedStyle() Base.BroadcastStyle(::Type{<:BatchedCuVector}) = BatchedStyle() Base.BroadcastStyle(::Type{<:SharedCuMatrix}) = BatchedStyle() Base.BroadcastStyle(::Type{<:SharedCuVector}) = BatchedStyle() -Base.BroadcastStyle(::Type{<:BatchedCholesky}) = BatchedStyle() -Base.BroadcastStyle(::Type{<:BatchedPDMat}) = BatchedStyle() -# HACK: Currently hard-coded but can be replaced with a custom StructArray type Base.BroadcastStyle(::Type{<:StructArray}) = BatchedStyle() Base.BroadcastStyle(::BatchedStyle, ::BatchedStyle) = BatchedStyle() Base.BroadcastStyle(::BatchedStyle, ::DefaultArrayStyle{0}) = BatchedStyle() @@ -30,12 +27,8 @@ Base.BroadcastStyle(::BatchedStyle, ::DefaultArrayStyle{0}) = BatchedStyle() # ============================================================================= maybe_convert_ref(x) = x -function maybe_convert_ref(r::Base.RefValue{<:CuVector{T}}) where {T} - return SharedCuVector{T,CuVector{T}}(r[]) -end -function maybe_convert_ref(r::Base.RefValue{<:CuMatrix{T}}) where {T} - return SharedCuMatrix{T,CuMatrix{T}}(r[]) -end +maybe_convert_ref(r::Base.RefValue{<:CuVector}) = SharedCuVector(r[]) +maybe_convert_ref(r::Base.RefValue{<:CuMatrix}) = SharedCuMatrix(r[]) # ============================================================================= # Structural Operations (Pass-through) @@ -54,15 +47,12 @@ broadcasted(::typeof(getfield), r::Base.RefValue, s::Symbol) = getfield(r[], s) # StructArray Wrapping # ============================================================================= -inner_eltype(arg::BatchedCuVector{T}) where {T} = CuVector{T} -inner_eltype(arg::BatchedCuMatrix{T}) where {T} = CuMatrix{T} -inner_eltype(arg::SharedCuVector{T}) where {T} = CuVector{T} -inner_eltype(arg::SharedCuMatrix{T}) where {T} = CuMatrix{T} -inner_eltype(arg::BatchedPDMat{T}) where {T} = PDMat{T,CuMatrix{T}} +inner_eltype(arg::BatchedOrShared) = eltype(arg) +inner_eltype(arg::StructArray) = eltype(arg) inner_eltype(arg) = typeof(arg) function wrap_if_batched(::Type{T}, args...) where {T} - if any(arg -> arg isa Union{BatchedArray,SharedArray,BatchedPDMat}, args) + if any(arg -> arg isa Union{BatchedOrShared,StructArray}, args) field_names = fieldnames(T) element_types = Tuple{map(inner_eltype, args)...} ElType = Core.Compiler.return_type(T, element_types) @@ -73,20 +63,37 @@ function wrap_if_batched(::Type{T}, args...) where {T} end end +# ============================================================================= +# Generic Wrapper Broadcasting +# ============================================================================= + +""" + broadcasted(::Type{W}, args::BatchedOrShared...) + +Generic wrapper for any type constructor applied to batched arguments. +Works for single-field wrappers (Adjoint, Transpose, LowerTriangular, etc.) +as well as multi-field types (PDMat, Cholesky, etc.). + +Returns a StructArray where each element is the type applied to the +corresponding elements of the input arrays. +""" +function broadcasted(::Type{W}, args::Vararg{BatchedOrShared}) where {W} + element_types = Tuple{map(eltype, args)...} + ElType = Core.Compiler.return_type(W, element_types) + field_names = fieldnames(ElType) + nt = NamedTuple{field_names}(args) + return StructArray{ElType}(nt) +end + +# Redirect function forms to type constructors +broadcasted(::typeof(adjoint), A::BatchedOrShared) = broadcasted(Adjoint, A) +broadcasted(::typeof(transpose), A::BatchedOrShared) = broadcasted(Transpose, A) + # ============================================================================= # IR Transformation # ============================================================================= -const SKIP_BROADCAST = Set{Any}([ - tuple, - Core.tuple, - getfield, - getproperty, - adjoint, - transpose, - LowerTriangular, - UpperTriangular, -]) +const SKIP_BROADCAST = Set{Any}([tuple, Core.tuple, getfield, getproperty]) const BROADCAST_TYPES = Set{Any}([PDMat]) diff --git a/GeneralisedFilters/src/batching/operations.jl b/GeneralisedFilters/src/batching/operations.jl index 5d138155..b1fc55d1 100644 --- a/GeneralisedFilters/src/batching/operations.jl +++ b/GeneralisedFilters/src/batching/operations.jl @@ -1,53 +1,71 @@ import PDMats: X_A_Xt # ============================================================================= -# Adjoint/Transpose Broadcasting +# GEMM-Compatible Types # ============================================================================= -function broadcasted(::typeof(adjoint), A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} - return BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}(A.data) -end - -function broadcasted(::typeof(transpose), A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} - return BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}(A.data) -end - -function broadcasted(::typeof(adjoint), A::SharedCuMatrix{T,CuMatrix{T}}) where {T} - return SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}}(A.data) -end - -function broadcasted(::typeof(transpose), A::SharedCuMatrix{T,CuMatrix{T}}) where {T} - return SharedCuMatrix{T,Transpose{T,CuMatrix{T}}}(A.data) -end - -function broadcasted( - ::typeof(adjoint), A::BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}} -) where {T} - return BatchedCuMatrix{T,CuMatrix{T}}(A.data) -end - -function broadcasted( - ::typeof(adjoint), A::SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}} -) where {T} - return SharedCuMatrix{T,CuMatrix{T}}(A.data) -end - -function broadcasted(::Type{LowerTriangular}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} - return BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}(A.data) -end - -function broadcasted(::Type{UpperTriangular}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} - return BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}(A.data) -end +# Type aliases for StructArray-wrapped matrices +const BatchedAdjoint{T,M} = StructArray{ + Adjoint{T,CuArray{T,2,M}},1,@NamedTuple{parent::BatchedCuMatrix{T,M}} +} +const BatchedTranspose{T,M} = StructArray{ + Transpose{T,CuArray{T,2,M}},1,@NamedTuple{parent::BatchedCuMatrix{T,M}} +} +const SharedAdjoint{T,M} = StructArray{ + Adjoint{T,CuArray{T,2,M}},1,@NamedTuple{parent::SharedCuMatrix{T,M}} +} +const SharedTranspose{T,M} = StructArray{ + Transpose{T,CuArray{T,2,M}},1,@NamedTuple{parent::SharedCuMatrix{T,M}} +} + +# Union of all GEMM-compatible matrix types +const GEMMCompatibleMatrix{T} = Union{ + BatchedCuMatrix{T}, + SharedCuMatrix{T}, + BatchedAdjoint{T}, + BatchedTranspose{T}, + SharedAdjoint{T}, + SharedTranspose{T}, +} + +# trans_flag: returns BLAS transpose flag for each type +trans_flag(::BatchedCuMatrix{T}) where {T} = 'N' +trans_flag(::SharedCuMatrix{T}) where {T} = 'N' +trans_flag(::BatchedAdjoint{T}) where {T} = T <: Real ? 'T' : 'C' +trans_flag(::BatchedTranspose{T}) where {T} = 'T' +trans_flag(::SharedAdjoint{T}) where {T} = T <: Real ? 'T' : 'C' +trans_flag(::SharedTranspose{T}) where {T} = 'T' + +# gemm_data: extracts the underlying BatchedCuMatrix/SharedCuMatrix for GEMM +gemm_data(A::BatchedCuMatrix) = A +gemm_data(A::SharedCuMatrix) = A +gemm_data(A::BatchedAdjoint) = A.parent +gemm_data(A::BatchedTranspose) = A.parent +gemm_data(A::SharedAdjoint) = A.parent +gemm_data(A::SharedTranspose) = A.parent + +# inner_size_for_blas for wrapped types (delegates to underlying data) +inner_size_for_blas(A::BatchedAdjoint) = inner_size_for_blas(A.parent) +inner_size_for_blas(A::BatchedTranspose) = inner_size_for_blas(A.parent) +inner_size_for_blas(A::SharedAdjoint) = inner_size_for_blas(A.parent) +inner_size_for_blas(A::SharedTranspose) = inner_size_for_blas(A.parent) + +# batch_size for wrapped types +batch_size(A::BatchedAdjoint) = batch_size(A.parent) +batch_size(A::BatchedTranspose) = batch_size(A.parent) +batch_size(A::SharedAdjoint) = batch_size(A.parent) +batch_size(A::SharedTranspose) = batch_size(A.parent) + +# TODO: For nested wrappers (e.g., Adjoint{LowerTriangular{...}}), we should +# materialize the inner wrapper first before extracting. For now, we only +# support single-level Adjoint/Transpose wrappers for efficient GEMM dispatch. # ============================================================================= # Matrix Multiply Broadcasting # ============================================================================= function broadcasted( - ::typeof(*), - A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, - B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + ::typeof(*), A::GEMMCompatibleMatrix{T}, B::GEMMCompatibleMatrix{T} ) where {T} transA = trans_flag(A) transB = trans_flag(B) @@ -62,17 +80,17 @@ function broadcasted( C_data = CuArray{T}(undef, m, n, N) C = BatchedCuMatrix(C_data) - gemm_batched!(transA, transB, one(T), A, B, zero(T), C) + gemm_batched!(transA, transB, one(T), gemm_data(A), gemm_data(B), zero(T), C) return C end # Multi-argument multiply function broadcasted( ::typeof(*), - A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, - B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, - C::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, - rest::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}..., + A::GEMMCompatibleMatrix{T}, + B::GEMMCompatibleMatrix{T}, + C::GEMMCompatibleMatrix{T}, + rest::GEMMCompatibleMatrix{T}..., ) where {T} result = broadcasted(*, A, B) result = broadcasted(*, result, C) @@ -143,24 +161,24 @@ end # PDMat Broadcasting # ============================================================================= -function broadcasted(::Type{PDMat}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} - chol = cholesky_batched(A) - return BatchedPDMat{T}(chol) -end - -function broadcasted(::typeof(\), S::BatchedPDMat{T}, A::BatchedCuMatrix{T}) where {T} - return pdmat_solve(S, A) -end - -function broadcasted(::typeof(/), A::BatchedCuMatrix{T}, S::BatchedPDMat{T}) where {T} - # Need to actually transpose the data, not just wrap it - At_data = permutedims(A.data, (2, 1, 3)) - At = BatchedCuMatrix(At_data) - result_t = pdmat_solve(S, At) - # Transpose back - result_data = permutedims(result_t.data, (2, 1, 3)) - return BatchedCuMatrix(result_data) -end +# function broadcasted(::Type{PDMat}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} +# chol = cholesky_batched(A) +# return BatchedPDMat{T}(chol) +# end + +# function broadcasted(::typeof(\), S::BatchedPDMat{T}, A::BatchedCuMatrix{T}) where {T} +# return pdmat_solve(S, A) +# end + +# function broadcasted(::typeof(/), A::BatchedCuMatrix{T}, S::BatchedPDMat{T}) where {T} +# # Need to actually transpose the data, not just wrap it +# At_data = permutedims(A.data, (2, 1, 3)) +# At = BatchedCuMatrix(At_data) +# result_t = pdmat_solve(S, At) +# # Transpose back +# result_data = permutedims(result_t.data, (2, 1, 3)) +# return BatchedCuMatrix(result_data) +# end # ============================================================================= # Quadratic Form Broadcasting @@ -179,34 +197,34 @@ end # X_A_Xt for BatchedPDMat: X * P * X' where P = L * L' # Computed as (X * L) * (X * L)' using TRMM and SYRK -function broadcasted( - ::typeof(X_A_Xt), P::BatchedPDMat{T}, X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}} -) where {T} - L = P.chol.factors - N = get_batch_size(P, X) - - X_inner = inner_size_for_blas(X) - m = X_inner[1] - - # Copy X to XL (TRMM overwrites in-place) - XL_data = if X isa SharedCuMatrix - repeat(reshape(X.data, size(X.data, 1), size(X.data, 2), 1), 1, 1, N) - else - copy(X.data) - end - XL = BatchedCuMatrix(XL_data) - - # XL = X * L using TRMM (side='R' for right multiply, uplo='L' for lower triangular) - L_data = BatchedCuMatrix(L.data) - trmm_batched!('R', 'L', 'N', 'N', one(T), L_data, XL) - - # Result = XL * XL' using SYRK (fills lower triangle) - Result_data = CuArray{T}(undef, m, m, N) - Result = BatchedCuMatrix(Result_data) - syrk_batched!('L', 'N', one(T), XL, zero(T), Result) - - # Symmetrize: copy lower triangle to upper - symmetrize_lower!(Result) - - return Result -end +# function broadcasted( +# ::typeof(X_A_Xt), P::BatchedPDMat{T}, X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}} +# ) where {T} +# L = P.chol.factors +# N = get_batch_size(P, X) + +# X_inner = inner_size_for_blas(X) +# m = X_inner[1] + +# # Copy X to XL (TRMM overwrites in-place) +# XL_data = if X isa SharedCuMatrix +# repeat(reshape(X.data, size(X.data, 1), size(X.data, 2), 1), 1, 1, N) +# else +# copy(X.data) +# end +# XL = BatchedCuMatrix(XL_data) + +# # XL = X * L using TRMM (side='R' for right multiply, uplo='L' for lower triangular) +# L_data = BatchedCuMatrix(L.data) +# trmm_batched!('R', 'L', 'N', 'N', one(T), L_data, XL) + +# # Result = XL * XL' using SYRK (fills lower triangle) +# Result_data = CuArray{T}(undef, m, m, N) +# Result = BatchedCuMatrix(Result_data) +# syrk_batched!('L', 'N', one(T), XL, zero(T), Result) + +# # Symmetrize: copy lower triangle to upper +# symmetrize_lower!(Result) + +# return Result +# end diff --git a/GeneralisedFilters/src/batching/types.jl b/GeneralisedFilters/src/batching/types.jl index 02be3464..8efc0bd1 100644 --- a/GeneralisedFilters/src/batching/types.jl +++ b/GeneralisedFilters/src/batching/types.jl @@ -2,28 +2,25 @@ using CUDA using LinearAlgebra: Adjoint, Transpose, LowerTriangular, UpperTriangular, UniformScaling, Cholesky using PDMats: PDMat +using StructArrays: StructArray export BatchedCuMatrix, BatchedCuVector export SharedCuMatrix, SharedCuVector -export BatchedPDMat, BatchedCholesky # ============================================================================= # Core Batched Types # ============================================================================= -struct BatchedCuMatrix{T,Inner<:AbstractMatrix{T}} <: AbstractVector{Inner} - data::CuArray{T,3} +struct BatchedCuMatrix{T,M} <: AbstractVector{CuArray{T,2,M}} + data::CuArray{T,3,M} end -struct BatchedCuVector{T,Inner<:AbstractVector{T}} <: AbstractVector{Inner} - data::CuMatrix{T} +struct BatchedCuVector{T,M} <: AbstractVector{CuArray{T,1,M}} + data::CuArray{T,2,M} end const BatchedArray = Union{BatchedCuVector,BatchedCuMatrix} -BatchedCuMatrix(data::CuArray{T,3}) where {T} = BatchedCuMatrix{T,CuMatrix{T}}(data) -BatchedCuVector(data::CuMatrix{T}) where {T} = BatchedCuVector{T,CuVector{T}}(data) - batch_size(x::BatchedCuVector) = size(x.data, 2) batch_size(x::BatchedCuMatrix) = size(x.data, 3) @@ -34,53 +31,29 @@ Base.length(x::BatchedArray) = batch_size(x) inner_size(x::BatchedCuVector) = (size(x.data, 1),) inner_size(x::BatchedCuMatrix) = (size(x.data, 1), size(x.data, 2)) -function Base.getindex(x::BatchedCuVector{T,CuVector{T}}, i::Int) where {T} - return view(x.data, :, i) -end - -function Base.getindex(x::BatchedCuMatrix{T,CuMatrix{T}}, i::Int) where {T} - return view(x.data,:,:,i) -end - -function Base.getindex( - x::BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, i::Int -) where {T} - return LowerTriangular(view(x.data,:,:,i)) -end - -function Base.getindex( - x::BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}, i::Int -) where {T} - return UpperTriangular(view(x.data,:,:,i)) -end - -function Base.getindex(x::BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}, i::Int) where {T} - return adjoint(view(x.data,:,:,i)) -end - -function Base.getindex(x::BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}, i::Int) where {T} - return transpose(view(x.data,:,:,i)) -end +Base.getindex(x::BatchedCuVector, i::Int) = view(x.data, :, i) +Base.getindex(x::BatchedCuMatrix, i::Int) = view(x.data,:,:,i) # ============================================================================= # Shared Types (same data reused across all batch elements) # ============================================================================= -struct SharedCuMatrix{T,Inner<:AbstractMatrix{T}} <: AbstractVector{Inner} - data::CuMatrix{T} +struct SharedCuMatrix{T,M} <: AbstractVector{CuArray{T,2,M}} + data::CuArray{T,2,M} end -struct SharedCuVector{T,Inner<:AbstractVector{T}} <: AbstractVector{Inner} - data::CuVector{T} +struct SharedCuVector{T,M} <: AbstractVector{CuArray{T,1,M}} + data::CuArray{T,1,M} end const SharedArray = Union{SharedCuVector,SharedCuMatrix} -SharedCuMatrix(data::CuMatrix{T}) where {T} = SharedCuMatrix{T,CuMatrix{T}}(data) -SharedCuVector(data::CuVector{T}) where {T} = SharedCuVector{T,CuVector{T}}(data) +# Convenience constructors that infer memory type +SharedCuMatrix(data::CuArray{T,2,M}) where {T,M} = SharedCuMatrix{T,M}(data) +SharedCuVector(data::CuArray{T,1,M}) where {T,M} = SharedCuVector{T,M}(data) -Shared(x::CuMatrix{T}) where {T} = SharedCuMatrix(x) -Shared(x::CuVector{T}) where {T} = SharedCuVector(x) +Shared(x::CuArray{T,2,M}) where {T,M} = SharedCuMatrix{T,M}(x) +Shared(x::CuArray{T,1,M}) where {T,M} = SharedCuVector{T,M}(x) batch_size(::SharedCuVector) = nothing batch_size(::SharedCuMatrix) = nothing @@ -93,34 +66,16 @@ Base.size(x::SharedCuMatrix) = (1,) Base.length(::SharedArray) = 1 Base.getindex(x::SharedCuVector, ::Int) = x.data -Base.getindex(x::SharedCuMatrix{T,CuMatrix{T}}, ::Int) where {T} = x.data -function Base.getindex(x::SharedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, ::Int) where {T} - return LowerTriangular(x.data) -end +Base.getindex(x::SharedCuMatrix, ::Int) = x.data # ============================================================================= -# Type Aliases and Union Types for Dispatch +# Union Types for Dispatch # ============================================================================= -const AnyBatchedMatrix{T} = Union{ - BatchedCuMatrix{T,CuMatrix{T}}, - BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}, - BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}, - BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, - BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}, -} - -const AnySharedMatrix{T} = Union{ - SharedCuMatrix{T,CuMatrix{T}}, - SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}}, - SharedCuMatrix{T,Transpose{T,CuMatrix{T}}}, - SharedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}, - SharedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}, +const BatchedOrShared = Union{ + BatchedCuMatrix,BatchedCuVector,SharedCuMatrix,SharedCuVector,StructArray } -const AnyMatrix{T} = Union{AnyBatchedMatrix{T},AnySharedMatrix{T}} -const AnyVector{T} = Union{BatchedCuVector{T},SharedCuVector{T}} - # ============================================================================= # Helper Functions # ============================================================================= @@ -135,18 +90,6 @@ unwrap_data(A::SharedCuMatrix) = A.data unwrap_data(x::BatchedCuVector) = x.data unwrap_data(x::SharedCuVector) = x.data -trans_flag(::BatchedCuMatrix{T,CuMatrix{T}}) where {T} = 'N' -trans_flag(::BatchedCuMatrix{T,Adjoint{T,CuMatrix{T}}}) where {T} = T <: Real ? 'T' : 'C' -trans_flag(::BatchedCuMatrix{T,Transpose{T,CuMatrix{T}}}) where {T} = 'T' -trans_flag(::BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}) where {T} = 'N' -trans_flag(::BatchedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}) where {T} = 'N' - -trans_flag(::SharedCuMatrix{T,CuMatrix{T}}) where {T} = 'N' -trans_flag(::SharedCuMatrix{T,Adjoint{T,CuMatrix{T}}}) where {T} = T <: Real ? 'T' : 'C' -trans_flag(::SharedCuMatrix{T,Transpose{T,CuMatrix{T}}}) where {T} = 'T' -trans_flag(::SharedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}}) where {T} = 'N' -trans_flag(::SharedCuMatrix{T,UpperTriangular{T,CuMatrix{T}}}) where {T} = 'N' - function inner_size_for_blas(A::BatchedCuMatrix) m, n = size(A.data, 1), size(A.data, 2) return (m, n) @@ -167,31 +110,6 @@ function get_batch_size(args...) return error("At least one argument must be batched") end -# ============================================================================= -# Stateful Wrapper Types (Cholesky, PDMat) -# ============================================================================= - -struct BatchedCholesky{T} <: AbstractVector{Cholesky{T,CuMatrix{T}}} - factors::BatchedCuMatrix{T,LowerTriangular{T,CuMatrix{T}}} - info::CuVector{Int32} - uplo::Char -end - -struct BatchedPDMat{T} <: AbstractVector{PDMat{T,CuMatrix{T}}} - chol::BatchedCholesky{T} -end - -batch_size(c::BatchedCholesky) = batch_size(c.factors) -batch_size(p::BatchedPDMat) = batch_size(p.chol) - -inner_size(c::BatchedCholesky) = inner_size(c.factors) -inner_size(p::BatchedPDMat) = inner_size(p.chol) - -Base.size(c::BatchedCholesky) = (batch_size(c),) -Base.size(p::BatchedPDMat) = (batch_size(p),) - -Base.getindex(p::BatchedPDMat{T}, i::Int) where {T} = p.chol.factors[i] * p.chol.factors[i]' - # ============================================================================= # Pointer Array Creation # ============================================================================= diff --git a/GeneralisedFilters/src/batching/wrappers.jl b/GeneralisedFilters/src/batching/wrappers.jl index 6f7e741b..969a6bb8 100644 --- a/GeneralisedFilters/src/batching/wrappers.jl +++ b/GeneralisedFilters/src/batching/wrappers.jl @@ -1,5 +1,7 @@ using Magma using Magma.LibMagma +using LinearAlgebra: cholesky, Cholesky, LowerTriangular +using StructArrays: StructArray # ============================================================================= # Trivial Wrappers (reductions and elementwise operations) @@ -389,27 +391,25 @@ function gemv_batched_smallsq!( return y end -function potrf_batched!(uplo::Char, A::BatchedCuMatrix{Float32}) +function potrf_batched!(uplo::Char, A::BatchedCuMatrix{Float32}, info::CuVector{Int64}) N = batch_size(A) n = size(A.data, 1) lda = n dA = create_pointer_array(A) - info_gpu = CUDA.zeros(Int64, N) CUDA.synchronize() queue_ptr = Ref{LibMagma.magma_queue_t}() LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) LibMagma.magma_spotrf_batched( - magma_uplo(uplo), n, dA, lda, pointer(info_gpu), N, queue_ptr[] + magma_uplo(uplo), n, dA, lda, pointer(info), N, queue_ptr[] ) LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) - factors = BatchedCuMatrix{Float32,LowerTriangular{Float32,CuMatrix{Float32}}}(A.data) - return BatchedCholesky{Float32}(factors, info_gpu, uplo) + return A end function potrs_batched!( @@ -489,26 +489,37 @@ end # Higher-level Cholesky Operations # ============================================================================= -function cholesky_batched(A::BatchedCuMatrix{T}) where {T} +function broadcasted(::typeof(cholesky), A::BatchedCuMatrix{T,M}) where {T,M} + N = batch_size(A) A_copy = BatchedCuMatrix(copy(A.data)) - return potrf_batched!('L', A_copy) -end + info = CUDA.zeros(Int64, N) -function pdmat_solve(S::BatchedPDMat{T}, B::BatchedCuMatrix{T}) where {T} - L = S.chol.factors - L_data = BatchedCuMatrix(L.data) + potrf_batched!('L', A_copy, info) - B_copy = BatchedCuMatrix(copy(B.data)) + factors_wrapped = broadcasted(LowerTriangular, A_copy) - # Solve L*L'*X = B via two triangular solves: - # 1. Solve L*Y = B (Y stored in B_copy) - trsm_batched!('L', 'L', 'N', 'N', one(T), L_data, B_copy) - # 2. Solve L'*X = Y (X stored in B_copy) - trsm_batched!('L', 'L', 'T', 'N', one(T), L_data, B_copy) + # TODO: Use a lazy constant vector for uplo instead of dense fill + uplo = fill('L', N) - return B_copy + ElType = Cholesky{T,eltype(A)} + return StructArray{ElType}((; factors=factors_wrapped, uplo=uplo, info=info)) end +# function pdmat_solve(S::BatchedPDMat{T}, B::BatchedCuMatrix{T}) where {T} +# L = S.chol.factors +# L_data = BatchedCuMatrix(L.data) + +# B_copy = BatchedCuMatrix(copy(B.data)) + +# # Solve L*L'*X = B via two triangular solves: +# # 1. Solve L*Y = B (Y stored in B_copy) +# trsm_batched!('L', 'L', 'N', 'N', one(T), L_data, B_copy) +# # 2. Solve L'*X = Y (X stored in B_copy) +# trsm_batched!('L', 'L', 'T', 'N', one(T), L_data, B_copy) + +# return B_copy +# end + # ============================================================================= # Batched TRMM (Triangular Matrix Multiply) # ============================================================================= diff --git a/research/batching/wrappers_demo.jl b/research/batching/wrappers_demo.jl new file mode 100644 index 00000000..a135cef1 --- /dev/null +++ b/research/batching/wrappers_demo.jl @@ -0,0 +1,146 @@ +using GeneralisedFilters +using CUDA +using LinearAlgebra +using StructArrays +using Base.Broadcast: broadcasted +using Distributions +using PDMats + +# ============================================================================= +# Configuration +# ============================================================================= + +N = 10 # batch size +D = 4 # matrix dimension + +println("Creating batched matrices...") +A = BatchedCuMatrix(CUDA.randn(Float32, D, D, N)); + +# ============================================================================= +# Test 1: Basic wrapper creation +# ============================================================================= + +println("\n=== Test 1: Basic wrapper creation ===\n") + +# Adjoint +A_adj = broadcasted(Adjoint, A); +println("Adjoint type: ", typeof(A_adj)) +println("Adjoint eltype: ", eltype(A_adj)) +println("First element type: ", typeof(A_adj[1])) + +# Transpose +A_trans = broadcasted(Transpose, A); +println("\nTranspose type: ", typeof(A_trans)) +println("Transpose eltype: ", eltype(A_trans)) + +# LowerTriangular +A_lower = broadcasted(LowerTriangular, A); +println("\nLowerTriangular type: ", typeof(A_lower)) +println("LowerTriangular eltype: ", eltype(A_lower)) + +# UpperTriangular +A_upper = broadcasted(UpperTriangular, A); +println("\nUpperTriangular type: ", typeof(A_upper)) +println("UpperTriangular eltype: ", eltype(A_upper)) + +# ============================================================================= +# Test 2: Function form redirects +# ============================================================================= + +println("\n=== Test 2: Function form redirects ===\n") + +A_adj2 = broadcasted(adjoint, A); +println("adjoint redirect type: ", typeof(A_adj2)) +println("Types match: ", typeof(A_adj) == typeof(A_adj2)) + +A_trans2 = broadcasted(transpose, A); +println("\ntranspose redirect type: ", typeof(A_trans2)) +println("Types match: ", typeof(A_trans) == typeof(A_trans2)) + +# ============================================================================= +# Test 3: Nested wrappers +# ============================================================================= + +println("\n=== Test 3: Nested wrappers ===\n") + +# Adjoint of LowerTriangular +A_lower_adj = broadcasted(Adjoint, A_lower); +println("Adjoint(LowerTriangular) type: ", typeof(A_lower_adj)) +println("Adjoint(LowerTriangular) eltype: ", eltype(A_lower_adj)) + +# ============================================================================= +# Test 4: Verify element access +# ============================================================================= + +println("\n=== Test 4: Element access verification ===\n") + +# Get first element of each wrapped array +println("A[1] type: ", typeof(A[1])) +println("A_adj[1] type: ", typeof(A_adj[1])) +println("A_lower[1] type: ", typeof(A_lower[1])) + +# Check values match +A_cpu = Array(A[1]); +A_adj_cpu = Array(parent(A_adj[1])); +println("\nValues match (A vs parent of A_adj): ", A_cpu ≈ A_adj_cpu) + +# ============================================================================= +# Test 5: SharedCuMatrix wrappers +# ============================================================================= + +println("\n=== Test 5: SharedCuMatrix wrappers ===\n") + +S = SharedCuMatrix(CUDA.randn(Float32, D, D)); +S_adj = broadcasted(Adjoint, S); +println("SharedCuMatrix adjoint type: ", typeof(S_adj)) +println("SharedCuMatrix adjoint eltype: ", eltype(S_adj)) + +# ============================================================================= +# Test 6: Cholesky +# ============================================================================= + +println("\n=== Test 6: Batched Cholesky ===\n") + +# Create positive definite matrices: B = A * A' +B = A .* broadcasted(adjoint, A); +println("B type: ", typeof(B)) + +chol_result = cholesky.(B); +println("Cholesky result type: ", typeof(chol_result)) +println("Cholesky eltype: ", eltype(chol_result)) + +# Check fields +println("\nCholesky fields:") +println(" factors type: ", typeof(chol_result.factors)) +println(" uplo type: ", typeof(chol_result.uplo)) +println(" info type: ", typeof(chol_result.info)) + +# Access individual element +# SKIP: requires scalar indexing into info +# println("\nFirst element:") +# println(" chol_result[1] type: ", typeof(chol_result[1])) + +# ============================================================================= +# Test 7: PDMat wrapper +# ============================================================================= + +println("\n=== Test 7: PDMat wrapper ===\n") + +P = broadcasted(PDMat, B, chol_result); +println("PDMat result type: ", typeof(P)) +println("PDMat eltype: ", eltype(P)) + +println("\nPDMat fields:") +# println(" dim type: ", typeof(P.dim)) # SKIP: not a real field +println(" mat type: ", typeof(P.mat)) +println(" chol type: ", typeof(P.chol)) + +# ============================================================================= +# Test 8: MvNormal wrapper +# ============================================================================= + +println("\n=== Test 8: MvNormal wrapper ===\n") +μ = BatchedCuVector(CUDA.randn(Float32, D, N)); +G = broadcasted(MvNormal, μ, P); +println("MvNormal type: ", typeof(G)) +println("MvNormal eltype: ", eltype(G)) From 84825e76512d98af9fef1090be8f36e6568c5c6e Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 12:31:23 +0000 Subject: [PATCH 16/29] Temporarily disable caching due to world age bugs --- .../src/batching/broadcasting.jl | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/GeneralisedFilters/src/batching/broadcasting.jl b/GeneralisedFilters/src/batching/broadcasting.jl index 8ffda2a0..636217df 100644 --- a/GeneralisedFilters/src/batching/broadcasting.jl +++ b/GeneralisedFilters/src/batching/broadcasting.jl @@ -198,12 +198,17 @@ function Broadcast.materialize(bc::Broadcasted{BatchedStyle}) argtypes = Tuple{map(typeof, args)...} key = (f, argtypes) - if !haskey(BATCHED_FUNC_CACHE, key) - println(" [Generating batched version of $f]") - batched_f = generate_batched_function(f, argtypes) - BATCHED_FUNC_CACHE[key] = batched_f - end - - batched_f = BATCHED_FUNC_CACHE[key] + # HACK: caching was issues when functions were modified + # if !haskey(BATCHED_FUNC_CACHE, key) + # println(" [Generating batched version of $f]") + # batched_f = generate_batched_function(f, argtypes) + # BATCHED_FUNC_CACHE[key] = batched_f + # end + + # batched_f = BATCHED_FUNC_CACHE[key] + # return Base.invokelatest(batched_f, nothing, args...) + + println(" [Generating batched version of $f]") + batched_f = generate_batched_function(f, argtypes) return Base.invokelatest(batched_f, nothing, args...) end From 708b86c95114f34baf01b0dd9df256fdbf854200 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 12:31:53 +0000 Subject: [PATCH 17/29] Update PD/Cholesky ops to new wrapping interface --- GeneralisedFilters/src/batching/operations.jl | 123 +++++++++++------- GeneralisedFilters/src/batching/types.jl | 3 - 2 files changed, 73 insertions(+), 53 deletions(-) diff --git a/GeneralisedFilters/src/batching/operations.jl b/GeneralisedFilters/src/batching/operations.jl index b1fc55d1..7568f1f0 100644 --- a/GeneralisedFilters/src/batching/operations.jl +++ b/GeneralisedFilters/src/batching/operations.jl @@ -161,24 +161,47 @@ end # PDMat Broadcasting # ============================================================================= -# function broadcasted(::Type{PDMat}, A::BatchedCuMatrix{T,CuMatrix{T}}) where {T} -# chol = cholesky_batched(A) -# return BatchedPDMat{T}(chol) -# end - -# function broadcasted(::typeof(\), S::BatchedPDMat{T}, A::BatchedCuMatrix{T}) where {T} -# return pdmat_solve(S, A) -# end - -# function broadcasted(::typeof(/), A::BatchedCuMatrix{T}, S::BatchedPDMat{T}) where {T} -# # Need to actually transpose the data, not just wrap it -# At_data = permutedims(A.data, (2, 1, 3)) -# At = BatchedCuMatrix(At_data) -# result_t = pdmat_solve(S, At) -# # Transpose back -# result_data = permutedims(result_t.data, (2, 1, 3)) -# return BatchedCuMatrix(result_data) -# end +# HACK: PDMat is a constructor so will use +# `broadcasted(::Type{W}, args::Union{BatchedCuMatrix, BatchedCuVector, SharedCuMatrix, SharedCuVector, StructArray}...) where W` +# rather than the desired recursive broadcast to +# `PDMat(mat::AbstractMatrix) = PDMat(mat, cholesky(mat))` +# This method hardcodes a manual override for BatchedCuMatrix inputs. This should be replaced +# by a more general solution in the future. +function broadcasted(::Type{PDMat}, A::BatchedCuMatrix{T}) where {T} + chol = cholesky.(A) + return PDMat.(A, chol) +end + +# HACK: Addition with PDMat extracts .mat field. Should be replaced by automatic +# materialization of PDMat to BatchedCuMatrix in the future. +function broadcasted( + ::typeof(+), A::BatchedCuMatrix{T}, P::StructArray{<:PDMat{T}} +) where {T} + return broadcasted(+, A, P.mat) +end + +function broadcasted( + ::typeof(+), P::StructArray{<:PDMat{T}}, A::BatchedCuMatrix{T} +) where {T} + return broadcasted(+, P.mat, A) +end + +# A / S where S is PDMat: computes A * inv(S) +# potrs solves S * X = B, so we solve S * X = A' and transpose back +function broadcasted( + ::typeof(/), A::BatchedCuMatrix{T}, S::StructArray{<:PDMat{T}} +) where {T} + L = S.chol.factors.data + + # Transpose A: potrs solves S*X = B, we want A*inv(S) = (inv(S)*A')' + At = BatchedCuMatrix(permutedims(A.data, (2, 1, 3))) + + # Solve S * X = A' in-place (result stored in At) + potrs_batched!('L', L, At) + + # Transpose back + return BatchedCuMatrix(permutedims(At.data, (2, 1, 3))) +end # ============================================================================= # Quadratic Form Broadcasting @@ -195,36 +218,36 @@ function broadcasted( return broadcasted(*, temp, Xt) end -# X_A_Xt for BatchedPDMat: X * P * X' where P = L * L' +# X_A_Xt for StructArray{PDMat}: X * P * X' where P = L * L' # Computed as (X * L) * (X * L)' using TRMM and SYRK -# function broadcasted( -# ::typeof(X_A_Xt), P::BatchedPDMat{T}, X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}} -# ) where {T} -# L = P.chol.factors -# N = get_batch_size(P, X) - -# X_inner = inner_size_for_blas(X) -# m = X_inner[1] - -# # Copy X to XL (TRMM overwrites in-place) -# XL_data = if X isa SharedCuMatrix -# repeat(reshape(X.data, size(X.data, 1), size(X.data, 2), 1), 1, 1, N) -# else -# copy(X.data) -# end -# XL = BatchedCuMatrix(XL_data) - -# # XL = X * L using TRMM (side='R' for right multiply, uplo='L' for lower triangular) -# L_data = BatchedCuMatrix(L.data) -# trmm_batched!('R', 'L', 'N', 'N', one(T), L_data, XL) - -# # Result = XL * XL' using SYRK (fills lower triangle) -# Result_data = CuArray{T}(undef, m, m, N) -# Result = BatchedCuMatrix(Result_data) -# syrk_batched!('L', 'N', one(T), XL, zero(T), Result) - -# # Symmetrize: copy lower triangle to upper -# symmetrize_lower!(Result) - -# return Result -# end +# HACK: this function should dispatch to specialised `*` for triangular types but this is +# not yet implemented +function broadcasted( + ::typeof(X_A_Xt), + P::StructArray{<:PDMat{T}}, + X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + # P.chol.factors is StructArray{LowerTriangular}, .data is the BatchedCuMatrix + L = P.chol.factors.data + N = get_batch_size(L, X) + out_dim = inner_size_for_blas(X)[1] + + # Copy X for in-place TRMM + XL = if X isa SharedCuMatrix + BatchedCuMatrix(repeat(reshape(X.data, size(X.data)..., 1), 1, 1, N)) + else + BatchedCuMatrix(copy(X.data)) + end + + # XL = X * L using TRMM (side='R', uplo='L', no transpose, non-unit diagonal) + trmm_batched!('R', 'L', 'N', 'N', one(T), L, XL) + + # result = XL * XL' using SYRK (fills lower triangle only) + result = BatchedCuMatrix(CuArray{T}(undef, out_dim, out_dim, N)) + syrk_batched!('L', 'N', one(T), XL, zero(T), result) + + # Copy lower triangle to upper for full symmetric matrix + symmetrize_lower!(result) + + return result +end diff --git a/GeneralisedFilters/src/batching/types.jl b/GeneralisedFilters/src/batching/types.jl index 8efc0bd1..88843b24 100644 --- a/GeneralisedFilters/src/batching/types.jl +++ b/GeneralisedFilters/src/batching/types.jl @@ -48,9 +48,6 @@ end const SharedArray = Union{SharedCuVector,SharedCuMatrix} -# Convenience constructors that infer memory type -SharedCuMatrix(data::CuArray{T,2,M}) where {T,M} = SharedCuMatrix{T,M}(data) -SharedCuVector(data::CuArray{T,1,M}) where {T,M} = SharedCuVector{T,M}(data) Shared(x::CuArray{T,2,M}) where {T,M} = SharedCuMatrix{T,M}(x) Shared(x::CuArray{T,1,M}) where {T,M} = SharedCuVector{T,M}(x) From 5977d4ceb83e709cd4c0efea26ecd6cd417ee543 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 12:32:11 +0000 Subject: [PATCH 18/29] Update batching demo script to new wrapper interface --- research/batching/batching_demo.jl | 40 ++++++++++++++++-------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/research/batching/batching_demo.jl b/research/batching/batching_demo.jl index d2841061..99ea2f77 100644 --- a/research/batching/batching_demo.jl +++ b/research/batching/batching_demo.jl @@ -17,9 +17,9 @@ Magma.magma_init() # Configuration # ============================================================================= -D_state = 64 -D_obs = 64 -N = 1000 +D_state = 2 +D_obs = 2 +N = 3 function kalman_predict(state, dyn_params) A = dyn_params[1] @@ -45,7 +45,7 @@ Q = Q_root * Q_root' + I Qs = SharedCuMatrix(Q) Σ_PDs = broadcasted(PDMat, Σs); -Gs = StructArray{MvNormal}((μ=μs, Σ=Σ_PDs)); +Gs = MvNormal.(μs, Σ_PDs); function kalman_predict(state, dyn_params) A = dyn_params[1] @@ -53,7 +53,8 @@ function kalman_predict(state, dyn_params) Q = dyn_params[3] μ̂ = A * state.μ + b - Σ̂ = X_A_Xt(state.Σ, A) + Q + Σ̂ = PDMat(X_A_Xt(state.Σ, A) + Q) + return MvNormal(μ̂, Σ̂) end @@ -69,11 +70,11 @@ Q_test = Array(Qs.data) pred_G_test = kalman_predict(MvNormal(μ_test, PDMat(Σ_test)), (A_test, b_test, Q_test)) println("=== Predict Comparison ===") -println("CPU Mean: ", pred_G_test.μ[1:5]) -println("GPU Mean: ", Array(pred_Gs[end].μ[1:5])) +println("CPU Mean: ", pred_G_test.μ) +println("GPU Mean: ", Array(pred_Gs.μ[end])) -println("CPU Covariance [1:3, 1:3]: ", Matrix(pred_G_test.Σ)[1:3, 1:3]) -println("GPU Covariance [1:3, 1:3]: ", Array(pred_Gs[end].Σ)[1:3, 1:3]) +println("CPU Covariance: ", Matrix(pred_G_test.Σ)) +println("GPU Covariance: ", Array(pred_Gs.Σ.mat[end])) # ============================================================================= # Kalman Update @@ -96,7 +97,7 @@ function kalman_update(state, obs_params, observation) # Update parameters using Joseph form for numerical stability μ̂ = μ + K * ȳ - Σ̂ = X_A_Xt(Σ, I - K * H) + X_A_Xt(R, K) + Σ̂ = PDMat(X_A_Xt(Σ, I - K * H) + X_A_Xt(R, K)) return MvNormal(μ̂, Σ̂) end @@ -114,6 +115,7 @@ I_obs = CuArray{Float32}(I, D_obs, D_obs) I_obs_shared = SharedCuMatrix(I_obs) Rs_root = BatchedCuMatrix(CUDA.randn(Float32, D_obs, D_obs, N)) Rs = Rs_root .* adjoint.(Rs_root) .+ I_obs_shared +Rs = PDMat.(Rs); obs_params = (Hs, cs, Rs) @@ -126,17 +128,17 @@ update_Gs = kalman_update.(pred_Gs, Ref(obs_params), observations); # Compare update to CPU H_test = Array(Hs.data) c_test = Array(cs.data) -R_test = PDMat(Array(Rs[end])) +R_test = PDMat(Array(Rs.mat[end])) obs_test = Array(observations[end]) update_G_test = kalman_update(pred_G_test, (H_test, c_test, R_test), obs_test) println("\n=== Update Comparison ===") -println("CPU Mean: ", update_G_test.μ[1:5]) -println("GPU Mean: ", Array(update_Gs.μ[end][1:5])) +println("CPU Mean: ", update_G_test.μ) +println("GPU Mean: ", Array(update_Gs.μ[end])) -println("CPU Covariance [1:3, 1:3]: ", Matrix(update_G_test.Σ)[1:3, 1:3]) -println("GPU Covariance [1:3, 1:3]: ", Array(update_Gs.Σ[end])[1:3, 1:3]) +println("CPU Covariance: ", Matrix(update_G_test.Σ)) +println("GPU Covariance: ", Array(update_Gs.Σ.mat[end])) # ============================================================================= # Full Kalman Step @@ -154,8 +156,8 @@ step_G_test = kalman_step( ) println("\n=== Full Step Comparison ===") -println("CPU Mean: ", step_G_test.μ[1:5]) -println("GPU Mean: ", Array(step_Gs.μ[end][1:5])) +println("CPU Mean: ", step_G_test.μ) +println("GPU Mean: ", Array(step_Gs.μ[end])) -println("CPU Covariance [1:3, 1:3]: ", Matrix(step_G_test.Σ)[1:3, 1:3]) -println("GPU Covariance [1:3, 1:3]: ", Array(step_Gs.Σ[end])[1:3, 1:3]) +println("CPU Covariance: ", Matrix(step_G_test.Σ)) +println("GPU Covariance: ", Array(step_Gs.Σ.mat[end])) From 8d2a1195f6c2b7f8e94d01f5e547017dd453ba5c Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 12:34:29 +0000 Subject: [PATCH 19/29] Fix formatting --- GeneralisedFilters/src/batching/types.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/GeneralisedFilters/src/batching/types.jl b/GeneralisedFilters/src/batching/types.jl index 88843b24..ece85e77 100644 --- a/GeneralisedFilters/src/batching/types.jl +++ b/GeneralisedFilters/src/batching/types.jl @@ -48,7 +48,6 @@ end const SharedArray = Union{SharedCuVector,SharedCuMatrix} - Shared(x::CuArray{T,2,M}) where {T,M} = SharedCuMatrix{T,M}(x) Shared(x::CuArray{T,1,M}) where {T,M} = SharedCuVector{T,M}(x) From ccf28e121a1487de48670e22addea83cbef5959b Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 13:23:57 +0000 Subject: [PATCH 20/29] Reintroduce broadcast caching with proper world age handling --- .../src/batching/broadcasting.jl | 52 +++++++++++++++---- research/batching/batching_demo.jl | 2 + 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/GeneralisedFilters/src/batching/broadcasting.jl b/GeneralisedFilters/src/batching/broadcasting.jl index 636217df..8b9797ba 100644 --- a/GeneralisedFilters/src/batching/broadcasting.jl +++ b/GeneralisedFilters/src/batching/broadcasting.jl @@ -8,6 +8,8 @@ import Base.Broadcast: broadcasted import PDMats: PDMat +export BATCHED_CACHE_VERBOSITY, clear_batched_cache! + # ============================================================================= # Broadcast Style # ============================================================================= @@ -184,7 +186,24 @@ end # Broadcast Materialization # ============================================================================= -const BATCHED_FUNC_CACHE = Dict{Tuple,Any}() +# Verbosity levels: :silent, :verbose, :debug +# :silent - no output +# :verbose - print when generating or regenerating (i.e. cache misses) +# :debug - print all cache activity including hits +const BATCHED_CACHE_VERBOSITY = Ref{Symbol}(:silent) + +# Cache stores (batched_function, world_age_when_cached) +const BATCHED_FUNC_CACHE = Dict{Tuple{Any,Type},Tuple{Any,UInt}}() + +""" + clear_batched_cache!() + +Clear the batched function cache. Useful for debugging or forcing regeneration. +""" +function clear_batched_cache!() + empty!(BATCHED_FUNC_CACHE) + return nothing +end function Broadcast.materialize(bc::Broadcasted{BatchedStyle}) f = bc.f @@ -198,17 +217,30 @@ function Broadcast.materialize(bc::Broadcasted{BatchedStyle}) argtypes = Tuple{map(typeof, args)...} key = (f, argtypes) - # HACK: caching was issues when functions were modified - # if !haskey(BATCHED_FUNC_CACHE, key) - # println(" [Generating batched version of $f]") - # batched_f = generate_batched_function(f, argtypes) - # BATCHED_FUNC_CACHE[key] = batched_f - # end + # Get element types for method lookup + element_types = Tuple{map(ir_element_type, argtypes.parameters)...} - # batched_f = BATCHED_FUNC_CACHE[key] - # return Base.invokelatest(batched_f, nothing, args...) + if haskey(BATCHED_FUNC_CACHE, key) + batched_f, cached_world = BATCHED_FUNC_CACHE[key] + # Check if the method has been redefined since caching + m = which(f, element_types) + if m.primary_world <= cached_world + if BATCHED_CACHE_VERBOSITY[] == :debug + println(" [Using cached batched version of $f]") + end + return Base.invokelatest(batched_f, nothing, args...) + end + if BATCHED_CACHE_VERBOSITY[] in (:verbose, :debug) + println(" [Regenerating batched version of $f (method redefined)]") + end + else + if BATCHED_CACHE_VERBOSITY[] in (:verbose, :debug) + println(" [Generating batched version of $f]") + end + end - println(" [Generating batched version of $f]") batched_f = generate_batched_function(f, argtypes) + current_world = Base.get_world_counter() + BATCHED_FUNC_CACHE[key] = (batched_f, current_world) return Base.invokelatest(batched_f, nothing, args...) end diff --git a/research/batching/batching_demo.jl b/research/batching/batching_demo.jl index 99ea2f77..303db60d 100644 --- a/research/batching/batching_demo.jl +++ b/research/batching/batching_demo.jl @@ -21,6 +21,8 @@ D_state = 2 D_obs = 2 N = 3 +BATCHED_CACHE_VERBOSITY[] = :debug + function kalman_predict(state, dyn_params) A = dyn_params[1] b = dyn_params[2] From 67826f0c00af917b09cd411538cb11abcc99746c Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 15:26:12 +0000 Subject: [PATCH 21/29] Implement reusual queue (3x speedup) --- GeneralisedFilters/src/batching/wrappers.jl | 151 +++++++++++--------- 1 file changed, 82 insertions(+), 69 deletions(-) diff --git a/GeneralisedFilters/src/batching/wrappers.jl b/GeneralisedFilters/src/batching/wrappers.jl index 969a6bb8..56de9a57 100644 --- a/GeneralisedFilters/src/batching/wrappers.jl +++ b/GeneralisedFilters/src/batching/wrappers.jl @@ -3,6 +3,49 @@ using Magma.LibMagma using LinearAlgebra: cholesky, Cholesky, LowerTriangular using StructArrays: StructArray +export get_magma_queue, reset_magma_queue! + +# ============================================================================= +# MAGMA Queue Management +# ============================================================================= + +# Store the queue pointer directly (not in a Ref) to avoid lifetime issues +mutable struct MagmaQueueCache + ptr::LibMagma.magma_queue_t + initialized::Bool +end + +const MAGMA_QUEUE_CACHE = MagmaQueueCache(LibMagma.magma_queue_t(), false) + +""" + get_magma_queue() + +Get the cached MAGMA queue, creating it on first use. +Returns the raw queue pointer for use with MAGMA functions. +""" +function get_magma_queue() + if !MAGMA_QUEUE_CACHE.initialized + queue_ref = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ref, C_NULL, C_NULL, 0) + MAGMA_QUEUE_CACHE.ptr = queue_ref[] + MAGMA_QUEUE_CACHE.initialized = true + end + return MAGMA_QUEUE_CACHE.ptr +end + +""" + reset_magma_queue!() + +Destroy and recreate the MAGMA queue. Useful if you suspect queue corruption. +""" +function reset_magma_queue!() + if MAGMA_QUEUE_CACHE.initialized + LibMagma.magma_queue_destroy_internal(MAGMA_QUEUE_CACHE.ptr, C_NULL, C_NULL, 0) + MAGMA_QUEUE_CACHE.initialized = false + end + return get_magma_queue() +end + # ============================================================================= # Trivial Wrappers (reductions and elementwise operations) # ============================================================================= @@ -127,8 +170,7 @@ function gemm_batched!( lddc = m CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magma_sgemm_batched( magma_trans(transA), magma_trans(transB), @@ -144,10 +186,9 @@ function gemm_batched!( dC, lddc, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dB) @@ -182,8 +223,7 @@ function gemm_batched_smallsq!( lddc = m CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_sgemm_batched_smallsq( magma_trans(transA), magma_trans(transB), @@ -205,10 +245,9 @@ function gemm_batched_smallsq!( 0, # cj lddc, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dB) @@ -239,8 +278,7 @@ function gemm_batched_smallsq!( lddc = m CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_dgemm_batched_smallsq( magma_trans(transA), magma_trans(transB), @@ -262,10 +300,9 @@ function gemm_batched_smallsq!( 0, # cj lddc, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dB) @@ -298,13 +335,11 @@ function gemv_batched!( incy = 1 CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_sgemv_batched( - magma_trans(transA), m, n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue_ptr[] + magma_trans(transA), m, n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dx) @@ -337,13 +372,11 @@ function gemv_batched!( incy = 1 CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_dgemv_batched( - magma_trans(transA), m, n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue_ptr[] + magma_trans(transA), m, n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dx) @@ -376,13 +409,11 @@ function gemv_batched_smallsq!( incy = 1 CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_sgemv_batched_smallsq( - magma_trans(transA), n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue_ptr[] + magma_trans(transA), n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dx) @@ -399,13 +430,9 @@ function potrf_batched!(uplo::Char, A::BatchedCuMatrix{Float32}, info::CuVector{ dA = create_pointer_array(A) CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) - LibMagma.magma_spotrf_batched( - magma_uplo(uplo), n, dA, lda, pointer(info), N, queue_ptr[] - ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + queue = get_magma_queue() + LibMagma.magma_spotrf_batched(magma_uplo(uplo), n, dA, lda, pointer(info), N, queue) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) @@ -426,13 +453,9 @@ function potrs_batched!( lddb = n CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) - LibMagma.magma_spotrs_batched( - magma_uplo(uplo), n, nrhs, dA, ldda, dB, lddb, N, queue_ptr[] - ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + queue = get_magma_queue() + LibMagma.magma_spotrs_batched(magma_uplo(uplo), n, nrhs, dA, ldda, dB, lddb, N, queue) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dB) @@ -459,8 +482,7 @@ function trsm_batched!( lddb = m CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_strsm_batched( magma_side(side), magma_uplo(uplo), @@ -474,10 +496,9 @@ function trsm_batched!( dB, lddb, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dB) @@ -543,8 +564,7 @@ function trmm_batched!( lddb = m CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_strmm_batched( magma_side(side), magma_uplo(uplo), @@ -558,10 +578,9 @@ function trmm_batched!( dB, lddb, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dB) @@ -588,8 +607,7 @@ function trmm_batched!( lddb = m CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_dtrmm_batched( magma_side(side), magma_uplo(uplo), @@ -603,10 +621,9 @@ function trmm_batched!( dB, lddb, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dB) @@ -637,8 +654,7 @@ function syrk_batched!( lddc = n CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_ssyrk_batched( magma_uplo(uplo), magma_trans(trans), @@ -651,10 +667,9 @@ function syrk_batched!( dC, lddc, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dC) @@ -681,8 +696,7 @@ function syrk_batched!( lddc = n CUDA.synchronize() - queue_ptr = Ref{LibMagma.magma_queue_t}() - LibMagma.magma_queue_create_internal(0, queue_ptr, C_NULL, C_NULL, 0) + queue = get_magma_queue() LibMagma.magmablas_dsyrk_batched( magma_uplo(uplo), magma_trans(trans), @@ -695,10 +709,9 @@ function syrk_batched!( dC, lddc, N, - queue_ptr[], + queue, ) - LibMagma.magma_queue_sync_internal(queue_ptr[], C_NULL, C_NULL, 0) - LibMagma.magma_queue_destroy_internal(queue_ptr[], C_NULL, C_NULL, 0) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) CUDA.unsafe_free!(dA) CUDA.unsafe_free!(dC) From a46c607823c50e3d225c62ded7437aec262e5f95 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 7 Jan 2026 16:03:06 +0000 Subject: [PATCH 22/29] Add benchmarking to demo script --- research/batching/batching_demo.jl | 118 ++++++++++++++++++++++++++++- 1 file changed, 115 insertions(+), 3 deletions(-) diff --git a/research/batching/batching_demo.jl b/research/batching/batching_demo.jl index 303db60d..8830fc1d 100644 --- a/research/batching/batching_demo.jl +++ b/research/batching/batching_demo.jl @@ -71,7 +71,7 @@ b_test = Array(bs[end]) Q_test = Array(Qs.data) pred_G_test = kalman_predict(MvNormal(μ_test, PDMat(Σ_test)), (A_test, b_test, Q_test)) -println("=== Predict Comparison ===") +println("=== Predict Comparison ===\n") println("CPU Mean: ", pred_G_test.μ) println("GPU Mean: ", Array(pred_Gs.μ[end])) @@ -135,7 +135,7 @@ obs_test = Array(observations[end]) update_G_test = kalman_update(pred_G_test, (H_test, c_test, R_test), obs_test) -println("\n=== Update Comparison ===") +println("\n=== Update Comparison ===\n") println("CPU Mean: ", update_G_test.μ) println("GPU Mean: ", Array(update_Gs.μ[end])) @@ -157,9 +157,121 @@ step_G_test = kalman_step( obs_test, ) -println("\n=== Full Step Comparison ===") +println("\n=== Full Step Comparison ===\n") println("CPU Mean: ", step_G_test.μ) println("GPU Mean: ", Array(step_Gs.μ[end])) println("CPU Covariance: ", Matrix(step_G_test.Σ)) println("GPU Covariance: ", Array(step_Gs.Σ.mat[end])) + +# ============================================================================= +# Benchmarking +# ============================================================================= + +using BenchmarkTools +using StaticArrays + +D_bench = 10 +N_bench = 60000 + +println("\n=== Benchmarking batched Kalman step ===\n") + +println("D = $D_bench, N = $N_bench") +println( + "Size of batched covariance matrices: ", + round(Int, D_bench * D_bench * N_bench * sizeof(Float32) / 1024^2), + " MB", +) + +Is_bench = SharedCuMatrix(CuArray{Float32}(I, D_bench, D_bench)) +μs_bench = BatchedCuVector(CUDA.randn(Float32, D_bench, N_bench)) +Σs_root_bench = BatchedCuMatrix(CUDA.randn(Float32, D_bench, D_bench, N_bench)) +Σs_bench = Σs_root_bench .* adjoint.(Σs_root_bench) .+ Is_bench +Σ_PDs_bench = broadcasted(PDMat, Σs_bench); +Gs_bench = MvNormal.(μs_bench, Σ_PDs_bench); + +As_bench = SharedCuMatrix(CUDA.randn(Float32, D_bench, D_bench)) +bs_bench = BatchedCuVector(CUDA.randn(Float32, D_bench, N_bench)) +Qs_root_bench = CUDA.randn(Float32, D_bench, D_bench) +Qs_bench_mat = Qs_root_bench * adjoint(Qs_root_bench) +Qs_bench_mat += I +Qs_bench = SharedCuMatrix(Qs_bench_mat) +dyn_params_bench = (As_bench, bs_bench, Qs_bench) + +Hs_bench = SharedCuMatrix(CUDA.randn(Float32, D_bench, D_bench)) +cs_bench = SharedCuVector(CUDA.randn(Float32, D_bench)) +Rs_root_bench = BatchedCuMatrix(CUDA.randn(Float32, D_bench, D_bench, N_bench)) +Rs_bench = Rs_root_bench .* adjoint.(Rs_root_bench) .+ Is_bench +Rs_bench = PDMat.(Rs_bench) +obs_params_bench = (Hs_bench, cs_bench, Rs_bench) + +ys_bench = BatchedCuVector(CUDA.randn(Float32, D_bench, N_bench)) + +println("\nBenchmarking batched Kalman step...\n") + +BATCHED_CACHE_VERBOSITY[] = :silent +display( + @benchmark kalman_step.( + $Gs_bench, Ref($dyn_params_bench), Ref($obs_params_bench), $ys_bench + ) +) + +# Compare to static arrays +μs_static = [@SVector randn(Float32, D_bench) for _ in 1:N_bench] +Σs_static = [ + begin + A = @SMatrix randn(Float32, D_bench, D_bench) + A * A' + I + end for _ in 1:N_bench +] +Gs_static = [MvNormal(μs_static[n], PDMat(Σs_static[n])) for n in 1:N_bench] + +As_static = @SMatrix randn(Float32, D_bench, D_bench) +bs_static = [@SVector randn(Float32, D_bench) for _ in 1:N_bench] +Qs_root_static = @SMatrix randn(Float32, D_bench, D_bench) +Qs_static_mat = Qs_root_static * adjoint(Qs_root_static) + I +Qs_static = Qs_static_mat +dyn_params_static = (As_static, bs_static, Qs_static) + +Hs_static = @SMatrix randn(Float32, D_bench, D_bench) +cs_static = @SVector randn(Float32, D_bench) +Rs_root_static = [@SMatrix randn(Float32, D_bench, D_bench) for _ in 1:N_bench] +Rs_static = [ + Symmetric(R_root * R_root') + SMatrix{D_bench,D_bench,Float32}(I) for + R_root in Rs_root_static +] +obs_static = [@SVector randn(Float32, D_bench) for _ in 1:N_bench] +obs_params_static = (Hs_static, cs_static, PDMat.(Rs_static)) + +function test_static(Gs, dyn_params, obs_params, observations, out) + N = length(out) + Threads.@threads for n in 1:N + @inbounds out[n] = kalman_step( + Gs[n], + (dyn_params[1], dyn_params[2][n], dyn_params[3]), + (obs_params[1], obs_params[2], obs_params[3][n]), + observations[n], + ) + end + return out +end + +out = Vector{eltype(Gs_static)}(undef, N_bench) + +println("\nBenchmarking static Kalman step...\n") + +test_static(Gs_static, dyn_params_static, obs_params_static, obs_static, out); # warm-up + +display( + @benchmark test_static( + $Gs_static, $dyn_params_static, $obs_params_static, $obs_static, $out + ) +) + +println("\nBenchmarking single batched matrix multiplication...\n") + +display(@benchmark $Σs_root_bench .* $Σs_root_bench) + +tot_mem = 3 * D_bench * D_bench * N_bench * sizeof(Float32) +throughput = 1008 * 10^9 +println("\nTheoretical optimum time: ", tot_mem / throughput * 10^6, " μs") From 9c7a9a59dc2990bd179b46fcb219309977a1950d Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Sat, 10 Jan 2026 09:32:21 +0000 Subject: [PATCH 23/29] Improve scalar handling --- .../src/batching/broadcasting.jl | 43 +++++++- GeneralisedFilters/src/batching/operations.jl | 98 +++++++++++++++++++ research/batching/full_kalman_demo.jl | 78 +++++++++++++++ 3 files changed, 215 insertions(+), 4 deletions(-) create mode 100644 research/batching/full_kalman_demo.jl diff --git a/GeneralisedFilters/src/batching/broadcasting.jl b/GeneralisedFilters/src/batching/broadcasting.jl index 8b9797ba..1be4a0e3 100644 --- a/GeneralisedFilters/src/batching/broadcasting.jl +++ b/GeneralisedFilters/src/batching/broadcasting.jl @@ -36,7 +36,6 @@ maybe_convert_ref(r::Base.RefValue{<:CuMatrix}) = SharedCuMatrix(r[]) # Structural Operations (Pass-through) # ============================================================================= -broadcasted(::typeof(tuple), args...) = tuple(args...) broadcasted(::typeof(getproperty), x, s::Symbol) = getproperty(x, s) broadcasted(::typeof(getfield), x, s::Symbol) = getfield(x, s) broadcasted(::typeof(getfield), x, i::Int) = getfield(x, i) @@ -45,6 +44,16 @@ broadcasted(::typeof(getfield), x, i::Int) = getfield(x, i) broadcasted(::typeof(getfield), r::Base.RefValue, i::Int) = getfield(r[], i) broadcasted(::typeof(getfield), r::Base.RefValue, s::Symbol) = getfield(r[], s) +# StructArray{<:Tuple} destructuring: extract the i-th component array +function broadcasted(::typeof(Base.indexed_iterate), sa::StructArray{<:Tuple}, i::Int) + return (StructArrays.component(sa, i), i + 1) +end +function broadcasted( + ::typeof(Base.indexed_iterate), sa::StructArray{<:Tuple}, i::Int, ::Any +) + return (StructArrays.component(sa, i), i + 1) +end + # ============================================================================= # StructArray Wrapping # ============================================================================= @@ -91,19 +100,44 @@ end broadcasted(::typeof(adjoint), A::BatchedOrShared) = broadcasted(Adjoint, A) broadcasted(::typeof(transpose), A::BatchedOrShared) = broadcasted(Transpose, A) +# Batched tuple creation: returns StructArray{Tuple{...}} +function broadcasted(::typeof(tuple), args::Vararg{BatchedOrShared}) + ElType = Tuple{map(eltype, args)...} + return StructArray{ElType}(args) +end + # ============================================================================= # IR Transformation # ============================================================================= -const SKIP_BROADCAST = Set{Any}([tuple, Core.tuple, getfield, getproperty]) +const SKIP_BROADCAST = Set{Any}([getfield, getproperty]) const BROADCAST_TYPES = Set{Any}([PDMat]) -maybe_wrap_scalar(x) = x -maybe_wrap_scalar(x::UniformScaling) = Ref(x) +# Don't wrap: batched arrays (already batched), callables (functions/types), modules, symbols, integers, already-wrapped refs +maybe_wrap_scalar(x::Union{BatchedOrShared,StructArray}) = x +maybe_wrap_scalar(x::Union{Type,Module,Symbol,Integer,Base.RefValue}) = x +maybe_wrap_scalar(x) = typeof(x) <: Function ? x : Ref(x) @inline function broadcast_and_materialize(f, args...) wrapped_args = map(maybe_wrap_scalar, args) + + # Check if any argument is actually batched (not just Ref-wrapped scalar) + has_batched = any(arg -> arg isa Union{BatchedOrShared,StructArray}, wrapped_args) + + if !has_batched + # All scalars - unwrap, execute scalar operation, re-wrap result + unwrapped_args = map(a -> a isa Base.RefValue ? a[] : a, wrapped_args) + result = f(unwrapped_args...) + # Don't wrap code/metadata, batched results, or already-wrapped values + should_wrap = !( + typeof(result) <: Function || + result isa Union{Type,Module,Symbol,BatchedOrShared,StructArray,Base.RefValue} + ) + return should_wrap ? Ref(result) : result + end + + # Has batched inputs - normal broadcast path result = broadcasted(f, wrapped_args...) if result isa Broadcasted return Broadcast.materialize(result) @@ -168,6 +202,7 @@ end ir_element_type(::Type{T}) where {T} = T ir_element_type(::Type{<:StructArray{T}}) where {T} = T +ir_element_type(::Type{<:Base.RefValue{T}}) where {T} = T function generate_batched_function(f, argtypes::Type{<:Tuple}) element_types = Tuple{map(ir_element_type, argtypes.parameters)...} diff --git a/GeneralisedFilters/src/batching/operations.jl b/GeneralisedFilters/src/batching/operations.jl index 7568f1f0..29cfd70b 100644 --- a/GeneralisedFilters/src/batching/operations.jl +++ b/GeneralisedFilters/src/batching/operations.jl @@ -1,4 +1,5 @@ import PDMats: X_A_Xt +import LinearAlgebra: norm # ============================================================================= # GEMM-Compatible Types @@ -157,6 +158,67 @@ function broadcasted( return BatchedCuMatrix(C) end +# ============================================================================= +# Matrix Plus Scaled Identity (A + λI) Custom Kernel +# ============================================================================= + +function plus_scaled_identity_kernel!(C, A, λ, m, n) + batch_idx = blockIdx().z + i = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + j = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y + + if i <= m && j <= n + if i == j + @inbounds C[i, j, batch_idx] = A[i, j, batch_idx] + λ + else + @inbounds C[i, j, batch_idx] = A[i, j, batch_idx] + end + end + return nothing +end + +function plus_scaled_identity_batched!(C::CuArray{T,3}, A::CuArray{T,3}, λ::T) where {T} + m, n, N = size(A) + threads = (16, 16) + blocks = (cld(m, 16), cld(n, 16), N) + @cuda threads = threads blocks = blocks plus_scaled_identity_kernel!(C, A, λ, m, n) + return C +end + +# A + Ref(I) where I is unscaled identity +function broadcasted( + ::typeof(+), A::BatchedCuMatrix{T}, ::Base.RefValue{UniformScaling{Bool}} +) where {T} + C = CuArray{T}(undef, size(A.data)) + plus_scaled_identity_batched!(C, A.data, one(T)) + return BatchedCuMatrix(C) +end + +function broadcasted( + ::typeof(+), ::Base.RefValue{UniformScaling{Bool}}, A::BatchedCuMatrix{T} +) where {T} + C = CuArray{T}(undef, size(A.data)) + plus_scaled_identity_batched!(C, A.data, one(T)) + return BatchedCuMatrix(C) +end + +# A + λI where λI is a scaled UniformScaling +function broadcasted( + ::typeof(+), A::BatchedCuMatrix{T}, J::Base.RefValue{<:UniformScaling{T}} +) where {T} + C = CuArray{T}(undef, size(A.data)) + plus_scaled_identity_batched!(C, A.data, J[].λ) + return BatchedCuMatrix(C) +end + +function broadcasted( + ::typeof(+), J::Base.RefValue{<:UniformScaling{T}}, A::BatchedCuMatrix{T} +) where {T} + C = CuArray{T}(undef, size(A.data)) + plus_scaled_identity_batched!(C, A.data, J[].λ) + return BatchedCuMatrix(C) +end + # ============================================================================= # PDMat Broadcasting # ============================================================================= @@ -251,3 +313,39 @@ function broadcasted( return result end + +# ============================================================================= +# Batched norm +# ============================================================================= + +# Compute 2-norm for each vector in the batch, returns a CuVector of scalars +function broadcasted(::typeof(norm), v::BatchedCuVector{T}) where {T} + # v.data is D×N, compute norm of each column + return vec(sqrt.(sum(abs2, v.data; dims=1))) +end + +# ============================================================================= +# Batched ifelse (for conditional selection with batched conditions) +# ============================================================================= + +# Select entire vectors: x[:,j] if cond[j], else y[:,j] +function broadcasted( + ::typeof(ifelse), cond::CuVector{Bool}, x::BatchedCuVector{T}, y::BatchedCuVector{T} +) where {T} + # cond is length N (one bool per batch element) + # x.data and y.data are D×N + mask = reshape(T.(cond), 1, :) # 1×N mask for column selection + result = mask .* x.data .+ (one(T) .- mask) .* y.data + return BatchedCuVector(result) +end + +# Select entire matrices: x[:,:,j] if cond[j], else y[:,:,j] +function broadcasted( + ::typeof(ifelse), cond::CuVector{Bool}, x::BatchedCuMatrix{T}, y::BatchedCuMatrix{T} +) where {T} + # cond is length N (one bool per batch element) + # x.data and y.data are D×D×N + mask = reshape(T.(cond), 1, 1, :) # 1×1×N mask for batch selection + result = mask .* x.data .+ (one(T) .- mask) .* y.data + return BatchedCuMatrix(result) +end diff --git a/research/batching/full_kalman_demo.jl b/research/batching/full_kalman_demo.jl new file mode 100644 index 00000000..2c65aaef --- /dev/null +++ b/research/batching/full_kalman_demo.jl @@ -0,0 +1,78 @@ +using GeneralisedFilters +const GF = GeneralisedFilters + +using Distributions +using LinearAlgebra +using Base.Broadcast: broadcasted +using PDMats +using StructArrays +using BenchmarkTools +import Distributions: params + +using CUDA +using Magma +using Magma.LibMagma + +Magma.magma_init() +BATCHED_CACHE_VERBOSITY[] = :debug + +D_state = 3 +D_obs = 2 +N = 3 + +μs = BatchedCuVector(CUDA.randn(Float32, D_state, N)) +Σs_root = BatchedCuMatrix(CUDA.randn(Float32, D_state, D_state, N)) +Σs = Σs_root .* adjoint.(Σs_root) .+ Ref(I) + +As = BatchedCuMatrix(CUDA.randn(Float32, D_state, D_state, N)) +bs = BatchedCuVector(CUDA.randn(Float32, D_state, N)) +Q_root = BatchedCuMatrix(CUDA.randn(Float32, D_state, D_state, N)) +Qs = Q_root .* adjoint.(Q_root) .+ Ref(I) +Qs = PDMat.(Qs); + +Σ_PDs = broadcasted(PDMat, Σs); +Gs = MvNormal.(μs, Σ_PDs); + +# Observation parameters (H and c shared, R batched) +Hs = BatchedCuMatrix(CUDA.randn(Float32, D_obs, D_state, N)) +cs = BatchedCuVector(CUDA.randn(Float32, D_obs, N)) +Rs_root = BatchedCuMatrix(CUDA.randn(Float32, D_obs, D_obs, N)) +Rs = Rs_root .* adjoint.(Rs_root) .+ Ref(I) +Rs = PDMat.(Rs); + +# Observations +observations = BatchedCuVector(CUDA.randn(Float32, D_obs, N)) +jitter = 1f-6 + +dyn_params = tuple.(As, bs, Qs) +obs_params = tuple.(Hs, cs, Rs) + +# Dispatch-based jitter application (avoids if statements in batched code) +_maybe_apply_jitter(Σ, ::Nothing) = Σ +_maybe_apply_jitter(Σ, jitter::Real) = Σ + jitter * I + +function kalman_step(state, dyn_params, obs_params, observation, jitter) + μ, Σ = params(state) + A, b, Q = dyn_params + + μ = A * μ + b + Σ = X_A_Xt(Σ, A) + Q + + H, c, R = obs_params + + z = GF._compute_innovation(μ, H, c, observation) + S = GF._compute_innovation_cov(Σ, H, R) + K = GF._compute_kalman_gain(Σ, H, S) + _, Σ̂_raw = GF._compute_joseph_update(Σ, K, H, R) + + μ = μ + K * z + Σ = PDMat(_maybe_apply_jitter(Σ̂_raw, jitter)) + + # ll = logpdf(MvNormal(z, S), zero(z)) + + return MvNormal(μ, Σ) +end + +res = kalman_step.(Gs, dyn_params, obs_params, observations, Ref(jitter)) +println(typeof(res)) +println(eltype(res)) From 789ccba2e9867f09262ebc59efdefd5000063a0f Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Sat, 10 Jan 2026 10:19:11 +0000 Subject: [PATCH 24/29] Complete full Kalman batching --- .../src/batching/broadcasting.jl | 15 ++-- GeneralisedFilters/src/batching/operations.jl | 76 +++++++++++++++++++ research/batching/full_kalman_demo.jl | 12 +-- 3 files changed, 92 insertions(+), 11 deletions(-) diff --git a/GeneralisedFilters/src/batching/broadcasting.jl b/GeneralisedFilters/src/batching/broadcasting.jl index 1be4a0e3..10967b5b 100644 --- a/GeneralisedFilters/src/batching/broadcasting.jl +++ b/GeneralisedFilters/src/batching/broadcasting.jl @@ -100,8 +100,11 @@ end broadcasted(::typeof(adjoint), A::BatchedOrShared) = broadcasted(Adjoint, A) broadcasted(::typeof(transpose), A::BatchedOrShared) = broadcasted(Transpose, A) +# Union of all types that represent batched data +const BatchedData = Union{BatchedOrShared,StructArray,CuVector} + # Batched tuple creation: returns StructArray{Tuple{...}} -function broadcasted(::typeof(tuple), args::Vararg{BatchedOrShared}) +function broadcasted(::typeof(tuple), args::Vararg{BatchedData}) ElType = Tuple{map(eltype, args)...} return StructArray{ElType}(args) end @@ -114,16 +117,16 @@ const SKIP_BROADCAST = Set{Any}([getfield, getproperty]) const BROADCAST_TYPES = Set{Any}([PDMat]) -# Don't wrap: batched arrays (already batched), callables (functions/types), modules, symbols, integers, already-wrapped refs -maybe_wrap_scalar(x::Union{BatchedOrShared,StructArray}) = x +# Don't wrap: batched data, callables, modules, symbols, integers, already-wrapped refs +maybe_wrap_scalar(x::BatchedData) = x maybe_wrap_scalar(x::Union{Type,Module,Symbol,Integer,Base.RefValue}) = x maybe_wrap_scalar(x) = typeof(x) <: Function ? x : Ref(x) @inline function broadcast_and_materialize(f, args...) wrapped_args = map(maybe_wrap_scalar, args) - # Check if any argument is actually batched (not just Ref-wrapped scalar) - has_batched = any(arg -> arg isa Union{BatchedOrShared,StructArray}, wrapped_args) + # Check if any argument is actually batched + has_batched = any(arg -> arg isa BatchedData, wrapped_args) if !has_batched # All scalars - unwrap, execute scalar operation, re-wrap result @@ -132,7 +135,7 @@ maybe_wrap_scalar(x) = typeof(x) <: Function ? x : Ref(x) # Don't wrap code/metadata, batched results, or already-wrapped values should_wrap = !( typeof(result) <: Function || - result isa Union{Type,Module,Symbol,BatchedOrShared,StructArray,Base.RefValue} + result isa Union{Type,Module,Symbol,BatchedData,Base.RefValue} ) return should_wrap ? Ref(result) : result end diff --git a/GeneralisedFilters/src/batching/operations.jl b/GeneralisedFilters/src/batching/operations.jl index 29cfd70b..50581d83 100644 --- a/GeneralisedFilters/src/batching/operations.jl +++ b/GeneralisedFilters/src/batching/operations.jl @@ -349,3 +349,79 @@ function broadcasted( result = mask .* x.data .+ (one(T) .- mask) .* y.data return BatchedCuMatrix(result) end + +# ============================================================================= +# Batched zero +# ============================================================================= + +function broadcasted(::typeof(zero), v::BatchedCuVector{T}) where {T} + return BatchedCuVector(CUDA.zeros(T, size(v.data))) +end + +function broadcasted(::typeof(zero), A::BatchedCuMatrix{T}) where {T} + return BatchedCuMatrix(CUDA.zeros(T, size(A.data))) +end + +# ============================================================================= +# Batched logdetcov and invquad (for _logpdf) +# ============================================================================= + +import Distributions: logdetcov, sqmahal +import PDMats: invquad + +# length for BatchedCuVector: returns the inner dimension (same for all batch elements) +function broadcasted(::typeof(length), v::BatchedCuVector) + return size(v.data, 1) +end + +# logdetcov for StructArray{MvNormal}: delegates to the covariance matrix +function broadcasted(::typeof(logdetcov), d::StructArray{<:MvNormal{T}}) where {T} + return broadcasted(logdetcov, d.Σ) +end + +# logdetcov for StructArray{PDMat}: 2 * sum(log.(diag(L))) for each batch element +function broadcasted(::typeof(logdetcov), P::StructArray{<:PDMat{T}}) where {T} + # P.chol.factors is StructArray{LowerTriangular}, .data is BatchedCuMatrix + L_data = P.chol.factors.data.data # D×D×N CuArray + D, _, N = size(L_data) + + # Extract diagonal: L_data[i,i,k] for each i,k + diag_indices = [CartesianIndex(i, i, k) for k in 1:N for i in 1:D] + diag_flat = L_data[diag_indices] # (D*N,) vector + diag_matrix = reshape(diag_flat, D, N) # D×N + + # logdet = 2 * sum(log.(diag(L))) + return vec(T(2) .* sum(log.(diag_matrix); dims=1)) +end + +# invquad for StructArray{PDMat} and BatchedCuVector: x' * inv(P) * x +# Computed as: solve P*y = x (via potrs), then dot(x, y) +function broadcasted( + ::typeof(invquad), P::StructArray{<:PDMat{T}}, x::BatchedCuVector{T} +) where {T} + L = P.chol.factors.data # BatchedCuMatrix (D×D×N) + D, N = size(x.data) + + # Reshape x to matrix for potrs: D×N -> D×1×N + x_mat = BatchedCuMatrix(reshape(copy(x.data), D, 1, N)) + + # Solve L*L'*y = x in-place + potrs_batched!('L', L, x_mat) + + # y is now in x_mat, compute dot(x, y) = sum(x .* y) for each batch + y_vec = reshape(x_mat.data, D, N) + return vec(sum(x.data .* y_vec; dims=1)) +end + +# sqmahal for StructArray{MvNormal} and BatchedCuVector: (x - μ)' * inv(Σ) * (x - μ) +function broadcasted( + ::typeof(sqmahal), d::StructArray{<:MvNormal{T}}, x::BatchedCuVector{T} +) where {T} + diff = broadcasted(-, x, d.μ) + return broadcasted(invquad, d.Σ, diff) +end + +# oftype for CuVector result: convert scalar to element type +function broadcasted(::typeof(oftype), x::CuVector{T}, y::Base.RefValue) where {T} + return T(y[]) +end diff --git a/research/batching/full_kalman_demo.jl b/research/batching/full_kalman_demo.jl index 2c65aaef..ee48dea3 100644 --- a/research/batching/full_kalman_demo.jl +++ b/research/batching/full_kalman_demo.jl @@ -42,7 +42,7 @@ Rs = PDMat.(Rs); # Observations observations = BatchedCuVector(CUDA.randn(Float32, D_obs, N)) -jitter = 1f-6 +jitter = 1.0f-6 dyn_params = tuple.(As, bs, Qs) obs_params = tuple.(Hs, cs, Rs) @@ -68,11 +68,13 @@ function kalman_step(state, dyn_params, obs_params, observation, jitter) μ = μ + K * z Σ = PDMat(_maybe_apply_jitter(Σ̂_raw, jitter)) - # ll = logpdf(MvNormal(z, S), zero(z)) + ll = Distributions._logpdf(MvNormal(z, S), zero(z)) - return MvNormal(μ, Σ) + return MvNormal(μ, Σ), ll end res = kalman_step.(Gs, dyn_params, obs_params, observations, Ref(jitter)) -println(typeof(res)) -println(eltype(res)) +println("\nFull type: ", typeof(res)) +println("\nElement type: ", eltype(res)) + +lls = StructArrays.component(res, 2) From 0a8e0f05ed28fed9c5e81d67a36bd2f77b20b9dc Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Thu, 26 Feb 2026 16:16:14 +0000 Subject: [PATCH 25/29] Refactor type system Add batched struct Generalised batced array Make shared array have concrete dimensions --- .../src/batching/broadcasting.jl | 178 +++++++--- GeneralisedFilters/src/batching/operations.jl | 50 +-- GeneralisedFilters/src/batching/types.jl | 316 ++++++++++++++---- GeneralisedFilters/src/batching/wrappers.jl | 57 ++-- research/batching/batching_demo.jl | 24 +- research/batching/full_kalman_demo.jl | 8 +- research/batching/wrappers_demo.jl | 3 +- 7 files changed, 445 insertions(+), 191 deletions(-) diff --git a/GeneralisedFilters/src/batching/broadcasting.jl b/GeneralisedFilters/src/batching/broadcasting.jl index 10967b5b..560b2dc0 100644 --- a/GeneralisedFilters/src/batching/broadcasting.jl +++ b/GeneralisedFilters/src/batching/broadcasting.jl @@ -1,6 +1,5 @@ using IRTools using IRTools: @code_ir, IR, Statement, Variable, func -using StructArrays using LinearAlgebra: I, UniformScaling using Base.Broadcast: Broadcasted, BroadcastStyle, DefaultArrayStyle @@ -16,11 +15,9 @@ export BATCHED_CACHE_VERBOSITY, clear_batched_cache! struct BatchedStyle <: Broadcast.BroadcastStyle end -Base.BroadcastStyle(::Type{<:BatchedCuMatrix}) = BatchedStyle() -Base.BroadcastStyle(::Type{<:BatchedCuVector}) = BatchedStyle() -Base.BroadcastStyle(::Type{<:SharedCuMatrix}) = BatchedStyle() -Base.BroadcastStyle(::Type{<:SharedCuVector}) = BatchedStyle() -Base.BroadcastStyle(::Type{<:StructArray}) = BatchedStyle() +Base.BroadcastStyle(::Type{<:BatchedCuArray}) = BatchedStyle() +Base.BroadcastStyle(::Type{<:SharedCuArray}) = BatchedStyle() +Base.BroadcastStyle(::Type{<:BatchedStruct}) = BatchedStyle() Base.BroadcastStyle(::BatchedStyle, ::BatchedStyle) = BatchedStyle() Base.BroadcastStyle(::BatchedStyle, ::DefaultArrayStyle{0}) = BatchedStyle() @@ -28,47 +25,48 @@ Base.BroadcastStyle(::BatchedStyle, ::DefaultArrayStyle{0}) = BatchedStyle() # Ref Conversion (for Shared arrays) # ============================================================================= -maybe_convert_ref(x) = x -maybe_convert_ref(r::Base.RefValue{<:CuVector}) = SharedCuVector(r[]) -maybe_convert_ref(r::Base.RefValue{<:CuMatrix}) = SharedCuMatrix(r[]) - -# ============================================================================= -# Structural Operations (Pass-through) -# ============================================================================= - -broadcasted(::typeof(getproperty), x, s::Symbol) = getproperty(x, s) -broadcasted(::typeof(getfield), x, s::Symbol) = getfield(x, s) -broadcasted(::typeof(getfield), x, i::Int) = getfield(x, i) - -# Special handling for RefValue - unwrap before indexing -broadcasted(::typeof(getfield), r::Base.RefValue, i::Int) = getfield(r[], i) -broadcasted(::typeof(getfield), r::Base.RefValue, s::Symbol) = getfield(r[], s) +maybe_convert_ref(x, ::Nothing) = x +maybe_convert_ref(x, ::Int) = x +function maybe_convert_ref(r::Base.RefValue{<:CuVector{T,M}}, N::Int) where {T,M} + return SharedCuVector{T,M}(r[], (N,)) +end +function maybe_convert_ref(r::Base.RefValue{<:CuMatrix{T,M}}, N::Int) where {T,M} + return SharedCuMatrix{T,M}(r[], (N,)) +end +# Can't convert without knowing N — leave as Ref and let downstream handle it +maybe_convert_ref(r::Base.RefValue{<:CuVector}, ::Nothing) = r +maybe_convert_ref(r::Base.RefValue{<:CuMatrix}, ::Nothing) = r -# StructArray{<:Tuple} destructuring: extract the i-th component array -function broadcasted(::typeof(Base.indexed_iterate), sa::StructArray{<:Tuple}, i::Int) - return (StructArrays.component(sa, i), i + 1) +# BatchedStruct{<:Tuple} destructuring: extract the i-th component +function broadcasted(::typeof(Base.indexed_iterate), bs::BatchedStruct{<:Tuple}, i::Int) + return (getfield(bs, :components)[i], i + 1) end function broadcasted( - ::typeof(Base.indexed_iterate), sa::StructArray{<:Tuple}, i::Int, ::Any + ::typeof(Base.indexed_iterate), bs::BatchedStruct{<:Tuple}, i::Int, ::Any ) - return (StructArrays.component(sa, i), i + 1) + return (getfield(bs, :components)[i], i + 1) end # ============================================================================= -# StructArray Wrapping +# BatchedStruct Wrapping # ============================================================================= inner_eltype(arg::BatchedOrShared) = eltype(arg) -inner_eltype(arg::StructArray) = eltype(arg) inner_eltype(arg) = typeof(arg) +""" + wrap_if_batched(::Type{T}, args...) + +If any argument is batched, create a BatchedStruct{T} with the args as components. +Otherwise, call the constructor T(args...) normally. +""" function wrap_if_batched(::Type{T}, args...) where {T} - if any(arg -> arg isa Union{BatchedOrShared,StructArray}, args) + if any(arg -> arg isa BatchedOrShared, args) field_names = fieldnames(T) element_types = Tuple{map(inner_eltype, args)...} ElType = Core.Compiler.return_type(T, element_types) nt = NamedTuple{field_names}(args) - return StructArray{ElType}(nt) + return BatchedStruct{ElType}(nt) else return T(args...) end @@ -85,7 +83,7 @@ Generic wrapper for any type constructor applied to batched arguments. Works for single-field wrappers (Adjoint, Transpose, LowerTriangular, etc.) as well as multi-field types (PDMat, Cholesky, etc.). -Returns a StructArray where each element is the type applied to the +Returns a BatchedStruct where each element is the type applied to the corresponding elements of the input arrays. """ function broadcasted(::Type{W}, args::Vararg{BatchedOrShared}) where {W} @@ -93,32 +91,101 @@ function broadcasted(::Type{W}, args::Vararg{BatchedOrShared}) where {W} ElType = Core.Compiler.return_type(W, element_types) field_names = fieldnames(ElType) nt = NamedTuple{field_names}(args) - return StructArray{ElType}(nt) + return BatchedStruct{ElType}(nt) end # Redirect function forms to type constructors broadcasted(::typeof(adjoint), A::BatchedOrShared) = broadcasted(Adjoint, A) broadcasted(::typeof(transpose), A::BatchedOrShared) = broadcasted(Transpose, A) +# copy for Adjoint/Transpose wrappers - materialize the transposition +function broadcasted(::typeof(copy), x::BatchedStruct{<:Adjoint}) + parent_data = x.parent # BatchedCuMatrix or SharedCuMatrix + if parent_data isa BatchedCuMatrix + return BatchedCuMatrix(permutedims(parent_data.data, (2, 1, 3))) + else # SharedCuMatrix + return SharedCuMatrix(permutedims(parent_data.data, (2, 1))) + end +end + +function broadcasted(::typeof(copy), x::BatchedStruct{<:Transpose}) + parent_data = x.parent + if parent_data isa BatchedCuMatrix + return BatchedCuMatrix(permutedims(parent_data.data, (2, 1, 3))) + else # SharedCuMatrix + return SharedCuMatrix(permutedims(parent_data.data, (2, 1))) + end +end + # Union of all types that represent batched data -const BatchedData = Union{BatchedOrShared,StructArray,CuVector} +const BatchedData = Union{BatchedOrShared,CuVector} -# Batched tuple creation: returns StructArray{Tuple{...}} +# Batched tuple creation: returns BatchedStruct{Tuple{...}} function broadcasted(::typeof(tuple), args::Vararg{BatchedData}) ElType = Tuple{map(eltype, args)...} - return StructArray{ElType}(args) + # For tuples, components is a regular Tuple, not NamedTuple + components = NamedTuple{ntuple(i -> Symbol("x$i"), length(args))}(args) + return BatchedStruct{ElType}(components) end # ============================================================================= -# IR Transformation +# getfield/getproperty broadcasting for BatchedStruct +# ============================================================================= + +# getfield on BatchedStruct: return the batched component +# Handle both unwrapped and Ref-wrapped field names (Ref from maybe_wrap_scalar) +function broadcasted(::typeof(getfield), x::BatchedStruct{T}, s::Symbol) where {T} + s in fieldnames(T) && return getfield(x, :components)[s] + return error("BatchedStruct{$T} has no field `$s`") +end +function broadcasted( + ::typeof(getfield), x::BatchedStruct{T}, s::Base.RefValue{Symbol} +) where {T} + return broadcasted(getfield, x, s[]) +end +broadcasted(::typeof(getfield), x::BatchedStruct, i::Int) = getfield(x, :components)[i] +function broadcasted(::typeof(getfield), x::BatchedStruct, i::Base.RefValue{<:Integer}) + return getfield(x, :components)[i[]] +end + +# getproperty on BatchedStruct: return batched component for real fields, +# fall through to tracing for computed properties +function broadcasted(::typeof(getproperty), x::BatchedStruct{T}, s::Symbol) where {T} + s in fieldnames(T) && return getfield(x, :components)[s] + # Computed property - return Broadcasted to trigger IR transformation + return Broadcasted{BatchedStyle}(getproperty, (x, s)) +end + +function broadcasted( + ::typeof(getproperty), x::BatchedStruct{T}, s::Base.RefValue{Symbol} +) where {T} + return broadcasted(getproperty, x, s[]) +end + +# ============================================================================= +# size broadcasting for batched arrays (returns inner dimensions) # ============================================================================= -const SKIP_BROADCAST = Set{Any}([getfield, getproperty]) +broadcasted(::typeof(size), A::BatchedCuMatrix) = inner_size(A) +broadcasted(::typeof(size), A::BatchedCuMatrix, i::Integer) = inner_size(A)[i] +broadcasted(::typeof(size), A::SharedCuMatrix) = inner_size(A) +broadcasted(::typeof(size), A::SharedCuMatrix, i::Integer) = inner_size(A)[i] + +broadcasted(::typeof(size), x::BatchedCuVector) = inner_size(x) +broadcasted(::typeof(size), x::BatchedCuVector, i::Integer) = inner_size(x)[i] +broadcasted(::typeof(size), x::SharedCuVector) = inner_size(x) +broadcasted(::typeof(size), x::SharedCuVector, i::Integer) = inner_size(x)[i] + +# ============================================================================= +# IR Transformation +# ============================================================================= +const SKIP_BROADCAST = Set{Any}() const BROADCAST_TYPES = Set{Any}([PDMat]) -# Don't wrap: batched data, callables, modules, symbols, integers, already-wrapped refs +# Don't wrap: batched data, shared scalars, callables, modules, symbols, integers, already-wrapped refs maybe_wrap_scalar(x::BatchedData) = x +maybe_wrap_scalar(x::SharedScalar) = x maybe_wrap_scalar(x::Union{Type,Module,Symbol,Integer,Base.RefValue}) = x maybe_wrap_scalar(x) = typeof(x) <: Function ? x : Ref(x) @@ -129,15 +196,17 @@ maybe_wrap_scalar(x) = typeof(x) <: Function ? x : Ref(x) has_batched = any(arg -> arg isa BatchedData, wrapped_args) if !has_batched - # All scalars - unwrap, execute scalar operation, re-wrap result + # All scalars - unwrap and execute directly unwrapped_args = map(a -> a isa Base.RefValue ? a[] : a, wrapped_args) - result = f(unwrapped_args...) - # Don't wrap code/metadata, batched results, or already-wrapped values - should_wrap = !( - typeof(result) <: Function || - result isa Union{Type,Module,Symbol,BatchedData,Base.RefValue} - ) - return should_wrap ? Ref(result) : result + return f(unwrapped_args...) + end + + # Special case: getfield on Ref-wrapped scalar object (getfield is a builtin, can't trace) + if f === getfield && length(wrapped_args) >= 1 + obj = wrapped_args[1] + if obj isa Base.RefValue + return getfield(obj[], wrapped_args[2:end]...) + end end # Has batched inputs - normal broadcast path @@ -204,7 +273,7 @@ function transform_to_batched(ir::IR) end ir_element_type(::Type{T}) where {T} = T -ir_element_type(::Type{<:StructArray{T}}) where {T} = T +ir_element_type(::Type{<:BatchedStruct{T}}) where {T} = T ir_element_type(::Type{<:Base.RefValue{T}}) where {T} = T function generate_batched_function(f, argtypes::Type{<:Tuple}) @@ -243,9 +312,22 @@ function clear_batched_cache!() return nothing end +function _find_batch_size(args) + for arg in args + if arg isa BatchedCuArray || arg isa SharedCuArray + return batch_size(arg) + elseif arg isa Broadcasted + n = _find_batch_size(arg.args) + n !== nothing && return n + end + end + return nothing +end + function Broadcast.materialize(bc::Broadcasted{BatchedStyle}) f = bc.f - args = map(maybe_convert_ref, bc.args) + N = _find_batch_size(bc.args) + args = map(a -> maybe_convert_ref(a, N), bc.args) result = broadcasted(f, args...) if !(result isa Broadcasted) diff --git a/GeneralisedFilters/src/batching/operations.jl b/GeneralisedFilters/src/batching/operations.jl index 50581d83..e871d1f0 100644 --- a/GeneralisedFilters/src/batching/operations.jl +++ b/GeneralisedFilters/src/batching/operations.jl @@ -5,18 +5,18 @@ import LinearAlgebra: norm # GEMM-Compatible Types # ============================================================================= -# Type aliases for StructArray-wrapped matrices -const BatchedAdjoint{T,M} = StructArray{ - Adjoint{T,CuArray{T,2,M}},1,@NamedTuple{parent::BatchedCuMatrix{T,M}} +# Type aliases for BatchedStruct-wrapped matrices +const BatchedAdjoint{T,M} = BatchedStruct{ + Adjoint{T,CuArray{T,2,M}},@NamedTuple{parent::BatchedCuMatrix{T,M}} } -const BatchedTranspose{T,M} = StructArray{ - Transpose{T,CuArray{T,2,M}},1,@NamedTuple{parent::BatchedCuMatrix{T,M}} +const BatchedTranspose{T,M} = BatchedStruct{ + Transpose{T,CuArray{T,2,M}},@NamedTuple{parent::BatchedCuMatrix{T,M}} } -const SharedAdjoint{T,M} = StructArray{ - Adjoint{T,CuArray{T,2,M}},1,@NamedTuple{parent::SharedCuMatrix{T,M}} +const SharedAdjoint{T,M} = BatchedStruct{ + Adjoint{T,CuArray{T,2,M}},@NamedTuple{parent::SharedCuMatrix{T,M}} } -const SharedTranspose{T,M} = StructArray{ - Transpose{T,CuArray{T,2,M}},1,@NamedTuple{parent::SharedCuMatrix{T,M}} +const SharedTranspose{T,M} = BatchedStruct{ + Transpose{T,CuArray{T,2,M}},@NamedTuple{parent::SharedCuMatrix{T,M}} } # Union of all GEMM-compatible matrix types @@ -224,7 +224,7 @@ end # ============================================================================= # HACK: PDMat is a constructor so will use -# `broadcasted(::Type{W}, args::Union{BatchedCuMatrix, BatchedCuVector, SharedCuMatrix, SharedCuVector, StructArray}...) where W` +# `broadcasted(::Type{W}, args::BatchedOrShared...) where W` # rather than the desired recursive broadcast to # `PDMat(mat::AbstractMatrix) = PDMat(mat, cholesky(mat))` # This method hardcodes a manual override for BatchedCuMatrix inputs. This should be replaced @@ -237,13 +237,13 @@ end # HACK: Addition with PDMat extracts .mat field. Should be replaced by automatic # materialization of PDMat to BatchedCuMatrix in the future. function broadcasted( - ::typeof(+), A::BatchedCuMatrix{T}, P::StructArray{<:PDMat{T}} + ::typeof(+), A::BatchedCuMatrix{T}, P::BatchedStruct{<:PDMat{T}} ) where {T} return broadcasted(+, A, P.mat) end function broadcasted( - ::typeof(+), P::StructArray{<:PDMat{T}}, A::BatchedCuMatrix{T} + ::typeof(+), P::BatchedStruct{<:PDMat{T}}, A::BatchedCuMatrix{T} ) where {T} return broadcasted(+, P.mat, A) end @@ -251,7 +251,7 @@ end # A / S where S is PDMat: computes A * inv(S) # potrs solves S * X = B, so we solve S * X = A' and transpose back function broadcasted( - ::typeof(/), A::BatchedCuMatrix{T}, S::StructArray{<:PDMat{T}} + ::typeof(/), A::BatchedCuMatrix{T}, S::BatchedStruct{<:PDMat{T}} ) where {T} L = S.chol.factors.data @@ -280,16 +280,16 @@ function broadcasted( return broadcasted(*, temp, Xt) end -# X_A_Xt for StructArray{PDMat}: X * P * X' where P = L * L' +# X_A_Xt for BatchedStruct{PDMat}: X * P * X' where P = L * L' # Computed as (X * L) * (X * L)' using TRMM and SYRK # HACK: this function should dispatch to specialised `*` for triangular types but this is # not yet implemented function broadcasted( ::typeof(X_A_Xt), - P::StructArray{<:PDMat{T}}, + P::BatchedStruct{<:PDMat{T}}, X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, ) where {T} - # P.chol.factors is StructArray{LowerTriangular}, .data is the BatchedCuMatrix + # P.chol.factors is BatchedStruct{LowerTriangular}, .parent is the BatchedCuMatrix L = P.chol.factors.data N = get_batch_size(L, X) out_dim = inner_size_for_blas(X)[1] @@ -374,14 +374,14 @@ function broadcasted(::typeof(length), v::BatchedCuVector) return size(v.data, 1) end -# logdetcov for StructArray{MvNormal}: delegates to the covariance matrix -function broadcasted(::typeof(logdetcov), d::StructArray{<:MvNormal{T}}) where {T} +# logdetcov for BatchedStruct{MvNormal}: delegates to the covariance matrix +function broadcasted(::typeof(logdetcov), d::BatchedStruct{<:MvNormal{T}}) where {T} return broadcasted(logdetcov, d.Σ) end -# logdetcov for StructArray{PDMat}: 2 * sum(log.(diag(L))) for each batch element -function broadcasted(::typeof(logdetcov), P::StructArray{<:PDMat{T}}) where {T} - # P.chol.factors is StructArray{LowerTriangular}, .data is BatchedCuMatrix +# logdetcov for BatchedStruct{PDMat}: 2 * sum(log.(diag(L))) for each batch element +function broadcasted(::typeof(logdetcov), P::BatchedStruct{<:PDMat{T}}) where {T} + # P.chol.factors is BatchedStruct{LowerTriangular}, .parent is BatchedCuMatrix L_data = P.chol.factors.data.data # D×D×N CuArray D, _, N = size(L_data) @@ -394,10 +394,10 @@ function broadcasted(::typeof(logdetcov), P::StructArray{<:PDMat{T}}) where {T} return vec(T(2) .* sum(log.(diag_matrix); dims=1)) end -# invquad for StructArray{PDMat} and BatchedCuVector: x' * inv(P) * x +# invquad for BatchedStruct{PDMat} and BatchedCuVector: x' * inv(P) * x # Computed as: solve P*y = x (via potrs), then dot(x, y) function broadcasted( - ::typeof(invquad), P::StructArray{<:PDMat{T}}, x::BatchedCuVector{T} + ::typeof(invquad), P::BatchedStruct{<:PDMat{T}}, x::BatchedCuVector{T} ) where {T} L = P.chol.factors.data # BatchedCuMatrix (D×D×N) D, N = size(x.data) @@ -413,9 +413,9 @@ function broadcasted( return vec(sum(x.data .* y_vec; dims=1)) end -# sqmahal for StructArray{MvNormal} and BatchedCuVector: (x - μ)' * inv(Σ) * (x - μ) +# sqmahal for BatchedStruct{MvNormal} and BatchedCuVector: (x - μ)' * inv(Σ) * (x - μ) function broadcasted( - ::typeof(sqmahal), d::StructArray{<:MvNormal{T}}, x::BatchedCuVector{T} + ::typeof(sqmahal), d::BatchedStruct{<:MvNormal{T}}, x::BatchedCuVector{T} ) where {T} diff = broadcasted(-, x, d.μ) return broadcasted(invquad, d.Σ, diff) diff --git a/GeneralisedFilters/src/batching/types.jl b/GeneralisedFilters/src/batching/types.jl index ece85e77..6e2049c3 100644 --- a/GeneralisedFilters/src/batching/types.jl +++ b/GeneralisedFilters/src/batching/types.jl @@ -2,89 +2,290 @@ using CUDA using LinearAlgebra: Adjoint, Transpose, LowerTriangular, UpperTriangular, UniformScaling, Cholesky using PDMats: PDMat -using StructArrays: StructArray -export BatchedCuMatrix, BatchedCuVector -export SharedCuMatrix, SharedCuVector +export BatchedCuArray, BatchedCuMatrix, BatchedCuVector +export SharedCuArray, SharedCuMatrix, SharedCuVector, SharedScalar +export Shared +export BatchedStruct # ============================================================================= -# Core Batched Types +# Core Batched Type # ============================================================================= -struct BatchedCuMatrix{T,M} <: AbstractVector{CuArray{T,2,M}} - data::CuArray{T,3,M} +""" + BatchedCuArray{T, NE, NB, NT, M} <: AbstractArray{CuArray{T,NE,M}, NB} + +An `NB`-dimensional batch of `NE`-dimensional CuArrays, stored as a single contiguous +`CuArray{T, NT, M}` where `NT = NE + NB`. + +- `NE`: number of element dimensions (the "inner" array shape) +- `NB`: number of batch dimensions +- `NT`: total number of dimensions (`NE + NB`); required explicitly because Julia's type + system cannot express arithmetic on type parameters + +The first `NE` dimensions index within each element; the last `NB` dimensions index across +the batch. + +# Common aliases +- `BatchedCuMatrix{T,M}` = `BatchedCuArray{T,2,1,3,M}` — a vector of matrices +- `BatchedCuVector{T,M}` = `BatchedCuArray{T,1,1,2,M}` — a vector of vectors +""" +struct BatchedCuArray{T,NE,NB,NT,M} <: AbstractArray{CuArray{T,NE,M},NB} + data::CuArray{T,NT,M} + + function BatchedCuArray{T,NE,NB,NT,M}(data::CuArray{T,NT,M}) where {T,NE,NB,NT,M} + NE + NB == NT || error("NE ($NE) + NB ($NB) must equal ndims(data) ($NT)") + return new{T,NE,NB,NT,M}(data) + end end -struct BatchedCuVector{T,M} <: AbstractVector{CuArray{T,1,M}} - data::CuArray{T,2,M} +# Convenience constructor: infer T and M, require explicit NE and NB +function BatchedCuArray{T,NE,NB}(data::CuArray{T,NT,M}) where {T,NE,NB,NT,M} + NE + NB == NT || error("NE ($NE) + NB ($NB) must equal ndims(data) ($NT)") + return BatchedCuArray{T,NE,NB,NT,M}(data) end -const BatchedArray = Union{BatchedCuVector,BatchedCuMatrix} +# Common case aliases +const BatchedCuMatrix{T,M} = BatchedCuArray{T,2,1,3,M} +const BatchedCuVector{T,M} = BatchedCuArray{T,1,1,2,M} -batch_size(x::BatchedCuVector) = size(x.data, 2) -batch_size(x::BatchedCuMatrix) = size(x.data, 3) +# Constructors for aliased cases +BatchedCuMatrix(data::CuArray{T,3,M}) where {T,M} = BatchedCuArray{T,2,1,3,M}(data) +BatchedCuVector(data::CuArray{T,2,M}) where {T,M} = BatchedCuArray{T,1,1,2,M}(data) -Base.size(x::BatchedCuVector) = (batch_size(x),) -Base.size(x::BatchedCuMatrix) = (batch_size(x),) -Base.length(x::BatchedArray) = batch_size(x) +const BatchedArray = BatchedCuArray -inner_size(x::BatchedCuVector) = (size(x.data, 1),) -inner_size(x::BatchedCuMatrix) = (size(x.data, 1), size(x.data, 2)) +Base.IndexStyle(::Type{<:BatchedCuArray}) = Base.IndexCartesian() -Base.getindex(x::BatchedCuVector, i::Int) = view(x.data, :, i) -Base.getindex(x::BatchedCuMatrix, i::Int) = view(x.data,:,:,i) +function Base.size(x::BatchedCuArray{T,NE,NB}) where {T,NE,NB} + return ntuple(i -> size(x.data, NE + i), NB) +end + +function Base.getindex(x::BatchedCuArray{T,NE,NB}, I::Vararg{Int,NB}) where {T,NE,NB} + return view(x.data, ntuple(_ -> :, NE)..., I...) +end + +function inner_size(x::BatchedCuArray{T,NE}) where {T,NE} + return ntuple(i -> size(x.data, i), NE) +end + +batch_size(x::BatchedCuArray) = length(x) # ============================================================================= # Shared Types (same data reused across all batch elements) # ============================================================================= -struct SharedCuMatrix{T,M} <: AbstractVector{CuArray{T,2,M}} - data::CuArray{T,2,M} +""" + SharedCuArray{T, InnerN, BatchN, M} <: AbstractArray{CuArray{T,InnerN,M}, BatchN} + +A batch of CuArrays where every element is the same underlying `CuArray{T,InnerN,M}`. +Unlike `Ref(array)`, this type carries an explicit batch size and satisfies the +`AbstractArray` contract honestly. + +Use `Ref(array)` when the batch size is unknown or irrelevant (e.g. during broadcast +setup). Use `SharedCuArray` when you need a proper `AbstractArray` with a known size. + +# Common aliases +- `SharedCuMatrix{T,M}` = `SharedCuArray{T,2,1,M}` +- `SharedCuVector{T,M}` = `SharedCuArray{T,1,1,M}` +""" +struct SharedCuArray{T,InnerN,BatchN,M} <: AbstractArray{CuArray{T,InnerN,M},BatchN} + data::CuArray{T,InnerN,M} + batchsize::NTuple{BatchN,Int} end -struct SharedCuVector{T,M} <: AbstractVector{CuArray{T,1,M}} - data::CuArray{T,1,M} +# Outer constructor: accept a plain Int for the common 1D-batch case +function SharedCuArray{T,InnerN,1,M}(data::CuArray{T,InnerN,M}, N::Int) where {T,InnerN,M} + return SharedCuArray{T,InnerN,1,M}(data, (N,)) +end + +const SharedCuMatrix{T,M} = SharedCuArray{T,2,1,M} +const SharedCuVector{T,M} = SharedCuArray{T,1,1,M} + +const SharedArray = SharedCuArray + +""" + Shared(data::CuArray, N::Int) -> SharedCuArray + +Convenience constructor: create a `SharedCuArray` from a CuArray with an explicit +1D batch size `N`. +""" +Shared(x::CuArray{T,2,M}, N::Int) where {T,M} = SharedCuArray{T,2,1,M}(x, (N,)) +Shared(x::CuArray{T,1,M}, N::Int) where {T,M} = SharedCuArray{T,1,1,M}(x, (N,)) + +Base.IndexStyle(::Type{<:SharedCuArray}) = Base.IndexCartesian() + +Base.size(x::SharedCuArray) = x.batchsize + +function Base.getindex( + x::SharedCuArray{T,InnerN,BatchN}, ::Vararg{Int,BatchN} +) where {T,InnerN,BatchN} + return x.data end -const SharedArray = Union{SharedCuVector,SharedCuMatrix} +function inner_size(x::SharedCuArray) + return size(x.data) +end -Shared(x::CuArray{T,2,M}) where {T,M} = SharedCuMatrix{T,M}(x) -Shared(x::CuArray{T,1,M}) where {T,M} = SharedCuVector{T,M}(x) +batch_size(x::SharedCuArray) = length(x) -batch_size(::SharedCuVector) = nothing -batch_size(::SharedCuMatrix) = nothing +# ============================================================================= +# SharedScalar: a scalar value shared across all batch elements +# ============================================================================= -inner_size(x::SharedCuVector) = size(x.data) -inner_size(x::SharedCuMatrix) = size(x.data) +struct SharedScalar{T} <: AbstractVector{T} + value::T +end -Base.size(x::SharedCuVector) = (1,) -Base.size(x::SharedCuMatrix) = (1,) -Base.length(::SharedArray) = 1 +Base.size(::SharedScalar) = (1,) +Base.length(::SharedScalar) = 1 +Base.getindex(x::SharedScalar, ::Int) = x.value +Base.eltype(::Type{SharedScalar{T}}) where {T} = T +batch_size(::SharedScalar) = nothing +_get_component_batch_size(::SharedScalar) = nothing -Base.getindex(x::SharedCuVector, ::Int) = x.data -Base.getindex(x::SharedCuMatrix, ::Int) = x.data +# Comparisons unwrap the SharedScalar automatically +Base.:(==)(x::SharedScalar, y) = x.value == y +Base.:(==)(x, y::SharedScalar) = x == y.value +Base.:(==)(x::SharedScalar, y::SharedScalar) = x.value == y.value + +# ============================================================================= +# BatchedStruct - Custom wrapper for batched composite types +# ============================================================================= + +""" + BatchedStruct{T, C <: NamedTuple} <: AbstractVector{T} + +A wrapper type representing a batch of structs of type `T`, stored in a +column-oriented (struct-of-arrays) format. + +# Type Parameters +- `T`: The element type (e.g., `PDMat{Float32, CuMatrix{Float32}}`) +- `C`: The NamedTuple type holding the batched components + +# Fields +- `components::C`: A NamedTuple where each field is a batched array or nested BatchedStruct + +# Usage +BatchedStruct is designed to be created automatically by the batching IR transform +when constructors are called with batched arguments. Users typically don't need +to construct these directly. + +# Property Access +- `x.fieldname` returns the batched component for real fields of `T` +- `x.components` returns the underlying NamedTuple storage +- For computed properties (custom getproperty), falls back to element-wise evaluation + +# Indexing +- `x[i]` constructs and returns an instance of `T` for the i-th batch element +""" +struct BatchedStruct{T,C<:NamedTuple} <: AbstractVector{T} + components::C + + function BatchedStruct{T}(components::C) where {T,C<:NamedTuple} + # Validate all components have consistent batch size + sizes = map(_get_component_batch_size, values(components)) + non_nothing = Base.filter(!isnothing, sizes) + if !isempty(non_nothing) + first_size = first(non_nothing) + if !all(s -> s == first_size, non_nothing) + error("All batched components must have the same batch size") + end + end + return new{T,C}(components) + end +end + +# Helper to get batch size from a component +_get_component_batch_size(x::BatchedCuArray) = batch_size(x) +_get_component_batch_size(x::SharedCuArray) = batch_size(x) +_get_component_batch_size(x::BatchedStruct) = length(x) +_get_component_batch_size(::Any) = nothing + +# Convenience constructor that infers T from the fieldnames matching a type +function BatchedStruct{T}(; kwargs...) where {T} + components = NamedTuple{fieldnames(T)}(values(kwargs)) + return BatchedStruct{T}(components) +end + +# ============================================================================= +# BatchedStruct - AbstractVector Interface +# ============================================================================= + +function batch_size(x::BatchedStruct) + for component in values(getfield(x, :components)) + bs = _get_component_batch_size(component) + if bs !== nothing + return bs + end + end + return error("BatchedStruct has no batched components") +end + +Base.size(x::BatchedStruct) = (batch_size(x),) +Base.IndexStyle(::Type{<:BatchedStruct}) = IndexLinear() + +# Indexing: construct an element of type T by calling its constructor +@generated function Base.getindex(x::BatchedStruct{T}, i::Integer) where {T} + fields = fieldnames(T) + field_exprs = [:(getfield(x, :components).$(fields[j])[i]) for j in 1:length(fields)] + return :(T($(field_exprs...))) +end + +# ============================================================================= +# BatchedStruct - Property Access +# ============================================================================= + +function Base.getproperty(x::BatchedStruct{T}, s::Symbol) where {T} + s === :components && return getfield(x, :components) + s in fieldnames(T) && return getfield(x, :components)[s] + # Computed property - broadcast getproperty, triggering IR transformation + return getproperty.(x, Ref(s)) +end + +Base.propertynames(::BatchedStruct{T}) where {T} = (:components, fieldnames(T)...) + +# ============================================================================= +# BatchedStruct - Display (avoid materialization) +# ============================================================================= + +function Base.show(io::IO, x::BatchedStruct{T}) where {T} + return print(io, "BatchedStruct{", T, "} with ", length(x), " elements") +end + +function Base.show(io::IO, ::MIME"text/plain", x::BatchedStruct{T}) where {T} + println(io, "BatchedStruct{", T, "} with ", length(x), " elements:") + comps = getfield(x, :components) + for name in keys(comps) + component = comps[name] + print(io, " .", name, " :: ") + if component isa BatchedCuArray + println(io, typeof(component), " (", inner_size(component), ")") + elseif component isa SharedCuArray + println(io, typeof(component), " [shared] (", inner_size(component), ")") + elseif component isa BatchedStruct + println(io, typeof(component)) + else + println(io, typeof(component)) + end + end +end # ============================================================================= # Union Types for Dispatch # ============================================================================= -const BatchedOrShared = Union{ - BatchedCuMatrix,BatchedCuVector,SharedCuMatrix,SharedCuVector,StructArray -} +const BatchedOrShared = Union{BatchedCuArray,SharedCuArray,BatchedStruct} # ============================================================================= # Helper Functions # ============================================================================= -is_shared(::BatchedCuMatrix) = false -is_shared(::BatchedCuVector) = false -is_shared(::SharedCuMatrix) = true -is_shared(::SharedCuVector) = true +is_shared(::BatchedCuArray) = false +is_shared(::SharedCuArray) = true -unwrap_data(A::BatchedCuMatrix) = A.data -unwrap_data(A::SharedCuMatrix) = A.data -unwrap_data(x::BatchedCuVector) = x.data -unwrap_data(x::SharedCuVector) = x.data +unwrap_data(A::BatchedCuArray) = A.data +unwrap_data(A::SharedCuArray) = A.data function inner_size_for_blas(A::BatchedCuMatrix) m, n = size(A.data, 1), size(A.data, 2) @@ -110,27 +311,12 @@ end # Pointer Array Creation # ============================================================================= -function create_pointer_array(A::BatchedCuMatrix{T}) where {T} +function create_pointer_array(A::BatchedCuArray{T,InnerN,1}) where {T,InnerN} return CUDA.CUBLAS.unsafe_strided_batch(A.data) end -function create_pointer_array(A::SharedCuMatrix{T}, N::Int) where {T} - base_ptr = pointer(A.data) - ptrs_cpu = fill(base_ptr, N) - return CuArray(ptrs_cpu) -end - -function create_pointer_array_vector(x::BatchedCuVector{T}) where {T} - n = size(x.data, 1) - N = size(x.data, 2) - base_ptr = pointer(x.data) - stride = n * sizeof(T) - ptrs = CuArray([base_ptr + (i - 1) * stride for i in 1:N]) - return ptrs -end - -function create_pointer_array_vector(x::SharedCuVector{T}, N::Int) where {T} - base_ptr = pointer(x.data) - ptrs_cpu = fill(base_ptr, N) - return CuArray(ptrs_cpu) +function create_pointer_array(A::SharedCuArray{T,InnerN,1}) where {T,InnerN} + N = batch_size(A) + ptr = pointer(A.data) + return reinterpret(CuPtr{T}, CUDA.fill(UInt(ptr), N)) end diff --git a/GeneralisedFilters/src/batching/wrappers.jl b/GeneralisedFilters/src/batching/wrappers.jl index 56de9a57..7f26b5e7 100644 --- a/GeneralisedFilters/src/batching/wrappers.jl +++ b/GeneralisedFilters/src/batching/wrappers.jl @@ -1,7 +1,6 @@ using Magma using Magma.LibMagma using LinearAlgebra: cholesky, Cholesky, LowerTriangular -using StructArrays: StructArray export get_magma_queue, reset_magma_queue! @@ -56,7 +55,7 @@ function broadcasted( B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, ) where {T} if is_shared(A) && is_shared(B) - return SharedCuMatrix(A.data .+ B.data) + return SharedCuMatrix(A.data .+ B.data, batch_size(A)) else return BatchedCuMatrix(A.data .+ B.data) end @@ -68,7 +67,7 @@ function broadcasted( b::Union{BatchedCuVector{T},SharedCuVector{T}}, ) where {T} if is_shared(a) && is_shared(b) - return SharedCuVector(a.data .+ b.data) + return SharedCuVector(a.data .+ b.data, batch_size(a)) else return BatchedCuVector(a.data .+ b.data) end @@ -80,7 +79,7 @@ function broadcasted( B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, ) where {T} if is_shared(A) && is_shared(B) - return SharedCuMatrix(A.data .- B.data) + return SharedCuMatrix(A.data .- B.data, batch_size(A)) else return BatchedCuMatrix(A.data .- B.data) end @@ -92,7 +91,7 @@ function broadcasted( b::Union{BatchedCuVector{T},SharedCuVector{T}}, ) where {T} if is_shared(a) && is_shared(b) - return SharedCuVector(a.data .- b.data) + return SharedCuVector(a.data .- b.data, batch_size(a)) else return BatchedCuVector(a.data .- b.data) end @@ -161,8 +160,8 @@ function gemm_batched!( m, n = size(C.data, 1), size(C.data, 2) k = transA == 'N' ? size(unwrap_data(A), 2) : size(unwrap_data(A), 1) - dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) - dB = B isa BatchedCuMatrix ? create_pointer_array(B) : create_pointer_array(B, N) + dA = create_pointer_array(A) + dB = create_pointer_array(B) dC = create_pointer_array(C) ldda = size(unwrap_data(A), 1) @@ -214,8 +213,8 @@ function gemm_batched_smallsq!( m, n = size(C.data, 1), size(C.data, 2) k = transA == 'N' ? size(unwrap_data(A), 2) : size(unwrap_data(A), 1) - dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) - dB = B isa BatchedCuMatrix ? create_pointer_array(B) : create_pointer_array(B, N) + dA = create_pointer_array(A) + dB = create_pointer_array(B) dC = create_pointer_array(C) ldda = size(unwrap_data(A), 1) @@ -269,8 +268,8 @@ function gemm_batched_smallsq!( m, n = size(C.data, 1), size(C.data, 2) k = transA == 'N' ? size(unwrap_data(A), 2) : size(unwrap_data(A), 1) - dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) - dB = B isa BatchedCuMatrix ? create_pointer_array(B) : create_pointer_array(B, N) + dA = create_pointer_array(A) + dB = create_pointer_array(B) dC = create_pointer_array(C) ldda = size(unwrap_data(A), 1) @@ -322,13 +321,9 @@ function gemv_batched!( N = batch_size(y) m, n = size(unwrap_data(A), 1), size(unwrap_data(A), 2) - dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) - dx = if x isa BatchedCuVector - create_pointer_array_vector(x) - else - create_pointer_array_vector(x, N) - end - dy = create_pointer_array_vector(y) + dA = create_pointer_array(A) + dx = create_pointer_array(x) + dy = create_pointer_array(y) ldda = m incx = 1 @@ -359,13 +354,9 @@ function gemv_batched!( N = batch_size(y) m, n = size(unwrap_data(A), 1), size(unwrap_data(A), 2) - dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) - dx = if x isa BatchedCuVector - create_pointer_array_vector(x) - else - create_pointer_array_vector(x, N) - end - dy = create_pointer_array_vector(y) + dA = create_pointer_array(A) + dx = create_pointer_array(x) + dy = create_pointer_array(y) ldda = m incx = 1 @@ -396,13 +387,9 @@ function gemv_batched_smallsq!( N = batch_size(y) n = size(unwrap_data(A), 1) - dA = A isa BatchedCuMatrix ? create_pointer_array(A) : create_pointer_array(A, N) - dx = if x isa BatchedCuVector - create_pointer_array_vector(x) - else - create_pointer_array_vector(x, N) - end - dy = create_pointer_array_vector(y) + dA = create_pointer_array(A) + dx = create_pointer_array(x) + dy = create_pointer_array(y) ldda = n incx = 1 @@ -519,11 +506,11 @@ function broadcasted(::typeof(cholesky), A::BatchedCuMatrix{T,M}) where {T,M} factors_wrapped = broadcasted(LowerTriangular, A_copy) - # TODO: Use a lazy constant vector for uplo instead of dense fill - uplo = fill('L', N) + # Store as SharedScalar since it's the same for all batch elements + uplo = SharedScalar('L') ElType = Cholesky{T,eltype(A)} - return StructArray{ElType}((; factors=factors_wrapped, uplo=uplo, info=info)) + return BatchedStruct{ElType}((; factors=factors_wrapped, uplo=uplo, info=info)) end # function pdmat_solve(S::BatchedPDMat{T}, B::BatchedCuMatrix{T}) where {T} diff --git a/research/batching/batching_demo.jl b/research/batching/batching_demo.jl index 8830fc1d..e6f98994 100644 --- a/research/batching/batching_demo.jl +++ b/research/batching/batching_demo.jl @@ -4,7 +4,6 @@ using Distributions using LinearAlgebra using Base.Broadcast: broadcasted using PDMats -using StructArrays using BenchmarkTools using CUDA @@ -34,19 +33,20 @@ function kalman_predict(state, dyn_params) end I_mat = CuArray{Float32}(I, D_state, D_state) -Is = SharedCuMatrix(I_mat) +Is = Shared(I_mat, N) μs = BatchedCuVector(CUDA.randn(Float32, D_state, N)) Σs_root = BatchedCuMatrix(CUDA.randn(Float32, D_state, D_state, N)) Σs = Σs_root .* adjoint.(Σs_root) .+ Is -As = SharedCuMatrix(CUDA.randn(Float32, D_state, D_state)) +As = Shared(CUDA.randn(Float32, D_state, D_state), N) bs = BatchedCuVector(CUDA.randn(Float32, D_state, N)) Q_root = CUDA.randn(Float32, D_state, D_state) Q = Q_root * Q_root' + I -Qs = SharedCuMatrix(Q) +Qs = Shared(Q, N) Σ_PDs = broadcasted(PDMat, Σs); +PDMat.(Σs) Gs = MvNormal.(μs, Σ_PDs); function kalman_predict(state, dyn_params) @@ -111,10 +111,10 @@ function kalman_step(state, dyn_params, obs_params, observation) end # Observation parameters (H and c shared, R batched) -Hs = SharedCuMatrix(CUDA.randn(Float32, D_obs, D_state)) -cs = SharedCuVector(CUDA.randn(Float32, D_obs)) +Hs = Shared(CUDA.randn(Float32, D_obs, D_state), N) +cs = Shared(CUDA.randn(Float32, D_obs), N) I_obs = CuArray{Float32}(I, D_obs, D_obs) -I_obs_shared = SharedCuMatrix(I_obs) +I_obs_shared = Shared(I_obs, N) Rs_root = BatchedCuMatrix(CUDA.randn(Float32, D_obs, D_obs, N)) Rs = Rs_root .* adjoint.(Rs_root) .+ I_obs_shared Rs = PDMat.(Rs); @@ -183,23 +183,23 @@ println( " MB", ) -Is_bench = SharedCuMatrix(CuArray{Float32}(I, D_bench, D_bench)) +Is_bench = Shared(CuArray{Float32}(I, D_bench, D_bench), N_bench) μs_bench = BatchedCuVector(CUDA.randn(Float32, D_bench, N_bench)) Σs_root_bench = BatchedCuMatrix(CUDA.randn(Float32, D_bench, D_bench, N_bench)) Σs_bench = Σs_root_bench .* adjoint.(Σs_root_bench) .+ Is_bench Σ_PDs_bench = broadcasted(PDMat, Σs_bench); Gs_bench = MvNormal.(μs_bench, Σ_PDs_bench); -As_bench = SharedCuMatrix(CUDA.randn(Float32, D_bench, D_bench)) +As_bench = Shared(CUDA.randn(Float32, D_bench, D_bench), N_bench) bs_bench = BatchedCuVector(CUDA.randn(Float32, D_bench, N_bench)) Qs_root_bench = CUDA.randn(Float32, D_bench, D_bench) Qs_bench_mat = Qs_root_bench * adjoint(Qs_root_bench) Qs_bench_mat += I -Qs_bench = SharedCuMatrix(Qs_bench_mat) +Qs_bench = Shared(Qs_bench_mat, N_bench) dyn_params_bench = (As_bench, bs_bench, Qs_bench) -Hs_bench = SharedCuMatrix(CUDA.randn(Float32, D_bench, D_bench)) -cs_bench = SharedCuVector(CUDA.randn(Float32, D_bench)) +Hs_bench = Shared(CUDA.randn(Float32, D_bench, D_bench), N_bench) +cs_bench = Shared(CUDA.randn(Float32, D_bench), N_bench) Rs_root_bench = BatchedCuMatrix(CUDA.randn(Float32, D_bench, D_bench, N_bench)) Rs_bench = Rs_root_bench .* adjoint.(Rs_root_bench) .+ Is_bench Rs_bench = PDMat.(Rs_bench) diff --git a/research/batching/full_kalman_demo.jl b/research/batching/full_kalman_demo.jl index ee48dea3..f9590b44 100644 --- a/research/batching/full_kalman_demo.jl +++ b/research/batching/full_kalman_demo.jl @@ -5,7 +5,6 @@ using Distributions using LinearAlgebra using Base.Broadcast: broadcasted using PDMats -using StructArrays using BenchmarkTools import Distributions: params @@ -34,8 +33,8 @@ Qs = PDMat.(Qs); Gs = MvNormal.(μs, Σ_PDs); # Observation parameters (H and c shared, R batched) -Hs = BatchedCuMatrix(CUDA.randn(Float32, D_obs, D_state, N)) -cs = BatchedCuVector(CUDA.randn(Float32, D_obs, N)) +Hs = Shared(CUDA.randn(Float32, D_obs, D_state), N) +cs = Shared(CUDA.randn(Float32, D_obs), N) Rs_root = BatchedCuMatrix(CUDA.randn(Float32, D_obs, D_obs, N)) Rs = Rs_root .* adjoint.(Rs_root) .+ Ref(I) Rs = PDMat.(Rs); @@ -77,4 +76,5 @@ res = kalman_step.(Gs, dyn_params, obs_params, observations, Ref(jitter)) println("\nFull type: ", typeof(res)) println("\nElement type: ", eltype(res)) -lls = StructArrays.component(res, 2) +# Access second component of batched tuple (the log-likelihoods) +lls = res.components[2] diff --git a/research/batching/wrappers_demo.jl b/research/batching/wrappers_demo.jl index a135cef1..6bbd6090 100644 --- a/research/batching/wrappers_demo.jl +++ b/research/batching/wrappers_demo.jl @@ -1,7 +1,6 @@ using GeneralisedFilters using CUDA using LinearAlgebra -using StructArrays using Base.Broadcast: broadcasted using Distributions using PDMats @@ -90,7 +89,7 @@ println("\nValues match (A vs parent of A_adj): ", A_cpu ≈ A_adj_cpu) println("\n=== Test 5: SharedCuMatrix wrappers ===\n") -S = SharedCuMatrix(CUDA.randn(Float32, D, D)); +S = Shared(CUDA.randn(Float32, D, D), N); S_adj = broadcasted(Adjoint, S); println("SharedCuMatrix adjoint type: ", typeof(S_adj)) println("SharedCuMatrix adjoint eltype: ", eltype(S_adj)) From 34dd99bc75188848e39151a7afda09700c12fcb7 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Thu, 26 Feb 2026 16:17:34 +0000 Subject: [PATCH 26/29] Fix indexing bug in symmetrize --- GeneralisedFilters/src/batching/wrappers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GeneralisedFilters/src/batching/wrappers.jl b/GeneralisedFilters/src/batching/wrappers.jl index 7f26b5e7..2170490d 100644 --- a/GeneralisedFilters/src/batching/wrappers.jl +++ b/GeneralisedFilters/src/batching/wrappers.jl @@ -716,7 +716,7 @@ function symmetrize_lower_kernel!(A, n) j = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y if i <= n && j <= n && j > i - @inbounds A[j, i, batch_idx] = A[i, j, batch_idx] + @inbounds A[i, j, batch_idx] = A[j, i, batch_idx] end return nothing end From ed23a120fd604d0498ff62d340e6ed7047fa92e3 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Thu, 26 Feb 2026 16:26:39 +0000 Subject: [PATCH 27/29] Remove redundant constructors --- GeneralisedFilters/src/batching/broadcasting.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/GeneralisedFilters/src/batching/broadcasting.jl b/GeneralisedFilters/src/batching/broadcasting.jl index 560b2dc0..499b1cc7 100644 --- a/GeneralisedFilters/src/batching/broadcasting.jl +++ b/GeneralisedFilters/src/batching/broadcasting.jl @@ -94,10 +94,6 @@ function broadcasted(::Type{W}, args::Vararg{BatchedOrShared}) where {W} return BatchedStruct{ElType}(nt) end -# Redirect function forms to type constructors -broadcasted(::typeof(adjoint), A::BatchedOrShared) = broadcasted(Adjoint, A) -broadcasted(::typeof(transpose), A::BatchedOrShared) = broadcasted(Transpose, A) - # copy for Adjoint/Transpose wrappers - materialize the transposition function broadcasted(::typeof(copy), x::BatchedStruct{<:Adjoint}) parent_data = x.parent # BatchedCuMatrix or SharedCuMatrix From a63c5e13e4feab1f004c2c32549d87ee52f5e534 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Thu, 26 Feb 2026 16:42:58 +0000 Subject: [PATCH 28/29] Materialise returned values --- GeneralisedFilters/src/batching/broadcasting.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/GeneralisedFilters/src/batching/broadcasting.jl b/GeneralisedFilters/src/batching/broadcasting.jl index 499b1cc7..51c8325a 100644 --- a/GeneralisedFilters/src/batching/broadcasting.jl +++ b/GeneralisedFilters/src/batching/broadcasting.jl @@ -324,6 +324,10 @@ function Broadcast.materialize(bc::Broadcasted{BatchedStyle}) f = bc.f N = _find_batch_size(bc.args) args = map(a -> maybe_convert_ref(a, N), bc.args) + # Julia's broadcast fusion leaves inner dot-calls as lazy Broadcasted objects. + # Materialize them before dispatch so specialized broadcasted methods see concrete types. + # TODO: is this desirable? + args = map(a -> a isa Broadcasted ? Broadcast.materialize(a) : a, args) result = broadcasted(f, args...) if !(result isa Broadcasted) From 2dacc2e89d6f9bdb40478e9db4220da1cb66c575 Mon Sep 17 00:00:00 2001 From: Sihan Yu Date: Mon, 9 Mar 2026 12:29:44 +0000 Subject: [PATCH 29/29] Made BatchedCuArray, SharedCuArray, SharedScalar and BatchedStruct Adapt.jl compatible --- GeneralisedFilters/Project.toml | 2 + GeneralisedFilters/src/batching/operations.jl | 16 +-- GeneralisedFilters/src/batching/types.jl | 116 +++++++++++++----- 3 files changed, 97 insertions(+), 37 deletions(-) diff --git a/GeneralisedFilters/Project.toml b/GeneralisedFilters/Project.toml index 682265e9..dcc206ba 100644 --- a/GeneralisedFilters/Project.toml +++ b/GeneralisedFilters/Project.toml @@ -6,6 +6,7 @@ authors = ["THargreaves ", "Charles Knipp SharedCuArray + Shared(data::AbstractArray, N::Int) -> SharedCuArray -Convenience constructor: create a `SharedCuArray` from a CuArray with an explicit +Convenience constructor: create a `SharedCuArray` from an arrat with an explicit 1D batch size `N`. + +The underlying storage is generic to support `Adapt.jl` transformations, but in +the user-facing intended interface `A` is type `CuArray` """ -Shared(x::CuArray{T,2,M}, N::Int) where {T,M} = SharedCuArray{T,2,1,M}(x, (N,)) -Shared(x::CuArray{T,1,M}, N::Int) where {T,M} = SharedCuArray{T,1,1,M}(x, (N,)) +Shared(x::A, N::Int) where {T,A<:AbstractArray{T,2}} = SharedCuArray{T,2,1,A}(x, (N,)) +Shared(x::A, N::Int) where {T,A<:AbstractArray{T,1}} = SharedCuArray{T,1,1,A}(x, (N,)) Base.IndexStyle(::Type{<:SharedCuArray}) = Base.IndexCartesian() @@ -129,6 +157,15 @@ end batch_size(x::SharedCuArray) = length(x) +# Adapting SharedCuArray to bitstype +function Adapt.adapt_structure( + to, + x::SharedCuArray{T,InnerN,BatchN,A}, +) where {T,InnerN,BatchN,A<:AbstractArray{T,InnerN}} + data_adapted = Adapt.adapt(to, x.data) + return SharedCuArray{T,InnerN,BatchN,typeof(data_adapted)}(data_adapted, x.batchsize) +end + # ============================================================================= # SharedScalar: a scalar value shared across all batch elements # ============================================================================= @@ -149,6 +186,9 @@ Base.:(==)(x::SharedScalar, y) = x.value == y Base.:(==)(x, y::SharedScalar) = x == y.value Base.:(==)(x::SharedScalar, y::SharedScalar) = x.value == y.value +# Adapting SharedScalar to bitstype +Adapt.@adapt_structure SharedScalar + # ============================================================================= # BatchedStruct - Custom wrapper for batched composite types # ============================================================================= @@ -225,6 +265,15 @@ end Base.size(x::BatchedStruct) = (batch_size(x),) Base.IndexStyle(::Type{<:BatchedStruct}) = IndexLinear() +# Non-generic indexing +function Base.getindex(x::BatchedStruct{<:Adjoint}, i::Integer) + return adjoint(getfield(x, :components).parent[i]) +end + +function Base.getindex(x::BatchedStruct{<:Transpose}, i::Integer) + return transpose(getfield(x, :components).parent[i]) +end + # Indexing: construct an element of type T by calling its constructor @generated function Base.getindex(x::BatchedStruct{T}, i::Integer) where {T} fields = fieldnames(T) @@ -271,6 +320,15 @@ function Base.show(io::IO, ::MIME"text/plain", x::BatchedStruct{T}) where {T} end end +# Adapting BatchedStruct to bitstype +function Adapt.adapt_structure( + to, + x::BatchedStruct{T,C}, +) where {T,C<:NamedTuple} + comps_adapted = Adapt.adapt(to, x.components) + return BatchedStruct{T}(comps_adapted) +end + # ============================================================================= # Union Types for Dispatch # =============================================================================