diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 2ce015b36..e1ebf799e 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -2270,6 +2270,16 @@ def sample_inputs_unique_dim(op_info, device, dtype, requires_grad, **kwargs): yield sample +def sample_inputs_unique_consecutive(op_info, device, dtype, requires_grad, **kwargs): + for sample in common_methods_invocations.sample_inputs_unique( + op_info, device, dtype, requires_grad, **kwargs + ): + # unique_consecutive only supports dim=None or (dim=0 with rank=1) + # So filter out samples with dim != None + if sample.kwargs.get("dim") is None: + yield sample + + def sample_inputs_upsample_trilinear3d_vec(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs @@ -2878,6 +2888,14 @@ def __init__(self): supports_out=False, supports_autograd=False, ), + opinfo_core.OpInfo( + "ops.aten.unique_consecutive", + aten_name="unique_consecutive", + dtypes=common_dtype.integral_types(), + sample_inputs_func=sample_inputs_unique_consecutive, + supports_out=False, + supports_autograd=False, + ), opinfo_core.OpInfo( "ops.aten.upsample_bicubic2d.default", aten_name="upsample_bicubic2d", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index f6ce0f517..46b4958f8 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1865,6 +1865,7 @@ def _where_input_wrangler( "Our implementation is based on that for CUDA" ), ), + TorchLibOpInfo("ops.aten.unique_consecutive", core_ops.aten_unique_consecutive), TorchLibOpInfo("ops.prims.broadcast_in_dim.default", prims_ops.prims_broadcast_in_dim), TorchLibOpInfo( "ops.prims.var.default", prims_ops.prims_var, tolerance={torch.float16: (1e-3, 5e-2)}