Skip to content

Error when the train set has 1 example #73

@LeoGrin

Description

@LeoGrin

Example to reproduce:

from tabpfn_client import TabPFNClassifier
import numpy as np
import pandas as pd

# Create minimal example with just one training sample and two features
X_train = pd.DataFrame({
    "feature1": [0.5],
    "feature2": [0.7]
})
X_test = pd.DataFrame({
    "feature1": np.random.rand(10),
    "feature2": np.random.rand(10)
})

# Single training label and 10 test labels
y_train = np.array([1])  # Single class label
y_test = np.random.randint(0, 2, size=10)  # Random binary labels for testing

# Initialize and fit TabPFN
model = TabPFNClassifier()
model.fit(X_train, y_train)

# Make predictions
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)

# Calculate accuracy
accuracy = np.mean(y_pred == y_test)
print(f"Test accuracy: {accuracy:.4f}")

Traceback:

ERROR:tabpfn_client.client:Fail to call fit, response status: 500
Traceback (most recent call last):
  File "/scratch/lgrinszt/lm_tab/scripts/../test_one_example.py", line 24, in <module>
    model.fit(X_train, y_train)
  File "/scratch/lgrinszt/micromamba/envs/lm_tab/lib/python3.10/site-packages/tabpfn_client/estimator.py", line 146, in fit
    self.last_train_set_uid = InferenceClient.fit(X, y, config=estimator_param)
  File "/scratch/lgrinszt/micromamba/envs/lm_tab/lib/python3.10/site-packages/tabpfn_client/service_wrapper.py", line 225, in fit
    return ServiceClient.fit(X, y, config=config)
  File "/scratch/lgrinszt/micromamba/envs/lm_tab/lib/python3.10/site-packages/tabpfn_client/client.py", line 237, in fit
    cls._validate_response(response, "fit")
  File "/scratch/lgrinszt/micromamba/envs/lm_tab/lib/python3.10/site-packages/tabpfn_client/client.py", line 477, in _validate_response
    raise RuntimeError(
RuntimeError: Fail to call fit with error: 500, reason: Internal Server Error and text: Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/starlette/middleware/errors.py", line 165, in __call__
    await self.app(scope, receive, _send)
  File "/usr/local/lib/python3.10/site-packages/starlette/middleware/exceptions.py", line 62, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/usr/local/lib/python3.10/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
    raise exc
  File "/usr/local/lib/python3.10/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
    await app(scope, receive, sender)
  File "/usr/local/lib/python3.10/site-packages/starlette/routing.py", line 715, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.10/site-packages/starlette/routing.py", line 735, in app
    await route.handle(scope, receive, send)
  File "/usr/local/lib/python3.10/site-packages/starlette/routing.py", line 288, in handle
    await self.app(scope, receive, send)
  File "/usr/local/lib/python3.10/site-packages/starlette/routing.py", line 76, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/usr/local/lib/python3.10/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
    raise exc
  File "/usr/local/lib/python3.10/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
    await app(scope, receive, sender)
  File "/usr/local/lib/python3.10/site-packages/starlette/routing.py", line 73, in app
    response = await f(request)
  File "/usr/local/lib/python3.10/site-packages/fastapi/routing.py", line 301, in app
    raw_response = await run_endpoint_function(
  File "/usr/local/lib/python3.10/site-packages/fastapi/routing.py", line 212, in run_endpoint_function
    return await dependant.call(**values)
  File "/code/tabpfn-server/app/routers/fit.py", line 70, in fit
    train_set_schema = await upload_train_set(
  File "/code/tabpfn-server/app/routers/fit.py", line 39, in upload_train_set
    user_train_set_mapping = await dataset_serv.add_train_set(
  File "/code/tabpfn-server/app/services/dataset_repo_service.py", line 327, in add_train_set
    content[FileType.Y_TRAIN] = self.preprocess_y_train(content[FileType.Y_TRAIN])
  File "/code/tabpfn-server/app/services/dataset_repo_service.py", line 312, in preprocess_y_train
    return y_train.to_csv(index=False).encode()
AttributeError: 'numpy.int64' object has no attribute 'to_csv'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions