diff --git a/tests/quick_test.py b/tests/quick_test.py index 3ca5526..c092b48 100644 --- a/tests/quick_test.py +++ b/tests/quick_test.py @@ -9,6 +9,7 @@ import logging from unittest.mock import patch +import pandas as pd from sklearn.datasets import load_breast_cancer, load_diabetes from sklearn.model_selection import train_test_split @@ -18,18 +19,27 @@ logging.basicConfig(level=logging.INFO) +def embiggen(x): + df = pd.DataFrame(x) + print(f"shape before: {df.shape}") + big = pd.concat([df] * 50, ignore_index=True) + print(f"shape after: {big.shape}") + return big if __name__ == "__main__": # Patch webbrowser.open to prevent browser login with patch("webbrowser.open", return_value=False): - use_server = True - # use_server = False + trigger_multi_gpu_threshold = True X, y = load_breast_cancer(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.33, random_state=42 ) + if trigger_multi_gpu_threshold: + X_train = embiggen(X_train) + y_train = embiggen(y_train).values.ravel() + tabpfn = TabPFNClassifier.create_default_for_version( ModelVersion.V2_5, n_estimators=3 ) @@ -49,6 +59,10 @@ X, y, test_size=0.33, random_state=42 ) + if trigger_multi_gpu_threshold: + X_train = embiggen(X_train) + y_train = embiggen(y_train).values.ravel() + tabpfn = TabPFNRegressor.create_default_for_version( ModelVersion.V2_5, n_estimators=3 )