-
Notifications
You must be signed in to change notification settings - Fork 23
WIP: Save/load models for CAAFE -- Client #184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -207,6 +207,7 @@ class FitCompleteEvent(BaseModel): | |||||||||||
| status: Literal["complete"] = "complete" | ||||||||||||
| # Identifies the dataset and thus the model | ||||||||||||
| train_set_uid: str | ||||||||||||
| model_id: Optional[str] = None | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class FitErrorEvent(BaseModel): | ||||||||||||
|
|
@@ -225,6 +226,11 @@ class FitErrorEvent(BaseModel): | |||||||||||
| ResponseEventAdapter = TypeAdapter(ResponseEvents) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class FittedModel(BaseModel): | ||||||||||||
| train_set_uid: str | ||||||||||||
| model_id: Optional[str] | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @dataclass(frozen=True) | ||||||||||||
| class PredictionResult: | ||||||||||||
| y_pred: Union[np.ndarray, list[np.ndarray], dict[str, np.ndarray]] | ||||||||||||
|
|
@@ -324,7 +330,7 @@ def _build_tabpfn_params(tabpfn_config: Union[dict, None]) -> dict: | |||||||||||
| params : dict | ||||||||||||
| Dictionary containing 'tabpfn_systems' and optionally 'tabpfn_config'. | ||||||||||||
| """ | ||||||||||||
| _, tabpfn_systems, processed_config, _ = ServiceClient._process_tabpfn_config(tabpfn_config) | ||||||||||||
| _, tabpfn_systems, processed_config, tabpfnr_params = ServiceClient._process_tabpfn_config(tabpfn_config) | ||||||||||||
|
|
||||||||||||
| params = { | ||||||||||||
| "tabpfn_systems": json.dumps(tabpfn_systems) | ||||||||||||
|
|
@@ -334,7 +340,9 @@ def _build_tabpfn_params(tabpfn_config: Union[dict, None]) -> dict: | |||||||||||
| params["tabpfn_config"] = json.dumps( | ||||||||||||
| processed_config, default=lambda x: x.to_dict() | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| if tabpfnr_params is not None: | ||||||||||||
| params["tabpfnr_params"] = json.dumps(tabpfnr_params) | ||||||||||||
|
|
||||||||||||
| return params | ||||||||||||
|
|
||||||||||||
| @classmethod | ||||||||||||
|
|
@@ -379,7 +387,7 @@ def fit( | |||||||||||
| tabpfn_config: Union[dict, None] = None, | ||||||||||||
| task: Optional[Literal["classification", "regression"]] = None, | ||||||||||||
| description: str = "", | ||||||||||||
| ) -> str: | ||||||||||||
| ) -> FittedModel: | ||||||||||||
| """ | ||||||||||||
| Upload a train set to server and return the train set UID if successful. | ||||||||||||
|
|
||||||||||||
|
|
@@ -447,8 +455,6 @@ def fit( | |||||||||||
| query_params["task"] = task | ||||||||||||
| else: | ||||||||||||
| assert model_type != ModelType.TABPFN_R, "Thinking mode requires a task." | ||||||||||||
| if tabpfnr_params: | ||||||||||||
| form_data["tabpfnr_params"] = json.dumps(tabpfnr_params) | ||||||||||||
|
|
||||||||||||
| with cls.httpx_client.stream( | ||||||||||||
| "POST", | ||||||||||||
|
|
@@ -498,6 +504,7 @@ def fit( | |||||||||||
| ) | ||||||||||||
| elif isinstance(event, FitCompleteEvent): | ||||||||||||
| train_set_uid = event.train_set_uid | ||||||||||||
| model_id = event.model_id | ||||||||||||
| elif isinstance(event, FitErrorEvent): | ||||||||||||
| # Handle structured error events | ||||||||||||
| raise RuntimeError(f"Error from server: {event.message}") | ||||||||||||
|
|
@@ -509,7 +516,7 @@ def fit( | |||||||||||
| raise RuntimeError("Error during fit. No valid model received.") | ||||||||||||
|
|
||||||||||||
| cls.dataset_uid_cache_manager.add_dataset_uid(dataset_hash, train_set_uid) | ||||||||||||
| return train_set_uid | ||||||||||||
| return FittedModel(train_set_uid=train_set_uid, model_id=model_id) | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't this break backwards compatibility for anyone using the |
||||||||||||
|
|
||||||||||||
| @classmethod | ||||||||||||
| @backoff.on_exception( | ||||||||||||
|
|
@@ -535,6 +542,7 @@ def predict( | |||||||||||
| x_test, | ||||||||||||
| model_type: ModelType, | ||||||||||||
| task: Literal["classification", "regression"], | ||||||||||||
| model_id: Optional[str] = None, | ||||||||||||
| predict_params: Union[dict, None] = None, | ||||||||||||
| tabpfn_config: Union[dict, None] = None, | ||||||||||||
| X_train=None, | ||||||||||||
|
|
@@ -563,6 +571,7 @@ def predict( | |||||||||||
| params = cls._build_tabpfn_params(tabpfn_config) | ||||||||||||
| params.update({ | ||||||||||||
| "train_set_uid": train_set_uid, | ||||||||||||
| "model_id": model_id, | ||||||||||||
| "task": task, | ||||||||||||
| "predict_params": json.dumps(predict_params), | ||||||||||||
| }) | ||||||||||||
|
|
@@ -578,6 +587,7 @@ def predict( | |||||||||||
| ) = cls.dataset_uid_cache_manager.get_dataset_uid( | ||||||||||||
| x_test_serialized, | ||||||||||||
| train_set_uid, | ||||||||||||
| model_id, | ||||||||||||
| # Note: This may not be strictly necessary for predict, but keep it for safety. | ||||||||||||
| model_type.name, | ||||||||||||
| cls._access_token, | ||||||||||||
|
|
@@ -664,13 +674,14 @@ def run_progress(): | |||||||||||
| raise NotImplementedError( | ||||||||||||
| "Automatically re-uploading the train set is not supported for thinking mode. Please call fit()." | ||||||||||||
| ) | ||||||||||||
| train_set_uid = cls.fit( | ||||||||||||
| fitted_model = cls.fit( | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please share more context on why do we re-fit it again in |
||||||||||||
| X_train, | ||||||||||||
| y_train, | ||||||||||||
| model_type=model_type, | ||||||||||||
| tabpfn_config=tabpfn_config, | ||||||||||||
| task=task | ||||||||||||
| ) | ||||||||||||
| train_set_uid = fitted_model.train_set_uid | ||||||||||||
| params["train_set_uid"] = train_set_uid | ||||||||||||
|
Comment on lines
+684
to
685
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When retrying a prediction after a UID error, a new model is fitted, but only the
Suggested change
|
||||||||||||
| cached_test_set_uid = None | ||||||||||||
| else: | ||||||||||||
|
|
||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,12 @@ | ||
| # # local testing | ||
| # protocol: "http" | ||
| # host: "localhost" | ||
| # port: "8080" | ||
|
|
||
| # production | ||
| protocol: "https" | ||
| host: "api.priorlabs.ai" | ||
| port: "443" | ||
| # local testing | ||
| protocol: "http" | ||
| host: "localhost" | ||
| port: "8080" | ||
|
|
||
| ## production | ||
| #protocol: "https" | ||
| #host: "api.priorlabs.ai" | ||
| #port: "443" | ||
|
Comment on lines
+1
to
+9
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change switches the server configuration from production to a local testing environment. While this is useful for development, it should be reverted before merging to avoid pointing the client to a local server in a production release. Please ensure the production configuration is active in the final version of this pull request. # # local testing
# protocol: "http"
# host: "localhost"
# port: "8080"
# production
protocol: "https"
host: "api.priorlabs.ai"
port: "443" |
||
|
|
||
| gui_url: "https://ux.priorlabs.ai" | ||
| endpoints: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -278,14 +278,16 @@ def test_fit_with_caching(self, mock_server): | |
| mock_stream.return_value = mock_response | ||
|
|
||
| # First upload | ||
| train_set_uid1 = ServiceClient.fit( | ||
| fitted_model1 = ServiceClient.fit( | ||
| self.X_train, self.y_train, model_type=ModelType.TABPFN | ||
| ) | ||
| train_set_uid1 = fitted_model1.train_set_uid | ||
|
|
||
| # Second upload with the same data | ||
| train_set_uid2 = ServiceClient.fit( | ||
| fitted_model2 = ServiceClient.fit( | ||
| self.X_train, self.y_train, model_type=ModelType.TABPFN | ||
| ) | ||
| train_set_uid2 = fitted_model2.train_set_uid | ||
|
|
||
| # The train_set_uid should be the same due to caching | ||
| self.assertEqual(train_set_uid1, train_set_uid2) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The existing tests are correctly updated to handle the new return type of
|
||
|
|
@@ -348,9 +350,10 @@ def side_effect(*args, **kwargs): | |
| mock_stream.side_effect = side_effect | ||
|
|
||
| # Upload train set | ||
| train_set_uid = ServiceClient.fit( | ||
| fitted_model = ServiceClient.fit( | ||
| self.X_train, self.y_train, model_type=ModelType.TABPFN | ||
| ) | ||
| train_set_uid = fitted_model.train_set_uid | ||
|
|
||
| # First prediction | ||
| pred1 = ServiceClient.predict( | ||
|
|
@@ -452,9 +455,10 @@ def side_effect_counter(*args, **kwargs): | |
| mock_stream.side_effect = side_effect_counter | ||
|
|
||
| # Upload train set | ||
| train_set_uid = ServiceClient.fit( | ||
| fitted_model = ServiceClient.fit( | ||
| self.X_train, self.y_train, model_type=ModelType.TABPFN | ||
| ) | ||
| train_set_uid = fitted_model.train_set_uid | ||
|
|
||
| # Attempt prediction, which should fail and trigger retry | ||
| pred = ServiceClient.predict( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable
model_idis assigned here within anelifblock. However, it is not initialized before the loop (lines 474-511). If aFitCompleteEventis not received from the server, the subsequent reference tomodel_idon line 519 will raise anUnboundLocalError. To prevent this, please initializemodel_idtoNonebefore thetryblock, similar to howtrain_set_uidis initialized.