diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index ea92dc347d..037f69feb4 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -14,7 +14,7 @@ from typing import Optional, Sequence -from onnxscript import INT64 +from onnxscript import INT64, ir from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import TFloat from onnxscript.onnx_opset import opset18 as op @@ -118,12 +118,18 @@ def aten__fft_c2r( # Torch truncates/pads on the last dimension only. Typically, the only valid values that can be passed # into PyTorch are n or n//2+1, where n is self.shape[dim[-1]], but this is not always the case, so we # place no such restriction on the ONNX side. - transformed = op.DFT( - transformed, - dft_length=last_dim_size, - axis=dimension, - inverse=True, - onesided=False, + scale = (op.CastLike(last_dim_size, self)) / op.CastLike( + op.Shape(transformed, start=dimension, end=dimension + 1), self + ) + transformed = ( + op.DFT( + transformed, + dft_length=last_dim_size, + axis=dimension, + inverse=True, + onesided=False, + ) + * scale ) transformed = _fftn_onnx_normalization( transformed,