diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml deleted file mode 100644 index fd65da18..00000000 --- a/.github/workflows/claude-code-review.yml +++ /dev/null @@ -1,57 +0,0 @@ -name: Claude Code Review - -on: - pull_request: - types: [opened, synchronize] - # Optional: Only run on specific file changes - # paths: - # - "src/**/*.ts" - # - "src/**/*.tsx" - # - "src/**/*.js" - # - "src/**/*.jsx" - -jobs: - claude-review: - # Optional: Filter by PR author - # if: | - # github.event.pull_request.user.login == 'external-contributor' || - # github.event.pull_request.user.login == 'new-developer' || - # github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR' - - runs-on: ubuntu-latest - permissions: - contents: read - pull-requests: read - issues: read - id-token: write - - steps: - - name: Checkout repository - uses: actions/checkout@v6 - with: - fetch-depth: 1 - - - name: Run Claude Code Review - id: claude-review - uses: anthropics/claude-code-action@v1 - with: - claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} - prompt: | - REPO: ${{ github.repository }} - PR NUMBER: ${{ github.event.pull_request.number }} - - Please review this pull request and provide feedback on: - - Code quality and best practices - - Potential bugs or issues - - Performance considerations - - Security concerns - - Test coverage - - Use the repository's CLAUDE.md for guidance on style and conventions. Be constructive and helpful in your feedback. - - Use `gh pr comment` with your Bash tool to leave your review as a comment on the PR. - - # See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md - # or https://code.claude.com/docs/en/cli-reference for available options - claude_args: '--allowed-tools "Bash(gh issue view:*),Bash(gh search:*),Bash(gh issue list:*),Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*)"' - diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml deleted file mode 100644 index 94dcce55..00000000 --- a/.github/workflows/claude.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: Claude Code - -on: - issue_comment: - types: [created] - pull_request_review_comment: - types: [created] - issues: - types: [opened, assigned] - pull_request_review: - types: [submitted] - -jobs: - claude: - if: | - (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) || - (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) || - (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) || - (github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude'))) - runs-on: ubuntu-latest - permissions: - contents: read - pull-requests: read - issues: read - id-token: write - actions: read # Required for Claude to read CI results on PRs - steps: - - name: Checkout repository - uses: actions/checkout@v6 - with: - fetch-depth: 1 - - - name: Run Claude Code - id: claude - uses: anthropics/claude-code-action@v1 - with: - claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} - - # This is an optional setting that allows Claude to read CI results on PRs - additional_permissions: | - actions: read - - # Optional: Give a custom prompt to Claude. If this is not specified, Claude will perform the instructions specified in the comment that tagged it. - # prompt: 'Update the pull request description to include a summary of changes.' - - # Optional: Add claude_args to customize behavior and configuration - # See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md - # or https://code.claude.com/docs/en/cli-reference for available options - # claude_args: '--allowed-tools Bash(gh pr:*)' - diff --git a/.gitignore b/.gitignore index 278435e3..d7cf440d 100644 --- a/.gitignore +++ b/.gitignore @@ -39,6 +39,7 @@ Thumbs.db # AI assistant CLAUDE.md AGENTS.md +.opencode/ .serena/ .sisyphus/ diff --git a/custom/ec2/instances/actions.go b/custom/ec2/instances/actions.go index 6cf1df9d..333b1123 100644 --- a/custom/ec2/instances/actions.go +++ b/custom/ec2/instances/actions.go @@ -45,7 +45,7 @@ func init() { Name: "SSM Session", Shortcut: "x", Type: action.ActionTypeExec, - Command: "aws ssm start-session --target ${ID}", + Args: []string{"aws", "ssm", "start-session", "--target", "${ID}"}, }, }) diff --git a/custom/rds/instances/actions.go b/custom/rds/instances/actions.go index bbd9b782..c6cbe89f 100644 --- a/custom/rds/instances/actions.go +++ b/custom/rds/instances/actions.go @@ -8,7 +8,6 @@ import ( rdsClient "github.com/clawscli/claws/custom/rds" "github.com/clawscli/claws/internal/action" - appaws "github.com/clawscli/claws/internal/aws" "github.com/clawscli/claws/internal/dao" ) @@ -152,22 +151,14 @@ func executeDeleteInstance(ctx context.Context, resource dao.Resource) action.Ac return action.InvalidResourceResult() } - client, err := rdsClient.GetClient(ctx) + d, err := NewInstanceDAO(ctx) if err != nil { return action.ActionResult{Success: false, Error: err} } identifier := instance.GetID() - skipFinalSnapshot := true - input := &rds.DeleteDBInstanceInput{ - DBInstanceIdentifier: &identifier, - SkipFinalSnapshot: &skipFinalSnapshot, - DeleteAutomatedBackups: appaws.BoolPtr(true), - } - - _, err = client.DeleteDBInstance(ctx, input) - if err != nil { - return action.ActionResult{Success: false, Error: fmt.Errorf("delete db instance: %w", err)} + if err := d.Delete(ctx, identifier); err != nil { + return action.ActionResult{Success: false, Error: err} } return action.ActionResult{ diff --git a/custom/vpc/tgw-attachments/dao.go b/custom/vpc/tgw-attachments/dao.go index cc4961b3..f16c6414 100644 --- a/custom/vpc/tgw-attachments/dao.go +++ b/custom/vpc/tgw-attachments/dao.go @@ -85,7 +85,10 @@ func (d *TGWAttachmentDAO) Delete(ctx context.Context, id string) error { if err != nil { return err } - attRes := att.(*TGWAttachmentResource) + attRes, ok := att.(*TGWAttachmentResource) + if !ok { + return fmt.Errorf("unexpected transit gateway attachment resource type %T", att) + } switch attRes.ResourceType() { case "vpc": diff --git a/custom/vpc/vpcs/integration_test.go b/custom/vpc/vpcs/integration_test.go index 9b4c10e2..34150d49 100644 --- a/custom/vpc/vpcs/integration_test.go +++ b/custom/vpc/vpcs/integration_test.go @@ -98,8 +98,8 @@ func TestIntegration_VPCDAO_ServiceInfo(t *testing.T) { t.Fatalf("Failed to create VPCDAO: %v", err) } - if dao.ServiceName() != "ec2" { - t.Errorf("ServiceName() = %q, want %q", dao.ServiceName(), "ec2") + if dao.ServiceName() != "vpc" { + t.Errorf("ServiceName() = %q, want %q", dao.ServiceName(), "vpc") } if dao.ResourceType() != "vpcs" { t.Errorf("ResourceType() = %q, want %q", dao.ResourceType(), "vpcs") diff --git a/custom/wafv2/web-acls/dao.go b/custom/wafv2/web-acls/dao.go index b499fba8..0079e05c 100644 --- a/custom/wafv2/web-acls/dao.go +++ b/custom/wafv2/web-acls/dao.go @@ -2,16 +2,20 @@ package webacls import ( "context" + "errors" "fmt" "github.com/aws/aws-sdk-go-v2/service/wafv2" "github.com/aws/aws-sdk-go-v2/service/wafv2/types" appaws "github.com/clawscli/claws/internal/aws" + appconfig "github.com/clawscli/claws/internal/config" "github.com/clawscli/claws/internal/dao" apperrors "github.com/clawscli/claws/internal/errors" ) +const cloudFrontWebACLRegion = "us-east-1" + // WebACLDAO provides data access for WAFv2 Web ACLs type WebACLDAO struct { dao.BaseDAO @@ -41,19 +45,26 @@ func (d *WebACLDAO) List(ctx context.Context) ([]dao.Resource, error) { } resources = append(resources, regionalResources...) - // List CLOUDFRONT Web ACLs (only available in us-east-1) - // We'll try to list CloudFront scope but it may fail if not in us-east-1 + // List CLOUDFRONT Web ACLs only from us-east-1, where WAFv2 exposes CloudFront scope. + if currentRegion(ctx) != cloudFrontWebACLRegion { + return resources, nil + } cloudfrontResources, err := d.listByScope(ctx, types.ScopeCloudfront) if err != nil { - // CloudFront scope may fail if not in us-east-1, ignore this error - // and just return regional resources - return resources, nil + return resources, fmt.Errorf("list cloudfront web acls: %w", err) } resources = append(resources, cloudfrontResources...) return resources, nil } +func currentRegion(ctx context.Context) string { + if region := appaws.GetRegionFromContext(ctx); region != "" { + return region + } + return appconfig.Global().Region() +} + func (d *WebACLDAO) listByScope(ctx context.Context, scope types.Scope) ([]dao.Resource, error) { acls, err := appaws.Paginate(ctx, func(token *string) ([]types.WebACLSummary, *string, error) { output, err := d.client.ListWebACLs(ctx, &wafv2.ListWebACLsInput{ @@ -82,9 +93,11 @@ func (d *WebACLDAO) listByScope(ctx context.Context, scope types.Scope) ([]dao.R func (d *WebACLDAO) Get(ctx context.Context, id string) (dao.Resource, error) { // Parse the composite ID (scope/name/id) // For simplicity, we'll search through both scopes - for _, scope := range []types.Scope{types.ScopeRegional, types.ScopeCloudfront} { + var scopeErrs []error + for _, scope := range d.scopes(ctx) { resources, err := d.listByScope(ctx, scope) if err != nil { + scopeErrs = append(scopeErrs, fmt.Errorf("list %s web acls: %w", scope, err)) continue } @@ -98,9 +111,19 @@ func (d *WebACLDAO) Get(ctx context.Context, id string) (dao.Resource, error) { } } + if len(scopeErrs) > 0 { + return nil, errors.Join(append([]error{fmt.Errorf("web acl %s not found", id)}, scopeErrs...)...) + } return nil, fmt.Errorf("web acl %s not found", id) } +func (d *WebACLDAO) scopes(ctx context.Context) []types.Scope { + if currentRegion(ctx) == cloudFrontWebACLRegion { + return []types.Scope{types.ScopeRegional, types.ScopeCloudfront} + } + return []types.Scope{types.ScopeRegional} +} + func (d *WebACLDAO) getWebACLDetail(ctx context.Context, summary *WebACLResource) (*WebACLResource, error) { input := &wafv2.GetWebACLInput{ Name: summary.Summary.Name, diff --git a/custom/wafv2/web-acls/dao_test.go b/custom/wafv2/web-acls/dao_test.go new file mode 100644 index 00000000..62722449 --- /dev/null +++ b/custom/wafv2/web-acls/dao_test.go @@ -0,0 +1,90 @@ +package webacls + +import ( + "bytes" + "context" + "io" + "net/http" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/wafv2" + + appaws "github.com/clawscli/claws/internal/aws" +) + +func TestListSkipsCloudFrontScopeOutsideUSEast1(t *testing.T) { + client := &recordingHTTPClient{} + d := newTestWebACLDAO(client) + + _, err := d.List(appaws.WithRegionOverride(context.Background(), "us-west-2")) + if err != nil { + t.Fatalf("List() returned error: %v", err) + } + + if got := client.scopeCalls("REGIONAL"); got != 1 { + t.Fatalf("REGIONAL calls = %d, want 1; bodies=%v", got, client.bodies) + } + if got := client.scopeCalls("CLOUDFRONT"); got != 0 { + t.Fatalf("CLOUDFRONT calls = %d, want 0; bodies=%v", got, client.bodies) + } +} + +func TestListIncludesCloudFrontScopeInUSEast1(t *testing.T) { + client := &recordingHTTPClient{} + d := newTestWebACLDAO(client) + + _, err := d.List(appaws.WithRegionOverride(context.Background(), cloudFrontWebACLRegion)) + if err != nil { + t.Fatalf("List() returned error: %v", err) + } + + if got := client.scopeCalls("REGIONAL"); got != 1 { + t.Fatalf("REGIONAL calls = %d, want 1; bodies=%v", got, client.bodies) + } + if got := client.scopeCalls("CLOUDFRONT"); got != 1 { + t.Fatalf("CLOUDFRONT calls = %d, want 1; bodies=%v", got, client.bodies) + } +} + +func newTestWebACLDAO(httpClient aws.HTTPClient) *WebACLDAO { + return &WebACLDAO{ + client: wafv2.NewFromConfig(aws.Config{ + Region: cloudFrontWebACLRegion, + HTTPClient: httpClient, + Credentials: aws.CredentialsProviderFunc(func(context.Context) (aws.Credentials, error) { + return aws.Credentials{AccessKeyID: "test", SecretAccessKey: "test", Source: "test"}, nil + }), + }), + } +} + +type recordingHTTPClient struct { + bodies []string +} + +func (c *recordingHTTPClient) Do(req *http.Request) (*http.Response, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + c.bodies = append(c.bodies, string(body)) + + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/x-amz-json-1.1"}}, + Body: io.NopCloser(bytes.NewBufferString(`{"WebACLs":[]}`)), + }, nil +} + +func (c *recordingHTTPClient) scopeCalls(scope string) int { + needle := `"Scope":"` + scope + `"` + count := 0 + for _, body := range c.bodies { + if strings.Contains(body, needle) { + count++ + } + } + return count +} diff --git a/docs/ai-chat.ja.md b/docs/ai-chat.ja.md index bfb9da52..65a2f290 100644 --- a/docs/ai-chat.ja.md +++ b/docs/ai-chat.ja.md @@ -60,10 +60,12 @@ ai: - サービスやリージョンをまたいでAWSリソースの一覧取得やクエリを実行 - 特定のリソースの詳細情報を取得 - 対応リソース(Lambda、ECS、CodeBuildなど)のCloudWatchログを取得 -- AWSドキュメントを検索 +- 一般的で公開可能なAWS用語でAWSドキュメントを検索 AIは現在のプロファイル、リージョン、リソースコンテキストを自動的に使用します。 +AWSドキュメント検索では、検索クエリがAWSの公開ドキュメント検索エンドポイントへ送信されます。プライベートなAWSコンテキストの漏えいを避けるため、clawsはリソースID、アカウントID、ARN、プロファイル名、ログ行、タグ値、シークレットを含むドキュメント検索クエリを拒否します。 + ### コンテキスト認識 アシスタントは現在のビューに基づいてコンテキストを自動的に受け取ります: diff --git a/docs/ai-chat.ko.md b/docs/ai-chat.ko.md index 52ff6a53..2afb126c 100644 --- a/docs/ai-chat.ko.md +++ b/docs/ai-chat.ko.md @@ -60,10 +60,12 @@ ai: - 서비스 및 리전에 걸쳐 AWS 리소스 목록 조회 및 쿼리 실행 - 특정 리소스의 상세 정보 가져오기 - 지원 리소스(Lambda, ECS, CodeBuild 등)의 CloudWatch 로그 가져오기 -- AWS 문서 검색 +- 일반적이고 공개 가능한 AWS 용어로 AWS 문서 검색 AI는 현재 프로필, 리전, 리소스 컨텍스트를 자동으로 사용합니다. +AWS 문서 검색은 검색 쿼리를 AWS의 공개 문서 검색 엔드포인트로 전송합니다. 비공개 AWS 컨텍스트가 유출되지 않도록 claws는 리소스 ID, 계정 ID, ARN, 프로필 이름, 로그 행, 태그 값 또는 시크릿이 포함된 문서 검색 쿼리를 거부합니다. + ### 컨텍스트 인식 어시스턴트는 현재 뷰에 따라 컨텍스트를 자동으로 수신합니다: diff --git a/docs/ai-chat.md b/docs/ai-chat.md index 4fd3fe4e..f26c7fb5 100644 --- a/docs/ai-chat.md +++ b/docs/ai-chat.md @@ -60,10 +60,12 @@ Press `A` in list/detail/diff views to open the AI Chat overlay. - List and query AWS resources across services and regions - Get detailed information about specific resources - Fetch CloudWatch logs for supported resources (Lambda, ECS, CodeBuild, etc.) -- Search AWS documentation +- Search AWS documentation with general, public AWS terms The AI automatically uses the current profile, region, and resource context from your view. +AWS documentation search sends the search query to AWS's public documentation search endpoint. To avoid leaking private AWS context, claws rejects documentation queries that contain resource IDs, account IDs, ARNs, profile names, logs, tag values, or secrets. + ### Context Awareness The assistant automatically receives context based on your current view: diff --git a/docs/ai-chat.zh-CN.md b/docs/ai-chat.zh-CN.md index c77c2ec3..7f4acdc9 100644 --- a/docs/ai-chat.zh-CN.md +++ b/docs/ai-chat.zh-CN.md @@ -60,10 +60,12 @@ ai: - 跨服务和区域列出和查询 AWS 资源 - 获取特定资源的详细信息 - 获取支持的资源(Lambda、ECS、CodeBuild 等)的 CloudWatch 日志 -- 搜索 AWS 文档 +- 使用通用且可公开的 AWS 术语搜索 AWS 文档 AI 会自动使用当前视图中的配置文件、区域和资源上下文。 +AWS 文档搜索会将搜索查询发送到 AWS 的公共文档搜索端点。为避免泄露私有 AWS 上下文,claws 会拒绝包含资源 ID、账户 ID、ARN、配置文件名称、日志行、标签值或密钥的文档搜索查询。 + ### 上下文感知 助手会根据您当前的视图自动接收上下文信息: diff --git a/flake.nix b/flake.nix index 330375c4..4953ec9f 100644 --- a/flake.nix +++ b/flake.nix @@ -8,7 +8,34 @@ outputs = { self, nixpkgs, flake-utils }: flake-utils.lib.eachDefaultSystem (system: - let pkgs = import nixpkgs { inherit system; }; + let + pkgs = import nixpkgs { inherit system; }; + lib = pkgs.lib; + + indexionPlatform = { + "aarch64-darwin" = "darwin-arm64"; + "x86_64-linux" = "linux-x64"; + }.${system} or (throw "indexion: unsupported system ${system}"); + + indexionHash = { + "darwin-arm64" = "1gbjfzwy9rgn7n79hj354w1jh2cqc6fvsj1m2zscvg6va6b1hdhl"; + "linux-x64" = "1pqll5vkb50fygq7ibqdry0lby54r50p17f75fv2s95xqy515c3i"; + }.${indexionPlatform}; + + indexion = pkgs.stdenvNoCC.mkDerivation { + pname = "indexion"; + version = "0.11.0"; + src = pkgs.fetchzip { + url = "https://github.com/trkbt10/indexion/releases/download/v0.11.0/indexion-${indexionPlatform}.tar.gz"; + sha256 = indexionHash; + stripRoot = true; + }; + installPhase = '' + mkdir -p $out/bin $out/share/indexion + cp indexion $out/bin/ + cp -r kgfs $out/share/indexion/ + ''; + }; in { devShells.default = pkgs.mkShell { packages = with pkgs; [ @@ -17,6 +44,10 @@ gopls golangci-lint vhs + ttyd + indexion + nodejs + bash ]; env.GOROOT = "${pkgs.go_1_25}/share/go"; diff --git a/internal/action/action.go b/internal/action/action.go index 4a656b77..a7df2cfd 100644 --- a/internal/action/action.go +++ b/internal/action/action.go @@ -64,6 +64,7 @@ type Action struct { Shortcut string Type ActionType Command string + Args []string Operation string Confirm ConfirmLevel @@ -191,26 +192,17 @@ func ConfirmTokenName(r dao.Resource) string { return r.GetName() } -// MinConfirmChars is the minimum number of characters required for dangerous confirmation. -// For tokens longer than this, only the last MinConfirmChars characters need to be typed. -const MinConfirmChars = 6 - -// ConfirmSuffix returns the suffix of the token that the user must type. +// ConfirmSuffix returns the token that the user must type. // For empty tokens, returns "CONFIRM" as a fallback to prevent accidental confirmation. -// For tokens <= MinConfirmChars, returns the full token. -// For longer tokens, returns the last MinConfirmChars characters. func ConfirmSuffix(token string) string { if token == "" { return "CONFIRM" } - if len(token) <= MinConfirmChars { - return token - } - return token[len(token)-MinConfirmChars:] + return token } // ConfirmMatches checks if the user input matches the required confirmation. -// Returns true if input equals the suffix returned by ConfirmSuffix. +// Returns true if input equals the token returned by ConfirmSuffix. func ConfirmMatches(token, input string) bool { return input == ConfirmSuffix(token) } @@ -299,6 +291,32 @@ func ExecuteWithDAO(ctx context.Context, action Action, resource dao.Resource, s } func executeExec(ctx context.Context, action Action, resource dao.Resource) ActionResult { + if len(action.Args) > 0 { + args, err := ExpandArgs(action.Args, resource) + if err != nil { + return ActionResult{Success: false, Error: err} + } + if len(args) == 0 || args[0] == "" { + return ActionResult{Success: false, Error: ErrEmptyCommand} + } + args, err = ResolveArgsExecutable(args) + if err != nil { + return ActionResult{Success: false, Error: err} + } + + execCmd := exec.CommandContext(ctx, args[0], args[1:]...) + execCmd.Stdin = os.Stdin + execCmd.Stdout = os.Stdout + execCmd.Stderr = os.Stderr + if !action.SkipAWSEnv { + setAWSEnv(execCmd, aws.GetRegionFromContext(ctx)) + } + if err := execCmd.Run(); err != nil { + return ActionResult{Success: false, Error: err} + } + return ActionResult{Success: true, Message: "Command executed successfully"} + } + cmd, err := ExpandVariables(action.Command, resource) if err != nil { return ActionResult{Success: false, Error: err} @@ -325,6 +343,41 @@ func executeExec(ctx context.Context, action Action, resource dao.Resource) Acti return ActionResult{Success: true, Message: "Command executed successfully"} } +// ExpandArgs replaces variables in command arguments with resource values. +// Arguments are executed without a shell, so shell metacharacters are preserved as literals. +func ExpandArgs(args []string, resource dao.Resource) ([]string, error) { + expanded := make([]string, len(args)) + for i, arg := range args { + replacements := map[string]string{ + "${ID}": resource.GetID(), + "${NAME}": resource.GetName(), + "${ARN}": resource.GetARN(), + "${INSTANCE_ID}": resource.GetID(), + "${BUCKET}": resource.GetID(), + } + + if p, ok := resource.(PrivateIPProvider); ok { + replacements["${PRIVATE_IP}"] = p.PrivateIP() + } + if p, ok := resource.(ClusterArnProvider); ok { + replacements["${CLUSTER}"] = p.ClusterArn() + } + if p, ok := resource.(ContainerNameProvider); ok { + replacements["${CONTAINER}"] = p.FirstContainerName() + } + if p, ok := resource.(LogGroupNameProvider); ok { + replacements["${LOG_GROUP}"] = p.LogGroupName() + } + + expandedArg := arg + for k, v := range replacements { + expandedArg = strings.ReplaceAll(expandedArg, k, v) + } + expanded[i] = expandedArg + } + return expanded, nil +} + // Optional interfaces for variable expansion in action commands. // Resources can implement these to provide additional variables. type ( diff --git a/internal/action/action_test.go b/internal/action/action_test.go index 358715af..8536a0f9 100644 --- a/internal/action/action_test.go +++ b/internal/action/action_test.go @@ -1,6 +1,7 @@ package action import ( + "bytes" "context" "errors" "os/exec" @@ -19,6 +20,56 @@ type mockResource struct { tags map[string]string } +func TestSimpleExecArgsTreatShellMetacharactersAsLiteral(t *testing.T) { + var stdout bytes.Buffer + execCmd := &SimpleExec{ + Args: []string{"/bin/echo", "profile; echo injected"}, + SkipAWSEnv: true, + stdout: &stdout, + } + + if err := execCmd.Run(); err != nil { + t.Fatalf("Run() returned error: %v", err) + } + + if got, want := stdout.String(), "profile; echo injected\n"; got != want { + t.Fatalf("stdout = %q, want %q", got, want) + } +} + +func TestResolveArgsExecutableReturnsCopy(t *testing.T) { + original := []string{"/bin/echo", "hello"} + resolved, err := ResolveArgsExecutable(original) + if err != nil { + t.Fatalf("ResolveArgsExecutable() returned error: %v", err) + } + + if resolved[0] != "/bin/echo" { + t.Fatalf("resolved executable = %q, want /bin/echo", resolved[0]) + } + resolved[0] = "/tmp/changed" + if original[0] != "/bin/echo" { + t.Fatalf("original args mutated: %q", original[0]) + } +} + +func TestExpandArgsTreatsMetacharactersAsLiteralValues(t *testing.T) { + resource := &mockResource{id: "i-123; echo injected", name: "test"} + got, err := ExpandArgs([]string{"/bin/echo", "${ID}"}, resource) + if err != nil { + t.Fatalf("ExpandArgs() returned error: %v", err) + } + want := []string{"/bin/echo", "i-123; echo injected"} + if len(got) != len(want) { + t.Fatalf("ExpandArgs() len = %d, want %d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("ExpandArgs()[%d] = %q, want %q", i, got[i], want[i]) + } + } +} + func (m *mockResource) GetID() string { return m.id } func (m *mockResource) GetName() string { return m.name } func (m *mockResource) GetARN() string { return m.arn } @@ -500,6 +551,19 @@ func TestExecuteWithDAO_ExecType(t *testing.T) { } }) + t.Run("args with variable expansion", func(t *testing.T) { + action := Action{ + Type: ActionTypeExec, + Args: []string{"/bin/echo", "${ID}"}, + } + + result := ExecuteWithDAO(context.Background(), action, &mockResource{id: "test-id"}, "test", "resource") + + if !result.Success { + t.Errorf("ExecuteWithDAO should succeed, got error: %v", result.Error) + } + }) + t.Run("failing command", func(t *testing.T) { action := Action{ Type: ActionTypeExec, @@ -1035,9 +1099,9 @@ func TestConfirmSuffix(t *testing.T) { }{ {"abc", "abc"}, {"abcdef", "abcdef"}, - {"abcdefg", "bcdefg"}, - {"i-1234567890abcdef0", "bcdef0"}, - {"arn:aws:iam::123456789012:policy/MyPolicy", "Policy"}, + {"abcdefg", "abcdefg"}, + {"i-1234567890abcdef0", "i-1234567890abcdef0"}, + {"arn:aws:iam::123456789012:policy/MyPolicy", "arn:aws:iam::123456789012:policy/MyPolicy"}, {"", "CONFIRM"}, } @@ -1060,14 +1124,14 @@ func TestConfirmMatches(t *testing.T) { }{ {"exact match short", "abc", "abc", true}, {"exact match 6 chars", "abcdef", "abcdef", true}, - {"suffix match long token", "i-1234567890abcdef0", "bcdef0", true}, - {"suffix match ARN", "arn:aws:iam::123456789012:policy/MyPolicy", "Policy", true}, + {"suffix rejected for long token", "i-1234567890abcdef0", "bcdef0", false}, + {"suffix rejected for ARN", "arn:aws:iam::123456789012:policy/MyPolicy", "Policy", false}, {"wrong suffix", "i-1234567890abcdef0", "wrong", false}, {"partial suffix", "i-1234567890abcdef0", "def0", false}, {"empty input", "abcdef", "", false}, {"empty token requires CONFIRM", "", "CONFIRM", true}, {"empty token rejects empty input", "", "", false}, - {"full token when suffix expected", "i-1234567890abcdef0", "i-1234567890abcdef0", false}, + {"full token required", "i-1234567890abcdef0", "i-1234567890abcdef0", true}, } for _, tt := range tests { diff --git a/internal/action/exec_with_header.go b/internal/action/exec_with_header.go index d84593e9..e3f74781 100644 --- a/internal/action/exec_with_header.go +++ b/internal/action/exec_with_header.go @@ -7,6 +7,7 @@ import ( "io" "os" "os/exec" + "path/filepath" "strings" "golang.org/x/term" @@ -26,7 +27,9 @@ func setAWSEnv(cmd *exec.Cmd, region string) { // SimpleExec represents a simple exec command without header. // Implements tea.ExecCommand interface. type SimpleExec struct { + Context context.Context Command string + Args []string ActionName string // Action name for read-only allowlist check SkipAWSEnv bool // If true, don't inject AWS env vars (for commands that need to write to ~/.aws) @@ -50,7 +53,7 @@ func (e *SimpleExec) Run() error { return ErrReadOnlyDenied } - if e.Command == "" { + if e.Command == "" && len(e.Args) == 0 { return ErrEmptyCommand } @@ -67,7 +70,14 @@ func (e *SimpleExec) Run() error { stderr = os.Stderr } - cmd := exec.CommandContext(context.Background(), "/bin/sh", "-c", e.Command) + cmdCtx := e.Context + if cmdCtx == nil { + cmdCtx = context.Background() + } + cmd, err := e.command(cmdCtx) + if err != nil { + return err + } cmd.Stdin = stdin cmd.Stdout = stdout cmd.Stderr = stderr @@ -78,10 +88,59 @@ func (e *SimpleExec) Run() error { return cmd.Run() } +func (e *SimpleExec) command(ctx context.Context) (*exec.Cmd, error) { + if len(e.Args) > 0 { + if e.Args[0] == "" { + return nil, ErrEmptyCommand + } + args, err := ResolveArgsExecutable(e.Args) + if err != nil { + return nil, err + } + return exec.CommandContext(ctx, args[0], args[1:]...), nil + } + if e.Command == "" { + return nil, ErrEmptyCommand + } + return exec.CommandContext(ctx, "/bin/sh", "-c", e.Command), nil +} + +// ResolveExecutable resolves name to the executable path that will be invoked. +// Absolute or relative paths containing a path separator are returned unchanged. +func ResolveExecutable(name string) (string, error) { + if name == "" { + return "", ErrEmptyCommand + } + if filepath.IsAbs(name) || strings.ContainsRune(name, os.PathSeparator) { + return name, nil + } + path, err := exec.LookPath(name) + if err != nil { + return "", fmt.Errorf("resolve executable %q: %w", name, err) + } + return path, nil +} + +// ResolveArgsExecutable returns a copy of args with args[0] resolved to the executable path. +func ResolveArgsExecutable(args []string) ([]string, error) { + if len(args) == 0 || args[0] == "" { + return nil, ErrEmptyCommand + } + resolved, err := ResolveExecutable(args[0]) + if err != nil { + return nil, err + } + out := append([]string(nil), args...) + out[0] = resolved + return out, nil +} + // ExecWithHeader represents an exec command that should run with a fixed header // Implements tea.ExecCommand interface type ExecWithHeader struct { + Context context.Context Command string + Args []string ActionName string Resource dao.Resource Service string @@ -160,12 +219,15 @@ func (e *ExecWithHeader) Run() error { // Move cursor to scroll region _, _ = fmt.Fprintf(stdout, "\x1b[%d;1H", scrollTop) - // Prepare command - run through shell to support quoting and pipes - if e.Command == "" { - return ErrEmptyCommand + cmdCtx := e.Context + if cmdCtx == nil { + cmdCtx = context.Background() } - cmd := exec.CommandContext(context.Background(), "/bin/sh", "-c", e.Command) + cmd, err := e.command(cmdCtx) + if err != nil { + return err + } cmd.Stdin = stdin cmd.Stdout = stdout cmd.Stderr = stderr @@ -174,7 +236,7 @@ func (e *ExecWithHeader) Run() error { } // Run the command - err := cmd.Run() + err = cmd.Run() // Reset scroll region _, _ = fmt.Fprint(stdout, "\x1b[r") @@ -200,6 +262,20 @@ func (e *ExecWithHeader) Run() error { return err } +func (e *ExecWithHeader) command(ctx context.Context) (*exec.Cmd, error) { + if len(e.Args) > 0 { + args, err := ResolveArgsExecutable(e.Args) + if err != nil { + return nil, err + } + return exec.CommandContext(ctx, args[0], args[1:]...), nil + } + if e.Command == "" { + return nil, ErrEmptyCommand + } + return exec.CommandContext(ctx, "/bin/sh", "-c", e.Command), nil +} + func (e *ExecWithHeader) buildHeader(_ int) string { profileDisplay := config.Global().Selection().DisplayName() region := e.Region diff --git a/internal/ai/tools.go b/internal/ai/tools.go index a6f971ae..e66a249c 100644 --- a/internal/ai/tools.go +++ b/internal/ai/tools.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "net/http" + "regexp" "strings" "time" "unicode" @@ -18,6 +19,7 @@ import ( "github.com/clawscli/claws/internal/dao" "github.com/clawscli/claws/internal/log" "github.com/clawscli/claws/internal/registry" + "github.com/clawscli/claws/internal/sanitize" apigatewayStages "github.com/clawscli/claws/custom/apigateway/stages" apigatewayStagesV2 "github.com/clawscli/claws/custom/apigateway/stages-v2" @@ -31,15 +33,161 @@ import ( ) type ToolExecutor struct { - registry *registry.Registry + registry *registry.Registry + aiCtx *Context + docsSearcher func(context.Context, string) string } -func NewToolExecutor(_ context.Context, reg *registry.Registry) (*ToolExecutor, error) { +var ( + docsSearchARNPattern = regexp.MustCompile(`\barn:[^\s]+`) + docsSearchAccountIDPattern = regexp.MustCompile(`\b\d{12}\b`) +) + +func NewToolExecutor(_ context.Context, reg *registry.Registry, contexts ...*Context) (*ToolExecutor, error) { + var aiCtx *Context + if len(contexts) > 0 { + aiCtx = contexts[0] + } return &ToolExecutor{ registry: reg, + aiCtx: aiCtx, }, nil } +func (e *ToolExecutor) validateScope(service, resourceType, region, profile, id, cluster string) (string, string, error) { + ctx := e.aiCtx + if ctx == nil { + return profile, cluster, nil + } + if ctx.Service != "" && service != ctx.Service { + return "", "", fmt.Errorf("service %s is outside the current AI context", service) + } + if ctx.ResourceType != "" && resourceType != ctx.ResourceType { + return "", "", fmt.Errorf("resource type %s is outside the current AI context", resourceType) + } + if region != "" && !regionAllowed(ctx, region) { + return "", "", fmt.Errorf("region %s is outside the current AI context", region) + } + profile = defaultProfile(ctx, profile) + if profile != "" && !profileAllowed(ctx, profile) { + return "", "", fmt.Errorf("profile %s is outside the current AI context", profile) + } + if id != "" && !resourceAllowed(ctx, id) { + return "", "", fmt.Errorf("resource %s is outside the current AI context", id) + } + cluster = defaultCluster(ctx, cluster) + if cluster != "" && !clusterAllowed(ctx, cluster) { + return "", "", fmt.Errorf("cluster %s is outside the current AI context", cluster) + } + return profile, cluster, nil +} + +func defaultProfile(ctx *Context, profile string) string { + if profile != "" { + return profile + } + if ctx.ResourceProfile != "" { + return ctx.ResourceProfile + } + if refsHaveSameProfile(ctx.DiffLeft, ctx.DiffRight) { + return ctx.DiffLeft.Profile + } + return "" +} + +func defaultCluster(ctx *Context, cluster string) string { + if cluster != "" { + return cluster + } + if ctx.Cluster != "" { + return ctx.Cluster + } + if refsHaveSameCluster(ctx.DiffLeft, ctx.DiffRight) { + return ctx.DiffLeft.Cluster + } + return "" +} + +func regionAllowed(ctx *Context, region string) bool { + if ctx.ResourceRegion != "" { + return region == ctx.ResourceRegion + } + if ctx.DiffLeft != nil || ctx.DiffRight != nil { + return resourceRefHasRegion(ctx.DiffLeft, region) || resourceRefHasRegion(ctx.DiffRight, region) + } + if len(ctx.UserRegions) == 0 { + return true + } + for _, allowed := range ctx.UserRegions { + if region == allowed { + return true + } + } + return false +} + +func profileAllowed(ctx *Context, profile string) bool { + if ctx.ResourceProfile != "" { + return profile == ctx.ResourceProfile + } + if ctx.DiffLeft != nil || ctx.DiffRight != nil { + return resourceRefHasProfile(ctx.DiffLeft, profile) || resourceRefHasProfile(ctx.DiffRight, profile) + } + if len(ctx.UserProfiles) == 0 { + return true + } + for _, allowed := range ctx.UserProfiles { + if profile == allowed { + return true + } + } + return false +} + +func clusterAllowed(ctx *Context, cluster string) bool { + if ctx.Cluster != "" { + return cluster == ctx.Cluster + } + if ctx.DiffLeft != nil || ctx.DiffRight != nil { + return resourceRefHasCluster(ctx.DiffLeft, cluster) || resourceRefHasCluster(ctx.DiffRight, cluster) + } + return true +} + +func resourceAllowed(ctx *Context, id string) bool { + if ctx.ResourceID != "" { + return id == ctx.ResourceID + } + if ctx.DiffLeft != nil || ctx.DiffRight != nil { + return resourceRefHasID(ctx.DiffLeft, id) || resourceRefHasID(ctx.DiffRight, id) + } + return true +} + +func resourceRefHasID(ref *ResourceRef, id string) bool { + return ref != nil && ref.ID == id +} + +func resourceRefHasRegion(ref *ResourceRef, region string) bool { + return ref != nil && ref.Region == region +} + +func resourceRefHasProfile(ref *ResourceRef, profile string) bool { + return ref != nil && ref.Profile == profile +} + +func resourceRefHasCluster(ref *ResourceRef, cluster string) bool { + return ref != nil && ref.Cluster == cluster +} + +func refsHaveSameProfile(left, right *ResourceRef) bool { + return left != nil && right != nil && left.Profile != "" && left.Profile == right.Profile +} + +func refsHaveSameCluster(left, right *ResourceRef) bool { + return left != nil && right != nil && left.Cluster != "" && left.Cluster == right.Cluster +} + func (e *ToolExecutor) Tools() []Tool { return []Tool{ { @@ -176,7 +324,7 @@ func (e *ToolExecutor) Tools() []Tool { }, { Name: "search_aws_docs", - Description: "Search AWS documentation for information", + Description: "Search AWS documentation for information. Queries containing private or sensitive context are rejected before external search.", InputSchema: map[string]any{ "type": "object", "properties": map[string]any{ @@ -237,12 +385,23 @@ func (e *ToolExecutor) Execute(ctx context.Context, call *ToolUseContent) ToolRe content, isError = e.tailLogs(ctx, service, resourceType, region, id, cluster, profile, filter, since, int(limit)) case "search_aws_docs": query, _ := call.Input["query"].(string) - content = e.searchDocs(ctx, query) + var err error + query, err = e.prepareDocsSearchQuery(query) + if err != nil { + content = "Error: " + err.Error() + isError = true + } else { + content = e.runDocsSearch(ctx, query) + } default: content = fmt.Sprintf("Unknown tool: %s", call.Name) isError = true } + if isPrivateDataTool(call.Name) && isError { + content = e.redactPrivateToolOutput(content) + } + return ToolResultContent{ ToolUseID: call.ID, Content: content, @@ -250,6 +409,73 @@ func (e *ToolExecutor) Execute(ctx context.Context, call *ToolUseContent) ToolRe } } +func isPrivateDataTool(toolName string) bool { + switch toolName { + case "query_resources", "get_resource_detail", "tail_logs": + return true + default: + return false + } +} + +func (e *ToolExecutor) runDocsSearch(ctx context.Context, query string) string { + if e.docsSearcher != nil { + return e.docsSearcher(ctx, query) + } + return e.searchDocs(ctx, query) +} + +func (e *ToolExecutor) prepareDocsSearchQuery(query string) (string, error) { + query = strings.TrimSpace(query) + if query == "" { + return "", fmt.Errorf("query parameter is required") + } + if sanitized := e.redactPrivateDocsSearchQuery(query); sanitized != query { + return "", fmt.Errorf("AWS documentation search query contains private or sensitive context; ask a general AWS documentation question without resource IDs, account IDs, ARNs, profile names, logs, tags, or secrets") + } + return query, nil +} + +func (e *ToolExecutor) redactPrivateDocsSearchQuery(query string) string { + redacted := sanitize.SensitiveText(query) + redacted = docsSearchARNPattern.ReplaceAllString(redacted, sanitize.Redacted) + redacted = docsSearchAccountIDPattern.ReplaceAllString(redacted, sanitize.Redacted) + return redactDocsSearchContextValues(redacted, e.aiCtx) +} + +func (e *ToolExecutor) redactPrivateToolOutput(output string) string { + return e.redactPrivateDocsSearchQuery(output) +} + +func redactDocsSearchContextValues(query string, ctx *Context) string { + if ctx == nil { + return query + } + values := []string{ + ctx.ResourceID, + ctx.ResourceProfile, + ctx.Cluster, + ctx.FilterText, + } + values = append(values, ctx.UserProfiles...) + values = append(values, resourceRefPrivateValues(ctx.DiffLeft)...) + values = append(values, resourceRefPrivateValues(ctx.DiffRight)...) + for _, value := range values { + if value == "" { + continue + } + query = strings.ReplaceAll(query, value, sanitize.Redacted) + } + return query +} + +func resourceRefPrivateValues(ref *ResourceRef) []string { + if ref == nil { + return nil + } + return []string{ref.ID, ref.Name, ref.Profile, ref.Cluster} +} + func (e *ToolExecutor) listResources(service string) string { resources := e.registry.ListResources(service) if len(resources) == 0 { @@ -274,6 +500,11 @@ func (e *ToolExecutor) queryResources(ctx context.Context, service, resourceType if region == "" { return "Error: region parameter is required", true } + var err error + profile, _, err = e.validateScope(service, resourceType, region, profile, "", "") + if err != nil { + return "Error: " + err.Error(), true + } // Validate and apply limit if limit <= 0 { @@ -349,6 +580,11 @@ func (e *ToolExecutor) getResourceDetail(ctx context.Context, service, resourceT if region == "" { return "Error: region parameter is required", true } + var err error + profile, cluster, err = e.validateScope(service, resourceType, region, profile, id, cluster) + if err != nil { + return "Error: " + err.Error(), true + } if profile != "" { ctx = appaws.WithSelectionOverride(ctx, appconfig.ProfileSelectionFromID(profile)) @@ -383,6 +619,11 @@ func (e *ToolExecutor) tailLogs(ctx context.Context, service, resourceType, regi if region == "" { return "Error: region parameter is required", true } + var err error + profile, cluster, err = e.validateScope(service, resourceType, region, profile, id, cluster) + if err != nil { + return "Error: " + err.Error(), true + } if limit <= 0 { limit = 100 } @@ -440,7 +681,7 @@ func (e *ToolExecutor) tailLogs(ctx context.Context, service, resourceType, regi result := fmt.Sprintf("Logs from %s (%d events):\n\n", logGroup, len(output.Events)) for _, event := range output.Events { ts := time.UnixMilli(aws.ToInt64(event.Timestamp)) - result += fmt.Sprintf("[%s] %s\n", ts.Format("15:04:05"), aws.ToString(event.Message)) + result += fmt.Sprintf("[%s] %s\n", ts.Format("15:04:05"), sanitize.LogText(aws.ToString(event.Message))) } return result, false @@ -755,6 +996,11 @@ func formatResourceDetail(r dao.Resource) string { if tags := r.GetTags(); len(tags) > 0 { result += "\nTags:\n" for k, v := range tags { + if isSensitiveRawKey(k) { + v = sanitize.Redacted + } else { + v = sanitize.SensitiveText(v) + } result += fmt.Sprintf(" %s: %s\n", k, v) } } @@ -819,6 +1065,8 @@ func redactSensitiveValue(v any) any { redacted[i] = redactSensitiveValue(nested) } return redacted + case string: + return sanitize.SensitiveText(value) default: return value } diff --git a/internal/ai/tools_test.go b/internal/ai/tools_test.go index 6dde6703..df3696ec 100644 --- a/internal/ai/tools_test.go +++ b/internal/ai/tools_test.go @@ -2,8 +2,12 @@ package ai import ( "context" + "errors" "strings" "testing" + + "github.com/clawscli/claws/internal/dao" + "github.com/clawscli/claws/internal/registry" ) func TestToolExecutorTools(t *testing.T) { @@ -211,6 +215,154 @@ func TestToolExecuteTailLogsMissingRegion(t *testing.T) { } } +func TestToolExecuteRejectsOutOfScopeProfile(t *testing.T) { + executor := &ToolExecutor{ + registry: nil, + aiCtx: &Context{ + Mode: ContextModeList, + Service: "ec2", + ResourceType: "instances", + UserRegions: []string{"us-east-1"}, + UserProfiles: []string{"dev"}, + }, + } + + result := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "test-123", + Name: "query_resources", + Input: map[string]any{ + "service": "ec2", + "resource_type": "instances", + "region": "us-east-1", + "profile": "prod", + }, + }) + + if !result.IsError { + t.Fatal("expected out-of-scope profile to be rejected") + } + if !strings.Contains(result.Content, "profile prod is outside the current AI context") { + t.Fatalf("unexpected error content: %q", result.Content) + } +} + +func TestToolExecuteRejectsOutOfScopeResource(t *testing.T) { + executor := &ToolExecutor{ + registry: nil, + aiCtx: &Context{ + Mode: ContextModeSingle, + Service: "ec2", + ResourceType: "instances", + ResourceID: "i-allowed", + ResourceRegion: "us-east-1", + ResourceProfile: "dev", + }, + } + + result := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "test-123", + Name: "get_resource_detail", + Input: map[string]any{ + "service": "ec2", + "resource_type": "instances", + "region": "us-east-1", + "profile": "dev", + "id": "i-other", + }, + }) + + if !result.IsError { + t.Fatal("expected out-of-scope resource to be rejected") + } + if !strings.Contains(result.Content, "resource i-other is outside the current AI context") { + t.Fatalf("unexpected error content: %q", result.Content) + } +} + +func TestToolExecutorDefaultsSingleResourceProfileScope(t *testing.T) { + executor := &ToolExecutor{ + aiCtx: &Context{ + Mode: ContextModeSingle, + Service: "ec2", + ResourceType: "instances", + ResourceID: "i-allowed", + ResourceRegion: "us-east-1", + ResourceProfile: "dev", + }, + } + + profile, cluster, err := executor.validateScope("ec2", "instances", "us-east-1", "", "i-allowed", "") + if err != nil { + t.Fatalf("validateScope() returned error: %v", err) + } + if profile != "dev" { + t.Fatalf("profile = %q, want context resource profile", profile) + } + if cluster != "" { + t.Fatalf("cluster = %q, want empty", cluster) + } +} + +func TestToolExecuteRejectsOutOfScopeCluster(t *testing.T) { + executor := &ToolExecutor{ + registry: nil, + aiCtx: &Context{ + Mode: ContextModeSingle, + Service: "ecs", + ResourceType: "services", + ResourceID: "svc-allowed", + ResourceRegion: "us-east-1", + ResourceProfile: "dev", + Cluster: "cluster-a", + }, + } + + result := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "test-123", + Name: "get_resource_detail", + Input: map[string]any{ + "service": "ecs", + "resource_type": "services", + "region": "us-east-1", + "profile": "dev", + "id": "svc-allowed", + "cluster": "cluster-b", + }, + }) + + if !result.IsError { + t.Fatal("expected out-of-scope cluster to be rejected") + } + if !strings.Contains(result.Content, "cluster cluster-b is outside the current AI context") { + t.Fatalf("unexpected error content: %q", result.Content) + } +} + +func TestToolExecutorDefaultsSingleResourceClusterScope(t *testing.T) { + executor := &ToolExecutor{ + aiCtx: &Context{ + Mode: ContextModeSingle, + Service: "ecs", + ResourceType: "services", + ResourceID: "svc-allowed", + ResourceRegion: "us-east-1", + ResourceProfile: "dev", + Cluster: "cluster-a", + }, + } + + profile, cluster, err := executor.validateScope("ecs", "services", "us-east-1", "", "svc-allowed", "") + if err != nil { + t.Fatalf("validateScope() returned error: %v", err) + } + if profile != "dev" { + t.Fatalf("profile = %q, want context resource profile", profile) + } + if cluster != "cluster-a" { + t.Fatalf("cluster = %q, want context cluster", cluster) + } +} + func TestToolExecuteSearchDocsEmptyQuery(t *testing.T) { executor := &ToolExecutor{registry: nil} @@ -225,6 +377,299 @@ func TestToolExecuteSearchDocsEmptyQuery(t *testing.T) { } } +func TestPrepareDocsSearchQueryAllowsGeneralQueryBeforeAWSData(t *testing.T) { + executor := &ToolExecutor{ + aiCtx: &Context{ + Mode: ContextModeSingle, + Service: "ec2", + ResourceType: "instances", + ResourceID: "i-private123", + ResourceProfile: "production-profile", + }, + } + + query, err := executor.prepareDocsSearchQuery("EC2 instance metadata options") + if err != nil { + t.Fatalf("prepareDocsSearchQuery returned error: %v", err) + } + if query != "EC2 instance metadata options" { + t.Fatalf("query = %q, want unchanged general query", query) + } +} + +func TestToolExecuteRejectsSensitiveDocsSearchQuery(t *testing.T) { + executor := &ToolExecutor{ + aiCtx: &Context{ + Mode: ContextModeSingle, + Service: "ec2", + ResourceType: "instances", + ResourceID: "i-private123", + ResourceProfile: "production-profile", + }, + } + + result := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "test-123", + Name: "search_aws_docs", + Input: map[string]any{ + "query": "why is i-private123 in production-profile failing with token=plain-secret", + }, + }) + + if !result.IsError { + t.Fatal("expected sensitive documentation query to be rejected") + } + if !strings.Contains(result.Content, "private or sensitive context") { + t.Fatalf("unexpected rejection message: %q", result.Content) + } + for _, leaked := range []string{"i-private123", "production-profile", "plain-secret"} { + if strings.Contains(result.Content, leaked) { + t.Fatalf("rejection message leaked %q: %q", leaked, result.Content) + } + } +} + +func TestPrepareDocsSearchQueryRejectsAWSIdentifiers(t *testing.T) { + executor := &ToolExecutor{} + + for _, query := range []string{ + "explain arn:aws:lambda:us-east-1:123456789012:function:prod-handler timeout", + "why does account 123456789012 see access denied", + } { + t.Run(query, func(t *testing.T) { + if _, err := executor.prepareDocsSearchQuery(query); err == nil { + t.Fatal("expected AWS identifier query to be rejected") + } + }) + } +} + +func TestToolExecuteAllowsDocsSearchAfterAWSDataTool(t *testing.T) { + reg := registry.New() + reg.RegisterCustom("ec2", "instances", registry.Entry{ + DAOFactory: func(ctx context.Context) (dao.DAO, error) { + return &mockDAO{ + BaseDAO: dao.NewBaseDAO("ec2", "instances"), + resources: []dao.Resource{ + &mockResource{id: "i-123", name: "app-server"}, + }, + }, nil + }, + }) + executor := &ToolExecutor{ + registry: reg, + docsSearcher: func(ctx context.Context, query string) string { + return "docs: " + query + }, + } + + queryResult := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "query-123", + Name: "query_resources", + Input: map[string]any{ + "service": "ec2", + "resource_type": "instances", + "region": "us-east-1", + }, + }) + if queryResult.IsError { + t.Fatalf("expected query_resources to succeed, got %q", queryResult.Content) + } + + result := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "test-123", + Name: "search_aws_docs", + Input: map[string]any{ + "query": "how to rotate access keys", + }, + }) + + if result.IsError { + t.Fatalf("expected documentation search to remain allowed after AWS data tools, got %q", result.Content) + } + if result.Content != "docs: how to rotate access keys" { + t.Fatalf("unexpected documentation search result: %q", result.Content) + } +} + +func TestToolExecuteAllowsDocsSearchAfterResourceDetailTool(t *testing.T) { + reg := registry.New() + reg.RegisterCustom("ec2", "instances", registry.Entry{ + DAOFactory: func(ctx context.Context) (dao.DAO, error) { + return &mockDAO{ + BaseDAO: dao.NewBaseDAO("ec2", "instances"), + resources: []dao.Resource{ + &mockResource{id: "i-123", name: "app-server"}, + }, + }, nil + }, + }) + executor := &ToolExecutor{ + registry: reg, + docsSearcher: func(ctx context.Context, query string) string { + return "docs: " + query + }, + } + + detailResult := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "detail-123", + Name: "get_resource_detail", + Input: map[string]any{ + "service": "ec2", + "resource_type": "instances", + "region": "us-east-1", + "id": "i-123", + }, + }) + if detailResult.IsError { + t.Fatalf("expected get_resource_detail to succeed, got %q", detailResult.Content) + } + + result := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "test-123", + Name: "search_aws_docs", + Input: map[string]any{ + "query": "EC2 instance metadata options", + }, + }) + if result.IsError { + t.Fatalf("expected documentation search to remain allowed after get_resource_detail, got %q", result.Content) + } + if result.Content != "docs: EC2 instance metadata options" { + t.Fatalf("unexpected documentation search result: %q", result.Content) + } +} + +func TestToolExecuteAllowsDocsSearchAfterFailedAWSDataTool(t *testing.T) { + reg := registry.New() + reg.RegisterCustom("ec2", "instances", registry.Entry{ + DAOFactory: func(ctx context.Context) (dao.DAO, error) { + return &mockDAO{ + BaseDAO: dao.NewBaseDAO("ec2", "instances"), + listErr: errors.New("operation failed for arn:aws:ec2:us-east-1:123456789012:instance/i-private token=plain-secret"), + }, nil + }, + }) + executor := &ToolExecutor{ + registry: reg, + docsSearcher: func(ctx context.Context, query string) string { + return "docs: " + query + }, + } + + queryResult := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "query-123", + Name: "query_resources", + Input: map[string]any{ + "service": "ec2", + "resource_type": "instances", + "region": "us-east-1", + }, + }) + if !queryResult.IsError { + t.Fatal("expected query_resources to fail") + } + for _, leaked := range []string{"arn:aws:ec2", "123456789012", "plain-secret"} { + if strings.Contains(queryResult.Content, leaked) { + t.Fatalf("expected failed data tool output to redact %q, got %q", leaked, queryResult.Content) + } + } + + result := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "test-123", + Name: "search_aws_docs", + Input: map[string]any{ + "query": "EC2 instance metadata options", + }, + }) + if result.IsError { + t.Fatalf("expected documentation search to remain allowed after failed AWS data tool, got %q", result.Content) + } + if result.Content != "docs: EC2 instance metadata options" { + t.Fatalf("unexpected documentation search result: %q", result.Content) + } +} + +func TestToolExecuteRedactsFailedResourceDetailOutput(t *testing.T) { + reg := registry.New() + reg.RegisterCustom("ec2", "instances", registry.Entry{ + DAOFactory: func(ctx context.Context) (dao.DAO, error) { + return &mockDAO{ + BaseDAO: dao.NewBaseDAO("ec2", "instances"), + getErr: errors.New("lookup failed for arn:aws:ec2:us-east-1:123456789012:instance/i-private password=plain-secret"), + }, nil + }, + }) + executor := &ToolExecutor{registry: reg} + + result := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "detail-123", + Name: "get_resource_detail", + Input: map[string]any{ + "service": "ec2", + "resource_type": "instances", + "region": "us-east-1", + "id": "i-private", + }, + }) + if !result.IsError { + t.Fatal("expected get_resource_detail to fail") + } + for _, leaked := range []string{"arn:aws:ec2", "123456789012", "plain-secret"} { + if strings.Contains(result.Content, leaked) { + t.Fatalf("expected failed detail output to redact %q, got %q", leaked, result.Content) + } + } +} + +func TestToolExecuteAllowsDocsSearchAfterMissingDataToolParam(t *testing.T) { + executor := &ToolExecutor{ + docsSearcher: func(ctx context.Context, query string) string { + return "docs: " + query + }, + } + + missingParamResult := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "query-123", + Name: "query_resources", + Input: map[string]any{"service": "ec2", "resource_type": "instances"}, + }) + if !missingParamResult.IsError { + t.Fatal("expected missing parameter to fail") + } + + result := executor.Execute(context.TODO(), &ToolUseContent{ + ID: "test-123", + Name: "search_aws_docs", + Input: map[string]any{ + "query": "EC2 instance metadata options", + }, + }) + if result.IsError { + t.Fatalf("expected documentation search to remain allowed after failed AWS data tool attempt, got %q", result.Content) + } + if result.Content != "docs: EC2 instance metadata options" { + t.Fatalf("unexpected documentation search result: %q", result.Content) + } +} + +func TestIsPrivateDataToolIncludesDataTools(t *testing.T) { + for _, toolName := range []string{"query_resources", "get_resource_detail", "tail_logs"} { + t.Run(toolName, func(t *testing.T) { + if !isPrivateDataTool(toolName) { + t.Fatalf("expected %s to be treated as private data tool", toolName) + } + }) + } + for _, toolName := range []string{"list_resources", "search_aws_docs", "unknown"} { + t.Run(toolName, func(t *testing.T) { + if isPrivateDataTool(toolName) { + t.Fatalf("expected %s not to be treated as private data tool", toolName) + } + }) + } +} + func TestExtractLogGroupNameFromArn(t *testing.T) { tests := []struct { arn string @@ -344,6 +789,49 @@ func TestFormatResourceDetailRedactsSensitiveRawData(t *testing.T) { } } +func TestFormatResourceDetailRedactsSensitiveTags(t *testing.T) { + resource := &mockResource{ + id: "resource-1", + name: "resource", + tags: map[string]string{ + "Environment": "prod", + "ApiToken": "plain-secret-token", + }, + } + + result := formatResourceDetail(resource) + + if strings.Contains(result, "plain-secret-token") { + t.Fatalf("expected sensitive tag value to be redacted, got %q", result) + } + if !strings.Contains(result, "ApiToken") || !strings.Contains(result, "[REDACTED]") { + t.Fatalf("expected sensitive tag key with redaction marker, got %q", result) + } + if !strings.Contains(result, "Environment: prod") { + t.Fatalf("expected non-sensitive tag to remain, got %q", result) + } +} + +func TestFormatResourceDetailRedactsSensitiveTagValuePatterns(t *testing.T) { + resource := &mockResource{ + id: "resource-1", + name: "resource", + tags: map[string]string{ + "Environment": "prod", + "Endpoint": "postgres://app:super-secret-password@db.example.com:5432/app", + }, + } + + result := formatResourceDetail(resource) + + if strings.Contains(result, "super-secret-password") { + t.Fatalf("expected sensitive tag value pattern to be redacted, got %q", result) + } + if !strings.Contains(result, "Environment: prod") { + t.Fatalf("expected non-sensitive tag to remain, got %q", result) + } +} + func TestFormatResourceDetailRedactsSensitiveLabelValueRecords(t *testing.T) { resource := &mockResource{ id: "stack-1", @@ -444,6 +932,33 @@ func TestFormatResourceDetailPreservesMultipleSensitiveKeyNames(t *testing.T) { } } +func TestFormatResourceDetailRedactsSensitiveValuePatterns(t *testing.T) { + resource := &mockResource{ + id: "resource-1", + name: "resource", + raw: map[string]any{ + "DatabaseURL": "postgres://app:super-secret-password@db.example.com:5432/app", + "Header": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature", + "Certificate": "-----BEGIN PRIVATE KEY-----\nplain-private-key\n-----END PRIVATE KEY-----", + "PublicURL": "https://example.com/health", + }, + } + + result := formatResourceDetail(resource) + + for _, secret := range []string{"super-secret-password", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.payload.signature", "plain-private-key"} { + if strings.Contains(result, secret) { + t.Fatalf("expected value-only secret %q to be redacted, got %q", secret, result) + } + } + if !strings.Contains(result, "https://example.com/health") { + t.Fatalf("expected non-sensitive URL to remain, got %q", result) + } + if !strings.Contains(result, "[REDACTED]") { + t.Fatalf("expected redaction marker, got %q", result) + } +} + type mockResource struct { id string name string @@ -457,3 +972,33 @@ func (m *mockResource) GetName() string { return m.name } func (m *mockResource) GetARN() string { return m.arn } func (m *mockResource) GetTags() map[string]string { return m.tags } func (m *mockResource) Raw() any { return m.raw } + +type mockDAO struct { + dao.BaseDAO + resources []dao.Resource + listErr error + getErr error +} + +func (d *mockDAO) List(ctx context.Context) ([]dao.Resource, error) { + if d.listErr != nil { + return nil, d.listErr + } + return d.resources, nil +} + +func (d *mockDAO) Get(ctx context.Context, id string) (dao.Resource, error) { + if d.getErr != nil { + return nil, d.getErr + } + for _, resource := range d.resources { + if resource.GetID() == id { + return resource, nil + } + } + return nil, nil +} + +func (d *mockDAO) Delete(ctx context.Context, id string) error { + return nil +} diff --git a/internal/app/app.go b/internal/app/app.go index 5e7e5c3f..e7051dca 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -174,6 +174,13 @@ func (a *App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } + // Lifecycle messages must run before modal/command-mode focus, otherwise + // async results (e.g. awsContextReadyMsg) get swallowed while a modal is open + // and state flags like awsInitializing never clear. + if model, cmd, handled := a.handleAppLifecycleMsg(msg); handled { + return model, cmd + } + if a.modal != nil { return a.handleModalUpdate(msg) } @@ -410,47 +417,6 @@ func (a *App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.clipboardFlash = "" return a, nil - case awsContextReadyMsg: - a.awsInitializing = false - if msg.err != nil { - errStr := msg.err.Error() - // IMDS errors are expected on non-EC2 environments - log only, no warning - if strings.Contains(errStr, "ec2imds") { - log.Debug("IMDS region detection failed (expected on non-EC2)", "error", msg.err) - } else { - log.Debug("AWS context initialization failed", "error", msg.err) - config.Global().AddWarning("AWS init failed: " + errStr) - a.showWarnings = true - } - } - return a, nil - - case profileRefreshDoneMsg: - if msg.refreshID != a.profileRefreshID { - log.Debug("ignoring stale profile refresh", "got", msg.refreshID, "want", a.profileRefreshID) - return a, nil - } - a.profileRefreshing = false - a.profileRefreshError = msg.err - if msg.err != nil { - log.Warn("profile refresh failed", "error", msg.err) - return a, nil - } - if msg.region != "" { - config.Global().AddRegion(msg.region) - } - if config.File().PersistenceEnabled() { - if err := config.File().SaveRegions(config.Global().Regions()); err != nil { - log.Warn("failed to persist regions", "error", err) - } - } - if len(msg.accountIDs) > 0 { - for profileID, accountID := range msg.accountIDs { - config.Global().SetAccountIDForProfile(profileID, accountID) - } - } - return a, nil - case startupResourceMsg: if a.startupPath == nil { return a, nil @@ -605,6 +571,52 @@ func (a *App) renderWarnings() string { ) } +func (a *App) handleAppLifecycleMsg(msg tea.Msg) (tea.Model, tea.Cmd, bool) { + switch msg := msg.(type) { + case awsContextReadyMsg: + a.awsInitializing = false + if msg.err != nil { + errStr := msg.err.Error() + // IMDS errors are expected on non-EC2 environments - log only, no warning + if strings.Contains(errStr, "ec2imds") { + log.Debug("IMDS region detection failed (expected on non-EC2)", "error", msg.err) + } else { + log.Debug("AWS context initialization failed", "error", msg.err) + config.Global().AddWarning("AWS init failed: " + errStr) + a.showWarnings = true + } + } + return a, nil, true + + case profileRefreshDoneMsg: + if msg.refreshID != a.profileRefreshID { + log.Debug("ignoring stale profile refresh", "got", msg.refreshID, "want", a.profileRefreshID) + return a, nil, true + } + a.profileRefreshing = false + a.profileRefreshError = msg.err + if msg.err != nil { + log.Warn("profile refresh failed", "error", msg.err) + return a, nil, true + } + if msg.region != "" { + config.Global().AddRegion(msg.region) + } + if config.File().PersistenceEnabled() { + if err := config.File().SaveRegions(config.Global().Regions()); err != nil { + log.Warn("failed to persist regions", "error", err) + } + } + if len(msg.accountIDs) > 0 { + for profileID, accountID := range msg.accountIDs { + config.Global().SetAccountIDForProfile(profileID, accountID) + } + } + return a, nil, true + } + return a, nil, false +} + func (a *App) handleModalUpdate(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { case view.HideModalMsg: diff --git a/internal/aws/integration_test.go b/internal/aws/integration_test.go index 8024768b..f6f69125 100644 --- a/internal/aws/integration_test.go +++ b/internal/aws/integration_test.go @@ -6,6 +6,11 @@ import ( "context" "os" "testing" + + "github.com/aws/aws-sdk-go-v2/service/cloudformation" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/s3" ) // Integration tests require LocalStack to be running @@ -19,10 +24,11 @@ func TestIntegration_EC2Client(t *testing.T) { ctx := context.Background() - client, err := Global().EC2(ctx) + cfg, err := NewConfig(ctx) if err != nil { - t.Fatalf("Failed to get EC2 client: %v", err) + t.Fatalf("Failed to create config: %v", err) } + client := ec2.NewFromConfig(cfg) if client == nil { t.Fatal("EC2 client should not be nil") @@ -36,10 +42,11 @@ func TestIntegration_S3Client(t *testing.T) { ctx := context.Background() - client, err := Global().S3(ctx) + cfg, err := NewConfig(ctx) if err != nil { - t.Fatalf("Failed to get S3 client: %v", err) + t.Fatalf("Failed to create config: %v", err) } + client := s3.NewFromConfig(cfg) if client == nil { t.Fatal("S3 client should not be nil") @@ -53,10 +60,11 @@ func TestIntegration_IAMClient(t *testing.T) { ctx := context.Background() - client, err := Global().IAM(ctx) + cfg, err := NewConfig(ctx) if err != nil { - t.Fatalf("Failed to get IAM client: %v", err) + t.Fatalf("Failed to create config: %v", err) } + client := iam.NewFromConfig(cfg) if client == nil { t.Fatal("IAM client should not be nil") @@ -70,10 +78,11 @@ func TestIntegration_CloudFormationClient(t *testing.T) { ctx := context.Background() - client, err := Global().CloudFormation(ctx) + cfg, err := NewConfig(ctx) if err != nil { - t.Fatalf("Failed to get CloudFormation client: %v", err) + t.Fatalf("Failed to create config: %v", err) } + client := cloudformation.NewFromConfig(cfg) if client == nil { t.Fatal("CloudFormation client should not be nil") diff --git a/internal/aws/profiles.go b/internal/aws/profiles.go index c64a19ef..f1e05298 100644 --- a/internal/aws/profiles.go +++ b/internal/aws/profiles.go @@ -11,6 +11,7 @@ import ( "gopkg.in/ini.v1" + appconfig "github.com/clawscli/claws/internal/config" "github.com/clawscli/claws/internal/log" ) @@ -65,6 +66,10 @@ func LoadProfiles() ([]ProfileInfo, error) { } else { continue } + if !appconfig.IsValidProfileName(profileName) { + log.Debug("skipping invalid aws profile name", "profile", profileName) + continue + } ssoStartURL := section.Key("sso_start_url").String() ssoSession := section.Key("sso_session").String() @@ -106,6 +111,10 @@ func LoadProfiles() ([]ProfileInfo, error) { if name == "DEFAULT" { continue } + if !appconfig.IsValidProfileName(name) { + log.Debug("skipping invalid aws credentials profile name", "profile", name) + continue + } accessKeyID := section.Key("aws_access_key_id").String() hasCredentials := accessKeyID != "" diff --git a/internal/aws/profiles_test.go b/internal/aws/profiles_test.go new file mode 100644 index 00000000..805a851b --- /dev/null +++ b/internal/aws/profiles_test.go @@ -0,0 +1,39 @@ +package aws + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadProfilesSkipsInvalidProfileNames(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config") + credentialsPath := filepath.Join(dir, "credentials") + + configData := []byte("[profile valid-profile]\nregion = us-east-1\n[profile bad; echo injected]\nregion = us-west-2\n") + if err := os.WriteFile(configPath, configData, 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + credentialsData := []byte("[valid-profile]\naws_access_key_id = AKIA1234567890ABCD\n[bad; echo injected]\naws_access_key_id = AKIA1234567890EFGH\n") + if err := os.WriteFile(credentialsPath, credentialsData, 0o600); err != nil { + t.Fatalf("write credentials: %v", err) + } + + t.Setenv("AWS_CONFIG_FILE", configPath) + t.Setenv("AWS_SHARED_CREDENTIALS_FILE", credentialsPath) + + profiles, err := LoadProfiles() + if err != nil { + t.Fatalf("LoadProfiles() returned error: %v", err) + } + + for _, profile := range profiles { + if profile.Name == "bad; echo injected" { + t.Fatalf("LoadProfiles() returned invalid profile name: %+v", profile) + } + } + if len(profiles) != 1 || profiles[0].Name != "valid-profile" { + t.Fatalf("profiles = %+v, want only valid-profile", profiles) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index e5446a1d..aa81d3b3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -336,7 +336,7 @@ func (c *Config) GetAccountIDForProfile(profileID string) string { } func (c *Config) Warnings() []string { - return withRLock(&c.mu, func() []string { return c.warnings }) + return withRLock(&c.mu, func() []string { return append([]string(nil), c.warnings...) }) } func (c *Config) ReadOnly() bool { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 6f5373ec..d09bff99 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -134,6 +134,27 @@ func TestConfig_Warnings(t *testing.T) { } } +func TestConfig_WarningsReturnsDefensiveCopy(t *testing.T) { + cfg := &Config{} + cfg.AddWarning("warning 1") + cfg.AddWarning("warning 2") + + warnings := cfg.Warnings() + warnings[0] = "mutated" + mutated := append(warnings, "injected") + if len(mutated) != 3 { + t.Fatalf("mutated copy length = %d, want 3", len(mutated)) + } + + got := cfg.Warnings() + if len(got) != 2 { + t.Fatalf("Warnings() length = %d, want 2", len(got)) + } + if got[0] != "warning 1" { + t.Fatalf("Warnings()[0] = %q, want original warning", got[0]) + } +} + func TestGlobal(t *testing.T) { // Should return non-nil config cfg := Global() diff --git a/internal/sanitize/text.go b/internal/sanitize/text.go new file mode 100644 index 00000000..dc113a25 --- /dev/null +++ b/internal/sanitize/text.go @@ -0,0 +1,49 @@ +package sanitize + +import ( + "regexp" + "strings" + "unicode" +) + +const Redacted = "[REDACTED]" + +var sensitiveAssignmentPattern = regexp.MustCompile(`(?i)(^|[^A-Za-z0-9_])((?:aws[_-]?)?secret[_-]?access[_-]?key|password|passwd|pwd|secret|token|api[_-]?key|access[_-]?key(?:[_-]?id)?|credential)(\s*[:=]\s*)("[^"]*"|'[^']*'|[^\s,;]+)`) +var uriCredentialPattern = regexp.MustCompile(`(?i)\b([a-z][a-z0-9+.-]*://)([^/\s:@]+):([^@\s/]+)@`) +var bearerCredentialPattern = regexp.MustCompile(`(?i)\bbearer\s+[A-Za-z0-9._~+/=-]{16,}`) +var basicCredentialPattern = regexp.MustCompile(`\b[Bb]asic\s+[A-Za-z0-9+/=]*[A-Z0-9+/=][A-Za-z0-9+/=]{7,}`) +var jwtPattern = regexp.MustCompile(`\beyJ[A-Za-z0-9_-]*\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\b`) +var awsAccessKeyPattern = regexp.MustCompile(`\b(?:AKIA|ASIA)[A-Z0-9]{16}\b`) +var pemBlockPattern = regexp.MustCompile(`(?s)-----BEGIN [A-Z0-9 ]+-----.*?-----END [A-Z0-9 ]+-----`) +var ansiEscapePattern = regexp.MustCompile(`\x1b\[[0-?]*[ -/]*[@-~]|\x1b\][^\x07]*(\x07|\x1b\\)|\x1b[@-Z\\-_]`) + +// TerminalText removes ANSI escape sequences and control characters that can alter terminal state. +func TerminalText(s string) string { + s = ansiEscapePattern.ReplaceAllString(s, "") + return strings.Map(func(r rune) rune { + if r == '\t' { + return r + } + if unicode.IsControl(r) { + return -1 + } + return r + }, s) +} + +// SensitiveText redacts common key=value or key:value secret assignments. +func SensitiveText(s string) string { + s = sensitiveAssignmentPattern.ReplaceAllString(s, `${1}${2}${3}`+Redacted) + s = uriCredentialPattern.ReplaceAllString(s, `${1}`+Redacted+`@`) + s = bearerCredentialPattern.ReplaceAllString(s, `Bearer `+Redacted) + s = basicCredentialPattern.ReplaceAllString(s, `Basic `+Redacted) + s = jwtPattern.ReplaceAllString(s, Redacted) + s = awsAccessKeyPattern.ReplaceAllString(s, Redacted) + s = pemBlockPattern.ReplaceAllString(s, Redacted) + return s +} + +// LogText prepares untrusted log text for display or AI output. +func LogText(s string) string { + return SensitiveText(TerminalText(s)) +} diff --git a/internal/sanitize/text_test.go b/internal/sanitize/text_test.go new file mode 100644 index 00000000..841bfc35 --- /dev/null +++ b/internal/sanitize/text_test.go @@ -0,0 +1,111 @@ +package sanitize + +import ( + "strings" + "testing" +) + +func TestLogTextRedactsCommonSecretAssignments(t *testing.T) { + tests := []struct { + name string + input string + secret string + }{ + { + name: "aws secret access key", + input: "AWS_SECRET_ACCESS_KEY=plain-secret", + secret: "plain-secret", + }, + { + name: "quoted token with spaces", + input: `token="plain secret with spaces"`, + secret: "plain secret with spaces", + }, + { + name: "colon secret", + input: "secret:plain-secret", + secret: "plain-secret", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := LogText(tt.input) + if strings.Contains(got, tt.secret) { + t.Fatalf("LogText(%q) leaked secret in %q", tt.input, got) + } + if !strings.Contains(got, Redacted) { + t.Fatalf("LogText(%q) = %q, want redaction marker", tt.input, got) + } + }) + } +} + +func TestLogTextRemovesTerminalEscapeSequences(t *testing.T) { + got := LogText("ok \x1b[31mred\x1b[0m") + if strings.Contains(got, "\x1b") || strings.Contains(got, "[31m") { + t.Fatalf("LogText left terminal escape sequence in %q", got) + } + if !strings.Contains(got, "ok red") { + t.Fatalf("LogText removed visible text, got %q", got) + } +} + +func TestSensitiveTextRedactsValueOnlySecretPatterns(t *testing.T) { + tests := []struct { + name string + input string + secret string + }{ + { + name: "uri credentials", + input: "postgres://app:super-secret-password@db.example.com:5432/app", + secret: "super-secret-password", + }, + { + name: "bearer token", + input: "Authorization: Bearer abcdefghijklmnop", + secret: "abcdefghijklmnop", + }, + { + name: "basic token", + input: "Authorization: Basic dXNlcjpwYXNz", + secret: "dXNlcjpwYXNz", + }, + { + name: "jwt", + input: "jwt eyJhbGciOiJIUzI1NiJ9.payload.signature", + secret: "eyJhbGciOiJIUzI1NiJ9.payload.signature", + }, + { + name: "aws access key id", + input: "caller AKIAIOSFODNN7EXAMPLE", + secret: "AKIAIOSFODNN7EXAMPLE", + }, + { + name: "pem block", + input: "-----BEGIN PRIVATE KEY-----\nplain-private-key\n-----END PRIVATE KEY-----", + secret: "plain-private-key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SensitiveText(tt.input) + if strings.Contains(got, tt.secret) { + t.Fatalf("SensitiveText(%q) leaked secret in %q", tt.input, got) + } + if !strings.Contains(got, Redacted) { + t.Fatalf("SensitiveText(%q) = %q, want redaction marker", tt.input, got) + } + }) + } +} + +func TestSensitiveTextPreservesBasicDocumentationPhrase(t *testing.T) { + input := "basic authentication for CloudFront" + got := SensitiveText(input) + if got != input { + t.Fatalf("SensitiveText(%q) = %q, want unchanged documentation phrase", input, got) + } +} diff --git a/internal/view/action_menu.go b/internal/view/action_menu.go index a8f9b0d3..2979fec5 100644 --- a/internal/view/action_menu.go +++ b/internal/view/action_menu.go @@ -250,14 +250,21 @@ func (m *ActionMenu) getConfirmToken(act action.Action) string { func (m *ActionMenu) executeAction(act action.Action) (tea.Model, tea.Cmd) { if act.Type == action.ActionTypeExec { m.lastExecAction = &act - execCmd, err := action.ExpandVariables(act.Command, m.resource) + var execCommand string + var execArgs []string + var err error + if len(act.Args) > 0 { + execArgs, err = action.ExpandArgs(act.Args, m.resource) + } else { + execCommand, err = action.ExpandVariables(act.Command, m.resource) + } if err != nil { - return m, func() tea.Msg { - return execResultMsg{success: false, err: err} - } + return m, func() tea.Msg { return execResultMsg{success: false, err: err} } } exec := &action.ExecWithHeader{ - Command: execCmd, + Context: m.ctx, + Command: execCommand, + Args: execArgs, ActionName: act.Name, Resource: m.resource, Service: m.service, @@ -350,18 +357,14 @@ func (m *ActionMenu) renderDangerousConfirm(act action.Action) string { content += fmt.Sprintf("You are about to %s:\n", s.no.Render(act.Name)) content += s.bold.Render(m.dangerous.token) + "\n\n" - suffix := action.ConfirmSuffix(m.dangerous.token) - if len(suffix) < len(m.dangerous.token) { - content += fmt.Sprintf("Type last %d chars: ...%s\n", len(suffix), suffix) - } else { - content += "Type to confirm:\n" - } + confirmText := action.ConfirmSuffix(m.dangerous.token) + content += "Type the full confirmation token:\n" inputStyle := s.input matched := action.ConfirmMatches(m.dangerous.token, m.dangerous.input) if matched { inputStyle = inputStyle.BorderForeground(t.Success) - } else if len(m.dangerous.input) > 0 && strings.HasPrefix(suffix, m.dangerous.input) { + } else if len(m.dangerous.input) > 0 && strings.HasPrefix(confirmText, m.dangerous.input) { inputStyle = inputStyle.BorderForeground(t.Warning) } content += inputStyle.Render(m.dangerous.input+"▌") + "\n\n" @@ -390,14 +393,11 @@ func (m *ActionMenu) SetSize(_, _ int) tea.Cmd { func (m *ActionMenu) StatusLine() string { if m.dangerous.active { - suffix := action.ConfirmSuffix(m.dangerous.token) - if m.dangerous.input != "" && !strings.HasPrefix(suffix, m.dangerous.input) { + confirmText := action.ConfirmSuffix(m.dangerous.token) + if m.dangerous.input != "" && !strings.HasPrefix(confirmText, m.dangerous.input) { return "Token does not match" } - if len(suffix) < len(m.dangerous.token) { - return fmt.Sprintf("Type last %d chars to confirm", len(suffix)) - } - return "Type resource ID to confirm" + return "Type full confirmation token" } if m.confirming { return "Confirm: Y/N" diff --git a/internal/view/action_menu_test.go b/internal/view/action_menu_test.go index 75389354..16d29a1a 100644 --- a/internal/view/action_menu_test.go +++ b/internal/view/action_menu_test.go @@ -36,18 +36,18 @@ func TestActionMenuConfirmDangerousCorrectToken(t *testing.T) { menu.dangerous.token = "i-12345" // Default: uses GetID() menu.dangerous.input = "" - // Type the correct suffix (last 6 chars of "i-12345" = "-12345") - suffix := action.ConfirmSuffix("i-12345") - for _, r := range suffix { + // Type the full confirmation token. + confirmText := action.ConfirmSuffix("i-12345") + for _, r := range confirmText { msg := tea.KeyPressMsg{Text: string(r), Code: r} menu.Update(msg) } - if menu.dangerous.input != suffix { - t.Errorf("dangerousInput = %q, want %q", menu.dangerous.input, suffix) + if menu.dangerous.input != confirmText { + t.Errorf("dangerousInput = %q, want %q", menu.dangerous.input, confirmText) } - // Press enter - should accept since input matches suffix + // Press enter - should accept since input matches the full token. enterMsg := tea.KeyPressMsg{Code: tea.KeyEnter} menu.Update(enterMsg) @@ -63,6 +63,26 @@ func TestActionMenuConfirmDangerousCorrectToken(t *testing.T) { } } +func TestActionMenuConfirmDangerousSuffixOnlyRejected(t *testing.T) { + ctx := context.Background() + resource := &mockResource{id: "i-1234567890abcdef0", name: "test-instance"} + + menu := NewActionMenu(ctx, resource, "test", "items") + menu.dangerous.active = true + menu.confirmIdx = 0 + menu.dangerous.token = "i-1234567890abcdef0" + menu.dangerous.input = "bcdef0" + + menu.Update(tea.KeyPressMsg{Code: tea.KeyEnter}) + + if !menu.dangerous.active { + t.Error("Expected dangerousConfirm to remain true after suffix-only token + enter") + } + if menu.dangerous.input != "bcdef0" { + t.Errorf("Expected dangerousInput to remain %q, got %q", "bcdef0", menu.dangerous.input) + } +} + func TestActionMenuConfirmDangerousWrongToken(t *testing.T) { ctx := context.Background() resource := &mockResource{id: "i-12345", name: "test-instance"} @@ -218,3 +238,16 @@ func TestActionMenuConfirmDangerousHasActiveInput(t *testing.T) { t.Error("Expected HasActiveInput() to be true when dangerousConfirm is active") } } + +func TestActionMenuDangerousStatusLineFullToken(t *testing.T) { + ctx := context.Background() + resource := &mockResource{id: "i-12345", name: "test-instance"} + + menu := NewActionMenu(ctx, resource, "test", "items") + menu.dangerous.active = true + menu.dangerous.token = resource.GetID() + + if got, want := menu.StatusLine(), "Type full confirmation token"; got != want { + t.Errorf("StatusLine() = %q, want %q", got, want) + } +} diff --git a/internal/view/chat_overlay.go b/internal/view/chat_overlay.go index 70811784..889acd00 100644 --- a/internal/view/chat_overlay.go +++ b/internal/view/chat_overlay.go @@ -160,7 +160,7 @@ func (c *ChatOverlay) Init() tea.Cmd { } func (c *ChatOverlay) initClient() tea.Msg { - executor, err := ai.NewToolExecutor(c.ctx, c.registry) + executor, err := ai.NewToolExecutor(c.ctx, c.registry, c.aiCtx) if err != nil { return chatInitMsg{err: apperrors.Wrap(err, "init tool executor")} } diff --git a/internal/view/chat_overlay_prompt.go b/internal/view/chat_overlay_prompt.go index 2e7d1eb9..ae80650d 100644 --- a/internal/view/chat_overlay_prompt.go +++ b/internal/view/chat_overlay_prompt.go @@ -39,7 +39,7 @@ Available tools: - tail_logs(service, resource_type, region, id, cluster?, profile?): Fetches CloudWatch logs for a resource - Supported: lambda/functions, ecs/services, ecs/tasks, ecs/task-definitions, codebuild/projects, codebuild/builds, cloudtrail/trails, apigateway/stages, apigateway/stages-v2, stepfunctions/state-machines - cluster parameter required for ecs/services and ecs/tasks -- search_aws_docs(query): Search AWS documentation +- search_aws_docs(query): Search AWS documentation using only general, public AWS terms. This sends the query to an external AWS documentation search endpoint; never include resource IDs, account IDs, ARNs, profile names, log lines, tag values, or secrets. Queries containing private or sensitive context are rejected before external search. diff --git a/internal/view/command_input.go b/internal/view/command_input.go index 20294763..5b463d76 100644 --- a/internal/view/command_input.go +++ b/internal/view/command_input.go @@ -612,7 +612,8 @@ func (c *CommandInput) parseSortArgs(args string) tea.Cmd { func (c *CommandInput) executeLogin(profileName string) tea.Cmd { exec := &action.SimpleExec{ - Command: fmt.Sprintf("aws login --remote --profile %s", profileName), + Context: c.ctx, + Args: []string{"aws", "login", "--remote", "--profile", profileName}, ActionName: action.ActionNameLogin, SkipAWSEnv: true, } diff --git a/internal/view/log_view.go b/internal/view/log_view.go index 90253f93..3403dada 100644 --- a/internal/view/log_view.go +++ b/internal/view/log_view.go @@ -18,6 +18,7 @@ import ( "github.com/clawscli/claws/internal/config" apperrors "github.com/clawscli/claws/internal/errors" "github.com/clawscli/claws/internal/log" + "github.com/clawscli/claws/internal/sanitize" "github.com/clawscli/claws/internal/ui" ) @@ -228,7 +229,7 @@ func (v *LogView) processLogEvents(events []types.FilteredLogEvent, older bool) for _, event := range events { ts := time.UnixMilli(appaws.Int64(event.Timestamp)) - msg := appaws.Str(event.Message) + msg := sanitize.LogText(appaws.Str(event.Message)) entries = append(entries, logEntry{ timestamp: ts, message: strings.TrimSuffix(msg, "\n"), @@ -275,7 +276,7 @@ func (v *LogView) Update(msg tea.Msg) (tea.Model, tea.Cmd) { v.err = nil if msg.older { if len(msg.entries) > 0 { - v.logs = append(msg.entries, v.logs...) + v.logs = append(sanitizeLogEntries(msg.entries), v.logs...) if len(v.logs) > maxLogBufferSize { v.logs = v.logs[:maxLogBufferSize] } @@ -295,7 +296,7 @@ func (v *LogView) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if v.oldestEventTime == 0 && len(msg.entries) > 0 { v.oldestEventTime = msg.entries[0].timestamp.UnixMilli() } - v.logs = append(v.logs, msg.entries...) + v.logs = append(v.logs, sanitizeLogEntries(msg.entries)...) if len(v.logs) > maxLogBufferSize { v.logs = v.logs[len(v.logs)-maxLogBufferSize:] } @@ -413,6 +414,15 @@ func (v *LogView) updateViewportContent() { v.vp.Model.SetContent(sb.String()) } +func sanitizeLogEntries(entries []logEntry) []logEntry { + sanitized := make([]logEntry, len(entries)) + for i, entry := range entries { + sanitized[i] = entry + sanitized[i].message = sanitize.LogText(entry.message) + } + return sanitized +} + func (v *LogView) handleFilterInput(msg tea.KeyPressMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "esc": diff --git a/internal/view/log_view_test.go b/internal/view/log_view_test.go index 62038ceb..93f64db4 100644 --- a/internal/view/log_view_test.go +++ b/internal/view/log_view_test.go @@ -70,6 +70,28 @@ func TestLogViewLogsLoadedSuccess(t *testing.T) { } } +func TestLogViewSanitizesAndRedactsLogMessages(t *testing.T) { + ctx := context.Background() + lv := NewLogView(ctx, "/aws/test") + lv.SetSize(80, 24) + + entries := []logEntry{ + {timestamp: time.Now(), message: "token=plain-secret \x1b[31mred"}, + } + lv.Update(logsLoadedMsg{entries: entries, lastEventTime: 1}) + + if strings.Contains(lv.logs[0].message, "plain-secret") { + t.Fatalf("stored log message should redact sensitive values, got %q", lv.logs[0].message) + } + if strings.Contains(lv.logs[0].message, "\x1b") || strings.Contains(lv.logs[0].message, "[31m") { + t.Fatalf("stored log message should remove terminal escape sequences, got %q", lv.logs[0].message) + } + view := lv.ViewString() + if strings.Contains(view, "plain-secret") || strings.Contains(view, "[31m") { + t.Fatalf("rendered view leaked unsafe log message: %q", view) + } +} + func TestLogViewLogsLoadedError(t *testing.T) { ctx := context.Background() lv := NewLogView(ctx, "/aws/test") diff --git a/internal/view/profile_selector.go b/internal/view/profile_selector.go index 922337f8..4f09d28f 100644 --- a/internal/view/profile_selector.go +++ b/internal/view/profile_selector.go @@ -1,10 +1,7 @@ package view import ( - "context" "fmt" - "io" - "os/exec" "strings" tea "charm.land/bubbletea/v2" @@ -205,18 +202,24 @@ func (p *ProfileSelector) ssoLoginCurrentProfile() (tea.Model, tea.Cmd) { return p, nil } - if _, err := exec.LookPath("aws"); err != nil { + awsPath, err := action.ResolveExecutable("aws") + if err != nil { p.loginResult = &loginResultMsg{ profileID: profile.id, success: false, - err: fmt.Errorf("aws CLI not found in PATH"), + err: fmt.Errorf("aws CLI not found in PATH: %w", err), } p.updateExtraHeight() return p, nil } profileID := profile.id - return p, tea.Exec(&ssoLoginCmd{profileName: profileID}, func(err error) tea.Msg { + execCmd := &action.SimpleExec{ + Args: []string{awsPath, "sso", "login", "--profile", profileID}, + ActionName: action.ActionNameSSOLogin, + SkipAWSEnv: true, + } + return p, tea.Exec(execCmd, func(err error) tea.Msg { if err != nil { return loginResultMsg{profileID: profileID, success: false, err: err} } @@ -224,25 +227,6 @@ func (p *ProfileSelector) ssoLoginCurrentProfile() (tea.Model, tea.Cmd) { }) } -type ssoLoginCmd struct { - profileName string - stdin io.Reader - stdout io.Writer - stderr io.Writer -} - -func (s *ssoLoginCmd) Run() error { - cmd := exec.CommandContext(context.Background(), "aws", "sso", "login", "--profile", s.profileName) - cmd.Stdin = s.stdin - cmd.Stdout = s.stdout - cmd.Stderr = s.stderr - return cmd.Run() -} - -func (s *ssoLoginCmd) SetStdin(r io.Reader) { s.stdin = r } -func (s *ssoLoginCmd) SetStdout(w io.Writer) { s.stdout = w } -func (s *ssoLoginCmd) SetStderr(w io.Writer) { s.stderr = w } - func (p *ProfileSelector) consoleLoginCurrentProfile() (tea.Model, tea.Cmd) { profile, ok := p.selector.CurrentItem() if !ok { @@ -271,11 +255,11 @@ func (p *ProfileSelector) consoleLoginCurrentProfile() (tea.Model, tea.Cmd) { return p, nil } - if _, err := exec.LookPath("aws"); err != nil { + if _, err := action.ResolveExecutable("aws"); err != nil { p.loginResult = &loginResultMsg{ profileID: profile.id, success: false, - err: fmt.Errorf("aws CLI not found in PATH"), + err: fmt.Errorf("aws CLI not found in PATH: %w", err), isConsoleLogin: true, } p.updateExtraHeight() @@ -283,10 +267,11 @@ func (p *ProfileSelector) consoleLoginCurrentProfile() (tea.Model, tea.Cmd) { } profileID := profile.id - execCmd := &action.SimpleExec{ - Command: "aws login --remote --profile " + profileID, - ActionName: action.ActionNameLogin, - SkipAWSEnv: true, + execCmd, err := newProfileLoginExec(profileID) + if err != nil { + p.loginResult = &loginResultMsg{profileID: profileID, success: false, err: err, isConsoleLogin: true} + p.updateExtraHeight() + return p, nil } return p, tea.Exec(execCmd, func(err error) tea.Msg { if err != nil { @@ -298,6 +283,21 @@ func (p *ProfileSelector) consoleLoginCurrentProfile() (tea.Model, tea.Cmd) { }) } +func newProfileLoginExec(profileID string) (*action.SimpleExec, error) { + if !config.IsValidProfileName(profileID) { + return nil, fmt.Errorf("invalid profile name: %s", profileID) + } + awsPath, err := action.ResolveExecutable("aws") + if err != nil { + return nil, fmt.Errorf("aws CLI not found in PATH: %w", err) + } + return &action.SimpleExec{ + Args: []string{awsPath, "login", "--remote", "--profile", profileID}, + ActionName: action.ActionNameLogin, + SkipAWSEnv: true, + }, nil +} + func (p *ProfileSelector) ViewString() string { content := p.selector.ViewString() diff --git a/internal/view/resource_browser_fetch.go b/internal/view/resource_browser_fetch.go index a6ac697d..10b21226 100644 --- a/internal/view/resource_browser_fetch.go +++ b/internal/view/resource_browser_fetch.go @@ -391,10 +391,7 @@ type resourcesErrorMsg struct { } func (r *ResourceBrowser) shouldLoadNextPage() bool { - if !r.hasMorePages || r.isLoadingMore || r.loading { - return false - } - if r.nextPageToken == "" && len(r.nextPageTokens) == 0 && len(r.nextMultiPageTokens) == 0 { + if !r.hasLoadableNextPage() { return false } if r.filterText != "" && len(r.filtered) < 10 { @@ -407,6 +404,13 @@ func (r *ResourceBrowser) shouldLoadNextPage() bool { return r.tc.Cursor() >= len(r.filtered)-buffer } +func (r *ResourceBrowser) hasLoadableNextPage() bool { + if !r.hasMorePages || r.isLoadingMore || r.loading { + return false + } + return r.nextPageToken != "" || len(r.nextPageTokens) > 0 || len(r.nextMultiPageTokens) > 0 +} + func (r *ResourceBrowser) loadNextPage() tea.Msg { if len(r.nextMultiPageTokens) > 0 { return r.loadNextPageMultiProfile() diff --git a/internal/view/resource_browser_input.go b/internal/view/resource_browser_input.go index b2f9f5a4..d9a85da9 100644 --- a/internal/view/resource_browser_input.go +++ b/internal/view/resource_browser_input.go @@ -221,7 +221,7 @@ func (r *ResourceBrowser) handleNumberKey(key string) (tea.Model, tea.Cmd) { } func (r *ResourceBrowser) handleLoadNextPage() (tea.Model, tea.Cmd) { - if r.hasMorePages && !r.isLoadingMore && (r.nextPageToken != "" || len(r.nextPageTokens) > 0) { + if r.hasLoadableNextPage() { r.isLoadingMore = true return r, r.loadNextPage } diff --git a/internal/view/resource_browser_test.go b/internal/view/resource_browser_test.go index 7c5a08a5..c1128c34 100644 --- a/internal/view/resource_browser_test.go +++ b/internal/view/resource_browser_test.go @@ -523,6 +523,75 @@ func TestFetchParallelWithPageTokens(t *testing.T) { } } +func TestHandleLoadNextPageWithMultiProfileTokens(t *testing.T) { + browser := NewResourceBrowser(context.Background(), registry.New(), "ec2") + browser.loading = false + browser.hasMorePages = true + browser.nextMultiPageTokens = map[profileRegionKey]string{ + {Profile: "dev", Region: "us-east-1"}: "token-1", + } + + _, cmd := browser.handleLoadNextPage() + + if cmd == nil { + t.Fatal("handleLoadNextPage() returned nil cmd, want loadNextPage cmd") + } + if !browser.isLoadingMore { + t.Fatal("handleLoadNextPage() did not set isLoadingMore") + } +} + +func TestShouldLoadNextPageWithMultiProfileTokens(t *testing.T) { + browser := NewResourceBrowser(context.Background(), registry.New(), "ec2") + browser.loading = false + browser.hasMorePages = true + browser.nextMultiPageTokens = map[profileRegionKey]string{ + {Profile: "dev", Region: "us-east-1"}: "token-1", + } + for i := 0; i < 20; i++ { + browser.filtered = append(browser.filtered, &mockResource{id: "item"}) + } + browser.tc.SetCursor(10, len(browser.filtered)) + + if !browser.shouldLoadNextPage() { + t.Fatal("shouldLoadNextPage() = false, want true near bottom with nextMultiPageTokens") + } + + browser.isLoadingMore = true + if browser.shouldLoadNextPage() { + t.Fatal("shouldLoadNextPage() = true while isLoadingMore, want false") + } +} + +func TestHandleNextPageLoadedUpdatesMultiProfileTokens(t *testing.T) { + browser := NewResourceBrowser(context.Background(), registry.New(), "ec2") + browser.isLoadingMore = true + browser.resources = []dao.Resource{&mockResource{id: "existing"}} + browser.filtered = browser.resources + remaining := map[profileRegionKey]string{ + {Profile: "prod", Region: "ap-northeast-1"}: "token-2", + } + + browser.handleNextPageLoaded(nextPageLoadedMsg{ + resources: []dao.Resource{&mockResource{id: "next"}}, + nextMultiPageTokens: remaining, + hasMorePages: true, + }) + + if browser.isLoadingMore { + t.Fatal("handleNextPageLoaded() left isLoadingMore true") + } + if len(browser.resources) != 2 { + t.Fatalf("resources len = %d, want 2", len(browser.resources)) + } + if browser.nextMultiPageTokens[profileRegionKey{Profile: "prod", Region: "ap-northeast-1"}] != "token-2" { + t.Fatalf("nextMultiPageTokens not updated: %#v", browser.nextMultiPageTokens) + } + if !browser.hasMorePages { + t.Fatal("hasMorePages = false, want true") + } +} + func TestFetchParallelPartialErrors(t *testing.T) { ctx := context.Background() keys := []string{"ok", "fail", "ok2"}