diff --git a/package.json b/package.json index 8337e3d..99d4892 100644 --- a/package.json +++ b/package.json @@ -6,11 +6,13 @@ "scripts": { "dev": "vite", "build": "tsc -b && vite build", - "preview": "vite preview" + "preview": "vite preview", + "test": "vitest" }, "dependencies": { "@emotion/react": "^11.14.0", "@emotion/styled": "^11.14.1", + "@huggingface/hub": "^2.4.0", "@mlc-ai/web-llm": "^0.2.79", "@mui/icons-material": "^7.2.0", "@mui/material": "^7.2.0", @@ -23,10 +25,14 @@ "react-syntax-highlighter": "^15.6.1" }, "devDependencies": { + "@testing-library/jest-dom": "^6.6.4", + "@testing-library/react": "^16.3.0", "@types/react-dom": "^19.1.7", "@vitejs/plugin-react": "^4.7.0", "@webgpu/types": "^0.1.64", + "jsdom": "^26.1.0", "typescript": "~5.8.3", - "vite": "^7.0.4" + "vite": "^7.0.4", + "vitest": "^3.2.4" } } diff --git a/src/App.tsx b/src/App.tsx index d9c716a..e7c9c06 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -1,6 +1,8 @@ -import {downloadModel, sendPrompt} from "./LLM.ts"; +import {downloadModel, interrupt, sendPrompt} from "./LLM.ts"; import {useEffect, useState} from "react"; import {useTypedDispatch, useTypedSelector} from "./redux/store.ts"; +import {getCompatibleModels} from "./huggingface.ts"; +import {setModels} from "./redux/llmSlice.ts"; import { AppBar, Box, @@ -19,18 +21,21 @@ import {Send} from "@mui/icons-material"; import Markdown from "react-markdown"; import {setCriticalError} from "./redux/llmSlice.ts"; import {isWebGPUok} from "./CheckWebGPU.ts"; - -const MODEL = 'Llama-3.2-1B-Instruct-q4f16_1-MLC'; -const MODEL_SIZE_MB = 664; +import {DOWNLOADED_MODELS_KEY} from "./constants.ts"; +import {ModelSelector} from "./ModelSelector.tsx"; export function App() { - const {downloadStatus, messageHistory, criticalError} = useTypedSelector(state => state.llm); + const {downloadStatus, messageHistory, criticalError, isGenerating, selectedModel} = useTypedSelector(state => state.llm); const dispatch = useTypedDispatch(); const [inputValue, setInputValue] = useState(''); const [alreadyFromCache, setAlreadyFromCache] = useState(false); const [loadFinished, setLoadFinished] = useState(false); useEffect(() => { + getCompatibleModels().then(models => { + dispatch(setModels(models)); + }); + isWebGPUok().then(trueOrError => { if (trueOrError !== true) { dispatch(setCriticalError('WebGPU error: ' + trueOrError)); @@ -44,18 +49,18 @@ export function App() { navigator.storage.estimate().then(estimate => { if (estimate) { const remainingMb = (estimate.quota - estimate.usage) / 1024 / 1024; - if (!alreadyFromCache && remainingMb > 10 && remainingMb < MODEL_SIZE_MB) { - dispatch(setCriticalError('Remaining cache storage, that browser allowed is too low')); - } + // if (!alreadyFromCache && remainingMb > 10 && remainingMb < MODEL_SIZE_MB) { + // dispatch(setCriticalError('Remaining cache storage, that browser allowed is too low')); + // } } }); } else { dispatch(setCriticalError('StorageManager API is not supported in your browser')); } - if (localStorage.getItem('downloaded_models')) { + if (localStorage.getItem(DOWNLOADED_MODELS_KEY)) { setAlreadyFromCache(true); - downloadModel(MODEL).then(() => setLoadFinished(true)); + downloadModel(selectedModel).then(() => setLoadFinished(true)); } }, []); @@ -90,10 +95,11 @@ export function App() {

Browser LLM demo working on JavaScript and WebGPU

{!alreadyFromCache && !loadFinished && !criticalError && ( - + + + onClick={() => downloadModel(selectedModel).then(() => setLoadFinished(true))}>Download + Model )} Loading model: {downloadStatus} @@ -113,8 +119,7 @@ export function App() { > {message.role}: - {/* @ts-ignore */} - {message.content} + {message.content || ''} ))} @@ -147,10 +152,15 @@ export function App() { sx={{ml: 1, flex: 1}} InputProps={{disableUnderline: true}} /> - - - + {isGenerating ? ( + + ) : ( + + + + )} + {isGenerating && Generating response...} )} diff --git a/src/CheckWebGPU.test.ts b/src/CheckWebGPU.test.ts new file mode 100644 index 0000000..e48744f --- /dev/null +++ b/src/CheckWebGPU.test.ts @@ -0,0 +1,47 @@ +import {isWebGPUok} from './CheckWebGPU.ts'; +import {describe, test, expect, vi} from 'vitest'; + +describe('isWebGPUok', () => { + test('should return an error message if WebGPU is not supported', async () => { + // In jsdom, navigator.gpu is undefined, so this should fail. + const result = await isWebGPUok(); + expect(result).toBe('WebGPU is NOT supported on this browser.'); + }); + + test('should return an error message if adapter is not found', async () => { + // Mock navigator.gpu but make requestAdapter return null + vi.stubGlobal('navigator', { + gpu: { + requestAdapter: async () => null, + }, + }); + + const result = await isWebGPUok(); + expect(result).toBe('WebGPU Adapter not found.'); + }); + + test('should return true if WebGPU is supported and working', async () => { + // A more complete mock of the WebGPU API + const mockDevice = { + createShaderModule: vi.fn(() => ({ + getCompilationInfo: vi.fn().mockResolvedValue({ + messages: [], + }), + })), + }; + + const mockAdapter = { + requestDevice: async () => mockDevice, + features: new Set(['shader-f16']), + }; + + vi.stubGlobal('navigator', { + gpu: { + requestAdapter: async () => mockAdapter, + }, + }); + + const result = await isWebGPUok(); + expect(result).toBe(true); + }); +}); diff --git a/src/LLM.ts b/src/LLM.ts index c38491e..d9db6de 100644 --- a/src/LLM.ts +++ b/src/LLM.ts @@ -1,7 +1,15 @@ import type {ChatCompletionMessageParam} from "@mlc-ai/web-llm/lib/openai_api_protocols/chat_completion"; import type {MLCEngine} from "@mlc-ai/web-llm"; -import {setDownloadStatus, setMessageHistory, setCriticalError} from "./redux/llmSlice.ts"; +import { + setDownloadStatus, + setCriticalError, + addUserMessage, + addBotMessage, + updateLastBotMessageContent, + setIsGenerating +} from "./redux/llmSlice.ts"; import {dispatch, getState} from "./redux/store.ts"; +import {DOWNLOADED_MODELS_KEY} from "./constants.ts"; let libraryCache: any = null; @@ -42,43 +50,51 @@ export async function downloadModel(name: string) { console.error(error); return; } + dispatch(setDownloadStatus('done')); - localStorage.setItem('downloaded_models', JSON.stringify([name])); + localStorage.setItem(DOWNLOADED_MODELS_KEY, JSON.stringify([name])); } export async function sendPrompt(message: string, maxTokens = 1000) { - const messagesHistory = getState(state => state.llm.messageHistory); - const newUserMessage: ChatCompletionMessageParam = {role: 'user', content: message}; - let updatedHistory = [...messagesHistory, newUserMessage]; - dispatch(setMessageHistory(updatedHistory)); - if (!model) { throw new Error("Model not loaded"); } - const stream = await model.chat.completions.create({ - messages: updatedHistory, - stream: true, - max_tokens: maxTokens, - }); - const response: ChatCompletionMessageParam = { - role: "assistant", - content: "" - }; - updatedHistory = [...updatedHistory, response]; - dispatch(setMessageHistory(updatedHistory)); + dispatch(setIsGenerating(true)); - for await (const chunk of stream) { - const delta = chunk?.choices?.[0]?.delta?.content ?? ""; - if (delta) { - const current = getState(state => state.llm.messageHistory); - const updated = [...current]; - const lastIndex = updated.length - 1; - updated[lastIndex] = { - ...updated[lastIndex], - content: updated[lastIndex].content + delta - }; - dispatch(setMessageHistory(updated)); + const newUserMessage: ChatCompletionMessageParam = {role: 'user', content: message}; + dispatch(addUserMessage(newUserMessage)); + + const messagesHistory = getState().llm.messageHistory; + + try { + const stream = await model.chat.completions.create({ + messages: messagesHistory, + stream: true, + max_tokens: maxTokens, + }); + + const botMessage: ChatCompletionMessageParam = { + role: "assistant", + content: "" + }; + dispatch(addBotMessage(botMessage)); + + for await (const chunk of stream) { + const delta = chunk?.choices?.[0]?.delta?.content ?? ""; + if (delta) { + dispatch(updateLastBotMessageContent(delta)); + } } + } catch (e) { + console.error(e) + } finally { + dispatch(setIsGenerating(false)); + } +} + +export function interrupt() { + if (model) { + model.interrupt(); } } diff --git a/src/ModelSelector.tsx b/src/ModelSelector.tsx new file mode 100644 index 0000000..3eafb8f --- /dev/null +++ b/src/ModelSelector.tsx @@ -0,0 +1,25 @@ +import {FormControl, InputLabel, Select, MenuItem} from "@mui/material"; +import {useTypedDispatch, useTypedSelector} from "./redux/store.ts"; +import {setSelectedModel} from "./redux/llmSlice.ts"; + +export function ModelSelector() { + const dispatch = useTypedDispatch(); + const {selectedModel, models} = useTypedSelector(state => state.llm); + + return ( + + Model + + + ); +} diff --git a/src/constants.ts b/src/constants.ts new file mode 100644 index 0000000..08044ff --- /dev/null +++ b/src/constants.ts @@ -0,0 +1 @@ +export const DOWNLOADED_MODELS_KEY = 'downloaded_models'; diff --git a/src/huggingface.ts b/src/huggingface.ts new file mode 100644 index 0000000..27716f0 --- /dev/null +++ b/src/huggingface.ts @@ -0,0 +1,13 @@ +import {listModels} from "@huggingface/hub"; + +export async function getCompatibleModels() { + const models = []; + for await (const model of listModels({ + search: { + tags: ['gguf'], + } + })) { + models.push(model); + } + return models; +} diff --git a/src/redux/llmSlice.ts b/src/redux/llmSlice.ts index ab6418e..553fe3f 100644 --- a/src/redux/llmSlice.ts +++ b/src/redux/llmSlice.ts @@ -1,27 +1,40 @@ -import {createSlice} from '@reduxjs/toolkit' +import {createSlice, PayloadAction} from '@reduxjs/toolkit' import type {ChatCompletionMessageParam} from "@mlc-ai/web-llm/lib/openai_api_protocols/chat_completion"; +import {ModelEntry} from "@huggingface/hub"; type State = { messageHistory: ChatCompletionMessageParam[], criticalError: string | false, downloadStatus: string, + isGenerating: boolean, + models: ModelEntry[], + selectedModel: string, } const initialState: State = { messageHistory: [], criticalError: false, downloadStatus: 'waiting', + isGenerating: false, + models: [], + selectedModel: '', } export const llmSlice = createSlice({ name: 'llm', initialState, reducers: { - setMessageHistory: ( - state, - {payload}: { payload: ChatCompletionMessageParam[] } - ) => { - state.messageHistory = payload; + addUserMessage: (state, action: PayloadAction) => { + state.messageHistory.push(action.payload); + }, + addBotMessage: (state, action: PayloadAction) => { + state.messageHistory.push(action.payload); + }, + updateLastBotMessageContent: (state, action: PayloadAction) => { + const lastMessage = state.messageHistory[state.messageHistory.length - 1]; + if (lastMessage && lastMessage.role === 'assistant') { + lastMessage.content += action.payload; + } }, setDownloadStatus: ( state, @@ -35,8 +48,29 @@ export const llmSlice = createSlice({ ) => { state.criticalError = payload; }, + setIsGenerating: (state, action: PayloadAction) => { + state.isGenerating = action.payload; + }, + setModels: (state, action: PayloadAction) => { + state.models = action.payload; + if (action.payload.length > 0 && !state.selectedModel) { + state.selectedModel = action.payload[0].id; + } + }, + setSelectedModel: (state, action: PayloadAction) => { + state.selectedModel = action.payload; + } }, }) -export const {setMessageHistory, setDownloadStatus, setCriticalError} = llmSlice.actions; +export const { + addUserMessage, + addBotMessage, + updateLastBotMessageContent, + setDownloadStatus, + setCriticalError, + setIsGenerating, + setModels, + setSelectedModel +} = llmSlice.actions; export const llmReducer = llmSlice.reducer; \ No newline at end of file diff --git a/src/test/setup.ts b/src/test/setup.ts new file mode 100644 index 0000000..7b0828b --- /dev/null +++ b/src/test/setup.ts @@ -0,0 +1 @@ +import '@testing-library/jest-dom'; diff --git a/vite.config.ts b/vite.config.ts index eb50b36..3a92bf3 100644 --- a/vite.config.ts +++ b/vite.config.ts @@ -1,3 +1,4 @@ +/// import {defineConfig} from 'vite' import react from '@vitejs/plugin-react' @@ -8,4 +9,9 @@ export default defineConfig({ }, base: 'browser-llm', plugins: [react()], + test: { + globals: true, + environment: 'jsdom', + setupFiles: './src/test/setup.ts', + }, })