Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
"scripts": {
"dev": "vite",
"build": "tsc -b && vite build",
"preview": "vite preview"
"preview": "vite preview",
"test": "vitest"
},
"dependencies": {
"@emotion/react": "^11.14.0",
"@emotion/styled": "^11.14.1",
"@huggingface/hub": "^2.4.0",
"@mlc-ai/web-llm": "^0.2.79",
"@mui/icons-material": "^7.2.0",
"@mui/material": "^7.2.0",
Expand All @@ -23,10 +25,14 @@
"react-syntax-highlighter": "^15.6.1"
},
"devDependencies": {
"@testing-library/jest-dom": "^6.6.4",
"@testing-library/react": "^16.3.0",
"@types/react-dom": "^19.1.7",
"@vitejs/plugin-react": "^4.7.0",
"@webgpu/types": "^0.1.64",
"jsdom": "^26.1.0",
"typescript": "~5.8.3",
"vite": "^7.0.4"
"vite": "^7.0.4",
"vitest": "^3.2.4"
}
}
46 changes: 28 additions & 18 deletions src/App.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import {downloadModel, sendPrompt} from "./LLM.ts";
import {downloadModel, interrupt, sendPrompt} from "./LLM.ts";
import {useEffect, useState} from "react";
import {useTypedDispatch, useTypedSelector} from "./redux/store.ts";
import {getCompatibleModels} from "./huggingface.ts";
import {setModels} from "./redux/llmSlice.ts";
import {
AppBar,
Box,
Expand All @@ -19,18 +21,21 @@ import {Send} from "@mui/icons-material";
import Markdown from "react-markdown";
import {setCriticalError} from "./redux/llmSlice.ts";
import {isWebGPUok} from "./CheckWebGPU.ts";

const MODEL = 'Llama-3.2-1B-Instruct-q4f16_1-MLC';
const MODEL_SIZE_MB = 664;
import {DOWNLOADED_MODELS_KEY} from "./constants.ts";
import {ModelSelector} from "./ModelSelector.tsx";

export function App() {
const {downloadStatus, messageHistory, criticalError} = useTypedSelector(state => state.llm);
const {downloadStatus, messageHistory, criticalError, isGenerating, selectedModel} = useTypedSelector(state => state.llm);
const dispatch = useTypedDispatch();
const [inputValue, setInputValue] = useState('');
const [alreadyFromCache, setAlreadyFromCache] = useState(false);
const [loadFinished, setLoadFinished] = useState(false);

useEffect(() => {
getCompatibleModels().then(models => {
dispatch(setModels(models));
});

isWebGPUok().then(trueOrError => {
if (trueOrError !== true) {
dispatch(setCriticalError('WebGPU error: ' + trueOrError));
Expand All @@ -44,18 +49,18 @@ export function App() {
navigator.storage.estimate().then(estimate => {
if (estimate) {
const remainingMb = (estimate.quota - estimate.usage) / 1024 / 1024;
if (!alreadyFromCache && remainingMb > 10 && remainingMb < MODEL_SIZE_MB) {
dispatch(setCriticalError('Remaining cache storage, that browser allowed is too low'));
}
// if (!alreadyFromCache && remainingMb > 10 && remainingMb < MODEL_SIZE_MB) {
// dispatch(setCriticalError('Remaining cache storage, that browser allowed is too low'));
// }
}
});
} else {
dispatch(setCriticalError('StorageManager API is not supported in your browser'));
}

if (localStorage.getItem('downloaded_models')) {
if (localStorage.getItem(DOWNLOADED_MODELS_KEY)) {
setAlreadyFromCache(true);
downloadModel(MODEL).then(() => setLoadFinished(true));
downloadModel(selectedModel).then(() => setLoadFinished(true));
}

}, []);
Expand Down Expand Up @@ -90,10 +95,11 @@ export function App() {
<h1>Browser LLM demo working on JavaScript and WebGPU</h1>
<Box sx={{flexGrow: 1, overflowY: 'auto', py: 2}}>
{!alreadyFromCache && !loadFinished && !criticalError && (
<Box sx={{textAlign: 'center', mb: 2}}>
<Box sx={{textAlign: 'center', mb: 2, display: 'flex', flexDirection: 'column', gap: 2}}>
<ModelSelector/>
<Button variant="contained" color="primary"
onClick={() => downloadModel(MODEL).then(() => setLoadFinished(true))}>Download
Model ({MODEL_SIZE_MB}MB)</Button>
onClick={() => downloadModel(selectedModel).then(() => setLoadFinished(true))}>Download
Model</Button>
</Box>
)}
<Typography>Loading model: {downloadStatus}</Typography>
Expand All @@ -113,8 +119,7 @@ export function App() {
>
<Typography
variant="body2" sx={{color: 'text.secondary', mb: 0.5}}>{message.role}:</Typography>
{/* @ts-ignore */}
<Markdown>{message.content}</Markdown>
<Markdown>{message.content || ''}</Markdown>
</Paper>
))}
</Box>
Expand Down Expand Up @@ -147,10 +152,15 @@ export function App() {
sx={{ml: 1, flex: 1}}
InputProps={{disableUnderline: true}}
/>
<IconButton type="submit" sx={{p: '10px'}} aria-label="send">
<Send/>
</IconButton>
{isGenerating ? (
<Button onClick={() => interrupt()}>Stop</Button>
) : (
<IconButton type="submit" sx={{p: '10px'}} aria-label="send" disabled={!inputValue}>
<Send/>
</IconButton>
)}
</Paper>
{isGenerating && <Typography sx={{textAlign: 'center', mt: 1}}>Generating response...</Typography>}
</Box>
)}
</Container>
Expand Down
47 changes: 47 additions & 0 deletions src/CheckWebGPU.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import {isWebGPUok} from './CheckWebGPU.ts';
import {describe, test, expect, vi} from 'vitest';

describe('isWebGPUok', () => {
test('should return an error message if WebGPU is not supported', async () => {
// In jsdom, navigator.gpu is undefined, so this should fail.
const result = await isWebGPUok();
expect(result).toBe('WebGPU is NOT supported on this browser.');
});

test('should return an error message if adapter is not found', async () => {
// Mock navigator.gpu but make requestAdapter return null
vi.stubGlobal('navigator', {
gpu: {
requestAdapter: async () => null,
},
});

const result = await isWebGPUok();
expect(result).toBe('WebGPU Adapter not found.');
});

test('should return true if WebGPU is supported and working', async () => {
// A more complete mock of the WebGPU API
const mockDevice = {
createShaderModule: vi.fn(() => ({
getCompilationInfo: vi.fn().mockResolvedValue({
messages: [],
}),
})),
};

const mockAdapter = {
requestDevice: async () => mockDevice,
features: new Set(['shader-f16']),
};

vi.stubGlobal('navigator', {
gpu: {
requestAdapter: async () => mockAdapter,
},
});

const result = await isWebGPUok();
expect(result).toBe(true);
});
});
74 changes: 45 additions & 29 deletions src/LLM.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import type {ChatCompletionMessageParam} from "@mlc-ai/web-llm/lib/openai_api_protocols/chat_completion";
import type {MLCEngine} from "@mlc-ai/web-llm";
import {setDownloadStatus, setMessageHistory, setCriticalError} from "./redux/llmSlice.ts";
import {
setDownloadStatus,
setCriticalError,
addUserMessage,
addBotMessage,
updateLastBotMessageContent,
setIsGenerating
} from "./redux/llmSlice.ts";
import {dispatch, getState} from "./redux/store.ts";
import {DOWNLOADED_MODELS_KEY} from "./constants.ts";

let libraryCache: any = null;

Expand Down Expand Up @@ -42,43 +50,51 @@ export async function downloadModel(name: string) {
console.error(error);
return;
}

dispatch(setDownloadStatus('done'));
localStorage.setItem('downloaded_models', JSON.stringify([name]));
localStorage.setItem(DOWNLOADED_MODELS_KEY, JSON.stringify([name]));
}

export async function sendPrompt(message: string, maxTokens = 1000) {
const messagesHistory = getState(state => state.llm.messageHistory);
const newUserMessage: ChatCompletionMessageParam = {role: 'user', content: message};
let updatedHistory = [...messagesHistory, newUserMessage];
dispatch(setMessageHistory(updatedHistory));

if (!model) {
throw new Error("Model not loaded");
}

const stream = await model.chat.completions.create({
messages: updatedHistory,
stream: true,
max_tokens: maxTokens,
});
const response: ChatCompletionMessageParam = {
role: "assistant",
content: ""
};
updatedHistory = [...updatedHistory, response];
dispatch(setMessageHistory(updatedHistory));
dispatch(setIsGenerating(true));

for await (const chunk of stream) {
const delta = chunk?.choices?.[0]?.delta?.content ?? "";
if (delta) {
const current = getState(state => state.llm.messageHistory);
const updated = [...current];
const lastIndex = updated.length - 1;
updated[lastIndex] = {
...updated[lastIndex],
content: updated[lastIndex].content + delta
};
dispatch(setMessageHistory(updated));
const newUserMessage: ChatCompletionMessageParam = {role: 'user', content: message};
dispatch(addUserMessage(newUserMessage));

const messagesHistory = getState().llm.messageHistory;

try {
const stream = await model.chat.completions.create({
messages: messagesHistory,
stream: true,
max_tokens: maxTokens,
});

const botMessage: ChatCompletionMessageParam = {
role: "assistant",
content: ""
};
dispatch(addBotMessage(botMessage));

for await (const chunk of stream) {
const delta = chunk?.choices?.[0]?.delta?.content ?? "";
if (delta) {
dispatch(updateLastBotMessageContent(delta));
}
}
} catch (e) {
console.error(e)
} finally {
dispatch(setIsGenerating(false));
}
}

export function interrupt() {
if (model) {
model.interrupt();
}
}
25 changes: 25 additions & 0 deletions src/ModelSelector.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import {FormControl, InputLabel, Select, MenuItem} from "@mui/material";
import {useTypedDispatch, useTypedSelector} from "./redux/store.ts";
import {setSelectedModel} from "./redux/llmSlice.ts";

export function ModelSelector() {
const dispatch = useTypedDispatch();
const {selectedModel, models} = useTypedSelector(state => state.llm);

return (
<FormControl fullWidth>
<InputLabel id="model-select-label">Model</InputLabel>
<Select
labelId="model-select-label"
id="model-select"
value={selectedModel}
label="Model"
onChange={(e) => dispatch(setSelectedModel(e.target.value))}
>
{models.map(model => (
<MenuItem key={model.id} value={model.id}>{model.id}</MenuItem>
))}
</Select>
</FormControl>
);
}
1 change: 1 addition & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export const DOWNLOADED_MODELS_KEY = 'downloaded_models';
13 changes: 13 additions & 0 deletions src/huggingface.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import {listModels} from "@huggingface/hub";

export async function getCompatibleModels() {
const models = [];
for await (const model of listModels({
search: {
tags: ['gguf'],
}
})) {
models.push(model);
}
return models;
}
Loading