diff --git a/src/Furnace.Core/Tensor.fs b/src/Furnace.Core/Tensor.fs index d79270b0..2474cc79 100644 --- a/src/Furnace.Core/Tensor.fs +++ b/src/Furnace.Core/Tensor.fs @@ -792,35 +792,53 @@ type Tensor = /// The argument offset controls which diagonal to consider. /// member a.diagonal(?offset:int, ?dim1:int, ?dim2:int) = - // TODO: The following can be slow, especially for reverse mode differentiation of the diagonal of a large tensor. Consider a faster implementation. if a.dim < 2 then failwithf "Tensor must be at least 2-dimensional" let offset = defaultArg offset 0 let dim1 = defaultArg dim1 0 let dim2 = defaultArg dim2 1 - let mutable finished = false - let mutable d = [] - let mutable i = 0 - let mutable j = offset - while not finished do - if i >= a.shape[dim1] || j >= a.shape[dim2] then - finished <- true - elif j >= 0 then - // let bounds = array2D [[i0min; i0max; i0given]; [i1min; i1max; i1given]; [i2min; i2max; i2given]; [i3min; i3max; i3given]] - let bounds = Array2D.init (a.dim) 3 (fun ii jj -> - if ii = dim1 then - if jj < 2 then i else 1 - elif ii = dim2 then - if jj < 2 then j else 1 - else - if jj = 0 then 0 - elif jj = 1 then a.shape[ii]-1 - else 0 - ) - d <- [a.GetSlice(bounds)] |> List.append d - i <- i + 1 - j <- j + 1 - if d |> List.isEmpty then failwithf "Empty diagonal" - Tensor.stack(d) + + // Calculate diagonal size upfront + let minDim1 = a.shape[dim1] + let minDim2 = a.shape[dim2] + let diagSize = + if offset >= 0 then + max 0 (min minDim1 (minDim2 - offset)) + else + max 0 (min (minDim1 + offset) minDim2) + + if diagSize = 0 then failwithf "Empty diagonal" + + // Pre-allocate array for better performance + let diagonalElements = Array.zeroCreate diagSize + + // Calculate start positions + let startI = max 0 (-offset) + let startJ = max 0 offset + + // Create bounds template once and reuse + let boundsTemplate = Array2D.create a.dim 3 0 + for ii = 0 to a.dim - 1 do + if ii <> dim1 && ii <> dim2 then + boundsTemplate[ii, 0] <- 0 + boundsTemplate[ii, 1] <- a.shape[ii] - 1 + boundsTemplate[ii, 2] <- 0 + + // Extract diagonal elements efficiently + for k = 0 to diagSize - 1 do + let i = startI + k + let j = startJ + k + + // Set the specific indices for this diagonal element + boundsTemplate[dim1, 0] <- i + boundsTemplate[dim1, 1] <- i + boundsTemplate[dim1, 2] <- 1 + boundsTemplate[dim2, 0] <- j + boundsTemplate[dim2, 1] <- j + boundsTemplate[dim2, 2] <- 1 + + diagonalElements[k] <- a.GetSlice(boundsTemplate) + + Tensor.stack(diagonalElements) /// Returns the sum of the elements of the diagonal of the input 2-D matrix. member a.trace() = let d:Tensor = a.diagonal() in d.sum()