From ca42836b8d9db042374a4d370d5ccd2a37643ee5 Mon Sep 17 00:00:00 2001 From: littlestone <342206015@qq.com> Date: Thu, 3 Apr 2025 18:56:47 +0800 Subject: [PATCH] feat: add max tokens control --- components/Chat/Chat.tsx | 27 ++++++++++++--- components/Chat/MaxTokens.tsx | 64 +++++++++++++++++++++++++++++++++++ pages/api/chat.ts | 27 ++++++++++++--- pages/api/home/home.state.tsx | 5 ++- public/locales/zh/chat.json | 7 +++- types/chat.ts | 5 ++- utils/app/const.ts | 4 +++ utils/server/index.ts | 32 ++++++++++++------ 8 files changed, 148 insertions(+), 23 deletions(-) create mode 100644 components/Chat/MaxTokens.tsx diff --git a/components/Chat/Chat.tsx b/components/Chat/Chat.tsx index 2ad942a..0c26450 100644 --- a/components/Chat/Chat.tsx +++ b/components/Chat/Chat.tsx @@ -10,6 +10,8 @@ import { } from 'react'; import toast from 'react-hot-toast'; + + import { useTranslation } from 'next-i18next'; import { getEndpoint } from '@/utils/app/api'; @@ -23,16 +25,21 @@ import { throttle } from '@/utils/data/throttle'; import { ChatBody, Conversation, Message } from '@/types/chat'; import { Plugin } from '@/types/plugin'; + + import HomeContext from '@/pages/api/home/home.context'; + + import Spinner from '../Spinner'; import { ChatInput } from './ChatInput'; import { ChatLoader } from './ChatLoader'; import { ErrorMessageDiv } from './ErrorMessageDiv'; +import { MaxTokensSlider } from './MaxTokens'; +import { MemoizedChatMessage } from './MemoizedChatMessage'; import { ModelSelect } from './ModelSelect'; import { SystemPrompt } from './SystemPrompt'; import { TemperatureSlider } from './Temperature'; -import { MemoizedChatMessage } from './MemoizedChatMessage'; interface Props { stopConversationRef: MutableRefObject; @@ -99,6 +106,7 @@ export const Chat = memo(({ stopConversationRef }: Props) => { key: apiKey, prompt: updatedConversation.prompt, temperature: updatedConversation.temperature, + maxTokens: updatedConversation.maxTokens, }; const endpoint = getEndpoint(plugin); let body; @@ -251,6 +259,7 @@ export const Chat = memo(({ stopConversationRef }: Props) => { pluginKeys, selectedConversation, stopConversationRef, + homeDispatch, ], ); @@ -433,6 +442,16 @@ export const Chat = memo(({ stopConversationRef }: Props) => { }) } /> + + + handleUpdateConversation(selectedConversation, { + key: 'maxTokens', + value: maxTokens, + }) + } + /> )} @@ -440,8 +459,8 @@ export const Chat = memo(({ stopConversationRef }: Props) => { ) : ( <>
- {t('Model')}: {selectedConversation?.model?.name} | {t('Temp')} - : {selectedConversation?.temperature} | + {t('Model')}: {selectedConversation?.model?.name} |{' '} + {t('Temp')}: {selectedConversation?.temperature} |
); }); -Chat.displayName = 'Chat'; +Chat.displayName = 'Chat'; \ No newline at end of file diff --git a/components/Chat/MaxTokens.tsx b/components/Chat/MaxTokens.tsx new file mode 100644 index 0000000..87f4ae8 --- /dev/null +++ b/components/Chat/MaxTokens.tsx @@ -0,0 +1,64 @@ +import { FC, useContext, useState } from 'react'; + +import { useTranslation } from 'next-i18next'; + +import { DEFAULT_MAX_TOKENS } from '@/utils/app/const'; + +import HomeContext from '@/pages/api/home/home.context'; + +interface Props { + label: string; + onChangeMaxTokens: (temperature: number) => void; +} + +export const MaxTokensSlider: FC = ({ label, onChangeMaxTokens }) => { + const { + state: { conversations }, + } = useContext(HomeContext); + const lastConversation = conversations[conversations.length - 1]; + const [maxTokens, setMaxTokens] = useState( + lastConversation?.maxTokens ?? DEFAULT_MAX_TOKENS, + ); + const { t } = useTranslation('chat'); + const handleChange = (event: React.ChangeEvent) => { + const newValue = parseFloat(event.target.value); + setMaxTokens(newValue); + onChangeMaxTokens(newValue); + }; + + return ( +
+ + + {t( + 'Higher values for max_tokens will allow the model to generate longer responses, while lower values will restrict the output to be more concise and constrained.', + )} + + + {maxTokens} + + +
    +
  • + {t('Concise')} +
  • +
  • + {t('Moderate')} +
  • +
  • + {t('Extended')} +
  • +
+
+ ); +}; diff --git a/pages/api/chat.ts b/pages/api/chat.ts index 03b9e33..9f2b2b6 100644 --- a/pages/api/chat.ts +++ b/pages/api/chat.ts @@ -1,4 +1,8 @@ -import { DEFAULT_SYSTEM_PROMPT, DEFAULT_TEMPERATURE } from '@/utils/app/const'; +import { + DEFAULT_MAX_TOKENS, + DEFAULT_SYSTEM_PROMPT, + DEFAULT_TEMPERATURE, +} from '@/utils/app/const'; import { OpenAIError, OpenAIStream } from '@/utils/server'; import { ChatBody, Message } from '@/types/chat'; @@ -6,14 +10,14 @@ import { ChatBody, Message } from '@/types/chat'; // @ts-expect-error import wasm from '../../node_modules/@dqbd/tiktoken/lite/tiktoken_bg.wasm?module'; - export const config = { runtime: 'edge', }; const handler = async (req: Request): Promise => { try { - const { model, messages, key, prompt, temperature } = (await req.json()) as ChatBody; + const { model, messages, key, prompt, temperature, maxTokens } = + (await req.json()) as ChatBody; let promptToSend = prompt; if (!promptToSend) { @@ -24,11 +28,24 @@ const handler = async (req: Request): Promise => { if (temperatureToUse == null) { temperatureToUse = DEFAULT_TEMPERATURE; } + + let maxTokensToUse = maxTokens; + if (maxTokensToUse == null) { + maxTokensToUse = DEFAULT_MAX_TOKENS; + } + if (model == null) { throw new Error('No model specified'); } - - const stream = await OpenAIStream(model, promptToSend, temperatureToUse, key, messages); + + const stream = await OpenAIStream( + model, + promptToSend, + temperatureToUse, + maxTokensToUse, + key, + messages, + ); return new Response(stream); } catch (error) { diff --git a/pages/api/home/home.state.tsx b/pages/api/home/home.state.tsx index 4fcb082..bbf803f 100644 --- a/pages/api/home/home.state.tsx +++ b/pages/api/home/home.state.tsx @@ -5,6 +5,7 @@ import { OpenAIModel } from '@/types/openai'; import { PluginKey } from '@/types/plugin'; import { Prompt } from '@/types/prompt'; + export interface HomeInitialState { apiKey: string; pluginKeys: PluginKey[]; @@ -19,6 +20,7 @@ export interface HomeInitialState { currentMessage: Message | undefined; prompts: Prompt[]; temperature: number; + maxTokens: number; showChatbar: boolean; showPromptbar: boolean; currentFolder: FolderInterface | undefined; @@ -49,4 +51,5 @@ export const initialState: HomeInitialState = { searchTerm: '', serverSideApiKeyIsSet: false, serverSidePluginKeysSet: false, -}; + maxTokens: 1024, +}; \ No newline at end of file diff --git a/public/locales/zh/chat.json b/public/locales/zh/chat.json index b109d6f..98f1e0c 100644 --- a/public/locales/zh/chat.json +++ b/public/locales/zh/chat.json @@ -28,9 +28,14 @@ "Chatbot UI is an advanced chatbot kit for OpenAI's chat models aiming to mimic ChatGPT's interface and functionality.": "Chatbot UI 是一个高级聊天机器人工具包,旨在模仿 OpenAI 聊天模型的 ChatGPT 界面和功能。", "Are you sure you want to clear all messages?": "你确定要清除所有的消息吗?", "Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.": "较高的数值(例如0.8)会使输出更随机,而较低的数值(例如0.2)会使输出更加聚焦和确定性更强。", + "Max Tokens": "最大令牌数", + "Higher values for max_tokens will allow the model to generate longer responses, while lower values will restrict the output to be more concise and constrained.": "较高的 max_tokens 值将允许模型生成更长的响应,而较低的值将限制输出更加简洁和受限。", "View Account Usage": "查阅账户用量", "Temperature": "生成温度", "Precise": "保守", "Neutral": "中立", - "Creative": "随性" + "Creative": "随性", + "Concise": "简洁", + "Moderate": "适中", + "Extended": "较长" } diff --git a/types/chat.ts b/types/chat.ts index 5e47995..3a0fa9b 100644 --- a/types/chat.ts +++ b/types/chat.ts @@ -1,5 +1,6 @@ import { OpenAIModel } from './openai'; + export interface Message { role: Role; content: string; @@ -13,6 +14,7 @@ export interface ChatBody { key: string; prompt: string; temperature: number; + maxTokens: number; } export interface Conversation { @@ -22,5 +24,6 @@ export interface Conversation { model: OpenAIModel | null; prompt: string; temperature: number; + maxTokens: number; folderId: string | null; -} +} \ No newline at end of file diff --git a/utils/app/const.ts b/utils/app/const.ts index 7e8e18d..4715a05 100644 --- a/utils/app/const.ts +++ b/utils/app/const.ts @@ -19,3 +19,7 @@ export const OPENAI_ORGANIZATION = export const AZURE_DEPLOYMENT_ID = process.env.AZURE_DEPLOYMENT_ID || ''; + +export const DEFAULT_MAX_TOKENS = parseInt( + process.env.NEXT_PUBLIC_DEFAULT_MAX_TOKENS || '1024', +); diff --git a/utils/server/index.ts b/utils/server/index.ts index ada64f4..9794b29 100644 --- a/utils/server/index.ts +++ b/utils/server/index.ts @@ -1,7 +1,15 @@ import { Message } from '@/types/chat'; import { OpenAIModel } from '@/types/openai'; -import { AZURE_DEPLOYMENT_ID, OPENAI_API_HOST, OPENAI_API_TYPE, OPENAI_API_VERSION, OPENAI_ORGANIZATION } from '../app/const'; + + +import { + AZURE_DEPLOYMENT_ID, + OPENAI_API_HOST, + OPENAI_API_TYPE, + OPENAI_API_VERSION, + OPENAI_ORGANIZATION, +} from '../app/const'; import { ParsedEvent, @@ -26,7 +34,8 @@ export class OpenAIError extends Error { export const OpenAIStream = async ( model: OpenAIModel, systemPrompt: string, - temperature : number, + temperature: number, + maxTokens: number, key: string, messages: Message[], ) => { @@ -38,18 +47,19 @@ export const OpenAIStream = async ( headers: { 'Content-Type': 'application/json', ...(OPENAI_API_TYPE === 'openai' && { - Authorization: `Bearer ${key ? key : process.env.OPENAI_API_KEY}` + Authorization: `Bearer ${key ? key : process.env.OPENAI_API_KEY}`, }), ...(OPENAI_API_TYPE === 'azure' && { - 'api-key': `${key ? key : process.env.OPENAI_API_KEY}` - }), - ...((OPENAI_API_TYPE === 'openai' && OPENAI_ORGANIZATION) && { - 'OpenAI-Organization': OPENAI_ORGANIZATION, + 'api-key': `${key ? key : process.env.OPENAI_API_KEY}`, }), + ...(OPENAI_API_TYPE === 'openai' && + OPENAI_ORGANIZATION && { + 'OpenAI-Organization': OPENAI_ORGANIZATION, + }), }, method: 'POST', body: JSON.stringify({ - ...(OPENAI_API_TYPE === 'openai' && {model: model.id}), + ...(OPENAI_API_TYPE === 'openai' && { model: model.id }), messages: [ { role: 'system', @@ -57,7 +67,7 @@ export const OpenAIStream = async ( }, ...messages, ], - max_tokens: 2048, + max_tokens: maxTokens, temperature: temperature, stream: true, }), @@ -105,7 +115,7 @@ export const OpenAIStream = async ( const queue = encoder.encode(text); controller.enqueue(queue); } - + if (choices[0].finish_reason != null) { controller.close(); return; @@ -125,4 +135,4 @@ export const OpenAIStream = async ( }); return stream; -}; +}; \ No newline at end of file