diff --git a/streamlit_oauth/__init__.py b/streamlit_oauth/__init__.py index 9335fa3..714e71a 100644 --- a/streamlit_oauth/__init__.py +++ b/streamlit_oauth/__init__.py @@ -134,7 +134,11 @@ def refresh_token(self, token, force=False): if token.get('refresh_token') is None: raise Exception("Token is expired and no refresh token is available") else: - token = asyncio.run(self.client.refresh_token(token.get('refresh_token'))) + new_token = asyncio.run(self.client.refresh_token(token.get('refresh_token'))) + # Keep the old refresh token if the new one is missing it + if not new_token.get('refresh_token'): + new_token['refresh_token'] = token.get('refresh_token') + token = new_token return token def revoke_token(self, token, token_type_hint="access_token"): diff --git a/tests/test_oauth_component.py b/tests/test_oauth_component.py index ae3cdfb..ca34fc2 100644 --- a/tests/test_oauth_component.py +++ b/tests/test_oauth_component.py @@ -48,6 +48,7 @@ def test_refresh_token_expired(monkeypatch): result = oauth.refresh_token(token) assert result["access_token"] == "new" + assert "refresh_token" in result def test_revoke_token(monkeypatch):