From 2b8db86d1c3eb27d67d21a57695dc00bf6b304cc Mon Sep 17 00:00:00 2001 From: suifeng <369202865@qq.com> Date: Thu, 14 Aug 2025 23:57:03 +0800 Subject: [PATCH] =?UTF-8?q?[dev]=20=E6=94=AF=E6=8C=81=E4=BC=9A=E8=AF=9D?= =?UTF-8?q?=E4=B8=8A=E4=B8=8B=E6=96=87=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sfchain/core/AIService.java | 213 +++++++++++----- .../sfchain/core/BaseAIOperation.java | 239 ++++++++++++------ .../context/ChatContextService.java | 87 +++++++ .../persistence/context/ChatMessage.java | 62 +++++ .../context/MapBasedChatContextService.java | 177 +++++++++++++ prompto-lab-ui/src/assets/icons/default.svg | 1 + 6 files changed, 642 insertions(+), 137 deletions(-) create mode 100644 prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/persistence/context/ChatContextService.java create mode 100644 prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/persistence/context/ChatMessage.java create mode 100644 prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/persistence/context/MapBasedChatContextService.java create mode 100644 prompto-lab-ui/src/assets/icons/default.svg diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/core/AIService.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/core/AIService.java index ca5aa26..f914e0f 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/core/AIService.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/core/AIService.java @@ -1,5 +1,6 @@ package io.github.timemachinelab.sfchain.core; +import io.github.timemachinelab.sfchain.persistence.context.ChatContextService; import lombok.Getter; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; @@ -13,25 +14,28 @@ /** * 描述: AI服务类 - 新框架版本 * 统一管理AI操作的执行 - * + * * @author suifeng * 日期: 2025/8/11 */ @Slf4j @Service public class AIService { - + @Resource private AIOperationRegistry operationRegistry; + @Resource + private ChatContextService chatContextService; + /** * 操作执行统计 */ private final Map executionStats = new ConcurrentHashMap<>(); - + /** * 执行AI操作 - * + * * @param operationType 操作类型 * @param input 输入参数 * @param 输入类型 @@ -40,54 +44,123 @@ public class AIService { */ @SuppressWarnings("unchecked") public OUTPUT execute(String operationType, INPUT input) { - return execute(operationType, input, null); + return execute(operationType, input, null, null); } - + + + @SuppressWarnings("unchecked") + public OUTPUT execute(String operationType, INPUT input, String sessionId) { + return execute(operationType, input,null, sessionId); + } + /** - * 执行AI操作(指定模型) - * + * 执行AI操作(带上下文支持) + * * @param operationType 操作类型 * @param input 输入参数 * @param modelName 指定的模型名称 + * @param sessionId 会话ID,用于上下文管理 * @param 输入类型 * @param 输出类型 * @return 执行结果 */ @SuppressWarnings("unchecked") - public OUTPUT execute(String operationType, INPUT input, String modelName) { + public OUTPUT execute(String operationType, INPUT input, String modelName, String sessionId) { long startTime = System.currentTimeMillis(); - + try { // 获取操作实例 BaseAIOperation operation = (BaseAIOperation) operationRegistry.getOperation(operationType); - + // 检查操作是否启用 if (!operation.isEnabled()) { throw new IllegalStateException("操作已禁用: " + operationType); } - + + // 如果有会话ID,记录用户输入到上下文 + if (sessionId != null && input != null) { + chatContextService.addUserMessage(sessionId, input.toString()); + } + // 执行操作 - OUTPUT result = operation.execute(input, modelName); - + OUTPUT result = operation.execute(input, modelName, sessionId); + + // 如果有会话ID,记录AI回复到上下文 + if (sessionId != null && result != null) { + chatContextService.addAiResponse(sessionId, result.toString()); + } + // 记录执行统计 recordExecution(operationType, true, System.currentTimeMillis() - startTime); - + log.debug("AI操作执行成功: {} - 耗时: {}ms", operationType, System.currentTimeMillis() - startTime); - + return result; - + } catch (Exception e) { // 记录执行统计 recordExecution(operationType, false, System.currentTimeMillis() - startTime); - + log.error("AI操作执行失败: {} - {}", operationType, e.getMessage(), e); throw new RuntimeException("AI操作执行失败: " + e.getMessage(), e); } } - + + /** + * 设置会话系统提示词 + * + * @param sessionId 会话ID + * @param systemPrompt 系统提示词 + */ + public void setSystemPrompt(String sessionId, String systemPrompt) { + chatContextService.setSystemPrompt(sessionId, systemPrompt); + log.info("设置会话系统提示词: sessionId={}", sessionId); + } + + /** + * 获取会话上下文 + * + * @param sessionId 会话ID + * @param includeSystemPrompt 是否包含系统提示词 + * @return 上下文字符串 + */ + public String getSessionContext(String sessionId, boolean includeSystemPrompt) { + return chatContextService.getContextAsString(sessionId, includeSystemPrompt); + } + + /** + * 清除会话对话历史 + * + * @param sessionId 会话ID + */ + public void clearSessionConversation(String sessionId) { + chatContextService.clearConversation(sessionId); + log.info("清除会话对话历史: sessionId={}", sessionId); + } + + /** + * 完全清除会话 + * + * @param sessionId 会话ID + */ + public void clearSession(String sessionId) { + chatContextService.clearSession(sessionId); + log.info("完全清除会话: sessionId={}", sessionId); + } + + /** + * 检查会话是否存在 + * + * @param sessionId 会话ID + * @return 是否存在 + */ + public boolean sessionExists(String sessionId) { + return chatContextService.sessionExists(sessionId); + } + /** * 异步执行AI操作 - * + * * @param operationType 操作类型 * @param input 输入参数 * @param 输入类型 @@ -96,12 +169,12 @@ public OUTPUT execute(String operationType, INPUT input, String */ @SuppressWarnings("unchecked") public CompletableFuture executeAsync(String operationType, INPUT input) { - return executeAsync(operationType, input, null); + return executeAsync(operationType, input, null, null); } - + /** * 异步执行AI操作(指定模型) - * + * * @param operationType 操作类型 * @param input 输入参数 * @param modelName 指定的模型名称 @@ -111,12 +184,30 @@ public CompletableFuture executeAsync(String operationTy */ @SuppressWarnings("unchecked") public CompletableFuture executeAsync(String operationType, INPUT input, String modelName) { - return CompletableFuture.supplyAsync(() -> execute(operationType, input, modelName)); + return executeAsync(operationType, input, modelName, null); + } + + /** + * 异步执行AI操作(带上下文支持) + * + * @param operationType 操作类型 + * @param input 输入参数 + * @param modelName 指定的模型名称 + * @param sessionId 会话ID + * @param 输入类型 + * @param 输出类型 + * @return 异步执行结果 + */ + @SuppressWarnings("unchecked") + public CompletableFuture executeAsync(String operationType, INPUT input, String modelName, String sessionId) { + return CompletableFuture.supplyAsync(() -> execute(operationType, input, modelName, sessionId)); } - + + // ... existing code ... + /** * 批量执行AI操作 - * + * * @param operationType 操作类型 * @param inputs 输入参数列表 * @param 输入类型 @@ -126,10 +217,10 @@ public CompletableFuture executeAsync(String operationTy public List executeBatch(String operationType, List inputs) { return executeBatch(operationType, inputs, null); } - + /** * 批量执行AI操作(指定模型) - * + * * @param operationType 操作类型 * @param inputs 输入参数列表 * @param modelName 指定的模型名称 @@ -142,10 +233,10 @@ public List executeBatch(String operationType, List this.execute(operationType, input, modelName)) .toList(); } - + /** * 异步批量执行AI操作 - * + * * @param operationType 操作类型 * @param inputs 输入参数列表 * @param 输入类型 @@ -155,10 +246,10 @@ public List executeBatch(String operationType, List CompletableFuture> executeBatchAsync(String operationType, List inputs) { return executeBatchAsync(operationType, inputs, null); } - + /** * 异步批量执行AI操作(指定模型) - * + * * @param operationType 操作类型 * @param inputs 输入参数列表 * @param modelName 指定的模型名称 @@ -170,16 +261,16 @@ public CompletableFuture> executeBatchAsync(String List> futures = inputs.stream() .map(input -> this.executeAsync(operationType, input, modelName)) .toList(); - + return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])) .thenApply(v -> futures.stream() .map(CompletableFuture::join) .toList()); } - + /** * 获取所有可用的操作 - * + * * @return 操作类型列表 */ public List getAvailableOperations() { @@ -200,10 +291,10 @@ public boolean isOperationAvailable(String operationType) { return false; } } - + /** * 获取操作信息 - * + * * @param operationType 操作类型 * @return 操作信息 */ @@ -212,7 +303,7 @@ public OperationInfo getOperationInfo(String operationType) { if (operation == null) { return null; } - + return OperationInfo.builder() .operationType(operationType) .description(operation.getDescription()) @@ -223,36 +314,36 @@ public OperationInfo getOperationInfo(String operationType) { .defaultModel(operation.getAnnotation().defaultModel()) .build(); } - + /** * 获取操作执行统计 - * + * * @param operationType 操作类型 * @return 执行统计 */ public ExecutionStats getExecutionStats(String operationType) { return executionStats.getOrDefault(operationType, new ExecutionStats()); } - + /** * 获取所有操作的执行统计 - * + * * @return 执行统计映射 */ public Map getAllExecutionStats() { return Map.copyOf(executionStats); } - + /** * 清空执行统计 */ public void clearExecutionStats() { executionStats.clear(); } - + /** * 记录执行统计 - * + * * @param operationType 操作类型 * @param success 是否成功 * @param duration 执行时长 @@ -261,7 +352,9 @@ private void recordExecution(String operationType, boolean success, long duratio executionStats.computeIfAbsent(operationType, k -> new ExecutionStats()) .record(success, duration); } - + + // ... existing inner classes remain the same ... + /** * 操作信息类 */ @@ -275,55 +368,55 @@ public static class OperationInfo { private boolean enabled; private String[] supportedModels; private String defaultModel; - + public static OperationInfoBuilder builder() { return new OperationInfoBuilder(); } public static class OperationInfoBuilder { private OperationInfo info = new OperationInfo(); - + public OperationInfoBuilder operationType(String operationType) { info.operationType = operationType; return this; } - + public OperationInfoBuilder description(String description) { info.description = description; return this; } - + public OperationInfoBuilder inputType(Class inputType) { info.inputType = inputType; return this; } - + public OperationInfoBuilder outputType(Class outputType) { info.outputType = outputType; return this; } - + public OperationInfoBuilder enabled(boolean enabled) { info.enabled = enabled; return this; } - + public OperationInfoBuilder supportedModels(String[] supportedModels) { info.supportedModels = supportedModels; return this; } - + public OperationInfoBuilder defaultModel(String defaultModel) { info.defaultModel = defaultModel; return this; } - + public OperationInfo build() { return info; } } } - + /** * 执行统计类 */ @@ -337,7 +430,7 @@ public static class ExecutionStats { private long totalDuration = 0; private long minDuration = Long.MAX_VALUE; private long maxDuration = 0; - + public synchronized void record(boolean success, long duration) { totalExecutions++; if (success) { @@ -345,16 +438,16 @@ public synchronized void record(boolean success, long duration) { } else { failedExecutions++; } - + totalDuration += duration; minDuration = Math.min(minDuration, duration); maxDuration = Math.max(maxDuration, duration); } - + public double getSuccessRate() { return totalExecutions > 0 ? (double) successfulExecutions / totalExecutions : 0.0; } - + public double getAverageDuration() { return totalExecutions > 0 ? (double) totalDuration / totalExecutions : 0.0; } diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/core/BaseAIOperation.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/core/BaseAIOperation.java index 499ac2f..1b6bc95 100644 --- a/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/core/BaseAIOperation.java +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/core/BaseAIOperation.java @@ -5,6 +5,8 @@ import io.github.timemachinelab.sfchain.core.logging.AICallLog; import io.github.timemachinelab.sfchain.core.logging.AICallLogManager; import io.github.timemachinelab.sfchain.core.openai.OpenAICompatibleModel; +import io.github.timemachinelab.sfchain.persistence.context.ChatContextService; +import io.github.timemachinelab.sfchain.persistence.context.ChatMessage; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import lombok.Getter; @@ -15,6 +17,7 @@ import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.time.LocalDateTime; +import java.util.List; import java.util.UUID; import static io.github.timemachinelab.sfchain.constants.AIOperationConstant.JSON_REPAIR_OP; @@ -22,22 +25,25 @@ /** * 描述: AI操作抽象基类 - 新框架版本 * 提供统一的AI操作接口和实现 - * + * * @author suifeng * 日期: 2025/8/11 */ @Slf4j public abstract class BaseAIOperation { - + @Autowired protected AIOperationRegistry operationRegistry; - + @Autowired protected ModelRegistry modelRegistry; - + @Autowired protected ObjectMapper objectMapper; + @Autowired + protected ChatContextService chatContextService; + /** * 操作的注解信息 * -- GETTER -- @@ -47,7 +53,7 @@ public abstract class BaseAIOperation { */ @Getter private AIOp annotation; - + /** * 输入类型 * -- GETTER -- @@ -58,7 +64,7 @@ public abstract class BaseAIOperation { */ @Getter private Class inputType; - + /** * 输出类型 * -- GETTER -- @@ -69,7 +75,7 @@ public abstract class BaseAIOperation { */ @Getter private Class outputType; - + /** * 初始化方法 */ @@ -81,7 +87,7 @@ public void init() { if (annotation == null) { throw new IllegalStateException("AI操作类必须使用@AIOp注解: " + this.getClass().getSimpleName()); } - + // 获取泛型类型 Type superClass = this.getClass().getGenericSuperclass(); if (superClass instanceof ParameterizedType parameterizedType) { @@ -91,10 +97,10 @@ public void init() { this.outputType = (Class) typeArguments[1]; } } - + // 注册到操作注册中心 operationRegistry.registerOperation(annotation.value(), this); - + // 如果注解中有默认模型且当前没有设置模型映射,则自动设置 if (!annotation.defaultModel().isEmpty()) { String currentModel = operationRegistry.getModelForOperation(annotation.value()); @@ -110,41 +116,53 @@ public void init() { } } } - - log.info("初始化AI操作: {} [{}] -> 输入类型: {}, 输出类型: {}", - annotation.value(), this.getClass().getSimpleName(), + + log.info("初始化AI操作: {} [{}] -> 输入类型: {}, 输出类型: {}", + annotation.value(), this.getClass().getSimpleName(), inputType != null ? inputType.getSimpleName() : "Unknown", outputType != null ? outputType.getSimpleName() : "Unknown"); } - + /** * 执行AI操作 - * + * * @param input 输入参数 * @return 输出结果 */ public OUTPUT execute(INPUT input) { - return execute(input, null); + return execute(input, null, null); } - + /** * 执行AI操作(指定模型) - * + * * @param input 输入参数 * @param modelName 指定的模型名称,为null时使用默认模型 * @return 输出结果 */ + public OUTPUT execute(INPUT input, String modelName) { + return execute(input, modelName, null); + } + + /** + * 执行AI操作(带上下文支持) + * + * @param input 输入参数 + * @param modelName 指定的模型名称,为null时使用默认模型 + * @param sessionId 会话ID,用于上下文管理 + * @return 输出结果 + */ // 在BaseAIOperation类中添加以下字段和方法 - + @Autowired private AICallLogManager logManager; - - // 在execute方法中添加详细日志记录 - public OUTPUT execute(INPUT input, String modelName) { + + // 在execute方法中添加详细日志记录和上下文支持 + public OUTPUT execute(INPUT input, String modelName, String sessionId) { String callId = UUID.randomUUID().toString(); LocalDateTime startTime = LocalDateTime.now(); long startMillis = System.currentTimeMillis(); - + AICallLog.AICallLogBuilder logBuilder = AICallLog.builder() .callId(callId) .operationType(annotation.value()) @@ -153,25 +171,25 @@ public OUTPUT execute(INPUT input, String modelName) { .modelName(modelName) .frequency(1) .lastAccessTime(startTime); - + try { // 获取模型 AIModel model = getModel(modelName); logBuilder.modelName(model.getName()); - - // 构建提示词 - String prompt = buildPrompt(input); + + // 构建提示词(带上下文支持) + String prompt = buildPromptWithContext(input, sessionId); logBuilder.prompt(prompt); - + // 获取操作配置 AIOperationRegistry.OperationConfig config = operationRegistry.getOperationConfig(annotation.value()); - + // 合并配置 Integer finalMaxTokens = config.getMaxTokens() > 0 ? Integer.valueOf(config.getMaxTokens()) : (annotation.defaultMaxTokens() > 0 ? annotation.defaultMaxTokens() : null); Double finalTemperature = config.getTemperature() >= 0 ? Double.valueOf(config.getTemperature()) : (annotation.defaultTemperature() >= 0 ? annotation.defaultTemperature() : null); Boolean finalJsonOutput = config.isRequireJsonOutput() || annotation.requireJsonOutput(); boolean finalThinking = config.isSupportThinking() || annotation.supportThinking(); - + // 记录请求参数 AICallLog.AIRequestParams requestParams = AICallLog.AIRequestParams.builder() .maxTokens(finalMaxTokens) @@ -180,7 +198,7 @@ public OUTPUT execute(INPUT input, String modelName) { .thinking(finalThinking) .build(); logBuilder.requestParams(requestParams); - + // 调用AI模型 String response; if (model instanceof OpenAICompatibleModel openAIModel) { @@ -192,12 +210,12 @@ public OUTPUT execute(INPUT input, String modelName) { } else { response = model.generate(prompt); } - + logBuilder.rawResponse(response); - + // 解析响应 OUTPUT result = parseResponse(response, input); - + // 记录成功日志 long duration = System.currentTimeMillis() - startMillis; AICallLog log = logBuilder @@ -205,11 +223,11 @@ public OUTPUT execute(INPUT input, String modelName) { .duration(duration) .output(result) .build(); - + logManager.addLog(log); - + return result; - + } catch (Exception e) { // 记录失败日志 long duration = System.currentTimeMillis() - startMillis; @@ -218,25 +236,94 @@ public OUTPUT execute(INPUT input, String modelName) { .duration(duration) .errorMessage(e.getMessage()) .build(); - + logManager.addLog(callLog); - + log.error("执行AI操作失败: {} - {}", annotation.value(), e.getMessage(), e); throw new RuntimeException("AI操作执行失败: " + e.getMessage(), e); } } + /** + * 构建带上下文的提示词 + * + * @param input 输入参数 + * @param sessionId 会话ID + * @return 完整的提示词 + */ + protected String buildPromptWithContext(INPUT input, String sessionId) { + // 构建基础提示词 + String basePrompt = buildPrompt(input); + + // 如果没有会话ID,直接返回基础提示词 + if (sessionId == null || !chatContextService.sessionExists(sessionId)) { + return basePrompt; + } + + // 获取上下文信息 + List contextMessages = chatContextService.getFullContext(sessionId); + if (contextMessages.isEmpty()) { + return basePrompt; + } + + // 构建带上下文的提示词 + StringBuilder contextPrompt = new StringBuilder(); + + // 添加系统提示词(如果存在) + String systemPrompt = chatContextService.getSystemPrompt(sessionId); + if (systemPrompt != null && !systemPrompt.trim().isEmpty()) { + contextPrompt.append("系统提示: ").append(systemPrompt).append("\n\n"); + } + + // 添加对话历史 + List conversationHistory = chatContextService.getConversationHistory(sessionId); + if (!conversationHistory.isEmpty()) { + contextPrompt.append("对话历史:\n"); + for (ChatMessage message : conversationHistory) { + String role = getRoleString(message.getType()); + contextPrompt.append(role).append(": ").append(message.getContent()).append("\n"); + } + contextPrompt.append("\n"); + } + + // 添加当前任务提示词 + contextPrompt.append("当前任务:\n").append(basePrompt); + + return contextPrompt.toString(); + } + + /** + * 获取角色字符串 + * + * @param type 消息类型 + * @return 角色字符串 + */ + private String getRoleString(ChatMessage.MessageType type) { + switch (type) { + case SYSTEM: + return "系统"; + case USER: + return "用户"; + case ASSISTANT: + return "助手"; + default: + return "未知"; + } + } + /** * 构建提示词(子类实现) - * + * * @param input 输入参数 * @return 提示词 */ protected abstract String buildPrompt(INPUT input); - + + // ... existing code remains the same ... + /** * 解析AI响应(最终方法,子类不应重写) - * + * * @param response AI响应 * @param input 输入参数 * @return 解析后的结果 @@ -245,20 +332,20 @@ protected final OUTPUT parseResponse(String response, INPUT input) { if (outputType == String.class) { return (OUTPUT) response; } - + try { // 1. 预处理响应(子类可自定义) String processedResponse = preprocessResponse(response, input); - + // 2. 提取JSON内容 String jsonContent = extractJsonFromResponse(processedResponse); - + // 3. 预处理JSON内容(子类可自定义) String processedJson = preprocessJson(jsonContent, input); - + // 4. 解析为对象(子类可自定义解析逻辑) return parseJsonToResult(processedJson, input, response); - + } catch (JsonProcessingException e) { // 如果启用了自动JSON修复且需要JSON输出,尝试修复JSON if (annotation.requireJsonOutput() && annotation.autoRepairJson()) { @@ -279,16 +366,16 @@ protected final OUTPUT parseResponse(String response, INPUT input) { throw new RuntimeException("JSON解析和修复都失败: 原始错误=" + e.getMessage() + ", 修复错误=" + repairException.getMessage(), e); } } - + log.error("解析AI响应失败: {}", e.getMessage(), e); throw new RuntimeException("解析AI响应失败: " + e.getMessage(), e); } } - + /** * 预处理AI响应(子类可重写) * 在提取JSON之前对原始响应进行处理 - * + * * @param response 原始AI响应 * @param input 输入参数 * @return 处理后的响应 @@ -296,11 +383,11 @@ protected final OUTPUT parseResponse(String response, INPUT input) { protected String preprocessResponse(String response, INPUT input) { return response; } - + /** * 预处理JSON内容(子类可重写) * 在JSON解析之前对提取的JSON字符串进行处理 - * + * * @param jsonContent 提取的JSON字符串 * @param input 输入参数 * @return 处理后的JSON字符串 @@ -308,10 +395,10 @@ protected String preprocessResponse(String response, INPUT input) { protected String preprocessJson(String jsonContent, INPUT input) { return jsonContent; } - + /** * 将JSON字符串解析为结果对象(高级用法,一般用户无需重写) - * + * * @param jsonContent JSON内容 * @param input 输入参数 * @param originalResponse 原始响应 @@ -324,15 +411,15 @@ protected OUTPUT parseJsonToResult(String jsonContent, INPUT input, String origi if (customResult != null) { return customResult; } - + // 如果用户没有自定义解析,使用默认的JSON解析 return objectMapper.readValue(jsonContent, outputType); } - + /** * 解析AI返回的JSON为最终结果(推荐用户重写此方法) * 用户可以在此方法中处理AI返回的原始JSON,并转换为最终的结果对象 - * + * * @param jsonContent AI返回的JSON字符串 * @param input 输入参数 * @return 最终结果对象,如果返回null则使用默认的JSON解析 @@ -340,10 +427,10 @@ protected OUTPUT parseJsonToResult(String jsonContent, INPUT input, String origi protected OUTPUT parseResult(String jsonContent, INPUT input) { return null; // 默认返回null,表示使用框架的默认JSON解析 } - + /** * 工具方法:将JSON字符串解析为指定类型的对象 - * + * * @param jsonContent JSON字符串 * @param clazz 目标类型 * @param 泛型类型 @@ -353,12 +440,10 @@ protected OUTPUT parseResult(String jsonContent, INPUT input) { protected T parseJsonToObject(String jsonContent, Class clazz) throws JsonProcessingException { return objectMapper.readValue(jsonContent, clazz); } - - /** * 从响应中提取JSON内容 - * + * * @param response 原始响应 * @return JSON字符串 */ @@ -366,7 +451,7 @@ protected String extractJsonFromResponse(String response) { // 查找JSON代码块 String jsonStart = "```json"; String jsonEnd = "```"; - + int startIndex = response.indexOf(jsonStart); if (startIndex != -1) { startIndex += jsonStart.length(); @@ -375,21 +460,21 @@ protected String extractJsonFromResponse(String response) { return response.substring(startIndex, endIndex).trim(); } } - + // 查找花括号包围的JSON int braceStart = response.indexOf('{'); int braceEnd = response.lastIndexOf('}'); if (braceStart != -1 && braceEnd != -1 && braceEnd > braceStart) { return response.substring(braceStart, braceEnd + 1); } - + // 如果都找不到,返回原始响应 return response; } - + /** * 获取模型实例 - * + * * @param modelName 模型名称,为null时使用默认模型 * @return 模型实例 */ @@ -398,27 +483,27 @@ private AIModel getModel(String modelName) { // 使用注册中心配置的模型 modelName = operationRegistry.getModelForOperation(annotation.value()); } - + if (modelName == null) { // 使用注解中的默认模型 modelName = annotation.defaultModel(); } - + if (modelName == null || modelName.isEmpty()) { throw new IllegalStateException("未配置模型: " + annotation.value()); } - + AIModel model = modelRegistry.getModel(modelName); if (model == null) { throw new IllegalArgumentException("模型不存在: " + modelName); } - + return model; } /** * 获取操作类型 - * + * * @return 操作类型 */ public String getOperationType() { @@ -427,26 +512,26 @@ public String getOperationType() { /** * 检查操作是否启用 - * + * * @return 是否启用 */ public boolean isEnabled() { AIOperationRegistry.OperationConfig config = operationRegistry.getOperationConfig(annotation.value()); return config.isEnabled() && annotation.enabled(); } - + /** * 获取操作描述 - * + * * @return 操作描述 */ public String getDescription() { return annotation.description(); } - + /** * 获取支持的模型列表 - * + * * @return 支持的模型列表 */ public String[] getSupportedModels() { diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/persistence/context/ChatContextService.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/persistence/context/ChatContextService.java new file mode 100644 index 0000000..cf6bec3 --- /dev/null +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/persistence/context/ChatContextService.java @@ -0,0 +1,87 @@ +package io.github.timemachinelab.sfchain.persistence.context; + +import java.util.List; + +/** + * 描述: + * @author suifeng + * 日期: 2025/8/14 + */ +public interface ChatContextService { + + /** + * 设置或更新系统提示词 + * @param sessionId 会话ID + * @param systemPrompt 系统提示词内容 + */ + void setSystemPrompt(String sessionId, String systemPrompt); + + /** + * 获取系统提示词 + * @param sessionId 会话ID + * @return 系统提示词内容,如果不存在返回null + */ + String getSystemPrompt(String sessionId); + + /** + * 添加用户消息到会话上下文 + * @param sessionId 会话ID + * @param userMessage 用户消息内容 + */ + void addUserMessage(String sessionId, String userMessage); + + /** + * 添加AI回复到会话上下文 + * @param sessionId 会话ID + * @param aiResponse AI回复内容 + */ + void addAiResponse(String sessionId, String aiResponse); + + /** + * 获取完整的对话上下文(包含系统提示词) + * @param sessionId 会话ID + * @return 完整的消息列表,系统提示词在首位 + */ + List getFullContext(String sessionId); + + /** + * 获取对话历史(不包含系统提示词) + * @param sessionId 会话ID + * @return 用户和AI的对话历史 + */ + List getConversationHistory(String sessionId); + + /** + * 获取格式化的上下文字符串 + * @param sessionId 会话ID + * @param includeSystemPrompt 是否包含系统提示词 + * @return 格式化的上下文字符串 + */ + String getContextAsString(String sessionId, boolean includeSystemPrompt); + + /** + * 清除会话的对话历史(保留系统提示词) + * @param sessionId 会话ID + */ + void clearConversation(String sessionId); + + /** + * 完全清除会话(包括系统提示词) + * @param sessionId 会话ID + */ + void clearSession(String sessionId); + + /** + * 检查会话是否存在 + * @param sessionId 会话ID + * @return 是否存在 + */ + boolean sessionExists(String sessionId); + + /** + * 获取对话消息数量(不包含系统提示词) + * @param sessionId 会话ID + * @return 消息数量 + */ + int getConversationMessageCount(String sessionId); +} \ No newline at end of file diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/persistence/context/ChatMessage.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/persistence/context/ChatMessage.java new file mode 100644 index 0000000..138b5d8 --- /dev/null +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/persistence/context/ChatMessage.java @@ -0,0 +1,62 @@ +package io.github.timemachinelab.sfchain.persistence.context; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.time.LocalDateTime; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class ChatMessage { + + /** + * 消息类型:SYSTEM, USER, ASSISTANT + */ + public enum MessageType { + SYSTEM, // 系统提示词 + USER, // 用户消息 + ASSISTANT // AI回复 + } + + private String id; + private MessageType type; + private String content; + private LocalDateTime timestamp; + private String sessionId; + + public static ChatMessage systemMessage(String sessionId, String content) { + return new ChatMessage( + generateId(), + MessageType.SYSTEM, + content, + LocalDateTime.now(), + sessionId + ); + } + + public static ChatMessage userMessage(String sessionId, String content) { + return new ChatMessage( + generateId(), + MessageType.USER, + content, + LocalDateTime.now(), + sessionId + ); + } + + public static ChatMessage assistantMessage(String sessionId, String content) { + return new ChatMessage( + generateId(), + MessageType.ASSISTANT, + content, + LocalDateTime.now(), + sessionId + ); + } + + private static String generateId() { + return System.currentTimeMillis() + "_" + (int)(Math.random() * 1000); + } +} \ No newline at end of file diff --git a/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/persistence/context/MapBasedChatContextService.java b/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/persistence/context/MapBasedChatContextService.java new file mode 100644 index 0000000..29e4c3b --- /dev/null +++ b/prompto-lab-app/src/main/java/io/github/timemachinelab/sfchain/persistence/context/MapBasedChatContextService.java @@ -0,0 +1,177 @@ +package io.github.timemachinelab.sfchain.persistence.context; + + +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * 描述: + * @author suifeng + * 日期: 2025/8/14 + */ +@Slf4j +@Service +public class MapBasedChatContextService implements ChatContextService { + + // 系统提示词存储 + private final Map systemPrompts = new ConcurrentHashMap<>(); + + // 对话历史存储 + private final Map> conversationHistories = new ConcurrentHashMap<>(); + + private static final int MAX_MESSAGES_PER_SESSION = 20; + + @Override + public void setSystemPrompt(String sessionId, String systemPrompt) { + if (sessionId == null || systemPrompt == null) { + log.warn("会话ID或系统提示词为空,跳过设置"); + return; + } + + systemPrompts.put(sessionId, systemPrompt); + log.debug("设置系统提示词: sessionId={}", sessionId); + } + + @Override + public String getSystemPrompt(String sessionId) { + if (sessionId == null) { + return null; + } + return systemPrompts.get(sessionId); + } + + @Override + public void addUserMessage(String sessionId, String userMessage) { + if (sessionId == null || userMessage == null) { + log.warn("会话ID或用户消息为空,跳过添加"); + return; + } + + ChatMessage message = ChatMessage.userMessage(sessionId, userMessage); + addConversationMessage(sessionId, message); + log.debug("添加用户消息: sessionId={}", sessionId); + } + + @Override + public void addAiResponse(String sessionId, String aiResponse) { + if (sessionId == null || aiResponse == null) { + log.warn("会话ID或AI回复为空,跳过添加"); + return; + } + + ChatMessage message = ChatMessage.assistantMessage(sessionId, aiResponse); + addConversationMessage(sessionId, message); + log.debug("添加AI回复: sessionId={}", sessionId); + } + + private void addConversationMessage(String sessionId, ChatMessage message) { + conversationHistories.computeIfAbsent(sessionId, k -> new ArrayList<>()).add(message); + + List messages = conversationHistories.get(sessionId); + while (messages.size() > MAX_MESSAGES_PER_SESSION) { + messages.remove(0); + log.debug("对话历史超限,移除最旧消息: sessionId={}", sessionId); + } + } + + @Override + public List getFullContext(String sessionId) { + if (sessionId == null) { + return new ArrayList<>(); + } + + List fullContext = new ArrayList<>(); + + // 添加系统提示词 + String systemPrompt = getSystemPrompt(sessionId); + if (systemPrompt != null) { + fullContext.add(ChatMessage.systemMessage(sessionId, systemPrompt)); + } + + // 添加对话历史 + fullContext.addAll(getConversationHistory(sessionId)); + + return fullContext; + } + + @Override + public List getConversationHistory(String sessionId) { + if (sessionId == null) { + return new ArrayList<>(); + } + + return new ArrayList<>(conversationHistories.getOrDefault(sessionId, new ArrayList<>())); + } + + @Override + public String getContextAsString(String sessionId, boolean includeSystemPrompt) { + List messages; + + if (includeSystemPrompt) { + messages = getFullContext(sessionId); + } else { + messages = getConversationHistory(sessionId); + } + + if (messages.isEmpty()) { + return ""; + } + + StringBuilder context = new StringBuilder(); + for (ChatMessage message : messages) { + String role = getRoleString(message.getType()); + context.append(role).append(": ").append(message.getContent()).append("\n"); + } + + return context.toString(); + } + + @Override + public void clearConversation(String sessionId) { + if (sessionId != null) { + conversationHistories.remove(sessionId); + log.info("清除对话历史: sessionId={}", sessionId); + } + } + + @Override + public void clearSession(String sessionId) { + if (sessionId != null) { + systemPrompts.remove(sessionId); + conversationHistories.remove(sessionId); + log.info("完全清除会话: sessionId={}", sessionId); + } + } + + @Override + public boolean sessionExists(String sessionId) { + return sessionId != null && + (systemPrompts.containsKey(sessionId) || conversationHistories.containsKey(sessionId)); + } + + @Override + public int getConversationMessageCount(String sessionId) { + if (sessionId == null) { + return 0; + } + return conversationHistories.getOrDefault(sessionId, new ArrayList<>()).size(); + } + + private String getRoleString(ChatMessage.MessageType type) { + switch (type) { + case SYSTEM: + return "系统"; + case USER: + return "用户"; + case ASSISTANT: + return "助手"; + default: + return "未知"; + } + } +} \ No newline at end of file diff --git a/prompto-lab-ui/src/assets/icons/default.svg b/prompto-lab-ui/src/assets/icons/default.svg new file mode 100644 index 0000000..f72b3a4 --- /dev/null +++ b/prompto-lab-ui/src/assets/icons/default.svg @@ -0,0 +1 @@ + \ No newline at end of file