Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions src/tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ class FitCompleteEvent(BaseModel):
status: Literal["complete"] = "complete"
# Identifies the dataset and thus the model
train_set_uid: str
model_id: Optional[str] = None


class FitErrorEvent(BaseModel):
Expand All @@ -225,6 +226,11 @@ class FitErrorEvent(BaseModel):
ResponseEventAdapter = TypeAdapter(ResponseEvents)


class FittedModel(BaseModel):
train_set_uid: str
model_id: Optional[str]


@dataclass(frozen=True)
class PredictionResult:
y_pred: Union[np.ndarray, list[np.ndarray], dict[str, np.ndarray]]
Expand Down Expand Up @@ -324,7 +330,7 @@ def _build_tabpfn_params(tabpfn_config: Union[dict, None]) -> dict:
params : dict
Dictionary containing 'tabpfn_systems' and optionally 'tabpfn_config'.
"""
_, tabpfn_systems, processed_config, _ = ServiceClient._process_tabpfn_config(tabpfn_config)
_, tabpfn_systems, processed_config, tabpfnr_params = ServiceClient._process_tabpfn_config(tabpfn_config)

params = {
"tabpfn_systems": json.dumps(tabpfn_systems)
Expand All @@ -334,7 +340,9 @@ def _build_tabpfn_params(tabpfn_config: Union[dict, None]) -> dict:
params["tabpfn_config"] = json.dumps(
processed_config, default=lambda x: x.to_dict()
)

if tabpfnr_params is not None:
params["tabpfnr_params"] = json.dumps(tabpfnr_params)

return params

@classmethod
Expand Down Expand Up @@ -379,7 +387,7 @@ def fit(
tabpfn_config: Union[dict, None] = None,
task: Optional[Literal["classification", "regression"]] = None,
description: str = "",
) -> str:
) -> FittedModel:
"""
Upload a train set to server and return the train set UID if successful.

