Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
906 changes: 906 additions & 0 deletions internal/rollback/store.go

Large diffs are not rendered by default.

716 changes: 716 additions & 0 deletions internal/rollback/store_test.go

Large diffs are not rendered by default.

148 changes: 132 additions & 16 deletions internal/tools/apply_patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"strings"

"github.com/1024XEngineer/bytemind/internal/llm"
rollbackpkg "github.com/1024XEngineer/bytemind/internal/rollback"
)

type ApplyPatchTool struct{}
Expand All @@ -34,6 +35,13 @@ type patchHunk struct {
Lines []patchLine
}

type plannedPatchOperation struct {
Type string
OldPath string
NewPath string
Content string
}

var unifiedHunkHeaderPattern = regexp.MustCompile(`^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@(?: .*)?$`)

func (ApplyPatchTool) Definition() llm.ToolDefinition {
Expand All @@ -56,7 +64,7 @@ func (ApplyPatchTool) Definition() llm.ToolDefinition {
}
}

func (ApplyPatchTool) Run(_ context.Context, raw json.RawMessage, execCtx *ExecutionContext) (string, error) {
func (ApplyPatchTool) Run(ctx context.Context, raw json.RawMessage, execCtx *ExecutionContext) (string, error) {
var args struct {
Patch string `json:"patch"`
}
Expand All @@ -78,10 +86,31 @@ func (ApplyPatchTool) Run(_ context.Context, raw json.RawMessage, execCtx *Execu
i := 1
operations := make([]map[string]any, 0, 4)
diffFiles := make([]DiffFile, 0, 4)
planned := make([]plannedPatchOperation, 0, 4)
targets := make([]rollbackpkg.FileTarget, 0, 4)
touchedPaths := map[string]string{}
for i < len(lines) {
line := lines[i]
if line == "*** End Patch" {
return toJSON(buildPatchResult(operations, diffFiles))
tracker, err := beginRollbackOperation(ctx, execCtx, "apply_patch", targets)
if err != nil {
return "", err
}
if err := applyPlannedPatchOperations(planned); err != nil {
if tracker != nil {
tracker.abort(ctx, err.Error())
}
return "", err
}
operationID, err := tracker.commit(ctx)
if err != nil {
return "", err
}
result := buildPatchResult(operations, diffFiles)
if operationID != "" {
result["rollback_operation_id"] = operationID
}
return toJSON(result)
}

switch {
Expand All @@ -100,14 +129,26 @@ func (ApplyPatchTool) Run(_ context.Context, raw json.RawMessage, execCtx *Execu
if err != nil {
return "", err
}
if err := os.MkdirAll(filepath.Dir(resolved), 0o755); err != nil {
if _, err := os.Stat(resolved); err == nil {
return "", fmt.Errorf("add file target already exists: %s", path)
} else if err != nil && !errors.Is(err, os.ErrNotExist) {
return "", err
}
content := joinLines(contentLines, true, patchLineEnding)
if err := os.WriteFile(resolved, []byte(content), 0o644); err != nil {
if err := registerPatchTarget(touchedPaths, resolved, path); err != nil {
return "", err
}
content := joinLines(contentLines, true, patchLineEnding)
relPath := filepath.ToSlash(mustRel(execCtx.Workspace, resolved))
planned = append(planned, plannedPatchOperation{
Type: "add",
NewPath: resolved,
Content: content,
})
targets = append(targets, rollbackpkg.FileTarget{
Path: relPath,
AbsPath: resolved,
OpType: rollbackpkg.OpTypeAdd,
})
operations = append(operations, map[string]any{"type": "add", "path": relPath})
diffFiles = append(diffFiles, DiffFile{
Path: relPath,
Expand All @@ -122,10 +163,19 @@ func (ApplyPatchTool) Run(_ context.Context, raw json.RawMessage, execCtx *Execu
return "", err
}
removed := lineCount(resolved)
if err := os.Remove(resolved); err != nil {
relPath := filepath.ToSlash(mustRel(execCtx.Workspace, resolved))
if err := registerPatchTarget(touchedPaths, resolved, path); err != nil {
return "", err
}
relPath := filepath.ToSlash(mustRel(execCtx.Workspace, resolved))
planned = append(planned, plannedPatchOperation{
Type: "delete",
OldPath: resolved,
})
targets = append(targets, rollbackpkg.FileTarget{
Path: relPath,
AbsPath: resolved,
OpType: rollbackpkg.OpTypeDelete,
})
operations = append(operations, map[string]any{"type": "delete", "path": relPath})
diffFiles = append(diffFiles, DiffFile{
Path: relPath,
Expand All @@ -148,6 +198,21 @@ func (ApplyPatchTool) Run(_ context.Context, raw json.RawMessage, execCtx *Execu
}
i++
}
if newPath != oldPath {
if _, err := os.Stat(newPath); err == nil {
return "", fmt.Errorf("move target already exists: %s", filepath.ToSlash(mustRel(execCtx.Workspace, newPath)))
} else if err != nil && !errors.Is(err, os.ErrNotExist) {
return "", err
}
}
if err := registerPatchTarget(touchedPaths, oldPath, path); err != nil {
return "", err
}
if newPath != oldPath {
if err := registerPatchTarget(touchedPaths, newPath, filepath.ToSlash(mustRel(execCtx.Workspace, newPath))); err != nil {
return "", err
}
}
chunkLines := make([]string, 0, 64)
for i < len(lines) && !strings.HasPrefix(lines[i], "*** ") {
chunkLines = append(chunkLines, lines[i])
Expand All @@ -161,19 +226,28 @@ func (ApplyPatchTool) Run(_ context.Context, raw json.RawMessage, execCtx *Execu
if err != nil {
return "", err
}
if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil {
return "", err
relOld := filepath.ToSlash(mustRel(execCtx.Workspace, oldPath))
relNew := filepath.ToSlash(mustRel(execCtx.Workspace, newPath))
opType := rollbackpkg.OpTypeUpdate
if newPath != oldPath {
opType = rollbackpkg.OpTypeMove
}
if err := os.WriteFile(newPath, []byte(updated), 0o644); err != nil {
return "", err
planned = append(planned, plannedPatchOperation{
Type: "update",
OldPath: oldPath,
NewPath: newPath,
Content: updated,
})
target := rollbackpkg.FileTarget{
Path: relOld,
AbsPath: oldPath,
OpType: opType,
}
if newPath != oldPath {
if err := os.Remove(oldPath); err != nil {
return "", err
}
target.NewPath = relNew
target.NewAbsPath = newPath
}
relOld := filepath.ToSlash(mustRel(execCtx.Workspace, oldPath))
relNew := filepath.ToSlash(mustRel(execCtx.Workspace, newPath))
targets = append(targets, target)
operations = append(operations, map[string]any{
"type": "update",
"path": relOld,
Expand Down Expand Up @@ -201,6 +275,48 @@ func (ApplyPatchTool) Run(_ context.Context, raw json.RawMessage, execCtx *Execu
return "", errors.New("patch missing *** End Patch")
}

func registerPatchTarget(seen map[string]string, absPath, label string) error {
key := filepath.Clean(absPath)
if previous, ok := seen[key]; ok {
return fmt.Errorf("patch touches %s more than once (already used by %s)", label, previous)
}
seen[key] = label
return nil
}

func applyPlannedPatchOperations(operations []plannedPatchOperation) error {
for _, op := range operations {
switch op.Type {
case "add":
if err := os.MkdirAll(filepath.Dir(op.NewPath), 0o755); err != nil {
return err
}
if err := os.WriteFile(op.NewPath, []byte(op.Content), 0o644); err != nil {
return err
}
case "delete":
if err := os.Remove(op.OldPath); err != nil {
return err
}
case "update":
if err := os.MkdirAll(filepath.Dir(op.NewPath), 0o755); err != nil {
return err
}
if err := os.WriteFile(op.NewPath, []byte(op.Content), 0o644); err != nil {
return err
}
if op.NewPath != op.OldPath {
if err := os.Remove(op.OldPath); err != nil {
return err
}
}
default:
return fmt.Errorf("unsupported planned patch operation: %s", op.Type)
}
}
return nil
}

func applyStructuredPatch(original string, chunkLines []string) (string, error) {
hunks, err := parsePatchHunks(chunkLines)
if err != nil {
Expand Down
23 changes: 21 additions & 2 deletions internal/tools/replace_in_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"

"github.com/1024XEngineer/bytemind/internal/llm"
rollbackpkg "github.com/1024XEngineer/bytemind/internal/rollback"
)

type ReplaceInFileTool struct{}
Expand Down Expand Up @@ -45,7 +46,7 @@ func (ReplaceInFileTool) Definition() llm.ToolDefinition {
}
}

func (ReplaceInFileTool) Run(_ context.Context, raw json.RawMessage, execCtx *ExecutionContext) (string, error) {
func (ReplaceInFileTool) Run(ctx context.Context, raw json.RawMessage, execCtx *ExecutionContext) (string, error) {
var args struct {
Path string `json:"path"`
Old string `json:"old"`
Expand Down Expand Up @@ -78,12 +79,27 @@ func (ReplaceInFileTool) Run(_ context.Context, raw json.RawMessage, execCtx *Ex
} else {
updated = strings.Replace(content, args.Old, args.New, 1)
}
relPath := filepath.ToSlash(mustRel(execCtx.Workspace, path))
tracker, err := beginRollbackOperation(ctx, execCtx, "replace_in_file", []rollbackpkg.FileTarget{{
Path: relPath,
AbsPath: path,
OpType: rollbackpkg.OpTypeUpdate,
}})
if err != nil {
return "", err
}
if err := os.WriteFile(path, []byte(updated), 0o644); err != nil {
if tracker != nil {
tracker.abort(ctx, err.Error())
}
return "", err
}
operationID, err := tracker.commit(ctx)
if err != nil {
return "", err
}

var result map[string]any
relPath := filepath.ToSlash(mustRel(execCtx.Workspace, path))
if diffPreview := buildReplaceDiff(content, args.Old, args.New, args.ReplaceAll, relPath); diffPreview != nil {
result = map[string]any{
"ok": true,
Expand All @@ -100,6 +116,9 @@ func (ReplaceInFileTool) Run(_ context.Context, raw json.RawMessage, execCtx *Ex
"old_count": count,
}
}
if operationID != "" {
result["rollback_operation_id"] = operationID
}
return toJSON(result)
}

Expand Down
77 changes: 77 additions & 0 deletions internal/tools/rollback_helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package tools

import (
"context"
"strings"

rollbackpkg "github.com/1024XEngineer/bytemind/internal/rollback"
)

type rollbackTracker struct {
store *rollbackpkg.Store
op *rollbackpkg.Operation
roots []string
}

func beginRollbackOperation(ctx context.Context, execCtx *ExecutionContext, toolName string, targets []rollbackpkg.FileTarget) (*rollbackTracker, error) {
if !rollbackEnabled(execCtx) {
return nil, nil
}
if len(targets) == 0 {
return nil, nil
}
store, err := rollbackpkg.NewDefaultStore()
if err != nil {
return nil, err
}
op, err := store.Begin(ctx, rollbackpkg.BeginOptions{
Workspace: strings.TrimSpace(execCtx.Workspace),
SessionID: rollbackSessionID(execCtx),
TraceID: strings.TrimSpace(execCtx.RunID),
ToolName: toolName,
Actor: "agent",
}, targets)
if err != nil {
return nil, err
}
return &rollbackTracker{
store: store,
op: op,
roots: writableRootsFromExecContext(execCtx),
}, nil
}

func rollbackEnabled(execCtx *ExecutionContext) bool {
if execCtx == nil {
return false
}
return strings.TrimSpace(execCtx.RunID) != ""
}

func rollbackSessionID(execCtx *ExecutionContext) string {
if execCtx == nil || execCtx.Session == nil {
return ""
}
return strings.TrimSpace(execCtx.Session.ID)
}

func (t *rollbackTracker) commit(ctx context.Context) (string, error) {
if t == nil || t.store == nil || t.op == nil {
return "", nil
}
if err := t.store.Commit(ctx, t.op); err != nil {
restoreErr := t.store.AbortAndRestore(ctx, t.op, err.Error(), t.roots...)
if restoreErr != nil {
return "", restoreErr
}
return "", err
}
return t.op.OperationID, nil
}

func (t *rollbackTracker) abort(ctx context.Context, reason string) {
if t == nil || t.store == nil || t.op == nil {
return
}
_ = t.store.AbortAndRestore(ctx, t.op, reason, t.roots...)
}
Loading
Loading