diff --git a/package.json b/package.json index 62ec2e4..841e419 100644 --- a/package.json +++ b/package.json @@ -66,7 +66,8 @@ "appId": "com.dortort.keystone", "productName": "Keystone", "directories": { - "output": "release" + "output": "release", + "buildResources": "resources" }, "files": [ "dist/**/*" @@ -90,4 +91,4 @@ ] } } -} +} \ No newline at end of file diff --git a/resources/icon.png b/resources/icon.png new file mode 100644 index 0000000..cf4fbf3 Binary files /dev/null and b/resources/icon.png differ diff --git a/resources/social-preview.png b/resources/social-preview.png new file mode 100644 index 0000000..b7355cb Binary files /dev/null and b/resources/social-preview.png differ diff --git a/src/agents/providers/GoogleAdapter.ts b/src/agents/providers/GoogleAdapter.ts index 4cfe3e6..77acb70 100644 --- a/src/agents/providers/GoogleAdapter.ts +++ b/src/agents/providers/GoogleAdapter.ts @@ -1,13 +1,41 @@ import { BaseLLMClient } from './BaseLLMClient' import type { ChatMessage, ChatOptions } from '@shared/types/provider' +interface GoogleAuthApiKey { + apiKey: string +} + +interface GoogleAuthOAuth { + oauthToken: string +} + +type GoogleAuth = GoogleAuthApiKey | GoogleAuthOAuth + export class GoogleAdapter extends BaseLLMClient { - private apiKey: string + private auth: GoogleAuth private baseUrl = 'https://generativelanguage.googleapis.com/v1beta' - constructor(apiKey: string) { + constructor(auth: string | GoogleAuth) { super() - this.apiKey = apiKey + this.auth = typeof auth === 'string' ? { apiKey: auth } : auth + } + + updateOAuthToken(token: string): void { + this.auth = { oauthToken: token } + } + + private getAuthHeaders(): Record { + if ('oauthToken' in this.auth) { + return { Authorization: `Bearer ${this.auth.oauthToken}` } + } + return { 'x-goog-api-key': this.auth.apiKey } + } + + private getUrl(model: string): string { + const base = `${this.baseUrl}/models/${model}:streamGenerateContent?alt=sse` + // When using API key (not OAuth), append key as query param is not needed + // since we send it via header. Just return base URL. + return base } async *chat(messages: ChatMessage[], options?: ChatOptions): AsyncIterable { @@ -35,13 +63,13 @@ export class GoogleAdapter extends BaseLLMClient { } } - const url = `${this.baseUrl}/models/${model}:streamGenerateContent?alt=sse` + const url = this.getUrl(model) const response = await fetch(url, { method: 'POST', headers: { 'Content-Type': 'application/json', - 'x-goog-api-key': this.apiKey, + ...this.getAuthHeaders(), }, body: JSON.stringify(body), }) diff --git a/src/agents/providers/OpenAIAdapter.ts b/src/agents/providers/OpenAIAdapter.ts index 47703f7..2470ca6 100644 --- a/src/agents/providers/OpenAIAdapter.ts +++ b/src/agents/providers/OpenAIAdapter.ts @@ -1,13 +1,45 @@ import { BaseLLMClient } from './BaseLLMClient' import type { ChatMessage, ChatOptions } from '@shared/types/provider' +interface OpenAIAuthApiKey { + apiKey: string +} + +interface OpenAIAuthOAuth { + oauthToken: string + accountId?: string +} + +type OpenAIAuth = OpenAIAuthApiKey | OpenAIAuthOAuth + export class OpenAIAdapter extends BaseLLMClient { - private apiKey: string + private auth: OpenAIAuth private baseUrl = 'https://api.openai.com/v1' - constructor(apiKey: string) { + constructor(auth: string | OpenAIAuth) { super() - this.apiKey = apiKey + this.auth = typeof auth === 'string' ? { apiKey: auth } : auth + } + + updateOAuthToken(token: string, accountId?: string): void { + this.auth = { oauthToken: token, accountId } + } + + private getHeaders(): Record { + const headers: Record = { + 'Content-Type': 'application/json', + } + + if ('oauthToken' in this.auth) { + headers['Authorization'] = `Bearer ${this.auth.oauthToken}` + if (this.auth.accountId) { + headers['chatgpt-account-id'] = this.auth.accountId + } + } else { + headers['Authorization'] = `Bearer ${this.auth.apiKey}` + } + + return headers } async *chat(messages: ChatMessage[], options?: ChatOptions): AsyncIterable { @@ -31,10 +63,7 @@ export class OpenAIAdapter extends BaseLLMClient { const response = await fetch(`${this.baseUrl}/chat/completions`, { method: 'POST', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${this.apiKey}`, - }, + headers: this.getHeaders(), body: JSON.stringify(body), }) diff --git a/src/agents/providers/ProviderManager.ts b/src/agents/providers/ProviderManager.ts index 5cf862e..0c519fb 100644 --- a/src/agents/providers/ProviderManager.ts +++ b/src/agents/providers/ProviderManager.ts @@ -1,4 +1,4 @@ -import type { ProviderType, ProviderConfig } from '@shared/types/provider' +import type { ProviderType, ProviderConfig, LegacyProviderConfig } from '@shared/types/provider' import { BaseLLMClient } from './BaseLLMClient' import { AnthropicAdapter } from './AnthropicAdapter' import { OpenAIAdapter } from './OpenAIAdapter' @@ -8,21 +8,36 @@ export class ProviderManager { private providers: Map = new Map() private activeProvider: ProviderType | null = null - configure(config: ProviderConfig): void { + configure(config: ProviderConfig | LegacyProviderConfig): void { let client: BaseLLMClient + // Normalize: legacy configs lack authMethod field + const authMethod = 'authMethod' in config ? config.authMethod : 'apiKey' + switch (config.type) { case 'anthropic': - client = new AnthropicAdapter(config.apiKey) + // Anthropic only supports API keys + if (authMethod === 'oauth') { + throw new Error('Anthropic does not support OAuth authentication') + } + client = new AnthropicAdapter('apiKey' in config ? config.apiKey : '') break case 'openai': - client = new OpenAIAdapter(config.apiKey) + if (authMethod === 'oauth' && 'oauthToken' in config) { + client = new OpenAIAdapter({ oauthToken: config.oauthToken, accountId: config.accountId }) + } else { + client = new OpenAIAdapter('apiKey' in config ? config.apiKey : '') + } break case 'google': - client = new GoogleAdapter(config.apiKey) + if (authMethod === 'oauth' && 'oauthToken' in config) { + client = new GoogleAdapter({ oauthToken: config.oauthToken }) + } else { + client = new GoogleAdapter('apiKey' in config ? config.apiKey : '') + } break default: - throw new Error(`Unknown provider: ${config.type}`) + throw new Error(`Unknown provider: ${(config as any).type}`) } this.providers.set(config.type, client) @@ -31,6 +46,17 @@ export class ProviderManager { } } + updateOAuthToken(provider: ProviderType, token: string, accountId?: string): void { + const client = this.providers.get(provider) + if (!client) return + + if (provider === 'openai' && client instanceof OpenAIAdapter) { + client.updateOAuthToken(token, accountId) + } else if (provider === 'google' && client instanceof GoogleAdapter) { + client.updateOAuthToken(token) + } + } + setActive(type: ProviderType): void { if (!this.providers.has(type)) { throw new Error(`Provider ${type} not configured`) @@ -60,4 +86,12 @@ export class ProviderManager { getConfiguredProviders(): ProviderType[] { return Array.from(this.providers.keys()) } + + removeProvider(type: ProviderType): void { + this.providers.delete(type) + if (this.activeProvider === type) { + const remaining = Array.from(this.providers.keys()) + this.activeProvider = remaining.length > 0 ? remaining[0] : null + } + } } diff --git a/src/main/ipc/ai.router.ts b/src/main/ipc/ai.router.ts index cc3641e..defcd9d 100644 --- a/src/main/ipc/ai.router.ts +++ b/src/main/ipc/ai.router.ts @@ -144,16 +144,29 @@ export const aiRouter = router({ .input( z.object({ type: z.enum(['openai', 'anthropic', 'google']), - apiKey: z.string().min(1), + apiKey: z.string().min(1).optional(), + authMethod: z.enum(['apiKey', 'oauth']).optional(), + oauthToken: z.string().optional(), + accountId: z.string().optional(), }), ) .mutation(async ({ input, ctx }) => { - // Configure the in-memory provider - ctx.providerManager.configure({ type: input.type, apiKey: input.apiKey }) - ctx.providerManager.setActive(input.type) + const authMethod = input.authMethod || 'apiKey' + + if (authMethod === 'oauth' && input.oauthToken) { + ctx.providerManager.configure({ + type: input.type, + authMethod: 'oauth', + oauthToken: input.oauthToken, + accountId: input.accountId, + }) + } else if (input.apiKey) { + ctx.providerManager.configure({ type: input.type, authMethod: 'apiKey', apiKey: input.apiKey }) + ctx.settingsService.setApiKey(input.type, input.apiKey) + ctx.settingsService.setAuthMethod(input.type, 'apiKey') + } - // Persist the API key and active provider selection - ctx.settingsService.setApiKey(input.type, input.apiKey) + ctx.providerManager.setActive(input.type) ctx.settingsService.setActiveProvider(input.type) return { diff --git a/src/main/ipc/context.ts b/src/main/ipc/context.ts index 47deda9..73acf90 100644 --- a/src/main/ipc/context.ts +++ b/src/main/ipc/context.ts @@ -1,4 +1,5 @@ import type { Context } from './trpc' +import type { ProviderType } from '@shared/types/provider' import { DatabaseService } from '../services/DatabaseService' import { ProjectService } from '../services/ProjectService' import { DocumentService } from '../services/DocumentService' @@ -6,6 +7,7 @@ import { ThreadService } from '../services/ThreadService' import { ProviderManager } from '../../agents/providers/ProviderManager' import { Orchestrator } from '../../agents/orchestrator/Orchestrator' import { SettingsService } from '../services/SettingsService' +import { OAuthService } from '../services/OAuthService' let _context: Context | null = null @@ -20,13 +22,58 @@ export async function createContext(): Promise { const settingsService = new SettingsService() const providerManager = new ProviderManager() const orchestrator = new Orchestrator() + const oauthService = new OAuthService() - // Restore any previously configured providers from persisted settings + // Wire OAuth token refresh callback + oauthService.onTokenRefresh((provider, tokens) => { + settingsService.setOAuthTokens(provider, tokens) + providerManager.updateOAuthToken(provider, tokens.accessToken, tokens.accountId) + }) + + // Restore API-key-based providers from persisted settings for (const provider of settingsService.getConfiguredProviders()) { - const apiKey = settingsService.getApiKey(provider) - if (apiKey) { + const authMethod = settingsService.getAuthMethod(provider) + if (authMethod === 'apiKey') { + const apiKey = settingsService.getApiKey(provider) + if (apiKey) { + try { + providerManager.configure({ type: provider as ProviderType, authMethod: 'apiKey', apiKey }) + } catch { + // Skip invalid providers silently + } + } + } + } + + // Restore OAuth-based providers from persisted settings + for (const provider of settingsService.getOAuthConfiguredProviders()) { + const tokens = settingsService.getOAuthTokens(provider) + const authMethod = settingsService.getAuthMethod(provider) + if (tokens && authMethod === 'oauth') { try { - providerManager.configure({ type: provider as 'openai' | 'anthropic' | 'google', apiKey }) + // Check if token is expired — attempt immediate refresh if so + if (tokens.expiresAt < Date.now() && tokens.refreshToken) { + const refreshed = await oauthService.refreshToken(provider as ProviderType, tokens) + if (refreshed) { + settingsService.setOAuthTokens(provider, refreshed) + providerManager.configure({ + type: provider as ProviderType, + authMethod: 'oauth', + oauthToken: refreshed.accessToken, + accountId: refreshed.accountId, + }) + oauthService.scheduleRefresh(provider as ProviderType, refreshed) + } + // If refresh fails, skip — user will need to re-authenticate + } else { + providerManager.configure({ + type: provider as ProviderType, + authMethod: 'oauth', + oauthToken: tokens.accessToken, + accountId: tokens.accountId, + }) + oauthService.scheduleRefresh(provider as ProviderType, tokens) + } } catch { // Skip invalid providers silently } @@ -37,12 +84,12 @@ export async function createContext(): Promise { const activeProvider = settingsService.getActiveProvider() if (activeProvider && providerManager.isConfigured()) { try { - providerManager.setActive(activeProvider as 'openai' | 'anthropic' | 'google') + providerManager.setActive(activeProvider as ProviderType) } catch { // Active provider may no longer be configured } } - _context = { db, projectService, documentService, threadService, settingsService, providerManager, orchestrator } + _context = { db, projectService, documentService, threadService, settingsService, providerManager, orchestrator, oauthService } return _context } diff --git a/src/main/ipc/oauth.router.ts b/src/main/ipc/oauth.router.ts new file mode 100644 index 0000000..ee90900 --- /dev/null +++ b/src/main/ipc/oauth.router.ts @@ -0,0 +1,74 @@ +import { z } from 'zod' +import { router, publicProcedure } from './trpc' + +export const oauthRouter = router({ + startFlow: publicProcedure + .input(z.object({ + provider: z.enum(['openai', 'anthropic', 'google']), + })) + .mutation(async ({ input, ctx }) => { + const tokens = await ctx.oauthService.startFlow(input.provider) + + // Store tokens and update auth method + ctx.settingsService.setOAuthTokens(input.provider, tokens) + ctx.settingsService.setAuthMethod(input.provider, 'oauth') + + // Configure provider with OAuth token + ctx.providerManager.configure({ + type: input.provider, + authMethod: 'oauth', + oauthToken: tokens.accessToken, + accountId: tokens.accountId, + }) + ctx.providerManager.setActive(input.provider) + ctx.settingsService.setActiveProvider(input.provider) + + // Schedule token refresh + ctx.oauthService.scheduleRefresh(input.provider, tokens) + + return { + success: true, + email: tokens.email, + provider: input.provider, + } + }), + + disconnect: publicProcedure + .input(z.object({ + provider: z.enum(['openai', 'anthropic', 'google']), + })) + .mutation(({ input, ctx }) => { + // Clear tokens and refresh timer + ctx.settingsService.deleteOAuthTokens(input.provider) + ctx.oauthService.clearRefreshTimer(input.provider) + + // Remove the provider from memory + ctx.providerManager.removeProvider(input.provider) + + return { success: true } + }), + + getStatus: publicProcedure + .input(z.object({ + provider: z.enum(['openai', 'anthropic', 'google']), + })) + .query(({ input, ctx }) => { + const tokens = ctx.settingsService.getOAuthTokens(input.provider) + const authMethod = ctx.settingsService.getAuthMethod(input.provider) + + return { + connected: authMethod === 'oauth' && tokens !== null, + email: tokens?.email ?? null, + expiresAt: tokens?.expiresAt ?? null, + authMethod, + } + }), + + getCapabilities: publicProcedure + .input(z.object({ + provider: z.enum(['openai', 'anthropic', 'google']), + })) + .query(({ input, ctx }) => { + return ctx.oauthService.getCapabilities(input.provider) + }), +}) diff --git a/src/main/ipc/router.ts b/src/main/ipc/router.ts index 38c5342..3bd60ce 100644 --- a/src/main/ipc/router.ts +++ b/src/main/ipc/router.ts @@ -4,6 +4,7 @@ import { documentRouter } from './document.router' import { threadRouter } from './thread.router' import { aiRouter } from './ai.router' import { settingsRouter } from './settings.router' +import { oauthRouter } from './oauth.router' export const appRouter = router({ project: projectRouter, @@ -11,6 +12,7 @@ export const appRouter = router({ thread: threadRouter, ai: aiRouter, settings: settingsRouter, + oauth: oauthRouter, }) export type AppRouter = typeof appRouter diff --git a/src/main/ipc/trpc.ts b/src/main/ipc/trpc.ts index 302e340..d1212fc 100644 --- a/src/main/ipc/trpc.ts +++ b/src/main/ipc/trpc.ts @@ -6,6 +6,7 @@ import type { ThreadService } from '../services/ThreadService' import type { ProviderManager } from '../../agents/providers/ProviderManager' import type { Orchestrator } from '../../agents/orchestrator/Orchestrator' import type { SettingsService } from '../services/SettingsService' +import type { OAuthService } from '../services/OAuthService' export interface Context { db: DatabaseService @@ -15,6 +16,7 @@ export interface Context { providerManager: ProviderManager orchestrator: Orchestrator settingsService: SettingsService + oauthService: OAuthService } const t = initTRPC.context().create() diff --git a/src/main/services/OAuthService.ts b/src/main/services/OAuthService.ts new file mode 100644 index 0000000..7e0cada --- /dev/null +++ b/src/main/services/OAuthService.ts @@ -0,0 +1,389 @@ +import * as http from 'http' +import * as crypto from 'crypto' +import { BrowserWindow, shell } from 'electron' +import { OAUTH_CAPABILITIES } from '@shared/constants' +import type { ProviderType, OAuthTokens, OAuthFlowStatus } from '@shared/types/provider' +import { OAuthError } from '@shared/errors' + +type StatusCallback = (status: OAuthFlowStatus) => void + +export class OAuthService { + private refreshTimers: Map = new Map() + private activeServer: http.Server | null = null + private statusCallback: StatusCallback | null = null + private tokenRefreshCallback: ((provider: ProviderType, tokens: OAuthTokens) => void) | null = null + + onStatus(callback: StatusCallback): void { + this.statusCallback = callback + } + + onTokenRefresh(callback: (provider: ProviderType, tokens: OAuthTokens) => void): void { + this.tokenRefreshCallback = callback + } + + private emitStatus(status: OAuthFlowStatus): void { + this.statusCallback?.(status) + const win = BrowserWindow.getAllWindows()[0] + if (win && !win.isDestroyed()) { + win.webContents.send('oauth:status', status) + } + } + + async startFlow(provider: ProviderType): Promise { + const capabilities = OAUTH_CAPABILITIES[provider] + if (!capabilities?.supported) { + throw new OAuthError(provider, 'OAuth not supported for this provider') + } + + // Cancel any existing flow + this.cancelFlow() + + this.emitStatus({ state: 'pending', provider }) + + // Generate PKCE pair + const codeVerifier = crypto.randomBytes(32).toString('base64url') + const codeChallenge = crypto + .createHash('sha256') + .update(codeVerifier) + .digest('base64url') + + const state = crypto.randomBytes(16).toString('hex') + + return new Promise((resolve, reject) => { + let redirectUri = '' + let settled = false + + const settle = (fn: () => void) => { + if (!settled) { + settled = true + fn() + } + } + + const callbackPath = capabilities.callbackPath || '/callback' + + const server = http.createServer(async (req, res) => { + const url = new URL(req.url || '/', `http://127.0.0.1`) + + if (url.pathname !== callbackPath) { + res.writeHead(404) + res.end('Not found') + return + } + + const code = url.searchParams.get('code') + const returnedState = url.searchParams.get('state') + const error = url.searchParams.get('error') + + if (error) { + res.writeHead(200, { 'Content-Type': 'text/html' }) + res.end(this.buildCallbackHtml(false, `Authorization denied: ${error}`)) + this.emitStatus({ state: 'error', provider, error: `Authorization denied: ${error}` }) + this.shutdownServer() + settle(() => reject(new OAuthError(provider, `Authorization denied: ${error}`))) + return + } + + if (!code || returnedState !== state) { + res.writeHead(200, { 'Content-Type': 'text/html' }) + res.end(this.buildCallbackHtml(false, 'Invalid callback parameters')) + this.emitStatus({ state: 'error', provider, error: 'Invalid callback parameters' }) + this.shutdownServer() + settle(() => reject(new OAuthError(provider, 'Invalid callback parameters'))) + return + } + + try { + const tokens = await this.exchangeCode(provider, code, codeVerifier, redirectUri) + res.writeHead(200, { 'Content-Type': 'text/html' }) + res.end(this.buildCallbackHtml(true, 'You can close this tab and return to Keystone.')) + this.emitStatus({ state: 'success', provider, email: tokens.email }) + this.shutdownServer() + settle(() => resolve(tokens)) + } catch (err) { + const msg = err instanceof Error ? err.message : 'Token exchange failed' + res.writeHead(200, { 'Content-Type': 'text/html' }) + res.end(this.buildCallbackHtml(false, msg)) + this.emitStatus({ state: 'error', provider, error: msg }) + this.shutdownServer() + settle(() => reject(new OAuthError(provider, msg))) + } + }) + + // Listen on provider-specific port (or random port if not specified) + const listenPort = capabilities.redirectPort || 0 + server.listen(listenPort, '127.0.0.1', () => { + const addr = server.address() + if (!addr || typeof addr === 'string') { + settle(() => reject(new OAuthError(provider, 'Failed to start loopback server'))) + return + } + + this.activeServer = server + // Use localhost (not 127.0.0.1) as some providers require it + redirectUri = `http://localhost:${addr.port}${callbackPath}` + + // Build authorization URL + const params = new URLSearchParams({ + client_id: capabilities.clientId, + redirect_uri: redirectUri, + response_type: 'code', + scope: capabilities.scopes.join(' '), + state, + code_challenge: codeChallenge, + code_challenge_method: 'S256', + ...capabilities.extraAuthParams, + }) + + const authUrl = `${capabilities.authorizationUrl}?${params.toString()}` + shell.openExternal(authUrl) + }) + + // Timeout after 5 minutes + setTimeout(() => { + if (this.activeServer === server) { + this.emitStatus({ state: 'error', provider, error: 'Authorization timed out' }) + this.shutdownServer() + settle(() => reject(new OAuthError(provider, 'Authorization timed out'))) + } + }, 5 * 60 * 1000) + }) + } + + private async exchangeCode( + provider: ProviderType, + code: string, + codeVerifier: string, + redirectUri: string, + ): Promise { + const capabilities = OAUTH_CAPABILITIES[provider] + if (!capabilities) throw new OAuthError(provider, 'Unknown provider') + + const body = new URLSearchParams({ + grant_type: 'authorization_code', + code, + redirect_uri: redirectUri, + client_id: capabilities.clientId, + code_verifier: codeVerifier, + }) + + const response = await fetch(capabilities.tokenUrl, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: body.toString(), + }) + + if (!response.ok) { + const text = await response.text() + throw new OAuthError(provider, `Token exchange failed: ${response.status} ${text}`) + } + + const data = await response.json() + + const tokens: OAuthTokens = { + accessToken: data.access_token, + refreshToken: data.refresh_token, + expiresAt: Date.now() + (data.expires_in || 3600) * 1000, + } + + // Store and parse ID token if present + if (data.id_token) { + tokens.idToken = data.id_token + try { + const payload = JSON.parse( + Buffer.from(data.id_token.split('.')[1], 'base64url').toString(), + ) + if (payload.email) tokens.email = payload.email + } catch { + // ID token parsing is best-effort + } + } + + // OpenAI-specific: extract account ID + if (provider === 'openai' && data.account_id) { + tokens.accountId = data.account_id + } + + // OpenAI token exchange: convert ID token to API key + if (provider === 'openai' && capabilities.supportsTokenExchange && data.id_token) { + const apiKey = await this.exchangeForApiKey(capabilities, data.id_token) + if (apiKey) { + tokens.accessToken = apiKey + } else { + throw new OAuthError(provider, 'Failed to exchange token for API key. Please try again.') + } + } + + return tokens + } + + private async exchangeForApiKey( + capabilities: { tokenUrl: string; clientId: string }, + idToken: string, + ): Promise { + try { + const body = new URLSearchParams({ + grant_type: 'urn:ietf:params:oauth:grant-type:token-exchange', + client_id: capabilities.clientId, + requested_token: 'openai-api-key', + subject_token: idToken, + subject_token_type: 'urn:ietf:params:oauth:token-type:id_token', + }) + + const response = await fetch(capabilities.tokenUrl, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: body.toString(), + }) + + if (!response.ok) { + console.error('Token exchange for API key failed:', response.status) + return null + } + + const data = await response.json() + return data.access_token || null + } catch (error) { + console.error('Token exchange error:', error) + return null + } + } + + async refreshToken(provider: ProviderType, currentTokens: OAuthTokens): Promise { + if (!currentTokens.refreshToken) return null + + const capabilities = OAUTH_CAPABILITIES[provider] + if (!capabilities?.supported) return null + + try { + const body = new URLSearchParams({ + grant_type: 'refresh_token', + refresh_token: currentTokens.refreshToken, + client_id: capabilities.clientId, + }) + + const response = await fetch(capabilities.tokenUrl, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: body.toString(), + }) + + if (!response.ok) { + console.error(`Token refresh failed for ${provider}: ${response.status}`) + return null + } + + const data = await response.json() + + const newTokens: OAuthTokens = { + accessToken: data.access_token, + refreshToken: data.refresh_token || currentTokens.refreshToken, + idToken: data.id_token || currentTokens.idToken, + expiresAt: Date.now() + (data.expires_in || 3600) * 1000, + accountId: currentTokens.accountId, + email: currentTokens.email, + } + + // OpenAI: re-exchange the new ID token for an API key + if (provider === 'openai' && capabilities.supportsTokenExchange && newTokens.idToken) { + const apiKey = await this.exchangeForApiKey(capabilities, newTokens.idToken) + if (apiKey) { + newTokens.accessToken = apiKey + } + } + + return newTokens + } catch (error) { + console.error(`Token refresh error for ${provider}:`, error) + return null + } + } + + scheduleRefresh(provider: ProviderType, tokens: OAuthTokens): void { + this.clearRefreshTimer(provider) + + if (!tokens.refreshToken) return + + const refreshAt = tokens.expiresAt - 5 * 60 * 1000 + const delay = Math.max(refreshAt - Date.now(), 10_000) + + const timer = setTimeout(async () => { + const newTokens = await this.refreshToken(provider, tokens) + if (newTokens) { + this.tokenRefreshCallback?.(provider, newTokens) + this.scheduleRefresh(provider, newTokens) + } else { + this.emitStatus({ state: 'error', provider, error: 'Token refresh failed. Please sign in again.' }) + } + }, delay) + + this.refreshTimers.set(provider, timer) + } + + clearRefreshTimer(provider: string): void { + const timer = this.refreshTimers.get(provider) + if (timer) { + clearTimeout(timer) + this.refreshTimers.delete(provider) + } + } + + cancelFlow(): void { + this.shutdownServer() + } + + private shutdownServer(): void { + if (this.activeServer) { + this.activeServer.close() + this.activeServer = null + } + } + + getCapabilities(provider: ProviderType): { supported: boolean; experimental?: boolean } { + const cap = OAUTH_CAPABILITIES[provider] + return { + supported: cap?.supported ?? false, + experimental: cap?.experimental, + } + } + + destroy(): void { + this.shutdownServer() + for (const timer of this.refreshTimers.values()) { + clearTimeout(timer) + } + this.refreshTimers.clear() + } + + private escapeHtml(str: string): string { + return str + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"') + .replace(/'/g, ''') + } + + private buildCallbackHtml(success: boolean, message: string): string { + return ` + + + Keystone - OAuth + + + +
+
${success ? '✓' : '✗'}
+

${success ? 'Connected!' : 'Connection Failed'}

+

${this.escapeHtml(message)}

+
+ +` + } +} diff --git a/src/main/services/SettingsService.ts b/src/main/services/SettingsService.ts index 8cbb6bf..8b3a723 100644 --- a/src/main/services/SettingsService.ts +++ b/src/main/services/SettingsService.ts @@ -5,6 +5,8 @@ import * as path from 'path' interface SettingsData { activeProvider: string | null apiKeys: Record // stored as base64-encoded encrypted strings + oauthTokens: Record // stored as base64-encoded encrypted JSON strings + authMethods: Record // 'apiKey' | 'oauth' } export class SettingsService { @@ -21,6 +23,9 @@ export class SettingsService { } this.settings = this.loadSettings() + // Migrate older settings files that lack OAuth fields + if (!this.settings.oauthTokens) this.settings.oauthTokens = {} + if (!this.settings.authMethods) this.settings.authMethods = {} } private loadSettings(): SettingsData { @@ -35,7 +40,9 @@ export class SettingsService { return { activeProvider: null, - apiKeys: {} + apiKeys: {}, + oauthTokens: {}, + authMethods: {}, } } @@ -101,10 +108,67 @@ export class SettingsService { } getConfiguredProviders(): string[] { - return Object.keys(this.settings.apiKeys) + return [...new Set([ + ...Object.keys(this.settings.apiKeys), + ...Object.keys(this.settings.oauthTokens), + ])] } hasApiKey(provider: string): boolean { return !!this.settings.apiKeys[provider] } + + getOAuthTokens(provider: string): import('@shared/types/provider').OAuthTokens | null { + const encrypted = this.settings.oauthTokens[provider] + if (!encrypted) return null + + try { + let json: string + if (this.encryptionAvailable) { + const buffer = Buffer.from(encrypted, 'base64') + json = safeStorage.decryptString(buffer) + } else { + json = encrypted + } + return JSON.parse(json) + } catch (error) { + console.error(`Failed to decrypt OAuth tokens for ${provider}:`, error) + return null + } + } + + setOAuthTokens(provider: string, tokens: import('@shared/types/provider').OAuthTokens): void { + try { + const json = JSON.stringify(tokens) + if (this.encryptionAvailable) { + const encrypted = safeStorage.encryptString(json) + this.settings.oauthTokens[provider] = encrypted.toString('base64') + } else { + this.settings.oauthTokens[provider] = json + } + this.saveSettings() + } catch (error) { + console.error(`Failed to encrypt and save OAuth tokens for ${provider}:`, error) + throw error + } + } + + deleteOAuthTokens(provider: string): void { + delete this.settings.oauthTokens[provider] + delete this.settings.authMethods[provider] + this.saveSettings() + } + + getAuthMethod(provider: string): 'apiKey' | 'oauth' { + return (this.settings.authMethods[provider] as 'apiKey' | 'oauth') || 'apiKey' + } + + setAuthMethod(provider: string, method: 'apiKey' | 'oauth'): void { + this.settings.authMethods[provider] = method + this.saveSettings() + } + + getOAuthConfiguredProviders(): string[] { + return Object.keys(this.settings.oauthTokens) + } } diff --git a/src/preload/index.d.ts b/src/preload/index.d.ts index 6208751..30be92c 100644 --- a/src/preload/index.d.ts +++ b/src/preload/index.d.ts @@ -5,6 +5,8 @@ interface KeystoneIPC { onAIDone: (callback: (data: { threadId: string; messageId: string | null }) => void) => unknown removeAIListeners: () => void selectDirectory: () => Promise + onOAuthStatus: (callback: (data: { state: string; provider?: string; email?: string; error?: string }) => void) => unknown + removeOAuthListeners: () => void } declare global { diff --git a/src/preload/index.ts b/src/preload/index.ts index 339a69d..9a6bce9 100644 --- a/src/preload/index.ts +++ b/src/preload/index.ts @@ -22,5 +22,14 @@ process.once('loaded', () => { ipcRenderer.removeAllListeners('ai:done') }, selectDirectory: () => ipcRenderer.invoke('dialog:openDirectory'), + onOAuthStatus: (callback: (data: { state: string; provider?: string; email?: string; error?: string }) => void) => { + const handler = (_event: Electron.IpcRendererEvent, data: { state: string; provider?: string; email?: string; error?: string }) => + callback(data) + ipcRenderer.on('oauth:status', handler) + return handler + }, + removeOAuthListeners: () => { + ipcRenderer.removeAllListeners('oauth:status') + }, }) }) diff --git a/src/renderer/env.d.ts b/src/renderer/env.d.ts index 9cdadca..abcf3e9 100644 --- a/src/renderer/env.d.ts +++ b/src/renderer/env.d.ts @@ -3,6 +3,8 @@ interface KeystoneIPC { onAIDone: (callback: (data: { threadId: string; messageId: string | null }) => void) => unknown removeAIListeners: () => void selectDirectory: () => Promise + onOAuthStatus: (callback: (data: { state: string; provider?: string; email?: string; error?: string }) => void) => unknown + removeOAuthListeners: () => void } interface Window { diff --git a/src/renderer/features/settings/ProviderCard.tsx b/src/renderer/features/settings/ProviderCard.tsx new file mode 100644 index 0000000..df7cccc --- /dev/null +++ b/src/renderer/features/settings/ProviderCard.tsx @@ -0,0 +1,113 @@ +import { Button } from '../../components/ui/Button' +import { useSettingsStore } from '../../stores/settingsStore' +import type { ProviderType } from '@shared/types' +import { OAUTH_CAPABILITIES } from '@shared/constants' + +interface ProviderCardProps { + type: ProviderType + name: string + isActive: boolean + onSelect: () => void +} + +export function ProviderCard({ type, name, isActive, onSelect }: ProviderCardProps) { + const { oauthStatus, oauthFlowStatus, startOAuthFlow, disconnectOAuth } = useSettingsStore() + + const oauth = oauthStatus[type] + const capabilities = OAUTH_CAPABILITIES[type] + const isOAuthSupported = capabilities?.supported ?? false + const isExperimental = capabilities?.experimental ?? false + const isPending = oauthFlowStatus.state === 'pending' && 'provider' in oauthFlowStatus && oauthFlowStatus.provider === type + const isError = oauthFlowStatus.state === 'error' && 'provider' in oauthFlowStatus && oauthFlowStatus.provider === type + + const handleSignIn = async () => { + await startOAuthFlow(type) + } + + const handleDisconnect = async () => { + await disconnectOAuth(type) + } + + return ( +
+
+ +
+
+ {name} + {isExperimental && ( + + Experimental + + )} +
+
+
+ + {/* OAuth section */} + {isOAuthSupported && ( +
+ {oauth.connected ? ( +
+
+
+ + Connected{oauth.email ? ` as ${oauth.email}` : ''} + +
+ +
+ ) : ( +
+ + {isError && 'error' in oauthFlowStatus && ( +

{oauthFlowStatus.error}

+ )} +
+ )} +
+ )} + + {/* API key only notice for Anthropic */} + {!isOAuthSupported && type === 'anthropic' && ( +
+ API key required (no OAuth available) +
+ )} +
+ ) +} diff --git a/src/renderer/features/settings/SettingsDialog.tsx b/src/renderer/features/settings/SettingsDialog.tsx index 294c1b1..50f0288 100644 --- a/src/renderer/features/settings/SettingsDialog.tsx +++ b/src/renderer/features/settings/SettingsDialog.tsx @@ -1,8 +1,9 @@ -import { useEffect } from 'react' +import { useEffect, useState } from 'react' import { Dialog } from '../../components/ui/Dialog' import { Input } from '../../components/ui/Input' import { Button } from '../../components/ui/Button' import { useSettingsStore } from '../../stores/settingsStore' +import { ProviderCard } from './ProviderCard' import type { ProviderType } from '@shared/types' interface SettingsDialogProps { @@ -11,13 +12,14 @@ interface SettingsDialogProps { } const providers: Array<{ type: ProviderType; name: string; placeholder: string }> = [ - { type: 'anthropic', name: 'Anthropic (Claude)', placeholder: 'sk-ant-...' }, { type: 'openai', name: 'OpenAI', placeholder: 'sk-...' }, + { type: 'anthropic', name: 'Anthropic (Claude)', placeholder: 'sk-ant-...' }, { type: 'google', name: 'Google (Gemini)', placeholder: 'AI...' }, ] export function SettingsDialog({ open, onClose }: SettingsDialogProps) { const { apiKeys, activeProvider, setApiKey, setActiveProvider, loadSettings } = useSettingsStore() + const [showAdvanced, setShowAdvanced] = useState(false) useEffect(() => { if (open) { @@ -25,34 +27,75 @@ export function SettingsDialog({ open, onClose }: SettingsDialogProps) { } }, [open, loadSettings]) + // Listen for OAuth status updates from main process + useEffect(() => { + if (!open) return + + window.keystoneIPC.onOAuthStatus((data) => { + useSettingsStore.getState().updateOAuthFlowStatus(data as any) + }) + + return () => { + window.keystoneIPC.removeOAuthListeners() + } + }, [open]) + return ( -
+
-

AI Provider

-
+

AI Providers

+
{providers.map((p) => ( -
-
- setActiveProvider(p.type)} - className="text-indigo-600" - /> - -
- setApiKey(p.type, e.target.value)} - placeholder={p.placeholder} - /> -
+ setActiveProvider(p.type)} + /> ))}
+ + {/* Collapsible Advanced section for API keys */} +
+ + + {showAdvanced && ( +
+ {providers.map((p) => ( +
+ + setApiKey(p.type, e.target.value)} + placeholder={p.placeholder} + /> +
+ ))} +
+ )} +
+
diff --git a/src/renderer/stores/settingsStore.ts b/src/renderer/stores/settingsStore.ts index 815493a..6fecac3 100644 --- a/src/renderer/stores/settingsStore.ts +++ b/src/renderer/stores/settingsStore.ts @@ -1,19 +1,36 @@ import { create } from 'zustand' -import type { ProviderType } from '@shared/types' +import type { ProviderType, OAuthFlowStatus, AuthMethod } from '@shared/types' import { trpc } from '../lib/trpc' +interface OAuthProviderStatus { + connected: boolean + email: string | null + authMethod: AuthMethod +} + interface SettingsState { activeProvider: ProviderType | null apiKeys: Record + oauthStatus: Record + oauthFlowStatus: OAuthFlowStatus loaded: boolean setActiveProvider: (provider: ProviderType | null) => void setApiKey: (provider: ProviderType, key: string) => void loadSettings: () => Promise + startOAuthFlow: (provider: ProviderType) => Promise + disconnectOAuth: (provider: ProviderType) => Promise + updateOAuthFlowStatus: (status: OAuthFlowStatus) => void } export const useSettingsStore = create((set, get) => ({ activeProvider: null, apiKeys: { openai: '', anthropic: '', google: '' }, + oauthStatus: { + openai: { connected: false, email: null, authMethod: 'apiKey' }, + anthropic: { connected: false, email: null, authMethod: 'apiKey' }, + google: { connected: false, email: null, authMethod: 'apiKey' }, + }, + oauthFlowStatus: { state: 'idle' }, loaded: false, setActiveProvider: async (provider) => { @@ -31,7 +48,52 @@ export const useSettingsStore = create((set, get) => ({ // Also configure the provider manager in-memory for immediate use if (key) { - await trpc.ai.configureProvider.mutate({ type: provider, apiKey: key }) + await trpc.ai.configureProvider.mutate({ type: provider, apiKey: key, authMethod: 'apiKey' }) + } + }, + + startOAuthFlow: async (provider) => { + set({ oauthFlowStatus: { state: 'pending', provider } }) + try { + const result = await trpc.oauth.startFlow.mutate({ provider }) + set((state) => ({ + oauthFlowStatus: { state: 'success', provider, email: result.email }, + oauthStatus: { + ...state.oauthStatus, + [provider]: { connected: true, email: result.email ?? null, authMethod: 'oauth' as AuthMethod }, + }, + activeProvider: provider, + })) + } catch (err) { + const message = err instanceof Error ? err.message : 'OAuth flow failed' + set({ oauthFlowStatus: { state: 'error', provider, error: message } }) + } + }, + + disconnectOAuth: async (provider) => { + await trpc.oauth.disconnect.mutate({ provider }) + set((state) => ({ + oauthStatus: { + ...state.oauthStatus, + [provider]: { connected: false, email: null, authMethod: 'apiKey' as AuthMethod }, + }, + oauthFlowStatus: { state: 'idle' }, + })) + }, + + updateOAuthFlowStatus: (status) => { + set({ oauthFlowStatus: status }) + if (status.state === 'success' && 'provider' in status) { + set((state) => ({ + oauthStatus: { + ...state.oauthStatus, + [status.provider]: { + connected: true, + email: status.email ?? null, + authMethod: 'oauth' as AuthMethod, + }, + }, + })) } }, @@ -50,13 +112,34 @@ export const useSettingsStore = create((set, get) => ({ } } + // Load OAuth status for each provider + const oauthStatus: Record = { + openai: { connected: false, email: null, authMethod: 'apiKey' }, + anthropic: { connected: false, email: null, authMethod: 'apiKey' }, + google: { connected: false, email: null, authMethod: 'apiKey' }, + } + + for (const provider of ['openai', 'anthropic', 'google'] as ProviderType[]) { + try { + const status = await trpc.oauth.getStatus.query({ provider }) + oauthStatus[provider] = { + connected: status.connected, + email: status.email, + authMethod: status.authMethod as AuthMethod, + } + } catch { + // OAuth status query failed, keep defaults + } + } + set({ activeProvider: settings.activeProvider as ProviderType | null, apiKeys, - loaded: true + oauthStatus, + loaded: true, }) } catch (error) { console.error('Failed to load settings:', error) } - } + }, })) diff --git a/src/shared/constants.ts b/src/shared/constants.ts index f17b369..9dbd557 100644 --- a/src/shared/constants.ts +++ b/src/shared/constants.ts @@ -8,6 +8,49 @@ export const MAX_DOCUMENT_SIZE_BYTES = 1024 * 1024 // 1MB export const DEFAULT_PROVIDER = 'anthropic' as const +export const OAUTH_CAPABILITIES: Record + supportsTokenExchange?: boolean + experimental?: boolean + redirectPort?: number // Fixed port for OAuth callback (required by some providers) + callbackPath?: string // Callback path (default: /callback) +}> = { + openai: { + supported: true, + clientId: 'app_EMoamEEZ73f0CkXaXp7hrann', + authorizationUrl: 'https://auth.openai.com/oauth/authorize', + tokenUrl: 'https://auth.openai.com/oauth/token', + scopes: ['openid', 'profile', 'email', 'offline_access'], + extraAuthParams: { + id_token_add_organizations: 'true', + codex_cli_simplified_flow: 'true', + }, + supportsTokenExchange: true, + redirectPort: 1455, + callbackPath: '/auth/callback', + }, + anthropic: { + supported: false, + clientId: '', + authorizationUrl: '', + tokenUrl: '', + scopes: [], + }, + google: { + supported: true, + clientId: '539167010789-g3ltv0osl0j74oab94klpj41sv7l4mqb.apps.googleusercontent.com', + authorizationUrl: 'https://accounts.google.com/o/oauth2/v2/auth', + tokenUrl: 'https://oauth2.googleapis.com/token', + scopes: ['openid', 'email', 'https://www.googleapis.com/auth/generative-language'], + experimental: true, + }, +} + export const ADR_TEMPLATE = `# ADR-{number}: {title} ## Status diff --git a/src/shared/errors.ts b/src/shared/errors.ts index f0b4a31..55f95a4 100644 --- a/src/shared/errors.ts +++ b/src/shared/errors.ts @@ -31,3 +31,15 @@ export class QuotaExhaustedError extends KeystoneError { super(`${provider} quota exhausted`, 'QUOTA_EXHAUSTED') } } + +export class OAuthError extends KeystoneError { + constructor(provider: string, message: string) { + super(`OAuth error (${provider}): ${message}`, 'OAUTH_ERROR') + } +} + +export class OAuthTokenExpiredError extends KeystoneError { + constructor(provider: string) { + super(`OAuth token expired for ${provider}`, 'OAUTH_TOKEN_EXPIRED') + } +} diff --git a/src/shared/types/index.ts b/src/shared/types/index.ts index 4d2410c..cd068d6 100644 --- a/src/shared/types/index.ts +++ b/src/shared/types/index.ts @@ -8,4 +8,10 @@ export type { UsageStatus, ProviderType, ProviderConfig, + ProviderConfigApiKey, + ProviderConfigOAuth, + LegacyProviderConfig, + AuthMethod, + OAuthTokens, + OAuthFlowStatus, } from './provider' diff --git a/src/shared/types/provider.ts b/src/shared/types/provider.ts index 1872f6b..37b5891 100644 --- a/src/shared/types/provider.ts +++ b/src/shared/types/provider.ts @@ -19,7 +19,42 @@ export interface UsageStatus { export type ProviderType = 'openai' | 'anthropic' | 'google' -export interface ProviderConfig { +export type AuthMethod = 'apiKey' | 'oauth' + +export interface OAuthTokens { + accessToken: string // For OpenAI: this is the exchanged API key + refreshToken?: string + idToken?: string // Original ID token, used for OpenAI token exchange + expiresAt: number // Unix timestamp in ms + accountId?: string // e.g. OpenAI chatgpt-account-id + email?: string +} + +export type OAuthFlowStatus = + | { state: 'idle' } + | { state: 'pending'; provider: ProviderType } + | { state: 'success'; provider: ProviderType; email?: string } + | { state: 'error'; provider: ProviderType; error: string } + +export interface ProviderConfigApiKey { + type: ProviderType + authMethod: 'apiKey' + apiKey: string + model?: string +} + +export interface ProviderConfigOAuth { + type: ProviderType + authMethod: 'oauth' + oauthToken: string + accountId?: string + model?: string +} + +export type ProviderConfig = ProviderConfigApiKey | ProviderConfigOAuth + +// Backward-compatible: accept legacy shape too +export interface LegacyProviderConfig { type: ProviderType apiKey: string model?: string