From 541e0a21800c800bdab3e7e23ec4b97a23416abc Mon Sep 17 00:00:00 2001 From: Clayton Rosenthal Date: Wed, 19 Jul 2023 23:58:54 -0700 Subject: [PATCH] new-feature: oauth support --- examples/devices.py | 10 ++++- examples/oauth.py | 30 +++++++++++++++ src/tailscale/tailscale.py | 78 ++++++++++++++++++++++++++++++++++---- tests/test_tailscale.py | 78 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 186 insertions(+), 10 deletions(-) mode change 100644 => 100755 examples/devices.py create mode 100755 examples/oauth.py diff --git a/examples/devices.py b/examples/devices.py old mode 100644 new mode 100755 index 156abfc6..29723565 --- a/examples/devices.py +++ b/examples/devices.py @@ -1,16 +1,22 @@ +#!/usr/bin/env python3 # pylint: disable=W0621 """Asynchronous client for the Tailscale API.""" import asyncio +import os from tailscale import Tailscale +# "-" is the default tailnet of the API key +TAILNET = os.environ.get("TS_TAILNET", "-") +API_KEY = os.environ.get("TS_API_KEY", "") + async def main() -> None: """Show example on using the Tailscale API client.""" async with Tailscale( - tailnet="frenck", - api_key="tskey-somethingsomething", + tailnet=TAILNET, + api_key=API_KEY, ) as tailscale: devices = await tailscale.devices() print(devices) diff --git a/examples/oauth.py b/examples/oauth.py new file mode 100755 index 00000000..edd188b9 --- /dev/null +++ b/examples/oauth.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# pylint: disable=W0621 +"""Asynchronous client for the Tailscale API.""" + +import asyncio +import os + +from tailscale import Tailscale + +# "-" is the default tailnet of the API key +TAILNET = os.environ.get("TS_TAILNET", "-") + +# OAuth client ID and secret are required for OAuth authentication +OAUTH_CLIENT_ID = os.environ.get("TS_API_CLIENT_ID", "") +OAUTH_CLIENT_SECRET = os.environ.get("TS_API_CLIENT_SECRET", "") + + +async def main_oauth() -> None: + """Show example on using the Tailscale API client with OAuth.""" + async with Tailscale( + tailnet=TAILNET, + oauth_client_id=OAUTH_CLIENT_ID, + oauth_client_secret=OAUTH_CLIENT_SECRET, + ) as tailscale: + devices = await tailscale.devices() + print(devices) + + +if __name__ == "__main__": + asyncio.run(main_oauth()) diff --git a/src/tailscale/tailscale.py b/src/tailscale/tailscale.py index 305a52ea..4daadeed 100644 --- a/src/tailscale/tailscale.py +++ b/src/tailscale/tailscale.py @@ -3,13 +3,13 @@ from __future__ import annotations import asyncio +import json import socket from dataclasses import dataclass from typing import Any, Self -from aiohttp import BasicAuth from aiohttp.client import ClientError, ClientResponseError, ClientSession -from aiohttp.hdrs import METH_GET +from aiohttp.hdrs import METH_GET, METH_POST from yarl import URL from .exceptions import ( @@ -19,25 +19,80 @@ ) from .models import Device, Devices +# Placeholder value for the access token when it is not yet set. +ACCESS_TOKEN_PENDING = "" # noqa: S105 + @dataclass class Tailscale: """Main class for handling connections with the Tailscale API.""" - tailnet: str - api_key: str + # tailnet of '-' is the default tailnet of the API key + tailnet: str = "-" + api_key: str = "" # nosec + oauth_client_id: str = "" # nosec + oauth_client_secret: str = "" # nosec request_timeout: int = 8 session: ClientSession | None = None _close_session: bool = False + async def _check_access(self) -> None: + """Initialize the Tailscale client. + + Raises: + TailscaleAuthenticationError: when neither api_key nor oauth_client_id and + oauth_client_secret are provided. + + """ + if ( + not self.api_key + and not self.oauth_client_id + and not self.oauth_client_secret + ): + msg = "Either api_key or oauth client is required" + raise TailscaleAuthenticationError(msg) + if not self.api_key: + self.api_key = ACCESS_TOKEN_PENDING + self.api_key = await self._get_oauth_token() + + async def _get_oauth_token(self) -> str: + """Get an OAuth token from the Tailscale API. + + Raises: + TailscaleAuthenticationError: when access key not found in response. + + Returns: + A string with the OAuth token, or nothing on error + + """ + # Tailscale's OAuth endpoint requires form-encoded body + # with client_id and client_secret + data = { + "client_id": self.oauth_client_id, + "client_secret": self.oauth_client_secret, + } + response = await self._request( + "oauth/token", + data=data, + method=METH_POST, + use_form_encoding=True, + ) + + token = json.loads(response).get("access_token", "") + if not token: + msg = "Failed to get OAuth token" + raise TailscaleAuthenticationError(msg) + return str(token) + async def _request( self, uri: str, *, method: str = METH_GET, data: dict[str, Any] | None = None, + use_form_encoding: bool = False, ) -> str: """Handle a request to the Tailscale API. @@ -66,22 +121,29 @@ async def _request( """ url = URL("https://api.tailscale.com/api/v2/").join(URL(uri)) - headers = { + await self._check_access() + + headers: dict[str, str] = { "Accept": "application/json", } + if self.api_key and self.api_key != ACCESS_TOKEN_PENDING: + # API keys and oauth tokens can use Bearer authentication + headers["Authorization"] = f"Bearer {self.api_key}" + if self.session is None: self.session = ClientSession() self._close_session = True try: async with asyncio.timeout(self.request_timeout): + # Use form encoding for OAuth token requests, JSON for others response = await self.session.request( method, url, - json=data, - auth=BasicAuth(self.api_key), - headers=headers, + headers=headers if headers else None, + data=data if use_form_encoding else None, + json=data if not use_form_encoding else None, ) response.raise_for_status() except asyncio.TimeoutError as exception: diff --git a/tests/test_tailscale.py b/tests/test_tailscale.py index d3276819..68935118 100644 --- a/tests/test_tailscale.py +++ b/tests/test_tailscale.py @@ -15,6 +15,84 @@ ) +@pytest.mark.asyncio +async def test_no_access() -> None: + """Test api key or oauth key is checked correctly.""" + async with Tailscale(tailnet="frenck") as tailscale: + with pytest.raises(TailscaleAuthenticationError): + assert await tailscale._request("test") + + +@pytest.mark.asyncio +async def test_key_from_oauth(aresponses: ResponsesMockServer) -> None: + """Test oauth key response is handled correctly.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"access_token": "short-lived-token"}', + ), + ) + aresponses.add( + "api.tailscale.com", + "/api/v2/test", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"status": "ok"}', + ), + ) + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", # nosec + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + await tailscale._request("test") + second_request = aresponses.history[1].request + assert "Bearer" in second_request.headers["Authorization"] + await tailscale.close() + + aresponses.assert_plan_strictly_followed() + + +@pytest.mark.asyncio +async def test_bad_oauth(aresponses: ResponsesMockServer) -> None: + """Test bad oauth error is handled correctly.""" + aresponses.add( + "api.tailscale.com", + "/api/v2/oauth/token", + "POST", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text='{"no_access_token": "unauthorized"}', + ), + ) + + async with aiohttp.ClientSession() as session: + tailscale = Tailscale( + tailnet="frenck", + oauth_client_id="client", # nosec + oauth_client_secret="notsosecret", # noqa: S106 + session=session, + ) + with pytest.raises(TailscaleAuthenticationError) as excinfo: + assert await tailscale._request("test") + + assert excinfo.value.args[0] == "Failed to get OAuth token" + + await tailscale.close() + + aresponses.assert_plan_strictly_followed() + + +@pytest.mark.asyncio async def test_json_request(aresponses: ResponsesMockServer) -> None: """Test JSON response is handled correctly.""" aresponses.add(