diff --git a/src/components/authkit-provider.tsx b/src/components/authkit-provider.tsx index 8d78dd0..048ad5f 100644 --- a/src/components/authkit-provider.tsx +++ b/src/components/authkit-provider.tsx @@ -1,6 +1,6 @@ 'use client'; -import React, { createContext, ReactNode, useCallback, useContext, useEffect, useState } from 'react'; +import React, { createContext, ReactNode, useCallback, useContext, useEffect, useReducer } from 'react'; import { checkSessionAction, getAuthAction, @@ -8,19 +8,18 @@ import { refreshAuthAction, switchToOrganizationAction, } from '../actions.js'; -import type { Impersonator, User } from '@workos-inc/node'; +import type { User } from '@workos-inc/node'; import type { UserInfo, SwitchToOrganizationOptions, NoUserInfo } from '../interfaces.js'; -type AuthContextType = { - user: User | null; - sessionId: string | undefined; - organizationId: string | undefined; - role: string | undefined; - roles: string[] | undefined; - permissions: string[] | undefined; - entitlements: string[] | undefined; - featureFlags: string[] | undefined; - impersonator: Impersonator | undefined; +type AuthState = + | { + status: 'loading'; + data: Omit; + } + | { status: 'unauthenticated'; data: Omit } + | { status: 'authenticated'; data: Omit }; + +type AuthContextType = Omit & { loading: boolean; getAuth: (options?: { ensureSignedIn?: boolean }) => Promise; refreshAuth: (options?: { ensureSignedIn?: boolean; organizationId?: string }) => Promise; @@ -46,43 +45,73 @@ interface AuthKitProviderProps { initialAuth?: Omit; } +const unauthenticatedAuthStateData: Omit = { + user: null, + sessionId: undefined, + organizationId: undefined, + role: undefined, + roles: undefined, + permissions: undefined, + entitlements: undefined, + featureFlags: undefined, + impersonator: undefined, +}; + +function initAuthState(initialAuth: Omit | undefined): AuthState { + if (!initialAuth) { + return { status: 'loading', data: unauthenticatedAuthStateData }; + } + + if (!initialAuth.user) { + return { status: 'unauthenticated', data: initialAuth as Omit }; + } + + return { status: 'authenticated', data: initialAuth as Omit }; +} + +type AuthAction = + | { type: 'START_LOADING' } + | { type: 'SET_AUTH_STATE_AS_UNAUTHENTICATED'; data: Omit } + | { type: 'SET_AUTH_STATE_AS_AUTHENTICATED'; data: Omit } + | { type: 'STOP_LOADING' }; + +function authReducer(state: AuthState, action: AuthAction): AuthState { + switch (action.type) { + case 'START_LOADING': + return { status: 'loading', data: state.data }; + + case 'SET_AUTH_STATE_AS_AUTHENTICATED': + return { status: 'authenticated', data: action.data }; + + case 'SET_AUTH_STATE_AS_UNAUTHENTICATED': + return { status: 'unauthenticated', data: action.data }; + + case 'STOP_LOADING': + if (state.data.user) { + return { status: 'authenticated', data: state.data as Omit }; + } + return { status: 'unauthenticated', data: state.data as Omit }; + + default: + return state; + } +} + export const AuthKitProvider = ({ children, onSessionExpired, initialAuth }: AuthKitProviderProps) => { - const [user, setUser] = useState(initialAuth?.user ?? null); - const [sessionId, setSessionId] = useState(initialAuth?.sessionId); - const [organizationId, setOrganizationId] = useState(initialAuth?.organizationId); - const [role, setRole] = useState(initialAuth?.role); - const [roles, setRoles] = useState(initialAuth?.roles); - const [permissions, setPermissions] = useState(initialAuth?.permissions); - const [entitlements, setEntitlements] = useState(initialAuth?.entitlements); - const [featureFlags, setFeatureFlags] = useState(initialAuth?.featureFlags); - const [impersonator, setImpersonator] = useState(initialAuth?.impersonator); - const [loading, setLoading] = useState(!initialAuth); + const [authState, dispatch] = useReducer(authReducer, initialAuth, initAuthState); const getAuth = useCallback(async ({ ensureSignedIn = false }: { ensureSignedIn?: boolean } = {}) => { - setLoading(true); + dispatch({ type: 'START_LOADING' }); try { const auth = await getAuthAction({ ensureSignedIn }); - setUser(auth.user); - setSessionId(auth.sessionId); - setOrganizationId(auth.organizationId); - setRole(auth.role); - setRoles(auth.roles); - setPermissions(auth.permissions); - setEntitlements(auth.entitlements); - setFeatureFlags(auth.featureFlags); - setImpersonator(auth.impersonator); + + if (auth.user) { + dispatch({ type: 'SET_AUTH_STATE_AS_AUTHENTICATED', data: auth as Omit }); + } else { + dispatch({ type: 'SET_AUTH_STATE_AS_UNAUTHENTICATED', data: auth as Omit }); + } } catch (error) { - setUser(null); - setSessionId(undefined); - setOrganizationId(undefined); - setRole(undefined); - setRoles(undefined); - setPermissions(undefined); - setEntitlements(undefined); - setFeatureFlags(undefined); - setImpersonator(undefined); - } finally { - setLoading(false); + dispatch({ type: 'SET_AUTH_STATE_AS_UNAUTHENTICATED', data: unauthenticatedAuthStateData }); } }, []); @@ -105,23 +134,18 @@ export const AuthKitProvider = ({ children, onSessionExpired, initialAuth }: Aut const refreshAuth = useCallback( async ({ ensureSignedIn = false, organizationId }: { ensureSignedIn?: boolean; organizationId?: string } = {}) => { + dispatch({ type: 'START_LOADING' }); try { - setLoading(true); const auth = await refreshAuthAction({ ensureSignedIn, organizationId }); - setUser(auth.user); - setSessionId(auth.sessionId); - setOrganizationId(auth.organizationId); - setRole(auth.role); - setRoles(auth.roles); - setPermissions(auth.permissions); - setEntitlements(auth.entitlements); - setFeatureFlags(auth.featureFlags); - setImpersonator(auth.impersonator); + if (auth.user) { + dispatch({ type: 'SET_AUTH_STATE_AS_AUTHENTICATED', data: auth as Omit }); + } else { + dispatch({ type: 'SET_AUTH_STATE_AS_UNAUTHENTICATED', data: auth as Omit }); + } } catch (error) { + dispatch({ type: 'STOP_LOADING' }); return error instanceof Error ? { error: error.message } : { error: String(error) }; - } finally { - setLoading(false); } }, [], @@ -182,36 +206,25 @@ export const AuthKitProvider = ({ children, onSessionExpired, initialAuth }: Aut window.removeEventListener('focus', handleVisibilityChange); window.removeEventListener('visibilitychange', handleVisibilityChange); }; - }, [onSessionExpired]); - - return ( - - {children} - - ); + }, [onSessionExpired, initialAuth, getAuth]); + + const contextValue: AuthContextType = { + ...authState.data, + loading: authState.status === 'loading', + getAuth, + refreshAuth, + signOut, + switchToOrganization, + }; + + return {children}; }; export function useAuth(options: { ensureSignedIn: true; }): AuthContextType & ({ loading: true; user: User | null } | { loading: false; user: User }); export function useAuth(options?: { ensureSignedIn?: false }): AuthContextType; + export function useAuth({ ensureSignedIn = false }: { ensureSignedIn?: boolean } = {}) { const context = useContext(AuthContext);