diff --git a/.github/workflows/pypi-publish.yml b/.github/workflows/pypi-publish.yml new file mode 100644 index 0000000..80a7ea5 --- /dev/null +++ b/.github/workflows/pypi-publish.yml @@ -0,0 +1,30 @@ +name: Publish Python Package + +on: + push: + tags: + - 'v*' + +jobs: + build-and-publish: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build pytest + pip install -e . + - name: Run tests + run: pytest + - name: Build package + run: python -m build + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.PYPI_API_TOKEN }} + diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..1c13b2a --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,24 @@ +name: CI + +on: + push: + branches: [main, master] + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e . + pip install pytest + - name: Run tests + run: pytest + diff --git a/tests/test_internal.py b/tests/test_internal.py new file mode 100644 index 0000000..6e627bf --- /dev/null +++ b/tests/test_internal.py @@ -0,0 +1,31 @@ +import streamlit as st +from streamlit_oauth import _generate_state, _generate_pkce_pair +import pytest + + +def test_generate_state_same_key(): + st.session_state.clear() + s1 = _generate_state(key="a") + s2 = _generate_state(key="a") + assert s1 == s2 + + +def test_generate_state_different_key(): + st.session_state.clear() + s1 = _generate_state(key="a") + s2 = _generate_state(key="b") + assert s1 != s2 + + +def test_generate_pkce_pair_same_key(): + st.session_state.clear() + p1 = _generate_pkce_pair("S256", key="x") + p2 = _generate_pkce_pair("S256", key="x") + assert p1 == p2 + assert len(p1) == 2 + + +def test_generate_pkce_pair_invalid(): + st.session_state.clear() + with pytest.raises(Exception): + _generate_pkce_pair("plain", key="x") diff --git a/tests/test_oauth_component.py b/tests/test_oauth_component.py new file mode 100644 index 0000000..ae3cdfb --- /dev/null +++ b/tests/test_oauth_component.py @@ -0,0 +1,60 @@ +import time +import streamlit as st +import pytest +from unittest.mock import AsyncMock + +from streamlit_oauth import OAuth2Component, OAuth2, StreamlitOauthError + + +def test_authorize_button_success(monkeypatch): + st.session_state.clear() + client = OAuth2("id", "secret", "auth", "token") + oauth = OAuth2Component(client=client) + + # Mock async client methods + monkeypatch.setattr(oauth.client, "get_authorization_url", AsyncMock(return_value="http://auth")) + monkeypatch.setattr(oauth.client, "get_access_token", AsyncMock(return_value={"access_token": "tok"})) + + # Force deterministic state and component output + monkeypatch.setattr("streamlit_oauth._generate_state", lambda key=None: "STATE") + monkeypatch.setattr("streamlit_oauth._authorize_button", lambda **kwargs: {"code": "CODE", "state": "STATE"}) + + result = oauth.authorize_button("Login", "http://cb", "scope", key="k") + assert result["token"]["access_token"] == "tok" + assert f"state-k" not in st.session_state + + +def test_authorize_button_state_mismatch(monkeypatch): + st.session_state.clear() + client = OAuth2("id", "secret", "auth", "token") + oauth = OAuth2Component(client=client) + + monkeypatch.setattr(oauth.client, "get_authorization_url", AsyncMock(return_value="http://auth")) + monkeypatch.setattr(oauth.client, "get_access_token", AsyncMock(return_value={"access_token": "tok"})) + monkeypatch.setattr("streamlit_oauth._generate_state", lambda key=None: "GOOD") + monkeypatch.setattr("streamlit_oauth._authorize_button", lambda **kwargs: {"code": "CODE", "state": "BAD"}) + + with pytest.raises(StreamlitOauthError): + oauth.authorize_button("Login", "http://cb", "scope", key="k") + + +def test_refresh_token_expired(monkeypatch): + client = OAuth2("id", "secret", "auth", "token") + oauth = OAuth2Component(client=client) + + monkeypatch.setattr(oauth.client, "refresh_token", AsyncMock(return_value={"access_token": "new"})) + + token = {"access_token": "old", "refresh_token": "r", "expires_at": time.time() - 1} + result = oauth.refresh_token(token) + + assert result["access_token"] == "new" + + +def test_revoke_token(monkeypatch): + client = OAuth2("id", "secret", "auth", "token") + oauth = OAuth2Component(client=client) + revoke_mock = AsyncMock() + monkeypatch.setattr(oauth.client, "revoke_token", revoke_mock) + + assert oauth.revoke_token({"access_token": "abc"}) is True + revoke_mock.assert_awaited_once()