diff --git a/GeneralisedFilters/Project.toml b/GeneralisedFilters/Project.toml index 198b7ef3..dcc206ba 100644 --- a/GeneralisedFilters/Project.toml +++ b/GeneralisedFilters/Project.toml @@ -6,11 +6,14 @@ authors = ["THargreaves ", "Charles Knipp arg isa BatchedOrShared, args) + field_names = fieldnames(T) + element_types = Tuple{map(inner_eltype, args)...} + ElType = Core.Compiler.return_type(T, element_types) + nt = NamedTuple{field_names}(args) + return BatchedStruct{ElType}(nt) + else + return T(args...) + end +end + +# ============================================================================= +# Generic Wrapper Broadcasting +# ============================================================================= + +""" + broadcasted(::Type{W}, args::BatchedOrShared...) + +Generic wrapper for any type constructor applied to batched arguments. +Works for single-field wrappers (Adjoint, Transpose, LowerTriangular, etc.) +as well as multi-field types (PDMat, Cholesky, etc.). + +Returns a BatchedStruct where each element is the type applied to the +corresponding elements of the input arrays. +""" +function broadcasted(::Type{W}, args::Vararg{BatchedOrShared}) where {W} + element_types = Tuple{map(eltype, args)...} + ElType = Core.Compiler.return_type(W, element_types) + field_names = fieldnames(ElType) + nt = NamedTuple{field_names}(args) + return BatchedStruct{ElType}(nt) +end + +# copy for Adjoint/Transpose wrappers - materialize the transposition +function broadcasted(::typeof(copy), x::BatchedStruct{<:Adjoint}) + parent_data = x.parent # BatchedCuMatrix or SharedCuMatrix + if parent_data isa BatchedCuMatrix + return BatchedCuMatrix(permutedims(parent_data.data, (2, 1, 3))) + else # SharedCuMatrix + return SharedCuMatrix(permutedims(parent_data.data, (2, 1))) + end +end + +function broadcasted(::typeof(copy), x::BatchedStruct{<:Transpose}) + parent_data = x.parent + if parent_data isa BatchedCuMatrix + return BatchedCuMatrix(permutedims(parent_data.data, (2, 1, 3))) + else # SharedCuMatrix + return SharedCuMatrix(permutedims(parent_data.data, (2, 1))) + end +end + +# Union of all types that represent batched data +const BatchedData = Union{BatchedOrShared,CuVector} + +# Batched tuple creation: returns BatchedStruct{Tuple{...}} +function broadcasted(::typeof(tuple), args::Vararg{BatchedData}) + ElType = Tuple{map(eltype, args)...} + # For tuples, components is a regular Tuple, not NamedTuple + components = NamedTuple{ntuple(i -> Symbol("x$i"), length(args))}(args) + return BatchedStruct{ElType}(components) +end + +# ============================================================================= +# getfield/getproperty broadcasting for BatchedStruct +# ============================================================================= + +# getfield on BatchedStruct: return the batched component +# Handle both unwrapped and Ref-wrapped field names (Ref from maybe_wrap_scalar) +function broadcasted(::typeof(getfield), x::BatchedStruct{T}, s::Symbol) where {T} + s in fieldnames(T) && return getfield(x, :components)[s] + return error("BatchedStruct{$T} has no field `$s`") +end +function broadcasted( + ::typeof(getfield), x::BatchedStruct{T}, s::Base.RefValue{Symbol} +) where {T} + return broadcasted(getfield, x, s[]) +end +broadcasted(::typeof(getfield), x::BatchedStruct, i::Int) = getfield(x, :components)[i] +function broadcasted(::typeof(getfield), x::BatchedStruct, i::Base.RefValue{<:Integer}) + return getfield(x, :components)[i[]] +end + +# getproperty on BatchedStruct: return batched component for real fields, +# fall through to tracing for computed properties +function broadcasted(::typeof(getproperty), x::BatchedStruct{T}, s::Symbol) where {T} + s in fieldnames(T) && return getfield(x, :components)[s] + # Computed property - return Broadcasted to trigger IR transformation + return Broadcasted{BatchedStyle}(getproperty, (x, s)) +end + +function broadcasted( + ::typeof(getproperty), x::BatchedStruct{T}, s::Base.RefValue{Symbol} +) where {T} + return broadcasted(getproperty, x, s[]) +end + +# ============================================================================= +# size broadcasting for batched arrays (returns inner dimensions) +# ============================================================================= + +broadcasted(::typeof(size), A::BatchedCuMatrix) = inner_size(A) +broadcasted(::typeof(size), A::BatchedCuMatrix, i::Integer) = inner_size(A)[i] +broadcasted(::typeof(size), A::SharedCuMatrix) = inner_size(A) +broadcasted(::typeof(size), A::SharedCuMatrix, i::Integer) = inner_size(A)[i] + +broadcasted(::typeof(size), x::BatchedCuVector) = inner_size(x) +broadcasted(::typeof(size), x::BatchedCuVector, i::Integer) = inner_size(x)[i] +broadcasted(::typeof(size), x::SharedCuVector) = inner_size(x) +broadcasted(::typeof(size), x::SharedCuVector, i::Integer) = inner_size(x)[i] + +# ============================================================================= +# IR Transformation +# ============================================================================= + +const SKIP_BROADCAST = Set{Any}() +const BROADCAST_TYPES = Set{Any}([PDMat]) + +# Don't wrap: batched data, shared scalars, callables, modules, symbols, integers, already-wrapped refs +maybe_wrap_scalar(x::BatchedData) = x +maybe_wrap_scalar(x::SharedScalar) = x +maybe_wrap_scalar(x::Union{Type,Module,Symbol,Integer,Base.RefValue}) = x +maybe_wrap_scalar(x) = typeof(x) <: Function ? x : Ref(x) + +@inline function broadcast_and_materialize(f, args...) + wrapped_args = map(maybe_wrap_scalar, args) + + # Check if any argument is actually batched + has_batched = any(arg -> arg isa BatchedData, wrapped_args) + + if !has_batched + # All scalars - unwrap and execute directly + unwrapped_args = map(a -> a isa Base.RefValue ? a[] : a, wrapped_args) + return f(unwrapped_args...) + end + + # Special case: getfield on Ref-wrapped scalar object (getfield is a builtin, can't trace) + if f === getfield && length(wrapped_args) >= 1 + obj = wrapped_args[1] + if obj isa Base.RefValue + return getfield(obj[], wrapped_args[2:end]...) + end + end + + # Has batched inputs - normal broadcast path + result = broadcasted(f, wrapped_args...) + if result isa Broadcasted + return Broadcast.materialize(result) + end + return result +end + +function resolve_to_type(ir::IR, val) + val isa Type && return val + if val isa Variable + stmt = ir[val] + if stmt !== nothing + return resolve_to_type(ir, stmt.expr) + end + end + if val isa GlobalRef + try + resolved = getfield(val.mod, val.name) + return resolved isa Type ? resolved : nothing + catch + return nothing + end + end + return nothing +end + +function is_type_ref(ir::IR, val) + return resolve_to_type(ir, val) !== nothing +end + +function is_broadcast_type(ir::IR, val) + resolved = resolve_to_type(ir, val) + return resolved !== nothing && resolved in BROADCAST_TYPES +end + +function transform_to_batched(ir::IR) + ir = copy(ir) + + for (v, stmt) in ir + if stmt.expr isa Expr && stmt.expr.head == :call + fn = stmt.expr.args[1] + if fn in SKIP_BROADCAST + continue + end + if is_broadcast_type(ir, fn) + new_args = [broadcast_and_materialize, stmt.expr.args...] + ir[v] = Statement(stmt; expr=Expr(:call, new_args...)) + continue + end + if is_type_ref(ir, fn) + new_args = [wrap_if_batched, stmt.expr.args...] + ir[v] = Statement(stmt; expr=Expr(:call, new_args...)) + continue + end + new_args = [broadcast_and_materialize, stmt.expr.args...] + ir[v] = Statement(stmt; expr=Expr(:call, new_args...)) + end + end + + return ir +end + +ir_element_type(::Type{T}) where {T} = T +ir_element_type(::Type{<:BatchedStruct{T}}) where {T} = T +ir_element_type(::Type{<:Base.RefValue{T}}) where {T} = T + +function generate_batched_function(f, argtypes::Type{<:Tuple}) + element_types = Tuple{map(ir_element_type, argtypes.parameters)...} + + ir = IRTools.Inner.code_ir(f, element_types) + if ir === nothing + error( + "Could not get IR for function $f with types $element_types (original: $argtypes)", + ) + end + batched_ir = transform_to_batched(ir) + return IRTools.func(batched_ir) +end + +# ============================================================================= +# Broadcast Materialization +# ============================================================================= + +# Verbosity levels: :silent, :verbose, :debug +# :silent - no output +# :verbose - print when generating or regenerating (i.e. cache misses) +# :debug - print all cache activity including hits +const BATCHED_CACHE_VERBOSITY = Ref{Symbol}(:silent) + +# Cache stores (batched_function, world_age_when_cached) +const BATCHED_FUNC_CACHE = Dict{Tuple{Any,Type},Tuple{Any,UInt}}() + +""" + clear_batched_cache!() + +Clear the batched function cache. Useful for debugging or forcing regeneration. +""" +function clear_batched_cache!() + empty!(BATCHED_FUNC_CACHE) + return nothing +end + +function _find_batch_size(args) + for arg in args + if arg isa BatchedCuArray || arg isa SharedCuArray + return batch_size(arg) + elseif arg isa Broadcasted + n = _find_batch_size(arg.args) + n !== nothing && return n + end + end + return nothing +end + +function Broadcast.materialize(bc::Broadcasted{BatchedStyle}) + f = bc.f + N = _find_batch_size(bc.args) + args = map(a -> maybe_convert_ref(a, N), bc.args) + # Julia's broadcast fusion leaves inner dot-calls as lazy Broadcasted objects. + # Materialize them before dispatch so specialized broadcasted methods see concrete types. + # TODO: is this desirable? + args = map(a -> a isa Broadcasted ? Broadcast.materialize(a) : a, args) + + result = broadcasted(f, args...) + if !(result isa Broadcasted) + return result + end + + argtypes = Tuple{map(typeof, args)...} + key = (f, argtypes) + + # Get element types for method lookup + element_types = Tuple{map(ir_element_type, argtypes.parameters)...} + + if haskey(BATCHED_FUNC_CACHE, key) + batched_f, cached_world = BATCHED_FUNC_CACHE[key] + # Check if the method has been redefined since caching + m = which(f, element_types) + if m.primary_world <= cached_world + if BATCHED_CACHE_VERBOSITY[] == :debug + println(" [Using cached batched version of $f]") + end + return Base.invokelatest(batched_f, nothing, args...) + end + if BATCHED_CACHE_VERBOSITY[] in (:verbose, :debug) + println(" [Regenerating batched version of $f (method redefined)]") + end + else + if BATCHED_CACHE_VERBOSITY[] in (:verbose, :debug) + println(" [Generating batched version of $f]") + end + end + + batched_f = generate_batched_function(f, argtypes) + current_world = Base.get_world_counter() + BATCHED_FUNC_CACHE[key] = (batched_f, current_world) + return Base.invokelatest(batched_f, nothing, args...) +end diff --git a/GeneralisedFilters/src/batching/operations.jl b/GeneralisedFilters/src/batching/operations.jl new file mode 100644 index 00000000..cb73aea5 --- /dev/null +++ b/GeneralisedFilters/src/batching/operations.jl @@ -0,0 +1,427 @@ +import PDMats: X_A_Xt +import LinearAlgebra: norm + +# ============================================================================= +# GEMM-Compatible Types +# ============================================================================= + +# Type aliases for BatchedStruct-wrapped matrices +const BatchedAdjoint{T,A<:AbstractArray{T,3}} = BatchedStruct{ + <:Adjoint{T,<:AbstractArray{T,2}},@NamedTuple{parent::BatchedCuMatrix{T,A}} +} +const BatchedTranspose{T,A<:AbstractArray{T,3}} = BatchedStruct{ + <:Transpose{T,<:AbstractArray{T,2}},@NamedTuple{parent::BatchedCuMatrix{T,A}} +} +const SharedAdjoint{T,A<:AbstractArray{T,2}} = BatchedStruct{ + <:Adjoint{T,<:AbstractArray{T,2}},@NamedTuple{parent::SharedCuMatrix{T,A}} +} +const SharedTranspose{T,A<:AbstractArray{T,2}} = BatchedStruct{ + <:Transpose{T,<:AbstractArray{T,2}},@NamedTuple{parent::SharedCuMatrix{T,A}} +} + +# Union of all GEMM-compatible matrix types +const GEMMCompatibleMatrix{T} = Union{ + BatchedCuMatrix{T}, + SharedCuMatrix{T}, + BatchedAdjoint{T}, + BatchedTranspose{T}, + SharedAdjoint{T}, + SharedTranspose{T}, +} + +# trans_flag: returns BLAS transpose flag for each type +trans_flag(::BatchedCuMatrix{T}) where {T} = 'N' +trans_flag(::SharedCuMatrix{T}) where {T} = 'N' +trans_flag(::BatchedAdjoint{T}) where {T} = T <: Real ? 'T' : 'C' +trans_flag(::BatchedTranspose{T}) where {T} = 'T' +trans_flag(::SharedAdjoint{T}) where {T} = T <: Real ? 'T' : 'C' +trans_flag(::SharedTranspose{T}) where {T} = 'T' + +# gemm_data: extracts the underlying BatchedCuMatrix/SharedCuMatrix for GEMM +gemm_data(A::BatchedCuMatrix) = A +gemm_data(A::SharedCuMatrix) = A +gemm_data(A::BatchedAdjoint) = A.parent +gemm_data(A::BatchedTranspose) = A.parent +gemm_data(A::SharedAdjoint) = A.parent +gemm_data(A::SharedTranspose) = A.parent + +# inner_size_for_blas for wrapped types (delegates to underlying data) +inner_size_for_blas(A::BatchedAdjoint) = inner_size_for_blas(A.parent) +inner_size_for_blas(A::BatchedTranspose) = inner_size_for_blas(A.parent) +inner_size_for_blas(A::SharedAdjoint) = inner_size_for_blas(A.parent) +inner_size_for_blas(A::SharedTranspose) = inner_size_for_blas(A.parent) + +# batch_size for wrapped types +batch_size(A::BatchedAdjoint) = batch_size(A.parent) +batch_size(A::BatchedTranspose) = batch_size(A.parent) +batch_size(A::SharedAdjoint) = batch_size(A.parent) +batch_size(A::SharedTranspose) = batch_size(A.parent) + +# TODO: For nested wrappers (e.g., Adjoint{LowerTriangular{...}}), we should +# materialize the inner wrapper first before extracting. For now, we only +# support single-level Adjoint/Transpose wrappers for efficient GEMM dispatch. + +# ============================================================================= +# Matrix Multiply Broadcasting +# ============================================================================= + +function broadcasted( + ::typeof(*), A::GEMMCompatibleMatrix{T}, B::GEMMCompatibleMatrix{T} +) where {T} + transA = trans_flag(A) + transB = trans_flag(B) + + A_inner = inner_size_for_blas(A) + B_inner = inner_size_for_blas(B) + + m = transA == 'N' ? A_inner[1] : A_inner[2] + n = transB == 'N' ? B_inner[2] : B_inner[1] + N = get_batch_size(A, B) + + C_data = CuArray{T}(undef, m, n, N) + C = BatchedCuMatrix(C_data) + + gemm_batched!(transA, transB, one(T), gemm_data(A), gemm_data(B), zero(T), C) + return C +end + +# Multi-argument multiply +function broadcasted( + ::typeof(*), + A::GEMMCompatibleMatrix{T}, + B::GEMMCompatibleMatrix{T}, + C::GEMMCompatibleMatrix{T}, + rest::GEMMCompatibleMatrix{T}..., +) where {T} + result = broadcasted(*, A, B) + result = broadcasted(*, result, C) + for R in rest + result = broadcasted(*, result, R) + end + return result +end + +# ============================================================================= +# Matrix-Vector Multiply Broadcasting +# ============================================================================= + +function broadcasted( + ::typeof(*), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + x::Union{BatchedCuVector{T},SharedCuVector{T}}, +) where {T} + transA = trans_flag(A) + A_inner = inner_size_for_blas(A) + m = transA == 'N' ? A_inner[1] : A_inner[2] + N = get_batch_size(A, x) + + y_data = CuArray{T}(undef, m, N) + y = BatchedCuVector(y_data) + + gemv_batched!(transA, one(T), A, x, zero(T), y) + + return y +end + +# ============================================================================= +# Identity Minus Matrix (I - A) Custom Kernel +# ============================================================================= + +function identity_minus_kernel!(C, A, m, n) + batch_idx = blockIdx().z + i = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + j = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y + + if i <= m && j <= n + if i == j + @inbounds C[i, j, batch_idx] = one(eltype(C)) - A[i, j, batch_idx] + else + @inbounds C[i, j, batch_idx] = -A[i, j, batch_idx] + end + end + return nothing +end + +function identity_minus_batched!(C::CuArray{T,3}, A::CuArray{T,3}) where {T} + m, n, N = size(A) + threads = (16, 16) + blocks = (cld(m, 16), cld(n, 16), N) + @cuda threads = threads blocks = blocks identity_minus_kernel!(C, A, m, n) + return C +end + +function broadcasted( + ::typeof(-), ::Base.RefValue{UniformScaling{Bool}}, A::BatchedCuMatrix{T} +) where {T} + C = CuArray{T}(undef, size(A.data)) + identity_minus_batched!(C, A.data) + return BatchedCuMatrix(C) +end + +# ============================================================================= +# Matrix Plus Scaled Identity (A + λI) Custom Kernel +# ============================================================================= + +function plus_scaled_identity_kernel!(C, A, λ, m, n) + batch_idx = blockIdx().z + i = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + j = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y + + if i <= m && j <= n + if i == j + @inbounds C[i, j, batch_idx] = A[i, j, batch_idx] + λ + else + @inbounds C[i, j, batch_idx] = A[i, j, batch_idx] + end + end + return nothing +end + +function plus_scaled_identity_batched!(C::CuArray{T,3}, A::CuArray{T,3}, λ::T) where {T} + m, n, N = size(A) + threads = (16, 16) + blocks = (cld(m, 16), cld(n, 16), N) + @cuda threads = threads blocks = blocks plus_scaled_identity_kernel!(C, A, λ, m, n) + return C +end + +# A + Ref(I) where I is unscaled identity +function broadcasted( + ::typeof(+), A::BatchedCuMatrix{T}, ::Base.RefValue{UniformScaling{Bool}} +) where {T} + C = CuArray{T}(undef, size(A.data)) + plus_scaled_identity_batched!(C, A.data, one(T)) + return BatchedCuMatrix(C) +end + +function broadcasted( + ::typeof(+), ::Base.RefValue{UniformScaling{Bool}}, A::BatchedCuMatrix{T} +) where {T} + C = CuArray{T}(undef, size(A.data)) + plus_scaled_identity_batched!(C, A.data, one(T)) + return BatchedCuMatrix(C) +end + +# A + λI where λI is a scaled UniformScaling +function broadcasted( + ::typeof(+), A::BatchedCuMatrix{T}, J::Base.RefValue{<:UniformScaling{T}} +) where {T} + C = CuArray{T}(undef, size(A.data)) + plus_scaled_identity_batched!(C, A.data, J[].λ) + return BatchedCuMatrix(C) +end + +function broadcasted( + ::typeof(+), J::Base.RefValue{<:UniformScaling{T}}, A::BatchedCuMatrix{T} +) where {T} + C = CuArray{T}(undef, size(A.data)) + plus_scaled_identity_batched!(C, A.data, J[].λ) + return BatchedCuMatrix(C) +end + +# ============================================================================= +# PDMat Broadcasting +# ============================================================================= + +# HACK: PDMat is a constructor so will use +# `broadcasted(::Type{W}, args::BatchedOrShared...) where W` +# rather than the desired recursive broadcast to +# `PDMat(mat::AbstractMatrix) = PDMat(mat, cholesky(mat))` +# This method hardcodes a manual override for BatchedCuMatrix inputs. This should be replaced +# by a more general solution in the future. +function broadcasted(::Type{PDMat}, A::BatchedCuMatrix{T}) where {T} + chol = cholesky.(A) + return PDMat.(A, chol) +end + +# HACK: Addition with PDMat extracts .mat field. Should be replaced by automatic +# materialization of PDMat to BatchedCuMatrix in the future. +function broadcasted( + ::typeof(+), A::BatchedCuMatrix{T}, P::BatchedStruct{<:PDMat{T}} +) where {T} + return broadcasted(+, A, P.mat) +end + +function broadcasted( + ::typeof(+), P::BatchedStruct{<:PDMat{T}}, A::BatchedCuMatrix{T} +) where {T} + return broadcasted(+, P.mat, A) +end + +# A / S where S is PDMat: computes A * inv(S) +# potrs solves S * X = B, so we solve S * X = A' and transpose back +function broadcasted( + ::typeof(/), A::BatchedCuMatrix{T}, S::BatchedStruct{<:PDMat{T}} +) where {T} + L = S.chol.factors.data + + # Transpose A: potrs solves S*X = B, we want A*inv(S) = (inv(S)*A')' + At = BatchedCuMatrix(permutedims(A.data, (2, 1, 3))) + + # Solve S * X = A' in-place (result stored in At) + potrs_batched!('L', L, At) + + # Transpose back + return BatchedCuMatrix(permutedims(At.data, (2, 1, 3))) +end + +# ============================================================================= +# Quadratic Form Broadcasting +# ============================================================================= + +# HACK: treat this as two GEMMs for now +function broadcasted( + ::typeof(X_A_Xt), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + temp = broadcasted(*, X, A) + Xt = broadcasted(adjoint, X) + return broadcasted(*, temp, Xt) +end + +# X_A_Xt for BatchedStruct{PDMat}: X * P * X' where P = L * L' +# Computed as (X * L) * (X * L)' using TRMM and SYRK +# HACK: this function should dispatch to specialised `*` for triangular types but this is +# not yet implemented +function broadcasted( + ::typeof(X_A_Xt), + P::BatchedStruct{<:PDMat{T}}, + X::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + # P.chol.factors is BatchedStruct{LowerTriangular}, .parent is the BatchedCuMatrix + L = P.chol.factors.data + N = get_batch_size(L, X) + out_dim = inner_size_for_blas(X)[1] + + # Copy X for in-place TRMM + XL = if X isa SharedCuMatrix + BatchedCuMatrix(repeat(reshape(X.data, size(X.data)..., 1), 1, 1, N)) + else + BatchedCuMatrix(copy(X.data)) + end + + # XL = X * L using TRMM (side='R', uplo='L', no transpose, non-unit diagonal) + trmm_batched!('R', 'L', 'N', 'N', one(T), L, XL) + + # result = XL * XL' using SYRK (fills lower triangle only) + result = BatchedCuMatrix(CuArray{T}(undef, out_dim, out_dim, N)) + syrk_batched!('L', 'N', one(T), XL, zero(T), result) + + # Copy lower triangle to upper for full symmetric matrix + symmetrize_lower!(result) + + return result +end + +# ============================================================================= +# Batched norm +# ============================================================================= + +# Compute 2-norm for each vector in the batch, returns a CuVector of scalars +function broadcasted(::typeof(norm), v::BatchedCuVector{T}) where {T} + # v.data is D×N, compute norm of each column + return vec(sqrt.(sum(abs2, v.data; dims=1))) +end + +# ============================================================================= +# Batched ifelse (for conditional selection with batched conditions) +# ============================================================================= + +# Select entire vectors: x[:,j] if cond[j], else y[:,j] +function broadcasted( + ::typeof(ifelse), cond::CuVector{Bool}, x::BatchedCuVector{T}, y::BatchedCuVector{T} +) where {T} + # cond is length N (one bool per batch element) + # x.data and y.data are D×N + mask = reshape(T.(cond), 1, :) # 1×N mask for column selection + result = mask .* x.data .+ (one(T) .- mask) .* y.data + return BatchedCuVector(result) +end + +# Select entire matrices: x[:,:,j] if cond[j], else y[:,:,j] +function broadcasted( + ::typeof(ifelse), cond::CuVector{Bool}, x::BatchedCuMatrix{T}, y::BatchedCuMatrix{T} +) where {T} + # cond is length N (one bool per batch element) + # x.data and y.data are D×D×N + mask = reshape(T.(cond), 1, 1, :) # 1×1×N mask for batch selection + result = mask .* x.data .+ (one(T) .- mask) .* y.data + return BatchedCuMatrix(result) +end + +# ============================================================================= +# Batched zero +# ============================================================================= + +function broadcasted(::typeof(zero), v::BatchedCuVector{T}) where {T} + return BatchedCuVector(CUDA.zeros(T, size(v.data))) +end + +function broadcasted(::typeof(zero), A::BatchedCuMatrix{T}) where {T} + return BatchedCuMatrix(CUDA.zeros(T, size(A.data))) +end + +# ============================================================================= +# Batched logdetcov and invquad (for _logpdf) +# ============================================================================= + +import Distributions: logdetcov, sqmahal +import PDMats: invquad + +# length for BatchedCuVector: returns the inner dimension (same for all batch elements) +function broadcasted(::typeof(length), v::BatchedCuVector) + return size(v.data, 1) +end + +# logdetcov for BatchedStruct{MvNormal}: delegates to the covariance matrix +function broadcasted(::typeof(logdetcov), d::BatchedStruct{<:MvNormal{T}}) where {T} + return broadcasted(logdetcov, d.Σ) +end + +# logdetcov for BatchedStruct{PDMat}: 2 * sum(log.(diag(L))) for each batch element +function broadcasted(::typeof(logdetcov), P::BatchedStruct{<:PDMat{T}}) where {T} + # P.chol.factors is BatchedStruct{LowerTriangular}, .parent is BatchedCuMatrix + L_data = P.chol.factors.data.data # D×D×N CuArray + D, _, N = size(L_data) + + # Extract diagonal: L_data[i,i,k] for each i,k + diag_indices = [CartesianIndex(i, i, k) for k in 1:N for i in 1:D] + diag_flat = L_data[diag_indices] # (D*N,) vector + diag_matrix = reshape(diag_flat, D, N) # D×N + + # logdet = 2 * sum(log.(diag(L))) + return vec(T(2) .* sum(log.(diag_matrix); dims=1)) +end + +# invquad for BatchedStruct{PDMat} and BatchedCuVector: x' * inv(P) * x +# Computed as: solve P*y = x (via potrs), then dot(x, y) +function broadcasted( + ::typeof(invquad), P::BatchedStruct{<:PDMat{T}}, x::BatchedCuVector{T} +) where {T} + L = P.chol.factors.data # BatchedCuMatrix (D×D×N) + D, N = size(x.data) + + # Reshape x to matrix for potrs: D×N -> D×1×N + x_mat = BatchedCuMatrix(reshape(copy(x.data), D, 1, N)) + + # Solve L*L'*y = x in-place + potrs_batched!('L', L, x_mat) + + # y is now in x_mat, compute dot(x, y) = sum(x .* y) for each batch + y_vec = reshape(x_mat.data, D, N) + return vec(sum(x.data .* y_vec; dims=1)) +end + +# sqmahal for BatchedStruct{MvNormal} and BatchedCuVector: (x - μ)' * inv(Σ) * (x - μ) +function broadcasted( + ::typeof(sqmahal), d::BatchedStruct{<:MvNormal{T}}, x::BatchedCuVector{T} +) where {T} + diff = broadcasted(-, x, d.μ) + return broadcasted(invquad, d.Σ, diff) +end + +# oftype for CuVector result: convert scalar to element type +function broadcasted(::typeof(oftype), x::CuVector{T}, y::Base.RefValue) where {T} + return T(y[]) +end diff --git a/GeneralisedFilters/src/batching/types.jl b/GeneralisedFilters/src/batching/types.jl new file mode 100644 index 00000000..dcbd93a5 --- /dev/null +++ b/GeneralisedFilters/src/batching/types.jl @@ -0,0 +1,380 @@ +using Adapt +using CUDA +using LinearAlgebra: + Adjoint, Transpose, LowerTriangular, UpperTriangular, UniformScaling, Cholesky +using PDMats: PDMat + +export BatchedCuArray, BatchedCuMatrix, BatchedCuVector +export SharedCuArray, SharedCuMatrix, SharedCuVector, SharedScalar +export Shared +export BatchedStruct + +# ============================================================================= +# Core Batched Type +# ============================================================================= + +""" + BatchedCuArray{T,NE,NB,NT,A<:AbstractArray{T,NT}} <: AbstractArray{Any,NB} + +An `NB`-dimensional batch of `NE`-dimensional arrays, stored as a single contiguous +`NT`-dimensional array `data` where `NT = NE + NB`. + +- `NE`: number of element dimensions (the "inner" array shape) +- `NB`: number of batch dimensions +- `NT`: total number of dimensions (`NE + NB`); required explicitly because Julia's type + system cannot express arithmetic on type parameters +- `A`: storage array type + +The first `NE` dimensions index within each element; the last `NB` dimensions index across +the batch. + +This type is generic over the storage array type so that it can participate in `Adapt.jl` +transformations. In the user-facing intended usage, `data` is a `CuArray{T, NT, M}`. + +# Common aliases +- `BatchedCuMatrix{T,A}` = `BatchedCuArray{T,2,1,3,A}` — a vector of matrices +- `BatchedCuVector{T,A}` = `BatchedCuArray{T,1,1,2,A}` — a vector of vectors +""" +struct BatchedCuArray{T,NE,NB,NT,A<:AbstractArray{T,NT}} <: AbstractArray{Any,NB} + data::A + + function BatchedCuArray{T,NE,NB,NT,A}(data::A) where {T,NE,NB,NT,A<:AbstractArray{T,NT}} + NE + NB == NT || error("NE ($NE) + NB ($NB) must equal ndims(data) ($NT)") + return new{T,NE,NB,NT,A}(data) + end +end + +# Convenience constructor: infer T and M, require explicit NE and NB +function BatchedCuArray{T,NE,NB}(data::A) where {T,NE,NB,A<:AbstractArray{T}} + NT = ndims(data) + NE + NB == NT || error("NE ($NE) + NB ($NB) must equal ndims(data) ($NT)") + return BatchedCuArray{T,NE,NB,NT,A}(data) +end + +# Common case aliases +const BatchedCuMatrix{T,A<:AbstractArray{T,3}} = BatchedCuArray{T,2,1,3,A} +const BatchedCuVector{T,A<:AbstractArray{T,2}} = BatchedCuArray{T,1,1,2,A} + +# Constructors for aliased cases +BatchedCuMatrix(data::A) where {T,A<:AbstractArray{T,3}} = BatchedCuArray{T,2,1,3,A}(data) +BatchedCuVector(data::A) where {T,A<:AbstractArray{T,2}} = BatchedCuArray{T,1,1,2,A}(data) + +const BatchedArray = BatchedCuArray + +Base.IndexStyle(::Type{<:BatchedCuArray}) = Base.IndexCartesian() + +function Base.size(x::BatchedCuArray{T,NE,NB}) where {T,NE,NB} + return ntuple(i -> size(x.data, NE + i), NB) +end + +function Base.getindex(x::BatchedCuArray{T,NE,NB}, I::Vararg{Int,NB}) where {T,NE,NB} + return view(x.data, ntuple(_ -> :, NE)..., I...) +end + +function inner_size(x::BatchedCuArray{T,NE}) where {T,NE} + return ntuple(i -> size(x.data, i), NE) +end + +batch_size(x::BatchedCuArray) = length(x) + +# Adapting BatchedCuArray to bitstype +function Adapt.adapt_structure( + to, + x::BatchedCuArray{T,NE,NB,NT,A}, +) where {T,NE,NB,NT,A} + data_adapted = Adapt.adapt(to, x.data) + return BatchedCuArray{T,NE,NB,NT,typeof(data_adapted)}(data_adapted) +end + +# ============================================================================= +# Shared Types (same data reused across all batch elements) +# ============================================================================= + +""" + SharedCuArray{T,InnerN,BatchN,A<:AbstractArray{T,InnerN}} <: AbstractArray{Any,BatchN} + +A batch of arrays where every element is the same underlying array. +Unlike `Ref(array)`, this type carries an explicit batch size and satisfies the +`AbstractArray` contract honestly. + +Use `Ref(array)` when the batch size is unknown or irrelevant (e.g. during broadcast +setup). Use `SharedCuArray` when you need a proper `AbstractArray` with a known size. + +This type is generic over the storage array type so that it can participate in `Adapt.jl` +transformations. In the user-facing intended usage, `data` is a `CuArray{T,InnerN,M}`. + +# Common aliases +- `SharedCuMatrix{T,A}` = `SharedCuArray{T,2,1,A}` +- `SharedCuVector{T,A}` = `SharedCuArray{T,1,1,A}` +""" +struct SharedCuArray{T,InnerN,BatchN,A<:AbstractArray{T,InnerN}} <: AbstractArray{Any,BatchN} + data::A + batchsize::NTuple{BatchN,Int} +end + +# Outer constructor: accept a plain Int for the common 1D-batch case +function SharedCuArray{T,InnerN,1,A}(data::A, N::Int) where {T,InnerN,A<:AbstractArray{T,InnerN}} + return SharedCuArray{T,InnerN,1,A}(data, (N,)) +end + +const SharedCuMatrix{T,A<:AbstractArray{T,2}} = SharedCuArray{T,2,1,A} +const SharedCuVector{T,A<:AbstractArray{T,1}} = SharedCuArray{T,1,1,A} + +# Constructors for aliased cases +SharedCuMatrix(data::A, N::Int) where {T,A<:AbstractArray{T,2}} = SharedCuArray{T,2,1,A}(data, N) +SharedCuVector(data::A, N::Int) where {T,A<:AbstractArray{T,1}} = SharedCuArray{T,1,1,A}(data, N) + +const SharedArray = SharedCuArray + +Base.eltype(::Type{<:BatchedCuArray{T,NE}}) where {T,NE} = AbstractArray{T,NE} +Base.eltype(::Type{<:SharedCuArray{T,InnerN}}) where {T,InnerN} = AbstractArray{T,InnerN} + +""" + Shared(data::AbstractArray, N::Int) -> SharedCuArray + +Convenience constructor: create a `SharedCuArray` from an arrat with an explicit +1D batch size `N`. + +The underlying storage is generic to support `Adapt.jl` transformations, but in +the user-facing intended interface `A` is type `CuArray` +""" +Shared(x::A, N::Int) where {T,A<:AbstractArray{T,2}} = SharedCuArray{T,2,1,A}(x, (N,)) +Shared(x::A, N::Int) where {T,A<:AbstractArray{T,1}} = SharedCuArray{T,1,1,A}(x, (N,)) + +Base.IndexStyle(::Type{<:SharedCuArray}) = Base.IndexCartesian() + +Base.size(x::SharedCuArray) = x.batchsize + +function Base.getindex( + x::SharedCuArray{T,InnerN,BatchN}, ::Vararg{Int,BatchN} +) where {T,InnerN,BatchN} + return x.data +end + +function inner_size(x::SharedCuArray) + return size(x.data) +end + +batch_size(x::SharedCuArray) = length(x) + +# Adapting SharedCuArray to bitstype +function Adapt.adapt_structure( + to, + x::SharedCuArray{T,InnerN,BatchN,A}, +) where {T,InnerN,BatchN,A<:AbstractArray{T,InnerN}} + data_adapted = Adapt.adapt(to, x.data) + return SharedCuArray{T,InnerN,BatchN,typeof(data_adapted)}(data_adapted, x.batchsize) +end + +# ============================================================================= +# SharedScalar: a scalar value shared across all batch elements +# ============================================================================= + +struct SharedScalar{T} <: AbstractVector{T} + value::T +end + +Base.size(::SharedScalar) = (1,) +Base.length(::SharedScalar) = 1 +Base.getindex(x::SharedScalar, ::Int) = x.value +Base.eltype(::Type{SharedScalar{T}}) where {T} = T +batch_size(::SharedScalar) = nothing +_get_component_batch_size(::SharedScalar) = nothing + +# Comparisons unwrap the SharedScalar automatically +Base.:(==)(x::SharedScalar, y) = x.value == y +Base.:(==)(x, y::SharedScalar) = x == y.value +Base.:(==)(x::SharedScalar, y::SharedScalar) = x.value == y.value + +# Adapting SharedScalar to bitstype +Adapt.@adapt_structure SharedScalar + +# ============================================================================= +# BatchedStruct - Custom wrapper for batched composite types +# ============================================================================= + +""" + BatchedStruct{T, C <: NamedTuple} <: AbstractVector{T} + +A wrapper type representing a batch of structs of type `T`, stored in a +column-oriented (struct-of-arrays) format. + +# Type Parameters +- `T`: The element type (e.g., `PDMat{Float32, CuMatrix{Float32}}`) +- `C`: The NamedTuple type holding the batched components + +# Fields +- `components::C`: A NamedTuple where each field is a batched array or nested BatchedStruct + +# Usage +BatchedStruct is designed to be created automatically by the batching IR transform +when constructors are called with batched arguments. Users typically don't need +to construct these directly. + +# Property Access +- `x.fieldname` returns the batched component for real fields of `T` +- `x.components` returns the underlying NamedTuple storage +- For computed properties (custom getproperty), falls back to element-wise evaluation + +# Indexing +- `x[i]` constructs and returns an instance of `T` for the i-th batch element +""" +struct BatchedStruct{T,C<:NamedTuple} <: AbstractVector{T} + components::C + + function BatchedStruct{T}(components::C) where {T,C<:NamedTuple} + # Validate all components have consistent batch size + sizes = map(_get_component_batch_size, values(components)) + non_nothing = Base.filter(!isnothing, sizes) + if !isempty(non_nothing) + first_size = first(non_nothing) + if !all(s -> s == first_size, non_nothing) + error("All batched components must have the same batch size") + end + end + return new{T,C}(components) + end +end + +# Helper to get batch size from a component +_get_component_batch_size(x::BatchedCuArray) = batch_size(x) +_get_component_batch_size(x::SharedCuArray) = batch_size(x) +_get_component_batch_size(x::BatchedStruct) = length(x) +_get_component_batch_size(::Any) = nothing + +# Convenience constructor that infers T from the fieldnames matching a type +function BatchedStruct{T}(; kwargs...) where {T} + components = NamedTuple{fieldnames(T)}(values(kwargs)) + return BatchedStruct{T}(components) +end + +# ============================================================================= +# BatchedStruct - AbstractVector Interface +# ============================================================================= + +function batch_size(x::BatchedStruct) + for component in values(getfield(x, :components)) + bs = _get_component_batch_size(component) + if bs !== nothing + return bs + end + end + return error("BatchedStruct has no batched components") +end + +Base.size(x::BatchedStruct) = (batch_size(x),) +Base.IndexStyle(::Type{<:BatchedStruct}) = IndexLinear() + +# Non-generic indexing +function Base.getindex(x::BatchedStruct{<:Adjoint}, i::Integer) + return adjoint(getfield(x, :components).parent[i]) +end + +function Base.getindex(x::BatchedStruct{<:Transpose}, i::Integer) + return transpose(getfield(x, :components).parent[i]) +end + +# Indexing: construct an element of type T by calling its constructor +@generated function Base.getindex(x::BatchedStruct{T}, i::Integer) where {T} + fields = fieldnames(T) + field_exprs = [:(getfield(x, :components).$(fields[j])[i]) for j in 1:length(fields)] + return :(T($(field_exprs...))) +end + +# ============================================================================= +# BatchedStruct - Property Access +# ============================================================================= + +function Base.getproperty(x::BatchedStruct{T}, s::Symbol) where {T} + s === :components && return getfield(x, :components) + s in fieldnames(T) && return getfield(x, :components)[s] + # Computed property - broadcast getproperty, triggering IR transformation + return getproperty.(x, Ref(s)) +end + +Base.propertynames(::BatchedStruct{T}) where {T} = (:components, fieldnames(T)...) + +# ============================================================================= +# BatchedStruct - Display (avoid materialization) +# ============================================================================= + +function Base.show(io::IO, x::BatchedStruct{T}) where {T} + return print(io, "BatchedStruct{", T, "} with ", length(x), " elements") +end + +function Base.show(io::IO, ::MIME"text/plain", x::BatchedStruct{T}) where {T} + println(io, "BatchedStruct{", T, "} with ", length(x), " elements:") + comps = getfield(x, :components) + for name in keys(comps) + component = comps[name] + print(io, " .", name, " :: ") + if component isa BatchedCuArray + println(io, typeof(component), " (", inner_size(component), ")") + elseif component isa SharedCuArray + println(io, typeof(component), " [shared] (", inner_size(component), ")") + elseif component isa BatchedStruct + println(io, typeof(component)) + else + println(io, typeof(component)) + end + end +end + +# Adapting BatchedStruct to bitstype +function Adapt.adapt_structure( + to, + x::BatchedStruct{T,C}, +) where {T,C<:NamedTuple} + comps_adapted = Adapt.adapt(to, x.components) + return BatchedStruct{T}(comps_adapted) +end + +# ============================================================================= +# Union Types for Dispatch +# ============================================================================= + +const BatchedOrShared = Union{BatchedCuArray,SharedCuArray,BatchedStruct} + +# ============================================================================= +# Helper Functions +# ============================================================================= + +is_shared(::BatchedCuArray) = false +is_shared(::SharedCuArray) = true + +unwrap_data(A::BatchedCuArray) = A.data +unwrap_data(A::SharedCuArray) = A.data + +function inner_size_for_blas(A::BatchedCuMatrix) + m, n = size(A.data, 1), size(A.data, 2) + return (m, n) +end + +function inner_size_for_blas(A::SharedCuMatrix) + m, n = size(A.data) + return (m, n) +end + +function get_batch_size(args...) + for arg in args + bs = batch_size(arg) + if bs !== nothing + return bs + end + end + return error("At least one argument must be batched") +end + +# ============================================================================= +# Pointer Array Creation +# ============================================================================= + +function create_pointer_array(A::BatchedCuArray{T,InnerN,1}) where {T,InnerN} + return CUDA.CUBLAS.unsafe_strided_batch(A.data) +end + +function create_pointer_array(A::SharedCuArray{T,InnerN,1}) where {T,InnerN} + N = batch_size(A) + ptr = pointer(A.data) + return reinterpret(CuPtr{T}, CUDA.fill(UInt(ptr), N)) +end diff --git a/GeneralisedFilters/src/batching/wrappers.jl b/GeneralisedFilters/src/batching/wrappers.jl new file mode 100644 index 00000000..2170490d --- /dev/null +++ b/GeneralisedFilters/src/batching/wrappers.jl @@ -0,0 +1,731 @@ +using Magma +using Magma.LibMagma +using LinearAlgebra: cholesky, Cholesky, LowerTriangular + +export get_magma_queue, reset_magma_queue! + +# ============================================================================= +# MAGMA Queue Management +# ============================================================================= + +# Store the queue pointer directly (not in a Ref) to avoid lifetime issues +mutable struct MagmaQueueCache + ptr::LibMagma.magma_queue_t + initialized::Bool +end + +const MAGMA_QUEUE_CACHE = MagmaQueueCache(LibMagma.magma_queue_t(), false) + +""" + get_magma_queue() + +Get the cached MAGMA queue, creating it on first use. +Returns the raw queue pointer for use with MAGMA functions. +""" +function get_magma_queue() + if !MAGMA_QUEUE_CACHE.initialized + queue_ref = Ref{LibMagma.magma_queue_t}() + LibMagma.magma_queue_create_internal(0, queue_ref, C_NULL, C_NULL, 0) + MAGMA_QUEUE_CACHE.ptr = queue_ref[] + MAGMA_QUEUE_CACHE.initialized = true + end + return MAGMA_QUEUE_CACHE.ptr +end + +""" + reset_magma_queue!() + +Destroy and recreate the MAGMA queue. Useful if you suspect queue corruption. +""" +function reset_magma_queue!() + if MAGMA_QUEUE_CACHE.initialized + LibMagma.magma_queue_destroy_internal(MAGMA_QUEUE_CACHE.ptr, C_NULL, C_NULL, 0) + MAGMA_QUEUE_CACHE.initialized = false + end + return get_magma_queue() +end + +# ============================================================================= +# Trivial Wrappers (reductions and elementwise operations) +# ============================================================================= + +function broadcasted( + ::typeof(+), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + if is_shared(A) && is_shared(B) + return SharedCuMatrix(A.data .+ B.data, batch_size(A)) + else + return BatchedCuMatrix(A.data .+ B.data) + end +end + +function broadcasted( + ::typeof(+), + a::Union{BatchedCuVector{T},SharedCuVector{T}}, + b::Union{BatchedCuVector{T},SharedCuVector{T}}, +) where {T} + if is_shared(a) && is_shared(b) + return SharedCuVector(a.data .+ b.data, batch_size(a)) + else + return BatchedCuVector(a.data .+ b.data) + end +end + +function broadcasted( + ::typeof(-), + A::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, + B::Union{BatchedCuMatrix{T},SharedCuMatrix{T}}, +) where {T} + if is_shared(A) && is_shared(B) + return SharedCuMatrix(A.data .- B.data, batch_size(A)) + else + return BatchedCuMatrix(A.data .- B.data) + end +end + +function broadcasted( + ::typeof(-), + a::Union{BatchedCuVector{T},SharedCuVector{T}}, + b::Union{BatchedCuVector{T},SharedCuVector{T}}, +) where {T} + if is_shared(a) && is_shared(b) + return SharedCuVector(a.data .- b.data, batch_size(a)) + else + return BatchedCuVector(a.data .- b.data) + end +end + +# ============================================================================= +# MAGMA Constants Conversion +# ============================================================================= + +function magma_trans(c::Char) + if c == 'N' + return LibMagma.MagmaNoTrans + elseif c == 'T' + return LibMagma.MagmaTrans + elseif c == 'C' + return LibMagma.MagmaConjTrans + else + error("Unknown transpose char: $c") + end +end + +function magma_uplo(c::Char) + if c == 'L' + return LibMagma.MagmaLower + elseif c == 'U' + return LibMagma.MagmaUpper + else + error("Unknown uplo char: $c") + end +end + +function magma_side(c::Char) + if c == 'L' + return LibMagma.MagmaLeft + elseif c == 'R' + return LibMagma.MagmaRight + else + error("Unknown side char: $c") + end +end + +function magma_diag(c::Char) + if c == 'N' + return LibMagma.MagmaNonUnit + elseif c == 'U' + return LibMagma.MagmaUnit + else + error("Unknown diag char: $c") + end +end + +# ============================================================================= +# MAGMA Operations +# ============================================================================= + +function gemm_batched!( + transA::Char, + transB::Char, + alpha::Float32, + A::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + B::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + beta::Float32, + C::BatchedCuMatrix{Float32}, +) + N = batch_size(C) + m, n = size(C.data, 1), size(C.data, 2) + k = transA == 'N' ? size(unwrap_data(A), 2) : size(unwrap_data(A), 1) + + dA = create_pointer_array(A) + dB = create_pointer_array(B) + dC = create_pointer_array(C) + + ldda = size(unwrap_data(A), 1) + lddb = size(unwrap_data(B), 1) + lddc = m + + CUDA.synchronize() + queue = get_magma_queue() + LibMagma.magma_sgemm_batched( + magma_trans(transA), + magma_trans(transB), + m, + n, + k, + alpha, + dA, + ldda, + dB, + lddb, + beta, + dC, + lddc, + N, + queue, + ) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + CUDA.unsafe_free!(dC) + + return C +end + +# ============================================================================= +# Part 3b: Batched GEMM Small Square Wrapper (for D < 32) +# ============================================================================= + +function gemm_batched_smallsq!( + transA::Char, + transB::Char, + alpha::Float32, + A::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + B::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + beta::Float32, + C::BatchedCuMatrix{Float32}, +) + N = batch_size(C) + m, n = size(C.data, 1), size(C.data, 2) + k = transA == 'N' ? size(unwrap_data(A), 2) : size(unwrap_data(A), 1) + + dA = create_pointer_array(A) + dB = create_pointer_array(B) + dC = create_pointer_array(C) + + ldda = size(unwrap_data(A), 1) + lddb = size(unwrap_data(B), 1) + lddc = m + + CUDA.synchronize() + queue = get_magma_queue() + LibMagma.magmablas_sgemm_batched_smallsq( + magma_trans(transA), + magma_trans(transB), + m, + n, + k, + alpha, + dA, + 0, # ai + 0, # aj + ldda, + dB, + 0, # bi + 0, # bj + lddb, + beta, + dC, + 0, # ci + 0, # cj + lddc, + N, + queue, + ) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + CUDA.unsafe_free!(dC) + + return C +end + +function gemm_batched_smallsq!( + transA::Char, + transB::Char, + alpha::Float64, + A::Union{BatchedCuMatrix{Float64},SharedCuMatrix{Float64}}, + B::Union{BatchedCuMatrix{Float64},SharedCuMatrix{Float64}}, + beta::Float64, + C::BatchedCuMatrix{Float64}, +) + N = batch_size(C) + m, n = size(C.data, 1), size(C.data, 2) + k = transA == 'N' ? size(unwrap_data(A), 2) : size(unwrap_data(A), 1) + + dA = create_pointer_array(A) + dB = create_pointer_array(B) + dC = create_pointer_array(C) + + ldda = size(unwrap_data(A), 1) + lddb = size(unwrap_data(B), 1) + lddc = m + + CUDA.synchronize() + queue = get_magma_queue() + LibMagma.magmablas_dgemm_batched_smallsq( + magma_trans(transA), + magma_trans(transB), + m, + n, + k, + alpha, + dA, + 0, # ai + 0, # aj + ldda, + dB, + 0, # bi + 0, # bj + lddb, + beta, + dC, + 0, # ci + 0, # cj + lddc, + N, + queue, + ) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + CUDA.unsafe_free!(dC) + + return C +end + +function gemv_batched!( + transA::Char, + alpha::Float32, + A::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + x::Union{BatchedCuVector{Float32},SharedCuVector{Float32}}, + beta::Float32, + y::BatchedCuVector{Float32}, +) + N = batch_size(y) + m, n = size(unwrap_data(A), 1), size(unwrap_data(A), 2) + + dA = create_pointer_array(A) + dx = create_pointer_array(x) + dy = create_pointer_array(y) + + ldda = m + incx = 1 + incy = 1 + + CUDA.synchronize() + queue = get_magma_queue() + LibMagma.magmablas_sgemv_batched( + magma_trans(transA), m, n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue + ) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dx) + CUDA.unsafe_free!(dy) + + return y +end + +function gemv_batched!( + transA::Char, + alpha::Float64, + A::Union{BatchedCuMatrix{Float64},SharedCuMatrix{Float64}}, + x::Union{BatchedCuVector{Float64},SharedCuVector{Float64}}, + beta::Float64, + y::BatchedCuVector{Float64}, +) + N = batch_size(y) + m, n = size(unwrap_data(A), 1), size(unwrap_data(A), 2) + + dA = create_pointer_array(A) + dx = create_pointer_array(x) + dy = create_pointer_array(y) + + ldda = m + incx = 1 + incy = 1 + + CUDA.synchronize() + queue = get_magma_queue() + LibMagma.magmablas_dgemv_batched( + magma_trans(transA), m, n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue + ) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dx) + CUDA.unsafe_free!(dy) + + return y +end + +function gemv_batched_smallsq!( + transA::Char, + alpha::Float32, + A::Union{BatchedCuMatrix{Float32},SharedCuMatrix{Float32}}, + x::Union{BatchedCuVector{Float32},SharedCuVector{Float32}}, + beta::Float32, + y::BatchedCuVector{Float32}, +) + N = batch_size(y) + n = size(unwrap_data(A), 1) + + dA = create_pointer_array(A) + dx = create_pointer_array(x) + dy = create_pointer_array(y) + + ldda = n + incx = 1 + incy = 1 + + CUDA.synchronize() + queue = get_magma_queue() + LibMagma.magmablas_sgemv_batched_smallsq( + magma_trans(transA), n, alpha, dA, ldda, dx, incx, beta, dy, incy, N, queue + ) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dx) + CUDA.unsafe_free!(dy) + + return y +end + +function potrf_batched!(uplo::Char, A::BatchedCuMatrix{Float32}, info::CuVector{Int64}) + N = batch_size(A) + n = size(A.data, 1) + lda = n + + dA = create_pointer_array(A) + + CUDA.synchronize() + queue = get_magma_queue() + LibMagma.magma_spotrf_batched(magma_uplo(uplo), n, dA, lda, pointer(info), N, queue) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + + return A +end + +function potrs_batched!( + uplo::Char, A::BatchedCuMatrix{Float32}, B::BatchedCuMatrix{Float32} +) + N = batch_size(B) + n = size(A.data, 1) + nrhs = size(B.data, 2) + + dA = create_pointer_array(A) + dB = create_pointer_array(B) + + ldda = n + lddb = n + + CUDA.synchronize() + queue = get_magma_queue() + LibMagma.magma_spotrs_batched(magma_uplo(uplo), n, nrhs, dA, ldda, dB, lddb, N, queue) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + + return B +end + +function trsm_batched!( + side::Char, + uplo::Char, + transA::Char, + diag::Char, + alpha::Float32, + A::BatchedCuMatrix{Float32}, + B::BatchedCuMatrix{Float32}, +) + N = batch_size(B) + m, n = size(B.data, 1), size(B.data, 2) + + dA = create_pointer_array(A) + dB = create_pointer_array(B) + + ldda = size(A.data, 1) + lddb = m + + CUDA.synchronize() + queue = get_magma_queue() + LibMagma.magmablas_strsm_batched( + magma_side(side), + magma_uplo(uplo), + magma_trans(transA), + magma_diag(diag), + m, + n, + alpha, + dA, + ldda, + dB, + lddb, + N, + queue, + ) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + + return B +end + +# ============================================================================= +# Higher-level Cholesky Operations +# ============================================================================= + +function broadcasted(::typeof(cholesky), A::BatchedCuMatrix{T,M}) where {T,M} + N = batch_size(A) + A_copy = BatchedCuMatrix(copy(A.data)) + info = CUDA.zeros(Int64, N) + + potrf_batched!('L', A_copy, info) + + factors_wrapped = broadcasted(LowerTriangular, A_copy) + + # Store as SharedScalar since it's the same for all batch elements + uplo = SharedScalar('L') + + ElType = Cholesky{T,eltype(A)} + return BatchedStruct{ElType}((; factors=factors_wrapped, uplo=uplo, info=info)) +end + +# function pdmat_solve(S::BatchedPDMat{T}, B::BatchedCuMatrix{T}) where {T} +# L = S.chol.factors +# L_data = BatchedCuMatrix(L.data) + +# B_copy = BatchedCuMatrix(copy(B.data)) + +# # Solve L*L'*X = B via two triangular solves: +# # 1. Solve L*Y = B (Y stored in B_copy) +# trsm_batched!('L', 'L', 'N', 'N', one(T), L_data, B_copy) +# # 2. Solve L'*X = Y (X stored in B_copy) +# trsm_batched!('L', 'L', 'T', 'N', one(T), L_data, B_copy) + +# return B_copy +# end + +# ============================================================================= +# Batched TRMM (Triangular Matrix Multiply) +# ============================================================================= + +function trmm_batched!( + side::Char, + uplo::Char, + transA::Char, + diag::Char, + alpha::Float32, + A::BatchedCuMatrix{Float32}, + B::BatchedCuMatrix{Float32}, +) + N = batch_size(B) + m, n = size(B.data, 1), size(B.data, 2) + + dA = create_pointer_array(A) + dB = create_pointer_array(B) + + ldda = size(A.data, 1) + lddb = m + + CUDA.synchronize() + queue = get_magma_queue() + LibMagma.magmablas_strmm_batched( + magma_side(side), + magma_uplo(uplo), + magma_trans(transA), + magma_diag(diag), + m, + n, + alpha, + dA, + ldda, + dB, + lddb, + N, + queue, + ) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + + return B +end + +function trmm_batched!( + side::Char, + uplo::Char, + transA::Char, + diag::Char, + alpha::Float64, + A::BatchedCuMatrix{Float64}, + B::BatchedCuMatrix{Float64}, +) + N = batch_size(B) + m, n = size(B.data, 1), size(B.data, 2) + + dA = create_pointer_array(A) + dB = create_pointer_array(B) + + ldda = size(A.data, 1) + lddb = m + + CUDA.synchronize() + queue = get_magma_queue() + LibMagma.magmablas_dtrmm_batched( + magma_side(side), + magma_uplo(uplo), + magma_trans(transA), + magma_diag(diag), + m, + n, + alpha, + dA, + ldda, + dB, + lddb, + N, + queue, + ) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dB) + + return B +end + +# ============================================================================= +# Batched SYRK (Symmetric Rank-K Update) +# ============================================================================= + +function syrk_batched!( + uplo::Char, + trans::Char, + alpha::Float32, + A::BatchedCuMatrix{Float32}, + beta::Float32, + C::BatchedCuMatrix{Float32}, +) + N = batch_size(C) + n = size(C.data, 1) + k = trans == 'N' ? size(A.data, 2) : size(A.data, 1) + + dA = create_pointer_array(A) + dC = create_pointer_array(C) + + ldda = size(A.data, 1) + lddc = n + + CUDA.synchronize() + queue = get_magma_queue() + LibMagma.magmablas_ssyrk_batched( + magma_uplo(uplo), + magma_trans(trans), + n, + k, + alpha, + dA, + ldda, + beta, + dC, + lddc, + N, + queue, + ) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dC) + + return C +end + +function syrk_batched!( + uplo::Char, + trans::Char, + alpha::Float64, + A::BatchedCuMatrix{Float64}, + beta::Float64, + C::BatchedCuMatrix{Float64}, +) + N = batch_size(C) + n = size(C.data, 1) + k = trans == 'N' ? size(A.data, 2) : size(A.data, 1) + + dA = create_pointer_array(A) + dC = create_pointer_array(C) + + ldda = size(A.data, 1) + lddc = n + + CUDA.synchronize() + queue = get_magma_queue() + LibMagma.magmablas_dsyrk_batched( + magma_uplo(uplo), + magma_trans(trans), + n, + k, + alpha, + dA, + ldda, + beta, + dC, + lddc, + N, + queue, + ) + LibMagma.magma_queue_sync_internal(queue, C_NULL, C_NULL, 0) + + CUDA.unsafe_free!(dA) + CUDA.unsafe_free!(dC) + + return C +end + +# ============================================================================= +# Symmetrize Lower Triangular Matrix +# ============================================================================= + +function symmetrize_lower_kernel!(A, n) + batch_idx = blockIdx().z + i = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + j = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y + + if i <= n && j <= n && j > i + @inbounds A[i, j, batch_idx] = A[j, i, batch_idx] + end + return nothing +end + +function symmetrize_lower!(A::BatchedCuMatrix{T}) where {T} + n = size(A.data, 1) + N = size(A.data, 3) + threads = (16, 16) + blocks = (cld(n, 16), cld(n, 16), N) + @cuda threads = threads blocks = blocks symmetrize_lower_kernel!(A.data, n) + return A +end diff --git a/research/batching/batching_demo.jl b/research/batching/batching_demo.jl new file mode 100644 index 00000000..e6f98994 --- /dev/null +++ b/research/batching/batching_demo.jl @@ -0,0 +1,277 @@ +using GeneralisedFilters + +using Distributions +using LinearAlgebra +using Base.Broadcast: broadcasted +using PDMats +using BenchmarkTools + +using CUDA +using Magma +using Magma.LibMagma + +Magma.magma_init() + +# ============================================================================= +# Configuration +# ============================================================================= + +D_state = 2 +D_obs = 2 +N = 3 + +BATCHED_CACHE_VERBOSITY[] = :debug + +function kalman_predict(state, dyn_params) + A = dyn_params[1] + b = dyn_params[2] + Q = dyn_params[3] + + μ̂ = A * state.μ + b + Σ̂ = X_A_Xt(state.Σ, A) + Q + return MvNormal(μ̂, Σ̂) +end + +I_mat = CuArray{Float32}(I, D_state, D_state) +Is = Shared(I_mat, N) + +μs = BatchedCuVector(CUDA.randn(Float32, D_state, N)) +Σs_root = BatchedCuMatrix(CUDA.randn(Float32, D_state, D_state, N)) +Σs = Σs_root .* adjoint.(Σs_root) .+ Is + +As = Shared(CUDA.randn(Float32, D_state, D_state), N) +bs = BatchedCuVector(CUDA.randn(Float32, D_state, N)) +Q_root = CUDA.randn(Float32, D_state, D_state) +Q = Q_root * Q_root' + I +Qs = Shared(Q, N) + +Σ_PDs = broadcasted(PDMat, Σs); +PDMat.(Σs) +Gs = MvNormal.(μs, Σ_PDs); + +function kalman_predict(state, dyn_params) + A = dyn_params[1] + b = dyn_params[2] + Q = dyn_params[3] + + μ̂ = A * state.μ + b + Σ̂ = PDMat(X_A_Xt(state.Σ, A) + Q) + + return MvNormal(μ̂, Σ̂) +end + +dyn_params = (As, bs, Qs) +pred_Gs = kalman_predict.(Gs, Ref(dyn_params)); + +# Compare to CPU +μ_test = Array(μs[end]) +Σ_test = Array(Σs[end]) +A_test = Array(As.data) +b_test = Array(bs[end]) +Q_test = Array(Qs.data) +pred_G_test = kalman_predict(MvNormal(μ_test, PDMat(Σ_test)), (A_test, b_test, Q_test)) + +println("=== Predict Comparison ===\n") +println("CPU Mean: ", pred_G_test.μ) +println("GPU Mean: ", Array(pred_Gs.μ[end])) + +println("CPU Covariance: ", Matrix(pred_G_test.Σ)) +println("GPU Covariance: ", Array(pred_Gs.Σ.mat[end])) + +# ============================================================================= +# Kalman Update +# ============================================================================= + +function kalman_update(state, obs_params, observation) + μ = state.μ + Σ = state.Σ + H = obs_params[1] + c = obs_params[2] + R = obs_params[3] + + # Compute innovation distribution + m = H * μ + c + S = PDMat(X_A_Xt(Σ, H) + R) + ȳ = observation - m + + # Kalman gain + K = Σ * H' / S + + # Update parameters using Joseph form for numerical stability + μ̂ = μ + K * ȳ + Σ̂ = PDMat(X_A_Xt(Σ, I - K * H) + X_A_Xt(R, K)) + + return MvNormal(μ̂, Σ̂) +end + +function kalman_step(state, dyn_params, obs_params, observation) + state = kalman_predict(state, dyn_params) + state = kalman_update(state, obs_params, observation) + return state +end + +# Observation parameters (H and c shared, R batched) +Hs = Shared(CUDA.randn(Float32, D_obs, D_state), N) +cs = Shared(CUDA.randn(Float32, D_obs), N) +I_obs = CuArray{Float32}(I, D_obs, D_obs) +I_obs_shared = Shared(I_obs, N) +Rs_root = BatchedCuMatrix(CUDA.randn(Float32, D_obs, D_obs, N)) +Rs = Rs_root .* adjoint.(Rs_root) .+ I_obs_shared +Rs = PDMat.(Rs); + +obs_params = (Hs, cs, Rs) + +# Observations +observations = BatchedCuVector(CUDA.randn(Float32, D_obs, N)) + +# Run update on GPU +update_Gs = kalman_update.(pred_Gs, Ref(obs_params), observations); + +# Compare update to CPU +H_test = Array(Hs.data) +c_test = Array(cs.data) +R_test = PDMat(Array(Rs.mat[end])) +obs_test = Array(observations[end]) + +update_G_test = kalman_update(pred_G_test, (H_test, c_test, R_test), obs_test) + +println("\n=== Update Comparison ===\n") +println("CPU Mean: ", update_G_test.μ) +println("GPU Mean: ", Array(update_Gs.μ[end])) + +println("CPU Covariance: ", Matrix(update_G_test.Σ)) +println("GPU Covariance: ", Array(update_Gs.Σ.mat[end])) + +# ============================================================================= +# Full Kalman Step +# ============================================================================= + +# Run full step on GPU (from original state) +step_Gs = kalman_step.(Gs, Ref(dyn_params), Ref(obs_params), observations); + +# Compare full step to CPU +step_G_test = kalman_step( + MvNormal(μ_test, PDMat(Σ_test)), + (A_test, b_test, Q_test), + (H_test, c_test, R_test), + obs_test, +) + +println("\n=== Full Step Comparison ===\n") +println("CPU Mean: ", step_G_test.μ) +println("GPU Mean: ", Array(step_Gs.μ[end])) + +println("CPU Covariance: ", Matrix(step_G_test.Σ)) +println("GPU Covariance: ", Array(step_Gs.Σ.mat[end])) + +# ============================================================================= +# Benchmarking +# ============================================================================= + +using BenchmarkTools +using StaticArrays + +D_bench = 10 +N_bench = 60000 + +println("\n=== Benchmarking batched Kalman step ===\n") + +println("D = $D_bench, N = $N_bench") +println( + "Size of batched covariance matrices: ", + round(Int, D_bench * D_bench * N_bench * sizeof(Float32) / 1024^2), + " MB", +) + +Is_bench = Shared(CuArray{Float32}(I, D_bench, D_bench), N_bench) +μs_bench = BatchedCuVector(CUDA.randn(Float32, D_bench, N_bench)) +Σs_root_bench = BatchedCuMatrix(CUDA.randn(Float32, D_bench, D_bench, N_bench)) +Σs_bench = Σs_root_bench .* adjoint.(Σs_root_bench) .+ Is_bench +Σ_PDs_bench = broadcasted(PDMat, Σs_bench); +Gs_bench = MvNormal.(μs_bench, Σ_PDs_bench); + +As_bench = Shared(CUDA.randn(Float32, D_bench, D_bench), N_bench) +bs_bench = BatchedCuVector(CUDA.randn(Float32, D_bench, N_bench)) +Qs_root_bench = CUDA.randn(Float32, D_bench, D_bench) +Qs_bench_mat = Qs_root_bench * adjoint(Qs_root_bench) +Qs_bench_mat += I +Qs_bench = Shared(Qs_bench_mat, N_bench) +dyn_params_bench = (As_bench, bs_bench, Qs_bench) + +Hs_bench = Shared(CUDA.randn(Float32, D_bench, D_bench), N_bench) +cs_bench = Shared(CUDA.randn(Float32, D_bench), N_bench) +Rs_root_bench = BatchedCuMatrix(CUDA.randn(Float32, D_bench, D_bench, N_bench)) +Rs_bench = Rs_root_bench .* adjoint.(Rs_root_bench) .+ Is_bench +Rs_bench = PDMat.(Rs_bench) +obs_params_bench = (Hs_bench, cs_bench, Rs_bench) + +ys_bench = BatchedCuVector(CUDA.randn(Float32, D_bench, N_bench)) + +println("\nBenchmarking batched Kalman step...\n") + +BATCHED_CACHE_VERBOSITY[] = :silent +display( + @benchmark kalman_step.( + $Gs_bench, Ref($dyn_params_bench), Ref($obs_params_bench), $ys_bench + ) +) + +# Compare to static arrays +μs_static = [@SVector randn(Float32, D_bench) for _ in 1:N_bench] +Σs_static = [ + begin + A = @SMatrix randn(Float32, D_bench, D_bench) + A * A' + I + end for _ in 1:N_bench +] +Gs_static = [MvNormal(μs_static[n], PDMat(Σs_static[n])) for n in 1:N_bench] + +As_static = @SMatrix randn(Float32, D_bench, D_bench) +bs_static = [@SVector randn(Float32, D_bench) for _ in 1:N_bench] +Qs_root_static = @SMatrix randn(Float32, D_bench, D_bench) +Qs_static_mat = Qs_root_static * adjoint(Qs_root_static) + I +Qs_static = Qs_static_mat +dyn_params_static = (As_static, bs_static, Qs_static) + +Hs_static = @SMatrix randn(Float32, D_bench, D_bench) +cs_static = @SVector randn(Float32, D_bench) +Rs_root_static = [@SMatrix randn(Float32, D_bench, D_bench) for _ in 1:N_bench] +Rs_static = [ + Symmetric(R_root * R_root') + SMatrix{D_bench,D_bench,Float32}(I) for + R_root in Rs_root_static +] +obs_static = [@SVector randn(Float32, D_bench) for _ in 1:N_bench] +obs_params_static = (Hs_static, cs_static, PDMat.(Rs_static)) + +function test_static(Gs, dyn_params, obs_params, observations, out) + N = length(out) + Threads.@threads for n in 1:N + @inbounds out[n] = kalman_step( + Gs[n], + (dyn_params[1], dyn_params[2][n], dyn_params[3]), + (obs_params[1], obs_params[2], obs_params[3][n]), + observations[n], + ) + end + return out +end + +out = Vector{eltype(Gs_static)}(undef, N_bench) + +println("\nBenchmarking static Kalman step...\n") + +test_static(Gs_static, dyn_params_static, obs_params_static, obs_static, out); # warm-up + +display( + @benchmark test_static( + $Gs_static, $dyn_params_static, $obs_params_static, $obs_static, $out + ) +) + +println("\nBenchmarking single batched matrix multiplication...\n") + +display(@benchmark $Σs_root_bench .* $Σs_root_bench) + +tot_mem = 3 * D_bench * D_bench * N_bench * sizeof(Float32) +throughput = 1008 * 10^9 +println("\nTheoretical optimum time: ", tot_mem / throughput * 10^6, " μs") diff --git a/research/batching/full_kalman_demo.jl b/research/batching/full_kalman_demo.jl new file mode 100644 index 00000000..f9590b44 --- /dev/null +++ b/research/batching/full_kalman_demo.jl @@ -0,0 +1,80 @@ +using GeneralisedFilters +const GF = GeneralisedFilters + +using Distributions +using LinearAlgebra +using Base.Broadcast: broadcasted +using PDMats +using BenchmarkTools +import Distributions: params + +using CUDA +using Magma +using Magma.LibMagma + +Magma.magma_init() +BATCHED_CACHE_VERBOSITY[] = :debug + +D_state = 3 +D_obs = 2 +N = 3 + +μs = BatchedCuVector(CUDA.randn(Float32, D_state, N)) +Σs_root = BatchedCuMatrix(CUDA.randn(Float32, D_state, D_state, N)) +Σs = Σs_root .* adjoint.(Σs_root) .+ Ref(I) + +As = BatchedCuMatrix(CUDA.randn(Float32, D_state, D_state, N)) +bs = BatchedCuVector(CUDA.randn(Float32, D_state, N)) +Q_root = BatchedCuMatrix(CUDA.randn(Float32, D_state, D_state, N)) +Qs = Q_root .* adjoint.(Q_root) .+ Ref(I) +Qs = PDMat.(Qs); + +Σ_PDs = broadcasted(PDMat, Σs); +Gs = MvNormal.(μs, Σ_PDs); + +# Observation parameters (H and c shared, R batched) +Hs = Shared(CUDA.randn(Float32, D_obs, D_state), N) +cs = Shared(CUDA.randn(Float32, D_obs), N) +Rs_root = BatchedCuMatrix(CUDA.randn(Float32, D_obs, D_obs, N)) +Rs = Rs_root .* adjoint.(Rs_root) .+ Ref(I) +Rs = PDMat.(Rs); + +# Observations +observations = BatchedCuVector(CUDA.randn(Float32, D_obs, N)) +jitter = 1.0f-6 + +dyn_params = tuple.(As, bs, Qs) +obs_params = tuple.(Hs, cs, Rs) + +# Dispatch-based jitter application (avoids if statements in batched code) +_maybe_apply_jitter(Σ, ::Nothing) = Σ +_maybe_apply_jitter(Σ, jitter::Real) = Σ + jitter * I + +function kalman_step(state, dyn_params, obs_params, observation, jitter) + μ, Σ = params(state) + A, b, Q = dyn_params + + μ = A * μ + b + Σ = X_A_Xt(Σ, A) + Q + + H, c, R = obs_params + + z = GF._compute_innovation(μ, H, c, observation) + S = GF._compute_innovation_cov(Σ, H, R) + K = GF._compute_kalman_gain(Σ, H, S) + _, Σ̂_raw = GF._compute_joseph_update(Σ, K, H, R) + + μ = μ + K * z + Σ = PDMat(_maybe_apply_jitter(Σ̂_raw, jitter)) + + ll = Distributions._logpdf(MvNormal(z, S), zero(z)) + + return MvNormal(μ, Σ), ll +end + +res = kalman_step.(Gs, dyn_params, obs_params, observations, Ref(jitter)) +println("\nFull type: ", typeof(res)) +println("\nElement type: ", eltype(res)) + +# Access second component of batched tuple (the log-likelihoods) +lls = res.components[2] diff --git a/research/batching/wrappers_demo.jl b/research/batching/wrappers_demo.jl new file mode 100644 index 00000000..6bbd6090 --- /dev/null +++ b/research/batching/wrappers_demo.jl @@ -0,0 +1,145 @@ +using GeneralisedFilters +using CUDA +using LinearAlgebra +using Base.Broadcast: broadcasted +using Distributions +using PDMats + +# ============================================================================= +# Configuration +# ============================================================================= + +N = 10 # batch size +D = 4 # matrix dimension + +println("Creating batched matrices...") +A = BatchedCuMatrix(CUDA.randn(Float32, D, D, N)); + +# ============================================================================= +# Test 1: Basic wrapper creation +# ============================================================================= + +println("\n=== Test 1: Basic wrapper creation ===\n") + +# Adjoint +A_adj = broadcasted(Adjoint, A); +println("Adjoint type: ", typeof(A_adj)) +println("Adjoint eltype: ", eltype(A_adj)) +println("First element type: ", typeof(A_adj[1])) + +# Transpose +A_trans = broadcasted(Transpose, A); +println("\nTranspose type: ", typeof(A_trans)) +println("Transpose eltype: ", eltype(A_trans)) + +# LowerTriangular +A_lower = broadcasted(LowerTriangular, A); +println("\nLowerTriangular type: ", typeof(A_lower)) +println("LowerTriangular eltype: ", eltype(A_lower)) + +# UpperTriangular +A_upper = broadcasted(UpperTriangular, A); +println("\nUpperTriangular type: ", typeof(A_upper)) +println("UpperTriangular eltype: ", eltype(A_upper)) + +# ============================================================================= +# Test 2: Function form redirects +# ============================================================================= + +println("\n=== Test 2: Function form redirects ===\n") + +A_adj2 = broadcasted(adjoint, A); +println("adjoint redirect type: ", typeof(A_adj2)) +println("Types match: ", typeof(A_adj) == typeof(A_adj2)) + +A_trans2 = broadcasted(transpose, A); +println("\ntranspose redirect type: ", typeof(A_trans2)) +println("Types match: ", typeof(A_trans) == typeof(A_trans2)) + +# ============================================================================= +# Test 3: Nested wrappers +# ============================================================================= + +println("\n=== Test 3: Nested wrappers ===\n") + +# Adjoint of LowerTriangular +A_lower_adj = broadcasted(Adjoint, A_lower); +println("Adjoint(LowerTriangular) type: ", typeof(A_lower_adj)) +println("Adjoint(LowerTriangular) eltype: ", eltype(A_lower_adj)) + +# ============================================================================= +# Test 4: Verify element access +# ============================================================================= + +println("\n=== Test 4: Element access verification ===\n") + +# Get first element of each wrapped array +println("A[1] type: ", typeof(A[1])) +println("A_adj[1] type: ", typeof(A_adj[1])) +println("A_lower[1] type: ", typeof(A_lower[1])) + +# Check values match +A_cpu = Array(A[1]); +A_adj_cpu = Array(parent(A_adj[1])); +println("\nValues match (A vs parent of A_adj): ", A_cpu ≈ A_adj_cpu) + +# ============================================================================= +# Test 5: SharedCuMatrix wrappers +# ============================================================================= + +println("\n=== Test 5: SharedCuMatrix wrappers ===\n") + +S = Shared(CUDA.randn(Float32, D, D), N); +S_adj = broadcasted(Adjoint, S); +println("SharedCuMatrix adjoint type: ", typeof(S_adj)) +println("SharedCuMatrix adjoint eltype: ", eltype(S_adj)) + +# ============================================================================= +# Test 6: Cholesky +# ============================================================================= + +println("\n=== Test 6: Batched Cholesky ===\n") + +# Create positive definite matrices: B = A * A' +B = A .* broadcasted(adjoint, A); +println("B type: ", typeof(B)) + +chol_result = cholesky.(B); +println("Cholesky result type: ", typeof(chol_result)) +println("Cholesky eltype: ", eltype(chol_result)) + +# Check fields +println("\nCholesky fields:") +println(" factors type: ", typeof(chol_result.factors)) +println(" uplo type: ", typeof(chol_result.uplo)) +println(" info type: ", typeof(chol_result.info)) + +# Access individual element +# SKIP: requires scalar indexing into info +# println("\nFirst element:") +# println(" chol_result[1] type: ", typeof(chol_result[1])) + +# ============================================================================= +# Test 7: PDMat wrapper +# ============================================================================= + +println("\n=== Test 7: PDMat wrapper ===\n") + +P = broadcasted(PDMat, B, chol_result); +println("PDMat result type: ", typeof(P)) +println("PDMat eltype: ", eltype(P)) + +println("\nPDMat fields:") +# println(" dim type: ", typeof(P.dim)) # SKIP: not a real field +println(" mat type: ", typeof(P.mat)) +println(" chol type: ", typeof(P.chol)) + +# ============================================================================= +# Test 8: MvNormal wrapper +# ============================================================================= + +println("\n=== Test 8: MvNormal wrapper ===\n") +μ = BatchedCuVector(CUDA.randn(Float32, D, N)); +G = broadcasted(MvNormal, μ, P); +println("MvNormal type: ", typeof(G)) +println("MvNormal eltype: ", eltype(G))