From 9e3498d90f3ff10841b015dd36a717cdd52f46c2 Mon Sep 17 00:00:00 2001 From: Brendan Roof Date: Tue, 9 Dec 2025 16:13:43 +0100 Subject: [PATCH 1/2] Add option to trigger multi-gpu mode --- tests/quick_test.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/quick_test.py b/tests/quick_test.py index 3ca5526..af8d94c 100644 --- a/tests/quick_test.py +++ b/tests/quick_test.py @@ -9,6 +9,7 @@ import logging from unittest.mock import patch +import pandas as pd from sklearn.datasets import load_breast_cancer, load_diabetes from sklearn.model_selection import train_test_split @@ -18,18 +19,27 @@ logging.basicConfig(level=logging.INFO) +def embiggen(x): + df = pd.DataFrame(x) + print(f"shape before: {df.shape}") + big = pd.concat([df] * 50, ignore_index=True) + print(f"shape after: {big.shape}") + return big if __name__ == "__main__": # Patch webbrowser.open to prevent browser login with patch("webbrowser.open", return_value=False): - use_server = True - # use_server = False + trigger_multi_gpu_threshold = False X, y = load_breast_cancer(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.33, random_state=42 ) + if trigger_multi_gpu_threshold: + X_train = embiggen(X_train) + y_train = embiggen(y_train) + tabpfn = TabPFNClassifier.create_default_for_version( ModelVersion.V2_5, n_estimators=3 ) @@ -49,6 +59,10 @@ X, y, test_size=0.33, random_state=42 ) + if trigger_multi_gpu_threshold: + X_train = embiggen(X_train) + y_train = embiggen(y_train) + tabpfn = TabPFNRegressor.create_default_for_version( ModelVersion.V2_5, n_estimators=3 ) From 136c433f006686fffa75bd723244073573445e41 Mon Sep 17 00:00:00 2001 From: Brendan Roof Date: Tue, 9 Dec 2025 16:18:11 +0100 Subject: [PATCH 2/2] fix --- tests/quick_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/quick_test.py b/tests/quick_test.py index af8d94c..c092b48 100644 --- a/tests/quick_test.py +++ b/tests/quick_test.py @@ -29,7 +29,7 @@ def embiggen(x): if __name__ == "__main__": # Patch webbrowser.open to prevent browser login with patch("webbrowser.open", return_value=False): - trigger_multi_gpu_threshold = False + trigger_multi_gpu_threshold = True X, y = load_breast_cancer(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split( @@ -38,7 +38,7 @@ def embiggen(x): if trigger_multi_gpu_threshold: X_train = embiggen(X_train) - y_train = embiggen(y_train) + y_train = embiggen(y_train).values.ravel() tabpfn = TabPFNClassifier.create_default_for_version( ModelVersion.V2_5, n_estimators=3 @@ -61,7 +61,7 @@ def embiggen(x): if trigger_multi_gpu_threshold: X_train = embiggen(X_train) - y_train = embiggen(y_train) + y_train = embiggen(y_train).values.ravel() tabpfn = TabPFNRegressor.create_default_for_version( ModelVersion.V2_5, n_estimators=3