From fc670210a85566a9a289c840e3240c3faac01f82 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 3 Feb 2026 22:36:12 +0200 Subject: [PATCH 1/7] Refactor embedding model handling and add AST parsing endpoints - Updated import paths for IssueDTO in prompt_builder.py. - Modified analysis_summary prompt in prompt_constants.py for clarity. - Enhanced Dockerfile with CPU threading optimizations. - Improved environment validation in main.py to support Ollama and OpenRouter embedding providers. - Added new endpoints for parsing files and batch parsing in api.py, including AST metadata extraction. - Introduced embedding_factory.py to manage embedding model creation for Ollama and OpenRouter. - Implemented Ollama embedding wrapper in ollama_embedding.py for local model support. - Updated index_manager to check vector dimensions before copying branches. - Refactored RAGConfig to validate embedding provider configurations and auto-detect embedding dimensions. - Adjusted query_service to utilize the new embedding factory for model instantiation. --- deployment/config/rag-pipeline/.env.sample | 17 + .../src/main/java/module-info.java | 1 + .../dto/request/ai/AiAnalysisRequestImpl.java | 17 + .../request/ai/enrichment/FileContentDto.java | 46 + .../ai/enrichment/FileRelationshipDto.java | 90 ++ .../ai/enrichment/ParsedFileMetadataDto.java | 70 + .../ai/enrichment/PrEnrichmentDataDto.java | 63 + .../service/PrFileEnrichmentService.java | 539 +++++++ .../codecrow/vcsclient/VcsClient.java | 35 + .../bitbucket/cloud/BitbucketCloudClient.java | 65 + .../vcsclient/github/GitHubClient.java | 109 ++ .../vcsclient/gitlab/GitLabClient.java | 62 + .../platformmcp/db/PlatformDbService.java | 0 .../service/BitbucketAiClientService.java | 35 +- .../github/service/GitHubAiClientService.java | 37 +- .../gitlab/service/GitLabAiClientService.java | 35 +- python-ecosystem/mcp-client/api/__init__.py | 8 + python-ecosystem/mcp-client/api/app.py | 34 + .../mcp-client/api/routers/__init__.py | 5 + .../mcp-client/api/routers/commands.py | 204 +++ .../mcp-client/api/routers/health.py | 12 + .../mcp-client/api/routers/review.py | 157 ++ python-ecosystem/mcp-client/main.py | 2 +- python-ecosystem/mcp-client/model/__init__.py | 98 ++ .../mcp-client/model/{models.py => dtos.py} | 164 +- .../mcp-client/model/enrichment.py | 59 + python-ecosystem/mcp-client/model/enums.py | 31 + .../mcp-client/model/multi_stage.py | 102 ++ .../mcp-client/model/output_schemas.py | 47 + .../mcp-client/server/stdin_handler.py | 4 +- .../mcp-client/server/web_server.py | 284 ---- .../mcp-client/service/__init__.py | 44 + .../mcp-client/service/command/__init__.py | 8 + .../service/{ => command}/command_service.py | 5 +- .../service/multi_stage_orchestrator.py | 1429 ----------------- .../service/pooled_review_service.py | 261 --- .../mcp-client/service/rag/__init__.py | 19 + .../service/{ => rag}/llm_reranker.py | 0 .../service/{ => rag}/rag_client.py | 0 .../mcp-client/service/review/__init__.py | 39 + .../issue_processor.py} | 0 .../service/review/orchestrator/__init__.py | 25 + .../service/review/orchestrator/agents.py | 81 + .../review/orchestrator/context_helpers.py | 245 +++ .../service/review/orchestrator/json_utils.py | 155 ++ .../review/orchestrator/orchestrator.py | 247 +++ .../review/orchestrator/reconciliation.py | 280 ++++ .../service/review/orchestrator/stages.py | 666 ++++++++ .../service/{ => review}/review_service.py | 10 +- .../mcp-client/tests/test_dependency_graph.py | 366 +++++ .../mcp-client/utils/dependency_graph.py | 630 ++++++++ .../utils/prompts/prompt_builder.py | 2 +- .../utils/prompts/prompt_constants.py | 3 +- python-ecosystem/rag-pipeline/Dockerfile | 7 + python-ecosystem/rag-pipeline/main.py | 56 +- .../rag-pipeline/src/rag_pipeline/api/api.py | 154 ++ .../rag_pipeline/core/embedding_factory.py | 93 ++ .../core/index_manager/branch_manager.py | 35 + .../core/index_manager/indexer.py | 1 - .../core/index_manager/manager.py | 17 +- .../src/rag_pipeline/core/ollama_embedding.py | 252 +++ .../src/rag_pipeline/models/config.py | 55 +- .../rag_pipeline/services/query_service.py | 15 +- 63 files changed, 5429 insertions(+), 2203 deletions(-) create mode 100644 java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/FileContentDto.java create mode 100644 java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/FileRelationshipDto.java create mode 100644 java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/ParsedFileMetadataDto.java create mode 100644 java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/PrEnrichmentDataDto.java create mode 100644 java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/service/PrFileEnrichmentService.java delete mode 100644 java-ecosystem/mcp-servers/platform-mcp/src/main/java/org/rostilos/codecrow/platformmcp/db/PlatformDbService.java create mode 100644 python-ecosystem/mcp-client/api/__init__.py create mode 100644 python-ecosystem/mcp-client/api/app.py create mode 100644 python-ecosystem/mcp-client/api/routers/__init__.py create mode 100644 python-ecosystem/mcp-client/api/routers/commands.py create mode 100644 python-ecosystem/mcp-client/api/routers/health.py create mode 100644 python-ecosystem/mcp-client/api/routers/review.py create mode 100644 python-ecosystem/mcp-client/model/__init__.py rename python-ecosystem/mcp-client/model/{models.py => dtos.py} (52%) create mode 100644 python-ecosystem/mcp-client/model/enrichment.py create mode 100644 python-ecosystem/mcp-client/model/enums.py create mode 100644 python-ecosystem/mcp-client/model/multi_stage.py create mode 100644 python-ecosystem/mcp-client/model/output_schemas.py delete mode 100644 python-ecosystem/mcp-client/server/web_server.py create mode 100644 python-ecosystem/mcp-client/service/__init__.py create mode 100644 python-ecosystem/mcp-client/service/command/__init__.py rename python-ecosystem/mcp-client/service/{ => command}/command_service.py (99%) delete mode 100644 python-ecosystem/mcp-client/service/multi_stage_orchestrator.py delete mode 100644 python-ecosystem/mcp-client/service/pooled_review_service.py create mode 100644 python-ecosystem/mcp-client/service/rag/__init__.py rename python-ecosystem/mcp-client/service/{ => rag}/llm_reranker.py (100%) rename python-ecosystem/mcp-client/service/{ => rag}/rag_client.py (100%) create mode 100644 python-ecosystem/mcp-client/service/review/__init__.py rename python-ecosystem/mcp-client/service/{issue_post_processor.py => review/issue_processor.py} (100%) create mode 100644 python-ecosystem/mcp-client/service/review/orchestrator/__init__.py create mode 100644 python-ecosystem/mcp-client/service/review/orchestrator/agents.py create mode 100644 python-ecosystem/mcp-client/service/review/orchestrator/context_helpers.py create mode 100644 python-ecosystem/mcp-client/service/review/orchestrator/json_utils.py create mode 100644 python-ecosystem/mcp-client/service/review/orchestrator/orchestrator.py create mode 100644 python-ecosystem/mcp-client/service/review/orchestrator/reconciliation.py create mode 100644 python-ecosystem/mcp-client/service/review/orchestrator/stages.py rename python-ecosystem/mcp-client/service/{ => review}/review_service.py (98%) create mode 100644 python-ecosystem/mcp-client/tests/test_dependency_graph.py create mode 100644 python-ecosystem/mcp-client/utils/dependency_graph.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/embedding_factory.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/ollama_embedding.py diff --git a/deployment/config/rag-pipeline/.env.sample b/deployment/config/rag-pipeline/.env.sample index e19d8008..a0bb870c 100644 --- a/deployment/config/rag-pipeline/.env.sample +++ b/deployment/config/rag-pipeline/.env.sample @@ -2,6 +2,23 @@ QDRANT_URL=http://qdrant:6333 QDRANT_COLLECTION_PREFIX=codecrow +# ollama/openrouter +EMBEDDING_PROVIDER=openrouter +OLLAMA_BASE_URL=http://localhost:11434 +OLLAMA_EMBEDDING_MODEL=qwen3-embedding:0.6b + +# Ollama Performance Tuning +# Batch size for embedding requests (higher = better throughput, more memory) +OLLAMA_BATCH_SIZE=100 +# Request timeout in seconds (increase for slow CPU) +OLLAMA_TIMEOUT=120 + +# CPU Threading Optimization (set based on your CPU cores) +# Recommended: physical_cores - 1 (leave 1 core for system) +OMP_NUM_THREADS=6 +MKL_NUM_THREADS=6 +OPENBLAS_NUM_THREADS=6 + # OpenRouter Configuration # Get your API key from https://openrouter.ai/ OPENROUTER_API_KEY=sk-or-v1-your-api-key-here diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/module-info.java b/java-ecosystem/libs/analysis-engine/src/main/java/module-info.java index d03cdf59..c61e453f 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/module-info.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/module-info.java @@ -23,6 +23,7 @@ exports org.rostilos.codecrow.analysisengine.aiclient; exports org.rostilos.codecrow.analysisengine.config; exports org.rostilos.codecrow.analysisengine.dto.request.ai; + exports org.rostilos.codecrow.analysisengine.dto.request.ai.enrichment; exports org.rostilos.codecrow.analysisengine.dto.request.processor; exports org.rostilos.codecrow.analysisengine.dto.request.validation; exports org.rostilos.codecrow.analysisengine.exception; diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiAnalysisRequestImpl.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiAnalysisRequestImpl.java index 745b5e42..54f05b5a 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiAnalysisRequestImpl.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/AiAnalysisRequestImpl.java @@ -1,6 +1,7 @@ package org.rostilos.codecrow.analysisengine.dto.request.ai; import com.fasterxml.jackson.annotation.JsonProperty; +import org.rostilos.codecrow.analysisengine.dto.request.ai.enrichment.PrEnrichmentDataDto; import org.rostilos.codecrow.core.model.ai.AIConnection; import org.rostilos.codecrow.core.model.ai.AIProviderKey; import org.rostilos.codecrow.core.model.codeanalysis.AnalysisMode; @@ -44,6 +45,9 @@ public class AiAnalysisRequestImpl implements AiAnalysisRequest{ protected final String deltaDiff; protected final String previousCommitHash; protected final String currentCommitHash; + + // File enrichment data (full file contents + dependency graph) + protected final PrEnrichmentDataDto enrichmentData; protected AiAnalysisRequestImpl(Builder builder) { this.projectId = builder.projectId; @@ -74,6 +78,8 @@ protected AiAnalysisRequestImpl(Builder builder) { this.deltaDiff = builder.deltaDiff; this.previousCommitHash = builder.previousCommitHash; this.currentCommitHash = builder.currentCommitHash; + // File enrichment data + this.enrichmentData = builder.enrichmentData; } public Long getProjectId() { @@ -181,6 +187,10 @@ public String getCurrentCommitHash() { return currentCommitHash; } + public PrEnrichmentDataDto getEnrichmentData() { + return enrichmentData; + } + public static Builder builder() { return new Builder<>(); @@ -216,6 +226,8 @@ public static class Builder> { private String deltaDiff; private String previousCommitHash; private String currentCommitHash; + // File enrichment data + private PrEnrichmentDataDto enrichmentData; protected Builder() { } @@ -461,6 +473,11 @@ public T withCurrentCommitHash(String currentCommitHash) { return self(); } + public T withEnrichmentData(PrEnrichmentDataDto enrichmentData) { + this.enrichmentData = enrichmentData; + return self(); + } + public AiAnalysisRequestImpl build() { return new AiAnalysisRequestImpl(this); } diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/FileContentDto.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/FileContentDto.java new file mode 100644 index 00000000..85566f95 --- /dev/null +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/FileContentDto.java @@ -0,0 +1,46 @@ +package org.rostilos.codecrow.analysisengine.dto.request.ai.enrichment; + +/** + * DTO representing the content of a single file retrieved from VCS. + * Used for file enrichment during PR analysis to provide full file context. + */ +public record FileContentDto( + String path, + String content, + long sizeBytes, + boolean skipped, + String skipReason +) { + /** + * Create a successful file content result. + */ + public static FileContentDto of(String path, String content) { + return new FileContentDto( + path, + content, + content != null ? content.getBytes().length : 0, + false, + null + ); + } + + /** + * Create a skipped file result (e.g., file too large, binary, or fetch failed). + */ + public static FileContentDto skipped(String path, String reason) { + return new FileContentDto(path, null, 0, true, reason); + } + + /** + * Create a skipped file result due to size limit. + */ + public static FileContentDto skippedDueToSize(String path, long actualSize, long maxSize) { + return new FileContentDto( + path, + null, + actualSize, + true, + String.format("File size %d bytes exceeds limit %d bytes", actualSize, maxSize) + ); + } +} diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/FileRelationshipDto.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/FileRelationshipDto.java new file mode 100644 index 00000000..21adaa2d --- /dev/null +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/FileRelationshipDto.java @@ -0,0 +1,90 @@ +package org.rostilos.codecrow.analysisengine.dto.request.ai.enrichment; + +/** + * DTO representing a relationship between two files in the PR. + * Used for building the dependency graph for intelligent batching. + */ +public record FileRelationshipDto( + String sourceFile, + String targetFile, + RelationshipType relationshipType, + String matchedOn, + int strength +) { + /** + * Types of relationships between files. + */ + public enum RelationshipType { + IMPORTS, + EXTENDS, + IMPLEMENTS, + CALLS, + SAME_PACKAGE, + REFERENCES + } + + /** + * Create an import relationship. + */ + public static FileRelationshipDto imports(String sourceFile, String targetFile, String importStatement) { + return new FileRelationshipDto( + sourceFile, + targetFile, + RelationshipType.IMPORTS, + importStatement, + 10 // High strength for direct imports + ); + } + + /** + * Create an extends relationship. + */ + public static FileRelationshipDto extendsClass(String sourceFile, String targetFile, String className) { + return new FileRelationshipDto( + sourceFile, + targetFile, + RelationshipType.EXTENDS, + className, + 15 // Highest strength for inheritance + ); + } + + /** + * Create an implements relationship. + */ + public static FileRelationshipDto implementsInterface(String sourceFile, String targetFile, String interfaceName) { + return new FileRelationshipDto( + sourceFile, + targetFile, + RelationshipType.IMPLEMENTS, + interfaceName, + 15 // Highest strength for interface implementation + ); + } + + /** + * Create a calls relationship. + */ + public static FileRelationshipDto calls(String sourceFile, String targetFile, String methodName) { + return new FileRelationshipDto( + sourceFile, + targetFile, + RelationshipType.CALLS, + methodName, + 8 // Medium-high strength for method calls + ); + } + + /** + * Create a same-package relationship. + */ + public static FileRelationshipDto samePackage(String sourceFile, String targetFile, String packageName) { + return new FileRelationshipDto( + sourceFile, + targetFile, + RelationshipType.SAME_PACKAGE, + packageName, + 3 // Low strength for implicit package relationship + ); + } +} diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/ParsedFileMetadataDto.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/ParsedFileMetadataDto.java new file mode 100644 index 00000000..c6017ab5 --- /dev/null +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/ParsedFileMetadataDto.java @@ -0,0 +1,70 @@ +package org.rostilos.codecrow.analysisengine.dto.request.ai.enrichment; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +/** + * DTO representing parsed AST metadata for a single file. + * Mirrors the response from RAG pipeline's /parse endpoint. + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public record ParsedFileMetadataDto( + @JsonProperty("path") String path, + @JsonProperty("language") String language, + @JsonProperty("imports") List imports, + @JsonProperty("extends") List extendsClasses, + @JsonProperty("implements") List implementsInterfaces, + @JsonProperty("semantic_names") List semanticNames, + @JsonProperty("parent_class") String parentClass, + @JsonProperty("namespace") String namespace, + @JsonProperty("calls") List calls, + @JsonProperty("error") String error +) { + /** + * Create a metadata result with only imports and extends (minimal parsing). + */ + public static ParsedFileMetadataDto minimal(String path, List imports, List extendsClasses) { + return new ParsedFileMetadataDto( + path, + null, + imports, + extendsClasses, + List.of(), + List.of(), + null, + null, + List.of(), + null + ); + } + + /** + * Create an error result for a file that couldn't be parsed. + */ + public static ParsedFileMetadataDto error(String path, String errorMessage) { + return new ParsedFileMetadataDto( + path, + null, + List.of(), + List.of(), + List.of(), + List.of(), + null, + null, + List.of(), + errorMessage + ); + } + + /** + * Check if this metadata has any relationships to extract. + */ + public boolean hasRelationships() { + return (imports != null && !imports.isEmpty()) || + (extendsClasses != null && !extendsClasses.isEmpty()) || + (implementsInterfaces != null && !implementsInterfaces.isEmpty()) || + (calls != null && !calls.isEmpty()); + } +} diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/PrEnrichmentDataDto.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/PrEnrichmentDataDto.java new file mode 100644 index 00000000..2b1dec3b --- /dev/null +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/ai/enrichment/PrEnrichmentDataDto.java @@ -0,0 +1,63 @@ +package org.rostilos.codecrow.analysisengine.dto.request.ai.enrichment; + +import java.util.List; +import java.util.Map; + +/** + * Aggregate DTO containing all file enrichment data for a PR. + * This is the result of the enrichment process and can be serialized for transfer. + */ +public record PrEnrichmentDataDto( + List fileContents, + List fileMetadata, + List relationships, + EnrichmentStats stats +) { + /** + * Statistics about the enrichment process. + */ + public record EnrichmentStats( + int totalFilesRequested, + int filesEnriched, + int filesSkipped, + int relationshipsFound, + long totalContentSizeBytes, + long processingTimeMs, + Map skipReasons + ) { + public static EnrichmentStats empty() { + return new EnrichmentStats(0, 0, 0, 0, 0, 0, Map.of()); + } + } + + /** + * Create empty enrichment data (when enrichment is disabled or not applicable). + */ + public static PrEnrichmentDataDto empty() { + return new PrEnrichmentDataDto( + List.of(), + List.of(), + List.of(), + EnrichmentStats.empty() + ); + } + + /** + * Check if enrichment data is present. + */ + public boolean hasData() { + return (fileContents != null && !fileContents.isEmpty()) || + (relationships != null && !relationships.isEmpty()); + } + + /** + * Get total size of all file contents in bytes. + */ + public long getTotalContentSize() { + if (fileContents == null) return 0; + return fileContents.stream() + .filter(f -> !f.skipped()) + .mapToLong(FileContentDto::sizeBytes) + .sum(); + } +} diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/service/PrFileEnrichmentService.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/service/PrFileEnrichmentService.java new file mode 100644 index 00000000..6df5ea55 --- /dev/null +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/service/PrFileEnrichmentService.java @@ -0,0 +1,539 @@ +package org.rostilos.codecrow.analysisengine.service; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import okhttp3.*; +import org.rostilos.codecrow.analysisengine.dto.request.ai.enrichment.*; +import org.rostilos.codecrow.vcsclient.VcsClient; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; + +import java.io.IOException; +import java.util.*; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +/** + * Service responsible for enriching PR analysis requests with full file contents + * and pre-computed dependency relationships. + * + * This service: + * 1. Fetches full file contents for changed files (with size limits) + * 2. Calls RAG pipeline's /parse endpoint to extract AST metadata + * 3. Builds a relationship graph from the parsed metadata + * 4. Returns enriched data for intelligent batching in Python + */ +@Service +public class PrFileEnrichmentService { + private static final Logger log = LoggerFactory.getLogger(PrFileEnrichmentService.class); + + private static final MediaType JSON_MEDIA_TYPE = MediaType.get("application/json; charset=utf-8"); + + @Value("${pr.enrichment.enabled:true}") + private boolean enrichmentEnabled; + + @Value("${pr.enrichment.max-file-size-bytes:102400}") // 100KB default + private long maxFileSizeBytes; + + @Value("${pr.enrichment.max-total-size-bytes:10485760}") // 10MB default + private long maxTotalSizeBytes; + + @Value("${pr.enrichment.rag-pipeline-url:http://localhost:8006}") + private String ragPipelineUrl; + + @Value("${pr.enrichment.request-timeout-seconds:60}") + private int requestTimeoutSeconds; + + private final ObjectMapper objectMapper; + private final OkHttpClient httpClient; + + public PrFileEnrichmentService() { + this.objectMapper = new ObjectMapper(); + this.httpClient = new OkHttpClient.Builder() + .connectTimeout(30, TimeUnit.SECONDS) + .readTimeout(60, TimeUnit.SECONDS) + .writeTimeout(30, TimeUnit.SECONDS) + .build(); + } + + /** + * Check if PR enrichment is enabled. + */ + public boolean isEnrichmentEnabled() { + return enrichmentEnabled; + } + + /** + * Enrich PR analysis with full file contents and dependency graph. + * + * @param vcsClient VCS client for fetching file contents + * @param workspace VCS workspace/owner + * @param repoSlug Repository slug + * @param branch Branch name or commit SHA for file content + * @param changedFiles List of changed file paths + * @return Enrichment data with file contents and relationships + */ + public PrEnrichmentDataDto enrichPrFiles( + VcsClient vcsClient, + String workspace, + String repoSlug, + String branch, + List changedFiles + ) { + if (!enrichmentEnabled) { + log.debug("PR enrichment is disabled"); + return PrEnrichmentDataDto.empty(); + } + + if (changedFiles == null || changedFiles.isEmpty()) { + log.debug("No changed files to enrich"); + return PrEnrichmentDataDto.empty(); + } + + long startTime = System.currentTimeMillis(); + Map skipReasons = new HashMap<>(); + + try { + // Step 1: Filter to supported file types + List supportedFiles = filterSupportedFiles(changedFiles, skipReasons); + if (supportedFiles.isEmpty()) { + log.info("No supported files to enrich after filtering"); + return createEmptyResultWithStats(changedFiles.size(), 0, skipReasons, startTime); + } + + // Step 2: Fetch file contents in batch + log.info("Fetching {} file contents for enrichment (branch: {})", supportedFiles.size(), branch); + Map fileContents = vcsClient.getFileContents( + workspace, repoSlug, supportedFiles, branch, (int) Math.min(maxFileSizeBytes, Integer.MAX_VALUE) + ); + + // Step 3: Build FileContentDto list + List contentDtos = buildFileContentDtos( + supportedFiles, fileContents, skipReasons + ); + + // Check total size limit + long totalSize = contentDtos.stream() + .filter(f -> !f.skipped()) + .mapToLong(FileContentDto::sizeBytes) + .sum(); + + if (totalSize > maxTotalSizeBytes) { + log.warn("Total file content size {} exceeds limit {} - truncating", + totalSize, maxTotalSizeBytes); + contentDtos = truncateToSizeLimit(contentDtos, maxTotalSizeBytes, skipReasons); + } + + // Step 4: Parse files to extract AST metadata + List metadata = parseFilesForMetadata(contentDtos); + + // Step 5: Build relationship graph from metadata + List relationships = buildRelationshipGraph( + metadata, changedFiles + ); + + long processingTime = System.currentTimeMillis() - startTime; + + // Build stats + int filesEnriched = (int) contentDtos.stream().filter(f -> !f.skipped()).count(); + int filesSkipped = changedFiles.size() - filesEnriched; + + PrEnrichmentDataDto.EnrichmentStats stats = new PrEnrichmentDataDto.EnrichmentStats( + changedFiles.size(), + filesEnriched, + filesSkipped, + relationships.size(), + totalSize, + processingTime, + skipReasons + ); + + log.info("PR enrichment completed: {} files enriched, {} skipped, {} relationships in {}ms", + filesEnriched, filesSkipped, relationships.size(), processingTime); + + return new PrEnrichmentDataDto(contentDtos, metadata, relationships, stats); + + } catch (Exception e) { + log.error("Failed to enrich PR files: {}", e.getMessage(), e); + return createEmptyResultWithStats( + changedFiles.size(), 0, + Map.of("error", changedFiles.size()), + startTime + ); + } + } + + /** + * Filter files to only those with supported extensions for parsing. + */ + private List filterSupportedFiles(List files, Map skipReasons) { + Set supportedExtensions = Set.of( + ".java", ".py", ".js", ".ts", ".jsx", ".tsx", + ".go", ".rs", ".rb", ".php", ".cs", ".cpp", ".c", ".h", + ".kt", ".scala", ".swift", ".m", ".mm" + ); + + List supported = new ArrayList<>(); + for (String file : files) { + String lower = file.toLowerCase(); + boolean isSupported = supportedExtensions.stream() + .anyMatch(lower::endsWith); + + if (isSupported) { + supported.add(file); + } else { + skipReasons.merge("unsupported_extension", 1, Integer::sum); + } + } + + return supported; + } + + /** + * Build FileContentDto list from fetched contents. + */ + private List buildFileContentDtos( + List requestedFiles, + Map fileContents, + Map skipReasons + ) { + List result = new ArrayList<>(); + + for (String path : requestedFiles) { + String content = fileContents.get(path); + + if (content == null) { + result.add(FileContentDto.skipped(path, "fetch_failed")); + skipReasons.merge("fetch_failed", 1, Integer::sum); + } else if (content.isEmpty()) { + result.add(FileContentDto.skipped(path, "empty_file")); + skipReasons.merge("empty_file", 1, Integer::sum); + } else { + result.add(FileContentDto.of(path, content)); + } + } + + return result; + } + + /** + * Truncate file list to stay within total size limit. + * Prioritizes smaller files to maximize coverage. + */ + private List truncateToSizeLimit( + List contents, + long maxTotalSize, + Map skipReasons + ) { + // Sort by size (smallest first) to include more files + List sorted = contents.stream() + .filter(f -> !f.skipped()) + .sorted(Comparator.comparingLong(FileContentDto::sizeBytes)) + .collect(Collectors.toCollection(ArrayList::new)); + + List result = new ArrayList<>(); + long currentSize = 0; + + for (FileContentDto file : sorted) { + if (currentSize + file.sizeBytes() <= maxTotalSize) { + result.add(file); + currentSize += file.sizeBytes(); + } else { + result.add(FileContentDto.skipped(file.path(), "total_size_limit_exceeded")); + skipReasons.merge("total_size_limit", 1, Integer::sum); + } + } + + // Add already-skipped files + contents.stream() + .filter(FileContentDto::skipped) + .forEach(result::add); + + return result; + } + + /** + * Call RAG pipeline's /parse/batch endpoint to extract AST metadata. + */ + private List parseFilesForMetadata(List contents) { + List filesToParse = contents.stream() + .filter(f -> !f.skipped()) + .toList(); + + if (filesToParse.isEmpty()) { + return Collections.emptyList(); + } + + try { + // Build batch request + List> files = filesToParse.stream() + .map(f -> Map.of("path", f.path(), "content", f.content())) + .toList(); + + Map requestBody = Map.of("files", files); + String jsonBody = objectMapper.writeValueAsString(requestBody); + + Request request = new Request.Builder() + .url(ragPipelineUrl + "/parse/batch") + .post(RequestBody.create(jsonBody, JSON_MEDIA_TYPE)) + .build(); + + try (Response response = httpClient.newCall(request).execute()) { + if (!response.isSuccessful()) { + log.warn("RAG /parse/batch failed with status {}: {}", + response.code(), response.message()); + return createFallbackMetadata(filesToParse); + } + + ResponseBody responseBody = response.body(); + if (responseBody == null) { + return createFallbackMetadata(filesToParse); + } + + ParseBatchResponse batchResponse = objectMapper.readValue( + responseBody.string(), + ParseBatchResponse.class + ); + + return batchResponse.results != null + ? batchResponse.results + : createFallbackMetadata(filesToParse); + } + + } catch (IOException e) { + log.warn("Failed to call RAG /parse/batch: {}", e.getMessage()); + return createFallbackMetadata(filesToParse); + } + } + + /** + * Create minimal metadata when RAG parsing fails. + */ + private List createFallbackMetadata(List files) { + return files.stream() + .map(f -> ParsedFileMetadataDto.error(f.path(), "parse_failed")) + .toList(); + } + + /** + * Build relationship graph from parsed metadata. + * Matches imports/extends/calls against changed files. + */ + private List buildRelationshipGraph( + List metadata, + List changedFiles + ) { + List relationships = new ArrayList<>(); + + // Build a map of possible file matches for quick lookup + Map nameToPath = buildNameToPathMap(changedFiles); + + for (ParsedFileMetadataDto file : metadata) { + if (file.error() != null) continue; + + // Process imports + if (file.imports() != null) { + for (String importStmt : file.imports()) { + String targetPath = findMatchingFile(importStmt, nameToPath, changedFiles); + if (targetPath != null && !targetPath.equals(file.path())) { + relationships.add(FileRelationshipDto.imports( + file.path(), targetPath, importStmt)); + } + } + } + + // Process extends + if (file.extendsClasses() != null) { + for (String className : file.extendsClasses()) { + String targetPath = findMatchingFile(className, nameToPath, changedFiles); + if (targetPath != null && !targetPath.equals(file.path())) { + relationships.add(FileRelationshipDto.extendsClass( + file.path(), targetPath, className)); + } + } + } + + // Process implements + if (file.implementsInterfaces() != null) { + for (String interfaceName : file.implementsInterfaces()) { + String targetPath = findMatchingFile(interfaceName, nameToPath, changedFiles); + if (targetPath != null && !targetPath.equals(file.path())) { + relationships.add(FileRelationshipDto.implementsInterface( + file.path(), targetPath, interfaceName)); + } + } + } + + // Process calls + if (file.calls() != null) { + for (String call : file.calls()) { + String targetPath = findMatchingFile(call, nameToPath, changedFiles); + if (targetPath != null && !targetPath.equals(file.path())) { + relationships.add(FileRelationshipDto.calls( + file.path(), targetPath, call)); + } + } + } + } + + // Add same-package relationships + addSamePackageRelationships(changedFiles, relationships); + + // Deduplicate relationships + return relationships.stream() + .distinct() + .toList(); + } + + /** + * Build a map of class/module names to file paths for matching. + */ + private Map buildNameToPathMap(List filePaths) { + Map map = new HashMap<>(); + + for (String path : filePaths) { + // Extract filename without extension + String fileName = path.contains("/") + ? path.substring(path.lastIndexOf('/') + 1) + : path; + String nameWithoutExt = fileName.contains(".") + ? fileName.substring(0, fileName.lastIndexOf('.')) + : fileName; + + map.put(nameWithoutExt, path); + map.put(nameWithoutExt.toLowerCase(), path); + + // For Java-style paths, also map the package structure + // e.g., com/example/MyClass.java -> MyClass + if (path.endsWith(".java")) { + String className = nameWithoutExt; + map.put(className, path); + + // Also map full package path without extension + String fullPath = path.replace('/', '.').replace(".java", ""); + map.put(fullPath, path); + } + + // For Python, map module paths + if (path.endsWith(".py")) { + String modulePath = path.replace('/', '.').replace(".py", ""); + map.put(modulePath, path); + } + } + + return map; + } + + /** + * Find a matching file path for an import/extends statement. + */ + private String findMatchingFile( + String reference, + Map nameToPath, + List changedFiles + ) { + if (reference == null || reference.isEmpty()) return null; + + // Try direct match + if (nameToPath.containsKey(reference)) { + return nameToPath.get(reference); + } + + // Extract last component (class name from qualified name) + String simpleName = reference.contains(".") + ? reference.substring(reference.lastIndexOf('.') + 1) + : reference; + + if (nameToPath.containsKey(simpleName)) { + return nameToPath.get(simpleName); + } + + // Try case-insensitive + String lowerName = simpleName.toLowerCase(); + if (nameToPath.containsKey(lowerName)) { + return nameToPath.get(lowerName); + } + + // Try partial path matching + String normalizedRef = reference.replace('.', '/'); + for (String path : changedFiles) { + if (path.contains(normalizedRef) || path.endsWith(simpleName + ".java") + || path.endsWith(simpleName + ".py") + || path.endsWith(simpleName + ".ts") + || path.endsWith(simpleName + ".js")) { + return path; + } + } + + return null; + } + + /** + * Add implicit relationships for files in the same package/directory. + */ + private void addSamePackageRelationships( + List changedFiles, + List relationships + ) { + // Group files by directory + Map> filesByDir = changedFiles.stream() + .collect(Collectors.groupingBy(path -> { + int lastSlash = path.lastIndexOf('/'); + return lastSlash > 0 ? path.substring(0, lastSlash) : ""; + })); + + // Add relationships within each directory + for (Map.Entry> entry : filesByDir.entrySet()) { + List filesInDir = entry.getValue(); + String packageName = entry.getKey(); + + if (filesInDir.size() > 1) { + for (int i = 0; i < filesInDir.size(); i++) { + for (int j = i + 1; j < filesInDir.size(); j++) { + relationships.add(FileRelationshipDto.samePackage( + filesInDir.get(i), filesInDir.get(j), packageName)); + } + } + } + } + } + + private PrEnrichmentDataDto createEmptyResultWithStats( + int totalFiles, + int enriched, + Map skipReasons, + long startTime + ) { + long processingTime = System.currentTimeMillis() - startTime; + return new PrEnrichmentDataDto( + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + new PrEnrichmentDataDto.EnrichmentStats( + totalFiles, enriched, totalFiles - enriched, + 0, 0, processingTime, skipReasons + ) + ); + } + + /** + * Response DTO for /parse/batch endpoint. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + private static class ParseBatchResponse { + @JsonProperty("results") + public List results; + + @JsonProperty("total_files") + public int totalFiles; + + @JsonProperty("successful") + public int successful; + + @JsonProperty("failed") + public int failed; + } +} diff --git a/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/VcsClient.java b/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/VcsClient.java index d6946c7d..ff7204d1 100644 --- a/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/VcsClient.java +++ b/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/VcsClient.java @@ -218,4 +218,39 @@ default List getRepositoryCollaborators(String workspaceId, Str default String getBranchDiff(String workspaceId, String repoIdOrSlug, String baseBranch, String compareBranch) throws IOException { throw new UnsupportedOperationException("Branch diff not supported by this provider"); } + + /** + * Batch fetch file contents from repository. + * Implementations should use provider-optimal strategies: + * - GitHub: GraphQL batch query + * - Bitbucket/GitLab: Parallel fetch with exponential backoff + * + * @param workspaceId the external workspace/org ID + * @param repoIdOrSlug the repository ID or slug + * @param filePaths list of file paths to fetch + * @param branchOrCommit branch name or commit hash + * @param maxFileSizeBytes maximum file size to fetch (skip larger files) + * @return map of path -> content (missing/skipped files not in map) + */ + default java.util.Map getFileContents( + String workspaceId, + String repoIdOrSlug, + List filePaths, + String branchOrCommit, + int maxFileSizeBytes + ) throws IOException { + // Default implementation: sequential fetch with size check + java.util.Map results = new java.util.HashMap<>(); + for (String path : filePaths) { + try { + String content = getFileContent(workspaceId, repoIdOrSlug, path, branchOrCommit); + if (content != null && content.length() <= maxFileSizeBytes) { + results.put(path, content); + } + } catch (IOException e) { + // Skip files that fail to fetch + } + } + return results; + } } diff --git a/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/bitbucket/cloud/BitbucketCloudClient.java b/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/bitbucket/cloud/BitbucketCloudClient.java index 9c3e4b46..41989b9f 100644 --- a/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/bitbucket/cloud/BitbucketCloudClient.java +++ b/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/bitbucket/cloud/BitbucketCloudClient.java @@ -5,6 +5,8 @@ import okhttp3.*; import org.rostilos.codecrow.vcsclient.VcsClient; import org.rostilos.codecrow.vcsclient.model.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.InputStream; @@ -20,6 +22,7 @@ */ public class BitbucketCloudClient implements VcsClient { + private static final Logger log = LoggerFactory.getLogger(BitbucketCloudClient.class); private static final String API_BASE = "https://api.bitbucket.org/2.0"; private static final int DEFAULT_PAGE_SIZE = 50; private static final MediaType JSON_MEDIA_TYPE = MediaType.parse("application/json"); @@ -791,4 +794,66 @@ private VcsCollaborator parseCollaboratorPermission(JsonNode permNode) { return new VcsCollaborator(userId, username, displayName, avatarUrl, permission, htmlUrl); } + + /** + * Batch fetch file contents with parallel execution and exponential backoff. + * Bitbucket Cloud doesn't have a batch API, so we fetch in parallel with rate limit handling. + */ + @Override + public java.util.Map getFileContents( + String workspaceId, + String repoIdOrSlug, + java.util.List filePaths, + String branchOrCommit, + int maxFileSizeBytes + ) throws IOException { + java.util.Map results = new java.util.concurrent.ConcurrentHashMap<>(); + + // Use parallel stream with controlled concurrency + int parallelism = Math.min(10, filePaths.size()); // Max 10 concurrent requests + java.util.concurrent.ForkJoinPool customPool = new java.util.concurrent.ForkJoinPool(parallelism); + + try { + customPool.submit(() -> + filePaths.parallelStream().forEach(path -> { + int maxRetries = 3; + int retryCount = 0; + long backoffMs = 1000; // Start with 1 second + + while (retryCount < maxRetries) { + try { + String content = getFileContent(workspaceId, repoIdOrSlug, path, branchOrCommit); + if (content != null && content.getBytes(java.nio.charset.StandardCharsets.UTF_8).length <= maxFileSizeBytes) { + results.put(path, content); + } + break; // Success, exit retry loop + } catch (IOException e) { + retryCount++; + if (e.getMessage() != null && e.getMessage().contains("429")) { + // Rate limited - exponential backoff + try { + Thread.sleep(backoffMs); + backoffMs *= 2; // Double the backoff + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + break; + } + } else if (retryCount >= maxRetries) { + // Log and skip this file + log.warn("Failed to fetch file {} after {} retries: {}", path, maxRetries, e.getMessage()); + } + } + } + }) + ).get(); + } catch (InterruptedException | java.util.concurrent.ExecutionException e) { + log.error("Error in parallel file fetch: {}", e.getMessage()); + throw new IOException("Batch file fetch failed", e); + } finally { + customPool.shutdown(); + } + + log.info("Batch fetched {}/{} files from Bitbucket", results.size(), filePaths.size()); + return results; + } } diff --git a/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/github/GitHubClient.java b/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/github/GitHubClient.java index d9a6de5e..e36ce4a4 100644 --- a/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/github/GitHubClient.java +++ b/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/github/GitHubClient.java @@ -933,6 +933,115 @@ private Request createGetRequest(String url) { .build(); } + /** + * Batch fetch file contents using GitHub GraphQL API for efficiency. + * GraphQL allows fetching multiple files in a single request. + */ + @Override + public java.util.Map getFileContents( + String workspaceId, + String repoIdOrSlug, + java.util.List filePaths, + String branchOrCommit, + int maxFileSizeBytes + ) throws IOException { + java.util.Map results = new java.util.HashMap<>(); + + // GitHub GraphQL API endpoint + String graphqlUrl = "https://api.github.com/graphql"; + + // Process in batches of 50 files (GraphQL complexity limits) + int batchSize = 50; + for (int i = 0; i < filePaths.size(); i += batchSize) { + java.util.List batch = filePaths.subList(i, Math.min(i + batchSize, filePaths.size())); + + // Build GraphQL query for this batch + StringBuilder queryBuilder = new StringBuilder(); + queryBuilder.append("query { repository(owner: \"").append(workspaceId) + .append("\", name: \"").append(repoIdOrSlug).append("\") {"); + + for (int j = 0; j < batch.size(); j++) { + String path = batch.get(j); + String alias = "file" + j; + String expression = branchOrCommit + ":" + path; + queryBuilder.append(alias).append(": object(expression: \"") + .append(expression.replace("\"", "\\\"")) + .append("\") { ... on Blob { text byteSize } } "); + } + queryBuilder.append("}}"); + + String query = queryBuilder.toString(); + String requestBody = objectMapper.writeValueAsString(java.util.Map.of("query", query)); + + Request request = new Request.Builder() + .url(graphqlUrl) + .header(ACCEPT_HEADER, GITHUB_ACCEPT_HEADER) + .header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION) + .post(RequestBody.create(requestBody, JSON_MEDIA_TYPE)) + .build(); + + int maxRetries = 3; + int retryCount = 0; + long backoffMs = 1000; + + while (retryCount < maxRetries) { + try (Response response = httpClient.newCall(request).execute()) { + if (response.code() == 429) { + // Rate limited - exponential backoff + retryCount++; + Thread.sleep(backoffMs); + backoffMs *= 2; + continue; + } + + if (!response.isSuccessful()) { + log.warn("GraphQL batch fetch failed with status {}, falling back to REST API", response.code()); + // Fall back to sequential REST API + for (String path : batch) { + try { + String content = getFileContent(workspaceId, repoIdOrSlug, path, branchOrCommit); + if (content != null && content.getBytes(java.nio.charset.StandardCharsets.UTF_8).length <= maxFileSizeBytes) { + results.put(path, content); + } + } catch (IOException e) { + log.debug("Skipping file {}: {}", path, e.getMessage()); + } + } + break; + } + + // Parse GraphQL response + JsonNode root = objectMapper.readTree(response.body().string()); + JsonNode data = root.path("data").path("repository"); + + for (int j = 0; j < batch.size(); j++) { + String path = batch.get(j); + String alias = "file" + j; + JsonNode fileNode = data.path(alias); + + if (!fileNode.isMissingNode() && fileNode.has("text")) { + int byteSize = fileNode.path("byteSize").asInt(0); + if (byteSize <= maxFileSizeBytes) { + String text = fileNode.get("text").asText(); + if (text != null) { + results.put(path, text); + } + } + } + } + break; // Success + + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Interrupted during batch fetch", e); + } + } + } + + log.info("Batch fetched {}/{} files from GitHub via GraphQL", results.size(), filePaths.size()); + return results; + } + private record GitHubWebhookRequest( String name, diff --git a/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/gitlab/GitLabClient.java b/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/gitlab/GitLabClient.java index d2df53e7..0fb1ec93 100644 --- a/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/gitlab/GitLabClient.java +++ b/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/gitlab/GitLabClient.java @@ -849,4 +849,66 @@ private Request createGetRequest(String url) { .get() .build(); } + + /** + * Batch fetch file contents with parallel execution and exponential backoff. + * GitLab doesn't have a batch API, so we fetch in parallel with rate limit handling. + */ + @Override + public java.util.Map getFileContents( + String workspaceId, + String repoIdOrSlug, + java.util.List filePaths, + String branchOrCommit, + int maxFileSizeBytes + ) throws IOException { + java.util.Map results = new java.util.concurrent.ConcurrentHashMap<>(); + + // Use parallel stream with controlled concurrency + int parallelism = Math.min(10, filePaths.size()); // Max 10 concurrent requests + java.util.concurrent.ForkJoinPool customPool = new java.util.concurrent.ForkJoinPool(parallelism); + + try { + customPool.submit(() -> + filePaths.parallelStream().forEach(path -> { + int maxRetries = 3; + int retryCount = 0; + long backoffMs = 1000; // Start with 1 second + + while (retryCount < maxRetries) { + try { + String content = getFileContent(workspaceId, repoIdOrSlug, path, branchOrCommit); + if (content != null && content.getBytes(java.nio.charset.StandardCharsets.UTF_8).length <= maxFileSizeBytes) { + results.put(path, content); + } + break; // Success, exit retry loop + } catch (IOException e) { + retryCount++; + if (e.getMessage() != null && (e.getMessage().contains("429") || e.getMessage().contains("rate limit"))) { + // Rate limited - exponential backoff + try { + Thread.sleep(backoffMs); + backoffMs *= 2; // Double the backoff + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + break; + } + } else if (retryCount >= maxRetries) { + // Log and skip this file + log.warn("Failed to fetch file {} after {} retries: {}", path, maxRetries, e.getMessage()); + } + } + } + }) + ).get(); + } catch (InterruptedException | java.util.concurrent.ExecutionException e) { + log.error("Error in parallel file fetch: {}", e.getMessage()); + throw new IOException("Batch file fetch failed", e); + } finally { + customPool.shutdown(); + } + + log.info("Batch fetched {}/{} files from GitLab", results.size(), filePaths.size()); + return results; + } } diff --git a/java-ecosystem/mcp-servers/platform-mcp/src/main/java/org/rostilos/codecrow/platformmcp/db/PlatformDbService.java b/java-ecosystem/mcp-servers/platform-mcp/src/main/java/org/rostilos/codecrow/platformmcp/db/PlatformDbService.java deleted file mode 100644 index e69de29b..00000000 diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java index 88d33c55..7060f11e 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java @@ -1,6 +1,8 @@ package org.rostilos.codecrow.pipelineagent.bitbucket.service; import okhttp3.OkHttpClient; +import org.rostilos.codecrow.analysisengine.dto.request.ai.enrichment.PrEnrichmentDataDto; +import org.rostilos.codecrow.analysisengine.service.PrFileEnrichmentService; import org.rostilos.codecrow.core.model.ai.AIConnection; import org.rostilos.codecrow.core.model.codeanalysis.AnalysisMode; import org.rostilos.codecrow.core.model.codeanalysis.AnalysisType; @@ -19,6 +21,7 @@ import org.rostilos.codecrow.analysisengine.util.DiffParser; import org.rostilos.codecrow.analysisengine.util.TokenEstimator; import org.rostilos.codecrow.security.oauth.TokenEncryptionService; +import org.rostilos.codecrow.vcsclient.VcsClient; import org.rostilos.codecrow.vcsclient.VcsClientProvider; import org.rostilos.codecrow.vcsclient.bitbucket.cloud.actions.GetCommitRangeDiffAction; import org.rostilos.codecrow.vcsclient.bitbucket.cloud.actions.GetPullRequestAction; @@ -27,6 +30,7 @@ import org.rostilos.codecrow.vcsclient.utils.VcsConnectionCredentialsExtractor.VcsConnectionCredentials; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.io.IOException; @@ -54,14 +58,17 @@ public class BitbucketAiClientService implements VcsAiClientService { private final TokenEncryptionService tokenEncryptionService; private final VcsClientProvider vcsClientProvider; private final VcsConnectionCredentialsExtractor credentialsExtractor; + private final PrFileEnrichmentService enrichmentService; public BitbucketAiClientService( TokenEncryptionService tokenEncryptionService, - VcsClientProvider vcsClientProvider + VcsClientProvider vcsClientProvider, + @Autowired(required = false) PrFileEnrichmentService enrichmentService ) { this.tokenEncryptionService = tokenEncryptionService; this.vcsClientProvider = vcsClientProvider; this.credentialsExtractor = new VcsConnectionCredentialsExtractor(tokenEncryptionService); + this.enrichmentService = enrichmentService; } @Override @@ -246,7 +253,27 @@ public AiAnalysisRequest buildPrAnalysisRequest( // Continue without metadata - RAG will use fallback } - var builder = AiAnalysisRequestImpl.builder() + // Enrich PR with full file contents and dependency graph + PrEnrichmentDataDto enrichmentData = PrEnrichmentDataDto.empty(); + if (enrichmentService != null && enrichmentService.isEnrichmentEnabled() && !changedFiles.isEmpty()) { + try { + VcsClient vcsClient = vcsClientProvider.getClient(vcsConnection); + enrichmentData = enrichmentService.enrichPrFiles( + vcsClient, + vcsInfo.workspace(), + vcsInfo.repoSlug(), + request.getSourceBranchName(), + changedFiles + ); + log.info("PR enrichment completed: {} files enriched, {} relationships", + enrichmentData.stats().filesEnriched(), + enrichmentData.stats().relationshipsFound()); + } catch (Exception e) { + log.warn("Failed to enrich PR files (non-critical): {}", e.getMessage()); + } + } + + AiAnalysisRequestImpl.Builder builder = AiAnalysisRequestImpl.builder() .withProjectId(project.getId()) .withPullRequestId(request.getPullRequestId()) .withProjectAiConnection(aiConnection) @@ -268,7 +295,9 @@ public AiAnalysisRequest buildPrAnalysisRequest( .withAnalysisMode(analysisMode) .withDeltaDiff(deltaDiff) .withPreviousCommitHash(previousCommitHash) - .withCurrentCommitHash(currentCommitHash); + .withCurrentCommitHash(currentCommitHash) + // File enrichment data + .withEnrichmentData(enrichmentData); // Add VCS credentials based on connection type addVcsCredentials(builder, vcsConnection); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java index 69193350..fcd28019 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java @@ -2,6 +2,8 @@ import com.fasterxml.jackson.databind.JsonNode; import okhttp3.OkHttpClient; +import org.rostilos.codecrow.analysisengine.dto.request.ai.enrichment.PrEnrichmentDataDto; +import org.rostilos.codecrow.analysisengine.service.PrFileEnrichmentService; import org.rostilos.codecrow.core.model.ai.AIConnection; import org.rostilos.codecrow.core.model.codeanalysis.AnalysisMode; import org.rostilos.codecrow.core.model.codeanalysis.CodeAnalysis; @@ -19,6 +21,7 @@ import org.rostilos.codecrow.analysisengine.util.DiffParser; import org.rostilos.codecrow.analysisengine.util.TokenEstimator; import org.rostilos.codecrow.security.oauth.TokenEncryptionService; +import org.rostilos.codecrow.vcsclient.VcsClient; import org.rostilos.codecrow.vcsclient.VcsClientProvider; import org.rostilos.codecrow.vcsclient.github.actions.GetCommitRangeDiffAction; import org.rostilos.codecrow.vcsclient.github.actions.GetPullRequestAction; @@ -27,6 +30,7 @@ import org.rostilos.codecrow.vcsclient.utils.VcsConnectionCredentialsExtractor.VcsConnectionCredentials; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.io.IOException; @@ -53,14 +57,17 @@ public class GitHubAiClientService implements VcsAiClientService { private final TokenEncryptionService tokenEncryptionService; private final VcsClientProvider vcsClientProvider; private final VcsConnectionCredentialsExtractor credentialsExtractor; + private final PrFileEnrichmentService enrichmentService; public GitHubAiClientService( TokenEncryptionService tokenEncryptionService, - VcsClientProvider vcsClientProvider + VcsClientProvider vcsClientProvider, + @Autowired(required = false) PrFileEnrichmentService enrichmentService ) { this.tokenEncryptionService = tokenEncryptionService; this.vcsClientProvider = vcsClientProvider; this.credentialsExtractor = new VcsConnectionCredentialsExtractor(tokenEncryptionService); + this.enrichmentService = enrichmentService; } @Override @@ -238,7 +245,27 @@ private AiAnalysisRequest buildPrAnalysisRequest( log.warn("Failed to fetch/parse PR metadata/diff for RAG context: {}", e.getMessage()); } - var builder = AiAnalysisRequestImpl.builder() + // Enrich PR with full file contents and dependency graph + PrEnrichmentDataDto enrichmentData = PrEnrichmentDataDto.empty(); + if (enrichmentService != null && enrichmentService.isEnrichmentEnabled() && !changedFiles.isEmpty()) { + try { + VcsClient vcsClient = vcsClientProvider.getClient(vcsConnection); + enrichmentData = enrichmentService.enrichPrFiles( + vcsClient, + vcsInfo.owner(), + vcsInfo.repoSlug(), + request.getSourceBranchName(), + changedFiles + ); + log.info("PR enrichment completed: {} files enriched, {} relationships", + enrichmentData.stats().filesEnriched(), + enrichmentData.stats().relationshipsFound()); + } catch (Exception e) { + log.warn("Failed to enrich PR files (non-critical): {}", e.getMessage()); + } + } + + AiAnalysisRequestImpl.Builder builder = AiAnalysisRequestImpl.builder() .withProjectId(project.getId()) .withPullRequestId(request.getPullRequestId()) .withProjectAiConnection(aiConnection) @@ -260,7 +287,9 @@ private AiAnalysisRequest buildPrAnalysisRequest( .withAnalysisMode(analysisMode) .withDeltaDiff(deltaDiff) .withPreviousCommitHash(previousCommitHash) - .withCurrentCommitHash(currentCommitHash); + .withCurrentCommitHash(currentCommitHash) + // File enrichment data + .withEnrichmentData(enrichmentData); addVcsCredentials(builder, vcsConnection); @@ -326,7 +355,7 @@ private AiAnalysisRequest buildBranchAnalysisRequest( return builder.build(); } - private void addVcsCredentials(AiAnalysisRequestImpl.Builder builder, VcsConnection connection) + private void addVcsCredentials(AiAnalysisRequestImpl.Builder builder, VcsConnection connection) throws GeneralSecurityException { VcsConnectionCredentials credentials = credentialsExtractor.extractCredentials(connection); if (VcsConnectionCredentialsExtractor.hasAccessToken(credentials)) { diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java index fd736acf..e6b48c10 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java @@ -2,6 +2,8 @@ import com.fasterxml.jackson.databind.JsonNode; import okhttp3.OkHttpClient; +import org.rostilos.codecrow.analysisengine.dto.request.ai.enrichment.PrEnrichmentDataDto; +import org.rostilos.codecrow.analysisengine.service.PrFileEnrichmentService; import org.rostilos.codecrow.core.model.ai.AIConnection; import org.rostilos.codecrow.core.model.codeanalysis.AnalysisMode; import org.rostilos.codecrow.core.model.codeanalysis.CodeAnalysis; @@ -19,6 +21,7 @@ import org.rostilos.codecrow.analysisengine.util.DiffParser; import org.rostilos.codecrow.analysisengine.util.TokenEstimator; import org.rostilos.codecrow.security.oauth.TokenEncryptionService; +import org.rostilos.codecrow.vcsclient.VcsClient; import org.rostilos.codecrow.vcsclient.VcsClientProvider; import org.rostilos.codecrow.vcsclient.gitlab.actions.GetCommitRangeDiffAction; import org.rostilos.codecrow.vcsclient.gitlab.actions.GetMergeRequestAction; @@ -27,6 +30,7 @@ import org.rostilos.codecrow.vcsclient.utils.VcsConnectionCredentialsExtractor.VcsConnectionCredentials; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.io.IOException; @@ -53,14 +57,17 @@ public class GitLabAiClientService implements VcsAiClientService { private final TokenEncryptionService tokenEncryptionService; private final VcsClientProvider vcsClientProvider; private final VcsConnectionCredentialsExtractor credentialsExtractor; + private final PrFileEnrichmentService enrichmentService; public GitLabAiClientService( TokenEncryptionService tokenEncryptionService, - VcsClientProvider vcsClientProvider + VcsClientProvider vcsClientProvider, + @Autowired(required = false) PrFileEnrichmentService enrichmentService ) { this.tokenEncryptionService = tokenEncryptionService; this.vcsClientProvider = vcsClientProvider; this.credentialsExtractor = new VcsConnectionCredentialsExtractor(tokenEncryptionService); + this.enrichmentService = enrichmentService; } @Override @@ -239,7 +246,27 @@ private AiAnalysisRequest buildMrAnalysisRequest( log.warn("Failed to fetch/parse MR metadata/diff for RAG context: {}", e.getMessage()); } - var builder = AiAnalysisRequestImpl.builder() + // Enrich PR with full file contents and dependency graph + PrEnrichmentDataDto enrichmentData = PrEnrichmentDataDto.empty(); + if (enrichmentService != null && enrichmentService.isEnrichmentEnabled() && !changedFiles.isEmpty()) { + try { + VcsClient vcsClient = vcsClientProvider.getClient(vcsConnection); + enrichmentData = enrichmentService.enrichPrFiles( + vcsClient, + vcsInfo.namespace(), + vcsInfo.repoSlug(), + request.getSourceBranchName(), + changedFiles + ); + log.info("PR enrichment completed: {} files enriched, {} relationships", + enrichmentData.stats().filesEnriched(), + enrichmentData.stats().relationshipsFound()); + } catch (Exception e) { + log.warn("Failed to enrich MR files (non-critical): {}", e.getMessage()); + } + } + + AiAnalysisRequestImpl.Builder builder = AiAnalysisRequestImpl.builder() .withProjectId(project.getId()) .withPullRequestId(request.getPullRequestId()) .withProjectAiConnection(aiConnection) @@ -261,7 +288,9 @@ private AiAnalysisRequest buildMrAnalysisRequest( .withAnalysisMode(analysisMode) .withDeltaDiff(deltaDiff) .withPreviousCommitHash(previousCommitHash) - .withCurrentCommitHash(currentCommitHash); + .withCurrentCommitHash(currentCommitHash) + // File enrichment data + .withEnrichmentData(enrichmentData); addVcsCredentials(builder, vcsConnection); diff --git a/python-ecosystem/mcp-client/api/__init__.py b/python-ecosystem/mcp-client/api/__init__.py new file mode 100644 index 00000000..613c59d6 --- /dev/null +++ b/python-ecosystem/mcp-client/api/__init__.py @@ -0,0 +1,8 @@ +""" +API Package. + +Contains the FastAPI application and routers. +""" +from api.app import create_app, run_http_server + +__all__ = ["create_app", "run_http_server"] diff --git a/python-ecosystem/mcp-client/api/app.py b/python-ecosystem/mcp-client/api/app.py new file mode 100644 index 00000000..ac1a201a --- /dev/null +++ b/python-ecosystem/mcp-client/api/app.py @@ -0,0 +1,34 @@ +""" +FastAPI Application Factory. + +Creates and configures the FastAPI application with all routers. +""" +import os +from fastapi import FastAPI + +from api.routers import health, review, commands + + +def create_app() -> FastAPI: + """Create and configure FastAPI application.""" + app = FastAPI(title="codecrow-mcp-client") + + # Register routers + app.include_router(health.router) + app.include_router(review.router) + app.include_router(commands.router) + + return app + + +def run_http_server(host: str = "0.0.0.0", port: int = 8000): + """Run the FastAPI application.""" + app = create_app() + import uvicorn + uvicorn.run(app, host=host, port=port, log_level="info", timeout_keep_alive=300) + + +if __name__ == "__main__": + host = os.environ.get("AI_CLIENT_HOST", "0.0.0.0") + port = int(os.environ.get("AI_CLIENT_PORT", "8000")) + run_http_server(host=host, port=port) diff --git a/python-ecosystem/mcp-client/api/routers/__init__.py b/python-ecosystem/mcp-client/api/routers/__init__.py new file mode 100644 index 00000000..06f12462 --- /dev/null +++ b/python-ecosystem/mcp-client/api/routers/__init__.py @@ -0,0 +1,5 @@ +""" +API Routers Package. + +Contains FastAPI routers for different endpoint groups. +""" diff --git a/python-ecosystem/mcp-client/api/routers/commands.py b/python-ecosystem/mcp-client/api/routers/commands.py new file mode 100644 index 00000000..9992a298 --- /dev/null +++ b/python-ecosystem/mcp-client/api/routers/commands.py @@ -0,0 +1,204 @@ +""" +Command API endpoints (summarize, ask). +""" +import json +import asyncio +from typing import Dict, Any +from fastapi import APIRouter, Request +from starlette.responses import StreamingResponse + +from model.dtos import ( + SummarizeRequestDto, SummarizeResponseDto, + AskRequestDto, AskResponseDto, +) +from service.command.command_service import CommandService + +router = APIRouter(tags=["commands"]) + +# Service instance +_command_service = None + + +def get_command_service() -> CommandService: + """Get or create the command service singleton.""" + global _command_service + if _command_service is None: + _command_service = CommandService() + return _command_service + + +@router.post("/review/summarize", response_model=SummarizeResponseDto) +async def summarize_endpoint(req: SummarizeRequestDto, request: Request): + """ + HTTP endpoint to process /codecrow summarize command. + + Generates a comprehensive PR summary with: + - Overview of changes + - Key files modified + - Impact analysis + - Architecture diagram (Mermaid or ASCII) + """ + command_service = get_command_service() + + try: + wants_stream = _wants_streaming(request) + + if not wants_stream: + # Non-streaming behavior + result = await command_service.process_summarize(req) + return SummarizeResponseDto( + summary=result.get("summary"), + diagram=result.get("diagram"), + diagramType=result.get("diagramType", "MERMAID"), + error=result.get("error") + ) + + # Streaming behavior + async def event_stream(): + queue = asyncio.Queue() + + yield _json_event({"type": "status", "state": "queued", "message": "summarize request received"}) + + def event_callback(event: Dict[str, Any]): + try: + queue.put_nowait(event) + except asyncio.QueueFull: + pass + + async def runner(): + try: + result = await command_service.process_summarize(req, event_callback=event_callback) + await queue.put({ + "type": "final", + "result": result + }) + except Exception as e: + await queue.put({"type": "error", "message": str(e)}) + + task = asyncio.create_task(runner()) + + async for event in _drain_queue_until_final(queue, task): + yield _json_event(event) + + return StreamingResponse(event_stream(), media_type="application/x-ndjson") + + except Exception as e: + return SummarizeResponseDto(error=f"Summarize failed: {str(e)}") + + +@router.post("/review/ask", response_model=AskResponseDto) +async def ask_endpoint(req: AskRequestDto, request: Request): + """ + HTTP endpoint to process /codecrow ask command. + + Answers questions about: + - Specific issues + - PR changes + - Codebase (using RAG) + - Analysis results + """ + command_service = get_command_service() + + try: + wants_stream = _wants_streaming(request) + + if not wants_stream: + # Non-streaming behavior + result = await command_service.process_ask(req) + return AskResponseDto( + answer=result.get("answer"), + error=result.get("error") + ) + + # Streaming behavior + async def event_stream(): + queue = asyncio.Queue() + + yield _json_event({"type": "status", "state": "queued", "message": "ask request received"}) + + def event_callback(event: Dict[str, Any]): + try: + queue.put_nowait(event) + except asyncio.QueueFull: + pass + + async def runner(): + try: + result = await command_service.process_ask(req, event_callback=event_callback) + await queue.put({ + "type": "final", + "result": result + }) + except Exception as e: + await queue.put({"type": "error", "message": str(e)}) + + task = asyncio.create_task(runner()) + + async for event in _drain_queue_until_final(queue, task): + yield _json_event(event) + + return StreamingResponse(event_stream(), media_type="application/x-ndjson") + + except Exception as e: + return AskResponseDto(error=f"Ask failed: {str(e)}") + + +def _wants_streaming(request: Request) -> bool: + """Check if client wants streaming response.""" + accept_header = request.headers.get("accept", "") + return "application/x-ndjson" in accept_header.lower() + + +def _json_event(event: Dict[str, Any]) -> str: + """Serialize event to NDJSON line.""" + return json.dumps(event) + "\n" + + +async def _drain_queue_until_final(queue: asyncio.Queue, task: asyncio.Task): + """ + Drain events from queue until we see a final/error event and task is done. + Yields each event as it arrives. + """ + seen_final = False + + while True: + try: + # Wait for event with timeout to prevent hanging + event = await asyncio.wait_for(queue.get(), timeout=1.0) + yield event + + # Check if this is a terminal event + event_type = event.get("type") + if event_type in ("final", "error"): + seen_final = True + # Continue draining in case there are more events + + # If we've seen final and task is done, check for remaining events then exit + if seen_final and task.done(): + # Give a moment for any last events + await asyncio.sleep(0.1) + try: + while True: + event = queue.get_nowait() + yield event + except asyncio.QueueEmpty: + break + break + + except asyncio.TimeoutError: + # No event available, check if task is done + if task.done(): + # Task finished, drain any remaining events + try: + while True: + event = queue.get_nowait() + yield event + if event.get("type") in ("final", "error"): + seen_final = True + except asyncio.QueueEmpty: + pass + + # If we saw a final event or no more events, we're done + if seen_final or task.done(): + break + # Otherwise continue waiting for events diff --git a/python-ecosystem/mcp-client/api/routers/health.py b/python-ecosystem/mcp-client/api/routers/health.py new file mode 100644 index 00000000..98c53522 --- /dev/null +++ b/python-ecosystem/mcp-client/api/routers/health.py @@ -0,0 +1,12 @@ +""" +Health check endpoints. +""" +from fastapi import APIRouter + +router = APIRouter(tags=["health"]) + + +@router.get("/health") +def health(): + """Health check endpoint.""" + return {"status": "ok"} diff --git a/python-ecosystem/mcp-client/api/routers/review.py b/python-ecosystem/mcp-client/api/routers/review.py new file mode 100644 index 00000000..68af6aa6 --- /dev/null +++ b/python-ecosystem/mcp-client/api/routers/review.py @@ -0,0 +1,157 @@ +""" +Review API endpoints. +""" +import json +import asyncio +from typing import Dict, Any +from fastapi import APIRouter, Request +from starlette.responses import StreamingResponse + +from model.dtos import ReviewRequestDto, ReviewResponseDto +from service.review.review_service import ReviewService +from utils.response_parser import ResponseParser + +router = APIRouter(tags=["review"]) + +# Service instance +_review_service = None + + +def get_review_service() -> ReviewService: + """Get or create the review service singleton.""" + global _review_service + if _review_service is None: + _review_service = ReviewService() + return _review_service + + +@router.post("/review", response_model=ReviewResponseDto) +async def review_endpoint(req: ReviewRequestDto, request: Request): + """ + HTTP endpoint to accept review requests from the pipeline agent. + + Behavior: + - If the client requests streaming via header `Accept: application/x-ndjson`, + the endpoint will return a StreamingResponse that yields NDJSON events as they occur. + - Otherwise it preserves the original behavior and returns a single ReviewResponseDto JSON body. + """ + review_service = get_review_service() + + try: + wants_stream = _wants_streaming(request) + + if not wants_stream: + # Non-streaming (legacy) behavior + result = await review_service.process_review_request(req) + return ReviewResponseDto( + result=result.get("result"), + error=result.get("error") + ) + + # Streaming behavior + async def event_stream(): + queue = asyncio.Queue() + + # Emit initial queued status + yield _json_event({"type": "status", "state": "queued", "message": "request received"}) + + # Event callback to capture service events + def event_callback(event: Dict[str, Any]): + try: + queue.put_nowait(event) + except asyncio.QueueFull: + pass # Skip if queue is full + + # Run processing in background + async def runner(): + try: + result = await review_service.process_review_request( + req, + event_callback=event_callback + ) + # Emit final event with result + final_event = { + "type": "final", + "result": result.get("result") + } + await queue.put(final_event) + except Exception as e: + await queue.put({ + "type": "error", + "message": str(e) + }) + + task = asyncio.create_task(runner()) + + # Drain queue and yield events + async for event in _drain_queue_until_final(queue, task): + yield _json_event(event) + + return StreamingResponse(event_stream(), media_type="application/x-ndjson") + + except Exception as e: + error_response = ResponseParser.create_error_response( + "HTTP request processing failed", str(e) + ) + return ReviewResponseDto(result=error_response) + + +def _wants_streaming(request: Request) -> bool: + """Check if client wants streaming response.""" + accept_header = request.headers.get("accept", "") + return "application/x-ndjson" in accept_header.lower() + + +def _json_event(event: Dict[str, Any]) -> str: + """Serialize event to NDJSON line.""" + return json.dumps(event) + "\n" + + +async def _drain_queue_until_final(queue: asyncio.Queue, task: asyncio.Task): + """ + Drain events from queue until we see a final/error event and task is done. + Yields each event as it arrives. + """ + seen_final = False + + while True: + try: + # Wait for event with timeout to prevent hanging + event = await asyncio.wait_for(queue.get(), timeout=1.0) + yield event + + # Check if this is a terminal event + event_type = event.get("type") + if event_type in ("final", "error"): + seen_final = True + # Continue draining in case there are more events + + # If we've seen final and task is done, check for remaining events then exit + if seen_final and task.done(): + # Give a moment for any last events + await asyncio.sleep(0.1) + try: + while True: + event = queue.get_nowait() + yield event + except asyncio.QueueEmpty: + break + break + + except asyncio.TimeoutError: + # No event available, check if task is done + if task.done(): + # Task finished, drain any remaining events + try: + while True: + event = queue.get_nowait() + yield event + if event.get("type") in ("final", "error"): + seen_final = True + except asyncio.QueueEmpty: + pass + + # If we saw a final event or no more events, we're done + if seen_final or task.done(): + break + # Otherwise continue waiting for events diff --git a/python-ecosystem/mcp-client/main.py b/python-ecosystem/mcp-client/main.py index abb28d6f..9d1ce045 100644 --- a/python-ecosystem/mcp-client/main.py +++ b/python-ecosystem/mcp-client/main.py @@ -18,6 +18,7 @@ import threading from server.stdin_handler import StdinHandler +from api.app import run_http_server # Configure logging - only if not already configured # This prevents duplicate handlers when libraries also configure logging @@ -79,7 +80,6 @@ def main(): handler.process_stdin_request() else: # Run in HTTP server mode (default) - from server.web_server import run_http_server host = os.environ.get("AI_CLIENT_HOST", "0.0.0.0") port = int(os.environ.get("AI_CLIENT_PORT", "8000")) run_http_server(host=host, port=port) diff --git a/python-ecosystem/mcp-client/model/__init__.py b/python-ecosystem/mcp-client/model/__init__.py new file mode 100644 index 00000000..db99d4b2 --- /dev/null +++ b/python-ecosystem/mcp-client/model/__init__.py @@ -0,0 +1,98 @@ +""" +Model package - Re-exports all models for backward compatibility. + +The models are split into logical modules: +- enums: IssueCategory, AnalysisMode, RelationshipType +- enrichment: File enrichment DTOs (FileContentDto, PrEnrichmentDataDto, etc.) +- dtos: Request/Response DTOs (ReviewRequestDto, SummarizeRequestDto, etc.) +- output_schemas: MCP Agent output schemas (CodeReviewOutput, CodeReviewIssue, etc.) +- multi_stage: Multi-stage review models (ReviewPlan, FileReviewOutput, etc.) +""" + +# Enums +from model.enums import ( + IssueCategory, + AnalysisMode, + RelationshipType, +) + +# Enrichment models +from model.enrichment import ( + FileContentDto, + ParsedFileMetadataDto, + FileRelationshipDto, + EnrichmentStats, + PrEnrichmentDataDto, +) + +# DTOs +from model.dtos import ( + IssueDTO, + ReviewRequestDto, + ReviewResponseDto, + SummarizeRequestDto, + SummarizeResponseDto, + AskRequestDto, + AskResponseDto, +) + +# Output schemas +from model.output_schemas import ( + CodeReviewIssue, + CodeReviewOutput, + SummarizeOutput, + AskOutput, +) + +# Multi-stage review models +from model.multi_stage import ( + FileReviewOutput, + FileReviewBatchOutput, + ReviewFile, + FileGroup, + FileToSkip, + ReviewPlan, + CrossFileIssue, + DataFlowConcern, + ImmutabilityCheck, + DatabaseIntegrityCheck, + CrossFileAnalysisResult, +) + +__all__ = [ + # Enums + "IssueCategory", + "AnalysisMode", + "RelationshipType", + # Enrichment + "FileContentDto", + "ParsedFileMetadataDto", + "FileRelationshipDto", + "EnrichmentStats", + "PrEnrichmentDataDto", + # DTOs + "IssueDTO", + "ReviewRequestDto", + "ReviewResponseDto", + "SummarizeRequestDto", + "SummarizeResponseDto", + "AskRequestDto", + "AskResponseDto", + # Output schemas + "CodeReviewIssue", + "CodeReviewOutput", + "SummarizeOutput", + "AskOutput", + # Multi-stage + "FileReviewOutput", + "FileReviewBatchOutput", + "ReviewFile", + "FileGroup", + "FileToSkip", + "ReviewPlan", + "CrossFileIssue", + "DataFlowConcern", + "ImmutabilityCheck", + "DatabaseIntegrityCheck", + "CrossFileAnalysisResult", +] diff --git a/python-ecosystem/mcp-client/model/models.py b/python-ecosystem/mcp-client/model/dtos.py similarity index 52% rename from python-ecosystem/mcp-client/model/models.py rename to python-ecosystem/mcp-client/model/dtos.py index a7a0830c..192369fc 100644 --- a/python-ecosystem/mcp-client/model/models.py +++ b/python-ecosystem/mcp-client/model/dtos.py @@ -1,27 +1,8 @@ -from typing import Optional, Any, Dict, List +from typing import Optional, Any, List from pydantic import BaseModel, Field, AliasChoices from datetime import datetime -from enum import Enum - -class IssueCategory(str, Enum): - """Valid issue categories for code analysis.""" - SECURITY = "SECURITY" - PERFORMANCE = "PERFORMANCE" - CODE_QUALITY = "CODE_QUALITY" - BUG_RISK = "BUG_RISK" - STYLE = "STYLE" - DOCUMENTATION = "DOCUMENTATION" - BEST_PRACTICES = "BEST_PRACTICES" - ERROR_HANDLING = "ERROR_HANDLING" - TESTING = "TESTING" - ARCHITECTURE = "ARCHITECTURE" - - -class AnalysisMode(str, Enum): - """Analysis mode for PR reviews.""" - FULL = "FULL" # Full PR diff analysis (first review or escalation) - INCREMENTAL = "INCREMENTAL" # Delta diff analysis (subsequent reviews) +from model.enrichment import PrEnrichmentDataDto class IssueDTO(BaseModel): @@ -89,12 +70,16 @@ class ReviewRequestDto(BaseModel): deltaDiff: Optional[str] = Field(default=None, description="Delta diff between previous and current commit (only for INCREMENTAL mode)") previousCommitHash: Optional[str] = Field(default=None, description="Previously analyzed commit hash") currentCommitHash: Optional[str] = Field(default=None, description="Current commit hash being analyzed") + # File enrichment data (full file contents + pre-computed dependency graph) + enrichmentData: Optional[PrEnrichmentDataDto] = Field(default=None, description="Pre-computed file contents and dependency relationships from Java") + class ReviewResponseDto(BaseModel): result: Optional[Any] = None error: Optional[str] = None exception: Optional[str] = None + class SummarizeRequestDto(BaseModel): """Request model for PR summarization command.""" projectId: int @@ -116,6 +101,7 @@ class SummarizeRequestDto(BaseModel): maxAllowedTokens: Optional[int] = None vcsProvider: Optional[str] = Field(default=None, description="VCS provider type (github, bitbucket_cloud)") + class SummarizeResponseDto(BaseModel): """Response model for PR summarization command.""" summary: Optional[str] = None @@ -123,6 +109,7 @@ class SummarizeResponseDto(BaseModel): diagramType: Optional[str] = Field(default="MERMAID", description="MERMAID or ASCII") error: Optional[str] = None + class AskRequestDto(BaseModel): """Request model for ask command.""" projectId: int @@ -145,141 +132,8 @@ class AskRequestDto(BaseModel): analysisContext: Optional[str] = Field(default=None, description="Existing analysis data for context") issueReferences: Optional[List[str]] = Field(default_factory=list, description="Issue IDs referenced in the question") + class AskResponseDto(BaseModel): """Response model for ask command.""" answer: Optional[str] = None error: Optional[str] = None - -# ==================== Output Schemas for MCP Agent ==================== -# These Pydantic models are used with MCPAgent's output_schema parameter -# to ensure structured JSON output from the LLM. - -class CodeReviewIssue(BaseModel): - """Schema for a single code review issue.""" - # Optional issue identifier (preserve DB/client-side ids for reconciliation) - id: Optional[str] = Field(default=None, description="Optional issue id to link to existing issues") - severity: str = Field(description="Issue severity: HIGH, MEDIUM, LOW, or INFO") - category: str = Field(description="Issue category: SECURITY, PERFORMANCE, CODE_QUALITY, BUG_RISK, STYLE, DOCUMENTATION, BEST_PRACTICES, ERROR_HANDLING, TESTING, or ARCHITECTURE") - file: str = Field(description="File path where the issue is located") - line: str = Field(description="Line number or range (e.g., '42' or '42-45')") - reason: str = Field(description="Clear explanation of the issue") - suggestedFixDescription: str = Field(description="Description of the suggested fix") - suggestedFixDiff: Optional[str] = Field(default=None, description="Optional unified diff format patch for the fix") - isResolved: bool = Field(default=False, description="Whether this issue from previous analysis is resolved") - # Resolution tracking fields - resolutionExplanation: Optional[str] = Field(default=None, description="Explanation of how the issue was resolved (separate from original reason)") - resolvedInCommit: Optional[str] = Field(default=None, description="Commit hash where the issue was resolved") - # Additional fields preserved from previous issues during reconciliation - visibility: Optional[str] = Field(default=None, description="Issue visibility status") - codeSnippet: Optional[str] = Field(default=None, description="Code snippet associated with the issue") - - -class CodeReviewOutput(BaseModel): - """Schema for the complete code review output.""" - comment: str = Field(description="High-level summary of the PR analysis with key findings and recommendations") - issues: List[CodeReviewIssue] = Field(default_factory=list, description="List of identified issues in the code") - - -class SummarizeOutput(BaseModel): - """Schema for PR summarization output.""" - summary: str = Field(description="Comprehensive summary of the PR changes, purpose, and impact") - diagram: str = Field(default="", description="Visual diagram of the changes (Mermaid or ASCII format)") - diagramType: str = Field(default="MERMAID", description="Type of diagram: MERMAID or ASCII") - - -class AskOutput(BaseModel): - """Schema for ask command output.""" - answer: str = Field(description="Well-formatted markdown answer to the user's question") - - -# ==================== Multi-Stage Review Models ==================== - -class FileReviewOutput(BaseModel): - """Stage 1 Output: Single file review result.""" - file: str - analysis_summary: str - issues: List[CodeReviewIssue] = Field(default_factory=list) - confidence: str = Field(description="Confidence level (HIGH/MEDIUM/LOW)") - note: str = Field(default="", description="Optional analysis note") - - - -class FileReviewBatchOutput(BaseModel): - """Stage 1 Output: Batch of file reviews.""" - reviews: List[FileReviewOutput] = Field(description="List of review results for the files in the batch") - - - -class ReviewFile(BaseModel): - """File details for review planning.""" - path: str - focus_areas: List[str] = Field(default_factory=list, description="Specific areas to focus on (SECURITY, ARCHITECTURE, etc.)") - risk_level: str = Field(description="CRITICAL, HIGH, MEDIUM, or LOW") - estimated_issues: Optional[int] = Field(default=0) - - -class FileGroup(BaseModel): - """Group of files to be reviewed together.""" - group_id: str - priority: str = Field(description="CRITICAL, HIGH, MEDIUM, LOW") - rationale: str - files: List[ReviewFile] - - -class FileToSkip(BaseModel): - """File skipped from deep review.""" - path: str - reason: str - - -class ReviewPlan(BaseModel): - """Stage 0 Output: Plan for the review scanning.""" - analysis_summary: str - file_groups: List[FileGroup] - files_to_skip: List[FileToSkip] = Field(default_factory=list) - cross_file_concerns: List[str] = Field(default_factory=list, description="Hypotheses to verify in Stage 2") - - -class CrossFileIssue(BaseModel): - """Issue spanning multiple files (Stage 2).""" - id: str - severity: str - category: str - title: str - affected_files: List[str] - description: str - evidence: str - business_impact: str - suggestion: str - - -class DataFlowConcern(BaseModel): - """Stage 2: Data flow gap analysis.""" - flow: str - gap: str - files_involved: List[str] - severity: str - - -class ImmutabilityCheck(BaseModel): - """Stage 2: Immutability usage check.""" - rule: str - check_pass: bool = Field(alias="check_pass") - evidence: str - - -class DatabaseIntegrityCheck(BaseModel): - """Stage 2: DB integrity check.""" - concerns: List[str] - findings: List[str] - - -class CrossFileAnalysisResult(BaseModel): - """Stage 2 Output: Cross-file architectural analysis.""" - pr_risk_level: str - cross_file_issues: List[CrossFileIssue] - data_flow_concerns: List[DataFlowConcern] = Field(default_factory=list) - immutability_enforcement: Optional[ImmutabilityCheck] = None - database_integrity: Optional[DatabaseIntegrityCheck] = None - pr_recommendation: str - confidence: str \ No newline at end of file diff --git a/python-ecosystem/mcp-client/model/enrichment.py b/python-ecosystem/mcp-client/model/enrichment.py new file mode 100644 index 00000000..81484107 --- /dev/null +++ b/python-ecosystem/mcp-client/model/enrichment.py @@ -0,0 +1,59 @@ +from typing import Optional, Dict, List +from pydantic import BaseModel, Field + +from model.enums import RelationshipType + + +class FileContentDto(BaseModel): + """DTO representing the content of a single file retrieved from VCS.""" + path: str + content: Optional[str] = None + sizeBytes: int = 0 + skipped: bool = False + skipReason: Optional[str] = None + + +class ParsedFileMetadataDto(BaseModel): + """DTO representing parsed AST metadata for a single file.""" + path: str + language: Optional[str] = None + imports: List[str] = Field(default_factory=list) + extendsClasses: List[str] = Field(default_factory=list, alias="extends") + implementsInterfaces: List[str] = Field(default_factory=list, alias="implements") + semanticNames: List[str] = Field(default_factory=list, alias="semantic_names") + parentClass: Optional[str] = Field(default=None, alias="parent_class") + namespace: Optional[str] = None + calls: List[str] = Field(default_factory=list) + error: Optional[str] = None + + +class FileRelationshipDto(BaseModel): + """DTO representing a relationship between two files in the PR.""" + sourceFile: str + targetFile: str + relationshipType: RelationshipType + matchedOn: Optional[str] = None + strength: int = 0 + + +class EnrichmentStats(BaseModel): + """Statistics about the enrichment process.""" + totalFilesRequested: int = 0 + filesEnriched: int = 0 + filesSkipped: int = 0 + relationshipsFound: int = 0 + totalContentSizeBytes: int = 0 + processingTimeMs: int = 0 + skipReasons: Dict[str, int] = Field(default_factory=dict) + + +class PrEnrichmentDataDto(BaseModel): + """Aggregate DTO containing all file enrichment data for a PR.""" + fileContents: List[FileContentDto] = Field(default_factory=list) + fileMetadata: List[ParsedFileMetadataDto] = Field(default_factory=list) + relationships: List[FileRelationshipDto] = Field(default_factory=list) + stats: Optional[EnrichmentStats] = None + + def has_data(self) -> bool: + """Check if enrichment data is present.""" + return bool(self.fileContents) or bool(self.relationships) diff --git a/python-ecosystem/mcp-client/model/enums.py b/python-ecosystem/mcp-client/model/enums.py new file mode 100644 index 00000000..0541f03a --- /dev/null +++ b/python-ecosystem/mcp-client/model/enums.py @@ -0,0 +1,31 @@ +from enum import Enum + + +class IssueCategory(str, Enum): + """Valid issue categories for code analysis.""" + SECURITY = "SECURITY" + PERFORMANCE = "PERFORMANCE" + CODE_QUALITY = "CODE_QUALITY" + BUG_RISK = "BUG_RISK" + STYLE = "STYLE" + DOCUMENTATION = "DOCUMENTATION" + BEST_PRACTICES = "BEST_PRACTICES" + ERROR_HANDLING = "ERROR_HANDLING" + TESTING = "TESTING" + ARCHITECTURE = "ARCHITECTURE" + + +class AnalysisMode(str, Enum): + """Analysis mode for PR reviews.""" + FULL = "FULL" # Full PR diff analysis (first review or escalation) + INCREMENTAL = "INCREMENTAL" # Delta diff analysis (subsequent reviews) + + +class RelationshipType(str, Enum): + """Types of relationships between files.""" + IMPORTS = "IMPORTS" + EXTENDS = "EXTENDS" + IMPLEMENTS = "IMPLEMENTS" + CALLS = "CALLS" + SAME_PACKAGE = "SAME_PACKAGE" + REFERENCES = "REFERENCES" diff --git a/python-ecosystem/mcp-client/model/multi_stage.py b/python-ecosystem/mcp-client/model/multi_stage.py new file mode 100644 index 00000000..16e11fab --- /dev/null +++ b/python-ecosystem/mcp-client/model/multi_stage.py @@ -0,0 +1,102 @@ +""" +Multi-Stage Review Models. + +These models are used for the multi-stage PR review process: +- Stage 0: Planning (ReviewPlan) +- Stage 1: File-by-file review (FileReviewOutput, FileReviewBatchOutput) +- Stage 2: Cross-file analysis (CrossFileAnalysisResult) +""" + +from typing import Optional, List +from pydantic import BaseModel, Field + +from model.output_schemas import CodeReviewIssue + + +class FileReviewOutput(BaseModel): + """Stage 1 Output: Single file review result.""" + file: str + analysis_summary: str + issues: List[CodeReviewIssue] = Field(default_factory=list) + confidence: str = Field(description="Confidence level (HIGH/MEDIUM/LOW)") + note: str = Field(default="", description="Optional analysis note") + + +class FileReviewBatchOutput(BaseModel): + """Stage 1 Output: Batch of file reviews.""" + reviews: List[FileReviewOutput] = Field(description="List of review results for the files in the batch") + + +class ReviewFile(BaseModel): + """File details for review planning.""" + path: str + focus_areas: List[str] = Field(default_factory=list, description="Specific areas to focus on (SECURITY, ARCHITECTURE, etc.)") + risk_level: str = Field(description="CRITICAL, HIGH, MEDIUM, or LOW") + estimated_issues: Optional[int] = Field(default=0) + + +class FileGroup(BaseModel): + """Group of files to be reviewed together.""" + group_id: str + priority: str = Field(description="CRITICAL, HIGH, MEDIUM, LOW") + rationale: str + files: List[ReviewFile] + + +class FileToSkip(BaseModel): + """File skipped from deep review.""" + path: str + reason: str + + +class ReviewPlan(BaseModel): + """Stage 0 Output: Plan for the review scanning.""" + analysis_summary: str + file_groups: List[FileGroup] + files_to_skip: List[FileToSkip] = Field(default_factory=list) + cross_file_concerns: List[str] = Field(default_factory=list, description="Hypotheses to verify in Stage 2") + + +class CrossFileIssue(BaseModel): + """Issue spanning multiple files (Stage 2).""" + id: str + severity: str + category: str + title: str + affected_files: List[str] + description: str + evidence: str + business_impact: str + suggestion: str + + +class DataFlowConcern(BaseModel): + """Stage 2: Data flow gap analysis.""" + flow: str + gap: str + files_involved: List[str] + severity: str + + +class ImmutabilityCheck(BaseModel): + """Stage 2: Immutability usage check.""" + rule: str + check_pass: bool = Field(alias="check_pass") + evidence: str + + +class DatabaseIntegrityCheck(BaseModel): + """Stage 2: DB integrity check.""" + concerns: List[str] + findings: List[str] + + +class CrossFileAnalysisResult(BaseModel): + """Stage 2 Output: Cross-file architectural analysis.""" + pr_risk_level: str + cross_file_issues: List[CrossFileIssue] + data_flow_concerns: List[DataFlowConcern] = Field(default_factory=list) + immutability_enforcement: Optional[ImmutabilityCheck] = None + database_integrity: Optional[DatabaseIntegrityCheck] = None + pr_recommendation: str + confidence: str diff --git a/python-ecosystem/mcp-client/model/output_schemas.py b/python-ecosystem/mcp-client/model/output_schemas.py new file mode 100644 index 00000000..41583db4 --- /dev/null +++ b/python-ecosystem/mcp-client/model/output_schemas.py @@ -0,0 +1,47 @@ +""" +Output Schemas for MCP Agent. + +These Pydantic models are used with MCPAgent's output_schema parameter +to ensure structured JSON output from the LLM. +""" + +from typing import Optional, List +from pydantic import BaseModel, Field + + +class CodeReviewIssue(BaseModel): + """Schema for a single code review issue.""" + # Optional issue identifier (preserve DB/client-side ids for reconciliation) + id: Optional[str] = Field(default=None, description="Optional issue id to link to existing issues") + severity: str = Field(description="Issue severity: HIGH, MEDIUM, LOW, or INFO") + category: str = Field(description="Issue category: SECURITY, PERFORMANCE, CODE_QUALITY, BUG_RISK, STYLE, DOCUMENTATION, BEST_PRACTICES, ERROR_HANDLING, TESTING, or ARCHITECTURE") + file: str = Field(description="File path where the issue is located") + line: str = Field(description="Line number or range (e.g., '42' or '42-45')") + reason: str = Field(description="Clear explanation of the issue") + suggestedFixDescription: str = Field(description="Description of the suggested fix") + suggestedFixDiff: Optional[str] = Field(default=None, description="Optional unified diff format patch for the fix") + isResolved: bool = Field(default=False, description="Whether this issue from previous analysis is resolved") + # Resolution tracking fields + resolutionExplanation: Optional[str] = Field(default=None, description="Explanation of how the issue was resolved (separate from original reason)") + resolvedInCommit: Optional[str] = Field(default=None, description="Commit hash where the issue was resolved") + # Additional fields preserved from previous issues during reconciliation + visibility: Optional[str] = Field(default=None, description="Issue visibility status") + codeSnippet: Optional[str] = Field(default=None, description="Code snippet associated with the issue") + + +class CodeReviewOutput(BaseModel): + """Schema for the complete code review output.""" + comment: str = Field(description="High-level summary of the PR analysis with key findings and recommendations") + issues: List[CodeReviewIssue] = Field(default_factory=list, description="List of identified issues in the code") + + +class SummarizeOutput(BaseModel): + """Schema for PR summarization output.""" + summary: str = Field(description="Comprehensive summary of the PR changes, purpose, and impact") + diagram: str = Field(default="", description="Visual diagram of the changes (Mermaid or ASCII format)") + diagramType: str = Field(default="MERMAID", description="Type of diagram: MERMAID or ASCII") + + +class AskOutput(BaseModel): + """Schema for ask command output.""" + answer: str = Field(description="Well-formatted markdown answer to the user's question") diff --git a/python-ecosystem/mcp-client/server/stdin_handler.py b/python-ecosystem/mcp-client/server/stdin_handler.py index c7be29fc..9817c3b8 100644 --- a/python-ecosystem/mcp-client/server/stdin_handler.py +++ b/python-ecosystem/mcp-client/server/stdin_handler.py @@ -3,8 +3,8 @@ import asyncio from typing import Optional, Dict, Any -from model.models import ReviewRequestDto -from service.review_service import ReviewService +from model.dtos import ReviewRequestDto +from service.review.review_service import ReviewService class StdinHandler: diff --git a/python-ecosystem/mcp-client/server/web_server.py b/python-ecosystem/mcp-client/server/web_server.py deleted file mode 100644 index 58833af8..00000000 --- a/python-ecosystem/mcp-client/server/web_server.py +++ /dev/null @@ -1,284 +0,0 @@ -import os -import json -import shutil -import tempfile -import tarfile -import asyncio -from typing import Dict, Any -from fastapi import FastAPI, Request, UploadFile, File, Form -from starlette.responses import StreamingResponse - -from model.models import ( - ReviewRequestDto, ReviewResponseDto, - SummarizeRequestDto, SummarizeResponseDto, - AskRequestDto, AskResponseDto -) -from service.review_service import ReviewService -from service.command_service import CommandService -from utils.response_parser import ResponseParser - - -def create_app(): - """Create and configure FastAPI application.""" - app = FastAPI(title="codecrow-mcp-client") - review_service = ReviewService() - command_service = CommandService() - - @app.post("/review", response_model=ReviewResponseDto) - async def review_endpoint(req: ReviewRequestDto, request: Request): - """ - HTTP endpoint to accept review requests from the pipeline agent. - - Behavior: - - If the client requests streaming via header `Accept: application/x-ndjson`, - the endpoint will return a StreamingResponse that yields NDJSON events as they occur. - - Otherwise it preserves the original behavior and returns a single ReviewResponseDto JSON body. - """ - try: - wants_stream = _wants_streaming(request) - - if not wants_stream: - # Non-streaming (legacy) behavior - result = await review_service.process_review_request(req) - return ReviewResponseDto( - result=result.get("result"), - error=result.get("error") - ) - - # Streaming behavior - async def event_stream(): - queue = asyncio.Queue() - - # Emit initial queued status - yield _json_event({"type": "status", "state": "queued", "message": "request received"}) - - # Event callback to capture service events - def event_callback(event: Dict[str, Any]): - try: - queue.put_nowait(event) - except asyncio.QueueFull: - pass # Skip if queue is full - - # Run processing in background - async def runner(): - try: - result = await review_service.process_review_request( - req, - event_callback=event_callback - ) - # Emit final event with result - final_event = { - "type": "final", - "result": result.get("result") - } - await queue.put(final_event) - except Exception as e: - await queue.put({ - "type": "error", - "message": str(e) - }) - - task = asyncio.create_task(runner()) - - # Drain queue and yield events - async for event in _drain_queue_until_final(queue, task): - yield _json_event(event) - - return StreamingResponse(event_stream(), media_type="application/x-ndjson") - - except Exception as e: - error_response = ResponseParser.create_error_response( - "HTTP request processing failed", str(e) - ) - return ReviewResponseDto(result=error_response) - - @app.post("/review/summarize", response_model=SummarizeResponseDto) - async def summarize_endpoint(req: SummarizeRequestDto, request: Request): - """ - HTTP endpoint to process /codecrow summarize command. - - Generates a comprehensive PR summary with: - - Overview of changes - - Key files modified - - Impact analysis - - Architecture diagram (Mermaid or ASCII) - """ - try: - wants_stream = _wants_streaming(request) - - if not wants_stream: - # Non-streaming behavior - result = await command_service.process_summarize(req) - return SummarizeResponseDto( - summary=result.get("summary"), - diagram=result.get("diagram"), - diagramType=result.get("diagramType", "MERMAID"), - error=result.get("error") - ) - - # Streaming behavior - async def event_stream(): - queue = asyncio.Queue() - - yield _json_event({"type": "status", "state": "queued", "message": "summarize request received"}) - - def event_callback(event: Dict[str, Any]): - try: - queue.put_nowait(event) - except asyncio.QueueFull: - pass - - async def runner(): - try: - result = await command_service.process_summarize(req, event_callback=event_callback) - await queue.put({ - "type": "final", - "result": result - }) - except Exception as e: - await queue.put({"type": "error", "message": str(e)}) - - task = asyncio.create_task(runner()) - - async for event in _drain_queue_until_final(queue, task): - yield _json_event(event) - - return StreamingResponse(event_stream(), media_type="application/x-ndjson") - - except Exception as e: - return SummarizeResponseDto(error=f"Summarize failed: {str(e)}") - - @app.post("/review/ask", response_model=AskResponseDto) - async def ask_endpoint(req: AskRequestDto, request: Request): - """ - HTTP endpoint to process /codecrow ask command. - - Answers questions about: - - Specific issues - - PR changes - - Codebase (using RAG) - - Analysis results - """ - try: - wants_stream = _wants_streaming(request) - - if not wants_stream: - # Non-streaming behavior - result = await command_service.process_ask(req) - return AskResponseDto( - answer=result.get("answer"), - error=result.get("error") - ) - - # Streaming behavior - async def event_stream(): - queue = asyncio.Queue() - - yield _json_event({"type": "status", "state": "queued", "message": "ask request received"}) - - def event_callback(event: Dict[str, Any]): - try: - queue.put_nowait(event) - except asyncio.QueueFull: - pass - - async def runner(): - try: - result = await command_service.process_ask(req, event_callback=event_callback) - await queue.put({ - "type": "final", - "result": result - }) - except Exception as e: - await queue.put({"type": "error", "message": str(e)}) - - task = asyncio.create_task(runner()) - - async for event in _drain_queue_until_final(queue, task): - yield _json_event(event) - - return StreamingResponse(event_stream(), media_type="application/x-ndjson") - - except Exception as e: - return AskResponseDto(error=f"Ask failed: {str(e)}") - - @app.get("/health") - def health(): - """Health check endpoint.""" - return {"status": "ok"} - - return app - - -def _wants_streaming(request: Request) -> bool: - """Check if client wants streaming response.""" - accept_header = request.headers.get("accept", "") - return "application/x-ndjson" in accept_header.lower() - - -def _json_event(event: Dict[str, Any]) -> str: - """Serialize event to NDJSON line.""" - return json.dumps(event) + "\n" - - -async def _drain_queue_until_final(queue: asyncio.Queue, task: asyncio.Task): - """ - Drain events from queue until we see a final/error event and task is done. - Yields each event as it arrives. - """ - seen_final = False - - while True: - try: - # Wait for event with timeout to prevent hanging - event = await asyncio.wait_for(queue.get(), timeout=1.0) - yield event - - # Check if this is a terminal event - event_type = event.get("type") - if event_type in ("final", "error"): - seen_final = True - # Continue draining in case there are more events - - # If we've seen final and task is done, check for remaining events then exit - if seen_final and task.done(): - # Give a moment for any last events - await asyncio.sleep(0.1) - try: - while True: - event = queue.get_nowait() - yield event - except asyncio.QueueEmpty: - break - break - - except asyncio.TimeoutError: - # No event available, check if task is done - if task.done(): - # Task finished, drain any remaining events - try: - while True: - event = queue.get_nowait() - yield event - if event.get("type") in ("final", "error"): - seen_final = True - except asyncio.QueueEmpty: - pass - - # If we saw a final event or no more events, we're done - if seen_final or task.done(): - break - # Otherwise continue waiting for events - - -def run_http_server(host: str = "0.0.0.0", port: int = 8000): - """Run the FastAPI application.""" - app = create_app() - import uvicorn - uvicorn.run(app, host=host, port=port, log_level="info", timeout_keep_alive=300) - - -if __name__ == "__main__": - host = os.environ.get("AI_CLIENT_HOST", "0.0.0.0") - port = int(os.environ.get("AI_CLIENT_PORT", "8000")) - run_http_server(host=host, port=port) \ No newline at end of file diff --git a/python-ecosystem/mcp-client/service/__init__.py b/python-ecosystem/mcp-client/service/__init__.py new file mode 100644 index 00000000..6a893b0b --- /dev/null +++ b/python-ecosystem/mcp-client/service/__init__.py @@ -0,0 +1,44 @@ +""" +Service Package. + +Contains business logic services organized into subpackages: +- review: Code review functionality (ReviewService, orchestrator, issue processing) +- rag: RAG pipeline client and reranking +- command: Command handling (summarize, ask) +""" + +# Re-export from subpackages for backward compatibility +from service.review import ( + ReviewService, + MultiStageReviewOrchestrator, + IssuePostProcessor, + IssueDeduplicator, + post_process_analysis_result, + restore_missing_diffs_from_previous, +) +from service.rag import ( + RagClient, + RAG_MIN_RELEVANCE_SCORE, + RAG_DEFAULT_TOP_K, + LLMReranker, + RerankResult, +) +from service.command import CommandService + +__all__ = [ + # Review + "ReviewService", + "MultiStageReviewOrchestrator", + "IssuePostProcessor", + "IssueDeduplicator", + "post_process_analysis_result", + "restore_missing_diffs_from_previous", + # RAG + "RagClient", + "RAG_MIN_RELEVANCE_SCORE", + "RAG_DEFAULT_TOP_K", + "LLMReranker", + "RerankResult", + # Command + "CommandService", +] diff --git a/python-ecosystem/mcp-client/service/command/__init__.py b/python-ecosystem/mcp-client/service/command/__init__.py new file mode 100644 index 00000000..3019afae --- /dev/null +++ b/python-ecosystem/mcp-client/service/command/__init__.py @@ -0,0 +1,8 @@ +""" +Command Service Package. + +Contains the CommandService for handling CodeCrow commands (summarize, ask). +""" +from service.command.command_service import CommandService + +__all__ = ["CommandService"] diff --git a/python-ecosystem/mcp-client/service/command_service.py b/python-ecosystem/mcp-client/service/command/command_service.py similarity index 99% rename from python-ecosystem/mcp-client/service/command_service.py rename to python-ecosystem/mcp-client/service/command/command_service.py index 9de4898b..d7c287da 100644 --- a/python-ecosystem/mcp-client/service/command_service.py +++ b/python-ecosystem/mcp-client/service/command/command_service.py @@ -10,10 +10,11 @@ from mcp_use import MCPAgent, MCPClient from langchain_core.agents import AgentAction -from model.models import SummarizeRequestDto, AskRequestDto, SummarizeOutput, AskOutput +from model.dtos import SummarizeRequestDto, AskRequestDto +from model.output_schemas import SummarizeOutput, AskOutput from utils.mcp_config import MCPConfigBuilder from llm.llm_factory import LLMFactory -from service.rag_client import RagClient +from service.rag.rag_client import RagClient from utils.error_sanitizer import sanitize_error_for_display, create_user_friendly_error logger = logging.getLogger(__name__) diff --git a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py deleted file mode 100644 index b5686d77..00000000 --- a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py +++ /dev/null @@ -1,1429 +0,0 @@ -import logging -import asyncio -import json -import re -from typing import Dict, Any, List, Optional, Callable - -from model.models import ( - ReviewRequestDto, - ReviewPlan, - CodeReviewOutput, - CodeReviewIssue, - ReviewFile, - FileGroup, - CrossFileAnalysisResult, - FileReviewBatchOutput -) -from utils.prompts.prompt_builder import PromptBuilder -from utils.diff_processor import ProcessedDiff, DiffProcessor -from mcp_use import MCPAgent - -logger = logging.getLogger(__name__) - - -def extract_llm_response_text(response: Any) -> str: - """ - Extract text content from LLM response, handling different response formats. - Some LLM providers return content as a list of objects instead of a string. - """ - if hasattr(response, 'content'): - content = response.content - if isinstance(content, list): - # Handle list content (e.g., from Gemini or other providers) - text_parts = [] - for item in content: - if isinstance(item, str): - text_parts.append(item) - elif isinstance(item, dict): - if 'text' in item: - text_parts.append(item['text']) - elif 'content' in item: - text_parts.append(item['content']) - elif hasattr(item, 'text'): - text_parts.append(item.text) - return "".join(text_parts) - return str(content) - return str(response) - - -# Prevent duplicate logs from mcp_use -mcp_logger = logging.getLogger("mcp_use") -mcp_logger.propagate = False -if not mcp_logger.handlers: - handler = logging.StreamHandler() - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - handler.setFormatter(formatter) - mcp_logger.addHandler(handler) - -class RecursiveMCPAgent(MCPAgent): - """ - Subclass of MCPAgent that enforces a higher recursion limit on the internal agent executor. - """ - def __init__(self, *args, recursion_limit: int = 50, **kwargs): - self._custom_recursion_limit = recursion_limit - super().__init__(*args, **kwargs) - - async def stream(self, *args, **kwargs): - """ - Override stream to ensure recursion_limit is applied. - """ - # Ensure the executor exists - if self._agent_executor is None: - await self.initialize() - - # Patch the executor's astream if not already patched - executor = self._agent_executor - if executor and not getattr(executor, "_is_patched_recursion", False): - original_astream = executor.astream - limit = self._custom_recursion_limit - - async def patched_astream(input_data, config=None, **astream_kwargs): - if config is None: - config = {} - config["recursion_limit"] = limit - async for chunk in original_astream(input_data, config=config, **astream_kwargs): - yield chunk - - executor.astream = patched_astream - executor._is_patched_recursion = True - logger.info(f"RecursiveMCPAgent: Patched recursion limit to {limit}") - - # Call parent stream - async for item in super().stream(*args, **kwargs): - yield item - -class MultiStageReviewOrchestrator: - """ - Orchestrates the 4-stage AI code review pipeline: - Stage 0: Planning & Prioritization - Stage 1: Parallel File Review - Stage 2: Cross-File & Architectural Analysis - Stage 3: Aggregation & Final Report - """ - - def __init__( - self, - llm, - mcp_client, - rag_client=None, - event_callback: Optional[Callable[[Dict], None]] = None - ): - self.llm = llm - self.client = mcp_client - self.rag_client = rag_client - self.event_callback = event_callback - self.max_parallel_stage_1 = 5 # Limit parallel execution to avoid rate limits - # PR-specific RAG indexing (data goes into main collection with PR metadata) - self._pr_number: Optional[int] = None - self._pr_indexed: bool = False - - async def _index_pr_files( - self, - request: ReviewRequestDto, - processed_diff: Optional[ProcessedDiff] - ) -> None: - """ - Index PR files into the main RAG collection with PR-specific metadata. - This enables hybrid queries that prioritize PR data over stale branch data. - """ - if not self.rag_client or not processed_diff: - return - - pr_number = request.pullRequestId - if not pr_number: - logger.info("No PR number, skipping PR file indexing") - return - - # Prepare files for indexing - # Prefer full_content if available, otherwise use diff content - # Diff content still provides value for understanding what changed - files = [] - for f in processed_diff.get_included_files(): - content = f.full_content or f.content # Use full content if available, fallback to diff - change_type = f.change_type.value if hasattr(f.change_type, 'value') else str(f.change_type) - if content and change_type != "DELETED": - files.append({ - "path": f.path, - "content": content, - "change_type": change_type - }) - - if not files: - logger.info("No files to index for PR") - return - - try: - result = await self.rag_client.index_pr_files( - workspace=request.projectWorkspace, - project=request.projectNamespace, - pr_number=pr_number, - branch=request.targetBranchName or "unknown", - files=files - ) - if result.get("status") == "indexed": - self._pr_number = pr_number - self._pr_indexed = True - logger.info(f"Indexed PR #{pr_number}: {result.get('chunks_indexed', 0)} chunks") - else: - logger.warning(f"Failed to index PR files: {result}") - except Exception as e: - logger.warning(f"Error indexing PR files: {e}") - - async def _cleanup_pr_files(self, request: ReviewRequestDto) -> None: - """Delete PR-indexed data after analysis completes.""" - if not self._pr_indexed or not self._pr_number or not self.rag_client: - return - - try: - await self.rag_client.delete_pr_files( - workspace=request.projectWorkspace, - project=request.projectNamespace, - pr_number=self._pr_number - ) - logger.info(f"Cleaned up PR #{self._pr_number} indexed data") - except Exception as e: - logger.warning(f"Failed to cleanup PR files: {e}") - finally: - self._pr_number = None - self._pr_indexed = False - - async def execute_branch_analysis(self, prompt: str) -> Dict[str, Any]: - """ - Execute a single-pass branch analysis using the provided prompt. - """ - self._emit_status("branch_analysis_started", "Starting Branch Analysis & Reconciliation...") - - agent = RecursiveMCPAgent( - llm=self.llm, - client=self.client, - additional_instructions=PromptBuilder.get_additional_instructions() - ) - - - try: - final_text = "" - # Branch analysis expects standard CodeReviewOutput - async for item in agent.stream(prompt, max_steps=15, output_schema=CodeReviewOutput): - if isinstance(item, CodeReviewOutput): - # Convert to dict format expected by service - issues = [i.model_dump() for i in item.issues] if item.issues else [] - return { - "issues": issues, - "comment": item.comment or "Branch analysis completed." - } - - if isinstance(item, str): - final_text = item - - # If stream finished without object, try parsing text - if final_text: - data = await self._parse_response(final_text, CodeReviewOutput) - issues = [i.model_dump() for i in data.issues] if data.issues else [] - return { - "issues": issues, - "comment": data.comment or "Branch analysis completed." - } - - return {"issues": [], "comment": "No issues found."} - - except Exception as e: - logger.error(f"Branch analysis failed: {e}", exc_info=True) - self._emit_error(str(e)) - raise - - async def orchestrate_review( - self, - request: ReviewRequestDto, - rag_context: Optional[Dict[str, Any]] = None, - processed_diff: Optional[ProcessedDiff] = None - ) -> Dict[str, Any]: - """ - Main entry point for the multi-stage review. - Supports both FULL (initial review) and INCREMENTAL (follow-up review) modes. - The same pipeline is used, but with incremental-aware prompts and issue reconciliation. - """ - # Determine if this is an incremental review - is_incremental = ( - request.analysisMode == "INCREMENTAL" - and request.deltaDiff - ) - - if is_incremental: - logger.info(f"INCREMENTAL mode: reviewing delta diff, {len(request.previousCodeAnalysisIssues or [])} previous issues to reconcile") - else: - logger.info("FULL mode: initial PR review") - - # Generate unique ID for temp diff collection - analysis_id = f"{request.projectId}_{request.pullRequestId or request.commitHash or 'unknown'}" - - try: - # === Index PR files into RAG for hybrid queries === - # This indexes PR files with metadata (pr=true, pr_number=X) to enable - # queries that prioritize fresh PR data over potentially stale branch data - await self._index_pr_files(request, processed_diff) - - # === STAGE 0: Planning === - self._emit_status("stage_0_started", "Stage 0: Planning & Prioritization...") - review_plan = await self._execute_stage_0_planning(request, is_incremental) - - # Validate and fix the plan to ensure all files are included - review_plan = self._ensure_all_files_planned(review_plan, request.changedFiles or []) - self._emit_progress(10, "Stage 0 Complete: Review plan created") - - # === STAGE 1: File Reviews === - self._emit_status("stage_1_started", f"Stage 1: Analyzing {self._count_files(review_plan)} files...") - file_issues = await self._execute_stage_1_file_reviews( - request, review_plan, rag_context, processed_diff, is_incremental - ) - self._emit_progress(60, f"Stage 1 Complete: {len(file_issues)} issues found across files") - - # === STAGE 1.5: Issue Reconciliation === - # Run reconciliation if we have previous issues (both INCREMENTAL and FULL modes) - if request.previousCodeAnalysisIssues: - self._emit_status("reconciliation_started", "Reconciling previous issues...") - file_issues = await self._reconcile_previous_issues( - request, file_issues, processed_diff - ) - self._emit_progress(70, f"Reconciliation Complete: {len(file_issues)} total issues after reconciliation") - - # === STAGE 2: Cross-File Analysis === - self._emit_status("stage_2_started", "Stage 2: Analyzing cross-file patterns...") - cross_file_results = await self._execute_stage_2_cross_file(request, file_issues, review_plan) - self._emit_progress(85, "Stage 2 Complete: Cross-file analysis finished") - - # === STAGE 3: Aggregation === - self._emit_status("stage_3_started", "Stage 3: Generating final report...") - final_report = await self._execute_stage_3_aggregation( - request, review_plan, file_issues, cross_file_results, is_incremental - ) - self._emit_progress(100, "Stage 3 Complete: Report generated") - - # Return structure compatible with existing response expected by frontend/controller - return { - "comment": final_report, - "issues": [issue.model_dump() for issue in file_issues], - } - - except Exception as e: - logger.error(f"Multi-stage review failed: {e}", exc_info=True) - self._emit_error(str(e)) - raise - finally: - # Cleanup PR-indexed data - await self._cleanup_pr_files(request) - - async def _reconcile_previous_issues( - self, - request: ReviewRequestDto, - new_issues: List[CodeReviewIssue], - processed_diff: Optional[ProcessedDiff] = None - ) -> List[CodeReviewIssue]: - """ - Reconcile previous issues with new findings in incremental mode. - - Mark resolved issues as isResolved=true - - Update line numbers for persisting issues - - Merge with new issues found in delta diff - - PRESERVE original issue data (reason, suggestedFixDescription, suggestedFixDiff) - """ - if not request.previousCodeAnalysisIssues: - return new_issues - - logger.info(f"Reconciling {len(request.previousCodeAnalysisIssues)} previous issues with {len(new_issues)} new issues") - - # Current commit for resolution tracking - current_commit = request.currentCommitHash or request.commitHash - - # Get the delta diff content to check what files/lines changed - delta_diff = request.deltaDiff or "" - - # Build a set of files that changed in the delta - changed_files_in_delta = set() - if processed_diff: - for f in processed_diff.files: - changed_files_in_delta.add(f.path) - - # Build lookup of previous issues by ID for merging with LLM results - prev_issues_by_id = {} - for prev_issue in request.previousCodeAnalysisIssues: - if hasattr(prev_issue, 'model_dump'): - prev_data = prev_issue.model_dump() - else: - prev_data = prev_issue if isinstance(prev_issue, dict) else vars(prev_issue) - issue_id = prev_data.get('id') - if issue_id: - prev_issues_by_id[str(issue_id)] = prev_data - - reconciled_issues = [] - processed_prev_ids = set() # Track which previous issues we've handled - - # Process new issues from LLM - merge with previous issue data if they reference same ID - for new_issue in new_issues: - new_data = new_issue.model_dump() if hasattr(new_issue, 'model_dump') else new_issue - issue_id = new_data.get('id') - - # If this issue references a previous issue ID, merge data - if issue_id and str(issue_id) in prev_issues_by_id: - prev_data = prev_issues_by_id[str(issue_id)] - processed_prev_ids.add(str(issue_id)) - - # Check if LLM marked it resolved - is_resolved = new_data.get('isResolved', False) - - # PRESERVE original data, use LLM's reason as resolution explanation - merged_issue = CodeReviewIssue( - id=str(issue_id), - severity=(prev_data.get('severity') or prev_data.get('issueSeverity') or 'MEDIUM').upper(), - category=prev_data.get('category') or prev_data.get('issueCategory') or prev_data.get('type') or 'CODE_QUALITY', - file=prev_data.get('file') or prev_data.get('filePath') or new_data.get('file', 'unknown'), - line=str(prev_data.get('line') or prev_data.get('lineNumber') or new_data.get('line', '1')), - # PRESERVE original reason and fix description - reason=prev_data.get('reason') or prev_data.get('title') or prev_data.get('description') or '', - suggestedFixDescription=prev_data.get('suggestedFixDescription') or prev_data.get('suggestedFix') or '', - suggestedFixDiff=prev_data.get('suggestedFixDiff') or None, - isResolved=is_resolved, - # Store LLM's explanation separately if resolved - resolutionExplanation=new_data.get('reason') if is_resolved else None, - resolvedInCommit=current_commit if is_resolved else None, - visibility=prev_data.get('visibility'), - codeSnippet=prev_data.get('codeSnippet') - ) - reconciled_issues.append(merged_issue) - else: - # New issue not referencing previous - keep as is - reconciled_issues.append(new_issue) - - # Process remaining previous issues not handled by LLM - for prev_issue in request.previousCodeAnalysisIssues: - if hasattr(prev_issue, 'model_dump'): - prev_data = prev_issue.model_dump() - else: - prev_data = prev_issue if isinstance(prev_issue, dict) else vars(prev_issue) - - issue_id = prev_data.get('id') - if issue_id and str(issue_id) in processed_prev_ids: - continue # Already handled above - - file_path = prev_data.get('file', prev_data.get('filePath', '')) - - # Check if this issue was already found in new issues (by file+line) - already_reported = False - for new_issue in new_issues: - new_data = new_issue.model_dump() if hasattr(new_issue, 'model_dump') else new_issue - if (new_data.get('file') == file_path and - str(new_data.get('line')) == str(prev_data.get('line', prev_data.get('lineNumber')))): - already_reported = True - break - - if already_reported: - continue - - # Preserve all original issue data - persisting_issue = CodeReviewIssue( - id=str(issue_id) if issue_id else None, - severity=(prev_data.get('severity') or prev_data.get('issueSeverity') or 'MEDIUM').upper(), - category=prev_data.get('category') or prev_data.get('issueCategory') or prev_data.get('type') or 'CODE_QUALITY', - file=file_path or prev_data.get('file') or prev_data.get('filePath') or 'unknown', - line=str(prev_data.get('line') or prev_data.get('lineNumber') or '1'), - reason=prev_data.get('reason') or prev_data.get('title') or prev_data.get('description') or '', - suggestedFixDescription=prev_data.get('suggestedFixDescription') or prev_data.get('suggestedFix') or '', - suggestedFixDiff=prev_data.get('suggestedFixDiff') or None, - isResolved=False, - visibility=prev_data.get('visibility'), - codeSnippet=prev_data.get('codeSnippet') - ) - reconciled_issues.append(persisting_issue) - - logger.info(f"Reconciliation complete: {len(reconciled_issues)} total issues") - return reconciled_issues - - def _issue_matches_files(self, issue: Any, file_paths: List[str]) -> bool: - """Check if an issue is related to any of the given file paths.""" - if hasattr(issue, 'model_dump'): - issue_data = issue.model_dump() - elif isinstance(issue, dict): - issue_data = issue - else: - issue_data = vars(issue) if hasattr(issue, '__dict__') else {} - - issue_file = issue_data.get('file', issue_data.get('filePath', '')) - - for fp in file_paths: - if issue_file == fp or issue_file.endswith('/' + fp) or fp.endswith('/' + issue_file): - return True - # Also check basename match - if issue_file.split('/')[-1] == fp.split('/')[-1]: - return True - return False - - def _compute_issue_fingerprint(self, data: dict) -> str: - """Compute a fingerprint for issue deduplication. - - Uses file + normalized line (±3 tolerance) + severity + truncated reason. - """ - file_path = data.get('file', data.get('filePath', '')) - line = data.get('line', data.get('lineNumber', 0)) - line_group = int(line) // 3 if line else 0 - severity = data.get('severity', '') - reason = data.get('reason', data.get('description', '')) - reason_prefix = reason[:50].lower().strip() if reason else '' - - return f"{file_path}::{line_group}::{severity}::{reason_prefix}" - - def _deduplicate_issues(self, issues: List[Any]) -> List[dict]: - """Deduplicate issues by fingerprint, keeping most recent version. - - If an older version is resolved but newer isn't, preserves resolved status. - """ - if not issues: - return [] - - deduped: dict = {} - - for issue in issues: - if hasattr(issue, 'model_dump'): - data = issue.model_dump() - elif isinstance(issue, dict): - data = issue.copy() - else: - data = vars(issue).copy() if hasattr(issue, '__dict__') else {} - - fingerprint = self._compute_issue_fingerprint(data) - existing = deduped.get(fingerprint) - - if existing is None: - deduped[fingerprint] = data - else: - existing_version = existing.get('prVersion') or 0 - current_version = data.get('prVersion') or 0 - existing_resolved = existing.get('status', '').lower() == 'resolved' - current_resolved = data.get('status', '').lower() == 'resolved' - - if current_version > existing_version: - # Current is newer - if existing_resolved and not current_resolved: - # Preserve resolved status from older version - data['status'] = 'resolved' - data['resolvedDescription'] = existing.get('resolvedDescription') - data['resolvedByCommit'] = existing.get('resolvedByCommit') - data['resolvedInPrVersion'] = existing.get('resolvedInPrVersion') - deduped[fingerprint] = data - elif current_version == existing_version: - # Same version - prefer resolved - if current_resolved and not existing_resolved: - deduped[fingerprint] = data - - return list(deduped.values()) - - def _format_previous_issues_for_batch(self, issues: List[Any]) -> str: - """Format previous issues for inclusion in batch prompt. - - Deduplicates issues first, then formats with resolution tracking so LLM knows: - - Which issues were previously found - - Which have been resolved (and how) - - Which PR version each issue was found/resolved in - """ - if not issues: - return "" - - # Deduplicate issues to avoid confusing the LLM with duplicates - deduped_issues = self._deduplicate_issues(issues) - - # Separate OPEN and RESOLVED issues - open_issues = [i for i in deduped_issues if i.get('status', '').lower() != 'resolved'] - resolved_issues = [i for i in deduped_issues if i.get('status', '').lower() == 'resolved'] - - lines = ["=== PREVIOUS ISSUES HISTORY (check if resolved/persisting) ==="] - lines.append("Issues have been deduplicated. Only check OPEN issues - RESOLVED ones are for context only.") - lines.append("") - - if open_issues: - lines.append("--- OPEN ISSUES (check if now fixed) ---") - for data in open_issues: - issue_id = data.get('id', 'unknown') - severity = data.get('severity', 'MEDIUM') - file_path = data.get('file', data.get('filePath', 'unknown')) - line = data.get('line', data.get('lineNumber', '?')) - reason = data.get('reason', data.get('description', 'No description')) - pr_version = data.get('prVersion', '?') - - lines.append(f"[ID:{issue_id}] {severity} @ {file_path}:{line} (v{pr_version})") - lines.append(f" Issue: {reason}") - lines.append("") - - if resolved_issues: - lines.append("--- RESOLVED ISSUES (for context only, do NOT re-report) ---") - for data in resolved_issues: - issue_id = data.get('id', 'unknown') - severity = data.get('severity', 'MEDIUM') - file_path = data.get('file', data.get('filePath', 'unknown')) - line = data.get('line', data.get('lineNumber', '?')) - reason = data.get('reason', data.get('description', 'No description')) - pr_version = data.get('prVersion', '?') - resolved_desc = data.get('resolvedDescription', '') - resolved_in = data.get('resolvedInPrVersion', '') - - lines.append(f"[ID:{issue_id}] {severity} @ {file_path}:{line} (v{pr_version}) - RESOLVED") - if resolved_desc: - lines.append(f" Resolution: {resolved_desc}") - if resolved_in: - lines.append(f" Resolved in: v{resolved_in}") - lines.append(f" Original issue: {reason}") - lines.append("") - - lines.append("INSTRUCTIONS:") - lines.append("- For OPEN issues that are now FIXED: report with 'isResolved': true (boolean)") - lines.append("- For OPEN issues still present: report with 'isResolved': false (boolean)") - lines.append("- Do NOT re-report RESOLVED issues - they are only shown for context") - lines.append("- IMPORTANT: 'isResolved' MUST be a JSON boolean (true/false), not a string") - lines.append("- Preserve the 'id' field for all issues you report from previous issues") - lines.append("=== END PREVIOUS ISSUES ===") - return "\n".join(lines) - - def _extract_symbols_from_diff(self, diff_content: str) -> List[str]: - """ - Extract potential symbols (identifiers, class names, function names) from diff. - Used to query cross-file context for related changes. - """ - if not diff_content: - return [] - - # Common language keywords/stop-words to filter out - STOP_WORDS = { - # Python - 'import', 'from', 'class', 'def', 'return', 'if', 'else', 'elif', - 'for', 'while', 'try', 'except', 'finally', 'with', 'as', 'pass', - 'break', 'continue', 'raise', 'yield', 'lambda', 'async', 'await', - 'True', 'False', 'None', 'and', 'or', 'not', 'in', 'is', - # Java/TS/JS - 'public', 'private', 'protected', 'static', 'final', 'void', - 'new', 'this', 'super', 'extends', 'implements', 'interface', - 'abstract', 'const', 'let', 'var', 'function', 'export', 'default', - 'throw', 'throws', 'catch', 'instanceof', 'typeof', 'null', - # Common - 'true', 'false', 'null', 'undefined', 'self', 'args', 'kwargs', - 'string', 'number', 'boolean', 'object', 'array', 'list', 'dict', - } - - symbols = set() - - # Patterns for common identifiers - # Match CamelCase identifiers (likely class/component names) - camel_case = re.findall(r'\b([A-Z][a-z]+[A-Z][a-zA-Z]*)\b', diff_content) - symbols.update(camel_case) - - # Match snake_case identifiers (variables, functions) - snake_case = re.findall(r'\b([a-z][a-z0-9]*(?:_[a-z0-9]+)+)\b', diff_content) - symbols.update(s for s in snake_case if len(s) > 5) # Filter short ones - - # Match assignments and function calls - assignments = re.findall(r'\b(\w+)\s*[=:]\s*', diff_content) - symbols.update(a for a in assignments if len(a) > 3) - - # Match import statements - imports = re.findall(r'(?:from|import)\s+([a-zA-Z_][a-zA-Z0-9_.]+)', diff_content) - symbols.update(imports) - - # Filter out stop-words and return - filtered = [s for s in symbols if s.lower() not in STOP_WORDS and len(s) > 2] - return filtered[:20] # Limit to top 20 symbols - - def _extract_diff_snippets(self, diff_content: str) -> List[str]: - """ - Extract meaningful code snippets from diff content for RAG semantic search. - Focuses on added/modified lines that represent significant code changes. - """ - if not diff_content: - return [] - - snippets = [] - current_snippet_lines = [] - - for line in diff_content.splitlines(): - # Focus on added lines (new code) - if line.startswith("+") and not line.startswith("+++"): - clean_line = line[1:].strip() - # Skip trivial lines - if (clean_line and - len(clean_line) > 10 and # Minimum meaningful length - not clean_line.startswith("//") and # Skip comments - not clean_line.startswith("#") and - not clean_line.startswith("*") and - not clean_line == "{" and - not clean_line == "}" and - not clean_line == ""): - current_snippet_lines.append(clean_line) - - # Batch into snippets of 3-5 lines - if len(current_snippet_lines) >= 3: - snippets.append(" ".join(current_snippet_lines)) - current_snippet_lines = [] - - # Add remaining lines as final snippet - if current_snippet_lines: - snippets.append(" ".join(current_snippet_lines)) - - # Limit to most significant snippets - return snippets[:10] - - def _get_diff_snippets_for_batch( - self, - all_diff_snippets: List[str], - batch_file_paths: List[str] - ) -> List[str]: - """ - Filter diff snippets to only include those relevant to the batch files. - - Note: Java DiffParser.extractDiffSnippets() returns CLEAN CODE SNIPPETS (no file paths). - These snippets are just significant code lines like function signatures. - Since snippets don't contain file paths, we return all snippets for semantic search. - The embedding similarity will naturally prioritize relevant matches. - """ - if not all_diff_snippets: - return [] - - # Java snippets are clean code (no file paths), so we can't filter by path - # Return all snippets - the semantic search will find relevant matches - logger.info(f"Using {len(all_diff_snippets)} diff snippets for batch files {batch_file_paths}") - return all_diff_snippets - - async def _execute_stage_0_planning(self, request: ReviewRequestDto, is_incremental: bool = False) -> ReviewPlan: - """ - Stage 0: Analyze metadata and generate a review plan. - Uses structured output for reliable JSON parsing. - """ - # Prepare context for prompt - changed_files_summary = [] - if request.changedFiles: - for f in request.changedFiles: - changed_files_summary.append({ - "path": f, - "type": "MODIFIED", - "lines_added": "?", - "lines_deleted": "?" - }) - - prompt = PromptBuilder.build_stage_0_planning_prompt( - repo_slug=request.projectVcsRepoSlug, - pr_id=str(request.pullRequestId), - pr_title=request.prTitle or "", - author="Unknown", - branch_name="source-branch", - target_branch=request.targetBranchName or "main", - commit_hash=request.commitHash or "HEAD", - changed_files_json=json.dumps(changed_files_summary, indent=2) - ) - - # Stage 0 uses direct LLM call (no tools needed - all metadata is provided) - try: - structured_llm = self.llm.with_structured_output(ReviewPlan) - result = await structured_llm.ainvoke(prompt) - if result: - logger.info("Stage 0 planning completed with structured output") - return result - except Exception as e: - logger.warning(f"Structured output failed for Stage 0: {e}") - - # Fallback to manual parsing - try: - response = await self.llm.ainvoke(prompt) - content = extract_llm_response_text(response) - return await self._parse_response(content, ReviewPlan) - except Exception as e: - logger.error(f"Stage 0 planning failed: {e}") - raise ValueError(f"Stage 0 planning failed: {e}") - - def _chunk_files(self, file_groups: List[Any], max_files_per_batch: int = 5) -> List[List[Dict[str, Any]]]: - """Flatten file groups and chunk into batches.""" - all_files = [] - for group in file_groups: - for f in group.files: - # Attach priority context for the review - all_files.append({ - "file": f, - "priority": group.priority - }) - - return [all_files[i:i + max_files_per_batch] for i in range(0, len(all_files), max_files_per_batch)] - - async def _execute_stage_1_file_reviews( - self, - request: ReviewRequestDto, - plan: ReviewPlan, - rag_context: Optional[Dict[str, Any]] = None, - processed_diff: Optional[ProcessedDiff] = None, - is_incremental: bool = False - ) -> List[CodeReviewIssue]: - """ - Stage 1: Execute batch file reviews with per-batch RAG context. - """ - # Use smaller batches (5 files max) for better RAG relevance and review quality - batches = self._chunk_files(plan.file_groups, max_files_per_batch=5) - - total_files = sum(len(batch) for batch in batches) - logger.info(f"Stage 1: Processing {total_files} files in {len(batches)} batches (max 3 files/batch)") - - # Process batches with controlled parallelism - all_issues = [] - batch_results = [] - - # Process in waves to avoid rate limits - for wave_start in range(0, len(batches), self.max_parallel_stage_1): - wave_end = min(wave_start + self.max_parallel_stage_1, len(batches)) - wave_batches = batches[wave_start:wave_end] - - logger.info(f"Stage 1: Processing wave {wave_start // self.max_parallel_stage_1 + 1}, " - f"batches {wave_start + 1}-{wave_end} of {len(batches)}") - - tasks = [] - for batch_idx, batch in enumerate(wave_batches, start=wave_start + 1): - batch_paths = [item["file"].path for item in batch] - logger.debug(f"Batch {batch_idx}: {batch_paths}") - tasks.append(self._review_file_batch( - request, batch, processed_diff, is_incremental, - fallback_rag_context=rag_context - )) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - for idx, res in enumerate(results): - batch_num = wave_start + idx + 1 - if isinstance(res, Exception): - logger.error(f"Error reviewing batch {batch_num}: {res}") - elif res: - logger.info(f"Batch {batch_num} completed: {len(res)} issues found") - all_issues.extend(res) - else: - logger.info(f"Batch {batch_num} completed: no issues found") - - # Update progress - progress = 10 + int((wave_end / len(batches)) * 50) - self._emit_progress(progress, f"Stage 1: Reviewed {wave_end}/{len(batches)} batches") - - logger.info(f"Stage 1 Complete: {len(all_issues)} issues found across {total_files} files") - return all_issues - - async def _fetch_batch_rag_context( - self, - request: ReviewRequestDto, - batch_file_paths: List[str], - batch_diff_snippets: List[str] - ) -> Optional[Dict[str, Any]]: - """ - Fetch RAG context specifically for this batch of files. - Uses batch file paths and diff snippets for targeted semantic search. - - In hybrid mode (when PR files are indexed), passes pr_number to enable - queries that prioritize fresh PR data over potentially stale branch data. - """ - if not self.rag_client: - return None - - try: - # Determine branch for RAG query - rag_branch = request.targetBranchName or request.commitHash or "main" - - logger.info(f"Fetching per-batch RAG context for {len(batch_file_paths)} files") - - # Use hybrid mode if PR files were indexed - pr_number = request.pullRequestId if self._pr_indexed else None - all_pr_files = request.changedFiles if self._pr_indexed else None - - rag_response = await self.rag_client.get_pr_context( - workspace=request.projectWorkspace, - project=request.projectNamespace, - branch=rag_branch, - changed_files=batch_file_paths, - diff_snippets=batch_diff_snippets, - pr_title=request.prTitle, - pr_description=request.prDescription, - top_k=10, # Fewer chunks per batch for focused context - pr_number=pr_number, - all_pr_changed_files=all_pr_files - ) - - if rag_response and rag_response.get("context"): - context = rag_response.get("context") - chunk_count = len(context.get("relevant_code", [])) - logger.info(f"Per-batch RAG: retrieved {chunk_count} chunks for files {batch_file_paths}") - return context - - return None - - except Exception as e: - logger.warning(f"Failed to fetch per-batch RAG context: {e}") - return None - - async def _review_file_batch( - self, - request: ReviewRequestDto, - batch_items: List[Dict[str, Any]], - processed_diff: Optional[ProcessedDiff] = None, - is_incremental: bool = False, - fallback_rag_context: Optional[Dict[str, Any]] = None - ) -> List[CodeReviewIssue]: - """ - Review a batch of files in a single LLM call with per-batch RAG context. - In incremental mode, uses delta diff and focuses on new changes only. - """ - batch_files_data = [] - batch_file_paths = [] - batch_diff_snippets = [] - #TODO: Project custom rules - project_rules = "" - - # For incremental mode, use deltaDiff instead of full diff - diff_source = None - if is_incremental and request.deltaDiff: - # Parse delta diff to extract per-file diffs - diff_source = DiffProcessor().process(request.deltaDiff) if request.deltaDiff else None - else: - diff_source = processed_diff - - # Collect file paths, diffs, and extract snippets for this batch - for item in batch_items: - file_info = item["file"] - batch_file_paths.append(file_info.path) - - # Extract diff from the appropriate source (delta for incremental, full for initial) - file_diff = "" - if diff_source: - for f in diff_source.files: - if f.path == file_info.path or f.path.endswith("/" + file_info.path): - file_diff = f.content - # Extract code snippets from diff for RAG semantic search - if file_diff: - batch_diff_snippets.extend(self._extract_diff_snippets(file_diff)) - break - - batch_files_data.append({ - "path": file_info.path, - "type": "MODIFIED", - "focus_areas": file_info.focus_areas, - "old_code": "", - "diff": file_diff or "(Diff unavailable)", - "is_incremental": is_incremental # Pass mode to prompt builder - }) - - # Fetch per-batch RAG context using batch-specific files and diff snippets - rag_context_text = "" - batch_rag_context = None - - if self.rag_client: - batch_rag_context = await self._fetch_batch_rag_context( - request, batch_file_paths, batch_diff_snippets - ) - - # Use batch-specific RAG context if available, otherwise fall back to initial context - # Hybrid mode: PR-indexed data is already included via _fetch_batch_rag_context - if batch_rag_context: - logger.info(f"Using per-batch RAG context for: {batch_file_paths}") - rag_context_text = self._format_rag_context( - batch_rag_context, - set(batch_file_paths), - pr_changed_files=request.changedFiles - ) - elif fallback_rag_context: - logger.info(f"Using fallback RAG context for batch: {batch_file_paths}") - rag_context_text = self._format_rag_context( - fallback_rag_context, - set(batch_file_paths), - pr_changed_files=request.changedFiles - ) - - logger.info(f"RAG context for batch: {len(rag_context_text)} chars") - - # For incremental mode, filter previous issues relevant to this batch - # Also pass previous issues in FULL mode if they exist (subsequent PR iterations) - previous_issues_for_batch = "" - has_previous_issues = request.previousCodeAnalysisIssues and len(request.previousCodeAnalysisIssues) > 0 - if has_previous_issues: - relevant_prev_issues = [ - issue for issue in request.previousCodeAnalysisIssues - if self._issue_matches_files(issue, batch_file_paths) - ] - if relevant_prev_issues: - previous_issues_for_batch = self._format_previous_issues_for_batch(relevant_prev_issues) - - # Build ONE prompt for the batch - prompt = PromptBuilder.build_stage_1_batch_prompt( - files=batch_files_data, - priority=batch_items[0]["priority"] if batch_items else "MEDIUM", - project_rules=project_rules, - rag_context=rag_context_text, - is_incremental=is_incremental, - previous_issues=previous_issues_for_batch - ) - - # Stage 1 uses direct LLM call (no tools needed - diff is already provided) - try: - # Try structured output first - structured_llm = self.llm.with_structured_output(FileReviewBatchOutput) - result = await structured_llm.ainvoke(prompt) - if result: - all_batch_issues = [] - for review in result.reviews: - all_batch_issues.extend(review.issues) - return all_batch_issues - except Exception as e: - logger.warning(f"Structured output failed for Stage 1 batch: {e}") - - # Fallback to manual parsing - try: - response = await self.llm.ainvoke(prompt) - content = extract_llm_response_text(response) - data = await self._parse_response(content, FileReviewBatchOutput) - all_batch_issues = [] - for review in data.reviews: - all_batch_issues.extend(review.issues) - return all_batch_issues - except Exception as parse_err: - logger.error(f"Batch review failed: {parse_err}") - return [] - - return [] - - async def _execute_stage_2_cross_file( - self, - request: ReviewRequestDto, - stage_1_issues: List[CodeReviewIssue], - plan: ReviewPlan - ) -> CrossFileAnalysisResult: - """ - Stage 2: Cross-file analysis. - """ - # Serialize Stage 1 findings - issues_json = json.dumps([i.model_dump() for i in stage_1_issues], indent=2) - - prompt = PromptBuilder.build_stage_2_cross_file_prompt( - repo_slug=request.projectVcsRepoSlug, - pr_title=request.prTitle or "", - commit_hash=request.commitHash or "HEAD", - stage_1_findings_json=issues_json, - architecture_context="(Architecture context from MCP or knowledge base)", - migrations="(Migration scripts found in PR)", - cross_file_concerns=plan.cross_file_concerns - ) - - # Stage 2 uses direct LLM call (no tools needed - all data is provided from Stage 1) - try: - structured_llm = self.llm.with_structured_output(CrossFileAnalysisResult) - result = await structured_llm.ainvoke(prompt) - if result: - logger.info("Stage 2 cross-file analysis completed with structured output") - return result - except Exception as e: - logger.warning(f"Structured output failed for Stage 2: {e}") - - # Fallback to manual parsing - try: - response = await self.llm.ainvoke(prompt) - content = extract_llm_response_text(response) - return await self._parse_response(content, CrossFileAnalysisResult) - except Exception as e: - logger.error(f"Stage 2 cross-file analysis failed: {e}") - raise - - async def _execute_stage_3_aggregation( - self, - request: ReviewRequestDto, - plan: ReviewPlan, - stage_1_issues: List[CodeReviewIssue], - stage_2_results: CrossFileAnalysisResult, - is_incremental: bool = False - ) -> str: - """ - Stage 3: Generate Markdown report. - In incremental mode, includes summary of resolved vs new issues. - """ - stage_1_json = json.dumps([i.model_dump() for i in stage_1_issues], indent=2) - stage_2_json = stage_2_results.model_dump_json(indent=2) - plan_json = plan.model_dump_json(indent=2) - - # Add incremental context to aggregation - incremental_context = "" - if is_incremental: - resolved_count = sum(1 for i in stage_1_issues if i.isResolved) - new_count = len(stage_1_issues) - resolved_count - previous_count = len(request.previousCodeAnalysisIssues or []) - incremental_context = f""" -## INCREMENTAL REVIEW SUMMARY -- Previous issues from last review: {previous_count} -- Issues resolved in this update: {resolved_count} -- New issues found in delta: {new_count} -- Total issues after reconciliation: {len(stage_1_issues)} -""" - - prompt = PromptBuilder.build_stage_3_aggregation_prompt( - repo_slug=request.projectVcsRepoSlug, - pr_id=str(request.pullRequestId), - author="Unknown", - pr_title=request.prTitle or "", - total_files=len(request.changedFiles or []), - additions=0, # Need accurate stats - deletions=0, - stage_0_plan=plan_json, - stage_1_issues_json=stage_1_json, - stage_2_findings_json=stage_2_json, - recommendation=stage_2_results.pr_recommendation - ) - - response = await self.llm.ainvoke(prompt) - return extract_llm_response_text(response) - - async def _parse_response(self, content: str, model_class: Any, retries: int = 2) -> Any: - """ - Robustly parse JSON response into a Pydantic model with retries. - Falls back to manual parsing if structured output wasn't used. - """ - last_error = None - - # Initial cleaning attempt - try: - cleaned = self._clean_json_text(content) - logger.debug(f"Cleaned JSON for {model_class.__name__} (first 500 chars): {cleaned[:500]}") - data = json.loads(cleaned) - return model_class(**data) - except Exception as e: - last_error = e - logger.warning(f"Initial parse failed for {model_class.__name__}: {e}") - logger.debug(f"Raw content (first 1000 chars): {content[:1000]}") - - # Retry with structured output if available - try: - logger.info(f"Attempting structured output retry for {model_class.__name__}") - structured_llm = self.llm.with_structured_output(model_class) - result = await structured_llm.ainvoke( - f"Parse and return this as valid {model_class.__name__}:\n{content[:4000]}" - ) - if result: - logger.info(f"Structured output retry succeeded for {model_class.__name__}") - return result - except Exception as e: - logger.warning(f"Structured output retry failed: {e}") - last_error = e - - # Final fallback: LLM repair loop - for attempt in range(retries): - try: - logger.info(f"Repairing JSON for {model_class.__name__}, attempt {attempt+1}") - repaired = await self._repair_json_with_llm( - content, - str(last_error), - model_class.model_json_schema() - ) - cleaned = self._clean_json_text(repaired) - logger.debug(f"Repaired JSON attempt {attempt+1} (first 500 chars): {cleaned[:500]}") - data = json.loads(cleaned) - return model_class(**data) - except Exception as e: - last_error = e - logger.warning(f"Retry {attempt+1} failed: {e}") - - raise ValueError(f"Failed to parse {model_class.__name__} after retries: {last_error}") - - async def _repair_json_with_llm(self, broken_json: str, error: str, schema: Any) -> str: - """ - Ask LLM to repair malformed JSON. - """ - # Truncate the broken JSON to avoid token limits but show enough context - truncated_json = broken_json[:3000] if len(broken_json) > 3000 else broken_json - - prompt = f"""You are a JSON repair expert. -The following JSON failed to parse/validate: -Error: {error} - -Broken JSON: -{truncated_json} - -Required Schema (the output MUST be a JSON object, not an array): -{json.dumps(schema, indent=2)} - -CRITICAL INSTRUCTIONS: -1. Return ONLY the fixed valid JSON object -2. The response MUST start with {{ and end with }} -3. All property names MUST be enclosed in double quotes -4. No markdown code blocks (no ```) -5. No explanatory text before or after the JSON -6. Ensure all required fields from the schema are present - -Output the corrected JSON object now:""" - response = await self.llm.ainvoke(prompt) - return extract_llm_response_text(response) - - def _clean_json_text(self, text: str) -> str: - """ - Clean markdown and extraneous text from JSON. - """ - text = text.strip() - - # Remove markdown code blocks - if text.startswith("```"): - lines = text.split("\n") - # Skip the opening ``` line (with or without language identifier) - lines = lines[1:] - # Remove trailing ``` if present - if lines and lines[-1].strip() == "```": - lines = lines[:-1] - text = "\n".join(lines).strip() - - # Also handle case where ``` appears mid-text - if "```json" in text: - start_idx = text.find("```json") - end_idx = text.find("```", start_idx + 7) - if end_idx != -1: - text = text[start_idx + 7:end_idx].strip() - else: - text = text[start_idx + 7:].strip() - elif "```" in text: - # Generic code block without language - start_idx = text.find("```") - remaining = text[start_idx + 3:] - end_idx = remaining.find("```") - if end_idx != -1: - text = remaining[:end_idx].strip() - else: - text = remaining.strip() - - # Find JSON object boundaries - obj_start = text.find("{") - obj_end = text.rfind("}") - arr_start = text.find("[") - arr_end = text.rfind("]") - - # Determine if we have an object or array (whichever comes first) - if obj_start != -1 and obj_end != -1: - if arr_start == -1 or obj_start < arr_start: - # Object comes first or no array - text = text[obj_start:obj_end+1] - elif arr_start < obj_start and arr_end != -1: - # Array comes first - but we need an object for Pydantic - # Check if the object is nested inside the array or separate - if obj_end > arr_end: - # Object extends beyond array - likely the object we want - text = text[obj_start:obj_end+1] - else: - # Try to use the object anyway - text = text[obj_start:obj_end+1] - elif arr_start != -1 and arr_end != -1 and obj_start == -1: - # Only array found - log warning as Pydantic models expect objects - logger.warning(f"JSON cleaning found array instead of object, this may fail parsing") - text = text[arr_start:arr_end+1] - - return text - - def _format_rag_context( - self, - rag_context: Optional[Dict[str, Any]], - relevant_files: Optional[set] = None, - pr_changed_files: Optional[List[str]] = None - ) -> str: - """ - Format RAG context into a readable string for the prompt. - - IMPORTANT: We trust RAG's semantic similarity scores for relevance. - The RAG system already uses embeddings to find semantically related code. - We only filter out chunks from files being modified in the PR (stale data from main branch). - - Args: - rag_context: RAG response with code chunks - relevant_files: (UNUSED - kept for API compatibility) - we trust RAG scores instead - pr_changed_files: Files modified in the PR - chunks from these may be stale - """ - if not rag_context: - logger.debug("RAG context is empty or None") - return "" - - # Handle both "chunks" and "relevant_code" keys (RAG API uses "relevant_code") - chunks = rag_context.get("relevant_code", []) or rag_context.get("chunks", []) - if not chunks: - logger.debug("No chunks found in RAG context (keys: %s)", list(rag_context.keys())) - return "" - - logger.info(f"Processing {len(chunks)} RAG chunks (trusting semantic similarity scores)") - - # Normalize PR changed files for stale-data detection only - pr_changed_set = set() - if pr_changed_files: - for f in pr_changed_files: - pr_changed_set.add(f) - if "/" in f: - pr_changed_set.add(f.rsplit("/", 1)[-1]) - - formatted_parts = [] - included_count = 0 - skipped_stale = 0 - - for chunk in chunks: - if included_count >= 15: - logger.debug(f"Reached chunk limit of 15") - break - - metadata = chunk.get("metadata", {}) - path = metadata.get("path", chunk.get("path", "unknown")) - chunk_type = metadata.get("content_type", metadata.get("type", "code")) - score = chunk.get("score", chunk.get("relevance_score", 0)) - - # Only filter: chunks from PR-modified files with LOW scores (likely stale) - # High-score chunks from modified files may still be relevant (other parts of same file) - if pr_changed_set: - path_filename = path.rsplit("/", 1)[-1] if "/" in path else path - is_from_modified_file = ( - path in pr_changed_set or - path_filename in pr_changed_set or - any(path.endswith(f) or f.endswith(path) for f in pr_changed_set) - ) - - # Skip ONLY low-score chunks from modified files (likely stale/outdated) - if is_from_modified_file and score < 0.70: - logger.debug(f"Skipping stale chunk from modified file: {path} (score={score:.2f})") - skipped_stale += 1 - continue - - text = chunk.get("text", chunk.get("content", "")) - if not text: - continue - - included_count += 1 - - # Build rich metadata context - meta_lines = [f"File: {path}"] - - if metadata.get("namespace"): - meta_lines.append(f"Namespace: {metadata['namespace']}") - elif metadata.get("package"): - meta_lines.append(f"Package: {metadata['package']}") - - if metadata.get("primary_name"): - meta_lines.append(f"Definition: {metadata['primary_name']}") - elif metadata.get("semantic_names"): - meta_lines.append(f"Definitions: {', '.join(metadata['semantic_names'][:5])}") - - if metadata.get("extends"): - extends = metadata["extends"] - meta_lines.append(f"Extends: {', '.join(extends) if isinstance(extends, list) else extends}") - - if metadata.get("implements"): - implements = metadata["implements"] - meta_lines.append(f"Implements: {', '.join(implements) if isinstance(implements, list) else implements}") - - if metadata.get("imports"): - imports = metadata["imports"] - if isinstance(imports, list): - if len(imports) <= 5: - meta_lines.append(f"Imports: {'; '.join(imports)}") - else: - meta_lines.append(f"Imports: {'; '.join(imports[:5])}... (+{len(imports)-5} more)") - - if metadata.get("parent_context"): - parent_ctx = metadata["parent_context"] - if isinstance(parent_ctx, list): - meta_lines.append(f"Parent: {'.'.join(parent_ctx)}") - - if chunk_type and chunk_type != "code": - meta_lines.append(f"Type: {chunk_type}") - - meta_text = "\n".join(meta_lines) - # Use file path as primary identifier, not a number - # This encourages AI to reference by path rather than by chunk number - formatted_parts.append( - f"### Context from `{path}` (relevance: {score:.2f})\n" - f"{meta_text}\n" - f"```\n{text}\n```\n" - ) - - if not formatted_parts: - logger.warning(f"No RAG chunks included (total: {len(chunks)}, skipped_stale: {skipped_stale})") - return "" - - logger.info(f"Included {len(formatted_parts)} RAG chunks (skipped {skipped_stale} stale from modified files)") - return "\n".join(formatted_parts) - - def _emit_status(self, state: str, message: str): - if self.event_callback: - self.event_callback({ - "type": "status", - "state": state, - "message": message - }) - - def _emit_progress(self, percent: int, message: str): - if self.event_callback: - self.event_callback({ - "type": "progress", - "percent": percent, - "message": message - }) - - def _emit_error(self, message: str): - if self.event_callback: - self.event_callback({ - "type": "error", - "message": message - }) - - def _count_files(self, plan: ReviewPlan) -> int: - count = 0 - for group in plan.file_groups: - count += len(group.files) - return count - - def _ensure_all_files_planned(self, plan: ReviewPlan, all_changed_files: List[str]) -> ReviewPlan: - """ - Ensure all changed files are included in the review plan. - If LLM missed files, add them to a LOW priority group. - """ - # Collect files already in the plan - planned_files = set() - for group in plan.file_groups: - for f in group.files: - planned_files.add(f.path) - - # Also count skipped files - skipped_files = set() - for skip in plan.files_to_skip: - skipped_files.add(skip.path) - - # Find missing files - all_files_set = set(all_changed_files) - missing_files = all_files_set - planned_files - skipped_files - - if missing_files: - logger.warning( - f"Stage 0 plan missing {len(missing_files)} files out of {len(all_changed_files)}. " - f"Adding to LOW priority group." - ) - - # Create ReviewFile objects for missing files - missing_review_files = [] - for path in missing_files: - missing_review_files.append(ReviewFile( - path=path, - focus_areas=["GENERAL"], - risk_level="LOW", - estimated_issues=0 - )) - - # Add a new group for missing files or append to existing LOW group - low_group_found = False - for group in plan.file_groups: - if group.priority == "LOW": - group.files.extend(missing_review_files) - low_group_found = True - break - - if not low_group_found: - plan.file_groups.append(FileGroup( - group_id="GROUP_MISSING_FILES", - priority="LOW", - rationale="Files not categorized by planner - added automatically", - files=missing_review_files - )) - - logger.info(f"Plan now includes {self._count_files(plan)} files for review") - else: - logger.info( - f"Stage 0 plan complete: {len(planned_files)} files to review, " - f"{len(skipped_files)} files skipped" - ) - - return plan diff --git a/python-ecosystem/mcp-client/service/pooled_review_service.py b/python-ecosystem/mcp-client/service/pooled_review_service.py deleted file mode 100644 index dce735bb..00000000 --- a/python-ecosystem/mcp-client/service/pooled_review_service.py +++ /dev/null @@ -1,261 +0,0 @@ -""" -Example integration of MCP Process Pool with ReviewService. - -This file shows how to integrate the process pool for improved performance. -The actual integration requires careful testing due to the complexity of -the mcp_use library's MCPAgent/MCPClient. - -IMPORTANT: The mcp_use library creates its own subprocess internally. -To fully benefit from pooling, we need to either: -1. Modify mcp_use to accept an existing process -2. Or bypass mcp_use and communicate directly with our pooled processes - -This example shows approach #2 - direct communication with pooled processes. -""" - -import os -import asyncio -import json -import logging -from typing import Dict, Any, Optional, Callable -from dotenv import load_dotenv - -from utils.mcp_pool import get_mcp_pool, McpProcessPool -from model.models import ReviewRequestDto -from llm.llm_factory import LLMFactory -from utils.prompts.prompt_builder import PromptBuilder -from utils.response_parser import ResponseParser -from service.rag_client import RagClient - -logger = logging.getLogger(__name__) - - -class PooledReviewService: - """ - Review service using process pooling for MCP servers. - - This is an alternative implementation that bypasses mcp_use's internal - process management to use our own process pool. - - Benefits: - - Zero JVM startup overhead after pool warmup - - Shared memory footprint across requests - - Better resource utilization for SaaS deployments - - Trade-offs: - - Requires direct MCP protocol communication - - May not support all mcp_use features - """ - - MAX_FIX_RETRIES = 2 - - def __init__(self): - load_dotenv() - self.default_jar_path = os.environ.get( - "MCP_SERVER_JAR", - "/app/codecrow-vcs-mcp-1.0.jar" - ) - self.rag_client = RagClient() - self._pool: Optional[McpProcessPool] = None - - # Check if pooling is enabled - self.pooling_enabled = os.environ.get("MCP_POOLING_ENABLED", "false").lower() == "true" - - async def _get_pool(self) -> McpProcessPool: - """Get or initialize the process pool.""" - if self._pool is None: - self._pool = await get_mcp_pool(self.default_jar_path) - return self._pool - - async def process_review_request( - self, - request: ReviewRequestDto, - event_callback: Optional[Callable[[Dict], None]] = None - ) -> Dict[str, Any]: - """ - Process a review request. - - If pooling is enabled, uses the process pool. - Otherwise, falls back to the original mcp_use implementation. - """ - if self.pooling_enabled: - return await self._process_review_pooled(request, event_callback) - else: - # Fall back to original implementation - from service.review_service import ReviewService - original_service = ReviewService() - return await original_service.process_review_request(request, event_callback) - - async def _process_review_pooled( - self, - request: ReviewRequestDto, - event_callback: Optional[Callable[[Dict], None]] = None - ) -> Dict[str, Any]: - """ - Process review using pooled MCP server process. - - This method communicates directly with the MCP server via STDIO - instead of using mcp_use's internal process spawning. - """ - try: - self._emit_event(event_callback, { - "type": "status", - "state": "started", - "message": "Analysis starting (pooled mode)" - }) - - pool = await self._get_pool() - - self._emit_event(event_callback, { - "type": "status", - "state": "pool_acquired", - "message": "Acquired MCP server from pool" - }) - - async with pool.acquire() as pooled_process: - # Build the prompt and get LLM - pr_metadata = self._build_pr_metadata(request) - rag_context = await self._fetch_rag_context(request, event_callback) - prompt = self._build_prompt(request, pr_metadata, rag_context) - llm = self._create_llm(request) - - self._emit_event(event_callback, { - "type": "status", - "state": "executing", - "message": "Executing analysis with pooled process" - }) - - # Execute the analysis using the pooled process - # This is a simplified version - full implementation would - # need to handle the MCP protocol properly - result = await self._execute_with_pooled_process( - pooled_process, - prompt, - llm, - request, - event_callback - ) - - self._emit_event(event_callback, { - "type": "final", - "result": "Analysis completed" - }) - - return {"result": result} - - except Exception as e: - logger.error(f"Pooled review failed: {e}", exc_info=True) - error_response = ResponseParser.create_error_response( - "Agent execution failed (pooled)", str(e) - ) - self._emit_event(event_callback, { - "type": "error", - "message": str(e) - }) - return {"result": error_response} - - async def _execute_with_pooled_process( - self, - pooled_process, - prompt: str, - llm, - request: ReviewRequestDto, - event_callback: Optional[Callable[[Dict], None]] - ) -> Dict[str, Any]: - """ - Execute analysis using a pooled MCP server process. - - NOTE: This is a placeholder that shows the concept. - Full implementation requires proper MCP protocol handling. - - The mcp_use library's MCPClient handles the protocol, but it - creates its own subprocess. To use pooling, we'd need to either: - - 1. Fork mcp_use to accept external processes - 2. Implement MCP protocol directly - 3. Use the MCP SDK's Python client directly - - For now, this shows the architecture for option 2/3. - """ - # MCP JSON-RPC message format - # See: https://modelcontextprotocol.io/docs/spec/protocol - - # Initialize connection - init_request = { - "jsonrpc": "2.0", - "id": 1, - "method": "initialize", - "params": { - "protocolVersion": "2024-11-05", - "capabilities": {}, - "clientInfo": { - "name": "codecrow-mcp-client", - "version": "1.0.0" - } - } - } - - # Send initialization (this is conceptual - actual impl needs proper I/O handling) - # await self._send_mcp_message(pooled_process, init_request) - # response = await self._read_mcp_response(pooled_process) - - # For now, return a placeholder indicating pooling is working - # but full protocol implementation is pending - return { - "summary": "Pooled execution placeholder", - "issues": [], - "note": "Full MCP protocol integration pending. Process pool is working.", - "pool_metrics": (await self._get_pool()).get_metrics() - } - - def _emit_event(self, callback, event): - """Emit event to callback if provided.""" - if callback: - try: - callback(event) - except Exception as e: - logger.warning(f"Event callback error: {e}") - - def _build_pr_metadata(self, request: ReviewRequestDto) -> str: - """Build PR metadata string.""" - return f"PR #{request.pullRequestId} in {request.workspace}/{request.repoSlug}" - - async def _fetch_rag_context(self, request, event_callback) -> Optional[str]: - """Fetch RAG context if enabled.""" - if not request.ragEnabled: - return None - try: - return await self.rag_client.get_context(request) - except Exception as e: - logger.warning(f"RAG context fetch failed: {e}") - return None - - def _build_prompt(self, request, pr_metadata, rag_context) -> str: - """Build the analysis prompt.""" - return PromptBuilder.build_review_prompt( - pr_metadata=pr_metadata, - rag_context=rag_context, - custom_instructions=request.customInstructions - ) - - def _create_llm(self, request): - """Create LLM instance for the request.""" - return LLMFactory.create( - provider=request.aiProvider, - api_key=request.aiProviderApiKey, - model=request.aiModel, - max_tokens=request.maxAllowedTokens - ) - -#TODO: Implement pooling logic -# Example usage in web server -async def create_pooled_service(): - """Create and initialize a pooled review service.""" - service = PooledReviewService() - - # Pre-warm the pool - if service.pooling_enabled: - pool = await service._get_pool() - logger.info(f"MCP pool initialized: {pool.get_metrics()}") - - return service diff --git a/python-ecosystem/mcp-client/service/rag/__init__.py b/python-ecosystem/mcp-client/service/rag/__init__.py new file mode 100644 index 00000000..b25df464 --- /dev/null +++ b/python-ecosystem/mcp-client/service/rag/__init__.py @@ -0,0 +1,19 @@ +""" +RAG Service Package. + +Contains components for RAG (Retrieval Augmented Generation) functionality: +- rag_client: Client for interacting with the RAG Pipeline API +- llm_reranker: LLM-based reranking for RAG results +""" +from service.rag.rag_client import RagClient, RAG_MIN_RELEVANCE_SCORE, RAG_DEFAULT_TOP_K +from service.rag.llm_reranker import LLMReranker, LLM_RERANK_ENABLED, LLM_RERANK_THRESHOLD, RerankResult + +__all__ = [ + "RagClient", + "RAG_MIN_RELEVANCE_SCORE", + "RAG_DEFAULT_TOP_K", + "LLMReranker", + "LLM_RERANK_ENABLED", + "LLM_RERANK_THRESHOLD", + "RerankResult", +] diff --git a/python-ecosystem/mcp-client/service/llm_reranker.py b/python-ecosystem/mcp-client/service/rag/llm_reranker.py similarity index 100% rename from python-ecosystem/mcp-client/service/llm_reranker.py rename to python-ecosystem/mcp-client/service/rag/llm_reranker.py diff --git a/python-ecosystem/mcp-client/service/rag_client.py b/python-ecosystem/mcp-client/service/rag/rag_client.py similarity index 100% rename from python-ecosystem/mcp-client/service/rag_client.py rename to python-ecosystem/mcp-client/service/rag/rag_client.py diff --git a/python-ecosystem/mcp-client/service/review/__init__.py b/python-ecosystem/mcp-client/service/review/__init__.py new file mode 100644 index 00000000..d136b174 --- /dev/null +++ b/python-ecosystem/mcp-client/service/review/__init__.py @@ -0,0 +1,39 @@ +""" +Review Service Package. + +Contains components for code review functionality: +- review_service: Main entry point for review requests +- orchestrator/: Multi-stage review orchestrator (subpackage) +- issue_processor: Issue post-processing and deduplication +""" +from service.review.review_service import ReviewService +from service.review.orchestrator import ( + MultiStageReviewOrchestrator, + RecursiveMCPAgent, + extract_llm_response_text, + parse_llm_response, + clean_json_text, + reconcile_previous_issues, + deduplicate_issues, +) +from service.review.issue_processor import ( + IssuePostProcessor, + IssueDeduplicator, + post_process_analysis_result, + restore_missing_diffs_from_previous, +) + +__all__ = [ + "ReviewService", + "MultiStageReviewOrchestrator", + "IssuePostProcessor", + "IssueDeduplicator", + "post_process_analysis_result", + "restore_missing_diffs_from_previous", + "RecursiveMCPAgent", + "extract_llm_response_text", + "parse_llm_response", + "clean_json_text", + "reconcile_previous_issues", + "deduplicate_issues", +] diff --git a/python-ecosystem/mcp-client/service/issue_post_processor.py b/python-ecosystem/mcp-client/service/review/issue_processor.py similarity index 100% rename from python-ecosystem/mcp-client/service/issue_post_processor.py rename to python-ecosystem/mcp-client/service/review/issue_processor.py diff --git a/python-ecosystem/mcp-client/service/review/orchestrator/__init__.py b/python-ecosystem/mcp-client/service/review/orchestrator/__init__.py new file mode 100644 index 00000000..36f086a5 --- /dev/null +++ b/python-ecosystem/mcp-client/service/review/orchestrator/__init__.py @@ -0,0 +1,25 @@ +""" +Orchestrator Package. + +Multi-stage review orchestrator components: +- orchestrator: Main MultiStageReviewOrchestrator class +- stages: Stage 0-3 execution methods +- reconciliation: Issue reconciliation and deduplication +- context_helpers: RAG context and diff extraction +- json_utils: JSON parsing and repair +- agents: MCP agent with recursion limit support +""" +from service.review.orchestrator.orchestrator import MultiStageReviewOrchestrator +from service.review.orchestrator.agents import RecursiveMCPAgent, extract_llm_response_text +from service.review.orchestrator.json_utils import parse_llm_response, clean_json_text +from service.review.orchestrator.reconciliation import reconcile_previous_issues, deduplicate_issues + +__all__ = [ + "MultiStageReviewOrchestrator", + "RecursiveMCPAgent", + "extract_llm_response_text", + "parse_llm_response", + "clean_json_text", + "reconcile_previous_issues", + "deduplicate_issues", +] diff --git a/python-ecosystem/mcp-client/service/review/orchestrator/agents.py b/python-ecosystem/mcp-client/service/review/orchestrator/agents.py new file mode 100644 index 00000000..859ce6fc --- /dev/null +++ b/python-ecosystem/mcp-client/service/review/orchestrator/agents.py @@ -0,0 +1,81 @@ +""" +Custom MCP Agent with increased recursion limit. +""" +import logging +from typing import Any +from mcp_use import MCPAgent + +logger = logging.getLogger(__name__) + + +def extract_llm_response_text(response: Any) -> str: + """ + Extract text content from LLM response, handling different response formats. + Some LLM providers return content as a list of objects instead of a string. + """ + if hasattr(response, 'content'): + content = response.content + if isinstance(content, list): + # Handle list content (e.g., from Gemini or other providers) + text_parts = [] + for item in content: + if isinstance(item, str): + text_parts.append(item) + elif isinstance(item, dict): + if 'text' in item: + text_parts.append(item['text']) + elif 'content' in item: + text_parts.append(item['content']) + elif hasattr(item, 'text'): + text_parts.append(item.text) + return "".join(text_parts) + return str(content) + return str(response) + + +# Prevent duplicate logs from mcp_use +mcp_logger = logging.getLogger("mcp_use") +mcp_logger.propagate = False +if not mcp_logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + mcp_logger.addHandler(handler) + + +class RecursiveMCPAgent(MCPAgent): + """ + Subclass of MCPAgent that enforces a higher recursion limit on the internal agent executor. + """ + def __init__(self, *args, recursion_limit: int = 50, **kwargs): + self._custom_recursion_limit = recursion_limit + super().__init__(*args, **kwargs) + + async def stream(self, *args, **kwargs): + """ + Override stream to ensure recursion_limit is applied. + """ + # Ensure the executor exists + if self._agent_executor is None: + await self.initialize() + + # Patch the executor's astream if not already patched + executor = self._agent_executor + if executor and not getattr(executor, "_is_patched_recursion", False): + original_astream = executor.astream + limit = self._custom_recursion_limit + + async def patched_astream(input_data, config=None, **astream_kwargs): + if config is None: + config = {} + config["recursion_limit"] = limit + async for chunk in original_astream(input_data, config=config, **astream_kwargs): + yield chunk + + executor.astream = patched_astream + executor._is_patched_recursion = True + logger.info(f"RecursiveMCPAgent: Patched recursion limit to {limit}") + + # Call parent stream + async for item in super().stream(*args, **kwargs): + yield item diff --git a/python-ecosystem/mcp-client/service/review/orchestrator/context_helpers.py b/python-ecosystem/mcp-client/service/review/orchestrator/context_helpers.py new file mode 100644 index 00000000..b70c3a23 --- /dev/null +++ b/python-ecosystem/mcp-client/service/review/orchestrator/context_helpers.py @@ -0,0 +1,245 @@ +""" +Context and diff extraction helpers for the review orchestrator. +""" +import re +import logging +from typing import Any, Dict, List, Optional, Set + +logger = logging.getLogger(__name__) + + +def extract_symbols_from_diff(diff_content: str) -> List[str]: + """ + Extract potential symbols (identifiers, class names, function names) from diff. + Used to query cross-file context for related changes. + """ + if not diff_content: + return [] + + # Common language keywords/stop-words to filter out + STOP_WORDS = { + # Python + 'import', 'from', 'class', 'def', 'return', 'if', 'else', 'elif', + 'for', 'while', 'try', 'except', 'finally', 'with', 'as', 'pass', + 'break', 'continue', 'raise', 'yield', 'lambda', 'async', 'await', + 'True', 'False', 'None', 'and', 'or', 'not', 'in', 'is', + # Java/TS/JS + 'public', 'private', 'protected', 'static', 'final', 'void', + 'new', 'this', 'super', 'extends', 'implements', 'interface', + 'abstract', 'const', 'let', 'var', 'function', 'export', 'default', + 'throw', 'throws', 'catch', 'instanceof', 'typeof', 'null', + # Common + 'true', 'false', 'null', 'undefined', 'self', 'args', 'kwargs', + 'string', 'number', 'boolean', 'object', 'array', 'list', 'dict', + } + + symbols = set() + + # Patterns for common identifiers + # Match CamelCase identifiers (likely class/component names) + camel_case = re.findall(r'\b([A-Z][a-z]+[A-Z][a-zA-Z]*)\b', diff_content) + symbols.update(camel_case) + + # Match snake_case identifiers (variables, functions) + snake_case = re.findall(r'\b([a-z][a-z0-9]*(?:_[a-z0-9]+)+)\b', diff_content) + symbols.update(s for s in snake_case if len(s) > 5) # Filter short ones + + # Match assignments and function calls + assignments = re.findall(r'\b(\w+)\s*[=:]\s*', diff_content) + symbols.update(a for a in assignments if len(a) > 3) + + # Match import statements + imports = re.findall(r'(?:from|import)\s+([a-zA-Z_][a-zA-Z0-9_.]+)', diff_content) + symbols.update(imports) + + # Filter out stop-words and return + filtered = [s for s in symbols if s.lower() not in STOP_WORDS and len(s) > 2] + return filtered[:20] # Limit to top 20 symbols + + +def extract_diff_snippets(diff_content: str) -> List[str]: + """ + Extract meaningful code snippets from diff content for RAG semantic search. + Focuses on added/modified lines that represent significant code changes. + """ + if not diff_content: + return [] + + snippets = [] + current_snippet_lines = [] + + for line in diff_content.splitlines(): + # Focus on added lines (new code) + if line.startswith("+") and not line.startswith("+++"): + clean_line = line[1:].strip() + # Skip trivial lines + if (clean_line and + len(clean_line) > 10 and # Minimum meaningful length + not clean_line.startswith("//") and # Skip comments + not clean_line.startswith("#") and + not clean_line.startswith("*") and + not clean_line == "{" and + not clean_line == "}" and + not clean_line == ""): + current_snippet_lines.append(clean_line) + + # Batch into snippets of 3-5 lines + if len(current_snippet_lines) >= 3: + snippets.append(" ".join(current_snippet_lines)) + current_snippet_lines = [] + + # Add remaining lines as final snippet + if current_snippet_lines: + snippets.append(" ".join(current_snippet_lines)) + + # Limit to most significant snippets + return snippets[:10] + + +def get_diff_snippets_for_batch( + all_diff_snippets: List[str], + batch_file_paths: List[str] +) -> List[str]: + """ + Filter diff snippets to only include those relevant to the batch files. + + Note: Java DiffParser.extractDiffSnippets() returns CLEAN CODE SNIPPETS (no file paths). + These snippets are just significant code lines like function signatures. + Since snippets don't contain file paths, we return all snippets for semantic search. + The embedding similarity will naturally prioritize relevant matches. + """ + if not all_diff_snippets: + return [] + + # Java snippets are clean code (no file paths), so we can't filter by path + # Return all snippets - the semantic search will find relevant matches + logger.info(f"Using {len(all_diff_snippets)} diff snippets for batch files {batch_file_paths}") + return all_diff_snippets + + +def format_rag_context( + rag_context: Optional[Dict[str, Any]], + relevant_files: Optional[Set[str]] = None, + pr_changed_files: Optional[List[str]] = None +) -> str: + """ + Format RAG context into a readable string for the prompt. + + IMPORTANT: We trust RAG's semantic similarity scores for relevance. + The RAG system already uses embeddings to find semantically related code. + We only filter out chunks from files being modified in the PR (stale data from main branch). + + Args: + rag_context: RAG response with code chunks + relevant_files: (UNUSED - kept for API compatibility) - we trust RAG scores instead + pr_changed_files: Files modified in the PR - chunks from these may be stale + """ + if not rag_context: + logger.debug("RAG context is empty or None") + return "" + + # Handle both "chunks" and "relevant_code" keys (RAG API uses "relevant_code") + chunks = rag_context.get("relevant_code", []) or rag_context.get("chunks", []) + if not chunks: + logger.debug("No chunks found in RAG context (keys: %s)", list(rag_context.keys())) + return "" + + logger.info(f"Processing {len(chunks)} RAG chunks (trusting semantic similarity scores)") + + # Normalize PR changed files for stale-data detection only + pr_changed_set = set() + if pr_changed_files: + for f in pr_changed_files: + pr_changed_set.add(f) + if "/" in f: + pr_changed_set.add(f.rsplit("/", 1)[-1]) + + formatted_parts = [] + included_count = 0 + skipped_stale = 0 + + for chunk in chunks: + if included_count >= 15: + logger.debug(f"Reached chunk limit of 15") + break + + metadata = chunk.get("metadata", {}) + # Support both 'path' and 'file_path' keys (deterministic uses file_path) + path = metadata.get("path") or chunk.get("path") or chunk.get("file_path", "unknown") + chunk_type = metadata.get("content_type", metadata.get("type", "code")) + score = chunk.get("score", chunk.get("relevance_score", 0)) + + # Only filter: chunks from PR-modified files with LOW scores (likely stale) + # High-score chunks from modified files may still be relevant (other parts of same file) + if pr_changed_set: + path_filename = path.rsplit("/", 1)[-1] if "/" in path else path + is_from_modified_file = ( + path in pr_changed_set or + path_filename in pr_changed_set or + any(path.endswith(f) or f.endswith(path) for f in pr_changed_set) + ) + + # Skip ONLY low-score chunks from modified files (likely stale/outdated) + if is_from_modified_file and score < 0.70: + logger.debug(f"Skipping stale chunk from modified file: {path} (score={score:.2f})") + skipped_stale += 1 + continue + + text = chunk.get("text", chunk.get("content", "")) + if not text: + continue + + included_count += 1 + + # Build rich metadata context + meta_lines = [f"File: {path}"] + + if metadata.get("namespace"): + meta_lines.append(f"Namespace: {metadata['namespace']}") + elif metadata.get("package"): + meta_lines.append(f"Package: {metadata['package']}") + + if metadata.get("primary_name"): + meta_lines.append(f"Definition: {metadata['primary_name']}") + elif metadata.get("semantic_names"): + meta_lines.append(f"Definitions: {', '.join(metadata['semantic_names'][:5])}") + + if metadata.get("extends"): + extends = metadata["extends"] + meta_lines.append(f"Extends: {', '.join(extends) if isinstance(extends, list) else extends}") + + if metadata.get("implements"): + implements = metadata["implements"] + meta_lines.append(f"Implements: {', '.join(implements) if isinstance(implements, list) else implements}") + + if metadata.get("imports"): + imports = metadata["imports"] + if isinstance(imports, list): + if len(imports) <= 5: + meta_lines.append(f"Imports: {'; '.join(imports)}") + else: + meta_lines.append(f"Imports: {'; '.join(imports[:5])}... (+{len(imports)-5} more)") + + if metadata.get("parent_context"): + parent_ctx = metadata["parent_context"] + if isinstance(parent_ctx, list): + meta_lines.append(f"Parent: {'.'.join(parent_ctx)}") + + if chunk_type and chunk_type != "code": + meta_lines.append(f"Type: {chunk_type}") + + meta_text = "\n".join(meta_lines) + # Use file path as primary identifier, not a number + # This encourages AI to reference by path rather than by chunk number + formatted_parts.append( + f"### Context from `{path}` (relevance: {score:.2f})\n" + f"{meta_text}\n" + f"```\n{text}\n```\n" + ) + + if not formatted_parts: + logger.warning(f"No RAG chunks included (total: {len(chunks)}, skipped_stale: {skipped_stale})") + return "" + + logger.info(f"Included {len(formatted_parts)} RAG chunks (skipped {skipped_stale} stale from modified files)") + return "\n".join(formatted_parts) diff --git a/python-ecosystem/mcp-client/service/review/orchestrator/json_utils.py b/python-ecosystem/mcp-client/service/review/orchestrator/json_utils.py new file mode 100644 index 00000000..82e15b03 --- /dev/null +++ b/python-ecosystem/mcp-client/service/review/orchestrator/json_utils.py @@ -0,0 +1,155 @@ +""" +JSON parsing, repair, and cleaning utilities for LLM responses. +""" +import json +import logging +from typing import Any, Dict, Optional + +from service.review.orchestrator.agents import extract_llm_response_text + +logger = logging.getLogger(__name__) + + +async def parse_llm_response(content: str, model_class: Any, llm, retries: int = 2) -> Any: + """ + Robustly parse JSON response into a Pydantic model with retries. + Falls back to manual parsing if structured output wasn't used. + """ + last_error = None + + # Initial cleaning attempt + try: + cleaned = clean_json_text(content) + logger.debug(f"Cleaned JSON for {model_class.__name__} (first 500 chars): {cleaned[:500]}") + data = json.loads(cleaned) + return model_class(**data) + except Exception as e: + last_error = e + logger.warning(f"Initial parse failed for {model_class.__name__}: {e}") + logger.debug(f"Raw content (first 1000 chars): {content[:1000]}") + + # Retry with structured output if available + try: + logger.info(f"Attempting structured output retry for {model_class.__name__}") + structured_llm = llm.with_structured_output(model_class) + result = await structured_llm.ainvoke( + f"Parse and return this as valid {model_class.__name__}:\n{content[:4000]}" + ) + if result: + logger.info(f"Structured output retry succeeded for {model_class.__name__}") + return result + except Exception as e: + logger.warning(f"Structured output retry failed: {e}") + last_error = e + + # Final fallback: LLM repair loop + for attempt in range(retries): + try: + logger.info(f"Repairing JSON for {model_class.__name__}, attempt {attempt+1}") + repaired = await repair_json_with_llm( + llm, + content, + str(last_error), + model_class.model_json_schema() + ) + cleaned = clean_json_text(repaired) + logger.debug(f"Repaired JSON attempt {attempt+1} (first 500 chars): {cleaned[:500]}") + data = json.loads(cleaned) + return model_class(**data) + except Exception as e: + last_error = e + logger.warning(f"Retry {attempt+1} failed: {e}") + + raise ValueError(f"Failed to parse {model_class.__name__} after retries: {last_error}") + + +async def repair_json_with_llm(llm, broken_json: str, error: str, schema: Any) -> str: + """ + Ask LLM to repair malformed JSON. + """ + # Truncate the broken JSON to avoid token limits but show enough context + truncated_json = broken_json[:3000] if len(broken_json) > 3000 else broken_json + + prompt = f"""You are a JSON repair expert. +The following JSON failed to parse/validate: +Error: {error} + +Broken JSON: +{truncated_json} + +Required Schema (the output MUST be a JSON object, not an array): +{json.dumps(schema, indent=2)} + +CRITICAL INSTRUCTIONS: +1. Return ONLY the fixed valid JSON object +2. The response MUST start with {{ and end with }} +3. All property names MUST be enclosed in double quotes +4. No markdown code blocks (no ```) +5. No explanatory text before or after the JSON +6. Ensure all required fields from the schema are present + +Output the corrected JSON object now:""" + response = await llm.ainvoke(prompt) + return extract_llm_response_text(response) + + +def clean_json_text(text: str) -> str: + """ + Clean markdown and extraneous text from JSON. + """ + text = text.strip() + + # Remove markdown code blocks + if text.startswith("```"): + lines = text.split("\n") + # Skip the opening ``` line (with or without language identifier) + lines = lines[1:] + # Remove trailing ``` if present + if lines and lines[-1].strip() == "```": + lines = lines[:-1] + text = "\n".join(lines).strip() + + # Also handle case where ``` appears mid-text + if "```json" in text: + start_idx = text.find("```json") + end_idx = text.find("```", start_idx + 7) + if end_idx != -1: + text = text[start_idx + 7:end_idx].strip() + else: + text = text[start_idx + 7:].strip() + elif "```" in text: + # Generic code block without language + start_idx = text.find("```") + remaining = text[start_idx + 3:] + end_idx = remaining.find("```") + if end_idx != -1: + text = remaining[:end_idx].strip() + else: + text = remaining.strip() + + # Find JSON object boundaries + obj_start = text.find("{") + obj_end = text.rfind("}") + arr_start = text.find("[") + arr_end = text.rfind("]") + + # Determine if we have an object or array (whichever comes first) + if obj_start != -1 and obj_end != -1: + if arr_start == -1 or obj_start < arr_start: + # Object comes first or no array + text = text[obj_start:obj_end+1] + elif arr_start < obj_start and arr_end != -1: + # Array comes first - but we need an object for Pydantic + # Check if the object is nested inside the array or separate + if obj_end > arr_end: + # Object extends beyond array - likely the object we want + text = text[obj_start:obj_end+1] + else: + # Try to use the object anyway + text = text[obj_start:obj_end+1] + elif arr_start != -1 and arr_end != -1 and obj_start == -1: + # Only array found - log warning as Pydantic models expect objects + logger.warning(f"JSON cleaning found array instead of object, this may fail parsing") + text = text[arr_start:arr_end+1] + + return text diff --git a/python-ecosystem/mcp-client/service/review/orchestrator/orchestrator.py b/python-ecosystem/mcp-client/service/review/orchestrator/orchestrator.py new file mode 100644 index 00000000..c843f22c --- /dev/null +++ b/python-ecosystem/mcp-client/service/review/orchestrator/orchestrator.py @@ -0,0 +1,247 @@ +""" +Multi-Stage Review Orchestrator. + +Orchestrates the 4-stage AI code review pipeline: +- Stage 0: Planning & Prioritization +- Stage 1: Parallel File Review +- Stage 2: Cross-File & Architectural Analysis +- Stage 3: Aggregation & Final Report +""" +import logging +from typing import Dict, Any, List, Optional, Callable + +from model.dtos import ReviewRequestDto +from utils.diff_processor import ProcessedDiff + +from service.review.orchestrator.reconciliation import reconcile_previous_issues +from service.review.orchestrator.stages import ( + execute_branch_analysis, + execute_stage_0_planning, + execute_stage_1_file_reviews, + execute_stage_2_cross_file, + execute_stage_3_aggregation, + _emit_status, + _emit_progress, + _emit_error, +) + +logger = logging.getLogger(__name__) + + +class MultiStageReviewOrchestrator: + """ + Orchestrates the 4-stage AI code review pipeline: + Stage 0: Planning & Prioritization + Stage 1: Parallel File Review + Stage 2: Cross-File & Architectural Analysis + Stage 3: Aggregation & Final Report + """ + + def __init__( + self, + llm, + mcp_client, + rag_client=None, + event_callback: Optional[Callable[[Dict], None]] = None + ): + self.llm = llm + self.client = mcp_client + self.rag_client = rag_client + self.event_callback = event_callback + self.max_parallel_stage_1 = 5 + self._pr_number: Optional[int] = None + self._pr_indexed: bool = False + + async def _index_pr_files( + self, + request: ReviewRequestDto, + processed_diff: Optional[ProcessedDiff] + ) -> None: + """ + Index PR files into the main RAG collection with PR-specific metadata. + This enables hybrid queries that prioritize PR data over stale branch data. + """ + if not self.rag_client or not processed_diff: + return + + pr_number = request.pullRequestId + if not pr_number: + logger.info("No PR number, skipping PR file indexing") + return + + files = [] + for f in processed_diff.get_included_files(): + content = f.full_content or f.content + change_type = f.change_type.value if hasattr(f.change_type, 'value') else str(f.change_type) + if content and change_type != "DELETED": + files.append({ + "path": f.path, + "content": content, + "change_type": change_type + }) + + if not files: + logger.info("No files to index for PR") + return + + try: + result = await self.rag_client.index_pr_files( + workspace=request.projectWorkspace, + project=request.projectNamespace, + pr_number=pr_number, + branch=request.targetBranchName or "unknown", + files=files + ) + if result.get("status") == "indexed": + self._pr_number = pr_number + self._pr_indexed = True + logger.info(f"Indexed PR #{pr_number}: {result.get('chunks_indexed', 0)} chunks") + else: + logger.warning(f"Failed to index PR files: {result}") + except Exception as e: + logger.warning(f"Error indexing PR files: {e}") + + async def _cleanup_pr_files(self, request: ReviewRequestDto) -> None: + """Delete PR-indexed data after analysis completes.""" + if not self._pr_indexed or not self._pr_number or not self.rag_client: + return + + try: + await self.rag_client.delete_pr_files( + workspace=request.projectWorkspace, + project=request.projectNamespace, + pr_number=self._pr_number + ) + logger.info(f"Cleaned up PR #{self._pr_number} indexed data") + except Exception as e: + logger.warning(f"Failed to cleanup PR files: {e}") + finally: + self._pr_number = None + self._pr_indexed = False + + async def execute_branch_analysis(self, prompt: str) -> Dict[str, Any]: + """ + Execute a single-pass branch analysis using the provided prompt. + """ + return await execute_branch_analysis( + self.llm, + self.client, + prompt, + self.event_callback + ) + + async def orchestrate_review( + self, + request: ReviewRequestDto, + rag_context: Optional[Dict[str, Any]] = None, + processed_diff: Optional[ProcessedDiff] = None + ) -> Dict[str, Any]: + """ + Main entry point for the multi-stage review. + Supports both FULL (initial review) and INCREMENTAL (follow-up review) modes. + """ + is_incremental = ( + request.analysisMode == "INCREMENTAL" + and request.deltaDiff + ) + + if is_incremental: + logger.info(f"INCREMENTAL mode: reviewing delta diff, {len(request.previousCodeAnalysisIssues or [])} previous issues to reconcile") + else: + logger.info("FULL mode: initial PR review") + + try: + # Index PR files into RAG for hybrid queries + await self._index_pr_files(request, processed_diff) + + # === STAGE 0: Planning === + _emit_status(self.event_callback, "stage_0_started", "Stage 0: Planning & Prioritization...") + review_plan = await execute_stage_0_planning(self.llm, request, is_incremental) + + review_plan = self._ensure_all_files_planned(review_plan, request.changedFiles or []) + _emit_progress(self.event_callback, 10, "Stage 0 Complete: Review plan created") + + # === STAGE 1: File Reviews === + _emit_status(self.event_callback, "stage_1_started", f"Stage 1: Analyzing {self._count_files(review_plan)} files...") + file_issues = await execute_stage_1_file_reviews( + self.llm, + request, + review_plan, + self.rag_client, + rag_context, + processed_diff, + is_incremental, + self.max_parallel_stage_1, + self.event_callback, + self._pr_indexed + ) + _emit_progress(self.event_callback, 60, f"Stage 1 Complete: {len(file_issues)} issues found across files") + + # === STAGE 1.5: Issue Reconciliation === + if request.previousCodeAnalysisIssues: + _emit_status(self.event_callback, "reconciliation_started", "Reconciling previous issues...") + file_issues = await reconcile_previous_issues( + request, file_issues, processed_diff + ) + _emit_progress(self.event_callback, 70, f"Reconciliation Complete: {len(file_issues)} total issues after reconciliation") + + # === STAGE 2: Cross-File Analysis === + _emit_status(self.event_callback, "stage_2_started", "Stage 2: Analyzing cross-file patterns...") + cross_file_results = await execute_stage_2_cross_file( + self.llm, request, file_issues, review_plan + ) + _emit_progress(self.event_callback, 85, "Stage 2 Complete: Cross-file analysis finished") + + # === STAGE 3: Aggregation === + _emit_status(self.event_callback, "stage_3_started", "Stage 3: Generating final report...") + final_report = await execute_stage_3_aggregation( + self.llm, request, review_plan, file_issues, cross_file_results, is_incremental + ) + _emit_progress(self.event_callback, 100, "Stage 3 Complete: Report generated") + + return { + "comment": final_report, + "issues": [issue.model_dump() for issue in file_issues], + } + + except Exception as e: + logger.error(f"Multi-stage review failed: {e}", exc_info=True) + _emit_error(self.event_callback, str(e)) + raise + finally: + await self._cleanup_pr_files(request) + + def _count_files(self, plan) -> int: + """Count total files in review plan.""" + return sum(len(g.files) for g in plan.file_groups) + + def _ensure_all_files_planned(self, plan, changed_files: List[str]): + """ + Ensure all changed files are included in the review plan. + LLM may miss some files, so we add them to a catch-all group. + """ + from model.multi_stage import ReviewFile, FileGroup + + planned_files = set() + for group in plan.file_groups: + for f in group.files: + planned_files.add(f.path) + + missing_files = [f for f in changed_files if f not in planned_files] + + if missing_files: + logger.warning(f"Stage 0 missed {len(missing_files)} files, adding to catch-all group") + catch_all_files = [ + ReviewFile(path=f, focus_areas=["general review"]) + for f in missing_files + ] + plan.file_groups.append( + FileGroup( + name="uncategorized", + priority="MEDIUM", + rationale="Files not categorized by initial planning", + files=catch_all_files + ) + ) + + return plan diff --git a/python-ecosystem/mcp-client/service/review/orchestrator/reconciliation.py b/python-ecosystem/mcp-client/service/review/orchestrator/reconciliation.py new file mode 100644 index 00000000..1a20008b --- /dev/null +++ b/python-ecosystem/mcp-client/service/review/orchestrator/reconciliation.py @@ -0,0 +1,280 @@ +""" +Issue reconciliation and deduplication logic for incremental reviews. +""" +import logging +from typing import Any, Dict, List, Optional + +from model.output_schemas import CodeReviewIssue + +logger = logging.getLogger(__name__) + + +def issue_matches_files(issue: Any, file_paths: List[str]) -> bool: + """Check if an issue is related to any of the given file paths.""" + if hasattr(issue, 'model_dump'): + issue_data = issue.model_dump() + elif isinstance(issue, dict): + issue_data = issue + else: + issue_data = vars(issue) if hasattr(issue, '__dict__') else {} + + issue_file = issue_data.get('file', issue_data.get('filePath', '')) + + for fp in file_paths: + if issue_file == fp or issue_file.endswith('/' + fp) or fp.endswith('/' + issue_file): + return True + # Also check basename match + if issue_file.split('/')[-1] == fp.split('/')[-1]: + return True + return False + + +def compute_issue_fingerprint(data: dict) -> str: + """Compute a fingerprint for issue deduplication. + + Uses file + normalized line (±3 tolerance) + severity + truncated reason. + """ + file_path = data.get('file', data.get('filePath', '')) + line = data.get('line', data.get('lineNumber', 0)) + line_group = int(line) // 3 if line else 0 + severity = data.get('severity', '') + reason = data.get('reason', data.get('description', '')) + reason_prefix = reason[:50].lower().strip() if reason else '' + + return f"{file_path}::{line_group}::{severity}::{reason_prefix}" + + +def deduplicate_issues(issues: List[Any]) -> List[dict]: + """Deduplicate issues by fingerprint, keeping most recent version. + + If an older version is resolved but newer isn't, preserves resolved status. + """ + if not issues: + return [] + + deduped: dict = {} + + for issue in issues: + if hasattr(issue, 'model_dump'): + data = issue.model_dump() + elif isinstance(issue, dict): + data = issue.copy() + else: + data = vars(issue).copy() if hasattr(issue, '__dict__') else {} + + fingerprint = compute_issue_fingerprint(data) + existing = deduped.get(fingerprint) + + if existing is None: + deduped[fingerprint] = data + else: + existing_version = existing.get('prVersion') or 0 + current_version = data.get('prVersion') or 0 + existing_resolved = existing.get('status', '').lower() == 'resolved' + current_resolved = data.get('status', '').lower() == 'resolved' + + if current_version > existing_version: + # Current is newer + if existing_resolved and not current_resolved: + # Preserve resolved status from older version + data['status'] = 'resolved' + data['resolvedDescription'] = existing.get('resolvedDescription') + data['resolvedByCommit'] = existing.get('resolvedByCommit') + data['resolvedInPrVersion'] = existing.get('resolvedInPrVersion') + deduped[fingerprint] = data + elif current_version == existing_version: + # Same version - prefer resolved + if current_resolved and not existing_resolved: + deduped[fingerprint] = data + + return list(deduped.values()) + + +def format_previous_issues_for_batch(issues: List[Any]) -> str: + """Format previous issues for inclusion in batch prompt. + + Deduplicates issues first, then formats with resolution tracking so LLM knows: + - Which issues were previously found + - Which have been resolved (and how) + - Which PR version each issue was found/resolved in + """ + if not issues: + return "" + + # Deduplicate issues to avoid confusing the LLM with duplicates + deduped_issues = deduplicate_issues(issues) + + # Separate OPEN and RESOLVED issues + open_issues = [i for i in deduped_issues if i.get('status', '').lower() != 'resolved'] + resolved_issues = [i for i in deduped_issues if i.get('status', '').lower() == 'resolved'] + + lines = ["=== PREVIOUS ISSUES HISTORY (check if resolved/persisting) ==="] + lines.append("Issues have been deduplicated. Only check OPEN issues - RESOLVED ones are for context only.") + lines.append("") + + if open_issues: + lines.append("--- OPEN ISSUES (check if now fixed) ---") + for data in open_issues: + issue_id = data.get('id', 'unknown') + severity = data.get('severity', 'MEDIUM') + file_path = data.get('file', data.get('filePath', 'unknown')) + line = data.get('line', data.get('lineNumber', '?')) + reason = data.get('reason', data.get('description', 'No description')) + pr_version = data.get('prVersion', '?') + + lines.append(f"[ID:{issue_id}] {severity} @ {file_path}:{line} (v{pr_version})") + lines.append(f" Issue: {reason}") + lines.append("") + + if resolved_issues: + lines.append("--- RESOLVED ISSUES (for context only, do NOT re-report) ---") + for data in resolved_issues: + issue_id = data.get('id', 'unknown') + severity = data.get('severity', 'MEDIUM') + file_path = data.get('file', data.get('filePath', 'unknown')) + line = data.get('line', data.get('lineNumber', '?')) + reason = data.get('reason', data.get('description', 'No description')) + pr_version = data.get('prVersion', '?') + resolved_desc = data.get('resolvedDescription', '') + resolved_in = data.get('resolvedInPrVersion', '') + + lines.append(f"[ID:{issue_id}] {severity} @ {file_path}:{line} (v{pr_version}) - RESOLVED") + if resolved_desc: + lines.append(f" Resolution: {resolved_desc}") + if resolved_in: + lines.append(f" Resolved in: v{resolved_in}") + lines.append(f" Original issue: {reason}") + lines.append("") + + lines.append("INSTRUCTIONS:") + lines.append("- For OPEN issues that are now FIXED: report with 'isResolved': true (boolean)") + lines.append("- For OPEN issues still present: report with 'isResolved': false (boolean)") + lines.append("- Do NOT re-report RESOLVED issues - they are only shown for context") + lines.append("- IMPORTANT: 'isResolved' MUST be a JSON boolean (true/false), not a string") + lines.append("- Preserve the 'id' field for all issues you report from previous issues") + lines.append("=== END PREVIOUS ISSUES ===") + return "\n".join(lines) + + +async def reconcile_previous_issues( + request, + new_issues: List[CodeReviewIssue], + processed_diff = None +) -> List[CodeReviewIssue]: + """ + Reconcile previous issues with new findings in incremental mode. + - Mark resolved issues as isResolved=true + - Update line numbers for persisting issues + - Merge with new issues found in delta diff + - PRESERVE original issue data (reason, suggestedFixDescription, suggestedFixDiff) + """ + if not request.previousCodeAnalysisIssues: + return new_issues + + logger.info(f"Reconciling {len(request.previousCodeAnalysisIssues)} previous issues with {len(new_issues)} new issues") + + # Current commit for resolution tracking + current_commit = request.currentCommitHash or request.commitHash + + # Get the delta diff content to check what files/lines changed + delta_diff = request.deltaDiff or "" + + # Build a set of files that changed in the delta + changed_files_in_delta = set() + if processed_diff: + for f in processed_diff.files: + changed_files_in_delta.add(f.path) + + # Build lookup of previous issues by ID for merging with LLM results + prev_issues_by_id = {} + for prev_issue in request.previousCodeAnalysisIssues: + if hasattr(prev_issue, 'model_dump'): + prev_data = prev_issue.model_dump() + else: + prev_data = prev_issue if isinstance(prev_issue, dict) else vars(prev_issue) + issue_id = prev_data.get('id') + if issue_id: + prev_issues_by_id[str(issue_id)] = prev_data + + reconciled_issues = [] + processed_prev_ids = set() # Track which previous issues we've handled + + # Process new issues from LLM - merge with previous issue data if they reference same ID + for new_issue in new_issues: + new_data = new_issue.model_dump() if hasattr(new_issue, 'model_dump') else new_issue + issue_id = new_data.get('id') + + # If this issue references a previous issue ID, merge data + if issue_id and str(issue_id) in prev_issues_by_id: + prev_data = prev_issues_by_id[str(issue_id)] + processed_prev_ids.add(str(issue_id)) + + # Check if LLM marked it resolved + is_resolved = new_data.get('isResolved', False) + + # PRESERVE original data, use LLM's reason as resolution explanation + merged_issue = CodeReviewIssue( + id=str(issue_id), + severity=(prev_data.get('severity') or prev_data.get('issueSeverity') or 'MEDIUM').upper(), + category=prev_data.get('category') or prev_data.get('issueCategory') or prev_data.get('type') or 'CODE_QUALITY', + file=prev_data.get('file') or prev_data.get('filePath') or new_data.get('file', 'unknown'), + line=str(prev_data.get('line') or prev_data.get('lineNumber') or new_data.get('line', '1')), + # PRESERVE original reason and fix description + reason=prev_data.get('reason') or prev_data.get('title') or prev_data.get('description') or '', + suggestedFixDescription=prev_data.get('suggestedFixDescription') or prev_data.get('suggestedFix') or '', + suggestedFixDiff=prev_data.get('suggestedFixDiff') or None, + isResolved=is_resolved, + # Store LLM's explanation separately if resolved + resolutionExplanation=new_data.get('reason') if is_resolved else None, + resolvedInCommit=current_commit if is_resolved else None, + visibility=prev_data.get('visibility'), + codeSnippet=prev_data.get('codeSnippet') + ) + reconciled_issues.append(merged_issue) + else: + # New issue not referencing previous - keep as is + reconciled_issues.append(new_issue) + + # Process remaining previous issues not handled by LLM + for prev_issue in request.previousCodeAnalysisIssues: + if hasattr(prev_issue, 'model_dump'): + prev_data = prev_issue.model_dump() + else: + prev_data = prev_issue if isinstance(prev_issue, dict) else vars(prev_issue) + + issue_id = prev_data.get('id') + if issue_id and str(issue_id) in processed_prev_ids: + continue # Already handled above + + file_path = prev_data.get('file', prev_data.get('filePath', '')) + + # Check if this issue was already found in new issues (by file+line) + already_reported = False + for new_issue in new_issues: + new_data = new_issue.model_dump() if hasattr(new_issue, 'model_dump') else new_issue + if (new_data.get('file') == file_path and + str(new_data.get('line')) == str(prev_data.get('line', prev_data.get('lineNumber')))): + already_reported = True + break + + if already_reported: + continue + + # Preserve all original issue data + persisting_issue = CodeReviewIssue( + id=str(issue_id) if issue_id else None, + severity=(prev_data.get('severity') or prev_data.get('issueSeverity') or 'MEDIUM').upper(), + category=prev_data.get('category') or prev_data.get('issueCategory') or prev_data.get('type') or 'CODE_QUALITY', + file=file_path or prev_data.get('file') or prev_data.get('filePath') or 'unknown', + line=str(prev_data.get('line') or prev_data.get('lineNumber') or '1'), + reason=prev_data.get('reason') or prev_data.get('title') or prev_data.get('description') or '', + suggestedFixDescription=prev_data.get('suggestedFixDescription') or prev_data.get('suggestedFix') or '', + suggestedFixDiff=prev_data.get('suggestedFixDiff') or None, + isResolved=False, + visibility=prev_data.get('visibility'), + codeSnippet=prev_data.get('codeSnippet') + ) + reconciled_issues.append(persisting_issue) + + logger.info(f"Reconciliation complete: {len(reconciled_issues)} total issues") + return reconciled_issues diff --git a/python-ecosystem/mcp-client/service/review/orchestrator/stages.py b/python-ecosystem/mcp-client/service/review/orchestrator/stages.py new file mode 100644 index 00000000..a2a98e3c --- /dev/null +++ b/python-ecosystem/mcp-client/service/review/orchestrator/stages.py @@ -0,0 +1,666 @@ +""" +Stage execution methods for the multi-stage review pipeline. +""" +import json +import asyncio +import logging +import time +from typing import Any, Dict, List, Optional, Callable + +from model.dtos import ReviewRequestDto +from model.output_schemas import CodeReviewOutput, CodeReviewIssue +from model.multi_stage import ( + ReviewPlan, + ReviewFile, + FileGroup, + CrossFileAnalysisResult, + FileReviewBatchOutput, +) +from utils.prompts.prompt_builder import PromptBuilder +from utils.diff_processor import ProcessedDiff, DiffProcessor +from utils.dependency_graph import create_smart_batches + +from service.review.orchestrator.agents import RecursiveMCPAgent, extract_llm_response_text +from service.review.orchestrator.json_utils import parse_llm_response +from service.review.orchestrator.reconciliation import ( + issue_matches_files, + format_previous_issues_for_batch, +) +from service.review.orchestrator.context_helpers import ( + extract_diff_snippets, + format_rag_context, +) + +logger = logging.getLogger(__name__) + + +async def execute_branch_analysis( + llm, + client, + prompt: str, + event_callback: Optional[Callable[[Dict], None]] = None +) -> Dict[str, Any]: + """ + Execute a single-pass branch analysis using the provided prompt. + """ + _emit_status(event_callback, "branch_analysis_started", "Starting Branch Analysis & Reconciliation...") + + agent = RecursiveMCPAgent( + llm=llm, + client=client, + additional_instructions=PromptBuilder.get_additional_instructions() + ) + + + try: + final_text = "" + # Branch analysis expects standard CodeReviewOutput + async for item in agent.stream(prompt, max_steps=15, output_schema=CodeReviewOutput): + if isinstance(item, CodeReviewOutput): + # Convert to dict format expected by service + issues = [i.model_dump() for i in item.issues] if item.issues else [] + return { + "issues": issues, + "comment": item.comment or "Branch analysis completed." + } + + if isinstance(item, str): + final_text = item + + # If stream finished without object, try parsing text + if final_text: + data = await parse_llm_response(final_text, CodeReviewOutput, llm) + issues = [i.model_dump() for i in data.issues] if data.issues else [] + return { + "issues": issues, + "comment": data.comment or "Branch analysis completed." + } + + return {"issues": [], "comment": "No issues found."} + + except Exception as e: + logger.error(f"Branch analysis failed: {e}", exc_info=True) + _emit_error(event_callback, str(e)) + raise + + +async def execute_stage_0_planning( + llm, + request: ReviewRequestDto, + is_incremental: bool = False +) -> ReviewPlan: + """ + Stage 0: Analyze metadata and generate a review plan. + Uses structured output for reliable JSON parsing. + """ + # Prepare context for prompt + changed_files_summary = [] + if request.changedFiles: + for f in request.changedFiles: + changed_files_summary.append({ + "path": f, + "type": "MODIFIED", + "lines_added": "?", + "lines_deleted": "?" + }) + + prompt = PromptBuilder.build_stage_0_planning_prompt( + repo_slug=request.projectVcsRepoSlug, + pr_id=str(request.pullRequestId), + pr_title=request.prTitle or "", + author="Unknown", + branch_name="source-branch", + target_branch=request.targetBranchName or "main", + commit_hash=request.commitHash or "HEAD", + changed_files_json=json.dumps(changed_files_summary, indent=2) + ) + + # Stage 0 uses direct LLM call (no tools needed - all metadata is provided) + try: + structured_llm = llm.with_structured_output(ReviewPlan) + result = await structured_llm.ainvoke(prompt) + if result: + logger.info("Stage 0 planning completed with structured output") + return result + except Exception as e: + logger.warning(f"Structured output failed for Stage 0: {e}") + + # Fallback to manual parsing + try: + response = await llm.ainvoke(prompt) + content = extract_llm_response_text(response) + return await parse_llm_response(content, ReviewPlan, llm) + except Exception as e: + logger.error(f"Stage 0 planning failed: {e}") + raise ValueError(f"Stage 0 planning failed: {e}") + + +def chunk_files(file_groups: List[Any], max_files_per_batch: int = 5) -> List[List[Dict[str, Any]]]: + """ + Flatten file groups and chunk into batches. + DEPRECATED: Use create_smart_batches for dependency-aware batching. + Kept for fallback when diff content is unavailable. + """ + all_files = [] + for group in file_groups: + for f in group.files: + # Attach priority context for the review + all_files.append({ + "file": f, + "priority": group.priority + }) + + return [all_files[i:i + max_files_per_batch] for i in range(0, len(all_files), max_files_per_batch)] + + +def create_smart_batches_wrapper( + file_groups: List[Any], + processed_diff: Optional[ProcessedDiff], + request: ReviewRequestDto, + rag_client, + max_files_per_batch: int = 7 +) -> List[List[Dict[str, Any]]]: + """ + Create dependency-aware batches that keep related files together. + + If enrichment data is available from Java (pre-computed relationships), + use it directly. Otherwise, use RAG's tree-sitter metadata to discover + file relationships: + - imports/exports: which files use symbols from other files + - class context: methods in the same class + - namespace context: files in the same package + + Falls back to directory-based grouping if both are unavailable. + """ + # Build branches list for RAG query + branches = [] + if request.targetBranchName: + branches.append(request.targetBranchName) + # Note: sourceBranchName is not in ReviewRequestDto, skip this check + if not branches: + branches = ['main', 'master'] # Fallback + + # Check for enrichment data from Java + enrichment_data = getattr(request, 'enrichmentData', None) + + try: + # Use RAG-based dependency analysis for intelligent batching + batches = create_smart_batches( + file_groups=file_groups, + workspace=request.projectWorkspace, + project=request.projectNamespace, + branches=branches, + rag_client=rag_client, + max_batch_size=max_files_per_batch, + enrichment_data=enrichment_data + ) + + # Log relationship analysis results + total_files = sum(len(b) for b in batches) + related_files = sum(1 for b in batches for f in b if f.get('has_relationships')) + enrichment_source = "enrichment data" if enrichment_data else "RAG discovery" + logger.info(f"Smart batching ({enrichment_source}): {total_files} files in {len(batches)} batches, " + f"{related_files} files have cross-file relationships") + + return batches + except Exception as e: + logger.warning(f"Smart batching failed, falling back to simple batching: {e}") + return chunk_files(file_groups, max_files_per_batch) + + +async def execute_stage_1_file_reviews( + llm, + request: ReviewRequestDto, + plan: ReviewPlan, + rag_client, + rag_context: Optional[Dict[str, Any]] = None, + processed_diff: Optional[ProcessedDiff] = None, + is_incremental: bool = False, + max_parallel: int = 5, + event_callback: Optional[Callable[[Dict], None]] = None, + pr_indexed: bool = False +) -> List[CodeReviewIssue]: + """ + Stage 1: Execute batch file reviews with per-batch RAG context. + Uses dependency-aware batching to keep related files together. + """ + # Use smart batching with RAG-based relationship discovery + batches = create_smart_batches_wrapper( + plan.file_groups, processed_diff, request, rag_client, max_files_per_batch=7 + ) + + total_files = sum(len(batch) for batch in batches) + related_batches = sum(1 for b in batches if any(f.get('has_relationships') for f in b)) + logger.info(f"Stage 1: Processing {total_files} files in {len(batches)} batches " + f"({related_batches} batches with cross-file relationships)") + + # Process batches with controlled parallelism + all_issues = [] + + # Process in waves to avoid rate limits + for wave_start in range(0, len(batches), max_parallel): + wave_end = min(wave_start + max_parallel, len(batches)) + wave_batches = batches[wave_start:wave_end] + wave_num = wave_start // max_parallel + 1 + + logger.info(f"Stage 1: Processing wave {wave_num}, " + f"batches {wave_start + 1}-{wave_end} of {len(batches)} IN PARALLEL") + + wave_start_time = time.time() + + tasks = [] + for batch_idx, batch in enumerate(wave_batches, start=wave_start + 1): + batch_paths = [item["file"].path for item in batch] + has_rels = any(item.get('has_relationships') for item in batch) + logger.debug(f"Batch {batch_idx}: {batch_paths} (cross-file relationships: {has_rels})") + # Create coroutine with batch_idx for tracking + tasks.append(_review_batch_with_timing( + batch_idx, llm, request, batch, rag_client, processed_diff, + is_incremental, rag_context, pr_indexed + )) + + # asyncio.gather runs all tasks CONCURRENTLY + results = await asyncio.gather(*tasks, return_exceptions=True) + + wave_elapsed = time.time() - wave_start_time + logger.info(f"Wave {wave_num} completed in {wave_elapsed:.2f}s " + f"({len(wave_batches)} batches parallel)") + + for idx, res in enumerate(results): + batch_num = wave_start + idx + 1 + if isinstance(res, Exception): + logger.error(f"Error reviewing batch {batch_num}: {res}") + elif res: + logger.info(f"Batch {batch_num} completed: {len(res)} issues found") + all_issues.extend(res) + else: + logger.info(f"Batch {batch_num} completed: no issues found") + + # Update progress + progress = 10 + int((wave_end / len(batches)) * 50) + _emit_progress(event_callback, progress, f"Stage 1: Reviewed {wave_end}/{len(batches)} batches") + + logger.info(f"Stage 1 Complete: {len(all_issues)} issues found across {total_files} files") + return all_issues + + +async def _review_batch_with_timing( + batch_idx: int, + llm, + request: ReviewRequestDto, + batch: List[Dict[str, Any]], + rag_client, + processed_diff: Optional[ProcessedDiff], + is_incremental: bool, + fallback_rag_context: Optional[Dict[str, Any]], + pr_indexed: bool +) -> List[CodeReviewIssue]: + """ + Wrapper that adds timing logs to show parallel execution. + """ + start_time = time.time() + batch_paths = [item["file"].path for item in batch] + logger.info(f"[Batch {batch_idx}] STARTED - files: {batch_paths}") + + try: + result = await review_file_batch( + llm, request, batch, rag_client, processed_diff, is_incremental, + fallback_rag_context=fallback_rag_context, pr_indexed=pr_indexed + ) + elapsed = time.time() - start_time + logger.info(f"[Batch {batch_idx}] FINISHED in {elapsed:.2f}s - {len(result)} issues") + return result + except Exception as e: + elapsed = time.time() - start_time + logger.error(f"[Batch {batch_idx}] FAILED after {elapsed:.2f}s: {e}") + raise + + +async def fetch_batch_rag_context( + rag_client, + request: ReviewRequestDto, + batch_file_paths: List[str], + batch_diff_snippets: List[str], + pr_indexed: bool = False +) -> Optional[Dict[str, Any]]: + """ + Fetch RAG context specifically for this batch of files. + + Two-pronged approach for comprehensive context: + 1. Semantic search using diff snippets (finds conceptually related code) + 2. Deterministic lookup using tree-sitter metadata (finds imported/referenced definitions) + + In hybrid mode (when PR files are indexed), passes pr_number to enable + queries that prioritize fresh PR data over potentially stale branch data. + """ + if not rag_client: + return None + + try: + # Determine branches for RAG query + rag_branch = request.targetBranchName or request.commitHash or "main" + base_branch = "main" # Default base branch for deterministic context + + logger.info(f"Fetching per-batch RAG context for {len(batch_file_paths)} files") + + # Use hybrid mode if PR files were indexed + pr_number = request.pullRequestId if pr_indexed else None + all_pr_files = request.changedFiles if pr_indexed else None + + # 1. Semantic search for conceptually related code + rag_response = await rag_client.get_pr_context( + workspace=request.projectWorkspace, + project=request.projectNamespace, + branch=rag_branch, + changed_files=batch_file_paths, + diff_snippets=batch_diff_snippets, + pr_title=request.prTitle, + pr_description=request.prDescription, + top_k=10, # Fewer chunks per batch for focused context + pr_number=pr_number, + all_pr_changed_files=all_pr_files + ) + + context = None + if rag_response and rag_response.get("context"): + context = rag_response.get("context") + chunk_count = len(context.get("relevant_code", [])) + logger.info(f"Semantic RAG: retrieved {chunk_count} chunks for batch") + + # 2. Deterministic lookup for cross-file dependencies (imports, extends, etc.) + # This uses tree-sitter metadata indexed during repo indexing + try: + deterministic_response = await rag_client.get_deterministic_context( + workspace=request.projectWorkspace, + project=request.projectNamespace, + branches=[rag_branch, base_branch], + file_paths=batch_file_paths, + limit_per_file=5 # Limit to avoid overwhelming context + ) + + if deterministic_response and deterministic_response.get("context"): + det_context = deterministic_response.get("context") + related_defs = det_context.get("related_definitions", {}) + + if related_defs: + # Merge deterministic results into the main context + if context is None: + context = {"relevant_code": []} + + # Add related definitions as additional context chunks + for def_name, def_chunks in related_defs.items(): + for chunk in def_chunks[:3]: # Limit per definition + context["relevant_code"].append({ + "file_path": chunk.get("file_path", ""), + "content": chunk.get("content", ""), + "score": 0.85, # High score for deterministic matches + "source": "deterministic", + "definition_name": def_name + }) + + logger.info(f"Deterministic RAG: added {len(related_defs)} related definitions") + + except Exception as det_err: + # Deterministic context is optional enhancement, don't fail the whole request + logger.debug(f"Deterministic RAG lookup skipped: {det_err}") + + if context: + total_chunks = len(context.get("relevant_code", [])) + logger.info(f"Total RAG context: {total_chunks} chunks for files {batch_file_paths}") + return context + + return None + + except Exception as e: + logger.warning(f"Failed to fetch per-batch RAG context: {e}") + return None + + +async def review_file_batch( + llm, + request: ReviewRequestDto, + batch_items: List[Dict[str, Any]], + rag_client, + processed_diff: Optional[ProcessedDiff] = None, + is_incremental: bool = False, + fallback_rag_context: Optional[Dict[str, Any]] = None, + pr_indexed: bool = False +) -> List[CodeReviewIssue]: + """ + Review a batch of files in a single LLM call with per-batch RAG context. + In incremental mode, uses delta diff and focuses on new changes only. + """ + batch_files_data = [] + batch_file_paths = [] + batch_diff_snippets = [] + #TODO: Project custom rules + project_rules = "" + + # For incremental mode, use deltaDiff instead of full diff + diff_source = None + if is_incremental and request.deltaDiff: + # Parse delta diff to extract per-file diffs + diff_source = DiffProcessor().process(request.deltaDiff) if request.deltaDiff else None + else: + diff_source = processed_diff + + # Collect file paths, diffs, and extract snippets for this batch + for item in batch_items: + file_info = item["file"] + batch_file_paths.append(file_info.path) + + # Extract diff from the appropriate source (delta for incremental, full for initial) + file_diff = "" + if diff_source: + for f in diff_source.files: + if f.path == file_info.path or f.path.endswith("/" + file_info.path): + file_diff = f.content + # Extract code snippets from diff for RAG semantic search + if file_diff: + batch_diff_snippets.extend(extract_diff_snippets(file_diff)) + break + + batch_files_data.append({ + "path": file_info.path, + "type": "MODIFIED", + "focus_areas": file_info.focus_areas, + "old_code": "", + "diff": file_diff or "(Diff unavailable)", + "is_incremental": is_incremental # Pass mode to prompt builder + }) + + # Fetch per-batch RAG context using batch-specific files and diff snippets + rag_context_text = "" + batch_rag_context = None + + if rag_client: + batch_rag_context = await fetch_batch_rag_context( + rag_client, request, batch_file_paths, batch_diff_snippets, pr_indexed + ) + + # Use batch-specific RAG context if available, otherwise fall back to initial context + # Hybrid mode: PR-indexed data is already included via fetch_batch_rag_context + if batch_rag_context: + logger.info(f"Using per-batch RAG context for: {batch_file_paths}") + rag_context_text = format_rag_context( + batch_rag_context, + set(batch_file_paths), + pr_changed_files=request.changedFiles + ) + elif fallback_rag_context: + logger.info(f"Using fallback RAG context for batch: {batch_file_paths}") + rag_context_text = format_rag_context( + fallback_rag_context, + set(batch_file_paths), + pr_changed_files=request.changedFiles + ) + + logger.info(f"RAG context for batch: {len(rag_context_text)} chars") + + # For incremental mode, filter previous issues relevant to this batch + # Also pass previous issues in FULL mode if they exist (subsequent PR iterations) + previous_issues_for_batch = "" + has_previous_issues = request.previousCodeAnalysisIssues and len(request.previousCodeAnalysisIssues) > 0 + if has_previous_issues: + relevant_prev_issues = [ + issue for issue in request.previousCodeAnalysisIssues + if issue_matches_files(issue, batch_file_paths) + ] + if relevant_prev_issues: + previous_issues_for_batch = format_previous_issues_for_batch(relevant_prev_issues) + + # Build ONE prompt for the batch with cross-file awareness + prompt = PromptBuilder.build_stage_1_batch_prompt( + files=batch_files_data, + priority=batch_items[0]["priority"] if batch_items else "MEDIUM", + project_rules=project_rules, + rag_context=rag_context_text, + is_incremental=is_incremental, + previous_issues=previous_issues_for_batch, + all_pr_files=request.changedFiles # Enable cross-file awareness in prompt + ) + + # Stage 1 uses direct LLM call (no tools needed - diff is already provided) + try: + # Try structured output first + structured_llm = llm.with_structured_output(FileReviewBatchOutput) + result = await structured_llm.ainvoke(prompt) + if result: + all_batch_issues = [] + for review in result.reviews: + all_batch_issues.extend(review.issues) + return all_batch_issues + except Exception as e: + logger.warning(f"Structured output failed for Stage 1 batch: {e}") + + # Fallback to manual parsing + try: + response = await llm.ainvoke(prompt) + content = extract_llm_response_text(response) + data = await parse_llm_response(content, FileReviewBatchOutput, llm) + all_batch_issues = [] + for review in data.reviews: + all_batch_issues.extend(review.issues) + return all_batch_issues + except Exception as parse_err: + logger.error(f"Batch review failed: {parse_err}") + return [] + + return [] + + +async def execute_stage_2_cross_file( + llm, + request: ReviewRequestDto, + stage_1_issues: List[CodeReviewIssue], + plan: ReviewPlan +) -> CrossFileAnalysisResult: + """ + Stage 2: Cross-file analysis. + """ + # Serialize Stage 1 findings + issues_json = json.dumps([i.model_dump() for i in stage_1_issues], indent=2) + + prompt = PromptBuilder.build_stage_2_cross_file_prompt( + repo_slug=request.projectVcsRepoSlug, + pr_title=request.prTitle or "", + commit_hash=request.commitHash or "HEAD", + stage_1_findings_json=issues_json, + architecture_context="(Architecture context from MCP or knowledge base)", + migrations="(Migration scripts found in PR)", + cross_file_concerns=plan.cross_file_concerns + ) + + # Stage 2 uses direct LLM call (no tools needed - all data is provided from Stage 1) + try: + structured_llm = llm.with_structured_output(CrossFileAnalysisResult) + result = await structured_llm.ainvoke(prompt) + if result: + logger.info("Stage 2 cross-file analysis completed with structured output") + return result + except Exception as e: + logger.warning(f"Structured output failed for Stage 2: {e}") + + # Fallback to manual parsing + try: + response = await llm.ainvoke(prompt) + content = extract_llm_response_text(response) + return await parse_llm_response(content, CrossFileAnalysisResult, llm) + except Exception as e: + logger.error(f"Stage 2 cross-file analysis failed: {e}") + raise + + +async def execute_stage_3_aggregation( + llm, + request: ReviewRequestDto, + plan: ReviewPlan, + stage_1_issues: List[CodeReviewIssue], + stage_2_results: CrossFileAnalysisResult, + is_incremental: bool = False +) -> str: + """ + Stage 3: Generate Markdown report. + In incremental mode, includes summary of resolved vs new issues. + """ + stage_1_json = json.dumps([i.model_dump() for i in stage_1_issues], indent=2) + stage_2_json = stage_2_results.model_dump_json(indent=2) + plan_json = plan.model_dump_json(indent=2) + + # Add incremental context to aggregation + incremental_context = "" + if is_incremental: + resolved_count = sum(1 for i in stage_1_issues if i.isResolved) + new_count = len(stage_1_issues) - resolved_count + previous_count = len(request.previousCodeAnalysisIssues or []) + incremental_context = f""" +## INCREMENTAL REVIEW SUMMARY +- Previous issues from last review: {previous_count} +- Issues resolved in this update: {resolved_count} +- New issues found in delta: {new_count} +- Total issues after reconciliation: {len(stage_1_issues)} +""" + + prompt = PromptBuilder.build_stage_3_aggregation_prompt( + repo_slug=request.projectVcsRepoSlug, + pr_id=str(request.pullRequestId), + author="Unknown", + pr_title=request.prTitle or "", + total_files=len(request.changedFiles or []), + additions=0, # Need accurate stats + deletions=0, + stage_0_plan=plan_json, + stage_1_issues_json=stage_1_json, + stage_2_findings_json=stage_2_json, + recommendation=stage_2_results.pr_recommendation + ) + + response = await llm.ainvoke(prompt) + return extract_llm_response_text(response) + + +# Helper functions for event emission +def _emit_status(callback: Optional[Callable[[Dict], None]], state: str, message: str): + if callback: + callback({ + "type": "status", + "state": state, + "message": message + }) + + +def _emit_progress(callback: Optional[Callable[[Dict], None]], percent: int, message: str): + if callback: + callback({ + "type": "progress", + "percent": percent, + "message": message + }) + + +def _emit_error(callback: Optional[Callable[[Dict], None]], message: str): + if callback: + callback({ + "type": "error", + "message": message + }) diff --git a/python-ecosystem/mcp-client/service/review_service.py b/python-ecosystem/mcp-client/service/review/review_service.py similarity index 98% rename from python-ecosystem/mcp-client/service/review_service.py rename to python-ecosystem/mcp-client/service/review/review_service.py index ba57a085..24fe62d3 100644 --- a/python-ecosystem/mcp-client/service/review_service.py +++ b/python-ecosystem/mcp-client/service/review/review_service.py @@ -6,18 +6,18 @@ from dotenv import load_dotenv from mcp_use import MCPClient -from model.models import ReviewRequestDto +from model.dtos import ReviewRequestDto from utils.mcp_config import MCPConfigBuilder from llm.llm_factory import LLMFactory from utils.prompts.prompt_builder import PromptBuilder from utils.response_parser import ResponseParser -from service.rag_client import RagClient, RAG_DEFAULT_TOP_K -from service.llm_reranker import LLMReranker -from service.issue_post_processor import post_process_analysis_result +from service.rag.rag_client import RagClient, RAG_DEFAULT_TOP_K +from service.rag.llm_reranker import LLMReranker +from service.review.issue_processor import post_process_analysis_result from utils.context_builder import (RAGMetrics, get_rag_cache) from utils.diff_processor import DiffProcessor from utils.error_sanitizer import create_user_friendly_error -from service.multi_stage_orchestrator import MultiStageReviewOrchestrator +from service.review.orchestrator import MultiStageReviewOrchestrator logger = logging.getLogger(__name__) diff --git a/python-ecosystem/mcp-client/tests/test_dependency_graph.py b/python-ecosystem/mcp-client/tests/test_dependency_graph.py new file mode 100644 index 00000000..9a06969e --- /dev/null +++ b/python-ecosystem/mcp-client/tests/test_dependency_graph.py @@ -0,0 +1,366 @@ +""" +Unit tests for the RAG-based dependency graph builder. +""" +import pytest +from unittest.mock import MagicMock, AsyncMock +from utils.dependency_graph import ( + DependencyGraphBuilder, + create_smart_batches, + FileNode, + FileRelationship +) + + +class MockReviewFile: + """Mock ReviewFile for testing.""" + def __init__(self, path: str, focus_areas: list = None): + self.path = path + self.focus_areas = focus_areas or [] + self.risk_level = "MEDIUM" + + +class MockFileGroup: + """Mock FileGroup for testing.""" + def __init__(self, files: list, priority: str = "MEDIUM"): + self.files = files + self.priority = priority + self.group_id = "test-group" + self.rationale = "Test group" + + +class MockRAGClient: + """Mock RAG client that returns predefined responses.""" + + def __init__(self, response: dict = None): + self.response = response or {} + self.call_count = 0 + self.last_call_args = None + + def get_deterministic_context(self, workspace, project, branches, file_paths, limit_per_file=15): + self.call_count += 1 + self.last_call_args = { + "workspace": workspace, + "project": project, + "branches": branches, + "file_paths": file_paths, + "limit_per_file": limit_per_file + } + return self.response + + +class TestDependencyGraphBuilder: + """Tests for DependencyGraphBuilder class.""" + + def test_init_without_rag_client(self): + """Test initialization without RAG client.""" + builder = DependencyGraphBuilder() + assert builder.rag_client is None + assert builder.nodes == {} + assert builder.relationships == [] + + def test_init_with_rag_client(self): + """Test initialization with RAG client.""" + mock_rag = MockRAGClient() + builder = DependencyGraphBuilder(rag_client=mock_rag) + assert builder.rag_client is mock_rag + + def test_build_basic_graph_fallback(self): + """Test fallback to basic graph when no RAG client.""" + builder = DependencyGraphBuilder() + + file_groups = [ + MockFileGroup([ + MockReviewFile("src/service/user_service.py"), + MockReviewFile("src/service/auth_service.py"), + ], priority="HIGH"), + MockFileGroup([ + MockReviewFile("tests/test_user.py"), + ], priority="MEDIUM"), + ] + + nodes = builder._build_basic_graph(file_groups) + + assert len(nodes) == 3 + assert "src/service/user_service.py" in nodes + # Files in same directory should be related + assert "src/service/auth_service.py" in nodes["src/service/user_service.py"].related_files + + def test_build_graph_from_rag_with_relationships(self): + """Test building graph with RAG-discovered relationships.""" + # Create RAG response with relationships + rag_response = { + "changed_files": { + "src/service/user_service.py": [ + { + "metadata": { + "primary_name": "UserService", + "semantic_names": ["get_user", "create_user"], + "imports": ["AuthService", "UserRepository"], + "parent_class": None, + "namespace": "src.service", + "path": "src/service/user_service.py" + } + } + ], + "src/service/auth_service.py": [ + { + "metadata": { + "primary_name": "AuthService", + "semantic_names": ["authenticate", "validate_token"], + "imports": [], + "parent_class": None, + "namespace": "src.service", + "path": "src/service/auth_service.py" + } + } + ] + }, + "related_definitions": { + "AuthService": [ + { + "metadata": { + "primary_name": "AuthService", + "path": "src/service/auth_service.py" + } + } + ] + }, + "class_context": {}, + "namespace_context": { + "src.service": [ + {"metadata": {"path": "src/service/user_service.py"}}, + {"metadata": {"path": "src/service/auth_service.py"}} + ] + } + } + + mock_rag = MockRAGClient(response=rag_response) + builder = DependencyGraphBuilder(rag_client=mock_rag) + + file_groups = [ + MockFileGroup([ + MockReviewFile("src/service/user_service.py"), + MockReviewFile("src/service/auth_service.py"), + ], priority="HIGH"), + ] + + nodes = builder.build_graph_from_rag( + file_groups, + workspace="test-workspace", + project="test-project", + branches=["main"] + ) + + assert len(nodes) == 2 + assert mock_rag.call_count == 1 + + # Check metadata extraction + user_node = nodes["src/service/user_service.py"] + assert "AuthService" in user_node.imports_symbols + assert "UserService" in user_node.exports_symbols + + # Check relationships discovered + assert len(builder.relationships) > 0 + + def test_get_connected_components(self): + """Test finding connected components.""" + builder = DependencyGraphBuilder() + + # Manually set up nodes with relationships + builder.nodes = { + "a.py": FileNode(path="a.py", priority="HIGH", related_files={"b.py"}), + "b.py": FileNode(path="b.py", priority="HIGH", related_files={"a.py", "c.py"}), + "c.py": FileNode(path="c.py", priority="HIGH", related_files={"b.py"}), + "d.py": FileNode(path="d.py", priority="MEDIUM", related_files=set()), # Isolated + } + + components = builder.get_connected_components() + + assert len(components) == 2 + # One component with a, b, c + abc_component = next(c for c in components if "a.py" in c) + assert abc_component == {"a.py", "b.py", "c.py"} + # One isolated component with d + d_component = next(c for c in components if "d.py" in c) + assert d_component == {"d.py"} + + def test_smart_batches_keeps_related_files_together(self): + """Test that smart batching keeps related files in same batch.""" + rag_response = { + "changed_files": { + "src/model/user.py": [ + {"metadata": {"primary_name": "User", "path": "src/model/user.py"}} + ], + "src/service/user_service.py": [ + {"metadata": { + "primary_name": "UserService", + "imports": ["User"], + "path": "src/service/user_service.py" + }} + ], + "src/api/user_api.py": [ + {"metadata": { + "primary_name": "UserAPI", + "imports": ["UserService"], + "path": "src/api/user_api.py" + }} + ], + "unrelated/config.py": [ + {"metadata": {"primary_name": "Config", "path": "unrelated/config.py"}} + ] + }, + "related_definitions": { + "User": [{"metadata": {"path": "src/model/user.py"}}], + "UserService": [{"metadata": {"path": "src/service/user_service.py"}}] + }, + "class_context": {}, + "namespace_context": {} + } + + mock_rag = MockRAGClient(response=rag_response) + builder = DependencyGraphBuilder(rag_client=mock_rag) + + file_groups = [ + MockFileGroup([ + MockReviewFile("src/model/user.py"), + MockReviewFile("src/service/user_service.py"), + MockReviewFile("src/api/user_api.py"), + MockReviewFile("unrelated/config.py"), + ], priority="HIGH"), + ] + + batches = builder.get_smart_batches( + file_groups, + workspace="test", + project="test", + branches=["main"], + max_batch_size=3 + ) + + # Should have batches created + assert len(batches) > 0 + + # Get all paths per batch + batch_paths = [[f['file'].path for f in b] for b in batches] + + # Related files should be in the same batch if possible + # (User -> UserService -> UserAPI chain) + for batch in batch_paths: + if "src/model/user.py" in batch: + # If User is in batch, UserService should be too (they're related) + if len(batch) > 1: + assert "src/service/user_service.py" in batch or "src/api/user_api.py" in batch + + def test_relationship_summary(self): + """Test getting relationship summary.""" + builder = DependencyGraphBuilder() + + builder.nodes = { + "a.py": FileNode(path="a.py", priority="HIGH", related_files={"b.py"}), + "b.py": FileNode(path="b.py", priority="HIGH", related_files={"a.py"}), + "c.py": FileNode(path="c.py", priority="LOW", related_files=set()), + } + builder.relationships = [ + FileRelationship("a.py", "b.py", "definition", "SomeClass", 0.95), + ] + + summary = builder.get_relationship_summary() + + assert summary["total_files"] == 3 + assert summary["total_relationships"] == 1 + assert summary["files_with_relationships"] == 2 + assert summary["relationship_types"]["definition"] == 1 + + +class TestCreateSmartBatches: + """Tests for the convenience function.""" + + def test_without_rag_client(self): + """Test smart batches without RAG client (fallback).""" + file_groups = [ + MockFileGroup([ + MockReviewFile("a.py"), + MockReviewFile("b.py"), + ], priority="HIGH"), + ] + + batches = create_smart_batches( + file_groups, + workspace="test", + project="test", + branches=["main"], + rag_client=None, + max_batch_size=5 + ) + + assert len(batches) > 0 + total_files = sum(len(b) for b in batches) + assert total_files == 2 + + def test_with_rag_client(self): + """Test smart batches with RAG client.""" + mock_rag = MockRAGClient(response={ + "changed_files": {}, + "related_definitions": {}, + "class_context": {}, + "namespace_context": {} + }) + + file_groups = [ + MockFileGroup([ + MockReviewFile("a.py"), + MockReviewFile("b.py"), + MockReviewFile("c.py"), + ], priority="HIGH"), + ] + + batches = create_smart_batches( + file_groups, + workspace="test", + project="test", + branches=["main"], + rag_client=mock_rag, + max_batch_size=2 + ) + + assert mock_rag.call_count == 1 + assert len(batches) >= 1 + + +class TestMergingSmallBatches: + """Tests for batch merging optimization.""" + + def test_merge_same_priority_batches(self): + """Test that small batches of same priority get merged.""" + builder = DependencyGraphBuilder() + + batches = [ + [{"file": MockReviewFile("a.py"), "priority": "HIGH"}], + [{"file": MockReviewFile("b.py"), "priority": "HIGH"}], + [{"file": MockReviewFile("c.py"), "priority": "LOW"}], + ] + + merged = builder._merge_small_batches(batches, min_size=2, max_size=5) + + # HIGH priority files should be merged + high_batch = next((b for b in merged if b[0]["priority"] == "HIGH"), None) + if high_batch: + assert len(high_batch) <= 5 + + def test_no_merge_when_at_max_size(self): + """Test that batches at max size are not merged.""" + builder = DependencyGraphBuilder() + + batches = [ + [ + {"file": MockReviewFile("a.py"), "priority": "HIGH"}, + {"file": MockReviewFile("b.py"), "priority": "HIGH"}, + {"file": MockReviewFile("c.py"), "priority": "HIGH"}, + ], + [{"file": MockReviewFile("d.py"), "priority": "HIGH"}], + ] + + merged = builder._merge_small_batches(batches, min_size=2, max_size=3) + + # First batch is already at max, so d.py stays separate + assert any(len(b) == 3 for b in merged) diff --git a/python-ecosystem/mcp-client/utils/dependency_graph.py b/python-ecosystem/mcp-client/utils/dependency_graph.py new file mode 100644 index 00000000..e8d3cd0e --- /dev/null +++ b/python-ecosystem/mcp-client/utils/dependency_graph.py @@ -0,0 +1,630 @@ +""" +Dependency graph builder for intelligent file batching. + +SMART APPROACH: Leverages RAG's pre-indexed tree-sitter metadata to discover +file relationships instead of re-parsing diffs with regex. + +The RAG system already has: +- semantic_names: function/method/class names +- imports: import statements parsed by tree-sitter +- extends: parent classes/interfaces +- parent_class: containing class +- namespace: package/namespace + +This module queries RAG to build a relationship graph, enabling intelligent +batching that keeps related files together for better cross-file context. +""" +import logging +from collections import defaultdict +from typing import Dict, List, Set, Any, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from service.rag.rag_client import RagClient + +from dataclasses import dataclass, field + +logger = logging.getLogger(__name__) + + +@dataclass +class FileNode: + """Represents a file in the dependency graph.""" + path: str + priority: str + # Relationships discovered from RAG tree-sitter metadata + related_files: Set[str] = field(default_factory=set) + imports_symbols: Set[str] = field(default_factory=set) + exports_symbols: Set[str] = field(default_factory=set) + parent_classes: Set[str] = field(default_factory=set) + namespaces: Set[str] = field(default_factory=set) + extends: Set[str] = field(default_factory=set) + focus_areas: List[str] = field(default_factory=list) + relationship_strength: float = 0.0 + + +@dataclass +class FileRelationship: + """Represents a relationship between two files.""" + source_file: str + target_file: str + relationship_type: str # 'definition', 'same_class', 'same_namespace' + matched_on: str + strength: float + + +class DependencyGraphBuilder: + """ + Builds a dependency graph using RAG's tree-sitter metadata or pre-computed relationships. + + SMART APPROACH (v2): When Java sends enrichment data with pre-computed relationships, + use those directly instead of querying RAG. This eliminates duplicate work since + Java already called RAG's /parse endpoint. + + Fallback: When no enrichment data available, query RAG's deterministic context API + which has the FULL file indexed with tree-sitter metadata. + + Relationship types discovered: + - definition: File A uses symbol defined in File B + - same_class: Files contain methods of the same class + - same_namespace: Files are in the same package/namespace + """ + + RELATIONSHIP_WEIGHTS = { + 'changed_file': 1.0, + 'definition': 0.95, + 'IMPORTS': 0.90, + 'EXTENDS': 0.95, + 'IMPLEMENTS': 0.95, + 'CALLS': 0.85, + 'class_context': 0.85, + 'namespace_context': 0.75, + 'SAME_PACKAGE': 0.60, + } + + def __init__(self, rag_client: Optional["RAGClient"] = None): + self.rag_client = rag_client + self.nodes: Dict[str, FileNode] = {} + self.relationships: List[FileRelationship] = [] + self._metadata_cache: Dict[str, Dict] = {} + + def build_graph_from_enrichment( + self, + file_groups: List[Any], + enrichment_data: Any, + ) -> Dict[str, FileNode]: + """ + Build dependency graph from pre-computed enrichment data sent by Java. + + This is the preferred method when enrichment_data is available, as it: + - Eliminates redundant RAG calls (Java already parsed the files) + - Uses full file content for accurate relationship detection + - Provides relationships computed with proper AST parsing + + Args: + file_groups: List of FileGroup objects with files to analyze + enrichment_data: PrEnrichmentDataDto from Java with relationships and metadata + + Returns: + Dict of file paths to FileNode objects with relationships populated + """ + if not enrichment_data or not enrichment_data.has_data(): + logger.info("No enrichment data available, falling back to basic grouping") + return self._build_basic_graph(file_groups) + + # Initialize nodes from file groups + for group in file_groups: + for f in group.files: + self.nodes[f.path] = FileNode( + path=f.path, + priority=group.priority, + focus_areas=f.focus_areas if hasattr(f, 'focus_areas') else [] + ) + + # Process pre-computed relationships + relationships_by_file: Dict[str, Set[str]] = defaultdict(set) + + for rel in enrichment_data.relationships: + source = rel.sourceFile + target = rel.targetFile + rel_type = rel.relationshipType.value if hasattr(rel.relationshipType, 'value') else str(rel.relationshipType) + + # Only add relationships between files we're analyzing + if source in self.nodes and target in self.nodes: + relationships_by_file[source].add(target) + relationships_by_file[target].add(source) + + weight = self.RELATIONSHIP_WEIGHTS.get(rel_type, 0.5) + self.relationships.append(FileRelationship( + source_file=source, + target_file=target, + relationship_type=rel_type, + matched_on=rel.matchedOn or "", + strength=weight + )) + + # Process metadata to populate node symbols + for meta in enrichment_data.fileMetadata: + if meta.path in self.nodes: + node = self.nodes[meta.path] + if meta.imports: + node.imports_symbols.update(meta.imports) + if meta.semanticNames: + node.exports_symbols.update(meta.semanticNames) + if meta.extendsClasses: + node.extends.update(meta.extendsClasses) + if meta.parentClass: + node.parent_classes.add(meta.parentClass) + if meta.namespace: + node.namespaces.add(meta.namespace) + + # Update nodes with discovered relationships + for file_path, related in relationships_by_file.items(): + if file_path in self.nodes: + self.nodes[file_path].related_files.update(related) + self.nodes[file_path].relationship_strength = self._calculate_strength( + file_path, related + ) + + logger.info( + f"Dependency graph built from enrichment: {len(self.nodes)} files, " + f"{len(self.relationships)} relationships" + ) + + return self.nodes + + def build_graph_from_rag( + self, + file_groups: List[Any], + workspace: str, + project: str, + branches: List[str], + ) -> Dict[str, FileNode]: + """ + Build dependency graph by querying RAG's deterministic context API. + + This leverages tree-sitter metadata extracted during indexing: + - imports, extends, parent_class, namespace, semantic_names + """ + if not self.rag_client: + logger.warning("No RAG client provided, falling back to basic grouping") + return self._build_basic_graph(file_groups) + + # Collect all file paths + all_file_paths = [] + file_priority_map = {} + file_info_map = {} + + for group in file_groups: + for f in group.files: + all_file_paths.append(f.path) + file_priority_map[f.path] = group.priority + file_info_map[f.path] = f + self.nodes[f.path] = FileNode( + path=f.path, + priority=group.priority, + focus_areas=f.focus_areas if hasattr(f, 'focus_areas') else [] + ) + + if not all_file_paths: + return self.nodes + + # Query RAG for deterministic context + try: + rag_response = self.rag_client.get_deterministic_context( + workspace=workspace, + project=project, + branches=branches, + file_paths=all_file_paths, + limit_per_file=15 + ) + self._metadata_cache['last_response'] = rag_response + except Exception as e: + logger.warning(f"RAG query failed, falling back to basic grouping: {e}") + return self._build_basic_graph(file_groups) + + # Extract relationships from RAG response + self._extract_relationships_from_rag(rag_response, all_file_paths) + + logger.info( + f"Dependency graph built: {len(self.nodes)} files, " + f"{len(self.relationships)} relationships" + ) + + return self.nodes + + def _extract_relationships_from_rag( + self, + rag_response: Dict, + changed_file_paths: List[str] + ) -> None: + """Extract file relationships from RAG deterministic context response.""" + changed_file_set = set(changed_file_paths) + file_relationships: Dict[str, Set[str]] = defaultdict(set) + + # Process changed_files to extract metadata + changed_files = rag_response.get('changed_files', {}) + for file_path, chunks in changed_files.items(): + norm_path = file_path.lstrip('/') + if norm_path in self.nodes: + for chunk in chunks: + metadata = chunk.get('metadata', {}) + + # Extract symbols this file exports (defines) + if metadata.get('primary_name'): + self.nodes[norm_path].exports_symbols.add(metadata['primary_name']) + if metadata.get('semantic_names'): + self.nodes[norm_path].exports_symbols.update(metadata['semantic_names']) + + # Extract what this file imports + if metadata.get('imports'): + for imp in metadata['imports']: + if isinstance(imp, str): + parts = imp.replace(';', '').split('\\') + if parts: + self.nodes[norm_path].imports_symbols.add(parts[-1].strip()) + + # Track class/namespace membership + if metadata.get('parent_class'): + self.nodes[norm_path].parent_classes.add(metadata['parent_class']) + if metadata.get('namespace'): + self.nodes[norm_path].namespaces.add(metadata['namespace']) + if metadata.get('extends'): + self.nodes[norm_path].extends.update(metadata['extends']) + + # Process related_definitions + related_definitions = rag_response.get('related_definitions', {}) + for symbol, chunks in related_definitions.items(): + for chunk in chunks: + metadata = chunk.get('metadata', {}) + related_path = metadata.get('path', '').lstrip('/') + + if related_path and related_path in self.nodes: + for file_path in changed_file_set: + norm_path = file_path.lstrip('/') + if norm_path in self.nodes: + node = self.nodes[norm_path] + if symbol in node.imports_symbols or symbol in node.exports_symbols: + file_relationships[norm_path].add(related_path) + file_relationships[related_path].add(norm_path) + self.relationships.append(FileRelationship( + source_file=norm_path, + target_file=related_path, + relationship_type='definition', + matched_on=symbol, + strength=self.RELATIONSHIP_WEIGHTS['definition'] + )) + + # Process class_context + class_context = rag_response.get('class_context', {}) + for parent_class, chunks in class_context.items(): + class_files = set() + for chunk in chunks: + metadata = chunk.get('metadata', {}) + related_path = metadata.get('path', '').lstrip('/') + if related_path in self.nodes: + class_files.add(related_path) + + for f1 in class_files: + for f2 in class_files: + if f1 != f2: + file_relationships[f1].add(f2) + if f1 < f2: + self.relationships.append(FileRelationship( + source_file=f1, + target_file=f2, + relationship_type='same_class', + matched_on=parent_class, + strength=self.RELATIONSHIP_WEIGHTS['class_context'] + )) + + # Process namespace_context + namespace_context = rag_response.get('namespace_context', {}) + for namespace, chunks in namespace_context.items(): + ns_files = set() + for chunk in chunks: + metadata = chunk.get('metadata', {}) + related_path = metadata.get('path', '').lstrip('/') + if related_path in self.nodes: + ns_files.add(related_path) + + for f1 in ns_files: + for f2 in ns_files: + if f1 != f2: + file_relationships[f1].add(f2) + if f1 < f2: + self.relationships.append(FileRelationship( + source_file=f1, + target_file=f2, + relationship_type='same_namespace', + matched_on=namespace, + strength=self.RELATIONSHIP_WEIGHTS['namespace_context'] + )) + + # Update nodes with discovered relationships + for file_path, related in file_relationships.items(): + if file_path in self.nodes: + self.nodes[file_path].related_files.update(related) + self.nodes[file_path].relationship_strength = self._calculate_strength( + file_path, related + ) + + def _calculate_strength(self, file_path: str, related_files: Set[str]) -> float: + total_strength = 0.0 + for rel in self.relationships: + if rel.source_file == file_path or rel.target_file == file_path: + total_strength += rel.strength + return min(total_strength, 5.0) + + def _build_basic_graph(self, file_groups: List[Any]) -> Dict[str, FileNode]: + """Fallback: build basic graph without RAG (by directory).""" + for group in file_groups: + for f in group.files: + self.nodes[f.path] = FileNode( + path=f.path, + priority=group.priority, + focus_areas=f.focus_areas if hasattr(f, 'focus_areas') else [] + ) + + # Files in same directory are related + dir_files: Dict[str, List[str]] = defaultdict(list) + for path in self.nodes: + dir_path = '/'.join(path.split('/')[:-1]) if '/' in path else '' + dir_files[dir_path].append(path) + + for dir_path, files in dir_files.items(): + if len(files) > 1: + for f1 in files: + for f2 in files: + if f1 != f2: + self.nodes[f1].related_files.add(f2) + + return self.nodes + + def get_connected_components(self) -> List[Set[str]]: + """Find connected components in the dependency graph.""" + visited = set() + components = [] + + def dfs(node_path: str, component: Set[str]): + if node_path in visited: + return + visited.add(node_path) + component.add(node_path) + + node = self.nodes.get(node_path) + if not node: + return + + for related_path in node.related_files: + if related_path in self.nodes: + dfs(related_path, component) + + for path in self.nodes: + if path not in visited: + component: Set[str] = set() + dfs(path, component) + if component: + components.append(component) + + return components + + def get_smart_batches( + self, + file_groups: List[Any], + workspace: str, + project: str, + branches: List[str], + max_batch_size: int = 7, + min_batch_size: int = 3, + enrichment_data: Any = None + ) -> List[List[Dict[str, Any]]]: + """ + Create intelligent batches that keep related files together. + + Strategy: + 1. If enrichment_data is available, use pre-computed relationships from Java + 2. Otherwise, query RAG to discover file relationships via tree-sitter metadata + 3. Find connected components (files that are related) + 4. Batch files within components together + 5. For large components, split by priority while keeping related files together + + Args: + file_groups: List of FileGroup objects with files + workspace: Repository workspace/owner + project: Repository slug + branches: Branch names for context + max_batch_size: Maximum files per batch + min_batch_size: Minimum files per batch + enrichment_data: Optional PrEnrichmentDataDto from Java with pre-computed relationships + """ + # Use enrichment data if available, otherwise fall back to RAG + if enrichment_data and hasattr(enrichment_data, 'has_data') and enrichment_data.has_data(): + logger.info("Using pre-computed enrichment data for dependency graph") + self.build_graph_from_enrichment(file_groups, enrichment_data) + else: + self.build_graph_from_rag(file_groups, workspace, project, branches) + components = self.get_connected_components() + + logger.info( + f"Dependency analysis: {len(self.nodes)} files, " + f"{len(components)} connected components, " + f"{len(self.relationships)} relationships" + ) + + file_priority_map = {} + file_info_map = {} + for group in file_groups: + for f in group.files: + file_priority_map[f.path] = group.priority + file_info_map[f.path] = f + + batches = [] + processed_files = set() + priority_order = ['CRITICAL', 'HIGH', 'MEDIUM', 'LOW'] + + def component_sort_key(comp): + max_priority = min( + priority_order.index(file_priority_map.get(f, 'LOW')) + for f in comp + ) + return (-len(comp), max_priority) + + for component in sorted(components, key=component_sort_key): + if all(f in processed_files for f in component): + continue + + component_files = [f for f in component if f not in processed_files] + if not component_files: + continue + + component_files_sorted = sorted( + component_files, + key=lambda f: ( + -self.nodes[f].relationship_strength, + priority_order.index(file_priority_map.get(f, 'LOW')), + f + ) + ) + + current_batch = [] + for file_path in component_files_sorted: + file_info = file_info_map.get(file_path) + if not file_info: + continue + + node = self.nodes[file_path] + current_batch.append({ + "file": file_info, + "priority": file_priority_map.get(file_path, 'MEDIUM'), + "has_relationships": len(node.related_files) > 0, + "relationship_strength": node.relationship_strength, + "related_in_batch": [ + r for r in node.related_files + if r in {b['file'].path for b in current_batch} + ] + }) + processed_files.add(file_path) + + if len(current_batch) >= max_batch_size: + batches.append(current_batch) + current_batch = [] + + if current_batch: + batches.append(current_batch) + + # Handle orphan files + orphan_files = [] + for group in file_groups: + for f in group.files: + if f.path not in processed_files: + orphan_files.append({ + "file": f, + "priority": group.priority, + "has_relationships": False, + "relationship_strength": 0.0, + "related_in_batch": [] + }) + processed_files.add(f.path) + + if orphan_files: + orphan_files_sorted = sorted( + orphan_files, + key=lambda x: (priority_order.index(x['priority']), x['file'].path) + ) + for i in range(0, len(orphan_files_sorted), max_batch_size): + batches.append(orphan_files_sorted[i:i + max_batch_size]) + + batches = self._merge_small_batches(batches, min_batch_size, max_batch_size) + + logger.info(f"Smart batching created {len(batches)} batches from {len(self.nodes)} files") + for i, batch in enumerate(batches): + paths = [b['file'].path for b in batch] + rel_count = sum(1 for b in batch if b.get('has_relationships')) + logger.debug(f"Batch {i+1}: {len(batch)} files ({rel_count} with relationships): {paths}") + + return batches + + def _merge_small_batches( + self, + batches: List[List[Dict[str, Any]]], + min_size: int, + max_size: int + ) -> List[List[Dict[str, Any]]]: + """Merge small batches if they have the same priority.""" + if not batches: + return batches + + priority_batches: Dict[str, List[List[Dict[str, Any]]]] = defaultdict(list) + for batch in batches: + if not batch: + continue + priorities = [b['priority'] for b in batch] + dominant = max(set(priorities), key=priorities.count) + priority_batches[dominant].append(batch) + + merged = [] + for priority, p_batches in priority_batches.items(): + current_merged = [] + for batch in p_batches: + if len(current_merged) + len(batch) <= max_size: + current_merged.extend(batch) + else: + if current_merged: + merged.append(current_merged) + current_merged = batch[:] + if current_merged: + merged.append(current_merged) + + return merged + + def get_relationship_summary(self) -> Dict[str, Any]: + """Get a summary of discovered relationships.""" + relationship_types = defaultdict(int) + for rel in self.relationships: + relationship_types[rel.relationship_type] += 1 + + return { + "total_files": len(self.nodes), + "total_relationships": len(self.relationships), + "relationship_types": dict(relationship_types), + "files_with_relationships": sum( + 1 for node in self.nodes.values() + if len(node.related_files) > 0 + ), + "avg_relationships_per_file": ( + len(self.relationships) * 2 / len(self.nodes) + if self.nodes else 0 + ) + } + + +def create_smart_batches( + file_groups: List[Any], + workspace: str, + project: str, + branches: List[str], + rag_client: Optional["RAGClient"] = None, + max_batch_size: int = 7, + enrichment_data: Any = None +) -> List[List[Dict[str, Any]]]: + """ + Convenience function to create smart batches from file groups. + + Args: + file_groups: List of FileGroup objects with files + workspace: Repository workspace/owner + project: Repository slug + branches: Branch names for context + rag_client: Optional RAG client for relationship discovery + max_batch_size: Maximum files per batch + enrichment_data: Optional PrEnrichmentDataDto with pre-computed relationships from Java + """ + builder = DependencyGraphBuilder(rag_client=rag_client) + return builder.get_smart_batches( + file_groups, + workspace, + project, + branches, + max_batch_size, + enrichment_data=enrichment_data + ) diff --git a/python-ecosystem/mcp-client/utils/prompts/prompt_builder.py b/python-ecosystem/mcp-client/utils/prompts/prompt_builder.py index 4edba35f..7883d90d 100644 --- a/python-ecosystem/mcp-client/utils/prompts/prompt_builder.py +++ b/python-ecosystem/mcp-client/utils/prompts/prompt_builder.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional import json -from model.models import IssueDTO +from model.dtos import IssueDTO from utils.prompts.prompt_constants import ( ADDITIONAL_INSTRUCTIONS, BRANCH_REVIEW_PROMPT_TEMPLATE, diff --git a/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py b/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py index 1c86a621..a1aeb473 100644 --- a/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py +++ b/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py @@ -275,7 +275,7 @@ Create a prioritized review plan in this JSON format: {{ - "analysis_summary": "2-sentence overview of PR scope and risk level", + "analysis_summary": "overview of PR scope and risk level", "file_groups": [ {{ "group_id": "GROUP_A_SECURITY", @@ -350,6 +350,7 @@ When in doubt, assume the code is correct - the developer can see the full file, you cannot. {incremental_instructions} +{pr_files_context} PROJECT RULES: {project_rules} diff --git a/python-ecosystem/rag-pipeline/Dockerfile b/python-ecosystem/rag-pipeline/Dockerfile index 46f619c5..5537e2e7 100644 --- a/python-ecosystem/rag-pipeline/Dockerfile +++ b/python-ecosystem/rag-pipeline/Dockerfile @@ -34,6 +34,13 @@ ENV LLAMA_INDEX_CACHE_DIR=/tmp/.llama_index # Allow concurrent indexing by running multiple Uvicorn workers ENV UVICORN_WORKERS=4 +# CPU Threading Optimization (can be overridden via .env) +# These control parallelism for numerical operations +ENV OMP_NUM_THREADS=4 +ENV MKL_NUM_THREADS=4 +ENV OPENBLAS_NUM_THREADS=4 +ENV NUMEXPR_NUM_THREADS=4 + EXPOSE 8001 CMD ["python", "main.py"] diff --git a/python-ecosystem/rag-pipeline/main.py b/python-ecosystem/rag-pipeline/main.py index 7b3e146f..4918eaa5 100644 --- a/python-ecosystem/rag-pipeline/main.py +++ b/python-ecosystem/rag-pipeline/main.py @@ -17,32 +17,52 @@ # Validate critical environment variables before starting def validate_environment(): """Validate that required environment variables are set""" + embedding_provider = os.environ.get("EMBEDDING_PROVIDER", "ollama").lower() openrouter_key = os.environ.get("OPENROUTER_API_KEY", "") logger.info("=" * 60) logger.info("RAG Pipeline Starting - Environment Check") logger.info("=" * 60) - logger.info(f"QDRANT_URL: {os.getenv('QDRANT_URL', 'NOT SET')}") + logger.info(f"QDRANT_URL: {os.getenv('QDRANT_URL', 'http://localhost:6333')}") logger.info(f"QDRANT_COLLECTION_PREFIX: {os.getenv('QDRANT_COLLECTION_PREFIX', 'codecrow')}") - logger.info(f"OPENROUTER_MODEL: {os.getenv('OPENROUTER_MODEL', 'openai/text-embedding-3-small')}") + logger.info(f"EMBEDDING_PROVIDER: {embedding_provider}") - if not openrouter_key or openrouter_key.strip() == "": - logger.error("=" * 60) - logger.error("CRITICAL ERROR: OPENROUTER_API_KEY not set!") - logger.error("=" * 60) - logger.error("The OPENROUTER_API_KEY environment variable is required") - logger.error("but was not found or is empty.") - logger.error("") - logger.error("To fix this:") - logger.error("1. Set the environment variable:") - logger.error(" export OPENROUTER_API_KEY='sk-or-v1-...'") - logger.error("2. Or add it to docker-compose.yml") - logger.error("3. Or create a .env file with: OPENROUTER_API_KEY=sk-or-v1-...") - logger.error("=" * 60) - sys.exit(1) + if embedding_provider == "ollama": + ollama_url = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') + ollama_model = os.getenv('OLLAMA_EMBEDDING_MODEL', 'qwen3-embedding:0.6b') + logger.info(f"OLLAMA_BASE_URL: {ollama_url}") + logger.info(f"OLLAMA_EMBEDDING_MODEL: {ollama_model}") + logger.info("=" * 60) + logger.info("Using Ollama for local embeddings ✓") + logger.info("Make sure Ollama is running: ollama serve") + logger.info(f"And model is pulled: ollama pull {ollama_model}") + logger.info("=" * 60) + elif embedding_provider == "openrouter": + logger.info(f"OPENROUTER_MODEL: {os.getenv('OPENROUTER_MODEL', 'qwen/qwen3-embedding-8b')}") + + if not openrouter_key or openrouter_key.strip() == "": + logger.error("=" * 60) + logger.error("CRITICAL ERROR: OPENROUTER_API_KEY not set!") + logger.error("=" * 60) + logger.error("The OPENROUTER_API_KEY environment variable is required") + logger.error("when EMBEDDING_PROVIDER=openrouter but was not found or is empty.") + logger.error("") + logger.error("To fix this:") + logger.error("1. Set the environment variable:") + logger.error(" export OPENROUTER_API_KEY='sk-or-v1-...'") + logger.error("2. Or add it to docker-compose.yml") + logger.error("3. Or create a .env file with: OPENROUTER_API_KEY=sk-or-v1-...") + logger.error("4. Or switch to local embeddings: EMBEDDING_PROVIDER=ollama") + logger.error("=" * 60) + sys.exit(1) + + logger.info(f"OPENROUTER_API_KEY: {openrouter_key[:15]}...{openrouter_key[-4:]} ✓") + logger.info("=" * 60) + logger.info("Using OpenRouter for cloud embeddings ✓") + logger.info("=" * 60) + else: + logger.warning(f"Unknown EMBEDDING_PROVIDER '{embedding_provider}', defaulting to 'ollama'") - logger.info(f"OPENROUTER_API_KEY: {openrouter_key[:15]}...{openrouter_key[-4:]} ✓") - logger.info("=" * 60) logger.info("Environment validation passed ✓") logger.info("Using Qdrant for vector storage") logger.info("=" * 60) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/api/api.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/api/api.py index 2d15b12d..8cb66131 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/api/api.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/api/api.py @@ -106,6 +106,160 @@ def health(): return {"status": "healthy"} +# ============================================================================= +# PARSE ENDPOINTS (AST metadata extraction without indexing) +# ============================================================================= + +class ParseFileRequest(BaseModel): + """Request to parse a single file and extract AST metadata.""" + path: str + content: str + language: Optional[str] = None # Auto-detected if not provided + + +class ParseBatchRequest(BaseModel): + """Request to parse multiple files in batch.""" + files: List[ParseFileRequest] + + +class ParsedFileMetadata(BaseModel): + """AST metadata extracted from a file.""" + path: str + language: Optional[str] = None + imports: List[str] = [] + extends: List[str] = [] + implements: List[str] = [] + semantic_names: List[str] = [] # Function/class/method names + parent_class: Optional[str] = None + namespace: Optional[str] = None + calls: List[str] = [] # Called functions/methods + success: bool = True + error: Optional[str] = None + + +@app.post("/parse", response_model=ParsedFileMetadata) +def parse_file(request: ParseFileRequest): + """ + Parse a single file and extract AST metadata WITHOUT indexing. + + Returns tree-sitter extracted metadata: + - imports: Import statements + - extends: Parent classes/interfaces + - implements: Implemented interfaces + - semantic_names: Function/class/method names defined + - namespace: Package/namespace + - calls: Called functions/methods + + Used by Java pipeline-agent to build dependency graph. + """ + try: + from ..core.splitter import ASTCodeSplitter + from ..core.splitter.languages import get_language_from_path, EXTENSION_TO_LANGUAGE + + # Detect language + language = request.language + if not language: + lang_enum = get_language_from_path(request.path) + language = lang_enum.value if lang_enum else None + + if not language: + # Try to infer from extension + ext = '.' + request.path.rsplit('.', 1)[-1] if '.' in request.path else '' + language = EXTENSION_TO_LANGUAGE.get(ext, {}).get('name') + + splitter = ASTCodeSplitter( + max_chunk_size=50000, # Large to avoid splitting + enrich_embedding_text=False + ) + + # Create a minimal document for parsing + from llama_index.core.schema import Document as LlamaDocument + doc = LlamaDocument(text=request.content, metadata={'path': request.path}) + + # Parse and extract chunks + nodes = splitter.split_documents([doc]) + + # Aggregate metadata from all chunks + imports = set() + extends = set() + implements = set() + semantic_names = set() + calls = set() + namespace = None + parent_classes = set() + + for node in nodes: + meta = node.metadata + + if meta.get('imports'): + imports.update(meta['imports']) + if meta.get('extends'): + extends.update(meta['extends']) + if meta.get('implements'): + implements.update(meta['implements']) + if meta.get('semantic_names'): + semantic_names.update(meta['semantic_names']) + if meta.get('calls'): + calls.update(meta['calls']) + if meta.get('namespace') and not namespace: + namespace = meta['namespace'] + if meta.get('parent_class'): + parent_classes.add(meta['parent_class']) + + # Primary parent class (first one found) + parent_class = list(parent_classes)[0] if parent_classes else None + + return ParsedFileMetadata( + path=request.path, + language=language, + imports=sorted(list(imports)), + extends=sorted(list(extends)), + implements=sorted(list(implements)), + semantic_names=sorted(list(semantic_names)), + parent_class=parent_class, + namespace=namespace, + calls=sorted(list(calls)), + success=True + ) + + except Exception as e: + logger.warning(f"Error parsing file {request.path}: {e}") + return ParsedFileMetadata( + path=request.path, + success=False, + error=str(e) + ) + + +@app.post("/parse/batch") +def parse_files_batch(request: ParseBatchRequest): + """ + Parse multiple files and extract AST metadata in batch. + + Returns list of ParsedFileMetadata for each file. + Continues processing even if individual files fail. + """ + results = [] + + for file_req in request.files: + result = parse_file(file_req) + results.append(result) + + successful = sum(1 for r in results if r.success) + failed = len(results) - successful + + logger.info(f"Batch parse: {successful} successful, {failed} failed out of {len(results)} files") + + return { + "results": results, + "summary": { + "total": len(results), + "successful": successful, + "failed": failed + } + } + + @app.get("/limits") def get_limits(): """Get current RAG indexing limits (for free plan info)""" diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/embedding_factory.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/embedding_factory.py new file mode 100644 index 00000000..899cd354 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/embedding_factory.py @@ -0,0 +1,93 @@ +""" +Embedding factory for creating embedding models based on configuration. +Supports switching between local (Ollama) and cloud (OpenRouter) providers. +""" + +import logging +from typing import Union + +from llama_index.core.base.embeddings.base import BaseEmbedding + +from ..models.config import RAGConfig +from .ollama_embedding import OllamaEmbedding +from .openrouter_embedding import OpenRouterEmbedding + +logger = logging.getLogger(__name__) + + +def create_embedding_model(config: RAGConfig) -> BaseEmbedding: + """ + Create an embedding model based on the configuration. + + Args: + config: RAGConfig with embedding provider settings + + Returns: + BaseEmbedding instance (OllamaEmbedding or OpenRouterEmbedding) + """ + provider = config.embedding_provider.lower() + + if provider == "ollama": + logger.info(f"Creating Ollama embedding model: {config.ollama_model}") + return OllamaEmbedding( + model=config.ollama_model, + base_url=config.ollama_base_url, + timeout=120.0, + expected_dim=config.embedding_dim + ) + + elif provider == "openrouter": + logger.info(f"Creating OpenRouter embedding model: {config.openrouter_model}") + return OpenRouterEmbedding( + api_key=config.openrouter_api_key, + model=config.openrouter_model, + api_base=config.openrouter_base_url, + timeout=60.0, + max_retries=3, + expected_dim=config.embedding_dim + ) + + else: + logger.warning(f"Unknown embedding provider '{provider}', defaulting to Ollama") + return OllamaEmbedding( + model=config.ollama_model, + base_url=config.ollama_base_url, + timeout=120.0, + expected_dim=config.embedding_dim + ) + + +def get_embedding_model_info(config: RAGConfig) -> dict: + """ + Get information about the configured embedding model. + + Args: + config: RAGConfig with embedding provider settings + + Returns: + Dictionary with provider info + """ + provider = config.embedding_provider.lower() + + if provider == "ollama": + return { + "provider": "ollama", + "model": config.ollama_model, + "base_url": config.ollama_base_url, + "embedding_dim": config.embedding_dim, + "type": "local" + } + elif provider == "openrouter": + return { + "provider": "openrouter", + "model": config.openrouter_model, + "base_url": config.openrouter_base_url, + "embedding_dim": config.embedding_dim, + "type": "cloud" + } + else: + return { + "provider": provider, + "embedding_dim": config.embedding_dim, + "type": "unknown" + } diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/branch_manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/branch_manager.py index 9b1eb7ec..5ea50f45 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/branch_manager.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/branch_manager.py @@ -157,7 +157,42 @@ def stream_copy_points_to_collection( """Stream copy points from one collection to another, excluding a branch. Memory-efficient alternative to preserve_other_branch_points + copy_points_to_collection. + Skips copying if vector dimensions don't match between collections. """ + # Check vector dimensions match before copying + try: + source_info = self.client.get_collection(source_collection) + target_info = self.client.get_collection(target_collection) + + # Get dimensions from vector config + source_dim = None + target_dim = None + + if hasattr(source_info.config.params, 'vectors'): + vectors_config = source_info.config.params.vectors + if hasattr(vectors_config, 'size'): + source_dim = vectors_config.size + elif isinstance(vectors_config, dict) and '' in vectors_config: + source_dim = vectors_config[''].size + + if hasattr(target_info.config.params, 'vectors'): + vectors_config = target_info.config.params.vectors + if hasattr(vectors_config, 'size'): + target_dim = vectors_config.size + elif isinstance(vectors_config, dict) and '' in vectors_config: + target_dim = vectors_config[''].size + + if source_dim and target_dim and source_dim != target_dim: + logger.warning( + f"Skipping branch preservation: dimension mismatch " + f"(source: {source_dim}, target: {target_dim}). " + f"Re-embedding required for all branches." + ) + return 0 + + except Exception as e: + logger.warning(f"Could not verify collection dimensions: {e}") + # Continue anyway - will fail at upsert if dimensions don't match total_copied = 0 for batch in self.preserve_other_branch_points(source_collection, exclude_branch, batch_size): diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py index 363962fc..259ccea5 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py @@ -242,7 +242,6 @@ def index_repository( self.collection_manager.delete_collection(temp_collection_name) raise e finally: - del existing_other_branch_points gc.collect() self.stats_manager.store_metadata( diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/manager.py index d6bd371b..abd1d174 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/manager.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/manager.py @@ -14,7 +14,7 @@ from ...utils.utils import make_namespace, make_project_namespace from ..splitter import ASTCodeSplitter from ..loader import DocumentLoader -from ..openrouter_embedding import OpenRouterEmbedding +from ..embedding_factory import create_embedding_model, get_embedding_model_info from .collection_manager import CollectionManager from .branch_manager import BranchManager @@ -38,15 +38,12 @@ def __init__(self, config: RAGConfig): self.qdrant_client = QdrantClient(url=config.qdrant_url) logger.info(f"Connected to Qdrant at {config.qdrant_url}") - # Embedding model - self.embed_model = OpenRouterEmbedding( - api_key=config.openrouter_api_key, - model=config.openrouter_model, - api_base=config.openrouter_base_url, - timeout=60.0, - max_retries=3, - expected_dim=config.embedding_dim - ) + # Embedding model (supports Ollama and OpenRouter via factory) + embed_info = get_embedding_model_info(config) + logger.info(f"Using embedding provider: {embed_info['provider']} ({embed_info['type']})") + logger.info(f"Embedding model: {embed_info['model']}, dimension: {embed_info['embedding_dim']}") + + self.embed_model = create_embedding_model(config) # Global settings Settings.embed_model = self.embed_model diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/ollama_embedding.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/ollama_embedding.py new file mode 100644 index 00000000..5155a2f8 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/ollama_embedding.py @@ -0,0 +1,252 @@ +""" +Ollama embedding wrapper for LlamaIndex. +Supports local embedding models running via Ollama. +""" + +import os +from typing import Any, List, Optional +from llama_index.core.base.embeddings.base import BaseEmbedding +import httpx +import logging + +from ..models.config import get_embedding_dim_for_model + +logger = logging.getLogger(__name__) + +# Default batch size - can be overridden via OLLAMA_BATCH_SIZE env var +DEFAULT_BATCH_SIZE = int(os.getenv("OLLAMA_BATCH_SIZE", "100")) +# Default timeout - can be overridden via OLLAMA_TIMEOUT env var +DEFAULT_TIMEOUT = float(os.getenv("OLLAMA_TIMEOUT", "120")) + + +class OllamaEmbedding(BaseEmbedding): + """ + Custom embedding class for Ollama API. + + Supports local embedding models like qwen3-embedding:0.6b, nomic-embed-text, etc. + """ + + def __init__( + self, + model: str = "qwen3-embedding:0.6b", + base_url: str = "http://localhost:11434", + timeout: float = None, + embed_batch_size: int = None, + expected_dim: Optional[int] = None, + **kwargs: Any + ): + # Use env-configured defaults if not specified + if timeout is None: + timeout = DEFAULT_TIMEOUT + if embed_batch_size is None: + embed_batch_size = DEFAULT_BATCH_SIZE + + super().__init__(embed_batch_size=embed_batch_size, **kwargs) + + # Determine expected embedding dimension + if expected_dim is not None: + embedding_dim = expected_dim + else: + embedding_dim = get_embedding_dim_for_model(model) + + logger.info(f"OllamaEmbedding: Initializing with model: {model}") + logger.info(f"OllamaEmbedding: Expected embedding dimension: {embedding_dim}") + logger.info(f"OllamaEmbedding: Base URL: {base_url}") + logger.info(f"OllamaEmbedding: Batch size: {embed_batch_size}") + + # Store config using object.__setattr__ to bypass Pydantic validation + object.__setattr__(self, '_config', { + "model": model, + "base_url": base_url.rstrip('/'), + "timeout": timeout, + "embed_batch_size": embed_batch_size, + "embedding_dim": embedding_dim + }) + + # Initialize HTTP client + object.__setattr__(self, '_client', httpx.Client( + base_url=base_url.rstrip('/'), + timeout=timeout + )) + + # Test connection + self._test_connection() + logger.info(f"Ollama embeddings initialized successfully") + + def _test_connection(self): + """Test connection to Ollama server.""" + try: + response = self._client.get("/api/tags") + if response.status_code == 200: + models = response.json().get("models", []) + model_names = [m.get("name", "") for m in models] + logger.info(f"Connected to Ollama. Available models: {model_names}") + + # Check if our model is available + model_name = self._config["model"] + if not any(model_name in name or name in model_name for name in model_names): + logger.warning(f"Model '{model_name}' may not be available. Pull it with: ollama pull {model_name}") + else: + logger.warning(f"Could not list Ollama models: {response.status_code}") + except Exception as e: + logger.warning(f"Could not connect to Ollama at {self._config['base_url']}: {e}") + logger.warning("Make sure Ollama is running: ollama serve") + + def close(self): + """Close the HTTP client and free resources.""" + try: + if hasattr(self, '_client') and self._client: + self._client.close() + logger.info("Ollama embedding client closed") + except Exception as e: + logger.warning(f"Error closing Ollama client: {e}") + + def __del__(self): + """Destructor to ensure client is closed.""" + self.close() + + @property + def model(self) -> str: + """Get the model name.""" + return self._config["model"] + + def _get_query_embedding(self, query: str) -> List[float]: + """Get embedding for a query text.""" + return self._get_embedding(query) + + def _get_text_embedding(self, text: str) -> List[float]: + """Get embedding for a text.""" + return self._get_embedding(text) + + def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """ + Get embeddings for multiple texts using batch API. + Uses /api/embed which supports array input for batching. + """ + if not texts: + return [] + + expected_dim = self._config.get("embedding_dim", 1024) + batch_size = self._config.get("embed_batch_size", DEFAULT_BATCH_SIZE) + all_embeddings = [] + + # Process in batches + for i in range(0, len(texts), batch_size): + batch = texts[i:i + batch_size] + logger.debug(f"Embedding batch {i // batch_size + 1}: {len(batch)} texts") + + # Preprocess batch + processed_batch = [] + for text in batch: + if not text or not text.strip(): + processed_batch.append(" ") # Placeholder for empty + else: + # Truncate if too long + if len(text) > 24000: + text = text[:24000] + processed_batch.append(text.strip()) + + try: + # Use /api/embed with array input for batch processing + response = self._client.post( + "/api/embed", + json={ + "model": self._config["model"], + "input": processed_batch + } + ) + response.raise_for_status() + data = response.json() + + # /api/embed returns {"embeddings": [[...], [...], ...]} + if "embeddings" in data: + batch_embeddings = data["embeddings"] + all_embeddings.extend(batch_embeddings) + else: + logger.error(f"Unexpected batch response format: {list(data.keys())}") + # Fallback to zeros + all_embeddings.extend([[0.0] * expected_dim] * len(processed_batch)) + + except Exception as e: + logger.error(f"Batch embedding failed: {e}, falling back to single requests") + # Fallback to single embedding requests + for text in processed_batch: + try: + embedding = self._get_embedding(text) + all_embeddings.append(embedding) + except Exception: + all_embeddings.append([0.0] * expected_dim) + + return all_embeddings + + def _get_embedding(self, text: str) -> List[float]: + """Get embedding from Ollama API.""" + expected_dim = self._config.get("embedding_dim", 1024) + + try: + # Validate input + if not text or not text.strip(): + logger.warning("Empty text provided for embedding, using placeholder") + return [0.0] * expected_dim + + # Truncate if too long (Ollama typically handles ~8k tokens) + max_chars = 24000 + if len(text) > max_chars: + logger.warning(f"Text too long ({len(text)} chars), truncating to {max_chars}") + text = text[:max_chars] + + # Clean the text + text = text.strip() + if not text: + logger.warning("Text became empty after stripping") + return [0.0] * expected_dim + + # Call Ollama embeddings API + response = self._client.post( + "/api/embeddings", + json={ + "model": self._config["model"], + "prompt": text + } + ) + response.raise_for_status() + data = response.json() + + # Ollama returns embedding in 'embedding' field + if "embedding" in data: + embedding = data["embedding"] + elif "embeddings" in data and len(data["embeddings"]) > 0: + # Fallback for potential future API changes + embedding = data["embeddings"][0] + else: + logger.error(f"Unexpected response format from Ollama: {data}") + return [0.0] * expected_dim + + # Validate embedding dimensions + if len(embedding) != expected_dim: + logger.warning(f"Unexpected embedding dimension: {len(embedding)}, expected {expected_dim}") + # Update expected dim if this is consistently different + self._config["embedding_dim"] = len(embedding) + + return embedding + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error getting embedding from Ollama: {e}") + logger.error(f"Response: {e.response.text if e.response else 'No response'}") + raise + except Exception as e: + logger.error(f"Error getting embedding from Ollama: {e}") + logger.error(f"Text length: {len(text) if text else 0}, Text preview: {text[:100] if text else 'None'}...") + raise + + async def _aget_query_embedding(self, query: str) -> List[float]: + """Async get embedding for a query text.""" + return self._get_query_embedding(query) + + async def _aget_text_embedding(self, text: str) -> List[float]: + """Async get embedding for a text.""" + return self._get_text_embedding(text) + + async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Async batch get embeddings for multiple texts.""" + return self._get_text_embeddings(texts) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py index 69cbc5e9..18826395 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py @@ -1,11 +1,14 @@ import os -from typing import Optional +from typing import Optional, Literal from pydantic import BaseModel, Field, field_validator, model_validator import logging from dotenv import load_dotenv logger = logging.getLogger(__name__) +# Embedding provider types +EmbeddingProvider = Literal["ollama", "openrouter"] + # Known embedding model dimensions EMBEDDING_MODEL_DIMENSIONS = { # OpenAI models @@ -18,9 +21,17 @@ "cohere/embed-multilingual-v3.0": 1024, "voyage/voyage-large-2": 1536, "voyage/voyage-code-2": 1536, - # Alibaba models (often 4096) + # Alibaba/Qwen models "qwen/qwen-embedding": 4096, + "qwen/qwen3-embedding-0.6b": 1024, "qwen/qwen3-embedding-8b": 4096, + # Ollama local models (same Qwen models) + "qwen3-embedding-0.6b": 1024, + "qwen3-embedding:0.6b": 1024, + "snowflake-arctic-embed": 1024, + "nomic-embed-text": 768, + "mxbai-embed-large": 1024, + "all-minilm": 384, # Default fallback "default": 1536, } @@ -45,7 +56,14 @@ class RAGConfig(BaseModel): qdrant_url: str = Field(default_factory=lambda: os.getenv("QDRANT_URL", "http://localhost:6333")) qdrant_collection_prefix: str = Field(default_factory=lambda: os.getenv("QDRANT_COLLECTION_PREFIX", "codecrow")) - # OpenRouter for embeddings + # Embedding provider selection: "ollama" (local) or "openrouter" (cloud) + embedding_provider: str = Field(default_factory=lambda: os.getenv("EMBEDDING_PROVIDER", "ollama")) + + # Ollama configuration (local embeddings - default) + ollama_base_url: str = Field(default_factory=lambda: os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")) + ollama_model: str = Field(default_factory=lambda: os.getenv("OLLAMA_EMBEDDING_MODEL", "qwen3-embedding:0.6b")) + + # OpenRouter configuration (cloud embeddings - optional) openrouter_api_key: str = Field(default_factory=lambda: os.getenv("OPENROUTER_API_KEY", "")) openrouter_model: str = Field(default_factory=lambda: os.getenv("OPENROUTER_MODEL", "qwen/qwen3-embedding-8b")) openrouter_base_url: str = Field(default="https://openrouter.ai/api/v1") @@ -58,21 +76,30 @@ class RAGConfig(BaseModel): def set_embedding_dim_from_model(self) -> 'RAGConfig': """Auto-detect embedding dimension from model if not explicitly set.""" if self.embedding_dim == 0: - self.embedding_dim = get_embedding_dim_for_model(self.openrouter_model) - logger.info(f"Auto-detected embedding dimension {self.embedding_dim} for model {self.openrouter_model}") + # Use the model for the selected provider + model = self.ollama_model if self.embedding_provider == "ollama" else self.openrouter_model + self.embedding_dim = get_embedding_dim_for_model(model) + logger.info(f"Auto-detected embedding dimension {self.embedding_dim} for model {model}") else: logger.info(f"Using configured embedding dimension: {self.embedding_dim}") return self - @field_validator('openrouter_api_key') - @classmethod - def validate_api_key(cls, v: str) -> str: - if not v or v.strip() == "": - logger.error("OPENROUTER_API_KEY is not set or is empty!") - logger.error("Please set the OPENROUTER_API_KEY environment variable") - raise ValueError("OPENROUTER_API_KEY is required but not set") - logger.info(f"OpenRouter API key loaded: {v[:10]}...{v[-4:]}") - return v + @model_validator(mode='after') + def validate_provider_config(self) -> 'RAGConfig': + """Validate that the selected provider has required configuration.""" + if self.embedding_provider == "openrouter": + if not self.openrouter_api_key or self.openrouter_api_key.strip() == "": + logger.error("OPENROUTER_API_KEY is not set but embedding_provider is 'openrouter'!") + raise ValueError("OPENROUTER_API_KEY is required when using OpenRouter provider") + logger.info(f"Using OpenRouter embeddings with model: {self.openrouter_model}") + logger.info(f"OpenRouter API key loaded: {self.openrouter_api_key[:10]}...{self.openrouter_api_key[-4:]}") + elif self.embedding_provider == "ollama": + logger.info(f"Using Ollama local embeddings with model: {self.ollama_model}") + logger.info(f"Ollama base URL: {self.ollama_base_url}") + else: + logger.warning(f"Unknown embedding provider '{self.embedding_provider}', defaulting to 'ollama'") + self.embedding_provider = "ollama" + return self # Chunk size for code files # text-embedding-3-small supports ~8191 tokens (~32K chars) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py index c5bfa64c..a6fb49ae 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py @@ -9,7 +9,7 @@ from ..models.config import RAGConfig from ..models.scoring_config import get_scoring_config from ..utils.utils import make_project_namespace -from ..core.openrouter_embedding import OpenRouterEmbedding +from ..core.embedding_factory import create_embedding_model, get_embedding_model_info from ..models.instructions import InstructionType, format_query logger = logging.getLogger(__name__) @@ -28,14 +28,11 @@ def __init__(self, config: RAGConfig): # Qdrant client self.qdrant_client = QdrantClient(url=config.qdrant_url) - # Embedding model - self.embed_model = OpenRouterEmbedding( - api_key=config.openrouter_api_key, - model=config.openrouter_model, - api_base=config.openrouter_base_url, - timeout=60.0, - max_retries=3 - ) + # Embedding model (supports Ollama and OpenRouter via factory) + embed_info = get_embedding_model_info(config) + logger.info(f"QueryService using embedding provider: {embed_info['provider']} ({embed_info['type']})") + + self.embed_model = create_embedding_model(config) def _collection_or_alias_exists(self, name: str) -> bool: """Check if a collection or alias with the given name exists.""" From c90cb4e9452c8da262b520e21bc0add542dccbba Mon Sep 17 00:00:00 2001 From: rostislav Date: Wed, 4 Feb 2026 21:56:56 +0200 Subject: [PATCH 2/7] feat: Implement optimistic locking in VcsConnection and enhance transaction management in JobService and WebhookAsyncProcessor --- .../core/model/vcs/VcsConnection.java | 8 ++ .../codecrow/core/service/JobService.java | 18 +++-- .../security/web/WebSecurityConfig.java | 77 +++++++++---------- .../codecrow/vcsclient/VcsClientProvider.java | 9 ++- .../processor/WebhookAsyncProcessor.java | 3 +- 5 files changed, 66 insertions(+), 49 deletions(-) diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/vcs/VcsConnection.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/vcs/VcsConnection.java index 383ea9b1..2c0a74d7 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/vcs/VcsConnection.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/vcs/VcsConnection.java @@ -116,6 +116,14 @@ public class VcsConnection { @Column(name = "updated_at") private LocalDateTime updatedAt; + /** + * Version field for optimistic locking. + * Prevents concurrent token refresh operations from overwriting each other. + */ + @Version + @Column(name = "version") + private Long version; + /** * Provider-specific configuration (JSON column). * Stores additional settings like OAuth keys for manual connections. diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java index cdae83c8..3feac5fb 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java @@ -380,8 +380,9 @@ public Job updateProgress(Job job, int progress, String currentStep) { /** * Add a log entry to a job. + * Uses REQUIRES_NEW to commit immediately and not hold the parent transaction open. */ - @Transactional + @Transactional(propagation = org.springframework.transaction.annotation.Propagation.REQUIRES_NEW, timeout = 10) public JobLog addLog(Job job, JobLogLevel level, String step, String message) { JobLog logEntry = new JobLog(); logEntry.setJob(job); @@ -425,8 +426,9 @@ public JobLog addLogInNewTransaction(Job job, JobLogLevel level, String step, St /** * Add a log entry with metadata. + * Uses REQUIRES_NEW to commit immediately and not hold the parent transaction open. */ - @Transactional + @Transactional(propagation = org.springframework.transaction.annotation.Propagation.REQUIRES_NEW, timeout = 10) public JobLog addLog(Job job, JobLogLevel level, String step, String message, Map metadata) { JobLog logEntry = new JobLog(); logEntry.setJob(job); @@ -449,32 +451,36 @@ public JobLog addLog(Job job, JobLogLevel level, String step, String message, Ma /** * Add an info log. + * Uses REQUIRES_NEW to commit immediately. */ - @Transactional + @Transactional(propagation = org.springframework.transaction.annotation.Propagation.REQUIRES_NEW, timeout = 10) public JobLog info(Job job, String step, String message) { return addLog(job, JobLogLevel.INFO, step, message); } /** * Add a warning log. + * Uses REQUIRES_NEW to commit immediately. */ - @Transactional + @Transactional(propagation = org.springframework.transaction.annotation.Propagation.REQUIRES_NEW, timeout = 10) public JobLog warn(Job job, String step, String message) { return addLog(job, JobLogLevel.WARN, step, message); } /** * Add an error log. + * Uses REQUIRES_NEW to commit immediately. */ - @Transactional + @Transactional(propagation = org.springframework.transaction.annotation.Propagation.REQUIRES_NEW, timeout = 10) public JobLog error(Job job, String step, String message) { return addLog(job, JobLogLevel.ERROR, step, message); } /** * Add a debug log. + * Uses REQUIRES_NEW to commit immediately. */ - @Transactional + @Transactional(propagation = org.springframework.transaction.annotation.Propagation.REQUIRES_NEW, timeout = 10) public JobLog debug(Job job, String step, String message) { return addLog(job, JobLogLevel.DEBUG, step, message); } diff --git a/java-ecosystem/libs/security/src/main/java/org/rostilos/codecrow/security/web/WebSecurityConfig.java b/java-ecosystem/libs/security/src/main/java/org/rostilos/codecrow/security/web/WebSecurityConfig.java index 4c8eef38..10a65065 100644 --- a/java-ecosystem/libs/security/src/main/java/org/rostilos/codecrow/security/web/WebSecurityConfig.java +++ b/java-ecosystem/libs/security/src/main/java/org/rostilos/codecrow/security/web/WebSecurityConfig.java @@ -26,26 +26,23 @@ import org.springframework.web.cors.UrlBasedCorsConfigurationSource; @Configuration -@EnableMethodSecurity - (securedEnabled = true, - jsr250Enabled = true) +@EnableMethodSecurity(securedEnabled = true, jsr250Enabled = true) public class WebSecurityConfig { @Value("${codecrow.security.encryption-key}") private String encryptionKey; @Value("${codecrow.security.encryption-key-old}") private String oldEncryptionKey; - + @Value("${codecrow.internal.api.secret:}") private String internalApiSecret; - + private final UserDetailsServiceImpl userDetailsService; private final AuthEntryPoint unauthorizedHandler; public WebSecurityConfig( UserDetailsServiceImpl userDetailsService, - AuthEntryPoint unauthorizedHandler - ) { + AuthEntryPoint unauthorizedHandler) { this.userDetailsService = userDetailsService; this.unauthorizedHandler = unauthorizedHandler; } @@ -85,7 +82,8 @@ public CorsConfigurationSource corsConfigurationSource() { CorsConfiguration configuration = new CorsConfiguration(); // TODO: replace Arrays.asList("*") with the explicit domain(s) - // TODO: Example: Arrays.asList("http://localhost:8080", "https://frontend.rostilos.pp.ua") + // TODO: Example: Arrays.asList("http://localhost:8080", + // "https://frontend.rostilos.pp.ua") configuration.setAllowedOriginPatterns(Arrays.asList("*")); configuration.setAllowedMethods(Arrays.asList("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")); @@ -109,38 +107,37 @@ public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { .headers(headers -> headers .frameOptions(HeadersConfigurer.FrameOptionsConfig::disable) .contentSecurityPolicy(csp -> csp - .policyDirectives("frame-ancestors 'self' https://bitbucket.org https://*.bitbucket.org") - ) - ) - .authorizeHttpRequests(auth -> - auth - // Allow async dispatches to complete (SSE, streaming responses) - .dispatcherTypeMatchers(DispatcherType.ASYNC, DispatcherType.ERROR).permitAll() - // Allow all OPTIONS requests (CORS preflight) - .requestMatchers(org.springframework.http.HttpMethod.OPTIONS, "/**").permitAll() - // Allow error page to be rendered without authentication - .requestMatchers("/error").permitAll() - .requestMatchers("/api/auth/**").permitAll() - .requestMatchers("/api/test/**").permitAll() - // OAuth callbacks need to be public (called by VCS providers) - .requestMatchers("/api/*/integrations/*/app/callback").permitAll() - // Generic OAuth callbacks without workspace slug (for GitHub, etc.) - .requestMatchers("/api/integrations/*/app/callback").permitAll() - .requestMatchers("/actuator/**").permitAll() - .requestMatchers("/internal/projects/**").permitAll() - .requestMatchers("/api/internal/**").permitAll() - .requestMatchers("/swagger-ui-custom.html").permitAll() - .requestMatchers("/api-docs").permitAll() - // Bitbucket Connect App lifecycle callbacks (uses JWT auth) - .requestMatchers("/api/bitbucket/connect/descriptor").permitAll() - .requestMatchers("/api/bitbucket/connect/installed").permitAll() - .requestMatchers("/api/bitbucket/connect/uninstalled").permitAll() - .requestMatchers("/api/bitbucket/connect/enabled").permitAll() - .requestMatchers("/api/bitbucket/connect/disabled").permitAll() - .requestMatchers("/api/bitbucket/connect/status").permitAll() - .requestMatchers("/api/bitbucket/connect/configure").permitAll() - .anyRequest().authenticated() - ); + .policyDirectives( + "frame-ancestors 'self' https://bitbucket.org https://*.bitbucket.org"))) + .authorizeHttpRequests(auth -> auth + // Allow async dispatches to complete (SSE, streaming responses) + .dispatcherTypeMatchers(DispatcherType.ASYNC, DispatcherType.ERROR).permitAll() + // Allow all OPTIONS requests (CORS preflight) + .requestMatchers(org.springframework.http.HttpMethod.OPTIONS, "/**").permitAll() + // Allow error page to be rendered without authentication + .requestMatchers("/error").permitAll() + .requestMatchers("/api/auth/**").permitAll() + .requestMatchers("/api/test/**").permitAll() + // OAuth callbacks need to be public (called by VCS providers) + .requestMatchers("/api/*/integrations/*/app/callback").permitAll() + // Generic OAuth callbacks without workspace slug (for GitHub, etc.) + .requestMatchers("/api/integrations/*/app/callback").permitAll() + .requestMatchers("/actuator/**").permitAll() + .requestMatchers("/internal/projects/**").permitAll() + .requestMatchers("/api/internal/**").permitAll() + .requestMatchers("/swagger-ui-custom.html").permitAll() + .requestMatchers("/api-docs").permitAll() + // Bitbucket Connect App lifecycle callbacks (uses JWT auth) + .requestMatchers("/api/bitbucket/connect/descriptor").permitAll() + .requestMatchers("/api/bitbucket/connect/installed").permitAll() + .requestMatchers("/api/bitbucket/connect/uninstalled").permitAll() + .requestMatchers("/api/bitbucket/connect/enabled").permitAll() + .requestMatchers("/api/bitbucket/connect/disabled").permitAll() + .requestMatchers("/api/bitbucket/connect/status").permitAll() + .requestMatchers("/api/bitbucket/connect/configure").permitAll() + // Stripe webhooks (authenticate via signature verification, not JWT) + .requestMatchers("/api/webhooks/**").permitAll() + .anyRequest().authenticated()); http.authenticationProvider(authenticationProvider()); diff --git a/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/VcsClientProvider.java b/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/VcsClientProvider.java index 1a6ad7de..7190f68b 100644 --- a/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/VcsClientProvider.java +++ b/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/VcsClientProvider.java @@ -210,7 +210,7 @@ public boolean needsTokenRefresh(VcsConnection connection) { * @return updated connection with new tokens * @throws VcsClientException if refresh fails */ - @Transactional + @Transactional(timeout = 30) // 30 second timeout to prevent deadlocks public VcsConnection refreshToken(VcsConnection connection) { log.info("Refreshing access token for connection: {} (provider: {}, type: {})", connection.getId(), connection.getProviderType(), connection.getConnectionType()); @@ -415,7 +415,12 @@ private TokenResponse refreshGitLabToken(String refreshToken) throws IOException throw new IOException("GitLab OAuth credentials not configured. Set codecrow.gitlab.oauth.client-id and codecrow.gitlab.oauth.client-secret"); } - OkHttpClient httpClient = new OkHttpClient(); + // Use short timeouts to prevent holding database locks during slow network operations + OkHttpClient httpClient = new OkHttpClient.Builder() + .connectTimeout(10, java.util.concurrent.TimeUnit.SECONDS) + .readTimeout(15, java.util.concurrent.TimeUnit.SECONDS) + .writeTimeout(10, java.util.concurrent.TimeUnit.SECONDS) + .build(); // Determine GitLab token URL (support self-hosted) String tokenUrl = (gitlabBaseUrl != null && !gitlabBaseUrl.isBlank() && !gitlabBaseUrl.equals("https://gitlab.com")) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index 10ac7b5d..da2a0787 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -92,8 +92,9 @@ public void processWebhookAsync( /** * Process webhook within a transaction. * Called from async method via self-injection to ensure transaction proxy works. + * Note: Transaction timeout set to 5 minutes to prevent indefinite 'idle in transaction' states. */ - @Transactional + @Transactional(timeout = 300) public void processWebhookInTransaction( EVcsProvider provider, Long projectId, From 2958327d28d378be3aa3dda72c802bff8ba87ab2 Mon Sep 17 00:00:00 2001 From: rostislav Date: Thu, 5 Feb 2026 20:36:51 +0200 Subject: [PATCH 3/7] feat: Add workspace management interface and implement service methods for CRUD operations. Cloud version preparations --- deployment/config/web-frontend/.env.sample | 17 ++ .../project/service/IProjectService.java | 113 +++++++++ .../project/service/ProjectService.java | 227 +++++++++--------- .../workspace/service/IWorkspaceService.java | 64 +++++ .../workspace/service/WorkspaceService.java | 59 +++-- 5 files changed, 342 insertions(+), 138 deletions(-) create mode 100644 java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/IProjectService.java create mode 100644 java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/workspace/service/IWorkspaceService.java diff --git a/deployment/config/web-frontend/.env.sample b/deployment/config/web-frontend/.env.sample index aa2fa5fd..c9f9f269 100644 --- a/deployment/config/web-frontend/.env.sample +++ b/deployment/config/web-frontend/.env.sample @@ -5,6 +5,23 @@ SERVER_PORT=8080 VITE_BLOG_URL=http://localhost:8083 +# ============================================================================ +# Feature Flags (Cloud-specific features - disabled by default for OSS) +# ============================================================================ +# Set to "true" to enable cloud features. Leave empty or "false" to disable. +# +# VITE_FEATURE_BILLING - Enables billing page, subscription management, payment methods +VITE_FEATURE_BILLING=false +# +# VITE_FEATURE_CLOUD_PLANS - Enables cloud subscription plans (Pro, Pro+, Enterprise) +VITE_FEATURE_CLOUD_PLANS=false +# +# VITE_FEATURE_USAGE_ANALYTICS - Enables usage tracking and quota management +VITE_FEATURE_USAGE_ANALYTICS=false +# +# VITE_FEATURE_ENTERPRISE - Enables SSO, SAML, and advanced team management +VITE_FEATURE_ENTERPRISE=false + # New Relic Browser Monitoring (Optional - leave empty to disable) # Get these values from: https://one.newrelic.com -> Browser -> Add data -> Browser monitoring diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/IProjectService.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/IProjectService.java new file mode 100644 index 00000000..154b949e --- /dev/null +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/IProjectService.java @@ -0,0 +1,113 @@ +package org.rostilos.codecrow.webserver.project.service; + +import java.security.GeneralSecurityException; +import java.util.List; + +import org.rostilos.codecrow.core.model.branch.Branch; +import org.rostilos.codecrow.core.model.project.Project; +import org.rostilos.codecrow.core.model.project.config.BranchAnalysisConfig; +import org.rostilos.codecrow.core.model.project.config.CommentCommandsConfig; +import org.rostilos.codecrow.core.model.project.config.InstallationMethod; +import org.rostilos.codecrow.core.model.vcs.EVcsProvider; +import org.rostilos.codecrow.webserver.project.dto.request.BindAiConnectionRequest; +import org.rostilos.codecrow.webserver.project.dto.request.BindRepositoryRequest; +import org.rostilos.codecrow.webserver.project.dto.request.ChangeVcsConnectionRequest; +import org.rostilos.codecrow.webserver.project.dto.request.CreateProjectRequest; +import org.rostilos.codecrow.webserver.project.dto.request.UpdateCommentCommandsConfigRequest; +import org.rostilos.codecrow.webserver.project.dto.request.UpdateProjectRequest; +import org.rostilos.codecrow.webserver.project.dto.request.UpdateRepositorySettingsRequest; +import org.springframework.data.domain.Page; + +/** + * Interface for project management operations. + *

+ * This interface enables cloud implementations to extend or decorate + * the base project service with additional capabilities like billing limits. + */ +public interface IProjectService { + + // ==================== Core CRUD ==================== + + List listWorkspaceProjects(Long workspaceId); + + Page listWorkspaceProjectsPaginated(Long workspaceId, String search, int page, int size); + + Project createProject(Long workspaceId, CreateProjectRequest request); + + Project getProjectById(Long projectId); + + Project getProjectByWorkspaceAndNamespace(Long workspaceId, String namespace); + + Project updateProject(Long workspaceId, Long projectId, UpdateProjectRequest request); + + void deleteProject(Long workspaceId, Long projectId); + + void deleteProjectByNamespace(Long workspaceId, String namespace); + + // ==================== Repository Binding ==================== + + Project bindRepository(Long workspaceId, Long projectId, BindRepositoryRequest request); + + Project unbindRepository(Long workspaceId, Long projectId); + + void updateRepositorySettings(Long workspaceId, Long projectId, UpdateRepositorySettingsRequest request) + throws GeneralSecurityException; + + Project changeVcsConnection(Long workspaceId, Long projectId, ChangeVcsConnectionRequest request); + + // ==================== AI Connection ==================== + + boolean bindAiConnection(Long workspaceId, Long projectId, BindAiConnectionRequest request); + + // ==================== Branch Management ==================== + + List getProjectBranches(Long workspaceId, String namespace); + + Project setDefaultBranch(Long workspaceId, String namespace, Long branchId); + + Project setDefaultBranchByName(Long workspaceId, String namespace, String branchName); + + // ==================== Configuration ==================== + + BranchAnalysisConfig getBranchAnalysisConfig(Project project); + + Project updateBranchAnalysisConfig(Long workspaceId, Long projectId, + List prTargetBranches, List branchPushPatterns); + + Project updateRagConfig(Long workspaceId, Long projectId, boolean enabled, String branch, + List excludePatterns, Boolean multiBranchEnabled, Integer branchRetentionDays); + + Project updateRagConfig(Long workspaceId, Long projectId, boolean enabled, String branch, + List excludePatterns); + + Project updateAnalysisSettings(Long workspaceId, Long projectId, Boolean prAnalysisEnabled, + Boolean branchAnalysisEnabled, InstallationMethod installationMethod, Integer maxAnalysisTokenLimit); + + Project updateProjectQualityGate(Long workspaceId, Long projectId, Long qualityGateId); + + CommentCommandsConfig getCommentCommandsConfig(Project project); + + Project updateCommentCommandsConfig(Long workspaceId, Long projectId, UpdateCommentCommandsConfigRequest request); + + // ==================== Webhooks ==================== + + WebhookSetupResult setupWebhooks(Long workspaceId, Long projectId); + + WebhookInfo getWebhookInfo(Long workspaceId, Long projectId); + + // ==================== DTOs ==================== + + record WebhookSetupResult( + boolean success, + String webhookId, + String webhookUrl, + String message) { + } + + record WebhookInfo( + boolean webhooksConfigured, + String webhookId, + String webhookUrl, + EVcsProvider provider) { + } +} diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java index f14718a9..aa2498cc 100644 --- a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java @@ -58,9 +58,9 @@ import org.slf4j.LoggerFactory; @Service -public class ProjectService { +public class ProjectService implements IProjectService { private static final Logger log = LoggerFactory.getLogger(ProjectService.class); - + private final ProjectRepository projectRepository; private final VcsConnectionRepository vcsConnectionRepository; private final TokenEncryptionService tokenEncryptionService; @@ -103,8 +103,7 @@ public ProjectService( JobLogRepository jobLogRepository, PrSummarizeCacheRepository prSummarizeCacheRepository, VcsClientProvider vcsClientProvider, - QualityGateRepository qualityGateRepository - ) { + QualityGateRepository qualityGateRepository) { this.projectRepository = projectRepository; this.vcsConnectionRepository = vcsConnectionRepository; this.tokenEncryptionService = tokenEncryptionService; @@ -128,21 +127,21 @@ public ProjectService( @Transactional(readOnly = true) public List listWorkspaceProjects(Long workspaceId) { - // Use the method that fetches default branch eagerly to include stats in project list + // Use the method that fetches default branch eagerly to include stats in + // project list return projectRepository.findByWorkspaceIdWithDefaultBranch(workspaceId); } @Transactional(readOnly = true) public org.springframework.data.domain.Page listWorkspaceProjectsPaginated( - Long workspaceId, - String search, - int page, + Long workspaceId, + String search, + int page, int size) { org.springframework.data.domain.Pageable pageable = org.springframework.data.domain.PageRequest.of( - page, - size, - org.springframework.data.domain.Sort.by(org.springframework.data.domain.Sort.Direction.DESC, "id") - ); + page, + size, + org.springframework.data.domain.Sort.by(org.springframework.data.domain.Sort.Direction.DESC, "id")); return projectRepository.findByWorkspaceIdWithSearchPaginated(workspaceId, search, pageable); } @@ -159,7 +158,9 @@ public Project createProject(Long workspaceId, CreateProjectRequest request) thr // ensure namespace uniqueness per workspace projectRepository.findByWorkspaceIdAndNamespace(workspaceId, request.getNamespace()) - .ifPresent(p -> { throw new InvalidProjectRequestException("Project namespace already exists in workspace"); }); + .ifPresent(p -> { + throw new InvalidProjectRequestException("Project namespace already exists in workspace"); + }); Project newProject = new Project(); newProject.setWorkspace(ws); @@ -179,10 +180,12 @@ public Project createProject(Long workspaceId, CreateProjectRequest request) thr newProject.setConfiguration(config); if (request.hasVcsConnection()) { - VcsConnection vcsConnection = vcsConnectionRepository.findByWorkspace_IdAndId(workspaceId, request.getVcsConnectionId()) + VcsConnection vcsConnection = vcsConnectionRepository + .findByWorkspace_IdAndId(workspaceId, request.getVcsConnectionId()) .orElseThrow(() -> new NoSuchElementException("VCS connection not found!")); - // Use VcsRepoBinding (provider-agnostic) instead of legacy ProjectVcsConnectionBinding + // Use VcsRepoBinding (provider-agnostic) instead of legacy + // ProjectVcsConnectionBinding VcsRepoBinding vcsRepoBinding = new VcsRepoBinding(); vcsRepoBinding.setProject(newProject); vcsRepoBinding.setWorkspace(ws); @@ -193,7 +196,8 @@ public Project createProject(Long workspaceId, CreateProjectRequest request) thr if (request.getRepositoryUUID() != null) { vcsRepoBinding.setExternalRepoId(request.getRepositoryUUID().toString()); } - // For Bitbucket, extract workspace from config; for other providers, use a default + // For Bitbucket, extract workspace from config; for other providers, use a + // default String externalNamespace = null; if (vcsConnection.getConfiguration() instanceof BitbucketCloudConfig bbConfig) { externalNamespace = bbConfig.workspaceId(); @@ -206,7 +210,8 @@ public Project createProject(Long workspaceId, CreateProjectRequest request) thr } if (request.getAiConnectionId() != null) { - AIConnection aiConnection = aiConnectionRepository.findByWorkspace_IdAndId(workspaceId, request.getAiConnectionId()) + AIConnection aiConnection = aiConnectionRepository + .findByWorkspace_IdAndId(workspaceId, request.getAiConnectionId()) .orElseThrow(() -> new NoSuchElementException("AI connection not found!")); ProjectAiConnectionBinding aiBinding = new ProjectAiConnectionBinding(); @@ -222,7 +227,8 @@ public Project createProject(Long workspaceId, CreateProjectRequest request) thr String plainToken = Base64.getUrlEncoder().withoutPadding().encodeToString(random); String encrypted = tokenEncryptionService.encrypt(plainToken); newProject.setAuthToken(encrypted); - // Note: plainToken is not returned in API responses; store encrypted token only. + // Note: plainToken is not returned in API responses; store encrypted token + // only. } catch (GeneralSecurityException e) { throw new SecurityException("Failed to generate project auth token"); } @@ -345,7 +351,7 @@ public Project bindRepository(Long workspaceId, Long projectId, BindRepositoryRe VcsConnection conn = vcsConnectionRepository.findByWorkspace_IdAndId(workspaceId, request.getConnectionId()) .orElseThrow(() -> new NoSuchElementException("Connection not found")); } - //TODO: bind implementation + // TODO: bind implementation return projectRepository.save(p); } @@ -353,14 +359,14 @@ public Project bindRepository(Long workspaceId, Long projectId, BindRepositoryRe public Project unbindRepository(Long workspaceId, Long projectId) { Project p = projectRepository.findByWorkspaceIdAndId(workspaceId, projectId) .orElseThrow(() -> new NoSuchElementException("Project not found")); - //TODO: unbind implementation + // TODO: unbind implementation // Clear settings placeholder if used in future return projectRepository.save(p); } - @Transactional - public void updateRepositorySettings(Long workspaceId, Long projectId, UpdateRepositorySettingsRequest request) throws GeneralSecurityException { + public void updateRepositorySettings(Long workspaceId, Long projectId, UpdateRepositorySettingsRequest request) + throws GeneralSecurityException { Project p = projectRepository.findByWorkspaceIdAndId(workspaceId, projectId) .orElseThrow(() -> new NoSuchElementException("Project not found")); @@ -368,17 +374,19 @@ public void updateRepositorySettings(Long workspaceId, Long projectId, UpdateRep String encrypted = tokenEncryptionService.encrypt(request.getToken()); p.setAuthToken(encrypted); } - //TODO: service implementation + // TODO: service implementation // apiBaseUrl and webhookSecret ignored for now (no fields yet) projectRepository.save(p); } @Transactional - public boolean bindAiConnection(Long workspaceId, Long projectId, BindAiConnectionRequest request) throws SecurityException { - // Use findByIdWithConnections to eagerly fetch aiBinding for proper orphan removal + public boolean bindAiConnection(Long workspaceId, Long projectId, BindAiConnectionRequest request) + throws SecurityException { + // Use findByIdWithConnections to eagerly fetch aiBinding for proper orphan + // removal Project project = projectRepository.findByIdWithConnections(projectId) .orElseThrow(() -> new NoSuchElementException("Project not found")); - + // Verify workspace ownership if (!project.getWorkspace().getId().equals(workspaceId)) { throw new NoSuchElementException("Project not found in workspace"); @@ -401,7 +409,7 @@ public boolean bindAiConnection(Long workspaceId, Long projectId, BindAiConnecti aiConnectionBinding.setAiConnection(aiConnection); project.setAiConnectionBinding(aiConnectionBinding); } - + projectRepository.save(project); return true; } @@ -461,7 +469,8 @@ public Project setDefaultBranchByName(Long workspaceId, String namespace, String Project project = getProjectByWorkspaceAndNamespace(workspaceId, namespace); Branch branch = branchRepository.findByProjectIdAndBranchName(project.getId(), branchName) - .orElseThrow(() -> new NoSuchElementException("Branch '" + branchName + "' not found for this project")); + .orElseThrow( + () -> new NoSuchElementException("Branch '" + branchName + "' not found for this project")); project.setDefaultBranch(branch); return projectRepository.save(project); @@ -482,10 +491,13 @@ public BranchAnalysisConfig getBranchAnalysisConfig(Project project) { /** * Update the branch analysis configuration for a project. * Main branch is always ensured to be in the patterns. - * @param workspaceId the workspace ID - * @param projectId the project ID - * @param prTargetBranches patterns for PR target branches (e.g., ["main", "develop", "release/*"]) - * @param branchPushPatterns patterns for branch push analysis (e.g., ["main", "develop"]) + * + * @param workspaceId the workspace ID + * @param projectId the project ID + * @param prTargetBranches patterns for PR target branches (e.g., ["main", + * "develop", "release/*"]) + * @param branchPushPatterns patterns for branch push analysis (e.g., ["main", + * "develop"]) * @return the updated project */ @Transactional @@ -493,8 +505,7 @@ public Project updateBranchAnalysisConfig( Long workspaceId, Long projectId, List prTargetBranches, - List branchPushPatterns - ) { + List branchPushPatterns) { Project project = projectRepository.findByWorkspaceIdAndId(workspaceId, projectId) .orElseThrow(() -> new NoSuchElementException("Project not found")); @@ -505,8 +516,7 @@ public Project updateBranchAnalysisConfig( BranchAnalysisConfig branchConfig = new BranchAnalysisConfig( prTargetBranches, - branchPushPatterns - ); + branchPushPatterns); currentConfig.setBranchAnalysis(branchConfig); // Ensure main branch is always included in patterns @@ -517,12 +527,14 @@ public Project updateBranchAnalysisConfig( /** * Update the RAG configuration for a project. - * @param workspaceId the workspace ID - * @param projectId the project ID - * @param enabled whether RAG indexing is enabled - * @param branch the branch to index (null uses defaultBranch or 'main') - * @param excludePatterns patterns to exclude from indexing - * @param multiBranchEnabled whether multi-branch indexing is enabled + * + * @param workspaceId the workspace ID + * @param projectId the project ID + * @param enabled whether RAG indexing is enabled + * @param branch the branch to index (null uses defaultBranch or + * 'main') + * @param excludePatterns patterns to exclude from indexing + * @param multiBranchEnabled whether multi-branch indexing is enabled * @param branchRetentionDays how long to keep branch index metadata * @return the updated project */ @@ -534,8 +546,7 @@ public Project updateRagConfig( String branch, java.util.List excludePatterns, Boolean multiBranchEnabled, - Integer branchRetentionDays - ) { + Integer branchRetentionDays) { Project project = projectRepository.findByWorkspaceIdAndId(workspaceId, projectId) .orElseThrow(() -> new NoSuchElementException("Project not found")); @@ -555,7 +566,7 @@ public Project updateRagConfig( prAnalysisEnabled, branchAnalysisEnabled, installationMethod, commentCommands)); return projectRepository.save(project); } - + /** * Simplified RAG config update (backward compatible). */ @@ -565,8 +576,7 @@ public Project updateRagConfig( Long projectId, boolean enabled, String branch, - java.util.List excludePatterns - ) { + java.util.List excludePatterns) { return updateRagConfig(workspaceId, projectId, enabled, branch, excludePatterns, null, null); } @@ -577,8 +587,7 @@ public Project updateAnalysisSettings( Boolean prAnalysisEnabled, Boolean branchAnalysisEnabled, InstallationMethod installationMethod, - Integer maxAnalysisTokenLimit - ) { + Integer maxAnalysisTokenLimit) { Project project = projectRepository.findByWorkspaceIdAndId(workspaceId, projectId) .orElseThrow(() -> new NoSuchElementException("Project not found")); @@ -588,18 +597,19 @@ public Project updateAnalysisSettings( var branchAnalysis = currentConfig != null ? currentConfig.branchAnalysis() : null; var ragConfig = currentConfig != null ? currentConfig.ragConfig() : null; var commentCommands = currentConfig != null ? currentConfig.commentCommands() : null; - int currentMaxTokenLimit = currentConfig != null ? currentConfig.maxAnalysisTokenLimit() : ProjectConfig.DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT; - - Boolean newPrAnalysis = prAnalysisEnabled != null ? prAnalysisEnabled : - (currentConfig != null ? currentConfig.prAnalysisEnabled() : true); - Boolean newBranchAnalysis = branchAnalysisEnabled != null ? branchAnalysisEnabled : - (currentConfig != null ? currentConfig.branchAnalysisEnabled() : true); - var newInstallationMethod = installationMethod != null ? installationMethod : - (currentConfig != null ? currentConfig.installationMethod() : null); + int currentMaxTokenLimit = currentConfig != null ? currentConfig.maxAnalysisTokenLimit() + : ProjectConfig.DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT; + + Boolean newPrAnalysis = prAnalysisEnabled != null ? prAnalysisEnabled + : (currentConfig != null ? currentConfig.prAnalysisEnabled() : true); + Boolean newBranchAnalysis = branchAnalysisEnabled != null ? branchAnalysisEnabled + : (currentConfig != null ? currentConfig.branchAnalysisEnabled() : true); + var newInstallationMethod = installationMethod != null ? installationMethod + : (currentConfig != null ? currentConfig.installationMethod() : null); int newMaxTokenLimit = maxAnalysisTokenLimit != null ? maxAnalysisTokenLimit : currentMaxTokenLimit; // Update both the direct column and the JSON config - //TODO: remove duplication + // TODO: remove duplication project.setPrAnalysisEnabled(newPrAnalysis != null ? newPrAnalysis : true); project.setBranchAnalysisEnabled(newBranchAnalysis != null ? newBranchAnalysis : true); @@ -610,8 +620,9 @@ public Project updateAnalysisSettings( /** * Update the quality gate for a project. - * @param workspaceId the workspace ID - * @param projectId the project ID + * + * @param workspaceId the workspace ID + * @param projectId the project ID * @param qualityGateId the quality gate ID (null to remove) * @return the updated project */ @@ -619,7 +630,7 @@ public Project updateAnalysisSettings( public Project updateProjectQualityGate(Long workspaceId, Long projectId, Long qualityGateId) { Project project = projectRepository.findByWorkspaceIdAndId(workspaceId, projectId) .orElseThrow(() -> new NoSuchElementException("Project not found")); - + if (qualityGateId != null) { QualityGate qualityGate = qualityGateRepository.findByIdAndWorkspaceId(qualityGateId, workspaceId) .orElseThrow(() -> new NoSuchElementException("Quality gate not found")); @@ -627,13 +638,14 @@ public Project updateProjectQualityGate(Long workspaceId, Long projectId, Long q } else { project.setQualityGate(null); } - + return projectRepository.save(project); } /** * Get the comment commands configuration for a project. - * Returns a CommentCommandsConfig record (never null, returns default disabled config if not configured). + * Returns a CommentCommandsConfig record (never null, returns default disabled + * config if not configured). */ @Transactional(readOnly = true) public CommentCommandsConfig getCommentCommandsConfig(Project project) { @@ -645,17 +657,17 @@ public CommentCommandsConfig getCommentCommandsConfig(Project project) { /** * Update the comment commands configuration for a project. + * * @param workspaceId the workspace ID - * @param projectId the project ID - * @param request the update request + * @param projectId the project ID + * @param request the update request * @return the updated project */ @Transactional public Project updateCommentCommandsConfig( Long workspaceId, Long projectId, - org.rostilos.codecrow.webserver.project.dto.request.UpdateCommentCommandsConfigRequest request - ) { + org.rostilos.codecrow.webserver.project.dto.request.UpdateCommentCommandsConfigRequest request) { Project project = projectRepository.findByWorkspaceIdAndId(workspaceId, projectId) .orElseThrow(() -> new NoSuchElementException("Project not found")); @@ -671,25 +683,27 @@ public Project updateCommentCommandsConfig( // Build new comment commands config var existingCommentConfig = currentConfig != null ? currentConfig.commentCommands() : null; - boolean enabled = request.enabled() != null ? request.enabled() : - (existingCommentConfig != null ? existingCommentConfig.enabled() : false); - Integer rateLimit = request.rateLimit() != null ? request.rateLimit() : - (existingCommentConfig != null ? existingCommentConfig.rateLimit() : CommentCommandsConfig.DEFAULT_RATE_LIMIT); - Integer rateLimitWindow = request.rateLimitWindowMinutes() != null ? request.rateLimitWindowMinutes() : - (existingCommentConfig != null ? existingCommentConfig.rateLimitWindowMinutes() : CommentCommandsConfig.DEFAULT_RATE_LIMIT_WINDOW_MINUTES); - Boolean allowPublicRepoCommands = request.allowPublicRepoCommands() != null ? request.allowPublicRepoCommands() : - (existingCommentConfig != null ? existingCommentConfig.allowPublicRepoCommands() : false); - List allowedCommands = request.allowedCommands() != null ? request.validatedAllowedCommands() : - (existingCommentConfig != null ? existingCommentConfig.allowedCommands() : null); - CommandAuthorizationMode authorizationMode = request.authorizationMode() != null ? request.authorizationMode() : - (existingCommentConfig != null ? existingCommentConfig.authorizationMode() : CommentCommandsConfig.DEFAULT_AUTHORIZATION_MODE); - Boolean allowPrAuthor = request.allowPrAuthor() != null ? request.allowPrAuthor() : - (existingCommentConfig != null ? existingCommentConfig.allowPrAuthor() : true); + boolean enabled = request.enabled() != null ? request.enabled() + : (existingCommentConfig != null ? existingCommentConfig.enabled() : false); + Integer rateLimit = request.rateLimit() != null ? request.rateLimit() + : (existingCommentConfig != null ? existingCommentConfig.rateLimit() + : CommentCommandsConfig.DEFAULT_RATE_LIMIT); + Integer rateLimitWindow = request.rateLimitWindowMinutes() != null ? request.rateLimitWindowMinutes() + : (existingCommentConfig != null ? existingCommentConfig.rateLimitWindowMinutes() + : CommentCommandsConfig.DEFAULT_RATE_LIMIT_WINDOW_MINUTES); + Boolean allowPublicRepoCommands = request.allowPublicRepoCommands() != null ? request.allowPublicRepoCommands() + : (existingCommentConfig != null ? existingCommentConfig.allowPublicRepoCommands() : false); + List allowedCommands = request.allowedCommands() != null ? request.validatedAllowedCommands() + : (existingCommentConfig != null ? existingCommentConfig.allowedCommands() : null); + CommandAuthorizationMode authorizationMode = request.authorizationMode() != null ? request.authorizationMode() + : (existingCommentConfig != null ? existingCommentConfig.authorizationMode() + : CommentCommandsConfig.DEFAULT_AUTHORIZATION_MODE); + Boolean allowPrAuthor = request.allowPrAuthor() != null ? request.allowPrAuthor() + : (existingCommentConfig != null ? existingCommentConfig.allowPrAuthor() : true); var commentCommands = new CommentCommandsConfig( enabled, rateLimit, rateLimitWindow, allowPublicRepoCommands, allowedCommands, - authorizationMode, allowPrAuthor - ); + authorizationMode, allowPrAuthor); project.setConfiguration(new ProjectConfig(useLocalMcp, mainBranch, branchAnalysis, ragConfig, prAnalysisEnabled, branchAnalysisEnabled, installationMethod, commentCommands)); @@ -738,14 +752,14 @@ public WebhookSetupResult setupWebhooks(Long workspaceId, Long projectId) { // Get VCS client and setup webhook org.rostilos.codecrow.vcsclient.VcsClient client = vcsClientProvider.getClient(connection); List events = getWebhookEvents(binding.getProvider()); - + String workspaceIdOrNamespace; String repoSlug; - + // For REPOSITORY_TOKEN connections, use the repositoryPath from the connection // because the token is scoped to that specific repository - if (connection.getConnectionType() == EVcsConnectionType.REPOSITORY_TOKEN - && connection.getRepositoryPath() != null + if (connection.getConnectionType() == EVcsConnectionType.REPOSITORY_TOKEN + && connection.getRepositoryPath() != null && !connection.getRepositoryPath().isBlank()) { String repositoryPath = connection.getRepositoryPath(); int lastSlash = repositoryPath.lastIndexOf('/'); @@ -756,16 +770,16 @@ public WebhookSetupResult setupWebhooks(Long workspaceId, Long projectId) { workspaceIdOrNamespace = binding.getExternalNamespace(); repoSlug = repositoryPath; } - log.info("REPOSITORY_TOKEN webhook setup - using repositoryPath: {}, namespace: {}, slug: {}", + log.info("REPOSITORY_TOKEN webhook setup - using repositoryPath: {}, namespace: {}, slug: {}", repositoryPath, workspaceIdOrNamespace, repoSlug); } else { workspaceIdOrNamespace = binding.getExternalNamespace(); repoSlug = binding.getExternalRepoSlug(); log.info("Standard webhook setup - namespace: {}, slug: {}", workspaceIdOrNamespace, repoSlug); } - + String webhookId = client.ensureWebhook(workspaceIdOrNamespace, repoSlug, webhookUrl, events); - + if (webhookId != null) { binding.setWebhookId(webhookId); binding.setWebhooksConfigured(true); @@ -799,8 +813,7 @@ public WebhookInfo getWebhookInfo(Long workspaceId, Long projectId) { binding.isWebhooksConfigured(), binding.getWebhookId(), webhookUrl, - binding.getProvider() - ); + binding.getProvider()); } private String generateWebhookUrl(EVcsProvider provider, Project project) { @@ -810,7 +823,8 @@ private String generateWebhookUrl(EVcsProvider provider, Project project) { private List getWebhookEvents(EVcsProvider provider) { return switch (provider) { - case BITBUCKET_CLOUD -> List.of("pullrequest:created", "pullrequest:updated", "pullrequest:fulfilled", "pullrequest:comment_created", "repo:push"); + case BITBUCKET_CLOUD -> List.of("pullrequest:created", "pullrequest:updated", "pullrequest:fulfilled", + "pullrequest:comment_created", "repo:push"); case GITHUB -> List.of("pull_request", "pull_request_review_comment", "issue_comment", "push"); case GITLAB -> List.of("merge_requests_events", "note_events", "push_events"); default -> List.of(); @@ -834,7 +848,8 @@ public Project changeVcsConnection(Long workspaceId, Long projectId, ChangeVcsCo Workspace workspace = workspaceRepository.findById(workspaceId) .orElseThrow(() -> new NoSuchElementException("Workspace not found")); - VcsConnection newConnection = vcsConnectionRepository.findByWorkspace_IdAndId(workspaceId, request.getConnectionId()) + VcsConnection newConnection = vcsConnectionRepository + .findByWorkspace_IdAndId(workspaceId, request.getConnectionId()) .orElseThrow(() -> new NoSuchElementException("VCS Connection not found")); // Get or create VcsRepoBinding @@ -857,9 +872,11 @@ public Project changeVcsConnection(Long workspaceId, Long projectId, ChangeVcsCo binding.setVcsConnection(newConnection); binding.setProvider(newConnection.getProviderType()); binding.setExternalRepoSlug(request.getRepositorySlug()); - binding.setExternalNamespace(request.getWorkspaceId() != null ? request.getWorkspaceId() : newConnection.getExternalWorkspaceSlug()); - binding.setExternalRepoId(request.getRepositoryId() != null ? request.getRepositoryId() : request.getRepositorySlug()); - + binding.setExternalNamespace( + request.getWorkspaceId() != null ? request.getWorkspaceId() : newConnection.getExternalWorkspaceSlug()); + binding.setExternalRepoId( + request.getRepositoryId() != null ? request.getRepositoryId() : request.getRepositorySlug()); + if (request.getDefaultBranch() != null && !request.getDefaultBranch().isBlank()) { binding.setDefaultBranch(request.getDefaultBranch()); } @@ -900,20 +917,4 @@ private void clearProjectAnalysisData(Long projectId) { analysisLockRepository.deleteByProjectId(projectId); ragIndexStatusRepository.deleteByProjectId(projectId); } - - // ==================== DTOs for Webhook Operations ==================== - - public record WebhookSetupResult( - boolean success, - String webhookId, - String webhookUrl, - String message - ) {} - - public record WebhookInfo( - boolean webhooksConfigured, - String webhookId, - String webhookUrl, - EVcsProvider provider - ) {} } diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/workspace/service/IWorkspaceService.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/workspace/service/IWorkspaceService.java new file mode 100644 index 00000000..b48331a9 --- /dev/null +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/workspace/service/IWorkspaceService.java @@ -0,0 +1,64 @@ +package org.rostilos.codecrow.webserver.workspace.service; + +import java.util.List; + +import org.rostilos.codecrow.core.model.workspace.EWorkspaceRole; +import org.rostilos.codecrow.core.model.workspace.Workspace; +import org.rostilos.codecrow.core.model.workspace.WorkspaceMember; + +/** + * Interface for workspace management operations. + *

+ * This interface enables cloud implementations to extend or decorate + * the base workspace service with additional capabilities like billing limits. + */ +public interface IWorkspaceService { + + // ==================== Core CRUD ==================== + + Workspace createWorkspace(Long actorUserId, String slug, String name, String description); + + Workspace getWorkspaceById(Long id); + + Workspace getWorkspaceBySlug(String slug); + + void deleteWorkspace(Long actorUserId, String workspaceSlug); + + // ==================== Membership ==================== + + WorkspaceMember inviteToWorkspace(Long actorUserId, Long workspaceId, String username, EWorkspaceRole role); + + WorkspaceMember inviteToWorkspace(Long actorUserId, String workspaceSlug, String username, EWorkspaceRole role); + + void removeMemberFromWorkspace(Long actorUserId, Long workspaceId, String username); + + void removeMemberFromWorkspace(Long actorUserId, String workspaceSlug, String username); + + void changeWorkspaceRole(Long actorUserId, Long workspaceId, String targetUsername, EWorkspaceRole newRole); + + void changeWorkspaceRole(Long actorUserId, String workspaceSlug, String targetUsername, EWorkspaceRole newRole); + + WorkspaceMember acceptInvite(Long userId, Long workspaceId); + + WorkspaceMember acceptInvite(Long userId, String workspaceSlug); + + // ==================== Queries ==================== + + List listMembers(Long workspaceId); + + List listMembers(String workspaceSlug); + + List listMembers(Long workspaceId, List excludeUsernames); + + List listUserWorkspaces(Long userId); + + EWorkspaceRole getUserRole(Long workspaceId, Long userId); + + EWorkspaceRole getUserRole(String workspaceSlug, Long userId); + + // ==================== Deletion Scheduling ==================== + + Workspace scheduleDeletion(Long actorUserId, String workspaceSlug); + + Workspace cancelScheduledDeletion(Long actorUserId, String workspaceSlug); +} diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/workspace/service/WorkspaceService.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/workspace/service/WorkspaceService.java index 0b991dab..65b134b3 100644 --- a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/workspace/service/WorkspaceService.java +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/workspace/service/WorkspaceService.java @@ -19,7 +19,7 @@ import jakarta.persistence.PersistenceContext; @Service -public class WorkspaceService { +public class WorkspaceService implements IWorkspaceService { @PersistenceContext private EntityManager entityManager; @@ -29,8 +29,8 @@ public class WorkspaceService { private final UserRepository userRepository; public WorkspaceService(WorkspaceRepository workspaceRepository, - WorkspaceMemberRepository workspaceMemberRepository, - UserRepository userRepository) { + WorkspaceMemberRepository workspaceMemberRepository, + UserRepository userRepository) { this.workspaceRepository = workspaceRepository; this.workspaceMemberRepository = workspaceMemberRepository; this.userRepository = userRepository; @@ -105,7 +105,8 @@ public WorkspaceMember inviteToWorkspace(Long actorUserId, Long workspaceId, Str } @Transactional - public WorkspaceMember inviteToWorkspace(Long actorUserId, String workspaceSlug, String username, EWorkspaceRole role) { + public WorkspaceMember inviteToWorkspace(Long actorUserId, String workspaceSlug, String username, + EWorkspaceRole role) { Workspace workspace = getWorkspaceBySlug(workspaceSlug); return inviteToWorkspace(actorUserId, workspace.getId(), username, role); } @@ -115,16 +116,19 @@ public void changeWorkspaceRole(Long actorUserId, Long workspaceId, String targe User target = userRepository.findByUsername(targetUsername) .orElseThrow(() -> new IllegalArgumentException("User not found")); - Optional optionalWorkspaceMemberActor = workspaceMemberRepository.findByWorkspaceIdAndUserId(workspaceId, actorUserId); - if(optionalWorkspaceMemberActor.isEmpty()) { + Optional optionalWorkspaceMemberActor = workspaceMemberRepository + .findByWorkspaceIdAndUserId(workspaceId, actorUserId); + if (optionalWorkspaceMemberActor.isEmpty()) { throw new SecurityException("The current user does not have editing privileges."); } WorkspaceMember workspaceMemberActor = optionalWorkspaceMemberActor.get(); - if(workspaceMemberActor.getRole() == EWorkspaceRole.ADMIN && ( newRole == EWorkspaceRole.ADMIN || newRole == EWorkspaceRole.OWNER )) { + if (workspaceMemberActor.getRole() == EWorkspaceRole.ADMIN + && (newRole == EWorkspaceRole.ADMIN || newRole == EWorkspaceRole.OWNER)) { throw new SecurityException("The current user does not have editing privileges."); } - WorkspaceMember workspaceMember = workspaceMemberRepository.findByWorkspaceIdAndUserId(workspaceId, target.getId()) + WorkspaceMember workspaceMember = workspaceMemberRepository + .findByWorkspaceIdAndUserId(workspaceId, target.getId()) .orElseThrow(() -> new NoSuchElementException("The user is not present in the workspace.")); workspaceMember.setRole(newRole); @@ -132,7 +136,8 @@ public void changeWorkspaceRole(Long actorUserId, Long workspaceId, String targe } @Transactional - public void changeWorkspaceRole(Long actorUserId, String workspaceSlug, String targetUsername, EWorkspaceRole newRole) { + public void changeWorkspaceRole(Long actorUserId, String workspaceSlug, String targetUsername, + EWorkspaceRole newRole) { Workspace workspace = getWorkspaceBySlug(workspaceSlug); changeWorkspaceRole(actorUserId, workspace.getId(), targetUsername, newRole); } @@ -172,7 +177,8 @@ public List listMembers(String workspaceSlug) { } /** - * Returns workspace members and optionally excludes members whose usernames are in excludeUsernames. + * Returns workspace members and optionally excludes members whose usernames are + * in excludeUsernames. * excludeUsernames can be null or empty (no exclusions). */ @Transactional(readOnly = true) @@ -209,18 +215,19 @@ public EWorkspaceRole getUserRole(String workspaceSlug, Long userId) { @Transactional public void deleteWorkspace(Long actorUserId, String workspaceSlug) { Workspace workspace = getWorkspaceBySlug(workspaceSlug); - + // Verify actor is the owner - WorkspaceMember actorMember = workspaceMemberRepository.findByWorkspaceIdAndUserId(workspace.getId(), actorUserId) + WorkspaceMember actorMember = workspaceMemberRepository + .findByWorkspaceIdAndUserId(workspace.getId(), actorUserId) .orElseThrow(() -> new SecurityException("User is not a member of this workspace")); - + if (actorMember.getRole() != EWorkspaceRole.OWNER) { throw new SecurityException("Only the workspace owner can delete a workspace"); } - + // Delete all workspace members first workspaceMemberRepository.deleteAllByWorkspaceId(workspace.getId()); - + // Delete the workspace workspaceRepository.delete(workspace); } @@ -231,19 +238,20 @@ public void deleteWorkspace(Long actorUserId, String workspaceSlug) { @Transactional public Workspace scheduleDeletion(Long actorUserId, String workspaceSlug) { Workspace workspace = getWorkspaceBySlug(workspaceSlug); - + // Verify actor is the owner - WorkspaceMember actorMember = workspaceMemberRepository.findByWorkspaceIdAndUserId(workspace.getId(), actorUserId) + WorkspaceMember actorMember = workspaceMemberRepository + .findByWorkspaceIdAndUserId(workspace.getId(), actorUserId) .orElseThrow(() -> new SecurityException("User is not a member of this workspace")); - + if (actorMember.getRole() != EWorkspaceRole.OWNER) { throw new SecurityException("Only the workspace owner can schedule workspace deletion"); } - + if (workspace.isScheduledForDeletion()) { throw new IllegalStateException("Workspace is already scheduled for deletion"); } - + workspace.scheduleDeletion(actorUserId); return workspaceRepository.save(workspace); } @@ -254,19 +262,20 @@ public Workspace scheduleDeletion(Long actorUserId, String workspaceSlug) { @Transactional public Workspace cancelScheduledDeletion(Long actorUserId, String workspaceSlug) { Workspace workspace = getWorkspaceBySlug(workspaceSlug); - + // Verify actor is the owner - WorkspaceMember actorMember = workspaceMemberRepository.findByWorkspaceIdAndUserId(workspace.getId(), actorUserId) + WorkspaceMember actorMember = workspaceMemberRepository + .findByWorkspaceIdAndUserId(workspace.getId(), actorUserId) .orElseThrow(() -> new SecurityException("User is not a member of this workspace")); - + if (actorMember.getRole() != EWorkspaceRole.OWNER) { throw new SecurityException("Only the workspace owner can cancel workspace deletion"); } - + if (!workspace.isScheduledForDeletion()) { throw new IllegalStateException("Workspace is not scheduled for deletion"); } - + workspace.cancelDeletion(); return workspaceRepository.save(workspace); } From 5dfa058e7f7ed7da3faac9f24b55716f3f3f767d Mon Sep 17 00:00:00 2001 From: rostislav Date: Sun, 8 Feb 2026 18:36:14 +0200 Subject: [PATCH 4/7] feat: Enhance VcsConnection with optimistic locking and set default version for existing records; update database migration scripts for version handling and remove token limitation from ai_connection --- .../core/model/vcs/VcsConnection.java | 5 +- ...e_token_limitation_from_ai_connection.sql} | 0 ....0__set_default_vcs_connection_version.sql | 5 ++ .../processor/WebhookAsyncProcessor.java | 82 +++++++++---------- python-ecosystem/mcp-client/Dockerfile | 2 + .../mcp-client/model/multi_stage.py | 2 +- .../review/orchestrator/orchestrator.py | 4 +- 7 files changed, 52 insertions(+), 48 deletions(-) rename java-ecosystem/libs/core/src/main/resources/db/migration/{1.4.0/V1.4.0__remove_token_limitation_from_ai_connection.sql => 1.3.0/V1.3.0__remove_token_limitation_from_ai_connection.sql} (100%) create mode 100644 java-ecosystem/libs/core/src/main/resources/db/migration/1.3.0/V1.3.0__set_default_vcs_connection_version.sql diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/vcs/VcsConnection.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/vcs/VcsConnection.java index 2c0a74d7..9c3f231a 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/vcs/VcsConnection.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/vcs/VcsConnection.java @@ -119,10 +119,11 @@ public class VcsConnection { /** * Version field for optimistic locking. * Prevents concurrent token refresh operations from overwriting each other. + * Initialized to 0 to handle existing records that don't have this field yet. */ @Version - @Column(name = "version") - private Long version; + @Column(name = "version", nullable = false, columnDefinition = "BIGINT DEFAULT 0") + private Long version = 0L; /** * Provider-specific configuration (JSON column). diff --git a/java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.0__remove_token_limitation_from_ai_connection.sql b/java-ecosystem/libs/core/src/main/resources/db/migration/1.3.0/V1.3.0__remove_token_limitation_from_ai_connection.sql similarity index 100% rename from java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.0__remove_token_limitation_from_ai_connection.sql rename to java-ecosystem/libs/core/src/main/resources/db/migration/1.3.0/V1.3.0__remove_token_limitation_from_ai_connection.sql diff --git a/java-ecosystem/libs/core/src/main/resources/db/migration/1.3.0/V1.3.0__set_default_vcs_connection_version.sql b/java-ecosystem/libs/core/src/main/resources/db/migration/1.3.0/V1.3.0__set_default_vcs_connection_version.sql new file mode 100644 index 00000000..0fd68ede --- /dev/null +++ b/java-ecosystem/libs/core/src/main/resources/db/migration/1.3.0/V1.3.0__set_default_vcs_connection_version.sql @@ -0,0 +1,5 @@ +-- Add version column if not exists and set default for existing records +ALTER TABLE vcs_connection ADD COLUMN IF NOT EXISTS version BIGINT DEFAULT 0; +UPDATE vcs_connection SET version = 0 WHERE version IS NULL; +ALTER TABLE vcs_connection ALTER COLUMN version SET NOT NULL; +ALTER TABLE vcs_connection ALTER COLUMN version SET DEFAULT 0; \ No newline at end of file diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index da2a0787..4b7f2bec 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -1,6 +1,5 @@ package org.rostilos.codecrow.pipelineagent.generic.processor; -import jakarta.persistence.EntityManager; import org.rostilos.codecrow.core.model.job.Job; import org.rostilos.codecrow.core.model.project.Project; import org.rostilos.codecrow.core.model.vcs.EVcsProvider; @@ -24,8 +23,11 @@ import org.springframework.transaction.annotation.Transactional; /** - * Service for processing webhooks asynchronously within a transactional context. - * This ensures Hibernate sessions are available for lazy loading. + * Service for processing webhooks asynchronously. + * + * Uses a short read-only transaction to load the project entity with all lazy associations, + * then processes the webhook (including the potentially long-running AI analysis) WITHOUT + * holding a DB transaction open. All downstream services manage their own transactions. */ @Service public class WebhookAsyncProcessor { @@ -35,7 +37,6 @@ public class WebhookAsyncProcessor { private final ProjectRepository projectRepository; private final JobService jobService; private final VcsServiceFactory vcsServiceFactory; - private final EntityManager entityManager; // Self-injection for @Transactional proxy to work from @Async method @Autowired @@ -45,20 +46,17 @@ public class WebhookAsyncProcessor { public WebhookAsyncProcessor( ProjectRepository projectRepository, JobService jobService, - VcsServiceFactory vcsServiceFactory, - EntityManager entityManager + VcsServiceFactory vcsServiceFactory ) { this.projectRepository = projectRepository; this.jobService = jobService; this.vcsServiceFactory = vcsServiceFactory; - this.entityManager = entityManager; } /** * Process a webhook asynchronously. - * Delegates to a transactional method to ensure lazy associations can be loaded. - * NOTE: @Async and @Transactional cannot be on the same method - the transaction - * proxy gets bypassed. We use self-injection to call a separate @Transactional method. + * Delegates to processWebhookInTransaction via self-injection to ensure + * the @Transactional proxy on loadAndInitializeProject is invoked properly. */ @Async("webhookExecutor") public void processWebhookAsync( @@ -90,11 +88,29 @@ public void processWebhookAsync( } /** - * Process webhook within a transaction. - * Called from async method via self-injection to ensure transaction proxy works. - * Note: Transaction timeout set to 5 minutes to prevent indefinite 'idle in transaction' states. + * Load project and eagerly initialize all lazy associations within a short read-only transaction. + * This is separated from the main processing so the DB transaction is NOT held open + * during the potentially long-running AI analysis (which can take 10+ minutes for large PRs). + */ + @Transactional(timeout = 30, readOnly = true) + public Project loadAndInitializeProject(Long projectId) { + Project project = projectRepository.findById(projectId) + .orElseThrow(() -> new IllegalStateException("Project not found: " + projectId)); + initializeProjectAssociations(project); + return project; + } + + /** + * Process webhook - called from async method via self-injection. + * + * IMPORTANT: This method is intentionally NOT @Transactional. + * All downstream services (JobService, PullRequestService, CodeAnalysisService, etc.) + * manage their own transaction boundaries. Wrapping the entire webhook processing in a + * single transaction caused timeout failures on large PRs where AI analysis takes 10+ minutes, + * resulting in lost results after tokens were already spent. + * + * Project loading uses a short read-only transaction via {@link #loadAndInitializeProject(Long)}. */ - @Transactional(timeout = 300) public void processWebhookInTransaction( EVcsProvider provider, Long projectId, @@ -107,12 +123,9 @@ public void processWebhookInTransaction( Project project = null; try { - // Re-fetch project to ensure all lazy associations are available - project = projectRepository.findById(projectId) - .orElseThrow(() -> new IllegalStateException("Project not found: " + projectId)); - - // Initialize lazy associations we'll need - initializeProjectAssociations(project); + // Load project in a short read-only transaction + // All lazy associations are eagerly initialized before the transaction closes + project = self.loadAndInitializeProject(projectId); log.info("Calling jobService.startJob for job {}", job.getExternalId()); jobService.startJob(job); @@ -191,19 +204,14 @@ public void processWebhookInTransaction( try { if (project == null) { - project = projectRepository.findById(projectId).orElse(null); - } - if (project != null) { - initializeProjectAssociations(project); - postInfoToVcs(provider, project, payload, skipMessage, placeholderCommentId, job); + project = self.loadAndInitializeProject(projectId); } + postInfoToVcs(provider, project, payload, skipMessage, placeholderCommentId, job); } catch (Exception postError) { log.error("Failed to post skip message to VCS: {}", postError.getMessage()); } try { - // Detach job from persistence context to prevent outer transaction from overwriting - entityManager.detach(job); jobService.skipJob(job, "Diff too large: " + diffEx.getEstimatedTokens() + " tokens > " + diffEx.getMaxAllowedTokens() + " limit"); } catch (Exception skipError) { log.error("Failed to skip job: {}", skipError.getMessage()); @@ -227,21 +235,15 @@ public void processWebhookInTransaction( try { if (project == null) { - project = projectRepository.findById(projectId).orElse(null); - } - if (project != null) { - initializeProjectAssociations(project); - postErrorToVcs(provider, project, payload, failMessage, placeholderCommentId, job); + project = self.loadAndInitializeProject(projectId); } + postErrorToVcs(provider, project, payload, failMessage, placeholderCommentId, job); } catch (Exception postError) { log.error("Failed to post lock error to VCS: {}", postError.getMessage()); } try { log.info("Marking job {} as FAILED due to lock acquisition timeout", job.getExternalId()); - // Detach job from persistence context to prevent outer transaction from overwriting - // the FAILED status when it commits (since failJob uses REQUIRES_NEW) - entityManager.detach(job); Job updatedJob = jobService.failJob(job, "Lock acquisition timeout: " + lockEx.getMessage()); log.info("Job {} marked as FAILED, new status: {}", job.getExternalId(), updatedJob.getStatus()); } catch (Exception failError) { @@ -255,21 +257,15 @@ public void processWebhookInTransaction( try { if (project == null) { - project = projectRepository.findById(projectId).orElse(null); - } - if (project != null) { - initializeProjectAssociations(project); - postErrorToVcs(provider, project, payload, "Processing failed: " + e.getMessage(), placeholderCommentId, job); + project = self.loadAndInitializeProject(projectId); } + postErrorToVcs(provider, project, payload, "Processing failed: " + e.getMessage(), placeholderCommentId, job); } catch (Exception postError) { log.error("Failed to post error to VCS: {}", postError.getMessage()); } try { log.info("Marking job {} as FAILED due to processing error: {}", job.getExternalId(), e.getMessage()); - // Detach job from persistence context to prevent outer transaction from overwriting - // the FAILED status when it commits (since failJob uses REQUIRES_NEW) - entityManager.detach(job); Job updatedJob = jobService.failJob(job, "Processing failed: " + e.getMessage()); log.info("Job {} marked as FAILED, new status: {}", job.getExternalId(), updatedJob.getStatus()); } catch (Exception failError) { diff --git a/python-ecosystem/mcp-client/Dockerfile b/python-ecosystem/mcp-client/Dockerfile index 9dd14dba..2c03cc27 100644 --- a/python-ecosystem/mcp-client/Dockerfile +++ b/python-ecosystem/mcp-client/Dockerfile @@ -22,6 +22,7 @@ RUN pip install --no-cache-dir -r requirements.txt # --- Builder Stage 2: Copy and Install Application Modules --- # Copy application source code and the JAR COPY main.py . +COPY api ./api/ COPY server ./server/ COPY model ./model/ COPY service ./service/ @@ -56,6 +57,7 @@ COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/pytho # Copy application code from the builder stage COPY --from=builder /app/main.py ./ +COPY --from=builder /app/api ./api/ COPY --from=builder /app/server ./server/ COPY --from=builder /app/model ./model/ COPY --from=builder /app/service ./service/ diff --git a/python-ecosystem/mcp-client/model/multi_stage.py b/python-ecosystem/mcp-client/model/multi_stage.py index 16e11fab..10cd1979 100644 --- a/python-ecosystem/mcp-client/model/multi_stage.py +++ b/python-ecosystem/mcp-client/model/multi_stage.py @@ -31,7 +31,7 @@ class ReviewFile(BaseModel): """File details for review planning.""" path: str focus_areas: List[str] = Field(default_factory=list, description="Specific areas to focus on (SECURITY, ARCHITECTURE, etc.)") - risk_level: str = Field(description="CRITICAL, HIGH, MEDIUM, or LOW") + risk_level: str = Field(default="MEDIUM", description="CRITICAL, HIGH, MEDIUM, or LOW") estimated_issues: Optional[int] = Field(default=0) diff --git a/python-ecosystem/mcp-client/service/review/orchestrator/orchestrator.py b/python-ecosystem/mcp-client/service/review/orchestrator/orchestrator.py index c843f22c..f50fd9f6 100644 --- a/python-ecosystem/mcp-client/service/review/orchestrator/orchestrator.py +++ b/python-ecosystem/mcp-client/service/review/orchestrator/orchestrator.py @@ -232,12 +232,12 @@ def _ensure_all_files_planned(self, plan, changed_files: List[str]): if missing_files: logger.warning(f"Stage 0 missed {len(missing_files)} files, adding to catch-all group") catch_all_files = [ - ReviewFile(path=f, focus_areas=["general review"]) + ReviewFile(path=f, focus_areas=["general review"], risk_level="MEDIUM") for f in missing_files ] plan.file_groups.append( FileGroup( - name="uncategorized", + group_id="uncategorized", priority="MEDIUM", rationale="Files not categorized by initial planning", files=catch_all_files From bb8918a3a9cc167fb4ad079305e7a1285802ac79 Mon Sep 17 00:00:00 2001 From: rostislav Date: Sun, 8 Feb 2026 21:55:05 +0200 Subject: [PATCH 5/7] feat: Enhance ReviewService with concurrency control and timeout handling - Introduced MAX_CONCURRENT_REVIEWS to limit simultaneous review requests. - Implemented asyncio.Semaphore to manage concurrent reviews. - Added a timeout of 600 seconds for review processing to prevent long-running requests. - Refactored LLM reranker initialization to be per-request, improving resource management. - Ensured MCP sessions are closed after review processing to release resources. - Enhanced error handling for timeouts and exceptions during review processing. refactor: Simplify context builder and remove unused components - Removed legacy context budget management and model context limits. - Streamlined context builder utilities for RAG metrics and caching. - Updated context fetching logic to align with new architecture. fix: Update prompt templates for clarity and accuracy - Revised Stage 2 cross-file prompt to focus on relevant aspects. - Changed references from "Database Migrations" to "Migration Files" for consistency. feat: Implement service-to-service authentication middleware - Added ServiceSecretMiddleware to validate internal requests with a shared secret. - Configured middleware to skip authentication for public endpoints. enhance: Improve collection management with payload indexing - Added functionality to create payload indexes for efficient filtering on common fields in Qdrant collections. fix: Adjust query service to handle path prefix mismatches - Updated fallback logic in RAGQueryService to improve handling of filename matches during queries. --- deployment/config/mcp-client/.env.sample | 10 + deployment/config/rag-pipeline/.env.sample | 12 + .../processor/WebhookAsyncProcessor.java | 23 +- python-ecosystem/mcp-client/api/app.py | 41 +- python-ecosystem/mcp-client/api/middleware.py | 46 ++ .../mcp-client/api/routers/commands.py | 16 +- .../mcp-client/api/routers/review.py | 14 +- python-ecosystem/mcp-client/main.py | 45 +- python-ecosystem/mcp-client/model/__init__.py | 4 - .../mcp-client/model/multi_stage.py | 15 - .../service/command/command_service.py | 74 ++- .../mcp-client/service/rag/rag_client.py | 27 +- .../review/orchestrator/orchestrator.py | 10 +- .../review/orchestrator/reconciliation.py | 12 +- .../service/review/orchestrator/stages.py | 298 ++++++++- .../service/review/review_service.py | 84 ++- .../mcp-client/utils/context_builder.py | 606 +----------------- .../utils/prompts/prompt_constants.py | 13 +- .../rag-pipeline/src/rag_pipeline/api/api.py | 36 +- .../src/rag_pipeline/api/middleware.py | 46 ++ .../core/index_manager/collection_manager.py | 24 +- .../rag_pipeline/services/query_service.py | 20 +- 22 files changed, 688 insertions(+), 788 deletions(-) create mode 100644 python-ecosystem/mcp-client/api/middleware.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/api/middleware.py diff --git a/deployment/config/mcp-client/.env.sample b/deployment/config/mcp-client/.env.sample index c92e89a9..05066c29 100644 --- a/deployment/config/mcp-client/.env.sample +++ b/deployment/config/mcp-client/.env.sample @@ -2,6 +2,16 @@ AI_CLIENT_PORT=8000 RAG_ENABLED=true RAG_API_URL=http://host.docker.internal:8001 +# === Service-to-Service Auth === +# Shared secret for authenticating requests between internal services. +# Must match the SERVICE_SECRET configured on rag-pipeline. +# Leave empty to disable auth (dev mode only). +SERVICE_SECRET=change-me-to-a-random-secret + +# === Concurrency === +# Max parallel review requests handled simultaneously (default: 4) +MAX_CONCURRENT_REVIEWS=4 + # API access for Platform MCP (internal network only) CODECROW_API_URL=http://codecrow-web-application:8081 diff --git a/deployment/config/rag-pipeline/.env.sample b/deployment/config/rag-pipeline/.env.sample index a0bb870c..1152a685 100644 --- a/deployment/config/rag-pipeline/.env.sample +++ b/deployment/config/rag-pipeline/.env.sample @@ -1,3 +1,15 @@ +# === Service-to-Service Auth === +# Shared secret for authenticating incoming requests from mcp-client. +# Must match the SERVICE_SECRET configured on mcp-client. +# Leave empty to disable auth (dev mode only). +SERVICE_SECRET=change-me-to-a-random-secret + +# === Path Traversal Guard === +# Root directory that repo_path arguments are allowed under. +# The rag-pipeline will reject any index/query request whose resolved +# path escapes this directory. Default: /tmp +ALLOWED_REPO_ROOT=/tmp + #QDRANT configuration QDRANT_URL=http://qdrant:6333 QDRANT_COLLECTION_PREFIX=codecrow diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index 4b7f2bec..064c0061 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -275,7 +275,11 @@ public void processWebhookInTransaction( } /** - * Initialize lazy associations that will be needed for VCS operations. + * Initialize lazy associations that will be needed during webhook processing. + * Must be called within an active Hibernate session (i.e., inside a @Transactional method). + * + * Touches all lazy proxies so they are fully loaded before the session closes, + * allowing downstream services to access them outside a transaction. */ private void initializeProjectAssociations(Project project) { // Force initialization of VCS connections using unified accessor @@ -288,6 +292,23 @@ private void initializeProjectAssociations(Project project) { vcsConn.getProviderType(); } } + + // Force initialization of Workspace (lazy @ManyToOne) — accessed by all AiClientServices + // when building analysis requests via project.getWorkspace().getName() + var workspace = project.getWorkspace(); + if (workspace != null) { + workspace.getName(); + } + + // Force initialization of AI binding chain (lazy @OneToOne) — accessed by all AiClientServices + // via project.getAiBinding().getAiConnection() to get provider config, model, API key + var aiBinding = project.getAiBinding(); + if (aiBinding != null) { + var aiConn = aiBinding.getAiConnection(); + if (aiConn != null) { + aiConn.getProviderKey(); + } + } } /** diff --git a/python-ecosystem/mcp-client/api/app.py b/python-ecosystem/mcp-client/api/app.py index ac1a201a..fc513bc4 100644 --- a/python-ecosystem/mcp-client/api/app.py +++ b/python-ecosystem/mcp-client/api/app.py @@ -2,16 +2,55 @@ FastAPI Application Factory. Creates and configures the FastAPI application with all routers. +Uses lifespan context manager for proper startup/shutdown of shared resources. """ import os +import logging +from contextlib import asynccontextmanager from fastapi import FastAPI from api.routers import health, review, commands +from api.middleware import ServiceSecretMiddleware +from service.review.review_service import ReviewService +from service.command.command_service import CommandService + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage application lifecycle: create services on startup, clean up on shutdown.""" + # --- Startup --- + logger.info("Initializing application services...") + review_service = ReviewService() + command_service = CommandService() + + app.state.review_service = review_service + app.state.command_service = command_service + logger.info("Application services ready") + + yield + + # --- Shutdown --- + logger.info("Shutting down application services...") + # Close the RagClient HTTP pools owned by each service + try: + await review_service.rag_client.close() + except Exception as e: + logger.warning(f"Error closing review RagClient: {e}") + try: + await command_service.rag_client.close() + except Exception as e: + logger.warning(f"Error closing command RagClient: {e}") + logger.info("Application services shut down") def create_app() -> FastAPI: """Create and configure FastAPI application.""" - app = FastAPI(title="codecrow-mcp-client") + app = FastAPI(title="codecrow-mcp-client", lifespan=lifespan) + + # Service-to-service auth + app.add_middleware(ServiceSecretMiddleware) # Register routers app.include_router(health.router) diff --git a/python-ecosystem/mcp-client/api/middleware.py b/python-ecosystem/mcp-client/api/middleware.py new file mode 100644 index 00000000..9e78d0d5 --- /dev/null +++ b/python-ecosystem/mcp-client/api/middleware.py @@ -0,0 +1,46 @@ +""" +Shared-secret authentication middleware. + +Validates that internal service-to-service requests carry +the correct X-Service-Secret header matching the SERVICE_SECRET env var. +""" +import os +import logging + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse + +logger = logging.getLogger(__name__) + +# Paths that skip auth (health checks, readiness probes) +_PUBLIC_PATHS = frozenset({"/health", "/docs", "/openapi.json", "/redoc"}) + + +class ServiceSecretMiddleware(BaseHTTPMiddleware): + """Reject requests that don't carry a valid shared service secret.""" + + def __init__(self, app, secret: str | None = None): + super().__init__(app) + self.secret = secret or os.environ.get("SERVICE_SECRET", "") + + async def dispatch(self, request: Request, call_next): + # Skip auth for health/doc endpoints + if request.url.path in _PUBLIC_PATHS: + return await call_next(request) + + # If no secret is configured, allow all (dev mode) + if not self.secret: + return await call_next(request) + + provided = request.headers.get("x-service-secret", "") + if provided != self.secret: + logger.warning( + f"Unauthorized request to {request.url.path} from {request.client.host if request.client else 'unknown'}" + ) + return JSONResponse( + status_code=401, + content={"detail": "Missing or invalid service secret"}, + ) + + return await call_next(request) diff --git a/python-ecosystem/mcp-client/api/routers/commands.py b/python-ecosystem/mcp-client/api/routers/commands.py index 9992a298..de6046e6 100644 --- a/python-ecosystem/mcp-client/api/routers/commands.py +++ b/python-ecosystem/mcp-client/api/routers/commands.py @@ -15,16 +15,10 @@ router = APIRouter(tags=["commands"]) -# Service instance -_command_service = None - -def get_command_service() -> CommandService: - """Get or create the command service singleton.""" - global _command_service - if _command_service is None: - _command_service = CommandService() - return _command_service +def get_command_service(request: Request) -> CommandService: + """Retrieve the CommandService instance created during app lifespan.""" + return request.app.state.command_service @router.post("/review/summarize", response_model=SummarizeResponseDto) @@ -38,7 +32,7 @@ async def summarize_endpoint(req: SummarizeRequestDto, request: Request): - Impact analysis - Architecture diagram (Mermaid or ASCII) """ - command_service = get_command_service() + command_service = get_command_service(request) try: wants_stream = _wants_streaming(request) @@ -97,7 +91,7 @@ async def ask_endpoint(req: AskRequestDto, request: Request): - Codebase (using RAG) - Analysis results """ - command_service = get_command_service() + command_service = get_command_service(request) try: wants_stream = _wants_streaming(request) diff --git a/python-ecosystem/mcp-client/api/routers/review.py b/python-ecosystem/mcp-client/api/routers/review.py index 68af6aa6..ca167662 100644 --- a/python-ecosystem/mcp-client/api/routers/review.py +++ b/python-ecosystem/mcp-client/api/routers/review.py @@ -13,16 +13,10 @@ router = APIRouter(tags=["review"]) -# Service instance -_review_service = None - -def get_review_service() -> ReviewService: - """Get or create the review service singleton.""" - global _review_service - if _review_service is None: - _review_service = ReviewService() - return _review_service +def get_review_service(request: Request) -> ReviewService: + """Retrieve the ReviewService instance created during app lifespan.""" + return request.app.state.review_service @router.post("/review", response_model=ReviewResponseDto) @@ -35,7 +29,7 @@ async def review_endpoint(req: ReviewRequestDto, request: Request): the endpoint will return a StreamingResponse that yields NDJSON events as they occur. - Otherwise it preserves the original behavior and returns a single ReviewResponseDto JSON body. """ - review_service = get_review_service() + review_service = get_review_service(request) try: wants_stream = _wants_streaming(request) diff --git a/python-ecosystem/mcp-client/main.py b/python-ecosystem/mcp-client/main.py index 9d1ce045..d62dc50b 100644 --- a/python-ecosystem/mcp-client/main.py +++ b/python-ecosystem/mcp-client/main.py @@ -15,7 +15,6 @@ import sys import logging import warnings -import threading from server.stdin_handler import StdinHandler from api.app import run_http_server @@ -39,38 +38,24 @@ # Suppress pydantic warnings warnings.filterwarnings('ignore', category=UserWarning, module='pydantic') -# Wrap stderr to filter out JSONRPC parsing errors from MCP library -# These occur when Java MCP servers leak log messages to stdout -class FilteredStderr: - def __init__(self, original_stderr): - self.original_stderr = original_stderr - self.buffer = "" - self._lock = threading.Lock() - self._suppress_next_lines = 0 - def write(self, text): - with self._lock: - # Check if this is the start of a JSONRPC parsing error - if "Failed to parse JSONRPC message from server" in text: - self._suppress_next_lines = 15 # Suppress the next ~15 lines (traceback) - return +class _McpJsonRpcFilter(logging.Filter): + """Filter out noisy JSONRPC parsing errors from the mcp_use library. + + These occur when Java MCP servers leak log messages to stdout, which + the mcp_use library then fails to parse as JSON-RPC. They are harmless + and clutter the logs. + """ + def filter(self, record: logging.LogRecord) -> bool: + msg = record.getMessage() + if "Failed to parse JSONRPC message from server" in msg: + return False + return True - # If we're suppressing lines, decrement counter - if self._suppress_next_lines > 0: - self._suppress_next_lines -= 1 - return - # Otherwise, write to original stderr - self.original_stderr.write(text) - - def flush(self): - self.original_stderr.flush() - - def __getattr__(self, name): - return getattr(self.original_stderr, name) - -# Install the filtered stderr wrapper -sys.stderr = FilteredStderr(sys.stderr) +# Install the filter on the mcp_use loggers instead of wrapping stderr +for _logger_name in ("mcp_use", "mcp_use.client", "mcp_use.client.session"): + logging.getLogger(_logger_name).addFilter(_McpJsonRpcFilter()) def main(): diff --git a/python-ecosystem/mcp-client/model/__init__.py b/python-ecosystem/mcp-client/model/__init__.py index db99d4b2..5e9aa368 100644 --- a/python-ecosystem/mcp-client/model/__init__.py +++ b/python-ecosystem/mcp-client/model/__init__.py @@ -54,8 +54,6 @@ ReviewPlan, CrossFileIssue, DataFlowConcern, - ImmutabilityCheck, - DatabaseIntegrityCheck, CrossFileAnalysisResult, ) @@ -92,7 +90,5 @@ "ReviewPlan", "CrossFileIssue", "DataFlowConcern", - "ImmutabilityCheck", - "DatabaseIntegrityCheck", "CrossFileAnalysisResult", ] diff --git a/python-ecosystem/mcp-client/model/multi_stage.py b/python-ecosystem/mcp-client/model/multi_stage.py index 10cd1979..ff82ff77 100644 --- a/python-ecosystem/mcp-client/model/multi_stage.py +++ b/python-ecosystem/mcp-client/model/multi_stage.py @@ -78,25 +78,10 @@ class DataFlowConcern(BaseModel): severity: str -class ImmutabilityCheck(BaseModel): - """Stage 2: Immutability usage check.""" - rule: str - check_pass: bool = Field(alias="check_pass") - evidence: str - - -class DatabaseIntegrityCheck(BaseModel): - """Stage 2: DB integrity check.""" - concerns: List[str] - findings: List[str] - - class CrossFileAnalysisResult(BaseModel): """Stage 2 Output: Cross-file architectural analysis.""" pr_risk_level: str cross_file_issues: List[CrossFileIssue] data_flow_concerns: List[DataFlowConcern] = Field(default_factory=list) - immutability_enforcement: Optional[ImmutabilityCheck] = None - database_integrity: Optional[DatabaseIntegrityCheck] = None pr_recommendation: str confidence: str diff --git a/python-ecosystem/mcp-client/service/command/command_service.py b/python-ecosystem/mcp-client/service/command/command_service.py index d7c287da..a0e59991 100644 --- a/python-ecosystem/mcp-client/service/command/command_service.py +++ b/python-ecosystem/mcp-client/service/command/command_service.py @@ -57,11 +57,12 @@ async def process_summarize( return {"error": error_msg} try: - self._emit_event(event_callback, { - "type": "status", - "state": "started", - "message": "Starting PR summarization" - }) + async with asyncio.timeout(300): # 5-minute ceiling for commands + self._emit_event(event_callback, { + "type": "status", + "state": "started", + "message": "Starting PR summarization" + }) # Build configuration jvm_props = self._build_jvm_props_for_summarize(request) @@ -93,13 +94,20 @@ async def process_summarize( # TODO: Mermaid diagrams disabled for now - AI-generated Mermaid often has syntax errors # that fail to render on GitHub. Using ASCII diagrams until we add validation/fixing. # Original: supports_mermaid=request.supportsMermaid - result = await self._execute_summarize( - llm=llm, - client=client, - prompt=prompt, - supports_mermaid=False, # Mermaid disabled - always use ASCII - event_callback=event_callback - ) + try: + result = await self._execute_summarize( + llm=llm, + client=client, + prompt=prompt, + supports_mermaid=False, # Mermaid disabled - always use ASCII + event_callback=event_callback + ) + finally: + # Always close MCP sessions to release JVM subprocesses + try: + await client.close_all_sessions() + except Exception as close_err: + logger.warning(f"Error closing MCP sessions: {close_err}") self._emit_event(event_callback, { "type": "final", @@ -108,6 +116,12 @@ async def process_summarize( return result + except TimeoutError: + timeout_msg = "Summarize command timed out after 300 seconds" + logger.error(timeout_msg) + self._emit_event(event_callback, {"type": "error", "message": timeout_msg}) + return {"error": timeout_msg} + except Exception as e: logger.error(f"Summarize failed: {str(e)}", exc_info=True) sanitized_msg = create_user_friendly_error(e) @@ -142,11 +156,12 @@ async def process_ask( ) try: - self._emit_event(event_callback, { - "type": "status", - "state": "started", - "message": "Processing your question" - }) + async with asyncio.timeout(300): # 5-minute ceiling for commands + self._emit_event(event_callback, { + "type": "status", + "state": "started", + "message": "Processing your question" + }) # Build configuration with both VCS and Platform MCP servers jvm_props = self._build_jvm_props_for_ask(request) @@ -189,12 +204,19 @@ async def process_ask( }) # Execute with MCP agent - result = await self._execute_ask( - llm=llm, - client=client, - prompt=prompt, - event_callback=event_callback - ) + try: + result = await self._execute_ask( + llm=llm, + client=client, + prompt=prompt, + event_callback=event_callback + ) + finally: + # Always close MCP sessions to release JVM subprocesses + try: + await client.close_all_sessions() + except Exception as close_err: + logger.warning(f"Error closing MCP sessions: {close_err}") self._emit_event(event_callback, { "type": "final", @@ -203,6 +225,12 @@ async def process_ask( return result + except TimeoutError: + timeout_msg = "Ask command timed out after 300 seconds" + logger.error(timeout_msg) + self._emit_event(event_callback, {"type": "error", "message": timeout_msg}) + return {"error": timeout_msg} + except Exception as e: logger.error(f"Ask failed: {str(e)}", exc_info=True) sanitized_msg = create_user_friendly_error(e) diff --git a/python-ecosystem/mcp-client/service/rag/rag_client.py b/python-ecosystem/mcp-client/service/rag/rag_client.py index c8254fdd..0c7e4972 100644 --- a/python-ecosystem/mcp-client/service/rag/rag_client.py +++ b/python-ecosystem/mcp-client/service/rag/rag_client.py @@ -16,9 +16,6 @@ class RagClient: """Client for interacting with the RAG Pipeline API.""" - - # Shared HTTP client for connection pooling - _shared_client: Optional[httpx.AsyncClient] = None def __init__(self, base_url: Optional[str] = None, enabled: Optional[bool] = None): """ @@ -31,6 +28,8 @@ def __init__(self, base_url: Optional[str] = None, enabled: Optional[bool] = Non self.base_url = base_url or os.environ.get("RAG_API_URL", "http://rag-pipeline:8001") self.enabled = enabled if enabled is not None else os.environ.get("RAG_ENABLED", "false").lower() == "true" self.timeout = 30.0 + self._client: Optional[httpx.AsyncClient] = None + self._service_secret = os.environ.get("SERVICE_SECRET", "") if self.enabled: logger.info(f"RAG client initialized: {self.base_url}") @@ -38,19 +37,23 @@ def __init__(self, base_url: Optional[str] = None, enabled: Optional[bool] = Non logger.info("RAG client disabled") async def _get_client(self) -> httpx.AsyncClient: - """Get or create a shared HTTP client for connection pooling.""" - if RagClient._shared_client is None or RagClient._shared_client.is_closed: - RagClient._shared_client = httpx.AsyncClient( + """Get or create an HTTP client for connection pooling (instance-level).""" + if self._client is None or self._client.is_closed: + headers = {} + if self._service_secret: + headers["x-service-secret"] = self._service_secret + self._client = httpx.AsyncClient( timeout=self.timeout, - limits=httpx.Limits(max_connections=10, max_keepalive_connections=5) + limits=httpx.Limits(max_connections=10, max_keepalive_connections=5), + headers=headers, ) - return RagClient._shared_client + return self._client async def close(self): - """Close the shared HTTP client.""" - if RagClient._shared_client is not None and not RagClient._shared_client.is_closed: - await RagClient._shared_client.aclose() - RagClient._shared_client = None + """Close this instance's HTTP client.""" + if self._client is not None and not self._client.is_closed: + await self._client.aclose() + self._client = None async def get_pr_context( self, diff --git a/python-ecosystem/mcp-client/service/review/orchestrator/orchestrator.py b/python-ecosystem/mcp-client/service/review/orchestrator/orchestrator.py index f50fd9f6..f6f138bd 100644 --- a/python-ecosystem/mcp-client/service/review/orchestrator/orchestrator.py +++ b/python-ecosystem/mcp-client/service/review/orchestrator/orchestrator.py @@ -156,7 +156,9 @@ async def orchestrate_review( # === STAGE 0: Planning === _emit_status(self.event_callback, "stage_0_started", "Stage 0: Planning & Prioritization...") - review_plan = await execute_stage_0_planning(self.llm, request, is_incremental) + review_plan = await execute_stage_0_planning( + self.llm, request, is_incremental, processed_diff=processed_diff + ) review_plan = self._ensure_all_files_planned(review_plan, request.changedFiles or []) _emit_progress(self.event_callback, 10, "Stage 0 Complete: Review plan created") @@ -188,14 +190,16 @@ async def orchestrate_review( # === STAGE 2: Cross-File Analysis === _emit_status(self.event_callback, "stage_2_started", "Stage 2: Analyzing cross-file patterns...") cross_file_results = await execute_stage_2_cross_file( - self.llm, request, file_issues, review_plan + self.llm, request, file_issues, review_plan, + processed_diff=processed_diff, ) _emit_progress(self.event_callback, 85, "Stage 2 Complete: Cross-file analysis finished") # === STAGE 3: Aggregation === _emit_status(self.event_callback, "stage_3_started", "Stage 3: Generating final report...") final_report = await execute_stage_3_aggregation( - self.llm, request, review_plan, file_issues, cross_file_results, is_incremental + self.llm, request, review_plan, file_issues, cross_file_results, + is_incremental, processed_diff=processed_diff ) _emit_progress(self.event_callback, 100, "Stage 3 Complete: Report generated") diff --git a/python-ecosystem/mcp-client/service/review/orchestrator/reconciliation.py b/python-ecosystem/mcp-client/service/review/orchestrator/reconciliation.py index 1a20008b..61d69b4d 100644 --- a/python-ecosystem/mcp-client/service/review/orchestrator/reconciliation.py +++ b/python-ecosystem/mcp-client/service/review/orchestrator/reconciliation.py @@ -10,7 +10,12 @@ def issue_matches_files(issue: Any, file_paths: List[str]) -> bool: - """Check if an issue is related to any of the given file paths.""" + """Check if an issue is related to any of the given file paths. + + Matches on exact path, or when one path is a suffix of the other + (handles relative vs absolute paths). Does NOT match on basename alone + to avoid false positives (e.g., two different utils.py files). + """ if hasattr(issue, 'model_dump'): issue_data = issue.model_dump() elif isinstance(issue, dict): @@ -19,13 +24,12 @@ def issue_matches_files(issue: Any, file_paths: List[str]) -> bool: issue_data = vars(issue) if hasattr(issue, '__dict__') else {} issue_file = issue_data.get('file', issue_data.get('filePath', '')) + if not issue_file: + return False for fp in file_paths: if issue_file == fp or issue_file.endswith('/' + fp) or fp.endswith('/' + issue_file): return True - # Also check basename match - if issue_file.split('/')[-1] == fp.split('/')[-1]: - return True return False diff --git a/python-ecosystem/mcp-client/service/review/orchestrator/stages.py b/python-ecosystem/mcp-client/service/review/orchestrator/stages.py index a2a98e3c..90497c27 100644 --- a/python-ecosystem/mcp-client/service/review/orchestrator/stages.py +++ b/python-ecosystem/mcp-client/service/review/orchestrator/stages.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Callable from model.dtos import ReviewRequestDto +from model.enrichment import PrEnrichmentDataDto from model.output_schemas import CodeReviewOutput, CodeReviewIssue from model.multi_stage import ( ReviewPlan, @@ -87,21 +88,32 @@ async def execute_branch_analysis( async def execute_stage_0_planning( llm, request: ReviewRequestDto, - is_incremental: bool = False + is_incremental: bool = False, + processed_diff: Optional[ProcessedDiff] = None, ) -> ReviewPlan: """ Stage 0: Analyze metadata and generate a review plan. Uses structured output for reliable JSON parsing. """ + # Build a path → DiffFile lookup for real line stats + diff_by_path: Dict[str, Any] = {} + if processed_diff: + for df in processed_diff.files: + diff_by_path[df.path] = df + # Also index by basename for fuzzy matching + if '/' in df.path: + diff_by_path[df.path.rsplit('/', 1)[-1]] = df + # Prepare context for prompt changed_files_summary = [] if request.changedFiles: for f in request.changedFiles: + df = diff_by_path.get(f) or diff_by_path.get(f.rsplit('/', 1)[-1] if '/' in f else f) changed_files_summary.append({ "path": f, - "type": "MODIFIED", - "lines_added": "?", - "lines_deleted": "?" + "type": df.change_type.value.upper() if df else "MODIFIED", + "lines_added": df.additions if df else "?", + "lines_deleted": df.deletions if df else "?", }) prompt = PromptBuilder.build_stage_0_planning_prompt( @@ -416,6 +428,54 @@ async def fetch_batch_rag_context( return None +def _filter_rag_chunks_for_batch( + rag_context: Dict[str, Any], + batch_file_paths: List[str], +) -> Optional[Dict[str, Any]]: + """ + Pre-filter global RAG context to keep only chunks whose source path + is related to the current batch. This avoids injecting the full + 15-chunk global set into every batch when the per-batch fetch fails. + """ + chunks = rag_context.get("relevant_code", []) or rag_context.get("chunks", []) + if not chunks: + return rag_context + + batch_basenames = {p.rsplit("/", 1)[-1] if "/" in p else p for p in batch_file_paths} + batch_dirs = set() + for p in batch_file_paths: + parts = p.rsplit("/", 1) + if len(parts) == 2: + batch_dirs.add(parts[0]) + + filtered = [] + for chunk in chunks: + meta = chunk.get("metadata", {}) + chunk_path = meta.get("path") or chunk.get("path") or chunk.get("file_path", "") + if not chunk_path: + filtered.append(chunk) # keep chunks without path info + continue + + chunk_basename = chunk_path.rsplit("/", 1)[-1] if "/" in chunk_path else chunk_path + chunk_dir = chunk_path.rsplit("/", 1)[0] if "/" in chunk_path else "" + + # Keep if same file, same directory, or high-score (>= 0.8) regardless + score = chunk.get("score", chunk.get("relevance_score", 0)) + if (chunk_basename in batch_basenames + or chunk_dir in batch_dirs + or any(chunk_path.endswith(bp) or bp.endswith(chunk_path) for bp in batch_file_paths) + or score >= 0.8): + filtered.append(chunk) + + if not filtered: + # Don't return empty — keep original as fallback + return rag_context + + result = dict(rag_context) + result["relevant_code"] = filtered + return result + + async def review_file_batch( llm, request: ReviewRequestDto, @@ -488,9 +548,18 @@ async def review_file_batch( pr_changed_files=request.changedFiles ) elif fallback_rag_context: - logger.info(f"Using fallback RAG context for batch: {batch_file_paths}") + # Filter the global RAG context to chunks relevant to this batch. + # Without filtering, every batch receives the same 15-chunk global set + # (wasting tokens on context unrelated to the batch's files). + filtered_fallback = _filter_rag_chunks_for_batch( + fallback_rag_context, batch_file_paths + ) + logger.info( + f"Using filtered fallback RAG context for batch: {batch_file_paths} " + f"({len((filtered_fallback or {}).get('relevant_code', []))} chunks)" + ) rag_context_text = format_rag_context( - fallback_rag_context, + filtered_fallback, set(batch_file_paths), pr_changed_files=request.changedFiles ) @@ -543,31 +612,216 @@ async def review_file_batch( all_batch_issues.extend(review.issues) return all_batch_issues except Exception as parse_err: - logger.error(f"Batch review failed: {parse_err}") + logger.error( + f"Batch review double parse failure for {batch_file_paths}: {parse_err}. " + "Zero issues will be reported for this batch — results may be incomplete." + ) return [] - - return [] + + +# --------------------------------------------------------------------------- +# Stage 2 / Stage 3 context helpers +# --------------------------------------------------------------------------- + +# Simple substrings that mark a path as a migration / schema file. +# Checked with case-insensitive containment — no compiled regexes needed. +_MIGRATION_PATH_MARKERS = ( + '/db/migrate/', '/migrations/', '/migration/', + '/flyway/', '/liquibase/', '/alembic/', '/changeset/', +) + +# Fields to strip from CodeReviewIssue dicts before sending to Stage 2. +# Stage 2 only needs location + severity + reason to detect cross-file patterns. +_STAGE_2_STRIP_FIELDS = { + 'suggestedFixDiff', 'suggestedFixDescription', 'codeSnippet', + 'resolutionExplanation', 'resolvedInCommit', 'visibility', +} + + +def _build_architecture_context( + enrichment: Optional[PrEnrichmentDataDto], + changed_files: Optional[List[str]], +) -> str: + """ + Synthesise an architecture-reference section from the enrichment data + that pipeline-agent already computed (class hierarchy, inter-file + relationships, key imports). Zero extra LLM / RAG cost. + """ + sections: List[str] = [] + + if enrichment and enrichment.relationships: + rel_lines = [] + for r in enrichment.relationships: + rel_lines.append( + f" {r.sourceFile} --[{r.relationshipType.value}]--> {r.targetFile}" + + (f" (matched on: {r.matchedOn})" if r.matchedOn else "") + ) + if rel_lines: + sections.append( + "### Inter-file relationships (from dependency analysis)\n" + + "\n".join(rel_lines) + ) + + if enrichment and enrichment.fileMetadata: + hierarchy_lines = [] + for meta in enrichment.fileMetadata: + parts = [] + if meta.extendsClasses: + parts.append(f"extends {', '.join(meta.extendsClasses)}") + if meta.implementsInterfaces: + parts.append(f"implements {', '.join(meta.implementsInterfaces)}") + if parts: + hierarchy_lines.append(f" {meta.path}: {'; '.join(parts)}") + if hierarchy_lines: + sections.append( + "### Class hierarchy in changed files\n" + + "\n".join(hierarchy_lines) + ) + + # Summarise cross-file imports between changed files + if changed_files: + changed_set = set(changed_files or []) + import_lines = [] + for meta in enrichment.fileMetadata: + cross_imports = [ + imp for imp in meta.imports + if any(imp in cf or cf.endswith(imp) for cf in changed_set) + ] + if cross_imports: + import_lines.append( + f" {meta.path} imports: {', '.join(cross_imports[:10])}" + ) + if import_lines: + sections.append( + "### Cross-file imports among changed files\n" + + "\n".join(import_lines) + ) + + if not sections: + return "No architecture context available (enrichment data not provided)." + + return "\n\n".join(sections) + + +def _detect_migration_paths( + processed_diff: Optional[ProcessedDiff], +) -> str: + """ + Return a short list of migration file paths found in the diff. + + Stage 1 already reviews each migration file in full detail. + Stage 2 only needs to *know which files are migrations* so it can + reason about cross-file DB concerns (e.g. code referencing a column + that a migration drops). No raw SQL is injected. + """ + if not processed_diff: + return "No migration scripts detected." + + migration_files: List[str] = [] + for f in processed_diff.files: + path_lower = f.path.lower() + if path_lower.endswith('.sql') or any(m in path_lower for m in _MIGRATION_PATH_MARKERS): + migration_files.append(f.path) + + if not migration_files: + return "No migration scripts detected in this PR." + + listing = "\n".join(f"- {p}" for p in migration_files[:15]) # cap at 15 + return f"Migration files in this PR ({len(migration_files)}):\n{listing}" + + +def _slim_issues_for_stage_2( + issues: List[CodeReviewIssue], +) -> str: + """ + Serialize Stage 1 issues for Stage 2, stripping bulky fields that + Stage 2 does not need (fix diffs, code snippets, resolution details). + + Stage 2 detects *cross-file patterns* — it only needs: + file, line, severity, category, title/reason. + """ + slim = [] + for issue in issues: + d = issue.model_dump() + for key in _STAGE_2_STRIP_FIELDS: + d.pop(key, None) + slim.append(d) + return json.dumps(slim, indent=2) + + +def _summarize_issues_for_stage_3( + issues: List[CodeReviewIssue], +) -> str: + """ + Build a compact summary of Stage 1 issues for Stage 3 (executive report). + + The full issue list is posted as a separate comment, so Stage 3 only needs + aggregate counts and a short list of the most critical findings. + """ + if not issues: + return "No issues found in Stage 1." + + # --- Counts by severity --- + severity_counts: Dict[str, int] = {} + category_counts: Dict[str, int] = {} + for issue in issues: + sev = issue.severity.upper() + severity_counts[sev] = severity_counts.get(sev, 0) + 1 + cat = issue.category.upper() + category_counts[cat] = category_counts.get(cat, 0) + 1 + + lines = [ + f"Total issues: {len(issues)}", + "By severity: " + ", ".join(f"{k}: {v}" for k, v in sorted(severity_counts.items())), + "By category: " + ", ".join(f"{k}: {v}" for k, v in sorted(category_counts.items())), + ] + + # --- Top critical/high issues (title + file only) --- + priority_order = {'CRITICAL': 0, 'HIGH': 1, 'MEDIUM': 2, 'LOW': 3, 'INFO': 4} + ranked = sorted(issues, key=lambda i: priority_order.get(i.severity.upper(), 5)) + top_n = ranked[:10] + if top_n: + lines.append("\nTop findings:") + for i, issue in enumerate(top_n, 1): + lines.append(f" {i}. [{issue.severity}] {issue.file}: {issue.reason[:120]}") + + return "\n".join(lines) async def execute_stage_2_cross_file( llm, request: ReviewRequestDto, stage_1_issues: List[CodeReviewIssue], - plan: ReviewPlan + plan: ReviewPlan, + processed_diff: Optional[ProcessedDiff] = None, ) -> CrossFileAnalysisResult: """ Stage 2: Cross-file analysis. + + Uses enrichment data (relationships, class hierarchy) and diff-detected + migrations to provide the LLM with real architecture context instead of + placeholders. """ - # Serialize Stage 1 findings - issues_json = json.dumps([i.model_dump() for i in stage_1_issues], indent=2) - + # Slim Stage 1 findings (strip fix diffs, code snippets — Stage 2 only + # needs location + severity + reason for cross-file pattern detection) + issues_json = _slim_issues_for_stage_2(stage_1_issues) + + # Build architecture reference from enrichment data (zero-cost) + architecture_context = _build_architecture_context( + enrichment=request.enrichmentData, + changed_files=request.changedFiles, + ) + + # List migration file paths (no raw SQL — Stage 1 already reviewed them) + migrations = _detect_migration_paths(processed_diff) + prompt = PromptBuilder.build_stage_2_cross_file_prompt( repo_slug=request.projectVcsRepoSlug, pr_title=request.prTitle or "", commit_hash=request.commitHash or "HEAD", stage_1_findings_json=issues_json, - architecture_context="(Architecture context from MCP or knowledge base)", - migrations="(Migration scripts found in PR)", + architecture_context=architecture_context, + migrations=migrations, cross_file_concerns=plan.cross_file_concerns ) @@ -597,13 +851,15 @@ async def execute_stage_3_aggregation( plan: ReviewPlan, stage_1_issues: List[CodeReviewIssue], stage_2_results: CrossFileAnalysisResult, - is_incremental: bool = False + is_incremental: bool = False, + processed_diff: Optional[ProcessedDiff] = None ) -> str: """ Stage 3: Generate Markdown report. In incremental mode, includes summary of resolved vs new issues. """ - stage_1_json = json.dumps([i.model_dump() for i in stage_1_issues], indent=2) + # Compact summary — the full issue list is posted as a separate comment + stage_1_json = _summarize_issues_for_stage_3(stage_1_issues) stage_2_json = stage_2_results.model_dump_json(indent=2) plan_json = plan.model_dump_json(indent=2) @@ -621,14 +877,18 @@ async def execute_stage_3_aggregation( - Total issues after reconciliation: {len(stage_1_issues)} """ + # Use real diff stats when available, fall back to 0 + additions = processed_diff.total_additions if processed_diff else 0 + deletions = processed_diff.total_deletions if processed_diff else 0 + prompt = PromptBuilder.build_stage_3_aggregation_prompt( repo_slug=request.projectVcsRepoSlug, pr_id=str(request.pullRequestId), author="Unknown", pr_title=request.prTitle or "", total_files=len(request.changedFiles or []), - additions=0, # Need accurate stats - deletions=0, + additions=additions, + deletions=deletions, stage_0_plan=plan_json, stage_1_issues_json=stage_1_json, stage_2_findings_json=stage_2_json, diff --git a/python-ecosystem/mcp-client/service/review/review_service.py b/python-ecosystem/mcp-client/service/review/review_service.py index 24fe62d3..1466f4c8 100644 --- a/python-ecosystem/mcp-client/service/review/review_service.py +++ b/python-ecosystem/mcp-client/service/review/review_service.py @@ -30,6 +30,9 @@ class ReviewService: # Threshold for using LLM reranking (number of changed files) LLM_RERANK_FILE_THRESHOLD = 20 + # Maximum concurrent reviews (each spawns a JVM subprocess + LLM calls) + MAX_CONCURRENT_REVIEWS = int(os.environ.get("MAX_CONCURRENT_REVIEWS", "4")) + def __init__(self): load_dotenv() self.default_jar_path = os.environ.get( @@ -39,7 +42,7 @@ def __init__(self): ) self.rag_client = RagClient() self.rag_cache = get_rag_cache() - self.llm_reranker = None # Initialized lazily with LLM + self._review_semaphore = asyncio.Semaphore(self.MAX_CONCURRENT_REVIEWS) async def process_review_request( self, @@ -58,11 +61,12 @@ async def process_review_request( Returns: Dict with "result" key containing the analysis result or error """ - return await self._process_review( - request=request, - repo_path=None, - event_callback=event_callback - ) + async with self._review_semaphore: + return await self._process_review( + request=request, + repo_path=None, + event_callback=event_callback + ) async def _process_review( self, @@ -99,7 +103,8 @@ async def _process_review( has_raw_diff = bool(request.rawDiff) try: - context = "with pre-fetched diff" if has_raw_diff else "fetching diff via MCP" + async with asyncio.timeout(600): # 10-minute hard ceiling per review + context = "with pre-fetched diff" if has_raw_diff else "fetching diff via MCP" self._emit_event(event_callback, { "type": "status", "state": "started", @@ -120,9 +125,12 @@ async def _process_review( # Create LLM instance llm = self._create_llm(request) + + # Create a per-request reranker (not shared across concurrent requests) + llm_reranker = LLMReranker(llm_client=llm) # Fetch RAG context if enabled - rag_context = await self._fetch_rag_context(request, event_callback) + rag_context = await self._fetch_rag_context(request, event_callback, llm_reranker=llm_reranker) # Build processed_diff if rawDiff is available to optimize Stage 1 processed_diff = None @@ -164,22 +172,29 @@ async def _process_review( event_callback=event_callback ) - # Check for Branch Analysis / Reconciliation mode - if request.analysisType == "BRANCH_ANALYSIS": - logger.info("Executing Branch Analysis & Reconciliation mode") - # Build specific prompt for branch analysis - pr_metadata = self._build_pr_metadata(request) - prompt = PromptBuilder.build_branch_review_prompt_with_branch_issues_data(pr_metadata) - - result = await orchestrator.execute_branch_analysis(prompt) - else: - # Execute review with Multi-Stage Orchestrator - # Standard PR Review - result = await orchestrator.orchestrate_review( - request=request, - rag_context=rag_context, - processed_diff=processed_diff - ) + try: + # Check for Branch Analysis / Reconciliation mode + if request.analysisType == "BRANCH_ANALYSIS": + logger.info("Executing Branch Analysis & Reconciliation mode") + # Build specific prompt for branch analysis + pr_metadata = self._build_pr_metadata(request) + prompt = PromptBuilder.build_branch_review_prompt_with_branch_issues_data(pr_metadata) + + result = await orchestrator.execute_branch_analysis(prompt) + else: + # Execute review with Multi-Stage Orchestrator + # Standard PR Review + result = await orchestrator.orchestrate_review( + request=request, + rag_context=rag_context, + processed_diff=processed_diff + ) + finally: + # Always close MCP sessions to release JVM subprocesses + try: + await client.close_all_sessions() + except Exception as close_err: + logger.warning(f"Error closing MCP sessions: {close_err}") # Post-process issues to fix line numbers and merge duplicates @@ -221,6 +236,15 @@ async def _process_review( return {"result": result} + except TimeoutError: + timeout_msg = "Review timed out after 600 seconds" + logger.error(timeout_msg) + self._emit_event(event_callback, {"type": "error", "message": timeout_msg}) + error_response = ResponseParser.create_error_response( + "Review timed out", timeout_msg + ) + return {"result": error_response} + except Exception as e: # Log full error for debugging, but sanitize for user display logger.error(f"Review processing failed: {str(e)}", exc_info=True) @@ -256,7 +280,8 @@ def _build_jvm_props( async def _fetch_rag_context( self, request: ReviewRequestDto, - event_callback: Optional[Callable[[Dict], None]] + event_callback: Optional[Callable[[Dict], None]], + llm_reranker: Optional[LLMReranker] = None ) -> Optional[Dict[str, Any]]: """ Fetch relevant context from RAG pipeline. @@ -333,8 +358,8 @@ async def _fetch_rag_context( relevant_code = context.get("relevant_code", []) # Apply LLM reranking for large PRs - if len(changed_files) >= self.LLM_RERANK_FILE_THRESHOLD and self.llm_reranker: - reranked, rerank_result = await self.llm_reranker.rerank( + if len(changed_files) >= self.LLM_RERANK_FILE_THRESHOLD and llm_reranker: + reranked, rerank_result = await llm_reranker.rerank( relevant_code, pr_title=request.prTitle, pr_description=request.prDescription, @@ -391,7 +416,7 @@ def _create_mcp_client(self, config: Dict[str, Any]) -> MCPClient: raise Exception(f"Failed to construct MCPClient: {str(e)}") def _create_llm(self, request: ReviewRequestDto): - """Create LLM instance from request parameters and initialize reranker.""" + """Create LLM instance from request parameters.""" try: # Log the model being used for this request logger.info(f"Creating LLM for project {request.projectId}: provider={request.aiProvider}, model={request.aiModel}") @@ -402,9 +427,6 @@ def _create_llm(self, request: ReviewRequestDto): request.aiApiKey ) - # Initialize LLM reranker for large PRs - self.llm_reranker = LLMReranker(llm_client=llm) - return llm except Exception as e: raise Exception(f"Failed to create LLM instance: {str(e)}") diff --git a/python-ecosystem/mcp-client/utils/context_builder.py b/python-ecosystem/mcp-client/utils/context_builder.py index e414d35e..9ebd1e1d 100644 --- a/python-ecosystem/mcp-client/utils/context_builder.py +++ b/python-ecosystem/mcp-client/utils/context_builder.py @@ -1,37 +1,18 @@ """ -Context builder for structured code review with Lost-in-the-Middle protection. -Implements priority-based context assembly with token budget management. +Context builder utilities for RAG metrics and caching. + +Note: The original ContextBuilder, ContextBudget, and MODEL_CONTEXT_LIMITS +were removed — they are superseded by the multi-stage orchestrator pipeline. """ import logging import os -import re import hashlib -from typing import Dict, List, Any, Optional, Tuple +from typing import Dict, List, Any, Optional from dataclasses import dataclass, field -from enum import Enum from datetime import datetime -from .file_classifier import FileClassifier, FilePriority, ClassifiedFile - logger = logging.getLogger(__name__) -# === Environment-based configuration === -# Context Budget percentages (must sum to 1.0) -CONTEXT_BUDGET_HIGH_PRIORITY_PCT = float(os.environ.get("CONTEXT_BUDGET_HIGH_PRIORITY_PCT", "0.30")) -CONTEXT_BUDGET_MEDIUM_PRIORITY_PCT = float(os.environ.get("CONTEXT_BUDGET_MEDIUM_PRIORITY_PCT", "0.40")) -CONTEXT_BUDGET_LOW_PRIORITY_PCT = float(os.environ.get("CONTEXT_BUDGET_LOW_PRIORITY_PCT", "0.20")) -CONTEXT_BUDGET_RAG_PCT = float(os.environ.get("CONTEXT_BUDGET_RAG_PCT", "0.10")) - -# Validate budget percentages sum to 1.0 -_budget_sum = CONTEXT_BUDGET_HIGH_PRIORITY_PCT + CONTEXT_BUDGET_MEDIUM_PRIORITY_PCT + CONTEXT_BUDGET_LOW_PRIORITY_PCT + CONTEXT_BUDGET_RAG_PCT -if abs(_budget_sum - 1.0) > 0.01: - logger.warning(f"Context budget percentages sum to {_budget_sum}, expected 1.0. Normalizing...") - _factor = 1.0 / _budget_sum - CONTEXT_BUDGET_HIGH_PRIORITY_PCT *= _factor - CONTEXT_BUDGET_MEDIUM_PRIORITY_PCT *= _factor - CONTEXT_BUDGET_LOW_PRIORITY_PCT *= _factor - CONTEXT_BUDGET_RAG_PCT *= _factor - # RAG Cache settings RAG_CACHE_TTL_SECONDS = int(os.environ.get("RAG_CACHE_TTL_SECONDS", "300")) RAG_CACHE_MAX_SIZE = int(os.environ.get("RAG_CACHE_MAX_SIZE", "100")) @@ -41,583 +22,6 @@ RAG_DEFAULT_TOP_K = int(os.environ.get("RAG_DEFAULT_TOP_K", "15")) -# Model context limits (approximate usable tokens after system prompt) -MODEL_CONTEXT_LIMITS = { - # OpenAI models - "gpt-4-turbo": 90000, - "gpt-4-turbo-preview": 90000, - "gpt-4-0125-preview": 90000, - "gpt-4-1106-preview": 90000, - "gpt-4": 6000, - "gpt-4-32k": 24000, - "gpt-3.5-turbo": 12000, - "gpt-3.5-turbo-16k": 12000, - # Anthropic models - "claude-3-opus": 140000, - "claude-3-opus-20240229": 140000, - "claude-3-sonnet": 140000, - "claude-3-sonnet-20240229": 140000, - "claude-3-haiku": 140000, - "claude-3-haiku-20240307": 140000, - "claude-3-5-sonnet": 140000, - "claude-3-5-sonnet-20240620": 140000, - "claude-3-5-sonnet-20241022": 140000, - # OpenRouter models (common ones) - "anthropic/claude-3-opus": 140000, - "anthropic/claude-3-sonnet": 140000, - "anthropic/claude-3-haiku": 140000, - "openai/gpt-4-turbo": 90000, - "openai/gpt-4": 6000, - "google/gemini-pro": 24000, - "google/gemini-pro-1.5": 700000, - "meta-llama/llama-3-70b-instruct": 6000, - "mistralai/mistral-large": 24000, - "deepseek/deepseek-chat": 48000, - "deepseek/deepseek-coder": 48000, - - "openai/gpt-5-mini": 400000, - "openai/gpt-5.1-codex-mini": 400000, - "openai/gpt-5.2": 512000, - "openai/gpt-5.1": 400000, - "openai/gpt-5.1-thinking": 196000, - "openai/gpt-5": 128000, - "openai/o3-high": 200000, - "openai/o4-mini": 128000, - - "anthropic/claude-4.5-opus": 500000, - "anthropic/claude-4.5-sonnet": 200000, - "anthropic/claude-4.5-haiku": 200000, - "anthropic/claude-3.7-sonnet": 200000, - - "google/gemini-3-pro": 1000000, - "google/gemini-3-flash": 1000000, - "google/gemini-2.5-pro": 2000000, - "google/gemini-2.5-flash": 1000000, - "google/gemini-3-flash-preview": 1000000, - "google/gemini-3-pro-preview": 1000000, - - "llama-4-405b": 128000, - "llama-4-70b": 128000, - "llama-4-scout": 1000000, - - "deepseek-v3.1": 160000, - "deepseek-v3": 128000, - "mistral-large-2025": 128000, - - "gpt-4o": 128000, - "gpt-4-turbo": 128000, - "claude-3-5-sonnet": 200000, - - "x-ai/grok-4.1-fast": 2000000, - - "default": 200000 -} - - -@dataclass -class ContextBudget: - """Token budget allocation for different context sections.""" - total_tokens: int = 45000 # Default max context budget - - # Budget distribution (percentages) - loaded from environment - high_priority_pct: float = field(default_factory=lambda: CONTEXT_BUDGET_HIGH_PRIORITY_PCT) - medium_priority_pct: float = field(default_factory=lambda: CONTEXT_BUDGET_MEDIUM_PRIORITY_PCT) - low_priority_pct: float = field(default_factory=lambda: CONTEXT_BUDGET_LOW_PRIORITY_PCT) - rag_context_pct: float = field(default_factory=lambda: CONTEXT_BUDGET_RAG_PCT) - - @property - def high_priority_tokens(self) -> int: - return int(self.total_tokens * self.high_priority_pct) - - @property - def medium_priority_tokens(self) -> int: - return int(self.total_tokens * self.medium_priority_pct) - - @property - def low_priority_tokens(self) -> int: - return int(self.total_tokens * self.low_priority_pct) - - @property - def rag_tokens(self) -> int: - return int(self.total_tokens * self.rag_context_pct) - - @classmethod - def for_model(cls, model_name: str) -> "ContextBudget": - """ - Create a ContextBudget optimized for a specific model. - - Args: - model_name: The model identifier (e.g., "gpt-4-turbo", "claude-3-opus") - - Returns: - ContextBudget with appropriate token limits - """ - # Normalize model name for lookup - model_lower = model_name.lower() - - # Try exact match first - limit = MODEL_CONTEXT_LIMITS.get(model_lower) - - # Try partial matches - if limit is None: - for key, value in MODEL_CONTEXT_LIMITS.items(): - if key in model_lower or model_lower in key: - limit = value - break - - # Use default if no match - if limit is None: - limit = MODEL_CONTEXT_LIMITS["default"] - logger.warning(f"Unknown model '{model_name}', using default context budget: {limit}") - - logger.info(f"Context budget for model '{model_name}': {limit} tokens") - return cls(total_tokens=limit) - - -@dataclass -class ContextSection: - """A section of structured context.""" - priority: str - title: str - description: str - content: str - file_count: int - token_estimate: int - files_included: List[str] = field(default_factory=list) - files_truncated: List[str] = field(default_factory=list) - - -@dataclass -class StructuredContext: - """Complete structured context with all sections.""" - sections: List[ContextSection] - total_files: int - total_tokens_estimated: int - files_analyzed: List[str] - files_skipped: List[str] - metadata: Dict[str, Any] = field(default_factory=dict) - - def to_prompt_string(self) -> str: - """Convert to formatted string for prompt injection.""" - parts = [] - - # Add metadata header - parts.append("=" * 60) - parts.append("STRUCTURED CONTEXT WITH PRIORITY MARKERS") - parts.append(f"Total files: {self.total_files} | Estimated tokens: {self.total_tokens_estimated}") - parts.append("=" * 60) - parts.append("") - - for section in self.sections: - if section.content: - parts.append(f"=== {section.priority} PRIORITY: {section.title} ===") - parts.append(f"({section.description})") - parts.append(f"Files: {section.file_count} | ~{section.token_estimate} tokens") - parts.append("") - parts.append(section.content) - parts.append("") - parts.append(f"=== END {section.priority} PRIORITY ===") - parts.append("") - - return "\n".join(parts) - - -class ContextBuilder: - """ - Builds structured context for code review with Lost-in-the-Middle protection. - - Key features: - 1. Priority-based file ordering (HIGH -> MEDIUM -> LOW -> RAG) - 2. Token budget management per section - 3. Explicit section markers for LLM attention - 4. Smart truncation when budget exceeded - """ - - # Approximate tokens per character (conservative estimate) - TOKENS_PER_CHAR = 0.25 - - def __init__(self, budget: Optional[ContextBudget] = None): - self.budget = budget or ContextBudget() - self.file_classifier = FileClassifier() - - def build_structured_context( - self, - pr_metadata: Dict[str, Any], - diff_content: Dict[str, str], # file_path -> diff/content - rag_context: Optional[Dict[str, Any]] = None, - file_paths: Optional[List[str]] = None - ) -> StructuredContext: - """ - Build structured context from PR data and RAG results. - - Args: - pr_metadata: PR metadata (title, description, etc.) - diff_content: Dict mapping file paths to their diff content - rag_context: Optional RAG query results - file_paths: Optional list of changed file paths - - Returns: - StructuredContext with prioritized sections - """ - file_paths = file_paths or list(diff_content.keys()) - - # Step 1: Classify files by priority - classified = FileClassifier.classify_files(file_paths) - stats = FileClassifier.get_priority_stats(classified) - logger.info(f"File classification: {stats}") - - sections = [] - files_analyzed = [] - files_skipped = [] - total_tokens = 0 - - # Step 2: Build HIGH priority section - high_section = self._build_priority_section( - priority=FilePriority.HIGH, - files=classified[FilePriority.HIGH], - diff_content=diff_content, - token_budget=self.budget.high_priority_tokens, - title="Core Business Logic", - description="Analyze FIRST - security, auth, core services" - ) - sections.append(high_section) - files_analyzed.extend(high_section.files_included) - files_skipped.extend(high_section.files_truncated) - total_tokens += high_section.token_estimate - - # Step 3: Build MEDIUM priority section - medium_section = self._build_priority_section( - priority=FilePriority.MEDIUM, - files=classified[FilePriority.MEDIUM], - diff_content=diff_content, - token_budget=self.budget.medium_priority_tokens, - title="Dependencies & Shared Utils", - description="Models, DTOs, utilities, components" - ) - sections.append(medium_section) - files_analyzed.extend(medium_section.files_included) - files_skipped.extend(medium_section.files_truncated) - total_tokens += medium_section.token_estimate - - # Step 4: Build LOW priority section - low_section = self._build_priority_section( - priority=FilePriority.LOW, - files=classified[FilePriority.LOW], - diff_content=diff_content, - token_budget=self.budget.low_priority_tokens, - title="Tests & Config", - description="Test files, configurations (quick scan)" - ) - sections.append(low_section) - files_analyzed.extend(low_section.files_included) - files_skipped.extend(low_section.files_truncated) - total_tokens += low_section.token_estimate - - # Step 5: Build RAG context section - rag_section = self._build_rag_section( - rag_context=rag_context, - token_budget=self.budget.rag_tokens, - already_included=set(files_analyzed) - ) - sections.append(rag_section) - total_tokens += rag_section.token_estimate - - # Add skipped files from classification - for f in classified[FilePriority.SKIP]: - files_skipped.append(f.path) - - return StructuredContext( - sections=sections, - total_files=len(file_paths), - total_tokens_estimated=total_tokens, - files_analyzed=files_analyzed, - files_skipped=files_skipped, - metadata={ - "classification_stats": stats, - "budget": { - "total": self.budget.total_tokens, - "high": self.budget.high_priority_tokens, - "medium": self.budget.medium_priority_tokens, - "low": self.budget.low_priority_tokens, - "rag": self.budget.rag_tokens - } - } - ) - - def _build_priority_section( - self, - priority: FilePriority, - files: List[ClassifiedFile], - diff_content: Dict[str, str], - token_budget: int, - title: str, - description: str - ) -> ContextSection: - """Build a single priority section with budget management.""" - content_parts = [] - files_included = [] - files_truncated = [] - current_tokens = 0 - - for classified_file in files: - path = classified_file.path - if path not in diff_content: - continue - - file_content = diff_content[path] - file_tokens = self._estimate_tokens(file_content) - - # Check if adding this file would exceed budget - if current_tokens + file_tokens > token_budget: - # Try to include truncated version if significant budget remains - remaining_budget = token_budget - current_tokens - if remaining_budget > 500: # Only truncate if meaningful space left - truncated_content = self._truncate_content(file_content, remaining_budget) - content_parts.append(f"### {path} (TRUNCATED)") - content_parts.append(truncated_content) - content_parts.append("") - files_included.append(path) - current_tokens += self._estimate_tokens(truncated_content) - else: - files_truncated.append(path) - continue - - content_parts.append(f"### {path}") - content_parts.append(f"Category: {classified_file.category} | Importance: {classified_file.estimated_importance:.2f}") - content_parts.append("```") - content_parts.append(file_content) - content_parts.append("```") - content_parts.append("") - - files_included.append(path) - current_tokens += file_tokens - - return ContextSection( - priority=priority.value, - title=title, - description=description, - content="\n".join(content_parts), - file_count=len(files_included), - token_estimate=current_tokens, - files_included=files_included, - files_truncated=files_truncated - ) - - def _build_rag_section( - self, - rag_context: Optional[Dict[str, Any]], - token_budget: int, - already_included: set - ) -> ContextSection: - """Build RAG context section with deduplication.""" - if not rag_context or not rag_context.get("relevant_code"): - return ContextSection( - priority="RAG", - title="Additional Context from Codebase", - description="No RAG context available", - content="", - file_count=0, - token_estimate=0 - ) - - content_parts = [] - files_included = [] - current_tokens = 0 - max_rag_chunks = 5 # Limit RAG chunks - - relevant_code = rag_context.get("relevant_code", []) - chunk_count = 0 - - for chunk in relevant_code: - if chunk_count >= max_rag_chunks: - break - - # Skip if file already included in priority sections - chunk_path = chunk.get("metadata", {}).get("path", "unknown") - if chunk_path in already_included: - continue - - chunk_text = chunk.get("text", "") - chunk_score = chunk.get("score", 0) - chunk_tokens = self._estimate_tokens(chunk_text) - - if current_tokens + chunk_tokens > token_budget: - break - - content_parts.append(f"### RAG Context {chunk_count + 1}: {chunk_path}") - content_parts.append(f"Relevance score: {chunk_score:.3f}") - content_parts.append("```") - content_parts.append(chunk_text) - content_parts.append("```") - content_parts.append("") - - files_included.append(chunk_path) - current_tokens += chunk_tokens - chunk_count += 1 - - return ContextSection( - priority="RAG", - title="Additional Context from Codebase", - description="Semantically relevant code from repository (max 5 chunks)", - content="\n".join(content_parts), - file_count=len(files_included), - token_estimate=current_tokens, - files_included=files_included - ) - - def _estimate_tokens(self, text: str) -> int: - """Estimate token count for text.""" - return int(len(text) * self.TOKENS_PER_CHAR) - - def _truncate_content(self, content: str, target_tokens: int) -> str: - """Truncate content to fit within token budget.""" - target_chars = int(target_tokens / self.TOKENS_PER_CHAR) - - if len(content) <= target_chars: - return content - - # Try to truncate at a logical boundary - truncated = content[:target_chars] - - # Find last newline to avoid cutting mid-line - last_newline = truncated.rfind('\n') - if last_newline > target_chars * 0.8: # Only use if we don't lose too much - truncated = truncated[:last_newline] - - return truncated + "\n... [TRUNCATED - remaining content omitted for token budget]" - - -class RagReranker: - """ - Reranks RAG results for better relevance using multiple strategies. - """ - - @staticmethod - def rerank_by_file_priority( - rag_results: List[Dict[str, Any]], - classified_files: Dict[FilePriority, List[ClassifiedFile]], - boost_factor: float = 1.5 - ) -> List[Dict[str, Any]]: - """ - Boost RAG results that relate to high-priority files. - - Args: - rag_results: Original RAG results - classified_files: Files classified by priority - boost_factor: Multiplier for scores of related high-priority results - - Returns: - Reranked results with adjusted scores - """ - # Build set of high-priority file paths - high_priority_paths = { - f.path for f in classified_files.get(FilePriority.HIGH, []) - } - - # Get directory patterns from high-priority files - high_priority_dirs = set() - for path in high_priority_paths: - parts = path.split('/') - for i in range(1, len(parts)): - high_priority_dirs.add('/'.join(parts[:i])) - - reranked = [] - for result in rag_results: - result_copy = result.copy() - result_path = result.get("metadata", {}).get("path", "") - - # Boost if directly matches high-priority file - if result_path in high_priority_paths: - result_copy["score"] = result.get("score", 0) * boost_factor - result_copy["_boost_reason"] = "direct_high_priority_match" - # Smaller boost if in high-priority directory - elif any(result_path.startswith(d) for d in high_priority_dirs): - result_copy["score"] = result.get("score", 0) * (boost_factor * 0.7) - result_copy["_boost_reason"] = "high_priority_directory" - - reranked.append(result_copy) - - # Sort by adjusted score - reranked.sort(key=lambda x: x.get("score", 0), reverse=True) - return reranked - - @staticmethod - def filter_by_relevance_threshold( - results: List[Dict[str, Any]], - min_score: float = None, - min_results: int = 3 - ) -> List[Dict[str, Any]]: - """ - Filter results by minimum relevance score. - Always returns at least min_results if available. - Default min_score from RAG_MIN_RELEVANCE_SCORE env var. - """ - if min_score is None: - min_score = RAG_MIN_RELEVANCE_SCORE - filtered = [r for r in results if r.get("score", 0) >= min_score] - - # Ensure minimum results - if len(filtered) < min_results and len(results) >= min_results: - # Add top results regardless of score - for result in results: - if result not in filtered: - filtered.append(result) - if len(filtered) >= min_results: - break - - return filtered - - @staticmethod - def deduplicate_by_content( - results: List[Dict[str, Any]], - similarity_threshold: float = 0.8 - ) -> List[Dict[str, Any]]: - """ - Remove near-duplicate results based on content similarity. - Uses simple text overlap for efficiency. - """ - if not results: - return [] - - unique_results = [results[0]] - - for result in results[1:]: - is_duplicate = False - result_text = result.get("text", "") - - for existing in unique_results: - existing_text = existing.get("text", "") - - # Simple overlap check - overlap = RagReranker._calculate_overlap(result_text, existing_text) - if overlap > similarity_threshold: - is_duplicate = True - break - - if not is_duplicate: - unique_results.append(result) - - return unique_results - - @staticmethod - def _calculate_overlap(text1: str, text2: str) -> float: - """Calculate simple text overlap ratio.""" - if not text1 or not text2: - return 0.0 - - # Use set of words for simple comparison - words1 = set(text1.lower().split()) - words2 = set(text2.lower().split()) - - if not words1 or not words2: - return 0.0 - - intersection = len(words1 & words2) - union = len(words1 | words2) - - return intersection / union if union > 0 else 0.0 - - @dataclass class RAGMetrics: """Metrics for RAG processing quality and performance.""" diff --git a/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py b/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py index a1aeb473..f9d79c70 100644 --- a/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py +++ b/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py @@ -425,7 +425,7 @@ STAGE_2_CROSS_FILE_PROMPT_TEMPLATE = """SYSTEM ROLE: You are a staff architect reviewing this PR for systemic risks. -Focus on: data flow, authorization patterns, consistency, database integrity, service boundaries. +Focus on: data flow, authorization patterns, consistency, service boundaries. At temperature 0.1, you will be conservative—that is correct. Flag even low-confidence concerns. Return structured JSON. @@ -447,7 +447,7 @@ Architecture Reference {architecture_context} -Database Migrations in This PR +Migration Files in This PR {migrations} Output Format @@ -476,15 +476,6 @@ "severity": "HIGH" }} ], - "immutability_enforcement": {{ - "rule": "Analysis results immutable after status=FINAL", - "check_pass": true, - "evidence": "..." - }}, - "database_integrity": {{ - "concerns": ["FK constraints", "cascade deletes"], - "findings": [] - }}, "pr_recommendation": "PASS|PASS_WITH_WARNINGS|FAIL", "confidence": "HIGH|MEDIUM|LOW|INFO" }} diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/api/api.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/api/api.py index 8cb66131..35190448 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/api/api.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/api/api.py @@ -1,9 +1,10 @@ import logging import gc +import os import uuid from typing import Dict, List, Optional from fastapi import FastAPI, HTTPException, BackgroundTasks -from pydantic import BaseModel +from pydantic import BaseModel, field_validator from llama_index.core.schema import TextNode from qdrant_client.models import PointStruct @@ -16,11 +17,29 @@ app = FastAPI(title="CodeCrow RAG API", version="2.0.0") +# Service-to-service auth +from .middleware import ServiceSecretMiddleware +app.add_middleware(ServiceSecretMiddleware) + config = RAGConfig() index_manager = RAGIndexManager(config) query_service = RAGQueryService(config) +# Allowed base directory for all repo paths (set via env or default to /tmp) +_ALLOWED_REPO_ROOT = os.environ.get("ALLOWED_REPO_ROOT", "/tmp") + + +def _validate_repo_path(path: str) -> str: + """Validate that a repo path is within the allowed root and contains no traversal.""" + resolved = os.path.realpath(path) + if not resolved.startswith(os.path.realpath(_ALLOWED_REPO_ROOT)): + raise ValueError( + f"Path must be under {_ALLOWED_REPO_ROOT}, got: {path}" + ) + return path + + class IndexRequest(BaseModel): repo_path: str workspace: str @@ -29,6 +48,11 @@ class IndexRequest(BaseModel): commit: str exclude_patterns: Optional[List[str]] = None + @field_validator("repo_path") + @classmethod + def validate_repo_path(cls, v: str) -> str: + return _validate_repo_path(v) + class UpdateFilesRequest(BaseModel): file_paths: List[str] @@ -38,6 +62,11 @@ class UpdateFilesRequest(BaseModel): branch: str commit: str + @field_validator("repo_base") + @classmethod + def validate_repo_base(cls, v: str) -> str: + return _validate_repo_path(v) + class DeleteFilesRequest(BaseModel): file_paths: List[str] @@ -276,6 +305,11 @@ class EstimateRequest(BaseModel): repo_path: str exclude_patterns: Optional[List[str]] = None + @field_validator("repo_path") + @classmethod + def validate_repo_path(cls, v: str) -> str: + return _validate_repo_path(v) + class EstimateResponse(BaseModel): file_count: int diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/api/middleware.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/api/middleware.py new file mode 100644 index 00000000..9e78d0d5 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/api/middleware.py @@ -0,0 +1,46 @@ +""" +Shared-secret authentication middleware. + +Validates that internal service-to-service requests carry +the correct X-Service-Secret header matching the SERVICE_SECRET env var. +""" +import os +import logging + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse + +logger = logging.getLogger(__name__) + +# Paths that skip auth (health checks, readiness probes) +_PUBLIC_PATHS = frozenset({"/health", "/docs", "/openapi.json", "/redoc"}) + + +class ServiceSecretMiddleware(BaseHTTPMiddleware): + """Reject requests that don't carry a valid shared service secret.""" + + def __init__(self, app, secret: str | None = None): + super().__init__(app) + self.secret = secret or os.environ.get("SERVICE_SECRET", "") + + async def dispatch(self, request: Request, call_next): + # Skip auth for health/doc endpoints + if request.url.path in _PUBLIC_PATHS: + return await call_next(request) + + # If no secret is configured, allow all (dev mode) + if not self.secret: + return await call_next(request) + + provided = request.headers.get("x-service-secret", "") + if provided != self.secret: + logger.warning( + f"Unauthorized request to {request.url.path} from {request.client.host if request.client else 'unknown'}" + ) + return JSONResponse( + status_code=401, + content={"detail": "Missing or invalid service secret"}, + ) + + return await call_next(request) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/collection_manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/collection_manager.py index a55f1111..039f06bf 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/collection_manager.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/collection_manager.py @@ -11,7 +11,8 @@ from qdrant_client import QdrantClient from qdrant_client.models import ( Distance, VectorParams, - CreateAlias, DeleteAlias, CreateAliasOperation, DeleteAliasOperation + CreateAlias, DeleteAlias, CreateAliasOperation, DeleteAliasOperation, + PayloadSchemaType, TextIndexParams, TokenizerType ) logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ def ensure_collection_exists(self, collection_name: str) -> None: ) ) logger.info(f"Created collection {collection_name}") + self._ensure_payload_indexes(collection_name) else: logger.info(f"Collection {collection_name} already exists") @@ -63,8 +65,28 @@ def create_versioned_collection(self, base_name: str) -> str: distance=Distance.COSINE ) ) + self._ensure_payload_indexes(versioned_name) return versioned_name + def _ensure_payload_indexes(self, collection_name: str) -> None: + """Create payload indexes for efficient filtering on common fields.""" + try: + # Keyword index on 'path' for exact match and prefix filtering + self.client.create_payload_index( + collection_name=collection_name, + field_name="path", + field_schema=PayloadSchemaType.KEYWORD, + ) + # Keyword index on 'branch' for branch filtering + self.client.create_payload_index( + collection_name=collection_name, + field_name="branch", + field_schema=PayloadSchemaType.KEYWORD, + ) + logger.info(f"Payload indexes created for {collection_name}") + except Exception as e: + logger.warning(f"Failed to create payload indexes for {collection_name}: {e}") + def delete_collection(self, collection_name: str) -> bool: """Delete a collection.""" try: diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py index a6fb49ae..06768098 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py @@ -4,7 +4,7 @@ from llama_index.core import VectorStoreIndex from llama_index.vector_stores.qdrant import QdrantVectorStore from qdrant_client import QdrantClient -from qdrant_client.http.models import Filter, FieldCondition, MatchValue, MatchAny +from qdrant_client.http.models import Filter, FieldCondition, MatchValue, MatchAny, MatchText from ..models.config import RAGConfig from ..models.scoring_config import get_scoring_config @@ -407,20 +407,20 @@ def _apply_branch_priority(points: list, target: str, existing_target_paths: set with_vectors=False ) - # Fallback: partial match if exact fails + # Fallback: try with filename only (handles path prefix mismatches) if not results: - all_results, _ = self.qdrant_client.scroll( + results, _ = self.qdrant_client.scroll( collection_name=collection_name, - scroll_filter=Filter(must=[branch_filter]), - limit=1000, + scroll_filter=Filter( + must=[ + branch_filter, + FieldCondition(key="path", match=MatchText(text=filename)) + ] + ), + limit=limit_per_file * len(branches), with_payload=True, with_vectors=False ) - results = [ - p for p in all_results - if normalized_path in p.payload.get("path", "") or - filename == p.payload.get("path", "").rsplit("/", 1)[-1] - ][:limit_per_file * len(branches)] # Apply branch priority: if file exists in target branch, only keep target branch version if target_branch and len(branches) > 1: From b89fd99a987cff176cb504516c9a17ae95e34a4e Mon Sep 17 00:00:00 2001 From: rostislav Date: Mon, 9 Feb 2026 11:04:23 +0200 Subject: [PATCH 6/7] feat: Implement diff fingerprint caching for pull request analysis - Added DiffFingerprintUtil to compute a stable fingerprint for code changes in pull requests. - Enhanced PullRequestAnalysisProcessor to utilize commit hash and diff fingerprint caches for reusing analysis results. - Updated CodeAnalysis model to include a diff_fingerprint field for storage. - Modified CodeAnalysisService to support retrieval and cloning of analyses based on diff fingerprints and commit hashes. - Added database migrations to introduce the diff_fingerprint column and create necessary indexes. - Improved error handling and logging in various components, including file existence checks in Bitbucket Cloud. - Refactored tests to accommodate new functionality and ensure coverage for caching mechanisms. --- deployment/config/mcp-client/.env.sample | 1 + deployment/config/rag-pipeline/.env.sample | 1 + .../analysis/BranchAnalysisProcessor.java | 113 +++++++---- .../PullRequestAnalysisProcessor.java | 56 +++++- .../util/DiffFingerprintUtil.java | 102 ++++++++++ .../analysis/BranchAnalysisProcessorTest.java | 7 +- .../PullRequestAnalysisProcessorTest.java | 10 +- .../core/model/codeanalysis/CodeAnalysis.java | 6 + .../codeanalysis/CodeAnalysisRepository.java | 40 ++++ .../core/service/CodeAnalysisService.java | 176 ++++++++++++++---- ..._add_diff_fingerprint_to_code_analysis.sql | 9 + .../V1.4.1__deduplicate_branch_issues.sql | 59 ++++++ .../CheckFileExistsInBranchAction.java | 70 +++++-- .../processor/WebhookAsyncProcessor.java | 7 + python-ecosystem/mcp-client/api/middleware.py | 12 +- .../service/command/command_service.py | 2 +- .../service/review/review_service.py | 2 +- python-ecosystem/rag-pipeline/main.py | 2 +- .../src/rag_pipeline/api/middleware.py | 12 +- .../src/rag_pipeline/models/config.py | 2 +- 20 files changed, 588 insertions(+), 101 deletions(-) create mode 100644 java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/util/DiffFingerprintUtil.java create mode 100644 java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.0__add_diff_fingerprint_to_code_analysis.sql create mode 100644 java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.1__deduplicate_branch_issues.sql diff --git a/deployment/config/mcp-client/.env.sample b/deployment/config/mcp-client/.env.sample index 05066c29..43168803 100644 --- a/deployment/config/mcp-client/.env.sample +++ b/deployment/config/mcp-client/.env.sample @@ -6,6 +6,7 @@ RAG_API_URL=http://host.docker.internal:8001 # Shared secret for authenticating requests between internal services. # Must match the SERVICE_SECRET configured on rag-pipeline. # Leave empty to disable auth (dev mode only). +# IMPORTANT: Avoid $ { } characters in the secret — they can cause dotenv parsing issues. SERVICE_SECRET=change-me-to-a-random-secret # === Concurrency === diff --git a/deployment/config/rag-pipeline/.env.sample b/deployment/config/rag-pipeline/.env.sample index 1152a685..67003ea4 100644 --- a/deployment/config/rag-pipeline/.env.sample +++ b/deployment/config/rag-pipeline/.env.sample @@ -2,6 +2,7 @@ # Shared secret for authenticating incoming requests from mcp-client. # Must match the SERVICE_SECRET configured on mcp-client. # Leave empty to disable auth (dev mode only). +# IMPORTANT: Avoid $ { } characters in the secret — they can cause dotenv parsing issues. SERVICE_SECRET=change-me-to-a-random-secret # === Path Traversal Guard === diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/BranchAnalysisProcessor.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/BranchAnalysisProcessor.java index 80f54fd0..41979a48 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/BranchAnalysisProcessor.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/BranchAnalysisProcessor.java @@ -224,10 +224,10 @@ public Map process(BranchProcessRequest request, Consumer existingFiles = updateBranchFiles(changedFiles, project, request.getTargetBranchName()); Branch branch = createOrUpdateProjectBranch(project, request); - mapCodeAnalysisIssuesToBranch(changedFiles, branch, project); + mapCodeAnalysisIssuesToBranch(changedFiles, existingFiles, branch, project); // Always update branch issue counts after mapping (even on first analysis) // Previously this was only done in reanalyzeCandidateIssues() which could be skipped @@ -282,10 +282,15 @@ public Set parseFilePathsFromDiff(String rawDiff) { return files; } - private void updateBranchFiles(Set changedFiles, Project project, String branchName) { + /** + * Updates branch file records for changed files. + * @return the set of file paths confirmed to exist in the branch (used to avoid redundant API calls) + */ + private Set updateBranchFiles(Set changedFiles, Project project, String branchName) { VcsInfo vcsInfo = getVcsInfo(project); EVcsProvider provider = getVcsProvider(project); VcsOperationsService operationsService = vcsServiceFactory.getOperationsService(provider); + Set filesExistingInBranch = new HashSet<>(); for (String filePath : changedFiles) { try { @@ -303,9 +308,12 @@ private void updateBranchFiles(Set changedFiles, Project project, String log.debug("Skipping file {} - does not exist in branch {}", filePath, branchName); continue; } + filesExistingInBranch.add(filePath); } catch (Exception e) { log.warn("Failed to check file existence for {} in branch {}: {}. Proceeding anyway.", filePath, branchName, e.getMessage()); + // On error, assume the file exists so we don't skip it + filesExistingInBranch.add(filePath); } List relatedIssues = codeAnalysisIssueRepository @@ -314,7 +322,14 @@ private void updateBranchFiles(Set changedFiles, Project project, String .filter(issue -> branchName.equals(issue.getAnalysis().getBranchName()) || branchName.equals(issue.getAnalysis().getSourceBranchName())) .toList(); - long unresolvedCount = branchSpecific.stream().filter(i -> !i.isResolved()).count(); + + // Deduplicate by content key before counting — multiple analyses may + // report the same logical issue with different DB ids + Set seenKeys = new HashSet<>(); + long unresolvedCount = branchSpecific.stream() + .filter(i -> !i.isResolved()) + .filter(i -> seenKeys.add(buildIssueContentKey(i))) + .count(); Optional projectFileOptional = branchFileRepository .findByProjectIdAndBranchNameAndFilePath(project.getId(), branchName, filePath); @@ -333,6 +348,7 @@ private void updateBranchFiles(Set changedFiles, Project project, String branchFileRepository.save(branchFile); } } + return filesExistingInBranch; } private Branch createOrUpdateProjectBranch(Project project, BranchProcessRequest request) { @@ -348,31 +364,14 @@ private Branch createOrUpdateProjectBranch(Project project, BranchProcessRequest return branchRepository.save(branch); } - private void mapCodeAnalysisIssuesToBranch(Set changedFiles, Branch branch, Project project) { - VcsInfo vcsInfo = getVcsInfo(project); - EVcsProvider provider = getVcsProvider(project); - VcsOperationsService operationsService = vcsServiceFactory.getOperationsService(provider); - + private void mapCodeAnalysisIssuesToBranch(Set changedFiles, Set filesExistingInBranch, + Branch branch, Project project) { for (String filePath : changedFiles) { - try { - OkHttpClient client = vcsClientProvider.getHttpClient(vcsInfo.vcsConnection()); - - boolean fileExistsInBranch = operationsService.checkFileExistsInBranch( - client, - vcsInfo.workspace(), - vcsInfo.repoSlug(), - branch.getBranchName(), - filePath - ); - - if (!fileExistsInBranch) { - log.debug("Skipping issue mapping for file {} - does not exist in branch {}", + // Use cached file existence from updateBranchFiles to avoid redundant API calls + if (!filesExistingInBranch.contains(filePath)) { + log.debug("Skipping issue mapping for file {} - does not exist in branch {} (cached)", filePath, branch.getBranchName()); - continue; - } - } catch (Exception e) { - log.warn("Failed to check file existence for {} in branch {}: {}. Proceeding with mapping.", - filePath, branch.getBranchName(), e.getMessage()); + continue; } List allIssues = codeAnalysisIssueRepository.findByProjectIdAndFilePath(project.getId(), filePath); @@ -387,27 +386,67 @@ private void mapCodeAnalysisIssuesToBranch(Set changedFiles, Branch bran }) .toList(); + // Content-based deduplication: build a map of existing BranchIssues by content key + // to prevent the same logical issue from being linked multiple times across analyses. + // Key = "lineNumber:severity:category" — unique enough within a single file context. + List existingBranchIssues = branchIssueRepository + .findUnresolvedByBranchIdAndFilePath(branch.getId(), filePath); + Map contentKeyMap = new HashMap<>(); + for (BranchIssue bi : existingBranchIssues) { + String key = buildIssueContentKey(bi.getCodeAnalysisIssue()); + contentKeyMap.putIfAbsent(key, bi); + } + + int skipped = 0; for (CodeAnalysisIssue issue : branchSpecificIssues) { + // Tier 1: exact ID match — same CodeAnalysisIssue already linked Optional existing = branchIssueRepository .findByBranchIdAndCodeAnalysisIssueId(branch.getId(), issue.getId()); - BranchIssue bc; + if (existing.isPresent()) { - bc = existing.get(); - bc.setSeverity(issue.getSeverity()); - branchIssueRepository.saveAndFlush(bc); - } else { - bc = new BranchIssue(); - bc.setBranch(branch); - bc.setCodeAnalysisIssue(issue); - bc.setResolved(issue.isResolved()); + BranchIssue bc = existing.get(); bc.setSeverity(issue.getSeverity()); - bc.setFirstDetectedPrNumber(issue.getAnalysis() != null ? issue.getAnalysis().getPrNumber() : null); branchIssueRepository.saveAndFlush(bc); + continue; } + + // Tier 2: content-based dedup — same logical issue from a different analysis + String contentKey = buildIssueContentKey(issue); + if (contentKeyMap.containsKey(contentKey)) { + skipped++; + continue; + } + + // No match — create new BranchIssue + BranchIssue bc = new BranchIssue(); + bc.setBranch(branch); + bc.setCodeAnalysisIssue(issue); + bc.setResolved(issue.isResolved()); + bc.setSeverity(issue.getSeverity()); + bc.setFirstDetectedPrNumber(issue.getAnalysis() != null ? issue.getAnalysis().getPrNumber() : null); + branchIssueRepository.saveAndFlush(bc); + // Register in map so subsequent issues in this batch also dedup + contentKeyMap.put(contentKey, bc); + } + + if (skipped > 0) { + log.debug("Skipped {} duplicate issue(s) for file {} in branch {}", + skipped, filePath, branch.getBranchName()); } } } + /** + * Builds a content key for deduplication of branch issues. + * Two CodeAnalysisIssue records with the same key represent the same logical issue. + */ + private String buildIssueContentKey(CodeAnalysisIssue issue) { + return issue.getFilePath() + ":" + + issue.getLineNumber() + ":" + + issue.getSeverity() + ":" + + issue.getIssueCategory(); + } + private void reanalyzeCandidateIssues(Set changedFiles, Branch branch, Project project, BranchProcessRequest request, Consumer> consumer) { List candidateBranchIssues = new ArrayList<>(); for (String filePath : changedFiles) { diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java index 5c7783c0..052a8585 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java @@ -33,6 +33,8 @@ import java.util.Map; import java.util.Optional; +import org.rostilos.codecrow.analysisengine.util.DiffFingerprintUtil; + /** * Generic service that handles pull request analysis. * Uses VCS-specific services via VcsServiceFactory for provider-specific operations. @@ -151,6 +153,29 @@ public Map process( return Map.of("status", "cached", "cached", true); } + // --- Fallback cache: same commit hash, any PR number (handles close/reopen) --- + Optional commitHashHit = codeAnalysisService.getAnalysisByCommitHash( + project.getId(), request.getCommitHash()); + if (commitHashHit.isPresent()) { + log.info("Commit-hash cache hit for project={}, commit={} (source PR={}). Cloning for PR={}.", + project.getId(), request.getCommitHash(), + commitHashHit.get().getPrNumber(), request.getPullRequestId()); + CodeAnalysis cloned = codeAnalysisService.cloneAnalysisForPr( + commitHashHit.get(), project, request.getPullRequestId(), + request.getCommitHash(), request.getTargetBranchName(), + request.getSourceBranchName(), commitHashHit.get().getDiffFingerprint()); + try { + reportingService.postAnalysisResults(cloned, project, + request.getPullRequestId(), pullRequest.getId(), + request.getPlaceholderCommentId()); + } catch (IOException e) { + log.error("Failed to post commit-hash cached results to VCS: {}", e.getMessage(), e); + } + publishAnalysisCompletedEvent(project, request, correlationId, startTime, + AnalysisCompletedEvent.CompletionStatus.SUCCESS, 0, 0, null); + return Map.of("status", "cached_by_commit", "cached", true); + } + // Get all previous analyses for this PR to provide full issue history to AI List allPrAnalyses = codeAnalysisService.getAllPrAnalyses( project.getId(), @@ -170,6 +195,34 @@ public Map process( AiAnalysisRequest aiRequest = aiClientService.buildAiAnalysisRequest( project, request, previousAnalysis, allPrAnalyses); + // --- Diff fingerprint cache: same code changes, different PR/commit --- + String diffFingerprint = DiffFingerprintUtil.compute(aiRequest.getRawDiff()); + if (diffFingerprint != null) { + Optional fingerprintHit = codeAnalysisService.getAnalysisByDiffFingerprint( + project.getId(), diffFingerprint); + if (fingerprintHit.isPresent()) { + log.info("Diff fingerprint cache hit for project={}, fingerprint={} (source PR={}). Cloning for PR={}.", + project.getId(), diffFingerprint.substring(0, 8) + "...", + fingerprintHit.get().getPrNumber(), request.getPullRequestId()); + // TODO: Option B — LIGHTWEIGHT mode: instead of full clone, reuse Stage 1 issues + // but re-run Stage 2 cross-file analysis against the new target branch context. + CodeAnalysis cloned = codeAnalysisService.cloneAnalysisForPr( + fingerprintHit.get(), project, request.getPullRequestId(), + request.getCommitHash(), request.getTargetBranchName(), + request.getSourceBranchName(), diffFingerprint); + try { + reportingService.postAnalysisResults(cloned, project, + request.getPullRequestId(), pullRequest.getId(), + request.getPlaceholderCommentId()); + } catch (IOException e) { + log.error("Failed to post fingerprint-cached results to VCS: {}", e.getMessage(), e); + } + publishAnalysisCompletedEvent(project, request, correlationId, startTime, + AnalysisCompletedEvent.CompletionStatus.SUCCESS, 0, 0, null); + return Map.of("status", "cached_by_fingerprint", "cached", true); + } + } + Map aiResponse = aiAnalysisClient.performAnalysis(aiRequest, event -> { try { log.debug("Received event from AI client: type={}", event.get("type")); @@ -188,7 +241,8 @@ public Map process( request.getSourceBranchName(), request.getCommitHash(), request.getPrAuthorId(), - request.getPrAuthorUsername() + request.getPrAuthorUsername(), + diffFingerprint ); int issuesFound = newAnalysis.getIssues() != null ? newAnalysis.getIssues().size() : 0; diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/util/DiffFingerprintUtil.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/util/DiffFingerprintUtil.java new file mode 100644 index 00000000..6f8a381e --- /dev/null +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/util/DiffFingerprintUtil.java @@ -0,0 +1,102 @@ +package org.rostilos.codecrow.analysisengine.util; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Computes a content-based fingerprint of a unified diff. + *

+ * Only actual change lines ({@code +} / {@code -}) are included — context lines, + * hunk headers ({@code @@}), and file headers ({@code +++} / {@code ---} / {@code diff --git}) + * are excluded. The change lines are sorted to make the fingerprint stable regardless + * of file ordering within the diff. + *

+ * This allows detecting that two PRs carry the same code changes even if they target + * different branches (different merge-base → different context/hunk headers). + */ +public final class DiffFingerprintUtil { + + private DiffFingerprintUtil() { /* utility */ } + + /** + * Compute a SHA-256 hex digest of the normalised change lines in the given diff. + * + * @param rawDiff the filtered unified diff (may be {@code null} or empty) + * @return 64-char lowercase hex string, or {@code null} if the diff is blank + */ + public static String compute(String rawDiff) { + if (rawDiff == null || rawDiff.isBlank()) { + return null; + } + + List changeLines = extractChangeLines(rawDiff); + if (changeLines.isEmpty()) { + return null; + } + + // Sort for stability across different file orderings + Collections.sort(changeLines); + + try { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + for (String line : changeLines) { + digest.update(line.getBytes(StandardCharsets.UTF_8)); + digest.update((byte) '\n'); + } + return bytesToHex(digest.digest()); + } catch (NoSuchAlgorithmException e) { + // SHA-256 is guaranteed by the JVM spec — should never happen + throw new IllegalStateException("SHA-256 not available", e); + } + } + + /** + * Extract only the actual change lines from a unified diff. + * A "change line" starts with exactly one {@code +} or {@code -} and is NOT + * a file header ({@code +++}, {@code ---}) or a diff metadata line. + */ + private static List extractChangeLines(String diff) { + List lines = new ArrayList<>(); + // Normalise line endings + String normalised = diff.replace("\r\n", "\n").replace("\r", "\n"); + for (String raw : normalised.split("\n")) { + String line = trimTrailingWhitespace(raw); + if (line.isEmpty()) { + continue; + } + char first = line.charAt(0); + if (first != '+' && first != '-') { + continue; + } + // Skip file-level headers: "+++", "---", "diff --git" + if (line.startsWith("+++") || line.startsWith("---")) { + continue; + } + if (line.startsWith("diff ")) { + continue; + } + lines.add(line); + } + return lines; + } + + private static String trimTrailingWhitespace(String s) { + int end = s.length(); + while (end > 0 && Character.isWhitespace(s.charAt(end - 1))) { + end--; + } + return s.substring(0, end); + } + + private static String bytesToHex(byte[] bytes) { + StringBuilder sb = new StringBuilder(bytes.length * 2); + for (byte b : bytes) { + sb.append(String.format("%02x", b & 0xff)); + } + return sb.toString(); + } +} diff --git a/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/processor/analysis/BranchAnalysisProcessorTest.java b/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/processor/analysis/BranchAnalysisProcessorTest.java index c38168d3..f3b3d5e3 100644 --- a/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/processor/analysis/BranchAnalysisProcessorTest.java +++ b/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/processor/analysis/BranchAnalysisProcessorTest.java @@ -28,7 +28,6 @@ import org.rostilos.codecrow.core.persistence.repository.branch.BranchRepository; import org.rostilos.codecrow.core.persistence.repository.codeanalysis.CodeAnalysisIssueRepository; import org.rostilos.codecrow.vcsclient.VcsClientProvider; -import org.springframework.context.ApplicationEventPublisher; import java.io.IOException; import java.util.*; @@ -73,9 +72,6 @@ class BranchAnalysisProcessorTest { @Mock private RagOperationsService ragOperationsService; - @Mock - private ApplicationEventPublisher eventPublisher; - @Mock private VcsOperationsService operationsService; @@ -262,7 +258,8 @@ void shouldThrowAnalysisLockedExceptionWhenLockCannotBeAcquired() throws IOExcep assertThatThrownBy(() -> processor.process(request, consumer)) .isInstanceOf(AnalysisLockedException.class); - verify(eventPublisher, times(2)).publishEvent(any()); + // No consumer or event interactions should occur when lock is not acquired + verifyNoInteractions(consumer); } // Note: Full process() integration tests are complex and require extensive mocking. diff --git a/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessorTest.java b/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessorTest.java index 0b36db83..0a8c555d 100644 --- a/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessorTest.java +++ b/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessorTest.java @@ -140,10 +140,13 @@ void shouldSuccessfullyProcessPRAnalysis() throws Exception { when(codeAnalysisService.getCodeAnalysisCache(anyLong(), anyString(), anyLong())) .thenReturn(Optional.empty()); - when(codeAnalysisService.getPreviousVersionCodeAnalysis(anyLong(), anyLong())) + when(codeAnalysisService.getAnalysisByCommitHash(anyLong(), anyString())) .thenReturn(Optional.empty()); + when(codeAnalysisService.getAllPrAnalyses(anyLong(), anyLong())) + .thenReturn(List.of()); - when(aiClientService.buildAiAnalysisRequest(any(), any(), any())).thenReturn(aiAnalysisRequest); + when(aiClientService.buildAiAnalysisRequest(any(), any(), any(), anyList())).thenReturn(aiAnalysisRequest); + when(aiAnalysisRequest.getRawDiff()).thenReturn(""); Map aiResponse = Map.of( "comment", "Review comment", @@ -151,7 +154,8 @@ void shouldSuccessfullyProcessPRAnalysis() throws Exception { ); when(aiAnalysisClient.performAnalysis(any(), any())).thenReturn(aiResponse); - when(codeAnalysisService.createAnalysisFromAiResponse(any(), any(), anyLong(), anyString(), anyString(), anyString(), any(), any())) + when(codeAnalysisService.createAnalysisFromAiResponse( + any(), any(), anyLong(), anyString(), anyString(), anyString(), any(), any(), any())) .thenReturn(codeAnalysis); Map result = processor.process(request, consumer, project); diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/codeanalysis/CodeAnalysis.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/codeanalysis/CodeAnalysis.java index ae4377fa..14bd145d 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/codeanalysis/CodeAnalysis.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/codeanalysis/CodeAnalysis.java @@ -38,6 +38,9 @@ public class CodeAnalysis { @Column(name = "commit_hash", length = 40) private String commitHash; + @Column(name = "diff_fingerprint", length = 64) + private String diffFingerprint; + @Column(name = "target_branch_name") private String branchName; @@ -113,6 +116,9 @@ public void updateIssueCounts() { public String getCommitHash() { return commitHash; } public void setCommitHash(String commitHash) { this.commitHash = commitHash; } + public String getDiffFingerprint() { return diffFingerprint; } + public void setDiffFingerprint(String diffFingerprint) { this.diffFingerprint = diffFingerprint; } + public String getBranchName() { return branchName; } public void setBranchName(String branchName) { this.branchName = branchName; } diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/codeanalysis/CodeAnalysisRepository.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/codeanalysis/CodeAnalysisRepository.java index 27aff62d..516ca3c9 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/codeanalysis/CodeAnalysisRepository.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/codeanalysis/CodeAnalysisRepository.java @@ -112,6 +112,46 @@ Page searchAnalyses( @Query("SELECT ca FROM CodeAnalysis ca WHERE ca.id = :id") Optional findByIdWithIssues(@Param("id") Long id); + /** + * Find the most recent ACCEPTED analysis for a project with the same diff fingerprint. + * Used for content-based cache: reuse analysis when the same code changes appear in a different PR. + */ + @org.springframework.data.jpa.repository.EntityGraph(attributePaths = { + "issues", + "project", + "project.workspace", + "project.vcsBinding", + "project.vcsBinding.vcsConnection", + "project.aiBinding" + }) + @Query("SELECT ca FROM CodeAnalysis ca WHERE ca.project.id = :projectId " + + "AND ca.diffFingerprint = :diffFingerprint " + + "AND ca.status = org.rostilos.codecrow.core.model.codeanalysis.AnalysisStatus.ACCEPTED " + + "ORDER BY ca.createdAt DESC LIMIT 1") + Optional findTopByProjectIdAndDiffFingerprint( + @Param("projectId") Long projectId, + @Param("diffFingerprint") String diffFingerprint); + + /** + * Find the most recent ACCEPTED analysis for a project + commit hash (any PR number). + * Fallback cache for close/reopen scenarios where the same commit gets a new PR number. + */ + @org.springframework.data.jpa.repository.EntityGraph(attributePaths = { + "issues", + "project", + "project.workspace", + "project.vcsBinding", + "project.vcsBinding.vcsConnection", + "project.aiBinding" + }) + @Query("SELECT ca FROM CodeAnalysis ca WHERE ca.project.id = :projectId " + + "AND ca.commitHash = :commitHash " + + "AND ca.status = org.rostilos.codecrow.core.model.codeanalysis.AnalysisStatus.ACCEPTED " + + "ORDER BY ca.createdAt DESC LIMIT 1") + Optional findTopByProjectIdAndCommitHash( + @Param("projectId") Long projectId, + @Param("commitHash") String commitHash); + /** * Find all analyses for a PR across all versions, ordered by version descending. * Used to provide LLM with full issue history including resolved issues. diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/CodeAnalysisService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/CodeAnalysisService.java index d0fd4624..f34a58ea 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/CodeAnalysisService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/CodeAnalysisService.java @@ -24,23 +24,20 @@ @Transactional public class CodeAnalysisService { - private final CodeAnalysisRepository analysisRepository; - private final CodeAnalysisIssueRepository issueRepository; private final CodeAnalysisRepository codeAnalysisRepository; + private final CodeAnalysisIssueRepository issueRepository; private final QualityGateRepository qualityGateRepository; private final QualityGateEvaluator qualityGateEvaluator; private static final Logger log = LoggerFactory.getLogger(CodeAnalysisService.class); @Autowired public CodeAnalysisService( - CodeAnalysisRepository analysisRepository, - CodeAnalysisIssueRepository issueRepository, CodeAnalysisRepository codeAnalysisRepository, + CodeAnalysisIssueRepository issueRepository, QualityGateRepository qualityGateRepository ) { - this.analysisRepository = analysisRepository; - this.issueRepository = issueRepository; this.codeAnalysisRepository = codeAnalysisRepository; + this.issueRepository = issueRepository; this.qualityGateRepository = qualityGateRepository; this.qualityGateEvaluator = new QualityGateEvaluator(); } @@ -54,6 +51,21 @@ public CodeAnalysis createAnalysisFromAiResponse( String commitHash, String vcsAuthorId, String vcsAuthorUsername + ) { + return createAnalysisFromAiResponse(project, analysisData, pullRequestId, + targetBranchName, sourceBranchName, commitHash, vcsAuthorId, vcsAuthorUsername, null); + } + + public CodeAnalysis createAnalysisFromAiResponse( + Project project, + Map analysisData, + Long pullRequestId, + String targetBranchName, + String sourceBranchName, + String commitHash, + String vcsAuthorId, + String vcsAuthorUsername, + String diffFingerprint ) { try { // Check if analysis already exists for this commit (handles webhook retries) @@ -74,6 +86,7 @@ public CodeAnalysis createAnalysisFromAiResponse( analysis.setBranchName(targetBranchName); analysis.setSourceBranchName(sourceBranchName); analysis.setPrVersion(previousVersion + 1); + analysis.setDiffFingerprint(diffFingerprint); return fillAnalysisData(analysis, analysisData, commitHash, vcsAuthorId, vcsAuthorUsername); } catch (Exception e) { @@ -107,11 +120,11 @@ private CodeAnalysis fillAnalysisData( Object issuesObj = analysisData.get("issues"); if (issuesObj == null) { log.warn("No issues found in analysis data"); - return analysisRepository.save(analysis); + return codeAnalysisRepository.save(analysis); } // Save analysis first to get its ID for resolution tracking - CodeAnalysis savedAnalysis = analysisRepository.save(analysis); + CodeAnalysis savedAnalysis = codeAnalysisRepository.save(analysis); Long analysisId = savedAnalysis.getId(); Long prNumber = savedAnalysis.getPrNumber(); @@ -168,17 +181,22 @@ else if (issuesObj instanceof Map) { log.info("Successfully created analysis with {} issues", savedAnalysis.getIssues().size()); - // Evaluate quality gate - QualityGate qualityGate = getQualityGateForAnalysis(savedAnalysis); - if (qualityGate != null) { - QualityGateResult qgResult = qualityGateEvaluator.evaluate(savedAnalysis, qualityGate); - savedAnalysis.setAnalysisResult(qgResult.result()); - log.info("Quality gate '{}' evaluated with result: {}", qualityGate.getName(), qgResult.result()); - } else { - log.info("No quality gate found for analysis, skipping evaluation"); + // Evaluate quality gate — wrapped defensively so a QG failure + // (e.g. detached entity, lazy-init) does not abort the entire analysis + try { + QualityGate qualityGate = getQualityGateForAnalysis(savedAnalysis); + if (qualityGate != null) { + QualityGateResult qgResult = qualityGateEvaluator.evaluate(savedAnalysis, qualityGate); + savedAnalysis.setAnalysisResult(qgResult.result()); + log.info("Quality gate '{}' evaluated with result: {}", qualityGate.getName(), qgResult.result()); + } else { + log.info("No quality gate found for analysis, skipping evaluation"); + } + } catch (Exception qgEx) { + log.warn("Quality gate evaluation failed, analysis will be saved without QG result: {}", qgEx.getMessage()); } - return analysisRepository.save(savedAnalysis); + return codeAnalysisRepository.save(savedAnalysis); } catch (Exception e) { log.error("Error creating analysis from AI response: {}", e.getMessage(), e); @@ -230,6 +248,100 @@ public Optional getCodeAnalysisCache(Long projectId, String commit return codeAnalysisRepository.findByProjectIdAndCommitHashAndPrNumber(projectId, commitHash, prNumber).stream().findFirst(); } + /** + * Fallback cache lookup by commit hash only (ignoring PR number). + * Handles close/reopen scenarios where the same commit gets a new PR number. + */ + public Optional getAnalysisByCommitHash(Long projectId, String commitHash) { + return codeAnalysisRepository.findTopByProjectIdAndCommitHash(projectId, commitHash); + } + + /** + * Content-based cache lookup by diff fingerprint. + * Handles branch-cascade flows where the same code changes appear in different PRs + * (e.g. feature→release analyzed, then release→main opens with the same changes). + */ + public Optional getAnalysisByDiffFingerprint(Long projectId, String diffFingerprint) { + if (diffFingerprint == null || diffFingerprint.isBlank()) { + return Optional.empty(); + } + return codeAnalysisRepository.findTopByProjectIdAndDiffFingerprint(projectId, diffFingerprint); + } + + /** + * Clone an existing analysis for a new PR. + * Creates a new CodeAnalysis row with cloned issues, linked to the new PR identity. + * Used when a fingerprint/commit-hash cache hit matches a different PR. + * + * // TODO: Option B — LIGHTWEIGHT mode: instead of full clone, reuse Stage 1 issues + * // but re-run Stage 2 cross-file analysis against the new target branch context. + * // This would catch interaction differences when target branches differ. + * + * // TODO: Consider tracking storage growth from cloned analyses. If it becomes significant, + * // explore referencing the original analysis instead of deep-copying issues. + */ + public CodeAnalysis cloneAnalysisForPr( + CodeAnalysis source, + Project project, + Long newPrNumber, + String commitHash, + String targetBranchName, + String sourceBranchName, + String diffFingerprint + ) { + // Guard against duplicates (same idempotency check as createAnalysisFromAiResponse) + Optional existing = codeAnalysisRepository + .findByProjectIdAndCommitHashAndPrNumber(project.getId(), commitHash, newPrNumber); + if (existing.isPresent()) { + log.info("Cloned analysis already exists for project={}, commit={}, pr={}. Returning existing.", + project.getId(), commitHash, newPrNumber); + return existing.get(); + } + + int previousVersion = codeAnalysisRepository.findMaxPrVersion(project.getId(), newPrNumber).orElse(0); + + CodeAnalysis clone = new CodeAnalysis(); + clone.setProject(project); + clone.setAnalysisType(source.getAnalysisType()); + clone.setPrNumber(newPrNumber); + clone.setCommitHash(commitHash); + clone.setDiffFingerprint(diffFingerprint); + clone.setBranchName(targetBranchName); + clone.setSourceBranchName(sourceBranchName); + clone.setComment(source.getComment()); + clone.setStatus(source.getStatus()); + clone.setAnalysisResult(source.getAnalysisResult()); + clone.setPrVersion(previousVersion + 1); + + // Save first to get an ID + CodeAnalysis saved = codeAnalysisRepository.save(clone); + + // Deep-copy issues + for (CodeAnalysisIssue srcIssue : source.getIssues()) { + CodeAnalysisIssue issueClone = new CodeAnalysisIssue(); + issueClone.setSeverity(srcIssue.getSeverity()); + issueClone.setFilePath(srcIssue.getFilePath()); + issueClone.setLineNumber(srcIssue.getLineNumber()); + issueClone.setReason(srcIssue.getReason()); + issueClone.setSuggestedFixDescription(srcIssue.getSuggestedFixDescription()); + issueClone.setSuggestedFixDiff(srcIssue.getSuggestedFixDiff()); + issueClone.setIssueCategory(srcIssue.getIssueCategory()); + issueClone.setResolved(srcIssue.isResolved()); + issueClone.setResolvedDescription(srcIssue.getResolvedDescription()); + issueClone.setVcsAuthorId(srcIssue.getVcsAuthorId()); + issueClone.setVcsAuthorUsername(srcIssue.getVcsAuthorUsername()); + saved.addIssue(issueClone); + } + + saved.updateIssueCounts(); + CodeAnalysis result = codeAnalysisRepository.save(saved); + log.info("Cloned analysis {} → {} for PR {} (fingerprint={}, {} issues)", + source.getId(), result.getId(), newPrNumber, + diffFingerprint != null ? diffFingerprint.substring(0, 8) + "..." : "null", + result.getIssues().size()); + return result; + } + public Optional getPreviousVersionCodeAnalysis(Long projectId, Long prNumber) { return codeAnalysisRepository.findByProjectIdAndPrNumberWithMaxPrVersion(projectId, prNumber); } @@ -400,40 +512,40 @@ public CodeAnalysis createAnalysis(Project project, AnalysisType analysisType) { analysis.setProject(project); analysis.setAnalysisType(analysisType); analysis.setStatus(AnalysisStatus.PENDING); - return analysisRepository.save(analysis); + return codeAnalysisRepository.save(analysis); } public CodeAnalysis saveAnalysis(CodeAnalysis analysis) { analysis.updateIssueCounts(); - return analysisRepository.save(analysis); + return codeAnalysisRepository.save(analysis); } public Optional findById(Long id) { - return analysisRepository.findById(id); + return codeAnalysisRepository.findById(id); } public List findByProjectId(Long projectId) { - return analysisRepository.findByProjectIdOrderByCreatedAtDesc(projectId); + return codeAnalysisRepository.findByProjectIdOrderByCreatedAtDesc(projectId); } public List findByProjectIdAndType(Long projectId, AnalysisType analysisType) { - return analysisRepository.findByProjectIdAndAnalysisTypeOrderByCreatedAtDesc(projectId, analysisType); + return codeAnalysisRepository.findByProjectIdAndAnalysisTypeOrderByCreatedAtDesc(projectId, analysisType); } public Optional findByProjectIdAndPrNumber(Long projectId, Long prNumber) { - return analysisRepository.findByProjectIdAndPrNumber(projectId, prNumber); + return codeAnalysisRepository.findByProjectIdAndPrNumber(projectId, prNumber); } public Optional findByProjectIdAndPrNumberAndPrVersion(Long projectId, Long prNumber, int prVersion) { - return analysisRepository.findByProjectIdAndPrNumberAndPrVersion(projectId, prNumber, prVersion); + return codeAnalysisRepository.findByProjectIdAndPrNumberAndPrVersion(projectId, prNumber, prVersion); } public List findByProjectIdAndDateRange(Long projectId, OffsetDateTime startDate, OffsetDateTime endDate) { - return analysisRepository.findByProjectIdAndDateRange(projectId, startDate, endDate); + return codeAnalysisRepository.findByProjectIdAndDateRange(projectId, startDate, endDate); } public List findByProjectIdWithHighSeverityIssues(Long projectId) { - return analysisRepository.findByProjectIdWithHighSeverityIssues(projectId); + return codeAnalysisRepository.findByProjectIdWithHighSeverityIssues(projectId); } /** @@ -451,16 +563,16 @@ public org.springframework.data.domain.Page searchAnalyses( Long prNumber, AnalysisStatus status, org.springframework.data.domain.Pageable pageable) { - return analysisRepository.searchAnalyses(projectId, prNumber, status, pageable); + return codeAnalysisRepository.searchAnalyses(projectId, prNumber, status, pageable); } public Optional findLatestByProjectId(Long projectId) { - return analysisRepository.findLatestByProjectId(projectId); + return codeAnalysisRepository.findLatestByProjectId(projectId); } public AnalysisStats getProjectAnalysisStats(Long projectId) { - long totalAnalyses = analysisRepository.countByProjectId(projectId); - Double avgIssues = analysisRepository.getAverageIssuesPerAnalysis(projectId); + long totalAnalyses = codeAnalysisRepository.countByProjectId(projectId); + Double avgIssues = codeAnalysisRepository.getAverageIssuesPerAnalysis(projectId); long highSeverityCount = issueRepository.countByProjectIdAndSeverity(projectId, IssueSeverity.HIGH); long mediumSeverityCount = issueRepository.countByProjectIdAndSeverity(projectId, IssueSeverity.MEDIUM); @@ -493,11 +605,11 @@ public void markIssueAsResolved(Long issueId) { } public void deleteAnalysis(Long analysisId) { - analysisRepository.deleteById(analysisId); + codeAnalysisRepository.deleteById(analysisId); } public void deleteAllAnalysesByProjectId(Long projectId) { - analysisRepository.deleteByProjectId(projectId); + codeAnalysisRepository.deleteByProjectId(projectId); } public static class AnalysisStats { diff --git a/java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.0__add_diff_fingerprint_to_code_analysis.sql b/java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.0__add_diff_fingerprint_to_code_analysis.sql new file mode 100644 index 00000000..9a68490a --- /dev/null +++ b/java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.0__add_diff_fingerprint_to_code_analysis.sql @@ -0,0 +1,9 @@ +-- Add diff_fingerprint column for content-based analysis caching. +-- Allows reusing analysis results when the same code changes appear in different PRs +-- (e.g. close/reopen with a new PR number, or branch-cascade flows like feature→release→main). +ALTER TABLE code_analysis ADD COLUMN IF NOT EXISTS diff_fingerprint VARCHAR(64); + +-- Index for fingerprint-based cache lookups: (project_id, diff_fingerprint) +CREATE INDEX IF NOT EXISTS idx_code_analysis_project_diff_fingerprint + ON code_analysis (project_id, diff_fingerprint) + WHERE diff_fingerprint IS NOT NULL; diff --git a/java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.1__deduplicate_branch_issues.sql b/java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.1__deduplicate_branch_issues.sql new file mode 100644 index 00000000..8ce9d6e9 --- /dev/null +++ b/java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.1__deduplicate_branch_issues.sql @@ -0,0 +1,59 @@ +-- V1.4.1: Remove duplicate branch_issue rows that accumulated because the +-- deduplication key was based on code_analysis_issue_id (database PK) rather +-- than on the issue's semantic content. Each PR analysis creates fresh +-- CodeAnalysisIssue rows with new IDs, so the same logical issue +-- (same file, line, severity, category) ended up with N BranchIssue rows. +-- +-- Strategy: for each (branch_id, file_path, line_number, severity, category) +-- group keep only the row with the LOWEST id (oldest / first-detected) and +-- delete the rest. Afterwards recompute the denormalized counters on branch. + +-- 1. Delete duplicate branch_issues, keeping the first (lowest id) per group +DELETE FROM branch_issue +WHERE id NOT IN ( + SELECT keeper_id FROM ( + SELECT MIN(bi.id) AS keeper_id + FROM branch_issue bi + JOIN code_analysis_issue cai ON bi.code_analysis_issue_id = cai.id + GROUP BY bi.branch_id, + cai.file_path, + cai.line_number, + cai.severity, + COALESCE(cai.issue_category, '__NONE__') + ) AS keepers +); + +-- 2. Recompute denormalized branch issue counts +UPDATE branch b SET + total_issues = COALESCE(sub.total_unresolved, 0), + high_severity_count = COALESCE(sub.high_count, 0), + medium_severity_count = COALESCE(sub.medium_count, 0), + low_severity_count = COALESCE(sub.low_count, 0), + info_severity_count = COALESCE(sub.info_count, 0), + resolved_count = COALESCE(sub.resolved_total, 0), + updated_at = NOW() +FROM ( + SELECT + bi.branch_id, + COUNT(*) FILTER (WHERE bi.is_resolved = false) AS total_unresolved, + COUNT(*) FILTER (WHERE bi.is_resolved = false AND cai.severity = 'HIGH') AS high_count, + COUNT(*) FILTER (WHERE bi.is_resolved = false AND cai.severity = 'MEDIUM') AS medium_count, + COUNT(*) FILTER (WHERE bi.is_resolved = false AND cai.severity = 'LOW') AS low_count, + COUNT(*) FILTER (WHERE bi.is_resolved = false AND cai.severity = 'INFO') AS info_count, + COUNT(*) FILTER (WHERE bi.is_resolved = true) AS resolved_total + FROM branch_issue bi + JOIN code_analysis_issue cai ON bi.code_analysis_issue_id = cai.id + GROUP BY bi.branch_id +) sub +WHERE b.id = sub.branch_id; + +-- Also zero-out branches that lost all issues after cleanup +UPDATE branch SET + total_issues = 0, + high_severity_count = 0, + medium_severity_count = 0, + low_severity_count = 0, + info_severity_count = 0, + resolved_count = 0, + updated_at = NOW() +WHERE id NOT IN (SELECT DISTINCT branch_id FROM branch_issue); diff --git a/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/bitbucket/cloud/actions/CheckFileExistsInBranchAction.java b/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/bitbucket/cloud/actions/CheckFileExistsInBranchAction.java index 143c017b..c61a5e93 100644 --- a/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/bitbucket/cloud/actions/CheckFileExistsInBranchAction.java +++ b/java-ecosystem/libs/vcs-client/src/main/java/org/rostilos/codecrow/vcsclient/bitbucket/cloud/actions/CheckFileExistsInBranchAction.java @@ -15,10 +15,14 @@ /** * Action to check if a file exists in a specific branch on Bitbucket Cloud. * Uses the Bitbucket Cloud API to verify file existence without downloading content. + * Includes retry logic with exponential backoff for rate-limited (429) responses. */ public class CheckFileExistsInBranchAction { private static final Logger log = LoggerFactory.getLogger(CheckFileExistsInBranchAction.class); + private static final int MAX_RETRIES = 3; + private static final long INITIAL_BACKOFF_MS = 2_000; + private final OkHttpClient authorizedOkHttpClient; public CheckFileExistsInBranchAction(OkHttpClient authorizedOkHttpClient) { @@ -28,18 +32,19 @@ public CheckFileExistsInBranchAction(OkHttpClient authorizedOkHttpClient) { /** * Checks if a file exists in the specified branch. * Uses HEAD request to check existence without downloading file content. + * Retries with exponential backoff on 429 (rate-limit) responses. * * @param workspace workspace or team slug * @param repoSlug repository slug * @param branchName branch name (or commit hash) * @param filePath file path relative to repository root * @return true if file exists in the branch, false otherwise - * @throws IOException on network errors + * @throws IOException on network errors after retries are exhausted */ public boolean fileExists(String workspace, String repoSlug, String branchName, String filePath) throws IOException { String ws = Optional.ofNullable(workspace).orElse(""); String encodedPath = encodeFilePath(filePath); - + // Use Bitbucket Cloud API endpoint to check file existence with HEAD request String apiUrl = String.format("%s/repositories/%s/%s/src/%s/%s", BitbucketCloudConfig.BITBUCKET_API_BASE, ws, repoSlug, branchName, encodedPath); @@ -49,22 +54,53 @@ public boolean fileExists(String workspace, String repoSlug, String branchName, .head() .build(); - try (Response resp = authorizedOkHttpClient.newCall(req).execute()) { - if (resp.isSuccessful()) { - return true; - } else if (resp.code() == 404) { - log.debug("File not found: {} in branch {} (workspace: {}, repo: {})", - filePath, branchName, workspace, repoSlug); - return false; - } else { - String msg = String.format("Unexpected response %d when checking file existence: %s in branch %s", - resp.code(), filePath, branchName); - log.warn(msg); - throw new IOException(msg); + int attempt = 0; + long backoffMs = INITIAL_BACKOFF_MS; + + while (true) { + try (Response resp = authorizedOkHttpClient.newCall(req).execute()) { + if (resp.isSuccessful()) { + return true; + } else if (resp.code() == 404) { + log.debug("File not found: {} in branch {} (workspace: {}, repo: {})", + filePath, branchName, workspace, repoSlug); + return false; + } else if (resp.code() == 429 && attempt < MAX_RETRIES) { + // Rate limited — honour Retry-After header if present, otherwise exponential backoff + long waitMs = backoffMs; + String retryAfter = resp.header("Retry-After"); + if (retryAfter != null) { + try { + waitMs = Long.parseLong(retryAfter.trim()) * 1_000; + } catch (NumberFormatException ignored) { + // Use default backoff + } + } + log.info("Rate limited (429) checking file {}. Retrying in {}ms (attempt {}/{})", + filePath, waitMs, attempt + 1, MAX_RETRIES); + try { + Thread.sleep(waitMs); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new IOException("Interrupted while waiting for rate-limit backoff", ie); + } + attempt++; + backoffMs *= 2; // Exponential backoff + } else { + String msg = String.format("Unexpected response %d when checking file existence: %s in branch %s", + resp.code(), filePath, branchName); + log.warn(msg); + throw new IOException(msg); + } + } catch (IOException e) { + if (attempt < MAX_RETRIES && e.getMessage() != null && e.getMessage().contains("429")) { + attempt++; + backoffMs *= 2; + continue; + } + log.error("Failed to check file existence for {}: {}", filePath, e.getMessage(), e); + throw e; } - } catch (IOException e) { - log.error("Failed to check file existence for {}: {}", filePath, e.getMessage(), e); - throw e; } } diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index 064c0061..a1c67990 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -309,6 +309,13 @@ private void initializeProjectAssociations(Project project) { aiConn.getProviderKey(); } } + + // Force initialization of Quality Gate (lazy @ManyToOne) — accessed by + // CodeAnalysisService.getQualityGateForAnalysis() after session closes + var qualityGate = project.getQualityGate(); + if (qualityGate != null) { + qualityGate.getName(); + } } /** diff --git a/python-ecosystem/mcp-client/api/middleware.py b/python-ecosystem/mcp-client/api/middleware.py index 9e78d0d5..96bef76c 100644 --- a/python-ecosystem/mcp-client/api/middleware.py +++ b/python-ecosystem/mcp-client/api/middleware.py @@ -23,6 +23,10 @@ class ServiceSecretMiddleware(BaseHTTPMiddleware): def __init__(self, app, secret: str | None = None): super().__init__(app) self.secret = secret or os.environ.get("SERVICE_SECRET", "") + if self.secret: + logger.info("ServiceSecretMiddleware: secret configured (length=%d)", len(self.secret)) + else: + logger.warning("ServiceSecretMiddleware: no secret configured — auth disabled") async def dispatch(self, request: Request, call_next): # Skip auth for health/doc endpoints @@ -36,7 +40,13 @@ async def dispatch(self, request: Request, call_next): provided = request.headers.get("x-service-secret", "") if provided != self.secret: logger.warning( - f"Unauthorized request to {request.url.path} from {request.client.host if request.client else 'unknown'}" + "Unauthorized request to %s from %s — " + "provided_len=%d expected_len=%d match=%s", + request.url.path, + request.client.host if request.client else "unknown", + len(provided), + len(self.secret), + provided == self.secret, ) return JSONResponse( status_code=401, diff --git a/python-ecosystem/mcp-client/service/command/command_service.py b/python-ecosystem/mcp-client/service/command/command_service.py index a0e59991..f8743e6a 100644 --- a/python-ecosystem/mcp-client/service/command/command_service.py +++ b/python-ecosystem/mcp-client/service/command/command_service.py @@ -28,7 +28,7 @@ class CommandService: MAX_STEPS_ASK = 40 def __init__(self): - load_dotenv() + load_dotenv(interpolate=False) self.default_jar_path = os.environ.get( "MCP_SERVER_JAR", "/app/codecrow-vcs-mcp-1.0.jar" diff --git a/python-ecosystem/mcp-client/service/review/review_service.py b/python-ecosystem/mcp-client/service/review/review_service.py index 1466f4c8..b3f4eb4e 100644 --- a/python-ecosystem/mcp-client/service/review/review_service.py +++ b/python-ecosystem/mcp-client/service/review/review_service.py @@ -34,7 +34,7 @@ class ReviewService: MAX_CONCURRENT_REVIEWS = int(os.environ.get("MAX_CONCURRENT_REVIEWS", "4")) def __init__(self): - load_dotenv() + load_dotenv(interpolate=False) self.default_jar_path = os.environ.get( "MCP_SERVER_JAR", #"/var/www/html/codecrow/codecrow-public/java-ecosystem/mcp-servers/vcs-mcp/target/codecrow-vcs-mcp-1.0.jar", diff --git a/python-ecosystem/rag-pipeline/main.py b/python-ecosystem/rag-pipeline/main.py index 4918eaa5..3e18e67c 100644 --- a/python-ecosystem/rag-pipeline/main.py +++ b/python-ecosystem/rag-pipeline/main.py @@ -12,7 +12,7 @@ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) -load_dotenv() +load_dotenv(interpolate=False) # Validate critical environment variables before starting def validate_environment(): diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/api/middleware.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/api/middleware.py index 9e78d0d5..96bef76c 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/api/middleware.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/api/middleware.py @@ -23,6 +23,10 @@ class ServiceSecretMiddleware(BaseHTTPMiddleware): def __init__(self, app, secret: str | None = None): super().__init__(app) self.secret = secret or os.environ.get("SERVICE_SECRET", "") + if self.secret: + logger.info("ServiceSecretMiddleware: secret configured (length=%d)", len(self.secret)) + else: + logger.warning("ServiceSecretMiddleware: no secret configured — auth disabled") async def dispatch(self, request: Request, call_next): # Skip auth for health/doc endpoints @@ -36,7 +40,13 @@ async def dispatch(self, request: Request, call_next): provided = request.headers.get("x-service-secret", "") if provided != self.secret: logger.warning( - f"Unauthorized request to {request.url.path} from {request.client.host if request.client else 'unknown'}" + "Unauthorized request to %s from %s — " + "provided_len=%d expected_len=%d match=%s", + request.url.path, + request.client.host if request.client else "unknown", + len(provided), + len(self.secret), + provided == self.secret, ) return JSONResponse( status_code=401, diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py index 18826395..2cffc100 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py @@ -51,7 +51,7 @@ def get_embedding_dim_for_model(model: str) -> int: class RAGConfig(BaseModel): """Configuration for RAG pipeline""" - load_dotenv() + load_dotenv(interpolate=False) # Qdrant for vector storage qdrant_url: str = Field(default_factory=lambda: os.getenv("QDRANT_URL", "http://localhost:6333")) qdrant_collection_prefix: str = Field(default_factory=lambda: os.getenv("QDRANT_COLLECTION_PREFIX", "codecrow")) From c509092c050bd3ce4d646b0d149064c3842ccd4a Mon Sep 17 00:00:00 2001 From: rostislav Date: Mon, 9 Feb 2026 12:26:35 +0200 Subject: [PATCH 7/7] feat: Add service secret configuration for RAG API and update tests for header validation --- .../java-shared/application.properties.sample | 1 + .../analysis/BranchAnalysisProcessorTest.java | 1 - .../ragengine/client/RagPipelineClient.java | 56 ++++++++++++------- .../client/RagPipelineClientTest.java | 45 +++++++++++++-- 4 files changed, 77 insertions(+), 26 deletions(-) diff --git a/deployment/config/java-shared/application.properties.sample b/deployment/config/java-shared/application.properties.sample index 765c71cd..3f794130 100644 --- a/deployment/config/java-shared/application.properties.sample +++ b/deployment/config/java-shared/application.properties.sample @@ -232,6 +232,7 @@ llm.sync.anthropic.api-key= #RAG codecrow.rag.api.url=http://host.docker.internal:8001 codecrow.rag.api.enabled=true +codecrow.rag.api.secret=change-me-to-a-random-secret # RAG API timeouts (in seconds) codecrow.rag.api.timeout.connect=30 codecrow.rag.api.timeout.read=120 diff --git a/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/processor/analysis/BranchAnalysisProcessorTest.java b/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/processor/analysis/BranchAnalysisProcessorTest.java index f3b3d5e3..6e7d5bac 100644 --- a/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/processor/analysis/BranchAnalysisProcessorTest.java +++ b/java-ecosystem/libs/analysis-engine/src/test/java/org/rostilos/codecrow/analysisengine/processor/analysis/BranchAnalysisProcessorTest.java @@ -251,7 +251,6 @@ void shouldThrowAnalysisLockedExceptionWhenLockCannotBeAcquired() throws IOExcep when(projectService.getProjectWithConnections(1L)).thenReturn(project); when(project.getId()).thenReturn(1L); - when(project.getName()).thenReturn("Test Project"); when(analysisLockService.acquireLockWithWait(any(), anyString(), any(), anyString(), any(), any())) .thenReturn(Optional.empty()); diff --git a/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/client/RagPipelineClient.java b/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/client/RagPipelineClient.java index efb34d84..44456325 100644 --- a/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/client/RagPipelineClient.java +++ b/java-ecosystem/libs/rag-engine/src/main/java/org/rostilos/codecrow/ragengine/client/RagPipelineClient.java @@ -22,16 +22,19 @@ public class RagPipelineClient { private final ObjectMapper objectMapper; private final String ragApiUrl; private final boolean ragEnabled; + private final String serviceSecret; public RagPipelineClient( @Value("${codecrow.rag.api.url:http://rag-pipeline:8001}") String ragApiUrl, @Value("${codecrow.rag.api.enabled:false}") boolean ragEnabled, @Value("${codecrow.rag.api.timeout.connect:30}") int connectTimeout, @Value("${codecrow.rag.api.timeout.read:120}") int readTimeout, - @Value("${codecrow.rag.api.timeout.indexing:14400}") int indexingTimeout + @Value("${codecrow.rag.api.timeout.indexing:14400}") int indexingTimeout, + @Value("${codecrow.rag.api.secret:}") String serviceSecret ) { this.ragApiUrl = ragApiUrl; this.ragEnabled = ragEnabled; + this.serviceSecret = serviceSecret != null ? serviceSecret : ""; this.httpClient = new OkHttpClient.Builder() .connectTimeout(connectTimeout, java.util.concurrent.TimeUnit.SECONDS) @@ -209,10 +212,11 @@ public void deleteIndex(String workspace, String project, String branch) throws } String url = String.format("%s/index/%s/%s/%s", ragApiUrl, workspace, project, branch); - Request request = new Request.Builder() + Request.Builder builder = new Request.Builder() .url(url) - .delete() - .build(); + .delete(); + addAuthHeader(builder); + Request request = builder.build(); try (Response response = httpClient.newCall(request).execute()) { if (!response.isSuccessful()) { @@ -241,10 +245,11 @@ public boolean deleteBranch(String workspace, String project, String branch) thr String encodedBranch = java.net.URLEncoder.encode(branch, java.nio.charset.StandardCharsets.UTF_8); String url = String.format("%s/index/%s/%s/branch/%s", ragApiUrl, workspace, project, encodedBranch); - Request request = new Request.Builder() + Request.Builder builder = new Request.Builder() .url(url) - .delete() - .build(); + .delete(); + addAuthHeader(builder); + Request request = builder.build(); try (Response response = httpClient.newCall(request).execute()) { if (response.isSuccessful()) { @@ -271,10 +276,11 @@ public List getIndexedBranches(String workspace, String project) { try { String url = String.format("%s/index/%s/%s/branches", ragApiUrl, workspace, project); - Request request = new Request.Builder() + Request.Builder builder = new Request.Builder() .url(url) - .get() - .build(); + .get(); + addAuthHeader(builder); + Request request = builder.build(); try (Response response = httpClient.newCall(request).execute()) { if (response.isSuccessful() && response.body() != null) { @@ -311,10 +317,11 @@ public List> getIndexedBranchesWithStats(String workspace, S try { String url = String.format("%s/index/%s/%s/branches", ragApiUrl, workspace, project); - Request request = new Request.Builder() + Request.Builder builder = new Request.Builder() .url(url) - .get() - .build(); + .get(); + addAuthHeader(builder); + Request request = builder.build(); try (Response response = httpClient.newCall(request).execute()) { if (response.isSuccessful() && response.body() != null) { @@ -376,10 +383,11 @@ public boolean isHealthy() { } try { - Request request = new Request.Builder() + Request.Builder builder = new Request.Builder() .url(ragApiUrl + "/health") - .get() - .build(); + .get(); + addAuthHeader(builder); + Request request = builder.build(); try (Response response = httpClient.newCall(request).execute()) { return response.isSuccessful(); @@ -398,15 +406,25 @@ private Map postLongRunning(String url, Map payl return doRequest(url, payload, longRunningHttpClient); } + /** + * Adds the x-service-secret header to the request if a secret is configured. + */ + private void addAuthHeader(Request.Builder builder) { + if (!serviceSecret.isEmpty()) { + builder.addHeader("x-service-secret", serviceSecret); + } + } + @SuppressWarnings("unchecked") private Map doRequest(String url, Map payload, OkHttpClient client) throws IOException { String json = objectMapper.writeValueAsString(payload); RequestBody body = RequestBody.create(json, JSON); - Request request = new Request.Builder() + Request.Builder builder = new Request.Builder() .url(url) - .post(body) - .build(); + .post(body); + addAuthHeader(builder); + Request request = builder.build(); try (Response response = client.newCall(request).execute()) { String responseBody = response.body() != null ? response.body().string() : "{}"; diff --git a/java-ecosystem/libs/rag-engine/src/test/java/org/rostilos/codecrow/ragengine/client/RagPipelineClientTest.java b/java-ecosystem/libs/rag-engine/src/test/java/org/rostilos/codecrow/ragengine/client/RagPipelineClientTest.java index 274fc1c5..f2694048 100644 --- a/java-ecosystem/libs/rag-engine/src/test/java/org/rostilos/codecrow/ragengine/client/RagPipelineClientTest.java +++ b/java-ecosystem/libs/rag-engine/src/test/java/org/rostilos/codecrow/ragengine/client/RagPipelineClientTest.java @@ -33,7 +33,8 @@ void setUp() throws IOException { true, // enabled 5, // connect timeout 10, // read timeout - 20 // indexing timeout + 20, // indexing timeout + "test-secret" // service secret ); objectMapper = new ObjectMapper(); @@ -71,7 +72,7 @@ void testDeleteFiles_WhenDisabled() throws Exception { RagPipelineClient disabledClient = new RagPipelineClient( mockWebServer.url("/").toString(), false, // disabled - 5, 10, 20 + 5, 10, 20, "" ); List files = List.of("file1.java"); @@ -123,7 +124,7 @@ void testSemanticSearch_WhenDisabled() throws Exception { RagPipelineClient disabledClient = new RagPipelineClient( mockWebServer.url("/").toString(), false, - 5, 10, 20 + 5, 10, 20, "" ); Map result = disabledClient.semanticSearch( @@ -159,7 +160,7 @@ void testGetPRContext_WhenDisabled() throws Exception { RagPipelineClient disabledClient = new RagPipelineClient( mockWebServer.url("/").toString(), false, - 5, 10, 20 + 5, 10, 20, "" ); Map result = disabledClient.getPRContext( @@ -226,7 +227,7 @@ void testUpdateFiles_WhenDisabled() throws Exception { RagPipelineClient disabledClient = new RagPipelineClient( mockWebServer.url("/").toString(), false, - 5, 10, 20 + 5, 10, 20, "" ); Map result = disabledClient.updateFiles( @@ -243,7 +244,8 @@ void testConstructor_WithDefaults() { true, 30, 120, - 14400 + 14400, + "" ); assertThat(defaultClient).isNotNull(); @@ -268,4 +270,35 @@ void testNetworkError_ThrowsIOException() throws IOException { )) .isInstanceOf(IOException.class); } + + @Test + void testServiceSecretHeader_SentOnRequests() throws Exception { + Map mockResponse = Map.of("status", "success"); + mockWebServer.enqueue(new MockResponse() + .setBody(objectMapper.writeValueAsString(mockResponse)) + .addHeader("Content-Type", "application/json")); + + client.deleteFiles(List.of("file.java"), "ws", "proj", "main"); + + RecordedRequest request = mockWebServer.takeRequest(); + assertThat(request.getHeader("x-service-secret")).isEqualTo("test-secret"); + } + + @Test + void testServiceSecretHeader_NotSentWhenEmpty() throws Exception { + RagPipelineClient noSecretClient = new RagPipelineClient( + mockWebServer.url("/").toString(), + true, 5, 10, 20, "" + ); + + Map mockResponse = Map.of("status", "success"); + mockWebServer.enqueue(new MockResponse() + .setBody(objectMapper.writeValueAsString(mockResponse)) + .addHeader("Content-Type", "application/json")); + + noSecretClient.deleteFiles(List.of("file.java"), "ws", "proj", "main"); + + RecordedRequest request = mockWebServer.takeRequest(); + assertThat(request.getHeader("x-service-secret")).isNull(); + } }