Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion src/tabpfn_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ def __new__(cls, *args, **kwargs):


def init(use_server=True):
"""
Initializes the TabPFN client and handles user authentication.

This function checks for existing credentials, prompts for login/registration
if necessary, and verifies email status. It must be called before performing
any inference tasks.

:param use_server: Whether to use the TabPFN cloud service. Currently, only
True is supported.
:raises RuntimeError: If local inference is requested or if the server is unreachable.
"""
# initialize config
Config.use_server = use_server

Expand Down Expand Up @@ -91,6 +102,12 @@ def init(use_server=True):


def reset():
"""
Resets the client state and clears local authentication caches.

Use this function if you need to log out or clear stored session data
from the local machine.
"""
Config.is_initialized = False
# reset user auth handler
if Config.use_server:
Expand All @@ -101,16 +118,42 @@ def reset():


def get_access_token() -> str:
"""
Retrieves the current active access token.

If the client is not yet initialized, this will trigger the `init()` login flow.

:return: The access token string used for API requests.
"""
init()
return ServiceClient.get_access_token()


def set_access_token(access_token: str):
"""
Manually sets the access token for the session.

This is useful for non-interactive environments
(e.g., CI/CD, Notebooks) where you want to skip
the interactive login prompt.

You can obtain your access token via the TabPFN
UX as explained at:
https://docs.priorlabs.ai/api-reference/getting-started#1-get-your-access-token

:param access_token: A valid TabPFN access token string.
"""
UserAuthenticationClient.set_token(access_token)
Config.is_initialized = True


def get_api_usage() -> dict:
def get_api_usage() -> str:
"""
Fetches and formats the current API usage statistics for the user.

:return: A human-readable string detailing current credit usage,
the total limit, and when the limit resets.
"""
access_token = get_access_token()
response = ServiceClient.get_api_usage(access_token)
return f"Currently, you have used {response['current_usage']} of the allowed limit of {'Unlimited' if int(response['usage_limit']) == -1 else response['usage_limit']} credits. The limit will reset at {response['reset_time']}."