From 8fc46f3bd4a3bdc4e74ebb99d80ec448ba1eb25b Mon Sep 17 00:00:00 2001 From: rostislav Date: Mon, 26 Jan 2026 15:12:42 +0200 Subject: [PATCH 01/20] Refactor AI connection handling and improve job deletion logic - Updated JobService to use REQUIRES_NEW transaction propagation for deleting ignored jobs, ensuring fresh entity retrieval and preventing issues with the calling transaction. - Removed token limitation from AI connection model and related DTOs, transitioning to project-level configuration for token limits. - Adjusted AIConnectionDTO tests to reflect the removal of token limitation. - Enhanced Bitbucket, GitHub, and GitLab AI client services to check token limits before analysis, throwing DiffTooLargeException when limits are exceeded. - Updated command processors to utilize project-level token limits instead of AI connection-specific limits. - Modified webhook processing to handle diff size issues gracefully, posting informative messages to VCS when analysis is skipped due to large diffs. - Cleaned up integration tests to remove references to token limitation in AI connection creation and updates. --- java-ecosystem/libs/analysis-engine/pom.xml | 6 ++ .../src/main/java/module-info.java | 1 + .../exception/DiffTooLargeException.java | 47 ++++++++++ .../analysisengine/util/TokenEstimator.java | 83 +++++++++++++++++ .../codecrow/core/dto/ai/AIConnectionDTO.java | 6 +- .../codecrow/core/model/ai/AIConnection.java | 11 --- .../model/project/config/ProjectConfig.java | 34 ++++++- .../codecrow/core/service/JobService.java | 19 +++- ...ve_token_limitation_from_ai_connection.sql | 5 ++ .../core/dto/ai/AIConnectionDTOTest.java | 35 ++------ .../core/model/ai/AIConnectionTest.java | 15 ---- .../service/BitbucketAiClientService.java | 23 ++++- .../processor/WebhookAsyncProcessor.java | 89 ++++++++++++++++++- .../command/AskCommandProcessor.java | 2 +- .../command/ReviewCommandProcessor.java | 2 +- .../command/SummarizeCommandProcessor.java | 2 +- .../github/service/GitHubAiClientService.java | 23 ++++- .../gitlab/service/GitLabAiClientService.java | 23 ++++- .../GitLabMergeRequestWebhookHandler.java | 27 +++++- .../request/CreateAIConnectionRequest.java | 2 - .../request/UpdateAiConnectionRequest.java | 1 - .../ai/service/AIConnectionService.java | 4 - .../integration/ai/AIConnectionCrudIT.java | 34 +++---- .../integration/auth/UserAuthFlowIT.java | 3 +- .../builder/AIConnectionBuilder.java | 7 -- .../integration/util/AuthTestHelper.java | 1 - 26 files changed, 393 insertions(+), 112 deletions(-) create mode 100644 java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/exception/DiffTooLargeException.java create mode 100644 java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/util/TokenEstimator.java create mode 100644 java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.0__remove_token_limitation_from_ai_connection.sql diff --git a/java-ecosystem/libs/analysis-engine/pom.xml b/java-ecosystem/libs/analysis-engine/pom.xml index 4d16d658..3fba0c89 100644 --- a/java-ecosystem/libs/analysis-engine/pom.xml +++ b/java-ecosystem/libs/analysis-engine/pom.xml @@ -68,6 +68,12 @@ okhttp + + + com.knuddels + jtokkit + + org.junit.jupiter diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/module-info.java b/java-ecosystem/libs/analysis-engine/src/main/java/module-info.java index b3e30345..d03cdf59 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/module-info.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/module-info.java @@ -18,6 +18,7 @@ requires com.fasterxml.jackson.annotation; requires jakarta.persistence; requires kotlin.stdlib; + requires jtokkit; exports org.rostilos.codecrow.analysisengine.aiclient; exports org.rostilos.codecrow.analysisengine.config; diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/exception/DiffTooLargeException.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/exception/DiffTooLargeException.java new file mode 100644 index 00000000..7304448c --- /dev/null +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/exception/DiffTooLargeException.java @@ -0,0 +1,47 @@ +package org.rostilos.codecrow.analysisengine.exception; + +/** + * Exception thrown when a diff exceeds the configured token limit for analysis. + * This is a soft skip - the analysis is not performed but the job is not marked as failed. + */ +public class DiffTooLargeException extends RuntimeException { + + private final int estimatedTokens; + private final int maxAllowedTokens; + private final Long projectId; + private final Long pullRequestId; + + public DiffTooLargeException(int estimatedTokens, int maxAllowedTokens, Long projectId, Long pullRequestId) { + super(String.format( + "PR diff exceeds token limit: estimated %d tokens, max allowed %d tokens (project=%d, PR=%d)", + estimatedTokens, maxAllowedTokens, projectId, pullRequestId + )); + this.estimatedTokens = estimatedTokens; + this.maxAllowedTokens = maxAllowedTokens; + this.projectId = projectId; + this.pullRequestId = pullRequestId; + } + + public int getEstimatedTokens() { + return estimatedTokens; + } + + public int getMaxAllowedTokens() { + return maxAllowedTokens; + } + + public Long getProjectId() { + return projectId; + } + + public Long getPullRequestId() { + return pullRequestId; + } + + /** + * Returns the percentage of the token limit that would be used. + */ + public double getUtilizationPercentage() { + return maxAllowedTokens > 0 ? (estimatedTokens * 100.0 / maxAllowedTokens) : 0; + } +} diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/util/TokenEstimator.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/util/TokenEstimator.java new file mode 100644 index 00000000..4ccb613e --- /dev/null +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/util/TokenEstimator.java @@ -0,0 +1,83 @@ +package org.rostilos.codecrow.analysisengine.util; + +import com.knuddels.jtokkit.Encodings; +import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.EncodingRegistry; +import com.knuddels.jtokkit.api.EncodingType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Utility class for estimating token counts in text content. + * Uses the cl100k_base encoding (used by GPT-4, Claude, and most modern LLMs). + */ +public class TokenEstimator { + private static final Logger log = LoggerFactory.getLogger(TokenEstimator.class); + + private static final EncodingRegistry ENCODING_REGISTRY = Encodings.newDefaultEncodingRegistry(); + private static final Encoding ENCODING = ENCODING_REGISTRY.getEncoding(EncodingType.CL100K_BASE); + + /** + * Estimate the number of tokens in the given text. + * + * @param text The text to estimate tokens for + * @return The estimated token count, or 0 if text is null/empty + */ + public static int estimateTokens(String text) { + if (text == null || text.isEmpty()) { + return 0; + } + try { + return ENCODING.countTokens(text); + } catch (Exception e) { + log.warn("Failed to count tokens, using fallback estimation: {}", e.getMessage()); + // Fallback: rough estimate of ~4 characters per token + return text.length() / 4; + } + } + + /** + * Check if the estimated token count exceeds the given limit. + * + * @param text The text to check + * @param maxTokens The maximum allowed tokens + * @return true if the text exceeds the limit, false otherwise + */ + public static boolean exceedsLimit(String text, int maxTokens) { + return estimateTokens(text) > maxTokens; + } + + /** + * Result of a token estimation check with details. + */ + public record TokenEstimationResult( + int estimatedTokens, + int maxAllowedTokens, + boolean exceedsLimit, + double utilizationPercentage + ) { + public String toLogString() { + return String.format("Tokens: %d / %d (%.1f%%) - %s", + estimatedTokens, maxAllowedTokens, utilizationPercentage, + exceedsLimit ? "EXCEEDS LIMIT" : "within limit"); + } + } + + /** + * Estimate tokens and check against limit, returning detailed result. + * + * @param text The text to check + * @param maxTokens The maximum allowed tokens + * @return Detailed estimation result + */ + public static TokenEstimationResult estimateAndCheck(String text, int maxTokens) { + int estimated = estimateTokens(text); + double utilization = maxTokens > 0 ? (estimated * 100.0 / maxTokens) : 0; + return new TokenEstimationResult( + estimated, + maxTokens, + estimated > maxTokens, + utilization + ); + } +} diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTO.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTO.java index b9e16fb5..b04b4434 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTO.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTO.java @@ -11,8 +11,7 @@ public record AIConnectionDTO( AIProviderKey providerKey, String aiModel, OffsetDateTime createdAt, - OffsetDateTime updatedAt, - int tokenLimitation + OffsetDateTime updatedAt ) { public static AIConnectionDTO fromAiConnection(AIConnection aiConnection) { @@ -22,8 +21,7 @@ public static AIConnectionDTO fromAiConnection(AIConnection aiConnection) { aiConnection.getProviderKey(), aiConnection.getAiModel(), aiConnection.getCreatedAt(), - aiConnection.getUpdatedAt(), - aiConnection.getTokenLimitation() + aiConnection.getUpdatedAt() ); } } diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/ai/AIConnection.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/ai/AIConnection.java index 2ca682c3..f6558f75 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/ai/AIConnection.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/ai/AIConnection.java @@ -39,9 +39,6 @@ public class AIConnection { @Column(name = "updated_at", nullable = false) private OffsetDateTime updatedAt = OffsetDateTime.now(); - @Column(name= "token_limitation", nullable = false) - private int tokenLimitation = 100000; - @PreUpdate public void onUpdate() { this.updatedAt = OffsetDateTime.now(); @@ -98,12 +95,4 @@ public OffsetDateTime getCreatedAt() { public OffsetDateTime getUpdatedAt() { return updatedAt; } - - public void setTokenLimitation(int tokenLimitation) { - this.tokenLimitation = tokenLimitation; - } - - public int getTokenLimitation() { - return tokenLimitation; - } } diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/config/ProjectConfig.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/config/ProjectConfig.java index 66335185..99d18764 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/config/ProjectConfig.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/config/ProjectConfig.java @@ -24,6 +24,8 @@ * - branchAnalysisEnabled: whether to analyze branch pushes (default: true). * - installationMethod: how the project integration is installed (WEBHOOK, PIPELINE, GITHUB_ACTION). * - commentCommands: configuration for PR comment-triggered commands (/codecrow analyze, summarize, ask). + * - maxAnalysisTokenLimit: maximum allowed tokens for PR analysis (default: 200000). + * Analysis will be skipped if the diff exceeds this limit. * * @see BranchAnalysisConfig * @see RagConfig @@ -32,6 +34,8 @@ */ @JsonIgnoreProperties(ignoreUnknown = true) public class ProjectConfig { + public static final int DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT = 200000; + @JsonProperty("useLocalMcp") private boolean useLocalMcp; @@ -56,16 +60,27 @@ public class ProjectConfig { private InstallationMethod installationMethod; @JsonProperty("commentCommands") private CommentCommandsConfig commentCommands; + @JsonProperty("maxAnalysisTokenLimit") + private Integer maxAnalysisTokenLimit; public ProjectConfig() { this.useLocalMcp = false; this.prAnalysisEnabled = true; this.branchAnalysisEnabled = true; + this.maxAnalysisTokenLimit = DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT; } public ProjectConfig(boolean useLocalMcp, String mainBranch, BranchAnalysisConfig branchAnalysis, RagConfig ragConfig, Boolean prAnalysisEnabled, Boolean branchAnalysisEnabled, InstallationMethod installationMethod, CommentCommandsConfig commentCommands) { + this(useLocalMcp, mainBranch, branchAnalysis, ragConfig, prAnalysisEnabled, branchAnalysisEnabled, + installationMethod, commentCommands, DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT); + } + + public ProjectConfig(boolean useLocalMcp, String mainBranch, BranchAnalysisConfig branchAnalysis, + RagConfig ragConfig, Boolean prAnalysisEnabled, Boolean branchAnalysisEnabled, + InstallationMethod installationMethod, CommentCommandsConfig commentCommands, + Integer maxAnalysisTokenLimit) { this.useLocalMcp = useLocalMcp; this.mainBranch = mainBranch; this.defaultBranch = mainBranch; // Keep in sync for backward compatibility @@ -75,6 +90,7 @@ public ProjectConfig(boolean useLocalMcp, String mainBranch, BranchAnalysisConfi this.branchAnalysisEnabled = branchAnalysisEnabled; this.installationMethod = installationMethod; this.commentCommands = commentCommands; + this.maxAnalysisTokenLimit = maxAnalysisTokenLimit != null ? maxAnalysisTokenLimit : DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT; } public ProjectConfig(boolean useLocalMcp, String mainBranch) { @@ -112,6 +128,14 @@ public String defaultBranch() { public InstallationMethod installationMethod() { return installationMethod; } public CommentCommandsConfig commentCommands() { return commentCommands; } + /** + * Get the maximum token limit for PR analysis. + * Returns the configured value or the default (200000) if not set. + */ + public int maxAnalysisTokenLimit() { + return maxAnalysisTokenLimit != null ? maxAnalysisTokenLimit : DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT; + } + // Setters for Jackson public void setUseLocalMcp(boolean useLocalMcp) { this.useLocalMcp = useLocalMcp; } @@ -149,6 +173,9 @@ public void setDefaultBranch(String defaultBranch) { public void setBranchAnalysisEnabled(Boolean branchAnalysisEnabled) { this.branchAnalysisEnabled = branchAnalysisEnabled; } public void setInstallationMethod(InstallationMethod installationMethod) { this.installationMethod = installationMethod; } public void setCommentCommands(CommentCommandsConfig commentCommands) { this.commentCommands = commentCommands; } + public void setMaxAnalysisTokenLimit(Integer maxAnalysisTokenLimit) { + this.maxAnalysisTokenLimit = maxAnalysisTokenLimit != null ? maxAnalysisTokenLimit : DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT; + } public void ensureMainBranchInPatterns() { String main = mainBranch(); @@ -230,13 +257,15 @@ public boolean equals(Object o) { Objects.equals(prAnalysisEnabled, that.prAnalysisEnabled) && Objects.equals(branchAnalysisEnabled, that.branchAnalysisEnabled) && installationMethod == that.installationMethod && - Objects.equals(commentCommands, that.commentCommands); + Objects.equals(commentCommands, that.commentCommands) && + Objects.equals(maxAnalysisTokenLimit, that.maxAnalysisTokenLimit); } @Override public int hashCode() { return Objects.hash(useLocalMcp, mainBranch, branchAnalysis, ragConfig, - prAnalysisEnabled, branchAnalysisEnabled, installationMethod, commentCommands); + prAnalysisEnabled, branchAnalysisEnabled, installationMethod, + commentCommands, maxAnalysisTokenLimit); } @Override @@ -250,6 +279,7 @@ public String toString() { ", branchAnalysisEnabled=" + branchAnalysisEnabled + ", installationMethod=" + installationMethod + ", commentCommands=" + commentCommands + + ", maxAnalysisTokenLimit=" + maxAnalysisTokenLimit + '}'; } } diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java index f239b9c1..04036c47 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java @@ -311,14 +311,27 @@ public Job skipJob(Job job, String reason) { * Used for jobs that were created but then determined to be unnecessary * (e.g., branch not matching pattern, PR analysis disabled). * This prevents DB clutter from ignored webhooks. + * Uses REQUIRES_NEW to ensure this runs in its own transaction, + * allowing it to work even if the calling transaction has issues. */ - @Transactional + @Transactional(propagation = org.springframework.transaction.annotation.Propagation.REQUIRES_NEW) public void deleteIgnoredJob(Job job, String reason) { log.info("Deleting ignored job {} ({}): {}", job.getExternalId(), job.getJobType(), reason); + // Re-fetch the job to ensure we have a fresh entity in this new transaction + Long jobId = job.getId(); + if (jobId == null) { + log.warn("Cannot delete ignored job - job ID is null"); + return; + } + Optional existingJob = jobRepository.findById(jobId); + if (existingJob.isEmpty()) { + log.warn("Cannot delete ignored job {} - not found in database", job.getExternalId()); + return; + } // Delete any logs first (foreign key constraint) - jobLogRepository.deleteByJobId(job.getId()); + jobLogRepository.deleteByJobId(jobId); // Delete the job - jobRepository.delete(job); + jobRepository.delete(existingJob.get()); } /** diff --git a/java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.0__remove_token_limitation_from_ai_connection.sql b/java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.0__remove_token_limitation_from_ai_connection.sql new file mode 100644 index 00000000..fcd0faaf --- /dev/null +++ b/java-ecosystem/libs/core/src/main/resources/db/migration/1.4.0/V1.4.0__remove_token_limitation_from_ai_connection.sql @@ -0,0 +1,5 @@ +-- Remove token_limitation column from ai_connection table +-- Token limitation is now configured per-project in the project configuration JSON +-- Default value is 200000 tokens, configured in ProjectConfig.maxAnalysisTokenLimit + +ALTER TABLE ai_connection DROP COLUMN IF EXISTS token_limitation; diff --git a/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTOTest.java b/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTOTest.java index 5a8af6ef..7536258c 100644 --- a/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTOTest.java +++ b/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/ai/AIConnectionDTOTest.java @@ -24,7 +24,7 @@ void shouldCreateWithAllFields() { OffsetDateTime now = OffsetDateTime.now(); AIConnectionDTO dto = new AIConnectionDTO( 1L, "Test Connection", AIProviderKey.ANTHROPIC, "claude-3-opus", - now, now, 100000 + now, now ); assertThat(dto.id()).isEqualTo(1L); @@ -33,14 +33,13 @@ void shouldCreateWithAllFields() { assertThat(dto.aiModel()).isEqualTo("claude-3-opus"); assertThat(dto.createdAt()).isEqualTo(now); assertThat(dto.updatedAt()).isEqualTo(now); - assertThat(dto.tokenLimitation()).isEqualTo(100000); } @Test @DisplayName("should create AIConnectionDTO with null optional fields") void shouldCreateWithNullOptionalFields() { AIConnectionDTO dto = new AIConnectionDTO( - 1L, null, AIProviderKey.OPENAI, null, null, null, 50000 + 1L, null, AIProviderKey.OPENAI, null, null, null ); assertThat(dto.id()).isEqualTo(1L); @@ -53,24 +52,14 @@ void shouldCreateWithNullOptionalFields() { @Test @DisplayName("should create AIConnectionDTO with different providers") void shouldCreateWithDifferentProviders() { - AIConnectionDTO openai = new AIConnectionDTO(1L, "OpenAI", AIProviderKey.OPENAI, "gpt-4", null, null, 100000); - AIConnectionDTO anthropic = new AIConnectionDTO(2L, "Anthropic", AIProviderKey.ANTHROPIC, "claude-3", null, null, 200000); - AIConnectionDTO google = new AIConnectionDTO(3L, "Google", AIProviderKey.GOOGLE, "gemini-pro", null, null, 150000); + AIConnectionDTO openai = new AIConnectionDTO(1L, "OpenAI", AIProviderKey.OPENAI, "gpt-4", null, null); + AIConnectionDTO anthropic = new AIConnectionDTO(2L, "Anthropic", AIProviderKey.ANTHROPIC, "claude-3", null, null); + AIConnectionDTO google = new AIConnectionDTO(3L, "Google", AIProviderKey.GOOGLE, "gemini-pro", null, null); assertThat(openai.providerKey()).isEqualTo(AIProviderKey.OPENAI); assertThat(anthropic.providerKey()).isEqualTo(AIProviderKey.ANTHROPIC); assertThat(google.providerKey()).isEqualTo(AIProviderKey.GOOGLE); } - - @Test - @DisplayName("should support different token limitations") - void shouldSupportDifferentTokenLimitations() { - AIConnectionDTO small = new AIConnectionDTO(1L, "Small", AIProviderKey.OPENAI, "gpt-3.5", null, null, 16000); - AIConnectionDTO large = new AIConnectionDTO(2L, "Large", AIProviderKey.ANTHROPIC, "claude-3", null, null, 200000); - - assertThat(small.tokenLimitation()).isEqualTo(16000); - assertThat(large.tokenLimitation()).isEqualTo(200000); - } } @Nested @@ -85,7 +74,6 @@ void shouldConvertWithAllFields() { connection.setName("Production AI"); setField(connection, "providerKey", AIProviderKey.ANTHROPIC); setField(connection, "aiModel", "claude-3-opus"); - setField(connection, "tokenLimitation", 100000); AIConnectionDTO dto = AIConnectionDTO.fromAiConnection(connection); @@ -93,7 +81,6 @@ void shouldConvertWithAllFields() { assertThat(dto.name()).isEqualTo("Production AI"); assertThat(dto.providerKey()).isEqualTo(AIProviderKey.ANTHROPIC); assertThat(dto.aiModel()).isEqualTo("claude-3-opus"); - assertThat(dto.tokenLimitation()).isEqualTo(100000); } @Test @@ -104,7 +91,6 @@ void shouldConvertWithNullName() { connection.setName(null); setField(connection, "providerKey", AIProviderKey.OPENAI); setField(connection, "aiModel", "gpt-4"); - setField(connection, "tokenLimitation", 50000); AIConnectionDTO dto = AIConnectionDTO.fromAiConnection(connection); @@ -120,7 +106,6 @@ void shouldConvertWithNullModel() { connection.setName("Test"); setField(connection, "providerKey", AIProviderKey.GOOGLE); setField(connection, "aiModel", null); - setField(connection, "tokenLimitation", 75000); AIConnectionDTO dto = AIConnectionDTO.fromAiConnection(connection); @@ -134,7 +119,6 @@ void shouldConvertAllProviderTypes() { AIConnection connection = new AIConnection(); setField(connection, "id", 1L); setField(connection, "providerKey", providerKey); - setField(connection, "tokenLimitation", 100000); AIConnectionDTO dto = AIConnectionDTO.fromAiConnection(connection); @@ -149,7 +133,6 @@ void shouldHandleTimestamps() { setField(connection, "id", 1L); connection.setName("Test"); setField(connection, "providerKey", AIProviderKey.ANTHROPIC); - setField(connection, "tokenLimitation", 100000); AIConnectionDTO dto = AIConnectionDTO.fromAiConnection(connection); @@ -165,8 +148,8 @@ class EqualityTests { @DisplayName("should be equal for same values") void shouldBeEqualForSameValues() { OffsetDateTime now = OffsetDateTime.now(); - AIConnectionDTO dto1 = new AIConnectionDTO(1L, "Test", AIProviderKey.OPENAI, "gpt-4", now, now, 100000); - AIConnectionDTO dto2 = new AIConnectionDTO(1L, "Test", AIProviderKey.OPENAI, "gpt-4", now, now, 100000); + AIConnectionDTO dto1 = new AIConnectionDTO(1L, "Test", AIProviderKey.OPENAI, "gpt-4", now, now); + AIConnectionDTO dto2 = new AIConnectionDTO(1L, "Test", AIProviderKey.OPENAI, "gpt-4", now, now); assertThat(dto1).isEqualTo(dto2); assertThat(dto1.hashCode()).isEqualTo(dto2.hashCode()); @@ -176,8 +159,8 @@ void shouldBeEqualForSameValues() { @DisplayName("should not be equal for different values") void shouldNotBeEqualForDifferentValues() { OffsetDateTime now = OffsetDateTime.now(); - AIConnectionDTO dto1 = new AIConnectionDTO(1L, "Test1", AIProviderKey.OPENAI, "gpt-4", now, now, 100000); - AIConnectionDTO dto2 = new AIConnectionDTO(2L, "Test2", AIProviderKey.ANTHROPIC, "claude", now, now, 200000); + AIConnectionDTO dto1 = new AIConnectionDTO(1L, "Test1", AIProviderKey.OPENAI, "gpt-4", now, now); + AIConnectionDTO dto2 = new AIConnectionDTO(2L, "Test2", AIProviderKey.ANTHROPIC, "claude", now, now); assertThat(dto1).isNotEqualTo(dto2); } diff --git a/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/model/ai/AIConnectionTest.java b/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/model/ai/AIConnectionTest.java index d2dd3d4d..6adcfaa6 100644 --- a/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/model/ai/AIConnectionTest.java +++ b/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/model/ai/AIConnectionTest.java @@ -67,25 +67,12 @@ void shouldSetAndGetApiKeyEncrypted() { aiConnection.setApiKeyEncrypted("encrypted-api-key-xyz"); assertThat(aiConnection.getApiKeyEncrypted()).isEqualTo("encrypted-api-key-xyz"); } - - @Test - @DisplayName("Should set and get tokenLimitation") - void shouldSetAndGetTokenLimitation() { - aiConnection.setTokenLimitation(50000); - assertThat(aiConnection.getTokenLimitation()).isEqualTo(50000); - } } @Nested @DisplayName("Default value tests") class DefaultValueTests { - @Test - @DisplayName("Default tokenLimitation should be 100000") - void defaultTokenLimitationShouldBe100000() { - assertThat(aiConnection.getTokenLimitation()).isEqualTo(100000); - } - @Test @DisplayName("Id should be null for new entity") void idShouldBeNullForNewEntity() { @@ -154,14 +141,12 @@ void shouldBeAbleToUpdateAllFields() { aiConnection.setProviderKey(AIProviderKey.ANTHROPIC); aiConnection.setAiModel("claude-3-opus"); aiConnection.setApiKeyEncrypted("new-encrypted-key"); - aiConnection.setTokenLimitation(200000); assertThat(aiConnection.getName()).isEqualTo("Updated Name"); assertThat(aiConnection.getWorkspace()).isSameAs(workspace); assertThat(aiConnection.getProviderKey()).isEqualTo(AIProviderKey.ANTHROPIC); assertThat(aiConnection.getAiModel()).isEqualTo("claude-3-opus"); assertThat(aiConnection.getApiKeyEncrypted()).isEqualTo("new-encrypted-key"); - assertThat(aiConnection.getTokenLimitation()).isEqualTo(200000); } @Test diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java index 0094252d..142cba75 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java @@ -13,9 +13,11 @@ import org.rostilos.codecrow.analysisengine.dto.request.processor.BranchProcessRequest; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; import org.rostilos.codecrow.analysisengine.dto.request.ai.AiAnalysisRequest; +import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.service.vcs.VcsAiClientService; import org.rostilos.codecrow.analysisengine.util.DiffContentFilter; import org.rostilos.codecrow.analysisengine.util.DiffParser; +import org.rostilos.codecrow.analysisengine.util.TokenEstimator; import org.rostilos.codecrow.security.oauth.TokenEncryptionService; import org.rostilos.codecrow.vcsclient.VcsClientProvider; import org.rostilos.codecrow.vcsclient.bitbucket.cloud.actions.GetCommitRangeDiffAction; @@ -172,6 +174,23 @@ public AiAnalysisRequest buildPrAnalysisRequest( originalSize > 0 ? (100 - (filteredSize * 100 / originalSize)) : 0); } + // Check token limit before proceeding with analysis + int maxTokenLimit = project.getEffectiveConfig().maxAnalysisTokenLimit(); + TokenEstimator.TokenEstimationResult tokenEstimate = TokenEstimator.estimateAndCheck(rawDiff, maxTokenLimit); + log.info("Token estimation for PR diff: {}", tokenEstimate.toLogString()); + + if (tokenEstimate.exceedsLimit()) { + log.warn("PR diff exceeds token limit - skipping analysis. Project={}, PR={}, Tokens={}/{}", + project.getId(), request.getPullRequestId(), + tokenEstimate.estimatedTokens(), tokenEstimate.maxAllowedTokens()); + throw new DiffTooLargeException( + tokenEstimate.estimatedTokens(), + tokenEstimate.maxAllowedTokens(), + project.getId(), + request.getPullRequestId() + ); + } + // Determine analysis mode: INCREMENTAL if we have previous analysis with different commit boolean canUseIncremental = previousAnalysis.isPresent() && previousCommitHash != null @@ -232,7 +251,7 @@ public AiAnalysisRequest buildPrAnalysisRequest( .withProjectAiConnectionTokenDecrypted(tokenEncryptionService.decrypt(projectAiConnection.getApiKeyEncrypted())) .withUseLocalMcp(true) .withAllPrAnalysesData(allPrAnalyses) // Use full PR history instead of just previous version - .withMaxAllowedTokens(aiConnection.getTokenLimitation()) + .withMaxAllowedTokens(project.getEffectiveConfig().maxAnalysisTokenLimit()) .withAnalysisType(request.getAnalysisType()) .withPrTitle(prTitle) .withPrDescription(prDescription) @@ -303,7 +322,7 @@ public AiAnalysisRequest buildBranchAnalysisRequest( .withProjectAiConnectionTokenDecrypted(tokenEncryptionService.decrypt(projectAiConnection.getApiKeyEncrypted())) .withUseLocalMcp(true) .withPreviousAnalysisData(previousAnalysis) - .withMaxAllowedTokens(aiConnection.getTokenLimitation()) + .withMaxAllowedTokens(project.getEffectiveConfig().maxAnalysisTokenLimit()) .withAnalysisType(request.getAnalysisType()) .withTargetBranchName(request.getTargetBranchName()) .withCurrentCommitHash(request.getCommitHash()) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index a216316a..039eb86c 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -7,6 +7,7 @@ import org.rostilos.codecrow.core.service.JobService; import org.rostilos.codecrow.pipelineagent.generic.dto.webhook.WebhookPayload; import org.rostilos.codecrow.pipelineagent.generic.webhookhandler.WebhookHandler; +import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.service.vcs.VcsReportingService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsServiceFactory; import org.slf4j.Logger; @@ -127,7 +128,18 @@ public void processWebhookAsync( deletePlaceholderComment(provider, project, payload, finalPlaceholderCommentId); } // Delete the job entirely - don't clutter DB with ignored webhooks - jobService.deleteIgnoredJob(job, result.message()); + // If deletion fails, skip the job instead + try { + jobService.deleteIgnoredJob(job, result.message()); + } catch (Exception deleteError) { + log.warn("Failed to delete ignored job {}, skipping instead: {}", + job.getExternalId(), deleteError.getMessage()); + try { + jobService.skipJob(job, result.message()); + } catch (Exception skipError) { + log.error("Failed to skip job {}: {}", job.getExternalId(), skipError.getMessage()); + } + } return; } @@ -151,6 +163,42 @@ public void processWebhookAsync( jobService.failJob(job, result.message()); } + } catch (DiffTooLargeException diffEx) { + // Handle diff too large - this is a soft skip, not an error + log.warn("Diff too large for analysis - skipping: {}", diffEx.getMessage()); + + String skipMessage = String.format( + "⚠️ **Analysis Skipped - PR Too Large**\n\n" + + "This PR's diff exceeds the configured token limit:\n" + + "- **Estimated tokens:** %,d\n" + + "- **Maximum allowed:** %,d (%.1f%% of limit)\n\n" + + "To analyze this PR, consider:\n" + + "1. Breaking it into smaller PRs\n" + + "2. Increasing the token limit in project settings\n" + + "3. Using `/codecrow analyze` command on specific commits", + diffEx.getEstimatedTokens(), + diffEx.getMaxAllowedTokens(), + diffEx.getUtilizationPercentage() + ); + + try { + if (project == null) { + project = projectRepository.findById(projectId).orElse(null); + } + if (project != null) { + initializeProjectAssociations(project); + postInfoToVcs(provider, project, payload, skipMessage, placeholderCommentId, job); + } + } catch (Exception postError) { + log.error("Failed to post skip message to VCS: {}", postError.getMessage()); + } + + try { + jobService.skipJob(job, "Diff too large: " + diffEx.getEstimatedTokens() + " tokens > " + diffEx.getMaxAllowedTokens() + " limit"); + } catch (Exception skipError) { + log.error("Failed to skip job: {}", skipError.getMessage()); + } + } catch (Exception e) { log.error("Error processing webhook for job {}", job.getExternalId(), e); @@ -390,6 +438,45 @@ private void postErrorToVcs(EVcsProvider provider, Project project, WebhookPaylo } } + /** + * Post an info message to VCS as a comment (for skipped/info scenarios). + * If placeholderCommentId is provided, update that comment with the info. + */ + private void postInfoToVcs(EVcsProvider provider, Project project, WebhookPayload payload, + String infoMessage, String placeholderCommentId, Job job) { + try { + if (payload.pullRequestId() == null) { + return; + } + + VcsReportingService reportingService = vcsServiceFactory.getReportingService(provider); + + // If we have a placeholder comment, update it with the info + if (placeholderCommentId != null) { + reportingService.updateComment( + project, + Long.parseLong(payload.pullRequestId()), + placeholderCommentId, + infoMessage, + CODECROW_COMMAND_MARKER + ); + log.info("Updated placeholder comment {} with info message for PR {}", placeholderCommentId, payload.pullRequestId()); + } else { + // No placeholder - post new info comment + reportingService.postComment( + project, + Long.parseLong(payload.pullRequestId()), + infoMessage, + CODECROW_COMMAND_MARKER + ); + log.info("Posted info message to PR {}", payload.pullRequestId()); + } + + } catch (Exception e) { + log.error("Failed to post info to VCS: {}", e.getMessage()); + } + } + /** * Sanitize error messages for display on VCS platforms. * Removes sensitive technical details like API keys, quotas, and internal stack traces. diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/AskCommandProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/AskCommandProcessor.java index b8ceb92f..68c5f87d 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/AskCommandProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/AskCommandProcessor.java @@ -375,7 +375,7 @@ private AskRequest buildAskRequest( credentials.oAuthClient(), credentials.oAuthSecret(), credentials.accessToken(), - aiConnection.getTokenLimitation(), + project.getEffectiveConfig().maxAnalysisTokenLimit(), credentials.vcsProviderString(), analysisContext, context.issueReferences() diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java index 78a8dcb2..8d381e79 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java @@ -180,7 +180,7 @@ private ReviewRequest buildReviewRequest(Project project, WebhookPayload payload credentials.oAuthClient(), credentials.oAuthSecret(), credentials.accessToken(), - aiConnection.getTokenLimitation(), + project.getEffectiveConfig().maxAnalysisTokenLimit(), credentials.vcsProviderString() ); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/SummarizeCommandProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/SummarizeCommandProcessor.java index c2312413..8e3f992b 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/SummarizeCommandProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/SummarizeCommandProcessor.java @@ -292,7 +292,7 @@ private SummarizeRequest buildSummarizeRequest( credentials.oAuthSecret(), credentials.accessToken(), diagramType == PrSummarizeCache.DiagramType.MERMAID, - aiConnection.getTokenLimitation(), + project.getEffectiveConfig().maxAnalysisTokenLimit(), credentials.vcsProviderString() ); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java index b39202de..f8f0dc09 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java @@ -13,9 +13,11 @@ import org.rostilos.codecrow.analysisengine.dto.request.processor.AnalysisProcessRequest; import org.rostilos.codecrow.analysisengine.dto.request.processor.BranchProcessRequest; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; +import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.service.vcs.VcsAiClientService; import org.rostilos.codecrow.analysisengine.util.DiffContentFilter; import org.rostilos.codecrow.analysisengine.util.DiffParser; +import org.rostilos.codecrow.analysisengine.util.TokenEstimator; import org.rostilos.codecrow.security.oauth.TokenEncryptionService; import org.rostilos.codecrow.vcsclient.VcsClientProvider; import org.rostilos.codecrow.vcsclient.github.actions.GetCommitRangeDiffAction; @@ -165,6 +167,23 @@ private AiAnalysisRequest buildPrAnalysisRequest( originalSize > 0 ? (100 - (filteredSize * 100 / originalSize)) : 0); } + // Check token limit before proceeding with analysis + int maxTokenLimit = project.getEffectiveConfig().maxAnalysisTokenLimit(); + TokenEstimator.TokenEstimationResult tokenEstimate = TokenEstimator.estimateAndCheck(rawDiff, maxTokenLimit); + log.info("Token estimation for PR diff: {}", tokenEstimate.toLogString()); + + if (tokenEstimate.exceedsLimit()) { + log.warn("PR diff exceeds token limit - skipping analysis. Project={}, PR={}, Tokens={}/{}", + project.getId(), request.getPullRequestId(), + tokenEstimate.estimatedTokens(), tokenEstimate.maxAllowedTokens()); + throw new DiffTooLargeException( + tokenEstimate.estimatedTokens(), + tokenEstimate.maxAllowedTokens(), + project.getId(), + request.getPullRequestId() + ); + } + // Determine analysis mode: INCREMENTAL if we have previous analysis with different commit boolean canUseIncremental = previousAnalysis.isPresent() && previousCommitHash != null @@ -223,7 +242,7 @@ private AiAnalysisRequest buildPrAnalysisRequest( .withProjectAiConnectionTokenDecrypted(tokenEncryptionService.decrypt(aiConnection.getApiKeyEncrypted())) .withUseLocalMcp(true) .withAllPrAnalysesData(allPrAnalyses) // Use full PR history instead of just previous version - .withMaxAllowedTokens(aiConnection.getTokenLimitation()) + .withMaxAllowedTokens(project.getEffectiveConfig().maxAnalysisTokenLimit()) .withAnalysisType(request.getAnalysisType()) .withPrTitle(prTitle) .withPrDescription(prDescription) @@ -291,7 +310,7 @@ private AiAnalysisRequest buildBranchAnalysisRequest( .withProjectAiConnectionTokenDecrypted(tokenEncryptionService.decrypt(aiConnection.getApiKeyEncrypted())) .withUseLocalMcp(true) .withPreviousAnalysisData(previousAnalysis) - .withMaxAllowedTokens(aiConnection.getTokenLimitation()) + .withMaxAllowedTokens(project.getEffectiveConfig().maxAnalysisTokenLimit()) .withAnalysisType(request.getAnalysisType()) .withTargetBranchName(request.getTargetBranchName()) .withCurrentCommitHash(request.getCommitHash()) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java index 1c697fc7..ed2be10a 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java @@ -13,9 +13,11 @@ import org.rostilos.codecrow.analysisengine.dto.request.processor.AnalysisProcessRequest; import org.rostilos.codecrow.analysisengine.dto.request.processor.BranchProcessRequest; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; +import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.service.vcs.VcsAiClientService; import org.rostilos.codecrow.analysisengine.util.DiffContentFilter; import org.rostilos.codecrow.analysisengine.util.DiffParser; +import org.rostilos.codecrow.analysisengine.util.TokenEstimator; import org.rostilos.codecrow.security.oauth.TokenEncryptionService; import org.rostilos.codecrow.vcsclient.VcsClientProvider; import org.rostilos.codecrow.vcsclient.gitlab.actions.GetCommitRangeDiffAction; @@ -166,6 +168,23 @@ private AiAnalysisRequest buildMrAnalysisRequest( originalSize > 0 ? (100 - (filteredSize * 100 / originalSize)) : 0); } + // Check token limit before proceeding with analysis + int maxTokenLimit = project.getEffectiveConfig().maxAnalysisTokenLimit(); + TokenEstimator.TokenEstimationResult tokenEstimate = TokenEstimator.estimateAndCheck(rawDiff, maxTokenLimit); + log.info("Token estimation for MR diff: {}", tokenEstimate.toLogString()); + + if (tokenEstimate.exceedsLimit()) { + log.warn("MR diff exceeds token limit - skipping analysis. Project={}, PR={}, Tokens={}/{}", + project.getId(), request.getPullRequestId(), + tokenEstimate.estimatedTokens(), tokenEstimate.maxAllowedTokens()); + throw new DiffTooLargeException( + tokenEstimate.estimatedTokens(), + tokenEstimate.maxAllowedTokens(), + project.getId(), + request.getPullRequestId() + ); + } + // Determine analysis mode: INCREMENTAL if we have previous analysis with different commit boolean canUseIncremental = previousAnalysis.isPresent() && previousCommitHash != null @@ -224,7 +243,7 @@ private AiAnalysisRequest buildMrAnalysisRequest( .withProjectAiConnectionTokenDecrypted(tokenEncryptionService.decrypt(aiConnection.getApiKeyEncrypted())) .withUseLocalMcp(true) .withAllPrAnalysesData(allPrAnalyses) // Use full PR history instead of just previous version - .withMaxAllowedTokens(aiConnection.getTokenLimitation()) + .withMaxAllowedTokens(project.getEffectiveConfig().maxAnalysisTokenLimit()) .withAnalysisType(request.getAnalysisType()) .withPrTitle(mrTitle) .withPrDescription(mrDescription) @@ -292,7 +311,7 @@ private AiAnalysisRequest buildBranchAnalysisRequest( .withProjectAiConnectionTokenDecrypted(tokenEncryptionService.decrypt(aiConnection.getApiKeyEncrypted())) .withUseLocalMcp(true) .withPreviousAnalysisData(previousAnalysis) - .withMaxAllowedTokens(aiConnection.getTokenLimitation()) + .withMaxAllowedTokens(project.getEffectiveConfig().maxAnalysisTokenLimit()) .withAnalysisType(request.getAnalysisType()) .withTargetBranchName(request.getTargetBranchName()) .withCurrentCommitHash(request.getCommitHash()) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/webhookhandler/GitLabMergeRequestWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/webhookhandler/GitLabMergeRequestWebhookHandler.java index 15fefd59..77f62466 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/webhookhandler/GitLabMergeRequestWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/webhookhandler/GitLabMergeRequestWebhookHandler.java @@ -1,10 +1,12 @@ package org.rostilos.codecrow.pipelineagent.gitlab.webhookhandler; +import org.rostilos.codecrow.core.model.analysis.AnalysisLockType; import org.rostilos.codecrow.core.model.codeanalysis.AnalysisType; import org.rostilos.codecrow.core.model.project.Project; import org.rostilos.codecrow.core.model.vcs.EVcsProvider; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; import org.rostilos.codecrow.analysisengine.processor.analysis.PullRequestAnalysisProcessor; +import org.rostilos.codecrow.analysisengine.service.AnalysisLockService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsReportingService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsServiceFactory; import org.rostilos.codecrow.pipelineagent.generic.dto.webhook.WebhookPayload; @@ -15,6 +17,7 @@ import org.springframework.stereotype.Component; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.Consumer; @@ -47,13 +50,16 @@ public class GitLabMergeRequestWebhookHandler extends AbstractWebhookHandler imp private final PullRequestAnalysisProcessor pullRequestAnalysisProcessor; private final VcsServiceFactory vcsServiceFactory; + private final AnalysisLockService analysisLockService; public GitLabMergeRequestWebhookHandler( PullRequestAnalysisProcessor pullRequestAnalysisProcessor, - VcsServiceFactory vcsServiceFactory + VcsServiceFactory vcsServiceFactory, + AnalysisLockService analysisLockService ) { this.pullRequestAnalysisProcessor = pullRequestAnalysisProcessor; this.vcsServiceFactory = vcsServiceFactory; + this.analysisLockService = analysisLockService; } @Override @@ -119,6 +125,25 @@ private WebhookResult handleMergeRequestEvent( String placeholderCommentId = null; try { + // Try to acquire lock atomically BEFORE posting placeholder + // This prevents TOCTOU race where multiple webhooks could pass isLocked() check simultaneously + // Note: PullRequestAnalysisProcessor.process() uses acquireLockWithWait() which will + // reuse this lock since it's for the same project/branch/type + String sourceBranch = payload.sourceBranch(); + Optional earlyLock = analysisLockService.acquireLock( + project, sourceBranch, AnalysisLockType.PR_ANALYSIS, + payload.commitHash(), Long.parseLong(payload.pullRequestId())); + + if (earlyLock.isEmpty()) { + log.info("MR analysis already in progress for project={}, branch={}, MR={} - skipping duplicate webhook", + project.getId(), sourceBranch, payload.pullRequestId()); + return WebhookResult.ignored("MR analysis already in progress for this branch"); + } + + // Lock acquired - placeholder posting is now protected from race conditions + // Note: We don't release this lock here - PullRequestAnalysisProcessor will manage it + // since acquireLockWithWait() will detect the existing lock and use it + // Post placeholder comment immediately to show analysis has started placeholderCommentId = postPlaceholderComment(project, Long.parseLong(payload.pullRequestId())); diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/CreateAIConnectionRequest.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/CreateAIConnectionRequest.java index 558f236e..307dc44e 100644 --- a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/CreateAIConnectionRequest.java +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/CreateAIConnectionRequest.java @@ -13,6 +13,4 @@ public class CreateAIConnectionRequest { public String aiModel; @NotBlank(message = "API key is required") public String apiKey; - @NotBlank(message = "Please specify max token limit") - public String tokenLimitation; } diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/UpdateAiConnectionRequest.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/UpdateAiConnectionRequest.java index ae834bb5..f7805b90 100644 --- a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/UpdateAiConnectionRequest.java +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/dto/request/UpdateAiConnectionRequest.java @@ -9,5 +9,4 @@ public class UpdateAiConnectionRequest { public AIProviderKey providerKey; public String aiModel; public String apiKey; - public String tokenLimitation; } diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/service/AIConnectionService.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/service/AIConnectionService.java index 510f0901..48949fc7 100644 --- a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/service/AIConnectionService.java +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/ai/service/AIConnectionService.java @@ -53,7 +53,6 @@ public AIConnection createAiConnection(Long workspaceId, CreateAIConnectionReque newAiConnection.setProviderKey(request.providerKey); newAiConnection.setAiModel(request.aiModel); newAiConnection.setApiKeyEncrypted(apiKeyEncrypted); - newAiConnection.setTokenLimitation(Integer.parseInt(request.tokenLimitation)); return connectionRepository.save(newAiConnection); } @@ -77,9 +76,6 @@ public AIConnection updateAiConnection(Long workspaceId, Long connectionId, Upda connection.setApiKeyEncrypted(apiKeyEncrypted); } - if(request.tokenLimitation != null && !request.tokenLimitation.isEmpty()) { - connection.setTokenLimitation(Integer.parseInt(request.tokenLimitation)); - } return connectionRepository.save(connection); } diff --git a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/ai/AIConnectionCrudIT.java b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/ai/AIConnectionCrudIT.java index 2c94d528..69d3bef6 100644 --- a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/ai/AIConnectionCrudIT.java +++ b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/ai/AIConnectionCrudIT.java @@ -36,8 +36,7 @@ void shouldCreateOpenRouterConnection() { { "providerKey": "OPENROUTER", "aiModel": "anthropic/claude-3-haiku", - "apiKey": "test-api-key-openrouter", - "tokenLimitation": "200000" + "apiKey": "test-api-key-openrouter" } """; @@ -50,7 +49,6 @@ void shouldCreateOpenRouterConnection() { .statusCode(201) .body("providerKey", equalTo("OPENROUTER")) .body("aiModel", equalTo("anthropic/claude-3-haiku")) - .body("tokenLimitation", equalTo(200000)) .body("id", notNullValue()); } @@ -62,8 +60,7 @@ void shouldCreateOpenAIConnection() { { "providerKey": "OPENAI", "aiModel": "gpt-4o-mini", - "apiKey": "test-api-key-openai", - "tokenLimitation": "128000" + "apiKey": "test-api-key-openai" } """; @@ -86,8 +83,7 @@ void shouldCreateAnthropicConnection() { { "providerKey": "ANTHROPIC", "aiModel": "claude-3-haiku-20240307", - "apiKey": "test-api-key-anthropic", - "tokenLimitation": "200000" + "apiKey": "test-api-key-anthropic" } """; @@ -110,8 +106,7 @@ void shouldListAIConnections() { { "providerKey": "OPENROUTER", "aiModel": "test-model", - "apiKey": "test-key", - "tokenLimitation": "100000" + "apiKey": "test-key" } """; @@ -140,8 +135,7 @@ void shouldUpdateAIConnection() { { "providerKey": "OPENROUTER", "aiModel": "original-model", - "apiKey": "original-key", - "tokenLimitation": "100000" + "apiKey": "original-key" } """; @@ -158,8 +152,7 @@ void shouldUpdateAIConnection() { { "providerKey": "OPENROUTER", "aiModel": "updated-model", - "apiKey": "updated-key", - "tokenLimitation": "150000" + "apiKey": "updated-key" } """; @@ -170,8 +163,7 @@ void shouldUpdateAIConnection() { .patch("/api/{workspaceSlug}/ai/{connectionId}", testWorkspace.getSlug(), connectionId) .then() .statusCode(200) - .body("aiModel", equalTo("updated-model")) - .body("tokenLimitation", equalTo(150000)); + .body("aiModel", equalTo("updated-model")); } @Test @@ -182,8 +174,7 @@ void shouldDeleteAIConnection() { { "providerKey": "OPENROUTER", "aiModel": "to-delete", - "apiKey": "delete-key", - "tokenLimitation": "100000" + "apiKey": "delete-key" } """; @@ -212,8 +203,7 @@ void shouldRequireAdminRightsForAIOperations() { { "providerKey": "OPENROUTER", "aiModel": "test-model", - "apiKey": "test-key", - "tokenLimitation": "100000" + "apiKey": "test-key" } """; @@ -272,8 +262,7 @@ void shouldValidateProviderKey() { { "providerKey": "INVALID_PROVIDER", "aiModel": "test-model", - "apiKey": "test-key", - "tokenLimitation": "100000" + "apiKey": "test-key" } """; @@ -300,8 +289,7 @@ void shouldPreventCrossWorkspaceAccess() { { "providerKey": "OPENROUTER", "aiModel": "other-ws-model", - "apiKey": "other-ws-key", - "tokenLimitation": "100000" + "apiKey": "other-ws-key" } """; diff --git a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/auth/UserAuthFlowIT.java b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/auth/UserAuthFlowIT.java index 5d0ffaca..bb8d4220 100644 --- a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/auth/UserAuthFlowIT.java +++ b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/auth/UserAuthFlowIT.java @@ -313,8 +313,7 @@ void shouldHandleWorkspaceRoleDowngrade() { { "providerKey": "OPENROUTER", "aiModel": "test", - "apiKey": "key", - "tokenLimitation": "100000" + "apiKey": "key" } """; given() diff --git a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/builder/AIConnectionBuilder.java b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/builder/AIConnectionBuilder.java index ef858489..69ecb04e 100644 --- a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/builder/AIConnectionBuilder.java +++ b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/builder/AIConnectionBuilder.java @@ -14,7 +14,6 @@ public class AIConnectionBuilder { private AIProviderKey providerKey = AIProviderKey.OPENROUTER; private String aiModel = "anthropic/claude-3-haiku"; private String apiKeyEncrypted = "encrypted-test-key"; - private int tokenLimitation = 200000; public static AIConnectionBuilder anAIConnection() { return new AIConnectionBuilder(); @@ -45,11 +44,6 @@ public AIConnectionBuilder withApiKeyEncrypted(String apiKeyEncrypted) { return this; } - public AIConnectionBuilder withTokenLimitation(int tokenLimitation) { - this.tokenLimitation = tokenLimitation; - return this; - } - public AIConnectionBuilder openAI() { this.providerKey = AIProviderKey.OPENAI; this.aiModel = "gpt-4o-mini"; @@ -75,7 +69,6 @@ public AIConnection build() { connection.setProviderKey(providerKey); connection.setAiModel(aiModel); connection.setApiKeyEncrypted(apiKeyEncrypted); - connection.setTokenLimitation(tokenLimitation); return connection; } } diff --git a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/util/AuthTestHelper.java b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/util/AuthTestHelper.java index c819e219..8902eaf5 100644 --- a/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/util/AuthTestHelper.java +++ b/java-ecosystem/tests/integration-tests/src/test/java/org/rostilos/codecrow/integration/util/AuthTestHelper.java @@ -254,7 +254,6 @@ public AIConnection createTestAiConnection(Workspace workspace, String name, AIP aiConnection.setProviderKey(provider); aiConnection.setAiModel("gpt-4"); aiConnection.setApiKeyEncrypted("test-encrypted-api-key-" + UUID.randomUUID().toString().substring(0, 8)); - aiConnection.setTokenLimitation(100000); return aiConnectionRepository.save(aiConnection); } From 7c780573801fc3dd8c52c39af26c616385237306 Mon Sep 17 00:00:00 2001 From: rostislav Date: Mon, 26 Jan 2026 18:46:00 +0200 Subject: [PATCH 02/20] feat: Add pre-acquired lock key to prevent double-locking in PR analysis processing. Project PR analysis max analysis token limit implementation --- .../request/processor/PrProcessRequest.java | 9 +++ .../PullRequestAnalysisProcessor.java | 63 +++++++++++-------- .../codecrow/core/dto/project/ProjectDTO.java | 9 ++- .../codecrow/core/model/project/Project.java | 9 +++ .../core/dto/project/ProjectDTOTest.java | 4 +- ...tbucketCloudPullRequestWebhookHandler.java | 8 +-- .../GitHubPullRequestWebhookHandler.java | 8 +-- .../project/controller/ProjectController.java | 6 +- .../project/service/ProjectService.java | 7 ++- 9 files changed, 78 insertions(+), 45 deletions(-) diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/processor/PrProcessRequest.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/processor/PrProcessRequest.java index 5dc47f1e..752efc2a 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/processor/PrProcessRequest.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/dto/request/processor/PrProcessRequest.java @@ -35,6 +35,13 @@ public class PrProcessRequest implements AnalysisProcessRequest { public String prAuthorId; public String prAuthorUsername; + + /** + * Optional pre-acquired lock key. If set, the processor will skip lock acquisition + * and use this lock key directly. This prevents double-locking when the webhook handler + * has already acquired the lock before calling the processor. + */ + public String preAcquiredLockKey; public Long getProjectId() { @@ -64,4 +71,6 @@ public String getSourceBranchName() { public String getPrAuthorId() { return prAuthorId; } public String getPrAuthorUsername() { return prAuthorUsername; } + + public String getPreAcquiredLockKey() { return preAcquiredLockKey; } } diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java index f5db459a..ce7b3292 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java @@ -91,34 +91,43 @@ public Map process( // Publish analysis started event publishAnalysisStartedEvent(project, request, correlationId); - Optional lockKey = analysisLockService.acquireLockWithWait( - project, - request.getSourceBranchName(), - AnalysisLockType.PR_ANALYSIS, - request.getCommitHash(), - request.getPullRequestId(), - consumer::accept - ); - - if (lockKey.isEmpty()) { - String message = String.format( - "Failed to acquire lock after %d minutes for project=%s, PR=%d, branch=%s. Another analysis is still in progress.", - analysisLockService.getLockWaitTimeoutMinutes(), - project.getId(), - request.getPullRequestId(), - request.getSourceBranchName() - ); - log.warn(message); - - // Publish failed event due to lock timeout - publishAnalysisCompletedEvent(project, request, correlationId, startTime, - AnalysisCompletedEvent.CompletionStatus.FAILED, 0, 0, "Lock acquisition timeout"); - - throw new AnalysisLockedException( - AnalysisLockType.PR_ANALYSIS.name(), + // Check if a lock was already acquired by the caller (e.g., webhook handler) + // to prevent double-locking which causes unnecessary 2-minute waits + String lockKey; + if (request.getPreAcquiredLockKey() != null && !request.getPreAcquiredLockKey().isBlank()) { + lockKey = request.getPreAcquiredLockKey(); + log.info("Using pre-acquired lock: {} for project={}, PR={}", lockKey, project.getId(), request.getPullRequestId()); + } else { + Optional acquiredLock = analysisLockService.acquireLockWithWait( + project, request.getSourceBranchName(), - project.getId() + AnalysisLockType.PR_ANALYSIS, + request.getCommitHash(), + request.getPullRequestId(), + consumer::accept ); + + if (acquiredLock.isEmpty()) { + String message = String.format( + "Failed to acquire lock after %d minutes for project=%s, PR=%d, branch=%s. Another analysis is still in progress.", + analysisLockService.getLockWaitTimeoutMinutes(), + project.getId(), + request.getPullRequestId(), + request.getSourceBranchName() + ); + log.warn(message); + + // Publish failed event due to lock timeout + publishAnalysisCompletedEvent(project, request, correlationId, startTime, + AnalysisCompletedEvent.CompletionStatus.FAILED, 0, 0, "Lock acquisition timeout"); + + throw new AnalysisLockedException( + AnalysisLockType.PR_ANALYSIS.name(), + request.getSourceBranchName(), + project.getId() + ); + } + lockKey = acquiredLock.get(); } try { @@ -216,7 +225,7 @@ public Map process( return Map.of("status", "error", "message", e.getMessage()); } finally { - analysisLockService.releaseLock(lockKey.get()); + analysisLockService.releaseLock(lockKey); } } diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/project/ProjectDTO.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/project/ProjectDTO.java index 9c2a70a3..98027a0d 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/project/ProjectDTO.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/dto/project/ProjectDTO.java @@ -32,7 +32,8 @@ public record ProjectDTO( String installationMethod, CommentCommandsConfigDTO commentCommandsConfig, Boolean webhooksConfigured, - Long qualityGateId + Long qualityGateId, + Integer maxAnalysisTokenLimit ) { public static ProjectDTO fromProject(Project project) { Long vcsConnectionId = null; @@ -123,6 +124,9 @@ public static ProjectDTO fromProject(Project project) { if (project.getVcsRepoBinding() != null) { webhooksConfigured = project.getVcsRepoBinding().isWebhooksConfigured(); } + + // Get maxAnalysisTokenLimit from config + Integer maxAnalysisTokenLimit = config != null ? config.maxAnalysisTokenLimit() : ProjectConfig.DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT; return new ProjectDTO( project.getId(), @@ -146,7 +150,8 @@ public static ProjectDTO fromProject(Project project) { installationMethod, commentCommandsConfigDTO, webhooksConfigured, - project.getQualityGate() != null ? project.getQualityGate().getId() : null + project.getQualityGate() != null ? project.getQualityGate().getId() : null, + maxAnalysisTokenLimit ); } diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/Project.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/Project.java index 1a956edd..b12b4227 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/Project.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/model/project/Project.java @@ -222,6 +222,15 @@ public void setConfiguration(org.rostilos.codecrow.core.model.project.config.Pro this.configuration = configuration; } + /** + * Returns the effective project configuration. + * If configuration is null, returns a new default ProjectConfig. + * This ensures callers always get a valid config with default values. + */ + public org.rostilos.codecrow.core.model.project.config.ProjectConfig getEffectiveConfig() { + return configuration != null ? configuration : new org.rostilos.codecrow.core.model.project.config.ProjectConfig(); + } + public org.rostilos.codecrow.core.model.branch.Branch getDefaultBranch() { return defaultBranch; } diff --git a/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/project/ProjectDTOTest.java b/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/project/ProjectDTOTest.java index 63096a11..2fe071a6 100644 --- a/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/project/ProjectDTOTest.java +++ b/java-ecosystem/libs/core/src/test/java/org/rostilos/codecrow/core/dto/project/ProjectDTOTest.java @@ -50,7 +50,7 @@ void shouldCreateWithAllFields() { 20L, "namespace", "main", "main", 100L, stats, ragConfig, true, false, "WEBHOOK", - commandsConfig, true, 50L + commandsConfig, true, 50L, 200000 ); assertThat(dto.id()).isEqualTo(1L); @@ -84,7 +84,7 @@ void shouldCreateWithNullOptionalFields() { 1L, "Test", null, true, null, null, null, null, null, null, null, null, null, null, null, - null, null, null, null, null, null, null + null, null, null, null, null, null, null, null ); assertThat(dto.description()).isNull(); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java index 04036114..71124e1d 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java @@ -111,9 +111,7 @@ private WebhookResult handlePullRequestEvent(WebhookPayload payload, Project pro try { // Try to acquire lock atomically BEFORE posting placeholder - // This prevents TOCTOU race where multiple webhooks could pass isLocked() check simultaneously - // Note: PullRequestAnalysisProcessor.process() uses acquireLockWithWait() which will - // reuse this lock since it's for the same project/branch/type + // This prevents race condition where multiple webhooks could post duplicate placeholders String sourceBranch = payload.sourceBranch(); Optional earlyLock = analysisLockService.acquireLock( project, sourceBranch, AnalysisLockType.PR_ANALYSIS, @@ -126,8 +124,6 @@ private WebhookResult handlePullRequestEvent(WebhookPayload payload, Project pro } // Lock acquired - placeholder posting is now protected from race conditions - // Note: We don't release this lock here - PullRequestAnalysisProcessor will manage it - // since acquireLockWithWait() will detect the existing lock and use it // Post placeholder comment immediately to show analysis has started placeholderCommentId = postPlaceholderComment(project, Long.parseLong(payload.pullRequestId())); @@ -143,6 +139,8 @@ private WebhookResult handlePullRequestEvent(WebhookPayload payload, Project pro request.placeholderCommentId = placeholderCommentId; request.prAuthorId = payload.prAuthorId(); request.prAuthorUsername = payload.prAuthorUsername(); + // Pass the pre-acquired lock key to avoid double-locking in the processor + request.preAcquiredLockKey = earlyLock.get(); log.info("Processing PR analysis: project={}, PR={}, source={}, target={}, placeholderCommentId={}", project.getId(), request.pullRequestId, request.sourceBranchName, request.targetBranchName, placeholderCommentId); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java index 2997a833..7f72aeaf 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java @@ -124,9 +124,7 @@ private WebhookResult handlePullRequestEvent( try { // Try to acquire lock atomically BEFORE posting placeholder - // This prevents TOCTOU race where multiple webhooks could pass isLocked() check simultaneously - // Note: PullRequestAnalysisProcessor.process() uses acquireLockWithWait() which will - // reuse this lock since it's for the same project/branch/type + // This prevents race condition where multiple webhooks could post duplicate placeholders String sourceBranch = payload.sourceBranch(); Optional earlyLock = analysisLockService.acquireLock( project, sourceBranch, AnalysisLockType.PR_ANALYSIS, @@ -139,8 +137,6 @@ private WebhookResult handlePullRequestEvent( } // Lock acquired - placeholder posting is now protected from race conditions - // Note: We don't release this lock here - PullRequestAnalysisProcessor will manage it - // since acquireLockWithWait() will detect the existing lock and use it // Post placeholder comment immediately to show analysis has started placeholderCommentId = postPlaceholderComment(project, Long.parseLong(payload.pullRequestId())); @@ -156,6 +152,8 @@ private WebhookResult handlePullRequestEvent( request.placeholderCommentId = placeholderCommentId; request.prAuthorId = payload.prAuthorId(); request.prAuthorUsername = payload.prAuthorUsername(); + // Pass the pre-acquired lock key to avoid double-locking in the processor + request.preAcquiredLockKey = earlyLock.get(); log.info("Processing PR analysis: project={}, PR={}, source={}, target={}, placeholderCommentId={}", project.getId(), request.pullRequestId, request.sourceBranchName, request.targetBranchName, placeholderCommentId); diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/controller/ProjectController.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/controller/ProjectController.java index 9704f48b..774b41fd 100644 --- a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/controller/ProjectController.java +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/controller/ProjectController.java @@ -584,7 +584,8 @@ public ResponseEntity updateAnalysisSettings( project.getId(), request.prAnalysisEnabled(), request.branchAnalysisEnabled(), - installationMethod + installationMethod, + request.maxAnalysisTokenLimit() ); return new ResponseEntity<>(ProjectDTO.fromProject(updated), HttpStatus.OK); } @@ -592,7 +593,8 @@ public ResponseEntity updateAnalysisSettings( public record UpdateAnalysisSettingsRequest( Boolean prAnalysisEnabled, Boolean branchAnalysisEnabled, - String installationMethod + String installationMethod, + Integer maxAnalysisTokenLimit ) {} /** diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java index 07aeb2be..ac98db3a 100644 --- a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java @@ -561,7 +561,8 @@ public Project updateAnalysisSettings( Long projectId, Boolean prAnalysisEnabled, Boolean branchAnalysisEnabled, - InstallationMethod installationMethod + InstallationMethod installationMethod, + Integer maxAnalysisTokenLimit ) { Project project = projectRepository.findByWorkspaceIdAndId(workspaceId, projectId) .orElseThrow(() -> new NoSuchElementException("Project not found")); @@ -572,6 +573,7 @@ public Project updateAnalysisSettings( var branchAnalysis = currentConfig != null ? currentConfig.branchAnalysis() : null; var ragConfig = currentConfig != null ? currentConfig.ragConfig() : null; var commentCommands = currentConfig != null ? currentConfig.commentCommands() : null; + int currentMaxTokenLimit = currentConfig != null ? currentConfig.maxAnalysisTokenLimit() : ProjectConfig.DEFAULT_MAX_ANALYSIS_TOKEN_LIMIT; Boolean newPrAnalysis = prAnalysisEnabled != null ? prAnalysisEnabled : (currentConfig != null ? currentConfig.prAnalysisEnabled() : true); @@ -579,6 +581,7 @@ public Project updateAnalysisSettings( (currentConfig != null ? currentConfig.branchAnalysisEnabled() : true); var newInstallationMethod = installationMethod != null ? installationMethod : (currentConfig != null ? currentConfig.installationMethod() : null); + int newMaxTokenLimit = maxAnalysisTokenLimit != null ? maxAnalysisTokenLimit : currentMaxTokenLimit; // Update both the direct column and the JSON config //TODO: remove duplication @@ -586,7 +589,7 @@ public Project updateAnalysisSettings( project.setBranchAnalysisEnabled(newBranchAnalysis != null ? newBranchAnalysis : true); project.setConfiguration(new ProjectConfig(useLocalMcp, mainBranch, branchAnalysis, ragConfig, - newPrAnalysis, newBranchAnalysis, newInstallationMethod, commentCommands)); + newPrAnalysis, newBranchAnalysis, newInstallationMethod, commentCommands, newMaxTokenLimit)); return projectRepository.save(project); } From 6d80d71283010177ebe0abac77e3e2552630c254 Mon Sep 17 00:00:00 2001 From: rostislav Date: Mon, 26 Jan 2026 20:59:16 +0200 Subject: [PATCH 03/20] feat: Implement handling for AnalysisLockedException and DiffTooLargeException in webhook processors --- ...tbucketCloudPullRequestWebhookHandler.java | 6 ++++ .../processor/WebhookAsyncProcessor.java | 35 +++++++++++++++++++ .../CommentCommandWebhookHandler.java | 10 ++++++ .../GitHubPullRequestWebhookHandler.java | 6 ++++ 4 files changed, 57 insertions(+) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java index 71124e1d..68b30ae4 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudPullRequestWebhookHandler.java @@ -5,6 +5,8 @@ import org.rostilos.codecrow.core.model.project.Project; import org.rostilos.codecrow.core.model.vcs.EVcsProvider; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; +import org.rostilos.codecrow.analysisengine.exception.AnalysisLockedException; +import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.processor.analysis.PullRequestAnalysisProcessor; import org.rostilos.codecrow.analysisengine.service.AnalysisLockService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsReportingService; @@ -166,6 +168,10 @@ private WebhookResult handlePullRequestEvent(WebhookPayload payload, Project pro return WebhookResult.success("PR analysis completed", result); + } catch (DiffTooLargeException | AnalysisLockedException e) { + // Re-throw these exceptions so WebhookAsyncProcessor can handle them properly + log.warn("PR analysis failed with recoverable exception for project {}: {}", project.getId(), e.getMessage()); + throw e; } catch (Exception e) { log.error("PR analysis failed for project {}", project.getId(), e); // Try to update placeholder with error message diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index 039eb86c..d65c7236 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -7,6 +7,7 @@ import org.rostilos.codecrow.core.service.JobService; import org.rostilos.codecrow.pipelineagent.generic.dto.webhook.WebhookPayload; import org.rostilos.codecrow.pipelineagent.generic.webhookhandler.WebhookHandler; +import org.rostilos.codecrow.analysisengine.exception.AnalysisLockedException; import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.service.vcs.VcsReportingService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsServiceFactory; @@ -199,6 +200,40 @@ public void processWebhookAsync( log.error("Failed to skip job: {}", skipError.getMessage()); } + } catch (AnalysisLockedException lockEx) { + // Handle lock acquisition failure - mark job as failed + log.warn("Lock acquisition failed for analysis: {}", lockEx.getMessage()); + + String failMessage = String.format( + "⚠️ **Analysis Failed - Resource Locked**\n\n" + + "Could not acquire analysis lock after timeout:\n" + + "- **Lock type:** %s\n" + + "- **Branch:** %s\n" + + "- **Project:** %d\n\n" + + "Another analysis may be in progress. Please try again later.", + lockEx.getLockType(), + lockEx.getBranchName(), + lockEx.getProjectId() + ); + + try { + if (project == null) { + project = projectRepository.findById(projectId).orElse(null); + } + if (project != null) { + initializeProjectAssociations(project); + postErrorToVcs(provider, project, payload, failMessage, placeholderCommentId, job); + } + } catch (Exception postError) { + log.error("Failed to post lock error to VCS: {}", postError.getMessage()); + } + + try { + jobService.failJob(job, "Lock acquisition timeout: " + lockEx.getMessage()); + } catch (Exception failError) { + log.error("Failed to fail job: {}", failError.getMessage()); + } + } catch (Exception e) { log.error("Error processing webhook for job {}", job.getExternalId(), e); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/CommentCommandWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/CommentCommandWebhookHandler.java index 6267cbd7..0a120fb7 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/CommentCommandWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/CommentCommandWebhookHandler.java @@ -13,6 +13,8 @@ import org.rostilos.codecrow.core.persistence.repository.codeanalysis.PrSummarizeCacheRepository; import org.rostilos.codecrow.core.service.CodeAnalysisService; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; +import org.rostilos.codecrow.analysisengine.exception.AnalysisLockedException; +import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.processor.analysis.PullRequestAnalysisProcessor; import org.rostilos.codecrow.pipelineagent.generic.dto.webhook.WebhookPayload; import org.rostilos.codecrow.pipelineagent.generic.service.CommandAuthorizationService; @@ -445,6 +447,14 @@ private WebhookResult runPrAnalysis( // If we got here, the processor posted results directly (which it does) return WebhookResult.success("Analysis completed", Map.of("commandType", commandType)); + } catch (DiffTooLargeException e) { + // Re-throw DiffTooLargeException so WebhookAsyncProcessor can handle it with proper job status + log.warn("PR diff too large for {} command: {}", commandType, e.getMessage()); + throw e; + } catch (AnalysisLockedException e) { + // Re-throw AnalysisLockedException so WebhookAsyncProcessor can handle it with proper job status + log.warn("Lock acquisition failed for {} command: {}", commandType, e.getMessage()); + throw e; } catch (Exception e) { log.error("Error running PR analysis for {} command: {}", commandType, e.getMessage(), e); return WebhookResult.error("Analysis failed: " + e.getMessage()); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java index 7f72aeaf..87f949f9 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/webhookhandler/GitHubPullRequestWebhookHandler.java @@ -5,6 +5,8 @@ import org.rostilos.codecrow.core.model.project.Project; import org.rostilos.codecrow.core.model.vcs.EVcsProvider; import org.rostilos.codecrow.analysisengine.dto.request.processor.PrProcessRequest; +import org.rostilos.codecrow.analysisengine.exception.AnalysisLockedException; +import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; import org.rostilos.codecrow.analysisengine.processor.analysis.PullRequestAnalysisProcessor; import org.rostilos.codecrow.analysisengine.service.AnalysisLockService; import org.rostilos.codecrow.analysisengine.service.vcs.VcsReportingService; @@ -179,6 +181,10 @@ private WebhookResult handlePullRequestEvent( return WebhookResult.success("PR analysis completed", result); + } catch (DiffTooLargeException | AnalysisLockedException e) { + // Re-throw these exceptions so WebhookAsyncProcessor can handle them properly + log.warn("PR analysis failed with recoverable exception for project {}: {}", project.getId(), e.getMessage()); + throw e; } catch (Exception e) { log.error("PR analysis failed for project {}", project.getId(), e); // Try to update placeholder with error message From e2c14743383242c33891d4d6f989ceec2bb92b11 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 10:35:19 +0200 Subject: [PATCH 04/20] feat: Re-fetch job entities in transaction methods to handle detached entities from async contexts --- .../codecrow/core/service/JobService.java | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java index 04036c47..0a50c269 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java @@ -221,9 +221,12 @@ private String getCommandJobTitle(JobType type, Long prNumber) { /** * Start a job (transition from PENDING to RUNNING). + * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job startJob(Job job) { + // Re-fetch to ensure attached entity in current transaction (handles async context) + job = jobRepository.findById(job.getId()).orElse(job); job.start(); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "start", "Job started"); @@ -241,9 +244,12 @@ public Job startJob(String externalId) { /** * Complete a job successfully. + * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job completeJob(Job job) { + // Re-fetch to ensure attached entity in current transaction (handles async context) + job = jobRepository.findById(job.getId()).orElse(job); job.complete(); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "complete", "Job completed successfully"); @@ -253,9 +259,12 @@ public Job completeJob(Job job) { /** * Complete a job and link it to a code analysis. + * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job completeJob(Job job, CodeAnalysis codeAnalysis) { + // Re-fetch to ensure attached entity in current transaction (handles async context) + job = jobRepository.findById(job.getId()).orElse(job); job.setCodeAnalysis(codeAnalysis); job.complete(); job = jobRepository.save(job); @@ -284,9 +293,12 @@ public Job failJob(Job job, String errorMessage) { /** * Cancel a job. + * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job cancelJob(Job job) { + // Re-fetch to ensure attached entity in current transaction (handles async context) + job = jobRepository.findById(job.getId()).orElse(job); job.cancel(); job = jobRepository.save(job); addLog(job, JobLogLevel.WARN, "cancel", "Job cancelled"); @@ -296,9 +308,12 @@ public Job cancelJob(Job job) { /** * Skip a job (e.g., due to branch pattern settings). + * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job skipJob(Job job, String reason) { + // Re-fetch to ensure attached entity in current transaction (handles async context) + job = jobRepository.findById(job.getId()).orElse(job); job.skip(reason); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "skipped", reason); From 342c4fadbbe9b09dbe5ab56c69c02af4493fb13f Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 11:01:57 +0200 Subject: [PATCH 05/20] feat: Update JobService and WebhookAsyncProcessor to manage job entities without re-fetching in async contexts --- .../codecrow/core/service/JobService.java | 15 --- .../processor/WebhookAsyncProcessor.java | 92 ++++++++++++++----- 2 files changed, 70 insertions(+), 37 deletions(-) diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java index 0a50c269..04036c47 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java @@ -221,12 +221,9 @@ private String getCommandJobTitle(JobType type, Long prNumber) { /** * Start a job (transition from PENDING to RUNNING). - * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job startJob(Job job) { - // Re-fetch to ensure attached entity in current transaction (handles async context) - job = jobRepository.findById(job.getId()).orElse(job); job.start(); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "start", "Job started"); @@ -244,12 +241,9 @@ public Job startJob(String externalId) { /** * Complete a job successfully. - * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job completeJob(Job job) { - // Re-fetch to ensure attached entity in current transaction (handles async context) - job = jobRepository.findById(job.getId()).orElse(job); job.complete(); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "complete", "Job completed successfully"); @@ -259,12 +253,9 @@ public Job completeJob(Job job) { /** * Complete a job and link it to a code analysis. - * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job completeJob(Job job, CodeAnalysis codeAnalysis) { - // Re-fetch to ensure attached entity in current transaction (handles async context) - job = jobRepository.findById(job.getId()).orElse(job); job.setCodeAnalysis(codeAnalysis); job.complete(); job = jobRepository.save(job); @@ -293,12 +284,9 @@ public Job failJob(Job job, String errorMessage) { /** * Cancel a job. - * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job cancelJob(Job job) { - // Re-fetch to ensure attached entity in current transaction (handles async context) - job = jobRepository.findById(job.getId()).orElse(job); job.cancel(); job = jobRepository.save(job); addLog(job, JobLogLevel.WARN, "cancel", "Job cancelled"); @@ -308,12 +296,9 @@ public Job cancelJob(Job job) { /** * Skip a job (e.g., due to branch pattern settings). - * Re-fetches the job to handle detached entities from async contexts. */ @Transactional public Job skipJob(Job job, String reason) { - // Re-fetch to ensure attached entity in current transaction (handles async context) - job = jobRepository.findById(job.getId()).orElse(job); job.skip(reason); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "skipped", reason); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index d65c7236..133c148f 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -82,10 +82,11 @@ public WebhookAsyncProcessor( } /** - * Process a webhook asynchronously with proper transactional context. + * Process a webhook asynchronously. + * Note: This method is NOT transactional to avoid issues with nested transactions + * (e.g., failJob uses REQUIRES_NEW). Each inner operation manages its own transaction. */ @Async("webhookExecutor") - @Transactional public void processWebhookAsync( EVcsProvider provider, Long projectId, @@ -96,19 +97,35 @@ public void processWebhookAsync( String placeholderCommentId = null; Project project = null; + // Store job external ID for re-fetching - the passed Job entity is detached + // since it was created in the HTTP request transaction which has already committed + String jobExternalId = job.getExternalId(); + + // Declare managed job reference that will be set after re-fetching + // This needs to be accessible in catch blocks for error handling + Job managedJob = null; + try { - // Re-fetch project within transaction to ensure all lazy associations are available + // Re-fetch project to ensure all lazy associations are available project = projectRepository.findById(projectId) .orElseThrow(() -> new IllegalStateException("Project not found: " + projectId)); // Initialize lazy associations we'll need initializeProjectAssociations(project); - jobService.startJob(job); + // Re-fetch the job by external ID to get a managed entity in the current context + // This is necessary because the Job was created in the HTTP request transaction + // which has already committed by the time this async method runs + managedJob = jobService.findByExternalIdOrThrow(jobExternalId); + + // Create final reference for use in lambda + final Job jobForLambda = managedJob; + + jobService.startJob(managedJob); // Post placeholder comment immediately if this is a CodeCrow command on a PR if (payload.hasCodecrowCommand() && payload.pullRequestId() != null) { - placeholderCommentId = postPlaceholderComment(provider, project, payload, job); + placeholderCommentId = postPlaceholderComment(provider, project, payload, managedJob); } // Store placeholder ID for use in result posting @@ -118,7 +135,7 @@ public void processWebhookAsync( WebhookHandler.WebhookResult result = handler.handle(payload, project, event -> { String message = (String) event.getOrDefault("message", "Processing..."); String state = (String) event.getOrDefault("state", "processing"); - jobService.info(job, state, message); + jobService.info(jobForLambda, state, message); }); // Check if the webhook was ignored (e.g., branch not matching pattern, analysis disabled) @@ -131,14 +148,14 @@ public void processWebhookAsync( // Delete the job entirely - don't clutter DB with ignored webhooks // If deletion fails, skip the job instead try { - jobService.deleteIgnoredJob(job, result.message()); + jobService.deleteIgnoredJob(managedJob, result.message()); } catch (Exception deleteError) { log.warn("Failed to delete ignored job {}, skipping instead: {}", - job.getExternalId(), deleteError.getMessage()); + managedJob.getExternalId(), deleteError.getMessage()); try { - jobService.skipJob(job, result.message()); + jobService.skipJob(managedJob, result.message()); } catch (Exception skipError) { - log.error("Failed to skip job {}: {}", job.getExternalId(), skipError.getMessage()); + log.error("Failed to skip job {}: {}", managedJob.getExternalId(), skipError.getMessage()); } } return; @@ -146,28 +163,38 @@ public void processWebhookAsync( if (result.success()) { // Post result to VCS if there's content to post - postResultToVcs(provider, project, payload, result, finalPlaceholderCommentId, job); + postResultToVcs(provider, project, payload, result, finalPlaceholderCommentId, managedJob); if (result.data().containsKey("analysisId")) { Long analysisId = ((Number) result.data().get("analysisId")).longValue(); - jobService.info(job, "complete", "Analysis completed. Analysis ID: " + analysisId); + jobService.info(managedJob, "complete", "Analysis completed. Analysis ID: " + analysisId); } - jobService.completeJob(job); + jobService.completeJob(managedJob); } else { // Post error to VCS (update placeholder if exists) - but ensure failJob is always called try { - postErrorToVcs(provider, project, payload, result.message(), finalPlaceholderCommentId, job); + postErrorToVcs(provider, project, payload, result.message(), finalPlaceholderCommentId, managedJob); } catch (Exception postError) { log.error("Failed to post error to VCS: {}", postError.getMessage()); } // Always mark the job as failed, even if posting to VCS failed - jobService.failJob(job, result.message()); + jobService.failJob(managedJob, result.message()); } } catch (DiffTooLargeException diffEx) { // Handle diff too large - this is a soft skip, not an error log.warn("Diff too large for analysis - skipping: {}", diffEx.getMessage()); + // Re-fetch job if not yet fetched + if (managedJob == null) { + try { + managedJob = jobService.findByExternalIdOrThrow(jobExternalId); + } catch (Exception fetchError) { + log.error("Failed to fetch job {} for skip operation: {}", jobExternalId, fetchError.getMessage()); + return; + } + } + String skipMessage = String.format( "⚠️ **Analysis Skipped - PR Too Large**\n\n" + "This PR's diff exceeds the configured token limit:\n" + @@ -188,14 +215,14 @@ public void processWebhookAsync( } if (project != null) { initializeProjectAssociations(project); - postInfoToVcs(provider, project, payload, skipMessage, placeholderCommentId, job); + postInfoToVcs(provider, project, payload, skipMessage, placeholderCommentId, managedJob); } } catch (Exception postError) { log.error("Failed to post skip message to VCS: {}", postError.getMessage()); } try { - jobService.skipJob(job, "Diff too large: " + diffEx.getEstimatedTokens() + " tokens > " + diffEx.getMaxAllowedTokens() + " limit"); + jobService.skipJob(managedJob, "Diff too large: " + diffEx.getEstimatedTokens() + " tokens > " + diffEx.getMaxAllowedTokens() + " limit"); } catch (Exception skipError) { log.error("Failed to skip job: {}", skipError.getMessage()); } @@ -204,6 +231,16 @@ public void processWebhookAsync( // Handle lock acquisition failure - mark job as failed log.warn("Lock acquisition failed for analysis: {}", lockEx.getMessage()); + // Re-fetch job if not yet fetched + if (managedJob == null) { + try { + managedJob = jobService.findByExternalIdOrThrow(jobExternalId); + } catch (Exception fetchError) { + log.error("Failed to fetch job {} for fail operation: {}", jobExternalId, fetchError.getMessage()); + return; + } + } + String failMessage = String.format( "⚠️ **Analysis Failed - Resource Locked**\n\n" + "Could not acquire analysis lock after timeout:\n" + @@ -222,20 +259,31 @@ public void processWebhookAsync( } if (project != null) { initializeProjectAssociations(project); - postErrorToVcs(provider, project, payload, failMessage, placeholderCommentId, job); + postErrorToVcs(provider, project, payload, failMessage, placeholderCommentId, managedJob); } } catch (Exception postError) { log.error("Failed to post lock error to VCS: {}", postError.getMessage()); } try { - jobService.failJob(job, "Lock acquisition timeout: " + lockEx.getMessage()); + jobService.failJob(managedJob, "Lock acquisition timeout: " + lockEx.getMessage()); } catch (Exception failError) { log.error("Failed to fail job: {}", failError.getMessage()); } } catch (Exception e) { - log.error("Error processing webhook for job {}", job.getExternalId(), e); + // Re-fetch job if not yet fetched + if (managedJob == null) { + try { + managedJob = jobService.findByExternalIdOrThrow(jobExternalId); + } catch (Exception fetchError) { + log.error("Failed to fetch job {} for fail operation: {}", jobExternalId, fetchError.getMessage()); + log.error("Original error processing webhook", e); + return; + } + } + + log.error("Error processing webhook for job {}", managedJob.getExternalId(), e); try { if (project == null) { @@ -243,14 +291,14 @@ public void processWebhookAsync( } if (project != null) { initializeProjectAssociations(project); - postErrorToVcs(provider, project, payload, "Processing failed: " + e.getMessage(), placeholderCommentId, job); + postErrorToVcs(provider, project, payload, "Processing failed: " + e.getMessage(), placeholderCommentId, managedJob); } } catch (Exception postError) { log.error("Failed to post error to VCS: {}", postError.getMessage()); } try { - jobService.failJob(job, "Processing failed: " + e.getMessage()); + jobService.failJob(managedJob, "Processing failed: " + e.getMessage()); } catch (Exception failError) { log.error("Failed to mark job as failed: {}", failError.getMessage()); } From 409c42df5ffaf13b6ab101f74496941d87977d3a Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 11:05:40 +0200 Subject: [PATCH 06/20] feat: Enable transaction management in processWebhookAsync to support lazy loading of associations --- .../generic/processor/WebhookAsyncProcessor.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index 133c148f..52ce4565 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -83,10 +83,11 @@ public WebhookAsyncProcessor( /** * Process a webhook asynchronously. - * Note: This method is NOT transactional to avoid issues with nested transactions - * (e.g., failJob uses REQUIRES_NEW). Each inner operation manages its own transaction. + * This method uses a transaction to ensure lazy associations can be loaded. + * Inner operations like failJob use REQUIRES_NEW which creates nested transactions as needed. */ @Async("webhookExecutor") + @Transactional public void processWebhookAsync( EVcsProvider provider, Long projectId, From 11c983c4b5377c330f3066cc5541b483f802547c Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 11:57:26 +0200 Subject: [PATCH 07/20] feat: Re-fetch job entities in JobService methods to ensure consistency across transaction contexts --- .../codecrow/core/service/JobService.java | 8 ++ .../processor/WebhookAsyncProcessor.java | 88 ++++--------------- 2 files changed, 27 insertions(+), 69 deletions(-) diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java index 04036c47..29148529 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java @@ -224,6 +224,8 @@ private String getCommandJobTitle(JobType type, Long prNumber) { */ @Transactional public Job startJob(Job job) { + // Re-fetch the job in case it was passed from a different transaction context + job = jobRepository.findById(job.getId()).orElse(job); job.start(); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "start", "Job started"); @@ -244,6 +246,8 @@ public Job startJob(String externalId) { */ @Transactional public Job completeJob(Job job) { + // Re-fetch the job in case it was passed from a different transaction context + job = jobRepository.findById(job.getId()).orElse(job); job.complete(); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "complete", "Job completed successfully"); @@ -287,6 +291,8 @@ public Job failJob(Job job, String errorMessage) { */ @Transactional public Job cancelJob(Job job) { + // Re-fetch the job in case it was passed from a different transaction context + job = jobRepository.findById(job.getId()).orElse(job); job.cancel(); job = jobRepository.save(job); addLog(job, JobLogLevel.WARN, "cancel", "Job cancelled"); @@ -299,6 +305,8 @@ public Job cancelJob(Job job) { */ @Transactional public Job skipJob(Job job, String reason) { + // Re-fetch the job in case it was passed from a different transaction context + job = jobRepository.findById(job.getId()).orElse(job); job.skip(reason); job = jobRepository.save(job); addLog(job, JobLogLevel.INFO, "skipped", reason); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index 52ce4565..1d41b507 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -83,11 +83,8 @@ public WebhookAsyncProcessor( /** * Process a webhook asynchronously. - * This method uses a transaction to ensure lazy associations can be loaded. - * Inner operations like failJob use REQUIRES_NEW which creates nested transactions as needed. */ @Async("webhookExecutor") - @Transactional public void processWebhookAsync( EVcsProvider provider, Long projectId, @@ -98,14 +95,6 @@ public void processWebhookAsync( String placeholderCommentId = null; Project project = null; - // Store job external ID for re-fetching - the passed Job entity is detached - // since it was created in the HTTP request transaction which has already committed - String jobExternalId = job.getExternalId(); - - // Declare managed job reference that will be set after re-fetching - // This needs to be accessible in catch blocks for error handling - Job managedJob = null; - try { // Re-fetch project to ensure all lazy associations are available project = projectRepository.findById(projectId) @@ -114,19 +103,11 @@ public void processWebhookAsync( // Initialize lazy associations we'll need initializeProjectAssociations(project); - // Re-fetch the job by external ID to get a managed entity in the current context - // This is necessary because the Job was created in the HTTP request transaction - // which has already committed by the time this async method runs - managedJob = jobService.findByExternalIdOrThrow(jobExternalId); - - // Create final reference for use in lambda - final Job jobForLambda = managedJob; - - jobService.startJob(managedJob); + jobService.startJob(job); // Post placeholder comment immediately if this is a CodeCrow command on a PR if (payload.hasCodecrowCommand() && payload.pullRequestId() != null) { - placeholderCommentId = postPlaceholderComment(provider, project, payload, managedJob); + placeholderCommentId = postPlaceholderComment(provider, project, payload, job); } // Store placeholder ID for use in result posting @@ -136,7 +117,7 @@ public void processWebhookAsync( WebhookHandler.WebhookResult result = handler.handle(payload, project, event -> { String message = (String) event.getOrDefault("message", "Processing..."); String state = (String) event.getOrDefault("state", "processing"); - jobService.info(jobForLambda, state, message); + jobService.info(job, state, message); }); // Check if the webhook was ignored (e.g., branch not matching pattern, analysis disabled) @@ -149,14 +130,14 @@ public void processWebhookAsync( // Delete the job entirely - don't clutter DB with ignored webhooks // If deletion fails, skip the job instead try { - jobService.deleteIgnoredJob(managedJob, result.message()); + jobService.deleteIgnoredJob(job, result.message()); } catch (Exception deleteError) { log.warn("Failed to delete ignored job {}, skipping instead: {}", - managedJob.getExternalId(), deleteError.getMessage()); + job.getExternalId(), deleteError.getMessage()); try { - jobService.skipJob(managedJob, result.message()); + jobService.skipJob(job, result.message()); } catch (Exception skipError) { - log.error("Failed to skip job {}: {}", managedJob.getExternalId(), skipError.getMessage()); + log.error("Failed to skip job {}: {}", job.getExternalId(), skipError.getMessage()); } } return; @@ -164,38 +145,28 @@ public void processWebhookAsync( if (result.success()) { // Post result to VCS if there's content to post - postResultToVcs(provider, project, payload, result, finalPlaceholderCommentId, managedJob); + postResultToVcs(provider, project, payload, result, finalPlaceholderCommentId, job); if (result.data().containsKey("analysisId")) { Long analysisId = ((Number) result.data().get("analysisId")).longValue(); - jobService.info(managedJob, "complete", "Analysis completed. Analysis ID: " + analysisId); + jobService.info(job, "complete", "Analysis completed. Analysis ID: " + analysisId); } - jobService.completeJob(managedJob); + jobService.completeJob(job); } else { // Post error to VCS (update placeholder if exists) - but ensure failJob is always called try { - postErrorToVcs(provider, project, payload, result.message(), finalPlaceholderCommentId, managedJob); + postErrorToVcs(provider, project, payload, result.message(), finalPlaceholderCommentId, job); } catch (Exception postError) { log.error("Failed to post error to VCS: {}", postError.getMessage()); } // Always mark the job as failed, even if posting to VCS failed - jobService.failJob(managedJob, result.message()); + jobService.failJob(job, result.message()); } } catch (DiffTooLargeException diffEx) { // Handle diff too large - this is a soft skip, not an error log.warn("Diff too large for analysis - skipping: {}", diffEx.getMessage()); - // Re-fetch job if not yet fetched - if (managedJob == null) { - try { - managedJob = jobService.findByExternalIdOrThrow(jobExternalId); - } catch (Exception fetchError) { - log.error("Failed to fetch job {} for skip operation: {}", jobExternalId, fetchError.getMessage()); - return; - } - } - String skipMessage = String.format( "⚠️ **Analysis Skipped - PR Too Large**\n\n" + "This PR's diff exceeds the configured token limit:\n" + @@ -216,14 +187,14 @@ public void processWebhookAsync( } if (project != null) { initializeProjectAssociations(project); - postInfoToVcs(provider, project, payload, skipMessage, placeholderCommentId, managedJob); + postInfoToVcs(provider, project, payload, skipMessage, placeholderCommentId, job); } } catch (Exception postError) { log.error("Failed to post skip message to VCS: {}", postError.getMessage()); } try { - jobService.skipJob(managedJob, "Diff too large: " + diffEx.getEstimatedTokens() + " tokens > " + diffEx.getMaxAllowedTokens() + " limit"); + jobService.skipJob(job, "Diff too large: " + diffEx.getEstimatedTokens() + " tokens > " + diffEx.getMaxAllowedTokens() + " limit"); } catch (Exception skipError) { log.error("Failed to skip job: {}", skipError.getMessage()); } @@ -232,16 +203,6 @@ public void processWebhookAsync( // Handle lock acquisition failure - mark job as failed log.warn("Lock acquisition failed for analysis: {}", lockEx.getMessage()); - // Re-fetch job if not yet fetched - if (managedJob == null) { - try { - managedJob = jobService.findByExternalIdOrThrow(jobExternalId); - } catch (Exception fetchError) { - log.error("Failed to fetch job {} for fail operation: {}", jobExternalId, fetchError.getMessage()); - return; - } - } - String failMessage = String.format( "⚠️ **Analysis Failed - Resource Locked**\n\n" + "Could not acquire analysis lock after timeout:\n" + @@ -260,31 +221,20 @@ public void processWebhookAsync( } if (project != null) { initializeProjectAssociations(project); - postErrorToVcs(provider, project, payload, failMessage, placeholderCommentId, managedJob); + postErrorToVcs(provider, project, payload, failMessage, placeholderCommentId, job); } } catch (Exception postError) { log.error("Failed to post lock error to VCS: {}", postError.getMessage()); } try { - jobService.failJob(managedJob, "Lock acquisition timeout: " + lockEx.getMessage()); + jobService.failJob(job, "Lock acquisition timeout: " + lockEx.getMessage()); } catch (Exception failError) { log.error("Failed to fail job: {}", failError.getMessage()); } } catch (Exception e) { - // Re-fetch job if not yet fetched - if (managedJob == null) { - try { - managedJob = jobService.findByExternalIdOrThrow(jobExternalId); - } catch (Exception fetchError) { - log.error("Failed to fetch job {} for fail operation: {}", jobExternalId, fetchError.getMessage()); - log.error("Original error processing webhook", e); - return; - } - } - - log.error("Error processing webhook for job {}", managedJob.getExternalId(), e); + log.error("Error processing webhook for job {}", job.getExternalId(), e); try { if (project == null) { @@ -292,14 +242,14 @@ public void processWebhookAsync( } if (project != null) { initializeProjectAssociations(project); - postErrorToVcs(provider, project, payload, "Processing failed: " + e.getMessage(), placeholderCommentId, managedJob); + postErrorToVcs(provider, project, payload, "Processing failed: " + e.getMessage(), placeholderCommentId, job); } } catch (Exception postError) { log.error("Failed to post error to VCS: {}", postError.getMessage()); } try { - jobService.failJob(managedJob, "Processing failed: " + e.getMessage()); + jobService.failJob(job, "Processing failed: " + e.getMessage()); } catch (Exception failError) { log.error("Failed to mark job as failed: {}", failError.getMessage()); } From c75eaba20c13abbc0a170271f3a94667093c2745 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 12:04:31 +0200 Subject: [PATCH 08/20] feat: Add @Transactional annotation to processWebhookAsync for lazy loading of associations --- .../pipelineagent/generic/processor/WebhookAsyncProcessor.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index 1d41b507..f9e210a2 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -83,8 +83,10 @@ public WebhookAsyncProcessor( /** * Process a webhook asynchronously. + * Uses @Transactional to ensure lazy associations can be loaded. */ @Async("webhookExecutor") + @Transactional public void processWebhookAsync( EVcsProvider provider, Long projectId, From 8afc0ada6a03ed71399c630d65cdbe00519d865a Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 12:27:09 +0200 Subject: [PATCH 09/20] feat: Implement self-injection in WebhookAsyncProcessor for proper transaction management in async context --- .../processor/WebhookAsyncProcessor.java | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index f9e210a2..85dd25b2 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -15,6 +15,8 @@ import java.io.IOException; import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Lazy; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -71,6 +73,11 @@ public class WebhookAsyncProcessor { private final JobService jobService; private final VcsServiceFactory vcsServiceFactory; + // Self-injection for @Transactional proxy to work from @Async method + @Autowired + @Lazy + private WebhookAsyncProcessor self; + public WebhookAsyncProcessor( ProjectRepository projectRepository, JobService jobService, @@ -83,16 +90,35 @@ public WebhookAsyncProcessor( /** * Process a webhook asynchronously. - * Uses @Transactional to ensure lazy associations can be loaded. + * Delegates to a transactional method to ensure lazy associations can be loaded. + * NOTE: @Async and @Transactional cannot be on the same method - the transaction + * proxy gets bypassed. We use self-injection to call a separate @Transactional method. */ @Async("webhookExecutor") - @Transactional public void processWebhookAsync( EVcsProvider provider, Long projectId, WebhookPayload payload, WebhookHandler handler, Job job + ) { + log.info("processWebhookAsync started for job {} (projectId={}, event={})", + job.getExternalId(), projectId, payload.eventType()); + // Delegate to transactional method via self-reference to ensure proxy is used + self.processWebhookInTransaction(provider, projectId, payload, handler, job); + } + + /** + * Process webhook within a transaction. + * Called from async method via self-injection to ensure transaction proxy works. + */ + @Transactional + public void processWebhookInTransaction( + EVcsProvider provider, + Long projectId, + WebhookPayload payload, + WebhookHandler handler, + Job job ) { String placeholderCommentId = null; Project project = null; From 402486b97def35d3098f8e2a6a8b4c2b4c72f4c4 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 13:09:26 +0200 Subject: [PATCH 10/20] feat: Enhance logging and error handling in processWebhookAsync for improved job management --- .../processor/WebhookAsyncProcessor.java | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index 85dd25b2..819855ee 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -104,8 +104,19 @@ public void processWebhookAsync( ) { log.info("processWebhookAsync started for job {} (projectId={}, event={})", job.getExternalId(), projectId, payload.eventType()); - // Delegate to transactional method via self-reference to ensure proxy is used - self.processWebhookInTransaction(provider, projectId, payload, handler, job); + try { + // Delegate to transactional method via self-reference to ensure proxy is used + self.processWebhookInTransaction(provider, projectId, payload, handler, job); + log.info("processWebhookAsync completed normally for job {}", job.getExternalId()); + } catch (Exception e) { + log.error("processWebhookAsync FAILED for job {}: {}", job.getExternalId(), e.getMessage(), e); + // Try to fail the job so it doesn't stay in PENDING + try { + jobService.failJob(job, "Async processing failed: " + e.getMessage()); + } catch (Exception failError) { + log.error("Failed to mark job {} as failed: {}", job.getExternalId(), failError.getMessage()); + } + } } /** @@ -120,6 +131,7 @@ public void processWebhookInTransaction( WebhookHandler handler, Job job ) { + log.info("processWebhookInTransaction ENTERED for job {}", job.getExternalId()); String placeholderCommentId = null; Project project = null; @@ -131,7 +143,9 @@ public void processWebhookInTransaction( // Initialize lazy associations we'll need initializeProjectAssociations(project); + log.info("Calling jobService.startJob for job {}", job.getExternalId()); jobService.startJob(job); + log.info("jobService.startJob completed for job {}", job.getExternalId()); // Post placeholder comment immediately if this is a CodeCrow command on a PR if (payload.hasCodecrowCommand() && payload.pullRequestId() != null) { @@ -142,15 +156,17 @@ public void processWebhookInTransaction( final String finalPlaceholderCommentId = placeholderCommentId; // Create event consumer that logs to job + log.info("Calling handler.handle for job {}", job.getExternalId()); WebhookHandler.WebhookResult result = handler.handle(payload, project, event -> { String message = (String) event.getOrDefault("message", "Processing..."); String state = (String) event.getOrDefault("state", "processing"); jobService.info(job, state, message); }); + log.info("handler.handle completed for job {}, result status={}", job.getExternalId(), result.status()); // Check if the webhook was ignored (e.g., branch not matching pattern, analysis disabled) if ("ignored".equals(result.status())) { - log.info("Webhook ignored: {}", result.message()); + log.info("Webhook ignored for job {}: {}", job.getExternalId(), result.message()); // Delete placeholder if we posted one for an ignored command if (finalPlaceholderCommentId != null) { deletePlaceholderComment(provider, project, payload, finalPlaceholderCommentId); @@ -158,7 +174,9 @@ public void processWebhookInTransaction( // Delete the job entirely - don't clutter DB with ignored webhooks // If deletion fails, skip the job instead try { + log.info("Deleting ignored job {}", job.getExternalId()); jobService.deleteIgnoredJob(job, result.message()); + log.info("Successfully deleted ignored job {}", job.getExternalId()); } catch (Exception deleteError) { log.warn("Failed to delete ignored job {}, skipping instead: {}", job.getExternalId(), deleteError.getMessage()); @@ -179,7 +197,9 @@ public void processWebhookInTransaction( Long analysisId = ((Number) result.data().get("analysisId")).longValue(); jobService.info(job, "complete", "Analysis completed. Analysis ID: " + analysisId); } + log.info("Calling jobService.completeJob for job {}", job.getExternalId()); jobService.completeJob(job); + log.info("jobService.completeJob completed for job {}", job.getExternalId()); } else { // Post error to VCS (update placeholder if exists) - but ensure failJob is always called try { From fdcdca0aac15e47ba8f99826d810384a55cfa854 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 14:24:23 +0200 Subject: [PATCH 11/20] feat: Implement webhook deduplication service to prevent duplicate commit analysis --- frontend | 2 +- .../BitbucketCloudBranchWebhookHandler.java | 18 +++- .../pipelineagent/config/AsyncConfig.java | 11 ++- .../controller/ProviderWebhookController.java | 5 ++ .../service/WebhookDeduplicationService.java | 86 +++++++++++++++++++ 5 files changed, 119 insertions(+), 3 deletions(-) create mode 100644 java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/service/WebhookDeduplicationService.java diff --git a/frontend b/frontend index fdbb0555..d97b8264 160000 --- a/frontend +++ b/frontend @@ -1 +1 @@ -Subproject commit fdbb055524794f49a0299fd7f020177243855e58 +Subproject commit d97b826464e13edb3da8a2b9a3e4b32680f23001 diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudBranchWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudBranchWebhookHandler.java index 81dd08c9..4b6a4a9e 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudBranchWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/webhookhandler/BitbucketCloudBranchWebhookHandler.java @@ -7,6 +7,7 @@ import org.rostilos.codecrow.analysisengine.dto.request.processor.BranchProcessRequest; import org.rostilos.codecrow.analysisengine.processor.analysis.BranchAnalysisProcessor; import org.rostilos.codecrow.pipelineagent.generic.dto.webhook.WebhookPayload; +import org.rostilos.codecrow.pipelineagent.generic.service.WebhookDeduplicationService; import org.rostilos.codecrow.pipelineagent.generic.webhookhandler.AbstractWebhookHandler; import org.rostilos.codecrow.pipelineagent.generic.webhookhandler.WebhookHandler; import org.slf4j.Logger; @@ -38,13 +39,16 @@ public class BitbucketCloudBranchWebhookHandler extends AbstractWebhookHandler i private final BranchAnalysisProcessor branchAnalysisProcessor; private final RagOperationsService ragOperationsService; + private final WebhookDeduplicationService deduplicationService; public BitbucketCloudBranchWebhookHandler( BranchAnalysisProcessor branchAnalysisProcessor, - @Autowired(required = false) RagOperationsService ragOperationsService + @Autowired(required = false) RagOperationsService ragOperationsService, + WebhookDeduplicationService deduplicationService ) { this.branchAnalysisProcessor = branchAnalysisProcessor; this.ragOperationsService = ragOperationsService; + this.deduplicationService = deduplicationService; } @Override @@ -87,6 +91,12 @@ public WebhookResult handle(WebhookPayload payload, Project project, Consumer { + log.error("WEBHOOK EXECUTOR REJECTED TASK! Queue is full. Pool size: {}, Active: {}, Queue size: {}", + e.getPoolSize(), e.getActiveCount(), e.getQueue().size()); + // Try to run in caller thread as fallback + if (!e.isShutdown()) { + r.run(); + } + }); executor.initialize(); - log.info("Webhook executor initialized with core={}, max={}", 4, 8); + log.info("Webhook executor initialized with core={}, max={}, queueCapacity={}", 4, 8, 100); return executor; } diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/controller/ProviderWebhookController.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/controller/ProviderWebhookController.java index d475014f..95ecfc0a 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/controller/ProviderWebhookController.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/controller/ProviderWebhookController.java @@ -260,6 +260,9 @@ private ResponseEntity processWebhook(EVcsProvider provider, WebhookPayload p String jobUrl = buildJobUrl(project, job); String logsStreamUrl = buildJobLogsStreamUrl(job); + log.info("Dispatching webhook to async processor: job={}, event={}", + job.getExternalId(), payload.eventType()); + // Process webhook asynchronously with proper transactional context webhookAsyncProcessor.processWebhookAsync( provider, @@ -269,6 +272,8 @@ private ResponseEntity processWebhook(EVcsProvider provider, WebhookPayload p job ); + log.info("Webhook dispatched to async processor: job={}", job.getExternalId()); + return ResponseEntity.accepted().body(Map.of( "status", "accepted", "message", "Webhook received, processing started", diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/service/WebhookDeduplicationService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/service/WebhookDeduplicationService.java new file mode 100644 index 00000000..a7f87088 --- /dev/null +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/service/WebhookDeduplicationService.java @@ -0,0 +1,86 @@ +package org.rostilos.codecrow.pipelineagent.generic.service; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Service; + +import java.time.Instant; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Service to deduplicate webhook events based on commit hash. + * + * When a PR is merged in Bitbucket, it sends both: + * - pullrequest:fulfilled (with merge commit) + * - repo:push (with same merge commit) + * + * Both events would trigger analysis for the same commit, causing duplicate processing. + * This service tracks recently analyzed commits and skips duplicates within a time window. + */ +@Service +public class WebhookDeduplicationService { + + private static final Logger log = LoggerFactory.getLogger(WebhookDeduplicationService.class); + + /** + * Time window in seconds to consider events as duplicates. + */ + private static final long DEDUP_WINDOW_SECONDS = 30; + + /** + * Cache of recently analyzed commits. + * Key: "projectId:commitHash" + * Value: timestamp when the analysis was triggered + */ + private final Map recentCommitAnalyses = new ConcurrentHashMap<>(); + + /** + * Check if a commit analysis should be skipped as a duplicate. + * If not a duplicate, records this commit for future deduplication. + * + * @param projectId The project ID + * @param commitHash The commit being analyzed + * @param eventType The webhook event type (for logging) + * @return true if this is a duplicate and should be skipped, false if it should proceed + */ + public boolean isDuplicateCommitAnalysis(Long projectId, String commitHash, String eventType) { + if (commitHash == null || commitHash.isBlank()) { + return false; + } + + String key = projectId + ":" + commitHash; + Instant now = Instant.now(); + + Instant lastAnalysis = recentCommitAnalyses.get(key); + + if (lastAnalysis != null) { + long secondsSinceLastAnalysis = now.getEpochSecond() - lastAnalysis.getEpochSecond(); + + if (secondsSinceLastAnalysis < DEDUP_WINDOW_SECONDS) { + log.info("Skipping duplicate commit analysis: project={}, commit={}, event={}, " + + "lastAnalysis={}s ago (within {}s window)", + projectId, commitHash, eventType, secondsSinceLastAnalysis, DEDUP_WINDOW_SECONDS); + return true; + } + } + + // Record this analysis + recentCommitAnalyses.put(key, now); + + // Cleanup old entries + cleanupOldEntries(now); + + return false; + } + + /** + * Remove entries older than the dedup window to prevent memory growth. + */ + private void cleanupOldEntries(Instant now) { + recentCommitAnalyses.entrySet().removeIf(entry -> { + long age = now.getEpochSecond() - entry.getValue().getEpochSecond(); + return age > DEDUP_WINDOW_SECONDS * 2; + }); + } +} From e3213617d76e45d1d15cdbec1bc89452420f06a3 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 14:35:33 +0200 Subject: [PATCH 12/20] feat: Enhance job deletion process with logging and persistence context management --- .../rostilos/codecrow/core/service/JobService.java | 3 +++ .../pipeline-agent/src/main/java/module-info.java | 1 + .../generic/processor/WebhookAsyncProcessor.java | 11 +++++++++++ 3 files changed, 15 insertions(+) diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java index 29148529..39a9aca2 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java @@ -338,8 +338,11 @@ public void deleteIgnoredJob(Job job, String reason) { } // Delete any logs first (foreign key constraint) jobLogRepository.deleteByJobId(jobId); + log.info("Deleted job logs for ignored job {}", job.getExternalId()); // Delete the job jobRepository.delete(existingJob.get()); + jobRepository.flush(); // Force immediate execution + log.info("Successfully deleted ignored job {} from database", job.getExternalId()); } /** diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/module-info.java b/java-ecosystem/services/pipeline-agent/src/main/java/module-info.java index 739f411f..3188b03a 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/module-info.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/module-info.java @@ -9,6 +9,7 @@ requires spring.beans; requires org.slf4j; requires jakarta.validation; + requires jakarta.persistence; requires spring.web; requires jjwt.api; requires okhttp3; diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index 819855ee..f374edea 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -13,6 +13,8 @@ import org.rostilos.codecrow.analysisengine.service.vcs.VcsServiceFactory; import org.slf4j.Logger; +import jakarta.persistence.EntityManager; +import jakarta.persistence.PersistenceContext; import java.io.IOException; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -73,6 +75,9 @@ public class WebhookAsyncProcessor { private final JobService jobService; private final VcsServiceFactory vcsServiceFactory; + @PersistenceContext + private EntityManager entityManager; + // Self-injection for @Transactional proxy to work from @Async method @Autowired @Lazy @@ -176,6 +181,12 @@ public void processWebhookInTransaction( try { log.info("Deleting ignored job {}", job.getExternalId()); jobService.deleteIgnoredJob(job, result.message()); + // CRITICAL: Detach the job from this transaction's persistence context + // to prevent JPA from re-saving it when the outer transaction commits + if (entityManager.contains(job)) { + entityManager.detach(job); + log.info("Detached deleted job {} from persistence context", job.getExternalId()); + } log.info("Successfully deleted ignored job {}", job.getExternalId()); } catch (Exception deleteError) { log.warn("Failed to delete ignored job {}, skipping instead: {}", From ebd0fad92d106dec8dbf6cfd550c1c74ca3f26b2 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 14:43:36 +0200 Subject: [PATCH 13/20] feat: Improve job deletion process with enhanced logging and error handling --- .../codecrow/core/service/JobService.java | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java index 39a9aca2..31840dcf 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java @@ -336,13 +336,21 @@ public void deleteIgnoredJob(Job job, String reason) { log.warn("Cannot delete ignored job {} - not found in database", job.getExternalId()); return; } - // Delete any logs first (foreign key constraint) - jobLogRepository.deleteByJobId(jobId); - log.info("Deleted job logs for ignored job {}", job.getExternalId()); - // Delete the job - jobRepository.delete(existingJob.get()); - jobRepository.flush(); // Force immediate execution - log.info("Successfully deleted ignored job {} from database", job.getExternalId()); + try { + // Delete any logs first (foreign key constraint) + jobLogRepository.deleteByJobId(jobId); + jobLogRepository.flush(); + log.info("Deleted job logs for ignored job {}", job.getExternalId()); + + // Delete the job + log.info("About to delete job entity {} (id={})", job.getExternalId(), jobId); + jobRepository.deleteById(jobId); + jobRepository.flush(); // Force immediate execution + log.info("Successfully deleted ignored job {} from database", job.getExternalId()); + } catch (Exception e) { + log.error("Failed to delete ignored job {}: {} - {}", job.getExternalId(), e.getClass().getSimpleName(), e.getMessage(), e); + throw e; // Re-throw so caller knows deletion failed + } } /** From 092b36138bf1dabf66c74f3d53a7dc768921aa31 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 14:56:56 +0200 Subject: [PATCH 14/20] feat: Add method to delete job by ID in JobRepository and update JobService for direct deletion --- .../repository/job/JobRepository.java | 4 ++++ .../codecrow/core/service/JobService.java | 18 ++++++------------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/job/JobRepository.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/job/JobRepository.java index 6dd799b5..b7f75a97 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/job/JobRepository.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/persistence/repository/job/JobRepository.java @@ -101,4 +101,8 @@ Page findByProjectIdAndDateRange( @Modifying @Query("DELETE FROM Job j WHERE j.project.id = :projectId") void deleteByProjectId(@Param("projectId") Long projectId); + + @Modifying + @Query("DELETE FROM Job j WHERE j.id = :jobId") + void deleteJobById(@Param("jobId") Long jobId); } diff --git a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java index 31840dcf..db182728 100644 --- a/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java +++ b/java-ecosystem/libs/core/src/main/java/org/rostilos/codecrow/core/service/JobService.java @@ -325,31 +325,25 @@ public Job skipJob(Job job, String reason) { @Transactional(propagation = org.springframework.transaction.annotation.Propagation.REQUIRES_NEW) public void deleteIgnoredJob(Job job, String reason) { log.info("Deleting ignored job {} ({}): {}", job.getExternalId(), job.getJobType(), reason); - // Re-fetch the job to ensure we have a fresh entity in this new transaction Long jobId = job.getId(); if (jobId == null) { log.warn("Cannot delete ignored job - job ID is null"); return; } - Optional existingJob = jobRepository.findById(jobId); - if (existingJob.isEmpty()) { - log.warn("Cannot delete ignored job {} - not found in database", job.getExternalId()); - return; - } + try { - // Delete any logs first (foreign key constraint) + // Use direct JPQL queries to avoid JPA entity lifecycle issues + // Delete logs first (foreign key constraint) jobLogRepository.deleteByJobId(jobId); - jobLogRepository.flush(); log.info("Deleted job logs for ignored job {}", job.getExternalId()); - // Delete the job + // Delete the job using direct JPQL query (bypasses entity state tracking) log.info("About to delete job entity {} (id={})", job.getExternalId(), jobId); - jobRepository.deleteById(jobId); - jobRepository.flush(); // Force immediate execution + jobRepository.deleteJobById(jobId); log.info("Successfully deleted ignored job {} from database", job.getExternalId()); } catch (Exception e) { log.error("Failed to delete ignored job {}: {} - {}", job.getExternalId(), e.getClass().getSimpleName(), e.getMessage(), e); - throw e; // Re-throw so caller knows deletion failed + throw e; } } From 61d2620c9bb1977d7e52d3f8d4c38918c0be4d93 Mon Sep 17 00:00:00 2001 From: rostislav Date: Tue, 27 Jan 2026 15:04:14 +0200 Subject: [PATCH 15/20] feat: Simplify job handling by marking ignored jobs as SKIPPED instead of deleting --- .../processor/WebhookAsyncProcessor.java | 30 +++---------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index f374edea..b1144c53 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -13,8 +13,6 @@ import org.rostilos.codecrow.analysisengine.service.vcs.VcsServiceFactory; import org.slf4j.Logger; -import jakarta.persistence.EntityManager; -import jakarta.persistence.PersistenceContext; import java.io.IOException; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -75,9 +73,6 @@ public class WebhookAsyncProcessor { private final JobService jobService; private final VcsServiceFactory vcsServiceFactory; - @PersistenceContext - private EntityManager entityManager; - // Self-injection for @Transactional proxy to work from @Async method @Autowired @Lazy @@ -176,27 +171,10 @@ public void processWebhookInTransaction( if (finalPlaceholderCommentId != null) { deletePlaceholderComment(provider, project, payload, finalPlaceholderCommentId); } - // Delete the job entirely - don't clutter DB with ignored webhooks - // If deletion fails, skip the job instead - try { - log.info("Deleting ignored job {}", job.getExternalId()); - jobService.deleteIgnoredJob(job, result.message()); - // CRITICAL: Detach the job from this transaction's persistence context - // to prevent JPA from re-saving it when the outer transaction commits - if (entityManager.contains(job)) { - entityManager.detach(job); - log.info("Detached deleted job {} from persistence context", job.getExternalId()); - } - log.info("Successfully deleted ignored job {}", job.getExternalId()); - } catch (Exception deleteError) { - log.warn("Failed to delete ignored job {}, skipping instead: {}", - job.getExternalId(), deleteError.getMessage()); - try { - jobService.skipJob(job, result.message()); - } catch (Exception skipError) { - log.error("Failed to skip job {}: {}", job.getExternalId(), skipError.getMessage()); - } - } + // Mark job as SKIPPED - simpler and more reliable than deletion + // which can have transaction/lock issues with concurrent requests + jobService.skipJob(job, result.message()); + log.info("Marked ignored job {} as SKIPPED", job.getExternalId()); return; } From 704a7a256429f514695d8266af9ca7e9724ed45a Mon Sep 17 00:00:00 2001 From: rostislav Date: Wed, 28 Jan 2026 02:42:43 +0200 Subject: [PATCH 16/20] feat: Enhance AI connection logging and refactor placeholder management in webhook processing --- .../service/BitbucketAiClientService.java | 4 + .../processor/WebhookAsyncProcessor.java | 78 ++----------------- .../command/ReviewCommandProcessor.java | 5 ++ .../generic/utils/CommentPlaceholders.java | 66 ++++++++++++++++ .../AbstractWebhookHandler.java | 1 - .../github/service/GitHubAiClientService.java | 4 + .../gitlab/service/GitLabAiClientService.java | 4 + .../project/service/ProjectService.java | 25 ++++-- 8 files changed, 111 insertions(+), 76 deletions(-) create mode 100644 java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/utils/CommentPlaceholders.java diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java index 142cba75..6ebb0894 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/bitbucket/service/BitbucketAiClientService.java @@ -124,6 +124,10 @@ public AiAnalysisRequest buildPrAnalysisRequest( VcsConnection vcsConnection = vcsInfo.vcsConnection(); AIConnection aiConnection = project.getAiBinding().getAiConnection(); AIConnection projectAiConnection = project.getAiBinding().getAiConnection(); + + // CRITICAL: Log the AI connection being used for debugging + log.info("Building PR analysis request for project={}, AI model={}, provider={}, aiConnectionId={}", + project.getId(), aiConnection.getAiModel(), aiConnection.getProviderKey(), aiConnection.getId()); // Initialize variables List changedFiles = Collections.emptyList(); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java index b1144c53..d509c93c 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/WebhookAsyncProcessor.java @@ -6,6 +6,7 @@ import org.rostilos.codecrow.core.persistence.repository.project.ProjectRepository; import org.rostilos.codecrow.core.service.JobService; import org.rostilos.codecrow.pipelineagent.generic.dto.webhook.WebhookPayload; +import org.rostilos.codecrow.pipelineagent.generic.utils.CommentPlaceholders; import org.rostilos.codecrow.pipelineagent.generic.webhookhandler.WebhookHandler; import org.rostilos.codecrow.analysisengine.exception.AnalysisLockedException; import org.rostilos.codecrow.analysisengine.exception.DiffTooLargeException; @@ -30,45 +31,6 @@ public class WebhookAsyncProcessor { private static final Logger log = LoggerFactory.getLogger(WebhookAsyncProcessor.class); - /** Comment markers for CodeCrow command responses */ - private static final String CODECROW_COMMAND_MARKER = ""; - private static final String CODECROW_SUMMARY_MARKER = ""; - private static final String CODECROW_REVIEW_MARKER = ""; - - /** Placeholder messages for commands */ - private static final String PLACEHOLDER_ANALYZE = """ - 🔄 **CodeCrow is analyzing this PR...** - - This may take a few minutes depending on the size of the changes. - I'll update this comment with the results when the analysis is complete. - """; - - private static final String PLACEHOLDER_SUMMARIZE = """ - 🔄 **CodeCrow is generating a summary...** - - I'm analyzing the changes and creating diagrams. - This comment will be updated with the summary when ready. - """; - - private static final String PLACEHOLDER_REVIEW = """ - 🔄 **CodeCrow is reviewing this PR...** - - I'm examining the code changes for potential issues. - This comment will be updated with the review results when complete. - """; - - private static final String PLACEHOLDER_ASK = """ - 🔄 **CodeCrow is processing your question...** - - I'm analyzing the context to provide a helpful answer. - """; - - private static final String PLACEHOLDER_DEFAULT = """ - 🔄 **CodeCrow is processing...** - - Please wait while I complete this task. - """; - private final ProjectRepository projectRepository; private final JobService jobService; private final VcsServiceFactory vcsServiceFactory; @@ -419,7 +381,7 @@ private void postAskReply(VcsReportingService reportingService, Project project, private void postWithMarker(VcsReportingService reportingService, Project project, WebhookPayload payload, String content, String commandType, String placeholderCommentId, Job job) throws IOException { - String marker = getMarkerForCommandType(commandType); + String marker = CommentPlaceholders.getMarkerForCommandType(commandType); // If we have a placeholder comment, update it instead of creating a new one if (placeholderCommentId != null) { @@ -490,7 +452,7 @@ private void postErrorToVcs(EVcsProvider provider, Project project, WebhookPaylo Long.parseLong(payload.pullRequestId()), placeholderCommentId, content, - CODECROW_COMMAND_MARKER + CommentPlaceholders.CODECROW_COMMAND_MARKER ); log.info("Updated placeholder comment {} with error for PR {}", placeholderCommentId, payload.pullRequestId()); } else { @@ -499,7 +461,7 @@ private void postErrorToVcs(EVcsProvider provider, Project project, WebhookPaylo project, Long.parseLong(payload.pullRequestId()), content, - CODECROW_COMMAND_MARKER + CommentPlaceholders.CODECROW_COMMAND_MARKER ); log.info("Posted error to PR {}", payload.pullRequestId()); } @@ -529,7 +491,7 @@ private void postInfoToVcs(EVcsProvider provider, Project project, WebhookPayloa Long.parseLong(payload.pullRequestId()), placeholderCommentId, infoMessage, - CODECROW_COMMAND_MARKER + CommentPlaceholders.CODECROW_COMMAND_MARKER ); log.info("Updated placeholder comment {} with info message for PR {}", placeholderCommentId, payload.pullRequestId()); } else { @@ -538,7 +500,7 @@ private void postInfoToVcs(EVcsProvider provider, Project project, WebhookPayloa project, Long.parseLong(payload.pullRequestId()), infoMessage, - CODECROW_COMMAND_MARKER + CommentPlaceholders.CODECROW_COMMAND_MARKER ); log.info("Posted info message to PR {}", payload.pullRequestId()); } @@ -641,8 +603,8 @@ private String postPlaceholderComment(EVcsProvider provider, Project project, ? payload.getCodecrowCommand().type().name().toLowerCase() : "default"; - String placeholderContent = getPlaceholderMessage(commandType); - String marker = getMarkerForCommandType(commandType); + String placeholderContent = CommentPlaceholders.getPlaceholderMessage(commandType); + String marker = CommentPlaceholders.getMarkerForCommandType(commandType); // Delete any previous comments with the same marker before posting placeholder try { @@ -691,28 +653,4 @@ private void deletePlaceholderComment(EVcsProvider provider, Project project, log.warn("Failed to delete placeholder comment {}: {}", commentId, e.getMessage()); } } - - /** - * Get the placeholder message for a command type. - */ - private String getPlaceholderMessage(String commandType) { - return switch (commandType.toLowerCase()) { - case "analyze" -> PLACEHOLDER_ANALYZE; - case "summarize" -> PLACEHOLDER_SUMMARIZE; - case "review" -> PLACEHOLDER_REVIEW; - case "ask" -> PLACEHOLDER_ASK; - default -> PLACEHOLDER_DEFAULT; - }; - } - - /** - * Get the comment marker for a command type. - */ - private String getMarkerForCommandType(String commandType) { - return switch (commandType.toLowerCase()) { - case "summarize" -> CODECROW_SUMMARY_MARKER; - case "review" -> CODECROW_REVIEW_MARKER; - default -> CODECROW_COMMAND_MARKER; - }; - } } diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java index 8d381e79..92a00954 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/processor/command/ReviewCommandProcessor.java @@ -154,6 +154,11 @@ private ReviewRequest buildReviewRequest(Project project, WebhookPayload payload } AIConnection aiConnection = project.getAiBinding().getAiConnection(); + + // CRITICAL: Log the AI connection being used for debugging + log.info("Building review command request for project={}, AI model={}, provider={}, aiConnectionId={}", + project.getId(), aiConnection.getAiModel(), aiConnection.getProviderKey(), aiConnection.getId()); + String decryptedApiKey = tokenEncryptionService.decrypt(aiConnection.getApiKeyEncrypted()); // Get VCS credentials using centralized extractor diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/utils/CommentPlaceholders.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/utils/CommentPlaceholders.java new file mode 100644 index 00000000..1145895a --- /dev/null +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/utils/CommentPlaceholders.java @@ -0,0 +1,66 @@ +package org.rostilos.codecrow.pipelineagent.generic.utils; + +public class CommentPlaceholders { + /** Comment markers for CodeCrow command responses */ + public static final String CODECROW_COMMAND_MARKER = ""; + public static final String CODECROW_SUMMARY_MARKER = ""; + public static final String CODECROW_REVIEW_MARKER = ""; + + /** Placeholder messages for commands */ + public static final String PLACEHOLDER_ANALYZE = """ + 🔄 **CodeCrow is analyzing this PR...** + + This may take a few minutes depending on the size of the changes. + I'll update this comment with the results when the analysis is complete. + """; + + public static final String PLACEHOLDER_SUMMARIZE = """ + 🔄 **CodeCrow is generating a summary...** + + I'm analyzing the changes and creating diagrams. + This comment will be updated with the summary when ready. + """; + + public static final String PLACEHOLDER_REVIEW = """ + 🔄 **CodeCrow is reviewing this PR...** + + I'm examining the code changes for potential issues. + This comment will be updated with the review results when complete. + """; + + public static final String PLACEHOLDER_ASK = """ + 🔄 **CodeCrow is processing your question...** + + I'm analyzing the context to provide a helpful answer. + """; + + public static final String PLACEHOLDER_DEFAULT = """ + 🔄 **CodeCrow is processing...** + + Please wait while I complete this task. + """; + + /** + * Get the placeholder message for a command type. + */ + public static String getPlaceholderMessage(String commandType) { + return switch (commandType.toLowerCase()) { + case "analyze" -> PLACEHOLDER_ANALYZE; + case "summarize" -> PLACEHOLDER_SUMMARIZE; + case "review" -> PLACEHOLDER_REVIEW; + case "ask" -> PLACEHOLDER_ASK; + default -> PLACEHOLDER_DEFAULT; + }; + } + + /** + * Get the comment marker for a command type. + */ + public static String getMarkerForCommandType(String commandType) { + return switch (commandType.toLowerCase()) { + case "summarize" -> CODECROW_SUMMARY_MARKER; + case "review" -> CODECROW_REVIEW_MARKER; + default -> CODECROW_COMMAND_MARKER; + }; + } +} diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/AbstractWebhookHandler.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/AbstractWebhookHandler.java index 989c4438..e4f5e0e3 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/AbstractWebhookHandler.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/generic/webhookhandler/AbstractWebhookHandler.java @@ -3,7 +3,6 @@ import org.rostilos.codecrow.core.model.codeanalysis.AnalysisType; import org.rostilos.codecrow.core.model.project.Project; import org.rostilos.codecrow.core.model.project.config.BranchAnalysisConfig; -import org.rostilos.codecrow.core.model.project.config.ProjectConfig; import org.rostilos.codecrow.core.util.BranchPatternMatcher; import java.util.List; diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java index f8f0dc09..69193350 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/github/service/GitHubAiClientService.java @@ -118,6 +118,10 @@ private AiAnalysisRequest buildPrAnalysisRequest( VcsInfo vcsInfo = getVcsInfo(project); VcsConnection vcsConnection = vcsInfo.vcsConnection(); AIConnection aiConnection = project.getAiBinding().getAiConnection(); + + // CRITICAL: Log the AI connection being used for debugging + log.info("Building PR analysis request for project={}, AI model={}, provider={}, aiConnectionId={}", + project.getId(), aiConnection.getAiModel(), aiConnection.getProviderKey(), aiConnection.getId()); // Initialize variables List changedFiles = Collections.emptyList(); diff --git a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java index ed2be10a..fd736acf 100644 --- a/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java +++ b/java-ecosystem/services/pipeline-agent/src/main/java/org/rostilos/codecrow/pipelineagent/gitlab/service/GitLabAiClientService.java @@ -118,6 +118,10 @@ private AiAnalysisRequest buildMrAnalysisRequest( VcsInfo vcsInfo = getVcsInfo(project); VcsConnection vcsConnection = vcsInfo.vcsConnection(); AIConnection aiConnection = project.getAiBinding().getAiConnection(); + + // CRITICAL: Log the AI connection being used for debugging + log.info("Building MR analysis request for project={}, AI model={}, provider={}, aiConnectionId={}", + project.getId(), aiConnection.getAiModel(), aiConnection.getProviderKey(), aiConnection.getId()); // Initialize variables List changedFiles = Collections.emptyList(); diff --git a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java index ac98db3a..f14718a9 100644 --- a/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java +++ b/java-ecosystem/services/web-server/src/main/java/org/rostilos/codecrow/webserver/project/service/ProjectService.java @@ -375,18 +375,33 @@ public void updateRepositorySettings(Long workspaceId, Long projectId, UpdateRep @Transactional public boolean bindAiConnection(Long workspaceId, Long projectId, BindAiConnectionRequest request) throws SecurityException { - Project project = projectRepository.findByWorkspaceIdAndId(workspaceId, projectId) + // Use findByIdWithConnections to eagerly fetch aiBinding for proper orphan removal + Project project = projectRepository.findByIdWithConnections(projectId) .orElseThrow(() -> new NoSuchElementException("Project not found")); + + // Verify workspace ownership + if (!project.getWorkspace().getId().equals(workspaceId)) { + throw new NoSuchElementException("Project not found in workspace"); + } if (request.getAiConnectionId() != null) { Long aiConnectionId = request.getAiConnectionId(); AIConnection aiConnection = aiConnectionRepository.findByWorkspace_IdAndId(workspaceId, aiConnectionId) .orElseThrow(() -> new NoSuchElementException("Ai connection not found")); - ProjectAiConnectionBinding aiConnectionBinding = new ProjectAiConnectionBinding(); - aiConnectionBinding.setProject(project); - aiConnectionBinding.setAiConnection(aiConnection); - project.setAiConnectionBinding(aiConnectionBinding); + // Check if there's an existing binding that needs to be updated + ProjectAiConnectionBinding existingBinding = project.getAiBinding(); + if (existingBinding != null) { + // Update existing binding instead of creating new one + existingBinding.setAiConnection(aiConnection); + } else { + // Create new binding + ProjectAiConnectionBinding aiConnectionBinding = new ProjectAiConnectionBinding(); + aiConnectionBinding.setProject(project); + aiConnectionBinding.setAiConnection(aiConnection); + project.setAiConnectionBinding(aiConnectionBinding); + } + projectRepository.save(project); return true; } From 2e42ebc3d7ca43a21adca084d7eb2e957e319b1b Mon Sep 17 00:00:00 2001 From: rostislav Date: Wed, 28 Jan 2026 02:43:30 +0200 Subject: [PATCH 17/20] feat: Add logging for LLM creation and enhance diff snippet extraction for RAG context --- .../mcp-client/llm/llm_factory.py | 3 + .../service/multi_stage_orchestrator.py | 196 ++++++++++++------ .../mcp-client/service/review_service.py | 3 + 3 files changed, 136 insertions(+), 66 deletions(-) diff --git a/python-ecosystem/mcp-client/llm/llm_factory.py b/python-ecosystem/mcp-client/llm/llm_factory.py index 470e7c4e..61b14e85 100644 --- a/python-ecosystem/mcp-client/llm/llm_factory.py +++ b/python-ecosystem/mcp-client/llm/llm_factory.py @@ -139,6 +139,9 @@ def create_llm(ai_model: str, ai_provider: str, ai_api_key: str, temperature: Op # Normalize provider provider = LLMFactory._normalize_provider(ai_provider) + # CRITICAL: Log the model being used for debugging + logger.info(f"Creating LLM instance: provider={provider}, model={ai_model}, temperature={temperature}") + # Check for unsupported Gemini thinking models (applies to all providers) LLMFactory._check_unsupported_gemini_model(ai_model) diff --git a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py index f58c702e..e4b9846b 100644 --- a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py +++ b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py @@ -399,6 +399,44 @@ def _format_previous_issues_for_batch(self, issues: List[Any]) -> str: lines.append("=== END PREVIOUS ISSUES ===") return "\n".join(lines) + def _extract_diff_snippets(self, diff_content: str) -> List[str]: + """ + Extract meaningful code snippets from diff content for RAG semantic search. + Focuses on added/modified lines that represent significant code changes. + """ + if not diff_content: + return [] + + snippets = [] + current_snippet_lines = [] + + for line in diff_content.split("\n"): + # Focus on added lines (new code) + if line.startswith("+") and not line.startswith("+++"): + clean_line = line[1:].strip() + # Skip trivial lines + if (clean_line and + len(clean_line) > 10 and # Minimum meaningful length + not clean_line.startswith("//") and # Skip comments + not clean_line.startswith("#") and + not clean_line.startswith("*") and + not clean_line == "{" and + not clean_line == "}" and + not clean_line == ""): + current_snippet_lines.append(clean_line) + + # Batch into snippets of 3-5 lines + if len(current_snippet_lines) >= 3: + snippets.append(" ".join(current_snippet_lines)) + current_snippet_lines = [] + + # Add remaining lines as final snippet + if current_snippet_lines: + snippets.append(" ".join(current_snippet_lines)) + + # Limit to most significant snippets + return snippets[:10] + def _get_diff_snippets_for_batch( self, all_diff_snippets: List[str], @@ -536,6 +574,48 @@ async def _execute_stage_1_file_reviews( logger.info(f"Stage 1 Complete: {len(all_issues)} issues found across {total_files} files") return all_issues + async def _fetch_batch_rag_context( + self, + request: ReviewRequestDto, + batch_file_paths: List[str], + batch_diff_snippets: List[str] + ) -> Optional[Dict[str, Any]]: + """ + Fetch RAG context specifically for this batch of files. + Uses batch file paths and diff snippets for targeted semantic search. + """ + if not self.rag_client: + return None + + try: + # Determine branch for RAG query + rag_branch = request.targetBranchName or request.commitHash or "main" + + logger.info(f"Fetching per-batch RAG context for {len(batch_file_paths)} files") + + rag_response = await self.rag_client.get_pr_context( + workspace=request.projectWorkspace, + project=request.projectNamespace, + branch=rag_branch, + changed_files=batch_file_paths, + diff_snippets=batch_diff_snippets, + pr_title=request.prTitle, + pr_description=request.prDescription, + top_k=10 # Fewer chunks per batch for focused context + ) + + if rag_response and rag_response.get("context"): + context = rag_response.get("context") + chunk_count = len(context.get("relevant_code", [])) + logger.info(f"Per-batch RAG: retrieved {chunk_count} chunks for files {batch_file_paths}") + return context + + return None + + except Exception as e: + logger.warning(f"Failed to fetch per-batch RAG context: {e}") + return None + async def _review_file_batch( self, request: ReviewRequestDto, @@ -550,6 +630,7 @@ async def _review_file_batch( """ batch_files_data = [] batch_file_paths = [] + batch_diff_snippets = [] project_rules = "1. No hardcoded secrets.\n2. Use dependency injection.\n3. Verify all inputs." # For incremental mode, use deltaDiff instead of full diff @@ -560,7 +641,7 @@ async def _review_file_batch( else: diff_source = processed_diff - # Collect file paths and diffs for this batch + # Collect file paths, diffs, and extract snippets for this batch for item in batch_items: file_info = item["file"] batch_file_paths.append(file_info.path) @@ -571,6 +652,9 @@ async def _review_file_batch( for f in diff_source.files: if f.path == file_info.path or f.path.endswith("/" + file_info.path): file_diff = f.content + # Extract code snippets from diff for RAG semantic search + if file_diff: + batch_diff_snippets.extend(self._extract_diff_snippets(file_diff)) break batch_files_data.append({ @@ -582,18 +666,32 @@ async def _review_file_batch( "is_incremental": is_incremental # Pass mode to prompt builder }) - # Use initial RAG context (already fetched with all files/snippets) - # The initial query is more comprehensive - it uses ALL changed files and snippets - # Per-batch filtering is done in _format_rag_context via relevant_files param + # Fetch per-batch RAG context using batch-specific files and diff snippets rag_context_text = "" - if fallback_rag_context: - logger.info(f"Using initial RAG context for batch: {batch_file_paths}") + batch_rag_context = None + + if self.rag_client: + batch_rag_context = await self._fetch_batch_rag_context( + request, batch_file_paths, batch_diff_snippets + ) + + # Use batch-specific RAG context if available, otherwise fall back to initial context + if batch_rag_context: + logger.info(f"Using per-batch RAG context for: {batch_file_paths}") + rag_context_text = self._format_rag_context( + batch_rag_context, + set(batch_file_paths), + pr_changed_files=request.changedFiles + ) + elif fallback_rag_context: + logger.info(f"Using fallback RAG context for batch: {batch_file_paths}") rag_context_text = self._format_rag_context( fallback_rag_context, set(batch_file_paths), pr_changed_files=request.changedFiles ) - logger.info(f"RAG context for batch: {len(rag_context_text)} chars") + + logger.info(f"RAG context for batch: {len(rag_context_text)} chars") # For incremental mode, filter previous issues relevant to this batch # Also pass previous issues in FULL mode if they exist (subsequent PR iterations) @@ -882,13 +980,15 @@ def _format_rag_context( ) -> str: """ Format RAG context into a readable string for the prompt. - Includes rich AST metadata (imports, extends, implements) for better LLM context. + + IMPORTANT: We trust RAG's semantic similarity scores for relevance. + The RAG system already uses embeddings to find semantically related code. + We only filter out chunks from files being modified in the PR (stale data from main branch). Args: rag_context: RAG response with code chunks - relevant_files: Files in current batch to prioritize - pr_changed_files: ALL files modified in the PR - chunks from these files - are marked as potentially stale (from main branch, not PR branch) + relevant_files: (UNUSED - kept for API compatibility) - we trust RAG scores instead + pr_changed_files: Files modified in the PR - chunks from these may be stale """ if not rag_context: logger.debug("RAG context is empty or None") @@ -900,35 +1000,32 @@ def _format_rag_context( logger.debug("No chunks found in RAG context (keys: %s)", list(rag_context.keys())) return "" - logger.debug(f"Processing {len(chunks)} RAG chunks for context") - logger.debug(f"PR changed files for filtering: {pr_changed_files[:5] if pr_changed_files else 'none'}...") + logger.info(f"Processing {len(chunks)} RAG chunks (trusting semantic similarity scores)") - # Normalize PR changed files for comparison + # Normalize PR changed files for stale-data detection only pr_changed_set = set() if pr_changed_files: for f in pr_changed_files: pr_changed_set.add(f) - # Also add just the filename for matching if "/" in f: pr_changed_set.add(f.rsplit("/", 1)[-1]) formatted_parts = [] included_count = 0 - skipped_modified = 0 - skipped_relevance = 0 + skipped_stale = 0 + for chunk in chunks: - if included_count >= 15: # Increased from 10 for more context - logger.debug(f"Reached chunk limit of 15, stopping") + if included_count >= 15: + logger.debug(f"Reached chunk limit of 15") break - # Extract metadata metadata = chunk.get("metadata", {}) path = metadata.get("path", chunk.get("path", "unknown")) chunk_type = metadata.get("content_type", metadata.get("type", "code")) score = chunk.get("score", chunk.get("relevance_score", 0)) - # Check if this chunk is from a file being modified in the PR - is_from_modified_file = False + # Only filter: chunks from PR-modified files with LOW scores (likely stale) + # High-score chunks from modified files may still be relevant (other parts of same file) if pr_changed_set: path_filename = path.rsplit("/", 1)[-1] if "/" in path else path is_from_modified_file = ( @@ -936,30 +1033,11 @@ def _format_rag_context( path_filename in pr_changed_set or any(path.endswith(f) or f.endswith(path) for f in pr_changed_set) ) - - # For chunks from modified files: - # - Skip if very low relevance (score < 0.70) - likely not useful - # - Include if moderate+ relevance (score >= 0.70) - context is valuable - if is_from_modified_file: - if score < 0.70: - logger.debug(f"Skipping RAG chunk from modified file (low score): {path} (score={score})") - skipped_modified += 1 - continue - else: - logger.debug(f"Including RAG chunk from modified file (relevant): {path} (score={score})") - - # Optionally filter by relevance to batch files - if relevant_files: - # Include if the chunk's path relates to any batch file - is_relevant = any( - path in f or f in path or - path.rsplit("/", 1)[-1] == f.rsplit("/", 1)[-1] - for f in relevant_files - ) - # Also include chunks with moderate+ score regardless - if not is_relevant and score < 0.70: - logger.debug(f"Skipping RAG chunk (not relevant to batch and low score): {path} (score={score})") - skipped_relevance += 1 + + # Skip ONLY low-score chunks from modified files (likely stale/outdated) + if is_from_modified_file and score < 0.70: + logger.debug(f"Skipping stale chunk from modified file: {path} (score={score:.2f})") + skipped_stale += 1 continue text = chunk.get("text", chunk.get("content", "")) @@ -968,38 +1046,27 @@ def _format_rag_context( included_count += 1 - # Build rich metadata context from AST-extracted info - meta_lines = [] - meta_lines.append(f"File: {path}") + # Build rich metadata context + meta_lines = [f"File: {path}"] - # Include namespace/package if available if metadata.get("namespace"): meta_lines.append(f"Namespace: {metadata['namespace']}") elif metadata.get("package"): meta_lines.append(f"Package: {metadata['package']}") - # Include class/function name if metadata.get("primary_name"): meta_lines.append(f"Definition: {metadata['primary_name']}") elif metadata.get("semantic_names"): meta_lines.append(f"Definitions: {', '.join(metadata['semantic_names'][:5])}") - # Include inheritance info (extends, implements) if metadata.get("extends"): extends = metadata["extends"] - if isinstance(extends, list): - meta_lines.append(f"Extends: {', '.join(extends)}") - else: - meta_lines.append(f"Extends: {extends}") + meta_lines.append(f"Extends: {', '.join(extends) if isinstance(extends, list) else extends}") if metadata.get("implements"): implements = metadata["implements"] - if isinstance(implements, list): - meta_lines.append(f"Implements: {', '.join(implements)}") - else: - meta_lines.append(f"Implements: {implements}") + meta_lines.append(f"Implements: {', '.join(implements) if isinstance(implements, list) else implements}") - # Include imports (abbreviated if too many) if metadata.get("imports"): imports = metadata["imports"] if isinstance(imports, list): @@ -1008,13 +1075,11 @@ def _format_rag_context( else: meta_lines.append(f"Imports: {'; '.join(imports[:5])}... (+{len(imports)-5} more)") - # Include parent context (for nested methods) if metadata.get("parent_context"): parent_ctx = metadata["parent_context"] if isinstance(parent_ctx, list): meta_lines.append(f"Parent: {'.'.join(parent_ctx)}") - # Include content type for understanding chunk nature if chunk_type and chunk_type != "code": meta_lines.append(f"Type: {chunk_type}") @@ -1026,11 +1091,10 @@ def _format_rag_context( ) if not formatted_parts: - logger.warning(f"No RAG chunks included in prompt (total: {len(chunks)}, skipped_modified: {skipped_modified}, skipped_relevance: {skipped_relevance}). " - f"PR changed files: {pr_changed_files[:5] if pr_changed_files else 'none'}...") + logger.warning(f"No RAG chunks included (total: {len(chunks)}, skipped_stale: {skipped_stale})") return "" - logger.info(f"Included {len(formatted_parts)} RAG chunks in prompt context (total: {len(chunks)}, skipped: {skipped_modified} low-score modified, {skipped_relevance} low relevance)") + logger.info(f"Included {len(formatted_parts)} RAG chunks (skipped {skipped_stale} stale from modified files)") return "\n".join(formatted_parts) def _emit_status(self, state: str, message: str): diff --git a/python-ecosystem/mcp-client/service/review_service.py b/python-ecosystem/mcp-client/service/review_service.py index ef9bb770..ba57a085 100644 --- a/python-ecosystem/mcp-client/service/review_service.py +++ b/python-ecosystem/mcp-client/service/review_service.py @@ -393,6 +393,9 @@ def _create_mcp_client(self, config: Dict[str, Any]) -> MCPClient: def _create_llm(self, request: ReviewRequestDto): """Create LLM instance from request parameters and initialize reranker.""" try: + # Log the model being used for this request + logger.info(f"Creating LLM for project {request.projectId}: provider={request.aiProvider}, model={request.aiModel}") + llm = LLMFactory.create_llm( request.aiModel, request.aiProvider, From d036fa972b1efdc551da88bb1179045f76faffe0 Mon Sep 17 00:00:00 2001 From: rostislav Date: Wed, 28 Jan 2026 02:44:12 +0200 Subject: [PATCH 18/20] feat: Implement AST-based code splitter and scoring configuration - Added AST-based code splitter using Tree-sitter for accurate code parsing. - Introduced TreeSitterParser for dynamic language loading and caching. - Created scoring configuration for RAG query result reranking with configurable boost factors and priority patterns. - Refactored RAGQueryService to utilize the new scoring configuration for enhanced result ranking. - Improved metadata extraction and handling for better context in scoring. --- .../rag-pipeline/src/rag_pipeline/__init__.py | 7 +- .../src/rag_pipeline/core/__init__.py | 11 +- .../src/rag_pipeline/core/ast_splitter.py | 1401 ----------------- .../src/rag_pipeline/core/chunking.py | 171 -- .../src/rag_pipeline/core/index_manager.py | 31 +- .../core/index_manager/__init__.py | 10 + .../core/index_manager/branch_manager.py | 172 ++ .../core/index_manager/collection_manager.py | 164 ++ .../core/index_manager/indexer.py | 398 +++++ .../core/index_manager/manager.py | 290 ++++ .../core/index_manager/point_operations.py | 151 ++ .../core/index_manager/stats_manager.py | 156 ++ .../rag_pipeline/core/semantic_splitter.py | 455 ------ .../rag_pipeline/core/splitter/__init__.py | 53 + .../rag_pipeline/core/splitter/languages.py | 139 ++ .../rag_pipeline/core/splitter/metadata.py | 339 ++++ .../core/splitter/queries/c_sharp.scm | 56 + .../rag_pipeline/core/splitter/queries/go.scm | 26 + .../core/splitter/queries/java.scm | 45 + .../core/splitter/queries/javascript.scm | 42 + .../core/splitter/queries/php.scm | 40 + .../core/splitter/queries/python.scm | 28 + .../core/splitter/queries/rust.scm | 46 + .../core/splitter/queries/typescript.scm | 52 + .../core/splitter/query_runner.py | 360 +++++ .../rag_pipeline/core/splitter/splitter.py | 720 +++++++++ .../rag_pipeline/core/splitter/tree_parser.py | 129 ++ .../src/rag_pipeline/models/scoring_config.py | 232 +++ .../rag_pipeline/services/query_service.py | 75 +- 29 files changed, 3682 insertions(+), 2117 deletions(-) delete mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/ast_splitter.py delete mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/chunking.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/__init__.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/branch_manager.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/collection_manager.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/manager.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/stats_manager.py delete mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/semantic_splitter.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/__init__.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/languages.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/metadata.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/c_sharp.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/go.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/java.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/javascript.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/php.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/python.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/rust.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/typescript.scm create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/query_runner.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/tree_parser.py create mode 100644 python-ecosystem/rag-pipeline/src/rag_pipeline/models/scoring_config.py diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/__init__.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/__init__.py index 3dc56254..e0fc412d 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/__init__.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/__init__.py @@ -2,7 +2,7 @@ CodeCrow RAG Pipeline A RAG (Retrieval-Augmented Generation) pipeline for code repositories. -Provides indexing and querying capabilities for code using LlamaIndex and MongoDB. +Provides indexing and querying capabilities for code using LlamaIndex, Tree-sitter and Qdrant. """ __version__ = "1.0.0" @@ -11,7 +11,7 @@ from .core.index_manager import RAGIndexManager from .services.query_service import RAGQueryService from .core.loader import DocumentLoader -from .core.chunking import CodeAwareSplitter, FunctionAwareSplitter +from .core.splitter import ASTCodeSplitter from .utils.utils import make_namespace, detect_language_from_path __all__ = [ @@ -21,8 +21,7 @@ "RAGIndexManager", "RAGQueryService", "DocumentLoader", - "CodeAwareSplitter", - "FunctionAwareSplitter", + "ASTCodeSplitter", "make_namespace", "detect_language_from_path", ] diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/__init__.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/__init__.py index dcb908ff..43cff58b 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/__init__.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/__init__.py @@ -1,17 +1,10 @@ """Core functionality for indexing and document processing""" __all__ = [ "DocumentLoader", - "CodeAwareSplitter", - "FunctionAwareSplitter", - "SemanticCodeSplitter", "ASTCodeSplitter", "RAGIndexManager" ] from .index_manager import RAGIndexManager -from .chunking import CodeAwareSplitter, FunctionAwareSplitter -from .semantic_splitter import SemanticCodeSplitter -from .ast_splitter import ASTCodeSplitter -from .loader import DocumentLoader - - +from .splitter import ASTCodeSplitter +from .loader import DocumentLoader \ No newline at end of file diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/ast_splitter.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/ast_splitter.py deleted file mode 100644 index 691a3b0d..00000000 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/ast_splitter.py +++ /dev/null @@ -1,1401 +0,0 @@ -""" -AST-based Code Splitter using Tree-sitter for accurate code parsing. - -This module provides true AST-aware code chunking that: -1. Uses Tree-sitter for accurate AST parsing (15+ languages) -2. Splits code into semantic units (classes, functions, methods) -3. Uses RecursiveCharacterTextSplitter for oversized chunks (large methods) -4. Enriches metadata for better RAG retrieval -5. Maintains parent context ("breadcrumbs") for nested structures -6. Uses deterministic IDs for Qdrant deduplication - -Key benefits over regex-based splitting: -- Accurate function/class boundary detection -- Language-aware parsing for 15+ languages -- Better metadata: content_type, language, semantic_names, parent_class -- Handles edge cases (nested functions, decorators, etc.) -- Deterministic chunk IDs prevent duplicates on re-indexing -""" - -import re -import hashlib -import logging -from typing import List, Dict, Any, Optional, Set -from pathlib import Path -from dataclasses import dataclass, field -from enum import Enum - -from langchain_text_splitters import RecursiveCharacterTextSplitter, Language -from llama_index.core.schema import Document as LlamaDocument, TextNode - -logger = logging.getLogger(__name__) - - -class ContentType(Enum): - """Content type as determined by AST parsing""" - FUNCTIONS_CLASSES = "functions_classes" # Full function/class definition - SIMPLIFIED_CODE = "simplified_code" # Remaining code with placeholders - FALLBACK = "fallback" # Non-AST parsed content - OVERSIZED_SPLIT = "oversized_split" # Large chunk split by RecursiveCharacterTextSplitter - - -# Map file extensions to LangChain Language enum -EXTENSION_TO_LANGUAGE: Dict[str, Language] = { - # Python - '.py': Language.PYTHON, - '.pyw': Language.PYTHON, - '.pyi': Language.PYTHON, - - # Java/JVM - '.java': Language.JAVA, - '.kt': Language.KOTLIN, - '.kts': Language.KOTLIN, - '.scala': Language.SCALA, - - # JavaScript/TypeScript - '.js': Language.JS, - '.jsx': Language.JS, - '.mjs': Language.JS, - '.cjs': Language.JS, - '.ts': Language.TS, - '.tsx': Language.TS, - - # Systems languages - '.go': Language.GO, - '.rs': Language.RUST, - '.c': Language.C, - '.h': Language.C, - '.cpp': Language.CPP, - '.cc': Language.CPP, - '.cxx': Language.CPP, - '.hpp': Language.CPP, - '.hxx': Language.CPP, - '.cs': Language.CSHARP, - - # Web/Scripting - '.php': Language.PHP, - '.phtml': Language.PHP, # PHP template files (Magento, Zend, etc.) - '.php3': Language.PHP, - '.php4': Language.PHP, - '.php5': Language.PHP, - '.phps': Language.PHP, - '.inc': Language.PHP, # PHP include files - '.rb': Language.RUBY, - '.erb': Language.RUBY, # Ruby template files - '.lua': Language.LUA, - '.pl': Language.PERL, - '.pm': Language.PERL, - '.swift': Language.SWIFT, - - # Markup/Config - '.md': Language.MARKDOWN, - '.markdown': Language.MARKDOWN, - '.html': Language.HTML, - '.htm': Language.HTML, - '.rst': Language.RST, - '.tex': Language.LATEX, - '.proto': Language.PROTO, - '.sol': Language.SOL, - '.hs': Language.HASKELL, - '.cob': Language.COBOL, - '.cbl': Language.COBOL, - '.xml': Language.HTML, # Use HTML splitter for XML -} - -# Languages that support full AST parsing via tree-sitter -AST_SUPPORTED_LANGUAGES = { - Language.PYTHON, Language.JAVA, Language.KOTLIN, Language.JS, Language.TS, - Language.GO, Language.RUST, Language.C, Language.CPP, Language.CSHARP, - Language.PHP, Language.RUBY, Language.SCALA, Language.LUA, Language.PERL, - Language.SWIFT, Language.HASKELL, Language.COBOL -} - -# Tree-sitter language name mapping (tree-sitter-languages uses these names) -LANGUAGE_TO_TREESITTER: Dict[Language, str] = { - Language.PYTHON: 'python', - Language.JAVA: 'java', - Language.KOTLIN: 'kotlin', - Language.JS: 'javascript', - Language.TS: 'typescript', - Language.GO: 'go', - Language.RUST: 'rust', - Language.C: 'c', - Language.CPP: 'cpp', - Language.CSHARP: 'c_sharp', - Language.PHP: 'php', - Language.RUBY: 'ruby', - Language.SCALA: 'scala', - Language.LUA: 'lua', - Language.PERL: 'perl', - Language.SWIFT: 'swift', - Language.HASKELL: 'haskell', -} - -# Node types that represent semantic CHUNKING units (classes, functions) -# NOTE: imports, namespace, inheritance are now extracted DYNAMICALLY from AST -# by pattern matching on node type names - no hardcoded mappings needed! -SEMANTIC_NODE_TYPES: Dict[str, Dict[str, List[str]]] = { - 'python': { - 'class': ['class_definition'], - 'function': ['function_definition', 'async_function_definition'], - }, - 'java': { - 'class': ['class_declaration', 'interface_declaration', 'enum_declaration'], - 'function': ['method_declaration', 'constructor_declaration'], - }, - 'javascript': { - 'class': ['class_declaration'], - 'function': ['function_declaration', 'method_definition', 'arrow_function', 'generator_function_declaration'], - }, - 'typescript': { - 'class': ['class_declaration', 'interface_declaration'], - 'function': ['function_declaration', 'method_definition', 'arrow_function'], - }, - 'go': { - 'class': ['type_declaration'], # structs, interfaces - 'function': ['function_declaration', 'method_declaration'], - }, - 'rust': { - 'class': ['struct_item', 'impl_item', 'trait_item', 'enum_item'], - 'function': ['function_item'], - }, - 'c_sharp': { - 'class': ['class_declaration', 'interface_declaration', 'struct_declaration'], - 'function': ['method_declaration', 'constructor_declaration'], - }, - 'kotlin': { - 'class': ['class_declaration', 'object_declaration', 'interface_declaration'], - 'function': ['function_declaration'], - }, - 'php': { - 'class': ['class_declaration', 'interface_declaration', 'trait_declaration'], - 'function': ['function_definition', 'method_declaration'], - }, - 'ruby': { - 'class': ['class', 'module'], - 'function': ['method', 'singleton_method'], - }, - 'cpp': { - 'class': ['class_specifier', 'struct_specifier'], - 'function': ['function_definition'], - }, - 'c': { - 'class': ['struct_specifier'], - 'function': ['function_definition'], - }, - 'scala': { - 'class': ['class_definition', 'object_definition', 'trait_definition'], - 'function': ['function_definition'], - }, -} - -# Metadata extraction patterns (fallback when AST doesn't provide names) -METADATA_PATTERNS = { - 'python': { - 'class': re.compile(r'^class\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'^(?:async\s+)?def\s+(\w+)\s*\(', re.MULTILINE), - }, - 'java': { - 'class': re.compile(r'(?:public\s+|private\s+|protected\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)', re.MULTILINE), - 'interface': re.compile(r'(?:public\s+)?interface\s+(\w+)', re.MULTILINE), - 'method': re.compile(r'(?:public|private|protected)\s+(?:static\s+)?[\w<>,\s]+\s+(\w+)\s*\(', re.MULTILINE), - }, - 'javascript': { - 'class': re.compile(r'(?:export\s+)?(?:default\s+)?class\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'(?:export\s+)?(?:async\s+)?function\s*\*?\s*(\w+)\s*\(', re.MULTILINE), - }, - 'typescript': { - 'class': re.compile(r'(?:export\s+)?(?:default\s+)?class\s+(\w+)', re.MULTILINE), - 'interface': re.compile(r'(?:export\s+)?interface\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'(?:export\s+)?(?:async\s+)?function\s*\*?\s*(\w+)\s*\(', re.MULTILINE), - }, - 'go': { - 'function': re.compile(r'^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(', re.MULTILINE), - 'struct': re.compile(r'^type\s+(\w+)\s+struct\s*\{', re.MULTILINE), - }, - 'rust': { - 'function': re.compile(r'^(?:pub\s+)?(?:async\s+)?fn\s+(\w+)', re.MULTILINE), - 'struct': re.compile(r'^(?:pub\s+)?struct\s+(\w+)', re.MULTILINE), - }, - 'c_sharp': { - 'class': re.compile(r'(?:public\s+|private\s+|internal\s+)?(?:abstract\s+|sealed\s+)?class\s+(\w+)', re.MULTILINE), - 'method': re.compile(r'(?:public|private|protected|internal)\s+(?:static\s+)?[\w<>,\s]+\s+(\w+)\s*\(', re.MULTILINE), - }, - 'kotlin': { - 'class': re.compile(r'(?:data\s+|sealed\s+|open\s+)?class\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'(?:fun|suspend\s+fun)\s+(\w+)\s*\(', re.MULTILINE), - }, - 'php': { - 'class': re.compile(r'(?:abstract\s+|final\s+)?class\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'(?:public|private|protected|static|\s)*function\s+(\w+)\s*\(', re.MULTILINE), - }, -} - -# Patterns for extracting class inheritance, interfaces, and imports -CLASS_INHERITANCE_PATTERNS = { - 'php': { - 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w\\]+)', re.MULTILINE), - 'implements': re.compile(r'class\s+\w+(?:\s+extends\s+[\w\\]+)?\s+implements\s+([\w\\,\s]+)', re.MULTILINE), - 'use': re.compile(r'^use\s+([\w\\]+)(?:\s+as\s+\w+)?;', re.MULTILINE), - 'namespace': re.compile(r'^namespace\s+([\w\\]+);', re.MULTILINE), - 'type_hint': re.compile(r'@var\s+(\\?[\w\\|]+)', re.MULTILINE), - # PHTML template type hints: /** @var \Namespace\Class $variable */ - 'template_type': re.compile(r'/\*\*\s*@var\s+([\w\\]+)\s+\$\w+\s*\*/', re.MULTILINE), - # Variable type hints in PHPDoc: @param Type $name, @return Type - 'phpdoc_types': re.compile(r'@(?:param|return|throws)\s+([\w\\|]+)', re.MULTILINE), - }, - 'java': { - 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w.]+)', re.MULTILINE), - 'implements': re.compile(r'class\s+\w+(?:\s+extends\s+[\w.]+)?\s+implements\s+([\w.,\s]+)', re.MULTILINE), - 'import': re.compile(r'^import\s+([\w.]+(?:\.\*)?);', re.MULTILINE), - 'package': re.compile(r'^package\s+([\w.]+);', re.MULTILINE), - }, - 'kotlin': { - 'extends': re.compile(r'class\s+\w+\s*:\s*([\w.]+)(?:\([^)]*\))?', re.MULTILINE), - 'import': re.compile(r'^import\s+([\w.]+(?:\.\*)?)', re.MULTILINE), - 'package': re.compile(r'^package\s+([\w.]+)', re.MULTILINE), - }, - 'python': { - 'extends': re.compile(r'class\s+\w+\s*\(\s*([\w.,\s]+)\s*\)\s*:', re.MULTILINE), - 'import': re.compile(r'^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s*]+)', re.MULTILINE), - }, - 'typescript': { - 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w.]+)', re.MULTILINE), - 'implements': re.compile(r'class\s+\w+(?:\s+extends\s+[\w.]+)?\s+implements\s+([\w.,\s]+)', re.MULTILINE), - 'import': re.compile(r'^import\s+(?:[\w{},\s*]+\s+from\s+)?["\']([^"\']+)["\'];?', re.MULTILINE), - }, - 'javascript': { - 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w.]+)', re.MULTILINE), - 'import': re.compile(r'^import\s+(?:[\w{},\s*]+\s+from\s+)?["\']([^"\']+)["\'];?', re.MULTILINE), - 'require': re.compile(r'require\s*\(\s*["\']([^"\']+)["\']\s*\)', re.MULTILINE), - }, - 'c_sharp': { - 'extends': re.compile(r'class\s+\w+\s*:\s*([\w.]+)', re.MULTILINE), - 'implements': re.compile(r'class\s+\w+\s*:\s*(?:[\w.]+\s*,\s*)*([\w.,\s]+)', re.MULTILINE), - 'using': re.compile(r'^using\s+([\w.]+);', re.MULTILINE), - 'namespace': re.compile(r'^namespace\s+([\w.]+)', re.MULTILINE), - }, - 'go': { - 'import': re.compile(r'^import\s+(?:\(\s*)?"([^"]+)"', re.MULTILINE), - 'package': re.compile(r'^package\s+(\w+)', re.MULTILINE), - }, - 'rust': { - 'use': re.compile(r'^use\s+([\w:]+(?:::\{[^}]+\})?);', re.MULTILINE), - 'impl_for': re.compile(r'impl\s+(?:<[^>]+>\s+)?([\w:]+)\s+for\s+([\w:]+)', re.MULTILINE), - }, - 'scala': { - 'extends': re.compile(r'(?:class|object|trait)\s+\w+\s+extends\s+([\w.]+)', re.MULTILINE), - 'with': re.compile(r'with\s+([\w.]+)', re.MULTILINE), - 'import': re.compile(r'^import\s+([\w._{}]+)', re.MULTILINE), - 'package': re.compile(r'^package\s+([\w.]+)', re.MULTILINE), - }, -} - - -@dataclass -class ASTChunk: - """Represents a chunk of code from AST parsing""" - content: str - content_type: ContentType - language: str - path: str - semantic_names: List[str] = field(default_factory=list) - parent_context: List[str] = field(default_factory=list) # Breadcrumb: ["MyClass", "inner_method"] - docstring: Optional[str] = None - signature: Optional[str] = None - start_line: int = 0 - end_line: int = 0 - node_type: Optional[str] = None - class_metadata: Dict[str, Any] = field(default_factory=dict) # extends, implements from AST - file_metadata: Dict[str, Any] = field(default_factory=dict) # imports, namespace from AST - - -def generate_deterministic_id(path: str, content: str, chunk_index: int = 0) -> str: - """ - Generate a deterministic ID for a chunk based on file path and content. - - This ensures the same code chunk always gets the same ID, preventing - duplicates in Qdrant during re-indexing. - - Args: - path: File path - content: Chunk content - chunk_index: Index of chunk within file (for disambiguation) - - Returns: - Deterministic hex ID string - """ - # Use path + content hash + index for uniqueness - hash_input = f"{path}:{chunk_index}:{content[:500]}" # First 500 chars for efficiency - return hashlib.sha256(hash_input.encode('utf-8')).hexdigest()[:32] - - -def compute_file_hash(content: str) -> str: - """Compute hash of file content for change detection""" - return hashlib.sha256(content.encode('utf-8')).hexdigest() - - -class ASTCodeSplitter: - """ - AST-based code splitter using Tree-sitter for accurate parsing. - - Features: - - True AST parsing via tree-sitter for accurate code structure detection - - Splits code into semantic units (classes, functions, methods) - - Maintains parent context (breadcrumbs) for nested structures - - Falls back to RecursiveCharacterTextSplitter for oversized chunks - - Uses deterministic IDs for Qdrant deduplication - - Enriches metadata for improved RAG retrieval - - Usage: - splitter = ASTCodeSplitter(max_chunk_size=2000) - nodes = splitter.split_documents(documents) - """ - - DEFAULT_MAX_CHUNK_SIZE = 2000 - DEFAULT_MIN_CHUNK_SIZE = 100 - DEFAULT_CHUNK_OVERLAP = 200 - DEFAULT_PARSER_THRESHOLD = 10 # Minimum lines for AST parsing - - def __init__( - self, - max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE, - min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE, - chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, - parser_threshold: int = DEFAULT_PARSER_THRESHOLD - ): - """ - Initialize AST code splitter. - - Args: - max_chunk_size: Maximum characters per chunk. Larger chunks are split. - min_chunk_size: Minimum characters for a valid chunk. - chunk_overlap: Overlap between chunks when splitting oversized content. - parser_threshold: Minimum lines for AST parsing (smaller files use fallback). - """ - self.max_chunk_size = max_chunk_size - self.min_chunk_size = min_chunk_size - self.chunk_overlap = chunk_overlap - self.parser_threshold = parser_threshold - - # Cache text splitters for oversized chunks - self._splitter_cache: Dict[Language, RecursiveCharacterTextSplitter] = {} - - # Default text splitter for unknown languages - self._default_splitter = RecursiveCharacterTextSplitter( - chunk_size=max_chunk_size, - chunk_overlap=chunk_overlap, - length_function=len, - ) - - # Track if tree-sitter is available - self._tree_sitter_available: Optional[bool] = None - # Cache for language modules and parsers - self._language_cache: Dict[str, Any] = {} - - def _get_tree_sitter_language(self, lang_name: str): - """ - Get tree-sitter Language object for a language name. - Uses the new tree-sitter API with individual language packages. - - Note: Different packages have different APIs: - - Most use: module.language() - - PHP uses: module.language_php() - - TypeScript uses: module.language_typescript() - """ - if lang_name in self._language_cache: - return self._language_cache[lang_name] - - try: - from tree_sitter import Language - - # Map language names to their package modules and function names - # Format: (module_name, function_name or None for 'language') - lang_modules = { - 'python': ('tree_sitter_python', 'language'), - 'java': ('tree_sitter_java', 'language'), - 'javascript': ('tree_sitter_javascript', 'language'), - 'typescript': ('tree_sitter_typescript', 'language_typescript'), - 'go': ('tree_sitter_go', 'language'), - 'rust': ('tree_sitter_rust', 'language'), - 'c': ('tree_sitter_c', 'language'), - 'cpp': ('tree_sitter_cpp', 'language'), - 'c_sharp': ('tree_sitter_c_sharp', 'language'), - 'ruby': ('tree_sitter_ruby', 'language'), - 'php': ('tree_sitter_php', 'language_php'), - } - - lang_info = lang_modules.get(lang_name) - if not lang_info: - return None - - module_name, func_name = lang_info - - # Dynamic import of language module - import importlib - lang_module = importlib.import_module(module_name) - - # Get the language function - lang_func = getattr(lang_module, func_name, None) - if not lang_func: - logger.debug(f"Module {module_name} has no {func_name} function") - return None - - # Create Language object using the new API - language = Language(lang_func()) - self._language_cache[lang_name] = language - return language - - except Exception as e: - logger.debug(f"Could not load tree-sitter language '{lang_name}': {e}") - return None - - def _check_tree_sitter(self) -> bool: - """Check if tree-sitter is available""" - if self._tree_sitter_available is None: - try: - from tree_sitter import Parser, Language - import tree_sitter_python as tspython - - # Test with the new API - py_language = Language(tspython.language()) - parser = Parser(py_language) - parser.parse(b"def test(): pass") - - self._tree_sitter_available = True - logger.info("tree-sitter is available and working") - except ImportError as e: - logger.warning(f"tree-sitter not installed: {e}") - self._tree_sitter_available = False - except Exception as e: - logger.warning(f"tree-sitter error: {type(e).__name__}: {e}") - self._tree_sitter_available = False - return self._tree_sitter_available - - def _get_language_from_path(self, path: str) -> Optional[Language]: - """Determine Language enum from file path""" - ext = Path(path).suffix.lower() - return EXTENSION_TO_LANGUAGE.get(ext) - - def _get_treesitter_language(self, language: Language) -> Optional[str]: - """Get tree-sitter language name from Language enum""" - return LANGUAGE_TO_TREESITTER.get(language) - - def _get_text_splitter(self, language: Language) -> RecursiveCharacterTextSplitter: - """Get language-specific text splitter for oversized chunks""" - if language not in self._splitter_cache: - try: - self._splitter_cache[language] = RecursiveCharacterTextSplitter.from_language( - language=language, - chunk_size=self.max_chunk_size, - chunk_overlap=self.chunk_overlap, - ) - except Exception: - # Fallback if language not supported - self._splitter_cache[language] = self._default_splitter - return self._splitter_cache[language] - - def _parse_inheritance_clause(self, clause_text: str, language: str) -> List[str]: - """ - Parse inheritance clause from AST node text to extract class/interface names. - - Handles various formats: - - PHP: "extends ParentClass" or "implements Interface1, Interface2" - - Java: "extends Parent" or "implements I1, I2" - - Python: "(Parent1, Parent2)" - - TypeScript/JS: "extends Parent implements Interface" - """ - if not clause_text: - return [] - - # Clean up the clause - text = clause_text.strip() - - # Remove common keywords - for keyword in ['extends', 'implements', 'with', ':']: - text = text.replace(keyword, ' ') - - # Handle parentheses (Python style) - text = text.strip('()') - - # Split by comma and clean up - names = [] - for part in text.split(','): - name = part.strip() - # Remove any generic type parameters - if '<' in name: - name = name.split('<')[0].strip() - # Remove any constructor calls () - if '(' in name: - name = name.split('(')[0].strip() - if name and name not in ('', ' '): - names.append(name) - - return names - - def _parse_namespace(self, ns_text: str, language: str) -> Optional[str]: - """ - Extract namespace/package name from AST node text. - - Handles various formats: - - PHP: "namespace Vendor\\Package\\Module;" - - Java/Kotlin: "package com.example.app;" - - C#: "namespace MyNamespace { ... }" - """ - if not ns_text: - return None - - # Clean up - text = ns_text.strip() - - # Remove keywords and semicolons - for keyword in ['namespace', 'package']: - text = text.replace(keyword, ' ') - - text = text.strip().rstrip(';').rstrip('{').strip() - - return text if text else None - - def _parse_with_ast( - self, - text: str, - language: Language, - path: str - ) -> List[ASTChunk]: - """ - Parse code using AST via tree-sitter. - - Returns list of ASTChunk objects with content and metadata. - """ - if not self._check_tree_sitter(): - return [] - - ts_lang = self._get_treesitter_language(language) - if not ts_lang: - logger.debug(f"No tree-sitter mapping for {language}, using fallback") - return [] - - try: - from tree_sitter import Parser - - # Get Language object for this language - lang_obj = self._get_tree_sitter_language(ts_lang) - if not lang_obj: - logger.debug(f"tree-sitter language '{ts_lang}' not available") - return [] - - # Create parser with the language - parser = Parser(lang_obj) - tree = parser.parse(bytes(text, "utf8")) - - # Extract chunks with breadcrumb context - chunks = self._extract_ast_chunks_with_context( - tree.root_node, - text, - ts_lang, - path - ) - - return chunks - - except Exception as e: - logger.warning(f"AST parsing failed for {path}: {e}") - return [] - - def _extract_ast_chunks_with_context( - self, - root_node, - source_code: str, - language: str, - path: str - ) -> List[ASTChunk]: - """ - Extract function/class chunks from AST tree with parent context (breadcrumbs). - - This solves the "context loss" problem by tracking parent classes/modules - so that a method knows it belongs to a specific class. - - Also extracts file-level metadata dynamically from AST - no hardcoded mappings needed. - """ - chunks = [] - processed_ranges: Set[tuple] = set() # Track (start, end) to avoid duplicates - - # IMPORTANT: Tree-sitter returns byte positions, not character positions - # We need to slice bytes and decode, not slice the string directly - source_bytes = source_code.encode('utf-8') - - # File-level metadata collected dynamically from AST - file_metadata: Dict[str, Any] = { - 'imports': [], - 'types_referenced': [], - } - - # Get node types for this language (only for chunking - class/function boundaries) - lang_node_types = SEMANTIC_NODE_TYPES.get(language, {}) - class_types = set(lang_node_types.get('class', [])) - function_types = set(lang_node_types.get('function', [])) - all_semantic_types = class_types | function_types - - def get_node_text(node) -> str: - """Get full text content of a node using byte positions""" - return source_bytes[node.start_byte:node.end_byte].decode('utf-8', errors='replace') - - def extract_identifiers(node) -> List[str]: - """Recursively extract all identifier names from a node""" - identifiers = [] - if node.type in ('identifier', 'name', 'type_identifier', 'qualified_name', 'scoped_identifier'): - identifiers.append(get_node_text(node)) - for child in node.children: - identifiers.extend(extract_identifiers(child)) - return identifiers - - def extract_ast_metadata(node, metadata: Dict[str, Any]) -> None: - """ - Dynamically extract metadata from AST node based on node type patterns. - No hardcoded language mappings - uses common naming conventions in tree-sitter. - """ - node_type = node.type - node_text = get_node_text(node).strip() - - # === IMPORTS (pattern: *import*, *use*, *require*, *include*) === - if any(kw in node_type for kw in ('import', 'use_', 'require', 'include', 'using')): - if node_text and len(node_text) < 500: # Skip huge nodes - metadata['imports'].append(node_text) - return # Don't recurse into import children - - # === NAMESPACE/PACKAGE (pattern: *namespace*, *package*, *module*) === - if any(kw in node_type for kw in ('namespace', 'package', 'module_declaration')): - # Extract the name part - names = extract_identifiers(node) - if names: - metadata['namespace'] = names[0] if len(names) == 1 else '.'.join(names) - elif node_text: - # Fallback: parse from text - metadata['namespace'] = self._parse_namespace(node_text, language) - return - - # Recurse into children - for child in node.children: - extract_ast_metadata(child, metadata) - - def extract_class_metadata_from_ast(node) -> Dict[str, Any]: - """ - Dynamically extract class metadata (extends, implements) from AST. - Uses common tree-sitter naming patterns - no manual mapping needed. - """ - meta: Dict[str, Any] = {} - - def find_inheritance(n, depth=0): - """Recursively find inheritance-related nodes""" - node_type = n.type - - # === EXTENDS / SUPERCLASS (pattern: *super*, *base*, *extends*, *heritage*) === - if any(kw in node_type for kw in ('super', 'base_clause', 'extends', 'heritage', 'parent')): - names = extract_identifiers(n) - if names: - meta.setdefault('extends', []).extend(names) - meta['parent_types'] = meta.get('extends', []) - return # Found it, don't go deeper - - # === IMPLEMENTS / INTERFACES (pattern: *implement*, *interface_clause*, *conform*) === - if any(kw in node_type for kw in ('implement', 'interface_clause', 'conform', 'protocol')): - names = extract_identifiers(n) - if names: - meta.setdefault('implements', []).extend(names) - return - - # === TRAIT/MIXIN (pattern: *trait*, *mixin*, *with*) === - if any(kw in node_type for kw in ('trait', 'mixin', 'with_clause')): - names = extract_identifiers(n) - if names: - meta.setdefault('traits', []).extend(names) - return - - # === TYPE PARAMETERS / GENERICS === - if any(kw in node_type for kw in ('type_parameter', 'generic', 'type_argument')): - names = extract_identifiers(n) - if names: - meta.setdefault('type_params', []).extend(names) - return - - # Recurse but limit depth to avoid going too deep - if depth < 5: - for child in n.children: - find_inheritance(child, depth + 1) - - for child in node.children: - find_inheritance(child) - - # Deduplicate - for key in meta: - if isinstance(meta[key], list): - meta[key] = list(dict.fromkeys(meta[key])) # Preserve order, remove dupes - - return meta - - def get_node_name(node) -> Optional[str]: - """Extract name from a node (class/function name)""" - for child in node.children: - if child.type in ('identifier', 'name', 'type_identifier', 'property_identifier'): - return get_node_text(child) - return None - - def traverse(node, parent_context: List[str], depth: int = 0): - """ - Recursively traverse AST and extract semantic chunks with breadcrumbs. - - Args: - node: Current AST node - parent_context: List of parent class/function names (breadcrumb) - depth: Current depth in tree - """ - node_range = (node.start_byte, node.end_byte) - - # Check if this is a semantic unit - if node.type in all_semantic_types: - # Skip if already processed (nested in another chunk) - if node_range in processed_ranges: - return - - # Use bytes for slicing since tree-sitter returns byte positions - content = source_bytes[node.start_byte:node.end_byte].decode('utf-8', errors='replace') - - # Calculate line numbers (use bytes for consistency) - start_line = source_bytes[:node.start_byte].count(b'\n') + 1 - end_line = start_line + content.count('\n') - - # Get the name of this node - node_name = get_node_name(node) - - # Determine content type - is_class = node.type in class_types - - # Extract class inheritance metadata dynamically from AST - class_metadata = {} - if is_class: - class_metadata = extract_class_metadata_from_ast(node) - - chunk = ASTChunk( - content=content, - content_type=ContentType.FUNCTIONS_CLASSES, - language=language, - path=path, - semantic_names=[node_name] if node_name else [], - parent_context=list(parent_context), # Copy the breadcrumb - start_line=start_line, - end_line=end_line, - node_type=node.type, - class_metadata=class_metadata, - ) - - chunks.append(chunk) - processed_ranges.add(node_range) - - # If this is a class, traverse children with updated context - if is_class and node_name: - new_context = parent_context + [node_name] - for child in node.children: - traverse(child, new_context, depth + 1) - else: - # Continue traversing children with current context - for child in node.children: - traverse(child, parent_context, depth + 1) - - # First pass: extract file-level metadata (imports, namespace) from entire AST - extract_ast_metadata(root_node, file_metadata) - - # Second pass: extract semantic chunks (classes, functions) - traverse(root_node, []) - - # Clean up file_metadata - remove empty values - clean_file_metadata = {k: v for k, v in file_metadata.items() if v} - - # Create simplified code (skeleton with placeholders) - simplified = self._create_simplified_code(source_code, chunks, language) - if simplified and simplified.strip() and len(simplified.strip()) > 50: - chunks.append(ASTChunk( - content=simplified, - content_type=ContentType.SIMPLIFIED_CODE, - language=language, - path=path, - semantic_names=[], - parent_context=[], - start_line=1, - end_line=source_code.count('\n') + 1, - node_type='simplified', - file_metadata=clean_file_metadata, # Include imports/namespace from AST - )) - - # Also attach file_metadata to all chunks for enriched metadata - for chunk in chunks: - if not chunk.file_metadata: - chunk.file_metadata = clean_file_metadata - - return chunks - - def _create_simplified_code( - self, - source_code: str, - chunks: List[ASTChunk], - language: str - ) -> str: - """ - Create simplified code with placeholders for extracted chunks. - - This gives RAG context about the overall file structure without - including full function/class bodies. - - Example output: - # Code for: class MyClass: - # Code for: def my_function(): - if __name__ == "__main__": - main() - """ - if not chunks: - return source_code - - # Get chunks that are functions_classes type (not simplified) - semantic_chunks = [c for c in chunks if c.content_type == ContentType.FUNCTIONS_CLASSES] - - if not semantic_chunks: - return source_code - - # Sort by start position (reverse) to replace from end - sorted_chunks = sorted( - semantic_chunks, - key=lambda x: source_code.find(x.content), - reverse=True - ) - - result = source_code - - # Comment style by language - comment_prefix = { - 'python': '#', - 'javascript': '//', - 'typescript': '//', - 'java': '//', - 'kotlin': '//', - 'go': '//', - 'rust': '//', - 'c': '//', - 'cpp': '//', - 'c_sharp': '//', - 'php': '//', - 'ruby': '#', - 'lua': '--', - 'perl': '#', - 'scala': '//', - }.get(language, '//') - - for chunk in sorted_chunks: - # Find the position of this chunk in the source - pos = result.find(chunk.content) - if pos == -1: - continue - - # Extract first line for placeholder - first_line = chunk.content.split('\n')[0].strip() - # Truncate if too long - if len(first_line) > 60: - first_line = first_line[:60] + '...' - - # Add breadcrumb context to placeholder - breadcrumb = "" - if chunk.parent_context: - breadcrumb = f" (in {'.'.join(chunk.parent_context)})" - - placeholder = f"{comment_prefix} Code for: {first_line}{breadcrumb}\n" - - result = result[:pos] + placeholder + result[pos + len(chunk.content):] - - return result.strip() - - def _extract_metadata( - self, - chunk: ASTChunk, - base_metadata: Dict[str, Any] - ) -> Dict[str, Any]: - """Extract and enrich metadata from an AST chunk""" - metadata = dict(base_metadata) - - # Core AST metadata - metadata['content_type'] = chunk.content_type.value - metadata['node_type'] = chunk.node_type - - # Breadcrumb context (critical for RAG) - if chunk.parent_context: - metadata['parent_context'] = chunk.parent_context - metadata['parent_class'] = chunk.parent_context[-1] if chunk.parent_context else None - metadata['full_path'] = '.'.join(chunk.parent_context + chunk.semantic_names[:1]) - - # Semantic names - if chunk.semantic_names: - metadata['semantic_names'] = chunk.semantic_names[:10] - metadata['primary_name'] = chunk.semantic_names[0] - - # Line numbers - metadata['start_line'] = chunk.start_line - metadata['end_line'] = chunk.end_line - - # === Use AST-extracted metadata (from tree-sitter) === - - # File-level metadata from AST (imports, namespace, package) - if chunk.file_metadata: - if chunk.file_metadata.get('imports'): - metadata['imports'] = chunk.file_metadata['imports'][:20] - if chunk.file_metadata.get('namespace'): - metadata['namespace'] = chunk.file_metadata['namespace'] - if chunk.file_metadata.get('package'): - metadata['package'] = chunk.file_metadata['package'] - - # Class-level metadata from AST (extends, implements) - if chunk.class_metadata: - if chunk.class_metadata.get('extends'): - metadata['extends'] = chunk.class_metadata['extends'] - metadata['parent_types'] = chunk.class_metadata['extends'] - if chunk.class_metadata.get('implements'): - metadata['implements'] = chunk.class_metadata['implements'] - - # Try to extract additional metadata via regex patterns - patterns = METADATA_PATTERNS.get(chunk.language, {}) - - # Extract docstring - docstring = self._extract_docstring(chunk.content, chunk.language) - if docstring: - metadata['docstring'] = docstring[:500] - - # Extract signature - signature = self._extract_signature(chunk.content, chunk.language) - if signature: - metadata['signature'] = signature - - # Extract additional names not caught by AST - if not chunk.semantic_names: - names = [] - for pattern_type, pattern in patterns.items(): - matches = pattern.findall(chunk.content) - names.extend(matches) - if names: - metadata['semantic_names'] = list(set(names))[:10] - metadata['primary_name'] = names[0] - - # Fallback: Extract inheritance via regex if AST didn't find it - if 'extends' not in metadata and 'implements' not in metadata: - self._extract_inheritance_metadata(chunk.content, chunk.language, metadata) - - return metadata - - def _extract_inheritance_metadata( - self, - content: str, - language: str, - metadata: Dict[str, Any] - ) -> None: - """Extract inheritance, interfaces, and imports from code chunk""" - inheritance_patterns = CLASS_INHERITANCE_PATTERNS.get(language, {}) - - if not inheritance_patterns: - return - - # Extract extends (parent class) - if 'extends' in inheritance_patterns: - match = inheritance_patterns['extends'].search(content) - if match: - extends = match.group(1).strip() - # Clean up and split multiple classes (for multiple inheritance) - extends_list = [e.strip() for e in extends.split(',') if e.strip()] - if extends_list: - metadata['extends'] = extends_list - metadata['parent_types'] = extends_list # Alias for searchability - - # Extract implements (interfaces) - if 'implements' in inheritance_patterns: - match = inheritance_patterns['implements'].search(content) - if match: - implements = match.group(1).strip() - implements_list = [i.strip() for i in implements.split(',') if i.strip()] - if implements_list: - metadata['implements'] = implements_list - - # Extract imports/use statements - import_key = None - for key in ['import', 'use', 'using', 'require']: - if key in inheritance_patterns: - import_key = key - break - - if import_key: - matches = inheritance_patterns[import_key].findall(content) - if matches: - # Flatten if matches are tuples (from groups in regex) - imports = [] - for m in matches: - if isinstance(m, tuple): - imports.extend([x.strip() for x in m if x and x.strip()]) - else: - imports.append(m.strip()) - if imports: - metadata['imports'] = imports[:20] # Limit to 20 - - # Extract namespace/package - for key in ['namespace', 'package']: - if key in inheritance_patterns: - match = inheritance_patterns[key].search(content) - if match: - metadata[key] = match.group(1).strip() - break - - # Extract Rust impl for - if 'impl_for' in inheritance_patterns: - matches = inheritance_patterns['impl_for'].findall(content) - if matches: - # matches are tuples of (trait, type) - metadata['impl_traits'] = [m[0] for m in matches if m[0]] - metadata['impl_types'] = [m[1] for m in matches if m[1]] - - # Extract Scala with traits - if 'with' in inheritance_patterns: - matches = inheritance_patterns['with'].findall(content) - if matches: - metadata['with_traits'] = matches - - # Extract PHP type hints from docblocks - if 'type_hint' in inheritance_patterns: - matches = inheritance_patterns['type_hint'].findall(content) - if matches: - # Extract unique types referenced in docblocks - type_refs = list(set(matches))[:10] - metadata['type_references'] = type_refs - - # Extract PHTML template type hints (/** @var \Class $var */) - if 'template_type' in inheritance_patterns: - matches = inheritance_patterns['template_type'].findall(content) - if matches: - template_types = list(set(matches))[:10] - # Merge with type_references if exists - existing = metadata.get('type_references', []) - metadata['type_references'] = list(set(existing + template_types))[:15] - # Also add to related_classes for better searchability - metadata['related_classes'] = template_types - - # Extract PHP PHPDoc types (@param, @return, @throws) - if 'phpdoc_types' in inheritance_patterns: - matches = inheritance_patterns['phpdoc_types'].findall(content) - if matches: - # Filter and clean type names, handle union types - phpdoc_types = [] - for m in matches: - for t in m.split('|'): - t = t.strip().lstrip('\\') - if t and t[0].isupper(): # Only class names - phpdoc_types.append(t) - if phpdoc_types: - existing = metadata.get('type_references', []) - metadata['type_references'] = list(set(existing + phpdoc_types))[:20] - - def _extract_docstring(self, content: str, language: str) -> Optional[str]: - """Extract docstring from code chunk""" - if language == 'python': - match = re.search(r'"""([\s\S]*?)"""|\'\'\'([\s\S]*?)\'\'\'', content) - if match: - return (match.group(1) or match.group(2)).strip() - - elif language in ('javascript', 'typescript', 'java', 'kotlin', 'c_sharp', 'php', 'go', 'scala'): - match = re.search(r'/\*\*([\s\S]*?)\*/', content) - if match: - doc = match.group(1) - doc = re.sub(r'^\s*\*\s?', '', doc, flags=re.MULTILINE) - return doc.strip() - - elif language == 'rust': - lines = [] - for line in content.split('\n'): - if line.strip().startswith('///'): - lines.append(line.strip()[3:].strip()) - elif lines: - break - if lines: - return '\n'.join(lines) - - return None - - def _extract_signature(self, content: str, language: str) -> Optional[str]: - """Extract function/method signature from code chunk""" - lines = content.split('\n') - - for line in lines[:15]: - line = line.strip() - - if language == 'python': - if line.startswith(('def ', 'async def ', 'class ')): - sig = line - if line.startswith('class ') and ':' in line: - return line.split(':')[0] + ':' - if ')' not in sig and ':' not in sig: - idx = -1 - for i, l in enumerate(lines): - if l.strip() == line: - idx = i - break - if idx >= 0: - for next_line in lines[idx+1:idx+5]: - sig += ' ' + next_line.strip() - if ')' in next_line: - break - if ':' in sig: - return sig.split(':')[0] + ':' - return sig - - elif language in ('java', 'kotlin', 'c_sharp'): - if any(kw in line for kw in ['public ', 'private ', 'protected ', 'internal ', 'fun ']): - if '(' in line and not line.startswith('//'): - return line.split('{')[0].strip() - - elif language in ('javascript', 'typescript'): - if line.startswith(('function ', 'async function ', 'class ')): - return line.split('{')[0].strip() - if '=>' in line and '(' in line: - return line.split('=>')[0].strip() + ' =>' - - elif language == 'go': - if line.startswith('func ') or line.startswith('type '): - return line.split('{')[0].strip() - - elif language == 'rust': - if line.startswith(('fn ', 'pub fn ', 'async fn ', 'pub async fn ', 'impl ', 'struct ', 'trait ')): - return line.split('{')[0].strip() - - return None - - def _split_oversized_chunk( - self, - chunk: ASTChunk, - language: Optional[Language], - base_metadata: Dict[str, Any], - path: str - ) -> List[TextNode]: - """ - Split an oversized chunk using RecursiveCharacterTextSplitter. - - This is used when AST-parsed chunks (e.g., very large classes/functions) - still exceed the max_chunk_size. - """ - splitter = ( - self._get_text_splitter(language) - if language and language in AST_SUPPORTED_LANGUAGES - else self._default_splitter - ) - - sub_chunks = splitter.split_text(chunk.content) - nodes = [] - - # Parent ID for linking sub-chunks - parent_id = generate_deterministic_id(path, chunk.content, 0) - - for i, sub_chunk in enumerate(sub_chunks): - if not sub_chunk or not sub_chunk.strip(): - continue - - if len(sub_chunk.strip()) < self.min_chunk_size and len(sub_chunks) > 1: - continue - - metadata = dict(base_metadata) - metadata['content_type'] = ContentType.OVERSIZED_SPLIT.value - metadata['original_content_type'] = chunk.content_type.value - metadata['parent_chunk_id'] = parent_id - metadata['sub_chunk_index'] = i - metadata['total_sub_chunks'] = len(sub_chunks) - - # Preserve breadcrumb context - if chunk.parent_context: - metadata['parent_context'] = chunk.parent_context - metadata['parent_class'] = chunk.parent_context[-1] - - if chunk.semantic_names: - metadata['semantic_names'] = chunk.semantic_names - metadata['primary_name'] = chunk.semantic_names[0] - - # Deterministic ID for this sub-chunk - chunk_id = generate_deterministic_id(path, sub_chunk, i) - - node = TextNode( - id_=chunk_id, - text=sub_chunk, - metadata=metadata - ) - nodes.append(node) - - return nodes - - def split_documents(self, documents: List[LlamaDocument]) -> List[TextNode]: - """ - Split LlamaIndex documents using AST-based parsing. - - Args: - documents: List of LlamaIndex Document objects - - Returns: - List of TextNode objects with enriched metadata and deterministic IDs - """ - all_nodes = [] - - for doc in documents: - path = doc.metadata.get('path', 'unknown') - - # Determine Language enum - language = self._get_language_from_path(path) - - # Check if AST parsing is supported and beneficial - line_count = doc.text.count('\n') + 1 - use_ast = ( - language is not None - and language in AST_SUPPORTED_LANGUAGES - and line_count >= self.parser_threshold - and self._check_tree_sitter() - ) - - if use_ast: - nodes = self._split_with_ast(doc, language) - else: - nodes = self._split_fallback(doc, language) - - all_nodes.extend(nodes) - logger.debug(f"Split {path} into {len(nodes)} chunks (AST={use_ast})") - - return all_nodes - - def _split_with_ast( - self, - doc: LlamaDocument, - language: Language - ) -> List[TextNode]: - """Split document using AST parsing with breadcrumb context""" - text = doc.text - path = doc.metadata.get('path', 'unknown') - - # Try AST parsing - ast_chunks = self._parse_with_ast(text, language, path) - - if not ast_chunks: - return self._split_fallback(doc, language) - - nodes = [] - chunk_counter = 0 - - for ast_chunk in ast_chunks: - # Check if chunk is oversized - if len(ast_chunk.content) > self.max_chunk_size: - # Split oversized chunk - sub_nodes = self._split_oversized_chunk( - ast_chunk, - language, - doc.metadata, - path - ) - nodes.extend(sub_nodes) - chunk_counter += len(sub_nodes) - else: - # Create node with enriched metadata - metadata = self._extract_metadata(ast_chunk, doc.metadata) - metadata['chunk_index'] = chunk_counter - metadata['total_chunks'] = len(ast_chunks) - - # Deterministic ID - chunk_id = generate_deterministic_id(path, ast_chunk.content, chunk_counter) - - node = TextNode( - id_=chunk_id, - text=ast_chunk.content, - metadata=metadata - ) - nodes.append(node) - chunk_counter += 1 - - return nodes - - def _split_fallback( - self, - doc: LlamaDocument, - language: Optional[Language] = None - ) -> List[TextNode]: - """Fallback splitting using RecursiveCharacterTextSplitter""" - text = doc.text - path = doc.metadata.get('path', 'unknown') - - if not text or not text.strip(): - return [] - - splitter = ( - self._get_text_splitter(language) - if language and language in AST_SUPPORTED_LANGUAGES - else self._default_splitter - ) - - chunks = splitter.split_text(text) - nodes = [] - text_offset = 0 - - for i, chunk in enumerate(chunks): - if not chunk or not chunk.strip(): - continue - - if len(chunk.strip()) < self.min_chunk_size and len(chunks) > 1: - continue - - # Truncate if too large - if len(chunk) > 30000: - chunk = chunk[:30000] - - # Calculate line numbers - start_line = text[:text_offset].count('\n') + 1 if text_offset > 0 else 1 - chunk_pos = text.find(chunk, text_offset) - if chunk_pos >= 0: - text_offset = chunk_pos + len(chunk) - end_line = start_line + chunk.count('\n') - - # Extract metadata using regex patterns - lang_str = doc.metadata.get('language', 'text') - metadata = dict(doc.metadata) - metadata['content_type'] = ContentType.FALLBACK.value - metadata['chunk_index'] = i - metadata['total_chunks'] = len(chunks) - metadata['start_line'] = start_line - metadata['end_line'] = end_line - - # Try to extract semantic names - patterns = METADATA_PATTERNS.get(lang_str, {}) - names = [] - for pattern_type, pattern in patterns.items(): - matches = pattern.findall(chunk) - names.extend(matches) - if names: - metadata['semantic_names'] = list(set(names))[:10] - metadata['primary_name'] = names[0] - - # Extract class inheritance, interfaces, and imports (also for fallback) - self._extract_inheritance_metadata(chunk, lang_str, metadata) - - # Deterministic ID - chunk_id = generate_deterministic_id(path, chunk, i) - - node = TextNode( - id_=chunk_id, - text=chunk, - metadata=metadata - ) - nodes.append(node) - - return nodes - - @staticmethod - def get_supported_languages() -> List[str]: - """Return list of languages with AST support""" - return list(LANGUAGE_TO_TREESITTER.values()) - - @staticmethod - def is_ast_supported(path: str) -> bool: - """Check if AST parsing is supported for a file""" - ext = Path(path).suffix.lower() - lang = EXTENSION_TO_LANGUAGE.get(ext) - return lang is not None and lang in AST_SUPPORTED_LANGUAGES diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/chunking.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/chunking.py deleted file mode 100644 index 3c8cd8cc..00000000 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/chunking.py +++ /dev/null @@ -1,171 +0,0 @@ -from typing import List -import uuid -from llama_index.core.node_parser import SentenceSplitter -from llama_index.core.schema import Document, TextNode -from ..utils.utils import is_code_file - - -class CodeAwareSplitter: - """ - Code-aware text splitter that handles code and text differently. - - DEPRECATED: Use SemanticCodeSplitter instead, which provides: - - Full AST-aware parsing for multiple languages - - Better metadata extraction (docstrings, signatures, imports) - - Smarter chunk merging and boundary detection - - This class just wraps SentenceSplitter with different chunk sizes for - code vs text. For truly semantic code splitting, use SemanticCodeSplitter. - """ - - def __init__(self, code_chunk_size: int = 800, code_overlap: int = 200, - text_chunk_size: int = 1000, text_overlap: int = 200): - self.code_splitter = SentenceSplitter( - chunk_size=code_chunk_size, - chunk_overlap=code_overlap, - separator="\n\n", - ) - - self.text_splitter = SentenceSplitter( - chunk_size=text_chunk_size, - chunk_overlap=text_overlap, - ) - - def split_documents(self, documents: List[Document]) -> List[TextNode]: - """Split documents into chunks based on their language type""" - result = [] - - for doc in documents: - language = doc.metadata.get("language", "text") - is_code = is_code_file(language) - - splitter = self.code_splitter if is_code else self.text_splitter - - nodes = splitter.get_nodes_from_documents([doc]) - - for i, node in enumerate(nodes): - # Skip empty or whitespace-only chunks - if not node.text or not node.text.strip(): - continue - - # Truncate text if too large (>30k chars ≈ 7.5k tokens) - text = node.text - if len(text) > 30000: - text = text[:30000] - - metadata = dict(doc.metadata) - metadata["chunk_index"] = i - metadata["total_chunks"] = len(nodes) - - # Create TextNode with explicit UUID - chunk_node = TextNode( - id_=str(uuid.uuid4()), - text=text, - metadata=metadata - ) - result.append(chunk_node) - - return result - - def split_text_for_language(self, text: str, language: str) -> List[str]: - """Split text based on language type""" - is_code = is_code_file(language) - splitter = self.code_splitter if is_code else self.text_splitter - - temp_doc = Document(text=text, metadata={"language": language}) - nodes = splitter.get_nodes_from_documents([temp_doc]) - - return [node.text for node in nodes] - - -class FunctionAwareSplitter: - """ - Advanced splitter that tries to preserve function boundaries. - - DEPRECATED: Use SemanticCodeSplitter instead, which provides: - - Full AST-aware parsing for multiple languages - - Better metadata extraction (docstrings, signatures, imports) - - Smarter chunk merging and boundary detection - - This class is kept for backward compatibility only. - """ - - def __init__(self, max_chunk_size: int = 800, overlap: int = 200): - self.max_chunk_size = max_chunk_size - self.overlap = overlap - self.fallback_splitter = SentenceSplitter( - chunk_size=max_chunk_size, - chunk_overlap=overlap, - ) - - def split_by_functions(self, text: str, language: str) -> List[str]: - """Try to split code by functions/classes""" - - if language == 'python': - return self._split_python(text) - elif language in ['javascript', 'typescript', 'java', 'cpp', 'c', 'go', 'rust', 'php']: - return self._split_brace_language(text) - else: - temp_doc = Document(text=text) - nodes = self.fallback_splitter.get_nodes_from_documents([temp_doc]) - return [node.text for node in nodes] - - def _split_python(self, text: str) -> List[str]: - """Split Python code by top-level definitions""" - lines = text.split('\n') - chunks = [] - current_chunk = [] - - for line in lines: - stripped = line.lstrip() - - if stripped.startswith(('def ', 'class ', 'async def ')): - if current_chunk and len('\n'.join(current_chunk)) > 50: - chunks.append('\n'.join(current_chunk)) - current_chunk = [] - - current_chunk.append(line) - - if len('\n'.join(current_chunk)) > self.max_chunk_size: - chunks.append('\n'.join(current_chunk)) - current_chunk = [] - - if current_chunk: - chunks.append('\n'.join(current_chunk)) - - return chunks if chunks else [text] - - def _split_brace_language(self, text: str) -> List[str]: - """Split brace-based languages by functions/classes""" - chunks = [] - current_chunk = [] - brace_count = 0 - in_function = False - - lines = text.split('\n') - - for line in lines: - if any(keyword in line for keyword in - ['function ', 'class ', 'def ', 'fn ', 'func ', 'public ', 'private ', 'protected ']): - if '{' in line: - in_function = True - - current_chunk.append(line) - - brace_count += line.count('{') - line.count('}') - - if in_function and brace_count == 0 and len(current_chunk) > 3: - chunks.append('\n'.join(current_chunk)) - current_chunk = [] - in_function = False - - if len('\n'.join(current_chunk)) > self.max_chunk_size: - chunks.append('\n'.join(current_chunk)) - current_chunk = [] - brace_count = 0 - - if current_chunk: - chunks.append('\n'.join(current_chunk)) - - return chunks if chunks else [text] - diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py index 1694fcd9..7eebf78e 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py @@ -20,8 +20,7 @@ from ..models.config import RAGConfig, IndexStats from ..utils.utils import make_namespace, make_project_namespace -from .semantic_splitter import SemanticCodeSplitter -from .ast_splitter import ASTCodeSplitter +from .splitter import ASTCodeSplitter from .loader import DocumentLoader from .openrouter_embedding import OpenRouterEmbedding @@ -57,25 +56,15 @@ def __init__(self, config: RAGConfig): Settings.chunk_size = config.chunk_size Settings.chunk_overlap = config.chunk_overlap - # Choose splitter based on environment variable or config - # AST splitter provides better semantic chunking for supported languages - use_ast_splitter = os.environ.get('RAG_USE_AST_SPLITTER', 'true').lower() == 'true' - - if use_ast_splitter: - logger.info("Using ASTCodeSplitter for code chunking (tree-sitter based)") - self.splitter = ASTCodeSplitter( - max_chunk_size=config.chunk_size, - min_chunk_size=min(200, config.chunk_size // 4), - chunk_overlap=config.chunk_overlap, - parser_threshold=10 # Minimum lines for AST parsing - ) - else: - logger.info("Using SemanticCodeSplitter for code chunking (regex-based)") - self.splitter = SemanticCodeSplitter( - max_chunk_size=config.chunk_size, - min_chunk_size=min(200, config.chunk_size // 4), - overlap=config.chunk_overlap - ) + # AST splitter with tree-sitter query-based parsing + # Falls back internally when tree-sitter unavailable + logger.info("Using ASTCodeSplitter for code chunking (tree-sitter query-based)") + self.splitter = ASTCodeSplitter( + max_chunk_size=config.chunk_size, + min_chunk_size=min(200, config.chunk_size // 4), + chunk_overlap=config.chunk_overlap, + parser_threshold=10 # Minimum lines for AST parsing + ) self.loader = DocumentLoader(config) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/__init__.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/__init__.py new file mode 100644 index 00000000..d06d664f --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/__init__.py @@ -0,0 +1,10 @@ +""" +Index Manager module for RAG indexing operations. + +Provides RAGIndexManager as the main entry point for indexing repositories, +managing collections, and handling branch-level operations. +""" + +from .manager import RAGIndexManager + +__all__ = ["RAGIndexManager"] diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/branch_manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/branch_manager.py new file mode 100644 index 00000000..3782a4f6 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/branch_manager.py @@ -0,0 +1,172 @@ +""" +Branch-level operations for RAG indices. + +Handles branch-specific point management within project collections. +""" + +import logging +from typing import List, Set, Optional + +from qdrant_client import QdrantClient +from qdrant_client.models import Filter, FieldCondition, MatchValue, PointStruct + +logger = logging.getLogger(__name__) + + +class BranchManager: + """Manages branch-level operations within project collections.""" + + def __init__(self, client: QdrantClient): + self.client = client + + def delete_branch_points( + self, + collection_name: str, + branch: str + ) -> bool: + """Delete all points for a specific branch from the collection.""" + logger.info(f"Deleting all points for branch '{branch}' from {collection_name}") + + try: + self.client.delete( + collection_name=collection_name, + points_selector=Filter( + must=[ + FieldCondition( + key="branch", + match=MatchValue(value=branch) + ) + ] + ) + ) + logger.info(f"Successfully deleted all points for branch '{branch}'") + return True + except Exception as e: + logger.error(f"Failed to delete branch '{branch}': {e}") + return False + + def get_branch_point_count( + self, + collection_name: str, + branch: str + ) -> int: + """Get the number of points for a specific branch.""" + try: + result = self.client.count( + collection_name=collection_name, + count_filter=Filter( + must=[ + FieldCondition( + key="branch", + match=MatchValue(value=branch) + ) + ] + ) + ) + return result.count + except Exception as e: + logger.error(f"Failed to get point count for branch '{branch}': {e}") + return 0 + + def get_indexed_branches(self, collection_name: str) -> List[str]: + """Get list of branches that have points in the collection.""" + try: + branches: Set[str] = set() + offset = None + limit = 100 + + while True: + results = self.client.scroll( + collection_name=collection_name, + limit=limit, + offset=offset, + with_payload=["branch"], + with_vectors=False + ) + + points, next_offset = results + + for point in points: + if point.payload and "branch" in point.payload: + branches.add(point.payload["branch"]) + + if next_offset is None or len(points) < limit: + break + offset = next_offset + + return list(branches) + except Exception as e: + logger.error(f"Failed to get indexed branches: {e}") + return [] + + def preserve_other_branch_points( + self, + collection_name: str, + exclude_branch: str + ) -> List[PointStruct]: + """Preserve points from branches other than the one being reindexed. + + Used during full reindex to keep data from other branches. + """ + logger.info(f"Preserving points from branches other than '{exclude_branch}'...") + + preserved_points = [] + offset = None + + try: + while True: + results = self.client.scroll( + collection_name=collection_name, + limit=100, + offset=offset, + scroll_filter=Filter( + must_not=[ + FieldCondition( + key="branch", + match=MatchValue(value=exclude_branch) + ) + ] + ), + with_payload=True, + with_vectors=True + ) + points, next_offset = results + preserved_points.extend(points) + + if next_offset is None or len(points) < 100: + break + offset = next_offset + + logger.info(f"Found {len(preserved_points)} points from other branches to preserve") + return preserved_points + except Exception as e: + logger.warning(f"Could not read existing points: {e}") + return [] + + def copy_points_to_collection( + self, + points: List, + target_collection: str, + batch_size: int = 50 + ) -> None: + """Copy preserved points to a new collection.""" + if not points: + return + + logger.info(f"Copying {len(points)} points to {target_collection}...") + + for i in range(0, len(points), batch_size): + batch = points[i:i + batch_size] + points_to_upsert = [ + PointStruct( + id=p.id, + vector=p.vector, + payload=p.payload + ) for p in batch + ] + self.client.upsert( + collection_name=target_collection, + points=points_to_upsert + ) + + logger.info("Points copied successfully") diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/collection_manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/collection_manager.py new file mode 100644 index 00000000..81a3a407 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/collection_manager.py @@ -0,0 +1,164 @@ +""" +Qdrant collection and alias management utilities. + +Handles collection creation, alias operations, and resolution. +""" + +import logging +import time +from typing import Optional, List + +from qdrant_client import QdrantClient +from qdrant_client.models import ( + Distance, VectorParams, + CreateAlias, DeleteAlias, CreateAliasOperation, DeleteAliasOperation +) + +logger = logging.getLogger(__name__) + + +class CollectionManager: + """Manages Qdrant collections and aliases.""" + + def __init__(self, client: QdrantClient, embedding_dim: int): + self.client = client + self.embedding_dim = embedding_dim + + def ensure_collection_exists(self, collection_name: str) -> None: + """Ensure Qdrant collection exists with proper configuration. + + If the collection_name is actually an alias, use the aliased collection instead. + """ + if self.alias_exists(collection_name): + logger.info(f"Collection name {collection_name} is an alias, using existing aliased collection") + return + + collections = self.client.get_collections().collections + collection_names = [c.name for c in collections] + logger.debug(f"Existing collections: {collection_names}") + + if collection_name not in collection_names: + logger.info(f"Creating Qdrant collection: {collection_name}") + self.client.create_collection( + collection_name=collection_name, + vectors_config=VectorParams( + size=self.embedding_dim, + distance=Distance.COSINE + ) + ) + logger.info(f"Created collection {collection_name}") + else: + logger.info(f"Collection {collection_name} already exists") + + def create_versioned_collection(self, base_name: str) -> str: + """Create a new versioned collection for atomic swap indexing.""" + versioned_name = f"{base_name}_v{int(time.time())}" + logger.info(f"Creating versioned collection: {versioned_name}") + + self.client.create_collection( + collection_name=versioned_name, + vectors_config=VectorParams( + size=self.embedding_dim, + distance=Distance.COSINE + ) + ) + return versioned_name + + def delete_collection(self, collection_name: str) -> bool: + """Delete a collection.""" + try: + self.client.delete_collection(collection_name) + logger.info(f"Deleted collection: {collection_name}") + return True + except Exception as e: + logger.warning(f"Failed to delete collection {collection_name}: {e}") + return False + + def collection_exists(self, collection_name: str) -> bool: + """Check if a collection exists (not alias).""" + collections = self.client.get_collections().collections + return collection_name in [c.name for c in collections] + + def get_collection_names(self) -> List[str]: + """Get all collection names.""" + collections = self.client.get_collections().collections + return [c.name for c in collections] + + # Alias operations + + def alias_exists(self, alias_name: str) -> bool: + """Check if an alias exists.""" + try: + aliases = self.client.get_aliases() + exists = any(a.alias_name == alias_name for a in aliases.aliases) + logger.debug(f"Checking if alias '{alias_name}' exists: {exists}") + return exists + except Exception as e: + logger.warning(f"Error checking alias {alias_name}: {e}") + return False + + def resolve_alias(self, alias_name: str) -> Optional[str]: + """Resolve an alias to its underlying collection name.""" + try: + aliases = self.client.get_aliases() + for alias in aliases.aliases: + if alias.alias_name == alias_name: + return alias.collection_name + except Exception as e: + logger.debug(f"Error resolving alias {alias_name}: {e}") + return None + + def atomic_alias_swap( + self, + alias_name: str, + new_collection: str, + old_alias_exists: bool + ) -> None: + """Perform atomic alias swap for zero-downtime reindexing.""" + alias_operations = [] + + if old_alias_exists: + alias_operations.append( + DeleteAliasOperation(delete_alias=DeleteAlias(alias_name=alias_name)) + ) + + alias_operations.append( + CreateAliasOperation(create_alias=CreateAlias( + alias_name=alias_name, + collection_name=new_collection + )) + ) + + self.client.update_collection_aliases( + change_aliases_operations=alias_operations + ) + logger.info(f"Alias swap completed: {alias_name} -> {new_collection}") + + def delete_alias(self, alias_name: str) -> bool: + """Delete an alias.""" + try: + self.client.delete_alias(alias_name) + logger.info(f"Deleted alias: {alias_name}") + return True + except Exception as e: + logger.warning(f"Failed to delete alias {alias_name}: {e}") + return False + + def cleanup_orphaned_versioned_collections( + self, + base_name: str, + current_target: Optional[str] = None, + exclude_name: Optional[str] = None + ) -> int: + """Clean up orphaned versioned collections from failed indexing attempts.""" + cleaned = 0 + collection_names = self.get_collection_names() + + for coll_name in collection_names: + if coll_name.startswith(f"{base_name}_v") and coll_name != exclude_name: + if current_target != coll_name: + logger.info(f"Cleaning up orphaned versioned collection: {coll_name}") + if self.delete_collection(coll_name): + cleaned += 1 + + return cleaned diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py new file mode 100644 index 00000000..2bf145a4 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py @@ -0,0 +1,398 @@ +""" +Repository indexing operations. + +Handles full repository indexing with atomic swap and streaming processing. +""" + +import gc +import logging +import time +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional, List + +from qdrant_client.models import Filter, FieldCondition, MatchAny, MatchValue + +from ...models.config import RAGConfig, IndexStats +from ...utils.utils import make_namespace +from .collection_manager import CollectionManager +from .branch_manager import BranchManager +from .point_operations import PointOperations +from .stats_manager import StatsManager + +logger = logging.getLogger(__name__) + +# Memory-efficient batch sizes +DOCUMENT_BATCH_SIZE = 50 +INSERT_BATCH_SIZE = 50 + + +class RepositoryIndexer: + """Handles repository indexing operations.""" + + def __init__( + self, + config: RAGConfig, + collection_manager: CollectionManager, + branch_manager: BranchManager, + point_ops: PointOperations, + stats_manager: StatsManager, + splitter, + loader + ): + self.config = config + self.collection_manager = collection_manager + self.branch_manager = branch_manager + self.point_ops = point_ops + self.stats_manager = stats_manager + self.splitter = splitter + self.loader = loader + + def estimate_repository_size( + self, + repo_path: str, + exclude_patterns: Optional[List[str]] = None + ) -> tuple[int, int]: + """Estimate repository size (file count and chunk count) without actually indexing.""" + logger.info(f"Estimating repository size for: {repo_path}") + + repo_path_obj = Path(repo_path) + file_list = list(self.loader.iter_repository_files(repo_path_obj, exclude_patterns)) + file_count = len(file_list) + logger.info(f"Found {file_count} files for estimation") + + if file_count == 0: + return 0, 0 + + SAMPLE_SIZE = 100 + chunk_count = 0 + + if file_count <= SAMPLE_SIZE: + for i in range(0, file_count, DOCUMENT_BATCH_SIZE): + batch = file_list[i:i + DOCUMENT_BATCH_SIZE] + documents = self.loader.load_file_batch( + batch, repo_path_obj, "estimate", "estimate", "estimate", "estimate" + ) + if documents: + chunks = self.splitter.split_documents(documents) + chunk_count += len(chunks) + del chunks + del documents + gc.collect() + else: + import random + sample_files = random.sample(file_list, SAMPLE_SIZE) + sample_chunk_count = 0 + + for i in range(0, len(sample_files), DOCUMENT_BATCH_SIZE): + batch = sample_files[i:i + DOCUMENT_BATCH_SIZE] + documents = self.loader.load_file_batch( + batch, repo_path_obj, "estimate", "estimate", "estimate", "estimate" + ) + if documents: + chunks = self.splitter.split_documents(documents) + sample_chunk_count += len(chunks) + del chunks + del documents + + avg_chunks_per_file = sample_chunk_count / SAMPLE_SIZE + chunk_count = int(avg_chunks_per_file * file_count) + logger.info(f"Estimated ~{avg_chunks_per_file:.1f} chunks/file from {SAMPLE_SIZE} samples") + gc.collect() + + logger.info(f"Estimated {chunk_count} chunks from {file_count} files") + return file_count, chunk_count + + def index_repository( + self, + repo_path: str, + workspace: str, + project: str, + branch: str, + commit: str, + alias_name: str, + exclude_patterns: Optional[List[str]] = None + ) -> IndexStats: + """Index entire repository for a branch using atomic swap strategy.""" + logger.info(f"Indexing repository: {workspace}/{project}/{branch} from {repo_path}") + + repo_path_obj = Path(repo_path) + temp_collection_name = self.collection_manager.create_versioned_collection(alias_name) + + # Check existing collection and preserve other branch data + old_collection_exists = self.collection_manager.alias_exists(alias_name) + if not old_collection_exists: + old_collection_exists = self.collection_manager.collection_exists(alias_name) + + existing_other_branch_points = [] + if old_collection_exists: + actual_collection = self.collection_manager.resolve_alias(alias_name) or alias_name + existing_other_branch_points = self.branch_manager.preserve_other_branch_points( + actual_collection, branch + ) + + # Clean up orphaned versioned collections + current_target = self.collection_manager.resolve_alias(alias_name) + self.collection_manager.cleanup_orphaned_versioned_collections( + alias_name, current_target, temp_collection_name + ) + + # Get file list + file_list = list(self.loader.iter_repository_files(repo_path_obj, exclude_patterns)) + total_files = len(file_list) + logger.info(f"Found {total_files} files to index for branch '{branch}'") + + if total_files == 0: + logger.warning("No documents to index") + self.collection_manager.delete_collection(temp_collection_name) + return self.stats_manager.get_branch_stats( + workspace, project, branch, + self.collection_manager.resolve_alias(alias_name) or alias_name + ) + + # Validate limits + if self.config.max_files_per_index > 0 and total_files > self.config.max_files_per_index: + self.collection_manager.delete_collection(temp_collection_name) + raise ValueError( + f"Repository exceeds file limit: {total_files} files (max: {self.config.max_files_per_index})." + ) + + if self.config.max_chunks_per_index > 0: + logger.info("Estimating chunk count before indexing...") + _, estimated_chunks = self.estimate_repository_size(repo_path, exclude_patterns) + if estimated_chunks > self.config.max_chunks_per_index * 1.2: + self.collection_manager.delete_collection(temp_collection_name) + raise ValueError( + f"Repository estimated to exceed chunk limit: ~{estimated_chunks} chunks (max: {self.config.max_chunks_per_index})." + ) + + document_count = 0 + chunk_count = 0 + successful_chunks = 0 + failed_chunks = 0 + + try: + # Copy preserved points from other branches + if existing_other_branch_points: + self.branch_manager.copy_points_to_collection( + existing_other_branch_points, + temp_collection_name, + INSERT_BATCH_SIZE + ) + + # Stream process files in batches + logger.info("Starting memory-efficient streaming indexing...") + batch_num = 0 + total_batches = (total_files + DOCUMENT_BATCH_SIZE - 1) // DOCUMENT_BATCH_SIZE + + for i in range(0, total_files, DOCUMENT_BATCH_SIZE): + batch_num += 1 + file_batch = file_list[i:i + DOCUMENT_BATCH_SIZE] + + documents = self.loader.load_file_batch( + file_batch, repo_path_obj, workspace, project, branch, commit + ) + document_count += len(documents) + + if not documents: + continue + + chunks = self.splitter.split_documents(documents) + batch_chunk_count = len(chunks) + chunk_count += batch_chunk_count + + # Check chunk limit + if self.config.max_chunks_per_index > 0 and chunk_count > self.config.max_chunks_per_index: + self.collection_manager.delete_collection(temp_collection_name) + raise ValueError(f"Repository exceeds chunk limit: {chunk_count}+ chunks.") + + # Process and upsert + success, failed = self.point_ops.process_and_upsert_chunks( + chunks, temp_collection_name, workspace, project, branch + ) + successful_chunks += success + failed_chunks += failed + + logger.info( + f"Batch {batch_num}/{total_batches}: processed {len(documents)} files, " + f"{batch_chunk_count} chunks" + ) + + del documents + del chunks + + if batch_num % 5 == 0: + gc.collect() + + logger.info( + f"Streaming indexing complete: {document_count} files, " + f"{successful_chunks}/{chunk_count} chunks indexed ({failed_chunks} failed)" + ) + + # Verify and perform atomic swap + temp_info = self.point_ops.client.get_collection(temp_collection_name) + if temp_info.points_count == 0: + raise Exception("Temporary collection is empty after indexing") + + self._perform_atomic_swap( + alias_name, temp_collection_name, old_collection_exists + ) + + except Exception as e: + logger.error(f"Indexing failed: {e}") + self.collection_manager.delete_collection(temp_collection_name) + raise e + finally: + del existing_other_branch_points + gc.collect() + + self.stats_manager.store_metadata( + workspace, project, branch, commit, document_count, chunk_count + ) + + namespace = make_namespace(workspace, project, branch) + return IndexStats( + namespace=namespace, + document_count=document_count, + chunk_count=successful_chunks, + last_updated=datetime.now(timezone.utc).isoformat(), + workspace=workspace, + project=project, + branch=branch + ) + + def _perform_atomic_swap( + self, + alias_name: str, + temp_collection_name: str, + old_collection_exists: bool + ) -> None: + """Perform atomic alias swap with migration handling.""" + logger.info("Performing atomic alias swap...") + + is_direct_collection = ( + self.collection_manager.collection_exists(alias_name) and + not self.collection_manager.alias_exists(alias_name) + ) + + old_versioned_name = None + if old_collection_exists and not is_direct_collection: + old_versioned_name = self.collection_manager.resolve_alias(alias_name) + + try: + self.collection_manager.atomic_alias_swap( + alias_name, temp_collection_name, + old_collection_exists and not is_direct_collection + ) + except Exception as alias_err: + if is_direct_collection and "already exists" in str(alias_err).lower(): + logger.info("Migrating from direct collection to alias-based indexing...") + self.collection_manager.delete_collection(alias_name) + self.collection_manager.atomic_alias_swap(alias_name, temp_collection_name, False) + else: + raise alias_err + + if old_versioned_name and old_versioned_name != temp_collection_name: + self.collection_manager.delete_collection(old_versioned_name) + + +class FileOperations: + """Handles individual file update and delete operations.""" + + def __init__( + self, + client, + point_ops: PointOperations, + collection_manager: CollectionManager, + stats_manager: StatsManager, + splitter, + loader + ): + self.client = client + self.point_ops = point_ops + self.collection_manager = collection_manager + self.stats_manager = stats_manager + self.splitter = splitter + self.loader = loader + + def update_files( + self, + file_paths: List[str], + repo_base: str, + workspace: str, + project: str, + branch: str, + commit: str, + collection_name: str + ) -> IndexStats: + """Update specific files in the index (Delete Old -> Insert New).""" + logger.info(f"Updating {len(file_paths)} files in {workspace}/{project} for branch '{branch}'") + + repo_base_obj = Path(repo_base) + file_path_objs = [Path(fp) for fp in file_paths] + + self.collection_manager.ensure_collection_exists(collection_name) + + # Delete old chunks for these files and branch + logger.info(f"Purging existing vectors for {len(file_paths)} files in branch '{branch}'...") + self.client.delete( + collection_name=collection_name, + points_selector=Filter( + must=[ + FieldCondition(key="path", match=MatchAny(any=file_paths)), + FieldCondition(key="branch", match=MatchValue(value=branch)) + ] + ) + ) + + # Load and split new content + documents = self.loader.load_specific_files( + file_paths=file_path_objs, + repo_base=repo_base_obj, + workspace=workspace, + project=project, + branch=branch, + commit=commit + ) + + if not documents: + logger.warning("No documents loaded from provided paths.") + return self.stats_manager.get_project_stats(workspace, project, collection_name) + + chunks = self.splitter.split_documents(documents) + logger.info(f"Generated {len(chunks)} new chunks") + + # Process and upsert + self.point_ops.process_and_upsert_chunks( + chunks, collection_name, workspace, project, branch + ) + + logger.info(f"Successfully updated {len(chunks)} chunks for branch '{branch}'") + return self.stats_manager.get_project_stats(workspace, project, collection_name) + + def delete_files( + self, + file_paths: List[str], + workspace: str, + project: str, + branch: str, + collection_name: str + ) -> IndexStats: + """Delete specific files from the index for a specific branch.""" + logger.info(f"Deleting {len(file_paths)} files from {workspace}/{project} branch '{branch}'") + + try: + self.client.delete( + collection_name=collection_name, + points_selector=Filter( + must=[ + FieldCondition(key="path", match=MatchAny(any=file_paths)), + FieldCondition(key="branch", match=MatchValue(value=branch)) + ] + ) + ) + logger.info(f"Deleted {len(file_paths)} files from branch '{branch}'") + except Exception as e: + logger.warning(f"Error deleting files: {e}") + + return self.stats_manager.get_project_stats(workspace, project, collection_name) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/manager.py new file mode 100644 index 00000000..7197f2ad --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/manager.py @@ -0,0 +1,290 @@ +""" +Main RAG Index Manager. + +Composes all index management components and provides the public API. +""" + +import logging +from typing import Optional, List + +from llama_index.core import Settings +from qdrant_client import QdrantClient + +from ...models.config import RAGConfig, IndexStats +from ...utils.utils import make_namespace, make_project_namespace +from ..splitter import ASTCodeSplitter +from ..loader import DocumentLoader +from ..openrouter_embedding import OpenRouterEmbedding + +from .collection_manager import CollectionManager +from .branch_manager import BranchManager +from .point_operations import PointOperations +from .stats_manager import StatsManager +from .indexer import RepositoryIndexer, FileOperations + +logger = logging.getLogger(__name__) + + +class RAGIndexManager: + """Manage RAG indices for code repositories using Qdrant. + + This is the main entry point for all indexing operations. + """ + + def __init__(self, config: RAGConfig): + self.config = config + + # Qdrant client + self.qdrant_client = QdrantClient(url=config.qdrant_url) + logger.info(f"Connected to Qdrant at {config.qdrant_url}") + + # Embedding model + self.embed_model = OpenRouterEmbedding( + api_key=config.openrouter_api_key, + model=config.openrouter_model, + api_base=config.openrouter_base_url, + timeout=60.0, + max_retries=3, + expected_dim=config.embedding_dim + ) + + # Global settings + Settings.embed_model = self.embed_model + Settings.chunk_size = config.chunk_size + Settings.chunk_overlap = config.chunk_overlap + + # Splitter and loader + logger.info("Using ASTCodeSplitter for code chunking (tree-sitter query-based)") + self.splitter = ASTCodeSplitter( + max_chunk_size=config.chunk_size, + min_chunk_size=min(200, config.chunk_size // 4), + chunk_overlap=config.chunk_overlap, + parser_threshold=10 + ) + self.loader = DocumentLoader(config) + + # Component managers + self._collection_manager = CollectionManager( + self.qdrant_client, config.embedding_dim + ) + self._branch_manager = BranchManager(self.qdrant_client) + self._point_ops = PointOperations( + self.qdrant_client, self.embed_model, batch_size=50 + ) + self._stats_manager = StatsManager( + self.qdrant_client, config.qdrant_collection_prefix + ) + + # Higher-level operations + self._indexer = RepositoryIndexer( + config=config, + collection_manager=self._collection_manager, + branch_manager=self._branch_manager, + point_ops=self._point_ops, + stats_manager=self._stats_manager, + splitter=self.splitter, + loader=self.loader + ) + self._file_ops = FileOperations( + client=self.qdrant_client, + point_ops=self._point_ops, + collection_manager=self._collection_manager, + stats_manager=self._stats_manager, + splitter=self.splitter, + loader=self.loader + ) + + # Collection naming + + def _get_project_collection_name(self, workspace: str, project: str) -> str: + """Generate Qdrant collection name from workspace/project.""" + namespace = make_project_namespace(workspace, project) + return f"{self.config.qdrant_collection_prefix}_{namespace}" + + def _get_collection_name(self, workspace: str, project: str, branch: str) -> str: + """Generate collection name (DEPRECATED - use _get_project_collection_name).""" + namespace = make_namespace(workspace, project, branch) + return f"{self.config.qdrant_collection_prefix}_{namespace}" + + # Repository indexing + + def estimate_repository_size( + self, + repo_path: str, + exclude_patterns: Optional[List[str]] = None + ) -> tuple[int, int]: + """Estimate repository size (file count and chunk count).""" + return self._indexer.estimate_repository_size(repo_path, exclude_patterns) + + def index_repository( + self, + repo_path: str, + workspace: str, + project: str, + branch: str, + commit: str, + exclude_patterns: Optional[List[str]] = None + ) -> IndexStats: + """Index entire repository for a branch using atomic swap strategy.""" + alias_name = self._get_project_collection_name(workspace, project) + return self._indexer.index_repository( + repo_path=repo_path, + workspace=workspace, + project=project, + branch=branch, + commit=commit, + alias_name=alias_name, + exclude_patterns=exclude_patterns + ) + + # File operations + + def update_files( + self, + file_paths: List[str], + repo_base: str, + workspace: str, + project: str, + branch: str, + commit: str + ) -> IndexStats: + """Update specific files in the index (Delete Old -> Insert New).""" + collection_name = self._get_project_collection_name(workspace, project) + return self._file_ops.update_files( + file_paths=file_paths, + repo_base=repo_base, + workspace=workspace, + project=project, + branch=branch, + commit=commit, + collection_name=collection_name + ) + + def delete_files( + self, + file_paths: List[str], + workspace: str, + project: str, + branch: str + ) -> IndexStats: + """Delete specific files from the index for a specific branch.""" + collection_name = self._get_project_collection_name(workspace, project) + return self._file_ops.delete_files( + file_paths=file_paths, + workspace=workspace, + project=project, + branch=branch, + collection_name=collection_name + ) + + # Branch operations + + def delete_branch(self, workspace: str, project: str, branch: str) -> bool: + """Delete all points for a specific branch from the project collection.""" + collection_name = self._get_project_collection_name(workspace, project) + + if not self._collection_manager.collection_exists(collection_name): + if not self._collection_manager.alias_exists(collection_name): + logger.warning(f"Collection {collection_name} does not exist") + return False + + return self._branch_manager.delete_branch_points(collection_name, branch) + + def get_branch_point_count(self, workspace: str, project: str, branch: str) -> int: + """Get the number of points for a specific branch.""" + collection_name = self._get_project_collection_name(workspace, project) + + if not self._collection_manager.collection_exists(collection_name): + if not self._collection_manager.alias_exists(collection_name): + return 0 + + return self._branch_manager.get_branch_point_count(collection_name, branch) + + def get_indexed_branches(self, workspace: str, project: str) -> List[str]: + """Get list of branches that have points in the collection.""" + collection_name = self._get_project_collection_name(workspace, project) + + if not self._collection_manager.collection_exists(collection_name): + if not self._collection_manager.alias_exists(collection_name): + return [] + + return self._branch_manager.get_indexed_branches(collection_name) + + # Index management + + def delete_index(self, workspace: str, project: str, branch: str): + """Delete branch data from project index.""" + if branch and branch != "*": + self.delete_branch(workspace, project, branch) + else: + self.delete_project_index(workspace, project) + + def delete_project_index(self, workspace: str, project: str): + """Delete entire project collection (all branches).""" + collection_name = self._get_project_collection_name(workspace, project) + namespace = make_project_namespace(workspace, project) + + logger.info(f"Deleting entire project index for {namespace}") + + try: + if self._collection_manager.alias_exists(collection_name): + actual_collection = self._collection_manager.resolve_alias(collection_name) + self._collection_manager.delete_alias(collection_name) + if actual_collection: + self._collection_manager.delete_collection(actual_collection) + else: + self._collection_manager.delete_collection(collection_name) + logger.info(f"Deleted Qdrant collection: {collection_name}") + except Exception as e: + logger.warning(f"Failed to delete Qdrant collection: {e}") + + # Statistics + + def _get_index_stats(self, workspace: str, project: str, branch: str) -> IndexStats: + """Get statistics about a branch index (backward compatibility).""" + return self._get_branch_index_stats(workspace, project, branch) + + def _get_branch_index_stats(self, workspace: str, project: str, branch: str) -> IndexStats: + """Get statistics about a specific branch within a project collection.""" + collection_name = self._get_project_collection_name(workspace, project) + return self._stats_manager.get_branch_stats( + workspace, project, branch, collection_name + ) + + def _get_project_index_stats(self, workspace: str, project: str) -> IndexStats: + """Get statistics about a project's index (all branches combined).""" + collection_name = self._get_project_collection_name(workspace, project) + return self._stats_manager.get_project_stats(workspace, project, collection_name) + + def list_indices(self) -> List[IndexStats]: + """List all project indices with branch breakdown.""" + return self._stats_manager.list_all_indices( + self._collection_manager.alias_exists + ) + + # Legacy/compatibility methods + + def _ensure_collection_exists(self, collection_name: str): + """Ensure Qdrant collection exists (legacy compatibility).""" + self._collection_manager.ensure_collection_exists(collection_name) + + def _alias_exists(self, alias_name: str) -> bool: + """Check if an alias exists (legacy compatibility).""" + return self._collection_manager.alias_exists(alias_name) + + def _resolve_alias_to_collection(self, alias_name: str) -> Optional[str]: + """Resolve an alias to its collection (legacy compatibility).""" + return self._collection_manager.resolve_alias(alias_name) + + def _generate_point_id( + self, + workspace: str, + project: str, + branch: str, + path: str, + chunk_index: int + ) -> str: + """Generate deterministic point ID (legacy compatibility).""" + return PointOperations.generate_point_id( + workspace, project, branch, path, chunk_index + ) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py new file mode 100644 index 00000000..b9682bf0 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py @@ -0,0 +1,151 @@ +""" +Point operations for embedding and upserting vectors. + +Handles embedding generation, point creation, and batch upsert operations. +""" + +import logging +import uuid +from datetime import datetime, timezone +from typing import List, Dict, Tuple + +from llama_index.core.schema import TextNode +from qdrant_client import QdrantClient +from qdrant_client.models import PointStruct + +logger = logging.getLogger(__name__) + + +class PointOperations: + """Handles point embedding and upsert operations.""" + + def __init__(self, client: QdrantClient, embed_model, batch_size: int = 50): + self.client = client + self.embed_model = embed_model + self.batch_size = batch_size + + @staticmethod + def generate_point_id( + workspace: str, + project: str, + branch: str, + path: str, + chunk_index: int + ) -> str: + """Generate deterministic point ID for upsert (same content = same ID = replace).""" + key = f"{workspace}:{project}:{branch}:{path}:{chunk_index}" + return str(uuid.uuid5(uuid.NAMESPACE_DNS, key)) + + def prepare_chunks_for_embedding( + self, + chunks: List[TextNode], + workspace: str, + project: str, + branch: str + ) -> List[Tuple[str, TextNode]]: + """Prepare chunks with deterministic IDs for embedding. + + Returns list of (point_id, chunk) tuples. + """ + # Group chunks by file path + chunks_by_file: Dict[str, List[TextNode]] = {} + for chunk in chunks: + path = chunk.metadata.get("path", "unknown") + if path not in chunks_by_file: + chunks_by_file[path] = [] + chunks_by_file[path].append(chunk) + + # Assign deterministic IDs + chunk_data = [] + for path, file_chunks in chunks_by_file.items(): + for chunk_index, chunk in enumerate(file_chunks): + point_id = self.generate_point_id(workspace, project, branch, path, chunk_index) + chunk.metadata["indexed_at"] = datetime.now(timezone.utc).isoformat() + chunk_data.append((point_id, chunk)) + + return chunk_data + + def embed_and_create_points( + self, + chunk_data: List[Tuple[str, TextNode]] + ) -> List[PointStruct]: + """Embed chunks and create Qdrant points. + + Args: + chunk_data: List of (point_id, chunk) tuples + + Returns: + List of PointStruct ready for upsert + """ + if not chunk_data: + return [] + + # Batch embed all chunks at once + texts_to_embed = [chunk.text for _, chunk in chunk_data] + embeddings = self.embed_model.get_text_embedding_batch(texts_to_embed) + + # Build points with embeddings + points = [] + for (point_id, chunk), embedding in zip(chunk_data, embeddings): + points.append(PointStruct( + id=point_id, + vector=embedding, + payload={ + **chunk.metadata, + "text": chunk.text, + "_node_content": chunk.text, + } + )) + + return points + + def upsert_points( + self, + collection_name: str, + points: List[PointStruct] + ) -> Tuple[int, int]: + """Upsert points to collection in batches. + + Returns: + Tuple of (successful_count, failed_count) + """ + successful = 0 + failed = 0 + + for i in range(0, len(points), self.batch_size): + batch = points[i:i + self.batch_size] + try: + self.client.upsert( + collection_name=collection_name, + points=batch + ) + successful += len(batch) + except Exception as e: + logger.error(f"Failed to upsert batch starting at {i}: {e}") + failed += len(batch) + + return successful, failed + + def process_and_upsert_chunks( + self, + chunks: List[TextNode], + collection_name: str, + workspace: str, + project: str, + branch: str + ) -> Tuple[int, int]: + """Full pipeline: prepare, embed, and upsert chunks. + + Returns: + Tuple of (successful_count, failed_count) + """ + # Prepare chunks with IDs + chunk_data = self.prepare_chunks_for_embedding( + chunks, workspace, project, branch + ) + + # Embed and create points + points = self.embed_and_create_points(chunk_data) + + # Upsert to collection + return self.upsert_points(collection_name, points) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/stats_manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/stats_manager.py new file mode 100644 index 00000000..226f021f --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/stats_manager.py @@ -0,0 +1,156 @@ +""" +Index statistics and metadata operations. +""" + +import logging +from datetime import datetime, timezone +from typing import List, Optional + +from qdrant_client import QdrantClient +from qdrant_client.models import Filter, FieldCondition, MatchValue + +from ...models.config import IndexStats +from ...utils.utils import make_namespace, make_project_namespace + +logger = logging.getLogger(__name__) + + +class StatsManager: + """Manages index statistics and metadata.""" + + def __init__(self, client: QdrantClient, collection_prefix: str): + self.client = client + self.collection_prefix = collection_prefix + + def get_branch_stats( + self, + workspace: str, + project: str, + branch: str, + collection_name: str + ) -> IndexStats: + """Get statistics about a specific branch within a project collection.""" + namespace = make_namespace(workspace, project, branch) + + try: + count_result = self.client.count( + collection_name=collection_name, + count_filter=Filter( + must=[ + FieldCondition( + key="branch", + match=MatchValue(value=branch) + ) + ] + ) + ) + chunk_count = count_result.count + + return IndexStats( + namespace=namespace, + document_count=0, + chunk_count=chunk_count, + last_updated=datetime.now(timezone.utc).isoformat(), + workspace=workspace, + project=project, + branch=branch + ) + except Exception: + return IndexStats( + namespace=namespace, + document_count=0, + chunk_count=0, + last_updated="", + workspace=workspace, + project=project, + branch=branch + ) + + def get_project_stats( + self, + workspace: str, + project: str, + collection_name: str + ) -> IndexStats: + """Get statistics about a project's index (all branches combined).""" + namespace = make_project_namespace(workspace, project) + + try: + collection_info = self.client.get_collection(collection_name) + chunk_count = collection_info.points_count + + return IndexStats( + namespace=namespace, + document_count=0, + chunk_count=chunk_count, + last_updated=datetime.now(timezone.utc).isoformat(), + workspace=workspace, + project=project, + branch="*" + ) + except Exception: + return IndexStats( + namespace=namespace, + document_count=0, + chunk_count=0, + last_updated="", + workspace=workspace, + project=project, + branch="*" + ) + + def list_all_indices(self, alias_checker) -> List[IndexStats]: + """List all project indices with branch breakdown. + + Args: + alias_checker: Function to check if name is an alias + """ + indices = [] + collections = self.client.get_collections().collections + + for collection in collections: + if collection.name.startswith(f"{self.collection_prefix}_"): + namespace = collection.name[len(f"{self.collection_prefix}_"):] + parts = namespace.split("__") + + if len(parts) == 2: + # New format: workspace__project + workspace, project = parts + stats = self.get_project_stats( + workspace, project, collection.name + ) + indices.append(stats) + elif len(parts) == 3: + # Legacy format: workspace__project__branch + workspace, project, branch = parts + stats = self.get_branch_stats( + workspace, project, branch, collection.name + ) + indices.append(stats) + + return indices + + def store_metadata( + self, + workspace: str, + project: str, + branch: str, + commit: str, + document_count: int, + chunk_count: int + ) -> None: + """Store/log metadata for an indexing operation.""" + namespace = make_namespace(workspace, project, branch) + + metadata = { + "namespace": namespace, + "workspace": workspace, + "project": project, + "branch": branch, + "commit": commit, + "document_count": document_count, + "chunk_count": chunk_count, + "last_updated": datetime.now(timezone.utc).isoformat(), + } + + logger.info(f"Indexed {namespace}: {document_count} docs, {chunk_count} chunks") diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/semantic_splitter.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/semantic_splitter.py deleted file mode 100644 index 349cdf23..00000000 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/semantic_splitter.py +++ /dev/null @@ -1,455 +0,0 @@ -""" -Semantic Code Splitter - Intelligent code splitting using LangChain's language-aware splitters. - -This module provides smart code chunking that: -1. Uses LangChain's RecursiveCharacterTextSplitter with language-specific separators -2. Supports 25+ programming languages out of the box -3. Enriches metadata with semantic information (function names, imports, etc.) -4. Falls back gracefully for unsupported languages -""" - -import re -import hashlib -import logging -from typing import List, Dict, Any, Optional -from dataclasses import dataclass, field -from enum import Enum - -from langchain_text_splitters import RecursiveCharacterTextSplitter, Language -from llama_index.core.schema import Document, TextNode - -logger = logging.getLogger(__name__) - - -class ChunkType(Enum): - """Type of code chunk for semantic understanding""" - CLASS = "class" - FUNCTION = "function" - METHOD = "method" - INTERFACE = "interface" - MODULE = "module" - IMPORTS = "imports" - CONSTANTS = "constants" - DOCUMENTATION = "documentation" - CONFIG = "config" - MIXED = "mixed" - UNKNOWN = "unknown" - - -@dataclass -class CodeBlock: - """Represents a logical block of code""" - content: str - chunk_type: ChunkType - name: Optional[str] = None - parent_name: Optional[str] = None - start_line: int = 0 - end_line: int = 0 - imports: List[str] = field(default_factory=list) - docstring: Optional[str] = None - signature: Optional[str] = None - - -# Map internal language names to LangChain Language enum -LANGUAGE_MAP: Dict[str, Language] = { - 'python': Language.PYTHON, - 'java': Language.JAVA, - 'kotlin': Language.KOTLIN, - 'javascript': Language.JS, - 'typescript': Language.TS, - 'go': Language.GO, - 'rust': Language.RUST, - 'php': Language.PHP, - 'ruby': Language.RUBY, - 'scala': Language.SCALA, - 'swift': Language.SWIFT, - 'c': Language.C, - 'cpp': Language.CPP, - 'csharp': Language.CSHARP, - 'markdown': Language.MARKDOWN, - 'html': Language.HTML, - 'latex': Language.LATEX, - 'rst': Language.RST, - 'lua': Language.LUA, - 'perl': Language.PERL, - 'haskell': Language.HASKELL, - 'solidity': Language.SOL, - 'proto': Language.PROTO, - 'cobol': Language.COBOL, -} - -# Patterns for metadata extraction -METADATA_PATTERNS = { - 'python': { - 'class': re.compile(r'^class\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'^(?:async\s+)?def\s+(\w+)\s*\(', re.MULTILINE), - 'import': re.compile(r'^(?:from\s+[\w.]+\s+)?import\s+.+$', re.MULTILINE), - 'docstring': re.compile(r'"""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\''), - }, - 'java': { - 'class': re.compile(r'(?:public\s+|private\s+|protected\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)', re.MULTILINE), - 'interface': re.compile(r'(?:public\s+)?interface\s+(\w+)', re.MULTILINE), - 'method': re.compile(r'(?:public|private|protected)\s+(?:static\s+)?(?:final\s+)?[\w<>,\s]+\s+(\w+)\s*\(', re.MULTILINE), - 'import': re.compile(r'^import\s+[\w.*]+;', re.MULTILINE), - }, - 'javascript': { - 'class': re.compile(r'(?:export\s+)?(?:default\s+)?class\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'(?:export\s+)?(?:async\s+)?function\s*\*?\s*(\w+)\s*\(', re.MULTILINE), - 'arrow': re.compile(r'(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?\([^)]*\)\s*=>', re.MULTILINE), - 'import': re.compile(r'^import\s+.*?from\s+[\'"]([^\'"]+)[\'"]', re.MULTILINE), - }, - 'typescript': { - 'class': re.compile(r'(?:export\s+)?(?:default\s+)?class\s+(\w+)', re.MULTILINE), - 'interface': re.compile(r'(?:export\s+)?interface\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'(?:export\s+)?(?:async\s+)?function\s*\*?\s*(\w+)\s*\(', re.MULTILINE), - 'type': re.compile(r'(?:export\s+)?type\s+(\w+)', re.MULTILINE), - 'import': re.compile(r'^import\s+.*?from\s+[\'"]([^\'"]+)[\'"]', re.MULTILINE), - }, - 'go': { - 'function': re.compile(r'^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(', re.MULTILINE), - 'struct': re.compile(r'^type\s+(\w+)\s+struct\s*\{', re.MULTILINE), - 'interface': re.compile(r'^type\s+(\w+)\s+interface\s*\{', re.MULTILINE), - }, - 'rust': { - 'function': re.compile(r'^(?:pub\s+)?(?:async\s+)?fn\s+(\w+)', re.MULTILINE), - 'struct': re.compile(r'^(?:pub\s+)?struct\s+(\w+)', re.MULTILINE), - 'impl': re.compile(r'^impl(?:<[^>]+>)?\s+(?:\w+\s+for\s+)?(\w+)', re.MULTILINE), - 'trait': re.compile(r'^(?:pub\s+)?trait\s+(\w+)', re.MULTILINE), - }, - 'php': { - 'class': re.compile(r'(?:abstract\s+|final\s+)?class\s+(\w+)', re.MULTILINE), - 'interface': re.compile(r'interface\s+(\w+)', re.MULTILINE), - 'function': re.compile(r'(?:public|private|protected|static|\s)*function\s+(\w+)\s*\(', re.MULTILINE), - }, - 'csharp': { - 'class': re.compile(r'(?:public\s+|private\s+|protected\s+)?(?:abstract\s+|sealed\s+)?class\s+(\w+)', re.MULTILINE), - 'interface': re.compile(r'(?:public\s+)?interface\s+(\w+)', re.MULTILINE), - 'method': re.compile(r'(?:public|private|protected)\s+(?:static\s+)?(?:async\s+)?[\w<>,\s]+\s+(\w+)\s*\(', re.MULTILINE), - }, -} - - -class SemanticCodeSplitter: - """ - Intelligent code splitter using LangChain's language-aware text splitters. - - Features: - - Uses LangChain's RecursiveCharacterTextSplitter with language-specific separators - - Supports 25+ programming languages (Python, Java, JS/TS, Go, Rust, PHP, etc.) - - Enriches chunks with semantic metadata (function names, classes, imports) - - Graceful fallback for unsupported languages - """ - - DEFAULT_CHUNK_SIZE = 1500 - DEFAULT_CHUNK_OVERLAP = 200 - DEFAULT_MIN_CHUNK_SIZE = 100 - - def __init__( - self, - max_chunk_size: int = DEFAULT_CHUNK_SIZE, - min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE, - overlap: int = DEFAULT_CHUNK_OVERLAP - ): - self.max_chunk_size = max_chunk_size - self.min_chunk_size = min_chunk_size - self.overlap = overlap - - # Cache splitters for reuse - self._splitter_cache: Dict[str, RecursiveCharacterTextSplitter] = {} - - # Default splitter for unknown languages - self._default_splitter = RecursiveCharacterTextSplitter( - chunk_size=max_chunk_size, - chunk_overlap=overlap, - length_function=len, - is_separator_regex=False, - ) - - @staticmethod - def _make_deterministic_id(namespace: str, path: str, chunk_index: int) -> str: - """Generate deterministic chunk ID for idempotent indexing""" - key = f"{namespace}:{path}:{chunk_index}" - return hashlib.sha256(key.encode()).hexdigest()[:32] - - def _get_splitter(self, language: str) -> RecursiveCharacterTextSplitter: - """Get or create a language-specific splitter""" - if language in self._splitter_cache: - return self._splitter_cache[language] - - lang_enum = LANGUAGE_MAP.get(language.lower()) - - if lang_enum: - splitter = RecursiveCharacterTextSplitter.from_language( - language=lang_enum, - chunk_size=self.max_chunk_size, - chunk_overlap=self.overlap, - ) - self._splitter_cache[language] = splitter - return splitter - - return self._default_splitter - - def split_documents(self, documents: List[Document]) -> List[TextNode]: - """Split documents into semantic chunks with enriched metadata""" - return list(self.iter_split_documents(documents)) - - def iter_split_documents(self, documents: List[Document]): - """Generator that yields chunks one at a time for memory efficiency""" - for doc in documents: - language = doc.metadata.get("language", "text") - path = doc.metadata.get("path", "unknown") - - try: - for node in self._split_document(doc, language): - yield node - except Exception as e: - logger.warning(f"Splitting failed for {path}: {e}, using fallback") - for node in self._fallback_split(doc): - yield node - - def _split_document(self, doc: Document, language: str) -> List[TextNode]: - """Split a single document using language-aware splitter""" - text = doc.text - - if not text or not text.strip(): - return [] - - # Get language-specific splitter - splitter = self._get_splitter(language) - - # Split the text - chunks = splitter.split_text(text) - - # Filter empty chunks and convert to nodes with metadata - nodes = [] - text_offset = 0 - - for i, chunk in enumerate(chunks): - if not chunk or not chunk.strip(): - continue - - # Skip very small chunks unless they're standalone - if len(chunk.strip()) < self.min_chunk_size and len(chunks) > 1: - # Try to find and merge with adjacent chunk - continue - - # Calculate approximate line numbers - start_line = text[:text_offset].count('\n') + 1 if text_offset > 0 else 1 - chunk_pos = text.find(chunk, text_offset) - if chunk_pos >= 0: - text_offset = chunk_pos + len(chunk) - end_line = start_line + chunk.count('\n') - - # Extract semantic metadata - metadata = self._extract_metadata(chunk, language, doc.metadata) - metadata.update({ - 'chunk_index': i, - 'total_chunks': len(chunks), - 'start_line': start_line, - 'end_line': end_line, - }) - - chunk_id = self._make_deterministic_id( - metadata.get('namespace', ''), - metadata.get('path', ''), - i - ) - node = TextNode( - id_=chunk_id, - text=chunk, - metadata=metadata - ) - nodes.append(node) - - return nodes - - def _extract_metadata( - self, - chunk: str, - language: str, - base_metadata: Dict[str, Any] - ) -> Dict[str, Any]: - """Extract semantic metadata from a code chunk""" - metadata = dict(base_metadata) - - # Determine chunk type and extract names - chunk_type = ChunkType.MIXED - names = [] - imports = [] - - patterns = METADATA_PATTERNS.get(language.lower(), {}) - - # Check for classes - if 'class' in patterns: - matches = patterns['class'].findall(chunk) - if matches: - chunk_type = ChunkType.CLASS - names.extend(matches) - - # Check for interfaces - if 'interface' in patterns: - matches = patterns['interface'].findall(chunk) - if matches: - chunk_type = ChunkType.INTERFACE - names.extend(matches) - - # Check for functions/methods - if chunk_type == ChunkType.MIXED: - for key in ['function', 'method', 'arrow']: - if key in patterns: - matches = patterns[key].findall(chunk) - if matches: - chunk_type = ChunkType.FUNCTION - names.extend(matches) - break - - # Check for imports - if 'import' in patterns: - import_matches = patterns['import'].findall(chunk) - if import_matches: - imports = import_matches[:10] # Limit - if not names: # Pure import block - chunk_type = ChunkType.IMPORTS - - # Check for documentation files - if language in ('markdown', 'rst', 'text'): - chunk_type = ChunkType.DOCUMENTATION - - # Check for config files - if language in ('json', 'yaml', 'yml', 'toml', 'xml', 'ini'): - chunk_type = ChunkType.CONFIG - - # Extract docstring if present - docstring = self._extract_docstring(chunk, language) - - # Extract function signature - signature = self._extract_signature(chunk, language) - - # Update metadata - metadata['chunk_type'] = chunk_type.value - - if names: - metadata['semantic_names'] = names[:5] # Limit to 5 names - metadata['primary_name'] = names[0] - - if imports: - metadata['imports'] = imports - - if docstring: - metadata['docstring'] = docstring[:500] # Limit size - - if signature: - metadata['signature'] = signature - - return metadata - - def _extract_docstring(self, chunk: str, language: str) -> Optional[str]: - """Extract docstring from code chunk""" - if language == 'python': - # Python docstrings - match = re.search(r'"""([\s\S]*?)"""|\'\'\'([\s\S]*?)\'\'\'', chunk) - if match: - return (match.group(1) or match.group(2)).strip() - - elif language in ('javascript', 'typescript', 'java', 'csharp', 'php', 'go'): - # JSDoc / JavaDoc style - match = re.search(r'/\*\*([\s\S]*?)\*/', chunk) - if match: - # Clean up the comment - doc = match.group(1) - doc = re.sub(r'^\s*\*\s?', '', doc, flags=re.MULTILINE) - return doc.strip() - - return None - - def _extract_signature(self, chunk: str, language: str) -> Optional[str]: - """Extract function/method signature from code chunk""" - lines = chunk.split('\n') - - for line in lines[:10]: # Check first 10 lines - line = line.strip() - - if language == 'python': - if line.startswith(('def ', 'async def ')): - # Get full signature including multi-line params - sig = line - if ')' not in sig: - # Multi-line signature - idx = lines.index(line.strip()) if line.strip() in lines else -1 - if idx >= 0: - for next_line in lines[idx+1:idx+5]: - sig += ' ' + next_line.strip() - if ')' in next_line: - break - return sig.split(':')[0] + ':' - - elif language in ('java', 'csharp', 'kotlin'): - if any(kw in line for kw in ['public ', 'private ', 'protected ', 'internal ']): - if '(' in line and not line.startswith('//'): - return line.split('{')[0].strip() - - elif language in ('javascript', 'typescript'): - if line.startswith(('function ', 'async function ')): - return line.split('{')[0].strip() - if '=>' in line and '(' in line: - return line.split('=>')[0].strip() + ' =>' - - elif language == 'go': - if line.startswith('func '): - return line.split('{')[0].strip() - - elif language == 'rust': - if line.startswith(('fn ', 'pub fn ', 'async fn ', 'pub async fn ')): - return line.split('{')[0].strip() - - return None - - def _fallback_split(self, doc: Document) -> List[TextNode]: - """Fallback splitting for problematic documents""" - text = doc.text - - if not text or not text.strip(): - return [] - - # Use default splitter - chunks = self._default_splitter.split_text(text) - - nodes = [] - for i, chunk in enumerate(chunks): - if not chunk or not chunk.strip(): - continue - - # Truncate if too large - if len(chunk) > 30000: - chunk = chunk[:30000] - - metadata = dict(doc.metadata) - metadata['chunk_index'] = i - metadata['total_chunks'] = len(chunks) - metadata['chunk_type'] = 'fallback' - - chunk_id = self._make_deterministic_id( - metadata.get('namespace', ''), - metadata.get('path', ''), - i - ) - nodes.append(TextNode( - id_=chunk_id, - text=chunk, - metadata=metadata - )) - - return nodes - - @staticmethod - def get_supported_languages() -> List[str]: - """Return list of supported languages""" - return list(LANGUAGE_MAP.keys()) - - @staticmethod - def get_separators_for_language(language: str) -> Optional[List[str]]: - """Get the separators used for a specific language""" - lang_enum = LANGUAGE_MAP.get(language.lower()) - if lang_enum: - return RecursiveCharacterTextSplitter.get_separators_for_language(lang_enum) - return None diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/__init__.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/__init__.py new file mode 100644 index 00000000..74a91cf0 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/__init__.py @@ -0,0 +1,53 @@ +""" +AST-based code splitter module using Tree-sitter. + +Provides semantic code chunking with: +- Tree-sitter query-based extraction (.scm files) +- Fallback to manual AST traversal +- RecursiveCharacterTextSplitter for oversized chunks +- Rich metadata extraction for RAG +""" + +from .splitter import ASTCodeSplitter, ASTChunk, generate_deterministic_id, compute_file_hash +from .languages import ( + get_language_from_path, + get_treesitter_name, + is_ast_supported, + get_supported_languages, + EXTENSION_TO_LANGUAGE, + AST_SUPPORTED_LANGUAGES, + LANGUAGE_TO_TREESITTER, +) +from .metadata import ContentType, ChunkMetadata, MetadataExtractor +from .tree_parser import TreeSitterParser, get_parser +from .query_runner import QueryRunner, QueryMatch, CapturedNode, get_query_runner + +__all__ = [ + # Main splitter + "ASTCodeSplitter", + "ASTChunk", + "generate_deterministic_id", + "compute_file_hash", + + # Languages + "get_language_from_path", + "get_treesitter_name", + "is_ast_supported", + "get_supported_languages", + "EXTENSION_TO_LANGUAGE", + "AST_SUPPORTED_LANGUAGES", + "LANGUAGE_TO_TREESITTER", + + # Metadata + "ContentType", + "ChunkMetadata", + "MetadataExtractor", + + # Tree-sitter + "TreeSitterParser", + "get_parser", + "QueryRunner", + "QueryMatch", + "CapturedNode", + "get_query_runner", +] diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/languages.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/languages.py new file mode 100644 index 00000000..f4758ff1 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/languages.py @@ -0,0 +1,139 @@ +""" +Language detection and mapping for AST-based code splitting. + +Maps file extensions to tree-sitter language names and LangChain Language enum. +""" + +from pathlib import Path +from typing import Dict, Optional, Set +from langchain_text_splitters import Language + + +# Map file extensions to LangChain Language enum (for RecursiveCharacterTextSplitter fallback) +EXTENSION_TO_LANGUAGE: Dict[str, Language] = { + # Python + '.py': Language.PYTHON, + '.pyw': Language.PYTHON, + '.pyi': Language.PYTHON, + + # Java/JVM + '.java': Language.JAVA, + '.kt': Language.KOTLIN, + '.kts': Language.KOTLIN, + '.scala': Language.SCALA, + + # JavaScript/TypeScript + '.js': Language.JS, + '.jsx': Language.JS, + '.mjs': Language.JS, + '.cjs': Language.JS, + '.ts': Language.TS, + '.tsx': Language.TS, + + # Systems languages + '.go': Language.GO, + '.rs': Language.RUST, + '.c': Language.C, + '.h': Language.C, + '.cpp': Language.CPP, + '.cc': Language.CPP, + '.cxx': Language.CPP, + '.hpp': Language.CPP, + '.hxx': Language.CPP, + '.cs': Language.CSHARP, + + # Web/Scripting + '.php': Language.PHP, + '.phtml': Language.PHP, + '.php3': Language.PHP, + '.php4': Language.PHP, + '.php5': Language.PHP, + '.phps': Language.PHP, + '.inc': Language.PHP, + '.rb': Language.RUBY, + '.erb': Language.RUBY, + '.lua': Language.LUA, + '.pl': Language.PERL, + '.pm': Language.PERL, + '.swift': Language.SWIFT, + + # Markup/Config + '.md': Language.MARKDOWN, + '.markdown': Language.MARKDOWN, + '.html': Language.HTML, + '.htm': Language.HTML, + '.rst': Language.RST, + '.tex': Language.LATEX, + '.proto': Language.PROTO, + '.sol': Language.SOL, + '.hs': Language.HASKELL, + '.cob': Language.COBOL, + '.cbl': Language.COBOL, + '.xml': Language.HTML, +} + +# Languages that support full AST parsing via tree-sitter +AST_SUPPORTED_LANGUAGES: Set[Language] = { + Language.PYTHON, Language.JAVA, Language.KOTLIN, Language.JS, Language.TS, + Language.GO, Language.RUST, Language.C, Language.CPP, Language.CSHARP, + Language.PHP, Language.RUBY, Language.SCALA, Language.LUA, Language.PERL, + Language.SWIFT, Language.HASKELL, Language.COBOL +} + +# Map LangChain Language enum to tree-sitter language name +LANGUAGE_TO_TREESITTER: Dict[Language, str] = { + Language.PYTHON: 'python', + Language.JAVA: 'java', + Language.KOTLIN: 'kotlin', + Language.JS: 'javascript', + Language.TS: 'typescript', + Language.GO: 'go', + Language.RUST: 'rust', + Language.C: 'c', + Language.CPP: 'cpp', + Language.CSHARP: 'c_sharp', + Language.PHP: 'php', + Language.RUBY: 'ruby', + Language.SCALA: 'scala', + Language.LUA: 'lua', + Language.PERL: 'perl', + Language.SWIFT: 'swift', + Language.HASKELL: 'haskell', +} + +# Map tree-sitter language name to module info: (module_name, function_name) +TREESITTER_MODULES: Dict[str, tuple] = { + 'python': ('tree_sitter_python', 'language'), + 'java': ('tree_sitter_java', 'language'), + 'javascript': ('tree_sitter_javascript', 'language'), + 'typescript': ('tree_sitter_typescript', 'language_typescript'), + 'go': ('tree_sitter_go', 'language'), + 'rust': ('tree_sitter_rust', 'language'), + 'c': ('tree_sitter_c', 'language'), + 'cpp': ('tree_sitter_cpp', 'language'), + 'c_sharp': ('tree_sitter_c_sharp', 'language'), + 'ruby': ('tree_sitter_ruby', 'language'), + 'php': ('tree_sitter_php', 'language_php'), +} + + +def get_language_from_path(path: str) -> Optional[Language]: + """Determine LangChain Language enum from file path.""" + ext = Path(path).suffix.lower() + return EXTENSION_TO_LANGUAGE.get(ext) + + +def get_treesitter_name(language: Language) -> Optional[str]: + """Get tree-sitter language name from LangChain Language enum.""" + return LANGUAGE_TO_TREESITTER.get(language) + + +def is_ast_supported(path: str) -> bool: + """Check if AST parsing is supported for a file.""" + language = get_language_from_path(path) + return language is not None and language in AST_SUPPORTED_LANGUAGES + + +def get_supported_languages() -> list: + """Return list of languages with AST support.""" + return list(LANGUAGE_TO_TREESITTER.values()) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/metadata.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/metadata.py new file mode 100644 index 00000000..6a2fccd7 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/metadata.py @@ -0,0 +1,339 @@ +""" +Metadata extraction from AST chunks. + +Extracts semantic metadata like docstrings, signatures, inheritance info +from parsed code chunks for improved RAG retrieval. +""" + +import re +import logging +from typing import Dict, Any, List, Optional +from dataclasses import dataclass, field +from enum import Enum + +logger = logging.getLogger(__name__) + + +class ContentType(Enum): + """Content type as determined by AST parsing.""" + FUNCTIONS_CLASSES = "functions_classes" + SIMPLIFIED_CODE = "simplified_code" + FALLBACK = "fallback" + OVERSIZED_SPLIT = "oversized_split" + + +@dataclass +class ChunkMetadata: + """Structured metadata for a code chunk.""" + content_type: ContentType + language: str + path: str + semantic_names: List[str] = field(default_factory=list) + parent_context: List[str] = field(default_factory=list) + docstring: Optional[str] = None + signature: Optional[str] = None + start_line: int = 0 + end_line: int = 0 + node_type: Optional[str] = None + # Class-level metadata + extends: List[str] = field(default_factory=list) + implements: List[str] = field(default_factory=list) + # File-level metadata + imports: List[str] = field(default_factory=list) + namespace: Optional[str] = None + + +class MetadataExtractor: + """ + Extract semantic metadata from code chunks. + + Uses both AST-derived information and regex fallbacks for + comprehensive metadata extraction. + """ + + # Comment prefixes by language + COMMENT_PREFIX: Dict[str, str] = { + 'python': '#', + 'javascript': '//', + 'typescript': '//', + 'java': '//', + 'kotlin': '//', + 'go': '//', + 'rust': '//', + 'c': '//', + 'cpp': '//', + 'c_sharp': '//', + 'php': '//', + 'ruby': '#', + 'lua': '--', + 'perl': '#', + 'scala': '//', + } + + def extract_docstring(self, content: str, language: str) -> Optional[str]: + """Extract docstring from code chunk.""" + if language == 'python': + match = re.search(r'"""([\s\S]*?)"""|\'\'\'([\s\S]*?)\'\'\'', content) + if match: + return (match.group(1) or match.group(2)).strip() + + elif language in ('javascript', 'typescript', 'java', 'kotlin', + 'c_sharp', 'php', 'go', 'scala', 'c', 'cpp'): + # JSDoc / JavaDoc style + match = re.search(r'/\*\*([\s\S]*?)\*/', content) + if match: + doc = match.group(1) + doc = re.sub(r'^\s*\*\s?', '', doc, flags=re.MULTILINE) + return doc.strip() + + elif language == 'rust': + # Rust doc comments + lines = [] + for line in content.split('\n'): + stripped = line.strip() + if stripped.startswith('///'): + lines.append(stripped[3:].strip()) + elif stripped.startswith('//!'): + lines.append(stripped[3:].strip()) + elif lines: + break + if lines: + return '\n'.join(lines) + + return None + + def extract_signature(self, content: str, language: str) -> Optional[str]: + """Extract function/method signature from code chunk.""" + lines = content.split('\n') + + for line in lines[:15]: + line = line.strip() + + if language == 'python': + if line.startswith(('def ', 'async def ', 'class ')): + sig = line + if line.startswith('class ') and ':' in line: + return line.split(':')[0] + ':' + if ')' not in sig and ':' not in sig: + idx = next((i for i, l in enumerate(lines) if l.strip() == line), -1) + if idx >= 0: + for next_line in lines[idx+1:idx+5]: + sig += ' ' + next_line.strip() + if ')' in next_line: + break + if ':' in sig: + return sig.split(':')[0] + ':' + return sig + + elif language in ('java', 'kotlin', 'c_sharp'): + if any(kw in line for kw in ['public ', 'private ', 'protected ', 'internal ', 'fun ']): + if '(' in line and not line.startswith('//'): + return line.split('{')[0].strip() + + elif language in ('javascript', 'typescript'): + if line.startswith(('function ', 'async function ', 'class ')): + return line.split('{')[0].strip() + if '=>' in line and '(' in line: + return line.split('=>')[0].strip() + ' =>' + + elif language == 'go': + if line.startswith('func ') or line.startswith('type '): + return line.split('{')[0].strip() + + elif language == 'rust': + if line.startswith(('fn ', 'pub fn ', 'async fn ', 'pub async fn ', + 'impl ', 'struct ', 'trait ', 'enum ')): + return line.split('{')[0].strip() + + elif language == 'php': + if 'function ' in line and '(' in line: + return line.split('{')[0].strip() + if line.startswith('class ') or line.startswith('interface '): + return line.split('{')[0].strip() + + return None + + def extract_names_from_content(self, content: str, language: str) -> List[str]: + """Extract semantic names (function/class names) using regex patterns.""" + patterns = self._get_name_patterns(language) + names = [] + + for pattern in patterns: + matches = pattern.findall(content) + names.extend(matches) + + # Deduplicate while preserving order + seen = set() + unique_names = [] + for name in names: + if name not in seen: + seen.add(name) + unique_names.append(name) + + return unique_names[:10] # Limit to 10 names + + def _get_name_patterns(self, language: str) -> List[re.Pattern]: + """Get regex patterns for extracting names by language.""" + patterns = { + 'python': [ + re.compile(r'^class\s+(\w+)', re.MULTILINE), + re.compile(r'^(?:async\s+)?def\s+(\w+)\s*\(', re.MULTILINE), + ], + 'java': [ + re.compile(r'(?:public\s+|private\s+|protected\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)', re.MULTILINE), + re.compile(r'(?:public\s+)?interface\s+(\w+)', re.MULTILINE), + re.compile(r'(?:public|private|protected)\s+(?:static\s+)?[\w<>,\s]+\s+(\w+)\s*\(', re.MULTILINE), + ], + 'javascript': [ + re.compile(r'(?:export\s+)?(?:default\s+)?class\s+(\w+)', re.MULTILINE), + re.compile(r'(?:export\s+)?(?:async\s+)?function\s*\*?\s*(\w+)\s*\(', re.MULTILINE), + re.compile(r'(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?\([^)]*\)\s*=>', re.MULTILINE), + ], + 'typescript': [ + re.compile(r'(?:export\s+)?(?:default\s+)?class\s+(\w+)', re.MULTILINE), + re.compile(r'(?:export\s+)?interface\s+(\w+)', re.MULTILINE), + re.compile(r'(?:export\s+)?(?:async\s+)?function\s*\*?\s*(\w+)\s*\(', re.MULTILINE), + re.compile(r'(?:export\s+)?type\s+(\w+)', re.MULTILINE), + ], + 'go': [ + re.compile(r'^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(', re.MULTILINE), + re.compile(r'^type\s+(\w+)\s+(?:struct|interface)\s*\{', re.MULTILINE), + ], + 'rust': [ + re.compile(r'^(?:pub\s+)?(?:async\s+)?fn\s+(\w+)', re.MULTILINE), + re.compile(r'^(?:pub\s+)?struct\s+(\w+)', re.MULTILINE), + re.compile(r'^(?:pub\s+)?trait\s+(\w+)', re.MULTILINE), + re.compile(r'^(?:pub\s+)?enum\s+(\w+)', re.MULTILINE), + ], + 'php': [ + re.compile(r'(?:abstract\s+|final\s+)?class\s+(\w+)', re.MULTILINE), + re.compile(r'interface\s+(\w+)', re.MULTILINE), + re.compile(r'(?:public|private|protected|static|\s)*function\s+(\w+)\s*\(', re.MULTILINE), + ], + 'c_sharp': [ + re.compile(r'(?:public\s+|private\s+|internal\s+)?(?:abstract\s+|sealed\s+)?class\s+(\w+)', re.MULTILINE), + re.compile(r'(?:public\s+)?interface\s+(\w+)', re.MULTILINE), + re.compile(r'(?:public|private|protected|internal)\s+(?:static\s+)?[\w<>,\s]+\s+(\w+)\s*\(', re.MULTILINE), + ], + } + return patterns.get(language, []) + + def extract_inheritance(self, content: str, language: str) -> Dict[str, List[str]]: + """Extract inheritance information (extends, implements).""" + result = {'extends': [], 'implements': [], 'imports': []} + + patterns = self._get_inheritance_patterns(language) + + if 'extends' in patterns: + match = patterns['extends'].search(content) + if match: + extends = match.group(1).strip() + result['extends'] = [e.strip() for e in extends.split(',') if e.strip()] + + if 'implements' in patterns: + match = patterns['implements'].search(content) + if match: + implements = match.group(1).strip() + result['implements'] = [i.strip() for i in implements.split(',') if i.strip()] + + for key in ('import', 'use', 'using', 'require'): + if key in patterns: + matches = patterns[key].findall(content) + for m in matches: + if isinstance(m, tuple): + result['imports'].extend([x.strip() for x in m if x and x.strip()]) + else: + result['imports'].append(m.strip()) + + # Limit imports + result['imports'] = result['imports'][:20] + + return result + + def _get_inheritance_patterns(self, language: str) -> Dict[str, re.Pattern]: + """Get regex patterns for inheritance extraction.""" + patterns = { + 'python': { + 'extends': re.compile(r'class\s+\w+\s*\(\s*([\w.,\s]+)\s*\)\s*:', re.MULTILINE), + 'import': re.compile(r'^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s*]+)', re.MULTILINE), + }, + 'java': { + 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w.]+)', re.MULTILINE), + 'implements': re.compile(r'class\s+\w+(?:\s+extends\s+[\w.]+)?\s+implements\s+([\w.,\s]+)', re.MULTILINE), + 'import': re.compile(r'^import\s+([\w.]+(?:\.\*)?);', re.MULTILINE), + }, + 'typescript': { + 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w.]+)', re.MULTILINE), + 'implements': re.compile(r'class\s+\w+(?:\s+extends\s+[\w.]+)?\s+implements\s+([\w.,\s]+)', re.MULTILINE), + 'import': re.compile(r'^import\s+(?:[\w{},\s*]+\s+from\s+)?["\']([^"\']+)["\'];?', re.MULTILINE), + }, + 'javascript': { + 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w.]+)', re.MULTILINE), + 'import': re.compile(r'^import\s+(?:[\w{},\s*]+\s+from\s+)?["\']([^"\']+)["\'];?', re.MULTILINE), + 'require': re.compile(r'require\s*\(\s*["\']([^"\']+)["\']\s*\)', re.MULTILINE), + }, + 'php': { + 'extends': re.compile(r'class\s+\w+\s+extends\s+([\w\\]+)', re.MULTILINE), + 'implements': re.compile(r'class\s+\w+(?:\s+extends\s+[\w\\]+)?\s+implements\s+([\w\\,\s]+)', re.MULTILINE), + 'use': re.compile(r'^use\s+([\w\\]+)(?:\s+as\s+\w+)?;', re.MULTILINE), + }, + 'c_sharp': { + 'extends': re.compile(r'class\s+\w+\s*:\s*([\w.]+)', re.MULTILINE), + 'using': re.compile(r'^using\s+([\w.]+);', re.MULTILINE), + }, + 'go': { + 'import': re.compile(r'^import\s+(?:\(\s*)?"([^"]+)"', re.MULTILINE), + }, + 'rust': { + 'use': re.compile(r'^use\s+([\w:]+(?:::\{[^}]+\})?);', re.MULTILINE), + }, + } + return patterns.get(language, {}) + + def get_comment_prefix(self, language: str) -> str: + """Get comment prefix for a language.""" + return self.COMMENT_PREFIX.get(language, '//') + + def build_metadata_dict( + self, + chunk_metadata: ChunkMetadata, + base_metadata: Dict[str, Any] + ) -> Dict[str, Any]: + """Build final metadata dictionary from ChunkMetadata.""" + metadata = dict(base_metadata) + + metadata['content_type'] = chunk_metadata.content_type.value + metadata['node_type'] = chunk_metadata.node_type + metadata['start_line'] = chunk_metadata.start_line + metadata['end_line'] = chunk_metadata.end_line + + if chunk_metadata.parent_context: + metadata['parent_context'] = chunk_metadata.parent_context + metadata['parent_class'] = chunk_metadata.parent_context[-1] + full_path_parts = chunk_metadata.parent_context + chunk_metadata.semantic_names[:1] + metadata['full_path'] = '.'.join(full_path_parts) + + if chunk_metadata.semantic_names: + metadata['semantic_names'] = chunk_metadata.semantic_names + metadata['primary_name'] = chunk_metadata.semantic_names[0] + + if chunk_metadata.docstring: + metadata['docstring'] = chunk_metadata.docstring[:500] + + if chunk_metadata.signature: + metadata['signature'] = chunk_metadata.signature + + if chunk_metadata.extends: + metadata['extends'] = chunk_metadata.extends + metadata['parent_types'] = chunk_metadata.extends + + if chunk_metadata.implements: + metadata['implements'] = chunk_metadata.implements + + if chunk_metadata.imports: + metadata['imports'] = chunk_metadata.imports + + if chunk_metadata.namespace: + metadata['namespace'] = chunk_metadata.namespace + + return metadata diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/c_sharp.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/c_sharp.scm new file mode 100644 index 00000000..4a6d7270 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/c_sharp.scm @@ -0,0 +1,56 @@ +; C# tree-sitter queries for AST-based code splitting + +; Using directives +(using_directive) @using + +; Namespace declarations +(namespace_declaration + name: (identifier) @name) @definition.namespace + +(file_scoped_namespace_declaration + name: (identifier) @name) @definition.namespace + +; Class declarations +(class_declaration + name: (identifier) @name) @definition.class + +; Struct declarations +(struct_declaration + name: (identifier) @name) @definition.struct + +; Interface declarations +(interface_declaration + name: (identifier) @name) @definition.interface + +; Enum declarations +(enum_declaration + name: (identifier) @name) @definition.enum + +; Record declarations +(record_declaration + name: (identifier) @name) @definition.record + +; Delegate declarations +(delegate_declaration + name: (identifier) @name) @definition.delegate + +; Method declarations +(method_declaration + name: (identifier) @name) @definition.method + +; Constructor declarations +(constructor_declaration + name: (identifier) @name) @definition.constructor + +; Property declarations +(property_declaration + name: (identifier) @name) @definition.property + +; Field declarations +(field_declaration) @definition.field + +; Event declarations +(event_declaration) @definition.event + +; Attributes +(attribute) @attribute diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/go.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/go.scm new file mode 100644 index 00000000..87bf86ff --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/go.scm @@ -0,0 +1,26 @@ +; Go tree-sitter queries for AST-based code splitting + +; Package clause +(package_clause) @package + +; Import declarations +(import_declaration) @import + +; Type declarations (struct, interface, type alias) +(type_declaration + (type_spec + name: (type_identifier) @name)) @definition.type + +; Function declarations +(function_declaration + name: (identifier) @name) @definition.function + +; Method declarations +(method_declaration + name: (field_identifier) @name) @definition.method + +; Variable declarations +(var_declaration) @definition.variable + +; Constant declarations +(const_declaration) @definition.const diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/java.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/java.scm new file mode 100644 index 00000000..f9e0bdc9 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/java.scm @@ -0,0 +1,45 @@ +; Java tree-sitter queries for AST-based code splitting + +; Package declaration +(package_declaration) @package + +; Import declarations +(import_declaration) @import + +; Class declarations +(class_declaration + name: (identifier) @name) @definition.class + +; Interface declarations +(interface_declaration + name: (identifier) @name) @definition.interface + +; Enum declarations +(enum_declaration + name: (identifier) @name) @definition.enum + +; Record declarations (Java 14+) +(record_declaration + name: (identifier) @name) @definition.record + +; Annotation type declarations +(annotation_type_declaration + name: (identifier) @name) @definition.annotation + +; Method declarations +(method_declaration + name: (identifier) @name) @definition.method + +; Constructor declarations +(constructor_declaration + name: (identifier) @name) @definition.constructor + +; Field declarations +(field_declaration) @definition.field + +; Annotations (for metadata) +(marker_annotation + name: (identifier) @name) @annotation + +(annotation + name: (identifier) @name) @annotation diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/javascript.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/javascript.scm new file mode 100644 index 00000000..dd8c3561 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/javascript.scm @@ -0,0 +1,42 @@ +; JavaScript/TypeScript tree-sitter queries for AST-based code splitting + +; Import statements +(import_statement) @import + +; Export statements +(export_statement) @export + +; Class declarations +(class_declaration + name: (identifier) @name) @definition.class + +; Function declarations +(function_declaration + name: (identifier) @name) @definition.function + +; Arrow functions assigned to variables +(lexical_declaration + (variable_declarator + name: (identifier) @name + value: (arrow_function))) @definition.function + +; Generator functions +(generator_function_declaration + name: (identifier) @name) @definition.function + +; Method definitions (inside class body) +(method_definition + name: (property_identifier) @name) @definition.method + +; Variable declarations (module-level) +(lexical_declaration) @definition.variable + +(variable_declaration) @definition.variable + +; Interface declarations (TypeScript) +(interface_declaration + name: (type_identifier) @name) @definition.interface + +; Type alias declarations (TypeScript) +(type_alias_declaration + name: (type_identifier) @name) @definition.type diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/php.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/php.scm new file mode 100644 index 00000000..52c29810 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/php.scm @@ -0,0 +1,40 @@ +; PHP tree-sitter queries for AST-based code splitting + +; Namespace definition +(namespace_definition) @namespace + +; Use statements +(namespace_use_declaration) @use + +; Class declarations +(class_declaration + name: (name) @name) @definition.class + +; Interface declarations +(interface_declaration + name: (name) @name) @definition.interface + +; Trait declarations +(trait_declaration + name: (name) @name) @definition.trait + +; Enum declarations (PHP 8.1+) +(enum_declaration + name: (name) @name) @definition.enum + +; Function definitions +(function_definition + name: (name) @name) @definition.function + +; Method declarations +(method_declaration + name: (name) @name) @definition.method + +; Property declarations +(property_declaration) @definition.property + +; Const declarations +(const_declaration) @definition.const + +; Attributes (PHP 8.0+) +(attribute) @attribute diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/python.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/python.scm new file mode 100644 index 00000000..6530ad79 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/python.scm @@ -0,0 +1,28 @@ +; Python tree-sitter queries for AST-based code splitting + +; Import statements +(import_statement) @import + +(import_from_statement) @import + +; Class definitions +(class_definition + name: (identifier) @name) @definition.class + +; Function definitions +(function_definition + name: (identifier) @name) @definition.function + +; Decorated definitions (class or function with decorators) +(decorated_definition) @definition.decorated + +; Decorators +(decorator) @decorator + +; Assignment statements (module-level constants) +(assignment + left: (identifier) @name) @definition.assignment + +; Type alias (Python 3.12+) +(type_alias_statement + name: (type) @name) @definition.type_alias diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/rust.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/rust.scm new file mode 100644 index 00000000..b3315f6a --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/rust.scm @@ -0,0 +1,46 @@ +; Rust tree-sitter queries for AST-based code splitting + +; Use declarations +(use_declaration) @use + +; Module declarations +(mod_item + name: (identifier) @name) @definition.module + +; Struct definitions +(struct_item + name: (type_identifier) @name) @definition.struct + +; Enum definitions +(enum_item + name: (type_identifier) @name) @definition.enum + +; Trait definitions +(trait_item + name: (type_identifier) @name) @definition.trait + +; Implementation blocks +(impl_item) @definition.impl + +; Function definitions +(function_item + name: (identifier) @name) @definition.function + +; Type alias +(type_item + name: (type_identifier) @name) @definition.type + +; Constant definitions +(const_item + name: (identifier) @name) @definition.const + +; Static definitions +(static_item + name: (identifier) @name) @definition.static + +; Macro definitions +(macro_definition + name: (identifier) @name) @definition.macro + +; Attributes +(attribute_item) @attribute diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/typescript.scm b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/typescript.scm new file mode 100644 index 00000000..71807c96 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/queries/typescript.scm @@ -0,0 +1,52 @@ +; TypeScript tree-sitter queries for AST-based code splitting +; Uses same patterns as JavaScript plus TypeScript-specific nodes + +; Import statements +(import_statement) @import + +; Export statements +(export_statement) @export + +; Class declarations +(class_declaration + name: (type_identifier) @name) @definition.class + +; Abstract class declarations +(abstract_class_declaration + name: (type_identifier) @name) @definition.class + +; Function declarations +(function_declaration + name: (identifier) @name) @definition.function + +; Arrow functions assigned to variables +(lexical_declaration + (variable_declarator + name: (identifier) @name + value: (arrow_function))) @definition.function + +; Method definitions (inside class body) +(method_definition + name: (property_identifier) @name) @definition.method + +; Interface declarations +(interface_declaration + name: (type_identifier) @name) @definition.interface + +; Type alias declarations +(type_alias_declaration + name: (type_identifier) @name) @definition.type + +; Enum declarations +(enum_declaration + name: (identifier) @name) @definition.enum + +; Module declarations +(module + name: (identifier) @name) @definition.module + +; Variable declarations (module-level) +(lexical_declaration) @definition.variable + +; Ambient declarations +(ambient_declaration) @definition.ambient diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/query_runner.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/query_runner.py new file mode 100644 index 00000000..47da9f26 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/query_runner.py @@ -0,0 +1,360 @@ +""" +Tree-sitter query runner using custom query files with built-in fallback. + +Prefers custom .scm query files for rich metadata extraction (extends, implements, imports), +falling back to built-in TAGS_QUERY only when custom query is unavailable. +""" + +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Any, Optional + +from .tree_parser import get_parser +from .languages import TREESITTER_MODULES + +logger = logging.getLogger(__name__) + +# Directory containing custom .scm query files +QUERIES_DIR = Path(__file__).parent / "queries" + +# Languages that have built-in TAGS_QUERY (used as fallback only) +LANGUAGES_WITH_BUILTIN_TAGS = {'python', 'java', 'javascript', 'go', 'rust', 'php'} + +# Languages with custom .scm files for rich metadata (extends, implements, imports) +LANGUAGES_WITH_CUSTOM_QUERY = { + 'python', 'java', 'javascript', 'typescript', 'c_sharp', 'go', 'rust', 'php' +} + + +@dataclass +class CapturedNode: + """Represents a captured AST node from a query.""" + name: str # Capture name (e.g., 'function.name', 'class.body') + text: str # Node text content + start_byte: int + end_byte: int + start_point: tuple # (row, column) + end_point: tuple + node_type: str # Tree-sitter node type + + @property + def start_line(self) -> int: + return self.start_point[0] + 1 # Convert to 1-based + + @property + def end_line(self) -> int: + return self.end_point[0] + 1 + + +@dataclass +class QueryMatch: + """A complete match from a query pattern.""" + pattern_name: str # e.g., 'function', 'class', 'import' + captures: Dict[str, CapturedNode] = field(default_factory=dict) + + def get(self, capture_name: str) -> Optional[CapturedNode]: + """Get a captured node by name.""" + return self.captures.get(capture_name) + + @property + def full_text(self) -> Optional[str]: + """Get the full text of the main capture (pattern_name without suffix).""" + main_capture = self.captures.get(self.pattern_name) + return main_capture.text if main_capture else None + + +class QueryRunner: + """ + Executes tree-sitter queries using custom .scm files with built-in fallback. + + Strategy: + 1. Prefer custom .scm files for rich metadata (extends, implements, imports, decorators) + 2. Fall back to built-in TAGS_QUERY only when no custom query exists + + Custom queries capture: @class.extends, @class.implements, @import, @decorator, + @method.visibility, @function.return_type, etc. + + Built-in TAGS_QUERY only captures: @definition.function, @definition.class, @name, @doc + """ + + def __init__(self): + self._query_cache: Dict[str, Any] = {} # lang -> compiled query + self._scm_cache: Dict[str, str] = {} # lang -> raw scm string + self._parser = get_parser() + + def _get_builtin_tags_query(self, lang_name: str) -> Optional[str]: + """Get built-in TAGS_QUERY from language package if available.""" + if lang_name not in LANGUAGES_WITH_BUILTIN_TAGS: + return None + + lang_info = TREESITTER_MODULES.get(lang_name) + if not lang_info: + return None + + module_name = lang_info[0] + try: + import importlib + lang_module = importlib.import_module(module_name) + tags_query = getattr(lang_module, 'TAGS_QUERY', None) + if tags_query: + logger.debug(f"Using built-in TAGS_QUERY for {lang_name}") + return tags_query + except (ImportError, AttributeError) as e: + logger.debug(f"Could not load built-in query for {lang_name}: {e}") + + return None + + def _load_custom_query_file(self, lang_name: str) -> Optional[str]: + """Load custom .scm query file for languages without built-in queries.""" + if lang_name in self._scm_cache: + return self._scm_cache[lang_name] + + query_file = QUERIES_DIR / f"{lang_name}.scm" + + if not query_file.exists(): + logger.debug(f"No custom query file for {lang_name}") + return None + + try: + scm_content = query_file.read_text(encoding='utf-8') + self._scm_cache[lang_name] = scm_content + logger.debug(f"Loaded custom query file for {lang_name}") + return scm_content + except Exception as e: + logger.warning(f"Failed to load query file {query_file}: {e}") + return None + + def _get_query_string(self, lang_name: str) -> Optional[str]: + """Get query string - custom first, then built-in fallback.""" + # Prefer custom .scm for rich metadata (extends, implements, imports) + custom = self._load_custom_query_file(lang_name) + if custom: + return custom + + # Fall back to built-in TAGS_QUERY (limited metadata) + return self._get_builtin_tags_query(lang_name) + + def _try_compile_query(self, lang_name: str, scm_content: str, language: Any) -> Optional[Any]: + """Try to compile a query string, returning None on failure.""" + try: + from tree_sitter import Query + return Query(language, scm_content) + except Exception as e: + logger.debug(f"Query compilation failed for {lang_name}: {e}") + return None + + def _get_compiled_query(self, lang_name: str) -> Optional[Any]: + """Get or compile the query for a language with fallback.""" + if lang_name in self._query_cache: + return self._query_cache[lang_name] + + language = self._parser.get_language(lang_name) + if not language: + return None + + # Try custom .scm first + custom_scm = self._load_custom_query_file(lang_name) + if custom_scm: + query = self._try_compile_query(lang_name, custom_scm, language) + if query: + logger.debug(f"Using custom query for {lang_name}") + self._query_cache[lang_name] = query + return query + else: + logger.debug(f"Custom query failed for {lang_name}, trying built-in") + + # Fallback to built-in TAGS_QUERY + builtin_scm = self._get_builtin_tags_query(lang_name) + if builtin_scm: + query = self._try_compile_query(lang_name, builtin_scm, language) + if query: + logger.debug(f"Using built-in TAGS_QUERY for {lang_name}") + self._query_cache[lang_name] = query + return query + + logger.debug(f"No working query available for {lang_name}") + return None + + def run_query( + self, + source_code: str, + lang_name: str, + tree: Optional[Any] = None + ) -> List[QueryMatch]: + """ + Run the query for a language and return all matches. + + Args: + source_code: Source code string + lang_name: Tree-sitter language name + tree: Optional pre-parsed tree (will parse if not provided) + + Returns: + List of QueryMatch objects with captured nodes + """ + query = self._get_compiled_query(lang_name) + if not query: + return [] + + if tree is None: + tree = self._parser.parse(source_code, lang_name) + if tree is None: + return [] + + source_bytes = source_code.encode('utf-8') + + try: + # Use QueryCursor.matches() for pattern-grouped results + # Each match is (pattern_id, {capture_name: [nodes]}) + from tree_sitter import QueryCursor + cursor = QueryCursor(query) + raw_matches = list(cursor.matches(tree.root_node)) + except Exception as e: + logger.warning(f"Query execution failed for {lang_name}: {e}") + return [] + + results: List[QueryMatch] = [] + + for pattern_id, captures_dict in raw_matches: + # Determine pattern type from captures + # Built-in: @definition.function, @definition.class, @name + # Custom: @function, @class, @function.name + + pattern_name = None + main_node = None + name_node = None + doc_node = None + + for capture_name, nodes in captures_dict.items(): + if not nodes: + continue + node = nodes[0] # Take first node for each capture + + # Built-in definition captures + if capture_name.startswith('definition.'): + pattern_name = capture_name[len('definition.'):] + main_node = node + # Built-in @name capture (associated with this pattern) + elif capture_name == 'name': + name_node = node + # Built-in @doc capture + elif capture_name == 'doc': + doc_node = node + # Skip reference captures + elif capture_name.startswith('reference.'): + continue + # Custom query captures: @function, @class + elif '.' not in capture_name: + pattern_name = capture_name + main_node = node + + # Skip if no definition pattern found + if not pattern_name or not main_node: + continue + + # Build the QueryMatch + match = QueryMatch(pattern_name=pattern_name) + + # Add main capture + match.captures[pattern_name] = CapturedNode( + name=pattern_name, + text=source_bytes[main_node.start_byte:main_node.end_byte].decode('utf-8', errors='replace'), + start_byte=main_node.start_byte, + end_byte=main_node.end_byte, + start_point=(main_node.start_point.row, main_node.start_point.column), + end_point=(main_node.end_point.row, main_node.end_point.column), + node_type=main_node.type + ) + + # Add name capture if present + if name_node: + match.captures[f'{pattern_name}.name'] = CapturedNode( + name=f'{pattern_name}.name', + text=source_bytes[name_node.start_byte:name_node.end_byte].decode('utf-8', errors='replace'), + start_byte=name_node.start_byte, + end_byte=name_node.end_byte, + start_point=(name_node.start_point.row, name_node.start_point.column), + end_point=(name_node.end_point.row, name_node.end_point.column), + node_type=name_node.type + ) + + # Add doc capture if present + if doc_node: + match.captures[f'{pattern_name}.doc'] = CapturedNode( + name=f'{pattern_name}.doc', + text=source_bytes[doc_node.start_byte:doc_node.end_byte].decode('utf-8', errors='replace'), + start_byte=doc_node.start_byte, + end_byte=doc_node.end_byte, + start_point=(doc_node.start_point.row, doc_node.start_point.column), + end_point=(doc_node.end_point.row, doc_node.end_point.column), + node_type=doc_node.type + ) + + # Process any additional sub-captures from custom queries + for capture_name, nodes in captures_dict.items(): + if '.' in capture_name and not capture_name.startswith(('definition.', 'reference.')): + node = nodes[0] + match.captures[capture_name] = CapturedNode( + name=capture_name, + text=source_bytes[node.start_byte:node.end_byte].decode('utf-8', errors='replace'), + start_byte=node.start_byte, + end_byte=node.end_byte, + start_point=(node.start_point.row, node.start_point.column), + end_point=(node.end_point.row, node.end_point.column), + node_type=node.type + ) + + results.append(match) + + return results + + def get_functions(self, source_code: str, lang_name: str) -> List[QueryMatch]: + """Convenience method to get function/method matches.""" + matches = self.run_query(source_code, lang_name) + return [m for m in matches if m.pattern_name in ('function', 'method')] + + def get_classes(self, source_code: str, lang_name: str) -> List[QueryMatch]: + """Convenience method to get class/struct/interface matches.""" + matches = self.run_query(source_code, lang_name) + return [m for m in matches if m.pattern_name in ('class', 'struct', 'interface', 'trait')] + + def get_imports(self, source_code: str, lang_name: str) -> List[QueryMatch]: + """Convenience method to get import statement matches.""" + matches = self.run_query(source_code, lang_name) + return [m for m in matches if m.pattern_name == 'import'] + + def has_query(self, lang_name: str) -> bool: + """Check if a query is available for this language (custom or built-in).""" + # Check custom file first + query_file = QUERIES_DIR / f"{lang_name}.scm" + if query_file.exists(): + return True + # Check built-in fallback + return lang_name in LANGUAGES_WITH_BUILTIN_TAGS + + def uses_custom_query(self, lang_name: str) -> bool: + """Check if this language uses custom .scm query (rich metadata).""" + query_file = QUERIES_DIR / f"{lang_name}.scm" + return query_file.exists() + + def uses_builtin_query(self, lang_name: str) -> bool: + """Check if this language uses built-in TAGS_QUERY (limited metadata).""" + return lang_name in LANGUAGES_WITH_BUILTIN_TAGS and not self.uses_custom_query(lang_name) + + def clear_cache(self): + """Clear compiled query cache.""" + self._query_cache.clear() + self._scm_cache.clear() + + +# Global singleton +_runner_instance: Optional[QueryRunner] = None + + +def get_query_runner() -> QueryRunner: + """Get the global QueryRunner instance.""" + global _runner_instance + if _runner_instance is None: + _runner_instance = QueryRunner() + return _runner_instance diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py new file mode 100644 index 00000000..622892d9 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py @@ -0,0 +1,720 @@ +""" +AST-based Code Splitter using Tree-sitter for accurate code parsing. + +This module provides true AST-aware code chunking that: +1. Uses Tree-sitter queries for efficient pattern matching (15+ languages) +2. Splits code into semantic units (classes, functions, methods) +3. Uses RecursiveCharacterTextSplitter for oversized chunks +4. Enriches metadata for better RAG retrieval +5. Maintains parent context ("breadcrumbs") for nested structures +6. Uses deterministic IDs for Qdrant deduplication +""" + +import hashlib +import logging +from typing import List, Dict, Any, Optional, Set +from pathlib import Path +from dataclasses import dataclass, field + +from langchain_text_splitters import RecursiveCharacterTextSplitter, Language +from llama_index.core.schema import Document as LlamaDocument, TextNode + +from .languages import ( + EXTENSION_TO_LANGUAGE, AST_SUPPORTED_LANGUAGES, LANGUAGE_TO_TREESITTER, + get_language_from_path, get_treesitter_name, is_ast_supported +) +from .tree_parser import get_parser +from .query_runner import get_query_runner, QueryMatch +from .metadata import MetadataExtractor, ContentType, ChunkMetadata + +logger = logging.getLogger(__name__) + + +def generate_deterministic_id(path: str, content: str, chunk_index: int = 0) -> str: + """ + Generate a deterministic ID for a chunk based on file path and content. + + This ensures the same code chunk always gets the same ID, preventing + duplicates in Qdrant during re-indexing. + """ + hash_input = f"{path}:{chunk_index}:{content[:500]}" + return hashlib.sha256(hash_input.encode('utf-8')).hexdigest()[:32] + + +def compute_file_hash(content: str) -> str: + """Compute hash of file content for change detection.""" + return hashlib.sha256(content.encode('utf-8')).hexdigest() + + +@dataclass +class ASTChunk: + """Represents a chunk of code from AST parsing.""" + content: str + content_type: ContentType + language: str + path: str + semantic_names: List[str] = field(default_factory=list) + parent_context: List[str] = field(default_factory=list) + docstring: Optional[str] = None + signature: Optional[str] = None + start_line: int = 0 + end_line: int = 0 + node_type: Optional[str] = None + extends: List[str] = field(default_factory=list) + implements: List[str] = field(default_factory=list) + imports: List[str] = field(default_factory=list) + namespace: Optional[str] = None + + +class ASTCodeSplitter: + """ + AST-based code splitter using Tree-sitter queries for accurate parsing. + + Features: + - Uses .scm query files for declarative pattern matching + - Splits code into semantic units (classes, functions, methods) + - Falls back to RecursiveCharacterTextSplitter when needed + - Uses deterministic IDs for Qdrant deduplication + - Enriches metadata for improved RAG retrieval + + Usage: + splitter = ASTCodeSplitter(max_chunk_size=2000) + nodes = splitter.split_documents(documents) + """ + + DEFAULT_MAX_CHUNK_SIZE = 2000 + DEFAULT_MIN_CHUNK_SIZE = 100 + DEFAULT_CHUNK_OVERLAP = 200 + DEFAULT_PARSER_THRESHOLD = 10 + + def __init__( + self, + max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE, + min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE, + chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, + parser_threshold: int = DEFAULT_PARSER_THRESHOLD + ): + """ + Initialize AST code splitter. + + Args: + max_chunk_size: Maximum characters per chunk + min_chunk_size: Minimum characters for a valid chunk + chunk_overlap: Overlap between chunks when splitting oversized content + parser_threshold: Minimum lines for AST parsing + """ + self.max_chunk_size = max_chunk_size + self.min_chunk_size = min_chunk_size + self.chunk_overlap = chunk_overlap + self.parser_threshold = parser_threshold + + # Components + self._parser = get_parser() + self._query_runner = get_query_runner() + self._metadata_extractor = MetadataExtractor() + + # Cache text splitters + self._splitter_cache: Dict[Language, RecursiveCharacterTextSplitter] = {} + + # Default splitter + self._default_splitter = RecursiveCharacterTextSplitter( + chunk_size=max_chunk_size, + chunk_overlap=chunk_overlap, + length_function=len, + ) + + def split_documents(self, documents: List[LlamaDocument]) -> List[TextNode]: + """ + Split LlamaIndex documents using AST-based parsing. + + Args: + documents: List of LlamaIndex Document objects + + Returns: + List of TextNode objects with enriched metadata + """ + all_nodes = [] + + for doc in documents: + path = doc.metadata.get('path', 'unknown') + language = get_language_from_path(path) + + line_count = doc.text.count('\n') + 1 + use_ast = ( + language is not None + and language in AST_SUPPORTED_LANGUAGES + and line_count >= self.parser_threshold + and self._parser.is_available() + ) + + if use_ast: + nodes = self._split_with_ast(doc, language) + else: + nodes = self._split_fallback(doc, language) + + all_nodes.extend(nodes) + logger.debug(f"Split {path} into {len(nodes)} chunks (AST={use_ast})") + + return all_nodes + + def _split_with_ast(self, doc: LlamaDocument, language: Language) -> List[TextNode]: + """Split document using AST parsing with query-based extraction.""" + text = doc.text + path = doc.metadata.get('path', 'unknown') + ts_lang = get_treesitter_name(language) + + if not ts_lang: + return self._split_fallback(doc, language) + + # Try query-based extraction first + chunks = self._extract_with_queries(text, ts_lang, path) + + # If no queries available, fall back to traversal-based extraction + if not chunks: + chunks = self._extract_with_traversal(text, ts_lang, path) + + # Still no chunks? Use fallback + if not chunks: + return self._split_fallback(doc, language) + + return self._process_chunks(chunks, doc, language, path) + + def _extract_with_queries( + self, + text: str, + lang_name: str, + path: str + ) -> List[ASTChunk]: + """Extract chunks using tree-sitter query files with rich metadata.""" + if not self._query_runner.has_query(lang_name): + return [] + + tree = self._parser.parse(text, lang_name) + if not tree: + return [] + + matches = self._query_runner.run_query(text, lang_name, tree) + if not matches: + return [] + + source_bytes = text.encode('utf-8') + chunks = [] + processed_ranges: Set[tuple] = set() + + # Collect file-level metadata from all matches + imports = [] + namespace = None + decorators_map: Dict[int, List[str]] = {} # line -> decorators + + for match in matches: + # Handle imports (multiple capture variations) + if match.pattern_name in ('import', 'use'): + import_path = ( + match.get('import.path') or + match.get('import') or + match.get('use.path') or + match.get('use') + ) + if import_path: + imports.append(import_path.text.strip().strip('"\'')) + continue + + # Handle namespace/package/module + if match.pattern_name in ('namespace', 'package', 'module'): + ns_cap = match.get(f'{match.pattern_name}.name') or match.get(match.pattern_name) + if ns_cap: + namespace = ns_cap.text.strip() + continue + + # Handle standalone decorators/attributes + if match.pattern_name in ('decorator', 'attribute', 'annotation'): + dec_cap = match.get(f'{match.pattern_name}.name') or match.get(match.pattern_name) + if dec_cap: + line = dec_cap.start_line + if line not in decorators_map: + decorators_map[line] = [] + decorators_map[line].append(dec_cap.text.strip()) + continue + + # Handle main constructs: functions, classes, methods, etc. + semantic_patterns = ( + 'function', 'method', 'class', 'interface', 'struct', 'trait', + 'enum', 'impl', 'constructor', 'closure', 'arrow', 'const', + 'var', 'static', 'type', 'record' + ) + if match.pattern_name in semantic_patterns: + main_cap = match.get(match.pattern_name) + if not main_cap: + continue + + range_key = (main_cap.start_byte, main_cap.end_byte) + if range_key in processed_ranges: + continue + processed_ranges.add(range_key) + + # Get name from various capture patterns + name_cap = ( + match.get(f'{match.pattern_name}.name') or + match.get('name') + ) + name = name_cap.text if name_cap else None + + # Get inheritance (extends/implements/embeds/supertrait) + extends = [] + implements = [] + + for ext_capture in ('extends', 'embeds', 'supertrait', 'base_type'): + cap = match.get(f'{match.pattern_name}.{ext_capture}') + if cap: + extends.extend(self._parse_type_list(cap.text)) + + for impl_capture in ('implements', 'trait'): + cap = match.get(f'{match.pattern_name}.{impl_capture}') + if cap: + implements.extend(self._parse_type_list(cap.text)) + + # Get additional metadata from captures + visibility = match.get(f'{match.pattern_name}.visibility') + return_type = match.get(f'{match.pattern_name}.return_type') + params = match.get(f'{match.pattern_name}.params') + modifiers = [] + + for mod in ('static', 'abstract', 'final', 'async', 'readonly', 'const', 'unsafe'): + if match.get(f'{match.pattern_name}.{mod}'): + modifiers.append(mod) + + chunk = ASTChunk( + content=main_cap.text, + content_type=ContentType.FUNCTIONS_CLASSES, + language=lang_name, + path=path, + semantic_names=[name] if name else [], + parent_context=[], + start_line=main_cap.start_line, + end_line=main_cap.end_line, + node_type=match.pattern_name, + extends=extends, + implements=implements, + ) + + # Extract docstring and signature + chunk.docstring = self._metadata_extractor.extract_docstring(main_cap.text, lang_name) + chunk.signature = self._metadata_extractor.extract_signature(main_cap.text, lang_name) + + chunks.append(chunk) + + # Add imports and namespace to all chunks + for chunk in chunks: + chunk.imports = imports[:30] + chunk.namespace = namespace + + # Create simplified code chunk + if chunks: + simplified = self._create_simplified_code(text, chunks, lang_name) + if simplified and len(simplified.strip()) > 50: + chunks.append(ASTChunk( + content=simplified, + content_type=ContentType.SIMPLIFIED_CODE, + language=lang_name, + path=path, + start_line=1, + end_line=text.count('\n') + 1, + node_type='simplified', + imports=imports[:30], + namespace=namespace, + )) + + return chunks + + def _extract_with_traversal( + self, + text: str, + lang_name: str, + path: str + ) -> List[ASTChunk]: + """Fallback: extract chunks using manual AST traversal.""" + tree = self._parser.parse(text, lang_name) + if not tree: + return [] + + source_bytes = text.encode('utf-8') + chunks = [] + processed_ranges: Set[tuple] = set() + + # Node types for semantic chunking + semantic_types = self._get_semantic_node_types(lang_name) + class_types = set(semantic_types.get('class', [])) + function_types = set(semantic_types.get('function', [])) + all_types = class_types | function_types + + def get_node_text(node) -> str: + return source_bytes[node.start_byte:node.end_byte].decode('utf-8', errors='replace') + + def get_node_name(node) -> Optional[str]: + for child in node.children: + if child.type in ('identifier', 'name', 'type_identifier', 'property_identifier'): + return get_node_text(child) + return None + + def traverse(node, parent_context: List[str]): + node_range = (node.start_byte, node.end_byte) + + if node.type in all_types: + if node_range in processed_ranges: + return + + content = get_node_text(node) + start_line = source_bytes[:node.start_byte].count(b'\n') + 1 + end_line = start_line + content.count('\n') + node_name = get_node_name(node) + is_class = node.type in class_types + + chunk = ASTChunk( + content=content, + content_type=ContentType.FUNCTIONS_CLASSES, + language=lang_name, + path=path, + semantic_names=[node_name] if node_name else [], + parent_context=list(parent_context), + start_line=start_line, + end_line=end_line, + node_type=node.type, + ) + + chunk.docstring = self._metadata_extractor.extract_docstring(content, lang_name) + chunk.signature = self._metadata_extractor.extract_signature(content, lang_name) + + # Extract inheritance via regex + inheritance = self._metadata_extractor.extract_inheritance(content, lang_name) + chunk.extends = inheritance.get('extends', []) + chunk.implements = inheritance.get('implements', []) + chunk.imports = inheritance.get('imports', []) + + chunks.append(chunk) + processed_ranges.add(node_range) + + if is_class and node_name: + for child in node.children: + traverse(child, parent_context + [node_name]) + else: + for child in node.children: + traverse(child, parent_context) + + traverse(tree.root_node, []) + + # Create simplified code + if chunks: + simplified = self._create_simplified_code(text, chunks, lang_name) + if simplified and len(simplified.strip()) > 50: + chunks.append(ASTChunk( + content=simplified, + content_type=ContentType.SIMPLIFIED_CODE, + language=lang_name, + path=path, + start_line=1, + end_line=text.count('\n') + 1, + node_type='simplified', + )) + + return chunks + + def _process_chunks( + self, + chunks: List[ASTChunk], + doc: LlamaDocument, + language: Language, + path: str + ) -> List[TextNode]: + """Process AST chunks into TextNodes, handling oversized chunks.""" + nodes = [] + chunk_counter = 0 + + for ast_chunk in chunks: + if len(ast_chunk.content) > self.max_chunk_size: + sub_nodes = self._split_oversized_chunk(ast_chunk, language, doc.metadata, path) + nodes.extend(sub_nodes) + chunk_counter += len(sub_nodes) + else: + metadata = self._build_metadata(ast_chunk, doc.metadata, chunk_counter, len(chunks)) + chunk_id = generate_deterministic_id(path, ast_chunk.content, chunk_counter) + + node = TextNode( + id_=chunk_id, + text=ast_chunk.content, + metadata=metadata + ) + nodes.append(node) + chunk_counter += 1 + + return nodes + + def _split_oversized_chunk( + self, + chunk: ASTChunk, + language: Optional[Language], + base_metadata: Dict[str, Any], + path: str + ) -> List[TextNode]: + """Split an oversized chunk using RecursiveCharacterTextSplitter.""" + splitter = self._get_text_splitter(language) if language else self._default_splitter + sub_chunks = splitter.split_text(chunk.content) + + nodes = [] + parent_id = generate_deterministic_id(path, chunk.content, 0) + + for i, sub_chunk in enumerate(sub_chunks): + if not sub_chunk or not sub_chunk.strip(): + continue + if len(sub_chunk.strip()) < self.min_chunk_size and len(sub_chunks) > 1: + continue + + metadata = dict(base_metadata) + metadata['content_type'] = ContentType.OVERSIZED_SPLIT.value + metadata['original_content_type'] = chunk.content_type.value + metadata['parent_chunk_id'] = parent_id + metadata['sub_chunk_index'] = i + metadata['total_sub_chunks'] = len(sub_chunks) + + if chunk.parent_context: + metadata['parent_context'] = chunk.parent_context + metadata['parent_class'] = chunk.parent_context[-1] + + if chunk.semantic_names: + metadata['semantic_names'] = chunk.semantic_names + metadata['primary_name'] = chunk.semantic_names[0] + + chunk_id = generate_deterministic_id(path, sub_chunk, i) + nodes.append(TextNode(id_=chunk_id, text=sub_chunk, metadata=metadata)) + + return nodes + + def _split_fallback( + self, + doc: LlamaDocument, + language: Optional[Language] = None + ) -> List[TextNode]: + """Fallback splitting using RecursiveCharacterTextSplitter.""" + text = doc.text + path = doc.metadata.get('path', 'unknown') + + if not text or not text.strip(): + return [] + + splitter = self._get_text_splitter(language) if language else self._default_splitter + chunks = splitter.split_text(text) + + nodes = [] + lang_str = doc.metadata.get('language', 'text') + text_offset = 0 + + for i, chunk in enumerate(chunks): + if not chunk or not chunk.strip(): + continue + if len(chunk.strip()) < self.min_chunk_size and len(chunks) > 1: + continue + if len(chunk) > 30000: + chunk = chunk[:30000] + + # Calculate line numbers + start_line = text[:text_offset].count('\n') + 1 if text_offset > 0 else 1 + chunk_pos = text.find(chunk, text_offset) + if chunk_pos >= 0: + text_offset = chunk_pos + len(chunk) + end_line = start_line + chunk.count('\n') + + metadata = dict(doc.metadata) + metadata['content_type'] = ContentType.FALLBACK.value + metadata['chunk_index'] = i + metadata['total_chunks'] = len(chunks) + metadata['start_line'] = start_line + metadata['end_line'] = end_line + + # Extract names via regex + names = self._metadata_extractor.extract_names_from_content(chunk, lang_str) + if names: + metadata['semantic_names'] = names + metadata['primary_name'] = names[0] + + # Extract inheritance + inheritance = self._metadata_extractor.extract_inheritance(chunk, lang_str) + if inheritance.get('extends'): + metadata['extends'] = inheritance['extends'] + metadata['parent_types'] = inheritance['extends'] + if inheritance.get('implements'): + metadata['implements'] = inheritance['implements'] + if inheritance.get('imports'): + metadata['imports'] = inheritance['imports'] + + chunk_id = generate_deterministic_id(path, chunk, i) + nodes.append(TextNode(id_=chunk_id, text=chunk, metadata=metadata)) + + return nodes + + def _build_metadata( + self, + chunk: ASTChunk, + base_metadata: Dict[str, Any], + chunk_index: int, + total_chunks: int + ) -> Dict[str, Any]: + """Build metadata dictionary from ASTChunk.""" + metadata = dict(base_metadata) + + metadata['content_type'] = chunk.content_type.value + metadata['node_type'] = chunk.node_type + metadata['chunk_index'] = chunk_index + metadata['total_chunks'] = total_chunks + metadata['start_line'] = chunk.start_line + metadata['end_line'] = chunk.end_line + + if chunk.parent_context: + metadata['parent_context'] = chunk.parent_context + metadata['parent_class'] = chunk.parent_context[-1] + metadata['full_path'] = '.'.join(chunk.parent_context + chunk.semantic_names[:1]) + + if chunk.semantic_names: + metadata['semantic_names'] = chunk.semantic_names + metadata['primary_name'] = chunk.semantic_names[0] + + if chunk.docstring: + metadata['docstring'] = chunk.docstring[:500] + + if chunk.signature: + metadata['signature'] = chunk.signature + + if chunk.extends: + metadata['extends'] = chunk.extends + metadata['parent_types'] = chunk.extends + + if chunk.implements: + metadata['implements'] = chunk.implements + + if chunk.imports: + metadata['imports'] = chunk.imports + + if chunk.namespace: + metadata['namespace'] = chunk.namespace + + return metadata + + def _get_text_splitter(self, language: Language) -> RecursiveCharacterTextSplitter: + """Get language-specific text splitter.""" + if language not in self._splitter_cache: + try: + self._splitter_cache[language] = RecursiveCharacterTextSplitter.from_language( + language=language, + chunk_size=self.max_chunk_size, + chunk_overlap=self.chunk_overlap, + ) + except Exception: + self._splitter_cache[language] = self._default_splitter + return self._splitter_cache[language] + + def _create_simplified_code( + self, + source_code: str, + chunks: List[ASTChunk], + language: str + ) -> str: + """Create simplified code with placeholders for extracted chunks.""" + semantic_chunks = [c for c in chunks if c.content_type == ContentType.FUNCTIONS_CLASSES] + if not semantic_chunks: + return source_code + + sorted_chunks = sorted( + semantic_chunks, + key=lambda x: source_code.find(x.content), + reverse=True + ) + + result = source_code + comment_prefix = self._metadata_extractor.get_comment_prefix(language) + + for chunk in sorted_chunks: + pos = result.find(chunk.content) + if pos == -1: + continue + + first_line = chunk.content.split('\n')[0].strip() + if len(first_line) > 60: + first_line = first_line[:60] + '...' + + breadcrumb = "" + if chunk.parent_context: + breadcrumb = f" (in {'.'.join(chunk.parent_context)})" + + placeholder = f"{comment_prefix} Code for: {first_line}{breadcrumb}\n" + result = result[:pos] + placeholder + result[pos + len(chunk.content):] + + return result.strip() + + def _parse_type_list(self, text: str) -> List[str]: + """Parse a comma-separated list of types.""" + if not text: + return [] + + text = text.strip().strip('()[]') + + # Remove keywords + for kw in ('extends', 'implements', 'with', ':'): + text = text.replace(kw, ' ') + + types = [] + for part in text.split(','): + name = part.strip() + if '<' in name: + name = name.split('<')[0].strip() + if '(' in name: + name = name.split('(')[0].strip() + if name: + types.append(name) + + return types + + def _get_semantic_node_types(self, language: str) -> Dict[str, List[str]]: + """Get semantic node types for manual traversal fallback.""" + types = { + 'python': { + 'class': ['class_definition'], + 'function': ['function_definition'], + }, + 'java': { + 'class': ['class_declaration', 'interface_declaration', 'enum_declaration'], + 'function': ['method_declaration', 'constructor_declaration'], + }, + 'javascript': { + 'class': ['class_declaration'], + 'function': ['function_declaration', 'method_definition', 'arrow_function'], + }, + 'typescript': { + 'class': ['class_declaration', 'interface_declaration'], + 'function': ['function_declaration', 'method_definition', 'arrow_function'], + }, + 'go': { + 'class': ['type_declaration'], + 'function': ['function_declaration', 'method_declaration'], + }, + 'rust': { + 'class': ['struct_item', 'impl_item', 'trait_item', 'enum_item'], + 'function': ['function_item'], + }, + 'c_sharp': { + 'class': ['class_declaration', 'interface_declaration', 'struct_declaration'], + 'function': ['method_declaration', 'constructor_declaration'], + }, + 'php': { + 'class': ['class_declaration', 'interface_declaration', 'trait_declaration'], + 'function': ['function_definition', 'method_declaration'], + }, + } + return types.get(language, {'class': [], 'function': []}) + + @staticmethod + def get_supported_languages() -> List[str]: + """Return list of languages with AST support.""" + return list(LANGUAGE_TO_TREESITTER.values()) + + @staticmethod + def is_ast_supported(path: str) -> bool: + """Check if AST parsing is supported for a file.""" + return is_ast_supported(path) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/tree_parser.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/tree_parser.py new file mode 100644 index 00000000..da378a38 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/tree_parser.py @@ -0,0 +1,129 @@ +""" +Tree-sitter parser wrapper with caching and language loading. + +Handles dynamic loading of tree-sitter language modules using the new API (v0.23+). +""" + +import logging +from typing import Dict, Any, Optional + +from .languages import TREESITTER_MODULES + +logger = logging.getLogger(__name__) + + +class TreeSitterParser: + """ + Wrapper for tree-sitter parser with language caching. + + Uses the new tree-sitter API (v0.23+) with individual language packages. + """ + + def __init__(self): + self._language_cache: Dict[str, Any] = {} + self._available: Optional[bool] = None + + def is_available(self) -> bool: + """Check if tree-sitter is available and working.""" + if self._available is None: + try: + from tree_sitter import Parser, Language + import tree_sitter_python as tspython + + py_language = Language(tspython.language()) + parser = Parser(py_language) + parser.parse(b"def test(): pass") + + self._available = True + logger.info("tree-sitter is available and working") + except ImportError as e: + logger.warning(f"tree-sitter not installed: {e}") + self._available = False + except Exception as e: + logger.warning(f"tree-sitter error: {type(e).__name__}: {e}") + self._available = False + return self._available + + def get_language(self, lang_name: str) -> Optional[Any]: + """ + Get tree-sitter Language object for a language name. + + Args: + lang_name: Tree-sitter language name (e.g., 'python', 'java', 'php') + + Returns: + tree_sitter.Language object or None if unavailable + """ + if lang_name in self._language_cache: + return self._language_cache[lang_name] + + if not self.is_available(): + return None + + try: + from tree_sitter import Language + + lang_info = TREESITTER_MODULES.get(lang_name) + if not lang_info: + logger.debug(f"No tree-sitter module mapping for '{lang_name}'") + return None + + module_name, func_name = lang_info + + import importlib + lang_module = importlib.import_module(module_name) + + lang_func = getattr(lang_module, func_name, None) + if not lang_func: + logger.debug(f"Module {module_name} has no {func_name} function") + return None + + language = Language(lang_func()) + self._language_cache[lang_name] = language + return language + + except Exception as e: + logger.debug(f"Could not load tree-sitter language '{lang_name}': {e}") + return None + + def parse(self, source_code: str, lang_name: str) -> Optional[Any]: + """ + Parse source code and return the AST tree. + + Args: + source_code: Source code string + lang_name: Tree-sitter language name + + Returns: + tree_sitter.Tree object or None if parsing failed + """ + language = self.get_language(lang_name) + if not language: + return None + + try: + from tree_sitter import Parser + + parser = Parser(language) + tree = parser.parse(bytes(source_code, "utf8")) + return tree + + except Exception as e: + logger.warning(f"Failed to parse code with tree-sitter ({lang_name}): {e}") + return None + + def clear_cache(self): + """Clear the language cache.""" + self._language_cache.clear() + + +# Global singleton instance +_parser_instance: Optional[TreeSitterParser] = None + + +def get_parser() -> TreeSitterParser: + """Get the global TreeSitterParser instance.""" + global _parser_instance + if _parser_instance is None: + _parser_instance = TreeSitterParser() + return _parser_instance diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/models/scoring_config.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/models/scoring_config.py new file mode 100644 index 00000000..e1cba601 --- /dev/null +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/models/scoring_config.py @@ -0,0 +1,232 @@ +""" +Scoring configuration for RAG query result reranking. + +Provides configurable boost factors and priority patterns that can be +overridden via environment variables. +""" + +import os +from typing import Dict, List +from pydantic import BaseModel, Field +import logging + +logger = logging.getLogger(__name__) + + +def _parse_list_env(env_var: str, default: List[str]) -> List[str]: + """Parse comma-separated environment variable into list.""" + value = os.getenv(env_var) + if not value: + return default + return [item.strip() for item in value.split(',') if item.strip()] + + +def _parse_float_env(env_var: str, default: float) -> float: + """Parse float from environment variable.""" + value = os.getenv(env_var) + if not value: + return default + try: + return float(value) + except ValueError: + logger.warning(f"Invalid float value for {env_var}: {value}, using default {default}") + return default + + +class ContentTypeBoost(BaseModel): + """Boost factors for different content types from AST parsing.""" + + functions_classes: float = Field( + default_factory=lambda: _parse_float_env("RAG_BOOST_FUNCTIONS_CLASSES", 1.2), + description="Boost for full function/class definitions (highest value)" + ) + fallback: float = Field( + default_factory=lambda: _parse_float_env("RAG_BOOST_FALLBACK", 1.0), + description="Boost for regex-based split chunks" + ) + oversized_split: float = Field( + default_factory=lambda: _parse_float_env("RAG_BOOST_OVERSIZED", 0.95), + description="Boost for large chunks that were split" + ) + simplified_code: float = Field( + default_factory=lambda: _parse_float_env("RAG_BOOST_SIMPLIFIED", 0.7), + description="Boost for code with placeholders (context only)" + ) + + def get(self, content_type: str) -> float: + """Get boost factor for a content type.""" + return getattr(self, content_type, 1.0) + + +class FilePriorityPatterns(BaseModel): + """File path patterns for priority-based boosting.""" + + high: List[str] = Field( + default_factory=lambda: _parse_list_env( + "RAG_HIGH_PRIORITY_PATTERNS", + ['service', 'controller', 'handler', 'api', 'core', 'auth', 'security', + 'permission', 'repository', 'dao', 'migration'] + ), + description="Patterns for high-priority files (1.3x boost)" + ) + + medium: List[str] = Field( + default_factory=lambda: _parse_list_env( + "RAG_MEDIUM_PRIORITY_PATTERNS", + ['model', 'entity', 'dto', 'schema', 'util', 'helper', 'common', + 'shared', 'component', 'hook', 'client', 'integration'] + ), + description="Patterns for medium-priority files (1.1x boost)" + ) + + low: List[str] = Field( + default_factory=lambda: _parse_list_env( + "RAG_LOW_PRIORITY_PATTERNS", + ['test', 'spec', 'config', 'mock', 'fixture', 'stub'] + ), + description="Patterns for low-priority files (0.8x penalty)" + ) + + high_boost: float = Field( + default_factory=lambda: _parse_float_env("RAG_HIGH_PRIORITY_BOOST", 1.3) + ) + medium_boost: float = Field( + default_factory=lambda: _parse_float_env("RAG_MEDIUM_PRIORITY_BOOST", 1.1) + ) + low_boost: float = Field( + default_factory=lambda: _parse_float_env("RAG_LOW_PRIORITY_BOOST", 0.8) + ) + + def get_priority(self, file_path: str) -> tuple: + """ + Get priority level and boost factor for a file path. + + Returns: + Tuple of (priority_name, boost_factor) + """ + path_lower = file_path.lower() + + if any(p in path_lower for p in self.high): + return ('HIGH', self.high_boost) + elif any(p in path_lower for p in self.medium): + return ('MEDIUM', self.medium_boost) + elif any(p in path_lower for p in self.low): + return ('LOW', self.low_boost) + else: + return ('MEDIUM', 1.0) + + +class MetadataBonus(BaseModel): + """Bonus multipliers for metadata presence.""" + + semantic_names: float = Field( + default_factory=lambda: _parse_float_env("RAG_BONUS_SEMANTIC_NAMES", 1.1), + description="Bonus for chunks with extracted semantic names" + ) + docstring: float = Field( + default_factory=lambda: _parse_float_env("RAG_BONUS_DOCSTRING", 1.05), + description="Bonus for chunks with docstrings" + ) + signature: float = Field( + default_factory=lambda: _parse_float_env("RAG_BONUS_SIGNATURE", 1.02), + description="Bonus for chunks with function signatures" + ) + + +class ScoringConfig(BaseModel): + """ + Complete scoring configuration for RAG query reranking. + + All values can be overridden via environment variables: + - RAG_BOOST_FUNCTIONS_CLASSES, RAG_BOOST_FALLBACK, etc. + - RAG_HIGH_PRIORITY_PATTERNS (comma-separated) + - RAG_HIGH_PRIORITY_BOOST, RAG_MEDIUM_PRIORITY_BOOST, etc. + - RAG_BONUS_SEMANTIC_NAMES, RAG_BONUS_DOCSTRING, etc. + + Usage: + config = ScoringConfig() + boost = config.content_type_boost.get('functions_classes') + priority, boost = config.file_priority.get_priority('/src/UserService.java') + """ + + content_type_boost: ContentTypeBoost = Field(default_factory=ContentTypeBoost) + file_priority: FilePriorityPatterns = Field(default_factory=FilePriorityPatterns) + metadata_bonus: MetadataBonus = Field(default_factory=MetadataBonus) + + # Score thresholds + min_relevance_score: float = Field( + default_factory=lambda: _parse_float_env("RAG_MIN_RELEVANCE_SCORE", 0.7), + description="Minimum score threshold for results" + ) + + max_score_cap: float = Field( + default_factory=lambda: _parse_float_env("RAG_MAX_SCORE_CAP", 1.0), + description="Maximum score cap after boosting" + ) + + def calculate_boosted_score( + self, + base_score: float, + file_path: str, + content_type: str, + has_semantic_names: bool = False, + has_docstring: bool = False, + has_signature: bool = False + ) -> tuple: + """ + Calculate final boosted score for a result. + + Args: + base_score: Original similarity score + file_path: File path of the chunk + content_type: Content type (functions_classes, fallback, etc.) + has_semantic_names: Whether chunk has semantic names + has_docstring: Whether chunk has docstring + has_signature: Whether chunk has signature + + Returns: + Tuple of (boosted_score, priority_level) + """ + score = base_score + + # File priority boost + priority, priority_boost = self.file_priority.get_priority(file_path) + score *= priority_boost + + # Content type boost + content_boost = self.content_type_boost.get(content_type) + score *= content_boost + + # Metadata bonuses + if has_semantic_names: + score *= self.metadata_bonus.semantic_names + if has_docstring: + score *= self.metadata_bonus.docstring + if has_signature: + score *= self.metadata_bonus.signature + + # Cap the score + score = min(score, self.max_score_cap) + + return (score, priority) + + +# Global singleton +_scoring_config: ScoringConfig | None = None + + +def get_scoring_config() -> ScoringConfig: + """Get the global ScoringConfig instance.""" + global _scoring_config + if _scoring_config is None: + _scoring_config = ScoringConfig() + logger.info("ScoringConfig initialized with:") + logger.info(f" High priority patterns: {_scoring_config.file_priority.high[:5]}...") + logger.info(f" Content type boosts: functions_classes={_scoring_config.content_type_boost.functions_classes}") + return _scoring_config + + +def reset_scoring_config(): + """Reset the global config (useful for testing).""" + global _scoring_config + _scoring_config = None diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py index dc222980..f9513fb4 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/services/query_service.py @@ -8,36 +8,13 @@ from qdrant_client.http.models import Filter, FieldCondition, MatchValue, MatchAny from ..models.config import RAGConfig +from ..models.scoring_config import get_scoring_config, ScoringConfig from ..utils.utils import make_namespace, make_project_namespace from ..core.openrouter_embedding import OpenRouterEmbedding from ..models.instructions import InstructionType, format_query logger = logging.getLogger(__name__) -# File priority patterns for smart RAG -HIGH_PRIORITY_PATTERNS = [ - 'service', 'controller', 'handler', 'api', 'core', 'auth', 'security', - 'permission', 'repository', 'dao', 'migration' -] - -MEDIUM_PRIORITY_PATTERNS = [ - 'model', 'entity', 'dto', 'schema', 'util', 'helper', 'common', - 'shared', 'component', 'hook', 'client', 'integration' -] - -LOW_PRIORITY_PATTERNS = [ - 'test', 'spec', 'config', 'mock', 'fixture', 'stub' -] - -# Content type priorities for AST-based chunks -# functions_classes are more valuable than simplified_code (placeholders) -CONTENT_TYPE_BOOST = { - 'functions_classes': 1.2, # Full function/class definitions - highest value - 'fallback': 1.0, # Regex-based split - normal value - 'oversized_split': 0.95, # Large chunks that were split - slightly lower - 'simplified_code': 0.7, # Code with placeholders - lower value (context only) -} - class RAGQueryService: """Service for querying RAG indices using Qdrant. @@ -864,11 +841,12 @@ def _merge_and_rank_results(self, results: List[Dict], min_score_threshold: floa """ Deduplicate matches and filter by relevance score with priority-based reranking. - Applies three types of boosting: + Uses ScoringConfig for configurable boosting factors: 1. File path priority (service/controller vs test/config) 2. Content type priority (functions_classes vs simplified_code) 3. Semantic name bonus (chunks with extracted function/class names) """ + scoring_config = get_scoring_config() grouped = {} # Deduplicate by file_path + content hash @@ -884,43 +862,28 @@ def _merge_and_rank_results(self, results: List[Dict], min_score_threshold: floa unique_results = list(grouped.values()) - # Apply multi-factor score boosting + # Apply multi-factor score boosting using ScoringConfig for result in unique_results: metadata = result.get('metadata', {}) - file_path = metadata.get('path', metadata.get('file_path', '')).lower() + file_path = metadata.get('path', metadata.get('file_path', '')) content_type = metadata.get('content_type', 'fallback') semantic_names = metadata.get('semantic_names', []) + has_docstring = bool(metadata.get('docstring')) + has_signature = bool(metadata.get('signature')) + + boosted_score, priority = scoring_config.calculate_boosted_score( + base_score=result['score'], + file_path=file_path, + content_type=content_type, + has_semantic_names=bool(semantic_names), + has_docstring=has_docstring, + has_signature=has_signature + ) - base_score = result['score'] - - # 1. File path priority boosting - if any(p in file_path for p in HIGH_PRIORITY_PATTERNS): - base_score *= 1.3 - result['_priority'] = 'HIGH' - elif any(p in file_path for p in MEDIUM_PRIORITY_PATTERNS): - base_score *= 1.1 - result['_priority'] = 'MEDIUM' - elif any(p in file_path for p in LOW_PRIORITY_PATTERNS): - base_score *= 0.8 # Penalize test/config files - result['_priority'] = 'LOW' - else: - result['_priority'] = 'MEDIUM' - - # 2. Content type boosting (AST-based metadata) - content_boost = CONTENT_TYPE_BOOST.get(content_type, 1.0) - base_score *= content_boost + result['score'] = boosted_score + result['_priority'] = priority result['_content_type'] = content_type - - # 3. Semantic name bonus - chunks with extracted names are more valuable - if semantic_names: - base_score *= 1.1 # 10% bonus for having semantic names - result['_has_semantic_names'] = True - - # 4. Docstring bonus - chunks with docstrings provide better context - if metadata.get('docstring'): - base_score *= 1.05 # 5% bonus for having docstring - - result['score'] = min(1.0, base_score) + result['_has_semantic_names'] = bool(semantic_names) # Filter by threshold filtered = [r for r in unique_results if r['score'] >= min_score_threshold] From 642bda00bc9ad63cd0f8f2fffbb4586a99524876 Mon Sep 17 00:00:00 2001 From: rostislav Date: Wed, 28 Jan 2026 02:55:20 +0200 Subject: [PATCH 19/20] feat: Enhance lock management in PullRequestAnalysisProcessor and improve code referencing in prompts --- .../processor/analysis/PullRequestAnalysisProcessor.java | 6 +++++- .../mcp-client/service/multi_stage_orchestrator.py | 6 ++++-- .../mcp-client/utils/prompts/prompt_constants.py | 6 ++++++ .../src/rag_pipeline/core/index_manager/indexer.py | 1 + .../src/rag_pipeline/core/index_manager/point_operations.py | 2 +- 5 files changed, 17 insertions(+), 4 deletions(-) diff --git a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java index ce7b3292..87443d21 100644 --- a/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java +++ b/java-ecosystem/libs/analysis-engine/src/main/java/org/rostilos/codecrow/analysisengine/processor/analysis/PullRequestAnalysisProcessor.java @@ -94,8 +94,10 @@ public Map process( // Check if a lock was already acquired by the caller (e.g., webhook handler) // to prevent double-locking which causes unnecessary 2-minute waits String lockKey; + boolean isPreAcquired = false; if (request.getPreAcquiredLockKey() != null && !request.getPreAcquiredLockKey().isBlank()) { lockKey = request.getPreAcquiredLockKey(); + isPreAcquired = true; log.info("Using pre-acquired lock: {} for project={}, PR={}", lockKey, project.getId(), request.getPullRequestId()); } else { Optional acquiredLock = analysisLockService.acquireLockWithWait( @@ -225,7 +227,9 @@ public Map process( return Map.of("status", "error", "message", e.getMessage()); } finally { - analysisLockService.releaseLock(lockKey); + if (!isPreAcquired) { + analysisLockService.releaseLock(lockKey); + } } } diff --git a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py index e4b9846b..bf662f22 100644 --- a/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py +++ b/python-ecosystem/mcp-client/service/multi_stage_orchestrator.py @@ -410,7 +410,7 @@ def _extract_diff_snippets(self, diff_content: str) -> List[str]: snippets = [] current_snippet_lines = [] - for line in diff_content.split("\n"): + for line in diff_content.splitlines(): # Focus on added lines (new code) if line.startswith("+") and not line.startswith("+++"): clean_line = line[1:].strip() @@ -1084,8 +1084,10 @@ def _format_rag_context( meta_lines.append(f"Type: {chunk_type}") meta_text = "\n".join(meta_lines) + # Use file path as primary identifier, not a number + # This encourages AI to reference by path rather than by chunk number formatted_parts.append( - f"### Related Code #{included_count} (relevance: {score:.2f})\n" + f"### Context from `{path}` (relevance: {score:.2f})\n" f"{meta_text}\n" f"```\n{text}\n```\n" ) diff --git a/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py b/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py index 47e65d1b..1c86a621 100644 --- a/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py +++ b/python-ecosystem/mcp-client/utils/prompts/prompt_constants.py @@ -356,6 +356,12 @@ CODEBASE CONTEXT (from RAG): {rag_context} +IMPORTANT: When referencing codebase context in your analysis: +- ALWAYS cite the actual file path (e.g., "as seen in `src/service/UserService.java`") +- NEVER reference context by number (e.g., DO NOT say "Related Code #1" or "chunk #3") +- Quote relevant code snippets when needed to support your analysis +- The numbered headers are for your reference only, not for output + {previous_issues} SUGGESTED_FIX_DIFF_FORMAT: diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py index 2bf145a4..d2f3f323 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/indexer.py @@ -94,6 +94,7 @@ def estimate_repository_size( sample_chunk_count += len(chunks) del chunks del documents + gc.collect() avg_chunks_per_file = sample_chunk_count / SAMPLE_SIZE chunk_count = int(avg_chunks_per_file * file_count) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py index b9682bf0..62ca6fd2 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager/point_operations.py @@ -50,7 +50,7 @@ def prepare_chunks_for_embedding( # Group chunks by file path chunks_by_file: Dict[str, List[TextNode]] = {} for chunk in chunks: - path = chunk.metadata.get("path", "unknown") + path = chunk.metadata.get("path", str(uuid.uuid4())) if path not in chunks_by_file: chunks_by_file[path] = [] chunks_by_file[path].append(chunk) From 5add89cf8ba4a2485e139e51d17b265b37635147 Mon Sep 17 00:00:00 2001 From: rostislav Date: Wed, 28 Jan 2026 22:05:19 +0200 Subject: [PATCH 20/20] feat: Enhance AST processing and metadata extraction in RAG pipeline components --- .../src/rag_pipeline/core/index_manager.py | 3 +- .../src/rag_pipeline/core/loader.py | 21 +- .../rag_pipeline/core/splitter/splitter.py | 646 +++++++++++++++++- .../src/rag_pipeline/models/config.py | 8 +- .../src/rag_pipeline/utils/utils.py | 47 ++ 5 files changed, 696 insertions(+), 29 deletions(-) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py index 7eebf78e..fae2373f 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/index_manager.py @@ -63,7 +63,8 @@ def __init__(self, config: RAGConfig): max_chunk_size=config.chunk_size, min_chunk_size=min(200, config.chunk_size // 4), chunk_overlap=config.chunk_overlap, - parser_threshold=10 # Minimum lines for AST parsing + parser_threshold=3, # Low threshold - AST benefits even small files + enrich_embedding_text=True # Prepend semantic context for better embeddings ) self.loader = DocumentLoader(config) diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/loader.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/loader.py index 25d92966..05235278 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/loader.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/loader.py @@ -3,7 +3,7 @@ import logging from llama_index.core.schema import Document -from ..utils.utils import detect_language_from_path, should_exclude_file, is_binary_file +from ..utils.utils import detect_language_from_path, should_exclude_file, is_binary_file, clean_archive_path from ..models.config import RAGConfig logger = logging.getLogger(__name__) @@ -106,11 +106,14 @@ def load_file_batch( language = detect_language_from_path(str(full_path)) filetype = full_path.suffix.lstrip('.') + # Clean archive root prefix from path (e.g., 'owner-repo-commit/src/file.php' -> 'src/file.php') + clean_path = clean_archive_path(relative_path_str) + metadata = { "workspace": workspace, "project": project, "branch": branch, - "path": relative_path_str, + "path": clean_path, "commit": commit, "language": language, "filetype": filetype, @@ -190,11 +193,14 @@ def load_from_directory( language = detect_language_from_path(str(file_path)) filetype = file_path.suffix.lstrip('.') + # Clean archive root prefix from path + clean_path = clean_archive_path(relative_path) + metadata = { "workspace": workspace, "project": project, "branch": branch, - "path": relative_path, + "path": clean_path, "commit": commit, "language": language, "filetype": filetype, @@ -207,7 +213,7 @@ def load_from_directory( ) documents.append(doc) - logger.debug(f"Loaded document: {relative_path} ({language})") + logger.debug(f"Loaded document: {clean_path} ({language})") logger.info(f"Loaded {len(documents)} documents from {repo_path} (excluded {excluded_count} files by patterns)") return documents @@ -257,11 +263,14 @@ def load_specific_files( language = detect_language_from_path(str(full_path)) filetype = full_path.suffix.lstrip('.') + # Clean archive root prefix from path + clean_path = clean_archive_path(relative_path) + metadata = { "workspace": workspace, "project": project, "branch": branch, - "path": relative_path, + "path": clean_path, "commit": commit, "language": language, "filetype": filetype, @@ -274,7 +283,7 @@ def load_specific_files( ) documents.append(doc) - logger.debug(f"Loaded document: {relative_path}") + logger.debug(f"Loaded document: {clean_path}") return documents diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py index 622892d9..25c12d8a 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/core/splitter/splitter.py @@ -48,22 +48,69 @@ def compute_file_hash(content: str) -> str: @dataclass class ASTChunk: - """Represents a chunk of code from AST parsing.""" + """Represents a chunk of code from AST parsing with rich metadata.""" content: str content_type: ContentType language: str path: str + + # Identity semantic_names: List[str] = field(default_factory=list) - parent_context: List[str] = field(default_factory=list) - docstring: Optional[str] = None - signature: Optional[str] = None + node_type: Optional[str] = None + namespace: Optional[str] = None + + # Location start_line: int = 0 end_line: int = 0 - node_type: Optional[str] = None + + # Hierarchy & Context + parent_context: List[str] = field(default_factory=list) # Breadcrumb path + + # Documentation + docstring: Optional[str] = None + signature: Optional[str] = None + + # Type relationships extends: List[str] = field(default_factory=list) implements: List[str] = field(default_factory=list) + + # Dependencies imports: List[str] = field(default_factory=list) - namespace: Optional[str] = None + + # --- RICH AST FIELDS (extracted from tree-sitter) --- + + # Methods/functions within this chunk (for classes) + methods: List[str] = field(default_factory=list) + + # Properties/fields within this chunk (for classes) + properties: List[str] = field(default_factory=list) + + # Parameters (for functions/methods) + parameters: List[str] = field(default_factory=list) + + # Return type (for functions/methods) + return_type: Optional[str] = None + + # Decorators/annotations + decorators: List[str] = field(default_factory=list) + + # Modifiers (public, private, static, async, abstract, etc.) + modifiers: List[str] = field(default_factory=list) + + # Called functions/methods (dependencies) + calls: List[str] = field(default_factory=list) + + # Referenced types (type annotations, generics) + referenced_types: List[str] = field(default_factory=list) + + # Variables declared in this chunk + variables: List[str] = field(default_factory=list) + + # Constants defined + constants: List[str] = field(default_factory=list) + + # Generic type parameters (e.g., ) + type_parameters: List[str] = field(default_factory=list) class ASTCodeSplitter: @@ -76,23 +123,36 @@ class ASTCodeSplitter: - Falls back to RecursiveCharacterTextSplitter when needed - Uses deterministic IDs for Qdrant deduplication - Enriches metadata for improved RAG retrieval + - Prepares embedding-optimized text with semantic context + + Chunk Size Strategy: + - text-embedding-3-small supports ~8191 tokens (~32K chars) + - We use 8000 chars as default to keep semantic units intact + - Only truly massive classes/functions get split + - Splitting loses AST benefits, so we avoid it when possible Usage: - splitter = ASTCodeSplitter(max_chunk_size=2000) + splitter = ASTCodeSplitter(max_chunk_size=8000) nodes = splitter.split_documents(documents) """ - DEFAULT_MAX_CHUNK_SIZE = 2000 + # Chunk size considerations: + # - Embedding models (text-embedding-3-small): ~8191 tokens = ~32K chars + # - Most classes/functions: 500-5000 chars + # - Keeping semantic units whole improves retrieval quality + # - Only split when absolutely necessary + DEFAULT_MAX_CHUNK_SIZE = 8000 # ~2000 tokens, fits most semantic units DEFAULT_MIN_CHUNK_SIZE = 100 DEFAULT_CHUNK_OVERLAP = 200 - DEFAULT_PARSER_THRESHOLD = 10 + DEFAULT_PARSER_THRESHOLD = 3 # Low threshold - AST benefits even small files def __init__( self, max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE, min_chunk_size: int = DEFAULT_MIN_CHUNK_SIZE, chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, - parser_threshold: int = DEFAULT_PARSER_THRESHOLD + parser_threshold: int = DEFAULT_PARSER_THRESHOLD, + enrich_embedding_text: bool = True ): """ Initialize AST code splitter. @@ -101,12 +161,15 @@ def __init__( max_chunk_size: Maximum characters per chunk min_chunk_size: Minimum characters for a valid chunk chunk_overlap: Overlap between chunks when splitting oversized content - parser_threshold: Minimum lines for AST parsing + parser_threshold: Minimum lines for AST parsing (3 recommended) + enrich_embedding_text: Whether to prepend semantic context to chunk text + for better embedding quality """ self.max_chunk_size = max_chunk_size self.min_chunk_size = min_chunk_size self.chunk_overlap = chunk_overlap self.parser_threshold = parser_threshold + self.enrich_embedding_text = enrich_embedding_text # Components self._parser = get_parser() @@ -295,12 +358,16 @@ def _extract_with_queries( node_type=match.pattern_name, extends=extends, implements=implements, + modifiers=modifiers, ) # Extract docstring and signature chunk.docstring = self._metadata_extractor.extract_docstring(main_cap.text, lang_name) chunk.signature = self._metadata_extractor.extract_signature(main_cap.text, lang_name) + # Extract rich AST details (methods, properties, params, calls, etc.) + self._extract_rich_ast_details(chunk, tree, main_cap, lang_name) + chunks.append(chunk) # Add imports and namespace to all chunks @@ -390,6 +457,9 @@ def traverse(node, parent_context: List[str]): chunk.implements = inheritance.get('implements', []) chunk.imports = inheritance.get('imports', []) + # Extract rich AST details directly from this node + self._extract_rich_details_from_node(chunk, node, source_bytes, lang_name) + chunks.append(chunk) processed_ranges.add(node_range) @@ -418,6 +488,339 @@ def traverse(node, parent_context: List[str]): return chunks + def _extract_rich_ast_details( + self, + chunk: ASTChunk, + tree: Any, + captured_node: Any, + lang_name: str + ) -> None: + """ + Extract rich AST details from tree-sitter node by traversing its children. + + This extracts: + - Methods (for classes) + - Properties/fields (for classes) + - Parameters (for functions/methods) + - Return type + - Decorators/annotations + - Called functions/methods + - Referenced types + - Variables + - Type parameters (generics) + """ + source_bytes = chunk.content.encode('utf-8') + + # Find the actual tree-sitter node for this capture + node = self._find_node_at_position( + tree.root_node, + captured_node.start_byte, + captured_node.end_byte + ) + if not node: + return + + # Language-specific node type mappings + node_types = self._get_rich_node_types(lang_name) + + def get_text(n) -> str: + """Get text for a node relative to chunk content.""" + start = n.start_byte - captured_node.start_byte + end = n.end_byte - captured_node.start_byte + if 0 <= start < len(source_bytes) and start < end <= len(source_bytes): + return source_bytes[start:end].decode('utf-8', errors='replace') + return '' + + def extract_identifier(n) -> Optional[str]: + """Extract identifier name from a node.""" + for child in n.children: + if child.type in node_types['identifier']: + return get_text(child) + return None + + def traverse_for_details(n, depth: int = 0): + """Recursively traverse to extract details.""" + if depth > 10: # Prevent infinite recursion + return + + node_type = n.type + + # Extract methods (for classes) + if node_type in node_types['method']: + method_name = extract_identifier(n) + if method_name and method_name not in chunk.methods: + chunk.methods.append(method_name) + + # Extract properties/fields + if node_type in node_types['property']: + prop_name = extract_identifier(n) + if prop_name and prop_name not in chunk.properties: + chunk.properties.append(prop_name) + + # Extract parameters + if node_type in node_types['parameter']: + param_name = extract_identifier(n) + if param_name and param_name not in chunk.parameters: + chunk.parameters.append(param_name) + + # Extract decorators/annotations + if node_type in node_types['decorator']: + dec_text = get_text(n).strip() + if dec_text and dec_text not in chunk.decorators: + # Clean up decorator text + if dec_text.startswith('@'): + dec_text = dec_text[1:] + if '(' in dec_text: + dec_text = dec_text.split('(')[0] + chunk.decorators.append(dec_text) + + # Extract function calls + if node_type in node_types['call']: + call_name = extract_identifier(n) + if call_name and call_name not in chunk.calls: + chunk.calls.append(call_name) + + # Extract type references + if node_type in node_types['type_ref']: + type_text = get_text(n).strip() + if type_text and type_text not in chunk.referenced_types: + # Clean generic params + if '<' in type_text: + type_text = type_text.split('<')[0] + chunk.referenced_types.append(type_text) + + # Extract return type + if node_type in node_types['return_type'] and not chunk.return_type: + chunk.return_type = get_text(n).strip() + + # Extract type parameters (generics) + if node_type in node_types['type_param']: + param_text = get_text(n).strip() + if param_text and param_text not in chunk.type_parameters: + chunk.type_parameters.append(param_text) + + # Extract variables + if node_type in node_types['variable']: + var_name = extract_identifier(n) + if var_name and var_name not in chunk.variables: + chunk.variables.append(var_name) + + # Recurse into children + for child in n.children: + traverse_for_details(child, depth + 1) + + traverse_for_details(node) + + # Limit list sizes to prevent bloat + chunk.methods = chunk.methods[:30] + chunk.properties = chunk.properties[:30] + chunk.parameters = chunk.parameters[:20] + chunk.decorators = chunk.decorators[:10] + chunk.calls = chunk.calls[:50] + chunk.referenced_types = chunk.referenced_types[:30] + chunk.variables = chunk.variables[:30] + chunk.type_parameters = chunk.type_parameters[:10] + + def _find_node_at_position(self, root, start_byte: int, end_byte: int) -> Optional[Any]: + """Find the tree-sitter node at the given byte position.""" + def find(node): + if node.start_byte == start_byte and node.end_byte == end_byte: + return node + for child in node.children: + if child.start_byte <= start_byte and child.end_byte >= end_byte: + result = find(child) + if result: + return result + return None + return find(root) + + def _get_rich_node_types(self, language: str) -> Dict[str, List[str]]: + """Get tree-sitter node types for extracting rich details.""" + # Common patterns across languages + common = { + 'identifier': ['identifier', 'name', 'type_identifier', 'property_identifier'], + 'call': ['call_expression', 'call', 'function_call', 'method_invocation'], + 'type_ref': ['type_identifier', 'generic_type', 'type_annotation', 'type'], + 'type_param': ['type_parameter', 'type_parameters', 'generic_parameter'], + } + + types = { + 'python': { + **common, + 'method': ['function_definition'], + 'property': ['assignment', 'expression_statement'], + 'parameter': ['parameter', 'default_parameter', 'typed_parameter'], + 'decorator': ['decorator'], + 'return_type': ['type'], + 'variable': ['assignment'], + }, + 'java': { + **common, + 'method': ['method_declaration', 'constructor_declaration'], + 'property': ['field_declaration'], + 'parameter': ['formal_parameter', 'spread_parameter'], + 'decorator': ['annotation', 'marker_annotation'], + 'return_type': ['type_identifier', 'generic_type', 'void_type'], + 'variable': ['local_variable_declaration'], + }, + 'javascript': { + **common, + 'method': ['method_definition', 'function_declaration'], + 'property': ['field_definition', 'public_field_definition'], + 'parameter': ['formal_parameters', 'required_parameter'], + 'decorator': ['decorator'], + 'return_type': ['type_annotation'], + 'variable': ['variable_declarator'], + }, + 'typescript': { + **common, + 'method': ['method_definition', 'method_signature', 'function_declaration'], + 'property': ['public_field_definition', 'property_signature'], + 'parameter': ['required_parameter', 'optional_parameter'], + 'decorator': ['decorator'], + 'return_type': ['type_annotation'], + 'variable': ['variable_declarator'], + }, + 'go': { + **common, + 'method': ['method_declaration', 'function_declaration'], + 'property': ['field_declaration'], + 'parameter': ['parameter_declaration'], + 'decorator': [], # Go doesn't have decorators + 'return_type': ['type_identifier', 'pointer_type'], + 'variable': ['short_var_declaration', 'var_declaration'], + }, + 'rust': { + **common, + 'method': ['function_item', 'associated_item'], + 'property': ['field_declaration'], + 'parameter': ['parameter'], + 'decorator': ['attribute_item'], + 'return_type': ['type_identifier', 'generic_type'], + 'variable': ['let_declaration'], + }, + 'c_sharp': { + **common, + 'method': ['method_declaration', 'constructor_declaration'], + 'property': ['property_declaration', 'field_declaration'], + 'parameter': ['parameter'], + 'decorator': ['attribute_list', 'attribute'], + 'return_type': ['predefined_type', 'generic_name'], + 'variable': ['variable_declaration'], + }, + 'php': { + **common, + 'method': ['method_declaration', 'function_definition'], + 'property': ['property_declaration'], + 'parameter': ['simple_parameter'], + 'decorator': ['attribute_list'], + 'return_type': ['named_type', 'union_type'], + 'variable': ['property_declaration', 'simple_variable'], + }, + } + + return types.get(language, { + **common, + 'method': [], + 'property': [], + 'parameter': [], + 'decorator': [], + 'return_type': [], + 'variable': [], + }) + + def _extract_rich_details_from_node( + self, + chunk: ASTChunk, + node: Any, + source_bytes: bytes, + lang_name: str + ) -> None: + """ + Extract rich AST details directly from a tree-sitter node. + Used by traversal-based extraction when we already have the node. + """ + node_types = self._get_rich_node_types(lang_name) + + def get_text(n) -> str: + return source_bytes[n.start_byte:n.end_byte].decode('utf-8', errors='replace') + + def extract_identifier(n) -> Optional[str]: + for child in n.children: + if child.type in node_types['identifier']: + return get_text(child) + return None + + def traverse(n, depth: int = 0): + if depth > 10: + return + + node_type = n.type + + if node_type in node_types['method']: + name = extract_identifier(n) + if name and name not in chunk.methods: + chunk.methods.append(name) + + if node_type in node_types['property']: + name = extract_identifier(n) + if name and name not in chunk.properties: + chunk.properties.append(name) + + if node_type in node_types['parameter']: + name = extract_identifier(n) + if name and name not in chunk.parameters: + chunk.parameters.append(name) + + if node_type in node_types['decorator']: + dec_text = get_text(n).strip() + if dec_text and dec_text not in chunk.decorators: + if dec_text.startswith('@'): + dec_text = dec_text[1:] + if '(' in dec_text: + dec_text = dec_text.split('(')[0] + chunk.decorators.append(dec_text) + + if node_type in node_types['call']: + name = extract_identifier(n) + if name and name not in chunk.calls: + chunk.calls.append(name) + + if node_type in node_types['type_ref']: + type_text = get_text(n).strip() + if type_text and type_text not in chunk.referenced_types: + if '<' in type_text: + type_text = type_text.split('<')[0] + chunk.referenced_types.append(type_text) + + if node_type in node_types['return_type'] and not chunk.return_type: + chunk.return_type = get_text(n).strip() + + if node_type in node_types['type_param']: + param_text = get_text(n).strip() + if param_text and param_text not in chunk.type_parameters: + chunk.type_parameters.append(param_text) + + if node_type in node_types['variable']: + name = extract_identifier(n) + if name and name not in chunk.variables: + chunk.variables.append(name) + + for child in n.children: + traverse(child, depth + 1) + + traverse(node) + + # Limit sizes + chunk.methods = chunk.methods[:30] + chunk.properties = chunk.properties[:30] + chunk.parameters = chunk.parameters[:20] + chunk.decorators = chunk.decorators[:10] + chunk.calls = chunk.calls[:50] + chunk.referenced_types = chunk.referenced_types[:30] + chunk.variables = chunk.variables[:30] + chunk.type_parameters = chunk.type_parameters[:10] + def _process_chunks( self, chunks: List[ASTChunk], @@ -438,9 +841,12 @@ def _process_chunks( metadata = self._build_metadata(ast_chunk, doc.metadata, chunk_counter, len(chunks)) chunk_id = generate_deterministic_id(path, ast_chunk.content, chunk_counter) + # Create embedding-enriched text with semantic context + enriched_text = self._create_embedding_text(ast_chunk.content, metadata) + node = TextNode( id_=chunk_id, - text=ast_chunk.content, + text=enriched_text, metadata=metadata ) nodes.append(node) @@ -455,36 +861,87 @@ def _split_oversized_chunk( base_metadata: Dict[str, Any], path: str ) -> List[TextNode]: - """Split an oversized chunk using RecursiveCharacterTextSplitter.""" + """ + Split an oversized chunk using RecursiveCharacterTextSplitter. + + IMPORTANT: Splitting an AST chunk loses semantic integrity. + We try to preserve what we can: + - Parent context and primary name are kept (they're still relevant) + - Detailed lists (methods, properties, calls) are NOT copied to sub-chunks + because they describe the whole unit, not the fragment + - A summary of the original unit is prepended to help embeddings + """ splitter = self._get_text_splitter(language) if language else self._default_splitter sub_chunks = splitter.split_text(chunk.content) nodes = [] parent_id = generate_deterministic_id(path, chunk.content, 0) + total_sub = len([s for s in sub_chunks if s and s.strip()]) + # Build a brief summary of the original semantic unit + # This helps embeddings understand context even in fragments + unit_summary_parts = [] + if chunk.semantic_names: + unit_summary_parts.append(f"{chunk.node_type or 'code'}: {chunk.semantic_names[0]}") + if chunk.extends: + unit_summary_parts.append(f"extends {', '.join(chunk.extends[:3])}") + if chunk.implements: + unit_summary_parts.append(f"implements {', '.join(chunk.implements[:3])}") + if chunk.methods: + unit_summary_parts.append(f"has {len(chunk.methods)} methods") + + unit_summary = " | ".join(unit_summary_parts) if unit_summary_parts else None + + sub_idx = 0 for i, sub_chunk in enumerate(sub_chunks): if not sub_chunk or not sub_chunk.strip(): continue - if len(sub_chunk.strip()) < self.min_chunk_size and len(sub_chunks) > 1: + if len(sub_chunk.strip()) < self.min_chunk_size and total_sub > 1: continue + # Build metadata for this fragment + # DO NOT copy detailed lists - they don't apply to fragments metadata = dict(base_metadata) metadata['content_type'] = ContentType.OVERSIZED_SPLIT.value metadata['original_content_type'] = chunk.content_type.value metadata['parent_chunk_id'] = parent_id - metadata['sub_chunk_index'] = i - metadata['total_sub_chunks'] = len(sub_chunks) + metadata['sub_chunk_index'] = sub_idx + metadata['total_sub_chunks'] = total_sub + metadata['start_line'] = chunk.start_line + metadata['end_line'] = chunk.end_line + # Keep parent context - still relevant if chunk.parent_context: metadata['parent_context'] = chunk.parent_context metadata['parent_class'] = chunk.parent_context[-1] + # Keep primary name - this fragment belongs to this unit if chunk.semantic_names: - metadata['semantic_names'] = chunk.semantic_names + metadata['semantic_names'] = chunk.semantic_names[:1] # Just the main name metadata['primary_name'] = chunk.semantic_names[0] - chunk_id = generate_deterministic_id(path, sub_chunk, i) - nodes.append(TextNode(id_=chunk_id, text=sub_chunk, metadata=metadata)) + # Add note that this is a fragment + metadata['is_fragment'] = True + metadata['fragment_of'] = chunk.semantic_names[0] if chunk.semantic_names else None + + # For embedding: prepend fragment context + if unit_summary: + fragment_header = f"[Fragment {sub_idx + 1}/{total_sub} of {unit_summary}]" + enriched_text = f"{fragment_header}\n\n{sub_chunk}" + else: + enriched_text = sub_chunk + + chunk_id = generate_deterministic_id(path, sub_chunk, sub_idx) + nodes.append(TextNode(id_=chunk_id, text=enriched_text, metadata=metadata)) + sub_idx += 1 + + # Log when splitting happens - it's a signal the chunk_size might need adjustment + if nodes: + logger.info( + f"Split oversized {chunk.node_type or 'chunk'} " + f"'{chunk.semantic_names[0] if chunk.semantic_names else 'unknown'}' " + f"({len(chunk.content)} chars) into {len(nodes)} fragments" + ) return nodes @@ -545,8 +1002,11 @@ def _split_fallback( if inheritance.get('imports'): metadata['imports'] = inheritance['imports'] + # Create embedding-enriched text with semantic context + enriched_text = self._create_embedding_text(chunk, metadata) + chunk_id = generate_deterministic_id(path, chunk, i) - nodes.append(TextNode(id_=chunk_id, text=chunk, metadata=metadata)) + nodes.append(TextNode(id_=chunk_id, text=enriched_text, metadata=metadata)) return nodes @@ -595,8 +1055,154 @@ def _build_metadata( if chunk.namespace: metadata['namespace'] = chunk.namespace + # --- RICH AST METADATA --- + + if chunk.methods: + metadata['methods'] = chunk.methods + + if chunk.properties: + metadata['properties'] = chunk.properties + + if chunk.parameters: + metadata['parameters'] = chunk.parameters + + if chunk.return_type: + metadata['return_type'] = chunk.return_type + + if chunk.decorators: + metadata['decorators'] = chunk.decorators + + if chunk.modifiers: + metadata['modifiers'] = chunk.modifiers + + if chunk.calls: + metadata['calls'] = chunk.calls + + if chunk.referenced_types: + metadata['referenced_types'] = chunk.referenced_types + + if chunk.variables: + metadata['variables'] = chunk.variables + + if chunk.constants: + metadata['constants'] = chunk.constants + + if chunk.type_parameters: + metadata['type_parameters'] = chunk.type_parameters + return metadata + def _create_embedding_text(self, content: str, metadata: Dict[str, Any]) -> str: + """ + Create embedding-optimized text by prepending concise semantic context. + + Design principles: + 1. Keep it SHORT - long headers can skew embeddings for small code chunks + 2. Avoid redundancy - don't repeat info that's obvious from the code + 3. Clean paths - strip commit hashes and archive prefixes + 4. Add VALUE - include info that helps semantic matching + + What we include (selectively): + - Clean file path (without commit/archive prefixes) + - Parent context (for nested structures - very valuable) + - Extends/implements (inheritance is critical for understanding) + - Docstring (helps semantic matching) + - For CLASSES: method count (helps identify scope) + - For METHODS: skip redundant method list + """ + if not self.enrich_embedding_text: + return content + + context_parts = [] + + # Clean file path - remove commit hash prefixes and archive structure + path = metadata.get('path', '') + if path: + path = self._clean_path(path) + context_parts.append(f"File: {path}") + + # Parent context - valuable for nested structures + parent_context = metadata.get('parent_context', []) + if parent_context: + context_parts.append(f"In: {'.'.join(parent_context)}") + + # Clean namespace - strip keyword if present + namespace = metadata.get('namespace', '') + if namespace: + ns_clean = namespace.replace('namespace ', '').replace('package ', '').strip().rstrip(';') + if ns_clean: + context_parts.append(f"Namespace: {ns_clean}") + + # Type relationships - very valuable for understanding code structure + extends = metadata.get('extends', []) + implements = metadata.get('implements', []) + if extends: + context_parts.append(f"Extends: {', '.join(extends[:3])}") + if implements: + context_parts.append(f"Implements: {', '.join(implements[:3])}") + + # For CLASSES: show method/property counts (helps understand scope) + # For METHODS/FUNCTIONS: skip - it's redundant + node_type = metadata.get('node_type', '') + is_container = node_type in ('class', 'interface', 'struct', 'trait', 'enum', 'impl') + + if is_container: + methods = metadata.get('methods', []) + properties = metadata.get('properties', []) + if methods and len(methods) > 1: + # Only show if there are multiple methods + context_parts.append(f"Methods({len(methods)}): {', '.join(methods[:8])}") + if properties and len(properties) > 1: + context_parts.append(f"Fields({len(properties)}): {', '.join(properties[:5])}") + + # Docstring - valuable for semantic matching + docstring = metadata.get('docstring', '') + if docstring: + # Take just the first sentence or 100 chars + brief = docstring.split('.')[0][:100].strip() + if brief: + context_parts.append(f"Desc: {brief}") + + # Build final text - only if we have meaningful context + if context_parts: + context_header = " | ".join(context_parts) + return f"[{context_header}]\n\n{content}" + + return content + + def _clean_path(self, path: str) -> str: + """ + Clean file path for embedding text. + + Removes: + - Commit hash prefixes (e.g., 'owner-repo-abc123def/') + - Archive extraction paths + - Redundant path components + """ + if not path: + return path + + # Split by '/' and look for src/, lib/, app/ etc as anchor points + parts = path.split('/') + + # Common source directory markers + source_markers = {'src', 'lib', 'app', 'source', 'main', 'test', 'tests', 'pkg', 'cmd', 'internal'} + + # Find the first source marker and start from there + for i, part in enumerate(parts): + if part.lower() in source_markers: + return '/'.join(parts[i:]) + + # If no marker found but path has commit-hash-like prefix (40 hex chars or similar) + if parts and len(parts) > 1: + first_part = parts[0] + # Check if first part looks like "owner-repo-commithash" pattern + if '-' in first_part and len(first_part) > 40: + # Skip the first part + return '/'.join(parts[1:]) + + return path + def _get_text_splitter(self, language: Language) -> RecursiveCharacterTextSplitter: """Get language-specific text splitter.""" if language not in self._splitter_cache: diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py index bde5e2bf..72013d78 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/models/config.py @@ -74,10 +74,14 @@ def validate_api_key(cls, v: str) -> str: logger.info(f"OpenRouter API key loaded: {v[:10]}...{v[-4:]}") return v - chunk_size: int = Field(default=800) + # Chunk size for code files + # text-embedding-3-small supports ~8191 tokens (~32K chars) + # 8000 chars keeps most semantic units (classes, functions) intact + chunk_size: int = Field(default=8000) chunk_overlap: int = Field(default=200) - text_chunk_size: int = Field(default=1000) + # Text chunk size for non-code files (markdown, docs) + text_chunk_size: int = Field(default=2000) text_chunk_overlap: int = Field(default=200) base_index_namespace: str = Field(default="code_rag") diff --git a/python-ecosystem/rag-pipeline/src/rag_pipeline/utils/utils.py b/python-ecosystem/rag-pipeline/src/rag_pipeline/utils/utils.py index 4b965cb2..8810113b 100644 --- a/python-ecosystem/rag-pipeline/src/rag_pipeline/utils/utils.py +++ b/python-ecosystem/rag-pipeline/src/rag_pipeline/utils/utils.py @@ -69,6 +69,53 @@ def make_project_namespace(workspace: str, project: str) -> str: return f"{workspace}__{project}".replace("/", "_").replace(".", "_").lower() +def clean_archive_path(path: str) -> str: + """ + Clean archive root prefix from file paths. + + Bitbucket and other VCS archives often create a root folder like: + - 'owner-repo-commitHash/' (Bitbucket) + - 'repo-branch/' (GitHub) + + This function strips that prefix to get clean paths like 'src/file.php'. + + Args: + path: File path potentially with archive prefix + + Returns: + Clean path without archive prefix + """ + if not path: + return path + + parts = Path(path).parts + if len(parts) <= 1: + return path + + first_part = parts[0] + + # Common source directory markers - if first part is one of these, path is already clean + source_markers = {'src', 'lib', 'app', 'source', 'main', 'test', 'tests', + 'pkg', 'cmd', 'internal', 'bin', 'scripts', 'docs'} + if first_part.lower() in source_markers: + return path + + # Check if first part looks like archive root: + # - Contains hyphens (owner-repo-commit pattern) + # - Or is very long (40+ chars for commit hash) + # - Or matches pattern like 'name-hexstring' + looks_like_archive = ( + '-' in first_part and len(first_part) > 20 or # owner-repo-commit + len(first_part) >= 40 or # Just commit hash + (first_part.count('-') >= 2 and any(c.isdigit() for c in first_part)) # Has digits and multiple hyphens + ) + + if looks_like_archive: + return '/'.join(parts[1:]) + + return path + + def should_exclude_file(path: str, excluded_patterns: list[str]) -> bool: """Check if file should be excluded based on patterns.