From 73744d07b5078d4820dc838cfd8d3a224dab9e6f Mon Sep 17 00:00:00 2001 From: Manjunath Gaonkar Date: Thu, 5 Mar 2026 17:11:53 +0000 Subject: [PATCH] Remove premature input_tensor_desc destruction in RNN kernels The input_tensor_desc is still needed by MIOpen for the execution buffer and must not be destroyed before the RNN operation completes. --- jaxlib/gpu/rnn_kernels.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index 5020c9a1d36f..52736202f015 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -356,7 +356,6 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyRNNDescriptor(rnn_desc))); #ifdef JAX_GPU_HIP JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuFree(dropout_states_dev))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyTensorDescriptor(input_tensor_desc))); #endif return absl::OkStatus(); @@ -545,7 +544,6 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyRNNDescriptor(rnn_desc))); #ifdef JAX_GPU_HIP JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuFree(dropout_states_dev))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyTensorDescriptor(input_tensor_desc))); #endif return absl::OkStatus();