Expand Down Expand Up @@ -447,8 +455,6 @@ def fit(
query_params["task"] = task
else:
assert model_type != ModelType.TABPFN_R, "Thinking mode requires a task."
if tabpfnr_params:
form_data["tabpfnr_params"] = json.dumps(tabpfnr_params)

with cls.httpx_client.stream(
"POST",
Expand Down Expand Up @@ -498,6 +504,7 @@ def fit(
)
elif isinstance(event, FitCompleteEvent):
train_set_uid = event.train_set_uid
model_id = event.model_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable model_id is assigned here within an elif block. However, it is not initialized before the loop (lines 474-511). If a FitCompleteEvent is not received from the server, the subsequent reference to model_id on line 519 will raise an UnboundLocalError. To prevent this, please initialize model_id to None before the try block, similar to how train_set_uid is initialized.

elif isinstance(event, FitErrorEvent):
# Handle structured error events
raise RuntimeError(f"Error from server: {event.message}")
Expand All @@ -509,7 +516,7 @@ def fit(
raise RuntimeError("Error during fit. No valid model received.")

cls.dataset_uid_cache_manager.add_dataset_uid(dataset_hash, train_set_uid)
return train_set_uid
return FittedModel(train_set_uid=train_set_uid, model_id=model_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this break backwards compatibility for anyone using the train_set_uid?


@classmethod
@backoff.on_exception(
Expand All @@ -535,6 +542,7 @@ def predict(
x_test,
model_type: ModelType,
task: Literal["classification", "regression"],
model_id: Optional[str] = None,
predict_params: Union[dict, None] = None,
tabpfn_config: Union[dict, None] = None,
X_train=None,
Expand Down Expand Up @@ -563,6 +571,7 @@ def predict(
params = cls._build_tabpfn_params(tabpfn_config)
params.update({
"train_set_uid": train_set_uid,
"model_id": model_id,
"task": task,
"predict_params": json.dumps(predict_params),
})
Expand All @@ -578,6 +587,7 @@ def predict(
) = cls.dataset_uid_cache_manager.get_dataset_uid(
x_test_serialized,
train_set_uid,
model_id,
# Note: This may not be strictly necessary for predict, but keep it for safety.
model_type.name,
cls._access_token,
Expand Down Expand Up @@ -664,13 +674,14 @@ def run_progress():
raise NotImplementedError(
"Automatically re-uploading the train set is not supported for thinking mode. Please call fit()."
)
train_set_uid = cls.fit(
fitted_model = cls.fit(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please share more context on why do we re-fit it again in predict?

X_train,
y_train,
model_type=model_type,
tabpfn_config=tabpfn_config,
task=task
)
train_set_uid = fitted_model.train_set_uid
params["train_set_uid"] = train_set_uid
Comment on lines +684 to 685
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When retrying a prediction after a UID error, a new model is fitted, but only the train_set_uid is updated in the params for the next attempt. The new model_id from fitted_model is ignored. This could lead to using a stale model_id in the retried prediction request. You should also update the model_id in the params dictionary.

Suggested change
train_set_uid = fitted_model.train_set_uid
params["train_set_uid"] = train_set_uid
train_set_uid = fitted_model.train_set_uid
params["train_set_uid"] = train_set_uid
params["model_id"] = fitted_model.model_id

cached_test_set_uid = None
else:
Expand Down
12 changes: 10 additions & 2 deletions src/tabpfn_client/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def __init__(
self.inference_config = inference_config
self.paper_version = paper_version
self.last_train_set_uid = None
self.last_model_id = None
self.last_train_X = None
self.last_train_y = None
self.thinking = thinking
Expand All @@ -240,14 +241,16 @@ def fit(self, X, y, description: str = ""):
estimator_param = self._get_estimator_params_with_model_path("classification")
if Config.use_server:
model_type = ModelType.TABPFN_R if self.thinking else ModelType.TABPFN
self.last_train_set_uid = InferenceClient.fit(
fitted_model = InferenceClient.fit(
X,
y,
tabpfn_config=estimator_param,
model_type=model_type,
task="classification",
description=description,
)
self.last_train_set_uid = fitted_model.train_set_uid
self.last_model_id = fitted_model.model_id
self.last_train_X = X
self.last_train_y = y
self.fitted_ = True
Expand Down Expand Up @@ -294,6 +297,7 @@ def predict_task() -> PredictionResult:
model_type=model_type,
task="classification",
train_set_uid=self.last_train_set_uid,
model_id=self.last_model_id,
tabpfn_config=estimator_param,
predict_params={"output_type": output_type},
X_train=self.last_train_X,
Expand Down Expand Up @@ -416,6 +420,7 @@ def __init__(
self.thinking = thinking
self.thinking_params = thinking_params
self.last_train_set_uid = None
self.last_model_id = None
self.last_train_X = None
self.last_train_y = None
self.last_meta = {}
Expand All @@ -433,14 +438,16 @@ def fit(self, X, y, description: str = ""):
estimator_param = self._get_estimator_params_with_model_path("regression")
if Config.use_server:
model_type = ModelType.TABPFN_R if self.thinking else ModelType.TABPFN
self.last_train_set_uid = InferenceClient.fit(
fitted_model = InferenceClient.fit(
X,
y,
tabpfn_config=estimator_param,
model_type=model_type,
task="regression",
description=description,
)
self.last_train_set_uid = fitted_model.train_set_uid
self.last_model_id = fitted_model.model_id
self.last_train_X = X
self.last_train_y = y
self.fitted_ = True
Expand Down Expand Up @@ -502,6 +509,7 @@ def predict_task() -> PredictionResult:
model_type=model_type,
task="regression",
train_set_uid=self.last_train_set_uid,
model_id=self.last_model_id,
tabpfn_config=estimator_param,
predict_params=predict_params,
X_train=self.last_train_X,
Expand Down
18 changes: 9 additions & 9 deletions src/tabpfn_client/server_config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# # local testing
# protocol: "http"
# host: "localhost"
# port: "8080"

# production
protocol: "https"
host: "api.priorlabs.ai"
port: "443"
# local testing
protocol: "http"
host: "localhost"
port: "8080"

## production
#protocol: "https"
#host: "api.priorlabs.ai"
#port: "443"
Comment on lines +1 to +9
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change switches the server configuration from production to a local testing environment. While this is useful for development, it should be reverted before merging to avoid pointing the client to a local server in a production release. Please ensure the production configuration is active in the final version of this pull request.

# # local testing
# protocol: "http"
# host: "localhost"
# port: "8080"

# production
protocol: "https"
host: "api.priorlabs.ai"
port: "443"


gui_url: "https://ux.priorlabs.ai"
endpoints:
Expand Down
6 changes: 4 additions & 2 deletions src/tabpfn_client/service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np

from tabpfn_client.client import ServiceClient, ModelType
from tabpfn_client.client import FittedModel, ServiceClient, ModelType
from tabpfn_client.constants import CACHE_DIR
from tabpfn_common_utils.utils import Singleton

Expand Down Expand Up @@ -254,7 +254,7 @@ def fit(
tabpfn_config=None,
task: Optional[Literal["classification", "regression"]] = None,
description: str = "",
) -> str:
) -> FittedModel:
return ServiceClient.fit(
X,
y,
Expand All @@ -270,6 +270,7 @@ def predict(
X,
task: Literal["classification", "regression"],
train_set_uid: str,
model_id: Optional[str] = None,
model_type: ModelType = ModelType.TABPFN,
tabpfn_config=None,
predict_params=None,
Expand All @@ -280,6 +281,7 @@ def predict(
train_set_uid=train_set_uid,
x_test=X,
tabpfn_config=tabpfn_config,
model_id=model_id,
predict_params=predict_params,
task=task,
X_train=X_train,
Expand Down
12 changes: 8 additions & 4 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,16 @@ def test_fit_with_caching(self, mock_server):
mock_stream.return_value = mock_response

# First upload
train_set_uid1 = ServiceClient.fit(
fitted_model1 = ServiceClient.fit(
self.X_train, self.y_train, model_type=ModelType.TABPFN
)
train_set_uid1 = fitted_model1.train_set_uid

# Second upload with the same data
train_set_uid2 = ServiceClient.fit(
fitted_model2 = ServiceClient.fit(
self.X_train, self.y_train, model_type=ModelType.TABPFN
)
train_set_uid2 = fitted_model2.train_set_uid

# The train_set_uid should be the same due to caching
self.assertEqual(train_set_uid1, train_set_uid2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The existing tests are correctly updated to handle the new return type of fit(). However, the new model_id functionality is not covered by any tests. Please consider adding new tests to verify that:

  1. The model_id is correctly passed to the predict endpoint.
  2. Caching for predictions works correctly when different model_ids are used for the same test data.

Expand Down Expand Up @@ -348,9 +350,10 @@ def side_effect(*args, **kwargs):
mock_stream.side_effect = side_effect

# Upload train set
train_set_uid = ServiceClient.fit(
fitted_model = ServiceClient.fit(
self.X_train, self.y_train, model_type=ModelType.TABPFN
)
train_set_uid = fitted_model.train_set_uid

# First prediction
pred1 = ServiceClient.predict(
Expand Down Expand Up @@ -452,9 +455,10 @@ def side_effect_counter(*args, **kwargs):
mock_stream.side_effect = side_effect_counter

# Upload train set
train_set_uid = ServiceClient.fit(
fitted_model = ServiceClient.fit(
self.X_train, self.y_train, model_type=ModelType.TABPFN
)
train_set_uid = fitted_model.train_set_uid

# Attempt prediction, which should fail and trigger retry
pred = ServiceClient.predict(
Expand Down
Loading