From 5a222eb41667d91757361fe61a0039292b6da96b Mon Sep 17 00:00:00 2001 From: Clayton Rosenthal Date: Wed, 19 Jul 2023 23:58:54 -0700 Subject: [PATCH 1/3] new-feature: oauth support --- examples/devices.py | 10 ++++- examples/oauth.py | 30 +++++++++++++++ src/tailscale/tailscale.py | 78 ++++++++++++++++++++++++++++++++++---- tests/test_tailscale.py | 78 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 186 insertions(+), 10 deletions(-) mode change 100644 => 100755 examples/devices.py create mode 100755 examples/oauth.py diff --git a/examples/devices.py b/examples/devices.py old mode 100644 new mode 100755 index 156abfc6..29723565 --- a/examples/devices.py +++ b/examples/devices.py @@ -1,16 +1,22 @@ +#!/usr/bin/env python3 # pylint: disable=W0621 """Asynchronous client for the Tailscale API.""" import asyncio +import os from tailscale import Tailscale +# "-" is the default tailnet of the API key +TAILNET = os.environ.get("TS_TAILNET", "-") +API_KEY = os.environ.get("TS_API_KEY", "") + async def main() -> None: """Show example on using the Tailscale API client.""" async with Tailscale( - tailnet="frenck", - api_key="tskey-somethingsomething", + tailnet=TAILNET, + api_key=API_KEY, ) as tailscale: devices = await tailscale.devices() print(devices) diff --git a/examples/oauth.py b/examples/oauth.py new file mode 100755 index 00000000..edd188b9 --- /dev/null +++ b/examples/oauth.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# pylint: disable=W0621 +"""Asynchronous client for the Tailscale API.""" + +import asyncio +import os + +from tailscale import Tailscale + +# "-" is the default tailnet of the API key +TAILNET = os.environ.get("TS_TAILNET", "-") + +# OAuth client ID and secret are required for OAuth authentication +OAUTH_CLIENT_ID = os.environ.get("TS_API_CLIENT_ID", "") +OAUTH_CLIENT_SECRET = os.environ.get("TS_API_CLIENT_SECRET", "") + + +async def main_oauth() -> None: + """Show example on using the Tailscale API client with OAuth.""" + async with Tailscale( + tailnet=TAILNET, + oauth_client_id=OAUTH_CLIENT_ID, + oauth_client_secret=OAUTH_CLIENT_SECRET, + ) as tailscale: + devices = await tailscale.devices() + print(devices) + + +if __name__ == "__main__": + asyncio.run(main_oauth()) diff --git a/src/tailscale/tailscale.py b/src/tailscale/tailscale.py index 305a52ea..4daadeed 100644 --- a/src/tailscale/tailscale.py +++ b/src/tailscale/tailscale.py @@ -3,13 +3,13 @@ from __future__ import annotations import asyncio +import json import socket from dataclasses import dataclass from typing import Any, Self -from aiohttp import BasicAuth from aiohttp.client import ClientError, ClientResponseError, ClientSession -from aiohttp.hdrs import METH_GET +from aiohttp.hdrs import METH_GET, METH_POST from yarl import URL from .exceptions import ( @@ -19,25 +19,80 @@ ) from .models import Device, Devices +# Placeholder value for the access token when it is not yet set. +ACCESS_TOKEN_PENDING = "" # noqa: S105 + @dataclass class Tailscale: """Main class for handling connections with the Tailscale API.""" - tailnet: str - api_key: str + # tailnet of '-' is the default tailnet of the API key + tailnet: str = "-" + api_key: str = "" # nosec + oauth_client_id: str = "" # nosec + oauth_client_secret: str = "" # nosec request_timeout: int = 8 session: ClientSession | None = None _close_session: bool = False + async def _check_access(self) -> None: + """Initialize the Tailscale client. + + Raises: + TailscaleAuthenticationError: when neither api_key nor oauth_client_id and + oauth_client_secret are provided. + + """ + if ( + not self.api_key + and not self.oauth_client_id + and not self.oauth_client_secret + ): + msg = "Either api_key or oauth client is required" + raise TailscaleAuthenticationError(msg) + if not self.api_key: + self.api_key = ACCESS_TOKEN_PENDING + self.api_key = await self._get_oauth_token() + + async def _get_oauth_token(self) -> str: + """Get an OAuth token from the Tailscale API. + + Raises: + TailscaleAuthenticationError: when access key not found in response. + + Returns: + A string with the OAuth token, or nothing on error + + """ + # Tailscale's OAuth endpoint requires form-encoded body + # with client_id and client_secret + data = { + "client_id": self.oauth_client_id, + "client_secret": self.oauth_client_secret, + } + response = await self._request( + "oauth/token", + data=data, + method=METH_POST, + use_form_encoding=True, + ) + + token = json.loads(response).get("access_token", "") + if not token: + msg = "Failed to get OAuth token" + raise TailscaleAuthenticationError(msg) + return str(token) + async def _request( self, uri: str, *, method: str = METH_GET, data: dict[str, Any] | None = None, + use_form_encoding: bool = False, ) -> str: """Handle a request to the Tailscale API. @@ -66,22 +121,29 @@ async def _request( """ url = URL("https://api.tailscale.com/api/v2/").join(URL(uri)) - headers = { + await self._check_access() + + headers: dict[str, str] = { "Accept": "application/json", } + if self.api_key and self.api_key != ACCESS_TOKEN_PENDING: + # API keys and oauth tokens can use Bearer authentication + headers["Authorization"] = f"Bearer {self.api_key}" + if self.session is None: self.session = ClientSession() self._close_session = True try: async with asyncio.timeout(self.request_timeout): + # Use form encoding for OAuth token requests, JSON for others response = await self.session.request( method, url, - json=data, - auth=BasicAuth(self.api_key), - headers=headers, + headers=headers if headers else None, + data=data if use_form_encoding else None, + json=data if not use_form_encoding else None, ) response.raise_for_status() except asyncio.TimeoutError as exception: diff --git a/tests/test_tailscale.py b/tests/test_tailscale.py index d3276819..68935118 100644 --- a/tests/test_tailscale.py +++ b/tests/test_tailscale.py @@ -15,6 +15,84 @@ ) +@pytest.mark.asyncio +async def test_no_access() -> None: + """Test api key or oauth key is checked correctly.""" + async with Tailscale(tailnet="frenck") as tailscale: + with pytest.raises(TailscaleAuthenticationError): + assert await tailscale._request("test") + + +@pytest.mark.asyncio +async def test_key_from_oauth(aresponses: ResponsesMockServer) -> None: + """Test oauth key response is handled correctly.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token"}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", # nosec + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + await tailscale._request("test") + second_request = aresponses.history[1].request + assert "Bearer" in second_request.headers["Authorization"] + await tailscale.close() + + aresponses.assert_plan_strictly_followed() + + +@pytest.mark.asyncio +async def test_bad_oauth(aresponses: ResponsesMockServer) -> None: + """Test bad oauth error is handled correctly.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"no_access_token": "unauthorized"}', + ), + ) + + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", # nosec + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + with pytest.raises(TailscaleAuthenticationError) as excinfo: + assert await tailscale._request("test") + + assert excinfo.value.args[0] == "Failed to get OAuth token" + + await tailscale.close() + + aresponses.assert_plan_strictly_followed() + + +@pytest.mark.asyncio async def test_json_request(aresponses: ResponsesMockServer) -> None: """Test JSON response is handled correctly.""" aresponses.add( From 50ceacc3e7f46891c307361856ea94e5a9587b56 Mon Sep 17 00:00:00 2001 From: Laszlo Magyar Date: Tue, 6 Jan 2026 18:14:44 +0100 Subject: [PATCH 2/3] fix oauth - handle race condition in _check_api_key() - handle token expiration - handle token invalidation - fix argument validation condition - fix attribute initialization - remove magic constant - add tests --- src/tailscale/tailscale.py | 110 ++++++++++----- tests/test_tailscale.py | 278 +++++++++++++++++++++++++++++++++++-- 2 files changed, 345 insertions(+), 43 deletions(-) diff --git a/src/tailscale/tailscale.py b/src/tailscale/tailscale.py index 4daadeed..ef43f387 100644 --- a/src/tailscale/tailscale.py +++ b/src/tailscale/tailscale.py @@ -19,26 +19,26 @@ ) from .models import Device, Devices -# Placeholder value for the access token when it is not yet set. -ACCESS_TOKEN_PENDING = "" # noqa: S105 - @dataclass +# pylint: disable-next=too-many-instance-attributes class Tailscale: """Main class for handling connections with the Tailscale API.""" # tailnet of '-' is the default tailnet of the API key tailnet: str = "-" - api_key: str = "" # nosec - oauth_client_id: str = "" # nosec - oauth_client_secret: str = "" # nosec + api_key: str | None = None + oauth_client_id: str | None = None + oauth_client_secret: str | None = None request_timeout: int = 8 session: ClientSession | None = None + _get_oauth_token_task: asyncio.Task[None] | None = None + _expire_oauth_token_task: asyncio.Task[None] | None = None _close_session: bool = False - async def _check_access(self) -> None: + async def _check_api_key(self) -> None: """Initialize the Tailscale client. Raises: @@ -46,25 +46,44 @@ async def _check_access(self) -> None: oauth_client_secret are provided. """ - if ( - not self.api_key - and not self.oauth_client_id - and not self.oauth_client_secret + if not ( + (self.api_key and not self.oauth_client_id and not self.oauth_client_secret) + or (not self.api_key and self.oauth_client_id and self.oauth_client_secret) + or ( + self.api_key + and self.oauth_client_id + and self.oauth_client_secret + and self._get_oauth_token_task + ) ): - msg = "Either api_key or oauth client is required" + msg = ( + "Either api_key or oauth_client_id and oauth_client_secret " + "are required when Tailscale client is initialized" + ) raise TailscaleAuthenticationError(msg) if not self.api_key: - self.api_key = ACCESS_TOKEN_PENDING - self.api_key = await self._get_oauth_token() + # Handle some inconsistent state + # possibly caused by user manually deleting api_key + if self._expire_oauth_token_task: + self._expire_oauth_token_task.cancel() + self._expire_oauth_token_task = None + if self._get_oauth_token_task: + self._get_oauth_token_task.cancel() + self._get_oauth_token_task = None + # Get a new OAuth token if not already in the process of getting one + if not self._get_oauth_token_task: + self._get_oauth_token_task = asyncio.create_task( + self._get_oauth_token() + ) + # Wait for the OAuth token to be retrieved + await self._get_oauth_token_task - async def _get_oauth_token(self) -> str: + async def _get_oauth_token(self) -> None: """Get an OAuth token from the Tailscale API. Raises: - TailscaleAuthenticationError: when access key not found in response. - - Returns: - A string with the OAuth token, or nothing on error + TailscaleAuthenticationError: when access token not found in response or + access token expires in less than 5 minutes. """ # Tailscale's OAuth endpoint requires form-encoded body @@ -77,14 +96,31 @@ async def _get_oauth_token(self) -> str: "oauth/token", data=data, method=METH_POST, - use_form_encoding=True, + _use_authentication=False, + _use_form_encoding=True, ) - token = json.loads(response).get("access_token", "") - if not token: + json_response = json.loads(response) + access_token = str(json_response.get("access_token", "")) + expires_in = float(json_response.get("expires_in", 0)) + if not access_token or not expires_in: msg = "Failed to get OAuth token" raise TailscaleAuthenticationError(msg) - return str(token) + if expires_in <= 60: + msg = "OAuth token expires in less than 1 minute" + raise TailscaleAuthenticationError(msg) + + self._expire_oauth_token_task = asyncio.create_task( + self._expire_oauth_token(expires_in) + ) + self.api_key = access_token + + async def _expire_oauth_token(self, expires_in: float) -> None: + """Expires the OAuth token 1 minute before expiration.""" + await asyncio.sleep(expires_in - 60) + self.api_key = None + self._get_oauth_token_task = None + self._expire_oauth_token_task = None async def _request( self, @@ -92,7 +128,8 @@ async def _request( *, method: str = METH_GET, data: dict[str, Any] | None = None, - use_form_encoding: bool = False, + _use_authentication: bool = True, + _use_form_encoding: bool = False, ) -> str: """Handle a request to the Tailscale API. @@ -107,8 +144,7 @@ async def _request( Returns: ------- - A Python dictionary (JSON decoded) with the response from - the Tailscale API. + The response from the Tailscale API. Raises: ------ @@ -121,13 +157,12 @@ async def _request( """ url = URL("https://api.tailscale.com/api/v2/").join(URL(uri)) - await self._check_access() - headers: dict[str, str] = { "Accept": "application/json", } - if self.api_key and self.api_key != ACCESS_TOKEN_PENDING: + if _use_authentication: + await self._check_api_key() # API keys and oauth tokens can use Bearer authentication headers["Authorization"] = f"Bearer {self.api_key}" @@ -142,8 +177,8 @@ async def _request( method, url, headers=headers if headers else None, - data=data if use_form_encoding else None, - json=data if not use_form_encoding else None, + data=data if _use_form_encoding else None, + json=data if not _use_form_encoding else None, ) response.raise_for_status() except asyncio.TimeoutError as exception: @@ -151,6 +186,13 @@ async def _request( raise TailscaleConnectionError(msg) from exception except ClientResponseError as exception: if exception.status in [401, 403]: + if _use_authentication and self.api_key and self.oauth_client_id: + # Invalidate the current OAuth token + self.api_key = None + self._get_oauth_token_task = None + if self._expire_oauth_token_task: + self._expire_oauth_token_task.cancel() + self._expire_oauth_token_task = None msg = "Authentication to the Tailscale API failed" raise TailscaleAuthenticationError(msg) from exception msg = "Error occurred while connecting to the Tailscale API" @@ -176,9 +218,13 @@ async def devices(self) -> dict[str, Device]: return Devices.from_json(data).devices async def close(self) -> None: - """Close open client session.""" + """Close open client session and cancel tasks.""" if self.session and self._close_session: await self.session.close() + if self._get_oauth_token_task: + self._get_oauth_token_task.cancel() + if self._expire_oauth_token_task: + self._expire_oauth_token_task.cancel() async def __aenter__(self) -> Self: """Async enter. diff --git a/tests/test_tailscale.py b/tests/test_tailscale.py index 68935118..bb21fe80 100644 --- a/tests/test_tailscale.py +++ b/tests/test_tailscale.py @@ -15,15 +15,48 @@ ) -@pytest.mark.asyncio -async def test_no_access() -> None: +async def test_wrong_arguments_no_auth() -> None: """Test api key or oauth key is checked correctly.""" - async with Tailscale(tailnet="frenck") as tailscale: - with pytest.raises(TailscaleAuthenticationError): + async with Tailscale() as tailscale: + with pytest.raises(TailscaleAuthenticationError) as excinfo: assert await tailscale._request("test") + assert excinfo.value.args[0] == ( + "Either api_key or oauth_client_id and oauth_client_secret " + "are required when Tailscale client is initialized" + ) + + +async def test_wrong_arguments_both_auth() -> None: + """Test api key or oauth key is checked correctly.""" + async with Tailscale( + api_key="abc", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + ) as tailscale: + with pytest.raises(TailscaleAuthenticationError) as excinfo: + assert await tailscale._request("test") + + assert excinfo.value.args[0] == ( + "Either api_key or oauth_client_id and oauth_client_secret " + "are required when Tailscale client is initialized" + ) + + +async def test_wrong_arguments_partial_oauth() -> None: + """Test api key or oauth key is checked correctly.""" + async with Tailscale( + oauth_client_id="client", + ) as tailscale: + with pytest.raises(TailscaleAuthenticationError) as excinfo: + assert await tailscale._request("test") + + assert excinfo.value.args[0] == ( + "Either api_key or oauth_client_id and oauth_client_secret " + "are required when Tailscale client is initialized" + ) + -@pytest.mark.asyncio async def test_key_from_oauth(aresponses: ResponsesMockServer) -> None: """Test oauth key response is handled correctly.""" aresponses.add( @@ -33,7 +66,7 @@ async def test_key_from_oauth(aresponses: ResponsesMockServer) -> None: aresponses.Response( status=200, headers={"Content-Type": "application/json"}, - text='{"access_token": "short-lived-token"}', + text='{"access_token": "short-lived-token", "expires_in": 3600}', ), ) aresponses.add( @@ -49,7 +82,7 @@ async def test_key_from_oauth(aresponses: ResponsesMockServer) -> None: async with aiohttp.ClientSession() as session: tailscale = Tailscale( tailnet="frenck", - oauth_client_id="client", # nosec + oauth_client_id="client", oauth_client_secret="notsosecret", # noqa: S106 session=session, ) @@ -61,9 +94,165 @@ async def test_key_from_oauth(aresponses: ResponsesMockServer) -> None: aresponses.assert_plan_strictly_followed() -@pytest.mark.asyncio +async def test_key_from_oauth_with_race_condition( + aresponses: ResponsesMockServer, +) -> None: + """Test oauth key request is sent out only once.""" + + async def oauth_handler(_: aiohttp.ClientResponse) -> Response: + """Response handler emulating slow oauth response.""" + await asyncio.sleep(1) + return aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token", "expires_in": 3600}', + ) + + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + oauth_handler, + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + + first_task = asyncio.create_task(tailscale._request("test")) + second_task = asyncio.create_task(tailscale._request("test")) + await asyncio.gather(first_task, second_task) + await tailscale.close() + + aresponses.assert_plan_strictly_followed() + + +async def test_new_key_from_oauth_on_manual_invalidation( + aresponses: ResponsesMockServer, +) -> None: + """Test oauth key manual invalidation is handled correctly.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token", "expires_in": 3600}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token", "expires_in": 3600}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + await tailscale._request("test") + tailscale.api_key = None # Manual invalidation + await tailscale._request("test") + await tailscale.close() + + aresponses.assert_plan_strictly_followed() + + +async def test_oauth_key_expiration(aresponses: ResponsesMockServer) -> None: + """Test oauth key expiration.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token", "expires_in": 61}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + await tailscale._request("test") + assert tailscale.api_key == "short-lived-token" + assert tailscale._get_oauth_token_task is not None + assert tailscale._expire_oauth_token_task is not None + await asyncio.sleep(2) + assert tailscale.api_key is None + assert tailscale._get_oauth_token_task is None + assert tailscale._expire_oauth_token_task is None + await tailscale.close() + + aresponses.assert_plan_strictly_followed() + + async def test_bad_oauth(aresponses: ResponsesMockServer) -> None: - """Test bad oauth error is handled correctly.""" + """Test bad oauth response is handled correctly.""" aresponses.add( "api.tailscale.com", "/api/v2/oauth/token", @@ -78,7 +267,7 @@ async def test_bad_oauth(aresponses: ResponsesMockServer) -> None: async with aiohttp.ClientSession() as session: tailscale = Tailscale( tailnet="frenck", - oauth_client_id="client", # nosec + oauth_client_id="client", oauth_client_secret="notsosecret", # noqa: S106 session=session, ) @@ -92,7 +281,34 @@ async def test_bad_oauth(aresponses: ResponsesMockServer) -> None: aresponses.assert_plan_strictly_followed() -@pytest.mark.asyncio +async def test_too_short_oauth_expiration(aresponses: ResponsesMockServer) -> None: + """Test too short oauth expiration is handled correctly.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token", "expires_in": 60}', + ), + ) + + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + with pytest.raises(TailscaleAuthenticationError) as excinfo: + assert await tailscale._request("test") + + assert excinfo.value.args[0] == "OAuth token expires in less than 1 minute" + + await tailscale.close() + + async def test_json_request(aresponses: ResponsesMockServer) -> None: """Test JSON response is handled correctly.""" aresponses.add( @@ -201,3 +417,43 @@ async def test_http_error401(aresponses: ResponsesMockServer) -> None: tailscale = Tailscale(tailnet="frenck", api_key="abc", session=session) with pytest.raises(TailscaleAuthenticationError): assert await tailscale._request("test") + + +async def test_http_error401_and_oauth_token_invalidation( + aresponses: ResponsesMockServer, +) -> None: + """Test HTTP 401 response handling and oauth token invalidation.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token", "expires_in": 3600}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response(text="Access denied!", status=401), + ) + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + with pytest.raises(TailscaleAuthenticationError) as excinfo: + assert await tailscale._request("test") + + assert excinfo.value.args[0] == "Authentication to the Tailscale API failed" + assert tailscale.api_key is None + assert tailscale._get_oauth_token_task is None + assert tailscale._expire_oauth_token_task is None + + await tailscale.close() + + aresponses.assert_plan_strictly_followed() From c78604e8f08c5e093efa404fe9ed606bd5ef447e Mon Sep 17 00:00:00 2001 From: Laszlo Magyar Date: Wed, 7 Jan 2026 19:03:00 +0100 Subject: [PATCH 3/3] add token storage functionality --- src/tailscale/__init__.py | 2 ++ src/tailscale/storage.py | 29 +++++++++++++++ src/tailscale/tailscale.py | 24 +++++++++++-- tests/storage.py | 27 ++++++++++++++ tests/test_tailscale.py | 74 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 154 insertions(+), 2 deletions(-) create mode 100644 src/tailscale/storage.py create mode 100644 tests/storage.py diff --git a/src/tailscale/__init__.py b/src/tailscale/__init__.py index 5000b29a..a2382f09 100644 --- a/src/tailscale/__init__.py +++ b/src/tailscale/__init__.py @@ -6,6 +6,7 @@ TailscaleError, ) from .models import ClientConnectivity, ClientSupports, Device, Devices +from .storage import TokenStorage from .tailscale import Tailscale __all__ = [ @@ -17,4 +18,5 @@ "TailscaleAuthenticationError", "TailscaleConnectionError", "TailscaleError", + "TokenStorage", ] diff --git a/src/tailscale/storage.py b/src/tailscale/storage.py new file mode 100644 index 00000000..36cf50ba --- /dev/null +++ b/src/tailscale/storage.py @@ -0,0 +1,29 @@ +"""Abstract token storage.""" + +from abc import ABC, abstractmethod +from datetime import datetime + + +class TokenStorage(ABC): + """Abstract class for token storage implementations.""" + + @abstractmethod + async def get_token(self) -> tuple[str, datetime] | None: + """Get the stored token. + + Returns: + The stored token and expiration time, or None if no token is stored. + + """ + raise NotImplementedError + + @abstractmethod + async def set_token(self, access_token: str, expires_at: datetime) -> None: + """Store the given token. + + Args: + access_token: The access token to store. + expires_at: The expiration time of the access token. + + """ + raise NotImplementedError diff --git a/src/tailscale/tailscale.py b/src/tailscale/tailscale.py index ef43f387..6319a3e3 100644 --- a/src/tailscale/tailscale.py +++ b/src/tailscale/tailscale.py @@ -6,7 +6,8 @@ import json import socket from dataclasses import dataclass -from typing import Any, Self +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any, Self from aiohttp.client import ClientError, ClientResponseError, ClientSession from aiohttp.hdrs import METH_GET, METH_POST @@ -19,6 +20,9 @@ ) from .models import Device, Devices +if TYPE_CHECKING: + from .storage import TokenStorage + @dataclass # pylint: disable-next=too-many-instance-attributes @@ -33,6 +37,7 @@ class Tailscale: request_timeout: int = 8 session: ClientSession | None = None + token_storage: TokenStorage | None = None _get_oauth_token_task: asyncio.Task[None] | None = None _expire_oauth_token_task: asyncio.Task[None] | None = None @@ -79,13 +84,25 @@ async def _check_api_key(self) -> None: await self._get_oauth_token_task async def _get_oauth_token(self) -> None: - """Get an OAuth token from the Tailscale API. + """Get an OAuth token from the Tailscale API or token storage. Raises: TailscaleAuthenticationError: when access token not found in response or access token expires in less than 5 minutes. """ + if self.token_storage: + token_data = await self.token_storage.get_token() + if token_data: + access_token, expires_at = token_data + expires_in = (expires_at - datetime.now(timezone.utc)).total_seconds() + if expires_in > 60: + self._expire_oauth_token_task = asyncio.create_task( + self._expire_oauth_token(expires_in) + ) + self.api_key = access_token + return + # Tailscale's OAuth endpoint requires form-encoded body # with client_id and client_secret data = { @@ -113,6 +130,9 @@ async def _get_oauth_token(self) -> None: self._expire_oauth_token_task = asyncio.create_task( self._expire_oauth_token(expires_in) ) + if self.token_storage: + expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + await self.token_storage.set_token(access_token, expires_at) self.api_key = access_token async def _expire_oauth_token(self, expires_in: float) -> None: diff --git a/tests/storage.py b/tests/storage.py new file mode 100644 index 00000000..e17114c5 --- /dev/null +++ b/tests/storage.py @@ -0,0 +1,27 @@ +"""Dummy token storage.""" + +from datetime import datetime + +from tailscale.storage import TokenStorage + + +class InMemoryTokenStorage(TokenStorage): + """In-memory token storage for testing purposes.""" + + def __init__( + self, access_token: str | None = None, expires_at: datetime | None = None + ) -> None: + """Initialize the in-memory token storage.""" + self._access_token = access_token + self._expires_at = expires_at + + async def get_token(self) -> tuple[str, datetime] | None: + """Get the stored token.""" + if self._access_token and self._expires_at: + return self._access_token, self._expires_at + return None + + async def set_token(self, access_token: str, expires_at: datetime) -> None: + """Store the token.""" + self._access_token = access_token + self._expires_at = expires_at diff --git a/tests/test_tailscale.py b/tests/test_tailscale.py index bb21fe80..f31fc758 100644 --- a/tests/test_tailscale.py +++ b/tests/test_tailscale.py @@ -2,6 +2,7 @@ # pylint: disable=protected-access import asyncio +from datetime import datetime, timedelta, timezone import aiohttp import pytest @@ -13,6 +14,7 @@ TailscaleConnectionError, TailscaleError, ) +from tests.storage import InMemoryTokenStorage async def test_wrong_arguments_no_auth() -> None: @@ -251,6 +253,78 @@ async def test_oauth_key_expiration(aresponses: ResponsesMockServer) -> None: aresponses.assert_plan_strictly_followed() +async def test_key_from_storage(aresponses: ResponsesMockServer) -> None: + """Test oauth key is loaded from storage.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + token_storage=InMemoryTokenStorage( + "stored-token", datetime.now(timezone.utc) + timedelta(hours=1) + ), + ) + await tailscale._request("test") + first_request = aresponses.history[0].request + assert "Bearer" in first_request.headers["Authorization"] + assert "stored-token" in first_request.headers["Authorization"] + await tailscale.close() + + +async def test_drop_key_from_storage(aresponses: ResponsesMockServer) -> None: + """Test oauth key response is handled correctly.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token", "expires_in": 3600}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + async with aiohttp.ClientSession() as session: + token_storage = InMemoryTokenStorage( + "stored-token", datetime.now(timezone.utc) + timedelta(seconds=30) + ) + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + token_storage=token_storage, + ) + await tailscale._request("test") + second_request = aresponses.history[1].request + assert "Bearer" in second_request.headers["Authorization"] + assert "short-lived-token" in second_request.headers["Authorization"] + assert token_storage._access_token == "short-lived-token" # noqa: S105 + await tailscale.close() + + aresponses.assert_plan_strictly_followed() + + async def test_bad_oauth(aresponses: ResponsesMockServer) -> None: """Test bad oauth response is handled correctly.""" aresponses.add(