From 7a2f20e9fef9cf283a72ce6c9c78a9bb68a48dce Mon Sep 17 00:00:00 2001 From: Bastian Grumbrecht Date: Sat, 27 Dec 2025 21:36:18 +0100 Subject: [PATCH] feat:add Narwhals utilities for horizontal statistical operations. --- src/centimators/narwhals_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/centimators/narwhals_utils.py b/src/centimators/narwhals_utils.py index b506710..1848900 100644 --- a/src/centimators/narwhals_utils.py +++ b/src/centimators/narwhals_utils.py @@ -13,7 +13,7 @@ def _ensure_numpy(data, allow_series: bool = False): """Convert data to numpy array, handling both numpy arrays and dataframes. Args: - data: Input data (numpy array, dataframe, or series) + data: Input data (numpy array, dataframe, series, or PyTorch tensor) allow_series: Whether to allow series inputs Returns: @@ -24,6 +24,14 @@ def _ensure_numpy(data, allow_series: bool = False): try: return nw.from_native(data, allow_series=allow_series).to_numpy() except Exception: + # Handle PyTorch tensors (including CUDA tensors) + try: + import torch + if isinstance(data, torch.Tensor): + # Move to CPU if on GPU, then convert to numpy + return data.detach().cpu().numpy() + except ImportError: + pass return numpy.asarray(data)