From f8f9e18f1c24f85dda344cb073a11ad731e545ed Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Thu, 5 Mar 2026 11:42:53 -0600 Subject: [PATCH] Revert "Fix HIP memory leaks in RNN kernels (#726)" This reverts commit 48d2ef1435bddcd0d690b1686b47db1af8018913. --- jaxlib/gpu/rnn_kernels.cc | 11 ----------- jaxlib/gpu/vendor.h | 1 - 2 files changed, 12 deletions(-) diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index 5020c9a1d36f..006b3d1150e3 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -172,9 +172,6 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, JAX_RETURN_IF_ERROR( JAX_AS_STATUS(gpudnnDestroyRNNDataDescriptor(input_data_desc))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyRNNDescriptor(rnn_desc))); -#ifdef JAX_GPU_HIP - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuFree(dropout_states_dev))); -#endif // Round up to nearest multiples of 4 so we can return them as f32 arrays. workSpaceSize += (workSpaceSize % 4); @@ -354,10 +351,6 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers, JAX_RETURN_IF_ERROR( JAX_AS_STATUS(gpudnnDestroyDropoutDescriptor(dropout_desc))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyRNNDescriptor(rnn_desc))); -#ifdef JAX_GPU_HIP - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuFree(dropout_states_dev))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyTensorDescriptor(input_tensor_desc))); -#endif return absl::OkStatus(); } @@ -543,10 +536,6 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers, JAX_RETURN_IF_ERROR( JAX_AS_STATUS(gpudnnDestroyDropoutDescriptor(dropout_desc))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyRNNDescriptor(rnn_desc))); -#ifdef JAX_GPU_HIP - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuFree(dropout_states_dev))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyTensorDescriptor(input_tensor_desc))); -#endif return absl::OkStatus(); } diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index becfb049de88..d93a488067bd 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -774,7 +774,6 @@ inline hipsparseStatus_t gpusparseCreate(gpusparseHandle_t* handle) { #define GPU_STREAM_NON_BLOCKING hipStreamNonBlocking #define gpuMalloc hipMalloc -#define gpuFree hipFree #define gpuGetLastError hipExtGetLastError #define gpuGetErrorString hipGetErrorString #define gpuMemcpyAsync hipMemcpyAsync