From 77d948691d890ed29275fbae794cea53046f0ea6 Mon Sep 17 00:00:00 2001 From: Ali Falahi Date: Sun, 29 Mar 2026 09:41:10 -0600 Subject: [PATCH 1/3] feat: add AWS SSO support to client layer Fix GetResource interface parameter order to match implementation. Add UpdateEntitlement method to C1Client interface. Add IsAWSPermissionSet detection in task.go. Add output.JSON constant for format checks. --- pkg/client/client.go | 2 +- pkg/client/error.go | 2 +- pkg/client/task.go | 24 ++++++++++++++++++++++++ pkg/output/output.go | 7 +++++-- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/pkg/client/client.go b/pkg/client/client.go index 9f171123..8daa4d41 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -155,7 +155,7 @@ type C1Client interface { GetUser(ctx context.Context, userID string) (*shared.User, error) GetEntitlement(ctx context.Context, appID string, entitlementID string) (*shared.AppEntitlement, error) SearchEntitlements(ctx context.Context, filter *SearchEntitlementsFilter) ([]*EntitlementWithBindings, error) - GetResource(ctx context.Context, appID string, resourceID string, resourceTypeID string) (*shared.AppResource, error) + GetResource(ctx context.Context, appID string, resourceTypeID string, resourceID string) (*shared.AppResource, error) GetResourceType(ctx context.Context, appID string, resourceTypeID string) (*shared.AppResourceType, error) GetApp(ctx context.Context, appID string) (*shared.App, error) GetTask(ctx context.Context, taskId string) (*shared.TaskServiceGetResponse, error) diff --git a/pkg/client/error.go b/pkg/client/error.go index c597f480..7152e97f 100644 --- a/pkg/client/error.go +++ b/pkg/client/error.go @@ -55,7 +55,7 @@ func HandleErrors(ctx context.Context, v *viper.Viper, input error) error { return input } outputType := v.GetString("output") - if outputType != "json" && outputType != output.JSONPretty { + if outputType != output.JSON && outputType != output.JSONPretty { return input } var jsonError []byte diff --git a/pkg/client/task.go b/pkg/client/task.go index fd345b90..9b6642ca 100644 --- a/pkg/client/task.go +++ b/pkg/client/task.go @@ -2,6 +2,7 @@ package client import ( "context" + "strings" "github.com/conductorone/conductorone-sdk-go/pkg/models/operations" "github.com/conductorone/conductorone-sdk-go/pkg/models/shared" @@ -181,3 +182,26 @@ func (c *client) UpdateTaskRequestData(ctx context.Context, taskID string, reque } return resp.TaskServiceActionResponse, nil } + +// IsAWSPermissionSet checks if an entitlement is an AWS SSO permission set. +func IsAWSPermissionSet(entitlement *shared.AppEntitlement, resourceType *shared.AppResourceType) bool { + if entitlement == nil || resourceType == nil { + return false + } + + if resourceType.DisplayName != nil { + if strings.Contains(strings.ToLower(*resourceType.DisplayName), "aws permission set") { + return true + } + } + + if entitlement.SourceConnectorIds != nil { + for _, value := range entitlement.SourceConnectorIds { + if strings.Contains(value, "arn:aws:sso:::permissionSet/") { + return true + } + } + } + + return false +} diff --git a/pkg/output/output.go b/pkg/output/output.go index 4ff4758b..334680a2 100644 --- a/pkg/output/output.go +++ b/pkg/output/output.go @@ -11,7 +11,10 @@ type Manager interface { Output(ctx context.Context, out interface{}, opts ...outputOption) error } -const JSONPretty = "json-pretty" +const ( + JSON = "json" + JSONPretty = "json-pretty" +) func NewManager(ctx context.Context, v *viper.Viper) Manager { var area *pterm.AreaPrinter @@ -22,7 +25,7 @@ func NewManager(ctx context.Context, v *viper.Viper) Manager { switch v.GetString("output") { case "table": return &tableManager{area: area, isWide: false} - case "json": + case JSON: return &jsonManager{} case JSONPretty: return &jsonManager{pretty: true} From 3bcf4d533016fea31b3cd31173ae2eb3b8570dbc Mon Sep 17 00:00:00 2001 From: Ali Falahi Date: Sun, 29 Mar 2026 09:41:17 -0600 Subject: [PATCH 2/3] feat: add `cone aws` commands for SSO integration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `cone aws setup` — configures SSO URL and regions, scans ConductorOne for AWS permission set entitlements, and creates ~/.aws/config profiles with credential_process pointing to cone. Add `cone aws credentials` — fetches temporary AWS credentials via SSO. Automatically submits a ConductorOne access request if no active grant exists, polls for auto-approval, and retries SSO login on expired tokens. Add `cone aws setup show` — displays current SSO configuration. --- cmd/cone/aws.go | 644 +++++++++++++++++++++++++++++++++++++++++++++++ cmd/cone/main.go | 1 + 2 files changed, 645 insertions(+) create mode 100644 cmd/cone/aws.go diff --git a/cmd/cone/aws.go b/cmd/cone/aws.go new file mode 100644 index 00000000..16c0dd2a --- /dev/null +++ b/cmd/cone/aws.go @@ -0,0 +1,644 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/conductorone/conductorone-sdk-go/pkg/models/shared" + "github.com/conductorone/cone/pkg/client" + "github.com/pterm/pterm" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +const defaultAWSRegion = "us-east-1" + +func awsCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "aws", + Short: "AWS SSO integration commands.", + } + cmd.AddCommand(awsSetupCmd()) + cmd.AddCommand(awsCredentialsCmd()) + return cmd +} + +// --- cone aws setup --- + +func awsSetupCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "setup", + Short: "Configure AWS SSO and create profiles for available permission sets.", + Long: `Scans ConductorOne for all AWS permission set entitlements available to you +and creates corresponding profiles in ~/.aws/config. + +On first run, provide your SSO start URL and regions: + cone aws setup --sso-url https://myorg.awsapps.com/start --sso-region us-east-1 --region us-west-2 + +Subsequent runs reuse the saved config: + cone aws setup + +View current configuration: + cone aws setup show`, + RunE: awsSetupRun, + } + cmd.Flags().String("sso-url", "", "AWS SSO start URL (saved to config for future runs).") + cmd.Flags().String("sso-region", "", "AWS region where SSO Identity Center lives (saved to config, default: us-east-1).") + cmd.Flags().String("region", "", "Default AWS region for CLI profiles, e.g. us-west-2 (saved to config, default: us-east-1).") + cmd.AddCommand(awsSetupShowCmd()) + return cmd +} + +func awsSetupShowCmd() *cobra.Command { + return &cobra.Command{ + Use: "show", + Short: "Show current AWS SSO configuration.", + RunE: func(cmd *cobra.Command, args []string) error { + ssoURL := viper.GetString("aws_sso_start_url") + ssoRegion := viper.GetString("aws_sso_region") + defaultRegion := viper.GetString("aws_default_region") + + if ssoURL == "" && ssoRegion == "" && defaultRegion == "" { + pterm.Warning.Println("No AWS SSO configuration found. Run 'cone aws setup --sso-url ' first.") + return nil + } + + if ssoURL != "" { + pterm.Info.Printf("SSO URL: %s\n", ssoURL) + } else { + pterm.Warning.Println("SSO URL: not set") + } + + if ssoRegion != "" { + pterm.Info.Printf("SSO Region: %s\n", ssoRegion) + } else { + pterm.Warning.Println("SSO Region: not set") + } + + if defaultRegion != "" { + pterm.Info.Printf("Default Region: %s\n", defaultRegion) + } else { + pterm.Warning.Println("Default Region: not set") + } + + return nil + }, + } +} + +func awsSetupRun(cmd *cobra.Command, args []string) error { + ctx, c, _, err := cmdContext(cmd) + if err != nil { + return err + } + + // Resolve SSO URL: flag > saved config. + ssoURL, _ := cmd.Flags().GetString("sso-url") + if ssoURL == "" { + ssoURL = viper.GetString("aws_sso_start_url") + } + if ssoURL == "" { + return fmt.Errorf("--sso-url is required on first run") + } + + // Resolve SSO region (where Identity Center lives): flag > saved config > default. + ssoRegion, _ := cmd.Flags().GetString("sso-region") + if ssoRegion == "" { + ssoRegion = viper.GetString("aws_sso_region") + } + if ssoRegion == "" { + ssoRegion = defaultAWSRegion + } + + // Resolve default AWS region for profiles: flag > saved config > default. + defaultRegion, _ := cmd.Flags().GetString("region") + if defaultRegion == "" { + defaultRegion = viper.GetString("aws_default_region") + } + if defaultRegion == "" { + defaultRegion = defaultAWSRegion + } + + // Persist for future runs. + viper.Set("aws_sso_start_url", ssoURL) + viper.Set("aws_sso_region", ssoRegion) + viper.Set("aws_default_region", defaultRegion) + if err := viper.WriteConfig(); err != nil { + // If config file doesn't exist yet, create it. + if err := viper.SafeWriteConfig(); err != nil { + return fmt.Errorf("failed to save config: %w", err) + } + } + + pterm.Info.Printf("SSO URL: %s\n", ssoURL) + pterm.Info.Printf("SSO Region: %s\n", ssoRegion) + pterm.Info.Printf("Default Region: %s\n", defaultRegion) + + // Search for all entitlements the user can see, expanding resource type info. + entitlements, err := c.SearchEntitlements(ctx, &client.SearchEntitlementsFilter{ + GrantedStatus: shared.GrantedStatusAll, + AppEntitlementExpandMask: shared.AppEntitlementExpandMask{ + Paths: []string{"app_id", "app_resource_type_id", "app_resource_id"}, + }, + }) + if err != nil { + return fmt.Errorf("failed to search entitlements: %w", err) + } + + // Filter to AWS permission sets and create profiles. + created := 0 + skipped := 0 + for _, ent := range entitlements { + appResourceType := client.GetExpanded[shared.AppResourceType](ent, client.ExpandedAppResourceType) + sdkEnt := shared.AppEntitlement(ent.Entitlement) + if !client.IsAWSPermissionSet(&sdkEnt, appResourceType) { + continue + } + + appResource := client.GetExpanded[shared.AppResource](ent, client.ExpandedAppResource) + profileName, err := createAWSProfile(&sdkEnt, appResource, ssoURL, ssoRegion, defaultRegion) + if err != nil { + pterm.Warning.Printf("Skipping %s: %v\n", client.StringFromPtr(ent.Entitlement.DisplayName), err) + skipped++ + continue + } + if profileName == "" { + skipped++ + continue + } + pterm.Success.Printf("Created profile: %s\n", profileName) + created++ + } + + if created == 0 && skipped == 0 { + pterm.Info.Println("No AWS permission set entitlements found.") + } else { + pterm.Info.Printf("\nDone: %d profiles created, %d skipped (already exist or missing data).\n", created, skipped) + } + return nil +} + +// createAWSProfile writes a single profile to ~/.aws/config. +// Returns the profile name if created, empty string if it already exists. +func createAWSProfile(entitlement *shared.AppEntitlement, resource *shared.AppResource, ssoURL string, ssoRegion string, defaultRegion string) (string, error) { + if entitlement == nil { + return "", fmt.Errorf("entitlement is nil") + } + + var accountID string + for _, value := range entitlement.SourceConnectorIds { + parts := strings.Split(value, "|") + if len(parts) == 2 { + accountID = parts[0] + break + } + } + if accountID == "" { + return "", fmt.Errorf("could not extract AWS account ID from sourceConnectorIds") + } + + if entitlement.DisplayName == nil { + return "", fmt.Errorf("entitlement has no display name") + } + roleName := strings.Split(*entitlement.DisplayName, " ")[0] + + accountName := "aws" + if resource != nil && resource.DisplayName != nil { + accountName = strings.ToLower(strings.ReplaceAll(*resource.DisplayName, " ", "-")) + } + profileName := fmt.Sprintf("%s-%s", accountName, strings.ToLower(roleName)) + + awsConfigDir := filepath.Join(os.Getenv("HOME"), ".aws") + if err := os.MkdirAll(awsConfigDir, 0700); err != nil { + return "", fmt.Errorf("failed to create ~/.aws: %w", err) + } + + configPath := filepath.Join(awsConfigDir, "config") + configContent, err := os.ReadFile(configPath) + if err != nil && !os.IsNotExist(err) { + return "", fmt.Errorf("failed to read AWS config: %w", err) + } + + configStr := string(configContent) + + // Skip if profile already exists. + if strings.Contains(configStr, fmt.Sprintf("[profile %s]", profileName)) { + return "", nil + } + + ssoSessionExists := strings.Contains(configStr, "[sso-session cone-sso]") + + newConfig := fmt.Sprintf(` +[profile %s] +credential_process = cone aws credentials "%s" +cone_sso_account_id = %s +cone_sso_role_name = %s +cone_sso_region = %s +cone_sso_start_url = %s +cone_sso_registration_scopes = sso:account:access +sso_session = cone-sso +region = %s +output = json +`, profileName, profileName, accountID, roleName, ssoRegion, ssoURL, defaultRegion) + + if !ssoSessionExists { + newConfig += fmt.Sprintf(` +[sso-session cone-sso] +sso_start_url = %s +sso_region = %s +sso_registration_scopes = sso:account:access +`, ssoURL, ssoRegion) + } + + if err := os.WriteFile(configPath, append(configContent, []byte(newConfig)...), 0600); err != nil { + return "", fmt.Errorf("failed to write AWS config: %w", err) + } + + return profileName, nil +} + +// --- cone aws credentials --- + +// AWSCredentials is the JSON format expected by AWS credential_process. +type AWSCredentials struct { + Version int `json:"Version"` + AccessKeyID string `json:"AccessKeyId"` + SecretAccessKey string `json:"SecretAccessKey"` + SessionToken string `json:"SessionToken"` + Expiration string `json:"Expiration"` +} + +const autoRequestPollInterval = 3 * time.Second +const autoRequestPollTimeout = 90 * time.Second + +func awsCredentialsCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "credentials ", + Short: "Get temporary AWS credentials for a profile.", + Long: `Retrieve temporary AWS credentials for an AWS SSO profile managed by cone. + +This command checks ConductorOne for an active grant. If you don't have access, +it automatically submits an access request. If the request is auto-approved, +credentials are returned immediately. Otherwise it tells you the request is +pending approval. + +Can be used directly or as an AWS credential_process: + credential_process = cone aws credentials "my-profile"`, + RunE: awsCredentialsRun, + } + return cmd +} + +func awsCredentialsRun(cmd *cobra.Command, args []string) error { + ctx, _, _, err := cmdContext(cmd) + if err != nil { + return err + } + + if err := validateArgLenth(1, args, cmd); err != nil { + return err + } + + profileName := args[0] + + accessResult, err := checkC1Access(ctx, profileName) + if err != nil { + return fmt.Errorf("failed to check access: %w", err) + } + + if !accessResult.hasAccess { + if accessResult.appID == "" || accessResult.entitlementID == "" { + return fmt.Errorf("no entitlement found matching profile %q", profileName) + } + + // Fetch the entitlement to get its max grant duration. + entitlement, err := accessResult.client.GetEntitlement(ctx, accessResult.appID, accessResult.entitlementID) + if err != nil { + return fmt.Errorf("failed to get entitlement details: %w", err) + } + + // Use the entitlement's max duration if set, otherwise leave empty (unlimited). + duration := "" + if entitlement.DurationGrant != nil { + duration = *entitlement.DurationGrant + } + + fmt.Fprintf(os.Stderr, "No active grant for %q — submitting access request...\n", profileName) + + grantResp, err := accessResult.client.CreateGrantTask( + ctx, + accessResult.appID, + accessResult.entitlementID, + accessResult.userID, + "", // appUserId + "Auto-requested via cone aws credentials", // justification + duration, // use entitlement's max duration + false, // emergencyAccess + nil, // requestData + ) + if err != nil { + return fmt.Errorf("failed to submit access request: %w", err) + } + + taskID := client.StringFromPtr(grantResp.TaskView.Task.ID) + fmt.Fprintf(os.Stderr, "Access request submitted (task: %s). Checking for auto-approval...\n", taskID) + + // Poll for approval. + granted := false + deadline := time.Now().Add(autoRequestPollTimeout) + for time.Now().Before(deadline) { + time.Sleep(autoRequestPollInterval) + fmt.Fprintf(os.Stderr, ".") + + taskResp, err := accessResult.client.GetTask(ctx, taskID) + if err != nil { + break + } + task := taskResp.TaskView.Task + if task.State == nil { + continue + } + if *task.State != shared.TaskStateTaskStateClosed { + continue + } + // Task is closed — check if it was granted. + if task.TaskType.TaskTypeGrant != nil && task.TaskType.TaskTypeGrant.Outcome != nil { + if *task.TaskType.TaskTypeGrant.Outcome == shared.TaskTypeGrantOutcomeGrantOutcomeGranted { + granted = true + } + } + break + } + fmt.Fprintf(os.Stderr, "\n") + + if !granted { + fmt.Fprintf(os.Stderr, "Request is pending approval. Retry the command after it's approved.\n") + return fmt.Errorf("access request pending for %q", profileName) + } + + fmt.Fprintf(os.Stderr, "Access granted!\n") + } + + // We have access — fetch credentials. + awsConfigDir := filepath.Join(os.Getenv("HOME"), ".aws") + configPath := filepath.Join(awsConfigDir, "config") + + configContent, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("failed to read AWS config: %w", err) + } + + profileConfig := extractProfileConfig(string(configContent), fmt.Sprintf("[profile %s]", profileName)) + + accountID := extractProfileValue(profileConfig, "cone_sso_account_id") + roleName := extractProfileValue(profileConfig, "cone_sso_role_name") + ssoStartURL := extractProfileValue(profileConfig, "cone_sso_start_url") + ssoRegion := extractProfileValue(profileConfig, "cone_sso_region") + if ssoRegion == "" { + ssoRegion = defaultAWSRegion + } + + if accountID == "" || roleName == "" || ssoStartURL == "" { + return fmt.Errorf("missing required SSO configuration for profile %s", profileName) + } + + creds, err := getTemporaryCredentials(accountID, roleName, ssoStartURL, ssoRegion) + if err != nil { + return err + } + + jsonOutput, err := json.Marshal(creds) + if err != nil { + return fmt.Errorf("failed to marshal credentials: %w", err) + } + + fmt.Fprintln(os.Stdout, string(jsonOutput)) + return nil +} + +// --- helpers --- + +func extractProfileConfig(config, profileSection string) string { + lines := strings.Split(config, "\n") + var profileLines []string + inProfile := false + + for _, line := range lines { + if line == profileSection { + inProfile = true + continue + } + if inProfile { + if strings.HasPrefix(line, "[") { + break + } + profileLines = append(profileLines, line) + } + } + + return strings.Join(profileLines, "\n") +} + +func extractProfileValue(profileConfig, key string) string { + for _, line := range strings.Split(profileConfig, "\n") { + if strings.HasPrefix(strings.TrimSpace(line), key) { + parts := strings.SplitN(line, "=", 2) + if len(parts) == 2 { + return strings.TrimSpace(parts[1]) + } + } + } + return "" +} + +func getSSOToken(ssoStartURL string) (string, error) { + cacheDir := filepath.Join(os.Getenv("HOME"), ".aws", "sso", "cache") + files, err := os.ReadDir(cacheDir) + if err != nil { + return "", fmt.Errorf("failed to read SSO cache: %w", err) + } + + for _, file := range files { + if file.IsDir() || !strings.HasSuffix(file.Name(), ".json") { + continue + } + content, err := os.ReadFile(filepath.Join(cacheDir, file.Name())) + if err != nil { + continue + } + + var cache struct { + AccessToken string `json:"accessToken"` + ExpiresAt time.Time `json:"expiresAt"` + StartURL string `json:"startUrl"` + } + if err := json.Unmarshal(content, &cache); err != nil { + continue + } + + if cache.StartURL == ssoStartURL && cache.ExpiresAt.After(time.Now()) { + return cache.AccessToken, nil + } + } + + return "", fmt.Errorf("no valid SSO token found for %s", ssoStartURL) +} + +func ssoLogin() error { + fmt.Fprintf(os.Stderr, "AWS SSO session expired. Logging in...\n") + loginCmd := exec.Command("aws", "sso", "login", "--sso-session", "cone-sso") + loginCmd.Stdin = os.Stdin + loginCmd.Stdout = os.Stderr + loginCmd.Stderr = os.Stderr + return loginCmd.Run() +} + +func getRoleCredentials(token, accountID, roleName, ssoRegion string) ([]byte, error) { + cmd := exec.Command("aws", "sso", "get-role-credentials", + "--access-token", token, + "--account-id", accountID, + "--role-name", roleName, + "--region", ssoRegion, + "--output", "json") + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return nil, fmt.Errorf("%s", strings.TrimSpace(stderr.String())) + } + return stdout.Bytes(), nil +} + +func getTemporaryCredentials(accountID, roleName, ssoStartURL, ssoRegion string) (*AWSCredentials, error) { + token, err := getSSOToken(ssoStartURL) + if err != nil { + if loginErr := ssoLogin(); loginErr != nil { + return nil, fmt.Errorf("SSO login failed: %w", loginErr) + } + token, err = getSSOToken(ssoStartURL) + if err != nil { + return nil, fmt.Errorf("failed to get SSO token after login: %w", err) + } + } + + output, err := getRoleCredentials(token, accountID, roleName, ssoRegion) + if err != nil { + // Token might be cached but invalid — retry with fresh login. + if strings.Contains(err.Error(), "UnauthorizedException") || strings.Contains(err.Error(), "Session token not found") { + if loginErr := ssoLogin(); loginErr != nil { + return nil, fmt.Errorf("SSO login failed: %w", loginErr) + } + token, err = getSSOToken(ssoStartURL) + if err != nil { + return nil, fmt.Errorf("failed to get SSO token after login: %w", err) + } + output, err = getRoleCredentials(token, accountID, roleName, ssoRegion) + if err != nil { + return nil, fmt.Errorf("failed to get credentials after re-login: %w", err) + } + } else { + return nil, fmt.Errorf("failed to get credentials: %w", err) + } + } + + var response struct { + RoleCredentials struct { + AccessKeyID string `json:"accessKeyId"` + SecretAccessKey string `json:"secretAccessKey"` + SessionToken string `json:"sessionToken"` + Expiration int64 `json:"expiration"` + } `json:"roleCredentials"` + } + if err := json.Unmarshal(output, &response); err != nil { + return nil, fmt.Errorf("failed to parse credentials: %w", err) + } + + return &AWSCredentials{ + Version: 1, + AccessKeyID: response.RoleCredentials.AccessKeyID, + SecretAccessKey: response.RoleCredentials.SecretAccessKey, + SessionToken: response.RoleCredentials.SessionToken, + Expiration: time.UnixMilli(response.RoleCredentials.Expiration).Format(time.RFC3339), + }, nil +} + +type accessCheckResult struct { + hasAccess bool + appID string + entitlementID string + userID string + client client.C1Client +} + +func checkC1Access(ctx context.Context, profileName string) (*accessCheckResult, error) { + // Build a minimal command to get a client context. + tempCmd := &cobra.Command{Use: "temp"} + tempCmd.PersistentFlags().StringP("profile", "p", "default", "") + tempCmd.PersistentFlags().BoolP("non-interactive", "i", false, "") + tempCmd.PersistentFlags().String("client-id", "", "") + tempCmd.PersistentFlags().String("client-secret", "", "") + tempCmd.PersistentFlags().String("api-endpoint", "", "") + tempCmd.PersistentFlags().StringP("output", "o", "table", "") + tempCmd.PersistentFlags().Bool("debug", false, "") + tempCmd.PersistentFlags().String("log-level", "", "") + tempCmd.SetContext(ctx) + + _, c, _, err := cmdContext(tempCmd) + if err != nil { + return nil, err + } + + userInfo, err := c.AuthIntrospect(ctx) + if err != nil { + return nil, err + } + userID := client.StringFromPtr(userInfo.UserID) + + entitlements, err := c.SearchEntitlements(ctx, &client.SearchEntitlementsFilter{ + EntitlementAlias: profileName, + GrantedStatus: shared.GrantedStatusAll, + }) + if err != nil { + return nil, err + } + + result := &accessCheckResult{ + userID: userID, + client: c, + } + + for _, ent := range entitlements { + appID := client.StringFromPtr(ent.Entitlement.AppID) + entID := client.StringFromPtr(ent.Entitlement.ID) + + // Track the first matching entitlement for request submission. + if result.appID == "" { + result.appID = appID + result.entitlementID = entID + } + + grants, err := c.GetGrantsForIdentity(ctx, appID, entID, userID) + if err != nil { + continue + } + for _, grant := range grants { + if grant.CreatedAt != nil && grant.DeletedAt == nil { + result.hasAccess = true + result.appID = appID + result.entitlementID = entID + return result, nil + } + } + } + + return result, nil +} diff --git a/cmd/cone/main.go b/cmd/cone/main.go index a96fc59a..143699cf 100644 --- a/cmd/cone/main.go +++ b/cmd/cone/main.go @@ -81,6 +81,7 @@ func runCli(ctx context.Context) int { cliCmd.AddCommand(tokenCmd()) cliCmd.AddCommand(decryptCredentialCmd()) cliCmd.AddCommand(generateAliasCmd()) + cliCmd.AddCommand(awsCmd()) err = cliCmd.ExecuteContext(ctx) if err != nil { From a7114f4bc7b0b2dad802accc630c814bb845b7f3 Mon Sep 17 00:00:00 2001 From: Robert Chiniquy Date: Sun, 29 Mar 2026 10:04:03 -0700 Subject: [PATCH 3/3] Fix checkC1Access to use caller's client, add AWS CLI check checkC1Access was building a fake cobra.Command with hardcoded flags just to call cmdContext() for a client. Since awsCredentialsRun already has a client from cmdContext, pass it through instead. Also add requireAWSCLI() check before shelling out to aws, so users get a clear error message instead of an exec failure. --- cmd/cone/aws.go | 42 +++++++++++++++++------------------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/cmd/cone/aws.go b/cmd/cone/aws.go index 16c0dd2a..f451ffe1 100644 --- a/cmd/cone/aws.go +++ b/cmd/cone/aws.go @@ -297,7 +297,7 @@ Can be used directly or as an AWS credential_process: } func awsCredentialsRun(cmd *cobra.Command, args []string) error { - ctx, _, _, err := cmdContext(cmd) + ctx, c, _, err := cmdContext(cmd) if err != nil { return err } @@ -308,7 +308,7 @@ func awsCredentialsRun(cmd *cobra.Command, args []string) error { profileName := args[0] - accessResult, err := checkC1Access(ctx, profileName) + accessResult, err := checkC1Access(ctx, c, profileName) if err != nil { return fmt.Errorf("failed to check access: %w", err) } @@ -319,7 +319,7 @@ func awsCredentialsRun(cmd *cobra.Command, args []string) error { } // Fetch the entitlement to get its max grant duration. - entitlement, err := accessResult.client.GetEntitlement(ctx, accessResult.appID, accessResult.entitlementID) + entitlement, err := c.GetEntitlement(ctx, accessResult.appID, accessResult.entitlementID) if err != nil { return fmt.Errorf("failed to get entitlement details: %w", err) } @@ -332,7 +332,7 @@ func awsCredentialsRun(cmd *cobra.Command, args []string) error { fmt.Fprintf(os.Stderr, "No active grant for %q — submitting access request...\n", profileName) - grantResp, err := accessResult.client.CreateGrantTask( + grantResp, err := c.CreateGrantTask( ctx, accessResult.appID, accessResult.entitlementID, @@ -357,7 +357,7 @@ func awsCredentialsRun(cmd *cobra.Command, args []string) error { time.Sleep(autoRequestPollInterval) fmt.Fprintf(os.Stderr, ".") - taskResp, err := accessResult.client.GetTask(ctx, taskID) + taskResp, err := c.GetTask(ctx, taskID) if err != nil { break } @@ -491,6 +491,13 @@ func getSSOToken(ssoStartURL string) (string, error) { return "", fmt.Errorf("no valid SSO token found for %s", ssoStartURL) } +func requireAWSCLI() error { + if _, err := exec.LookPath("aws"); err != nil { + return fmt.Errorf("the AWS CLI is required but was not found on PATH — install it from https://aws.amazon.com/cli/") + } + return nil +} + func ssoLogin() error { fmt.Fprintf(os.Stderr, "AWS SSO session expired. Logging in...\n") loginCmd := exec.Command("aws", "sso", "login", "--sso-session", "cone-sso") @@ -519,6 +526,10 @@ func getRoleCredentials(token, accountID, roleName, ssoRegion string) ([]byte, e } func getTemporaryCredentials(accountID, roleName, ssoStartURL, ssoRegion string) (*AWSCredentials, error) { + if err := requireAWSCLI(); err != nil { + return nil, err + } + token, err := getSSOToken(ssoStartURL) if err != nil { if loginErr := ssoLogin(); loginErr != nil { @@ -576,27 +587,9 @@ type accessCheckResult struct { appID string entitlementID string userID string - client client.C1Client } -func checkC1Access(ctx context.Context, profileName string) (*accessCheckResult, error) { - // Build a minimal command to get a client context. - tempCmd := &cobra.Command{Use: "temp"} - tempCmd.PersistentFlags().StringP("profile", "p", "default", "") - tempCmd.PersistentFlags().BoolP("non-interactive", "i", false, "") - tempCmd.PersistentFlags().String("client-id", "", "") - tempCmd.PersistentFlags().String("client-secret", "", "") - tempCmd.PersistentFlags().String("api-endpoint", "", "") - tempCmd.PersistentFlags().StringP("output", "o", "table", "") - tempCmd.PersistentFlags().Bool("debug", false, "") - tempCmd.PersistentFlags().String("log-level", "", "") - tempCmd.SetContext(ctx) - - _, c, _, err := cmdContext(tempCmd) - if err != nil { - return nil, err - } - +func checkC1Access(ctx context.Context, c client.C1Client, profileName string) (*accessCheckResult, error) { userInfo, err := c.AuthIntrospect(ctx) if err != nil { return nil, err @@ -613,7 +606,6 @@ func checkC1Access(ctx context.Context, profileName string) (*accessCheckResult, result := &accessCheckResult{ userID: userID, - client: c, } for _, ent := range entitlements {