Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 73 additions & 113 deletions tabpfn_client/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,92 +32,25 @@
MAX_COLS = 2000
MAX_NUMBER_OF_CLASSES = 10

# Special string used to identify v2.5 models in model paths.
V_2_5_IDENTIFIER = "v2.5"

DEFAULT_V2_MODEL_PATH = "v2_default"
DEFAULT_V2_5_MODEL_PATH = "v2.5_default"


class TabPFNModelSelection:
"""Base class for TabPFN model selection and path handling."""

_AVAILABLE_MODELS: list[str] = []
_VALID_TASKS = {"classification", "regression"}

@classmethod
def list_available_models(cls) -> list[str]:
return cls._AVAILABLE_MODELS

@classmethod
def _validate_model_name(cls, model_name: str) -> None:
if model_name != "default" and model_name not in cls._AVAILABLE_MODELS:
raise ValueError(
f"Invalid model name: {model_name}. "
f"Available models are: {cls.list_available_models()}"
)

@classmethod
def _model_name_to_path(
cls, task: Literal["classification", "regression"], model_name: str
) -> Optional[str]:
cls._validate_model_name(model_name)
model_name_task = "classifier" if task == "classification" else "regressor"
# Let the server handle the default model. This enables v2.5 as well.
if model_name == "default":
return None
if V_2_5_IDENTIFIER in model_name:
return f"tabpfn-{V_2_5_IDENTIFIER}-{model_name_task}-{model_name}.ckpt"
return f"tabpfn-v2-{model_name_task}-{model_name}.ckpt"

@classmethod
def create_default_for_version(cls, version: ModelVersion, **overrides) -> Self:
"""Construct an estimator that uses the given version of the model.

In addition to selecting the model, this also configures the estimator with
certain default settings associated with this model version.

Any kwargs will override the default settings.
"""
options = {
"n_estimators": 8,
"softmax_temperature": 0.9,
}
if version == ModelVersion.V2:
options["model_path"] = DEFAULT_V2_MODEL_PATH
elif version == ModelVersion.V2_5:
options["model_path"] = DEFAULT_V2_5_MODEL_PATH
else:
raise ValueError(f"Unknown version: {version}")
def _default_model_path_v2(task: Literal["classification", "regression"]) -> str:
if task == "classification":
return "tabpfn-v2-classifier-finetuned-zk73skhh.ckpt"
elif task == "regression":
return "tabpfn-v2-regressor.ckpt"
else:
raise ValueError(f"Invalid task: {task}")

options.update(overrides)

return cls(**options)
def _default_model_path_v2_5(task: Literal["classification", "regression"]) -> str:
# using None makes this robust to changes in the tabpfn package
return None


class TabPFNClassifier(ClassifierMixin, BaseEstimator, TabPFNModelSelection):
_AVAILABLE_MODELS = [
"v2.5_default-2",
DEFAULT_V2_5_MODEL_PATH,
"v2.5_large-features-L",
"v2.5_large-features-XL",
"v2.5_large-samples",
"v2.5_real-large-features",
"v2.5_real-large-samples-and-features",
"v2.5_real",
"v2.5_variant",
DEFAULT_V2_MODEL_PATH,
"default",
"gn2p4bpt",
"llderlii",
"od3j1g5m",
"vutqq28w",
"znskzxi4",
]

