diff --git a/examples/devices.py b/examples/devices.py old mode 100644 new mode 100755 index 156abfc6..29723565 --- a/examples/devices.py +++ b/examples/devices.py @@ -1,16 +1,22 @@ +#!/usr/bin/env python3 # pylint: disable=W0621 """Asynchronous client for the Tailscale API.""" import asyncio +import os from tailscale import Tailscale +# "-" is the default tailnet of the API key +TAILNET = os.environ.get("TS_TAILNET", "-") +API_KEY = os.environ.get("TS_API_KEY", "") + async def main() -> None: """Show example on using the Tailscale API client.""" async with Tailscale( - tailnet="frenck", - api_key="tskey-somethingsomething", + tailnet=TAILNET, + api_key=API_KEY, ) as tailscale: devices = await tailscale.devices() print(devices) diff --git a/examples/oauth.py b/examples/oauth.py new file mode 100755 index 00000000..edd188b9 --- /dev/null +++ b/examples/oauth.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# pylint: disable=W0621 +"""Asynchronous client for the Tailscale API.""" + +import asyncio +import os + +from tailscale import Tailscale + +# "-" is the default tailnet of the API key +TAILNET = os.environ.get("TS_TAILNET", "-") + +# OAuth client ID and secret are required for OAuth authentication +OAUTH_CLIENT_ID = os.environ.get("TS_API_CLIENT_ID", "") +OAUTH_CLIENT_SECRET = os.environ.get("TS_API_CLIENT_SECRET", "") + + +async def main_oauth() -> None: + """Show example on using the Tailscale API client with OAuth.""" + async with Tailscale( + tailnet=TAILNET, + oauth_client_id=OAUTH_CLIENT_ID, + oauth_client_secret=OAUTH_CLIENT_SECRET, + ) as tailscale: + devices = await tailscale.devices() + print(devices) + + +if __name__ == "__main__": + asyncio.run(main_oauth()) diff --git a/src/tailscale/__init__.py b/src/tailscale/__init__.py index 5000b29a..a2382f09 100644 --- a/src/tailscale/__init__.py +++ b/src/tailscale/__init__.py @@ -6,6 +6,7 @@ TailscaleError, ) from .models import ClientConnectivity, ClientSupports, Device, Devices +from .storage import TokenStorage from .tailscale import Tailscale __all__ = [ @@ -17,4 +18,5 @@ "TailscaleAuthenticationError", "TailscaleConnectionError", "TailscaleError", + "TokenStorage", ] diff --git a/src/tailscale/storage.py b/src/tailscale/storage.py new file mode 100644 index 00000000..36cf50ba --- /dev/null +++ b/src/tailscale/storage.py @@ -0,0 +1,29 @@ +"""Abstract token storage.""" + +from abc import ABC, abstractmethod +from datetime import datetime + + +class TokenStorage(ABC): + """Abstract class for token storage implementations.""" + + @abstractmethod + async def get_token(self) -> tuple[str, datetime] | None: + """Get the stored token. + + Returns: + The stored token and expiration time, or None if no token is stored. + + """ + raise NotImplementedError + + @abstractmethod + async def set_token(self, access_token: str, expires_at: datetime) -> None: + """Store the given token. + + Args: + access_token: The access token to store. + expires_at: The expiration time of the access token. + + """ + raise NotImplementedError diff --git a/src/tailscale/tailscale.py b/src/tailscale/tailscale.py index 305a52ea..6319a3e3 100644 --- a/src/tailscale/tailscale.py +++ b/src/tailscale/tailscale.py @@ -3,13 +3,14 @@ from __future__ import annotations import asyncio +import json import socket from dataclasses import dataclass -from typing import Any, Self +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any, Self -from aiohttp import BasicAuth from aiohttp.client import ClientError, ClientResponseError, ClientSession -from aiohttp.hdrs import METH_GET +from aiohttp.hdrs import METH_GET, METH_POST from yarl import URL from .exceptions import ( @@ -19,25 +20,136 @@ ) from .models import Device, Devices +if TYPE_CHECKING: + from .storage import TokenStorage + @dataclass +# pylint: disable-next=too-many-instance-attributes class Tailscale: """Main class for handling connections with the Tailscale API.""" - tailnet: str - api_key: str + # tailnet of '-' is the default tailnet of the API key + tailnet: str = "-" + api_key: str | None = None + oauth_client_id: str | None = None + oauth_client_secret: str | None = None request_timeout: int = 8 session: ClientSession | None = None + token_storage: TokenStorage | None = None + _get_oauth_token_task: asyncio.Task[None] | None = None + _expire_oauth_token_task: asyncio.Task[None] | None = None _close_session: bool = False + async def _check_api_key(self) -> None: + """Initialize the Tailscale client. + + Raises: + TailscaleAuthenticationError: when neither api_key nor oauth_client_id and + oauth_client_secret are provided. + + """ + if not ( + (self.api_key and not self.oauth_client_id and not self.oauth_client_secret) + or (not self.api_key and self.oauth_client_id and self.oauth_client_secret) + or ( + self.api_key + and self.oauth_client_id + and self.oauth_client_secret + and self._get_oauth_token_task + ) + ): + msg = ( + "Either api_key or oauth_client_id and oauth_client_secret " + "are required when Tailscale client is initialized" + ) + raise TailscaleAuthenticationError(msg) + if not self.api_key: + # Handle some inconsistent state + # possibly caused by user manually deleting api_key + if self._expire_oauth_token_task: + self._expire_oauth_token_task.cancel() + self._expire_oauth_token_task = None + if self._get_oauth_token_task: + self._get_oauth_token_task.cancel() + self._get_oauth_token_task = None + # Get a new OAuth token if not already in the process of getting one + if not self._get_oauth_token_task: + self._get_oauth_token_task = asyncio.create_task( + self._get_oauth_token() + ) + # Wait for the OAuth token to be retrieved + await self._get_oauth_token_task + + async def _get_oauth_token(self) -> None: + """Get an OAuth token from the Tailscale API or token storage. + + Raises: + TailscaleAuthenticationError: when access token not found in response or + access token expires in less than 5 minutes. + + """ + if self.token_storage: + token_data = await self.token_storage.get_token() + if token_data: + access_token, expires_at = token_data + expires_in = (expires_at - datetime.now(timezone.utc)).total_seconds() + if expires_in > 60: + self._expire_oauth_token_task = asyncio.create_task( + self._expire_oauth_token(expires_in) + ) + self.api_key = access_token + return + + # Tailscale's OAuth endpoint requires form-encoded body + # with client_id and client_secret + data = { + "client_id": self.oauth_client_id, + "client_secret": self.oauth_client_secret, + } + response = await self._request( + "oauth/token", + data=data, + method=METH_POST, + _use_authentication=False, + _use_form_encoding=True, + ) + + json_response = json.loads(response) + access_token = str(json_response.get("access_token", "")) + expires_in = float(json_response.get("expires_in", 0)) + if not access_token or not expires_in: + msg = "Failed to get OAuth token" + raise TailscaleAuthenticationError(msg) + if expires_in <= 60: + msg = "OAuth token expires in less than 1 minute" + raise TailscaleAuthenticationError(msg) + + self._expire_oauth_token_task = asyncio.create_task( + self._expire_oauth_token(expires_in) + ) + if self.token_storage: + expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + await self.token_storage.set_token(access_token, expires_at) + self.api_key = access_token + + async def _expire_oauth_token(self, expires_in: float) -> None: + """Expires the OAuth token 1 minute before expiration.""" + await asyncio.sleep(expires_in - 60) + self.api_key = None + self._get_oauth_token_task = None + self._expire_oauth_token_task = None + async def _request( self, uri: str, *, method: str = METH_GET, data: dict[str, Any] | None = None, + _use_authentication: bool = True, + _use_form_encoding: bool = False, ) -> str: """Handle a request to the Tailscale API. @@ -52,8 +164,7 @@ async def _request( Returns: ------- - A Python dictionary (JSON decoded) with the response from - the Tailscale API. + The response from the Tailscale API. Raises: ------ @@ -66,22 +177,28 @@ async def _request( """ url = URL("https://api.tailscale.com/api/v2/").join(URL(uri)) - headers = { + headers: dict[str, str] = { "Accept": "application/json", } + if _use_authentication: + await self._check_api_key() + # API keys and oauth tokens can use Bearer authentication + headers["Authorization"] = f"Bearer {self.api_key}" + if self.session is None: self.session = ClientSession() self._close_session = True try: async with asyncio.timeout(self.request_timeout): + # Use form encoding for OAuth token requests, JSON for others response = await self.session.request( method, url, - json=data, - auth=BasicAuth(self.api_key), - headers=headers, + headers=headers if headers else None, + data=data if _use_form_encoding else None, + json=data if not _use_form_encoding else None, ) response.raise_for_status() except asyncio.TimeoutError as exception: @@ -89,6 +206,13 @@ async def _request( raise TailscaleConnectionError(msg) from exception except ClientResponseError as exception: if exception.status in [401, 403]: + if _use_authentication and self.api_key and self.oauth_client_id: + # Invalidate the current OAuth token + self.api_key = None + self._get_oauth_token_task = None + if self._expire_oauth_token_task: + self._expire_oauth_token_task.cancel() + self._expire_oauth_token_task = None msg = "Authentication to the Tailscale API failed" raise TailscaleAuthenticationError(msg) from exception msg = "Error occurred while connecting to the Tailscale API" @@ -114,9 +238,13 @@ async def devices(self) -> dict[str, Device]: return Devices.from_json(data).devices async def close(self) -> None: - """Close open client session.""" + """Close open client session and cancel tasks.""" if self.session and self._close_session: await self.session.close() + if self._get_oauth_token_task: + self._get_oauth_token_task.cancel() + if self._expire_oauth_token_task: + self._expire_oauth_token_task.cancel() async def __aenter__(self) -> Self: """Async enter. diff --git a/tests/storage.py b/tests/storage.py new file mode 100644 index 00000000..e17114c5 --- /dev/null +++ b/tests/storage.py @@ -0,0 +1,27 @@ +"""Dummy token storage.""" + +from datetime import datetime + +from tailscale.storage import TokenStorage + + +class InMemoryTokenStorage(TokenStorage): + """In-memory token storage for testing purposes.""" + + def __init__( + self, access_token: str | None = None, expires_at: datetime | None = None + ) -> None: + """Initialize the in-memory token storage.""" + self._access_token = access_token + self._expires_at = expires_at + + async def get_token(self) -> tuple[str, datetime] | None: + """Get the stored token.""" + if self._access_token and self._expires_at: + return self._access_token, self._expires_at + return None + + async def set_token(self, access_token: str, expires_at: datetime) -> None: + """Store the token.""" + self._access_token = access_token + self._expires_at = expires_at diff --git a/tests/test_tailscale.py b/tests/test_tailscale.py index d3276819..f31fc758 100644 --- a/tests/test_tailscale.py +++ b/tests/test_tailscale.py @@ -2,6 +2,7 @@ # pylint: disable=protected-access import asyncio +from datetime import datetime, timedelta, timezone import aiohttp import pytest @@ -13,6 +14,373 @@ TailscaleConnectionError, TailscaleError, ) +from tests.storage import InMemoryTokenStorage + + +async def test_wrong_arguments_no_auth() -> None: + """Test api key or oauth key is checked correctly.""" + async with Tailscale() as tailscale: + with pytest.raises(TailscaleAuthenticationError) as excinfo: + assert await tailscale._request("test") + + assert excinfo.value.args[0] == ( + "Either api_key or oauth_client_id and oauth_client_secret " + "are required when Tailscale client is initialized" + ) + + +async def test_wrong_arguments_both_auth() -> None: + """Test api key or oauth key is checked correctly.""" + async with Tailscale( + api_key="abc", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + ) as tailscale: + with pytest.raises(TailscaleAuthenticationError) as excinfo: + assert await tailscale._request("test") + + assert excinfo.value.args[0] == ( + "Either api_key or oauth_client_id and oauth_client_secret " + "are required when Tailscale client is initialized" + ) + + +async def test_wrong_arguments_partial_oauth() -> None: + """Test api key or oauth key is checked correctly.""" + async with Tailscale( + oauth_client_id="client", + ) as tailscale: + with pytest.raises(TailscaleAuthenticationError) as excinfo: + assert await tailscale._request("test") + + assert excinfo.value.args[0] == ( + "Either api_key or oauth_client_id and oauth_client_secret " + "are required when Tailscale client is initialized" + ) + + +async def test_key_from_oauth(aresponses: ResponsesMockServer) -> None: + """Test oauth key response is handled correctly.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token", "expires_in": 3600}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + await tailscale._request("test") + second_request = aresponses.history[1].request + assert "Bearer" in second_request.headers["Authorization"] + await tailscale.close() + + aresponses.assert_plan_strictly_followed() + + +async def test_key_from_oauth_with_race_condition( + aresponses: ResponsesMockServer, +) -> None: + """Test oauth key request is sent out only once.""" + + async def oauth_handler(_: aiohttp.ClientResponse) -> Response: + """Response handler emulating slow oauth response.""" + await asyncio.sleep(1) + return aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token", "expires_in": 3600}', + ) + + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + oauth_handler, + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + + first_task = asyncio.create_task(tailscale._request("test")) + second_task = asyncio.create_task(tailscale._request("test")) + await asyncio.gather(first_task, second_task) + await tailscale.close() + + aresponses.assert_plan_strictly_followed() + + +async def test_new_key_from_oauth_on_manual_invalidation( + aresponses: ResponsesMockServer, +) -> None: + """Test oauth key manual invalidation is handled correctly.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token", "expires_in": 3600}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token", "expires_in": 3600}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + await tailscale._request("test") + tailscale.api_key = None # Manual invalidation + await tailscale._request("test") + await tailscale.close() + + aresponses.assert_plan_strictly_followed() + + +async def test_oauth_key_expiration(aresponses: ResponsesMockServer) -> None: + """Test oauth key expiration.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token", "expires_in": 61}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + await tailscale._request("test") + assert tailscale.api_key == "short-lived-token" + assert tailscale._get_oauth_token_task is not None + assert tailscale._expire_oauth_token_task is not None + await asyncio.sleep(2) + assert tailscale.api_key is None + assert tailscale._get_oauth_token_task is None + assert tailscale._expire_oauth_token_task is None + await tailscale.close() + + aresponses.assert_plan_strictly_followed() + + +async def test_key_from_storage(aresponses: ResponsesMockServer) -> None: + """Test oauth key is loaded from storage.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + token_storage=InMemoryTokenStorage( + "stored-token", datetime.now(timezone.utc) + timedelta(hours=1) + ), + ) + await tailscale._request("test") + first_request = aresponses.history[0].request + assert "Bearer" in first_request.headers["Authorization"] + assert "stored-token" in first_request.headers["Authorization"] + await tailscale.close() + + +async def test_drop_key_from_storage(aresponses: ResponsesMockServer) -> None: + """Test oauth key response is handled correctly.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token", "expires_in": 3600}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + async with aiohttp.ClientSession() as session: + token_storage = InMemoryTokenStorage( + "stored-token", datetime.now(timezone.utc) + timedelta(seconds=30) + ) + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + token_storage=token_storage, + ) + await tailscale._request("test") + second_request = aresponses.history[1].request + assert "Bearer" in second_request.headers["Authorization"] + assert "short-lived-token" in second_request.headers["Authorization"] + assert token_storage._access_token == "short-lived-token" # noqa: S105 + await tailscale.close() + + aresponses.assert_plan_strictly_followed() + + +async def test_bad_oauth(aresponses: ResponsesMockServer) -> None: + """Test bad oauth response is handled correctly.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"no_access_token": "unauthorized"}', + ), + ) + + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + with pytest.raises(TailscaleAuthenticationError) as excinfo: + assert await tailscale._request("test") + + assert excinfo.value.args[0] == "Failed to get OAuth token" + + await tailscale.close() + + aresponses.assert_plan_strictly_followed() + + +async def test_too_short_oauth_expiration(aresponses: ResponsesMockServer) -> None: + """Test too short oauth expiration is handled correctly.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token", "expires_in": 60}', + ), + ) + + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + with pytest.raises(TailscaleAuthenticationError) as excinfo: + assert await tailscale._request("test") + + assert excinfo.value.args[0] == "OAuth token expires in less than 1 minute" + + await tailscale.close() async def test_json_request(aresponses: ResponsesMockServer) -> None: @@ -123,3 +491,43 @@ async def test_http_error401(aresponses: ResponsesMockServer) -> None: tailscale = Tailscale(tailnet="frenck", api_key="abc", session=session) with pytest.raises(TailscaleAuthenticationError): assert await tailscale._request("test") + + +async def test_http_error401_and_oauth_token_invalidation( + aresponses: ResponsesMockServer, +) -> None: + """Test HTTP 401 response handling and oauth token invalidation.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token", "expires_in": 3600}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response(text="Access denied!", status=401), + ) + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + with pytest.raises(TailscaleAuthenticationError) as excinfo: + assert await tailscale._request("test") + + assert excinfo.value.args[0] == "Authentication to the Tailscale API failed" + assert tailscale.api_key is None + assert tailscale._get_oauth_token_task is None + assert tailscale._expire_oauth_token_task is None + + await tailscale.close() + + aresponses.assert_plan_strictly_followed()