diff --git a/frontend/components/ChatInput.tsx b/frontend/components/ChatInput.tsx index 0630f3b..691ee85 100644 --- a/frontend/components/ChatInput.tsx +++ b/frontend/components/ChatInput.tsx @@ -8,6 +8,9 @@ import { DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger, + DropdownMenuLabel, + DropdownMenuSeparator, + DropdownMenuGroup, } from '@/frontend/components/ui/dropdown-menu'; import useAutoResizeTextarea from '@/hooks/useAutoResizeTextArea'; import { UseChatHelpers, useCompletion } from '@ai-sdk/react'; @@ -187,6 +190,7 @@ const ChatInput = memo(PureChatInput, (prevProps, nextProps) => { const PureChatModelDropdown = () => { const getKey = useAPIKeyStore((state) => state.getKey); const { selectedModel, setModel } = useModelStore(); + const navigate = useNavigate(); const isModelEnabled = useCallback( (model: AIModel) => { @@ -197,6 +201,35 @@ const PureChatModelDropdown = () => { [getKey] ); + const groupedModels = useMemo(() => { + const groups: Record = {}; + + AI_MODELS.forEach((model) => { + const config = getModelConfig(model); + const provider = config.provider; + const apiKey = getKey(provider); + const hasKey = !!apiKey; + + if (!groups[provider]) { + groups[provider] = { + models: [], + hasApiKey: hasKey, + keyStatus: hasKey ? `API key configured` : 'No API key' + }; + } + + groups[provider].models.push(model); + }); + + return groups; + }, [getKey]); + + const providerDisplayNames = { + google: 'Google AI', + openai: 'OpenAI', + openrouter: 'OpenRouter' + }; + return (
@@ -213,30 +246,53 @@ const PureChatModelDropdown = () => { - {AI_MODELS.map((model) => { - const isEnabled = isModelEnabled(model); - return ( - isEnabled && setModel(model)} - disabled={!isEnabled} + {Object.entries(groupedModels).map(([provider, group], index) => ( + + navigate('/settings') : undefined} > - {model} - {selectedModel === model && ( - + + {providerDisplayNames[provider as keyof typeof providerDisplayNames]} + + {!group.hasApiKey && ( + + Click to configure + )} - - ); - })} + + {group.models.map((model) => { + const isEnabled = isModelEnabled(model); + return ( + isEnabled && setModel(model)} + disabled={!isEnabled} + className={cn( + 'flex items-center justify-between gap-2', + 'cursor-pointer ml-2' + )} + > + {model} + {selectedModel === model && ( + + )} + + ); + })} + {index < Object.keys(groupedModels).length - 1 && ( + + )} + + ))}