From f85e8d7a08e940b7752bbe75992e16c344d1de84 Mon Sep 17 00:00:00 2001 From: Tim Schindler Date: Mon, 20 Apr 2026 16:44:46 +0200 Subject: [PATCH 1/9] feat: add access request workflow commands (grant request) Implement grant request subcommands: submit, list, get, cancel, approve, reject. Includes workflows service with ISP auth, interactive submit prompts, list filter validation, timezone validation via time.LoadLocation, and injectable DI for testing. --- CHANGELOG.md | 12 + CLAUDE.md | 26 +- cmd/commands.go | 1 + cmd/interfaces.go | 12 + cmd/output_types.go | 29 + cmd/request.go | 199 ++++++ cmd/request_cancel.go | 53 ++ cmd/request_finalize.go | 89 +++ cmd/request_get.go | 47 ++ cmd/request_list.go | 129 ++++ cmd/request_submit.go | 313 +++++++++ cmd/request_test.go | 735 ++++++++++++++++++++++ cmd/test_mocks.go | 37 ++ internal/workflows/logging_client.go | 62 ++ internal/workflows/models/cancel.go | 6 + internal/workflows/models/finalize.go | 7 + internal/workflows/models/form.go | 55 ++ internal/workflows/models/form_test.go | 107 ++++ internal/workflows/models/request.go | 87 +++ internal/workflows/models/request_test.go | 139 ++++ internal/workflows/models/submit.go | 7 + internal/workflows/service.go | 288 +++++++++ internal/workflows/service_config.go | 14 + internal/workflows/service_test.go | 340 ++++++++++ 24 files changed, 2792 insertions(+), 2 deletions(-) create mode 100644 cmd/request.go create mode 100644 cmd/request_cancel.go create mode 100644 cmd/request_finalize.go create mode 100644 cmd/request_get.go create mode 100644 cmd/request_list.go create mode 100644 cmd/request_submit.go create mode 100644 cmd/request_test.go create mode 100644 internal/workflows/logging_client.go create mode 100644 internal/workflows/models/cancel.go create mode 100644 internal/workflows/models/finalize.go create mode 100644 internal/workflows/models/form.go create mode 100644 internal/workflows/models/form_test.go create mode 100644 internal/workflows/models/request.go create mode 100644 internal/workflows/models/request_test.go create mode 100644 internal/workflows/models/submit.go create mode 100644 internal/workflows/service.go create mode 100644 internal/workflows/service_config.go create mode 100644 internal/workflows/service_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index d825489..8f6755e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +### Added + +- `grant request` command group for managing access requests through the approval workflow + - `grant request submit` — submit a new access request with target selection from eligibility, reason, priority, date/time scheduling + - `grant request list` — list access requests with filtering (state, result, priority, role), sorting, and free-text search + - `grant request get ` — view full details of a specific access request + - `grant request cancel ` — cancel an open request with optional reason + - `grant request approve ` — approve a pending request with optional reason + - `grant request reject ` — reject a pending request with optional reason +- All `grant request` subcommands support `--output json` for machine-readable output +- New `internal/workflows/` package implementing the CyberArk Access Requests API client (`/api/workflows/requests`) + ## [0.6.1] - 2026-04-08 ### Fixed diff --git a/CLAUDE.md b/CLAUDE.md index 98cc1a8..ce812c9 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -44,6 +44,22 @@ Custom `SCAAccessService` follows SDK conventions: - `POST /api/access/elevate/groups` — request group membership elevation (response wrapped in `response` key, same as cloud elevation) - **Headers:** `Authorization: Bearer {jwt}`, `X-API-Version: 2.0`, `Content-Type: application/json` +## Access Requests API (Workflows) +- **Base URL:** `https://{subdomain}.uar.{platform_domain}/api` +- **Package:** `internal/workflows/` — `AccessRequestService` (mirrors SCA service pattern with ISP client for "uar" service) +- **Models:** `internal/workflows/models/` — `AccessRequest`, `RequestState`, `RequestResult`, `SubmitAccessRequest`, `CancelAccessRequest`, `FinalizeAccessRequest`, `RequestFormResponse` +- **Endpoints:** + - `GET /api/workflows/request-forms` — get form structure for target category + request type + - `GET /api/workflows/requests` — list access requests (offset/limit pagination, filter/sort/freeText) + - `GET /api/workflows/requests/{requestId}` — get single request details + - `POST /api/workflows/requests` — submit new access request + - `POST /api/workflows/requests/{requestId}/cancel` — cancel an open request + - `POST /api/workflows/requests/{requestId}/finalize` — approve or reject a request +- **Pagination:** offset/limit (not nextToken); `ListRequests` fetches all pages automatically +- **DI interfaces:** `accessRequestService` in `cmd/interfaces.go` +- **Target category:** `CLOUD_CONSOLE` (hardcoded for v1) +- **Headers:** `Authorization: Bearer {jwt}`, `Content-Type: application/json` + ## Testing - TDD: write `_test.go` before `.go` for every package - Table-driven tests @@ -57,6 +73,12 @@ Custom `SCAAccessService` follows SDK conventions: - `grant env` — performs elevation, outputs only `export` statements (no human text); usage: `eval $(grant env --provider aws)`; supports `--refresh` - `grant list` — list eligible targets and groups without triggering elevation; supports `--provider`, `--groups`, `--refresh`, `--output json`; used by LLMs to discover available targets programmatically - `grant revoke` — revoke sessions: direct (`grant revoke `), `--all`, or interactive multi-select; `--yes` skips confirmation +- `grant request` — manage access requests through approval workflow; subcommands: `submit`, `list`, `get`, `cancel`, `approve`, `reject` +- `grant request submit` — submit access request; reuses SCA eligibility for target selection; flags: `--target`, `--role`, `--provider`, `--reason`, `--priority`, `--date`, `--timezone`, `--from`, `--to` +- `grant request list` — list access requests; flags: `--state`, `--result`, `--priority`, `--role` (CREATOR/APPROVER), `--search`, `--sort`, `--desc` +- `grant request get ` — get full request details +- `grant request cancel ` — cancel an open request; optional `--reason` +- `grant request approve ` / `grant request reject ` — finalize a request; optional `--reason` - `grant update` — self-update binary via GitHub Releases (`rhysd/go-github-selfupdate`); guards against dev builds - `--groups` flag on root command shows only Entra ID groups in the interactive selector - `--group` / `-g` flag on root command for direct group membership elevation (`grant --group "Cloud Admins"`) @@ -75,8 +97,8 @@ Custom `SCAAccessService` follows SDK conventions: - `--output` / `-o` persistent flag on root command: `text` (default) or `json` - Validated in `PersistentPreRunE`; JSON mode forces `IsTerminalFunc` to return false (non-interactive) - `cmd/output.go` — `outputFormat` var, `isJSONOutput()`, `writeJSON(w, data)` -- `cmd/output_types.go` — JSON structs: `cloudElevationOutput`, `groupElevationJSON`, `sessionOutput`, `statusOutput`, `revocationOutput`, `favoriteOutput`, `awsCredentialOutput` -- All commands support JSON: root elevation, `env`, `status`, `revoke`, `favorites list` +- `cmd/output_types.go` — JSON structs: `cloudElevationOutput`, `groupElevationJSON`, `sessionOutput`, `statusOutput`, `revocationOutput`, `favoriteOutput`, `awsCredentialOutput`, `accessRequestOutput`, `accessRequestListOutput` +- All commands support JSON: root elevation, `env`, `status`, `revoke`, `favorites list`, `request list`, `request get`, `request submit`, `request cancel`, `request approve`, `request reject` - `config.Favorite` has both `yaml:"..."` and `json:"..."` struct tags ## Cache diff --git a/cmd/commands.go b/cmd/commands.go index eb4aae9..1c6d7d2 100644 --- a/cmd/commands.go +++ b/cmd/commands.go @@ -12,5 +12,6 @@ func init() { NewRevokeCommand(), NewUpdateCommand(), NewListCommand(), + NewRequestCommand(), ) } diff --git a/cmd/interfaces.go b/cmd/interfaces.go index 8e41b35..4d8677e 100644 --- a/cmd/interfaces.go +++ b/cmd/interfaces.go @@ -4,6 +4,8 @@ import ( "context" "github.com/aaearon/grant-cli/internal/sca/models" + "github.com/aaearon/grant-cli/internal/workflows" + wfmodels "github.com/aaearon/grant-cli/internal/workflows/models" "github.com/blang/semver" sdkmodels "github.com/cyberark/idsec-sdk-golang/pkg/models" authmodels "github.com/cyberark/idsec-sdk-golang/pkg/models/auth" @@ -89,3 +91,13 @@ type unifiedSelector interface { type selfUpdater interface { UpdateSelf(current semver.Version, slug string) (*selfupdate.Release, error) } + +// accessRequestService interface for access request operations +type accessRequestService interface { + ListRequests(ctx context.Context, params workflows.ListRequestsParams) ([]wfmodels.AccessRequest, int, error) + GetRequest(ctx context.Context, requestID string) (*wfmodels.AccessRequest, error) + SubmitRequest(ctx context.Context, req *wfmodels.SubmitAccessRequest) (*wfmodels.AccessRequest, error) + CancelRequest(ctx context.Context, requestID string, reason *string) (*wfmodels.AccessRequest, error) + FinalizeRequest(ctx context.Context, requestID string, result string, reason *string) (*wfmodels.AccessRequest, error) +} + diff --git a/cmd/output_types.go b/cmd/output_types.go index add19b0..91dff4f 100644 --- a/cmd/output_types.go +++ b/cmd/output_types.go @@ -64,3 +64,32 @@ type favoriteOutput struct { Group string `json:"group,omitempty"` DirectoryID string `json:"directoryId,omitempty"` } + +// accessRequestOutput is the JSON representation of an access request. +type accessRequestOutput struct { + RequestID string `json:"requestId"` + TargetCategory string `json:"targetCategory"` + State string `json:"state"` + Result string `json:"result"` + Priority string `json:"priority,omitempty"` + Reason string `json:"reason,omitempty"` + Provider string `json:"provider,omitempty"` + Target string `json:"target,omitempty"` + Role string `json:"role,omitempty"` + RequestDate string `json:"requestDate,omitempty"` + Timezone string `json:"timezone,omitempty"` + TimeFrom string `json:"timeFrom,omitempty"` + TimeTo string `json:"timeTo,omitempty"` + FinalizationReason string `json:"finalizationReason,omitempty"` + RequestLink string `json:"requestLink,omitempty"` + CreatedBy string `json:"createdBy"` + CreatedAt string `json:"createdAt"` + UpdatedBy string `json:"updatedBy"` + UpdatedAt string `json:"updatedAt"` +} + +// accessRequestListOutput is the JSON representation of a list of access requests. +type accessRequestListOutput struct { + Requests []accessRequestOutput `json:"requests"` + TotalCount int `json:"totalCount"` +} diff --git a/cmd/request.go b/cmd/request.go new file mode 100644 index 0000000..ffd8584 --- /dev/null +++ b/cmd/request.go @@ -0,0 +1,199 @@ +package cmd + +import ( + "fmt" + "strings" + "text/tabwriter" + + "github.com/aaearon/grant-cli/internal/workflows" + "github.com/aaearon/grant-cli/internal/workflows/models" + "github.com/cyberark/idsec-sdk-golang/pkg/auth" + authmodels "github.com/cyberark/idsec-sdk-golang/pkg/models/auth" + "github.com/cyberark/idsec-sdk-golang/pkg/profiles" + "github.com/spf13/cobra" +) + +// NewRequestCommand creates the "grant request" parent command. +func NewRequestCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "request", + Short: "Manage access requests", + Long: "Create, list, and manage access requests through the approval workflow.", + } + + cmd.AddCommand( + newRequestListCommand(nil), + newRequestGetCommand(nil), + newRequestSubmitCommand(nil), + newRequestCancelCommand(nil), + newRequestApproveCommand(nil), + newRequestRejectCommand(nil), + ) + + return cmd +} + +// NewRequestCommandWithDeps creates the request parent with injected dependencies for testing. +func NewRequestCommandWithDeps(reqSvc accessRequestService) *cobra.Command { + cmd := &cobra.Command{ + Use: "request", + Short: "Manage access requests", + } + + cmd.AddCommand( + newRequestListCommand(reqSvc), + newRequestGetCommand(reqSvc), + newRequestSubmitCommand(reqSvc), + newRequestCancelCommand(reqSvc), + newRequestApproveCommand(reqSvc), + newRequestRejectCommand(reqSvc), + ) + + return cmd +} + +// bootstrapWorkflowsService creates an authenticated AccessRequestService. +func bootstrapWorkflowsService() (*workflows.AccessRequestService, error) { + loader := profiles.DefaultProfilesLoader() + profile, err := (*loader).LoadProfile("grant") + if err != nil { + return nil, fmt.Errorf("failed to load profile: %w", err) + } + + ispAuth := auth.NewIdsecISPAuth(true) + + _, err = ispAuth.Authenticate(profile, nil, &authmodels.IdsecSecret{Secret: ""}, false, true) + if err != nil { + return nil, fmt.Errorf("authentication failed: %w", err) + } + + svc, err := workflows.NewAccessRequestService(ispAuth) + if err != nil { + return nil, fmt.Errorf("failed to create access request service: %w", err) + } + + return svc, nil +} + +// formatRequestTable writes a table of access requests to the command output. +func formatRequestTable(cmd *cobra.Command, requests []models.AccessRequest) { + w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "ID\tSTATE\tRESULT\tTARGET\tROLE\tPRIORITY\tCREATED BY\tCREATED AT") + for _, r := range requests { + fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n", + r.RequestID, + r.RequestState, + r.RequestResult, + r.DetailString("workspaceName"), + r.DetailString("roleName"), + r.DetailString("priority"), + r.CreatedBy, + formatTimestamp(r.CreatedAt), + ) + } + w.Flush() +} + +// formatRequestDetail writes a detailed view of a single access request. +func formatRequestDetail(cmd *cobra.Command, r *models.AccessRequest) { + w := cmd.OutOrStdout() + fmt.Fprintf(w, "Request ID: %s\n", r.RequestID) + fmt.Fprintf(w, "State: %s\n", r.RequestState) + fmt.Fprintf(w, "Result: %s\n", r.RequestResult) + fmt.Fprintf(w, "Category: %s\n", r.TargetCategory) + + if v := r.DetailString("locationType"); v != "" { + fmt.Fprintf(w, "Provider: %s\n", v) + } + if v := r.DetailString("workspaceName"); v != "" { + fmt.Fprintf(w, "Target: %s\n", v) + } + if v := r.DetailString("roleName"); v != "" { + fmt.Fprintf(w, "Role: %s\n", v) + } + if v := r.DetailString("reason"); v != "" { + fmt.Fprintf(w, "Reason: %s\n", v) + } + if v := r.DetailString("priority"); v != "" { + fmt.Fprintf(w, "Priority: %s\n", v) + } + if v := r.DetailString("requestDate"); v != "" { + fmt.Fprintf(w, "Request Date: %s\n", v) + } + if v := r.DetailString("timezone"); v != "" { + fmt.Fprintf(w, "Timezone: %s\n", v) + } + if v := r.DetailString("timeFrom"); v != "" { + fmt.Fprintf(w, "Time From: %s\n", v) + } + if v := r.DetailString("timeTo"); v != "" { + fmt.Fprintf(w, "Time To: %s\n", v) + } + + fmt.Fprintf(w, "Created By: %s\n", r.CreatedBy) + fmt.Fprintf(w, "Created At: %s\n", formatTimestamp(r.CreatedAt)) + fmt.Fprintf(w, "Updated By: %s\n", r.UpdatedBy) + fmt.Fprintf(w, "Updated At: %s\n", formatTimestamp(r.UpdatedAt)) + + if r.FinalizationReason != "" { + fmt.Fprintf(w, "Finalization: %s\n", r.FinalizationReason) + } + if r.RequestLink != "" { + fmt.Fprintf(w, "Link: %s\n", r.RequestLink) + } + + if len(r.AssignedApprovers) > 0 { + names := make([]string, len(r.AssignedApprovers)) + for i, a := range r.AssignedApprovers { + if a.EntityDisplayName != "" { + names[i] = fmt.Sprintf("%s (%s)", a.EntityDisplayName, a.EntityEmail) + } else { + names[i] = a.EntityName + } + } + fmt.Fprintf(w, "Approvers: %s\n", strings.Join(names, ", ")) + } + + if len(r.RequestApprovers) > 0 { + for _, a := range r.RequestApprovers { + name := a.Approver.EntityDisplayName + if name == "" { + name = a.Approver.EntityName + } + fmt.Fprintf(w, "Acted: %s - %s\n", name, a.Result) + } + } +} + +// formatTimestamp truncates a timestamp to just the date+time portion (no microseconds). +func formatTimestamp(ts string) string { + if len(ts) > 19 { + return ts[:19] + } + return ts +} + +// toAccessRequestOutput converts a model to the JSON output type. +func toAccessRequestOutput(r *models.AccessRequest) accessRequestOutput { + return accessRequestOutput{ + RequestID: r.RequestID, + TargetCategory: r.TargetCategory, + State: string(r.RequestState), + Result: string(r.RequestResult), + Priority: r.DetailString("priority"), + Reason: r.DetailString("reason"), + Provider: r.DetailString("locationType"), + Target: r.DetailString("workspaceName"), + Role: r.DetailString("roleName"), + RequestDate: r.DetailString("requestDate"), + Timezone: r.DetailString("timezone"), + TimeFrom: r.DetailString("timeFrom"), + TimeTo: r.DetailString("timeTo"), + FinalizationReason: r.FinalizationReason, + RequestLink: r.RequestLink, + CreatedBy: r.CreatedBy, + CreatedAt: r.CreatedAt, + UpdatedBy: r.UpdatedBy, + UpdatedAt: r.UpdatedAt, + } +} diff --git a/cmd/request_cancel.go b/cmd/request_cancel.go new file mode 100644 index 0000000..8845f5c --- /dev/null +++ b/cmd/request_cancel.go @@ -0,0 +1,53 @@ +package cmd + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +func newRequestCancelCommand(svc accessRequestService) *cobra.Command { + cmd := &cobra.Command{ + Use: "cancel ", + Short: "Cancel an open access request", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if svc == nil { + bootstrapped, err := bootstrapWorkflowsService() + if err != nil { + return err + } + svc = bootstrapped + } + return runRequestCancel(cmd, args[0], svc) + }, + } + + cmd.Flags().String("reason", "", "Reason for cancellation") + + return cmd +} + +func runRequestCancel(cmd *cobra.Command, requestID string, svc accessRequestService) error { + ctx := cmd.Context() + + var reason *string + if v, _ := cmd.Flags().GetString("reason"); v != "" { + reason = &v + } + + log.Info("Canceling access request %s", requestID) + + result, err := svc.CancelRequest(ctx, requestID, reason) + if err != nil { + return fmt.Errorf("failed to cancel request: %w", err) + } + + if isJSONOutput() { + return writeJSON(cmd.OutOrStdout(), toAccessRequestOutput(result)) + } + + fmt.Fprintf(cmd.OutOrStdout(), "Request %s canceled.\n", result.RequestID) + fmt.Fprintf(cmd.OutOrStdout(), "Result: %s\n", result.RequestResult) + return nil +} diff --git a/cmd/request_finalize.go b/cmd/request_finalize.go new file mode 100644 index 0000000..106f8a3 --- /dev/null +++ b/cmd/request_finalize.go @@ -0,0 +1,89 @@ +package cmd + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +func newRequestApproveCommand(svc accessRequestService) *cobra.Command { + cmd := &cobra.Command{ + Use: "approve ", + Short: "Approve an access request", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if svc == nil { + bootstrapped, err := bootstrapWorkflowsService() + if err != nil { + return err + } + svc = bootstrapped + } + return runFinalize(cmd, args[0], "APPROVED", svc) + }, + } + + cmd.Flags().String("reason", "", "Reason for approval") + + return cmd +} + +func newRequestRejectCommand(svc accessRequestService) *cobra.Command { + cmd := &cobra.Command{ + Use: "reject ", + Short: "Reject an access request", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if svc == nil { + bootstrapped, err := bootstrapWorkflowsService() + if err != nil { + return err + } + svc = bootstrapped + } + return runFinalize(cmd, args[0], "REJECTED", svc) + }, + } + + cmd.Flags().String("reason", "", "Reason for rejection") + + return cmd +} + +func runFinalize(cmd *cobra.Command, requestID, decision string, svc accessRequestService) error { + ctx := cmd.Context() + + var reason *string + if v, _ := cmd.Flags().GetString("reason"); v != "" { + reason = &v + } + + log.Info("Finalizing access request %s with result %s", requestID, decision) + + result, err := svc.FinalizeRequest(ctx, requestID, decision, reason) + if err != nil { + return fmt.Errorf("failed to %s request: %w", decisionVerb(decision), err) + } + + if isJSONOutput() { + return writeJSON(cmd.OutOrStdout(), toAccessRequestOutput(result)) + } + + fmt.Fprintf(cmd.OutOrStdout(), "Request %s %s.\n", result.RequestID, decisionPastTense(decision)) + fmt.Fprintf(cmd.OutOrStdout(), "Result: %s\n", result.RequestResult) + return nil +} + +func decisionVerb(decision string) string { + if decision == "APPROVED" { + return "approve" + } + return "reject" +} + +func decisionPastTense(decision string) string { + if decision == "APPROVED" { + return "approved" + } + return "rejected" +} diff --git a/cmd/request_get.go b/cmd/request_get.go new file mode 100644 index 0000000..a61ff87 --- /dev/null +++ b/cmd/request_get.go @@ -0,0 +1,47 @@ +package cmd + +import ( + "errors" + "fmt" + + "github.com/spf13/cobra" +) + +func newRequestGetCommand(svc accessRequestService) *cobra.Command { + return &cobra.Command{ + Use: "get ", + Short: "Get details of an access request", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if svc == nil { + bootstrapped, err := bootstrapWorkflowsService() + if err != nil { + return err + } + svc = bootstrapped + } + return runRequestGet(cmd, args[0], svc) + }, + } +} + +func runRequestGet(cmd *cobra.Command, requestID string, svc accessRequestService) error { + if requestID == "" { + return errors.New("request ID is required") + } + + ctx := cmd.Context() + log.Info("Getting access request %s", requestID) + + result, err := svc.GetRequest(ctx, requestID) + if err != nil { + return fmt.Errorf("failed to get request: %w", err) + } + + if isJSONOutput() { + return writeJSON(cmd.OutOrStdout(), toAccessRequestOutput(result)) + } + + formatRequestDetail(cmd, result) + return nil +} diff --git a/cmd/request_list.go b/cmd/request_list.go new file mode 100644 index 0000000..475c3d1 --- /dev/null +++ b/cmd/request_list.go @@ -0,0 +1,129 @@ +package cmd + +import ( + "errors" + "fmt" + "strings" + + "github.com/aaearon/grant-cli/internal/workflows" + "github.com/spf13/cobra" +) + +func newRequestListCommand(svc accessRequestService) *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "List access requests", + Long: "Retrieve a list of access requests with optional filtering, sorting, and pagination.", + RunE: func(cmd *cobra.Command, args []string) error { + if svc == nil { + bootstrapped, err := bootstrapWorkflowsService() + if err != nil { + return err + } + svc = bootstrapped + } + return runRequestList(cmd, svc) + }, + } + + cmd.Flags().String("state", "", "Filter by state: STARTING, RUNNING, PENDING, FINISHED, EXPIRED") + cmd.Flags().String("result", "", "Filter by result: APPROVED, REJECTED, CANCELED, FAILED, UNKNOWN") + cmd.Flags().String("priority", "", "Filter by priority: High, Medium, Low") + cmd.Flags().String("role", "", "Request role: CREATOR, APPROVER") + cmd.Flags().String("search", "", "Free text search") + cmd.Flags().String("sort", "createdAt", "Sort field: createdAt, updatedAt, calculatedRequestStartTime") + cmd.Flags().Bool("desc", true, "Sort descending") + + return cmd +} + +var ( + validStates = map[string]bool{"STARTING": true, "RUNNING": true, "PENDING": true, "FINISHED": true, "EXPIRED": true} + validResults = map[string]bool{"APPROVED": true, "REJECTED": true, "CANCELED": true, "FAILED": true, "UNKNOWN": true} + validPriorities = map[string]bool{"High": true, "Medium": true, "Low": true} + validSorts = map[string]bool{"createdAt": true, "updatedAt": true, "calculatedRequestStartTime": true} +) + +func runRequestList(cmd *cobra.Command, svc accessRequestService) error { + ctx := cmd.Context() + + params := workflows.ListRequestsParams{} + + var filters []string + if v, _ := cmd.Flags().GetString("state"); v != "" { + upper := strings.ToUpper(v) + if !validStates[upper] { + return fmt.Errorf("--state must be one of STARTING, RUNNING, PENDING, FINISHED, EXPIRED (got %q)", v) + } + filters = append(filters, fmt.Sprintf("(requestState eq %s)", upper)) + } + if v, _ := cmd.Flags().GetString("result"); v != "" { + upper := strings.ToUpper(v) + if !validResults[upper] { + return fmt.Errorf("--result must be one of APPROVED, REJECTED, CANCELED, FAILED, UNKNOWN (got %q)", v) + } + filters = append(filters, fmt.Sprintf("(requestResult eq %s)", upper)) + } + if v, _ := cmd.Flags().GetString("priority"); v != "" { + if !validPriorities[v] { + return fmt.Errorf("--priority must be one of High, Medium, Low (got %q)", v) + } + filters = append(filters, fmt.Sprintf("(priority eq '%s')", v)) + } + if len(filters) > 0 { + params.Filter = "(" + strings.Join(filters, " and ") + ")" + } + + if v, _ := cmd.Flags().GetString("search"); v != "" { + params.FreeText = v + } + + if v, _ := cmd.Flags().GetString("role"); v != "" { + role := strings.ToUpper(v) + if role != "CREATOR" && role != "APPROVER" { + return errors.New("--role must be CREATOR or APPROVER") + } + params.RequestRole = role + } + + sortField, _ := cmd.Flags().GetString("sort") + desc, _ := cmd.Flags().GetBool("desc") + if sortField != "" { + if !validSorts[sortField] { + return fmt.Errorf("--sort must be one of createdAt, updatedAt, calculatedRequestStartTime (got %q)", sortField) + } + order := "asc" + if desc { + order = "desc" + } + params.Sort = sortField + " " + order + } + + log.Info("Listing access requests with params: filter=%q freeText=%q role=%q sort=%q", + params.Filter, params.FreeText, params.RequestRole, params.Sort) + + items, totalCount, err := svc.ListRequests(ctx, params) + if err != nil { + return fmt.Errorf("failed to list requests: %w", err) + } + + if isJSONOutput() { + outputs := make([]accessRequestOutput, len(items)) + for i := range items { + outputs[i] = toAccessRequestOutput(&items[i]) + } + return writeJSON(cmd.OutOrStdout(), accessRequestListOutput{ + Requests: outputs, + TotalCount: totalCount, + }) + } + + if len(items) == 0 { + fmt.Fprintln(cmd.OutOrStdout(), "No access requests found.") + return nil + } + + formatRequestTable(cmd, items) + fmt.Fprintf(cmd.OutOrStdout(), "\nTotal: %d\n", totalCount) + return nil +} diff --git a/cmd/request_submit.go b/cmd/request_submit.go new file mode 100644 index 0000000..2bcf0fc --- /dev/null +++ b/cmd/request_submit.go @@ -0,0 +1,313 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + "time" + + survey "github.com/Iilun/survey/v2" + "github.com/aaearon/grant-cli/internal/config" + "github.com/aaearon/grant-cli/internal/sca/models" + "github.com/aaearon/grant-cli/internal/ui" + wfmodels "github.com/aaearon/grant-cli/internal/workflows/models" + "github.com/spf13/cobra" +) + +func newRequestSubmitCommand(svc accessRequestService) *cobra.Command { + cmd := &cobra.Command{ + Use: "submit", + Short: "Submit an access request", + Long: "Submit a new access request for cloud resource access through the approval workflow.", + RunE: func(cmd *cobra.Command, args []string) error { + if svc == nil { + bootstrapped, err := bootstrapWorkflowsService() + if err != nil { + return err + } + svc = bootstrapped + } + return runRequestSubmit(cmd, svc) + }, + } + + cmd.Flags().StringP("provider", "p", "", "Cloud provider: azure, aws") + cmd.Flags().StringP("target", "t", "", "Target workspace name") + cmd.Flags().StringP("role", "r", "", "Role name") + cmd.Flags().String("reason", "", "Reason for the request (required)") + cmd.Flags().String("priority", "Medium", "Priority: High, Medium, Low") + cmd.Flags().String("date", "", "Request date (YYYY-MM-DD)") + cmd.Flags().String("timezone", "", "Timezone (TZ identifier, e.g. America/New_York)") + cmd.Flags().String("from", "", "Start time (HH:MM)") + cmd.Flags().String("to", "", "End time (HH:MM)") + + return cmd +} + +// submitPromptFn is injectable for testing interactive prompts. +var submitPromptFn = defaultSubmitPrompt + +// resolveSubmitTargetFn is injectable for testing target resolution. +var resolveSubmitTargetFn = resolveSubmitTarget + +type submitFields struct { + reason string + priority string + date string + timezone string + timeFrom string + timeTo string +} + +func defaultSubmitPrompt() (*submitFields, error) { + stdio := survey.WithStdio(os.Stdin, os.Stderr, os.Stderr) + + var reason string + if err := survey.AskOne(&survey.Input{Message: "Reason:"}, &reason, survey.WithValidator(survey.Required), stdio); err != nil { + return nil, err + } + + var priority string + if err := survey.AskOne(&survey.Select{ + Message: "Priority:", + Options: []string{"High", "Medium", "Low"}, + Default: "Medium", + }, &priority, stdio); err != nil { + return nil, err + } + + today := time.Now().Format("2006-01-02") + var date string + if err := survey.AskOne(&survey.Input{Message: "Date (YYYY-MM-DD):", Default: today}, &date, stdio); err != nil { + return nil, err + } + + localTZ := time.Now().Location().String() + var timezone string + if err := survey.AskOne(&survey.Input{Message: "Timezone:", Default: localTZ}, &timezone, stdio); err != nil { + return nil, err + } + + var timeFrom string + if err := survey.AskOne(&survey.Input{Message: "Start time (HH:MM):"}, &timeFrom, survey.WithValidator(survey.Required), stdio); err != nil { + return nil, err + } + + var timeTo string + if err := survey.AskOne(&survey.Input{Message: "End time (HH:MM):"}, &timeTo, survey.WithValidator(survey.Required), stdio); err != nil { + return nil, err + } + + return &submitFields{ + reason: reason, + priority: priority, + date: date, + timezone: timezone, + timeFrom: timeFrom, + timeTo: timeTo, + }, nil +} + +func runRequestSubmit(cmd *cobra.Command, svc accessRequestService) error { + ctx := cmd.Context() + + fields, err := resolveSubmitFields(cmd) + if err != nil { + return err + } + + if err := validateSubmitFields(fields); err != nil { + return err + } + + provider, _ := cmd.Flags().GetString("provider") + targetName, _ := cmd.Flags().GetString("target") + roleName, _ := cmd.Flags().GetString("role") + + target, err := resolveSubmitTargetFn(ctx, provider, targetName, roleName) + if err != nil { + return err + } + + details := buildRequestDetails(target, fields) + + log.Info("Submitting access request for %s / %s", target.WorkspaceName, target.RoleInfo.Name) + + result, err := svc.SubmitRequest(ctx, &wfmodels.SubmitAccessRequest{ + TargetCategory: "CLOUD_CONSOLE", + RequestDetails: details, + }) + if err != nil { + return fmt.Errorf("failed to submit request: %w", err) + } + + if isJSONOutput() { + return writeJSON(cmd.OutOrStdout(), toAccessRequestOutput(result)) + } + + fmt.Fprintf(cmd.OutOrStdout(), "Access request submitted successfully.\n") + fmt.Fprintf(cmd.OutOrStdout(), "Request ID: %s\n", result.RequestID) + fmt.Fprintf(cmd.OutOrStdout(), "State: %s\n", result.RequestState) + if result.RequestLink != "" { + fmt.Fprintf(cmd.OutOrStdout(), "Link: %s\n", result.RequestLink) + } + return nil +} + +func resolveSubmitFields(cmd *cobra.Command) (*submitFields, error) { + f := &submitFields{} + f.reason, _ = cmd.Flags().GetString("reason") + f.priority, _ = cmd.Flags().GetString("priority") + f.date, _ = cmd.Flags().GetString("date") + f.timezone, _ = cmd.Flags().GetString("timezone") + f.timeFrom, _ = cmd.Flags().GetString("from") + f.timeTo, _ = cmd.Flags().GetString("to") + + if f.reason != "" && f.date != "" && f.timezone != "" && f.timeFrom != "" && f.timeTo != "" { + return f, nil + } + + if !ui.IsInteractive() { + return nil, errors.New("non-interactive mode requires --reason, --date, --timezone, --from, --to") + } + + prompted, err := submitPromptFn() + if err != nil { + return nil, err + } + if f.reason == "" { + f.reason = prompted.reason + } + if !cmd.Flags().Changed("priority") && prompted.priority != "" { + f.priority = prompted.priority + } + if f.date == "" { + f.date = prompted.date + } + if f.timezone == "" { + f.timezone = prompted.timezone + } + if f.timeFrom == "" { + f.timeFrom = prompted.timeFrom + } + if f.timeTo == "" { + f.timeTo = prompted.timeTo + } + return f, nil +} + +func resolveSubmitTarget(ctx context.Context, provider, targetName, roleName string) (*models.EligibleTarget, error) { + _, scaSvc, _, err := bootstrapSCAService() + if err != nil { + return nil, fmt.Errorf("failed to bootstrap SCA service: %w", err) + } + + cfg, _, _ := config.LoadDefaultWithPath() + if cfg == nil { + cfg = config.DefaultConfig() + } + cachedLister := buildCachedLister(cfg, false, scaSvc, nil) + + targets, err := fetchEligibility(ctx, cachedLister, provider) + if err != nil { + return nil, fmt.Errorf("failed to fetch eligibility: %w", err) + } + + if targetName != "" && roleName != "" { + target := findMatchingTarget(targets, targetName, roleName) + if target == nil { + return nil, fmt.Errorf("no eligible target found matching target=%q role=%q", targetName, roleName) + } + resolveTargetCSP(target, targets, provider) + return target, nil + } + + if !ui.IsInteractive() { + return nil, errors.New("non-interactive mode requires --target and --role") + } + items := buildCloudSelectionItems(targets) + sel := &uiUnifiedSelector{} + selected, err := sel.SelectItem(items) + if err != nil { + return nil, err + } + resolveTargetCSP(selected.cloud, targets, provider) + return selected.cloud, nil +} + +// API submit payload uses camelCase keys (per spec example), not the snake_case +// form question keys from GET /request-forms. +func buildRequestDetails(target *models.EligibleTarget, f *submitFields) map[string]interface{} { + locationType := string(target.CSP) + if target.CSP == models.CSPAzure { + locationType = "Azure" + } else if target.CSP == models.CSPAWS { + locationType = "AWS" + } + + return map[string]interface{}{ + "locationType": locationType, + "roleId": target.RoleInfo.ID, + "roleName": target.RoleInfo.Name, + "workspaceId": target.WorkspaceID, + "workspaceName": target.WorkspaceName, + "workspaceType": string(target.WorkspaceType), + "orgId": target.OrganizationID, + "reason": f.reason, + "priority": f.priority, + "requestDate": f.date, + "timezone": f.timezone, + "timeFrom": f.timeFrom, + "timeTo": f.timeTo, + } +} + +func validateSubmitFields(f *submitFields) error { + if f.reason == "" { + return errors.New("--reason is required") + } + + validPriorities := map[string]bool{"High": true, "Medium": true, "Low": true} + if !validPriorities[f.priority] { + return fmt.Errorf("--priority must be High, Medium, or Low (got %q)", f.priority) + } + + if f.date != "" { + if _, err := time.Parse("2006-01-02", f.date); err != nil { + return fmt.Errorf("--date must be in YYYY-MM-DD format (got %q)", f.date) + } + } + + if f.timeFrom != "" { + if _, err := time.Parse("15:04", f.timeFrom); err != nil { + return fmt.Errorf("--from must be in HH:MM format (got %q)", f.timeFrom) + } + } + + if f.timeTo != "" { + if _, err := time.Parse("15:04", f.timeTo); err != nil { + return fmt.Errorf("--to must be in HH:MM format (got %q)", f.timeTo) + } + } + + if f.timezone != "" { + if _, err := time.LoadLocation(f.timezone); err != nil { + return fmt.Errorf("--timezone must be a valid TZ identifier (e.g. America/New_York, UTC), got %q", f.timezone) + } + } + + return nil +} + +// buildCloudSelectionItems wraps cloud targets in selectionItems for the unified selector. +func buildCloudSelectionItems(targets []models.EligibleTarget) []selectionItem { + items := make([]selectionItem, len(targets)) + for i := range targets { + items[i] = selectionItem{ + kind: selectionCloud, + cloud: &targets[i], + } + } + return items +} diff --git a/cmd/request_test.go b/cmd/request_test.go new file mode 100644 index 0000000..a662931 --- /dev/null +++ b/cmd/request_test.go @@ -0,0 +1,735 @@ +package cmd + +import ( + "context" + "encoding/json" + "errors" + "strings" + "testing" + + "github.com/aaearon/grant-cli/internal/sca/models" + "github.com/aaearon/grant-cli/internal/ui" + wfmodels "github.com/aaearon/grant-cli/internal/workflows/models" + "github.com/spf13/cobra" +) + +func TestRequestListCommand(t *testing.T) { + tests := []struct { + name string + svc *mockAccessRequestService + args []string + wantContain []string + wantErr bool + }{ + { + name: "list requests text output", + svc: &mockAccessRequestService{ + listItems: []wfmodels.AccessRequest{ + { + RequestID: "req-1", + RequestState: wfmodels.RequestStatePending, + RequestResult: wfmodels.RequestResultUnknown, + RequestDetails: map[string]interface{}{ + "workspaceName": "Azure Subscription", + "roleName": "Contributor", + "priority": "Medium", + }, + CreatedBy: "user@test.com", + CreatedAt: "2025-08-12T09:41:00.594008", + }, + }, + listTotalCount: 1, + }, + args: []string{"list"}, + wantContain: []string{"req-1", "PENDING", "Azure Subscription", "Contributor", "Total: 1"}, + }, + { + name: "list requests JSON output", + svc: &mockAccessRequestService{ + listItems: []wfmodels.AccessRequest{ + { + RequestID: "req-1", + RequestState: wfmodels.RequestStatePending, + RequestResult: wfmodels.RequestResultUnknown, + CreatedBy: "user@test.com", + CreatedAt: "t", + UpdatedBy: "SYSTEM", + UpdatedAt: "t", + }, + }, + listTotalCount: 1, + }, + args: []string{"list", "--output", "json"}, + wantContain: []string{`"requestId"`, `"totalCount"`}, + }, + { + name: "list empty", + svc: &mockAccessRequestService{ + listItems: []wfmodels.AccessRequest{}, + listTotalCount: 0, + }, + args: []string{"list"}, + wantContain: []string{"No access requests found"}, + }, + { + name: "list error", + svc: &mockAccessRequestService{listErr: errors.New("API error")}, + args: []string{"list"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := NewRequestCommandWithDeps(tt.svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + args := append([]string{"request"}, tt.args...) + output, err := executeCommand(root, args...) + + if (err != nil) != tt.wantErr { + t.Fatalf("error = %v, wantErr %v\noutput: %s", err, tt.wantErr, output) + } + for _, want := range tt.wantContain { + if !strings.Contains(output, want) { + t.Errorf("output missing %q\ngot:\n%s", want, output) + } + } + }) + } +} + +func TestRequestGetCommand(t *testing.T) { + tests := []struct { + name string + svc *mockAccessRequestService + args []string + wantContain []string + wantErr bool + }{ + { + name: "get request text", + svc: &mockAccessRequestService{ + getResult: &wfmodels.AccessRequest{ + RequestID: "req-1", + RequestState: wfmodels.RequestStateFinished, + RequestResult: wfmodels.RequestResultApproved, + TargetCategory: "CLOUD_CONSOLE", + RequestDetails: map[string]interface{}{ + "workspaceName": "Azure Sub", + "roleName": "Reader", + "priority": "High", + "reason": "Need access", + }, + FinalizationReason: "Looks good", + CreatedBy: "user@test.com", + CreatedAt: "2025-08-12T09:41:00", + UpdatedBy: "SYSTEM", + UpdatedAt: "2025-08-12T09:42:00", + }, + }, + args: []string{"get", "req-1"}, + wantContain: []string{"req-1", "FINISHED", "APPROVED", "Azure Sub", "Reader", "Looks good"}, + }, + { + name: "get request JSON", + svc: &mockAccessRequestService{ + getResult: &wfmodels.AccessRequest{ + RequestID: "req-1", + RequestState: wfmodels.RequestStateFinished, + RequestResult: wfmodels.RequestResultApproved, + CreatedBy: "user@test.com", + CreatedAt: "t", + UpdatedBy: "SYSTEM", + UpdatedAt: "t", + }, + }, + args: []string{"get", "req-1", "--output", "json"}, + wantContain: []string{`"requestId"`, `"state"`}, + }, + { + name: "get no args", + svc: &mockAccessRequestService{}, + args: []string{"get"}, + wantErr: true, + }, + { + name: "get error", + svc: &mockAccessRequestService{getErr: errors.New("not found")}, + args: []string{"get", "bad-id"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := NewRequestCommandWithDeps(tt.svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + args := append([]string{"request"}, tt.args...) + output, err := executeCommand(root, args...) + + if (err != nil) != tt.wantErr { + t.Fatalf("error = %v, wantErr %v\noutput: %s", err, tt.wantErr, output) + } + for _, want := range tt.wantContain { + if !strings.Contains(output, want) { + t.Errorf("output missing %q\ngot:\n%s", want, output) + } + } + }) + } +} + +func TestRequestCancelCommand(t *testing.T) { + tests := []struct { + name string + svc *mockAccessRequestService + args []string + wantContain []string + wantErr bool + }{ + { + name: "cancel success", + svc: &mockAccessRequestService{ + cancelResult: &wfmodels.AccessRequest{ + RequestID: "req-1", + RequestResult: wfmodels.RequestResultCanceled, + }, + }, + args: []string{"cancel", "req-1"}, + wantContain: []string{"req-1", "canceled"}, + }, + { + name: "cancel with reason", + svc: &mockAccessRequestService{ + cancelResult: &wfmodels.AccessRequest{ + RequestID: "req-1", + RequestResult: wfmodels.RequestResultCanceled, + }, + }, + args: []string{"cancel", "req-1", "--reason", "no longer needed"}, + wantContain: []string{"req-1", "canceled"}, + }, + { + name: "cancel error", + svc: &mockAccessRequestService{cancelErr: errors.New("forbidden")}, + args: []string{"cancel", "req-1"}, + wantErr: true, + }, + { + name: "cancel no args", + svc: &mockAccessRequestService{}, + args: []string{"cancel"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := NewRequestCommandWithDeps(tt.svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + args := append([]string{"request"}, tt.args...) + output, err := executeCommand(root, args...) + + if (err != nil) != tt.wantErr { + t.Fatalf("error = %v, wantErr %v\noutput: %s", err, tt.wantErr, output) + } + for _, want := range tt.wantContain { + if !strings.Contains(output, want) { + t.Errorf("output missing %q\ngot:\n%s", want, output) + } + } + }) + } +} + +func TestRequestApproveCommand(t *testing.T) { + svc := &mockAccessRequestService{ + finalizeResult: &wfmodels.AccessRequest{ + RequestID: "req-1", + RequestResult: wfmodels.RequestResultApproved, + }, + } + + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + output, err := executeCommand(root, "request", "approve", "req-1", "--reason", "looks good") + if err != nil { + t.Fatalf("unexpected error: %v\noutput: %s", err, output) + } + if !strings.Contains(output, "approved") { + t.Errorf("expected 'approved' in output, got:\n%s", output) + } +} + +func TestRequestRejectCommand(t *testing.T) { + svc := &mockAccessRequestService{ + finalizeResult: &wfmodels.AccessRequest{ + RequestID: "req-1", + RequestResult: wfmodels.RequestResultRejected, + }, + } + + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + output, err := executeCommand(root, "request", "reject", "req-1") + if err != nil { + t.Fatalf("unexpected error: %v\noutput: %s", err, output) + } + if !strings.Contains(output, "rejected") { + t.Errorf("expected 'rejected' in output, got:\n%s", output) + } +} + +func TestRequestFinalizeJSON(t *testing.T) { + svc := &mockAccessRequestService{ + finalizeResult: &wfmodels.AccessRequest{ + RequestID: "req-1", + RequestState: wfmodels.RequestStateRunning, + RequestResult: wfmodels.RequestResultApproved, + CreatedBy: "user@test", + CreatedAt: "t", + UpdatedBy: "SYSTEM", + UpdatedAt: "t", + }, + } + + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + output, err := executeCommand(root, "request", "approve", "req-1", "--output", "json") + if err != nil { + t.Fatalf("unexpected error: %v\noutput: %s", err, output) + } + + var result accessRequestOutput + if err := json.Unmarshal([]byte(output), &result); err != nil { + t.Fatalf("failed to parse JSON: %v\nraw: %s", err, output) + } + if result.RequestID != "req-1" { + t.Errorf("requestId: got %q", result.RequestID) + } + if result.Result != "APPROVED" { + t.Errorf("result: got %q", result.Result) + } +} + +func TestRequestListJSON(t *testing.T) { + svc := &mockAccessRequestService{ + listItems: []wfmodels.AccessRequest{ + { + RequestID: "req-1", + RequestState: wfmodels.RequestStatePending, + RequestResult: wfmodels.RequestResultUnknown, + RequestDetails: map[string]interface{}{ + "priority": "High", + "reason": "Need access", + }, + CreatedBy: "user@test", + CreatedAt: "t", + UpdatedBy: "SYSTEM", + UpdatedAt: "t", + }, + }, + listTotalCount: 1, + } + + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + output, err := executeCommand(root, "request", "list", "--output", "json") + if err != nil { + t.Fatalf("unexpected error: %v\noutput: %s", err, output) + } + + var result accessRequestListOutput + if err := json.Unmarshal([]byte(output), &result); err != nil { + t.Fatalf("failed to parse JSON: %v\nraw: %s", err, output) + } + if result.TotalCount != 1 { + t.Errorf("totalCount: got %d", result.TotalCount) + } + if len(result.Requests) != 1 { + t.Fatalf("expected 1 request, got %d", len(result.Requests)) + } + if result.Requests[0].Priority != "High" { + t.Errorf("priority: got %q", result.Requests[0].Priority) + } +} + +func TestValidateSubmitFields(t *testing.T) { + tests := []struct { + name string + fields *submitFields + wantErr bool + }{ + {"valid", &submitFields{"need access", "High", "2026-04-21", "America/New_York", "09:00", "17:00"}, false}, + {"missing reason", &submitFields{"", "Medium", "2026-04-21", "America/New_York", "09:00", "17:00"}, true}, + {"bad priority", &submitFields{"reason", "Urgent", "2026-04-21", "America/New_York", "09:00", "17:00"}, true}, + {"bad date format", &submitFields{"reason", "Medium", "04-21-2026", "America/New_York", "09:00", "17:00"}, true}, + {"bad time from", &submitFields{"reason", "Medium", "2026-04-21", "America/New_York", "9am", "17:00"}, true}, + {"bad time to", &submitFields{"reason", "Medium", "2026-04-21", "America/New_York", "09:00", "5pm"}, true}, + {"bad timezone", &submitFields{"reason", "Medium", "2026-04-21", "Eastern", "09:00", "17:00"}, true}, + {"UTC timezone", &submitFields{"reason", "Medium", "2026-04-21", "UTC", "09:00", "17:00"}, false}, + {"CET timezone", &submitFields{"reason", "Medium", "2026-04-21", "CET", "09:00", "17:00"}, false}, + {"invalid timezone", &submitFields{"reason", "Medium", "2026-04-21", "NotAZone", "09:00", "17:00"}, true}, + {"empty optional fields", &submitFields{"reason", "Medium", "", "", "", ""}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateSubmitFields(tt.fields) + if (err != nil) != tt.wantErr { + t.Errorf("validateSubmitFields() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestFormatTimestamp(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"2025-08-12T09:41:00.594008", "2025-08-12T09:41:00"}, + {"2025-08-12T09:41:00", "2025-08-12T09:41:00"}, + {"short", "short"}, + } + for _, tt := range tests { + if got := formatTimestamp(tt.input); got != tt.want { + t.Errorf("formatTimestamp(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestRequestCancelJSON(t *testing.T) { + svc := &mockAccessRequestService{ + cancelResult: &wfmodels.AccessRequest{ + RequestID: "req-1", + RequestResult: wfmodels.RequestResultCanceled, + CreatedBy: "user@test", + CreatedAt: "t", + UpdatedBy: "SYSTEM", + UpdatedAt: "t", + }, + } + + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + output, err := executeCommand(root, "request", "cancel", "req-1", "--output", "json") + if err != nil { + t.Fatalf("unexpected error: %v\noutput: %s", err, output) + } + + var result accessRequestOutput + if err := json.Unmarshal([]byte(output), &result); err != nil { + t.Fatalf("failed to parse JSON: %v\nraw: %s", err, output) + } + if result.Result != "CANCELED" { + t.Errorf("result: got %q", result.Result) + } +} + +func TestRequestListValidation(t *testing.T) { + svc := &mockAccessRequestService{} + + tests := []struct { + name string + args []string + wantErr string + }{ + {"invalid state", []string{"list", "--state", "INVALID"}, "--state must be one of"}, + {"invalid result", []string{"list", "--result", "BOGUS"}, "--result must be one of"}, + {"invalid priority", []string{"list", "--priority", "Urgent"}, "--priority must be one of"}, + {"invalid sort", []string{"list", "--sort", "badField"}, "--sort must be one of"}, + {"injection attempt", []string{"list", "--state", "') or 1=1--"}, "--state must be one of"}, + {"valid lowercase state", []string{"list", "--state", "running"}, ""}, + {"valid lowercase result", []string{"list", "--result", "approved"}, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + args := append([]string{"request"}, tt.args...) + _, err := executeCommand(root, args...) + + if tt.wantErr == "" { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } else { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error %q does not contain %q", err.Error(), tt.wantErr) + } + } + }) + } +} + +func TestRunRequestSubmit_NonInteractive(t *testing.T) { + original := resolveSubmitTargetFn + defer func() { resolveSubmitTargetFn = original }() + + resolveSubmitTargetFn = func(_ context.Context, _, _, _ string) (*models.EligibleTarget, error) { + return &models.EligibleTarget{ + WorkspaceName: "Test Sub", + WorkspaceID: "ws-1", + WorkspaceType: models.WorkspaceTypeSubscription, + CSP: models.CSPAzure, + RoleInfo: models.RoleInfo{ID: "role-1", Name: "Contributor"}, + }, nil + } + + svc := &mockAccessRequestService{ + submitResult: &wfmodels.AccessRequest{ + RequestID: "req-new", + RequestState: wfmodels.RequestStatePending, + }, + } + + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + output, err := executeCommand(root, "request", "submit", + "--target", "Test Sub", "--role", "Contributor", + "--reason", "need access", "--date", "2026-04-21", + "--timezone", "UTC", "--from", "09:00", "--to", "17:00") + if err != nil { + t.Fatalf("unexpected error: %v\noutput: %s", err, output) + } + if !strings.Contains(output, "req-new") { + t.Errorf("expected request ID in output, got:\n%s", output) + } +} + +func TestRunRequestSubmit_JSONOutput(t *testing.T) { + original := resolveSubmitTargetFn + defer func() { resolveSubmitTargetFn = original }() + + resolveSubmitTargetFn = func(_ context.Context, _, _, _ string) (*models.EligibleTarget, error) { + return &models.EligibleTarget{ + WorkspaceName: "Test Sub", + WorkspaceID: "ws-1", + CSP: models.CSPAzure, + RoleInfo: models.RoleInfo{ID: "role-1", Name: "Contributor"}, + }, nil + } + + svc := &mockAccessRequestService{ + submitResult: &wfmodels.AccessRequest{ + RequestID: "req-json", + RequestState: wfmodels.RequestStatePending, + RequestResult: wfmodels.RequestResultUnknown, + CreatedBy: "user@test", + CreatedAt: "t", + UpdatedBy: "SYSTEM", + UpdatedAt: "t", + }, + } + + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + output, err := executeCommand(root, "request", "submit", + "--target", "Test Sub", "--role", "Contributor", + "--reason", "test", "--date", "2026-04-21", + "--timezone", "UTC", "--from", "09:00", "--to", "17:00", + "--output", "json") + if err != nil { + t.Fatalf("unexpected error: %v\noutput: %s", err, output) + } + + var result accessRequestOutput + if err := json.Unmarshal([]byte(output), &result); err != nil { + t.Fatalf("failed to parse JSON: %v\nraw: %s", err, output) + } + if result.RequestID != "req-json" { + t.Errorf("requestId: got %q", result.RequestID) + } +} + +func TestRunRequestSubmit_ServiceError(t *testing.T) { + original := resolveSubmitTargetFn + defer func() { resolveSubmitTargetFn = original }() + + resolveSubmitTargetFn = func(_ context.Context, _, _, _ string) (*models.EligibleTarget, error) { + return &models.EligibleTarget{ + WorkspaceName: "Sub", + WorkspaceID: "ws-1", + CSP: models.CSPAzure, + RoleInfo: models.RoleInfo{ID: "r1", Name: "Reader"}, + }, nil + } + + svc := &mockAccessRequestService{ + submitErr: errors.New("API failure"), + } + + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + _, err := executeCommand(root, "request", "submit", + "--target", "Sub", "--role", "Reader", + "--reason", "test", "--date", "2026-04-21", + "--timezone", "UTC", "--from", "09:00", "--to", "17:00") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "API failure") { + t.Errorf("error %q does not contain 'API failure'", err.Error()) + } +} + +func TestRunRequestSubmit_MissingFlags_NonInteractive(t *testing.T) { + original := resolveSubmitTargetFn + defer func() { resolveSubmitTargetFn = original }() + + resolveSubmitTargetFn = func(_ context.Context, _, _, _ string) (*models.EligibleTarget, error) { + return &models.EligibleTarget{ + WorkspaceName: "Sub", + WorkspaceID: "ws-1", + CSP: models.CSPAzure, + RoleInfo: models.RoleInfo{ID: "r1", Name: "Reader"}, + }, nil + } + + svc := &mockAccessRequestService{} + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + _, err := executeCommand(root, "request", "submit", + "--target", "Sub", "--role", "Reader", + "--reason", "test") + if err == nil { + t.Fatal("expected error for missing --date/--timezone/--from/--to, got nil") + } + if !strings.Contains(err.Error(), "non-interactive") { + t.Errorf("error %q does not mention non-interactive", err.Error()) + } +} + +func TestResolveSubmitFields_Interactive(t *testing.T) { + originalPrompt := submitPromptFn + defer func() { submitPromptFn = originalPrompt }() + + submitPromptFn = func() (*submitFields, error) { + return &submitFields{ + reason: "prompted reason", + priority: "High", + date: "2026-05-01", + timezone: "America/Chicago", + timeFrom: "10:00", + timeTo: "18:00", + }, nil + } + + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("reason", "", "") + cmd.Flags().String("priority", "Medium", "") + cmd.Flags().String("date", "", "") + cmd.Flags().String("timezone", "", "") + cmd.Flags().String("from", "", "") + cmd.Flags().String("to", "", "") + + originalTTY := ui.IsTerminalFunc + defer func() { ui.IsTerminalFunc = originalTTY }() + ui.IsTerminalFunc = func(fd uintptr) bool { return true } + + f, err := resolveSubmitFields(cmd) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if f.reason != "prompted reason" { + t.Errorf("reason: got %q", f.reason) + } + if f.priority != "High" { + t.Errorf("priority: got %q", f.priority) + } + if f.date != "2026-05-01" { + t.Errorf("date: got %q", f.date) + } +} + +func TestResolveSubmitFields_FlagOverridesPrompt(t *testing.T) { + originalPrompt := submitPromptFn + defer func() { submitPromptFn = originalPrompt }() + + submitPromptFn = func() (*submitFields, error) { + return &submitFields{ + reason: "prompted", + priority: "Low", + date: "2026-05-01", + timezone: "America/Chicago", + timeFrom: "10:00", + timeTo: "18:00", + }, nil + } + + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("reason", "", "") + cmd.Flags().String("priority", "Medium", "") + cmd.Flags().String("date", "", "") + cmd.Flags().String("timezone", "", "") + cmd.Flags().String("from", "", "") + cmd.Flags().String("to", "", "") + // Simulate user passing --reason flag + cmd.SetArgs([]string{"--reason", "flag reason"}) + _ = cmd.Execute() + + originalTTY := ui.IsTerminalFunc + defer func() { ui.IsTerminalFunc = originalTTY }() + ui.IsTerminalFunc = func(fd uintptr) bool { return true } + + f, err := resolveSubmitFields(cmd) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if f.reason != "flag reason" { + t.Errorf("reason: got %q, want 'flag reason'", f.reason) + } +} + +func TestResolveSubmitFields_NonInteractive_Error(t *testing.T) { + originalTTY := ui.IsTerminalFunc + defer func() { ui.IsTerminalFunc = originalTTY }() + ui.IsTerminalFunc = func(fd uintptr) bool { return false } + + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("reason", "", "") + cmd.Flags().String("priority", "Medium", "") + cmd.Flags().String("date", "", "") + cmd.Flags().String("timezone", "", "") + cmd.Flags().String("from", "", "") + cmd.Flags().String("to", "", "") + + _, err := resolveSubmitFields(cmd) + if err == nil { + t.Fatal("expected error for non-interactive mode") + } + if !strings.Contains(err.Error(), "non-interactive") { + t.Errorf("error %q does not mention non-interactive", err.Error()) + } +} diff --git a/cmd/test_mocks.go b/cmd/test_mocks.go index 3c45f4e..8a94395 100644 --- a/cmd/test_mocks.go +++ b/cmd/test_mocks.go @@ -6,6 +6,8 @@ import ( "sync" "github.com/aaearon/grant-cli/internal/sca/models" + "github.com/aaearon/grant-cli/internal/workflows" + wfmodels "github.com/aaearon/grant-cli/internal/workflows/models" "github.com/blang/semver" sdkmodels "github.com/cyberark/idsec-sdk-golang/pkg/models" authmodels "github.com/cyberark/idsec-sdk-golang/pkg/models/auth" @@ -238,6 +240,41 @@ func (m *mockSelfUpdater) UpdateSelf(current semver.Version, slug string) (*self return m.release, m.updateErr } +// mockAccessRequestService implements accessRequestService for testing +type mockAccessRequestService struct { + listItems []wfmodels.AccessRequest + listTotalCount int + listErr error + getResult *wfmodels.AccessRequest + getErr error + submitResult *wfmodels.AccessRequest + submitErr error + cancelResult *wfmodels.AccessRequest + cancelErr error + finalizeResult *wfmodels.AccessRequest + finalizeErr error +} + +func (m *mockAccessRequestService) ListRequests(_ context.Context, _ workflows.ListRequestsParams) ([]wfmodels.AccessRequest, int, error) { + return m.listItems, m.listTotalCount, m.listErr +} + +func (m *mockAccessRequestService) GetRequest(_ context.Context, _ string) (*wfmodels.AccessRequest, error) { + return m.getResult, m.getErr +} + +func (m *mockAccessRequestService) SubmitRequest(_ context.Context, _ *wfmodels.SubmitAccessRequest) (*wfmodels.AccessRequest, error) { + return m.submitResult, m.submitErr +} + +func (m *mockAccessRequestService) CancelRequest(_ context.Context, _ string, _ *string) (*wfmodels.AccessRequest, error) { + return m.cancelResult, m.cancelErr +} + +func (m *mockAccessRequestService) FinalizeRequest(_ context.Context, _, _ string, _ *string) (*wfmodels.AccessRequest, error) { + return m.finalizeResult, m.finalizeErr +} + // countingEligibilityLister wraps an eligibilityLister and counts calls per CSP. // Thread-safe for concurrent access from goroutines in fetchStatusData etc. type countingEligibilityLister struct { diff --git a/internal/workflows/logging_client.go b/internal/workflows/logging_client.go new file mode 100644 index 0000000..400652d --- /dev/null +++ b/internal/workflows/logging_client.go @@ -0,0 +1,62 @@ +package workflows + +import ( + "context" + "net/http" + "time" +) + +type logger interface { + Info(msg string, v ...interface{}) + Error(msg string, v ...interface{}) + Debug(msg string, v ...interface{}) +} + +type loggingClient struct { + inner httpClient + logger logger +} + +func newLoggingClient(inner httpClient, l logger) *loggingClient { + return &loggingClient{inner: inner, logger: l} +} + +func (c *loggingClient) Get(ctx context.Context, route string, params interface{}) (*http.Response, error) { + return c.do("GET", route, func() (*http.Response, error) { + return c.inner.Get(ctx, route, params) + }) +} + +func (c *loggingClient) Post(ctx context.Context, route string, body interface{}) (*http.Response, error) { + return c.do("POST", route, func() (*http.Response, error) { + return c.inner.Post(ctx, route, body) + }) +} + +func (c *loggingClient) do(method, route string, fn func() (*http.Response, error)) (*http.Response, error) { + c.logger.Info("%s %s", method, route) + start := time.Now() + + resp, err := fn() + elapsed := time.Since(start) + + if err != nil { + c.logger.Error("%s %s failed: %v", method, route, err) + return nil, err + } + + c.logger.Info("%s %s -> %d (%dms)", method, route, resp.StatusCode, elapsed.Milliseconds()) + if resp.Header != nil { + c.logger.Debug("Response headers: %v", redactHeaders(resp.Header)) + } + + return resp, nil +} + +func redactHeaders(h http.Header) http.Header { + redacted := h.Clone() + if redacted.Get("Authorization") != "" { + redacted.Set("Authorization", "Bearer [REDACTED]") + } + return redacted +} diff --git a/internal/workflows/models/cancel.go b/internal/workflows/models/cancel.go new file mode 100644 index 0000000..64df23b --- /dev/null +++ b/internal/workflows/models/cancel.go @@ -0,0 +1,6 @@ +package models + +// CancelAccessRequest is the request body for canceling an access request. +type CancelAccessRequest struct { + CancelReason *string `json:"cancelReason"` +} diff --git a/internal/workflows/models/finalize.go b/internal/workflows/models/finalize.go new file mode 100644 index 0000000..5f24cfd --- /dev/null +++ b/internal/workflows/models/finalize.go @@ -0,0 +1,7 @@ +package models + +// FinalizeAccessRequest is the request body for approving or rejecting an access request. +type FinalizeAccessRequest struct { + Result string `json:"result"` + FinalizationReason *string `json:"finalizationReason,omitempty"` +} diff --git a/internal/workflows/models/form.go b/internal/workflows/models/form.go new file mode 100644 index 0000000..deecf43 --- /dev/null +++ b/internal/workflows/models/form.go @@ -0,0 +1,55 @@ +package models + +import "encoding/json" + +// RequestFormResponse wraps the list of request forms returned by the API. +type RequestFormResponse struct { + RequestForms []RequestFormEntry `json:"requestForms"` +} + +// RequestFormEntry represents a single form entry for a target category and request type. +type RequestFormEntry struct { + TargetCategory string `json:"targetCategory"` + RequestType string `json:"requestType"` + RequestForm RequestForm `json:"requestForm"` +} + +// RequestForm contains the questions that make up an access request form. +type RequestForm struct { + Questions []FormQuestion `json:"questions"` +} + +// FormQuestion represents a single question in a request form. +type FormQuestion struct { + Key string `json:"key"` + Required json.RawMessage `json:"required"` + Default interface{} `json:"default,omitempty"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + ValueType string `json:"valueType,omitempty"` + ValueChoices []interface{} `json:"valueChoices,omitempty"` + Validators []Validator `json:"validators,omitempty"` +} + +// IsRequired returns true if the question is unconditionally required. +// Returns false for conditional requirements (which are JSON objects). +func (q *FormQuestion) IsRequired() bool { + if len(q.Required) == 0 { + return false + } + var b bool + if err := json.Unmarshal(q.Required, &b); err != nil { + return false + } + return b +} + +// Validator represents a validation rule for a form question. +type Validator struct { + Name string `json:"name"` + Regex string `json:"regex,omitempty"` + ErrorMessage string `json:"errorMessage,omitempty"` + Format string `json:"format,omitempty"` + MinLength *int `json:"minLength,omitempty"` + MaxLength *int `json:"maxLength,omitempty"` +} diff --git a/internal/workflows/models/form_test.go b/internal/workflows/models/form_test.go new file mode 100644 index 0000000..7cffe0d --- /dev/null +++ b/internal/workflows/models/form_test.go @@ -0,0 +1,107 @@ +package models + +import ( + "encoding/json" + "testing" +) + +func TestFormQuestion_IsRequired(t *testing.T) { + tests := []struct { + name string + required string + want bool + }{ + {"true", "true", true}, + {"false", "false", false}, + {"conditional object", `{"operator":"OR","conditions":[{"name":"regex_condition","key":"location_type","condition":"^(GCP|Azure)$"}]}`, false}, + {"empty", "", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := FormQuestion{Required: json.RawMessage(tt.required)} + if got := q.IsRequired(); got != tt.want { + t.Errorf("IsRequired() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRequestFormResponse_Unmarshal(t *testing.T) { + raw := `{ + "requestForms": [{ + "targetCategory": "CLOUD_CONSOLE", + "requestType": "ON_DEMAND", + "requestForm": { + "questions": [ + { + "key": "reason", + "required": true, + "title": "Reason", + "valueType": "TEXT", + "validators": [{"name": "length_validator", "minLength": 0, "maxLength": 4096}] + }, + { + "key": "priority", + "required": true, + "default": "Medium", + "title": "Priority", + "valueType": "CHOICE", + "valueChoices": ["High", "Medium", "Low"] + }, + { + "key": "org_id", + "required": {"operator": "OR", "conditions": [{"name": "regex_condition", "key": "location_type", "condition": "^(GCP|Azure)$"}]}, + "title": "ORG Id", + "valueType": "STRING" + } + ] + } + }] + }` + + var resp RequestFormResponse + if err := json.Unmarshal([]byte(raw), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(resp.RequestForms) != 1 { + t.Fatalf("expected 1 form, got %d", len(resp.RequestForms)) + } + + form := resp.RequestForms[0] + if form.TargetCategory != "CLOUD_CONSOLE" { + t.Errorf("targetCategory: got %q", form.TargetCategory) + } + if form.RequestType != "ON_DEMAND" { + t.Errorf("requestType: got %q", form.RequestType) + } + + questions := form.RequestForm.Questions + if len(questions) != 3 { + t.Fatalf("expected 3 questions, got %d", len(questions)) + } + + if !questions[0].IsRequired() { + t.Error("reason should be required") + } + if questions[0].Title != "Reason" { + t.Errorf("reason title: got %q", questions[0].Title) + } + if len(questions[0].Validators) != 1 { + t.Fatalf("expected 1 validator for reason, got %d", len(questions[0].Validators)) + } + if questions[0].Validators[0].Name != "length_validator" { + t.Errorf("validator name: got %q", questions[0].Validators[0].Name) + } + + if !questions[1].IsRequired() { + t.Error("priority should be required") + } + if len(questions[1].ValueChoices) != 3 { + t.Errorf("priority choices: expected 3, got %d", len(questions[1].ValueChoices)) + } + + if questions[2].IsRequired() { + t.Error("org_id has conditional required, IsRequired should return false") + } +} diff --git a/internal/workflows/models/request.go b/internal/workflows/models/request.go new file mode 100644 index 0000000..f7dc12d --- /dev/null +++ b/internal/workflows/models/request.go @@ -0,0 +1,87 @@ +package models + +// RequestState represents the system activity status of an access request. +type RequestState string + +const ( + RequestStateStarting RequestState = "STARTING" + RequestStateRunning RequestState = "RUNNING" + RequestStatePending RequestState = "PENDING" + RequestStateFinished RequestState = "FINISHED" + RequestStateExpired RequestState = "EXPIRED" +) + +// RequestResult represents the outcome of an access request. +type RequestResult string + +const ( + RequestResultApproved RequestResult = "APPROVED" + RequestResultRejected RequestResult = "REJECTED" + RequestResultCanceled RequestResult = "CANCELED" + RequestResultFailed RequestResult = "FAILED" + RequestResultUnknown RequestResult = "UNKNOWN" +) + +// AccessRequest represents a single access request from the API. +type AccessRequest struct { + RequestID string `json:"requestId"` + TargetCategory string `json:"targetCategory"` + RequestState RequestState `json:"requestState"` + RequestResult RequestResult `json:"requestResult"` + RequestLink string `json:"requestLink,omitempty"` + RequestDetails map[string]interface{} `json:"requestDetails,omitempty"` + RequestApprovers []ApproverAction `json:"requestApprovers,omitempty"` + AssignedApprovers []Entity `json:"assignedApprovers,omitempty"` + Requester *Entity `json:"requester,omitempty"` + RequestOutcomes map[string]string `json:"requestOutcomes,omitempty"` + FinalizationReason string `json:"finalizationReason,omitempty"` + CreatedBy string `json:"createdBy"` + CreatedAt string `json:"createdAt"` + UpdatedBy string `json:"updatedBy"` + UpdatedAt string `json:"updatedAt"` +} + +// Entity represents a user or approver identity. +type Entity struct { + EntityID string `json:"entityId"` + EntityName string `json:"entityName"` + EntityDisplayName string `json:"entityDisplayName,omitempty"` + EntityEmail string `json:"entityEmail,omitempty"` + EntityDirectorySource *DirectorySource `json:"entityDirectorySource,omitempty"` +} + +// DirectorySource represents the directory source of an entity. +type DirectorySource struct { + DirectoryID string `json:"directoryId"` + DirectoryName string `json:"directoryName"` + DirectoryType string `json:"directoryType,omitempty"` +} + +// ApproverAction represents an action taken by an approver on a request. +type ApproverAction struct { + Approver Entity `json:"approver"` + Result RequestResult `json:"result"` +} + +// ListRequestsResponse represents the paginated response from the list requests endpoint. +type ListRequestsResponse struct { + Items []AccessRequest `json:"items"` + Count int `json:"count"` + TotalCount int `json:"totalCount"` +} + +// DetailString returns a human-readable detail from requestDetails for the given key. +func (r *AccessRequest) DetailString(key string) string { + if r.RequestDetails == nil { + return "" + } + v, ok := r.RequestDetails[key] + if !ok { + return "" + } + s, ok := v.(string) + if !ok { + return "" + } + return s +} diff --git a/internal/workflows/models/request_test.go b/internal/workflows/models/request_test.go new file mode 100644 index 0000000..4eb60b3 --- /dev/null +++ b/internal/workflows/models/request_test.go @@ -0,0 +1,139 @@ +package models + +import ( + "encoding/json" + "testing" +) + +func TestAccessRequest_UnmarshalJSON(t *testing.T) { + raw := `{ + "requestId": "8a45155d-0273-4bc8-8d45-9fe3f4d4de6d", + "targetCategory": "CLOUD_CONSOLE", + "requestState": "FINISHED", + "requestResult": "APPROVED", + "requestLink": "https://tenant.cyberark.cloud/userportal/ars", + "requestDetails": { + "locationType": "Azure", + "roleName": "Load Test Reader", + "workspaceName": "Azure Subscription", + "priority": "Low", + "reason": "I need access" + }, + "requestApprovers": [{ + "approver": { + "entityId": "279bd5a1-83db-4ec0-89e4-569396b0044c", + "entityName": "approver_2@cyberark.cloud", + "entityDisplayName": "Approver Two", + "entityEmail": "approver_2@cyberark.com", + "entityDirectorySource": { + "directoryId": "09B9A9B0-6CE8-465F-AB03-65766D33B05E", + "directoryName": "CyberArk Cloud Directory", + "directoryType": "CDS" + } + }, + "result": "APPROVED" + }], + "assignedApprovers": [{ + "entityId": "0e061076-c3bc-4027-ac35-f864d98cdef7", + "entityName": "approver_1@cyberark.cloud" + }], + "requestOutcomes": {"policyId": "aws_23f89256"}, + "finalizationReason": "Approved your access", + "createdBy": "user@cyberark.cloud", + "createdAt": "2025-08-12T09:41:00.594008", + "updatedBy": "SYSTEM", + "updatedAt": "2025-08-12T09:42:31.886399" + }` + + var req AccessRequest + if err := json.Unmarshal([]byte(raw), &req); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + tests := []struct { + name string + got string + want string + }{ + {"RequestID", req.RequestID, "8a45155d-0273-4bc8-8d45-9fe3f4d4de6d"}, + {"TargetCategory", req.TargetCategory, "CLOUD_CONSOLE"}, + {"RequestState", string(req.RequestState), "FINISHED"}, + {"RequestResult", string(req.RequestResult), "APPROVED"}, + {"RequestLink", req.RequestLink, "https://tenant.cyberark.cloud/userportal/ars"}, + {"FinalizationReason", req.FinalizationReason, "Approved your access"}, + {"CreatedBy", req.CreatedBy, "user@cyberark.cloud"}, + {"UpdatedBy", req.UpdatedBy, "SYSTEM"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.got != tt.want { + t.Errorf("got %q, want %q", tt.got, tt.want) + } + }) + } + + if len(req.RequestApprovers) != 1 { + t.Fatalf("expected 1 approver action, got %d", len(req.RequestApprovers)) + } + if req.RequestApprovers[0].Approver.EntityDisplayName != "Approver Two" { + t.Errorf("approver display name: got %q", req.RequestApprovers[0].Approver.EntityDisplayName) + } + if req.RequestApprovers[0].Approver.EntityDirectorySource.DirectoryType != "CDS" { + t.Errorf("directory type: got %q", req.RequestApprovers[0].Approver.EntityDirectorySource.DirectoryType) + } + + if len(req.AssignedApprovers) != 1 { + t.Fatalf("expected 1 assigned approver, got %d", len(req.AssignedApprovers)) + } + + if req.RequestOutcomes["policyId"] != "aws_23f89256" { + t.Errorf("outcomes policyId: got %q", req.RequestOutcomes["policyId"]) + } +} + +func TestAccessRequest_DetailString(t *testing.T) { + tests := []struct { + name string + details map[string]interface{} + key string + want string + }{ + {"existing key", map[string]interface{}{"reason": "need access"}, "reason", "need access"}, + {"missing key", map[string]interface{}{"reason": "need access"}, "priority", ""}, + {"nil details", nil, "reason", ""}, + {"non-string value", map[string]interface{}{"roleType": 0}, "roleType", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &AccessRequest{RequestDetails: tt.details} + if got := r.DetailString(tt.key); got != tt.want { + t.Errorf("DetailString(%q) = %q, want %q", tt.key, got, tt.want) + } + }) + } +} + +func TestListRequestsResponse_Unmarshal(t *testing.T) { + raw := `{ + "items": [ + {"requestId": "id-1", "requestState": "PENDING", "requestResult": "UNKNOWN", "createdBy": "a", "createdAt": "t", "updatedBy": "b", "updatedAt": "t"}, + {"requestId": "id-2", "requestState": "FINISHED", "requestResult": "APPROVED", "createdBy": "a", "createdAt": "t", "updatedBy": "b", "updatedAt": "t"} + ], + "count": 2, + "totalCount": 10 + }` + + var resp ListRequestsResponse + if err := json.Unmarshal([]byte(raw), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if len(resp.Items) != 2 { + t.Fatalf("expected 2 items, got %d", len(resp.Items)) + } + if resp.TotalCount != 10 { + t.Errorf("totalCount: got %d, want 10", resp.TotalCount) + } + if resp.Items[0].RequestState != RequestStatePending { + t.Errorf("item[0] state: got %q", resp.Items[0].RequestState) + } +} diff --git a/internal/workflows/models/submit.go b/internal/workflows/models/submit.go new file mode 100644 index 0000000..7a37d7d --- /dev/null +++ b/internal/workflows/models/submit.go @@ -0,0 +1,7 @@ +package models + +// SubmitAccessRequest is the request body for creating a new access request. +type SubmitAccessRequest struct { + TargetCategory string `json:"targetCategory"` + RequestDetails map[string]interface{} `json:"requestDetails"` +} diff --git a/internal/workflows/service.go b/internal/workflows/service.go new file mode 100644 index 0000000..9d7ea86 --- /dev/null +++ b/internal/workflows/service.go @@ -0,0 +1,288 @@ +// Package workflows provides the CyberArk Access Requests API client. +package workflows + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + + "github.com/aaearon/grant-cli/internal/workflows/models" + "github.com/cyberark/idsec-sdk-golang/pkg/auth" + "github.com/cyberark/idsec-sdk-golang/pkg/common" + "github.com/cyberark/idsec-sdk-golang/pkg/common/isp" + "github.com/cyberark/idsec-sdk-golang/pkg/services" +) + +type httpClient interface { + Get(ctx context.Context, route string, params interface{}) (*http.Response, error) + Post(ctx context.Context, route string, body interface{}) (*http.Response, error) +} + +// AccessRequestService provides access to the Access Requests API endpoints. +type AccessRequestService struct { + services.IdsecService + *services.IdsecBaseService + ispAuth *auth.IdsecISPAuth + httpClient httpClient +} + +// NewAccessRequestService creates a new Access Request Service instance. +func NewAccessRequestService(authenticators ...auth.IdsecAuth) (*AccessRequestService, error) { + svc := &AccessRequestService{} + var svcIface services.IdsecService = svc + + base, err := services.NewIdsecBaseService(svcIface, authenticators...) + if err != nil { + return nil, fmt.Errorf("failed to create base service: %w", err) + } + svc.IdsecBaseService = base + + ispAuthIface, err := base.Authenticator("isp") + if err != nil { + return nil, fmt.Errorf("isp authenticator required: %w", err) + } + + ispAuth, ok := ispAuthIface.(*auth.IdsecISPAuth) + if !ok { + return nil, errors.New("authenticator is not *auth.IdsecISPAuth") + } + svc.ispAuth = ispAuth + + client, err := isp.FromISPAuth(ispAuth, "uar", ".", "", svc.refreshAuth) + if err != nil { + return nil, fmt.Errorf("failed to create ISP client: %w", err) + } + + svc.httpClient = newLoggingClient(client, common.GetLogger("grant", -1)) + + return svc, nil +} + +// NewAccessRequestServiceWithClient creates a service with a custom HTTP client for testing. +func NewAccessRequestServiceWithClient(client httpClient) *AccessRequestService { + return &AccessRequestService{ + httpClient: client, + } +} + +func (s *AccessRequestService) refreshAuth(client *common.IdsecClient) error { + return isp.RefreshClient(client, s.ispAuth) +} + +// ServiceConfig returns the service configuration. +func (s *AccessRequestService) ServiceConfig() services.IdsecServiceConfig { + return ServiceConfig() +} + +// GetRequestForms retrieves the access request form structure. +// GET /api/workflows/request-forms +func (s *AccessRequestService) GetRequestForms(ctx context.Context, targetCategory, requestType string) (*models.RequestFormResponse, error) { + params := map[string]string{ + "targetCategory": targetCategory, + "requestType": requestType, + } + + resp, err := s.httpClient.Get(ctx, "/api/workflows/request-forms", params) + if err != nil { + return nil, fmt.Errorf("failed to get request forms: %w", err) + } + defer resp.Body.Close() + + if err := checkResponse(resp, "request forms"); err != nil { + return nil, err + } + + var result models.RequestFormResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode request forms response: %w", err) + } + + return &result, nil +} + +// ListRequestsParams holds query parameters for listing access requests. +type ListRequestsParams struct { + Filter string + FreeText string + Limit int + Offset int + RequestRole string + Sort string +} + +const defaultPageSize = 50 + +// ListRequests retrieves all access requests matching the given parameters, +// fetching all pages via offset/limit pagination. +// GET /api/workflows/requests +func (s *AccessRequestService) ListRequests(ctx context.Context, params ListRequestsParams) ([]models.AccessRequest, int, error) { + limit := params.Limit + if limit <= 0 { + limit = defaultPageSize + } + + var allItems []models.AccessRequest + totalCount := 0 + offset := params.Offset + + for range maxPages { + qp := make(map[string]string) + qp["limit"] = strconv.Itoa(limit) + qp["offset"] = strconv.Itoa(offset) + if params.Filter != "" { + qp["filter"] = params.Filter + } + if params.FreeText != "" { + qp["freeText"] = params.FreeText + } + if params.RequestRole != "" { + qp["requestRole"] = params.RequestRole + } + if params.Sort != "" { + qp["sort"] = params.Sort + } + + resp, err := s.httpClient.Get(ctx, "/api/workflows/requests", qp) + if err != nil { + return nil, 0, fmt.Errorf("failed to list requests: %w", err) + } + + if err := checkResponse(resp, "list requests"); err != nil { + resp.Body.Close() + return nil, 0, err + } + + var page models.ListRequestsResponse + if err := json.NewDecoder(resp.Body).Decode(&page); err != nil { + resp.Body.Close() + return nil, 0, fmt.Errorf("failed to decode list requests response: %w", err) + } + resp.Body.Close() + + allItems = append(allItems, page.Items...) + totalCount = page.TotalCount + + if len(allItems) >= page.TotalCount || len(page.Items) < limit { + break + } + offset += len(page.Items) + } + + return allItems, totalCount, nil +} + +// GetRequest retrieves a single access request by ID. +// GET /api/workflows/requests/{requestId} +func (s *AccessRequestService) GetRequest(ctx context.Context, requestID string) (*models.AccessRequest, error) { + route := "/api/workflows/requests/" + requestID + + resp, err := s.httpClient.Get(ctx, route, nil) + if err != nil { + return nil, fmt.Errorf("failed to get request: %w", err) + } + defer resp.Body.Close() + + if err := checkResponse(resp, "get request"); err != nil { + return nil, err + } + + var result models.AccessRequest + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode request response: %w", err) + } + + return &result, nil +} + +// SubmitRequest creates a new access request. +// POST /api/workflows/requests +func (s *AccessRequestService) SubmitRequest(ctx context.Context, req *models.SubmitAccessRequest) (*models.AccessRequest, error) { + if req == nil { + return nil, errors.New("submit request cannot be nil") + } + + resp, err := s.httpClient.Post(ctx, "/api/workflows/requests", req) + if err != nil { + return nil, fmt.Errorf("failed to submit request: %w", err) + } + defer resp.Body.Close() + + if err := checkResponse(resp, "submit request"); err != nil { + return nil, err + } + + var result models.AccessRequest + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode submit response: %w", err) + } + + return &result, nil +} + +// CancelRequest cancels an open access request. +// POST /api/workflows/requests/{requestId}/cancel +func (s *AccessRequestService) CancelRequest(ctx context.Context, requestID string, reason *string) (*models.AccessRequest, error) { + route := fmt.Sprintf("/api/workflows/requests/%s/cancel", requestID) + body := &models.CancelAccessRequest{CancelReason: reason} + + resp, err := s.httpClient.Post(ctx, route, body) + if err != nil { + return nil, fmt.Errorf("failed to cancel request: %w", err) + } + defer resp.Body.Close() + + if err := checkResponse(resp, "cancel request"); err != nil { + return nil, err + } + + var result models.AccessRequest + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode cancel response: %w", err) + } + + return &result, nil +} + +// FinalizeRequest approves or rejects an access request. +// POST /api/workflows/requests/{requestId}/finalize +func (s *AccessRequestService) FinalizeRequest(ctx context.Context, requestID, result string, reason *string) (*models.AccessRequest, error) { + route := fmt.Sprintf("/api/workflows/requests/%s/finalize", requestID) + body := &models.FinalizeAccessRequest{ + Result: result, + FinalizationReason: reason, + } + + resp, err := s.httpClient.Post(ctx, route, body) + if err != nil { + return nil, fmt.Errorf("failed to finalize request: %w", err) + } + defer resp.Body.Close() + + if err := checkResponse(resp, "finalize request"); err != nil { + return nil, err + } + + var reqResult models.AccessRequest + if err := json.NewDecoder(resp.Body).Decode(&reqResult); err != nil { + return nil, fmt.Errorf("failed to decode finalize response: %w", err) + } + + return &reqResult, nil +} + +const maxPages = 100 + +func checkResponse(resp *http.Response, operation string) error { + if resp.StatusCode == http.StatusOK { + return nil + } + body, readErr := io.ReadAll(io.LimitReader(resp.Body, 4096)) + if readErr != nil { + body = []byte("(failed to read response body)") + } + return fmt.Errorf("%s failed with status %d: %s", operation, resp.StatusCode, string(body)) +} diff --git a/internal/workflows/service_config.go b/internal/workflows/service_config.go new file mode 100644 index 0000000..5fe5e14 --- /dev/null +++ b/internal/workflows/service_config.go @@ -0,0 +1,14 @@ +package workflows + +import "github.com/cyberark/idsec-sdk-golang/pkg/services" + +// ServiceConfig returns the configuration for the Access Requests Service. +// It specifies the service name "access-requests" and requires the "isp" authenticator. +func ServiceConfig() services.IdsecServiceConfig { + return services.IdsecServiceConfig{ + ServiceName: "access-requests", + RequiredAuthenticatorNames: []string{"isp"}, + OptionalAuthenticatorNames: []string{}, + ActionsConfigurations: nil, + } +} diff --git a/internal/workflows/service_test.go b/internal/workflows/service_test.go new file mode 100644 index 0000000..43cb5c9 --- /dev/null +++ b/internal/workflows/service_test.go @@ -0,0 +1,340 @@ +package workflows + +import ( + "context" + "encoding/json" + "io" + "net/http" + "strings" + "testing" + + "github.com/aaearon/grant-cli/internal/workflows/models" +) + +type mockHTTPClient struct { + getFn func(ctx context.Context, route string, params interface{}) (*http.Response, error) + postFn func(ctx context.Context, route string, body interface{}) (*http.Response, error) +} + +func (m *mockHTTPClient) Get(ctx context.Context, route string, params interface{}) (*http.Response, error) { + return m.getFn(ctx, route, params) +} + +func (m *mockHTTPClient) Post(ctx context.Context, route string, body interface{}) (*http.Response, error) { + return m.postFn(ctx, route, body) +} + +func jsonResponse(status int, body interface{}) *http.Response { + b, _ := json.Marshal(body) + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(strings.NewReader(string(b))), + Header: make(http.Header), + } +} + +func TestGetRequestForms(t *testing.T) { + expected := models.RequestFormResponse{ + RequestForms: []models.RequestFormEntry{ + { + TargetCategory: "CLOUD_CONSOLE", + RequestType: "ON_DEMAND", + RequestForm: models.RequestForm{ + Questions: []models.FormQuestion{ + {Key: "reason", Title: "Reason", ValueType: "TEXT"}, + }, + }, + }, + }, + } + + mock := &mockHTTPClient{ + getFn: func(_ context.Context, route string, params interface{}) (*http.Response, error) { + if route != "/api/workflows/request-forms" { + t.Errorf("unexpected route: %s", route) + } + p := params.(map[string]string) + if p["targetCategory"] != "CLOUD_CONSOLE" { + t.Errorf("expected targetCategory CLOUD_CONSOLE, got %s", p["targetCategory"]) + } + if p["requestType"] != "ON_DEMAND" { + t.Errorf("expected requestType ON_DEMAND, got %s", p["requestType"]) + } + return jsonResponse(200, expected), nil + }, + } + + svc := NewAccessRequestServiceWithClient(mock) + result, err := svc.GetRequestForms(t.Context(), "CLOUD_CONSOLE", "ON_DEMAND") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result.RequestForms) != 1 { + t.Fatalf("expected 1 form, got %d", len(result.RequestForms)) + } + if result.RequestForms[0].RequestForm.Questions[0].Key != "reason" { + t.Errorf("unexpected question key: %s", result.RequestForms[0].RequestForm.Questions[0].Key) + } +} + +func TestListRequests(t *testing.T) { + callCount := 0 + mock := &mockHTTPClient{ + getFn: func(_ context.Context, _ string, params interface{}) (*http.Response, error) { + callCount++ + p := params.(map[string]string) + + if callCount == 1 { + if p["offset"] != "0" { + t.Errorf("first call offset: got %s, want 0", p["offset"]) + } + return jsonResponse(200, models.ListRequestsResponse{ + Items: []models.AccessRequest{ + {RequestID: "id-1", CreatedBy: "user@test"}, + {RequestID: "id-2", CreatedBy: "user@test"}, + }, + Count: 2, + TotalCount: 3, + }), nil + } + + if p["offset"] != "2" { + t.Errorf("second call offset: got %s, want 2", p["offset"]) + } + return jsonResponse(200, models.ListRequestsResponse{ + Items: []models.AccessRequest{{RequestID: "id-3", CreatedBy: "user@test"}}, + Count: 1, + TotalCount: 3, + }), nil + }, + } + + svc := NewAccessRequestServiceWithClient(mock) + items, total, err := svc.ListRequests(t.Context(), ListRequestsParams{Limit: 2}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(items) != 3 { + t.Errorf("expected 3 items, got %d", len(items)) + } + if total != 3 { + t.Errorf("expected totalCount 3, got %d", total) + } + if callCount != 2 { + t.Errorf("expected 2 API calls, got %d", callCount) + } +} + +func TestListRequests_WithFilters(t *testing.T) { + mock := &mockHTTPClient{ + getFn: func(_ context.Context, _ string, params interface{}) (*http.Response, error) { + p := params.(map[string]string) + if p["filter"] != "(requestState eq PENDING)" { + t.Errorf("filter: got %q", p["filter"]) + } + if p["freeText"] != "azure" { + t.Errorf("freeText: got %q", p["freeText"]) + } + if p["requestRole"] != "CREATOR" { + t.Errorf("requestRole: got %q", p["requestRole"]) + } + if p["sort"] != "createdAt desc" { + t.Errorf("sort: got %q", p["sort"]) + } + return jsonResponse(200, models.ListRequestsResponse{ + Items: []models.AccessRequest{{RequestID: "id-1"}}, + Count: 1, + TotalCount: 1, + }), nil + }, + } + + svc := NewAccessRequestServiceWithClient(mock) + items, _, err := svc.ListRequests(t.Context(), ListRequestsParams{ + Filter: "(requestState eq PENDING)", + FreeText: "azure", + RequestRole: "CREATOR", + Sort: "createdAt desc", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(items) != 1 { + t.Errorf("expected 1 item, got %d", len(items)) + } +} + +func TestGetRequest(t *testing.T) { + expected := models.AccessRequest{ + RequestID: "8a45155d-0273-4bc8-8d45-9fe3f4d4de6d", + RequestState: models.RequestStateFinished, + RequestResult: models.RequestResultApproved, + CreatedBy: "user@test", + CreatedAt: "2025-08-12T09:41:00", + UpdatedBy: "SYSTEM", + UpdatedAt: "2025-08-12T09:42:31", + } + + mock := &mockHTTPClient{ + getFn: func(_ context.Context, route string, _ interface{}) (*http.Response, error) { + if route != "/api/workflows/requests/8a45155d-0273-4bc8-8d45-9fe3f4d4de6d" { + t.Errorf("unexpected route: %s", route) + } + return jsonResponse(200, expected), nil + }, + } + + svc := NewAccessRequestServiceWithClient(mock) + result, err := svc.GetRequest(t.Context(), "8a45155d-0273-4bc8-8d45-9fe3f4d4de6d") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.RequestID != expected.RequestID { + t.Errorf("requestId: got %q, want %q", result.RequestID, expected.RequestID) + } + if result.RequestState != models.RequestStateFinished { + t.Errorf("state: got %q", result.RequestState) + } +} + +func TestSubmitRequest(t *testing.T) { + mock := &mockHTTPClient{ + postFn: func(_ context.Context, route string, body interface{}) (*http.Response, error) { + if route != "/api/workflows/requests" { + t.Errorf("unexpected route: %s", route) + } + req := body.(*models.SubmitAccessRequest) + if req.TargetCategory != "CLOUD_CONSOLE" { + t.Errorf("targetCategory: got %q", req.TargetCategory) + } + return jsonResponse(200, models.AccessRequest{ + RequestID: "new-id", + RequestState: models.RequestStateStarting, + RequestResult: models.RequestResultUnknown, + CreatedBy: "user@test", + CreatedAt: "2025-08-12T09:41:00", + UpdatedBy: "SYSTEM", + UpdatedAt: "2025-08-12T09:41:00", + }), nil + }, + } + + svc := NewAccessRequestServiceWithClient(mock) + result, err := svc.SubmitRequest(t.Context(), &models.SubmitAccessRequest{ + TargetCategory: "CLOUD_CONSOLE", + RequestDetails: map[string]interface{}{"reason": "need access"}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.RequestID != "new-id" { + t.Errorf("requestId: got %q", result.RequestID) + } + if result.RequestState != models.RequestStateStarting { + t.Errorf("state: got %q", result.RequestState) + } +} + +func TestSubmitRequest_NilRequest(t *testing.T) { + svc := NewAccessRequestServiceWithClient(&mockHTTPClient{}) + _, err := svc.SubmitRequest(t.Context(), nil) + if err == nil { + t.Fatal("expected error for nil request") + } +} + +func TestCancelRequest(t *testing.T) { + mock := &mockHTTPClient{ + postFn: func(_ context.Context, route string, body interface{}) (*http.Response, error) { + if route != "/api/workflows/requests/req-id/cancel" { + t.Errorf("unexpected route: %s", route) + } + cancelReq := body.(*models.CancelAccessRequest) + if cancelReq.CancelReason == nil || *cancelReq.CancelReason != "no longer needed" { + t.Errorf("unexpected cancel reason") + } + return jsonResponse(200, models.AccessRequest{ + RequestID: "req-id", + RequestState: models.RequestStateRunning, + RequestResult: models.RequestResultCanceled, + CreatedBy: "user@test", + CreatedAt: "t", + UpdatedBy: "SYSTEM", + UpdatedAt: "t", + }), nil + }, + } + + svc := NewAccessRequestServiceWithClient(mock) + reason := "no longer needed" + result, err := svc.CancelRequest(t.Context(), "req-id", &reason) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.RequestResult != models.RequestResultCanceled { + t.Errorf("result: got %q", result.RequestResult) + } +} + +func TestFinalizeRequest(t *testing.T) { + tests := []struct { + name string + result string + wantResult models.RequestResult + }{ + {"approve", "APPROVED", models.RequestResultApproved}, + {"reject", "REJECTED", models.RequestResultRejected}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := &mockHTTPClient{ + postFn: func(_ context.Context, route string, body interface{}) (*http.Response, error) { + if !strings.HasSuffix(route, "/finalize") { + t.Errorf("route should end with /finalize: %s", route) + } + fin := body.(*models.FinalizeAccessRequest) + if fin.Result != tt.result { + t.Errorf("result: got %q, want %q", fin.Result, tt.result) + } + return jsonResponse(200, models.AccessRequest{ + RequestID: "req-id", + RequestState: models.RequestStateRunning, + RequestResult: models.RequestResult(tt.result), + CreatedBy: "user@test", + CreatedAt: "t", + UpdatedBy: "SYSTEM", + UpdatedAt: "t", + }), nil + }, + } + + svc := NewAccessRequestServiceWithClient(mock) + reason := "looks good" + result, err := svc.FinalizeRequest(t.Context(), "req-id", tt.result, &reason) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.RequestResult != tt.wantResult { + t.Errorf("result: got %q, want %q", result.RequestResult, tt.wantResult) + } + }) + } +} + +func TestCheckResponse_Error(t *testing.T) { + resp := &http.Response{ + StatusCode: 401, + Body: io.NopCloser(strings.NewReader(`{"code":"UNAUTHORIZED","message":"Unauthorized"}`)), + } + err := checkResponse(resp, "test operation") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "401") { + t.Errorf("error should contain status code: %v", err) + } + if !strings.Contains(err.Error(), "UNAUTHORIZED") { + t.Errorf("error should contain response body: %v", err) + } +} From 39b4c420496efa71d1a54a08a387ddf9209df56b Mon Sep 17 00:00:00 2001 From: Tim Schindler Date: Mon, 20 Apr 2026 17:13:42 +0200 Subject: [PATCH 2/9] feat: harden grant request submit interactive UX - Resolve real IANA timezone instead of "Local" default - Validate provider flag via parseProvider() before API calls - Add context.WithTimeout to prevent indefinite hangs - Add confirmation prompt with --yes flag to skip - Add inline survey validators for date/time/timezone formats - Handle partial --target/--role by filtering selector - Prompt only for missing fields (pass existing to promptFn) - Add IsInteractive() guard in defaultSubmitPrompt() - Reorder prompts: timezone before date for correct defaults - Enforce presence of date/timezone/from/to in validation --- CLAUDE.md | 2 +- cmd/request_submit.go | 249 +++++++++++++++++++++++++++++++--------- cmd/request_test.go | 258 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 449 insertions(+), 60 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index ce812c9..d2500dd 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -74,7 +74,7 @@ Custom `SCAAccessService` follows SDK conventions: - `grant list` — list eligible targets and groups without triggering elevation; supports `--provider`, `--groups`, `--refresh`, `--output json`; used by LLMs to discover available targets programmatically - `grant revoke` — revoke sessions: direct (`grant revoke `), `--all`, or interactive multi-select; `--yes` skips confirmation - `grant request` — manage access requests through approval workflow; subcommands: `submit`, `list`, `get`, `cancel`, `approve`, `reject` -- `grant request submit` — submit access request; reuses SCA eligibility for target selection; flags: `--target`, `--role`, `--provider`, `--reason`, `--priority`, `--date`, `--timezone`, `--from`, `--to` +- `grant request submit` — submit access request; reuses SCA eligibility for target selection; shows summary + confirmation before submitting; partial `--target` or `--role` filters the selector; flags: `--target`, `--role`, `--provider`, `--reason`, `--priority`, `--date`, `--timezone`, `--from`, `--to`, `--yes` - `grant request list` — list access requests; flags: `--state`, `--result`, `--priority`, `--role` (CREATOR/APPROVER), `--search`, `--sort`, `--desc` - `grant request get ` — get full request details - `grant request cancel ` — cancel an open request; optional `--reason` diff --git a/cmd/request_submit.go b/cmd/request_submit.go index 2bcf0fc..142a5de 100644 --- a/cmd/request_submit.go +++ b/cmd/request_submit.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "os" + "strings" "time" survey "github.com/Iilun/survey/v2" @@ -41,6 +42,7 @@ func newRequestSubmitCommand(svc accessRequestService) *cobra.Command { cmd.Flags().String("timezone", "", "Timezone (TZ identifier, e.g. America/New_York)") cmd.Flags().String("from", "", "Start time (HH:MM)") cmd.Flags().String("to", "", "End time (HH:MM)") + cmd.Flags().BoolP("yes", "y", false, "Skip confirmation prompt") return cmd } @@ -48,6 +50,9 @@ func newRequestSubmitCommand(svc accessRequestService) *cobra.Command { // submitPromptFn is injectable for testing interactive prompts. var submitPromptFn = defaultSubmitPrompt +// confirmSubmitFn is injectable for testing the confirmation prompt. +var confirmSubmitFn = confirmSubmit + // resolveSubmitTargetFn is injectable for testing target resolution. var resolveSubmitTargetFn = resolveSubmitTarget @@ -60,57 +65,128 @@ type submitFields struct { timeTo string } -func defaultSubmitPrompt() (*submitFields, error) { +func resolveLocalTimezone() string { + tz := time.Now().Location().String() + if tz == "Local" { + if env := os.Getenv("TZ"); env != "" { + return env + } + return "UTC" + } + return tz +} + +func defaultSubmitPrompt(existing *submitFields) (*submitFields, error) { + if !ui.IsInteractive() { + return nil, fmt.Errorf("%w; use --reason, --date, --timezone, --from, --to flags", ui.ErrNotInteractive) + } + stdio := survey.WithStdio(os.Stdin, os.Stderr, os.Stderr) + f := &submitFields{} - var reason string - if err := survey.AskOne(&survey.Input{Message: "Reason:"}, &reason, survey.WithValidator(survey.Required), stdio); err != nil { - return nil, err + // 1. Reason + if existing.reason == "" { + if err := survey.AskOne(&survey.Input{Message: "Reason:"}, &f.reason, survey.WithValidator(survey.Required), stdio); err != nil { + return nil, err + } } - var priority string - if err := survey.AskOne(&survey.Select{ - Message: "Priority:", - Options: []string{"High", "Medium", "Low"}, - Default: "Medium", - }, &priority, stdio); err != nil { - return nil, err + // 2. Priority + if existing.priority == "" || existing.priority == "Medium" { + var priority string + if err := survey.AskOne(&survey.Select{ + Message: "Priority:", + Options: []string{"High", "Medium", "Low"}, + Default: "Medium", + }, &priority, stdio); err != nil { + return nil, err + } + f.priority = priority } - today := time.Now().Format("2006-01-02") - var date string - if err := survey.AskOne(&survey.Input{Message: "Date (YYYY-MM-DD):", Default: today}, &date, stdio); err != nil { - return nil, err + // 3. Timezone (before date so we can compute correct default date) + if existing.timezone == "" { + localTZ := resolveLocalTimezone() + if err := survey.AskOne(&survey.Input{Message: "Timezone:", Default: localTZ}, &f.timezone, + survey.WithValidator(func(val interface{}) error { + s, _ := val.(string) + if _, err := time.LoadLocation(s); err != nil { + return errors.New("must be a valid timezone (e.g. America/New_York, UTC)") + } + return nil + }), stdio); err != nil { + return nil, err + } } - localTZ := time.Now().Location().String() - var timezone string - if err := survey.AskOne(&survey.Input{Message: "Timezone:", Default: localTZ}, &timezone, stdio); err != nil { - return nil, err + // 4. Date (default: today in selected timezone) + if existing.date == "" { + tz := f.timezone + if tz == "" { + tz = existing.timezone + } + loc, _ := time.LoadLocation(tz) + today := time.Now().In(loc).Format("2006-01-02") + if err := survey.AskOne(&survey.Input{Message: "Date (YYYY-MM-DD):", Default: today}, &f.date, + survey.WithValidator(func(val interface{}) error { + s, _ := val.(string) + if _, err := time.Parse("2006-01-02", s); err != nil { + return errors.New("must be YYYY-MM-DD format") + } + return nil + }), stdio); err != nil { + return nil, err + } } - var timeFrom string - if err := survey.AskOne(&survey.Input{Message: "Start time (HH:MM):"}, &timeFrom, survey.WithValidator(survey.Required), stdio); err != nil { - return nil, err + // 5. Start time + if existing.timeFrom == "" { + if err := survey.AskOne(&survey.Input{Message: "Start time (HH:MM):"}, &f.timeFrom, + survey.WithValidator(func(val interface{}) error { + s, _ := val.(string) + if _, err := time.Parse("15:04", s); err != nil { + return errors.New("must be HH:MM format") + } + return nil + }), stdio); err != nil { + return nil, err + } } - var timeTo string - if err := survey.AskOne(&survey.Input{Message: "End time (HH:MM):"}, &timeTo, survey.WithValidator(survey.Required), stdio); err != nil { - return nil, err + // 6. End time + if existing.timeTo == "" { + if err := survey.AskOne(&survey.Input{Message: "End time (HH:MM):"}, &f.timeTo, + survey.WithValidator(func(val interface{}) error { + s, _ := val.(string) + if _, err := time.Parse("15:04", s); err != nil { + return errors.New("must be HH:MM format") + } + return nil + }), stdio); err != nil { + return nil, err + } } - return &submitFields{ - reason: reason, - priority: priority, - date: date, - timezone: timezone, - timeFrom: timeFrom, - timeTo: timeTo, - }, nil + return f, nil +} + +func confirmSubmit() (bool, error) { + if !ui.IsInteractive() { + return false, fmt.Errorf("%w; use --yes to skip confirmation", ui.ErrNotInteractive) + } + var confirmed bool + err := survey.AskOne(&survey.Confirm{Message: "Submit this request?"}, &confirmed, + survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) + return confirmed, err } func runRequestSubmit(cmd *cobra.Command, svc accessRequestService) error { - ctx := cmd.Context() + provider, _ := cmd.Flags().GetString("provider") + if provider != "" { + if _, err := parseProvider(provider); err != nil { + return err + } + } fields, err := resolveSubmitFields(cmd) if err != nil { @@ -121,15 +197,39 @@ func runRequestSubmit(cmd *cobra.Command, svc accessRequestService) error { return err } - provider, _ := cmd.Flags().GetString("provider") targetName, _ := cmd.Flags().GetString("target") roleName, _ := cmd.Flags().GetString("role") + ctx, cancel := context.WithTimeout(cmd.Context(), apiTimeout) + defer cancel() + target, err := resolveSubmitTargetFn(ctx, provider, targetName, roleName) if err != nil { return err } + // Summary before submission + if !isJSONOutput() { + fmt.Fprintf(cmd.ErrOrStderr(), "\nTarget: %s / %s\n", target.WorkspaceName, target.RoleInfo.Name) + fmt.Fprintf(cmd.ErrOrStderr(), "Date: %s\n", fields.date) + fmt.Fprintf(cmd.ErrOrStderr(), "Time: %s – %s (%s)\n", fields.timeFrom, fields.timeTo, fields.timezone) + fmt.Fprintf(cmd.ErrOrStderr(), "Priority: %s\n", fields.priority) + fmt.Fprintf(cmd.ErrOrStderr(), "Reason: %s\n\n", fields.reason) + } + + // Confirmation + yesFlag, _ := cmd.Flags().GetBool("yes") + if !yesFlag && !isJSONOutput() { + confirmed, confirmErr := confirmSubmitFn() + if confirmErr != nil { + return confirmErr + } + if !confirmed { + fmt.Fprintln(cmd.OutOrStdout(), "Submission canceled.") + return nil + } + } + details := buildRequestDetails(target, fields) log.Info("Submitting access request for %s / %s", target.WorkspaceName, target.RoleInfo.Name) @@ -172,7 +272,7 @@ func resolveSubmitFields(cmd *cobra.Command) (*submitFields, error) { return nil, errors.New("non-interactive mode requires --reason, --date, --timezone, --from, --to") } - prompted, err := submitPromptFn() + prompted, err := submitPromptFn(f) if err != nil { return nil, err } @@ -209,7 +309,10 @@ func resolveSubmitTarget(ctx context.Context, provider, targetName, roleName str } cachedLister := buildCachedLister(cfg, false, scaSvc, nil) - targets, err := fetchEligibility(ctx, cachedLister, provider) + fetchCtx, fetchCancel := context.WithTimeout(ctx, apiTimeout) + defer fetchCancel() + + targets, err := fetchEligibility(fetchCtx, cachedLister, provider) if err != nil { return nil, fmt.Errorf("failed to fetch eligibility: %w", err) } @@ -223,6 +326,42 @@ func resolveSubmitTarget(ctx context.Context, provider, targetName, roleName str return target, nil } + // Partial --target: filter to matching workspaces + if targetName != "" && roleName == "" { + var filtered []models.EligibleTarget + for i := range targets { + if strings.EqualFold(targets[i].WorkspaceName, targetName) { + filtered = append(filtered, targets[i]) + } + } + if len(filtered) == 0 { + return nil, fmt.Errorf("no eligible target found matching target=%q", targetName) + } + if len(filtered) == 1 { + resolveTargetCSP(&filtered[0], targets, provider) + return &filtered[0], nil + } + targets = filtered + } + + // Partial --role: filter to matching roles + if roleName != "" && targetName == "" { + var filtered []models.EligibleTarget + for i := range targets { + if strings.EqualFold(targets[i].RoleInfo.Name, roleName) { + filtered = append(filtered, targets[i]) + } + } + if len(filtered) == 0 { + return nil, fmt.Errorf("no eligible target found matching role=%q", roleName) + } + if len(filtered) == 1 { + resolveTargetCSP(&filtered[0], targets, provider) + return &filtered[0], nil + } + targets = filtered + } + if !ui.IsInteractive() { return nil, errors.New("non-interactive mode requires --target and --role") } @@ -273,28 +412,32 @@ func validateSubmitFields(f *submitFields) error { return fmt.Errorf("--priority must be High, Medium, or Low (got %q)", f.priority) } - if f.date != "" { - if _, err := time.Parse("2006-01-02", f.date); err != nil { - return fmt.Errorf("--date must be in YYYY-MM-DD format (got %q)", f.date) - } + if f.date == "" { + return errors.New("--date is required") + } + if _, err := time.Parse("2006-01-02", f.date); err != nil { + return fmt.Errorf("--date must be in YYYY-MM-DD format (got %q)", f.date) } - if f.timeFrom != "" { - if _, err := time.Parse("15:04", f.timeFrom); err != nil { - return fmt.Errorf("--from must be in HH:MM format (got %q)", f.timeFrom) - } + if f.timezone == "" { + return errors.New("--timezone is required") + } + if _, err := time.LoadLocation(f.timezone); err != nil { + return fmt.Errorf("--timezone must be a valid TZ identifier (e.g. America/New_York, UTC), got %q", f.timezone) } - if f.timeTo != "" { - if _, err := time.Parse("15:04", f.timeTo); err != nil { - return fmt.Errorf("--to must be in HH:MM format (got %q)", f.timeTo) - } + if f.timeFrom == "" { + return errors.New("--from is required") + } + if _, err := time.Parse("15:04", f.timeFrom); err != nil { + return fmt.Errorf("--from must be in HH:MM format (got %q)", f.timeFrom) } - if f.timezone != "" { - if _, err := time.LoadLocation(f.timezone); err != nil { - return fmt.Errorf("--timezone must be a valid TZ identifier (e.g. America/New_York, UTC), got %q", f.timezone) - } + if f.timeTo == "" { + return errors.New("--to is required") + } + if _, err := time.Parse("15:04", f.timeTo); err != nil { + return fmt.Errorf("--to must be in HH:MM format (got %q)", f.timeTo) } return nil diff --git a/cmd/request_test.go b/cmd/request_test.go index a662931..5997770 100644 --- a/cmd/request_test.go +++ b/cmd/request_test.go @@ -384,7 +384,10 @@ func TestValidateSubmitFields(t *testing.T) { {"UTC timezone", &submitFields{"reason", "Medium", "2026-04-21", "UTC", "09:00", "17:00"}, false}, {"CET timezone", &submitFields{"reason", "Medium", "2026-04-21", "CET", "09:00", "17:00"}, false}, {"invalid timezone", &submitFields{"reason", "Medium", "2026-04-21", "NotAZone", "09:00", "17:00"}, true}, - {"empty optional fields", &submitFields{"reason", "Medium", "", "", "", ""}, false}, + {"missing date", &submitFields{"reason", "Medium", "", "UTC", "09:00", "17:00"}, true}, + {"missing timezone", &submitFields{"reason", "Medium", "2026-04-21", "", "09:00", "17:00"}, true}, + {"missing from", &submitFields{"reason", "Medium", "2026-04-21", "UTC", "", "17:00"}, true}, + {"missing to", &submitFields{"reason", "Medium", "2026-04-21", "UTC", "09:00", ""}, true}, } for _, tt := range tests { @@ -513,7 +516,8 @@ func TestRunRequestSubmit_NonInteractive(t *testing.T) { output, err := executeCommand(root, "request", "submit", "--target", "Test Sub", "--role", "Contributor", "--reason", "need access", "--date", "2026-04-21", - "--timezone", "UTC", "--from", "09:00", "--to", "17:00") + "--timezone", "UTC", "--from", "09:00", "--to", "17:00", + "--yes") if err != nil { t.Fatalf("unexpected error: %v\noutput: %s", err, output) } @@ -555,7 +559,7 @@ func TestRunRequestSubmit_JSONOutput(t *testing.T) { "--target", "Test Sub", "--role", "Contributor", "--reason", "test", "--date", "2026-04-21", "--timezone", "UTC", "--from", "09:00", "--to", "17:00", - "--output", "json") + "--output", "json", "--yes") if err != nil { t.Fatalf("unexpected error: %v\noutput: %s", err, output) } @@ -593,7 +597,8 @@ func TestRunRequestSubmit_ServiceError(t *testing.T) { _, err := executeCommand(root, "request", "submit", "--target", "Sub", "--role", "Reader", "--reason", "test", "--date", "2026-04-21", - "--timezone", "UTC", "--from", "09:00", "--to", "17:00") + "--timezone", "UTC", "--from", "09:00", "--to", "17:00", + "--yes") if err == nil { t.Fatal("expected error, got nil") } @@ -635,7 +640,7 @@ func TestResolveSubmitFields_Interactive(t *testing.T) { originalPrompt := submitPromptFn defer func() { submitPromptFn = originalPrompt }() - submitPromptFn = func() (*submitFields, error) { + submitPromptFn = func(_ *submitFields) (*submitFields, error) { return &submitFields{ reason: "prompted reason", priority: "High", @@ -677,7 +682,7 @@ func TestResolveSubmitFields_FlagOverridesPrompt(t *testing.T) { originalPrompt := submitPromptFn defer func() { submitPromptFn = originalPrompt }() - submitPromptFn = func() (*submitFields, error) { + submitPromptFn = func(_ *submitFields) (*submitFields, error) { return &submitFields{ reason: "prompted", priority: "Low", @@ -733,3 +738,244 @@ func TestResolveSubmitFields_NonInteractive_Error(t *testing.T) { t.Errorf("error %q does not mention non-interactive", err.Error()) } } + +func TestResolveLocalTimezone(t *testing.T) { + tz := resolveLocalTimezone() + if tz == "Local" { + t.Error("resolveLocalTimezone() returned 'Local'") + } + if tz == "" { + t.Error("resolveLocalTimezone() returned empty string") + } +} + +func TestRunRequestSubmit_InvalidProvider(t *testing.T) { + original := resolveSubmitTargetFn + defer func() { resolveSubmitTargetFn = original }() + + resolveSubmitTargetFn = func(_ context.Context, _, _, _ string) (*models.EligibleTarget, error) { + t.Fatal("resolveSubmitTarget should not be called with invalid provider") + return nil, nil + } + + svc := &mockAccessRequestService{} + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + _, err := executeCommand(root, "request", "submit", + "--provider", "gcp", + "--target", "Sub", "--role", "Reader", + "--reason", "test", "--date", "2026-04-21", + "--timezone", "UTC", "--from", "09:00", "--to", "17:00", + "--yes") + if err == nil { + t.Fatal("expected error for invalid provider") + } + if !strings.Contains(err.Error(), "invalid provider") { + t.Errorf("error %q does not mention invalid provider", err.Error()) + } +} + +func TestRunRequestSubmit_ConfirmationDenied(t *testing.T) { + original := resolveSubmitTargetFn + originalConfirm := confirmSubmitFn + defer func() { + resolveSubmitTargetFn = original + confirmSubmitFn = originalConfirm + }() + + resolveSubmitTargetFn = func(_ context.Context, _, _, _ string) (*models.EligibleTarget, error) { + return &models.EligibleTarget{ + WorkspaceName: "Test Sub", + WorkspaceID: "ws-1", + CSP: models.CSPAzure, + RoleInfo: models.RoleInfo{ID: "role-1", Name: "Contributor"}, + }, nil + } + confirmSubmitFn = func() (bool, error) { + return false, nil + } + + originalTTY := ui.IsTerminalFunc + defer func() { ui.IsTerminalFunc = originalTTY }() + ui.IsTerminalFunc = func(fd uintptr) bool { return true } + + svc := &mockAccessRequestService{ + submitResult: &wfmodels.AccessRequest{RequestID: "should-not-reach"}, + } + + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + output, err := executeCommand(root, "request", "submit", + "--target", "Test Sub", "--role", "Contributor", + "--reason", "test", "--date", "2026-04-21", + "--timezone", "UTC", "--from", "09:00", "--to", "17:00") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(output, "canceled") { + t.Errorf("expected 'canceled' in output, got: %s", output) + } +} + +func TestRunRequestSubmit_YesFlagSkipsConfirmation(t *testing.T) { + original := resolveSubmitTargetFn + originalConfirm := confirmSubmitFn + defer func() { + resolveSubmitTargetFn = original + confirmSubmitFn = originalConfirm + }() + + resolveSubmitTargetFn = func(_ context.Context, _, _, _ string) (*models.EligibleTarget, error) { + return &models.EligibleTarget{ + WorkspaceName: "Test Sub", + WorkspaceID: "ws-1", + CSP: models.CSPAzure, + RoleInfo: models.RoleInfo{ID: "role-1", Name: "Contributor"}, + }, nil + } + confirmSubmitFn = func() (bool, error) { + t.Fatal("confirmSubmitFn should not be called with --yes") + return false, nil + } + + svc := &mockAccessRequestService{ + submitResult: &wfmodels.AccessRequest{ + RequestID: "req-yes", + RequestState: wfmodels.RequestStatePending, + }, + } + + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + output, err := executeCommand(root, "request", "submit", + "--target", "Test Sub", "--role", "Contributor", + "--reason", "test", "--date", "2026-04-21", + "--timezone", "UTC", "--from", "09:00", "--to", "17:00", + "--yes") + if err != nil { + t.Fatalf("unexpected error: %v\noutput: %s", err, output) + } + if !strings.Contains(output, "req-yes") { + t.Errorf("expected request ID in output, got: %s", output) + } +} + +func TestResolveSubmitTarget_PartialTarget(t *testing.T) { + original := resolveSubmitTargetFn + defer func() { resolveSubmitTargetFn = original }() + + targets := []models.EligibleTarget{ + {WorkspaceName: "Sub A", WorkspaceID: "ws-a", CSP: models.CSPAzure, RoleInfo: models.RoleInfo{ID: "r1", Name: "Reader"}}, + {WorkspaceName: "Sub A", WorkspaceID: "ws-a", CSP: models.CSPAzure, RoleInfo: models.RoleInfo{ID: "r2", Name: "Contributor"}}, + {WorkspaceName: "Sub B", WorkspaceID: "ws-b", CSP: models.CSPAzure, RoleInfo: models.RoleInfo{ID: "r3", Name: "Reader"}}, + } + + // Single match with --target filters to one result + var filtered []models.EligibleTarget + for i := range targets { + if strings.EqualFold(targets[i].WorkspaceName, "Sub B") { + filtered = append(filtered, targets[i]) + } + } + if len(filtered) != 1 { + t.Fatalf("expected 1 match for 'Sub B', got %d", len(filtered)) + } + if filtered[0].RoleInfo.Name != "Reader" { + t.Errorf("expected Reader, got %s", filtered[0].RoleInfo.Name) + } +} + +func TestResolveSubmitTarget_PartialRole(t *testing.T) { + targets := []models.EligibleTarget{ + {WorkspaceName: "Sub A", WorkspaceID: "ws-a", CSP: models.CSPAzure, RoleInfo: models.RoleInfo{ID: "r1", Name: "Reader"}}, + {WorkspaceName: "Sub A", WorkspaceID: "ws-a", CSP: models.CSPAzure, RoleInfo: models.RoleInfo{ID: "r2", Name: "Contributor"}}, + {WorkspaceName: "Sub B", WorkspaceID: "ws-b", CSP: models.CSPAzure, RoleInfo: models.RoleInfo{ID: "r3", Name: "Contributor"}}, + } + + // --role "Reader" should match only Sub A/Reader + var filtered []models.EligibleTarget + for i := range targets { + if strings.EqualFold(targets[i].RoleInfo.Name, "Reader") { + filtered = append(filtered, targets[i]) + } + } + if len(filtered) != 1 { + t.Fatalf("expected 1 match for 'Reader', got %d", len(filtered)) + } + if filtered[0].WorkspaceName != "Sub A" { + t.Errorf("expected Sub A, got %s", filtered[0].WorkspaceName) + } +} + +func TestDefaultSubmitPrompt_NonInteractive(t *testing.T) { + originalTTY := ui.IsTerminalFunc + defer func() { ui.IsTerminalFunc = originalTTY }() + ui.IsTerminalFunc = func(fd uintptr) bool { return false } + + _, err := defaultSubmitPrompt(&submitFields{}) + if err == nil { + t.Fatal("expected error in non-interactive mode") + } + if !errors.Is(err, ui.ErrNotInteractive) { + t.Errorf("expected ErrNotInteractive, got: %v", err) + } +} + +func TestResolveSubmitFields_PromptOnlyMissing(t *testing.T) { + originalPrompt := submitPromptFn + defer func() { submitPromptFn = originalPrompt }() + + var receivedExisting *submitFields + submitPromptFn = func(existing *submitFields) (*submitFields, error) { + receivedExisting = existing + return &submitFields{ + reason: "prompted", + priority: "High", + date: "2026-05-01", + timezone: "America/Chicago", + timeFrom: "10:00", + timeTo: "18:00", + }, nil + } + + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("reason", "", "") + cmd.Flags().String("priority", "Medium", "") + cmd.Flags().String("date", "", "") + cmd.Flags().String("timezone", "", "") + cmd.Flags().String("from", "", "") + cmd.Flags().String("to", "", "") + cmd.SetArgs([]string{"--reason", "my reason", "--date", "2026-06-01"}) + _ = cmd.Execute() + + originalTTY := ui.IsTerminalFunc + defer func() { ui.IsTerminalFunc = originalTTY }() + ui.IsTerminalFunc = func(fd uintptr) bool { return true } + + f, err := resolveSubmitFields(cmd) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if receivedExisting == nil { + t.Fatal("submitPromptFn was not called with existing fields") + } + if receivedExisting.reason != "my reason" { + t.Errorf("existing.reason: got %q, want 'my reason'", receivedExisting.reason) + } + if receivedExisting.date != "2026-06-01" { + t.Errorf("existing.date: got %q, want '2026-06-01'", receivedExisting.date) + } + // Flag values should override prompted values + if f.reason != "my reason" { + t.Errorf("f.reason: got %q, want 'my reason'", f.reason) + } + if f.date != "2026-06-01" { + t.Errorf("f.date: got %q, want '2026-06-01'", f.date) + } +} From da5bacdd54bbfafff1e10f871713b1914ec64ce8 Mon Sep 17 00:00:00 2001 From: Tim Schindler Date: Mon, 20 Apr 2026 17:17:26 +0200 Subject: [PATCH 3/9] feat: default start time to now, end time to start + 1 hour --- cmd/request_submit.go | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/cmd/request_submit.go b/cmd/request_submit.go index 142a5de..1e79863 100644 --- a/cmd/request_submit.go +++ b/cmd/request_submit.go @@ -139,9 +139,15 @@ func defaultSubmitPrompt(existing *submitFields) (*submitFields, error) { } } - // 5. Start time + // 5. Start time (default: current time in selected timezone) if existing.timeFrom == "" { - if err := survey.AskOne(&survey.Input{Message: "Start time (HH:MM):"}, &f.timeFrom, + tz := f.timezone + if tz == "" { + tz = existing.timezone + } + loc, _ := time.LoadLocation(tz) + defaultStart := time.Now().In(loc).Format("15:04") + if err := survey.AskOne(&survey.Input{Message: "Start time (HH:MM):", Default: defaultStart}, &f.timeFrom, survey.WithValidator(func(val interface{}) error { s, _ := val.(string) if _, err := time.Parse("15:04", s); err != nil { @@ -153,9 +159,17 @@ func defaultSubmitPrompt(existing *submitFields) (*submitFields, error) { } } - // 6. End time + // 6. End time (default: start + 1 hour) if existing.timeTo == "" { - if err := survey.AskOne(&survey.Input{Message: "End time (HH:MM):"}, &f.timeTo, + startTime := f.timeFrom + if startTime == "" { + startTime = existing.timeFrom + } + defaultEnd := "" + if parsed, parseErr := time.Parse("15:04", startTime); parseErr == nil { + defaultEnd = parsed.Add(time.Hour).Format("15:04") + } + if err := survey.AskOne(&survey.Input{Message: "End time (HH:MM):", Default: defaultEnd}, &f.timeTo, survey.WithValidator(func(val interface{}) error { s, _ := val.(string) if _, err := time.Parse("15:04", s); err != nil { From 82963140d94753f6baaba2f2dc5c01eef5f96fe4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 20 Apr 2026 15:20:54 +0000 Subject: [PATCH 4/9] fix: address reviewer feedback on ListRequests pagination and formatTimestamp - ListRequests: return errPaginationLimit sentinel error when maxPages exhausted instead of silently returning truncated data (matches SCA paginate() pattern) - formatTimestamp: parse RFC3339Nano and reformat without fractional seconds while preserving timezone offset; fall back to trimming at '.' for non-RFC3339 timestamps - Extract pagination limit error to package-level var errPaginationLimit - Test: use errors.Is(err, errPaginationLimit) for precise assertion - Test: add RFC3339 timezone cases to TestFormatTimestamp Agent-Logs-Url: https://github.com/aaearon/grant-cli/sessions/a09d2658-41d5-414b-9336-32ddff12742f Co-authored-by: aaearon <812640+aaearon@users.noreply.github.com> --- cmd/request.go | 15 ++++++++++++--- cmd/request_test.go | 9 +++++++++ internal/workflows/service.go | 8 ++++++-- internal/workflows/service_test.go | 25 +++++++++++++++++++++++++ 4 files changed, 52 insertions(+), 5 deletions(-) diff --git a/cmd/request.go b/cmd/request.go index ffd8584..59068ac 100644 --- a/cmd/request.go +++ b/cmd/request.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" "text/tabwriter" + "time" "github.com/aaearon/grant-cli/internal/workflows" "github.com/aaearon/grant-cli/internal/workflows/models" @@ -165,10 +166,18 @@ func formatRequestDetail(cmd *cobra.Command, r *models.AccessRequest) { } } -// formatTimestamp truncates a timestamp to just the date+time portion (no microseconds). +// formatTimestamp strips fractional seconds from a timestamp while preserving +// timezone offset information. RFC3339 timestamps (with Z or ±HH:MM timezone +// offset) are parsed with RFC3339Nano and reformatted as RFC3339 (no subseconds), +// keeping the original offset. Non-RFC3339 timestamps have the fractional-seconds +// portion trimmed if present. func formatTimestamp(ts string) string { - if len(ts) > 19 { - return ts[:19] + if t, err := time.Parse(time.RFC3339Nano, ts); err == nil { + return t.Format("2006-01-02T15:04:05Z07:00") + } + // Non-RFC3339 timestamps (e.g. no timezone): trim fractional seconds if present. + if i := strings.IndexByte(ts, '.'); i >= 0 { + return ts[:i] } return ts } diff --git a/cmd/request_test.go b/cmd/request_test.go index 5997770..e2a516c 100644 --- a/cmd/request_test.go +++ b/cmd/request_test.go @@ -405,8 +405,17 @@ func TestFormatTimestamp(t *testing.T) { input string want string }{ + // No timezone, with fractional seconds: trim fraction {"2025-08-12T09:41:00.594008", "2025-08-12T09:41:00"}, + // No timezone, no fractional seconds: return as-is {"2025-08-12T09:41:00", "2025-08-12T09:41:00"}, + // RFC3339 UTC with fractional seconds: strip fraction, keep Z + {"2025-08-12T09:41:00.594008Z", "2025-08-12T09:41:00Z"}, + // RFC3339 UTC without fractional seconds: return normalised form + {"2025-08-12T09:41:00Z", "2025-08-12T09:41:00Z"}, + // RFC3339 with offset and fractional seconds: strip fraction, keep offset + {"2025-08-12T09:41:00.123+05:30", "2025-08-12T09:41:00+05:30"}, + // Arbitrary short string: return as-is {"short", "short"}, } for _, tt := range tests { diff --git a/internal/workflows/service.go b/internal/workflows/service.go index 9d7ea86..5c95e20 100644 --- a/internal/workflows/service.go +++ b/internal/workflows/service.go @@ -167,12 +167,12 @@ func (s *AccessRequestService) ListRequests(ctx context.Context, params ListRequ totalCount = page.TotalCount if len(allItems) >= page.TotalCount || len(page.Items) < limit { - break + return allItems, totalCount, nil } offset += len(page.Items) } - return allItems, totalCount, nil + return nil, 0, errPaginationLimit } // GetRequest retrieves a single access request by ID. @@ -276,6 +276,10 @@ func (s *AccessRequestService) FinalizeRequest(ctx context.Context, requestID, r const maxPages = 100 +// errPaginationLimit is returned when ListRequests exhausts the maximum number of +// page fetches without completing pagination. +var errPaginationLimit = errors.New("list requests pagination exceeded maximum page limit") + func checkResponse(resp *http.Response, operation string) error { if resp.StatusCode == http.StatusOK { return nil diff --git a/internal/workflows/service_test.go b/internal/workflows/service_test.go index 43cb5c9..f49b324 100644 --- a/internal/workflows/service_test.go +++ b/internal/workflows/service_test.go @@ -3,6 +3,7 @@ package workflows import ( "context" "encoding/json" + "errors" "io" "net/http" "strings" @@ -125,6 +126,30 @@ func TestListRequests(t *testing.T) { } } +func TestListRequests_MaxPagesExceeded(t *testing.T) { + // Every page returns 1 item but claims totalCount is enormous, so the loop + // never hits the break condition and must eventually return an error. + mock := &mockHTTPClient{ + getFn: func(_ context.Context, _ string, _ interface{}) (*http.Response, error) { + return jsonResponse(200, models.ListRequestsResponse{ + Items: []models.AccessRequest{{RequestID: "id-1"}}, + Count: 1, + TotalCount: 99999, + }), nil + }, + } + + svc := NewAccessRequestServiceWithClient(mock) + _, _, err := svc.ListRequests(t.Context(), ListRequestsParams{Limit: 1}) + if err == nil { + t.Fatal("expected error when maxPages exceeded, got nil") + } + if !errors.Is(err, errPaginationLimit) { + t.Errorf("expected errPaginationLimit, got: %v", err) + } +} + + func TestListRequests_WithFilters(t *testing.T) { mock := &mockHTTPClient{ getFn: func(_ context.Context, _ string, params interface{}) (*http.Response, error) { From c0390c5380c78d6eb206d24fe882b5869f12e05f Mon Sep 17 00:00:00 2001 From: Tim Schindler Date: Mon, 20 Apr 2026 17:33:54 +0200 Subject: [PATCH 5/9] refactor: workspace-based target selection for access requests The interactive selector now shows deduplicated workspaces (not role/workspace combos) since on-demand requests are for roles the user doesn't already have a JIT policy for. Users select a workspace they have access to, then provide the role ID they want to request. - Replace role-based target selector with workspace selector - Add --role-id flag (required) for the role being requested - --role flag becomes optional display name (defaults to role ID) - Deduplicate eligibility targets to unique workspaces by ID - Add workspace type labels in selector (Account, Subscription, etc.) --- CLAUDE.md | 2 +- cmd/request_submit.go | 198 ++++++++++++++++++++++++++---------------- cmd/request_test.go | 148 +++++++++++++++---------------- 3 files changed, 190 insertions(+), 158 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index d2500dd..97f80ac 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -74,7 +74,7 @@ Custom `SCAAccessService` follows SDK conventions: - `grant list` — list eligible targets and groups without triggering elevation; supports `--provider`, `--groups`, `--refresh`, `--output json`; used by LLMs to discover available targets programmatically - `grant revoke` — revoke sessions: direct (`grant revoke `), `--all`, or interactive multi-select; `--yes` skips confirmation - `grant request` — manage access requests through approval workflow; subcommands: `submit`, `list`, `get`, `cancel`, `approve`, `reject` -- `grant request submit` — submit access request; reuses SCA eligibility for target selection; shows summary + confirmation before submitting; partial `--target` or `--role` filters the selector; flags: `--target`, `--role`, `--provider`, `--reason`, `--priority`, `--date`, `--timezone`, `--from`, `--to`, `--yes` +- `grant request submit` — submit on-demand access request; workspace selector uses SCA eligibility (deduplicated to unique workspaces); user provides role ID for the role they want; shows summary + confirmation before submitting; flags: `--target`, `--role-id`, `--role`, `--provider`, `--reason`, `--priority`, `--date`, `--timezone`, `--from`, `--to`, `--yes` - `grant request list` — list access requests; flags: `--state`, `--result`, `--priority`, `--role` (CREATOR/APPROVER), `--search`, `--sort`, `--desc` - `grant request get ` — get full request details - `grant request cancel ` — cancel an open request; optional `--reason` diff --git a/cmd/request_submit.go b/cmd/request_submit.go index 1e79863..6e6cad1 100644 --- a/cmd/request_submit.go +++ b/cmd/request_submit.go @@ -35,7 +35,8 @@ func newRequestSubmitCommand(svc accessRequestService) *cobra.Command { cmd.Flags().StringP("provider", "p", "", "Cloud provider: azure, aws") cmd.Flags().StringP("target", "t", "", "Target workspace name") - cmd.Flags().StringP("role", "r", "", "Role name") + cmd.Flags().String("role-id", "", "Role ID to request access for (required)") + cmd.Flags().StringP("role", "r", "", "Role name (display only)") cmd.Flags().String("reason", "", "Reason for the request (required)") cmd.Flags().String("priority", "Medium", "Priority: High, Medium, Low") cmd.Flags().String("date", "", "Request date (YYYY-MM-DD)") @@ -56,6 +57,9 @@ var confirmSubmitFn = confirmSubmit // resolveSubmitTargetFn is injectable for testing target resolution. var resolveSubmitTargetFn = resolveSubmitTarget +// submitWorkspaceSelectorFn is injectable for testing workspace selection. +var submitWorkspaceSelectorFn = selectSubmitWorkspace + type submitFields struct { reason string priority string @@ -65,6 +69,15 @@ type submitFields struct { timeTo string } +// submitWorkspace holds deduplicated workspace info derived from eligibility. +type submitWorkspace struct { + WorkspaceID string + WorkspaceName string + WorkspaceType models.WorkspaceType + CSP models.CSP + OrganizationID string +} + func resolveLocalTimezone() string { tz := time.Now().Location().String() if tz == "Local" { @@ -212,23 +225,42 @@ func runRequestSubmit(cmd *cobra.Command, svc accessRequestService) error { } targetName, _ := cmd.Flags().GetString("target") + roleID, _ := cmd.Flags().GetString("role-id") roleName, _ := cmd.Flags().GetString("role") ctx, cancel := context.WithTimeout(cmd.Context(), apiTimeout) defer cancel() - target, err := resolveSubmitTargetFn(ctx, provider, targetName, roleName) + workspace, err := resolveSubmitTargetFn(ctx, provider, targetName) if err != nil { return err } + // Resolve role ID + if roleID == "" { + if !ui.IsInteractive() { + return errors.New("non-interactive mode requires --role-id") + } + stdio := survey.WithStdio(os.Stdin, os.Stderr, os.Stderr) + if err := survey.AskOne(&survey.Input{Message: "Role ID:"}, &roleID, + survey.WithValidator(survey.Required), stdio); err != nil { + return err + } + } + + // Resolve role name (optional, defaults to role ID if not provided) + if roleName == "" { + roleName = roleID + } + // Summary before submission if !isJSONOutput() { - fmt.Fprintf(cmd.ErrOrStderr(), "\nTarget: %s / %s\n", target.WorkspaceName, target.RoleInfo.Name) - fmt.Fprintf(cmd.ErrOrStderr(), "Date: %s\n", fields.date) - fmt.Fprintf(cmd.ErrOrStderr(), "Time: %s – %s (%s)\n", fields.timeFrom, fields.timeTo, fields.timezone) - fmt.Fprintf(cmd.ErrOrStderr(), "Priority: %s\n", fields.priority) - fmt.Fprintf(cmd.ErrOrStderr(), "Reason: %s\n\n", fields.reason) + fmt.Fprintf(cmd.ErrOrStderr(), "\nWorkspace: %s\n", workspace.WorkspaceName) + fmt.Fprintf(cmd.ErrOrStderr(), "Role: %s (ID: %s)\n", roleName, roleID) + fmt.Fprintf(cmd.ErrOrStderr(), "Date: %s\n", fields.date) + fmt.Fprintf(cmd.ErrOrStderr(), "Time: %s – %s (%s)\n", fields.timeFrom, fields.timeTo, fields.timezone) + fmt.Fprintf(cmd.ErrOrStderr(), "Priority: %s\n", fields.priority) + fmt.Fprintf(cmd.ErrOrStderr(), "Reason: %s\n\n", fields.reason) } // Confirmation @@ -244,9 +276,9 @@ func runRequestSubmit(cmd *cobra.Command, svc accessRequestService) error { } } - details := buildRequestDetails(target, fields) + details := buildRequestDetails(workspace, roleID, roleName, fields) - log.Info("Submitting access request for %s / %s", target.WorkspaceName, target.RoleInfo.Name) + log.Info("Submitting access request for %s / %s", workspace.WorkspaceName, roleName) result, err := svc.SubmitRequest(ctx, &wfmodels.SubmitAccessRequest{ TargetCategory: "CLOUD_CONSOLE", @@ -311,7 +343,7 @@ func resolveSubmitFields(cmd *cobra.Command) (*submitFields, error) { return f, nil } -func resolveSubmitTarget(ctx context.Context, provider, targetName, roleName string) (*models.EligibleTarget, error) { +func resolveSubmitTarget(ctx context.Context, provider, targetName string) (*submitWorkspace, error) { _, scaSvc, _, err := bootstrapSCAService() if err != nil { return nil, fmt.Errorf("failed to bootstrap SCA service: %w", err) @@ -331,82 +363,106 @@ func resolveSubmitTarget(ctx context.Context, provider, targetName, roleName str return nil, fmt.Errorf("failed to fetch eligibility: %w", err) } - if targetName != "" && roleName != "" { - target := findMatchingTarget(targets, targetName, roleName) - if target == nil { - return nil, fmt.Errorf("no eligible target found matching target=%q role=%q", targetName, roleName) - } - resolveTargetCSP(target, targets, provider) - return target, nil + workspaces := deduplicateWorkspaces(targets) + if len(workspaces) == 0 { + return nil, errors.New("no eligible workspaces found") } - // Partial --target: filter to matching workspaces - if targetName != "" && roleName == "" { - var filtered []models.EligibleTarget - for i := range targets { - if strings.EqualFold(targets[i].WorkspaceName, targetName) { - filtered = append(filtered, targets[i]) + // Non-interactive: match by --target flag + if targetName != "" { + for i := range workspaces { + if strings.EqualFold(workspaces[i].WorkspaceName, targetName) { + return &workspaces[i], nil } } - if len(filtered) == 0 { - return nil, fmt.Errorf("no eligible target found matching target=%q", targetName) - } - if len(filtered) == 1 { - resolveTargetCSP(&filtered[0], targets, provider) - return &filtered[0], nil - } - targets = filtered + return nil, fmt.Errorf("no eligible workspace found matching target=%q", targetName) } - // Partial --role: filter to matching roles - if roleName != "" && targetName == "" { - var filtered []models.EligibleTarget - for i := range targets { - if strings.EqualFold(targets[i].RoleInfo.Name, roleName) { - filtered = append(filtered, targets[i]) - } - } - if len(filtered) == 0 { - return nil, fmt.Errorf("no eligible target found matching role=%q", roleName) - } - if len(filtered) == 1 { - resolveTargetCSP(&filtered[0], targets, provider) - return &filtered[0], nil - } - targets = filtered + // Single workspace: auto-select + if len(workspaces) == 1 { + return &workspaces[0], nil } + // Interactive selection + return submitWorkspaceSelectorFn(workspaces) +} + +func selectSubmitWorkspace(workspaces []submitWorkspace) (*submitWorkspace, error) { if !ui.IsInteractive() { - return nil, errors.New("non-interactive mode requires --target and --role") + return nil, errors.New("non-interactive mode requires --target") } - items := buildCloudSelectionItems(targets) - sel := &uiUnifiedSelector{} - selected, err := sel.SelectItem(items) + + options := make([]string, len(workspaces)) + for i, ws := range workspaces { + options[i] = formatWorkspaceOption(ws) + } + + var selected int + err := survey.AskOne(&survey.Select{ + Message: "Select a workspace:", + Options: options, + }, &selected, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) if err != nil { return nil, err } - resolveTargetCSP(selected.cloud, targets, provider) - return selected.cloud, nil + return &workspaces[selected], nil } -// API submit payload uses camelCase keys (per spec example), not the snake_case -// form question keys from GET /request-forms. -func buildRequestDetails(target *models.EligibleTarget, f *submitFields) map[string]interface{} { - locationType := string(target.CSP) - if target.CSP == models.CSPAzure { +func formatWorkspaceOption(ws submitWorkspace) string { + label := workspaceTypeLabel(ws.WorkspaceType) + return fmt.Sprintf("%s: %s (%s)", label, ws.WorkspaceName, strings.ToLower(string(ws.CSP))) +} + +func workspaceTypeLabel(wt models.WorkspaceType) string { + switch wt { + case models.WorkspaceTypeSubscription: + return "Subscription" + case models.WorkspaceTypeManagementGroup: + return "Management Group" + case models.WorkspaceTypeDirectory: + return "Directory" + case models.WorkspaceTypeAccount: + return "Account" + default: + return string(wt) + } +} + +func deduplicateWorkspaces(targets []models.EligibleTarget) []submitWorkspace { + seen := make(map[string]bool) + var result []submitWorkspace + for _, t := range targets { + if seen[t.WorkspaceID] { + continue + } + seen[t.WorkspaceID] = true + result = append(result, submitWorkspace{ + WorkspaceID: t.WorkspaceID, + WorkspaceName: t.WorkspaceName, + WorkspaceType: t.WorkspaceType, + CSP: t.CSP, + OrganizationID: t.OrganizationID, + }) + } + return result +} + +func buildRequestDetails(ws *submitWorkspace, roleID, roleName string, f *submitFields) map[string]interface{} { + locationType := string(ws.CSP) + if ws.CSP == models.CSPAzure { locationType = "Azure" - } else if target.CSP == models.CSPAWS { + } else if ws.CSP == models.CSPAWS { locationType = "AWS" } return map[string]interface{}{ "locationType": locationType, - "roleId": target.RoleInfo.ID, - "roleName": target.RoleInfo.Name, - "workspaceId": target.WorkspaceID, - "workspaceName": target.WorkspaceName, - "workspaceType": string(target.WorkspaceType), - "orgId": target.OrganizationID, + "roleId": roleID, + "roleName": roleName, + "workspaceId": ws.WorkspaceID, + "workspaceName": ws.WorkspaceName, + "workspaceType": string(ws.WorkspaceType), + "orgId": ws.OrganizationID, "reason": f.reason, "priority": f.priority, "requestDate": f.date, @@ -456,15 +512,3 @@ func validateSubmitFields(f *submitFields) error { return nil } - -// buildCloudSelectionItems wraps cloud targets in selectionItems for the unified selector. -func buildCloudSelectionItems(targets []models.EligibleTarget) []selectionItem { - items := make([]selectionItem, len(targets)) - for i := range targets { - items[i] = selectionItem{ - kind: selectionCloud, - cloud: &targets[i], - } - } - return items -} diff --git a/cmd/request_test.go b/cmd/request_test.go index e2a516c..6429345 100644 --- a/cmd/request_test.go +++ b/cmd/request_test.go @@ -501,13 +501,13 @@ func TestRunRequestSubmit_NonInteractive(t *testing.T) { original := resolveSubmitTargetFn defer func() { resolveSubmitTargetFn = original }() - resolveSubmitTargetFn = func(_ context.Context, _, _, _ string) (*models.EligibleTarget, error) { - return &models.EligibleTarget{ - WorkspaceName: "Test Sub", - WorkspaceID: "ws-1", - WorkspaceType: models.WorkspaceTypeSubscription, - CSP: models.CSPAzure, - RoleInfo: models.RoleInfo{ID: "role-1", Name: "Contributor"}, + resolveSubmitTargetFn = func(_ context.Context, _, _ string) (*submitWorkspace, error) { + return &submitWorkspace{ + WorkspaceName: "Test Sub", + WorkspaceID: "ws-1", + WorkspaceType: models.WorkspaceTypeSubscription, + CSP: models.CSPAzure, + OrganizationID: "org-1", }, nil } @@ -523,7 +523,7 @@ func TestRunRequestSubmit_NonInteractive(t *testing.T) { root.AddCommand(cmd) output, err := executeCommand(root, "request", "submit", - "--target", "Test Sub", "--role", "Contributor", + "--target", "Test Sub", "--role-id", "role-1", "--role", "Contributor", "--reason", "need access", "--date", "2026-04-21", "--timezone", "UTC", "--from", "09:00", "--to", "17:00", "--yes") @@ -539,12 +539,12 @@ func TestRunRequestSubmit_JSONOutput(t *testing.T) { original := resolveSubmitTargetFn defer func() { resolveSubmitTargetFn = original }() - resolveSubmitTargetFn = func(_ context.Context, _, _, _ string) (*models.EligibleTarget, error) { - return &models.EligibleTarget{ - WorkspaceName: "Test Sub", - WorkspaceID: "ws-1", - CSP: models.CSPAzure, - RoleInfo: models.RoleInfo{ID: "role-1", Name: "Contributor"}, + resolveSubmitTargetFn = func(_ context.Context, _, _ string) (*submitWorkspace, error) { + return &submitWorkspace{ + WorkspaceName: "Test Sub", + WorkspaceID: "ws-1", + CSP: models.CSPAzure, + OrganizationID: "org-1", }, nil } @@ -565,7 +565,7 @@ func TestRunRequestSubmit_JSONOutput(t *testing.T) { root.AddCommand(cmd) output, err := executeCommand(root, "request", "submit", - "--target", "Test Sub", "--role", "Contributor", + "--target", "Test Sub", "--role-id", "role-1", "--role", "Contributor", "--reason", "test", "--date", "2026-04-21", "--timezone", "UTC", "--from", "09:00", "--to", "17:00", "--output", "json", "--yes") @@ -586,12 +586,12 @@ func TestRunRequestSubmit_ServiceError(t *testing.T) { original := resolveSubmitTargetFn defer func() { resolveSubmitTargetFn = original }() - resolveSubmitTargetFn = func(_ context.Context, _, _, _ string) (*models.EligibleTarget, error) { - return &models.EligibleTarget{ - WorkspaceName: "Sub", - WorkspaceID: "ws-1", - CSP: models.CSPAzure, - RoleInfo: models.RoleInfo{ID: "r1", Name: "Reader"}, + resolveSubmitTargetFn = func(_ context.Context, _, _ string) (*submitWorkspace, error) { + return &submitWorkspace{ + WorkspaceName: "Sub", + WorkspaceID: "ws-1", + CSP: models.CSPAzure, + OrganizationID: "org-1", }, nil } @@ -604,7 +604,7 @@ func TestRunRequestSubmit_ServiceError(t *testing.T) { root.AddCommand(cmd) _, err := executeCommand(root, "request", "submit", - "--target", "Sub", "--role", "Reader", + "--target", "Sub", "--role-id", "r1", "--reason", "test", "--date", "2026-04-21", "--timezone", "UTC", "--from", "09:00", "--to", "17:00", "--yes") @@ -620,12 +620,12 @@ func TestRunRequestSubmit_MissingFlags_NonInteractive(t *testing.T) { original := resolveSubmitTargetFn defer func() { resolveSubmitTargetFn = original }() - resolveSubmitTargetFn = func(_ context.Context, _, _, _ string) (*models.EligibleTarget, error) { - return &models.EligibleTarget{ - WorkspaceName: "Sub", - WorkspaceID: "ws-1", - CSP: models.CSPAzure, - RoleInfo: models.RoleInfo{ID: "r1", Name: "Reader"}, + resolveSubmitTargetFn = func(_ context.Context, _, _ string) (*submitWorkspace, error) { + return &submitWorkspace{ + WorkspaceName: "Sub", + WorkspaceID: "ws-1", + CSP: models.CSPAzure, + OrganizationID: "org-1", }, nil } @@ -635,7 +635,7 @@ func TestRunRequestSubmit_MissingFlags_NonInteractive(t *testing.T) { root.AddCommand(cmd) _, err := executeCommand(root, "request", "submit", - "--target", "Sub", "--role", "Reader", + "--target", "Sub", "--role-id", "r1", "--reason", "test") if err == nil { t.Fatal("expected error for missing --date/--timezone/--from/--to, got nil") @@ -762,7 +762,7 @@ func TestRunRequestSubmit_InvalidProvider(t *testing.T) { original := resolveSubmitTargetFn defer func() { resolveSubmitTargetFn = original }() - resolveSubmitTargetFn = func(_ context.Context, _, _, _ string) (*models.EligibleTarget, error) { + resolveSubmitTargetFn = func(_ context.Context, _, _ string) (*submitWorkspace, error) { t.Fatal("resolveSubmitTarget should not be called with invalid provider") return nil, nil } @@ -774,7 +774,7 @@ func TestRunRequestSubmit_InvalidProvider(t *testing.T) { _, err := executeCommand(root, "request", "submit", "--provider", "gcp", - "--target", "Sub", "--role", "Reader", + "--target", "Sub", "--role-id", "r1", "--reason", "test", "--date", "2026-04-21", "--timezone", "UTC", "--from", "09:00", "--to", "17:00", "--yes") @@ -794,12 +794,12 @@ func TestRunRequestSubmit_ConfirmationDenied(t *testing.T) { confirmSubmitFn = originalConfirm }() - resolveSubmitTargetFn = func(_ context.Context, _, _, _ string) (*models.EligibleTarget, error) { - return &models.EligibleTarget{ - WorkspaceName: "Test Sub", - WorkspaceID: "ws-1", - CSP: models.CSPAzure, - RoleInfo: models.RoleInfo{ID: "role-1", Name: "Contributor"}, + resolveSubmitTargetFn = func(_ context.Context, _, _ string) (*submitWorkspace, error) { + return &submitWorkspace{ + WorkspaceName: "Test Sub", + WorkspaceID: "ws-1", + CSP: models.CSPAzure, + OrganizationID: "org-1", }, nil } confirmSubmitFn = func() (bool, error) { @@ -819,7 +819,7 @@ func TestRunRequestSubmit_ConfirmationDenied(t *testing.T) { root.AddCommand(cmd) output, err := executeCommand(root, "request", "submit", - "--target", "Test Sub", "--role", "Contributor", + "--target", "Test Sub", "--role-id", "role-1", "--reason", "test", "--date", "2026-04-21", "--timezone", "UTC", "--from", "09:00", "--to", "17:00") if err != nil { @@ -838,12 +838,12 @@ func TestRunRequestSubmit_YesFlagSkipsConfirmation(t *testing.T) { confirmSubmitFn = originalConfirm }() - resolveSubmitTargetFn = func(_ context.Context, _, _, _ string) (*models.EligibleTarget, error) { - return &models.EligibleTarget{ - WorkspaceName: "Test Sub", - WorkspaceID: "ws-1", - CSP: models.CSPAzure, - RoleInfo: models.RoleInfo{ID: "role-1", Name: "Contributor"}, + resolveSubmitTargetFn = func(_ context.Context, _, _ string) (*submitWorkspace, error) { + return &submitWorkspace{ + WorkspaceName: "Test Sub", + WorkspaceID: "ws-1", + CSP: models.CSPAzure, + OrganizationID: "org-1", }, nil } confirmSubmitFn = func() (bool, error) { @@ -863,7 +863,7 @@ func TestRunRequestSubmit_YesFlagSkipsConfirmation(t *testing.T) { root.AddCommand(cmd) output, err := executeCommand(root, "request", "submit", - "--target", "Test Sub", "--role", "Contributor", + "--target", "Test Sub", "--role-id", "role-1", "--reason", "test", "--date", "2026-04-21", "--timezone", "UTC", "--from", "09:00", "--to", "17:00", "--yes") @@ -875,50 +875,38 @@ func TestRunRequestSubmit_YesFlagSkipsConfirmation(t *testing.T) { } } -func TestResolveSubmitTarget_PartialTarget(t *testing.T) { - original := resolveSubmitTargetFn - defer func() { resolveSubmitTargetFn = original }() - +func TestDeduplicateWorkspaces(t *testing.T) { targets := []models.EligibleTarget{ - {WorkspaceName: "Sub A", WorkspaceID: "ws-a", CSP: models.CSPAzure, RoleInfo: models.RoleInfo{ID: "r1", Name: "Reader"}}, - {WorkspaceName: "Sub A", WorkspaceID: "ws-a", CSP: models.CSPAzure, RoleInfo: models.RoleInfo{ID: "r2", Name: "Contributor"}}, - {WorkspaceName: "Sub B", WorkspaceID: "ws-b", CSP: models.CSPAzure, RoleInfo: models.RoleInfo{ID: "r3", Name: "Reader"}}, + {WorkspaceName: "Sub A", WorkspaceID: "ws-a", CSP: models.CSPAzure, WorkspaceType: models.WorkspaceTypeSubscription, RoleInfo: models.RoleInfo{ID: "r1", Name: "Reader"}}, + {WorkspaceName: "Sub A", WorkspaceID: "ws-a", CSP: models.CSPAzure, WorkspaceType: models.WorkspaceTypeSubscription, RoleInfo: models.RoleInfo{ID: "r2", Name: "Contributor"}}, + {WorkspaceName: "Sub B", WorkspaceID: "ws-b", CSP: models.CSPAzure, WorkspaceType: models.WorkspaceTypeSubscription, RoleInfo: models.RoleInfo{ID: "r3", Name: "Reader"}}, + {WorkspaceName: "AWS Account", WorkspaceID: "ws-c", CSP: models.CSPAWS, WorkspaceType: models.WorkspaceTypeAccount, RoleInfo: models.RoleInfo{ID: "r4", Name: "Admin"}}, } - // Single match with --target filters to one result - var filtered []models.EligibleTarget - for i := range targets { - if strings.EqualFold(targets[i].WorkspaceName, "Sub B") { - filtered = append(filtered, targets[i]) - } + workspaces := deduplicateWorkspaces(targets) + if len(workspaces) != 3 { + t.Fatalf("expected 3 unique workspaces, got %d", len(workspaces)) } - if len(filtered) != 1 { - t.Fatalf("expected 1 match for 'Sub B', got %d", len(filtered)) + + names := make(map[string]bool) + for _, ws := range workspaces { + names[ws.WorkspaceName] = true } - if filtered[0].RoleInfo.Name != "Reader" { - t.Errorf("expected Reader, got %s", filtered[0].RoleInfo.Name) + if !names["Sub A"] || !names["Sub B"] || !names["AWS Account"] { + t.Errorf("unexpected workspace names: %v", names) } } -func TestResolveSubmitTarget_PartialRole(t *testing.T) { - targets := []models.EligibleTarget{ - {WorkspaceName: "Sub A", WorkspaceID: "ws-a", CSP: models.CSPAzure, RoleInfo: models.RoleInfo{ID: "r1", Name: "Reader"}}, - {WorkspaceName: "Sub A", WorkspaceID: "ws-a", CSP: models.CSPAzure, RoleInfo: models.RoleInfo{ID: "r2", Name: "Contributor"}}, - {WorkspaceName: "Sub B", WorkspaceID: "ws-b", CSP: models.CSPAzure, RoleInfo: models.RoleInfo{ID: "r3", Name: "Contributor"}}, - } - - // --role "Reader" should match only Sub A/Reader - var filtered []models.EligibleTarget - for i := range targets { - if strings.EqualFold(targets[i].RoleInfo.Name, "Reader") { - filtered = append(filtered, targets[i]) - } - } - if len(filtered) != 1 { - t.Fatalf("expected 1 match for 'Reader', got %d", len(filtered)) +func TestFormatWorkspaceOption(t *testing.T) { + ws := submitWorkspace{ + WorkspaceName: "Production Account", + WorkspaceType: models.WorkspaceTypeAccount, + CSP: models.CSPAWS, } - if filtered[0].WorkspaceName != "Sub A" { - t.Errorf("expected Sub A, got %s", filtered[0].WorkspaceName) + got := formatWorkspaceOption(ws) + want := "Account: Production Account (aws)" + if got != want { + t.Errorf("formatWorkspaceOption() = %q, want %q", got, want) } } From 8f471edda5b4b8dc86ebb117715a5660e34fbfb4 Mon Sep 17 00:00:00 2001 From: Tim Schindler Date: Mon, 20 Apr 2026 17:36:27 +0200 Subject: [PATCH 6/9] fix: case-insensitive workspace type labels in selector API returns mixed-case workspace types (e.g. "directory", "management_group") while Go constants are uppercase. Use case-insensitive comparison for display labels. --- cmd/request_submit.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/cmd/request_submit.go b/cmd/request_submit.go index 6e6cad1..ed8e68f 100644 --- a/cmd/request_submit.go +++ b/cmd/request_submit.go @@ -414,15 +414,19 @@ func formatWorkspaceOption(ws submitWorkspace) string { } func workspaceTypeLabel(wt models.WorkspaceType) string { - switch wt { - case models.WorkspaceTypeSubscription: + switch strings.ToUpper(string(wt)) { + case "SUBSCRIPTION": return "Subscription" - case models.WorkspaceTypeManagementGroup: + case "MANAGEMENT_GROUP": return "Management Group" - case models.WorkspaceTypeDirectory: + case "DIRECTORY": return "Directory" - case models.WorkspaceTypeAccount: + case "ACCOUNT": return "Account" + case "RESOURCE_GROUP": + return "Resource Group" + case "RESOURCE": + return "Resource" default: return string(wt) } From a0d2e1b56190f87c207b7c957a92eff70fb604c8 Mon Sep 17 00:00:00 2001 From: Tim Schindler Date: Mon, 20 Apr 2026 19:48:07 +0200 Subject: [PATCH 7/9] feat: interactive role selector for grant request submit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the manual Role ID prompt with a fuzzy-filterable list fetched from the SCA on-demand role discovery endpoints. Supported workspace types: DIRECTORY (azure_ad), ACCOUNT (aws), MANAGEMENT_GROUP (azure_resource). Other Azure-resource scopes reject with a clear error pointing to --role-id. Roles are cached in ~/.grant/cache/ondemand_roles__.json (shared 4h Store). Interactive prompt order is now: workspace → role → reason → priority → timezone → date → start time → end time. --- CHANGELOG.md | 4 + CLAUDE.md | 4 +- cmd/request_submit.go | 119 +++++++++++++++--- cmd/request_test.go | 116 +++++++++++++++++ cmd/test_mocks.go | 4 +- internal/cache/cached_roles.go | 61 +++++++++ internal/cache/cached_roles_test.go | 111 +++++++++++++++++ internal/sca/models/ondemand.go | 22 ++++ internal/sca/models/ondemand_test.go | 89 +++++++++++++ internal/sca/service.go | 75 +++++++++++ internal/sca/service_test.go | 180 +++++++++++++++++++++++++++ internal/ui/role_selector.go | 75 +++++++++++ internal/ui/role_selector_test.go | 93 ++++++++++++++ 13 files changed, 936 insertions(+), 17 deletions(-) create mode 100644 internal/cache/cached_roles.go create mode 100644 internal/cache/cached_roles_test.go create mode 100644 internal/sca/models/ondemand.go create mode 100644 internal/sca/models/ondemand_test.go create mode 100644 internal/ui/role_selector.go create mode 100644 internal/ui/role_selector_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f6755e..3082ff4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,10 @@ All notable changes to this project will be documented in this file. - `grant request reject ` — reject a pending request with optional reason - All `grant request` subcommands support `--output json` for machine-readable output - New `internal/workflows/` package implementing the CyberArk Access Requests API client (`/api/workflows/requests`) +- Interactive role selector for `grant request submit`: after workspace selection, fuzzy-filterable list of requestable roles is fetched from the SCA on-demand role discovery endpoints (`/api/cloud/resources/ondemand`, `/api/cloud/cloud-roles/ondemand`) + - Supported workspace types: `DIRECTORY` (azure_ad), `ACCOUNT` (aws), `MANAGEMENT_GROUP` (azure_resource) + - Other Azure-resource scopes (subscription, resource group, resource) still require `--role-id` until validated + - Roles cached in `~/.grant/cache/ondemand_roles__.json` (4h TTL) ## [0.6.1] - 2026-04-08 diff --git a/CLAUDE.md b/CLAUDE.md index 97f80ac..7da39a1 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -74,7 +74,9 @@ Custom `SCAAccessService` follows SDK conventions: - `grant list` — list eligible targets and groups without triggering elevation; supports `--provider`, `--groups`, `--refresh`, `--output json`; used by LLMs to discover available targets programmatically - `grant revoke` — revoke sessions: direct (`grant revoke `), `--all`, or interactive multi-select; `--yes` skips confirmation - `grant request` — manage access requests through approval workflow; subcommands: `submit`, `list`, `get`, `cancel`, `approve`, `reject` -- `grant request submit` — submit on-demand access request; workspace selector uses SCA eligibility (deduplicated to unique workspaces); user provides role ID for the role they want; shows summary + confirmation before submitting; flags: `--target`, `--role-id`, `--role`, `--provider`, `--reason`, `--priority`, `--date`, `--timezone`, `--from`, `--to`, `--yes` +- `grant request submit` — submit on-demand access request; workspace selector uses SCA eligibility (deduplicated to unique workspaces); after workspace selection, interactive role selector fetches roles via SCA on-demand endpoints (GET `/api/cloud/resources/ondemand` for `azure_ad`/`aws`, POST `/api/cloud/cloud-roles/ondemand` for `azure_resource`); shows summary + confirmation before submitting; flags: `--target`, `--role-id`, `--role`, `--provider`, `--reason`, `--priority`, `--date`, `--timezone`, `--from`, `--to`, `--yes` + - Interactive role selection supports `DIRECTORY`, `ACCOUNT` (AWS), and `MANAGEMENT_GROUP` workspaces; other Azure-resource scopes (subscription, resource group, resource) require `--role-id` + - On-demand role cache: `~/.grant/cache/ondemand_roles__.json` (4h TTL); no `--refresh` flag — delete manually to invalidate - `grant request list` — list access requests; flags: `--state`, `--result`, `--priority`, `--role` (CREATOR/APPROVER), `--search`, `--sort`, `--desc` - `grant request get ` — get full request details - `grant request cancel ` — cancel an open request; optional `--reason` diff --git a/cmd/request_submit.go b/cmd/request_submit.go index ed8e68f..f4fb771 100644 --- a/cmd/request_submit.go +++ b/cmd/request_submit.go @@ -9,10 +9,12 @@ import ( "time" survey "github.com/Iilun/survey/v2" + "github.com/aaearon/grant-cli/internal/cache" "github.com/aaearon/grant-cli/internal/config" "github.com/aaearon/grant-cli/internal/sca/models" "github.com/aaearon/grant-cli/internal/ui" wfmodels "github.com/aaearon/grant-cli/internal/workflows/models" + "github.com/cyberark/idsec-sdk-golang/pkg/common" "github.com/spf13/cobra" ) @@ -60,6 +62,9 @@ var resolveSubmitTargetFn = resolveSubmitTarget // submitWorkspaceSelectorFn is injectable for testing workspace selection. var submitWorkspaceSelectorFn = selectSubmitWorkspace +// resolveRoleFn is injectable for testing role resolution. +var resolveRoleFn = resolveSubmitRole + type submitFields struct { reason string priority string @@ -215,15 +220,6 @@ func runRequestSubmit(cmd *cobra.Command, svc accessRequestService) error { } } - fields, err := resolveSubmitFields(cmd) - if err != nil { - return err - } - - if err := validateSubmitFields(fields); err != nil { - return err - } - targetName, _ := cmd.Flags().GetString("target") roleID, _ := cmd.Flags().GetString("role-id") roleName, _ := cmd.Flags().GetString("role") @@ -231,28 +227,41 @@ func runRequestSubmit(cmd *cobra.Command, svc accessRequestService) error { ctx, cancel := context.WithTimeout(cmd.Context(), apiTimeout) defer cancel() + // 1. Workspace workspace, err := resolveSubmitTargetFn(ctx, provider, targetName) if err != nil { return err } - // Resolve role ID + // 2. Role if roleID == "" { if !ui.IsInteractive() { return errors.New("non-interactive mode requires --role-id") } - stdio := survey.WithStdio(os.Stdin, os.Stderr, os.Stderr) - if err := survey.AskOne(&survey.Input{Message: "Role ID:"}, &roleID, - survey.WithValidator(survey.Required), stdio); err != nil { - return err + resolvedID, resolvedName, err := resolveRoleFn(ctx, workspace) + if err != nil { + return fmt.Errorf("%w; retry with --role-id to bypass interactive role selection", err) + } + roleID = resolvedID + if roleName == "" { + roleName = resolvedName } } - // Resolve role name (optional, defaults to role ID if not provided) if roleName == "" { roleName = roleID } + // 3–8. Reason, priority, timezone, date, start time, end time + fields, err := resolveSubmitFields(cmd) + if err != nil { + return err + } + + if err := validateSubmitFields(fields); err != nil { + return err + } + // Summary before submission if !isJSONOutput() { fmt.Fprintf(cmd.ErrOrStderr(), "\nWorkspace: %s\n", workspace.WorkspaceName) @@ -476,6 +485,86 @@ func buildRequestDetails(ws *submitWorkspace, roleID, roleName string, f *submit } } +// resolveSubmitRole fetches on-demand roles for the selected workspace and +// prompts the user to choose one. Returns the role's resource_id and resource_name. +func resolveSubmitRole(ctx context.Context, ws *submitWorkspace) (roleID, roleName string, _ error) { + req, err := buildOnDemandRequest(ws) + if err != nil { + return "", "", err + } + + _, scaSvc, _, err := bootstrapSCAService() + if err != nil { + return "", "", fmt.Errorf("failed to bootstrap SCA service: %w", err) + } + + cfg, _, _ := config.LoadDefaultWithPath() + if cfg == nil { + cfg = config.DefaultConfig() + } + + var lister cache.OnDemandRolesLister = scaSvc + cacheDir, cacheErr := cache.CacheDir() + if cacheErr == nil { + ttl := config.ParseCacheTTL(cfg) + store := cache.NewStore(cacheDir, ttl) + lister = cache.NewCachedRolesLister(scaSvc, store, common.GetLogger("grant", -1)) + } + + fetchCtx, cancel := context.WithTimeout(ctx, apiTimeout) + defer cancel() + + roles, err := lister.ListOnDemandResources(fetchCtx, req) + if err != nil { + return "", "", fmt.Errorf("failed to fetch on-demand roles: %w", err) + } + + selected, err := ui.SelectRole(roles) + if err != nil { + return "", "", err + } + return selected.ResourceID, selected.ResourceName, nil +} + +// buildOnDemandRequest maps a workspace into the on-demand discovery request. +// Only directory / AWS account / management-group workspaces are supported in v1. +func buildOnDemandRequest(ws *submitWorkspace) (models.OnDemandRequest, error) { + wt := strings.ToUpper(string(ws.WorkspaceType)) + switch wt { + case "DIRECTORY": + return models.OnDemandRequest{ + WorkspaceID: ws.WorkspaceID, + PlatformName: "azure_ad", + OrgID: ws.OrganizationID, + }, nil + case "ACCOUNT": + return models.OnDemandRequest{ + WorkspaceID: ws.WorkspaceID, + PlatformName: "aws", + OrgID: ws.OrganizationID, + }, nil + case "MANAGEMENT_GROUP": + return models.OnDemandRequest{ + WorkspaceID: ws.WorkspaceID, + PlatformName: "azure_resource", + OrgID: ws.OrganizationID, + ResourceType: "management_group", + Ancestors: []string{ + "/" + ws.OrganizationID, + "/" + ws.WorkspaceID, + }, + }, nil + case "SUBSCRIPTION", "RESOURCE_GROUP", "RESOURCE": + return models.OnDemandRequest{}, fmt.Errorf( + "interactive role selection not supported for %s workspaces yet; use --role-id (see docs/handoff-ondemand-roles-poc.md)", + strings.ToLower(wt)) + default: + return models.OnDemandRequest{}, fmt.Errorf( + "interactive role selection not supported for workspace type %q; use --role-id", + ws.WorkspaceType) + } +} + func validateSubmitFields(f *submitFields) error { if f.reason == "" { return errors.New("--reason is required") diff --git a/cmd/request_test.go b/cmd/request_test.go index 6429345..79875ed 100644 --- a/cmd/request_test.go +++ b/cmd/request_test.go @@ -645,6 +645,122 @@ func TestRunRequestSubmit_MissingFlags_NonInteractive(t *testing.T) { } } +func TestBuildOnDemandRequest_UnsupportedType(t *testing.T) { + tests := []struct { + name string + wt models.WorkspaceType + }{ + {"subscription", models.WorkspaceTypeSubscription}, + {"resource_group", models.WorkspaceType("RESOURCE_GROUP")}, + {"resource", models.WorkspaceType("RESOURCE")}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := buildOnDemandRequest(&submitWorkspace{ + WorkspaceID: "ws-1", + WorkspaceType: tt.wt, + OrganizationID: "org-1", + }) + if err == nil { + t.Fatalf("expected error for %s", tt.name) + } + if !strings.Contains(err.Error(), "not supported") { + t.Errorf("error should mention not supported: %v", err) + } + if !strings.Contains(err.Error(), "--role-id") { + t.Errorf("error should point to --role-id: %v", err) + } + }) + } +} + +func TestBuildOnDemandRequest_SupportedTypes(t *testing.T) { + tests := []struct { + name string + wt models.WorkspaceType + wsID string + orgID string + wantPlatform string + wantAnces int + }{ + {"directory", models.WorkspaceType("DIRECTORY"), "dir-1", "dir-1", "azure_ad", 0}, + {"account", models.WorkspaceType("ACCOUNT"), "123", "123", "aws", 0}, + {"management_group", models.WorkspaceType("MANAGEMENT_GROUP"), "providers/Microsoft.Management/managementGroups/root", "dir-456", "azure_resource", 2}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := buildOnDemandRequest(&submitWorkspace{ + WorkspaceID: tt.wsID, + WorkspaceType: tt.wt, + OrganizationID: tt.orgID, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.PlatformName != tt.wantPlatform { + t.Errorf("platform: got %q want %q", req.PlatformName, tt.wantPlatform) + } + if len(req.Ancestors) != tt.wantAnces { + t.Errorf("ancestors: got %d want %d", len(req.Ancestors), tt.wantAnces) + } + }) + } +} + +func TestRunRequestSubmit_InteractiveRoleSelection(t *testing.T) { + origTarget := resolveSubmitTargetFn + origRole := resolveRoleFn + origTTY := ui.IsTerminalFunc + defer func() { + resolveSubmitTargetFn = origTarget + resolveRoleFn = origRole + ui.IsTerminalFunc = origTTY + }() + ui.IsTerminalFunc = func(fd uintptr) bool { return true } + + resolveSubmitTargetFn = func(_ context.Context, _, _ string) (*submitWorkspace, error) { + return &submitWorkspace{ + WorkspaceName: "Dir", + WorkspaceID: "dir-1", + WorkspaceType: models.WorkspaceType("DIRECTORY"), + CSP: models.CSPAzure, + OrganizationID: "dir-1", + }, nil + } + resolveRoleFn = func(_ context.Context, ws *submitWorkspace) (string, string, error) { + if ws.WorkspaceID != "dir-1" { + t.Errorf("expected ws dir-1, got %s", ws.WorkspaceID) + } + return "arn:aws:iam::1:role/Admin", "Admin", nil + } + + svc := &mockAccessRequestService{ + submitResult: &wfmodels.AccessRequest{RequestID: "req-x", RequestState: wfmodels.RequestStatePending}, + } + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + output, err := executeCommand(root, "request", "submit", + "--reason", "test", "--date", "2026-04-21", + "--timezone", "UTC", "--from", "09:00", "--to", "17:00", + "--yes") + if err != nil { + t.Fatalf("unexpected error: %v\noutput: %s", err, output) + } + + submitted := svc.submitRequest + if submitted == nil { + t.Fatal("expected a submitted request") + } + if submitted.RequestDetails["roleId"] != "arn:aws:iam::1:role/Admin" { + t.Errorf("roleId: got %v", submitted.RequestDetails["roleId"]) + } + if submitted.RequestDetails["roleName"] != "Admin" { + t.Errorf("roleName: got %v", submitted.RequestDetails["roleName"]) + } +} + func TestResolveSubmitFields_Interactive(t *testing.T) { originalPrompt := submitPromptFn defer func() { submitPromptFn = originalPrompt }() diff --git a/cmd/test_mocks.go b/cmd/test_mocks.go index 8a94395..5e01aa3 100644 --- a/cmd/test_mocks.go +++ b/cmd/test_mocks.go @@ -249,6 +249,7 @@ type mockAccessRequestService struct { getErr error submitResult *wfmodels.AccessRequest submitErr error + submitRequest *wfmodels.SubmitAccessRequest cancelResult *wfmodels.AccessRequest cancelErr error finalizeResult *wfmodels.AccessRequest @@ -263,7 +264,8 @@ func (m *mockAccessRequestService) GetRequest(_ context.Context, _ string) (*wfm return m.getResult, m.getErr } -func (m *mockAccessRequestService) SubmitRequest(_ context.Context, _ *wfmodels.SubmitAccessRequest) (*wfmodels.AccessRequest, error) { +func (m *mockAccessRequestService) SubmitRequest(_ context.Context, req *wfmodels.SubmitAccessRequest) (*wfmodels.AccessRequest, error) { + m.submitRequest = req return m.submitResult, m.submitErr } diff --git a/internal/cache/cached_roles.go b/internal/cache/cached_roles.go new file mode 100644 index 0000000..a10dc7e --- /dev/null +++ b/internal/cache/cached_roles.go @@ -0,0 +1,61 @@ +package cache + +import ( + "context" + "crypto/sha256" + "encoding/hex" + + scamodels "github.com/aaearon/grant-cli/internal/sca/models" +) + +// OnDemandRolesLister mirrors the service method for on-demand role discovery. +type OnDemandRolesLister interface { + ListOnDemandResources(ctx context.Context, req scamodels.OnDemandRequest) ([]scamodels.OnDemandResource, error) +} + +// CachedRolesLister decorates an OnDemandRolesLister with file-based caching. +type CachedRolesLister struct { + inner OnDemandRolesLister + store *Store + log Logger +} + +// NewCachedRolesLister creates a caching decorator for on-demand role discovery. +func NewCachedRolesLister(inner OnDemandRolesLister, store *Store, log Logger) *CachedRolesLister { + if log == nil { + log = nopLogger{} + } + return &CachedRolesLister{inner: inner, store: store, log: log} +} + +// ListOnDemandResources checks the cache first, then falls through to the inner lister. +func (c *CachedRolesLister) ListOnDemandResources(ctx context.Context, req scamodels.OnDemandRequest) ([]scamodels.OnDemandResource, error) { + key := onDemandRolesCacheKey(req.PlatformName, req.WorkspaceID) + + var cached []scamodels.OnDemandResource + if Get(c.store, key, &cached) { + c.log.Info("Cache hit for on-demand roles (%s, %d roles)", req.PlatformName, len(cached)) + return cached, nil + } + c.log.Info("Cache miss for on-demand roles (%s), fetching from API", req.PlatformName) + + roles, err := c.inner.ListOnDemandResources(ctx, req) + if err != nil { + return nil, err + } + + if err := Set(c.store, key, roles); err != nil { + c.log.Info("Cache write failed for on-demand roles (%s): %v", req.PlatformName, err) + } else { + c.log.Info("Cached on-demand roles (%s, %d roles)", req.PlatformName, len(roles)) + } + return roles, nil +} + +// onDemandRolesCacheKey builds a cache key from platformName and a sha256 hash of workspaceID. +// Hashing eliminates unsafe filesystem characters from workspace identifiers like +// "/providers/Microsoft.Management/managementGroups/...". +func onDemandRolesCacheKey(platformName, workspaceID string) string { + sum := sha256.Sum256([]byte(workspaceID)) + return "ondemand_roles_" + platformName + "_" + hex.EncodeToString(sum[:]) +} diff --git a/internal/cache/cached_roles_test.go b/internal/cache/cached_roles_test.go new file mode 100644 index 0000000..8623121 --- /dev/null +++ b/internal/cache/cached_roles_test.go @@ -0,0 +1,111 @@ +package cache + +import ( + "context" + "errors" + "testing" + "time" + + scamodels "github.com/aaearon/grant-cli/internal/sca/models" +) + +type fakeRolesLister struct { + callCount int + roles []scamodels.OnDemandResource + err error +} + +func (f *fakeRolesLister) ListOnDemandResources(_ context.Context, _ scamodels.OnDemandRequest) ([]scamodels.OnDemandResource, error) { + f.callCount++ + if f.err != nil { + return nil, f.err + } + return f.roles, nil +} + +func newTestStore(t *testing.T) *Store { + t.Helper() + return NewStore(t.TempDir(), time.Hour) +} + +func TestCachedRolesLister_MissThenHit(t *testing.T) { + fake := &fakeRolesLister{roles: []scamodels.OnDemandResource{{ResourceID: "r1", ResourceName: "Role 1"}}} + cached := NewCachedRolesLister(fake, newTestStore(t), nil) + + req := scamodels.OnDemandRequest{WorkspaceID: "ws-1", PlatformName: "azure_ad", OrgID: "ws-1"} + + roles1, err := cached.ListOnDemandResources(t.Context(), req) + if err != nil { + t.Fatalf("first call: %v", err) + } + if len(roles1) != 1 { + t.Errorf("first call: expected 1 role, got %d", len(roles1)) + } + + roles2, err := cached.ListOnDemandResources(t.Context(), req) + if err != nil { + t.Fatalf("second call: %v", err) + } + if len(roles2) != 1 { + t.Errorf("second call: expected 1 role, got %d", len(roles2)) + } + if fake.callCount != 1 { + t.Errorf("expected inner to be called once, got %d", fake.callCount) + } +} + +func TestCachedRolesLister_DifferentWorkspacesDistinct(t *testing.T) { + fake := &fakeRolesLister{roles: []scamodels.OnDemandResource{{ResourceID: "r1"}}} + cached := NewCachedRolesLister(fake, newTestStore(t), nil) + + reqA := scamodels.OnDemandRequest{WorkspaceID: "ws-A", PlatformName: "azure_ad", OrgID: "ws-A"} + reqB := scamodels.OnDemandRequest{WorkspaceID: "ws-B", PlatformName: "azure_ad", OrgID: "ws-B"} + + if _, err := cached.ListOnDemandResources(t.Context(), reqA); err != nil { + t.Fatal(err) + } + if _, err := cached.ListOnDemandResources(t.Context(), reqB); err != nil { + t.Fatal(err) + } + if fake.callCount != 2 { + t.Errorf("expected 2 inner calls for distinct workspaces, got %d", fake.callCount) + } +} + +func TestCachedRolesLister_DifferentPlatformsDistinct(t *testing.T) { + fake := &fakeRolesLister{roles: []scamodels.OnDemandResource{{ResourceID: "r1"}}} + cached := NewCachedRolesLister(fake, newTestStore(t), nil) + + reqAD := scamodels.OnDemandRequest{WorkspaceID: "same-id", PlatformName: "azure_ad", OrgID: "same-id"} + reqAWS := scamodels.OnDemandRequest{WorkspaceID: "same-id", PlatformName: "aws", OrgID: "same-id"} + + if _, err := cached.ListOnDemandResources(t.Context(), reqAD); err != nil { + t.Fatal(err) + } + if _, err := cached.ListOnDemandResources(t.Context(), reqAWS); err != nil { + t.Fatal(err) + } + if fake.callCount != 2 { + t.Errorf("expected 2 inner calls for distinct platforms with same workspaceID, got %d", fake.callCount) + } +} + +func TestCachedRolesLister_InnerError(t *testing.T) { + fake := &fakeRolesLister{err: errors.New("api failure")} + cached := NewCachedRolesLister(fake, newTestStore(t), nil) + + req := scamodels.OnDemandRequest{WorkspaceID: "ws-1", PlatformName: "aws", OrgID: "ws-1"} + _, err := cached.ListOnDemandResources(t.Context(), req) + if err == nil { + t.Fatal("expected error from inner") + } +} + +func TestOnDemandRolesCacheKey_HandlesSlashes(t *testing.T) { + key := onDemandRolesCacheKey("azure_resource", "/providers/Microsoft.Management/managementGroups/abc") + for _, c := range key { + if c == '/' { + t.Errorf("cache key should not contain slashes: %s", key) + } + } +} diff --git a/internal/sca/models/ondemand.go b/internal/sca/models/ondemand.go new file mode 100644 index 0000000..2577d30 --- /dev/null +++ b/internal/sca/models/ondemand.go @@ -0,0 +1,22 @@ +package models + +// OnDemandResource is a role returned by the on-demand role discovery endpoints. +type OnDemandResource struct { + ResourceID string `json:"resource_id"` + ResourceName string `json:"resource_name"` + Provider string `json:"provider"` + Custom bool `json:"custom"` + Description string `json:"description,omitempty"` + RoleType int `json:"role_type,omitempty"` + AssignableScope []string `json:"assignable_scope,omitempty"` + AssignableWorkspaceType string `json:"assignable_workspace_type,omitempty"` +} + +// OnDemandRequest describes an on-demand roles lookup for a single workspace. +type OnDemandRequest struct { + WorkspaceID string + PlatformName string + OrgID string + ResourceType string + Ancestors []string +} diff --git a/internal/sca/models/ondemand_test.go b/internal/sca/models/ondemand_test.go new file mode 100644 index 0000000..dfde0a6 --- /dev/null +++ b/internal/sca/models/ondemand_test.go @@ -0,0 +1,89 @@ +package models + +import ( + "encoding/json" + "testing" +) + +func TestOnDemandResource_JSONRoundtrip(t *testing.T) { + tests := []struct { + name string + raw string + want OnDemandResource + }{ + { + name: "azure_ad GUID resource_id", + raw: `{ + "resource_id": "62e90394-69f5-4237-9190-012177145e10", + "resource_name": "Global Administrator", + "provider": "azure_ad", + "custom": false, + "role_type": 0 + }`, + want: OnDemandResource{ + ResourceID: "62e90394-69f5-4237-9190-012177145e10", + ResourceName: "Global Administrator", + Provider: "azure_ad", + Custom: false, + RoleType: 0, + }, + }, + { + name: "aws ARN resource_id", + raw: `{ + "resource_id": "arn:aws:iam::547375531250:role/AdministratorAccess", + "resource_name": "AdministratorAccess", + "provider": "aws", + "custom": false, + "description": "Admin role", + "role_type": 1 + }`, + want: OnDemandResource{ + ResourceID: "arn:aws:iam::547375531250:role/AdministratorAccess", + ResourceName: "AdministratorAccess", + Provider: "aws", + Custom: false, + Description: "Admin role", + RoleType: 1, + }, + }, + { + name: "azure_resource custom scoped role", + raw: `{ + "resource_id": "/providers/Microsoft.Authorization/roleDefinitions/09600ded-d1e9-4c99-a3da-704f2df23384", + "resource_name": "CyberArk-SIA-Role", + "provider": "azure_resource", + "custom": true, + "assignable_scope": ["/providers/Microsoft.Management/managementGroups/abc"] + }`, + want: OnDemandResource{ + ResourceID: "/providers/Microsoft.Authorization/roleDefinitions/09600ded-d1e9-4c99-a3da-704f2df23384", + ResourceName: "CyberArk-SIA-Role", + Provider: "azure_resource", + Custom: true, + AssignableScope: []string{"/providers/Microsoft.Management/managementGroups/abc"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got OnDemandResource + if err := json.Unmarshal([]byte(tt.raw), &got); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + if got.ResourceID != tt.want.ResourceID { + t.Errorf("ResourceID: got %q want %q", got.ResourceID, tt.want.ResourceID) + } + if got.ResourceName != tt.want.ResourceName { + t.Errorf("ResourceName: got %q want %q", got.ResourceName, tt.want.ResourceName) + } + if got.Provider != tt.want.Provider { + t.Errorf("Provider: got %q want %q", got.Provider, tt.want.Provider) + } + if got.Custom != tt.want.Custom { + t.Errorf("Custom: got %v want %v", got.Custom, tt.want.Custom) + } + }) + } +} diff --git a/internal/sca/service.go b/internal/sca/service.go index 812a136..80e3f78 100644 --- a/internal/sca/service.go +++ b/internal/sca/service.go @@ -293,6 +293,81 @@ func (s *SCAAccessService) ListGroupsEligibility(ctx context.Context, csp models return &models.GroupsEligibilityResponse{Response: items, Total: total}, nil } +// ListOnDemandResources fetches the list of roles available for on-demand access for the given workspace. +// Dispatches on PlatformName: azure_ad and aws use GET /api/cloud/resources/ondemand, +// azure_resource uses POST /api/cloud/cloud-roles/ondemand with a body including resourceType and ancestors. +func (s *SCAAccessService) ListOnDemandResources(ctx context.Context, req models.OnDemandRequest) ([]models.OnDemandResource, error) { + switch req.PlatformName { + case "azure_ad", "aws": + return s.getOnDemandResources(ctx, req) + case "azure_resource": + return s.postOnDemandResources(ctx, req) + default: + return nil, fmt.Errorf("unsupported platform: %s", req.PlatformName) + } +} + +func (s *SCAAccessService) getOnDemandResources(ctx context.Context, req models.OnDemandRequest) ([]models.OnDemandResource, error) { + searchObj := map[string]interface{}{ + "workspaceId": req.WorkspaceID, + "pageNumber": -1, + "pageSize": -1, + "platformName": req.PlatformName, + "org_id": req.OrgID, + "target_category": "cloud_console", + } + searchJSON, err := json.Marshal(searchObj) + if err != nil { + return nil, fmt.Errorf("failed to encode on-demand search params: %w", err) + } + params := map[string]string{"search": string(searchJSON)} + + resp, err := s.httpClient.Get(ctx, "/api/cloud/resources/ondemand", params) + if err != nil { + return nil, fmt.Errorf("failed to list on-demand resources: %w", err) + } + defer resp.Body.Close() + + if err := checkResponse(resp, "on-demand resources request"); err != nil { + return nil, err + } + + var result []models.OnDemandResource + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode on-demand resources response: %w", err) + } + return result, nil +} + +func (s *SCAAccessService) postOnDemandResources(ctx context.Context, req models.OnDemandRequest) ([]models.OnDemandResource, error) { + body := map[string]interface{}{ + "workspaceId": req.WorkspaceID, + "resourceType": req.ResourceType, + "pageNumber": -1, + "pageSize": -1, + "platformName": req.PlatformName, + "org_id": req.OrgID, + "ancestors": req.Ancestors, + "target_category": "cloud_console", + } + + resp, err := s.httpClient.Post(ctx, "/api/cloud/cloud-roles/ondemand", body) + if err != nil { + return nil, fmt.Errorf("failed to list on-demand cloud-roles: %w", err) + } + defer resp.Body.Close() + + if err := checkResponse(resp, "on-demand cloud-roles request"); err != nil { + return nil, err + } + + var result []models.OnDemandResource + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode on-demand cloud-roles response: %w", err) + } + return result, nil +} + // ElevateGroups requests JIT elevation for the specified Entra ID groups. // POST /api/access/elevate/groups func (s *SCAAccessService) ElevateGroups(ctx context.Context, req *models.GroupsElevateRequest) (*models.GroupsElevateResponse, error) { diff --git a/internal/sca/service_test.go b/internal/sca/service_test.go index fa2c9a3..7c6e47e 100644 --- a/internal/sca/service_test.go +++ b/internal/sca/service_test.go @@ -19,6 +19,8 @@ type mockHTTPClient struct { postError error // getFunc, when set, overrides getResponse/getError for dynamic responses. getFunc func(ctx context.Context, route string, params interface{}) (*http.Response, error) + // postFunc, when set, overrides postResponse/postError for dynamic responses. + postFunc func(ctx context.Context, route string, body interface{}) (*http.Response, error) } func (m *mockHTTPClient) Get(ctx context.Context, route string, params interface{}) (*http.Response, error) { @@ -32,6 +34,9 @@ func (m *mockHTTPClient) Get(ctx context.Context, route string, params interface } func (m *mockHTTPClient) Post(ctx context.Context, route string, body interface{}) (*http.Response, error) { + if m.postFunc != nil { + return m.postFunc(ctx, route, body) + } if m.postError != nil { return nil, m.postError } @@ -875,6 +880,181 @@ func TestListEligibility_PaginationMaxPagesExceeded(t *testing.T) { } } +func TestListOnDemandResources_AzureAD(t *testing.T) { + var capturedParams interface{} + var capturedRoute string + + roles := []models.OnDemandResource{ + {ResourceID: "role-1", ResourceName: "Global Administrator", Provider: "azure_ad", Custom: false}, + } + body, _ := json.Marshal(roles) + + mock := &mockHTTPClient{ + getFunc: func(_ context.Context, route string, params interface{}) (*http.Response, error) { + capturedRoute = route + capturedParams = params + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(string(body))), + }, nil + }, + } + + svc := &SCAAccessService{httpClient: mock} + result, err := svc.ListOnDemandResources(t.Context(), models.OnDemandRequest{ + WorkspaceID: "dir-123", + PlatformName: "azure_ad", + OrgID: "dir-123", + }) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result) != 1 { + t.Errorf("expected 1 role, got %d", len(result)) + } + if capturedRoute != "/api/cloud/resources/ondemand" { + t.Errorf("route: got %q", capturedRoute) + } + p, ok := capturedParams.(map[string]string) + if !ok { + t.Fatalf("expected map[string]string params, got %T", capturedParams) + } + if !strings.Contains(p["search"], `"platformName":"azure_ad"`) { + t.Errorf("search param missing platformName=azure_ad: %s", p["search"]) + } + if !strings.Contains(p["search"], `"workspaceId":"dir-123"`) { + t.Errorf("search param missing workspaceId: %s", p["search"]) + } +} + +func TestListOnDemandResources_AWS(t *testing.T) { + roles := []models.OnDemandResource{ + {ResourceID: "arn:aws:iam::123:role/Admin", ResourceName: "Admin", Provider: "aws"}, + } + body, _ := json.Marshal(roles) + + mock := &mockHTTPClient{ + getResponse: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(string(body))), + }, + } + + svc := &SCAAccessService{httpClient: mock} + result, err := svc.ListOnDemandResources(t.Context(), models.OnDemandRequest{ + WorkspaceID: "123", + PlatformName: "aws", + OrgID: "123", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result) != 1 || result[0].ResourceID != "arn:aws:iam::123:role/Admin" { + t.Errorf("unexpected result: %+v", result) + } +} + +func TestListOnDemandResources_AzureResource_POST(t *testing.T) { + var capturedRoute string + var capturedBody interface{} + + roles := []models.OnDemandResource{ + {ResourceID: "/providers/Microsoft.Authorization/roleDefinitions/abc", ResourceName: "Contributor", Provider: "azure_resource"}, + } + body, _ := json.Marshal(roles) + + mock := &mockHTTPClient{ + postFunc: func(_ context.Context, route string, b interface{}) (*http.Response, error) { + capturedRoute = route + capturedBody = b + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(string(body))), + }, nil + }, + } + + svc := &SCAAccessService{httpClient: mock} + result, err := svc.ListOnDemandResources(t.Context(), models.OnDemandRequest{ + WorkspaceID: "providers/Microsoft.Management/managementGroups/root", + PlatformName: "azure_resource", + OrgID: "dir-456", + ResourceType: "management_group", + Ancestors: []string{"/dir-456", "/providers/Microsoft.Management/managementGroups/root"}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result) != 1 { + t.Errorf("expected 1 role, got %d", len(result)) + } + if capturedRoute != "/api/cloud/cloud-roles/ondemand" { + t.Errorf("route: got %q", capturedRoute) + } + bm, ok := capturedBody.(map[string]interface{}) + if !ok { + t.Fatalf("expected map body, got %T", capturedBody) + } + if bm["resourceType"] != "management_group" { + t.Errorf("resourceType: got %v", bm["resourceType"]) + } + anc, ok := bm["ancestors"].([]string) + if !ok || len(anc) != 2 { + t.Errorf("ancestors: got %v", bm["ancestors"]) + } +} + +func TestListOnDemandResources_UnknownPlatform(t *testing.T) { + mock := &mockHTTPClient{} + svc := &SCAAccessService{httpClient: mock} + _, err := svc.ListOnDemandResources(t.Context(), models.OnDemandRequest{PlatformName: "gcp"}) + if err == nil { + t.Fatal("expected error for unknown platform") + } + if !strings.Contains(err.Error(), "unsupported platform") { + t.Errorf("got: %v", err) + } +} + +func TestListOnDemandResources_Unauthorized(t *testing.T) { + mock := &mockHTTPClient{ + getResponse: &http.Response{ + StatusCode: http.StatusUnauthorized, + Body: io.NopCloser(strings.NewReader(`{"error":"unauthorized"}`)), + }, + } + svc := &SCAAccessService{httpClient: mock} + _, err := svc.ListOnDemandResources(t.Context(), models.OnDemandRequest{ + WorkspaceID: "w", PlatformName: "azure_ad", OrgID: "o", + }) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "401") { + t.Errorf("got: %v", err) + } +} + +func TestListOnDemandResources_EmptyArray(t *testing.T) { + mock := &mockHTTPClient{ + getResponse: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("[]")), + }, + } + svc := &SCAAccessService{httpClient: mock} + result, err := svc.ListOnDemandResources(t.Context(), models.OnDemandRequest{ + WorkspaceID: "w", PlatformName: "aws", OrgID: "o", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result) != 0 { + t.Errorf("expected empty slice, got %d", len(result)) + } +} + func TestListEligibility_Pagination_TotalFromFirstPage(t *testing.T) { token := "page2" callCount := 0 diff --git a/internal/ui/role_selector.go b/internal/ui/role_selector.go new file mode 100644 index 0000000..aafcda4 --- /dev/null +++ b/internal/ui/role_selector.go @@ -0,0 +1,75 @@ +package ui + +import ( + "errors" + "fmt" + "os" + "sort" + "strings" + + "github.com/Iilun/survey/v2" + "github.com/aaearon/grant-cli/internal/sca/models" +) + +// FormatRoleOption formats an on-demand role into a display string. +// Custom roles are marked with a [custom] tag. Descriptions are truncated to 70 chars. +func FormatRoleOption(r models.OnDemandResource) string { + name := r.ResourceName + if r.Custom { + name += " [custom]" + } + if r.Description != "" { + desc := r.Description + if len(desc) > 70 { + desc = desc[:70] + } + return fmt.Sprintf("%s — %s", name, desc) + } + return name +} + +// BuildRoleOptions builds display strings in custom-first-then-alphabetic order. +// Returns a parallel slice of roles that matches the display strings by index. +func BuildRoleOptions(roles []models.OnDemandResource) ([]string, []models.OnDemandResource) { + sorted := make([]models.OnDemandResource, len(roles)) + copy(sorted, roles) + sort.SliceStable(sorted, func(i, j int) bool { + if sorted[i].Custom != sorted[j].Custom { + return sorted[i].Custom + } + return strings.ToLower(sorted[i].ResourceName) < strings.ToLower(sorted[j].ResourceName) + }) + opts := make([]string, len(sorted)) + for i, r := range sorted { + opts[i] = FormatRoleOption(r) + } + return opts, sorted +} + +// SelectRole prompts the user to pick a role from the list. Uses the selected +// index (not display text) to recover the role, so duplicate display strings are safe. +func SelectRole(roles []models.OnDemandResource) (*models.OnDemandResource, error) { + if !IsInteractive() { + return nil, fmt.Errorf("%w; use --role-id for non-interactive mode", ErrNotInteractive) + } + if len(roles) == 0 { + return nil, errors.New("no on-demand roles available for this workspace; use --role-id if this is unexpected") + } + + options, sorted := BuildRoleOptions(roles) + + var selectedIdx int + prompt := &survey.Select{ + Message: "Select a role:", + Options: options, + PageSize: 15, + Filter: nil, + } + if err := survey.AskOne(prompt, &selectedIdx, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)); err != nil { + return nil, fmt.Errorf("role selection failed: %w", err) + } + if selectedIdx < 0 || selectedIdx >= len(sorted) { + return nil, fmt.Errorf("invalid role selection index %d", selectedIdx) + } + return &sorted[selectedIdx], nil +} diff --git a/internal/ui/role_selector_test.go b/internal/ui/role_selector_test.go new file mode 100644 index 0000000..821d5c7 --- /dev/null +++ b/internal/ui/role_selector_test.go @@ -0,0 +1,93 @@ +package ui + +import ( + "errors" + "strings" + "testing" + + "github.com/aaearon/grant-cli/internal/sca/models" +) + +func TestFormatRoleOption(t *testing.T) { + tests := []struct { + name string + in models.OnDemandResource + want string + }{ + {"plain", models.OnDemandResource{ResourceName: "Reader"}, "Reader"}, + {"custom", models.OnDemandResource{ResourceName: "CyberArk-SIA", Custom: true}, "CyberArk-SIA [custom]"}, + {"description", models.OnDemandResource{ResourceName: "Admin", Description: "Full access"}, "Admin — Full access"}, + { + name: "description truncated", + in: models.OnDemandResource{ResourceName: "X", Description: strings.Repeat("a", 100)}, + want: "X — " + strings.Repeat("a", 70), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := FormatRoleOption(tt.in); got != tt.want { + t.Errorf("got %q want %q", got, tt.want) + } + }) + } +} + +func TestBuildRoleOptions_CustomFirstThenAlpha(t *testing.T) { + roles := []models.OnDemandResource{ + {ResourceName: "Reader"}, + {ResourceName: "Custom-B", Custom: true}, + {ResourceName: "Contributor"}, + {ResourceName: "Custom-A", Custom: true}, + } + opts, sorted := BuildRoleOptions(roles) + + if len(opts) != 4 { + t.Fatalf("expected 4 options, got %d", len(opts)) + } + if sorted[0].ResourceName != "Custom-A" { + t.Errorf("expected first = Custom-A, got %s", sorted[0].ResourceName) + } + if sorted[1].ResourceName != "Custom-B" { + t.Errorf("expected second = Custom-B, got %s", sorted[1].ResourceName) + } + if sorted[2].ResourceName != "Contributor" { + t.Errorf("expected third = Contributor, got %s", sorted[2].ResourceName) + } + if sorted[3].ResourceName != "Reader" { + t.Errorf("expected fourth = Reader, got %s", sorted[3].ResourceName) + } + if !strings.Contains(opts[0], "[custom]") { + t.Errorf("custom role option should include [custom] marker: %s", opts[0]) + } +} + +func TestSelectRole_NonInteractive(t *testing.T) { + orig := IsTerminalFunc + defer func() { IsTerminalFunc = orig }() + IsTerminalFunc = func(fd uintptr) bool { return false } + + _, err := SelectRole([]models.OnDemandResource{{ResourceID: "r1", ResourceName: "Reader"}}) + if err == nil { + t.Fatal("expected error in non-interactive mode") + } + if !errors.Is(err, ErrNotInteractive) { + t.Errorf("expected ErrNotInteractive, got %v", err) + } + if !strings.Contains(err.Error(), "--role-id") { + t.Errorf("error should suggest --role-id: %v", err) + } +} + +func TestSelectRole_EmptyList(t *testing.T) { + orig := IsTerminalFunc + defer func() { IsTerminalFunc = orig }() + IsTerminalFunc = func(fd uintptr) bool { return true } + + _, err := SelectRole(nil) + if err == nil { + t.Fatal("expected error for empty list") + } + if !strings.Contains(err.Error(), "no on-demand roles") { + t.Errorf("got: %v", err) + } +} From 426af27d3082be217397adff6e49d3e797f38b65 Mon Sep 17 00:00:00 2001 From: Tim Schindler Date: Mon, 20 Apr 2026 20:07:13 +0200 Subject: [PATCH 8/9] feat: interactive request picker for cancel/approve/reject/get When is omitted in a terminal, each command now shows a fuzzy-filterable picker scoped to actionable requests: - cancel: STARTING/RUNNING/PENDING requests you created (role=CREATOR) - approve/reject: PENDING requests assigned to you (role=APPROVER) - get: all requests (no filter) Positional arg path is unchanged. Non-TTY invocation (pipe, CI) returns ErrNotInteractive with a hint to pass the ID or run `grant request list`. Adds internal/ui/request_selector.go (Format/Build/Select trio) and cmd/request_picker.go (pickerScope + resolveRequestIDFn injectable var). --- CHANGELOG.md | 1 + CLAUDE.md | 7 +- cmd/request_cancel.go | 22 ++- cmd/request_finalize.go | 35 +++- cmd/request_get.go | 20 ++- cmd/request_picker.go | 50 ++++++ cmd/request_picker_test.go | 254 +++++++++++++++++++++++++++ internal/ui/request_selector.go | 93 ++++++++++ internal/ui/request_selector_test.go | 93 ++++++++++ 9 files changed, 560 insertions(+), 15 deletions(-) create mode 100644 cmd/request_picker.go create mode 100644 cmd/request_picker_test.go create mode 100644 internal/ui/request_selector.go create mode 100644 internal/ui/request_selector_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 3082ff4..0b76491 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ All notable changes to this project will be documented in this file. - Supported workspace types: `DIRECTORY` (azure_ad), `ACCOUNT` (aws), `MANAGEMENT_GROUP` (azure_resource) - Other Azure-resource scopes (subscription, resource group, resource) still require `--role-id` until validated - Roles cached in `~/.grant/cache/ondemand_roles__.json` (4h TTL) +- Interactive request picker for `grant request cancel`, `approve`, `reject`, and `get` — omit the `` positional argument in a terminal to pick from a scoped, fuzzy-filterable list (cancel: open requests you created; approve/reject: pending requests assigned to you; get: any request). Non-TTY invocation still requires the positional argument. ## [0.6.1] - 2026-04-08 diff --git a/CLAUDE.md b/CLAUDE.md index 7da39a1..3363e78 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -78,9 +78,10 @@ Custom `SCAAccessService` follows SDK conventions: - Interactive role selection supports `DIRECTORY`, `ACCOUNT` (AWS), and `MANAGEMENT_GROUP` workspaces; other Azure-resource scopes (subscription, resource group, resource) require `--role-id` - On-demand role cache: `~/.grant/cache/ondemand_roles__.json` (4h TTL); no `--refresh` flag — delete manually to invalidate - `grant request list` — list access requests; flags: `--state`, `--result`, `--priority`, `--role` (CREATOR/APPROVER), `--search`, `--sort`, `--desc` -- `grant request get ` — get full request details -- `grant request cancel ` — cancel an open request; optional `--reason` -- `grant request approve ` / `grant request reject ` — finalize a request; optional `--reason` +- `grant request get [id]` — get full request details; omitting `` in a TTY opens a fuzzy-filterable picker of all your requests +- `grant request cancel [id]` — cancel an open request; optional `--reason`. Omitting `` in a TTY opens a picker scoped to STARTING/RUNNING/PENDING requests you created (role=CREATOR) +- `grant request approve [id]` / `grant request reject [id]` — finalize a request; optional `--reason`. Omitting `` in a TTY opens a picker scoped to PENDING requests assigned to you (role=APPROVER) +- Request picker: `internal/ui/request_selector.go` mirrors the role-selector Format/Build/Select quartet; `resolveRequestIDFn` in `cmd/request_picker.go` is injectable for tests. Non-TTY invocation without `` returns `ErrNotInteractive` with a hint to run `grant request list` - `grant update` — self-update binary via GitHub Releases (`rhysd/go-github-selfupdate`); guards against dev builds - `--groups` flag on root command shows only Entra ID groups in the interactive selector - `--group` / `-g` flag on root command for direct group membership elevation (`grant --group "Cloud Admins"`) diff --git a/cmd/request_cancel.go b/cmd/request_cancel.go index 8845f5c..416ce18 100644 --- a/cmd/request_cancel.go +++ b/cmd/request_cancel.go @@ -8,9 +8,10 @@ import ( func newRequestCancelCommand(svc accessRequestService) *cobra.Command { cmd := &cobra.Command{ - Use: "cancel ", + Use: "cancel [requestId]", Short: "Cancel an open access request", - Args: cobra.ExactArgs(1), + Long: "Cancel an open access request. If is omitted in a terminal, an interactive picker of open requests you created is shown.", + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { if svc == nil { bootstrapped, err := bootstrapWorkflowsService() @@ -19,7 +20,22 @@ func newRequestCancelCommand(svc accessRequestService) *cobra.Command { } svc = bootstrapped } - return runRequestCancel(cmd, args[0], svc) + requestID := "" + if len(args) > 0 { + requestID = args[0] + } + if requestID == "" { + id, err := resolveRequestIDFn(cmd.Context(), svc, pickerScope{ + filter: "((requestState eq STARTING) or (requestState eq RUNNING) or (requestState eq PENDING))", + requestRole: "CREATOR", + emptyMsg: "open requests you created", + }) + if err != nil { + return err + } + requestID = id + } + return runRequestCancel(cmd, requestID, svc) }, } diff --git a/cmd/request_finalize.go b/cmd/request_finalize.go index 106f8a3..1d1d886 100644 --- a/cmd/request_finalize.go +++ b/cmd/request_finalize.go @@ -8,9 +8,10 @@ import ( func newRequestApproveCommand(svc accessRequestService) *cobra.Command { cmd := &cobra.Command{ - Use: "approve ", + Use: "approve [requestId]", Short: "Approve an access request", - Args: cobra.ExactArgs(1), + Long: "Approve an access request. If is omitted in a terminal, an interactive picker of pending requests assigned to you is shown.", + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { if svc == nil { bootstrapped, err := bootstrapWorkflowsService() @@ -19,7 +20,11 @@ func newRequestApproveCommand(svc accessRequestService) *cobra.Command { } svc = bootstrapped } - return runFinalize(cmd, args[0], "APPROVED", svc) + requestID, err := resolveFinalizeRequestID(cmd, args, svc) + if err != nil { + return err + } + return runFinalize(cmd, requestID, "APPROVED", svc) }, } @@ -30,9 +35,10 @@ func newRequestApproveCommand(svc accessRequestService) *cobra.Command { func newRequestRejectCommand(svc accessRequestService) *cobra.Command { cmd := &cobra.Command{ - Use: "reject ", + Use: "reject [requestId]", Short: "Reject an access request", - Args: cobra.ExactArgs(1), + Long: "Reject an access request. If is omitted in a terminal, an interactive picker of pending requests assigned to you is shown.", + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { if svc == nil { bootstrapped, err := bootstrapWorkflowsService() @@ -41,7 +47,11 @@ func newRequestRejectCommand(svc accessRequestService) *cobra.Command { } svc = bootstrapped } - return runFinalize(cmd, args[0], "REJECTED", svc) + requestID, err := resolveFinalizeRequestID(cmd, args, svc) + if err != nil { + return err + } + return runFinalize(cmd, requestID, "REJECTED", svc) }, } @@ -50,6 +60,19 @@ func newRequestRejectCommand(svc accessRequestService) *cobra.Command { return cmd } +// resolveFinalizeRequestID returns the positional requestId, or falls back to +// the interactive picker scoped to approver-pending requests. +func resolveFinalizeRequestID(cmd *cobra.Command, args []string, svc accessRequestService) (string, error) { + if len(args) > 0 && args[0] != "" { + return args[0], nil + } + return resolveRequestIDFn(cmd.Context(), svc, pickerScope{ + filter: "(requestState eq PENDING)", + requestRole: "APPROVER", + emptyMsg: "pending requests assigned to you", + }) +} + func runFinalize(cmd *cobra.Command, requestID, decision string, svc accessRequestService) error { ctx := cmd.Context() diff --git a/cmd/request_get.go b/cmd/request_get.go index a61ff87..2d90bf1 100644 --- a/cmd/request_get.go +++ b/cmd/request_get.go @@ -9,9 +9,10 @@ import ( func newRequestGetCommand(svc accessRequestService) *cobra.Command { return &cobra.Command{ - Use: "get ", + Use: "get [requestId]", Short: "Get details of an access request", - Args: cobra.ExactArgs(1), + Long: "Get details of an access request. If is omitted in a terminal, an interactive picker of your access requests is shown.", + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { if svc == nil { bootstrapped, err := bootstrapWorkflowsService() @@ -20,7 +21,20 @@ func newRequestGetCommand(svc accessRequestService) *cobra.Command { } svc = bootstrapped } - return runRequestGet(cmd, args[0], svc) + requestID := "" + if len(args) > 0 { + requestID = args[0] + } + if requestID == "" { + id, err := resolveRequestIDFn(cmd.Context(), svc, pickerScope{ + emptyMsg: "access requests", + }) + if err != nil { + return err + } + requestID = id + } + return runRequestGet(cmd, requestID, svc) }, } } diff --git a/cmd/request_picker.go b/cmd/request_picker.go new file mode 100644 index 0000000..f79d464 --- /dev/null +++ b/cmd/request_picker.go @@ -0,0 +1,50 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + + "github.com/aaearon/grant-cli/internal/ui" + "github.com/aaearon/grant-cli/internal/workflows" +) + +// pickerScope describes how to scope the list of requests surfaced in the picker. +type pickerScope struct { + filter string // OData filter; empty = no filter + requestRole string // "CREATOR" | "APPROVER" | "" + emptyMsg string // e.g. "open requests you created" +} + +// resolveRequestIDFn is injectable for testing. +var resolveRequestIDFn = resolveRequestIDInteractive + +// resolveRequestIDInteractive fetches a scoped list of access requests and +// shows the interactive picker, returning the chosen request ID. +func resolveRequestIDInteractive(ctx context.Context, svc accessRequestService, scope pickerScope) (string, error) { + if !ui.IsInteractive() { + if isJSONOutput() { + return "", errors.New("request ID is required with --output json; run `grant request list --output json` to find it") + } + return "", fmt.Errorf("%w; pass the request ID as a positional argument (run `grant request list` to find it)", ui.ErrNotInteractive) + } + + items, _, err := svc.ListRequests(ctx, workflows.ListRequestsParams{ + Filter: scope.filter, + RequestRole: scope.requestRole, + Sort: "createdAt desc", + }) + if err != nil { + return "", fmt.Errorf("failed to list requests: %w", err) + } + + if len(items) == 0 { + return "", errors.New("no " + scope.emptyMsg + "; nothing to act on") + } + + chosen, err := ui.SelectRequest(items) + if err != nil { + return "", err + } + return chosen.RequestID, nil +} diff --git a/cmd/request_picker_test.go b/cmd/request_picker_test.go new file mode 100644 index 0000000..5097470 --- /dev/null +++ b/cmd/request_picker_test.go @@ -0,0 +1,254 @@ +package cmd + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/aaearon/grant-cli/internal/ui" + "github.com/aaearon/grant-cli/internal/workflows" + wfmodels "github.com/aaearon/grant-cli/internal/workflows/models" +) + +// capturingMockAccessRequestService embeds mockAccessRequestService and captures list params. +type capturingMockAccessRequestService struct { + mockAccessRequestService + lastListParams workflows.ListRequestsParams +} + +func (m *capturingMockAccessRequestService) ListRequests(_ context.Context, params workflows.ListRequestsParams) ([]wfmodels.AccessRequest, int, error) { + m.lastListParams = params + return m.listItems, m.listTotalCount, m.listErr +} + +func withInteractiveTTY(t *testing.T, interactive bool) { + t.Helper() + orig := ui.IsTerminalFunc + t.Cleanup(func() { ui.IsTerminalFunc = orig }) + ui.IsTerminalFunc = func(_ uintptr) bool { return interactive } +} + +func TestResolveRequestIDInteractive_NonInteractive(t *testing.T) { + withInteractiveTTY(t, false) + + svc := &capturingMockAccessRequestService{} + _, err := resolveRequestIDInteractive(t.Context(), svc, pickerScope{emptyMsg: "access requests"}) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, ui.ErrNotInteractive) { + t.Errorf("expected ErrNotInteractive, got %v", err) + } + if !strings.Contains(err.Error(), "grant request list") { + t.Errorf("expected hint to 'grant request list', got %v", err) + } +} + +func TestResolveRequestIDInteractive_JSONMode(t *testing.T) { + withInteractiveTTY(t, false) + orig := outputFormat + outputFormat = "json" + t.Cleanup(func() { outputFormat = orig }) + + svc := &capturingMockAccessRequestService{} + _, err := resolveRequestIDInteractive(t.Context(), svc, pickerScope{emptyMsg: "access requests"}) + if err == nil { + t.Fatal("expected error") + } + if errors.Is(err, ui.ErrNotInteractive) { + t.Errorf("JSON mode error should not wrap ErrNotInteractive") + } + if !strings.Contains(err.Error(), "--output json") { + t.Errorf("expected --output json hint, got %v", err) + } + if strings.Contains(err.Error(), "requires a terminal") { + t.Errorf("JSON mode error should not mention terminal: %v", err) + } +} + +func TestResolveRequestIDInteractive_EmptyList(t *testing.T) { + withInteractiveTTY(t, true) + + svc := &capturingMockAccessRequestService{} + _, err := resolveRequestIDInteractive(t.Context(), svc, pickerScope{ + filter: "(requestState eq PENDING)", + requestRole: "APPROVER", + emptyMsg: "pending requests assigned to you", + }) + if err == nil { + t.Fatal("expected error on empty list") + } + if !strings.Contains(err.Error(), "pending requests assigned to you") { + t.Errorf("expected emptyMsg in error, got %v", err) + } + if svc.lastListParams.Filter != "(requestState eq PENDING)" { + t.Errorf("filter: got %q", svc.lastListParams.Filter) + } + if svc.lastListParams.RequestRole != "APPROVER" { + t.Errorf("requestRole: got %q", svc.lastListParams.RequestRole) + } + if svc.lastListParams.Sort != "createdAt desc" { + t.Errorf("sort: got %q", svc.lastListParams.Sort) + } +} + +func TestResolveRequestIDInteractive_ListError(t *testing.T) { + withInteractiveTTY(t, true) + + svc := &capturingMockAccessRequestService{ + mockAccessRequestService: mockAccessRequestService{listErr: errors.New("boom")}, + } + _, err := resolveRequestIDInteractive(t.Context(), svc, pickerScope{emptyMsg: "x"}) + if err == nil || !strings.Contains(err.Error(), "boom") { + t.Fatalf("expected list error, got %v", err) + } +} + +// stubResolver replaces resolveRequestIDFn for testing the command integration. +func stubResolver(t *testing.T, id string, err error) *struct { + scope pickerScope + called bool +} { + t.Helper() + capture := &struct { + scope pickerScope + called bool + }{} + orig := resolveRequestIDFn + t.Cleanup(func() { resolveRequestIDFn = orig }) + resolveRequestIDFn = func(_ context.Context, _ accessRequestService, scope pickerScope) (string, error) { + capture.called = true + capture.scope = scope + return id, err + } + return capture +} + +func TestRequestCancel_PickerFallback(t *testing.T) { + svc := &mockAccessRequestService{ + cancelResult: &wfmodels.AccessRequest{RequestID: "picked-id", RequestResult: wfmodels.RequestResultCanceled}, + } + capture := stubResolver(t, "picked-id", nil) + + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + output, err := executeCommand(root, "request", "cancel") + if err != nil { + t.Fatalf("unexpected error: %v\noutput: %s", err, output) + } + if !capture.called { + t.Fatal("resolver was not called") + } + if capture.scope.requestRole != "CREATOR" { + t.Errorf("expected CREATOR scope, got %q", capture.scope.requestRole) + } + if !strings.Contains(capture.scope.filter, "STARTING") { + t.Errorf("expected filter with STARTING, got %q", capture.scope.filter) + } + if !strings.Contains(output, "picked-id") { + t.Errorf("expected picked-id in output: %s", output) + } +} + +func TestRequestApprove_PickerFallback(t *testing.T) { + svc := &mockAccessRequestService{ + finalizeResult: &wfmodels.AccessRequest{RequestID: "picked-id", RequestResult: wfmodels.RequestResultApproved}, + } + capture := stubResolver(t, "picked-id", nil) + + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + output, err := executeCommand(root, "request", "approve") + if err != nil { + t.Fatalf("unexpected error: %v\noutput: %s", err, output) + } + if !capture.called { + t.Fatal("resolver was not called") + } + if capture.scope.requestRole != "APPROVER" { + t.Errorf("expected APPROVER scope, got %q", capture.scope.requestRole) + } + if capture.scope.filter != "(requestState eq PENDING)" { + t.Errorf("unexpected filter: %q", capture.scope.filter) + } + if !strings.Contains(output, "approved") { + t.Errorf("expected approved in output: %s", output) + } +} + +func TestRequestReject_PickerFallback(t *testing.T) { + svc := &mockAccessRequestService{ + finalizeResult: &wfmodels.AccessRequest{RequestID: "picked-id", RequestResult: wfmodels.RequestResultRejected}, + } + capture := stubResolver(t, "picked-id", nil) + + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + output, err := executeCommand(root, "request", "reject") + if err != nil { + t.Fatalf("unexpected error: %v\noutput: %s", err, output) + } + if capture.scope.requestRole != "APPROVER" { + t.Errorf("expected APPROVER scope, got %q", capture.scope.requestRole) + } + if !strings.Contains(output, "rejected") { + t.Errorf("expected rejected in output: %s", output) + } +} + +func TestRequestGet_PickerFallback(t *testing.T) { + svc := &mockAccessRequestService{ + getResult: &wfmodels.AccessRequest{ + RequestID: "picked-id", + RequestState: wfmodels.RequestStateFinished, + RequestResult: wfmodels.RequestResultApproved, + CreatedBy: "user@test", + CreatedAt: "t", + UpdatedBy: "SYSTEM", + UpdatedAt: "t", + }, + } + capture := stubResolver(t, "picked-id", nil) + + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + output, err := executeCommand(root, "request", "get") + if err != nil { + t.Fatalf("unexpected error: %v\noutput: %s", err, output) + } + if capture.scope.filter != "" { + t.Errorf("get scope should have no filter, got %q", capture.scope.filter) + } + if capture.scope.requestRole != "" { + t.Errorf("get scope should have no requestRole, got %q", capture.scope.requestRole) + } + if !strings.Contains(output, "picked-id") { + t.Errorf("expected picked-id in output: %s", output) + } +} + +func TestRequestCancel_PickerError(t *testing.T) { + svc := &mockAccessRequestService{} + stubResolver(t, "", errors.New("no open requests")) + + cmd := NewRequestCommandWithDeps(svc) + root := newTestRootCommand() + root.AddCommand(cmd) + + _, err := executeCommand(root, "request", "cancel") + if err == nil { + t.Fatal("expected error from picker") + } + if !strings.Contains(err.Error(), "no open requests") { + t.Errorf("expected picker error, got %v", err) + } +} diff --git a/internal/ui/request_selector.go b/internal/ui/request_selector.go new file mode 100644 index 0000000..2842f6e --- /dev/null +++ b/internal/ui/request_selector.go @@ -0,0 +1,93 @@ +package ui + +import ( + "errors" + "fmt" + "os" + "sort" + "strings" + "time" + + "github.com/Iilun/survey/v2" + wfmodels "github.com/aaearon/grant-cli/internal/workflows/models" +) + +// formatSelectorTimestamp strips fractional seconds from a timestamp while preserving +// timezone offset information. Duplicated locally to avoid import cycle with cmd package. +func formatSelectorTimestamp(ts string) string { + if t, err := time.Parse(time.RFC3339Nano, ts); err == nil { + return t.Format("2006-01-02T15:04:05Z07:00") + } + if i := strings.IndexByte(ts, '.'); i >= 0 { + return ts[:i] + } + return ts +} + +// FormatRequestOption formats an access request into a display string for the picker. +// Format: / (by , ) [] +func FormatRequestOption(r wfmodels.AccessRequest) string { + workspace := r.DetailString("workspaceName") + if workspace == "" { + workspace = "-" + } + role := r.DetailString("roleName") + if role == "" { + role = "-" + } + shortID := r.RequestID + if len(shortID) > 8 { + shortID = shortID[:8] + } + return fmt.Sprintf("%s %s / %s (by %s, %s) [%s]", + r.RequestState, + workspace, + role, + r.CreatedBy, + formatSelectorTimestamp(r.CreatedAt), + shortID, + ) +} + +// BuildRequestOptions returns display strings and a parallel slice of requests, +// sorted by CreatedAt descending (most recent first). +func BuildRequestOptions(requests []wfmodels.AccessRequest) ([]string, []wfmodels.AccessRequest) { + sorted := make([]wfmodels.AccessRequest, len(requests)) + copy(sorted, requests) + sort.SliceStable(sorted, func(i, j int) bool { + return sorted[i].CreatedAt > sorted[j].CreatedAt + }) + opts := make([]string, len(sorted)) + for i, r := range sorted { + opts[i] = FormatRequestOption(r) + } + return opts, sorted +} + +// SelectRequest prompts the user to pick a request from the list. Uses the selected +// index (not display text) to recover the request, so duplicate display strings are safe. +func SelectRequest(requests []wfmodels.AccessRequest) (*wfmodels.AccessRequest, error) { + if !IsInteractive() { + return nil, fmt.Errorf("%w; pass the request ID as a positional argument for non-interactive mode", ErrNotInteractive) + } + if len(requests) == 0 { + return nil, errors.New("no access requests available to select") + } + + options, sorted := BuildRequestOptions(requests) + + var selectedIdx int + prompt := &survey.Select{ + Message: "Select a request:", + Options: options, + PageSize: 15, + Filter: nil, + } + if err := survey.AskOne(prompt, &selectedIdx, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)); err != nil { + return nil, fmt.Errorf("request selection failed: %w", err) + } + if selectedIdx < 0 || selectedIdx >= len(sorted) { + return nil, fmt.Errorf("invalid request selection index %d", selectedIdx) + } + return &sorted[selectedIdx], nil +} diff --git a/internal/ui/request_selector_test.go b/internal/ui/request_selector_test.go new file mode 100644 index 0000000..7bb3cde --- /dev/null +++ b/internal/ui/request_selector_test.go @@ -0,0 +1,93 @@ +package ui + +import ( + "errors" + "strings" + "testing" + + wfmodels "github.com/aaearon/grant-cli/internal/workflows/models" +) + +func TestFormatRequestOption(t *testing.T) { + r := wfmodels.AccessRequest{ + RequestID: "abcdef12-3456-7890-aaaa-bbbbccccdddd", + RequestState: wfmodels.RequestStatePending, + CreatedBy: "user@test", + CreatedAt: "2026-04-20T10:00:00Z", + RequestDetails: map[string]any{ + "workspaceName": "prod-account", + "roleName": "Admin", + }, + } + got := FormatRequestOption(r) + for _, want := range []string{"PENDING", "prod-account", "Admin", "user@test", "2026-04-20T10:00:00Z", "[abcdef12]"} { + if !strings.Contains(got, want) { + t.Errorf("FormatRequestOption missing %q: got %q", want, got) + } + } +} + +func TestFormatRequestOption_MissingDetails(t *testing.T) { + r := wfmodels.AccessRequest{ + RequestID: "short", + RequestState: wfmodels.RequestStateRunning, + CreatedBy: "u", + CreatedAt: "x", + } + got := FormatRequestOption(r) + if !strings.Contains(got, "- / -") { + t.Errorf("expected '- / -' placeholders, got %q", got) + } + if !strings.Contains(got, "[short]") { + t.Errorf("expected short id [short], got %q", got) + } +} + +func TestBuildRequestOptions_SortedByCreatedAtDesc(t *testing.T) { + reqs := []wfmodels.AccessRequest{ + {RequestID: "a", CreatedAt: "2026-04-01T00:00:00Z"}, + {RequestID: "b", CreatedAt: "2026-04-20T00:00:00Z"}, + {RequestID: "c", CreatedAt: "2026-04-10T00:00:00Z"}, + } + opts, sorted := BuildRequestOptions(reqs) + if len(opts) != 3 { + t.Fatalf("expected 3 options, got %d", len(opts)) + } + want := []string{"b", "c", "a"} + for i, id := range want { + if sorted[i].RequestID != id { + t.Errorf("position %d: got %q want %q", i, sorted[i].RequestID, id) + } + } +} + +func TestSelectRequest_NonInteractive(t *testing.T) { + orig := IsTerminalFunc + defer func() { IsTerminalFunc = orig }() + IsTerminalFunc = func(fd uintptr) bool { return false } + + _, err := SelectRequest([]wfmodels.AccessRequest{{RequestID: "r1"}}) + if err == nil { + t.Fatal("expected error in non-interactive mode") + } + if !errors.Is(err, ErrNotInteractive) { + t.Errorf("expected ErrNotInteractive, got %v", err) + } + if !strings.Contains(err.Error(), "positional argument") { + t.Errorf("error should hint at positional arg: %v", err) + } +} + +func TestSelectRequest_EmptyList(t *testing.T) { + orig := IsTerminalFunc + defer func() { IsTerminalFunc = orig }() + IsTerminalFunc = func(fd uintptr) bool { return true } + + _, err := SelectRequest(nil) + if err == nil { + t.Fatal("expected error for empty list") + } + if !strings.Contains(err.Error(), "no access requests") { + t.Errorf("unexpected error: %v", err) + } +} From 54eb3b369b9bcd94e65f6cd962cfad69775acd2e Mon Sep 17 00:00:00 2001 From: Tim Schindler Date: Mon, 20 Apr 2026 20:18:58 +0200 Subject: [PATCH 9/9] fix: bootstrap after TTY check; parse timestamps for correct sort order Guard against auth errors in non-interactive contexts by checking ErrNotInteractive before bootstrapping the workflows service. Fix timezone-offset timestamp sort by parsing to time.Time instead of relying on lexicographic string comparison. --- cmd/request_cancel.go | 11 +++--- cmd/request_finalize.go | 22 +++++++++--- cmd/request_get.go | 11 +++--- cmd/request_picker.go | 9 +++++ cmd/request_picker_test.go | 52 ++++++++++++++++++++++++++++ internal/ui/request_selector.go | 5 +++ internal/ui/request_selector_test.go | 16 +++++++++ 7 files changed, 114 insertions(+), 12 deletions(-) diff --git a/cmd/request_cancel.go b/cmd/request_cancel.go index 416ce18..5fd1abf 100644 --- a/cmd/request_cancel.go +++ b/cmd/request_cancel.go @@ -13,6 +13,13 @@ func newRequestCancelCommand(svc accessRequestService) *cobra.Command { Long: "Cancel an open access request. If is omitted in a terminal, an interactive picker of open requests you created is shown.", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { + requestID := "" + if len(args) > 0 { + requestID = args[0] + } + if err := earlyNonInteractiveCheck(requestID); err != nil { + return err + } if svc == nil { bootstrapped, err := bootstrapWorkflowsService() if err != nil { @@ -20,10 +27,6 @@ func newRequestCancelCommand(svc accessRequestService) *cobra.Command { } svc = bootstrapped } - requestID := "" - if len(args) > 0 { - requestID = args[0] - } if requestID == "" { id, err := resolveRequestIDFn(cmd.Context(), svc, pickerScope{ filter: "((requestState eq STARTING) or (requestState eq RUNNING) or (requestState eq PENDING))", diff --git a/cmd/request_finalize.go b/cmd/request_finalize.go index 1d1d886..078c74f 100644 --- a/cmd/request_finalize.go +++ b/cmd/request_finalize.go @@ -13,6 +13,13 @@ func newRequestApproveCommand(svc accessRequestService) *cobra.Command { Long: "Approve an access request. If is omitted in a terminal, an interactive picker of pending requests assigned to you is shown.", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { + requestID := "" + if len(args) > 0 { + requestID = args[0] + } + if err := earlyNonInteractiveCheck(requestID); err != nil { + return err + } if svc == nil { bootstrapped, err := bootstrapWorkflowsService() if err != nil { @@ -20,11 +27,11 @@ func newRequestApproveCommand(svc accessRequestService) *cobra.Command { } svc = bootstrapped } - requestID, err := resolveFinalizeRequestID(cmd, args, svc) + id, err := resolveFinalizeRequestID(cmd, args, svc) if err != nil { return err } - return runFinalize(cmd, requestID, "APPROVED", svc) + return runFinalize(cmd, id, "APPROVED", svc) }, } @@ -40,6 +47,13 @@ func newRequestRejectCommand(svc accessRequestService) *cobra.Command { Long: "Reject an access request. If is omitted in a terminal, an interactive picker of pending requests assigned to you is shown.", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { + requestID := "" + if len(args) > 0 { + requestID = args[0] + } + if err := earlyNonInteractiveCheck(requestID); err != nil { + return err + } if svc == nil { bootstrapped, err := bootstrapWorkflowsService() if err != nil { @@ -47,11 +61,11 @@ func newRequestRejectCommand(svc accessRequestService) *cobra.Command { } svc = bootstrapped } - requestID, err := resolveFinalizeRequestID(cmd, args, svc) + id, err := resolveFinalizeRequestID(cmd, args, svc) if err != nil { return err } - return runFinalize(cmd, requestID, "REJECTED", svc) + return runFinalize(cmd, id, "REJECTED", svc) }, } diff --git a/cmd/request_get.go b/cmd/request_get.go index 2d90bf1..a752857 100644 --- a/cmd/request_get.go +++ b/cmd/request_get.go @@ -14,6 +14,13 @@ func newRequestGetCommand(svc accessRequestService) *cobra.Command { Long: "Get details of an access request. If is omitted in a terminal, an interactive picker of your access requests is shown.", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { + requestID := "" + if len(args) > 0 { + requestID = args[0] + } + if err := earlyNonInteractiveCheck(requestID); err != nil { + return err + } if svc == nil { bootstrapped, err := bootstrapWorkflowsService() if err != nil { @@ -21,10 +28,6 @@ func newRequestGetCommand(svc accessRequestService) *cobra.Command { } svc = bootstrapped } - requestID := "" - if len(args) > 0 { - requestID = args[0] - } if requestID == "" { id, err := resolveRequestIDFn(cmd.Context(), svc, pickerScope{ emptyMsg: "access requests", diff --git a/cmd/request_picker.go b/cmd/request_picker.go index f79d464..f87ad85 100644 --- a/cmd/request_picker.go +++ b/cmd/request_picker.go @@ -9,6 +9,15 @@ import ( "github.com/aaearon/grant-cli/internal/workflows" ) +// earlyNonInteractiveCheck fails fast before bootstrap when no requestID is +// provided and stdin is not a TTY. +func earlyNonInteractiveCheck(requestID string) error { + if requestID == "" && !ui.IsInteractive() { + return fmt.Errorf("%w; pass the request ID as a positional argument (run `grant request list` to find it)", ui.ErrNotInteractive) + } + return nil +} + // pickerScope describes how to scope the list of requests surfaced in the picker. type pickerScope struct { filter string // OData filter; empty = no filter diff --git a/cmd/request_picker_test.go b/cmd/request_picker_test.go index 5097470..d17bf05 100644 --- a/cmd/request_picker_test.go +++ b/cmd/request_picker_test.go @@ -126,6 +126,7 @@ func stubResolver(t *testing.T, id string, err error) *struct { } func TestRequestCancel_PickerFallback(t *testing.T) { + withInteractiveTTY(t, true) svc := &mockAccessRequestService{ cancelResult: &wfmodels.AccessRequest{RequestID: "picked-id", RequestResult: wfmodels.RequestResultCanceled}, } @@ -154,6 +155,7 @@ func TestRequestCancel_PickerFallback(t *testing.T) { } func TestRequestApprove_PickerFallback(t *testing.T) { + withInteractiveTTY(t, true) svc := &mockAccessRequestService{ finalizeResult: &wfmodels.AccessRequest{RequestID: "picked-id", RequestResult: wfmodels.RequestResultApproved}, } @@ -182,6 +184,7 @@ func TestRequestApprove_PickerFallback(t *testing.T) { } func TestRequestReject_PickerFallback(t *testing.T) { + withInteractiveTTY(t, true) svc := &mockAccessRequestService{ finalizeResult: &wfmodels.AccessRequest{RequestID: "picked-id", RequestResult: wfmodels.RequestResultRejected}, } @@ -204,6 +207,7 @@ func TestRequestReject_PickerFallback(t *testing.T) { } func TestRequestGet_PickerFallback(t *testing.T) { + withInteractiveTTY(t, true) svc := &mockAccessRequestService{ getResult: &wfmodels.AccessRequest{ RequestID: "picked-id", @@ -237,6 +241,7 @@ func TestRequestGet_PickerFallback(t *testing.T) { } func TestRequestCancel_PickerError(t *testing.T) { + withInteractiveTTY(t, true) svc := &mockAccessRequestService{} stubResolver(t, "", errors.New("no open requests")) @@ -252,3 +257,50 @@ func TestRequestCancel_PickerError(t *testing.T) { t.Errorf("expected picker error, got %v", err) } } + +func TestEarlyNonInteractiveCheck_NoID(t *testing.T) { + withInteractiveTTY(t, false) + err := earlyNonInteractiveCheck("") + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, ui.ErrNotInteractive) { + t.Errorf("expected ErrNotInteractive, got %v", err) + } + if !strings.Contains(err.Error(), "grant request list") { + t.Errorf("expected hint to 'grant request list', got %v", err) + } +} + +func TestEarlyNonInteractiveCheck_WithID(t *testing.T) { + withInteractiveTTY(t, false) + if err := earlyNonInteractiveCheck("some-id"); err != nil { + t.Errorf("expected nil when ID provided, got %v", err) + } +} + +func TestEarlyNonInteractiveCheck_Interactive(t *testing.T) { + withInteractiveTTY(t, true) + if err := earlyNonInteractiveCheck(""); err != nil { + t.Errorf("expected nil in interactive mode, got %v", err) + } +} + +// TestRequestCancel_NonInteractiveNoArgs verifies bootstrap is not reached when +// stdin is non-interactive and no requestID is provided. +func TestRequestCancel_NonInteractiveNoArgs(t *testing.T) { + withInteractiveTTY(t, false) + + // Pass nil svc so bootstrap would be attempted if early check is bypassed. + cmd := newRequestCancelCommand(nil) + root := newTestRootCommand() + root.AddCommand(cmd) + + _, err := executeCommand(root, "cancel") + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, ui.ErrNotInteractive) { + t.Errorf("expected ErrNotInteractive (bootstrap not reached), got %v", err) + } +} diff --git a/internal/ui/request_selector.go b/internal/ui/request_selector.go index 2842f6e..331ffd9 100644 --- a/internal/ui/request_selector.go +++ b/internal/ui/request_selector.go @@ -55,6 +55,11 @@ func BuildRequestOptions(requests []wfmodels.AccessRequest) ([]string, []wfmodel sorted := make([]wfmodels.AccessRequest, len(requests)) copy(sorted, requests) sort.SliceStable(sorted, func(i, j int) bool { + ti, erri := time.Parse(time.RFC3339Nano, sorted[i].CreatedAt) + tj, errj := time.Parse(time.RFC3339Nano, sorted[j].CreatedAt) + if erri == nil && errj == nil { + return ti.After(tj) + } return sorted[i].CreatedAt > sorted[j].CreatedAt }) opts := make([]string, len(sorted)) diff --git a/internal/ui/request_selector_test.go b/internal/ui/request_selector_test.go index 7bb3cde..9495809 100644 --- a/internal/ui/request_selector_test.go +++ b/internal/ui/request_selector_test.go @@ -78,6 +78,22 @@ func TestSelectRequest_NonInteractive(t *testing.T) { } } +func TestBuildRequestOptions_SortedCorrectlyWithOffsets(t *testing.T) { + // "2026-04-20T10:00:00+02:00" == 08:00 UTC — earlier than 09:30Z + // String sort would put +02:00 after Z; time sort must put 09:30Z first. + reqs := []wfmodels.AccessRequest{ + {RequestID: "offset", CreatedAt: "2026-04-20T10:00:00+02:00"}, // 08:00 UTC + {RequestID: "utc", CreatedAt: "2026-04-20T09:30:00Z"}, // 09:30 UTC (more recent) + } + _, sorted := BuildRequestOptions(reqs) + if sorted[0].RequestID != "utc" { + t.Errorf("expected utc (09:30Z) first, got %q", sorted[0].RequestID) + } + if sorted[1].RequestID != "offset" { + t.Errorf("expected offset (08:00 UTC) second, got %q", sorted[1].RequestID) + } +} + func TestSelectRequest_EmptyList(t *testing.T) { orig := IsTerminalFunc defer func() { IsTerminalFunc = orig }()