Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 92 additions & 79 deletions src/components/authkit-provider.tsx
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
'use client';

import React, { createContext, ReactNode, useCallback, useContext, useEffect, useState } from 'react';
import React, { createContext, ReactNode, useCallback, useContext, useEffect, useReducer } from 'react';
import {
checkSessionAction,
getAuthAction,
handleSignOutAction,
refreshAuthAction,
switchToOrganizationAction,
} from '../actions.js';
import type { Impersonator, User } from '@workos-inc/node';
import type { User } from '@workos-inc/node';
import type { UserInfo, SwitchToOrganizationOptions, NoUserInfo } from '../interfaces.js';

type AuthContextType = {
user: User | null;
sessionId: string | undefined;
organizationId: string | undefined;
role: string | undefined;
roles: string[] | undefined;
permissions: string[] | undefined;
entitlements: string[] | undefined;
featureFlags: string[] | undefined;
impersonator: Impersonator | undefined;
type AuthState =
| {
status: 'loading';
data: Omit<UserInfo | NoUserInfo, 'accessToken'>;
}
| { status: 'unauthenticated'; data: Omit<NoUserInfo, 'accessToken'> }
| { status: 'authenticated'; data: Omit<UserInfo, 'accessToken'> };

type AuthContextType = Omit<UserInfo | NoUserInfo, 'accessToken'> & {
loading: boolean;
getAuth: (options?: { ensureSignedIn?: boolean }) => Promise<void>;
refreshAuth: (options?: { ensureSignedIn?: boolean; organizationId?: string }) => Promise<void | { error: string }>;
Expand All @@ -46,43 +45,73 @@ interface AuthKitProviderProps {
initialAuth?: Omit<UserInfo | NoUserInfo, 'accessToken'>;
}

const unauthenticatedAuthStateData: Omit<NoUserInfo, 'accessToken'> = {
user: null,
sessionId: undefined,
organizationId: undefined,
role: undefined,
roles: undefined,
permissions: undefined,
entitlements: undefined,
featureFlags: undefined,
impersonator: undefined,
};

function initAuthState(initialAuth: Omit<UserInfo | NoUserInfo, 'accessToken'> | undefined): AuthState {
if (!initialAuth) {
return { status: 'loading', data: unauthenticatedAuthStateData };
}

if (!initialAuth.user) {
return { status: 'unauthenticated', data: initialAuth as Omit<NoUserInfo, 'accessToken'> };
}

return { status: 'authenticated', data: initialAuth as Omit<UserInfo, 'accessToken'> };
}

type AuthAction =
| { type: 'START_LOADING' }
| { type: 'SET_AUTH_STATE_AS_UNAUTHENTICATED'; data: Omit<NoUserInfo, 'accessToken'> }
| { type: 'SET_AUTH_STATE_AS_AUTHENTICATED'; data: Omit<UserInfo, 'accessToken'> }
| { type: 'STOP_LOADING' };

function authReducer(state: AuthState, action: AuthAction): AuthState {
switch (action.type) {
case 'START_LOADING':
return { status: 'loading', data: state.data };

case 'SET_AUTH_STATE_AS_AUTHENTICATED':
return { status: 'authenticated', data: action.data };

case 'SET_AUTH_STATE_AS_UNAUTHENTICATED':
return { status: 'unauthenticated', data: action.data };

case 'STOP_LOADING':
if (state.data.user) {
return { status: 'authenticated', data: state.data as Omit<UserInfo, 'accessToken'> };
}
return { status: 'unauthenticated', data: state.data as Omit<NoUserInfo, 'accessToken'> };

default:
return state;
}
}

export const AuthKitProvider = ({ children, onSessionExpired, initialAuth }: AuthKitProviderProps) => {
const [user, setUser] = useState<User | null>(initialAuth?.user ?? null);
const [sessionId, setSessionId] = useState<string | undefined>(initialAuth?.sessionId);
const [organizationId, setOrganizationId] = useState<string | undefined>(initialAuth?.organizationId);
const [role, setRole] = useState<string | undefined>(initialAuth?.role);
const [roles, setRoles] = useState<string[] | undefined>(initialAuth?.roles);
const [permissions, setPermissions] = useState<string[] | undefined>(initialAuth?.permissions);
const [entitlements, setEntitlements] = useState<string[] | undefined>(initialAuth?.entitlements);
const [featureFlags, setFeatureFlags] = useState<string[] | undefined>(initialAuth?.featureFlags);
const [impersonator, setImpersonator] = useState<Impersonator | undefined>(initialAuth?.impersonator);
const [loading, setLoading] = useState(!initialAuth);
const [authState, dispatch] = useReducer(authReducer, initialAuth, initAuthState);

const getAuth = useCallback(async ({ ensureSignedIn = false }: { ensureSignedIn?: boolean } = {}) => {
setLoading(true);
dispatch({ type: 'START_LOADING' });
try {
const auth = await getAuthAction({ ensureSignedIn });
setUser(auth.user);
setSessionId(auth.sessionId);
setOrganizationId(auth.organizationId);
setRole(auth.role);
setRoles(auth.roles);
setPermissions(auth.permissions);
setEntitlements(auth.entitlements);
setFeatureFlags(auth.featureFlags);
setImpersonator(auth.impersonator);

if (auth.user) {
dispatch({ type: 'SET_AUTH_STATE_AS_AUTHENTICATED', data: auth as Omit<UserInfo, 'accessToken'> });
} else {
dispatch({ type: 'SET_AUTH_STATE_AS_UNAUTHENTICATED', data: auth as Omit<NoUserInfo, 'accessToken'> });
}
} catch (error) {
setUser(null);
setSessionId(undefined);
setOrganizationId(undefined);
setRole(undefined);
setRoles(undefined);
setPermissions(undefined);
setEntitlements(undefined);
setFeatureFlags(undefined);
setImpersonator(undefined);
} finally {
setLoading(false);
dispatch({ type: 'SET_AUTH_STATE_AS_UNAUTHENTICATED', data: unauthenticatedAuthStateData });
}
}, []);

Expand All @@ -105,23 +134,18 @@ export const AuthKitProvider = ({ children, onSessionExpired, initialAuth }: Aut

const refreshAuth = useCallback(
async ({ ensureSignedIn = false, organizationId }: { ensureSignedIn?: boolean; organizationId?: string } = {}) => {
dispatch({ type: 'START_LOADING' });
try {
setLoading(true);
const auth = await refreshAuthAction({ ensureSignedIn, organizationId });

setUser(auth.user);
setSessionId(auth.sessionId);
setOrganizationId(auth.organizationId);
setRole(auth.role);
setRoles(auth.roles);
setPermissions(auth.permissions);
setEntitlements(auth.entitlements);
setFeatureFlags(auth.featureFlags);
setImpersonator(auth.impersonator);
if (auth.user) {
dispatch({ type: 'SET_AUTH_STATE_AS_AUTHENTICATED', data: auth as Omit<UserInfo, 'accessToken'> });
} else {
dispatch({ type: 'SET_AUTH_STATE_AS_UNAUTHENTICATED', data: auth as Omit<NoUserInfo, 'accessToken'> });
}
} catch (error) {
dispatch({ type: 'STOP_LOADING' });
return error instanceof Error ? { error: error.message } : { error: String(error) };
} finally {
setLoading(false);
Comment on lines -123 to -124
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this action call fails above then without this finally the loading state stays loading forever. We probably need to keep this in some form to reset the loading state.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

}
},
[],
Expand Down Expand Up @@ -182,36 +206,25 @@ export const AuthKitProvider = ({ children, onSessionExpired, initialAuth }: Aut
window.removeEventListener('focus', handleVisibilityChange);
window.removeEventListener('visibilitychange', handleVisibilityChange);
};
}, [onSessionExpired]);

return (
<AuthContext.Provider
value={{
user,
sessionId,
organizationId,
role,
roles,
permissions,
entitlements,
featureFlags,
impersonator,
loading,
getAuth,
refreshAuth,
signOut,
switchToOrganization,
}}
>
{children}
</AuthContext.Provider>
);
}, [onSessionExpired, initialAuth, getAuth]);

const contextValue: AuthContextType = {
...authState.data,
loading: authState.status === 'loading',
getAuth,
refreshAuth,
signOut,
switchToOrganization,
};

return <AuthContext.Provider value={contextValue}>{children}</AuthContext.Provider>;
};

export function useAuth(options: {
ensureSignedIn: true;
}): AuthContextType & ({ loading: true; user: User | null } | { loading: false; user: User });
export function useAuth(options?: { ensureSignedIn?: false }): AuthContextType;

export function useAuth({ ensureSignedIn = false }: { ensureSignedIn?: boolean } = {}) {
const context = useContext(AuthContext);

Expand Down