diff --git a/packages/router/src/__tests__/useBlocker.test.tsx b/packages/router/src/__tests__/useBlocker.test.tsx index 571f46ebe9..eac84864ff 100644 --- a/packages/router/src/__tests__/useBlocker.test.tsx +++ b/packages/router/src/__tests__/useBlocker.test.tsx @@ -97,4 +97,74 @@ describe('useBlocker', () => { gHistory.remove(listenerId) unmount() }) + + describe('when function', () => { + it('should initialize with IDLE state when using a function', () => { + const { result, unmount } = renderHook(() => + useBlocker({ when: () => false }), + ) + expect(result.current.state).toBe('IDLE') + unmount() + }) + + it('should block when function returns true', () => { + const whenFn = vi.fn(() => true) + const { result, unmount } = renderHook(() => useBlocker({ when: whenFn })) + + act(() => { + navigate('/blocked-path') + }) + + expect(whenFn).toHaveBeenCalled() + expect(result.current.state).toBe('BLOCKED') + unmount() + }) + + it('should not block when function returns false', () => { + const whenFn = vi.fn(() => false) + const { result, unmount } = renderHook(() => useBlocker({ when: whenFn })) + + act(() => { + navigate('/allowed-path') + }) + + expect(whenFn).toHaveBeenCalled() + expect(result.current.state).toBe('IDLE') + unmount() + }) + + it('should pass nextLocation to when function', () => { + const whenFn = vi.fn(() => true) + const { result, unmount } = renderHook(() => useBlocker({ when: whenFn })) + + act(() => { + navigate('/new-destination') + }) + + expect(whenFn).toHaveBeenCalledWith({ + nextLocation: '/new-destination', + }) + expect(result.current.state).toBe('BLOCKED') + unmount() + }) + + it('should block based on nextLocation', () => { + const whenFn = vi.fn(({ nextLocation }: { nextLocation: string }) => + nextLocation.startsWith('/protected'), + ) + const { result, unmount } = renderHook(() => useBlocker({ when: whenFn })) + + act(() => { + navigate('/allowed') + }) + expect(result.current.state).toBe('IDLE') + + act(() => { + navigate('/protected/page') + }) + expect(result.current.state).toBe('BLOCKED') + + unmount() + }) + }) }) diff --git a/packages/router/src/history.ts b/packages/router/src/history.ts index d746f29ef0..6950c6457c 100644 --- a/packages/router/src/history.ts +++ b/packages/router/src/history.ts @@ -5,7 +5,10 @@ export interface NavigateOptions { export type Listener = (ev?: PopStateEvent, options?: NavigateOptions) => any export type BeforeUnloadListener = (ev: BeforeUnloadEvent) => any -export type BlockerCallback = (tx: { retry: () => void }) => void +export type BlockerCallback = (tx: { + retry: () => void + nextLocation: string +}) => void export type Blocker = { id: string; callback: BlockerCallback } const createHistory = () => { @@ -51,7 +54,7 @@ const createHistory = () => { } if (blockers.length > 0) { - processBlockers(0, performNavigation) + processBlockers(0, performNavigation, to) } else { performNavigation() } @@ -65,7 +68,8 @@ const createHistory = () => { } if (blockers.length > 0) { - processBlockers(0, performBack) + // FIXME: for navigating back, we don't have the next location info + processBlockers(0, performBack, '') } else { performBack() } @@ -105,10 +109,15 @@ const createHistory = () => { }, } - const processBlockers = (index: number, navigate: () => void) => { + const processBlockers = ( + index: number, + navigate: () => void, + nextLocation: string, + ) => { if (index < blockers.length) { blockers[index].callback({ - retry: () => processBlockers(index + 1, navigate), + retry: () => processBlockers(index + 1, navigate, nextLocation), + nextLocation, }) } else { navigate() diff --git a/packages/router/src/useBlocker.ts b/packages/router/src/useBlocker.ts index 067f8a38ce..2a5883549d 100644 --- a/packages/router/src/useBlocker.ts +++ b/packages/router/src/useBlocker.ts @@ -5,8 +5,10 @@ import type { BlockerCallback } from './history.js' type BlockerState = 'IDLE' | 'BLOCKED' +type WhenFunction = (args: { nextLocation: string }) => boolean + interface UseBlockerOptions { - when: boolean + when: boolean | WhenFunction } export function useBlocker({ when }: UseBlockerOptions) { @@ -17,8 +19,11 @@ export function useBlocker({ when }: UseBlockerOptions) { const blockerId = useId() const blocker: BlockerCallback = useCallback( - ({ retry }) => { - if (when) { + ({ retry, nextLocation }) => { + const shouldBlock = + typeof when === 'function' ? when({ nextLocation }) : when + + if (shouldBlock) { setBlockerState('BLOCKED') setPendingNavigation(() => retry) } else { @@ -29,7 +34,8 @@ export function useBlocker({ when }: UseBlockerOptions) { ) useEffect(() => { - if (when) { + const shouldRegister = typeof when === 'function' || when + if (shouldRegister) { block(blockerId, blocker) } else { unblock(blockerId)