Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions .github/workflows/pypi-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: Publish Python Package

on:
push:
tags:
- 'v*'

jobs:
build-and-publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build pytest
pip install -e .
- name: Run tests
run: pytest
- name: Build package
run: python -m build
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}

24 changes: 24 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: CI

on:
push:
branches: [main, master]
pull_request:

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
pip install pytest
- name: Run tests
run: pytest

31 changes: 31 additions & 0 deletions tests/test_internal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import streamlit as st
from streamlit_oauth import _generate_state, _generate_pkce_pair
import pytest


def test_generate_state_same_key():
st.session_state.clear()
s1 = _generate_state(key="a")
s2 = _generate_state(key="a")
assert s1 == s2


def test_generate_state_different_key():
st.session_state.clear()
s1 = _generate_state(key="a")
s2 = _generate_state(key="b")
assert s1 != s2


def test_generate_pkce_pair_same_key():
st.session_state.clear()
p1 = _generate_pkce_pair("S256", key="x")
p2 = _generate_pkce_pair("S256", key="x")
assert p1 == p2
assert len(p1) == 2


def test_generate_pkce_pair_invalid():
st.session_state.clear()
with pytest.raises(Exception):
_generate_pkce_pair("plain", key="x")
60 changes: 60 additions & 0 deletions tests/test_oauth_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import time
import streamlit as st
import pytest
from unittest.mock import AsyncMock

from streamlit_oauth import OAuth2Component, OAuth2, StreamlitOauthError


def test_authorize_button_success(monkeypatch):
st.session_state.clear()
client = OAuth2("id", "secret", "auth", "token")
oauth = OAuth2Component(client=client)

# Mock async client methods
monkeypatch.setattr(oauth.client, "get_authorization_url", AsyncMock(return_value="http://auth"))
monkeypatch.setattr(oauth.client, "get_access_token", AsyncMock(return_value={"access_token": "tok"}))

# Force deterministic state and component output
monkeypatch.setattr("streamlit_oauth._generate_state", lambda key=None: "STATE")
monkeypatch.setattr("streamlit_oauth._authorize_button", lambda **kwargs: {"code": "CODE", "state": "STATE"})

result = oauth.authorize_button("Login", "http://cb", "scope", key="k")
assert result["token"]["access_token"] == "tok"
assert f"state-k" not in st.session_state


def test_authorize_button_state_mismatch(monkeypatch):
st.session_state.clear()
client = OAuth2("id", "secret", "auth", "token")
oauth = OAuth2Component(client=client)

monkeypatch.setattr(oauth.client, "get_authorization_url", AsyncMock(return_value="http://auth"))
monkeypatch.setattr(oauth.client, "get_access_token", AsyncMock(return_value={"access_token": "tok"}))
monkeypatch.setattr("streamlit_oauth._generate_state", lambda key=None: "GOOD")
monkeypatch.setattr("streamlit_oauth._authorize_button", lambda **kwargs: {"code": "CODE", "state": "BAD"})

with pytest.raises(StreamlitOauthError):
oauth.authorize_button("Login", "http://cb", "scope", key="k")


def test_refresh_token_expired(monkeypatch):
client = OAuth2("id", "secret", "auth", "token")
oauth = OAuth2Component(client=client)

monkeypatch.setattr(oauth.client, "refresh_token", AsyncMock(return_value={"access_token": "new"}))

token = {"access_token": "old", "refresh_token": "r", "expires_at": time.time() - 1}
result = oauth.refresh_token(token)

assert result["access_token"] == "new"


def test_revoke_token(monkeypatch):
client = OAuth2("id", "secret", "auth", "token")
oauth = OAuth2Component(client=client)
revoke_mock = AsyncMock()
monkeypatch.setattr(oauth.client, "revoke_token", revoke_mock)

assert oauth.revoke_token({"access_token": "abc"}) is True
revoke_mock.assert_awaited_once()