From 03a0e1383aefd05e48dc4185071503a8ad979285 Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Thu, 5 Mar 2026 15:26:37 +0000 Subject: [PATCH] Fix HIP memory leaks in RNN kernels Free dropout_states_dev GPU memory allocated via gpuMalloc in the HIP paths of DoRnnComputeWorkspaceReserveSpaceSizes, DnnRNNForward_, and DnnRNNBackward_. Also destroy the leaked miopenTensorDescriptor in the forward and backward functions. --- jaxlib/gpu/rnn_kernels.cc | 9 +++++++++ jaxlib/gpu/vendor.h | 1 + 2 files changed, 10 insertions(+) diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index 006b3d1150e3..52736202f015 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -172,6 +172,9 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, JAX_RETURN_IF_ERROR( JAX_AS_STATUS(gpudnnDestroyRNNDataDescriptor(input_data_desc))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyRNNDescriptor(rnn_desc))); +#ifdef JAX_GPU_HIP + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuFree(dropout_states_dev))); +#endif // Round up to nearest multiples of 4 so we can return them as f32 arrays. workSpaceSize += (workSpaceSize % 4); @@ -351,6 +354,9 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers, JAX_RETURN_IF_ERROR( JAX_AS_STATUS(gpudnnDestroyDropoutDescriptor(dropout_desc))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyRNNDescriptor(rnn_desc))); +#ifdef JAX_GPU_HIP + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuFree(dropout_states_dev))); +#endif return absl::OkStatus(); } @@ -536,6 +542,9 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers, JAX_RETURN_IF_ERROR( JAX_AS_STATUS(gpudnnDestroyDropoutDescriptor(dropout_desc))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyRNNDescriptor(rnn_desc))); +#ifdef JAX_GPU_HIP + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuFree(dropout_states_dev))); +#endif return absl::OkStatus(); } diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index d93a488067bd..becfb049de88 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -774,6 +774,7 @@ inline hipsparseStatus_t gpusparseCreate(gpusparseHandle_t* handle) { #define GPU_STREAM_NON_BLOCKING hipStreamNonBlocking #define gpuMalloc hipMalloc +#define gpuFree hipFree #define gpuGetLastError hipExtGetLastError #define gpuGetErrorString hipGetErrorString #define gpuMemcpyAsync hipMemcpyAsync