diff --git a/src/Furnace.Backends.Torch/Torch.RawTensor.fs b/src/Furnace.Backends.Torch/Torch.RawTensor.fs index fd1dad55..9708091e 100644 --- a/src/Furnace.Backends.Torch/Torch.RawTensor.fs +++ b/src/Furnace.Backends.Torch/Torch.RawTensor.fs @@ -1115,19 +1115,34 @@ type TorchRawTensor(tt: torch.Tensor, shape: Shape, dtype: Dtype, device: Device checkMutable() tt.add_(toTorchScalar t2) |> ignore - // TODO - this should be faster + // Optimized AddSliceInPlace - reduced allocations and conversions override t1.AddSliceInPlace(location, t2) = checkMutable() Shape.checkCanAddSlice t1.Shape location t2.Shape let shape1 = t1.Shape let shape2 = t2.Shape let expandedShape2 = Shape.unsqueezeAs shape2 shape1 - let t2Expanded = t2.TorchTensor.expand(toTorchShape expandedShape2) + + // Pre-compute torch shape to avoid repeated conversions + let torchExpandedShape2 = + let result = Array.zeroCreate expandedShape2.Length + for i = 0 to expandedShape2.Length - 1 do + result[i] <- int64 expandedShape2[i] + result + + let t2Expanded = t2.TorchTensor.expand(torchExpandedShape2) let mutable t1Slice = tt // will share memory with res + + // Optimize the slicing loop - cache shape values and reduce conditional checks for d in 0 .. location.Length - 1 do + let locationD = location[d] let len2 = expandedShape2[d] - if location[d] <> 0 || len2 <> shape1[d] then - t1Slice <- t1Slice.narrow(int64 d, int64 location[d], int64 len2) + let shape1D = shape1[d] + + // Only narrow if we're not accessing the full dimension + if locationD <> 0 || len2 <> shape1D then + t1Slice <- t1Slice.narrow(int64 d, int64 locationD, int64 len2) + t1Slice.add_(t2Expanded) |> ignore override _.SubInPlace(t2) = checkMutable(); tt.sub_(t2.TorchTensor) |> ignore