diff --git a/components/backend/handlers/corrections.go b/components/backend/handlers/corrections.go new file mode 100644 index 000000000..7fa7741d5 --- /dev/null +++ b/components/backend/handlers/corrections.go @@ -0,0 +1,267 @@ +package handlers + +import ( + "log" + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +// --- Feature flag name --- + +const correctionsFeatureFlag = "learning-agent-loop" + +// --- Types --- + +// CorrectionEvent represents a single correction event in the buffer. +type CorrectionEvent struct { + SessionName string `json:"sessionName"` + CorrectionType string `json:"correctionType"` + AgentAction string `json:"agentAction"` + UserCorrection string `json:"userCorrection"` + Target string `json:"target,omitempty"` + Source string `json:"source"` + Timestamp string `json:"timestamp,omitempty"` + ReceivedAt time.Time `json:"receivedAt"` +} + +// CorrectionRequest is the JSON body for POST /corrections. +type CorrectionRequest struct { + SessionName string `json:"sessionName" binding:"required"` + CorrectionType string `json:"correctionType" binding:"required"` + AgentAction string `json:"agentAction" binding:"required"` + UserCorrection string `json:"userCorrection" binding:"required"` + Target string `json:"target"` + Source string `json:"source" binding:"required"` + Timestamp string `json:"timestamp"` +} + +// --- Allowed enum values --- + +var allowedCorrectionTypes = map[string]bool{ + "incomplete": true, + "incorrect": true, + "out_of_scope": true, + "style": true, +} + +var allowedCorrectionSources = map[string]bool{ + "human": true, + "rubric": true, + "ui": true, +} + +// --- Per-project buffer --- + +const ( + maxEventsPerProject = 10000 + eventTTL = 24 * time.Hour +) + +// projectBuffer is a goroutine-safe FIFO buffer for a single project. +type projectBuffer struct { + mu sync.RWMutex + events []CorrectionEvent +} + +// append adds an event, evicting the oldest if the buffer is full. +func (b *projectBuffer) append(event CorrectionEvent) { + b.mu.Lock() + defer b.mu.Unlock() + if len(b.events) >= maxEventsPerProject { + // FIFO eviction: drop the oldest event + b.events = b.events[1:] + } + b.events = append(b.events, event) +} + +// list returns non-expired events, optionally filtered by session and target. +func (b *projectBuffer) list(session, target string) []CorrectionEvent { + b.mu.RLock() + defer b.mu.RUnlock() + + cutoff := time.Now().Add(-eventTTL) + result := make([]CorrectionEvent, 0) + for _, e := range b.events { + if e.ReceivedAt.Before(cutoff) { + continue + } + if session != "" && e.SessionName != session { + continue + } + if target != "" && e.Target != target { + continue + } + result = append(result, e) + } + return result +} + +// summary returns correction counts grouped by target, optionally filtered. +func (b *projectBuffer) summary(target string) map[string]int { + b.mu.RLock() + defer b.mu.RUnlock() + + cutoff := time.Now().Add(-eventTTL) + counts := make(map[string]int) + for _, e := range b.events { + if e.ReceivedAt.Before(cutoff) { + continue + } + if target != "" && e.Target != target { + continue + } + key := e.Target + if key == "" { + key = "(none)" + } + counts[key]++ + } + return counts +} + +// --- Global buffer registry --- + +var ( + buffersMu sync.RWMutex + buffers = make(map[string]*projectBuffer) +) + +func getProjectBuffer(project string) *projectBuffer { + buffersMu.RLock() + buf, ok := buffers[project] + buffersMu.RUnlock() + if ok { + return buf + } + + buffersMu.Lock() + defer buffersMu.Unlock() + // Double-check after acquiring write lock + if buf, ok = buffers[project]; ok { + return buf + } + buf = &projectBuffer{} + buffers[project] = buf + return buf +} + +// ResetCorrectionsBuffers clears all buffers. Exported for tests only. +func ResetCorrectionsBuffers() { + buffersMu.Lock() + defer buffersMu.Unlock() + buffers = make(map[string]*projectBuffer) +} + +// isCorrectionsEnabled checks workspace ConfigMap override first, then +// falls back to the Unleash SDK. This mirrors the pattern used by +// isRunnerEnabledWithOverrides in runner_types.go. +func isCorrectionsEnabled(c *gin.Context) bool { + namespace := sanitizeParam(c.Param("projectName")) + reqK8s, _ := GetK8sClientsForRequest(c) + if reqK8s == nil { + return false + } + + // Check workspace ConfigMap override first + overrides, err := getWorkspaceOverrides(c.Request.Context(), reqK8s, namespace) + if err == nil && overrides != nil { + if val, exists := overrides[correctionsFeatureFlag]; exists { + return val == "true" + } + } + + // Fall back to Unleash SDK + return FeatureEnabledForRequest(c, correctionsFeatureFlag) +} + +// --- Handlers --- + +// PostCorrection handles POST /api/projects/:projectName/corrections +func PostCorrection(c *gin.Context) { + project := sanitizeParam(c.Param("projectName")) + + // Feature flag gate + if !isCorrectionsEnabled(c) { + c.JSON(http.StatusNotFound, gin.H{"error": "Feature not enabled"}) + return + } + + var req CorrectionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body: " + err.Error()}) + return + } + + // Validate correctionType enum + if !allowedCorrectionTypes[req.CorrectionType] { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Invalid correctionType. Must be one of: incomplete, incorrect, out_of_scope, style", + }) + return + } + + // Validate source enum + if !allowedCorrectionSources[req.Source] { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Invalid source. Must be one of: human, rubric, ui", + }) + return + } + + event := CorrectionEvent{ + SessionName: req.SessionName, + CorrectionType: req.CorrectionType, + AgentAction: req.AgentAction, + UserCorrection: req.UserCorrection, + Target: req.Target, + Source: req.Source, + Timestamp: req.Timestamp, + ReceivedAt: time.Now(), + } + + buf := getProjectBuffer(project) + buf.append(event) + + log.Printf("Correction received: project=%s session=%s type=%s target=%s source=%s", + project, req.SessionName, req.CorrectionType, req.Target, req.Source) + + c.JSON(http.StatusCreated, gin.H{"message": "Correction recorded"}) +} + +// ListCorrections handles GET /api/projects/:projectName/corrections +func ListCorrections(c *gin.Context) { + // Feature flag gate + if !isCorrectionsEnabled(c) { + c.JSON(http.StatusNotFound, gin.H{"error": "Feature not enabled"}) + return + } + + project := sanitizeParam(c.Param("projectName")) + session := c.Query("session") + target := c.Query("target") + + buf := getProjectBuffer(project) + events := buf.list(session, target) + + c.JSON(http.StatusOK, gin.H{"corrections": events}) +} + +// GetCorrectionsSummary handles GET /api/projects/:projectName/corrections/summary +func GetCorrectionsSummary(c *gin.Context) { + // Feature flag gate + if !isCorrectionsEnabled(c) { + c.JSON(http.StatusNotFound, gin.H{"error": "Feature not enabled"}) + return + } + + project := sanitizeParam(c.Param("projectName")) + target := c.Query("target") + + buf := getProjectBuffer(project) + counts := buf.summary(target) + + c.JSON(http.StatusOK, gin.H{"summary": counts}) +} diff --git a/components/backend/handlers/corrections_test.go b/components/backend/handlers/corrections_test.go new file mode 100644 index 000000000..df2e75372 --- /dev/null +++ b/components/backend/handlers/corrections_test.go @@ -0,0 +1,701 @@ +//go:build test + +package handlers + +import ( + "context" + "net/http" + "time" + + test_constants "ambient-code-backend/tests/constants" + "ambient-code-backend/tests/logger" + "ambient-code-backend/tests/test_utils" + + "github.com/gin-gonic/gin" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +var _ = Describe("Corrections Pipeline Handler", Label(test_constants.LabelUnit, test_constants.LabelHandlers, test_constants.LabelCorrections), func() { + var ( + localHTTP *test_utils.HTTPTestUtils + localK8s *test_utils.K8sTestUtils + testToken string + ) + + BeforeEach(func() { + logger.Log("Setting up Corrections Pipeline test") + + localK8s = test_utils.NewK8sTestUtils(false, "test-project") + SetupHandlerDependencies(localK8s) + + localHTTP = test_utils.NewHTTPTestUtils() + + // Create namespace + role and mint token + ctx := context.Background() + _, err := localK8s.K8sClient.CoreV1().Namespaces().Create(ctx, &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{Name: "test-project"}, + }, metav1.CreateOptions{}) + if err != nil && !errors.IsAlreadyExists(err) { + Expect(err).NotTo(HaveOccurred()) + } + _, err = localK8s.CreateTestRole(ctx, "test-project", "test-corrections-role", + []string{"get", "list", "create", "update", "delete", "patch"}, "*", "") + Expect(err).NotTo(HaveOccurred()) + + token, _, err := localHTTP.SetValidTestToken( + localK8s, "test-project", + []string{"get", "list", "create", "update", "delete", "patch"}, + "*", "", "test-corrections-role", + ) + Expect(err).NotTo(HaveOccurred()) + testToken = token + + // Enable the feature flag via ConfigMap override + cm := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: FeatureFlagOverridesConfigMap, + Namespace: "test-project", + }, + Data: map[string]string{ + correctionsFeatureFlag: "true", + }, + } + _, err = localK8s.K8sClient.CoreV1().ConfigMaps("test-project").Create(ctx, cm, metav1.CreateOptions{}) + if err != nil && !errors.IsAlreadyExists(err) { + Expect(err).NotTo(HaveOccurred()) + } + + // Reset the in-memory buffer between tests + ResetCorrectionsBuffers() + }) + + AfterEach(func() { + if localK8s != nil { + _ = localK8s.K8sClient.CoreV1().Namespaces().Delete(context.Background(), "test-project", metav1.DeleteOptions{}) + } + }) + + // --------------------------------------------------------------- + // POST /corrections + // --------------------------------------------------------------- + Context("POST /corrections", func() { + It("Should accept a valid correction and return 201", func() { + body := map[string]interface{}{ + "sessionName": "session-1", + "correctionType": "incorrect", + "agentAction": "Used if/else", + "userCorrection": "Should have used try/except", + "target": "my-workflow", + "source": "human", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + PostCorrection(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusCreated) + logger.Log("POST /corrections returned 201 for valid input") + }) + + It("Should reject missing correctionType with 400", func() { + body := map[string]interface{}{ + "sessionName": "session-1", + "agentAction": "test", + "userCorrection": "test", + "source": "human", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + PostCorrection(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusBadRequest) + logger.Log("POST /corrections rejected missing correctionType") + }) + + It("Should reject invalid correctionType with 400", func() { + body := map[string]interface{}{ + "sessionName": "session-1", + "correctionType": "bogus", + "agentAction": "test", + "userCorrection": "test", + "source": "human", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + PostCorrection(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusBadRequest) + localHTTP.AssertErrorMessage("Invalid correctionType") + logger.Log("POST /corrections rejected invalid correctionType") + }) + + It("Should reject invalid source with 400", func() { + body := map[string]interface{}{ + "sessionName": "session-1", + "correctionType": "incorrect", + "agentAction": "test", + "userCorrection": "test", + "source": "invalid-source", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + PostCorrection(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusBadRequest) + localHTTP.AssertErrorMessage("Invalid source") + logger.Log("POST /corrections rejected invalid source") + }) + + It("Should accept source=ui for frontend corrections", func() { + body := map[string]interface{}{ + "sessionName": "session-1", + "correctionType": "style", + "agentAction": "Used wrong pattern", + "userCorrection": "Use the standard pattern", + "source": "ui", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + PostCorrection(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusCreated) + logger.Log("POST /corrections accepted source=ui") + }) + + It("Should accept source=rubric", func() { + body := map[string]interface{}{ + "sessionName": "session-1", + "correctionType": "style", + "agentAction": "Originality scored low", + "userCorrection": "Use fresh humor", + "source": "rubric", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + PostCorrection(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusCreated) + logger.Log("POST /corrections accepted source=rubric") + }) + + It("Should accept empty target", func() { + body := map[string]interface{}{ + "sessionName": "session-1", + "correctionType": "incomplete", + "agentAction": "Missed a step", + "userCorrection": "Include step 3", + "source": "human", + "target": "", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + PostCorrection(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusCreated) + logger.Log("POST /corrections accepted empty target") + }) + + It("Should require authentication", func() { + restore := WithAuthCheckEnabled() + defer restore() + + body := map[string]interface{}{ + "sessionName": "session-1", + "correctionType": "incorrect", + "agentAction": "test", + "userCorrection": "test", + "source": "human", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + // Don't set auth header + + PostCorrection(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusNotFound) + logger.Log("POST /corrections requires authentication") + }) + }) + + // --------------------------------------------------------------- + // GET /corrections + // --------------------------------------------------------------- + Context("GET /corrections", func() { + It("Should return empty list when no corrections exist", func() { + ginCtx := localHTTP.CreateTestGinContext("GET", "/api/projects/test-project/corrections", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + ListCorrections(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusOK) + + var response map[string]interface{} + localHTTP.GetResponseJSON(&response) + corrections := response["corrections"].([]interface{}) + Expect(corrections).To(HaveLen(0)) + logger.Log("GET /corrections returns empty list when no corrections") + }) + + It("Should return posted corrections", func() { + // Post a correction first + body := map[string]interface{}{ + "sessionName": "session-1", + "correctionType": "incorrect", + "agentAction": "Did X", + "userCorrection": "Should do Y", + "target": "workflow-a", + "source": "human", + } + postCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + postCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + PostCorrection(postCtx) + + // Now list + localHTTP = test_utils.NewHTTPTestUtils() + ginCtx := localHTTP.CreateTestGinContext("GET", "/api/projects/test-project/corrections", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + ListCorrections(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusOK) + + var response map[string]interface{} + localHTTP.GetResponseJSON(&response) + corrections := response["corrections"].([]interface{}) + Expect(corrections).To(HaveLen(1)) + + first := corrections[0].(map[string]interface{}) + Expect(first["sessionName"]).To(Equal("session-1")) + Expect(first["correctionType"]).To(Equal("incorrect")) + Expect(first["target"]).To(Equal("workflow-a")) + Expect(first["source"]).To(Equal("human")) + logger.Log("GET /corrections returns posted corrections") + }) + + It("Should filter by session query param", func() { + // Post corrections for different sessions + for _, sess := range []string{"session-a", "session-b", "session-a"} { + localHTTP = test_utils.NewHTTPTestUtils() + body := map[string]interface{}{ + "sessionName": sess, "correctionType": "incorrect", + "agentAction": "test", "userCorrection": "test", "source": "human", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + PostCorrection(ginCtx) + } + + // Filter for session-a + localHTTP = test_utils.NewHTTPTestUtils() + ginCtx := localHTTP.CreateTestGinContext("GET", "/api/projects/test-project/corrections?session=session-a", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + ginCtx.Request.URL.RawQuery = "session=session-a" + localHTTP.SetAuthHeader(testToken) + + ListCorrections(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusOK) + var response map[string]interface{} + localHTTP.GetResponseJSON(&response) + corrections := response["corrections"].([]interface{}) + Expect(corrections).To(HaveLen(2)) + logger.Log("GET /corrections filters by session") + }) + + It("Should filter by target query param", func() { + // Post corrections for different targets + for _, t := range []string{"wf-a", "wf-b", "wf-a"} { + localHTTP = test_utils.NewHTTPTestUtils() + body := map[string]interface{}{ + "sessionName": "s1", "correctionType": "style", + "agentAction": "test", "userCorrection": "test", + "target": t, "source": "human", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + PostCorrection(ginCtx) + } + + // Filter for wf-a + localHTTP = test_utils.NewHTTPTestUtils() + ginCtx := localHTTP.CreateTestGinContext("GET", "/api/projects/test-project/corrections?target=wf-a", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + ginCtx.Request.URL.RawQuery = "target=wf-a" + localHTTP.SetAuthHeader(testToken) + + ListCorrections(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusOK) + var response map[string]interface{} + localHTTP.GetResponseJSON(&response) + corrections := response["corrections"].([]interface{}) + Expect(corrections).To(HaveLen(2)) + logger.Log("GET /corrections filters by target") + }) + + It("Should return corrections from both sources", func() { + // Post from runner (human) + localHTTP = test_utils.NewHTTPTestUtils() + body := map[string]interface{}{ + "sessionName": "s1", "correctionType": "incorrect", + "agentAction": "test", "userCorrection": "test", + "source": "human", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + PostCorrection(ginCtx) + + // Post from UI + localHTTP = test_utils.NewHTTPTestUtils() + body = map[string]interface{}{ + "sessionName": "s1", "correctionType": "style", + "agentAction": "wrong approach", "userCorrection": "use standard pattern", + "source": "ui", + } + ginCtx = localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + PostCorrection(ginCtx) + + // List all + localHTTP = test_utils.NewHTTPTestUtils() + ginCtx = localHTTP.CreateTestGinContext("GET", "/api/projects/test-project/corrections", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + ListCorrections(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusOK) + var response map[string]interface{} + localHTTP.GetResponseJSON(&response) + corrections := response["corrections"].([]interface{}) + Expect(corrections).To(HaveLen(2)) + + // Verify both sources are present + sources := make(map[string]bool) + for _, c := range corrections { + evt := c.(map[string]interface{}) + sources[evt["source"].(string)] = true + } + Expect(sources).To(HaveKey("human")) + Expect(sources).To(HaveKey("ui")) + logger.Log("GET /corrections returns corrections from both sources") + }) + + It("Should require authentication", func() { + restore := WithAuthCheckEnabled() + defer restore() + + ginCtx := localHTTP.CreateTestGinContext("GET", "/api/projects/test-project/corrections", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + + ListCorrections(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusNotFound) + logger.Log("GET /corrections requires authentication") + }) + }) + + // --------------------------------------------------------------- + // GET /corrections/summary + // --------------------------------------------------------------- + Context("GET /corrections/summary", func() { + It("Should return empty summary when no corrections", func() { + ginCtx := localHTTP.CreateTestGinContext("GET", "/api/projects/test-project/corrections/summary", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + GetCorrectionsSummary(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusOK) + var response map[string]interface{} + localHTTP.GetResponseJSON(&response) + summary := response["summary"].(map[string]interface{}) + Expect(summary).To(BeEmpty()) + logger.Log("GET /corrections/summary returns empty when no corrections") + }) + + It("Should return counts grouped by target", func() { + // Post corrections for different targets + targets := []string{"workflow-a", "workflow-a", "workflow-a", "repo-b", "repo-b"} + for _, t := range targets { + localHTTP = test_utils.NewHTTPTestUtils() + body := map[string]interface{}{ + "sessionName": "s1", "correctionType": "incorrect", + "agentAction": "test", "userCorrection": "test", + "target": t, "source": "human", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + PostCorrection(ginCtx) + } + + localHTTP = test_utils.NewHTTPTestUtils() + ginCtx := localHTTP.CreateTestGinContext("GET", "/api/projects/test-project/corrections/summary", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + GetCorrectionsSummary(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusOK) + var response map[string]interface{} + localHTTP.GetResponseJSON(&response) + summary := response["summary"].(map[string]interface{}) + Expect(summary["workflow-a"]).To(BeNumerically("==", 3)) + Expect(summary["repo-b"]).To(BeNumerically("==", 2)) + logger.Log("GET /corrections/summary returns grouped counts") + }) + + It("Should filter summary by target query param", func() { + // Post corrections + for _, t := range []string{"workflow-a", "workflow-a", "repo-b"} { + localHTTP = test_utils.NewHTTPTestUtils() + body := map[string]interface{}{ + "sessionName": "s1", "correctionType": "incorrect", + "agentAction": "test", "userCorrection": "test", + "target": t, "source": "human", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + PostCorrection(ginCtx) + } + + localHTTP = test_utils.NewHTTPTestUtils() + ginCtx := localHTTP.CreateTestGinContext("GET", "/api/projects/test-project/corrections/summary?target=workflow-a", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + ginCtx.Request.URL.RawQuery = "target=workflow-a" + localHTTP.SetAuthHeader(testToken) + + GetCorrectionsSummary(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusOK) + var response map[string]interface{} + localHTTP.GetResponseJSON(&response) + summary := response["summary"].(map[string]interface{}) + Expect(summary).To(HaveLen(1)) + Expect(summary["workflow-a"]).To(BeNumerically("==", 2)) + logger.Log("GET /corrections/summary filters by target") + }) + + It("Should group empty targets under (none)", func() { + localHTTP = test_utils.NewHTTPTestUtils() + body := map[string]interface{}{ + "sessionName": "s1", "correctionType": "style", + "agentAction": "test", "userCorrection": "test", + "target": "", "source": "human", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + PostCorrection(ginCtx) + + localHTTP = test_utils.NewHTTPTestUtils() + ginCtx = localHTTP.CreateTestGinContext("GET", "/api/projects/test-project/corrections/summary", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + GetCorrectionsSummary(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusOK) + var response map[string]interface{} + localHTTP.GetResponseJSON(&response) + summary := response["summary"].(map[string]interface{}) + Expect(summary["(none)"]).To(BeNumerically("==", 1)) + logger.Log("GET /corrections/summary groups empty targets under (none)") + }) + + It("Should require authentication", func() { + restore := WithAuthCheckEnabled() + defer restore() + + ginCtx := localHTTP.CreateTestGinContext("GET", "/api/projects/test-project/corrections/summary", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + + GetCorrectionsSummary(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusNotFound) + logger.Log("GET /corrections/summary requires authentication") + }) + }) + + // --------------------------------------------------------------- + // Feature flag gating + // --------------------------------------------------------------- + Context("Feature Flag Gating", func() { + It("Should return 404 for POST when feature flag is disabled", func() { + // Delete the ConfigMap override to disable the flag + ctx := context.Background() + _ = localK8s.K8sClient.CoreV1().ConfigMaps("test-project").Delete(ctx, FeatureFlagOverridesConfigMap, metav1.DeleteOptions{}) + + body := map[string]interface{}{ + "sessionName": "s1", "correctionType": "incorrect", + "agentAction": "test", "userCorrection": "test", "source": "human", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + PostCorrection(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusNotFound) + logger.Log("POST /corrections returns 404 when flag disabled") + }) + + It("Should return 404 for GET when feature flag is disabled", func() { + ctx := context.Background() + _ = localK8s.K8sClient.CoreV1().ConfigMaps("test-project").Delete(ctx, FeatureFlagOverridesConfigMap, metav1.DeleteOptions{}) + + ginCtx := localHTTP.CreateTestGinContext("GET", "/api/projects/test-project/corrections", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + ListCorrections(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusNotFound) + logger.Log("GET /corrections returns 404 when flag disabled") + }) + + It("Should return 404 for summary when feature flag is disabled", func() { + ctx := context.Background() + _ = localK8s.K8sClient.CoreV1().ConfigMaps("test-project").Delete(ctx, FeatureFlagOverridesConfigMap, metav1.DeleteOptions{}) + + ginCtx := localHTTP.CreateTestGinContext("GET", "/api/projects/test-project/corrections/summary", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + + GetCorrectionsSummary(ginCtx) + + localHTTP.AssertHTTPStatus(http.StatusNotFound) + logger.Log("GET /corrections/summary returns 404 when flag disabled") + }) + }) + + // --------------------------------------------------------------- + // Buffer behavior + // --------------------------------------------------------------- + Context("Buffer Behavior", func() { + It("Should evict oldest events when buffer is full", func() { + buf := getProjectBuffer("eviction-test") + + // Fill buffer beyond max + for i := 0; i < maxEventsPerProject+5; i++ { + buf.append(CorrectionEvent{ + SessionName: "s1", + CorrectionType: "incorrect", + AgentAction: "test", + UserCorrection: "test", + Source: "human", + ReceivedAt: time.Now(), + }) + } + + buf.mu.RLock() + Expect(len(buf.events)).To(Equal(maxEventsPerProject)) + buf.mu.RUnlock() + logger.Log("Buffer evicts oldest events at max capacity") + }) + + It("Should not return expired events", func() { + buf := getProjectBuffer("expiry-test") + + // Add an event with a ReceivedAt in the past (>24h ago) + buf.mu.Lock() + buf.events = append(buf.events, CorrectionEvent{ + SessionName: "old-session", + CorrectionType: "incorrect", + AgentAction: "old action", + UserCorrection: "old correction", + Source: "human", + ReceivedAt: time.Now().Add(-25 * time.Hour), + }) + buf.mu.Unlock() + + // Add a fresh event + buf.append(CorrectionEvent{ + SessionName: "new-session", + CorrectionType: "style", + AgentAction: "new action", + UserCorrection: "new correction", + Source: "human", + ReceivedAt: time.Now(), + }) + + events := buf.list("", "") + Expect(events).To(HaveLen(1)) + Expect(events[0].SessionName).To(Equal("new-session")) + logger.Log("Buffer excludes expired events from list") + }) + + It("Should not count expired events in summary", func() { + buf := getProjectBuffer("expiry-summary-test") + + // Add an expired event + buf.mu.Lock() + buf.events = append(buf.events, CorrectionEvent{ + SessionName: "old", + CorrectionType: "incorrect", + Target: "wf-a", + Source: "human", + ReceivedAt: time.Now().Add(-25 * time.Hour), + }) + buf.mu.Unlock() + + // Add a fresh event + buf.append(CorrectionEvent{ + SessionName: "new", + CorrectionType: "style", + Target: "wf-a", + Source: "human", + ReceivedAt: time.Now(), + }) + + counts := buf.summary("") + Expect(counts["wf-a"]).To(Equal(1)) + logger.Log("Buffer excludes expired events from summary") + }) + + It("Should isolate corrections between projects", func() { + // Post to project test-project + body := map[string]interface{}{ + "sessionName": "s1", "correctionType": "incorrect", + "agentAction": "test", "userCorrection": "test", + "source": "human", + } + ginCtx := localHTTP.CreateTestGinContext("POST", "/api/projects/test-project/corrections", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + localHTTP.SetAuthHeader(testToken) + PostCorrection(ginCtx) + + // Check another project buffer is empty + otherBuf := getProjectBuffer("other-project") + events := otherBuf.list("", "") + Expect(events).To(HaveLen(0)) + logger.Log("Corrections are isolated between projects") + }) + }) +}) diff --git a/components/backend/handlers/extraction.go b/components/backend/handlers/extraction.go new file mode 100644 index 000000000..78175234c --- /dev/null +++ b/components/backend/handlers/extraction.go @@ -0,0 +1,858 @@ +// Package handlers: post-session insight extraction. +// +// When a session's run finishes (RUN_FINISHED or RUN_ERROR event), the +// backend optionally runs a lightweight LLM extraction pass against the +// session transcript and writes candidate insights as markdown files on +// a new branch, opening a draft PR for human review. +// +// Gated behind the "learning-agent-loop" feature flag. +// Configuration is read from .ambient/config.json in the workspace repo. +package handlers + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "regexp" + "sort" + "strings" + "time" + + "ambient-code-backend/types" + + "github.com/anthropics/anthropic-sdk-go" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" +) + +// LoadEventsForExtraction loads AG-UI events for a session. +// Set by the websocket package at init to avoid circular imports. +var LoadEventsForExtraction func(sessionName string) []map[string]interface{} + +// slugRegex is compiled once for use in slugify(). +var slugRegex = regexp.MustCompile(`[^a-z0-9]+`) + +// ─── Extraction status constants ──────────────────────────────────── + +const ( + ExtractionStatusPending = "pending" + ExtractionStatusRunning = "running" + ExtractionStatusCompleted = "completed" + ExtractionStatusSkipped = "skipped" + ExtractionStatusFailed = "failed" + ExtractionStatusPartialFailure = "partial-failure" +) + +// ─── Default configuration ────────────────────────────────────────── + +const ( + defaultExtractionModel = "claude-haiku-4-20250414" + defaultExtractionModelVertex = "claude-haiku-4@20250414" + defaultMaxMemoriesPerSession = 5 + defaultMinTurnThreshold = 5 + extractionAPITimeout = 30 * time.Second + maxTranscriptChars = 50000 +) + +// ─── Configuration types ──────────────────────────────────────────── + +// ExtractionConfig holds the extraction settings from .ambient/config.json. +type ExtractionConfig struct { + Enabled bool `json:"enabled"` + Model string `json:"model"` + MaxMemoriesPerSession int `json:"maxMemoriesPerSession"` + MinTurnThreshold int `json:"minTurnThreshold"` +} + +// LearningConfig holds the top-level learning settings. +type LearningConfig struct { + Enabled bool `json:"enabled"` + Extraction *ExtractionConfig `json:"extraction"` +} + +// AmbientConfig represents the .ambient/config.json file. +type AmbientConfig struct { + Learning *LearningConfig `json:"learning"` +} + +// InsightCandidate represents a single extracted insight from the LLM. +type InsightCandidate struct { + Title string `json:"title"` + Content string `json:"content"` + Type string `json:"type"` // "correction" or "pattern" + Confidence float64 `json:"confidence"` // 0.0 - 1.0 +} + +// ─── Config parsing ───────────────────────────────────────────────── + +// parseExtractionConfig parses extraction config from raw JSON bytes. +// Returns nil if extraction is not enabled or not configured. +func parseExtractionConfig(data []byte) *ExtractionConfig { + var cfg AmbientConfig + if err := json.Unmarshal(data, &cfg); err != nil { + log.Printf("Extraction: failed to parse .ambient/config.json: %v", err) + return nil + } + if cfg.Learning == nil || !cfg.Learning.Enabled { + return nil + } + if cfg.Learning.Extraction == nil || !cfg.Learning.Extraction.Enabled { + return nil + } + ext := cfg.Learning.Extraction + + // Apply defaults + if ext.Model == "" { + ext.Model = defaultExtractionModel + } + if ext.MaxMemoriesPerSession <= 0 { + ext.MaxMemoriesPerSession = defaultMaxMemoriesPerSession + } + if ext.MinTurnThreshold <= 0 { + ext.MinTurnThreshold = defaultMinTurnThreshold + } + return ext +} + +// ─── Transcript helpers ───────────────────────────────────────────── + +// countUserTurns counts the number of user messages in the event log. +// Prefers MESSAGES_SNAPSHOT (compacted sessions), falls back to +// counting TEXT_MESSAGE_START events with role=user. +func countUserTurns(events []map[string]interface{}) int { + // Try MESSAGES_SNAPSHOT first (last one wins — most recent state) + for i := len(events) - 1; i >= 0; i-- { + evt := events[i] + eventType, _ := evt["type"].(string) + if eventType == types.EventTypeMessagesSnapshot { + messages, ok := evt["messages"].([]interface{}) + if !ok { + continue + } + count := 0 + for _, msg := range messages { + m, ok := msg.(map[string]interface{}) + if !ok { + continue + } + if role, _ := m["role"].(string); role == types.RoleUser { + count++ + } + } + return count + } + } + + // Fallback: count streaming TEXT_MESSAGE_START with role=user + count := 0 + for _, evt := range events { + eventType, _ := evt["type"].(string) + if eventType == types.EventTypeTextMessageStart { + if role, _ := evt["role"].(string); role == types.RoleUser { + count++ + } + } + } + return count +} + +// buildTranscriptText extracts a compact text transcript from AG-UI events. +// Prefers MESSAGES_SNAPSHOT for compacted sessions. Truncates to maxTranscriptChars. +func buildTranscriptText(events []map[string]interface{}) string { + var sb strings.Builder + + for i := len(events) - 1; i >= 0; i-- { + evt := events[i] + eventType, _ := evt["type"].(string) + if eventType == types.EventTypeMessagesSnapshot { + messages, ok := evt["messages"].([]interface{}) + if !ok { + continue + } + for _, msg := range messages { + m, ok := msg.(map[string]interface{}) + if !ok { + continue + } + role, _ := m["role"].(string) + content, _ := m["content"].(string) + if role == "" || content == "" { + continue + } + // Skip system/developer messages (not useful for extraction) + if role == types.RoleSystem || role == types.RoleDeveloper { + continue + } + fmt.Fprintf(&sb, "[%s]: %s\n\n", role, content) + } + break + } + } + + text := sb.String() + if len(text) > maxTranscriptChars { + text = text[:maxTranscriptChars] + "\n\n[transcript truncated]" + } + return text +} + +// ─── LLM extraction ───────────────────────────────────────────────── + +const extractionPrompt = `You are an expert at identifying reusable engineering knowledge from coding session transcripts. + +Analyze the following transcript from an AI-assisted coding session. Extract reusable knowledge that would help future sessions avoid mistakes or follow better patterns. + +Focus on: +- CORRECTIONS: Mistakes that were made and corrected. Things to avoid in the future. +- PATTERNS: Conventions, idioms, or approaches that worked well and should be repeated. + +Ignore: +- Session-specific details (file names, variable names, specific bugs) +- Obvious or trivial knowledge +- Anything that wouldn't generalize to other sessions + +Return a JSON array of candidates. Each candidate must have: +- "title": A short descriptive title (max 80 chars) +- "content": The reusable knowledge as markdown (2-5 sentences) +- "type": Either "correction" or "pattern" +- "confidence": A float from 0.0 to 1.0 indicating how reusable this knowledge is + +Return ONLY the JSON array, no markdown fences, no explanation. If nothing is worth extracting, return an empty array [].` + +// callExtractionModel sends the transcript to the extraction LLM and returns parsed candidates. +func callExtractionModel(ctx context.Context, client anthropic.Client, transcript, modelName string) ([]InsightCandidate, error) { + message, err := client.Messages.New(ctx, anthropic.MessageNewParams{ + Model: anthropic.Model(modelName), + MaxTokens: 2048, + System: []anthropic.TextBlockParam{ + {Text: extractionPrompt}, + }, + Messages: []anthropic.MessageParam{ + anthropic.NewUserMessage(anthropic.NewTextBlock(transcript)), + }, + }) + if err != nil { + return nil, fmt.Errorf("LLM API call failed: %w", err) + } + if len(message.Content) == 0 { + return nil, fmt.Errorf("empty response from extraction model") + } + var responseText string + for _, block := range message.Content { + if block.Type == "text" { + responseText = strings.TrimSpace(block.Text) + break + } + } + if responseText == "" { + return nil, fmt.Errorf("no text content in extraction response") + } + return parseExtractionResponse(responseText) +} + +// parseExtractionResponse parses the LLM JSON response into InsightCandidate structs. +func parseExtractionResponse(responseText string) ([]InsightCandidate, error) { + responseText = strings.TrimSpace(responseText) + // Strip markdown code fences if present + if strings.HasPrefix(responseText, "```") { + lines := strings.Split(responseText, "\n") + if len(lines) >= 3 { + responseText = strings.Join(lines[1:len(lines)-1], "\n") + } + } + + var candidates []InsightCandidate + if err := json.Unmarshal([]byte(responseText), &candidates); err != nil { + return nil, fmt.Errorf("failed to parse extraction JSON: %w (response: %.200s)", err, responseText) + } + + // Validate and filter candidates + var valid []InsightCandidate + for _, c := range candidates { + if c.Title == "" || c.Content == "" || c.Type == "" { + continue + } + if c.Type != "correction" && c.Type != "pattern" { + continue + } + if c.Confidence < 0 { + c.Confidence = 0 + } + if c.Confidence > 1 { + c.Confidence = 1 + } + valid = append(valid, c) + } + return valid, nil +} + +// rankAndCap sorts candidates by confidence (descending) and truncates to maxCount. +func rankAndCap(candidates []InsightCandidate, maxCount int) []InsightCandidate { + if len(candidates) == 0 { + return candidates + } + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].Confidence > candidates[j].Confidence + }) + if len(candidates) > maxCount { + candidates = candidates[:maxCount] + } + return candidates +} + +// ─── Markdown + file path generation ──────────────────────────────── + +// slugify converts a title into a URL-safe slug. +func slugify(title string) string { + s := strings.ToLower(title) + s = slugRegex.ReplaceAllString(s, "-") + s = strings.Trim(s, "-") + if len(s) > 60 { + s = s[:60] + s = strings.TrimRight(s, "-") + } + if s == "" { + s = "insight" + } + return s +} + +// formatInsightMarkdown formats an insight candidate as a markdown file. +func formatInsightMarkdown(c InsightCandidate, sessionName, projectName string) string { + var sb strings.Builder + fmt.Fprintf(&sb, "# %s\n\n", c.Title) + fmt.Fprintf(&sb, "**Type:** %s \n", c.Type) + fmt.Fprintf(&sb, "**Confidence:** %.2f \n", c.Confidence) + sb.WriteString("**Source:** insight-extraction \n") + fmt.Fprintf(&sb, "**Session:** %s/%s \n", projectName, sessionName) + fmt.Fprintf(&sb, "**Extracted:** %s \n\n", time.Now().UTC().Format(time.RFC3339)) + sb.WriteString("---\n\n") + sb.WriteString(c.Content) + sb.WriteString("\n") + return sb.String() +} + +// insightFilePath returns the path for an insight file in the docs/learned/ directory. +func insightFilePath(c InsightCandidate) string { + date := time.Now().UTC().Format("2006-01-02") + slug := slugify(c.Title) + typeDir := c.Type + "s" // "corrections" or "patterns" + return fmt.Sprintf("docs/learned/%s/%s-%s.md", typeDir, date, slug) +} + +// ─── GitHub API helpers ───────────────────────────────────────────── + +type gitHubFileContent struct { + Path string + Content string +} + +// parseGitHubOwnerRepo extracts owner and repo from a GitHub URL. +func parseGitHubOwnerRepo(repoURL string) (string, string, error) { + repoURL = strings.TrimSuffix(repoURL, ".git") + if strings.Contains(repoURL, "github.com") { + parts := strings.Split(repoURL, "github.com") + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid GitHub URL: %s", repoURL) + } + path := strings.Trim(parts[1], "/:") + pathParts := strings.Split(path, "/") + if len(pathParts) < 2 { + return "", "", fmt.Errorf("invalid GitHub URL path: %s", repoURL) + } + return pathParts[0], pathParts[1], nil + } + return "", "", fmt.Errorf("not a GitHub URL: %s", repoURL) +} + +// githubAPIRequest is a helper for making GitHub API requests. +func githubAPIRequest(ctx context.Context, method, url, token string, body interface{}) ([]byte, int, error) { + var reqBody io.Reader + if body != nil { + bodyJSON, err := json.Marshal(body) + if err != nil { + return nil, 0, fmt.Errorf("failed to marshal request body: %w", err) + } + reqBody = bytes.NewReader(bodyJSON) + } + + req, err := http.NewRequestWithContext(ctx, method, url, reqBody) + if err != nil { + return nil, 0, err + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Accept", "application/vnd.github+json") + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, 0, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, resp.StatusCode, err + } + return respBody, resp.StatusCode, nil +} + +func getDefaultBranchSHA(ctx context.Context, apiBase, owner, repo, token string) (string, string, error) { + respBody, status, err := githubAPIRequest(ctx, "GET", + fmt.Sprintf("%s/repos/%s/%s", apiBase, owner, repo), token, nil) + if err != nil { + return "", "", err + } + if status != http.StatusOK { + return "", "", fmt.Errorf("GitHub API error %d: %s", status, string(respBody)) + } + var repoInfo struct { + DefaultBranch string `json:"default_branch"` + } + if err := json.Unmarshal(respBody, &repoInfo); err != nil { + return "", "", err + } + + refBody, refStatus, err := githubAPIRequest(ctx, "GET", + fmt.Sprintf("%s/repos/%s/%s/git/ref/heads/%s", apiBase, owner, repo, repoInfo.DefaultBranch), token, nil) + if err != nil { + return "", "", err + } + if refStatus != http.StatusOK { + return "", "", fmt.Errorf("GitHub ref API error %d: %s", refStatus, string(refBody)) + } + var refInfo struct { + Object struct { + SHA string `json:"sha"` + } `json:"object"` + } + if err := json.Unmarshal(refBody, &refInfo); err != nil { + return "", "", err + } + return repoInfo.DefaultBranch, refInfo.Object.SHA, nil +} + +func createGitRef(ctx context.Context, apiBase, owner, repo, token, branchName, sha string) error { + body := map[string]string{ + "ref": fmt.Sprintf("refs/heads/%s", branchName), + "sha": sha, + } + respBody, status, err := githubAPIRequest(ctx, "POST", + fmt.Sprintf("%s/repos/%s/%s/git/refs", apiBase, owner, repo), token, body) + if err != nil { + return err + } + if status != http.StatusCreated { + return fmt.Errorf("create ref failed %d: %s", status, string(respBody)) + } + return nil +} + +func createFileOnBranch(ctx context.Context, apiBase, owner, repo, token, branch, path, content, commitMsg string) error { + body := map[string]string{ + "message": commitMsg, + "content": base64.StdEncoding.EncodeToString([]byte(content)), + "branch": branch, + } + respBody, status, err := githubAPIRequest(ctx, "PUT", + fmt.Sprintf("%s/repos/%s/%s/contents/%s", apiBase, owner, repo, path), token, body) + if err != nil { + return err + } + if status != http.StatusCreated && status != http.StatusOK { + return fmt.Errorf("create file failed %d: %s", status, string(respBody)) + } + return nil +} + +func createDraftPR(ctx context.Context, apiBase, owner, repo, token, head, baseBranch, title, prBodyText string) (int, error) { + body := map[string]interface{}{ + "title": title, + "head": head, + "base": baseBranch, + "body": prBodyText, + "draft": true, + } + respBody, status, err := githubAPIRequest(ctx, "POST", + fmt.Sprintf("%s/repos/%s/%s/pulls", apiBase, owner, repo), token, body) + if err != nil { + return 0, err + } + if status != http.StatusCreated { + return 0, fmt.Errorf("create PR failed %d: %s", status, string(respBody)) + } + var prResult struct { + Number int `json:"number"` + } + if err := json.Unmarshal(respBody, &prResult); err != nil { + return 0, fmt.Errorf("failed to parse PR response: %w", err) + } + return prResult.Number, nil +} + +// addLabelToPR adds the continuous-learning label to a PR. Best-effort. +func addLabelToPR(ctx context.Context, apiBase, owner, repo, token string, prNumber int) { + body := map[string][]string{ + "labels": {"continuous-learning"}, + } + _, status, err := githubAPIRequest(ctx, "POST", + fmt.Sprintf("%s/repos/%s/%s/issues/%d/labels", apiBase, owner, repo, prNumber), token, body) + if err != nil { + log.Printf("Extraction: failed to add label: %v", err) + return + } + if status != http.StatusOK { + log.Printf("Extraction: add label returned status %d", status) + } +} + +// createExtractionPR creates a branch with insight files and opens a draft PR. +func createExtractionPR(ctx context.Context, repoURL, token, sessionName, projectName string, files []gitHubFileContent) error { + owner, repo, err := parseGitHubOwnerRepo(repoURL) + if err != nil { + return fmt.Errorf("failed to parse repo URL: %w", err) + } + + apiBase := "https://api.github.com" + defaultBranch, baseSHA, err := getDefaultBranchSHA(ctx, apiBase, owner, repo, token) + if err != nil { + return fmt.Errorf("failed to get default branch: %w", err) + } + + branchName := fmt.Sprintf("learned/%s/%s", sessionName, time.Now().UTC().Format("20060102-150405")) + + if err := createGitRef(ctx, apiBase, owner, repo, token, branchName, baseSHA); err != nil { + return fmt.Errorf("failed to create branch: %w", err) + } + + for _, f := range files { + commitMsg := fmt.Sprintf("learned: add insight from %s", sessionName) + if err := createFileOnBranch(ctx, apiBase, owner, repo, token, branchName, f.Path, f.Content, commitMsg); err != nil { + return fmt.Errorf("failed to create file %s: %w", f.Path, err) + } + } + + prTitle := fmt.Sprintf("learned: insights from %s", sessionName) + prBody := fmt.Sprintf("## Extracted Insights\n\n"+ + "**source:** insight-extraction \n"+ + "**session:** %s/%s \n"+ + "**files:** %d insight(s) \n\n"+ + "These insights were automatically extracted from a completed agentic session. "+ + "Review the changes, edit as needed, and merge to include in future sessions.\n\n"+ + "---\n_Generated by the Ambient Code Platform continuous learning pipeline._", + projectName, sessionName, len(files)) + + prNumber, err := createDraftPR(ctx, apiBase, owner, repo, token, branchName, defaultBranch, prTitle, prBody) + if err != nil { + return fmt.Errorf("failed to create PR: %w", err) + } + + addLabelToPR(ctx, apiBase, owner, repo, token, prNumber) + log.Printf("Extraction: created draft PR #%d for %s/%s with %d files", prNumber, projectName, sessionName, len(files)) + return nil +} + +// ─── Status update helper ─────────────────────────────────────────── + +func updateExtractionStatus(projectName, sessionName, status string) error { + if DynamicClient == nil { + return fmt.Errorf("dynamic client not initialized") + } + gvr := GetAgenticSessionV1Alpha1Resource() + ctx := context.Background() + + item, err := DynamicClient.Resource(gvr).Namespace(projectName).Get(ctx, sessionName, v1.GetOptions{}) + if err != nil { + return fmt.Errorf("failed to get session: %w", err) + } + + statusMap, _, _ := unstructured.NestedMap(item.Object, "status") + if statusMap == nil { + statusMap = make(map[string]interface{}) + } + statusMap["extractionStatus"] = status + if err := unstructured.SetNestedMap(item.Object, statusMap, "status"); err != nil { + return fmt.Errorf("failed to set status: %w", err) + } + + _, err = DynamicClient.Resource(gvr).Namespace(projectName).UpdateStatus(ctx, item, v1.UpdateOptions{}) + if err != nil { + return fmt.Errorf("failed to update session status: %w", err) + } + return nil +} + +// ─── Main extraction orchestrator ─────────────────────────────────── + +// TriggerExtractionAsync is the entry point called when a session run +// finishes (RUN_FINISHED or RUN_ERROR). It runs the entire extraction +// pipeline in a background goroutine and does not block the caller. +func TriggerExtractionAsync(projectName, sessionName string) { + go func() { + defer func() { + if r := recover(); r != nil { + log.Printf("Extraction: recovered from panic for %s/%s: %v", projectName, sessionName, r) + } + }() + if err := runExtraction(projectName, sessionName); err != nil { + log.Printf("Extraction: failed for %s/%s: %v", projectName, sessionName, err) + } + }() +} + +func runExtraction(projectName, sessionName string) error { + // 1. Check feature flag + if !FeatureEnabled(correctionsFeatureFlag) { + log.Printf("Extraction: learning-agent-loop flag disabled, skipping %s/%s", projectName, sessionName) + return nil + } + + // 2. Set extraction status to pending (at-most-once guard) + if err := claimExtraction(projectName, sessionName); err != nil { + log.Printf("Extraction: at-most-once guard for %s/%s: %v", projectName, sessionName, err) + return nil // Not an error — already claimed + } + + // 3. Update status to running + if err := updateExtractionStatus(projectName, sessionName, ExtractionStatusRunning); err != nil { + log.Printf("Extraction: failed to set running status for %s/%s: %v", projectName, sessionName, err) + } + + // 4. Get repo URL from session spec + repoURL, err := getSessionRepoURL(projectName, sessionName) + if err != nil || repoURL == "" { + log.Printf("Extraction: no repo URL for %s/%s, skipping", projectName, sessionName) + _ = updateExtractionStatus(projectName, sessionName, ExtractionStatusSkipped) + return nil + } + + // 5. Get GitHub token + ctx, cancel := context.WithTimeout(context.Background(), extractionAPITimeout) + defer cancel() + + token, err := getExtractionGitHubToken(ctx, projectName) + if err != nil { + log.Printf("Extraction: no GitHub token for %s/%s: %v", projectName, sessionName, err) + _ = updateExtractionStatus(projectName, sessionName, ExtractionStatusSkipped) + return nil + } + + // 6. Fetch .ambient/config.json from the repo + configData, err := fetchAmbientConfig(ctx, repoURL, token) + if err != nil { + log.Printf("Extraction: config fetch failed for %s/%s: %v", projectName, sessionName, err) + _ = updateExtractionStatus(projectName, sessionName, ExtractionStatusSkipped) + return nil + } + + cfg := parseExtractionConfig(configData) + if cfg == nil { + log.Printf("Extraction: not enabled for %s/%s", projectName, sessionName) + _ = updateExtractionStatus(projectName, sessionName, ExtractionStatusSkipped) + return nil + } + + // 7. Load transcript events + if LoadEventsForExtraction == nil { + log.Printf("Extraction: LoadEventsForExtraction not initialized for %s/%s", projectName, sessionName) + _ = updateExtractionStatus(projectName, sessionName, ExtractionStatusSkipped) + return nil + } + events := LoadEventsForExtraction(sessionName) + if len(events) == 0 { + log.Printf("Extraction: empty transcript for %s/%s", projectName, sessionName) + _ = updateExtractionStatus(projectName, sessionName, ExtractionStatusSkipped) + return nil + } + + // 8. Check minimum turn threshold + turns := countUserTurns(events) + if turns < cfg.MinTurnThreshold { + log.Printf("Extraction: session %s/%s below minimum turn threshold (%d < %d)", + projectName, sessionName, turns, cfg.MinTurnThreshold) + _ = updateExtractionStatus(projectName, sessionName, ExtractionStatusSkipped) + return nil + } + + // 9. Build transcript text + transcript := buildTranscriptText(events) + if strings.TrimSpace(transcript) == "" { + log.Printf("Extraction: empty transcript text for %s/%s", projectName, sessionName) + _ = updateExtractionStatus(projectName, sessionName, ExtractionStatusSkipped) + return nil + } + + // 10. Get Anthropic client and call extraction model + llmCtx, llmCancel := context.WithTimeout(context.Background(), extractionAPITimeout) + defer llmCancel() + + client, isVertex, err := getAnthropicClient(llmCtx, projectName) + if err != nil { + log.Printf("Extraction: Anthropic client error for %s/%s: %v", projectName, sessionName, err) + _ = updateExtractionStatus(projectName, sessionName, ExtractionStatusFailed) + return fmt.Errorf("failed to get Anthropic client: %w", err) + } + + modelName := cfg.Model + if isVertex && !strings.Contains(modelName, "@") { + modelName = defaultExtractionModelVertex + } + + candidates, err := callExtractionModel(llmCtx, client, transcript, modelName) + if err != nil { + log.Printf("Extraction: LLM call failed for %s/%s: %v", projectName, sessionName, err) + _ = updateExtractionStatus(projectName, sessionName, ExtractionStatusFailed) + return fmt.Errorf("LLM extraction failed: %w", err) + } + + if len(candidates) == 0 { + log.Printf("Extraction: no candidates for %s/%s", projectName, sessionName) + _ = updateExtractionStatus(projectName, sessionName, ExtractionStatusCompleted) + return nil + } + + // 11. Rank and cap + candidates = rankAndCap(candidates, cfg.MaxMemoriesPerSession) + + // 12. Build file list + var files []gitHubFileContent + for _, c := range candidates { + files = append(files, gitHubFileContent{ + Path: insightFilePath(c), + Content: formatInsightMarkdown(c, sessionName, projectName), + }) + } + + // 13. Create PR + prCtx, prCancel := context.WithTimeout(context.Background(), 60*time.Second) + defer prCancel() + + if err := createExtractionPR(prCtx, repoURL, token, sessionName, projectName, files); err != nil { + log.Printf("Extraction: PR creation failed for %s/%s: %v", projectName, sessionName, err) + _ = updateExtractionStatus(projectName, sessionName, ExtractionStatusPartialFailure) + return fmt.Errorf("PR creation failed: %w", err) + } + + _ = updateExtractionStatus(projectName, sessionName, ExtractionStatusCompleted) + log.Printf("Extraction: completed for %s/%s (%d insights)", projectName, sessionName, len(candidates)) + return nil +} + +// claimExtraction atomically sets extractionStatus to "pending" only if it +// is currently unset. Returns an error if extraction was already claimed. +// The K8s optimistic concurrency (resourceVersion) provides the at-most-once +// guarantee: if two goroutines race, only one update will succeed. +func claimExtraction(projectName, sessionName string) error { + if DynamicClient == nil { + return fmt.Errorf("dynamic client not initialized") + } + gvr := GetAgenticSessionV1Alpha1Resource() + ctx := context.Background() + + item, err := DynamicClient.Resource(gvr).Namespace(projectName).Get(ctx, sessionName, v1.GetOptions{}) + if err != nil { + return fmt.Errorf("failed to get session: %w", err) + } + + // Check if already claimed + currentStatus, _, _ := unstructured.NestedString(item.Object, "status", "extractionStatus") + if currentStatus != "" { + return fmt.Errorf("extraction already claimed (status=%s)", currentStatus) + } + + // Claim it + statusMap, _, _ := unstructured.NestedMap(item.Object, "status") + if statusMap == nil { + statusMap = make(map[string]interface{}) + } + statusMap["extractionStatus"] = ExtractionStatusPending + if err := unstructured.SetNestedMap(item.Object, statusMap, "status"); err != nil { + return fmt.Errorf("failed to set status: %w", err) + } + + _, err = DynamicClient.Resource(gvr).Namespace(projectName).UpdateStatus(ctx, item, v1.UpdateOptions{}) + if err != nil { + return fmt.Errorf("failed to update status (conflict expected on race): %w", err) + } + return nil +} + +// getSessionRepoURL returns the URL of the first repo from the session spec. +func getSessionRepoURL(projectName, sessionName string) (string, error) { + if DynamicClient == nil { + return "", fmt.Errorf("dynamic client not initialized") + } + gvr := GetAgenticSessionV1Alpha1Resource() + ctx := context.Background() + + item, err := DynamicClient.Resource(gvr).Namespace(projectName).Get(ctx, sessionName, v1.GetOptions{}) + if err != nil { + return "", fmt.Errorf("failed to get session: %w", err) + } + + repos, found, err := unstructured.NestedSlice(item.Object, "spec", "repos") + if err != nil || !found || len(repos) == 0 { + return "", nil + } + + firstRepo, ok := repos[0].(map[string]interface{}) + if !ok { + return "", nil + } + + repoURL, _ := firstRepo["url"].(string) + return repoURL, nil +} + +// getExtractionGitHubToken gets a GitHub token for the extraction pipeline. +func getExtractionGitHubToken(ctx context.Context, projectName string) (string, error) { + if GetGitHubToken == nil { + return "", fmt.Errorf("GetGitHubToken not initialized") + } + // Use the backend service account for extraction (internal operation). + // Pass empty userID — the token function will fall back to project-level credentials. + token, err := GetGitHubToken(ctx, nil, DynamicClient, projectName, "") + if err != nil { + return "", err + } + return token, nil +} + +// fetchAmbientConfig fetches .ambient/config.json from the default branch of the repo. +func fetchAmbientConfig(ctx context.Context, repoURL, token string) ([]byte, error) { + owner, repo, err := parseGitHubOwnerRepo(repoURL) + if err != nil { + return nil, err + } + + apiURL := fmt.Sprintf("https://api.github.com/repos/%s/%s/contents/.ambient/config.json", owner, repo) + req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Accept", "application/vnd.github.v3.raw") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return nil, fmt.Errorf(".ambient/config.json not found in repo") + } + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("GitHub API error %d: %s", resp.StatusCode, string(body)) + } + + return io.ReadAll(resp.Body) +} diff --git a/components/backend/handlers/extraction_test.go b/components/backend/handlers/extraction_test.go new file mode 100644 index 000000000..01f429490 --- /dev/null +++ b/components/backend/handlers/extraction_test.go @@ -0,0 +1,444 @@ +package handlers + +import ( + "strings" + "testing" + + "ambient-code-backend/types" +) + +func TestParseExtractionConfig(t *testing.T) { + tests := []struct { + name string + input string + wantNil bool + wantModel string + wantMax int + wantMin int + }{ + { + name: "fully enabled with custom settings", + input: `{"learning":{"enabled":true,"extraction":{"enabled":true,"model":"claude-haiku-4","maxMemoriesPerSession":3,"minTurnThreshold":10}}}`, + wantNil: false, + wantModel: "claude-haiku-4", + wantMax: 3, + wantMin: 10, + }, + { + name: "enabled with defaults", + input: `{"learning":{"enabled":true,"extraction":{"enabled":true}}}`, + wantNil: false, + wantModel: defaultExtractionModel, + wantMax: defaultMaxMemoriesPerSession, + wantMin: defaultMinTurnThreshold, + }, + { + name: "learning disabled", + input: `{"learning":{"enabled":false,"extraction":{"enabled":true}}}`, + wantNil: true, + }, + { + name: "extraction disabled", + input: `{"learning":{"enabled":true,"extraction":{"enabled":false}}}`, + wantNil: true, + }, + { + name: "no extraction key", + input: `{"learning":{"enabled":true}}`, + wantNil: true, + }, + { + name: "no learning key", + input: `{}`, + wantNil: true, + }, + { + name: "invalid JSON", + input: `not json`, + wantNil: true, + }, + { + name: "empty string", + input: ``, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := parseExtractionConfig([]byte(tt.input)) + if tt.wantNil { + if cfg != nil { + t.Errorf("expected nil config, got %+v", cfg) + } + return + } + if cfg == nil { + t.Fatal("expected non-nil config, got nil") + } + if cfg.Model != tt.wantModel { + t.Errorf("model: got %q, want %q", cfg.Model, tt.wantModel) + } + if cfg.MaxMemoriesPerSession != tt.wantMax { + t.Errorf("maxMemoriesPerSession: got %d, want %d", cfg.MaxMemoriesPerSession, tt.wantMax) + } + if cfg.MinTurnThreshold != tt.wantMin { + t.Errorf("minTurnThreshold: got %d, want %d", cfg.MinTurnThreshold, tt.wantMin) + } + }) + } +} + +func TestCountUserTurns(t *testing.T) { + tests := []struct { + name string + events []map[string]interface{} + want int + }{ + { + name: "messages snapshot with 3 user turns", + events: []map[string]interface{}{ + {"type": types.EventTypeMessagesSnapshot, "messages": []interface{}{ + map[string]interface{}{"role": types.RoleUser, "content": "hello"}, + map[string]interface{}{"role": types.RoleAssistant, "content": "hi"}, + map[string]interface{}{"role": types.RoleUser, "content": "help me"}, + map[string]interface{}{"role": types.RoleAssistant, "content": "sure"}, + map[string]interface{}{"role": types.RoleUser, "content": "thanks"}, + }}, + }, + want: 3, + }, + { + name: "streaming events with 2 user turns", + events: []map[string]interface{}{ + {"type": types.EventTypeTextMessageStart, "role": types.RoleUser}, + {"type": types.EventTypeTextMessageStart, "role": types.RoleAssistant}, + {"type": types.EventTypeTextMessageStart, "role": types.RoleUser}, + }, + want: 2, + }, + { + name: "empty events", + events: []map[string]interface{}{}, + want: 0, + }, + { + name: "nil events", + events: nil, + want: 0, + }, + { + name: "snapshot preferred over streaming events", + events: []map[string]interface{}{ + {"type": types.EventTypeTextMessageStart, "role": types.RoleUser}, + {"type": types.EventTypeTextMessageStart, "role": types.RoleUser}, + {"type": types.EventTypeMessagesSnapshot, "messages": []interface{}{ + map[string]interface{}{"role": types.RoleUser, "content": "hello"}, + }}, + }, + want: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := countUserTurns(tt.events) + if got != tt.want { + t.Errorf("countUserTurns() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestBuildTranscriptText(t *testing.T) { + events := []map[string]interface{}{ + {"type": types.EventTypeMessagesSnapshot, "messages": []interface{}{ + map[string]interface{}{"role": types.RoleSystem, "content": "system msg"}, + map[string]interface{}{"role": types.RoleDeveloper, "content": "developer msg"}, + map[string]interface{}{"role": types.RoleUser, "content": "user msg"}, + map[string]interface{}{"role": types.RoleAssistant, "content": "assistant msg"}, + }}, + } + + text := buildTranscriptText(events) + + if strings.Contains(text, "system msg") { + t.Error("transcript should not contain system messages") + } + if strings.Contains(text, "developer msg") { + t.Error("transcript should not contain developer messages") + } + if !strings.Contains(text, "[user]: user msg") { + t.Error("transcript should contain user messages") + } + if !strings.Contains(text, "[assistant]: assistant msg") { + t.Error("transcript should contain assistant messages") + } +} + +func TestBuildTranscriptTextEmpty(t *testing.T) { + text := buildTranscriptText(nil) + if text != "" { + t.Errorf("expected empty transcript for nil events, got %q", text) + } +} + +func TestBuildTranscriptTextTruncation(t *testing.T) { + // Build events with very long content + longContent := strings.Repeat("a", maxTranscriptChars+1000) + events := []map[string]interface{}{ + {"type": types.EventTypeMessagesSnapshot, "messages": []interface{}{ + map[string]interface{}{"role": types.RoleUser, "content": longContent}, + }}, + } + + text := buildTranscriptText(events) + if !strings.HasSuffix(text, "[transcript truncated]") { + t.Error("long transcript should be truncated") + } +} + +func TestParseExtractionResponse(t *testing.T) { + tests := []struct { + name string + input string + wantLen int + wantErr bool + }{ + { + name: "valid JSON array", + input: `[{"title":"Test","content":"Content","type":"pattern","confidence":0.8}]`, + wantLen: 1, + wantErr: false, + }, + { + name: "with markdown fences", + input: "```json\n[{\"title\":\"Test\",\"content\":\"Content\",\"type\":\"correction\",\"confidence\":0.9}]\n```", + wantLen: 1, + wantErr: false, + }, + { + name: "empty array", + input: `[]`, + wantLen: 0, + wantErr: false, + }, + { + name: "invalid JSON", + input: `not json`, + wantLen: 0, + wantErr: true, + }, + { + name: "missing required fields filtered out", + input: `[{"title":"","content":"Content","type":"pattern","confidence":0.5},{"title":"Good","content":"Content","type":"pattern","confidence":0.8}]`, + wantLen: 1, + wantErr: false, + }, + { + name: "invalid type filtered out", + input: `[{"title":"Test","content":"Content","type":"invalid","confidence":0.8}]`, + wantLen: 0, + wantErr: false, + }, + { + name: "confidence clamped to 1.0", + input: `[{"title":"Test","content":"Content","type":"pattern","confidence":1.5}]`, + wantLen: 1, + wantErr: false, + }, + { + name: "negative confidence clamped to 0.0", + input: `[{"title":"Test","content":"Content","type":"correction","confidence":-0.5}]`, + wantLen: 1, + wantErr: false, + }, + { + name: "multiple valid candidates", + input: `[{"title":"A","content":"C","type":"pattern","confidence":0.7},{"title":"B","content":"D","type":"correction","confidence":0.9}]`, + wantLen: 2, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseExtractionResponse(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("parseExtractionResponse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && len(got) != tt.wantLen { + t.Errorf("parseExtractionResponse() returned %d candidates, want %d", len(got), tt.wantLen) + } + }) + } +} + +func TestParseExtractionResponseConfidenceClamping(t *testing.T) { + candidates, err := parseExtractionResponse(`[{"title":"T","content":"C","type":"pattern","confidence":1.5}]`) + if err != nil { + t.Fatal(err) + } + if len(candidates) != 1 { + t.Fatal("expected 1 candidate") + } + if candidates[0].Confidence != 1.0 { + t.Errorf("confidence should be clamped to 1.0, got %f", candidates[0].Confidence) + } + + candidates, err = parseExtractionResponse(`[{"title":"T","content":"C","type":"pattern","confidence":-0.5}]`) + if err != nil { + t.Fatal(err) + } + if len(candidates) != 1 { + t.Fatal("expected 1 candidate") + } + if candidates[0].Confidence != 0.0 { + t.Errorf("confidence should be clamped to 0.0, got %f", candidates[0].Confidence) + } +} + +func TestRankAndCap(t *testing.T) { + candidates := []InsightCandidate{ + {Title: "Low", Confidence: 0.3}, + {Title: "High", Confidence: 0.9}, + {Title: "Mid", Confidence: 0.6}, + {Title: "VeryHigh", Confidence: 0.95}, + } + + result := rankAndCap(candidates, 2) + if len(result) != 2 { + t.Fatalf("expected 2 candidates, got %d", len(result)) + } + if result[0].Title != "VeryHigh" { + t.Errorf("expected first candidate to be VeryHigh, got %s", result[0].Title) + } + if result[1].Title != "High" { + t.Errorf("expected second candidate to be High, got %s", result[1].Title) + } +} + +func TestRankAndCapEmptySlice(t *testing.T) { + result := rankAndCap(nil, 5) + if result != nil { + t.Errorf("expected nil, got %v", result) + } +} + +func TestRankAndCapNoTruncation(t *testing.T) { + candidates := []InsightCandidate{ + {Title: "A", Confidence: 0.8}, + {Title: "B", Confidence: 0.5}, + } + result := rankAndCap(candidates, 10) + if len(result) != 2 { + t.Errorf("expected 2 candidates (no truncation needed), got %d", len(result)) + } +} + +func TestSlugify(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"Hello World!", "hello-world"}, + {"use-kebab-case", "use-kebab-case"}, + {"Special @#$ Characters", "special-characters"}, + {"", "insight"}, + {"A Very Long Title That Exceeds The Maximum Length Of Sixty Characters For Slugs And More", "a-very-long-title-that-exceeds-the-maximum-length-of-sixty-c"}, + {"---dashes---", "dashes"}, + {"123 Numbers", "123-numbers"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := slugify(tt.input) + if got != tt.want { + t.Errorf("slugify(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestInsightFilePath(t *testing.T) { + c := InsightCandidate{ + Title: "Use Context Managers", + Type: "pattern", + } + path := insightFilePath(c) + if !strings.Contains(path, "docs/learned/patterns/") { + t.Errorf("expected path to contain docs/learned/patterns/, got %s", path) + } + if !strings.Contains(path, "use-context-managers") { + t.Errorf("expected path to contain use-context-managers, got %s", path) + } + + c2 := InsightCandidate{ + Title: "Avoid Panic in Production", + Type: "correction", + } + path2 := insightFilePath(c2) + if !strings.Contains(path2, "docs/learned/corrections/") { + t.Errorf("expected path to contain docs/learned/corrections/, got %s", path2) + } +} + +func TestFormatInsightMarkdown(t *testing.T) { + c := InsightCandidate{ + Title: "Test Insight", + Content: "This is the content.", + Type: "pattern", + Confidence: 0.85, + } + md := formatInsightMarkdown(c, "test-session", "test-project") + + checks := []struct { + label string + contains string + }{ + {"title", "# Test Insight"}, + {"type", "**Type:** pattern"}, + {"source", "**Source:** insight-extraction"}, + {"session ref", "test-project/test-session"}, + {"content", "This is the content."}, + {"confidence", "**Confidence:** 0.85"}, + } + + for _, check := range checks { + if !strings.Contains(md, check.contains) { + t.Errorf("markdown should contain %s (%q)", check.label, check.contains) + } + } +} + +func TestParseGitHubOwnerRepo(t *testing.T) { + tests := []struct { + input string + wantOwner string + wantRepo string + wantErr bool + }{ + {"https://github.com/owner/repo", "owner", "repo", false}, + {"https://github.com/owner/repo.git", "owner", "repo", false}, + {"git@github.com:owner/repo.git", "owner", "repo", false}, + {"https://gitlab.com/owner/repo", "", "", true}, + {"invalid-url", "", "", true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + owner, repo, err := parseGitHubOwnerRepo(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("error = %v, wantErr %v", err, tt.wantErr) + return + } + if owner != tt.wantOwner { + t.Errorf("owner = %q, want %q", owner, tt.wantOwner) + } + if repo != tt.wantRepo { + t.Errorf("repo = %q, want %q", repo, tt.wantRepo) + } + }) + } +} diff --git a/components/backend/handlers/feedback_loop_config.go b/components/backend/handlers/feedback_loop_config.go new file mode 100644 index 000000000..434fa8b7e --- /dev/null +++ b/components/backend/handlers/feedback_loop_config.go @@ -0,0 +1,270 @@ +package handlers + +import ( + "context" + "encoding/json" + "log" + "net/http" + + "github.com/gin-gonic/gin" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const ( + // feedbackLoopConfigMap is the ConfigMap name for feedback loop config and state. + // Stores per-project configuration, deduplication state, and trigger history. + // State is persisted here so that backend restarts do not lose deduplication + // state or correction counts (FR-011). + feedbackLoopConfigMap = "feedback-loop-state" + + defaultMinCorrections = 2 + defaultTimeWindowHours = 24 +) + +// FeedbackLoopConfig holds per-project feedback loop settings. +type FeedbackLoopConfig struct { + MinCorrections int `json:"minCorrections"` + TimeWindowHours int `json:"timeWindowHours"` + AutoTriggerEnabled bool `json:"autoTriggerEnabled"` +} + +// FeedbackLoopHistoryEntry records a triggered improvement session. +type FeedbackLoopHistoryEntry struct { + SessionName string `json:"sessionName"` + CreatedAt string `json:"createdAt"` + Source string `json:"source"` // "event-driven" or "github-action" + TargetType string `json:"targetType"` // "workflow" or "repo" + TargetRepoURL string `json:"targetRepoURL"` + TargetBranch string `json:"targetBranch,omitempty"` + TargetPath string `json:"targetPath,omitempty"` + CorrectionIDs []string `json:"correctionIds"` +} + +// defaultFeedbackLoopConfig returns the default configuration. +func defaultFeedbackLoopConfig() FeedbackLoopConfig { + return FeedbackLoopConfig{ + MinCorrections: defaultMinCorrections, + TimeWindowHours: defaultTimeWindowHours, + AutoTriggerEnabled: true, + } +} + +// loadFeedbackLoopConfig reads the config from the ConfigMap. Returns defaults if not found. +func loadFeedbackLoopConfig(ctx context.Context, namespace string) (FeedbackLoopConfig, error) { + cm, err := K8sClient.CoreV1().ConfigMaps(namespace).Get(ctx, feedbackLoopConfigMap, metav1.GetOptions{}) + if errors.IsNotFound(err) { + return defaultFeedbackLoopConfig(), nil + } + if err != nil { + return FeedbackLoopConfig{}, err + } + + configData, ok := cm.Data["config"] + if !ok || configData == "" { + return defaultFeedbackLoopConfig(), nil + } + + var config FeedbackLoopConfig + if err := json.Unmarshal([]byte(configData), &config); err != nil { + log.Printf("Failed to parse feedback loop config in %s, using defaults: %v", namespace, err) + return defaultFeedbackLoopConfig(), nil + } + + return config, nil +} + +// saveFeedbackLoopConfig persists config to the ConfigMap. Creates it if absent. +func saveFeedbackLoopConfig(ctx context.Context, namespace string, config FeedbackLoopConfig) error { + data, err := json.Marshal(config) + if err != nil { + return err + } + + cm, err := K8sClient.CoreV1().ConfigMaps(namespace).Get(ctx, feedbackLoopConfigMap, metav1.GetOptions{}) + if errors.IsNotFound(err) { + newCM := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: feedbackLoopConfigMap, + Namespace: namespace, + Labels: map[string]string{ + "app.kubernetes.io/managed-by": "ambient-code", + "app.kubernetes.io/component": "feedback-loop", + }, + }, + Data: map[string]string{ + "config": string(data), + }, + } + _, err = K8sClient.CoreV1().ConfigMaps(namespace).Create(ctx, newCM, metav1.CreateOptions{}) + return err + } + if err != nil { + return err + } + + if cm.Data == nil { + cm.Data = map[string]string{} + } + cm.Data["config"] = string(data) + _, err = K8sClient.CoreV1().ConfigMaps(namespace).Update(ctx, cm, metav1.UpdateOptions{}) + return err +} + +// loadFeedbackLoopHistory reads the history from the ConfigMap. +func loadFeedbackLoopHistory(ctx context.Context, namespace string) ([]FeedbackLoopHistoryEntry, error) { + cm, err := K8sClient.CoreV1().ConfigMaps(namespace).Get(ctx, feedbackLoopConfigMap, metav1.GetOptions{}) + if errors.IsNotFound(err) { + return []FeedbackLoopHistoryEntry{}, nil + } + if err != nil { + return nil, err + } + + historyData, ok := cm.Data["history"] + if !ok || historyData == "" { + return []FeedbackLoopHistoryEntry{}, nil + } + + var entries []FeedbackLoopHistoryEntry + if err := json.Unmarshal([]byte(historyData), &entries); err != nil { + log.Printf("Failed to parse feedback loop history in %s: %v", namespace, err) + return []FeedbackLoopHistoryEntry{}, nil + } + + return entries, nil +} + +// appendFeedbackLoopHistory adds an entry to the history in the ConfigMap. +func appendFeedbackLoopHistory(ctx context.Context, namespace string, entry FeedbackLoopHistoryEntry) error { + entries, err := loadFeedbackLoopHistory(ctx, namespace) + if err != nil { + return err + } + + entries = append(entries, entry) + data, err := json.Marshal(entries) + if err != nil { + return err + } + + cm, err := K8sClient.CoreV1().ConfigMaps(namespace).Get(ctx, feedbackLoopConfigMap, metav1.GetOptions{}) + if errors.IsNotFound(err) { + newCM := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: feedbackLoopConfigMap, + Namespace: namespace, + Labels: map[string]string{ + "app.kubernetes.io/managed-by": "ambient-code", + "app.kubernetes.io/component": "feedback-loop", + }, + }, + Data: map[string]string{ + "history": string(data), + }, + } + _, err = K8sClient.CoreV1().ConfigMaps(namespace).Create(ctx, newCM, metav1.CreateOptions{}) + return err + } + if err != nil { + return err + } + + if cm.Data == nil { + cm.Data = map[string]string{} + } + cm.Data["history"] = string(data) + _, err = K8sClient.CoreV1().ConfigMaps(namespace).Update(ctx, cm, metav1.UpdateOptions{}) + return err +} + +// GetFeedbackLoopConfig handles GET /api/projects/:projectName/feedback-loop/config +func GetFeedbackLoopConfig(c *gin.Context) { + namespace := sanitizeParam(c.Param("projectName")) + + reqK8s, _ := GetK8sClientsForRequest(c) + if reqK8s == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "User token required"}) + c.Abort() + return + } + + config, err := loadFeedbackLoopConfig(c.Request.Context(), namespace) + if err != nil { + log.Printf("Failed to load feedback loop config for %s: %v", namespace, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to load feedback loop configuration"}) + return + } + + c.JSON(http.StatusOK, config) +} + +// PutFeedbackLoopConfig handles PUT /api/projects/:projectName/feedback-loop/config +func PutFeedbackLoopConfig(c *gin.Context) { + namespace := sanitizeParam(c.Param("projectName")) + + reqK8s, _ := GetK8sClientsForRequest(c) + if reqK8s == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "User token required"}) + c.Abort() + return + } + + // Check admin permission (ability to patch ConfigMaps in namespace) + allowed, err := checkConfigMapPermission(c.Request.Context(), reqK8s, namespace, "patch") + if err != nil { + log.Printf("Failed to check ConfigMap permissions for feedback loop in %s: %v", namespace, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"}) + return + } + if !allowed { + c.JSON(http.StatusForbidden, gin.H{"error": "Admin permissions required to modify feedback loop configuration"}) + return + } + + var req FeedbackLoopConfig + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"}) + return + } + + // Validate + if req.MinCorrections < 1 { + c.JSON(http.StatusBadRequest, gin.H{"error": "minCorrections must be >= 1"}) + return + } + if req.TimeWindowHours < 1 { + c.JSON(http.StatusBadRequest, gin.H{"error": "timeWindowHours must be >= 1"}) + return + } + + if err := saveFeedbackLoopConfig(c.Request.Context(), namespace, req); err != nil { + log.Printf("Failed to save feedback loop config for %s: %v", namespace, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to save feedback loop configuration"}) + return + } + + c.JSON(http.StatusOK, req) +} + +// GetFeedbackLoopHistory handles GET /api/projects/:projectName/feedback-loop/history +func GetFeedbackLoopHistory(c *gin.Context) { + namespace := sanitizeParam(c.Param("projectName")) + + reqK8s, _ := GetK8sClientsForRequest(c) + if reqK8s == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "User token required"}) + c.Abort() + return + } + + entries, err := loadFeedbackLoopHistory(c.Request.Context(), namespace) + if err != nil { + log.Printf("Failed to load feedback loop history for %s: %v", namespace, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to load feedback loop history"}) + return + } + + c.JSON(http.StatusOK, gin.H{"sessions": entries}) +} diff --git a/components/backend/handlers/feedback_loop_config_test.go b/components/backend/handlers/feedback_loop_config_test.go new file mode 100644 index 000000000..feab4fad5 --- /dev/null +++ b/components/backend/handlers/feedback_loop_config_test.go @@ -0,0 +1,291 @@ +//go:build test + +package handlers + +import ( + "context" + "encoding/json" + "net/http" + + test_constants "ambient-code-backend/tests/constants" + "ambient-code-backend/tests/logger" + "ambient-code-backend/tests/test_utils" + + "github.com/gin-gonic/gin" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +var _ = Describe("Feedback Loop Config Handler", Label(test_constants.LabelUnit, test_constants.LabelHandlers), func() { + var ( + httpUtils *test_utils.HTTPTestUtils + k8sUtils *test_utils.K8sTestUtils + testToken string + ) + + BeforeEach(func() { + logger.Log("Setting up Feedback Loop Config test") + + k8sUtils = test_utils.NewK8sTestUtils(false, "test-project") + SetupHandlerDependencies(k8sUtils) + + httpUtils = test_utils.NewHTTPTestUtils() + + ctx := context.Background() + _, err := k8sUtils.K8sClient.CoreV1().Namespaces().Create(ctx, &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{Name: "test-project"}, + }, metav1.CreateOptions{}) + if err != nil && !errors.IsAlreadyExists(err) { + Expect(err).NotTo(HaveOccurred()) + } + _, err = k8sUtils.CreateTestRole(ctx, "test-project", "feedback-loop-role", []string{"get", "list", "create", "update", "delete", "patch"}, "*", "") + Expect(err).NotTo(HaveOccurred()) + + token, _, err := httpUtils.SetValidTestToken( + k8sUtils, + "test-project", + []string{"get", "list", "create", "update", "delete", "patch"}, + "*", + "", + "feedback-loop-role", + ) + Expect(err).NotTo(HaveOccurred()) + testToken = token + }) + + AfterEach(func() { + if k8sUtils != nil { + _ = k8sUtils.K8sClient.CoreV1().Namespaces().Delete(context.Background(), "test-project", metav1.DeleteOptions{}) + } + }) + + Describe("GetFeedbackLoopConfig", func() { + It("Should return defaults when no config exists", func() { + ginCtx := httpUtils.CreateTestGinContext("GET", "/api/projects/test-project/feedback-loop/config", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + httpUtils.SetProjectContext("test-project") + httpUtils.SetAuthHeader(testToken) + + GetFeedbackLoopConfig(ginCtx) + + httpUtils.AssertHTTPStatus(http.StatusOK) + + var resp FeedbackLoopConfig + httpUtils.GetResponseJSON(&resp) + Expect(resp.MinCorrections).To(Equal(defaultMinCorrections)) + Expect(resp.TimeWindowHours).To(Equal(defaultTimeWindowHours)) + Expect(resp.AutoTriggerEnabled).To(BeTrue()) + + logger.Log("GetFeedbackLoopConfig returns defaults correctly") + }) + + It("Should return stored config when it exists", func() { + ctx := context.Background() + cm := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: feedbackLoopConfigMap, + Namespace: "test-project", + }, + Data: map[string]string{ + "config": `{"minCorrections":5,"timeWindowHours":12,"autoTriggerEnabled":false}`, + }, + } + _, err := k8sUtils.K8sClient.CoreV1().ConfigMaps("test-project").Create(ctx, cm, metav1.CreateOptions{}) + Expect(err).NotTo(HaveOccurred()) + + ginCtx := httpUtils.CreateTestGinContext("GET", "/api/projects/test-project/feedback-loop/config", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + httpUtils.SetProjectContext("test-project") + httpUtils.SetAuthHeader(testToken) + + GetFeedbackLoopConfig(ginCtx) + + httpUtils.AssertHTTPStatus(http.StatusOK) + + var resp FeedbackLoopConfig + httpUtils.GetResponseJSON(&resp) + Expect(resp.MinCorrections).To(Equal(5)) + Expect(resp.TimeWindowHours).To(Equal(12)) + Expect(resp.AutoTriggerEnabled).To(BeFalse()) + + logger.Log("GetFeedbackLoopConfig returns stored config") + }) + + It("Should require authentication", func() { + ginCtx := httpUtils.CreateTestGinContext("GET", "/api/projects/test-project/feedback-loop/config", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + // No auth header + + GetFeedbackLoopConfig(ginCtx) + + httpUtils.AssertHTTPStatus(http.StatusUnauthorized) + + logger.Log("GetFeedbackLoopConfig requires auth") + }) + }) + + Describe("PutFeedbackLoopConfig", func() { + It("Should store valid config", func() { + body := map[string]interface{}{ + "minCorrections": 3, + "timeWindowHours": 12, + "autoTriggerEnabled": true, + } + ginCtx := httpUtils.CreateTestGinContext("PUT", "/api/projects/test-project/feedback-loop/config", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + httpUtils.SetProjectContext("test-project") + httpUtils.SetAuthHeader(testToken) + + PutFeedbackLoopConfig(ginCtx) + + httpUtils.AssertHTTPStatus(http.StatusOK) + + // Verify it was persisted + ctx := context.Background() + cm, err := k8sUtils.K8sClient.CoreV1().ConfigMaps("test-project").Get(ctx, feedbackLoopConfigMap, metav1.GetOptions{}) + Expect(err).NotTo(HaveOccurred()) + Expect(cm.Data["config"]).To(ContainSubstring(`"minCorrections":3`)) + + logger.Log("PutFeedbackLoopConfig stores valid config") + }) + + It("Should reject minCorrections less than 1", func() { + body := map[string]interface{}{ + "minCorrections": 0, + "timeWindowHours": 24, + "autoTriggerEnabled": true, + } + ginCtx := httpUtils.CreateTestGinContext("PUT", "/api/projects/test-project/feedback-loop/config", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + httpUtils.SetProjectContext("test-project") + httpUtils.SetAuthHeader(testToken) + + PutFeedbackLoopConfig(ginCtx) + + httpUtils.AssertHTTPStatus(http.StatusBadRequest) + + logger.Log("PutFeedbackLoopConfig rejects invalid minCorrections") + }) + + It("Should reject timeWindowHours less than 1", func() { + body := map[string]interface{}{ + "minCorrections": 2, + "timeWindowHours": 0, + "autoTriggerEnabled": true, + } + ginCtx := httpUtils.CreateTestGinContext("PUT", "/api/projects/test-project/feedback-loop/config", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + httpUtils.SetProjectContext("test-project") + httpUtils.SetAuthHeader(testToken) + + PutFeedbackLoopConfig(ginCtx) + + httpUtils.AssertHTTPStatus(http.StatusBadRequest) + + logger.Log("PutFeedbackLoopConfig rejects invalid timeWindowHours") + }) + + It("Should require authentication", func() { + body := map[string]interface{}{ + "minCorrections": 2, + "timeWindowHours": 24, + "autoTriggerEnabled": true, + } + ginCtx := httpUtils.CreateTestGinContext("PUT", "/api/projects/test-project/feedback-loop/config", body) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + // No auth header + + PutFeedbackLoopConfig(ginCtx) + + httpUtils.AssertHTTPStatus(http.StatusUnauthorized) + + logger.Log("PutFeedbackLoopConfig requires auth") + }) + }) + + Describe("GetFeedbackLoopHistory", func() { + It("Should return empty list when no history exists", func() { + ginCtx := httpUtils.CreateTestGinContext("GET", "/api/projects/test-project/feedback-loop/history", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + httpUtils.SetProjectContext("test-project") + httpUtils.SetAuthHeader(testToken) + + GetFeedbackLoopHistory(ginCtx) + + httpUtils.AssertHTTPStatus(http.StatusOK) + + var resp map[string]json.RawMessage + httpUtils.GetResponseJSON(&resp) + var sessions []FeedbackLoopHistoryEntry + err := json.Unmarshal(resp["sessions"], &sessions) + Expect(err).NotTo(HaveOccurred()) + Expect(sessions).To(BeEmpty()) + + logger.Log("GetFeedbackLoopHistory returns empty list") + }) + + It("Should return stored history entries", func() { + ctx := context.Background() + entries := []FeedbackLoopHistoryEntry{ + { + SessionName: "session-123", + CreatedAt: "2026-04-16T10:00:00Z", + Source: "event-driven", + TargetType: "workflow", + TargetRepoURL: "https://github.com/org/repo", + TargetBranch: "main", + TargetPath: ".ambient/workflows/review", + CorrectionIDs: []string{"trace-1", "trace-2"}, + }, + } + data, err := json.Marshal(entries) + Expect(err).NotTo(HaveOccurred()) + + cm := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: feedbackLoopConfigMap, + Namespace: "test-project", + }, + Data: map[string]string{ + "history": string(data), + }, + } + _, err = k8sUtils.K8sClient.CoreV1().ConfigMaps("test-project").Create(ctx, cm, metav1.CreateOptions{}) + Expect(err).NotTo(HaveOccurred()) + + ginCtx := httpUtils.CreateTestGinContext("GET", "/api/projects/test-project/feedback-loop/history", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + httpUtils.SetProjectContext("test-project") + httpUtils.SetAuthHeader(testToken) + + GetFeedbackLoopHistory(ginCtx) + + httpUtils.AssertHTTPStatus(http.StatusOK) + + var resp map[string]json.RawMessage + httpUtils.GetResponseJSON(&resp) + var sessions []FeedbackLoopHistoryEntry + err = json.Unmarshal(resp["sessions"], &sessions) + Expect(err).NotTo(HaveOccurred()) + Expect(sessions).To(HaveLen(1)) + Expect(sessions[0].SessionName).To(Equal("session-123")) + Expect(sessions[0].Source).To(Equal("event-driven")) + + logger.Log("GetFeedbackLoopHistory returns stored entries") + }) + + It("Should require authentication", func() { + ginCtx := httpUtils.CreateTestGinContext("GET", "/api/projects/test-project/feedback-loop/history", nil) + ginCtx.Params = gin.Params{{Key: "projectName", Value: "test-project"}} + + GetFeedbackLoopHistory(ginCtx) + + httpUtils.AssertHTTPStatus(http.StatusUnauthorized) + + logger.Log("GetFeedbackLoopHistory requires auth") + }) + }) +}) diff --git a/components/backend/handlers/feedback_loop_prompt.go b/components/backend/handlers/feedback_loop_prompt.go new file mode 100644 index 000000000..59e3be94b --- /dev/null +++ b/components/backend/handlers/feedback_loop_prompt.go @@ -0,0 +1,249 @@ +package handlers + +import ( + "fmt" + "sort" + "strings" +) + +// maxCorrectionsPerPrompt caps the number of corrections included in a single +// improvement prompt. Any remainder is summarized with a count. +const maxCorrectionsPerPrompt = 50 + +// correctionDetail holds one correction for prompt rendering. +type correctionDetail struct { + CorrectionType string `json:"correction_type"` + Source string `json:"source"` + AgentAction string `json:"agent_action"` + UserCorrection string `json:"user_correction"` + SessionName string `json:"session_name"` + TraceID string `json:"trace_id"` +} + +// correctionGroup aggregates corrections for a single target. +type correctionGroup struct { + TargetType string `json:"target_type"` + TargetRepoURL string `json:"target_repo_url"` + TargetBranch string `json:"target_branch"` + TargetPath string `json:"target_path"` + Corrections []correctionDetail `json:"corrections"` + TotalCount int `json:"total_count"` + CorrectionTypeCounts map[string]int `json:"correction_type_counts"` + SourceCounts map[string]int `json:"source_counts"` +} + +var correctionTypeDescriptions = map[string]string{ + "incomplete": "missed something that should have been done", + "incorrect": "did the wrong thing", + "out_of_scope": "worked on wrong files or area", + "style": "right result, wrong approach or pattern", +} + +var correctionSourceDescriptions = map[string]string{ + "human": "user-provided correction during a session", + "rubric": "automatically detected from a rubric evaluation", +} + +// sanitizePromptText removes shell-interpreted characters from text embedded +// in prompts. The prompt may be passed through bash eval by ambient-action, +// so backticks, $, and angle brackets must be stripped or replaced. +func sanitizePromptText(text string) string { + r := strings.NewReplacer("`", "'", "$", "", "<", "(", ">", ")") + return r.Replace(text) +} + +// groupCorrectionKey builds a deduplication key for a correction target. +// Repo corrections exclude the branch (corrections apply regardless of branch). +// Workflow corrections include the branch since different branches may have +// different workflow instructions. +func groupCorrectionKey(targetType, repoURL, branch, path string) string { + groupBranch := "" + if targetType == "workflow" { + groupBranch = branch + } + return fmt.Sprintf("%s|%s|%s|%s", targetType, repoURL, groupBranch, path) +} + +// repoShortName extracts the short name from a repo URL. +func repoShortName(url string) string { + if url == "" { + return "unknown" + } + url = strings.TrimRight(url, "/") + parts := strings.Split(url, "/") + name := parts[len(parts)-1] + return strings.TrimSuffix(name, ".git") +} + +// buildImprovementPrompt constructs the prompt for an improvement session. +// This is the Go port of build_improvement_prompt() from +// scripts/feedback-loop/query_corrections.py. +func buildImprovementPrompt(group correctionGroup) string { + total := group.TotalCount + typeCounts := group.CorrectionTypeCounts + + // Find the most common correction type + topType := "N/A" + topCount := 0 + for t, count := range typeCounts { + if count > topCount { + topType = t + topCount = count + } + } + + // Build type breakdown sorted by count descending + type kv struct { + Key string + Value int + } + sortedTypes := make([]kv, 0, len(typeCounts)) + for k, v := range typeCounts { + sortedTypes = append(sortedTypes, kv{k, v}) + } + sort.Slice(sortedTypes, func(i, j int) bool { return sortedTypes[i].Value > sortedTypes[j].Value }) + + var typeBreakdown strings.Builder + for _, item := range sortedTypes { + desc := correctionTypeDescriptions[item.Key] + if desc == "" { + desc = item.Key + } + fmt.Fprintf(&typeBreakdown, "- **%s** (%s): %d\n", item.Key, desc, item.Value) + } + + // Build source breakdown sorted by count descending + sortedSources := make([]kv, 0, len(group.SourceCounts)) + for k, v := range group.SourceCounts { + sortedSources = append(sortedSources, kv{k, v}) + } + sort.Slice(sortedSources, func(i, j int) bool { return sortedSources[i].Value > sortedSources[j].Value }) + + var sourceBreakdown strings.Builder + for _, item := range sortedSources { + desc := correctionSourceDescriptions[item.Key] + if desc == "" { + desc = item.Key + } + fmt.Fprintf(&sourceBreakdown, "- **%s** (%s): %d\n", item.Key, desc, item.Value) + } + + // Build corrections detail (capped at maxCorrectionsPerPrompt) + corrections := group.Corrections + displayCount := len(corrections) + if displayCount > maxCorrectionsPerPrompt { + displayCount = maxCorrectionsPerPrompt + } + + var correctionsDetail strings.Builder + for i := 0; i < displayCount; i++ { + c := corrections[i] + sourceTag := "" + if c.Source == "rubric" { + sourceTag = " [rubric]" + } + agentAction := sanitizePromptText(c.AgentAction) + userCorrection := sanitizePromptText(c.UserCorrection) + fmt.Fprintf(&correctionsDetail, "### Correction %d (%s%s)\n", i+1, c.CorrectionType, sourceTag) + fmt.Fprintf(&correctionsDetail, "- **Agent did**: %s\n", agentAction) + fmt.Fprintf(&correctionsDetail, "- **User corrected to**: %s\n", userCorrection) + if c.SessionName != "" { + fmt.Fprintf(&correctionsDetail, "- **Session**: %s\n", c.SessionName) + } + correctionsDetail.WriteString("\n") + } + + if len(corrections) > maxCorrectionsPerPrompt { + remainder := len(corrections) - maxCorrectionsPerPrompt + fmt.Fprintf(&correctionsDetail, "\n*(%d additional corrections not shown — review Langfuse for full details)*\n\n", remainder) + } + + // Target description and task instructions differ by target type + var targetDescription, taskInstructions string + branchLabel := group.TargetBranch + if branchLabel == "" { + branchLabel = "default" + } + + if group.TargetType == "workflow" { + targetDescription = fmt.Sprintf( + "- **Target type**: workflow\n- **Workflow path**: %s\n- **Workflow repo**: %s (branch: %s)", + group.TargetPath, group.TargetRepoURL, branchLabel, + ) + taskInstructions = fmt.Sprintf( + "2. **Make targeted improvements**:\n"+ + " - Update workflow files in %s (system prompt, instructions)\n"+ + " where the workflow is guiding the agent incorrectly or incompletely\n"+ + " - Update rubric criteria if rubric-sourced corrections indicate misaligned expectations\n"+ + " - Update .claude/patterns/ files if the agent consistently used wrong patterns", + group.TargetPath, + ) + } else { + targetDescription = fmt.Sprintf( + "- **Target type**: repository\n- **Repository**: %s (branch: %s)", + group.TargetRepoURL, branchLabel, + ) + taskInstructions = "2. **Make targeted improvements**:\n" + + " - Update CLAUDE.md or .claude/ context files where the agent\n" + + " lacked necessary knowledge about this repository\n" + + " - Update .claude/patterns/ files if the agent consistently used wrong patterns\n" + + " - Add missing documentation that would have prevented these corrections" + } + + return fmt.Sprintf(`# Feedback Loop: Improvement Session + +## Context + +You are analyzing %d corrections collected from Ambient Code Platform sessions. + +%s +- **Most common correction type**: %s (%d occurrences) + +## Correction Type Breakdown + +%s +## Correction Sources + +%s +## Detailed Corrections + +%s## Your Task + +1. **Analyze patterns**: Look for recurring themes across the corrections. + Single incidents may be agent errors, but patterns indicate systemic gaps. + +%s + +3. **Use the corrections as a guide**: For each change, ask "would this correction + have been prevented if this information existed in the context?" + +4. **Be surgical**: Only update files directly related to the corrections. + Preserve existing content. Add or modify — do not replace wholesale. + +5. **Commit, push, and open a PR**: Commit your changes with a descriptive + message, push to a feature branch, then create a pull request targeting the + default branch. NEVER push directly to main or master. + + **Include a link to this improvement session in the PR body.** Build the URL + by reading the environment variables AMBIENT_UI_URL, AGENTIC_SESSION_NAMESPACE, + and AGENTIC_SESSION_NAME, then construct: + AMBIENT_UI_URL/projects/AGENTIC_SESSION_NAMESPACE/sessions/AGENTIC_SESSION_NAME + Add it under a "Session" heading so reviewers can trace the PR back to this session. + +## Requirements + +- Do NOT over-generalize from isolated incidents +- Focus on the most frequent correction types first +- Each improvement should directly address one or more specific corrections +- Keep changes minimal and focused +- Test that any modified configuration files are still valid +`, + total, + targetDescription, + topType, typeCounts[topType], + typeBreakdown.String(), + sourceBreakdown.String(), + correctionsDetail.String(), + taskInstructions, + ) +} diff --git a/components/backend/handlers/feedback_loop_prompt_test.go b/components/backend/handlers/feedback_loop_prompt_test.go new file mode 100644 index 000000000..22a0c8117 --- /dev/null +++ b/components/backend/handlers/feedback_loop_prompt_test.go @@ -0,0 +1,174 @@ +//go:build test + +package handlers + +import ( + "strings" + + test_constants "ambient-code-backend/tests/constants" + "ambient-code-backend/tests/logger" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Feedback Loop Prompt Builder", Label(test_constants.LabelUnit, test_constants.LabelHandlers), func() { + Describe("buildImprovementPrompt", func() { + It("Should build a prompt for workflow corrections", func() { + group := correctionGroup{ + TargetType: "workflow", + TargetRepoURL: "https://github.com/org/repo", + TargetBranch: "main", + TargetPath: ".ambient/workflows/review", + TotalCount: 2, + CorrectionTypeCounts: map[string]int{ + "incorrect": 1, + "style": 1, + }, + SourceCounts: map[string]int{ + "human": 2, + }, + Corrections: []correctionDetail{ + { + CorrectionType: "incorrect", + Source: "human", + AgentAction: "Deleted the test file", + UserCorrection: "Should have updated the test file", + SessionName: "session-abc", + TraceID: "trace-1", + }, + { + CorrectionType: "style", + Source: "human", + AgentAction: "Used fmt.Println for logging", + UserCorrection: "Use log.Printf instead", + SessionName: "session-def", + TraceID: "trace-2", + }, + }, + } + + prompt := buildImprovementPrompt(group) + + Expect(prompt).To(ContainSubstring("2 corrections")) + Expect(prompt).To(ContainSubstring("workflow")) + Expect(prompt).To(ContainSubstring(".ambient/workflows/review")) + Expect(prompt).To(ContainSubstring("Deleted the test file")) + Expect(prompt).To(ContainSubstring("Should have updated the test file")) + Expect(prompt).To(ContainSubstring("Update workflow files")) + + logger.Log("Workflow improvement prompt built successfully") + }) + + It("Should build a prompt for repo corrections", func() { + group := correctionGroup{ + TargetType: "repo", + TargetRepoURL: "https://github.com/org/repo", + TargetBranch: "main", + TargetPath: "", + TotalCount: 3, + CorrectionTypeCounts: map[string]int{ + "incomplete": 3, + }, + SourceCounts: map[string]int{ + "human": 2, + "rubric": 1, + }, + Corrections: []correctionDetail{ + {CorrectionType: "incomplete", Source: "human", AgentAction: "a1", UserCorrection: "c1", TraceID: "t1"}, + {CorrectionType: "incomplete", Source: "human", AgentAction: "a2", UserCorrection: "c2", TraceID: "t2"}, + {CorrectionType: "incomplete", Source: "rubric", AgentAction: "a3", UserCorrection: "c3", TraceID: "t3"}, + }, + } + + prompt := buildImprovementPrompt(group) + + Expect(prompt).To(ContainSubstring("3 corrections")) + Expect(prompt).To(ContainSubstring("repository")) + Expect(prompt).To(ContainSubstring("Update CLAUDE.md")) + Expect(prompt).NotTo(ContainSubstring("Update workflow files")) + + logger.Log("Repo improvement prompt built successfully") + }) + + It("Should cap corrections at 50 and summarize remainder", func() { + corrections := make([]correctionDetail, 55) + for i := range corrections { + corrections[i] = correctionDetail{ + CorrectionType: "style", + Source: "human", + AgentAction: "action", + UserCorrection: "correction", + TraceID: "trace", + } + } + + group := correctionGroup{ + TargetType: "repo", + TargetRepoURL: "https://github.com/org/repo", + TotalCount: 55, + CorrectionTypeCounts: map[string]int{"style": 55}, + SourceCounts: map[string]int{"human": 55}, + Corrections: corrections, + } + + prompt := buildImprovementPrompt(group) + + Expect(prompt).To(ContainSubstring("55 corrections")) + Expect(prompt).To(ContainSubstring("5 additional corrections")) + // Should only have 50 numbered correction sections + Expect(strings.Count(prompt, "### Correction ")).To(Equal(50)) + + logger.Log("Prompt correctly caps at 50 corrections") + }) + + It("Should sanitize shell-interpreted characters", func() { + group := correctionGroup{ + TargetType: "repo", + TargetRepoURL: "https://github.com/org/repo", + TotalCount: 1, + CorrectionTypeCounts: map[string]int{"style": 1}, + SourceCounts: map[string]int{"human": 1}, + Corrections: []correctionDetail{ + { + CorrectionType: "style", + Source: "human", + AgentAction: "Used `backticks` and $VAR", + UserCorrection: "Don't use brackets", + TraceID: "t1", + }, + }, + } + + prompt := buildImprovementPrompt(group) + + // The sanitized prompt should not contain the raw shell characters + // from the agent_action and user_correction fields. + // The prompt template itself contains markdown formatting which is fine. + Expect(prompt).To(ContainSubstring("Used 'backticks' and VAR")) + Expect(prompt).To(ContainSubstring("Don't use (angle) brackets")) + + logger.Log("Prompt sanitizes shell characters correctly") + }) + }) + + Describe("groupCorrectionKey", func() { + It("Should include branch for workflow targets", func() { + key := groupCorrectionKey("workflow", "https://github.com/org/repo", "main", ".ambient/workflows/review") + Expect(key).To(Equal("workflow|https://github.com/org/repo|main|.ambient/workflows/review")) + }) + + It("Should exclude branch for repo targets", func() { + key := groupCorrectionKey("repo", "https://github.com/org/repo", "feature-branch", "") + Expect(key).To(Equal("repo|https://github.com/org/repo||")) + }) + }) + + Describe("repoShortName", func() { + It("Should extract repo name from URL", func() { + Expect(repoShortName("https://github.com/org/my-repo.git")).To(Equal("my-repo")) + Expect(repoShortName("https://github.com/org/my-repo")).To(Equal("my-repo")) + Expect(repoShortName("")).To(Equal("unknown")) + }) + }) +}) diff --git a/components/backend/handlers/feedback_loop_watcher.go b/components/backend/handlers/feedback_loop_watcher.go new file mode 100644 index 000000000..fe37d380f --- /dev/null +++ b/components/backend/handlers/feedback_loop_watcher.go @@ -0,0 +1,311 @@ +package handlers + +import ( + "context" + "fmt" + "log" + "strings" + "sync" + "time" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" +) + +// IsFeedbackLoopEnabled is a package-level var so tests can override the +// feature flag check. In production it delegates to FeatureEnabled(). +// Uses correctionsFeatureFlag (defined in corrections.go) — all learning +// agent loop features share one flag. +var IsFeedbackLoopEnabled = func() bool { + return FeatureEnabled(correctionsFeatureFlag) +} + +// CorrectionNotification is the payload passed to the watcher when a correction +// is logged via the corrections pipeline (spec 003). +type CorrectionNotification struct { + Project string `json:"project"` + TargetType string `json:"target_type"` // "workflow" or "repo" + TargetRepoURL string `json:"target_repo_url"` + TargetBranch string `json:"target_branch"` + TargetPath string `json:"target_path"` + CorrectionType string `json:"correction_type"` // "incomplete", "incorrect", "out_of_scope", "style" + Source string `json:"source"` // "human" or "rubric" + AgentAction string `json:"agent_action"` + UserCorrection string `json:"user_correction"` + SessionName string `json:"session_name"` + TraceID string `json:"trace_id"` + Timestamp time.Time `json:"timestamp"` +} + +// bufferedCorrection stores a correction in the per-target buffer. +type bufferedCorrection struct { + CorrectionNotification + ReceivedAt time.Time +} + +// targetBuffer holds buffered corrections for a single target key. +type targetBuffer struct { + Corrections []bufferedCorrection +} + +// FeedbackLoopWatcher evaluates corrections against per-project thresholds +// and creates improvement sessions when thresholds are crossed. +// +// The watcher maintains an in-memory buffer of recent corrections per target +// for fast evaluation. Deduplication state and history are persisted in a +// ConfigMap so that backend restarts do not create duplicate sessions. +// +// NOTE (v1): The in-memory correction buffer is lost on backend restart. +// This means corrections logged before a restart will not count toward the +// threshold after restart. This is acceptable for v1 because: +// 1. The weekly GHA sweep catches anything the real-time path misses. +// 2. Persisting the full correction buffer in ConfigMap would add write +// amplification on every correction log (acceptable trade-off for v2). +type FeedbackLoopWatcher struct { + mu sync.Mutex + buffers map[string]*targetBuffer // key: "project|targetKey" +} + +// NewFeedbackLoopWatcher creates a new watcher instance. +func NewFeedbackLoopWatcher() *FeedbackLoopWatcher { + return &FeedbackLoopWatcher{ + buffers: make(map[string]*targetBuffer), + } +} + +// NotifyCorrection is called asynchronously when a correction is logged. +// It buffers the correction and evaluates whether the threshold has been crossed. +// Returns true if an improvement session was triggered (for testing). +// +// This function MUST NOT add latency to the correction logging path (NFR-002). +// Callers should invoke it in a goroutine. +func (w *FeedbackLoopWatcher) NotifyCorrection(ctx context.Context, n CorrectionNotification) bool { + // Gate behind feature flag + if !IsFeedbackLoopEnabled() { + return false + } + + // Load project config + config, err := loadFeedbackLoopConfig(ctx, n.Project) + if err != nil { + log.Printf("feedback-loop: failed to load config for %s: %v", n.Project, err) + return false + } + + if !config.AutoTriggerEnabled { + return false + } + + targetKey := groupCorrectionKey(n.TargetType, n.TargetRepoURL, n.TargetBranch, n.TargetPath) + bufferKey := fmt.Sprintf("%s|%s", n.Project, targetKey) + + w.mu.Lock() + defer w.mu.Unlock() + + buf, ok := w.buffers[bufferKey] + if !ok { + buf = &targetBuffer{} + w.buffers[bufferKey] = buf + } + + // Add the correction + buf.Corrections = append(buf.Corrections, bufferedCorrection{ + CorrectionNotification: n, + ReceivedAt: time.Now(), + }) + + // Prune corrections outside the time window + cutoff := time.Now().Add(-time.Duration(config.TimeWindowHours) * time.Hour) + pruned := buf.Corrections[:0] + for _, c := range buf.Corrections { + if c.Timestamp.After(cutoff) { + pruned = append(pruned, c) + } + } + buf.Corrections = pruned + + // Check threshold + if len(buf.Corrections) < config.MinCorrections { + return false + } + + // Check deduplication: was an improvement session already created for this target + // within the time window? Checks the persisted history in ConfigMap. + if w.isDuplicate(ctx, n.Project, targetKey, cutoff) { + return false + } + + // Threshold crossed -- create improvement session + group := w.buildGroupFromBuffer(buf, n) + sessionName, err := w.createImprovementSession(ctx, n.Project, group) + if err != nil { + log.Printf("feedback-loop: failed to create improvement session for %s in %s: %v", targetKey, n.Project, err) + return false + } + + // Record in history for deduplication and the history endpoint + traceIDs := make([]string, len(buf.Corrections)) + for i, c := range buf.Corrections { + traceIDs[i] = c.TraceID + } + + entry := FeedbackLoopHistoryEntry{ + SessionName: sessionName, + CreatedAt: time.Now().UTC().Format(time.RFC3339), + Source: "event-driven", + TargetType: n.TargetType, + TargetRepoURL: n.TargetRepoURL, + TargetBranch: n.TargetBranch, + TargetPath: n.TargetPath, + CorrectionIDs: traceIDs, + } + if err := appendFeedbackLoopHistory(ctx, n.Project, entry); err != nil { + log.Printf("feedback-loop: failed to record history for %s: %v", n.Project, err) + } + + // Clear the buffer for this target to prevent re-triggering + buf.Corrections = nil + + log.Printf("feedback-loop: triggered improvement session %s for target %s in project %s", sessionName, targetKey, n.Project) + return true +} + +// isDuplicate checks whether an improvement session was already created for +// this target within the time window by reading the persisted history. +func (w *FeedbackLoopWatcher) isDuplicate(ctx context.Context, project, targetKey string, cutoff time.Time) bool { + entries, err := loadFeedbackLoopHistory(ctx, project) + if err != nil { + log.Printf("feedback-loop: failed to load history for dedup check in %s: %v", project, err) + // Fail open: prefer creating a possible duplicate over silently dropping + return false + } + + for _, entry := range entries { + entryKey := groupCorrectionKey(entry.TargetType, entry.TargetRepoURL, entry.TargetBranch, entry.TargetPath) + if entryKey != targetKey { + continue + } + createdAt, err := time.Parse(time.RFC3339, entry.CreatedAt) + if err != nil { + continue + } + if createdAt.After(cutoff) { + return true + } + } + + return false +} + +// buildGroupFromBuffer constructs a correctionGroup from the buffered corrections. +func (w *FeedbackLoopWatcher) buildGroupFromBuffer(buf *targetBuffer, n CorrectionNotification) correctionGroup { + typeCounts := map[string]int{} + sourceCounts := map[string]int{} + details := make([]correctionDetail, len(buf.Corrections)) + + for i, c := range buf.Corrections { + typeCounts[c.CorrectionType]++ + sourceCounts[c.Source]++ + details[i] = correctionDetail{ + CorrectionType: c.CorrectionType, + Source: c.Source, + AgentAction: c.AgentAction, + UserCorrection: c.UserCorrection, + SessionName: c.SessionName, + TraceID: c.TraceID, + } + } + + return correctionGroup{ + TargetType: n.TargetType, + TargetRepoURL: n.TargetRepoURL, + TargetBranch: n.TargetBranch, + TargetPath: n.TargetPath, + Corrections: details, + TotalCount: len(buf.Corrections), + CorrectionTypeCounts: typeCounts, + SourceCounts: sourceCounts, + } +} + +// createImprovementSession creates a new AgenticSession CR for the improvement. +// Returns the session name on success. +func (w *FeedbackLoopWatcher) createImprovementSession(ctx context.Context, project string, group correctionGroup) (string, error) { + prompt := buildImprovementPrompt(group) + displayName := buildSessionDisplayName(group.TargetType, group.TargetRepoURL, group.TargetPath) + + labels := map[string]interface{}{ + "feedback-loop": "true", + "source": "event-driven", + "target-type": group.TargetType, + } + + spec := map[string]interface{}{ + "initialPrompt": prompt, + "displayName": displayName, + "timeout": 300, + "llmSettings": map[string]interface{}{ + "model": "claude-sonnet-4-6", + "temperature": 0.7, + "maxTokens": 4000, + }, + "environmentVariables": map[string]interface{}{ + "LANGFUSE_MASK_MESSAGES": "false", + }, + } + + // Add repo if available + if group.TargetRepoURL != "" && strings.HasPrefix(group.TargetRepoURL, "http") { + repo := map[string]interface{}{ + "url": group.TargetRepoURL, + "autoPush": true, + } + if group.TargetBranch != "" { + repo["branch"] = group.TargetBranch + } + spec["repos"] = []interface{}{repo} + } + + sessionName := fmt.Sprintf("feedback-%s", time.Now().UTC().Format("20060102-150405")) + + sessionObj := &unstructured.Unstructured{ + Object: map[string]interface{}{ + "apiVersion": "vteam.ambient-code/v1alpha1", + "kind": "AgenticSession", + "metadata": map[string]interface{}{ + "name": sessionName, + "namespace": project, + "labels": labels, + }, + "spec": spec, + }, + } + + gvr := GetAgenticSessionV1Alpha1Resource() + _, err := DynamicClient.Resource(gvr).Namespace(project).Create(ctx, sessionObj, metav1.CreateOptions{}) + if err != nil { + return "", fmt.Errorf("failed to create improvement session CR: %w", err) + } + + return sessionName, nil +} + +// buildSessionDisplayName constructs a human-readable display name for an +// improvement session. Matches the naming convention from +// scripts/feedback-loop/query_corrections.py. +func buildSessionDisplayName(targetType, repoURL, targetPath string) string { + repoShort := repoShortName(repoURL) + if targetType == "workflow" { + pathShort := "" + if targetPath != "" { + parts := strings.Split(strings.TrimRight(targetPath, "/"), "/") + pathShort = parts[len(parts)-1] + } + name := "Feedback Loop: " + repoShort + if pathShort != "" { + name += " (" + pathShort + ")" + } + return name + } + return "Feedback Loop: " + repoShort + " (repo)" +} diff --git a/components/backend/handlers/feedback_loop_watcher_test.go b/components/backend/handlers/feedback_loop_watcher_test.go new file mode 100644 index 000000000..8f7075cfb --- /dev/null +++ b/components/backend/handlers/feedback_loop_watcher_test.go @@ -0,0 +1,365 @@ +//go:build test + +package handlers + +import ( + "context" + "time" + + test_constants "ambient-code-backend/tests/constants" + "ambient-code-backend/tests/logger" + "ambient-code-backend/tests/test_utils" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +var _ = Describe("Feedback Loop Watcher", Label(test_constants.LabelUnit, test_constants.LabelHandlers), func() { + var ( + k8sUtils *test_utils.K8sTestUtils + ) + + BeforeEach(func() { + k8sUtils = test_utils.NewK8sTestUtils(false, "test-project") + SetupHandlerDependencies(k8sUtils) + + // Enable the feature flag for tests (Unleash is not configured in test env) + IsFeedbackLoopEnabled = func() bool { return true } + + ctx := context.Background() + _, err := k8sUtils.K8sClient.CoreV1().Namespaces().Create(ctx, &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{Name: "test-project"}, + }, metav1.CreateOptions{}) + if err != nil && !errors.IsAlreadyExists(err) { + Expect(err).NotTo(HaveOccurred()) + } + }) + + AfterEach(func() { + // Restore default feature flag behavior + IsFeedbackLoopEnabled = func() bool { return FeatureEnabled(feedbackLoopFeatureFlag) } + + if k8sUtils != nil { + _ = k8sUtils.K8sClient.CoreV1().Namespaces().Delete(context.Background(), "test-project", metav1.DeleteOptions{}) + } + }) + + Describe("NotifyCorrection", func() { + It("Should not trigger when below threshold", func() { + watcher := NewFeedbackLoopWatcher() + + correction := CorrectionNotification{ + Project: "test-project", + TargetType: "workflow", + TargetRepoURL: "https://github.com/org/repo", + TargetBranch: "main", + TargetPath: ".ambient/workflows/review", + CorrectionType: "incorrect", + Source: "human", + AgentAction: "did wrong thing", + UserCorrection: "do right thing", + SessionName: "session-1", + TraceID: "trace-1", + Timestamp: time.Now(), + } + + triggered := watcher.NotifyCorrection(context.Background(), correction) + Expect(triggered).To(BeFalse()) + + logger.Log("Single correction does not trigger threshold") + }) + + It("Should trigger when threshold is met", func() { + watcher := NewFeedbackLoopWatcher() + + ctx := context.Background() + config := FeedbackLoopConfig{ + MinCorrections: 2, + TimeWindowHours: 24, + AutoTriggerEnabled: true, + } + err := saveFeedbackLoopConfig(ctx, "test-project", config) + Expect(err).NotTo(HaveOccurred()) + + base := CorrectionNotification{ + Project: "test-project", + TargetType: "workflow", + TargetRepoURL: "https://github.com/org/repo", + TargetBranch: "main", + TargetPath: ".ambient/workflows/review", + CorrectionType: "incorrect", + Source: "human", + AgentAction: "did wrong thing", + UserCorrection: "do right thing", + Timestamp: time.Now(), + } + + c1 := base + c1.TraceID = "trace-1" + c1.SessionName = "session-1" + triggered := watcher.NotifyCorrection(ctx, c1) + Expect(triggered).To(BeFalse()) + + c2 := base + c2.TraceID = "trace-2" + c2.SessionName = "session-2" + triggered = watcher.NotifyCorrection(ctx, c2) + Expect(triggered).To(BeTrue()) + + logger.Log("Threshold crossing triggers improvement session") + }) + + It("Should deduplicate within time window", func() { + watcher := NewFeedbackLoopWatcher() + + ctx := context.Background() + config := FeedbackLoopConfig{ + MinCorrections: 2, + TimeWindowHours: 24, + AutoTriggerEnabled: true, + } + err := saveFeedbackLoopConfig(ctx, "test-project", config) + Expect(err).NotTo(HaveOccurred()) + + base := CorrectionNotification{ + Project: "test-project", + TargetType: "workflow", + TargetRepoURL: "https://github.com/org/repo", + TargetBranch: "main", + TargetPath: ".ambient/workflows/review", + CorrectionType: "incorrect", + Source: "human", + AgentAction: "did wrong thing", + UserCorrection: "do right thing", + Timestamp: time.Now(), + } + + // First threshold crossing + c1 := base + c1.TraceID = "trace-1" + c1.SessionName = "session-1" + watcher.NotifyCorrection(ctx, c1) + + c2 := base + c2.TraceID = "trace-2" + c2.SessionName = "session-2" + triggered := watcher.NotifyCorrection(ctx, c2) + Expect(triggered).To(BeTrue()) + + // More corrections should NOT trigger again (dedup within window) + c3 := base + c3.TraceID = "trace-3" + c3.SessionName = "session-3" + triggered = watcher.NotifyCorrection(ctx, c3) + Expect(triggered).To(BeFalse()) + + c4 := base + c4.TraceID = "trace-4" + c4.SessionName = "session-4" + triggered = watcher.NotifyCorrection(ctx, c4) + Expect(triggered).To(BeFalse()) + + logger.Log("Deduplication prevents multiple triggers within window") + }) + + It("Should not trigger when feature flag is disabled", func() { + IsFeedbackLoopEnabled = func() bool { return false } + + watcher := NewFeedbackLoopWatcher() + + ctx := context.Background() + config := FeedbackLoopConfig{ + MinCorrections: 2, + TimeWindowHours: 24, + AutoTriggerEnabled: true, + } + err := saveFeedbackLoopConfig(ctx, "test-project", config) + Expect(err).NotTo(HaveOccurred()) + + base := CorrectionNotification{ + Project: "test-project", + TargetType: "repo", + TargetRepoURL: "https://github.com/org/repo", + CorrectionType: "style", + Source: "human", + AgentAction: "action", + UserCorrection: "correction", + Timestamp: time.Now(), + } + + c1 := base + c1.TraceID = "trace-1" + watcher.NotifyCorrection(ctx, c1) + + c2 := base + c2.TraceID = "trace-2" + triggered := watcher.NotifyCorrection(ctx, c2) + Expect(triggered).To(BeFalse()) + + logger.Log("Feature flag disabled prevents triggering") + }) + + It("Should not trigger when autoTriggerEnabled is false", func() { + watcher := NewFeedbackLoopWatcher() + + ctx := context.Background() + config := FeedbackLoopConfig{ + MinCorrections: 2, + TimeWindowHours: 24, + AutoTriggerEnabled: false, + } + err := saveFeedbackLoopConfig(ctx, "test-project", config) + Expect(err).NotTo(HaveOccurred()) + + base := CorrectionNotification{ + Project: "test-project", + TargetType: "repo", + TargetRepoURL: "https://github.com/org/repo", + CorrectionType: "style", + Source: "human", + AgentAction: "action", + UserCorrection: "correction", + Timestamp: time.Now(), + } + + c1 := base + c1.TraceID = "trace-1" + watcher.NotifyCorrection(ctx, c1) + + c2 := base + c2.TraceID = "trace-2" + triggered := watcher.NotifyCorrection(ctx, c2) + Expect(triggered).To(BeFalse()) + + logger.Log("autoTriggerEnabled=false prevents triggering") + }) + + It("Should track different targets independently", func() { + watcher := NewFeedbackLoopWatcher() + + ctx := context.Background() + config := FeedbackLoopConfig{ + MinCorrections: 2, + TimeWindowHours: 24, + AutoTriggerEnabled: true, + } + err := saveFeedbackLoopConfig(ctx, "test-project", config) + Expect(err).NotTo(HaveOccurred()) + + // One correction for target A (workflow) + cA := CorrectionNotification{ + Project: "test-project", + TargetType: "workflow", + TargetRepoURL: "https://github.com/org/repo", + TargetBranch: "main", + TargetPath: ".ambient/workflows/review", + CorrectionType: "incorrect", + Source: "human", + AgentAction: "action", + UserCorrection: "correction", + TraceID: "trace-a1", + Timestamp: time.Now(), + } + watcher.NotifyCorrection(ctx, cA) + + // One correction for target B (repo) + cB := CorrectionNotification{ + Project: "test-project", + TargetType: "repo", + TargetRepoURL: "https://github.com/org/other-repo", + CorrectionType: "style", + Source: "human", + AgentAction: "action", + UserCorrection: "correction", + TraceID: "trace-b1", + Timestamp: time.Now(), + } + watcher.NotifyCorrection(ctx, cB) + + // Second correction for target B triggers + cB2 := cB + cB2.TraceID = "trace-b2" + triggered := watcher.NotifyCorrection(ctx, cB2) + Expect(triggered).To(BeTrue()) + + // Verify history has only target B + entries, err := loadFeedbackLoopHistory(ctx, "test-project") + Expect(err).NotTo(HaveOccurred()) + Expect(entries).To(HaveLen(1)) + Expect(entries[0].TargetRepoURL).To(Equal("https://github.com/org/other-repo")) + Expect(entries[0].TargetType).To(Equal("repo")) + Expect(entries[0].Source).To(Equal("event-driven")) + + logger.Log("Targets tracked independently") + }) + + It("Should record session labels correctly", func() { + watcher := NewFeedbackLoopWatcher() + + ctx := context.Background() + config := FeedbackLoopConfig{ + MinCorrections: 2, + TimeWindowHours: 24, + AutoTriggerEnabled: true, + } + err := saveFeedbackLoopConfig(ctx, "test-project", config) + Expect(err).NotTo(HaveOccurred()) + + base := CorrectionNotification{ + Project: "test-project", + TargetType: "workflow", + TargetRepoURL: "https://github.com/org/repo", + TargetBranch: "main", + TargetPath: ".ambient/workflows/review", + CorrectionType: "incorrect", + Source: "human", + AgentAction: "action", + UserCorrection: "correction", + Timestamp: time.Now(), + } + + c1 := base + c1.TraceID = "trace-1" + watcher.NotifyCorrection(ctx, c1) + + c2 := base + c2.TraceID = "trace-2" + triggered := watcher.NotifyCorrection(ctx, c2) + Expect(triggered).To(BeTrue()) + + // Verify the session was created with correct labels + gvr := GetAgenticSessionV1Alpha1Resource() + sessions, err := k8sUtils.DynamicClient.Resource(gvr).Namespace("test-project").List(ctx, metav1.ListOptions{}) + Expect(err).NotTo(HaveOccurred()) + Expect(sessions.Items).To(HaveLen(1)) + + session := sessions.Items[0] + labels := session.GetLabels() + Expect(labels["feedback-loop"]).To(Equal("true")) + Expect(labels["source"]).To(Equal("event-driven")) + Expect(labels["target-type"]).To(Equal("workflow")) + + logger.Log("Session labels set correctly") + }) + }) + + Describe("buildSessionDisplayName", func() { + It("Should format workflow display name", func() { + name := buildSessionDisplayName("workflow", "https://github.com/org/my-repo.git", ".ambient/workflows/review") + Expect(name).To(Equal("Feedback Loop: my-repo (review)")) + }) + + It("Should format repo display name", func() { + name := buildSessionDisplayName("repo", "https://github.com/org/my-repo", "") + Expect(name).To(Equal("Feedback Loop: my-repo (repo)")) + }) + + It("Should handle empty URL", func() { + name := buildSessionDisplayName("repo", "", "") + Expect(name).To(Equal("Feedback Loop: unknown (repo)")) + }) + }) +}) diff --git a/components/backend/handlers/learned.go b/components/backend/handlers/learned.go new file mode 100644 index 000000000..e311868e7 --- /dev/null +++ b/components/backend/handlers/learned.go @@ -0,0 +1,468 @@ +package handlers + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "strings" + "time" + + "ambient-code-backend/types" + + "github.com/gin-gonic/gin" +) + +// LearnedEntry represents a parsed learned file entry from docs/learned/ +type LearnedEntry struct { + Type string `json:"type"` + Date string `json:"date"` + Title string `json:"title"` + Session string `json:"session,omitempty"` + Project string `json:"project,omitempty"` + Author string `json:"author,omitempty"` + Content string `json:"content"` + FilePath string `json:"filePath"` +} + +// parseFrontmatter extracts YAML-like frontmatter key-value pairs from a +// markdown string delimited by "---". Returns the frontmatter map and the +// body text after the closing delimiter. Returns nil if no valid +// frontmatter is present. +func parseFrontmatter(content string) (map[string]string, string) { + if !strings.HasPrefix(content, "---") { + return nil, content + } + + parts := strings.SplitN(content, "---", 3) + if len(parts) < 3 { + return nil, content + } + + fm := make(map[string]string) + for _, line := range strings.Split(strings.TrimSpace(parts[1]), "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + idx := strings.Index(line, ":") + if idx < 0 { + continue + } + key := strings.TrimSpace(line[:idx]) + val := strings.TrimSpace(line[idx+1:]) + // Strip surrounding quotes + val = strings.Trim(val, "\"'") + if key != "" { + fm[key] = val + } + } + + body := strings.TrimSpace(parts[2]) + return fm, body +} + +// ListLearnedEntries handles GET /api/projects/:projectName/learned +// +// Reads docs/learned/ from the workspace repo via GitHub API and returns +// parsed entries with frontmatter metadata and content. +// +// Query parameters: +// - repo: repository URL (required) +// - ref: git ref/branch (required) +// - type: filter by entry type (optional, e.g. "correction") +// +// Uses GetK8sClientsForRequest for user-scoped RBAC. +func ListLearnedEntries(c *gin.Context) { + project := c.Param("projectName") + repo := c.Query("repo") + ref := c.Query("ref") + typeFilter := c.Query("type") + + if repo == "" || ref == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "repo and ref query parameters required"}) + return + } + + userID, _ := c.Get("userID") + reqK8s, reqDyn := GetK8sClientsForRequest(c) + + if reqK8s == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or missing token"}) + c.Abort() + return + } + + if userID == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing user context"}) + return + } + + // Detect provider — only GitHub is supported for learned files + provider := types.DetectProvider(repo) + if provider != types.ProviderGitHub { + c.JSON(http.StatusBadRequest, gin.H{"error": "learned files endpoint only supports GitHub repositories"}) + return + } + + token, err := GetGitHubTokenRepo(c.Request.Context(), reqK8s, reqDyn, project, userID.(string)) + if err != nil { + log.Printf("Failed to get GitHub token for learned endpoint, project %s: %v", project, err) + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or missing token"}) + return + } + + owner, repoName, err := parseOwnerRepo(repo) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + entries, err := fetchLearnedFiles(c, owner, repoName, ref, token) + if err != nil { + // If docs/learned/ doesn't exist, return empty array (not 404) + if strings.Contains(err.Error(), "404") || strings.Contains(err.Error(), "Not Found") { + c.JSON(http.StatusOK, gin.H{"entries": []LearnedEntry{}}) + return + } + c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("failed to fetch learned files: %v", err)}) + return + } + + // Apply type filter + if typeFilter != "" { + filtered := make([]LearnedEntry, 0, len(entries)) + for _, e := range entries { + if e.Type == typeFilter { + filtered = append(filtered, e) + } + } + entries = filtered + } + + c.JSON(http.StatusOK, gin.H{"entries": entries}) +} + +// fetchLearnedFiles retrieves and parses learned markdown files from the +// GitHub Contents API. It reads the top-level docs/learned/ directory and +// the corrections/ and patterns/ subdirectories. +func fetchLearnedFiles(c *gin.Context, owner, repo, ref, token string) ([]LearnedEntry, error) { + api := githubAPIBaseURL("github.com") + + // Collect .md file paths from docs/learned/ and its subdirectories + var mdPaths []string + + for _, dirPath := range []string{"docs/learned", "docs/learned/corrections", "docs/learned/patterns"} { + url := fmt.Sprintf("%s/repos/%s/%s/contents/%s?ref=%s", api, owner, repo, dirPath, ref) + resp, err := doGitHubRequest(c.Request.Context(), http.MethodGet, url, "Bearer "+token, "", nil) + if err != nil { + if dirPath == "docs/learned" { + return nil, fmt.Errorf("GitHub API request failed: %w", err) + } + continue + } + + if resp.StatusCode == http.StatusNotFound { + resp.Body.Close() + if dirPath == "docs/learned" { + return nil, fmt.Errorf("404 Not Found") + } + continue + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + if dirPath == "docs/learned" { + b, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return nil, fmt.Errorf("GitHub API error %d: %s", resp.StatusCode, string(b)) + } + resp.Body.Close() + continue + } + + var decoded interface{} + if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil { + resp.Body.Close() + continue + } + resp.Body.Close() + mdPaths = append(mdPaths, collectMDPaths(decoded)...) + } + + // Deduplicate paths + seen := make(map[string]bool) + uniquePaths := make([]string, 0, len(mdPaths)) + for _, p := range mdPaths { + if !seen[p] { + seen[p] = true + uniquePaths = append(uniquePaths, p) + } + } + + // Fetch and parse each file + var entries []LearnedEntry + for _, filePath := range uniquePaths { + fileURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s?ref=%s", api, owner, repo, filePath, ref) + fileResp, fileErr := doGitHubRequest(c.Request.Context(), http.MethodGet, fileURL, "Bearer "+token, "", nil) + if fileErr != nil { + log.Printf("Failed to fetch learned file %s: %v", filePath, fileErr) + continue + } + + if fileResp.StatusCode != http.StatusOK { + fileResp.Body.Close() + continue + } + + var fileObj map[string]interface{} + if json.NewDecoder(fileResp.Body).Decode(&fileObj) != nil { + fileResp.Body.Close() + continue + } + fileResp.Body.Close() + + rawContent, _ := fileObj["content"].(string) + encoding, _ := fileObj["encoding"].(string) + + var textContent string + if strings.ToLower(encoding) == "base64" { + raw := strings.ReplaceAll(rawContent, "\n", "") + data, decErr := base64.StdEncoding.DecodeString(raw) + if decErr != nil { + continue + } + textContent = string(data) + } else { + textContent = rawContent + } + + fm, body := parseFrontmatter(textContent) + if fm == nil { + continue + } + + entryType := fm["type"] + title := fm["title"] + date := fm["date"] + if entryType == "" || title == "" || date == "" { + continue + } + + entries = append(entries, LearnedEntry{ + Type: entryType, + Date: date, + Title: title, + Session: fm["session"], + Project: fm["project"], + Author: fm["author"], + Content: body, + FilePath: filePath, + }) + } + + return entries, nil +} + +// CreateLearnedPR handles POST /api/projects/:projectName/learned/create +// +// Creates a learned file on a new branch and opens a draft PR. +// Body: {"owner":"...","repo":"...","title":"...","content":"...","type":"correction|pattern"} +func CreateLearnedPR(c *gin.Context) { + project := c.Param("projectName") + + var req struct { + Owner string `json:"owner" binding:"required"` + Repo string `json:"repo" binding:"required"` + Title string `json:"title" binding:"required"` + Content string `json:"content" binding:"required"` + Type string `json:"type" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if req.Type != "correction" && req.Type != "pattern" { + c.JSON(http.StatusBadRequest, gin.H{"error": "type must be 'correction' or 'pattern'"}) + return + } + + userID, _ := c.Get("userID") + reqK8s, reqDyn := GetK8sClientsForRequest(c) + if reqK8s == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or missing token"}) + return + } + if userID == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing user context"}) + return + } + + repoURL := fmt.Sprintf("https://github.com/%s/%s", req.Owner, req.Repo) + token, err := GetGitHubTokenRepo(c.Request.Context(), reqK8s, reqDyn, project, userID.(string)) + if err != nil { + if fallback := os.Getenv("GITHUB_FALLBACK_TOKEN"); fallback != "" { + token = fallback + } else { + log.Printf("Failed to get GitHub token for learned PR, project %s: %v", project, err) + c.JSON(http.StatusUnauthorized, gin.H{"error": "GitHub authentication required. Connect GitHub via the Integrations page (GitHub App or PAT)."}) + return + } + } + _ = repoURL + + api := "https://api.github.com" + auth := "Bearer " + token + + // 1. Get default branch SHA + refResp, err := doGitHubRequest(c.Request.Context(), "GET", + fmt.Sprintf("%s/repos/%s/%s/git/ref/heads/main", api, req.Owner, req.Repo), + auth, "", nil) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("failed to get default branch: %v", err)}) + return + } + defer refResp.Body.Close() + if refResp.StatusCode != 200 { + body, _ := io.ReadAll(refResp.Body) + c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("failed to get default branch (%d): %s", refResp.StatusCode, string(body))}) + return + } + var refData struct { + Object struct { + SHA string `json:"sha"` + } `json:"object"` + } + json.NewDecoder(refResp.Body).Decode(&refData) + baseSHA := refData.Object.SHA + + // 2. Create branch + date := time.Now().Format("2006-01-02") + slug := slugRegex.ReplaceAllString(strings.ToLower(req.Title), "-") + slug = strings.Trim(slug, "-") + if len(slug) > 60 { + slug = slug[:60] + } + if slug == "" { + slug = "memory" + } + branchName := fmt.Sprintf("learned/%s-%s-%s", req.Type, date, slug) + + branchBody, _ := json.Marshal(map[string]interface{}{ + "ref": "refs/heads/" + branchName, + "sha": baseSHA, + }) + branchResp, err := doGitHubRequest(c.Request.Context(), "POST", + fmt.Sprintf("%s/repos/%s/%s/git/refs", api, req.Owner, req.Repo), + auth, "", bytes.NewReader(branchBody)) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("failed to create branch: %v", err)}) + return + } + defer branchResp.Body.Close() + if branchResp.StatusCode != 201 { + body, _ := io.ReadAll(branchResp.Body) + c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("failed to create branch (%d): %s", branchResp.StatusCode, string(body))}) + return + } + + // 3. Create file on branch + filePath := fmt.Sprintf("docs/learned/%ss/%s-%s.md", req.Type, date, slug) + now := time.Now().UTC().Format(time.RFC3339) + author := userID.(string) + fileContent := fmt.Sprintf("---\ntype: %s\ndate: %s\nauthor: %s\ntitle: \"%s\"\n---\n\n%s\n", + req.Type, now, author, req.Title, req.Content) + + fileBody, _ := json.Marshal(map[string]interface{}{ + "message": fmt.Sprintf("learned: %s", req.Title), + "content": base64.StdEncoding.EncodeToString([]byte(fileContent)), + "branch": branchName, + }) + fileResp, err := doGitHubRequest(c.Request.Context(), "PUT", + fmt.Sprintf("%s/repos/%s/%s/contents/%s", api, req.Owner, req.Repo, filePath), + auth, "", bytes.NewReader(fileBody)) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("failed to create file: %v", err)}) + return + } + defer fileResp.Body.Close() + if fileResp.StatusCode != 201 { + body, _ := io.ReadAll(fileResp.Body) + c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("failed to create file (%d): %s", fileResp.StatusCode, string(body))}) + return + } + + // 4. Create draft PR + prBody, _ := json.Marshal(map[string]interface{}{ + "title": fmt.Sprintf("learned: %s", req.Title), + "body": fmt.Sprintf("## New Memory\n\n**Type:** %s\n**Source:** Manual entry\n\n---\n\n%s", req.Type, req.Content), + "head": branchName, + "base": "main", + "draft": true, + }) + prResp, err := doGitHubRequest(c.Request.Context(), "POST", + fmt.Sprintf("%s/repos/%s/%s/pulls", api, req.Owner, req.Repo), + auth, "", bytes.NewReader(prBody)) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("failed to create PR: %v", err)}) + return + } + defer prResp.Body.Close() + prRespBody, _ := io.ReadAll(prResp.Body) + if prResp.StatusCode != 201 { + c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("failed to create PR (%d): %s", prResp.StatusCode, string(prRespBody))}) + return + } + + var prResult struct { + HTMLURL string `json:"html_url"` + Number int `json:"number"` + } + json.Unmarshal(prRespBody, &prResult) + + // 5. Add continuous-learning label (best-effort) + labelBody, _ := json.Marshal(map[string]interface{}{ + "labels": []string{"continuous-learning"}, + }) + labelResp, _ := doGitHubRequest(c.Request.Context(), "POST", + fmt.Sprintf("%s/repos/%s/%s/issues/%d/labels", api, req.Owner, req.Repo, prResult.Number), + auth, "", bytes.NewReader(labelBody)) + if labelResp != nil { + labelResp.Body.Close() + } + + c.JSON(http.StatusCreated, gin.H{ + "prUrl": prResult.HTMLURL, + "prNumber": prResult.Number, + }) +} + +// collectMDPaths extracts .md file paths from a GitHub API directory listing. +func collectMDPaths(decoded interface{}) []string { + var paths []string + + switch v := decoded.(type) { + case []interface{}: + for _, item := range v { + if m, ok := item.(map[string]interface{}); ok { + name, _ := m["name"].(string) + path, _ := m["path"].(string) + typ, _ := m["type"].(string) + if strings.ToLower(typ) == "file" && strings.HasSuffix(strings.ToLower(name), ".md") { + paths = append(paths, path) + } + } + } + case map[string]interface{}: + name, _ := v["name"].(string) + path, _ := v["path"].(string) + typ, _ := v["type"].(string) + if strings.ToLower(typ) == "file" && strings.HasSuffix(strings.ToLower(name), ".md") { + paths = append(paths, path) + } + } + + return paths +} diff --git a/components/backend/handlers/learned_test.go b/components/backend/handlers/learned_test.go new file mode 100644 index 000000000..7f49fe582 --- /dev/null +++ b/components/backend/handlers/learned_test.go @@ -0,0 +1,122 @@ +//go:build test + +package handlers + +import ( + test_constants "ambient-code-backend/tests/constants" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Learned Handler >", Label(test_constants.LabelUnit, test_constants.LabelHandlers), func() { + + Describe("parseFrontmatter", func() { + It("parses valid frontmatter with all fields", func() { + content := "---\ntype: correction\ndate: 2026-04-01T14:30:00Z\ntitle: Use Pydantic v2\nsession: session-1\nproject: my-project\nauthor: Agent\n---\n\nAlways use Pydantic v2 BaseModel." + fm, body := parseFrontmatter(content) + Expect(fm).NotTo(BeNil()) + Expect(fm["type"]).To(Equal("correction")) + Expect(fm["date"]).To(Equal("2026-04-01T14:30:00Z")) + Expect(fm["title"]).To(Equal("Use Pydantic v2")) + Expect(fm["session"]).To(Equal("session-1")) + Expect(fm["project"]).To(Equal("my-project")) + Expect(fm["author"]).To(Equal("Agent")) + Expect(body).To(Equal("Always use Pydantic v2 BaseModel.")) + }) + + It("handles missing frontmatter", func() { + content := "Just plain text." + fm, body := parseFrontmatter(content) + Expect(fm).To(BeNil()) + Expect(body).To(Equal("Just plain text.")) + }) + + It("handles incomplete frontmatter delimiters", func() { + content := "---\ntype: correction\nNo closing delimiter" + fm, body := parseFrontmatter(content) + Expect(fm).To(BeNil()) + Expect(body).To(Equal(content)) + }) + + It("strips quotes from values", func() { + content := "---\ntitle: \"Quoted Title\"\ntype: 'pattern'\ndate: 2026-04-01\n---\n\nBody." + fm, body := parseFrontmatter(content) + Expect(fm).NotTo(BeNil()) + Expect(fm["title"]).To(Equal("Quoted Title")) + Expect(fm["type"]).To(Equal("pattern")) + Expect(body).To(Equal("Body.")) + }) + + It("handles empty body", func() { + content := "---\ntype: pattern\ntitle: Empty\ndate: 2026-04-01\n---\n" + fm, body := parseFrontmatter(content) + Expect(fm).NotTo(BeNil()) + Expect(fm["type"]).To(Equal("pattern")) + Expect(body).To(Equal("")) + }) + + It("handles optional fields being absent", func() { + content := "---\ntype: pattern\ntitle: Minimal\ndate: 2026-04-01\n---\n\nBody text." + fm, body := parseFrontmatter(content) + Expect(fm).NotTo(BeNil()) + Expect(fm["type"]).To(Equal("pattern")) + Expect(fm["title"]).To(Equal("Minimal")) + Expect(fm["session"]).To(Equal("")) + Expect(fm["project"]).To(Equal("")) + Expect(fm["author"]).To(Equal("")) + Expect(body).To(Equal("Body text.")) + }) + }) + + Describe("collectMDPaths", func() { + It("collects .md files from array", func() { + input := []interface{}{ + map[string]interface{}{"name": "fix.md", "path": "docs/learned/corrections/fix.md", "type": "file"}, + map[string]interface{}{"name": "readme.txt", "path": "docs/learned/readme.txt", "type": "file"}, + map[string]interface{}{"name": "corrections", "path": "docs/learned/corrections", "type": "dir"}, + } + paths := collectMDPaths(input) + Expect(paths).To(HaveLen(1)) + Expect(paths[0]).To(Equal("docs/learned/corrections/fix.md")) + }) + + It("returns empty for empty array", func() { + paths := collectMDPaths([]interface{}{}) + Expect(paths).To(BeEmpty()) + }) + + It("handles single file object", func() { + input := map[string]interface{}{"name": "fix.md", "path": "docs/learned/fix.md", "type": "file"} + paths := collectMDPaths(input) + Expect(paths).To(HaveLen(1)) + Expect(paths[0]).To(Equal("docs/learned/fix.md")) + }) + + It("skips directories", func() { + input := []interface{}{ + map[string]interface{}{"name": "corrections", "path": "docs/learned/corrections", "type": "dir"}, + map[string]interface{}{"name": "patterns", "path": "docs/learned/patterns", "type": "dir"}, + } + paths := collectMDPaths(input) + Expect(paths).To(BeEmpty()) + }) + + It("skips non-md files", func() { + input := []interface{}{ + map[string]interface{}{"name": "image.png", "path": "docs/learned/image.png", "type": "file"}, + map[string]interface{}{"name": "notes.txt", "path": "docs/learned/notes.txt", "type": "file"}, + } + paths := collectMDPaths(input) + Expect(paths).To(BeEmpty()) + }) + + It("handles case-insensitive .MD extension", func() { + input := []interface{}{ + map[string]interface{}{"name": "FIX.MD", "path": "docs/learned/FIX.MD", "type": "file"}, + } + paths := collectMDPaths(input) + Expect(paths).To(HaveLen(1)) + }) + }) +}) diff --git a/components/backend/handlers/learning.go b/components/backend/handlers/learning.go new file mode 100644 index 000000000..ead2bed69 --- /dev/null +++ b/components/backend/handlers/learning.go @@ -0,0 +1,103 @@ +package handlers + +import ( + "net/http" + "strconv" + + "github.com/gin-gonic/gin" +) + +// LearningSummary represents aggregated learning metrics for a project. +type LearningSummary struct { + TotalCorrections int `json:"totalCorrections"` + CorrectionsByType map[string]int `json:"correctionsByType"` + ImprovementSessions int `json:"improvementSessions"` + MemoriesCreated int `json:"memoriesCreated"` + MemoryCitations int `json:"memoryCitations"` +} + +// TimelineEntry represents a single event in the learning timeline. +type TimelineEntry struct { + ID string `json:"id"` + Timestamp string `json:"timestamp"` + EventType string `json:"eventType"` + Summary string `json:"summary"` + CorrectionType string `json:"correctionType,omitempty"` + ImprovementSession string `json:"improvementSession,omitempty"` + MemoryID string `json:"memoryId,omitempty"` +} + +// TimelineResponse wraps timeline entries with pagination metadata. +type TimelineResponse struct { + Items []TimelineEntry `json:"items"` + TotalCount int `json:"totalCount"` + Page int `json:"page"` + PageSize int `json:"pageSize"` +} + +// GetLearningSummary returns aggregated learning metrics for a project. +// GET /api/projects/:projectName/learning/summary +// +// This endpoint returns correction counts, improvement session counts, +// and memory creation counts. Data is sourced from the corrections +// pipeline (spec 003) and project memory store (spec 002). +// +// Until those specs are implemented, this returns zero-value metrics. +func GetLearningSummary(c *gin.Context) { + _, dynClient := GetK8sClientsForRequest(c) + if dynClient == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"}) + return + } + + // projectName is validated by ValidateProjectContext middleware + // _ = c.Param("projectName") + + // TODO(spec-002, spec-003): Query actual data from corrections pipeline + // and project memory store once those specs are implemented. + // For now, return empty/zero-value summary. + summary := LearningSummary{ + TotalCorrections: 0, + CorrectionsByType: map[string]int{}, + ImprovementSessions: 0, + MemoriesCreated: 0, + MemoryCitations: 0, + } + + c.JSON(http.StatusOK, summary) +} + +// GetLearningTimeline returns a paginated, reverse-chronological list of +// correction events for a project. +// GET /api/projects/:projectName/learning/timeline?page=1&pageSize=20 +// +// Until specs 002/003 are implemented, this returns an empty list. +func GetLearningTimeline(c *gin.Context) { + _, dynClient := GetK8sClientsForRequest(c) + if dynClient == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"}) + return + } + + // Parse pagination params with defaults + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20")) + + if page < 1 { + page = 1 + } + if pageSize < 1 || pageSize > 100 { + pageSize = 20 + } + + // TODO(spec-002, spec-003): Query actual timeline events from + // corrections pipeline once implemented. + response := TimelineResponse{ + Items: []TimelineEntry{}, + TotalCount: 0, + Page: page, + PageSize: pageSize, + } + + c.JSON(http.StatusOK, response) +} diff --git a/components/backend/handlers/learning_test.go b/components/backend/handlers/learning_test.go new file mode 100644 index 000000000..b8285935b --- /dev/null +++ b/components/backend/handlers/learning_test.go @@ -0,0 +1,114 @@ +//go:build test + +package handlers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + + test_constants "ambient-code-backend/tests/constants" + + "github.com/gin-gonic/gin" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Learning Endpoints", Label(test_constants.LabelUnit), func() { + + var router *gin.Engine + + BeforeEach(func() { + gin.SetMode(gin.TestMode) + router = gin.New() + // Register routes with project context middleware + group := router.Group("/api/projects/:projectName", ValidateProjectContext()) + group.GET("/learning/summary", GetLearningSummary) + group.GET("/learning/timeline", GetLearningTimeline) + }) + + Describe("learning summary", func() { + It("returns 401 without auth header", func() { + req := httptest.NewRequest(http.MethodGet, "/api/projects/test-project/learning/summary", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("returns empty summary with auth", func() { + req := httptest.NewRequest(http.MethodGet, "/api/projects/test-project/learning/summary", nil) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK)) + + var summary LearningSummary + err := json.Unmarshal(w.Body.Bytes(), &summary) + Expect(err).NotTo(HaveOccurred()) + Expect(summary.TotalCorrections).To(Equal(0)) + Expect(summary.CorrectionsByType).To(BeEmpty()) + Expect(summary.ImprovementSessions).To(Equal(0)) + Expect(summary.MemoriesCreated).To(Equal(0)) + Expect(summary.MemoryCitations).To(Equal(0)) + }) + }) + + Describe("learning timeline", func() { + It("returns 401 without auth header", func() { + req := httptest.NewRequest(http.MethodGet, "/api/projects/test-project/learning/timeline", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("returns empty timeline with default pagination", func() { + req := httptest.NewRequest(http.MethodGet, "/api/projects/test-project/learning/timeline", nil) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK)) + + var response TimelineResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + Expect(err).NotTo(HaveOccurred()) + Expect(response.Items).To(BeEmpty()) + Expect(response.TotalCount).To(Equal(0)) + Expect(response.Page).To(Equal(1)) + Expect(response.PageSize).To(Equal(20)) + }) + + It("respects custom pagination parameters", func() { + req := httptest.NewRequest(http.MethodGet, "/api/projects/test-project/learning/timeline?page=3&pageSize=10", nil) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK)) + + var response TimelineResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + Expect(err).NotTo(HaveOccurred()) + Expect(response.Page).To(Equal(3)) + Expect(response.PageSize).To(Equal(10)) + }) + + It("clamps invalid pagination values", func() { + req := httptest.NewRequest(http.MethodGet, "/api/projects/test-project/learning/timeline?page=-1&pageSize=999", nil) + req.Header.Set("Authorization", "Bearer test-token") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK)) + + var response TimelineResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + Expect(err).NotTo(HaveOccurred()) + Expect(response.Page).To(Equal(1)) + Expect(response.PageSize).To(Equal(20)) + }) + }) +}) diff --git a/components/backend/handlers/sessions.go b/components/backend/handlers/sessions.go index 55df5788d..891e049d1 100755 --- a/components/backend/handlers/sessions.go +++ b/components/backend/handlers/sessions.go @@ -309,6 +309,10 @@ func parseStatus(status map[string]interface{}) *types.AgenticSessionStatus { } } + if extractionStatus, ok := status["extractionStatus"].(string); ok { + result.ExtractionStatus = extractionStatus + } + if repos, ok := status["reconciledRepos"].([]interface{}); ok && len(repos) > 0 { result.ReconciledRepos = make([]types.ReconciledRepo, 0, len(repos)) for _, entry := range repos { diff --git a/components/backend/main.go b/components/backend/main.go index c75827a70..2413220c6 100644 --- a/components/backend/main.go +++ b/components/backend/main.go @@ -191,6 +191,8 @@ func main() { // Initialize websocket package websocket.StateBaseDir = server.StateBaseDir handlers.DeriveAgentStatusFromEvents = websocket.DeriveAgentStatus + handlers.LoadEventsForExtraction = websocket.LoadEventsForSession + websocket.OnSessionRunComplete = handlers.TriggerExtractionAsync // Normal server mode if err := server.Run(registerRoutes); err != nil { diff --git a/components/backend/routes.go b/components/backend/routes.go index 9fb63050f..29f70095a 100755 --- a/components/backend/routes.go +++ b/components/backend/routes.go @@ -132,6 +132,15 @@ func registerRoutes(r *gin.Engine) { projectGroup.POST("/feature-flags/:flagName/enable", handlers.EnableFeatureFlag) projectGroup.POST("/feature-flags/:flagName/disable", handlers.DisableFeatureFlag) + // Learned files endpoints (project memory store) + projectGroup.GET("/learned", handlers.ListLearnedEntries) + projectGroup.POST("/learned/create", handlers.CreateLearnedPR) + + // Corrections pipeline endpoints (gated by learning-agent-loop feature flag) + projectGroup.POST("/corrections", handlers.PostCorrection) + projectGroup.GET("/corrections", handlers.ListCorrections) + projectGroup.GET("/corrections/summary", handlers.GetCorrectionsSummary) + // GitLab authentication endpoints (DEPRECATED - moved to cluster-scoped) // Kept for backward compatibility, will be removed in future version projectGroup.POST("/auth/gitlab/connect", handlers.ConnectGitLabGlobal) diff --git a/components/backend/tests/constants/labels.go b/components/backend/tests/constants/labels.go index bdc7a5d57..3f1e8f112 100644 --- a/components/backend/tests/constants/labels.go +++ b/components/backend/tests/constants/labels.go @@ -26,6 +26,7 @@ const ( LabelFeatureFlags = "feature-flags" LabelDisplayName = "display-name" LabelHealth = "health" + LabelCorrections = "corrections" // Specific component labels for other areas LabelOperations = "operations" // for git operations diff --git a/components/backend/types/session.go b/components/backend/types/session.go index 022822c57..4a76f42bc 100755 --- a/components/backend/types/session.go +++ b/components/backend/types/session.go @@ -59,6 +59,7 @@ type AgenticSessionStatus struct { ReconciledWorkflow *ReconciledWorkflow `json:"reconciledWorkflow,omitempty"` SDKSessionID string `json:"sdkSessionId,omitempty"` SDKRestartCount int `json:"sdkRestartCount,omitempty"` + ExtractionStatus string `json:"extractionStatus,omitempty"` Conditions []Condition `json:"conditions,omitempty"` } diff --git a/components/backend/websocket/agui_store.go b/components/backend/websocket/agui_store.go index 2b76550f4..a83f3e49b 100644 --- a/components/backend/websocket/agui_store.go +++ b/components/backend/websocket/agui_store.go @@ -24,6 +24,12 @@ import ( "time" ) +// OnSessionRunComplete is called when a RUN_FINISHED or RUN_ERROR event +// is persisted. Set by the handlers package at init to trigger post-session +// processing (e.g., insight extraction) without circular imports. +// Signature: func(projectName, sessionName string) +var OnSessionRunComplete func(projectName, sessionName string) + // ─── Write mutex eviction ──────────────────────────────────────────── // writeMutexes entries are evicted after writeMutexEvictAge of inactivity // to prevent unbounded sync.Map growth on long-running backends. @@ -209,6 +215,16 @@ func persistEvent(sessionID string, event map[string]interface{}) { default: log.Printf("AGUI Store: compaction skipped for %s (too many in-flight)", sessionID) } + + // Notify session completion callback (e.g., insight extraction) + if OnSessionRunComplete != nil { + if projectName, ok := sessionProjectMap.Load(sessionID); ok { + if pn, ok := projectName.(string); ok && pn != "" { + // Fire-and-forget — the callback is responsible for its own goroutine management + OnSessionRunComplete(pn, sessionID) + } + } + } } } @@ -472,6 +488,13 @@ func loadEventsForReplay(sessionID string) []map[string]interface{} { return events } +// LoadEventsForSession loads all AG-UI events for a session. +// Exported for cross-package access (e.g., post-session insight extraction). +// Returns nil if the session has no events or the session ID is invalid. +func LoadEventsForSession(sessionID string) []map[string]interface{} { + return loadEvents(sessionID) +} + // compactFinishedRun replaces the raw event log with snapshot-only events. // // Per AG-UI serialization spec, finished runs should only store: diff --git a/components/frontend/src/app/api/projects/[name]/corrections/route.ts b/components/frontend/src/app/api/projects/[name]/corrections/route.ts new file mode 100644 index 000000000..42c10b77f --- /dev/null +++ b/components/frontend/src/app/api/projects/[name]/corrections/route.ts @@ -0,0 +1,44 @@ +/** + * Corrections Endpoint Proxy + * Forwards user corrections to backend for Langfuse persistence and optional runner forwarding. + */ + +import { BACKEND_URL } from '@/lib/config' +import { buildForwardHeadersAsync } from '@/lib/auth' + +export const runtime = 'nodejs' +export const dynamic = 'force-dynamic' + +export async function POST( + request: Request, + { params }: { params: Promise<{ name: string }> }, +) { + try { + const { name } = await params + const headers = await buildForwardHeadersAsync(request) + const body = await request.text() + + const backendUrl = `${BACKEND_URL}/projects/${encodeURIComponent(name)}/corrections` + + const resp = await fetch(backendUrl, { + method: 'POST', + headers: { + ...headers, + 'Content-Type': 'application/json', + }, + body, + }) + + const data = await resp.text() + return new Response(data, { + status: resp.status, + headers: { 'Content-Type': 'application/json' }, + }) + } catch (error) { + console.error('Error submitting correction:', error) + return Response.json( + { error: 'Failed to submit correction', details: error instanceof Error ? error.message : String(error) }, + { status: 500 } + ) + } +} diff --git a/components/frontend/src/app/api/projects/[name]/learned/create/route.ts b/components/frontend/src/app/api/projects/[name]/learned/create/route.ts new file mode 100644 index 000000000..98670ea65 --- /dev/null +++ b/components/frontend/src/app/api/projects/[name]/learned/create/route.ts @@ -0,0 +1,105 @@ +import { NextRequest, NextResponse } from "next/server"; +import { BACKEND_URL } from "@/lib/config"; +import { buildForwardHeadersAsync } from "@/lib/auth"; +import { parseOwnerRepo } from "@/lib/github-utils"; + +/** + * POST /api/projects/:name/learned/create + * + * Creates a new memory by opening a draft PR in the target repo. + * Proxies to backend POST /projects/:name/learned/create. + */ +export async function POST( + request: NextRequest, + { params }: { params: Promise<{ name: string }> } +) { + try { + const { name: projectName } = await params; + const headers = await buildForwardHeadersAsync(request); + const body = await request.json(); + + const { title, content, type, repo } = body as { + title: string; + content: string; + type: "correction" | "pattern"; + repo?: string; + }; + + if (!title || !content || !type) { + return NextResponse.json( + { error: "title, content, and type are required" }, + { status: 400 } + ); + } + + if (!["correction", "pattern"].includes(type)) { + return NextResponse.json( + { error: "type must be 'correction' or 'pattern'" }, + { status: 400 } + ); + } + + // Resolve repo: use explicit repo from request, fall back to project annotation + let repoStr = repo || ""; + if (!repoStr) { + const projectRes = await fetch( + `${BACKEND_URL}/projects/${projectName}`, + { method: "GET", headers } + ); + if (projectRes.ok) { + const project = await projectRes.json(); + repoStr = + project?.data?.annotations?.["ambient.ai/repo"] || + project?.annotations?.["ambient.ai/repo"] || + ""; + } + } + + if (!repoStr) { + return NextResponse.json( + { error: "No repository specified. Enter a target repository (owner/repo)." }, + { status: 400 } + ); + } + + const ownerRepo = parseOwnerRepo(repoStr); + if (!ownerRepo) { + return NextResponse.json( + { error: "Invalid repository format. Use owner/repo (e.g. jeremyeder/continuous-learning-example)" }, + { status: 400 } + ); + } + + const res = await fetch( + `${BACKEND_URL}/projects/${projectName}/learned/create`, + { + method: "POST", + headers: { ...headers, "Content-Type": "application/json" }, + body: JSON.stringify({ + owner: ownerRepo.owner, + repo: ownerRepo.repo, + title, + content, + type, + }), + } + ); + + if (!res.ok) { + const errBody = await res.text(); + return NextResponse.json( + { error: errBody }, + { status: res.status } + ); + } + + const result = await res.json(); + return NextResponse.json(result); + } catch (error) { + console.error("Failed to create memory:", error); + return NextResponse.json( + { error: "Failed to create memory" }, + { status: 500 } + ); + } +} diff --git a/components/frontend/src/app/api/projects/[name]/learned/prs/route.ts b/components/frontend/src/app/api/projects/[name]/learned/prs/route.ts new file mode 100644 index 000000000..27a64115b --- /dev/null +++ b/components/frontend/src/app/api/projects/[name]/learned/prs/route.ts @@ -0,0 +1,93 @@ +import { NextRequest, NextResponse } from "next/server"; +import { BACKEND_URL } from "@/lib/config"; +import { buildForwardHeadersAsync } from "@/lib/auth"; + +/** + * GET /api/projects/:name/learned/prs + * + * Lists open draft PRs with the "continuous-learning" label. + */ +export async function GET( + request: NextRequest, + { params }: { params: Promise<{ name: string }> } +) { + try { + const { name: projectName } = await params; + const headers = await buildForwardHeadersAsync(request); + + // Get repo info from project + const projectRes = await fetch( + `${BACKEND_URL}/projects/${projectName}`, + { method: "GET", headers } + ); + if (!projectRes.ok) { + return NextResponse.json({ prs: [] }); + } + const project = await projectRes.json(); + const repoAnnotation = + project?.data?.annotations?.["ambient.ai/repo"] || + project?.annotations?.["ambient.ai/repo"] || + ""; + + if (!repoAnnotation) { + return NextResponse.json({ prs: [] }); + } + + const ownerRepo = parseOwnerRepo(repoAnnotation); + if (!ownerRepo) { + return NextResponse.json({ prs: [] }); + } + + // Fetch open PRs from GitHub API + const searchUrl = `https://api.github.com/repos/${ownerRepo.owner}/${ownerRepo.repo}/pulls?state=open&per_page=50`; + const ghHeaders: Record = { + Accept: "application/vnd.github.v3+json", + }; + + try { + const ghRes = await fetch(searchUrl, { headers: ghHeaders }); + if (!ghRes.ok) { + return NextResponse.json({ prs: [] }); + } + + const allPrs = await ghRes.json(); + + type GHLabel = { name: string }; + type GHUser = { login: string }; + type GHPR = { + draft: boolean; + number: number; + title: string; + html_url: string; + created_at: string; + user: GHUser; + body: string; + labels: GHLabel[]; + }; + + const filteredPrs = (allPrs as GHPR[]).filter( + (pr: GHPR) => + pr.draft === true && + pr.labels.some((l: GHLabel) => l.name === "continuous-learning") + ); + + const prs = filteredPrs.map((pr: GHPR) => ({ + number: pr.number, + title: pr.title, + url: pr.html_url, + createdAt: pr.created_at, + author: pr.user?.login || "", + body: pr.body || "", + })); + + return NextResponse.json({ prs }); + } catch { + return NextResponse.json({ prs: [] }); + } + } catch (error) { + console.error("Failed to fetch learned draft PRs:", error); + return NextResponse.json({ prs: [] }); + } +} + +import { parseOwnerRepo } from "@/lib/github-utils"; diff --git a/components/frontend/src/app/api/projects/[name]/learned/route.ts b/components/frontend/src/app/api/projects/[name]/learned/route.ts new file mode 100644 index 000000000..8e2b30deb --- /dev/null +++ b/components/frontend/src/app/api/projects/[name]/learned/route.ts @@ -0,0 +1,183 @@ +import { NextRequest, NextResponse } from "next/server"; +import { BACKEND_URL } from "@/lib/config"; +import { buildForwardHeadersAsync } from "@/lib/auth"; + +/** + * GET /api/projects/:name/learned + * + * Reads docs/learned/ from the workspace repo via the backend repo tree/blob + * endpoints. Parses frontmatter from each markdown file to build entries. + */ +export async function GET( + request: NextRequest, + { params }: { params: Promise<{ name: string }> } +) { + try { + const { name: projectName } = await params; + const headers = await buildForwardHeadersAsync(request); + const searchParams = request.nextUrl.searchParams; + const typeFilter = searchParams.get("type") || ""; + const page = parseInt(searchParams.get("page") || "0", 10); + const pageSize = parseInt(searchParams.get("pageSize") || "50", 10); + + // Get the project to find repo annotation + const projectRes = await fetch( + `${BACKEND_URL}/projects/${projectName}`, + { method: "GET", headers } + ); + if (!projectRes.ok) { + return NextResponse.json({ entries: [], totalCount: 0 }); + } + const project = await projectRes.json(); + const repoAnnotation = + project?.data?.annotations?.["ambient.ai/repo"] || + project?.annotations?.["ambient.ai/repo"] || + ""; + + if (!repoAnnotation) { + return NextResponse.json({ entries: [], totalCount: 0 }); + } + + // Fetch tree for docs/learned/corrections/ and docs/learned/patterns/ + const types = typeFilter ? [typeFilter] : ["correction", "pattern"]; + type ParsedEntry = { + title: string; + type: "correction" | "pattern"; + date: string; + author: string; + contentPreview: string; + filePath: string; + source: string; + session: string; + }; + const allEntries: ParsedEntry[] = []; + + for (const t of types) { + const dir = `docs/learned/${t}s`; + const treeParams = new URLSearchParams({ + repo: repoAnnotation, + ref: "HEAD", + path: dir, + }); + + const treeRes = await fetch( + `${BACKEND_URL}/projects/${projectName}/repo/tree?${treeParams.toString()}`, + { method: "GET", headers } + ); + + if (!treeRes.ok) continue; + + const treeData = await treeRes.json(); + const entries = + treeData?.data?.entries || treeData?.entries || []; + + for (const entry of entries as Array<{ path?: string; name?: string }>) { + const fileName = entry.path || entry.name || ""; + if (!fileName.endsWith(".md")) continue; + + const blobParams = new URLSearchParams({ + repo: repoAnnotation, + ref: "HEAD", + path: `${dir}/${fileName}`, + }); + + try { + const blobRes = await fetch( + `${BACKEND_URL}/projects/${projectName}/repo/blob?${blobParams.toString()}`, + { method: "GET", headers } + ); + + if (!blobRes.ok) continue; + + const blobText = await blobRes.text(); + const parsed = parseFrontmatter( + blobText, + `${dir}/${fileName}`, + t as "correction" | "pattern" + ); + if (parsed) { + allEntries.push(parsed); + } + } catch { + continue; + } + } + } + + // Sort by date descending + allEntries.sort((a, b) => b.date.localeCompare(a.date)); + + // Paginate + const totalCount = allEntries.length; + const start = page * pageSize; + const paged = allEntries.slice(start, start + pageSize); + + return NextResponse.json({ entries: paged, totalCount }); + } catch (error) { + console.error("Failed to fetch learned files:", error); + return NextResponse.json({ entries: [], totalCount: 0 }); + } +} + +/** + * Parse YAML frontmatter from a markdown file. + */ +function parseFrontmatter( + content: string, + filePath: string, + fallbackType: "correction" | "pattern" +): { + title: string; + type: "correction" | "pattern"; + date: string; + author: string; + contentPreview: string; + filePath: string; + source: string; + session: string; +} | null { + const fmMatch = content.match(/^---\s*\n([\s\S]*?)\n---\s*\n([\s\S]*)/); + if (!fmMatch) { + const filename = filePath.split("/").pop() || ""; + const bodyContent = content.trim(); + return { + title: filename.replace(/\.md$/, ""), + type: fallbackType, + date: "", + author: "", + contentPreview: bodyContent.slice(0, 200), + filePath, + source: "", + session: "", + }; + } + + const frontmatterBlock = fmMatch[1]; + const body = fmMatch[2].trim(); + + // Simple YAML key: value parsing (no nested structures needed) + const fm: Record = {}; + for (const line of frontmatterBlock.split("\n")) { + const colonIdx = line.indexOf(":"); + if (colonIdx === -1) continue; + const key = line.slice(0, colonIdx).trim(); + const value = line.slice(colonIdx + 1).trim(); + fm[key] = value; + } + + const validTypes = ["correction", "pattern"]; + const parsedType = validTypes.includes(fm.type) + ? (fm.type as "correction" | "pattern") + : fallbackType; + + return { + title: fm.title || filePath.split("/").pop()?.replace(/\.md$/, "") || "", + type: parsedType, + date: fm.date || "", + author: fm.author || fm.source || "", + contentPreview: body.slice(0, 200), + filePath, + source: fm.source || "", + session: fm.session || "", + }; +} diff --git a/components/frontend/src/app/api/projects/[name]/learning/summary/route.ts b/components/frontend/src/app/api/projects/[name]/learning/summary/route.ts new file mode 100644 index 000000000..caa89cd91 --- /dev/null +++ b/components/frontend/src/app/api/projects/[name]/learning/summary/route.ts @@ -0,0 +1,33 @@ +import { BACKEND_URL } from "@/lib/config"; +import { buildForwardHeadersAsync } from "@/lib/auth"; + +export async function GET( + request: Request, + { params }: { params: Promise<{ name: string }> } +) { + try { + const { name } = await params; + const headers = await buildForwardHeadersAsync(request); + + const response = await fetch( + `${BACKEND_URL}/projects/${encodeURIComponent(name)}/learning/summary`, + { headers } + ); + + if (!response.ok) { + const errorData = await response + .json() + .catch(() => ({ error: "Unknown error" })); + return Response.json(errorData, { status: response.status }); + } + + const data = await response.json(); + return Response.json(data); + } catch (error) { + console.error("Error fetching learning summary:", error); + return Response.json( + { error: "Failed to fetch learning summary" }, + { status: 500 } + ); + } +} diff --git a/components/frontend/src/app/api/projects/[name]/learning/timeline/route.ts b/components/frontend/src/app/api/projects/[name]/learning/timeline/route.ts new file mode 100644 index 000000000..a7e09d1a1 --- /dev/null +++ b/components/frontend/src/app/api/projects/[name]/learning/timeline/route.ts @@ -0,0 +1,36 @@ +import { BACKEND_URL } from "@/lib/config"; +import { buildForwardHeadersAsync } from "@/lib/auth"; + +export async function GET( + request: Request, + { params }: { params: Promise<{ name: string }> } +) { + try { + const { name } = await params; + const headers = await buildForwardHeadersAsync(request); + const url = new URL(request.url); + const page = url.searchParams.get("page") || "1"; + const pageSize = url.searchParams.get("pageSize") || "20"; + + const response = await fetch( + `${BACKEND_URL}/projects/${encodeURIComponent(name)}/learning/timeline?page=${page}&pageSize=${pageSize}`, + { headers } + ); + + if (!response.ok) { + const errorData = await response + .json() + .catch(() => ({ error: "Unknown error" })); + return Response.json(errorData, { status: response.status }); + } + + const data = await response.json(); + return Response.json(data); + } catch (error) { + console.error("Error fetching learning timeline:", error); + return Response.json( + { error: "Failed to fetch learning timeline" }, + { status: 500 } + ); + } +} diff --git a/components/frontend/src/app/projects/[name]/learning/__tests__/learning-page.test.tsx b/components/frontend/src/app/projects/[name]/learning/__tests__/learning-page.test.tsx new file mode 100644 index 000000000..c5db81aa2 --- /dev/null +++ b/components/frontend/src/app/projects/[name]/learning/__tests__/learning-page.test.tsx @@ -0,0 +1,129 @@ +import { render, screen } from "@testing-library/react"; +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import LearningPage from "../page"; +import type { LearningSummary, TimelineResponse } from "@/types/learning"; + +// Mock Next.js navigation +vi.mock("next/navigation", () => ({ + useParams: () => ({ name: "test-project" }), + useRouter: () => ({ push: vi.fn() }), + usePathname: () => "/projects/test-project/learning", +})); + +// Mock feature flag hook - use a mutable object so tests can override +const flagState = { enabled: true, isLoading: false, source: "unleash" as const, error: null }; +vi.mock("@/services/queries/use-feature-flags-admin", () => ({ + useWorkspaceFlag: () => flagState, +})); + +// Mock learning hooks - use mutable state objects +const summaryState: { + data: LearningSummary | undefined; + isLoading: boolean; + error: null; +} = { + data: undefined, + isLoading: false, + error: null, +}; +const timelineState: { + data: TimelineResponse | undefined; + isLoading: boolean; + error: null; +} = { + data: undefined, + isLoading: false, + error: null, +}; + +vi.mock("@/services/queries/use-learning", () => ({ + useLearningSummary: () => summaryState, + useLearningTimeline: () => timelineState, +})); + +function renderPage() { + const queryClient = new QueryClient({ + defaultOptions: { queries: { retry: false } }, + }); + return render( + + + + ); +} + +describe("LearningPage", () => { + beforeEach(() => { + summaryState.data = undefined; + summaryState.isLoading = false; + timelineState.data = undefined; + timelineState.isLoading = false; + flagState.enabled = true; + flagState.isLoading = false; + }); + + it("shows empty state when no data", () => { + summaryState.data = { + totalCorrections: 0, + correctionsByType: {}, + improvementSessions: 0, + memoriesCreated: 0, + memoryCitations: 0, + }; + timelineState.data = { + items: [], + totalCount: 0, + page: 1, + pageSize: 20, + }; + + renderPage(); + expect(screen.getByText("No learning data yet")).toBeTruthy(); + }); + + it("shows disabled message when flag is off", () => { + flagState.enabled = false; + + renderPage(); + expect( + screen.getByText(/learning agent loop feature is not enabled/) + ).toBeTruthy(); + }); + + it("renders summary cards with data", () => { + summaryState.data = { + totalCorrections: 10, + correctionsByType: { incomplete: 4, style: 6 }, + improvementSessions: 3, + memoriesCreated: 5, + memoryCitations: 12, + }; + timelineState.data = { + items: [], + totalCount: 0, + page: 1, + pageSize: 20, + }; + + renderPage(); + expect(screen.getByText("10")).toBeTruthy(); + expect(screen.getByText("3")).toBeTruthy(); + expect(screen.getByText("5")).toBeTruthy(); + expect(screen.getByText("12")).toBeTruthy(); + }); + + it("renders heading", () => { + summaryState.data = { + totalCorrections: 0, + correctionsByType: {}, + improvementSessions: 0, + memoriesCreated: 0, + memoryCitations: 0, + }; + timelineState.data = { items: [], totalCount: 0, page: 1, pageSize: 20 }; + + renderPage(); + expect(screen.getByText("Learning")).toBeTruthy(); + }); +}); diff --git a/components/frontend/src/app/projects/[name]/learning/page.tsx b/components/frontend/src/app/projects/[name]/learning/page.tsx new file mode 100644 index 000000000..bac1f1241 --- /dev/null +++ b/components/frontend/src/app/projects/[name]/learning/page.tsx @@ -0,0 +1,303 @@ +"use client"; + +import { useParams } from "next/navigation"; +import { + BookOpen, + GitPullRequestArrow, + Brain, + MessageSquareWarning, + Sparkles, +} from "lucide-react"; +import { + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { Skeleton } from "@/components/ui/skeleton"; +import { + useLearningSummary, + useLearningTimeline, +} from "@/services/queries/use-learning"; +import { useWorkspaceFlag } from "@/services/queries/use-feature-flags-admin"; +import type { TimelineEntry } from "@/types/learning"; +import { CORRECTION_TYPE_LABELS } from "@/services/api/corrections"; + +/** Colors for correction type breakdown bars. */ +const CORRECTION_TYPE_COLORS: Record = { + incomplete: "bg-amber-500", + incorrect: "bg-red-500", + out_of_scope: "bg-blue-500", + style: "bg-purple-500", +}; + +function SummaryCard({ + title, + value, + icon: Icon, + description, + isLoading, +}: { + title: string; + value: number; + icon: React.ComponentType<{ className?: string }>; + description: string; + isLoading: boolean; +}) { + return ( + + +
+ + {title} +
+
+ + {isLoading ? ( + + ) : ( +
{value}
+ )} + {description} +
+
+ ); +} + +function CorrectionBreakdown({ + correctionsByType, + total, + isLoading, +}: { + correctionsByType: Record; + total: number; + isLoading: boolean; +}) { + if (isLoading) { + return ( + + + Corrections by Type + + +
+ {Array.from({ length: 4 }).map((_, i) => ( + + ))} +
+
+
+ ); + } + + const types = Object.entries(correctionsByType); + + if (types.length === 0) { + return ( + + + Corrections by Type + + +

+ No correction data yet +

+
+
+ ); + } + + return ( + + + Corrections by Type + + +
+ {types.map(([type, count]) => { + const pct = total > 0 ? (count / total) * 100 : 0; + return ( +
+
+ + {CORRECTION_TYPE_LABELS[type as keyof typeof CORRECTION_TYPE_LABELS] || type} + + + {count} ({Math.round(pct)}%) + +
+
+
+
+
+ ); + })} +
+ + + ); +} + +function TimelineItem({ entry }: { entry: TimelineEntry }) { + return ( +
+
+
+
+
+

{entry.summary}

+
+ {new Date(entry.timestamp).toLocaleDateString()} + {entry.correctionType && ( + + {CORRECTION_TYPE_LABELS[entry.correctionType as keyof typeof CORRECTION_TYPE_LABELS] || + entry.correctionType} + + )} + {entry.memoryId && ( + {entry.memoryId} + )} +
+
+
+ ); +} + +function EmptyState() { + return ( +
+ +

No learning data yet

+

+ As your team works with agents, corrections and feedback are captured + automatically. The learning pipeline transforms these into project + memories that improve future sessions. +

+

+ Start by running sessions and providing corrections when the agent makes + mistakes. Each correction feeds the improvement loop. +

+
+ ); +} + +export default function LearningPage() { + const params = useParams(); + const projectName = params?.name as string; + + const { enabled: flagEnabled, isLoading: flagLoading } = useWorkspaceFlag( + projectName, + "learning-agent-loop" + ); + + const { data: summary, isLoading: summaryLoading } = + useLearningSummary(projectName); + + const { data: timeline, isLoading: timelineLoading } = + useLearningTimeline(projectName); + + // Hide the page entirely when the flag is off (not just loading) + if (!flagLoading && !flagEnabled) { + return ( +
+

+ The learning agent loop feature is not enabled for this workspace. +

+
+ ); + } + + const isLoading = summaryLoading || timelineLoading; + const isEmpty = + !isLoading && + summary?.totalCorrections === 0 && + (!timeline?.items || timeline.items.length === 0); + + return ( +
+
+

Learning

+

+ Track how corrections improve your agents over time. +

+
+ + {isEmpty ? ( + + ) : ( + <> + {/* Summary cards */} +
+ + + + +
+ +
+ {/* Correction breakdown */} + + + {/* Timeline */} + + + Recent Activity + + + {timelineLoading ? ( +
+ {Array.from({ length: 5 }).map((_, i) => ( + + ))} +
+ ) : timeline?.items && timeline.items.length > 0 ? ( +
+ {timeline.items.map((entry) => ( + + ))} +
+ ) : ( +

+ No recent activity +

+ )} +
+
+
+ + )} +
+ ); +} diff --git a/components/frontend/src/app/projects/[name]/memory/page.tsx b/components/frontend/src/app/projects/[name]/memory/page.tsx new file mode 100644 index 000000000..77c27d07e --- /dev/null +++ b/components/frontend/src/app/projects/[name]/memory/page.tsx @@ -0,0 +1,50 @@ +'use client'; + +import { useParams } from 'next/navigation'; +import { ProjectMemorySection } from '@/components/workspace-sections/project-memory-section'; +import { useWorkspaceFlag } from '@/services/queries/use-feature-flags-admin'; + +export default function ProjectMemoryPage() { + const params = useParams(); + const projectName = params?.name as string; + const { enabled: memoryEnabled, isLoading: flagLoading } = useWorkspaceFlag( + projectName, + 'learning-agent-loop' + ); + + if (!projectName) return null; + + if (flagLoading) { + return ( +
+
+
+
+
+
+ ); + } + + if (!memoryEnabled) { + return ( +
+
+

Project Memory

+

+ This feature is not enabled for this workspace. Enable the{" "} + + learning-agent-loop + {" "} + feature flag in Workspace Settings to use Project Memory. +

+
+
+ ); + } + + return ( +
+ +
+ ); +} diff --git a/components/frontend/src/app/projects/[name]/sessions/[sessionName]/components/sessions-sidebar.tsx b/components/frontend/src/app/projects/[name]/sessions/[sessionName]/components/sessions-sidebar.tsx index ba40b4020..6c3bfa44c 100644 --- a/components/frontend/src/app/projects/[name]/sessions/[sessionName]/components/sessions-sidebar.tsx +++ b/components/frontend/src/app/projects/[name]/sessions/[sessionName]/components/sessions-sidebar.tsx @@ -28,6 +28,8 @@ import { Share2, Key, Settings, + BookOpen, + Brain, MoreHorizontal, MoreVertical, Cpu, @@ -51,6 +53,7 @@ import { useProjectAccess } from "@/services/queries/use-project-access"; import { useVersion } from "@/services/queries/use-version"; import { cn } from "@/lib/utils"; import type { AgenticSession } from "@/types/api"; +import { useWorkspaceFlag } from "@/services/queries/use-feature-flags-admin"; type SessionsSidebarProps = { projectName: string; @@ -117,6 +120,8 @@ export function SessionsSidebar({ const hasMore = sessions.length > INITIAL_RECENTS_COUNT && !showAll; + const { enabled: learningEnabled } = useWorkspaceFlag(projectName, "learning-agent-loop"); + const navItems: NavItem[] = useMemo( () => [ { @@ -139,13 +144,25 @@ export function SessionsSidebar({ icon: Key, href: `/projects/${projectName}/keys`, }, + ...(learningEnabled ? [ + { + label: "Project Memory", + icon: Brain, + href: `/projects/${projectName}/memory`, + }, + { + label: "Learning", + icon: BookOpen, + href: `/projects/${projectName}/learning`, + }, + ] : []), { label: "Workspace Settings", icon: Settings, href: `/projects/${projectName}/settings`, }, ], - [projectName] + [projectName, learningEnabled] ); if (collapsed) return null; diff --git a/components/frontend/src/components/__tests__/memory-citation-badge.test.tsx b/components/frontend/src/components/__tests__/memory-citation-badge.test.tsx new file mode 100644 index 000000000..869b5e6c4 --- /dev/null +++ b/components/frontend/src/components/__tests__/memory-citation-badge.test.tsx @@ -0,0 +1,112 @@ +import { render, screen } from "@testing-library/react"; +import { describe, it, expect, vi } from "vitest"; +import { + MemoryCitationBadge, + MemoryCitationSummary, +} from "../memory-citation-badge"; + +// Mock Radix UI Popover for testability +vi.mock("radix-ui", () => ({ + Popover: { + Root: ({ children }: { children: React.ReactNode }) =>
{children}
, + Trigger: ({ + children, + ...props + }: { children: React.ReactNode } & Record) => ( + + ), + Portal: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + Content: ({ + children, + ...props + }: { children: React.ReactNode } & Record) => ( +
+ {children} +
+ ), + }, + Tooltip: { + Provider: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + Root: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + Trigger: ({ children }: { children: React.ReactNode }) => ( + {children} + ), + Portal: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + Content: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + Arrow: () => null, + }, +})); + +describe("MemoryCitationBadge", () => { + it("renders memory ID in badge", () => { + render(); + // Badge text appears in both the trigger and the popover header + const matches = screen.getAllByText("PM-042"); + expect(matches.length).toBeGreaterThanOrEqual(1); + }); + + it("renders with summary when provided", () => { + render( + + ); + const matches = screen.getAllByText("PM-042"); + expect(matches.length).toBeGreaterThanOrEqual(1); + }); + + it("renders warning state for non-existent memory (FR-015)", () => { + render(); + const badge = screen.getByText("PM-999"); + expect(badge).toBeTruthy(); + // Should have warning indicator via tooltip + expect(screen.getByTestId("tooltip-content")).toBeTruthy(); + }); + + it("renders deprecated state with data attribute (FR-014)", () => { + render(); + const matches = screen.getAllByText("PM-042"); + // Find the one inside a button with data-deprecated + const deprecatedButton = matches + .map((el) => el.closest("[data-deprecated]")) + .find(Boolean); + expect(deprecatedButton).toBeTruthy(); + }); + + it("shows popover content with full details", () => { + render( + + ); + expect(screen.getByTestId("popover-content")).toBeTruthy(); + expect(screen.getByText(/Full memory content here/)).toBeTruthy(); + }); +}); + +describe("MemoryCitationSummary", () => { + it("renders citation count when over 10 (FR-016)", () => { + render(); + expect(screen.getByText("12 memories cited")).toBeTruthy(); + }); + + it("does not render when count is 10 or fewer", () => { + const { container } = render(); + expect(container.textContent).toBe(""); + }); +}); diff --git a/components/frontend/src/components/feedback/CorrectionPopover.tsx b/components/frontend/src/components/feedback/CorrectionPopover.tsx new file mode 100644 index 000000000..0924ab18b --- /dev/null +++ b/components/frontend/src/components/feedback/CorrectionPopover.tsx @@ -0,0 +1,277 @@ +"use client"; + +import React, { useState, useCallback } from "react"; +import { PencilLine, Check, Loader2 } from "lucide-react"; +import { cn } from "@/lib/utils"; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from "@/components/ui/popover"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Textarea } from "@/components/ui/textarea"; +import { Checkbox } from "@/components/ui/checkbox"; +import { Button } from "@/components/ui/button"; +import { Label } from "@/components/ui/label"; +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from "@/components/ui/tooltip"; +import { useFeedbackContextOptional } from "@/contexts/FeedbackContext"; +import type { CorrectionType } from "@/services/api/corrections"; +import { submitCorrection, CORRECTION_TYPE_LABELS } from "@/services/api/corrections"; + +const MIN_CORRECTION_LENGTH = 10; +const MAX_CORRECTION_LENGTH = 2000; +const SUCCESS_COOLDOWN_MS = 2000; + +type CorrectionPopoverProps = { + messageId?: string; + messageContent?: string; + className?: string; +}; + +export function CorrectionPopover({ + messageId, + messageContent, + className, +}: CorrectionPopoverProps) { + const [open, setOpen] = useState(false); + const [correctionType, setCorrectionType] = useState(""); + const [correctionText, setCorrectionText] = useState(""); + const [includeContent, setIncludeContent] = useState(false); + const [isSubmitting, setIsSubmitting] = useState(false); + const [error, setError] = useState(null); + const [submittedCount, setSubmittedCount] = useState(0); + const [submitCooldown, setSubmitCooldown] = useState(false); + + const feedbackContext = useFeedbackContextOptional(); + + const charCount = correctionText.length; + const isValid = + correctionType !== "" && + charCount >= MIN_CORRECTION_LENGTH && + charCount <= MAX_CORRECTION_LENGTH; + const canSubmit = isValid && !isSubmitting && !submitCooldown; + + const resetForm = useCallback(() => { + setCorrectionType(""); + setCorrectionText(""); + setIncludeContent(false); + setError(null); + }, []); + + const handleSubmit = async () => { + if (!canSubmit || !feedbackContext || !messageId) return; + + setIsSubmitting(true); + setError(null); + + try { + await submitCorrection(feedbackContext.projectName, { + correction_type: correctionType as CorrectionType, + user_correction: correctionText, + session_name: feedbackContext.sessionName, + message_id: messageId, + message_content: includeContent ? messageContent : undefined, + source: "ui", + }); + + setSubmittedCount((prev) => prev + 1); + resetForm(); + setOpen(false); + + // Prevent rapid re-submission + setSubmitCooldown(true); + setTimeout(() => setSubmitCooldown(false), SUCCESS_COOLDOWN_MS); + } catch (err) { + setError( + err instanceof Error ? err.message : "Failed to submit correction" + ); + } finally { + setIsSubmitting(false); + } + }; + + const handleCancel = () => { + resetForm(); + setOpen(false); + }; + + // Don't render if no context available + if (!feedbackContext) { + return null; + } + + const hasSubmitted = submittedCount > 0; + + return ( + + + + + + + + + + {hasSubmitted ? "Correction submitted" : "Correct this"} + + + + { + // Prevent closing when submitting (Select dropdown renders in a portal) + if (isSubmitting) { + e.preventDefault(); + } + }} + > +
+
Correct this response
+ + {/* Correction type */} +
+ + +
+ + {/* Free-text correction */} +
+ + +
+ +
+ +
+
+
+
+ +
Click ✎ Correct this to try the interactive popover
+ + + +
+
+ 006 +

Visible Attribution

+

Phase 3

+
+
Dependencies: 002, 003, 005
+
+
+

Agent cites memories inline

+
+
+
+ session-debug-latency +
+
+
+
A
+
I've analyzed the production error rates. Per the project's established practice PM-042, I excluded all int-* instances.

The filtered p99 latency is 280ms with a 0.8% error rate across 712 production instances. Running analysis using table-driven tests PM-071 as the project prefers.
+
+
+
+
+
+

Corrections Impact Dashboard

+
+
Corrections
47
↑ 12 this week
+
Improvement Sessions
18
↑ 4 this week
+
Memories Created
23
↑ 6 this week
+
Citations
89
↑ 31 this week
+
+
+

Recent Activity

+
+
2h ago
Memory PM-089 created from insight extraction
+
5h ago
Improvement session triggered for review-workflow
+
6h ago
Correction: incorrect — included internal instances
+
+
+
+
+
+
GET/api/projects/:name/learning/summaryDashboard metrics
+
GET/api/projects/:name/learning/timelineActivity feed
+
+
+ + +
+
+ 007 +

Event-Driven Feedback Loop

+

Phase 2

+
+
Dependencies: 003
+
+
+
+
+

Correction Threshold Monitor

+ +
+
+
+ review-workflow +
+ 4 ✓ +
+
+ platform/backend +
+ 1 +
+
+ deploy-workflow +
+ 0 +
+
+
+ Improvement session auto-triggered for review-workflow (4 corrections in 3h) +
+
+
+
+
+

Feedback Loop Configuration

+
+
Auto-triggerCreate sessions on threshold
+
+
+
+
Min CorrectionsPer target to trigger
+ +
+
+
Time WindowRolling window in hours
+ +
+
+
Weekly SweepGHA batch still runs
+
✓ Active — Mon 9am UTC
+
+
+
+
+
+
GET/api/projects/:name/feedback-loop/configGet config
+
PUT/api/projects/:name/feedback-loop/configUpdate config
+
GET/api/projects/:name/feedback-loop/historyTriggered sessions
+
+
+ + +
+
+ 008 +

Cross-Session Memory

+

Phase 2

+
+
Dependencies: 002
+
+
+

Agent suggests a memory via MCP tool

+
+
session-onboard-redis
+
+
+
A
+
+
🔧 Tool call: suggest_memory
+
+
💡 New Memory Suggested
+
"The staging cluster uses a shared Redis at redis-staging.internal:6379. Production uses per-tenant Redis with TLS."
+
+ Environment + Pending Review +
+
+ I've noted this for the team. This discovery will be available for review in the Project Memory panel. +
+
+
+
+
+
+

Injected into system prompt

+
+
## Project Memory

+
### Environment
+
- [PM-042] Staging uses shared Redis at redis-staging.internal:6379
+
- [PM-055] CI runners use Ubuntu 22.04 with Go 1.22 and Node 20

+
### Procedure
+
- [PM-018] Always run make lint before committing Go changes
+
- [PM-033] Use table-driven tests for Go backend test cases

+
### Preference
+
- [PM-071] Exclude int-* instances from production error analysis

+
5 memories loaded · 1,240 tokens · Budget: 4,000
+
+
+
+
+ + +
+
+ 009 +

Post-Session Insight Extraction

+

Phase 3

+
+
Dependencies: 002, 008
+
+
01
Session Completes
Status changed
12 turns detected
+
02
📜
Fetch Transcript
AG-UI event store
Compact and format
+
03
🤖
LLM Extraction
Haiku-class model
Structured JSON output
+
04
🗃
Write Candidates
Project Memory Store
status = candidate
+
05
👤
Human Review
Approve / Edit
Dismiss irrelevant
+
+
+ +
+
+
+

Extracted Candidates

+
+
+ Candidate + Environment + confidence: 0.92 +
+
The staging cluster uses a shared Redis instance at redis-staging.internal:6379. Production uses per-tenant Redis with TLS.
+ +
+
+
+ Candidate + Procedure + confidence: 0.78 +
+
When debugging Redis connectivity, check the ConfigMap redis-config first for connection strings and TLS settings.
+ +
+
+
+

Extraction Configuration

+
+
EnabledAuto-extract after sessions
+
ModelFor extraction calls
+
Max per sessionCandidate limit
+
Min turnsSkip short sessions
+
+
+
GET/api/projects/:name/insight-extraction/configGet config
+
PUT/api/projects/:name/insight-extraction/configUpdate config
+
+
+
+
+ + + +
+ + + + diff --git a/specs/learning-agent-realistic-mockup.html b/specs/learning-agent-realistic-mockup.html new file mode 100644 index 000000000..e69584412 --- /dev/null +++ b/specs/learning-agent-realistic-mockup.html @@ -0,0 +1,1745 @@ + + + + + +Ambient Code Platform - Learning Agent Mockup + + + + + + + + +
+ + +
+ + +
+ +
+
+
+ +
+ + + + +