class TabPFNClassifier(ClassifierMixin, BaseEstimator):
def __init__(
self,
model_path: str = "default",
model_path: Optional[str] = None,
n_estimators: int = 8,
softmax_temperature: float = 0.9,
balance_probabilities: bool = False,
Expand All @@ -138,8 +71,9 @@ def __init__(

Parameters
----------
model_path: str, default="default"
The name of the model to use.
model_path: str or None, default=None
The name of the model to use. If None, default to our latest
default model.
n_estimators: int, default=8
The number of estimators in the TabPFN ensemble. We aggregate the
predictions of `n_estimators`-many forward passes of TabPFN. Each forward
Expand Down Expand Up @@ -194,6 +128,30 @@ def __init__(
self.last_train_y = None
self.last_meta = {}

@classmethod
def create_default_for_version(cls, version: ModelVersion, **overrides) -> Self:
"""Construct an estimator that uses the given version of the model.

In addition to selecting the model, this also configures the estimator with
certain default settings associated with this model version.

Any kwargs will override the default settings.
"""
options = {
"n_estimators": 8,
"softmax_temperature": 0.9,
}
if version == ModelVersion.V2:
options["model_path"] = _default_model_path_v2("classification")
elif version == ModelVersion.V2_5:
options["model_path"] = _default_model_path_v2_5("classification")
else:
raise ValueError(f"Unknown version: {version}")

options.update(overrides)

return cls(**options)

def fit(self, X, y):
# assert init() is called
init()
Expand All @@ -204,9 +162,8 @@ def fit(self, X, y):
_check_paper_version(self.paper_version, X)

estimator_param = self.get_params()
estimator_param["model_path"] = TabPFNClassifier._model_name_to_path(
"classification", self.model_path
)
estimator_param["model_path"] = self.model_path

if Config.use_server:
self.last_train_set_uid = InferenceClient.fit(X, y, config=estimator_param)
self.last_train_X = X
Expand Down Expand Up @@ -247,9 +204,7 @@ def _predict(self, X, output_type):
X = _clean_text_features(X)

estimator_param = self.get_params()
estimator_param["model_path"] = TabPFNClassifier._model_name_to_path(
"classification", self.model_path
)
estimator_param["model_path"] = self.model_path

result: PredictionResult = InferenceClient.predict(
X,
Expand Down Expand Up @@ -285,26 +240,10 @@ def _validate_targets_and_classes(self, y) -> np.ndarray:
)


class TabPFNRegressor(RegressorMixin, BaseEstimator, TabPFNModelSelection):
_AVAILABLE_MODELS = [
DEFAULT_V2_5_MODEL_PATH,
"v2.5_low-skew",
"v2.5_quantiles",
"v2.5_real-variant",
"v2.5_real",
"v2.5_small-samples",
"v2.5_variant",
DEFAULT_V2_MODEL_PATH,
"default",
"2noar4o2",
"5wof9ojf",
"09gpqh39",
"wyl4o83o",
]

class TabPFNRegressor(RegressorMixin, BaseEstimator):
def __init__(
self,
model_path: str = "default",
model_path: Optional[str] = None,
n_estimators: int = 8,
softmax_temperature: float = 0.9,
average_before_softmax: bool = False,
Expand All @@ -324,8 +263,9 @@ def __init__(

Parameters
----------
model_path: str, default="default"
The name to the model to use.
model_path: str or None, default=None
The name to the model to use. If None, default to our latest
default model.
n_estimators: int, default=8
The number of estimators in the TabPFN ensemble. We aggregate the
predictions of `n_estimators`-many forward passes of TabPFN. Each forward
Expand Down Expand Up @@ -373,6 +313,30 @@ def __init__(
self.last_train_y = None
self.last_meta = {}

@classmethod
def create_default_for_version(cls, version: ModelVersion, **overrides) -> Self:
"""Construct an estimator that uses the given version of the model.

In addition to selecting the model, this also configures the estimator with
certain default settings associated with this model version.

Any kwargs will override the default settings.
"""
options = {
"n_estimators": 8,
"softmax_temperature": 0.9,
}
if version == ModelVersion.V2:
options["model_path"] = _default_model_path_v2("regression")
elif version == ModelVersion.V2_5:
options["model_path"] = _default_model_path_v2_5("regression")
else:
raise ValueError(f"Unknown version: {version}")

options.update(overrides)

return cls(**options)

def fit(self, X, y):
# assert init() is called
init()
Expand All @@ -383,9 +347,7 @@ def fit(self, X, y):
_check_paper_version(self.paper_version, X)

estimator_param = self.get_params()
estimator_param["model_path"] = TabPFNRegressor._model_name_to_path(
"regression", self.model_path
)
estimator_param["model_path"] = self.model_path
if Config.use_server:
self.last_train_set_uid = InferenceClient.fit(X, y, config=estimator_param)
self.last_train_X = X
Expand Down Expand Up @@ -441,9 +403,7 @@ def predict(
}

estimator_param = self.get_params()
estimator_param["model_path"] = TabPFNRegressor._model_name_to_path(
"regression", self.model_path
)
estimator_param["model_path"] = self.model_path

result: PredictionResult = InferenceClient.predict(
X,
Expand Down
Loading
Loading