From ee45e2c7bec6cb9810724c0b75e436e90660bf14 Mon Sep 17 00:00:00 2001 From: Brendan Roof Date: Mon, 17 Nov 2025 23:56:50 +0100 Subject: [PATCH] wip --- src/tabpfn_client/client.py | 25 ++++++++++++++++++------- src/tabpfn_client/estimator.py | 12 ++++++++++-- src/tabpfn_client/server_config.yaml | 18 +++++++++--------- src/tabpfn_client/service_wrapper.py | 6 ++++-- tests/unit/test_client.py | 12 ++++++++---- 5 files changed, 49 insertions(+), 24 deletions(-) diff --git a/src/tabpfn_client/client.py b/src/tabpfn_client/client.py index 30712f9..cfabb01 100644 --- a/src/tabpfn_client/client.py +++ b/src/tabpfn_client/client.py @@ -207,6 +207,7 @@ class FitCompleteEvent(BaseModel): status: Literal["complete"] = "complete" # Identifies the dataset and thus the model train_set_uid: str + model_id: Optional[str] = None class FitErrorEvent(BaseModel): @@ -225,6 +226,11 @@ class FitErrorEvent(BaseModel): ResponseEventAdapter = TypeAdapter(ResponseEvents) +class FittedModel(BaseModel): + train_set_uid: str + model_id: Optional[str] + + @dataclass(frozen=True) class PredictionResult: y_pred: Union[np.ndarray, list[np.ndarray], dict[str, np.ndarray]] @@ -324,7 +330,7 @@ def _build_tabpfn_params(tabpfn_config: Union[dict, None]) -> dict: params : dict Dictionary containing 'tabpfn_systems' and optionally 'tabpfn_config'. """ - _, tabpfn_systems, processed_config, _ = ServiceClient._process_tabpfn_config(tabpfn_config) + _, tabpfn_systems, processed_config, tabpfnr_params = ServiceClient._process_tabpfn_config(tabpfn_config) params = { "tabpfn_systems": json.dumps(tabpfn_systems) @@ -334,7 +340,9 @@ def _build_tabpfn_params(tabpfn_config: Union[dict, None]) -> dict: params["tabpfn_config"] = json.dumps( processed_config, default=lambda x: x.to_dict() ) - + if tabpfnr_params is not None: + params["tabpfnr_params"] = json.dumps(tabpfnr_params) + return params @classmethod @@ -379,7 +387,7 @@ def fit( tabpfn_config: Union[dict, None] = None, task: Optional[Literal["classification", "regression"]] = None, description: str = "", - ) -> str: + ) -> FittedModel: """ Upload a train set to server and return the train set UID if successful. @@ -447,8 +455,6 @@ def fit( query_params["task"] = task else: assert model_type != ModelType.TABPFN_R, "Thinking mode requires a task." - if tabpfnr_params: - form_data["tabpfnr_params"] = json.dumps(tabpfnr_params) with cls.httpx_client.stream( "POST", @@ -498,6 +504,7 @@ def fit( ) elif isinstance(event, FitCompleteEvent): train_set_uid = event.train_set_uid + model_id = event.model_id elif isinstance(event, FitErrorEvent): # Handle structured error events raise RuntimeError(f"Error from server: {event.message}") @@ -509,7 +516,7 @@ def fit( raise RuntimeError("Error during fit. No valid model received.") cls.dataset_uid_cache_manager.add_dataset_uid(dataset_hash, train_set_uid) - return train_set_uid + return FittedModel(train_set_uid=train_set_uid, model_id=model_id) @classmethod @backoff.on_exception( @@ -535,6 +542,7 @@ def predict( x_test, model_type: ModelType, task: Literal["classification", "regression"], + model_id: Optional[str] = None, predict_params: Union[dict, None] = None, tabpfn_config: Union[dict, None] = None, X_train=None, @@ -563,6 +571,7 @@ def predict( params = cls._build_tabpfn_params(tabpfn_config) params.update({ "train_set_uid": train_set_uid, + "model_id": model_id, "task": task, "predict_params": json.dumps(predict_params), }) @@ -578,6 +587,7 @@ def predict( ) = cls.dataset_uid_cache_manager.get_dataset_uid( x_test_serialized, train_set_uid, + model_id, # Note: This may not be strictly necessary for predict, but keep it for safety. model_type.name, cls._access_token, @@ -664,13 +674,14 @@ def run_progress(): raise NotImplementedError( "Automatically re-uploading the train set is not supported for thinking mode. Please call fit()." ) - train_set_uid = cls.fit( + fitted_model = cls.fit( X_train, y_train, model_type=model_type, tabpfn_config=tabpfn_config, task=task ) + train_set_uid = fitted_model.train_set_uid params["train_set_uid"] = train_set_uid cached_test_set_uid = None else: diff --git a/src/tabpfn_client/estimator.py b/src/tabpfn_client/estimator.py index 757eb72..6b09f26 100644 --- a/src/tabpfn_client/estimator.py +++ b/src/tabpfn_client/estimator.py @@ -221,6 +221,7 @@ def __init__( self.inference_config = inference_config self.paper_version = paper_version self.last_train_set_uid = None + self.last_model_id = None self.last_train_X = None self.last_train_y = None self.thinking = thinking @@ -240,7 +241,7 @@ def fit(self, X, y, description: str = ""): estimator_param = self._get_estimator_params_with_model_path("classification") if Config.use_server: model_type = ModelType.TABPFN_R if self.thinking else ModelType.TABPFN - self.last_train_set_uid = InferenceClient.fit( + fitted_model = InferenceClient.fit( X, y, tabpfn_config=estimator_param, @@ -248,6 +249,8 @@ def fit(self, X, y, description: str = ""): task="classification", description=description, ) + self.last_train_set_uid = fitted_model.train_set_uid + self.last_model_id = fitted_model.model_id self.last_train_X = X self.last_train_y = y self.fitted_ = True @@ -294,6 +297,7 @@ def predict_task() -> PredictionResult: model_type=model_type, task="classification", train_set_uid=self.last_train_set_uid, + model_id=self.last_model_id, tabpfn_config=estimator_param, predict_params={"output_type": output_type}, X_train=self.last_train_X, @@ -416,6 +420,7 @@ def __init__( self.thinking = thinking self.thinking_params = thinking_params self.last_train_set_uid = None + self.last_model_id = None self.last_train_X = None self.last_train_y = None self.last_meta = {} @@ -433,7 +438,7 @@ def fit(self, X, y, description: str = ""): estimator_param = self._get_estimator_params_with_model_path("regression") if Config.use_server: model_type = ModelType.TABPFN_R if self.thinking else ModelType.TABPFN - self.last_train_set_uid = InferenceClient.fit( + fitted_model = InferenceClient.fit( X, y, tabpfn_config=estimator_param, @@ -441,6 +446,8 @@ def fit(self, X, y, description: str = ""): task="regression", description=description, ) + self.last_train_set_uid = fitted_model.train_set_uid + self.last_model_id = fitted_model.model_id self.last_train_X = X self.last_train_y = y self.fitted_ = True @@ -502,6 +509,7 @@ def predict_task() -> PredictionResult: model_type=model_type, task="regression", train_set_uid=self.last_train_set_uid, + model_id=self.last_model_id, tabpfn_config=estimator_param, predict_params=predict_params, X_train=self.last_train_X, diff --git a/src/tabpfn_client/server_config.yaml b/src/tabpfn_client/server_config.yaml index d553f7d..a75d65b 100644 --- a/src/tabpfn_client/server_config.yaml +++ b/src/tabpfn_client/server_config.yaml @@ -1,12 +1,12 @@ -# # local testing -# protocol: "http" -# host: "localhost" -# port: "8080" - -# production -protocol: "https" -host: "api.priorlabs.ai" -port: "443" +# local testing +protocol: "http" +host: "localhost" +port: "8080" + +## production +#protocol: "https" +#host: "api.priorlabs.ai" +#port: "443" gui_url: "https://ux.priorlabs.ai" endpoints: diff --git a/src/tabpfn_client/service_wrapper.py b/src/tabpfn_client/service_wrapper.py index 9ce0af0..980f0b8 100644 --- a/src/tabpfn_client/service_wrapper.py +++ b/src/tabpfn_client/service_wrapper.py @@ -8,7 +8,7 @@ import numpy as np -from tabpfn_client.client import ServiceClient, ModelType +from tabpfn_client.client import FittedModel, ServiceClient, ModelType from tabpfn_client.constants import CACHE_DIR from tabpfn_common_utils.utils import Singleton @@ -254,7 +254,7 @@ def fit( tabpfn_config=None, task: Optional[Literal["classification", "regression"]] = None, description: str = "", - ) -> str: + ) -> FittedModel: return ServiceClient.fit( X, y, @@ -270,6 +270,7 @@ def predict( X, task: Literal["classification", "regression"], train_set_uid: str, + model_id: Optional[str] = None, model_type: ModelType = ModelType.TABPFN, tabpfn_config=None, predict_params=None, @@ -280,6 +281,7 @@ def predict( train_set_uid=train_set_uid, x_test=X, tabpfn_config=tabpfn_config, + model_id=model_id, predict_params=predict_params, task=task, X_train=X_train, diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 66433c3..a21535b 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -278,14 +278,16 @@ def test_fit_with_caching(self, mock_server): mock_stream.return_value = mock_response # First upload - train_set_uid1 = ServiceClient.fit( + fitted_model1 = ServiceClient.fit( self.X_train, self.y_train, model_type=ModelType.TABPFN ) + train_set_uid1 = fitted_model1.train_set_uid # Second upload with the same data - train_set_uid2 = ServiceClient.fit( + fitted_model2 = ServiceClient.fit( self.X_train, self.y_train, model_type=ModelType.TABPFN ) + train_set_uid2 = fitted_model2.train_set_uid # The train_set_uid should be the same due to caching self.assertEqual(train_set_uid1, train_set_uid2) @@ -348,9 +350,10 @@ def side_effect(*args, **kwargs): mock_stream.side_effect = side_effect # Upload train set - train_set_uid = ServiceClient.fit( + fitted_model = ServiceClient.fit( self.X_train, self.y_train, model_type=ModelType.TABPFN ) + train_set_uid = fitted_model.train_set_uid # First prediction pred1 = ServiceClient.predict( @@ -452,9 +455,10 @@ def side_effect_counter(*args, **kwargs): mock_stream.side_effect = side_effect_counter # Upload train set - train_set_uid = ServiceClient.fit( + fitted_model = ServiceClient.fit( self.X_train, self.y_train, model_type=ModelType.TABPFN ) + train_set_uid = fitted_model.train_set_uid # Attempt prediction, which should fail and trigger retry pred = ServiceClient.predict(