diff --git a/README.md b/README.md index 2da35f3..1bb7564 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,7 @@ Commands with JSON output support: - **Profiles**: `create`, `list`, `get` - **Extensions**: `upload`, `list` - **Proxies**: `create`, `list`, `get` +- **API Keys**: `create`, `list`, `get`, `update` - **Apps**: `list`, `history` - **Deploy**: `deploy` (JSONL streaming), `history` - **Invoke**: `invoke` (JSONL streaming), `history` @@ -512,6 +513,29 @@ Automated authentication for web services. The `run` command orchestrates the fu - `kernel credentials totp-code ` - Get current TOTP code - `--output json`, `-o json` - Output raw JSON object +### API Keys + +- `kernel api-keys create` - Create a new API key + - `--name ` - API key name (required) + - `--days-to-expire ` - Number of days until expiry (1-3650); omit for never + - `--project-id ` - Create a project-scoped API key for this project ID; omit for org-wide. This is different from global `--project`, which only scopes the CLI request. + - `--output json`, `-o json` - Output raw JSON object, including the one-time plaintext key + +- `kernel api-keys list` - List API keys + - `--limit ` - Maximum number of results to return + - `--offset ` - Number of results to skip + - `--output json`, `-o json` - Output raw JSON array + +- `kernel api-keys get ` - Get an API key + - `--output json`, `-o json` - Output raw JSON object + +- `kernel api-keys update ` - Update an API key + - `--name ` - New API key name + - `--output json`, `-o json` - Output raw JSON object + +- `kernel api-keys delete ` - Delete an API key + - `-y, --yes` - Skip confirmation prompt + ## Examples ### Create a new app diff --git a/cmd/api_keys.go b/cmd/api_keys.go new file mode 100644 index 0000000..77374cf --- /dev/null +++ b/cmd/api_keys.go @@ -0,0 +1,392 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/kernel/cli/pkg/util" + "github.com/kernel/kernel-go-sdk" + "github.com/kernel/kernel-go-sdk/option" + "github.com/kernel/kernel-go-sdk/packages/pagination" + "github.com/pterm/pterm" + "github.com/spf13/cobra" +) + +type APIKeysService interface { + New(ctx context.Context, body kernel.APIKeyNewParams, opts ...option.RequestOption) (*kernel.CreatedAPIKey, error) + Get(ctx context.Context, id string, opts ...option.RequestOption) (*kernel.APIKey, error) + Update(ctx context.Context, id string, body kernel.APIKeyUpdateParams, opts ...option.RequestOption) (*kernel.APIKey, error) + List(ctx context.Context, query kernel.APIKeyListParams, opts ...option.RequestOption) (*pagination.OffsetPagination[kernel.APIKey], error) + Delete(ctx context.Context, id string, opts ...option.RequestOption) error +} + +type APIKeysCmd struct { + apiKeys APIKeysService +} + +type APIKeysCreateInput struct { + Name string + DaysToExpire Int64Flag + ProjectID string + Output string +} + +type APIKeysListInput struct { + Limit int + Offset int + Output string +} + +type APIKeysGetInput struct { + ID string + Output string +} + +type APIKeysUpdateInput struct { + ID string + Name string + Output string +} + +type APIKeysDeleteInput struct { + ID string + SkipConfirm bool +} + +func (c APIKeysCmd) Create(ctx context.Context, in APIKeysCreateInput) error { + if err := validateJSONOutput(in.Output); err != nil { + return err + } + if in.Name == "" { + return fmt.Errorf("--name is required") + } + + params := kernel.APIKeyNewParams{Name: in.Name} + if in.DaysToExpire.Set { + if in.DaysToExpire.Value < 1 || in.DaysToExpire.Value > 3650 { + return fmt.Errorf("--days-to-expire must be between 1 and 3650") + } + params.DaysToExpire = kernel.Int(in.DaysToExpire.Value) + } + if in.ProjectID != "" { + params.ProjectID = kernel.String(in.ProjectID) + } + + key, err := c.apiKeys.New(ctx, params) + if err != nil { + return util.CleanedUpSdkError{Err: err} + } + + if in.Output == "json" { + return util.PrintPrettyJSON(key) + } + + pterm.Success.Printf("Created API key: %s\n", key.ID) + renderCreatedAPIKey(key) + return nil +} + +func (c APIKeysCmd) List(ctx context.Context, in APIKeysListInput) error { + if err := validateJSONOutput(in.Output); err != nil { + return err + } + if in.Limit < 0 { + return fmt.Errorf("--limit must be non-negative") + } + if in.Offset < 0 { + return fmt.Errorf("--offset must be non-negative") + } + + params := kernel.APIKeyListParams{} + if in.Limit > 0 { + params.Limit = kernel.Int(int64(in.Limit)) + } + if in.Offset > 0 { + params.Offset = kernel.Int(int64(in.Offset)) + } + + page, err := c.apiKeys.List(ctx, params) + if err != nil { + return util.CleanedUpSdkError{Err: err} + } + + var keys []kernel.APIKey + if page != nil { + keys = page.Items + } + + if in.Output == "json" { + return util.PrintPrettyJSONSlice(keys) + } + + if len(keys) == 0 { + pterm.Info.Println("No API keys found") + return nil + } + + table := pterm.TableData{{"ID", "Name", "Scope", "Project", "Masked Key", "Expires At", "Created At"}} + for _, key := range keys { + table = append(table, []string{ + key.ID, + key.Name, + formatAPIKeyScope(key), + formatAPIKeyProject(key), + key.MaskedKey, + formatAPIKeyExpiresAt(key), + util.FormatLocal(key.CreatedAt), + }) + } + PrintTableNoPad(table, true) + return nil +} + +func (c APIKeysCmd) Get(ctx context.Context, in APIKeysGetInput) error { + if err := validateJSONOutput(in.Output); err != nil { + return err + } + + key, err := c.apiKeys.Get(ctx, in.ID) + if err != nil { + return util.CleanedUpSdkError{Err: err} + } + + if in.Output == "json" { + return util.PrintPrettyJSON(key) + } + + renderAPIKeyDetails(key) + return nil +} + +func (c APIKeysCmd) Update(ctx context.Context, in APIKeysUpdateInput) error { + if err := validateJSONOutput(in.Output); err != nil { + return err + } + if in.Name == "" { + return fmt.Errorf("--name is required") + } + + key, err := c.apiKeys.Update(ctx, in.ID, kernel.APIKeyUpdateParams{Name: in.Name}) + if err != nil { + return util.CleanedUpSdkError{Err: err} + } + + if in.Output == "json" { + return util.PrintPrettyJSON(key) + } + + pterm.Success.Printf("Updated API key: %s\n", key.ID) + return nil +} + +func (c APIKeysCmd) Delete(ctx context.Context, in APIKeysDeleteInput) error { + if !in.SkipConfirm { + msg := fmt.Sprintf("Are you sure you want to delete API key '%s'?", in.ID) + pterm.DefaultInteractiveConfirm.DefaultText = msg + ok, _ := pterm.DefaultInteractiveConfirm.Show() + if !ok { + pterm.Info.Println("Deletion cancelled") + return nil + } + } + + if err := c.apiKeys.Delete(ctx, in.ID); err != nil { + if util.IsNotFound(err) { + return fmt.Errorf("API key %q not found", in.ID) + } + return util.CleanedUpSdkError{Err: err} + } + + pterm.Success.Printf("Deleted API key: %s\n", in.ID) + return nil +} + +func renderCreatedAPIKey(key *kernel.CreatedAPIKey) { + rows := pterm.TableData{ + {"Field", "Value"}, + {"ID", key.ID}, + {"Name", key.Name}, + {"Key", key.Key}, + {"Scope", formatAPIKeyScope(key.APIKey)}, + {"Project", formatAPIKeyProject(key.APIKey)}, + {"Masked Key", key.MaskedKey}, + {"Expires At", formatAPIKeyExpiresAt(key.APIKey)}, + } + PrintTableNoPad(rows, true) +} + +func renderAPIKeyDetails(key *kernel.APIKey) { + rows := pterm.TableData{ + {"Field", "Value"}, + {"ID", key.ID}, + {"Name", key.Name}, + {"Scope", formatAPIKeyScope(*key)}, + {"Project", formatAPIKeyProject(*key)}, + {"Masked Key", key.MaskedKey}, + {"Created By", formatAPIKeyCreator(*key)}, + {"Expires At", formatAPIKeyExpiresAt(*key)}, + {"Created At", util.FormatLocal(key.CreatedAt)}, + } + PrintTableNoPad(rows, true) +} + +func formatAPIKeyProject(key kernel.APIKey) string { + if key.JSON.ProjectName.Valid() && key.ProjectName != "" { + return key.ProjectName + } + if key.JSON.ProjectID.Valid() && key.ProjectID != "" { + return key.ProjectID + } + return "-" +} + +func formatAPIKeyScope(key kernel.APIKey) string { + if key.JSON.ProjectID.Valid() && key.ProjectID != "" { + return "Project" + } + return "Org" +} + +func formatAPIKeyCreator(key kernel.APIKey) string { + if key.CreatedBy.JSON.Name.Valid() && key.CreatedBy.Name != "" { + return key.CreatedBy.Name + } + if key.CreatedBy.JSON.Email.Valid() && key.CreatedBy.Email != "" { + return key.CreatedBy.Email + } + return "-" +} + +func formatAPIKeyExpiresAt(key kernel.APIKey) string { + if !key.JSON.ExpiresAt.Valid() { + return "Never" + } + return util.FormatLocal(key.ExpiresAt) +} + +func getAPIKeysHandler(cmd *cobra.Command) APIKeysCmd { + client := getKernelClient(cmd) + return APIKeysCmd{apiKeys: &client.APIKeys} +} + +func runAPIKeysCreate(cmd *cobra.Command, args []string) error { + c := getAPIKeysHandler(cmd) + name, _ := cmd.Flags().GetString("name") + daysToExpire, _ := cmd.Flags().GetInt64("days-to-expire") + projectID, _ := cmd.Flags().GetString("project-id") + output, _ := cmd.Flags().GetString("output") + + return c.Create(cmd.Context(), APIKeysCreateInput{ + Name: name, + DaysToExpire: Int64Flag{ + Set: cmd.Flags().Changed("days-to-expire"), + Value: daysToExpire, + }, + ProjectID: projectID, + Output: output, + }) +} + +func runAPIKeysList(cmd *cobra.Command, args []string) error { + c := getAPIKeysHandler(cmd) + limit, _ := cmd.Flags().GetInt("limit") + offset, _ := cmd.Flags().GetInt("offset") + output, _ := cmd.Flags().GetString("output") + return c.List(cmd.Context(), APIKeysListInput{ + Limit: limit, + Offset: offset, + Output: output, + }) +} + +func runAPIKeysGet(cmd *cobra.Command, args []string) error { + c := getAPIKeysHandler(cmd) + output, _ := cmd.Flags().GetString("output") + return c.Get(cmd.Context(), APIKeysGetInput{ID: args[0], Output: output}) +} + +func runAPIKeysUpdate(cmd *cobra.Command, args []string) error { + c := getAPIKeysHandler(cmd) + name, _ := cmd.Flags().GetString("name") + output, _ := cmd.Flags().GetString("output") + return c.Update(cmd.Context(), APIKeysUpdateInput{ID: args[0], Name: name, Output: output}) +} + +func runAPIKeysDelete(cmd *cobra.Command, args []string) error { + c := getAPIKeysHandler(cmd) + skip, _ := cmd.Flags().GetBool("yes") + return c.Delete(cmd.Context(), APIKeysDeleteInput{ID: args[0], SkipConfirm: skip}) +} + +var apiKeysCmd = &cobra.Command{ + Use: "api-keys", + Aliases: []string{"api-key", "apikeys", "apikey"}, + Short: "Manage API keys", + Run: func(cmd *cobra.Command, args []string) { + _ = cmd.Help() + }, +} + +var apiKeysCreateCmd = &cobra.Command{ + Use: "create", + Short: "Create an API key", + Long: "Create an API key.\n\nBy default the new key is org-wide. Use --project-id to create a key whose own access is scoped to that project. The global --project flag only scopes this CLI request.", + Args: cobra.NoArgs, + RunE: runAPIKeysCreate, +} + +var apiKeysListCmd = &cobra.Command{ + Use: "list", + Short: "List API keys", + Args: cobra.NoArgs, + RunE: runAPIKeysList, +} + +var apiKeysGetCmd = &cobra.Command{ + Use: "get ", + Short: "Get an API key", + Args: cobra.ExactArgs(1), + RunE: runAPIKeysGet, +} + +var apiKeysUpdateCmd = &cobra.Command{ + Use: "update ", + Short: "Update an API key", + Args: cobra.ExactArgs(1), + RunE: runAPIKeysUpdate, +} + +var apiKeysDeleteCmd = &cobra.Command{ + Use: "delete ", + Short: "Delete an API key", + Args: cobra.ExactArgs(1), + RunE: runAPIKeysDelete, +} + +func init() { + addJSONOutputFlag(apiKeysCreateCmd) + apiKeysCreateCmd.Flags().String("name", "", "API key name (required)") + apiKeysCreateCmd.Flags().Int64("days-to-expire", 0, "Number of days until expiry (1-3650); omit for never") + apiKeysCreateCmd.Flags().String("project-id", "", "Create a project-scoped API key for this project ID; omit for org-wide") + _ = apiKeysCreateCmd.MarkFlagRequired("name") + + addJSONOutputFlag(apiKeysListCmd) + apiKeysListCmd.Flags().Int("limit", 0, "Maximum number of results to return") + apiKeysListCmd.Flags().Int("offset", 0, "Number of results to skip") + + addJSONOutputFlag(apiKeysGetCmd) + + addJSONOutputFlag(apiKeysUpdateCmd) + apiKeysUpdateCmd.Flags().String("name", "", "New API key name (required)") + _ = apiKeysUpdateCmd.MarkFlagRequired("name") + + apiKeysDeleteCmd.Flags().BoolP("yes", "y", false, "Skip confirmation prompt") + + apiKeysCmd.AddCommand(apiKeysCreateCmd) + apiKeysCmd.AddCommand(apiKeysListCmd) + apiKeysCmd.AddCommand(apiKeysGetCmd) + apiKeysCmd.AddCommand(apiKeysUpdateCmd) + apiKeysCmd.AddCommand(apiKeysDeleteCmd) + + rootCmd.AddCommand(apiKeysCmd) +} diff --git a/cmd/api_keys_test.go b/cmd/api_keys_test.go new file mode 100644 index 0000000..70ce814 --- /dev/null +++ b/cmd/api_keys_test.go @@ -0,0 +1,269 @@ +package cmd + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "testing" + + "github.com/kernel/kernel-go-sdk" + "github.com/kernel/kernel-go-sdk/option" + "github.com/kernel/kernel-go-sdk/packages/pagination" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type FakeAPIKeysService struct { + NewFunc func(ctx context.Context, body kernel.APIKeyNewParams, opts ...option.RequestOption) (*kernel.CreatedAPIKey, error) + GetFunc func(ctx context.Context, id string, opts ...option.RequestOption) (*kernel.APIKey, error) + UpdateFunc func(ctx context.Context, id string, body kernel.APIKeyUpdateParams, opts ...option.RequestOption) (*kernel.APIKey, error) + ListFunc func(ctx context.Context, query kernel.APIKeyListParams, opts ...option.RequestOption) (*pagination.OffsetPagination[kernel.APIKey], error) + DeleteFunc func(ctx context.Context, id string, opts ...option.RequestOption) error +} + +func (f *FakeAPIKeysService) New(ctx context.Context, body kernel.APIKeyNewParams, opts ...option.RequestOption) (*kernel.CreatedAPIKey, error) { + if f.NewFunc != nil { + return f.NewFunc(ctx, body, opts...) + } + return createdAPIKeyFromJSON(`{"id":"key_123","name":"default","key":"sk_test","masked_key":"sk_...test","created_at":"2026-05-27T12:00:00Z","created_by":{"id":"user_123","email":"dev@example.com","name":"Dev"},"expires_at":null,"project_id":null,"project_name":null}`), nil +} + +func (f *FakeAPIKeysService) Get(ctx context.Context, id string, opts ...option.RequestOption) (*kernel.APIKey, error) { + if f.GetFunc != nil { + return f.GetFunc(ctx, id, opts...) + } + return apiKeyFromJSON(`{"id":"` + id + `","name":"default","masked_key":"sk_...test","created_at":"2026-05-27T12:00:00Z","created_by":{"id":"user_123","email":"dev@example.com","name":"Dev"},"expires_at":null,"project_id":null,"project_name":null}`), nil +} + +func (f *FakeAPIKeysService) Update(ctx context.Context, id string, body kernel.APIKeyUpdateParams, opts ...option.RequestOption) (*kernel.APIKey, error) { + if f.UpdateFunc != nil { + return f.UpdateFunc(ctx, id, body, opts...) + } + return apiKeyFromJSON(`{"id":"` + id + `","name":"` + body.Name + `","masked_key":"sk_...test","created_at":"2026-05-27T12:00:00Z","created_by":{"id":"user_123","email":"dev@example.com","name":"Dev"},"expires_at":null,"project_id":null,"project_name":null}`), nil +} + +func (f *FakeAPIKeysService) List(ctx context.Context, query kernel.APIKeyListParams, opts ...option.RequestOption) (*pagination.OffsetPagination[kernel.APIKey], error) { + if f.ListFunc != nil { + return f.ListFunc(ctx, query, opts...) + } + return &pagination.OffsetPagination[kernel.APIKey]{Items: []kernel.APIKey{}}, nil +} + +func (f *FakeAPIKeysService) Delete(ctx context.Context, id string, opts ...option.RequestOption) error { + if f.DeleteFunc != nil { + return f.DeleteFunc(ctx, id, opts...) + } + return nil +} + +func createdAPIKeyFromJSON(raw string) *kernel.CreatedAPIKey { + var key kernel.CreatedAPIKey + if err := json.Unmarshal([]byte(raw), &key); err != nil { + panic(err) + } + return &key +} + +func apiKeyFromJSON(raw string) *kernel.APIKey { + var key kernel.APIKey + if err := json.Unmarshal([]byte(raw), &key); err != nil { + panic(err) + } + return &key +} + +func TestAPIKeysCreateBuildsParamsAndPrintsPlaintextKey(t *testing.T) { + buf := capturePtermOutput(t) + fake := &FakeAPIKeysService{ + NewFunc: func(ctx context.Context, body kernel.APIKeyNewParams, opts ...option.RequestOption) (*kernel.CreatedAPIKey, error) { + assert.Equal(t, "ci", body.Name) + assert.True(t, body.DaysToExpire.Valid()) + assert.Equal(t, int64(30), body.DaysToExpire.Value) + assert.True(t, body.ProjectID.Valid()) + assert.Equal(t, "proj_123", body.ProjectID.Value) + return createdAPIKeyFromJSON(`{"id":"key_123","name":"ci","key":"sk_live_123","masked_key":"sk_...123","created_at":"2026-05-27T12:00:00Z","created_by":{"id":"user_123","email":"dev@example.com","name":"Dev"},"expires_at":null,"project_id":"proj_123","project_name":"Prod"}`), nil + }, + } + c := APIKeysCmd{apiKeys: fake} + + err := c.Create(context.Background(), APIKeysCreateInput{ + Name: "ci", + DaysToExpire: Int64Flag{ + Set: true, + Value: 30, + }, + ProjectID: "proj_123", + }) + require.NoError(t, err) + + out := buf.String() + assert.Contains(t, out, "Created API key: key_123") + assert.Contains(t, out, "key_123") + assert.Contains(t, out, "sk_live_123") + assert.Contains(t, out, "Prod") +} + +func TestAPIKeysCreateRejectsInvalidDaysToExpire(t *testing.T) { + c := APIKeysCmd{apiKeys: &FakeAPIKeysService{}} + + err := c.Create(context.Background(), APIKeysCreateInput{ + Name: "ci", + DaysToExpire: Int64Flag{ + Set: true, + Value: 0, + }, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "--days-to-expire must be between 1 and 3650") +} + +func TestAPIKeysRejectInvalidOutputBeforeCallingAPI(t *testing.T) { + fake := &FakeAPIKeysService{ + NewFunc: func(ctx context.Context, body kernel.APIKeyNewParams, opts ...option.RequestOption) (*kernel.CreatedAPIKey, error) { + t.Fatal("New should not be called") + return nil, nil + }, + GetFunc: func(ctx context.Context, id string, opts ...option.RequestOption) (*kernel.APIKey, error) { + t.Fatal("Get should not be called") + return nil, nil + }, + UpdateFunc: func(ctx context.Context, id string, body kernel.APIKeyUpdateParams, opts ...option.RequestOption) (*kernel.APIKey, error) { + t.Fatal("Update should not be called") + return nil, nil + }, + ListFunc: func(ctx context.Context, query kernel.APIKeyListParams, opts ...option.RequestOption) (*pagination.OffsetPagination[kernel.APIKey], error) { + t.Fatal("List should not be called") + return nil, nil + }, + } + c := APIKeysCmd{apiKeys: fake} + + tests := []struct { + name string + run func() error + }{ + { + name: "create", + run: func() error { return c.Create(context.Background(), APIKeysCreateInput{Name: "ci", Output: "yaml"}) }, + }, + { + name: "list", + run: func() error { return c.List(context.Background(), APIKeysListInput{Output: "yaml"}) }, + }, + { + name: "get", + run: func() error { return c.Get(context.Background(), APIKeysGetInput{ID: "key_123", Output: "yaml"}) }, + }, + { + name: "update", + run: func() error { + return c.Update(context.Background(), APIKeysUpdateInput{ID: "key_123", Name: "ci", Output: "yaml"}) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.run() + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported --output value") + }) + } +} + +func TestAPIKeysListPassesPaginationAndRendersRows(t *testing.T) { + buf := capturePtermOutput(t) + key := *apiKeyFromJSON(`{"id":"key_123","name":"ci","masked_key":"sk_...123","created_at":"2026-05-27T12:00:00Z","created_by":{"id":"user_123","email":"dev@example.com","name":"Dev"},"expires_at":null,"project_id":null,"project_name":null}`) + fake := &FakeAPIKeysService{ + ListFunc: func(ctx context.Context, query kernel.APIKeyListParams, opts ...option.RequestOption) (*pagination.OffsetPagination[kernel.APIKey], error) { + assert.True(t, query.Limit.Valid()) + assert.Equal(t, int64(10), query.Limit.Value) + assert.True(t, query.Offset.Valid()) + assert.Equal(t, int64(20), query.Offset.Value) + return &pagination.OffsetPagination[kernel.APIKey]{Items: []kernel.APIKey{key}}, nil + }, + } + c := APIKeysCmd{apiKeys: fake} + + err := c.List(context.Background(), APIKeysListInput{Limit: 10, Offset: 20}) + require.NoError(t, err) + + out := buf.String() + assert.Contains(t, out, "key_123") + assert.Contains(t, out, "ci") + assert.Contains(t, out, "Never") +} + +func TestAPIKeysUpdateRequiresName(t *testing.T) { + c := APIKeysCmd{apiKeys: &FakeAPIKeysService{}} + err := c.Update(context.Background(), APIKeysUpdateInput{ID: "key_123"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "--name is required") +} + +func TestAPIKeysUpdatePrintsTerseSuccess(t *testing.T) { + buf := capturePtermOutput(t) + fake := &FakeAPIKeysService{ + UpdateFunc: func(ctx context.Context, id string, body kernel.APIKeyUpdateParams, opts ...option.RequestOption) (*kernel.APIKey, error) { + assert.Equal(t, "key_123", id) + assert.Equal(t, "renamed", body.Name) + return apiKeyFromJSON(`{"id":"key_123","name":"renamed","masked_key":"sk_...123","created_at":"2026-05-27T12:00:00Z","created_by":{"id":"user_123","email":"dev@example.com","name":"Dev"},"expires_at":null,"project_id":null,"project_name":null}`), nil + }, + } + c := APIKeysCmd{apiKeys: fake} + + err := c.Update(context.Background(), APIKeysUpdateInput{ID: "key_123", Name: "renamed"}) + require.NoError(t, err) + + out := buf.String() + assert.Contains(t, out, "Updated API key: key_123") + assert.NotContains(t, out, "Masked Key") + assert.NotContains(t, out, "sk_...123") +} + +func TestAPIKeysDeleteSkipsConfirmation(t *testing.T) { + buf := capturePtermOutput(t) + deleted := false + fake := &FakeAPIKeysService{ + DeleteFunc: func(ctx context.Context, id string, opts ...option.RequestOption) error { + assert.Equal(t, "key_123", id) + deleted = true + return nil + }, + } + c := APIKeysCmd{apiKeys: fake} + + err := c.Delete(context.Background(), APIKeysDeleteInput{ID: "key_123", SkipConfirm: true}) + require.NoError(t, err) + assert.True(t, deleted) + assert.Contains(t, buf.String(), "Deleted API key: key_123") +} + +func TestAPIKeysDeleteReturnsNotFoundError(t *testing.T) { + fake := &FakeAPIKeysService{ + DeleteFunc: func(ctx context.Context, id string, opts ...option.RequestOption) error { + assert.Equal(t, "missing_key", id) + return &kernel.Error{StatusCode: http.StatusNotFound} + }, + } + c := APIKeysCmd{apiKeys: fake} + + err := c.Delete(context.Background(), APIKeysDeleteInput{ID: "missing_key", SkipConfirm: true}) + require.Error(t, err) + assert.Contains(t, err.Error(), `API key "missing_key" not found`) +} + +func TestAPIKeysDeleteReturnsAPIError(t *testing.T) { + fake := &FakeAPIKeysService{ + DeleteFunc: func(ctx context.Context, id string, opts ...option.RequestOption) error { + return errors.New("API error") + }, + } + c := APIKeysCmd{apiKeys: fake} + + err := c.Delete(context.Background(), APIKeysDeleteInput{ID: "key_123", SkipConfirm: true}) + require.Error(t, err) + assert.Contains(t, err.Error(), "API error") +} diff --git a/cmd/app.go b/cmd/app.go index 08273b5..7271dcc 100644 --- a/cmd/app.go +++ b/cmd/app.go @@ -58,11 +58,11 @@ func init() { appListCmd.Flags().Int("limit", 20, "Max apps to return (default 20)") appListCmd.Flags().Int("per-page", 20, "Items per page (alias of --limit)") appListCmd.Flags().Int("page", 1, "Page number (1-based)") - appListCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(appListCmd) // Limit rows returned for app history (0 = all) appHistoryCmd.Flags().Int("limit", 20, "Max deployments to return (default 20)") - appHistoryCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(appHistoryCmd) } func runAppList(cmd *cobra.Command, args []string) error { @@ -75,8 +75,8 @@ func runAppList(cmd *cobra.Command, args []string) error { page, _ := cmd.Flags().GetInt("page") output, _ := cmd.Flags().GetString("output") - if output != "" && output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(output); err != nil { + return err } // Determine pagination inputs: prefer page/per-page if provided; else map legacy --limit @@ -303,8 +303,8 @@ func runAppHistory(cmd *cobra.Command, args []string) error { lim, _ := cmd.Flags().GetInt("limit") output, _ := cmd.Flags().GetString("output") - if output != "" && output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(output); err != nil { + return err } if output != "json" { diff --git a/cmd/auth_connections.go b/cmd/auth_connections.go index 72d5c5d..153a498 100644 --- a/cmd/auth_connections.go +++ b/cmd/auth_connections.go @@ -113,8 +113,8 @@ type AuthConnectionFollowInput struct { } func (c AuthConnectionCmd) Create(ctx context.Context, in AuthConnectionCreateInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if in.Domain == "" { @@ -207,7 +207,7 @@ func printManagedAuthSummary(auth *kernel.ManagedAuth) { {"Can Reauth", fmt.Sprintf("%t", auth.CanReauth)}, } if auth.CanReauthReason != "" { - tableData = append(tableData, []string{"Can Reauth Reason", auth.CanReauthReason}) + tableData = append(tableData, []string{"Can Reauth Reason", string(auth.CanReauthReason)}) } if auth.Credential.Name != "" { tableData = append(tableData, []string{"Credential Name", auth.Credential.Name}) @@ -222,8 +222,8 @@ func printManagedAuthSummary(auth *kernel.ManagedAuth) { } func (c AuthConnectionCmd) Update(ctx context.Context, in AuthConnectionUpdateInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } params := kernel.AuthConnectionUpdateParams{ @@ -304,8 +304,8 @@ func (c AuthConnectionCmd) Update(ctx context.Context, in AuthConnectionUpdateIn } func (c AuthConnectionCmd) Get(ctx context.Context, in AuthConnectionGetInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } auth, err := c.svc.Get(ctx, in.ID) @@ -326,7 +326,7 @@ func (c AuthConnectionCmd) Get(ctx context.Context, in AuthConnectionGetInput) e {"Can Reauth", fmt.Sprintf("%t", auth.CanReauth)}, } if auth.CanReauthReason != "" { - tableData = append(tableData, []string{"Can Reauth Reason", auth.CanReauthReason}) + tableData = append(tableData, []string{"Can Reauth Reason", string(auth.CanReauthReason)}) } if auth.Credential.Name != "" { tableData = append(tableData, []string{"Credential Name", auth.Credential.Name}) @@ -426,8 +426,8 @@ func (c AuthConnectionCmd) Get(ctx context.Context, in AuthConnectionGetInput) e } func (c AuthConnectionCmd) List(ctx context.Context, in AuthConnectionListInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } params := kernel.AuthConnectionListParams{} @@ -512,8 +512,8 @@ func (c AuthConnectionCmd) Delete(ctx context.Context, in AuthConnectionDeleteIn } func (c AuthConnectionCmd) Login(ctx context.Context, in AuthConnectionLoginInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } params := kernel.AuthConnectionLoginParams{} @@ -558,8 +558,8 @@ func (c AuthConnectionCmd) Login(ctx context.Context, in AuthConnectionLoginInpu } func (c AuthConnectionCmd) Submit(ctx context.Context, in AuthConnectionSubmitInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } // Validate that we have some input to submit @@ -653,8 +653,8 @@ func (c AuthConnectionCmd) Submit(ctx context.Context, in AuthConnectionSubmitIn } func (c AuthConnectionCmd) Follow(ctx context.Context, in AuthConnectionFollowInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } stream := c.svc.FollowStreaming(ctx, in.ID) @@ -797,7 +797,7 @@ var authConnectionsFollowCmd = &cobra.Command{ func init() { // Create flags - authConnectionsCreateCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(authConnectionsCreateCmd) authConnectionsCreateCmd.Flags().String("domain", "", "Target domain for authentication (required)") authConnectionsCreateCmd.Flags().String("profile-name", "", "Name of the profile to manage (required)") authConnectionsCreateCmd.Flags().String("login-url", "", "Optional login page URL to skip discovery") @@ -815,10 +815,10 @@ func init() { authConnectionsCreateCmd.MarkFlagsMutuallyExclusive("credential-name", "credential-provider") // Get flags - authConnectionsGetCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(authConnectionsGetCmd) // Update flags - authConnectionsUpdateCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(authConnectionsUpdateCmd) authConnectionsUpdateCmd.Flags().String("login-url", "", "Login page URL (set to empty string to clear)") authConnectionsUpdateCmd.Flags().StringSlice("allowed-domain", []string{}, "Additional allowed domains (replaces existing list)") authConnectionsUpdateCmd.Flags().String("credential-name", "", "Kernel credential name to use") @@ -834,7 +834,7 @@ func init() { authConnectionsUpdateCmd.MarkFlagsMutuallyExclusive("save-credentials", "no-save-credentials") // List flags - authConnectionsListCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(authConnectionsListCmd) authConnectionsListCmd.Flags().String("domain", "", "Filter by domain") authConnectionsListCmd.Flags().String("profile-name", "", "Filter by profile name") authConnectionsListCmd.Flags().Int("limit", 0, "Maximum number of results to return") @@ -844,12 +844,12 @@ func init() { authConnectionsDeleteCmd.Flags().BoolP("yes", "y", false, "Skip confirmation prompt") // Login flags - authConnectionsLoginCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(authConnectionsLoginCmd) authConnectionsLoginCmd.Flags().String("proxy-id", "", "Proxy ID to use for this login") authConnectionsLoginCmd.Flags().String("proxy-name", "", "Proxy name to use for this login") // Submit flags - authConnectionsSubmitCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(authConnectionsSubmitCmd) authConnectionsSubmitCmd.Flags().StringArray("field", []string{}, "Field name=value pair (repeatable)") authConnectionsSubmitCmd.Flags().String("mfa-option-id", "", "MFA option ID if user selected an MFA method") authConnectionsSubmitCmd.Flags().String("sign-in-option-id", "", "Sign-in option ID if the flow returned non-MFA choices") @@ -857,7 +857,7 @@ func init() { authConnectionsSubmitCmd.Flags().String("sso-provider", "", "SSO provider if user chose an SSO button by provider (e.g. google, github)") // Follow flags - authConnectionsFollowCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(authConnectionsFollowCmd) // Wire up commands authConnectionsCmd.AddCommand(authConnectionsCreateCmd) diff --git a/cmd/browser_pools.go b/cmd/browser_pools.go index ca5fa96..eaacc68 100644 --- a/cmd/browser_pools.go +++ b/cmd/browser_pools.go @@ -33,8 +33,8 @@ type BrowserPoolsListInput struct { } func (c BrowserPoolsCmd) List(ctx context.Context, in BrowserPoolsListInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } pools, err := c.client.List(ctx) @@ -93,8 +93,8 @@ type BrowserPoolsCreateInput struct { } func (c BrowserPoolsCmd) Create(ctx context.Context, in BrowserPoolsCreateInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if err := validateStartURLFlag(in.StartURL); err != nil { return err @@ -173,8 +173,8 @@ type BrowserPoolsGetInput struct { } func (c BrowserPoolsCmd) Get(ctx context.Context, in BrowserPoolsGetInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } pool, err := c.client.Get(ctx, in.IDOrName) @@ -234,8 +234,8 @@ type BrowserPoolsUpdateInput struct { } func (c BrowserPoolsCmd) Update(ctx context.Context, in BrowserPoolsUpdateInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if err := validateStartURLFlag(in.StartURL); err != nil { return err @@ -342,8 +342,8 @@ type BrowserPoolsAcquireInput struct { } func (c BrowserPoolsCmd) Acquire(ctx context.Context, in BrowserPoolsAcquireInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } params := kernel.BrowserPoolAcquireParams{} @@ -481,9 +481,9 @@ var browserPoolsFlushCmd = &cobra.Command{ } func init() { - browserPoolsListCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(browserPoolsListCmd) - browserPoolsCreateCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(browserPoolsCreateCmd) browserPoolsCreateCmd.Flags().String("name", "", "Optional unique name for the pool") browserPoolsCreateCmd.Flags().Int64("size", 0, "Number of browsers in the pool") _ = browserPoolsCreateCmd.MarkFlagRequired("size") @@ -500,7 +500,7 @@ func init() { browserPoolsCreateCmd.Flags().StringSlice("extension", []string{}, "Extension IDs or names") browserPoolsCreateCmd.Flags().String("viewport", "", "Viewport size (e.g. 1280x800)") - browserPoolsGetCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(browserPoolsGetCmd) browserPoolsUpdateCmd.Flags().String("name", "", "Update the pool name") browserPoolsUpdateCmd.Flags().Int64("size", 0, "Number of browsers in the pool") @@ -518,12 +518,12 @@ func init() { browserPoolsUpdateCmd.Flags().StringSlice("extension", []string{}, "Extension IDs or names") browserPoolsUpdateCmd.Flags().String("viewport", "", "Viewport size (e.g. 1280x800)") browserPoolsUpdateCmd.Flags().Bool("discard-all-idle", false, "Discard all idle browsers") - browserPoolsUpdateCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(browserPoolsUpdateCmd) browserPoolsDeleteCmd.Flags().Bool("force", false, "Force delete even if browsers are leased") browserPoolsAcquireCmd.Flags().Int64("timeout", 0, "Acquire timeout in seconds") - browserPoolsAcquireCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(browserPoolsAcquireCmd) browserPoolsReleaseCmd.Flags().String("session-id", "", "Browser session ID to release") _ = browserPoolsReleaseCmd.MarkFlagRequired("session-id") diff --git a/cmd/browsers.go b/cmd/browsers.go index 7457d51..d73e115 100644 --- a/cmd/browsers.go +++ b/cmd/browsers.go @@ -108,18 +108,6 @@ type BrowserComputerService interface { WriteClipboard(ctx context.Context, id string, body kernel.BrowserComputerWriteClipboardParams, opts ...option.RequestOption) (err error) } -// BoolFlag captures whether a boolean flag was set explicitly and its value. -type BoolFlag struct { - Set bool - Value bool -} - -// Int64Flag captures whether an int64 flag was set explicitly and its value. -type Int64Flag struct { - Set bool - Value int64 -} - // Regular expression to validate CUID2 identifiers (starts with a letter, 24 lowercase alphanumeric characters). var cuidRegex = regexp.MustCompile(`^[a-z][a-z0-9]{23}$`) @@ -239,8 +227,8 @@ type BrowsersListInput struct { } func (b BrowsersCmd) List(ctx context.Context, in BrowsersListInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } params := kernel.BrowserListParams{} @@ -336,8 +324,8 @@ func (b BrowsersCmd) List(ctx context.Context, in BrowsersListInput) error { } func (b BrowsersCmd) Create(ctx context.Context, in BrowsersCreateInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if err := validateStartURLFlag(in.StartURL); err != nil { return err @@ -473,8 +461,8 @@ func (b BrowsersCmd) Delete(ctx context.Context, in BrowsersDeleteInput) error { } func (b BrowsersCmd) View(ctx context.Context, in BrowsersViewInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } browser, err := b.browsers.Get(ctx, in.Identifier, kernel.BrowserGetParams{}) @@ -507,8 +495,8 @@ func (b BrowsersCmd) View(ctx context.Context, in BrowsersViewInput) error { } func (b BrowsersCmd) Get(ctx context.Context, in BrowsersGetInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } query := kernel.BrowserGetParams{} @@ -559,8 +547,8 @@ func (b BrowsersCmd) Get(ctx context.Context, in BrowsersGetInput) error { } func (b BrowsersCmd) Update(ctx context.Context, in BrowsersUpdateInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } // Validate profile selection: at most one of profile-id or profile-name must be provided @@ -1066,8 +1054,8 @@ func (b BrowsersCmd) ComputerBatch(ctx context.Context, in BrowsersComputerBatch } func (b BrowsersCmd) ComputerReadClipboard(ctx context.Context, in BrowsersComputerReadClipboardInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if b.computer == nil { pterm.Error.Println("computer service not available") @@ -1131,8 +1119,8 @@ type BrowsersReplaysDownloadInput struct { } func (b BrowsersCmd) ReplaysList(ctx context.Context, in BrowsersReplaysListInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } items, err := b.replays.List(ctx, in.Identifier) @@ -1161,8 +1149,8 @@ func (b BrowsersCmd) ReplaysList(ctx context.Context, in BrowsersReplaysListInpu } func (b BrowsersCmd) ReplaysStart(ctx context.Context, in BrowsersReplaysStartInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } br, err := b.browsers.Get(ctx, in.Identifier, kernel.BrowserGetParams{}) @@ -1307,8 +1295,8 @@ type BrowsersPlaywrightExecuteInput struct { } func (b BrowsersCmd) PlaywrightExecute(ctx context.Context, in BrowsersPlaywrightExecuteInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if b.playwright == nil { @@ -1357,8 +1345,8 @@ func (b BrowsersCmd) PlaywrightExecute(ctx context.Context, in BrowsersPlaywrigh } func (b BrowsersCmd) ProcessExec(ctx context.Context, in BrowsersProcessExecInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if b.process == nil { @@ -1424,8 +1412,8 @@ func (b BrowsersCmd) ProcessExec(ctx context.Context, in BrowsersProcessExecInpu } func (b BrowsersCmd) ProcessSpawn(ctx context.Context, in BrowsersProcessSpawnInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if b.process == nil { @@ -1573,8 +1561,8 @@ func (b BrowsersCmd) ProcessResize(ctx context.Context, in BrowsersProcessResize // FS Watch func (b BrowsersCmd) FSWatchStart(ctx context.Context, in BrowsersFSWatchStartInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if b.fsWatch == nil { @@ -1814,8 +1802,8 @@ func (b BrowsersCmd) FSDownloadDirZip(ctx context.Context, in BrowsersFSDownload } func (b BrowsersCmd) FSFileInfo(ctx context.Context, in BrowsersFSFileInfoInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if b.fs == nil { @@ -1841,8 +1829,8 @@ func (b BrowsersCmd) FSFileInfo(ctx context.Context, in BrowsersFSFileInfoInput) } func (b BrowsersCmd) FSListFiles(ctx context.Context, in BrowsersFSListFilesInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if b.fs == nil { @@ -2218,7 +2206,7 @@ Note: Profiles can only be loaded into sessions that don't already have a profil func init() { // list flags - browsersListCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(browsersListCmd) browsersListCmd.Flags().Bool("include-deleted", false, "DEPRECATED: Use --status instead. Include soft-deleted browser sessions in the results") browsersListCmd.Flags().String("status", "", "Filter by status: 'active' (default), 'deleted', or 'all'") browsersListCmd.Flags().Int("limit", 0, "Maximum number of results to return (default 20, max 100)") @@ -2226,14 +2214,14 @@ func init() { browsersListCmd.Flags().String("query", "", "Search browsers by session ID, profile ID, or proxy ID") // get flags - browsersGetCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(browsersGetCmd) browsersGetCmd.Flags().Bool("include-deleted", false, "Include soft-deleted browser sessions in the lookup") // view flags - browsersViewCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(browsersViewCmd) // update flags - browsersUpdateCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(browsersUpdateCmd) browsersUpdateCmd.Flags().String("proxy-id", "", "ID of the proxy to use for the browser session") browsersUpdateCmd.Flags().Bool("clear-proxy", false, "Remove the proxy from the browser session") browsersUpdateCmd.Flags().String("profile-id", "", "Profile ID to load into the browser session (mutually exclusive with --profile-name)") @@ -2266,11 +2254,11 @@ func init() { // replays replaysRoot := &cobra.Command{Use: "replays", Short: "Manage browser replays"} replaysList := &cobra.Command{Use: "list ", Short: "List replays for a browser", Args: cobra.ExactArgs(1), RunE: runBrowsersReplaysList} - replaysList.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(replaysList) replaysStart := &cobra.Command{Use: "start ", Short: "Start a replay recording", Args: cobra.ExactArgs(1), RunE: runBrowsersReplaysStart} replaysStart.Flags().Int("framerate", 0, "Recording framerate (fps)") replaysStart.Flags().Int("max-duration", 0, "Maximum duration in seconds") - replaysStart.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(replaysStart) replaysStop := &cobra.Command{Use: "stop ", Short: "Stop a replay recording", Args: cobra.ExactArgs(2), RunE: runBrowsersReplaysStop} replaysDownload := &cobra.Command{Use: "download ", Short: "Download a replay video", Args: cobra.ExactArgs(2), RunE: runBrowsersReplaysDownload} replaysDownload.Flags().StringP("output-file", "f", "", "Output file path for the replay video") @@ -2286,7 +2274,7 @@ func init() { procExec.Flags().Int("timeout", 0, "Timeout in seconds") procExec.Flags().String("as-user", "", "Run as user") procExec.Flags().Bool("as-root", false, "Run as root") - procExec.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(procExec) procSpawn := &cobra.Command{Use: "spawn [--] [command...]", Short: "Execute a command asynchronously", Args: cobra.MinimumNArgs(1), RunE: runBrowsersProcessSpawn} procSpawn.Flags().String("command", "", "Command to execute (optional; if omitted, trailing args are executed via /bin/bash -c)") procSpawn.Flags().StringSlice("args", []string{}, "Command arguments") @@ -2294,7 +2282,7 @@ func init() { procSpawn.Flags().Int("timeout", 0, "Timeout in seconds") procSpawn.Flags().String("as-user", "", "Run as user") procSpawn.Flags().Bool("as-root", false, "Run as root") - procSpawn.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(procSpawn) procKill := &cobra.Command{Use: "kill ", Short: "Send a signal to a process", Args: cobra.ExactArgs(2), RunE: runBrowsersProcessKill} procKill.Flags().String("signal", "TERM", "Signal to send (TERM, KILL, INT, HUP)") procStatus := &cobra.Command{Use: "status ", Short: "Get process status", Args: cobra.ExactArgs(2), RunE: runBrowsersProcessStatus} @@ -2329,11 +2317,11 @@ func init() { fsFileInfo := &cobra.Command{Use: "file-info ", Short: "Get file or directory info", Args: cobra.ExactArgs(1), RunE: runBrowsersFSFileInfo} fsFileInfo.Flags().String("path", "", "Absolute file or directory path") _ = fsFileInfo.MarkFlagRequired("path") - fsFileInfo.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(fsFileInfo) fsListFiles := &cobra.Command{Use: "list-files ", Short: "List files in a directory", Args: cobra.ExactArgs(1), RunE: runBrowsersFSListFiles} fsListFiles.Flags().String("path", "", "Absolute directory path") _ = fsListFiles.MarkFlagRequired("path") - fsListFiles.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(fsListFiles) fsMove := &cobra.Command{Use: "move ", Short: "Move or rename a file or directory", Args: cobra.ExactArgs(1), RunE: runBrowsersFSMove} fsMove.Flags().String("src", "", "Absolute source path") fsMove.Flags().String("dest", "", "Absolute destination path") @@ -2378,7 +2366,7 @@ func init() { fsWatchStart.Flags().String("path", "", "Directory to watch (required)") _ = fsWatchStart.MarkFlagRequired("path") fsWatchStart.Flags().Bool("recursive", false, "Watch recursively") - fsWatchStart.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(fsWatchStart) fsWatchStop := &cobra.Command{Use: "stop ", Short: "Stop watching a directory", Args: cobra.ExactArgs(2), RunE: runBrowsersFSWatchStop} fsWatchEvents := &cobra.Command{Use: "events ", Short: "Stream filesystem events", Args: cobra.ExactArgs(2), RunE: runBrowsersFSWatchEvents} fsWatchRoot.AddCommand(fsWatchStart, fsWatchStop, fsWatchEvents) @@ -2461,7 +2449,7 @@ func init() { // computer get-mouse-position computerGetMousePosition := &cobra.Command{Use: "get-mouse-position ", Short: "Get current mouse cursor position", Args: cobra.ExactArgs(1), RunE: runBrowsersComputerGetMousePosition} - computerGetMousePosition.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(computerGetMousePosition) // computer batch computerBatch := &cobra.Command{Use: "batch ", Short: "Execute a batch of computer actions from JSON", Args: cobra.ExactArgs(1), RunE: runBrowsersComputerBatch} @@ -2470,7 +2458,7 @@ func init() { // computer read-clipboard computerReadClipboard := &cobra.Command{Use: "read-clipboard ", Short: "Read text from the browser clipboard", Args: cobra.ExactArgs(1), RunE: runBrowsersComputerReadClipboard} - computerReadClipboard.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(computerReadClipboard) // computer write-clipboard computerWriteClipboard := &cobra.Command{Use: "write-clipboard ", Short: "Write text to the browser clipboard", Args: cobra.ExactArgs(1), RunE: runBrowsersComputerWriteClipboard} @@ -2484,12 +2472,12 @@ func init() { playwrightRoot := &cobra.Command{Use: "playwright", Short: "Playwright operations"} playwrightExecute := &cobra.Command{Use: "execute [code]", Short: "Execute Playwright/TypeScript code against the browser", Args: cobra.MinimumNArgs(1), RunE: runBrowsersPlaywrightExecute} playwrightExecute.Flags().Int64("timeout", 0, "Maximum execution time in seconds (default per server)") - playwrightExecute.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(playwrightExecute) playwrightRoot.AddCommand(playwrightExecute) browsersCmd.AddCommand(playwrightRoot) // Add flags for create command - browsersCreateCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(browsersCreateCmd) browsersCreateCmd.Flags().BoolP("stealth", "s", false, "Launch browser in stealth mode to avoid detection") browsersCreateCmd.Flags().BoolP("headless", "H", false, "Launch browser without GUI access") browsersCreateCmd.Flags().Bool("gpu", false, "Launch browser with hardware-accelerated GPU rendering") diff --git a/cmd/credential_providers.go b/cmd/credential_providers.go index 22f983d..a0f5808 100644 --- a/cmd/credential_providers.go +++ b/cmd/credential_providers.go @@ -71,8 +71,8 @@ type CredentialProvidersListItemsInput struct { } func (c CredentialProvidersCmd) List(ctx context.Context, in CredentialProvidersListInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } providers, err := c.providers.List(ctx) @@ -109,8 +109,8 @@ func (c CredentialProvidersCmd) List(ctx context.Context, in CredentialProviders } func (c CredentialProvidersCmd) Get(ctx context.Context, in CredentialProvidersGetInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } provider, err := c.providers.Get(ctx, in.ID) @@ -137,8 +137,8 @@ func (c CredentialProvidersCmd) Get(ctx context.Context, in CredentialProvidersG } func (c CredentialProvidersCmd) Create(ctx context.Context, in CredentialProvidersCreateInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if in.ProviderType == "" { @@ -197,8 +197,8 @@ func (c CredentialProvidersCmd) Create(ctx context.Context, in CredentialProvide } func (c CredentialProvidersCmd) Update(ctx context.Context, in CredentialProvidersUpdateInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } params := kernel.CredentialProviderUpdateParams{ @@ -260,8 +260,8 @@ func (c CredentialProvidersCmd) Delete(ctx context.Context, in CredentialProvide } func (c CredentialProvidersCmd) Test(ctx context.Context, in CredentialProvidersTestInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if in.Output != "json" { @@ -298,8 +298,8 @@ func (c CredentialProvidersCmd) Test(ctx context.Context, in CredentialProviders } func (c CredentialProvidersCmd) ListItems(ctx context.Context, in CredentialProvidersListItemsInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if in.Output != "json" { @@ -424,13 +424,13 @@ func init() { credentialProvidersCmd.AddCommand(credentialProvidersListItemsCmd) // List flags - credentialProvidersListCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(credentialProvidersListCmd) // Get flags - credentialProvidersGetCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(credentialProvidersGetCmd) // Create flags - credentialProvidersCreateCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(credentialProvidersCreateCmd) credentialProvidersCreateCmd.Flags().String("name", "", "Human-readable name for this provider instance") credentialProvidersCreateCmd.Flags().String("provider-type", "", "Provider type (e.g., onepassword)") credentialProvidersCreateCmd.Flags().String("token", "", "Service account token for the provider") @@ -440,7 +440,7 @@ func init() { _ = credentialProvidersCreateCmd.MarkFlagRequired("token") // Update flags - credentialProvidersUpdateCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(credentialProvidersUpdateCmd) credentialProvidersUpdateCmd.Flags().String("name", "", "New human-readable name for this provider instance") credentialProvidersUpdateCmd.Flags().String("token", "", "New service account token (to rotate credentials)") credentialProvidersUpdateCmd.Flags().Int64("cache-ttl", 0, "How long to cache credential lists in seconds") @@ -451,10 +451,10 @@ func init() { credentialProvidersDeleteCmd.Flags().BoolP("yes", "y", false, "Skip confirmation prompt") // Test flags - credentialProvidersTestCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(credentialProvidersTestCmd) // ListItems flags - credentialProvidersListItemsCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(credentialProvidersListItemsCmd) } func runCredentialProvidersList(cmd *cobra.Command, args []string) error { diff --git a/cmd/credentials.go b/cmd/credentials.go index 15952c6..b17a37e 100644 --- a/cmd/credentials.go +++ b/cmd/credentials.go @@ -69,8 +69,8 @@ type CredentialsTotpCodeInput struct { } func (c CredentialsCmd) List(ctx context.Context, in CredentialsListInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } params := kernel.CredentialListParams{} @@ -132,8 +132,8 @@ func (c CredentialsCmd) List(ctx context.Context, in CredentialsListInput) error } func (c CredentialsCmd) Get(ctx context.Context, in CredentialsGetInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } cred, err := c.credentials.Get(ctx, in.Identifier) @@ -170,8 +170,8 @@ func (c CredentialsCmd) Get(ctx context.Context, in CredentialsGetInput) error { } func (c CredentialsCmd) Create(ctx context.Context, in CredentialsCreateInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if in.Name == "" { @@ -242,8 +242,8 @@ func (c CredentialsCmd) Create(ctx context.Context, in CredentialsCreateInput) e } func (c CredentialsCmd) Update(ctx context.Context, in CredentialsUpdateInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } params := kernel.CredentialUpdateParams{ @@ -302,8 +302,8 @@ func (c CredentialsCmd) Delete(ctx context.Context, in CredentialsDeleteInput) e } func (c CredentialsCmd) TotpCode(ctx context.Context, in CredentialsTotpCodeInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } resp, err := c.credentials.TotpCode(ctx, in.Identifier) @@ -398,16 +398,16 @@ func init() { credentialsCmd.AddCommand(credentialsTotpCodeCmd) // List flags - credentialsListCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(credentialsListCmd) credentialsListCmd.Flags().String("domain", "", "Filter by domain") credentialsListCmd.Flags().Int("limit", 0, "Maximum number of results to return") credentialsListCmd.Flags().Int("offset", 0, "Number of results to skip") // Get flags - credentialsGetCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(credentialsGetCmd) // Create flags - credentialsCreateCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(credentialsCreateCmd) credentialsCreateCmd.Flags().String("name", "", "Unique name for the credential (required)") credentialsCreateCmd.Flags().String("domain", "", "Target domain this credential is for (required)") credentialsCreateCmd.Flags().StringArray("value", []string{}, "Field name=value pair (repeatable, e.g., --value username=myuser --value password=mypass)") @@ -417,7 +417,7 @@ func init() { _ = credentialsCreateCmd.MarkFlagRequired("domain") // Update flags - credentialsUpdateCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(credentialsUpdateCmd) credentialsUpdateCmd.Flags().String("name", "", "New name for the credential") credentialsUpdateCmd.Flags().String("sso-provider", "", "SSO provider (set to empty string to remove)") credentialsUpdateCmd.Flags().String("totp-secret", "", "Base32-encoded TOTP secret (set to empty string to remove)") @@ -427,7 +427,7 @@ func init() { credentialsDeleteCmd.Flags().BoolP("yes", "y", false, "Skip confirmation prompt") // TOTP code flags - credentialsTotpCodeCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(credentialsTotpCodeCmd) } func runCredentialsList(cmd *cobra.Command, args []string) error { diff --git a/cmd/deploy.go b/cmd/deploy.go index 40417f7..f399d6c 100644 --- a/cmd/deploy.go +++ b/cmd/deploy.go @@ -77,7 +77,7 @@ func init() { deployCmd.Flags().StringP("output", "o", "", "Output format: json for JSONL streaming output") // Subcommands under deploy - deployGetCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(deployGetCmd) deployCmd.AddCommand(deployGetCmd) deployLogsCmd.Flags().BoolP("follow", "f", false, "Follow logs in real-time (stream continuously)") @@ -92,7 +92,7 @@ func init() { deployHistoryCmd.Flags().Int("per-page", 20, "Items per page (alias of --limit)") deployHistoryCmd.Flags().Int("page", 1, "Page number (1-based)") deployHistoryCmd.Flags().String("app-version", "", "Filter by application version (requires app_name)") - deployHistoryCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(deployHistoryCmd) deployCmd.AddCommand(deployHistoryCmd) // Flags for GitHub deploy @@ -122,8 +122,8 @@ func runDeployGithub(cmd *cobra.Command, args []string) error { force, _ := cmd.Flags().GetBool("force") output, _ := cmd.Flags().GetString("output") - if output != "" && output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(output); err != nil { + return err } // Collect env vars similar to runDeploy @@ -244,8 +244,8 @@ func runDeploy(cmd *cobra.Command, args []string) (err error) { region, _ := cmd.Flags().GetString("region") output, _ := cmd.Flags().GetString("output") - if output != "" && output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(output); err != nil { + return err } if version == "" { @@ -341,8 +341,8 @@ func runDeployGet(cmd *cobra.Command, args []string) error { client := getKernelClient(cmd) output, _ := cmd.Flags().GetString("output") - if output != "" && output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(output); err != nil { + return err } deployment, err := client.Deployments.Get(cmd.Context(), args[0]) @@ -489,8 +489,8 @@ func runDeployHistory(cmd *cobra.Command, args []string) error { appVersionFilter, _ := cmd.Flags().GetString("app-version") output, _ := cmd.Flags().GetString("output") - if output != "" && output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(output); err != nil { + return err } // Prefer page/per-page when provided; map legacy --limit otherwise diff --git a/cmd/extensions.go b/cmd/extensions.go index b941336..7138e0c 100644 --- a/cmd/extensions.go +++ b/cmd/extensions.go @@ -82,8 +82,8 @@ type ExtensionsCmd struct { } func (e ExtensionsCmd) List(ctx context.Context, in ExtensionsListInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if in.Output != "json" { @@ -301,8 +301,8 @@ func (e ExtensionsCmd) DownloadWebStore(ctx context.Context, in ExtensionsDownlo } func (e ExtensionsCmd) Upload(ctx context.Context, in ExtensionsUploadInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if in.Dir == "" { @@ -518,12 +518,12 @@ func init() { extensionsCmd.AddCommand(extensionsUploadCmd) extensionsCmd.AddCommand(extensionsBuildWebBotAuthCmd) - extensionsListCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(extensionsListCmd) extensionsDeleteCmd.Flags().BoolP("yes", "y", false, "Skip confirmation prompt") extensionsDownloadCmd.Flags().String("to", "", "Output zip file path") extensionsDownloadWebStoreCmd.Flags().String("to", "", "Output zip file path for the downloaded archive") extensionsDownloadWebStoreCmd.Flags().String("os", "", "Target OS: mac, win, or linux (default linux)") - extensionsUploadCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(extensionsUploadCmd) extensionsUploadCmd.Flags().String("name", "", "Optional unique extension name") extensionsBuildWebBotAuthCmd.Flags().String("to", "./web-bot-auth", "Output directory for the prepared extension") extensionsBuildWebBotAuthCmd.Flags().String("url", "http://127.0.0.1:10001", "Base URL for update.xml and policy templates") diff --git a/cmd/extensions_test.go b/cmd/extensions_test.go index 4a1770c..8be3c44 100644 --- a/cmd/extensions_test.go +++ b/cmd/extensions_test.go @@ -14,34 +14,9 @@ import ( "github.com/kernel/kernel-go-sdk" "github.com/kernel/kernel-go-sdk/option" - "github.com/pterm/pterm" "github.com/stretchr/testify/assert" ) -// captureExtensionsOutput sets pterm writers for tests in this file -func captureExtensionsOutput(t *testing.T) *bytes.Buffer { - var buf bytes.Buffer - pterm.SetDefaultOutput(&buf) - pterm.Info.Writer = &buf - pterm.Error.Writer = &buf - pterm.Success.Writer = &buf - pterm.Warning.Writer = &buf - pterm.Debug.Writer = &buf - pterm.Fatal.Writer = &buf - pterm.DefaultTable = *pterm.DefaultTable.WithWriter(&buf) - t.Cleanup(func() { - pterm.SetDefaultOutput(os.Stdout) - pterm.Info.Writer = os.Stdout - pterm.Error.Writer = os.Stdout - pterm.Success.Writer = os.Stdout - pterm.Warning.Writer = os.Stdout - pterm.Debug.Writer = os.Stdout - pterm.Fatal.Writer = os.Stdout - pterm.DefaultTable = *pterm.DefaultTable.WithWriter(os.Stdout) - }) - return &buf -} - // FakeExtensionsService implements ExtensionsService type FakeExtensionsService struct { ListFunc func(ctx context.Context, opts ...option.RequestOption) (*[]kernel.ExtensionListResponse, error) @@ -84,7 +59,7 @@ func (f *FakeExtensionsService) Upload(ctx context.Context, body kernel.Extensio } func TestExtensionsList_Empty(t *testing.T) { - buf := captureExtensionsOutput(t) + buf := capturePtermOutput(t) fake := &FakeExtensionsService{} e := ExtensionsCmd{extensions: fake} _ = e.List(context.Background(), ExtensionsListInput{}) @@ -92,7 +67,7 @@ func TestExtensionsList_Empty(t *testing.T) { } func TestExtensionsList_WithRows(t *testing.T) { - buf := captureExtensionsOutput(t) + buf := capturePtermOutput(t) created := time.Unix(0, 0) rows := []kernel.ExtensionListResponse{{ID: "e1", Name: "alpha", CreatedAt: created, SizeBytes: 10}, {ID: "e2", Name: "", CreatedAt: created, SizeBytes: 20}} fake := &FakeExtensionsService{ListFunc: func(ctx context.Context, opts ...option.RequestOption) (*[]kernel.ExtensionListResponse, error) { @@ -107,7 +82,7 @@ func TestExtensionsList_WithRows(t *testing.T) { } func TestExtensionsDelete_SkipConfirm(t *testing.T) { - buf := captureExtensionsOutput(t) + buf := capturePtermOutput(t) fake := &FakeExtensionsService{} e := ExtensionsCmd{extensions: fake} _ = e.Delete(context.Background(), ExtensionsDeleteInput{Identifier: "e1", SkipConfirm: true}) @@ -115,7 +90,7 @@ func TestExtensionsDelete_SkipConfirm(t *testing.T) { } func TestExtensionsDelete_NotFound(t *testing.T) { - buf := captureExtensionsOutput(t) + buf := capturePtermOutput(t) fake := &FakeExtensionsService{DeleteFunc: func(ctx context.Context, idOrName string, opts ...option.RequestOption) error { return &kernel.Error{StatusCode: http.StatusNotFound} }} @@ -125,7 +100,7 @@ func TestExtensionsDelete_NotFound(t *testing.T) { } func TestExtensionsDownload_MissingOutput(t *testing.T) { - buf := captureExtensionsOutput(t) + buf := capturePtermOutput(t) fake := &FakeExtensionsService{DownloadFunc: func(ctx context.Context, idOrName string, opts ...option.RequestOption) (*http.Response, error) { return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("content")), Header: http.Header{}}, nil }} @@ -135,7 +110,7 @@ func TestExtensionsDownload_MissingOutput(t *testing.T) { } func TestExtensionsDownload_ExtractsToDir(t *testing.T) { - buf := captureExtensionsOutput(t) + buf := capturePtermOutput(t) // Create a small in-memory zip var zbuf bytes.Buffer zw := zip.NewWriter(&zbuf) @@ -160,7 +135,7 @@ func TestExtensionsDownload_ExtractsToDir(t *testing.T) { } func TestExtensionsDownloadWebStore_ExtractsToDir(t *testing.T) { - buf := captureExtensionsOutput(t) + buf := capturePtermOutput(t) var zbuf bytes.Buffer zw := zip.NewWriter(&zbuf) w, _ := zw.Create("manifest.json") @@ -183,7 +158,7 @@ func TestExtensionsDownloadWebStore_ExtractsToDir(t *testing.T) { } func TestExtensionsDownloadWebStore_InvalidOS(t *testing.T) { - buf := captureExtensionsOutput(t) + buf := capturePtermOutput(t) fake := &FakeExtensionsService{} e := ExtensionsCmd{extensions: fake} _ = e.DownloadWebStore(context.Background(), ExtensionsDownloadWebStoreInput{URL: "https://store/link", Output: "x", OS: "freebsd"}) @@ -191,7 +166,7 @@ func TestExtensionsDownloadWebStore_InvalidOS(t *testing.T) { } func TestExtensionsUpload_Success(t *testing.T) { - buf := captureExtensionsOutput(t) + buf := capturePtermOutput(t) dir := t.TempDir() // create a sample file inside dir err := os.WriteFile(filepath.Join(dir, "manifest.json"), []byte("{}"), 0644) diff --git a/cmd/flag_values.go b/cmd/flag_values.go new file mode 100644 index 0000000..b92b315 --- /dev/null +++ b/cmd/flag_values.go @@ -0,0 +1,13 @@ +package cmd + +// BoolFlag captures whether a boolean flag was set explicitly and its value. +type BoolFlag struct { + Set bool + Value bool +} + +// Int64Flag captures whether an int64 flag was set explicitly and its value. +type Int64Flag struct { + Set bool + Value int64 +} diff --git a/cmd/invoke.go b/cmd/invoke.go index 1c1543e..cca3804 100644 --- a/cmd/invoke.go +++ b/cmd/invoke.go @@ -84,13 +84,13 @@ func init() { invocationHistoryCmd.Flags().String("since", "", "Show invocations that started since the given time") invocationHistoryCmd.Flags().String("status", "", "Filter by invocation status: queued, running, succeeded, failed") invocationHistoryCmd.Flags().String("version", "", "Filter by invocation version") - invocationHistoryCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(invocationHistoryCmd) invokeCmd.AddCommand(invocationHistoryCmd) - invocationBrowsersCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(invocationBrowsersCmd) invokeCmd.AddCommand(invocationBrowsersCmd) - invocationGetCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(invocationGetCmd) invokeCmd.AddCommand(invocationGetCmd) invocationUpdateCmd.Flags().String("status", "", "New invocation status: succeeded or failed") @@ -112,8 +112,8 @@ func runInvoke(cmd *cobra.Command, args []string) error { version, _ := cmd.Flags().GetString("version") output, _ := cmd.Flags().GetString("output") - if output != "" && output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(output); err != nil { + return err } jsonOutput := output == "json" @@ -427,8 +427,8 @@ func runInvocationHistory(cmd *cobra.Command, args []string) error { versionFilter, _ := cmd.Flags().GetString("version") output, _ := cmd.Flags().GetString("output") - if output != "" && output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(output); err != nil { + return err } // Build parameters for the API call @@ -552,8 +552,8 @@ func runInvocationBrowsers(cmd *cobra.Command, args []string) error { invocationID := args[0] output, _ := cmd.Flags().GetString("output") - if output != "" && output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(output); err != nil { + return err } resp, err := client.Invocations.ListBrowsers(cmd.Context(), invocationID) @@ -608,8 +608,8 @@ func runInvocationGet(cmd *cobra.Command, args []string) error { client := getKernelClient(cmd) output, _ := cmd.Flags().GetString("output") - if output != "" && output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(output); err != nil { + return err } resp, err := client.Invocations.Get(cmd.Context(), args[0]) diff --git a/cmd/output.go b/cmd/output.go new file mode 100644 index 0000000..277b449 --- /dev/null +++ b/cmd/output.go @@ -0,0 +1,14 @@ +package cmd + +import ( + "github.com/kernel/cli/pkg/util" + "github.com/spf13/cobra" +) + +func validateJSONOutput(output string) error { + return util.ValidateJSONOutput(output) +} + +func addJSONOutputFlag(cmd *cobra.Command) { + util.AddJSONOutputFlag(cmd) +} diff --git a/cmd/output_test.go b/cmd/output_test.go new file mode 100644 index 0000000..1544b92 --- /dev/null +++ b/cmd/output_test.go @@ -0,0 +1,19 @@ +package cmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateJSONOutput(t *testing.T) { + require.NoError(t, validateJSONOutput("")) + require.NoError(t, validateJSONOutput("json")) + + err := validateJSONOutput("yaml") + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported --output value") + assert.Contains(t, err.Error(), `"yaml"`) + assert.Contains(t, err.Error(), "omit --output") +} diff --git a/cmd/profiles.go b/cmd/profiles.go index c19d3ca..bfc055c 100644 --- a/cmd/profiles.go +++ b/cmd/profiles.go @@ -63,8 +63,8 @@ type ProfilesCmd struct { } func (p ProfilesCmd) List(ctx context.Context, in ProfilesListInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } page := in.Page @@ -145,8 +145,8 @@ func (p ProfilesCmd) List(ctx context.Context, in ProfilesListInput) error { } func (p ProfilesCmd) Get(ctx context.Context, in ProfilesGetInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } item, err := p.profiles.Get(ctx, in.Identifier) @@ -181,8 +181,8 @@ func (p ProfilesCmd) Get(ctx context.Context, in ProfilesGetInput) error { } func (p ProfilesCmd) Create(ctx context.Context, in ProfilesCreateInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } params := kernel.ProfileNewParams{} @@ -388,12 +388,12 @@ func init() { profilesCmd.AddCommand(profilesDeleteCmd) profilesCmd.AddCommand(profilesDownloadCmd) - profilesListCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(profilesListCmd) profilesListCmd.Flags().Int("per-page", 20, "Items per page (default 20)") profilesListCmd.Flags().Int("page", 1, "Page number (1-based)") profilesListCmd.Flags().String("query", "", "Search profiles by name or ID") - profilesGetCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") - profilesCreateCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(profilesGetCmd) + addJSONOutputFlag(profilesCreateCmd) profilesCreateCmd.Flags().String("name", "", "Optional unique profile name") profilesDeleteCmd.Flags().BoolP("yes", "y", false, "Skip confirmation prompt") profilesDownloadCmd.Flags().String("to", "", "Directory to extract the profile into (required)") diff --git a/cmd/profiles_test.go b/cmd/profiles_test.go index 73f84c0..d104384 100644 --- a/cmd/profiles_test.go +++ b/cmd/profiles_test.go @@ -18,34 +18,9 @@ import ( "github.com/kernel/kernel-go-sdk/option" "github.com/kernel/kernel-go-sdk/packages/pagination" "github.com/klauspost/compress/zstd" - "github.com/pterm/pterm" "github.com/stretchr/testify/assert" ) -// captureProfilesOutput sets pterm writers for tests in this file -func captureProfilesOutput(t *testing.T) *bytes.Buffer { - var buf bytes.Buffer - pterm.SetDefaultOutput(&buf) - pterm.Info.Writer = &buf - pterm.Error.Writer = &buf - pterm.Success.Writer = &buf - pterm.Warning.Writer = &buf - pterm.Debug.Writer = &buf - pterm.Fatal.Writer = &buf - pterm.DefaultTable = *pterm.DefaultTable.WithWriter(&buf) - t.Cleanup(func() { - pterm.SetDefaultOutput(os.Stdout) - pterm.Info.Writer = os.Stdout - pterm.Error.Writer = os.Stdout - pterm.Success.Writer = os.Stdout - pterm.Warning.Writer = os.Stdout - pterm.Debug.Writer = os.Stdout - pterm.Fatal.Writer = os.Stdout - pterm.DefaultTable = *pterm.DefaultTable.WithWriter(os.Stdout) - }) - return &buf -} - // FakeProfilesService implements ProfilesService type FakeProfilesService struct { GetFunc func(ctx context.Context, idOrName string, opts ...option.RequestOption) (*kernel.Profile, error) @@ -87,7 +62,7 @@ func (f *FakeProfilesService) New(ctx context.Context, body kernel.ProfileNewPar } func TestProfilesList_Empty(t *testing.T) { - buf := captureProfilesOutput(t) + buf := capturePtermOutput(t) fake := &FakeProfilesService{} p := ProfilesCmd{profiles: fake} _ = p.List(context.Background(), ProfilesListInput{Page: 1, PerPage: 20}) @@ -95,7 +70,7 @@ func TestProfilesList_Empty(t *testing.T) { } func TestProfilesList_WithRows(t *testing.T) { - buf := captureProfilesOutput(t) + buf := capturePtermOutput(t) created := time.Unix(0, 0) rows := []kernel.Profile{{ID: "p1", Name: "alpha", CreatedAt: created, UpdatedAt: created}, {ID: "p2", Name: "", CreatedAt: created, UpdatedAt: created}} fake := &FakeProfilesService{ListFunc: func(ctx context.Context, query kernel.ProfileListParams, opts ...option.RequestOption) (*pagination.OffsetPagination[kernel.Profile], error) { @@ -111,7 +86,7 @@ func TestProfilesList_WithRows(t *testing.T) { } func TestProfilesList_HasMore(t *testing.T) { - buf := captureProfilesOutput(t) + buf := capturePtermOutput(t) created := time.Unix(0, 0) perPage := 2 items := make([]kernel.Profile, perPage+1) @@ -132,7 +107,7 @@ func TestProfilesList_HasMore(t *testing.T) { } func TestProfilesList_QueryInNextHint(t *testing.T) { - buf := captureProfilesOutput(t) + buf := capturePtermOutput(t) created := time.Unix(0, 0) items := make([]kernel.Profile, 3) for i := range items { @@ -148,7 +123,7 @@ func TestProfilesList_QueryInNextHint(t *testing.T) { } func TestProfilesList_QueryWithSpacesQuoted(t *testing.T) { - buf := captureProfilesOutput(t) + buf := capturePtermOutput(t) created := time.Unix(0, 0) items := make([]kernel.Profile, 3) for i := range items { @@ -164,7 +139,7 @@ func TestProfilesList_QueryWithSpacesQuoted(t *testing.T) { } func TestProfilesGet_Success(t *testing.T) { - buf := captureProfilesOutput(t) + buf := capturePtermOutput(t) fake := &FakeProfilesService{GetFunc: func(ctx context.Context, idOrName string, opts ...option.RequestOption) (*kernel.Profile, error) { return &kernel.Profile{ID: "p1", Name: "alpha", CreatedAt: time.Unix(0, 0), UpdatedAt: time.Unix(0, 0)}, nil }} @@ -188,7 +163,7 @@ func TestProfilesGet_Error(t *testing.T) { } func TestProfilesCreate_Success(t *testing.T) { - buf := captureProfilesOutput(t) + buf := capturePtermOutput(t) fake := &FakeProfilesService{NewFunc: func(ctx context.Context, body kernel.ProfileNewParams, opts ...option.RequestOption) (*kernel.Profile, error) { return &kernel.Profile{ID: "pnew", Name: body.Name.Value, CreatedAt: time.Unix(0, 0), UpdatedAt: time.Unix(0, 0)}, nil }} @@ -210,7 +185,7 @@ func TestProfilesCreate_Error(t *testing.T) { } func TestProfilesDelete_ConfirmNotFound(t *testing.T) { - buf := captureProfilesOutput(t) + buf := capturePtermOutput(t) fake := &FakeProfilesService{GetFunc: func(ctx context.Context, idOrName string, opts ...option.RequestOption) (*kernel.Profile, error) { return nil, &kernel.Error{StatusCode: http.StatusNotFound} }} @@ -220,7 +195,7 @@ func TestProfilesDelete_ConfirmNotFound(t *testing.T) { } func TestProfilesDelete_SkipConfirm(t *testing.T) { - buf := captureProfilesOutput(t) + buf := capturePtermOutput(t) fake := &FakeProfilesService{} p := ProfilesCmd{profiles: fake} _ = p.Delete(context.Background(), ProfilesDeleteInput{Identifier: "a", SkipConfirm: true}) @@ -255,7 +230,7 @@ func TestProfilesDownload_MissingTo(t *testing.T) { } func TestProfilesDownload_ExtractSuccess(t *testing.T) { - buf := captureProfilesOutput(t) + buf := capturePtermOutput(t) dir, err := os.MkdirTemp("", "profile-*") assert.NoError(t, err) defer os.RemoveAll(dir) @@ -283,7 +258,7 @@ func TestProfilesDownload_ExtractSuccess(t *testing.T) { } func TestProfilesDownload_202NoData(t *testing.T) { - buf := captureProfilesOutput(t) + buf := capturePtermOutput(t) dir, err := os.MkdirTemp("", "profile-*") assert.NoError(t, err) defer os.RemoveAll(dir) diff --git a/cmd/projects.go b/cmd/projects.go index a223165..4054961 100644 --- a/cmd/projects.go +++ b/cmd/projects.go @@ -148,8 +148,8 @@ func (c ProjectsCmd) Delete(ctx context.Context, in ProjectsDeleteInput) error { } func (c ProjectsCmd) LimitsGet(ctx context.Context, in ProjectsLimitsGetInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } projectID, err := resolveProjectArg(ctx, c.projects, in.Identifier) @@ -175,8 +175,8 @@ func (c ProjectsCmd) LimitsGet(ctx context.Context, in ProjectsLimitsGetInput) e } func (c ProjectsCmd) LimitsSet(ctx context.Context, in ProjectsLimitsSetInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } projectID, err := resolveProjectArg(ctx, c.projects, in.Identifier) @@ -311,15 +311,11 @@ func runProjectsLimitsSet(cmd *cobra.Command, args []string) error { }) } -func addProjectsLimitsOutputFlag(cmd *cobra.Command) { - cmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") -} - func addProjectsLimitsSetFlags(cmd *cobra.Command) { cmd.Flags().Int64("max-concurrent-sessions", 0, "Maximum concurrent browser sessions (0 to remove cap)") cmd.Flags().Int64("max-concurrent-invocations", 0, "Maximum concurrent app invocations (0 to remove cap)") cmd.Flags().Int64("max-pooled-sessions", 0, "Maximum pooled sessions capacity (0 to remove cap)") - addProjectsLimitsOutputFlag(cmd) + addJSONOutputFlag(cmd) } var projectsCmd = &cobra.Command{ @@ -396,9 +392,9 @@ var projectsSetLimitsCompatCmd = &cobra.Command{ } func init() { - addProjectsLimitsOutputFlag(projectsLimitsGetCmd) + addJSONOutputFlag(projectsLimitsGetCmd) addProjectsLimitsSetFlags(projectsLimitsSetCmd) - addProjectsLimitsOutputFlag(projectsGetLimitsCompatCmd) + addJSONOutputFlag(projectsGetLimitsCompatCmd) addProjectsLimitsSetFlags(projectsSetLimitsCompatCmd) projectsLimitsCmd.AddCommand(projectsLimitsGetCmd) diff --git a/cmd/projects_test.go b/cmd/projects_test.go index 33ae60d..3855cdf 100644 --- a/cmd/projects_test.go +++ b/cmd/projects_test.go @@ -1,43 +1,17 @@ package cmd import ( - "bytes" "context" "errors" - "os" "testing" "github.com/kernel/kernel-go-sdk" "github.com/kernel/kernel-go-sdk/option" "github.com/kernel/kernel-go-sdk/packages/pagination" "github.com/kernel/kernel-go-sdk/packages/respjson" - "github.com/pterm/pterm" "github.com/stretchr/testify/assert" ) -func captureProjectsOutput(t *testing.T) *bytes.Buffer { - var buf bytes.Buffer - pterm.SetDefaultOutput(&buf) - pterm.Info.Writer = &buf - pterm.Error.Writer = &buf - pterm.Success.Writer = &buf - pterm.Warning.Writer = &buf - pterm.Debug.Writer = &buf - pterm.Fatal.Writer = &buf - pterm.DefaultTable = *pterm.DefaultTable.WithWriter(&buf) - t.Cleanup(func() { - pterm.SetDefaultOutput(os.Stdout) - pterm.Info.Writer = os.Stdout - pterm.Error.Writer = os.Stdout - pterm.Success.Writer = os.Stdout - pterm.Warning.Writer = os.Stdout - pterm.Debug.Writer = os.Stdout - pterm.Fatal.Writer = os.Stdout - pterm.DefaultTable = *pterm.DefaultTable.WithWriter(os.Stdout) - }) - return &buf -} - type FakeProjectsService struct { ListFunc func(ctx context.Context, query kernel.ProjectListParams, opts ...option.RequestOption) (*pagination.OffsetPagination[kernel.Project], error) NewFunc func(ctx context.Context, body kernel.ProjectNewParams, opts ...option.RequestOption) (*kernel.Project, error) @@ -93,7 +67,7 @@ func (f *FakeProjectLimitsService) Update(ctx context.Context, id string, body k } func TestProjectsLimitsGet_DefaultOutput(t *testing.T) { - buf := captureProjectsOutput(t) + buf := capturePtermOutput(t) limits := &kernel.ProjectLimits{ MaxConcurrentSessions: 10, MaxConcurrentInvocations: 5, @@ -129,6 +103,24 @@ func TestProjectsLimitsGet_InvalidOutput(t *testing.T) { assert.Contains(t, err.Error(), "unsupported --output value") } +func TestProjectsLimitsSet_InvalidOutput(t *testing.T) { + c := ProjectsCmd{ + projects: &FakeProjectsService{}, + limits: &FakeProjectLimitsService{ + UpdateFunc: func(ctx context.Context, id string, body kernel.ProjectLimitUpdateParams, opts ...option.RequestOption) (*kernel.ProjectLimits, error) { + t.Fatal("Update should not be called") + return nil, nil + }, + }, + } + err := c.LimitsSet(context.Background(), ProjectsLimitsSetInput{ + Identifier: "a12345678901234567890123", + Output: "yaml", + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported --output value") +} + func TestProjectsLimitsSet_RejectsNegativeValues(t *testing.T) { c := ProjectsCmd{projects: &FakeProjectsService{}, limits: &FakeProjectLimitsService{}} err := c.LimitsSet(context.Background(), ProjectsLimitsSetInput{ @@ -143,7 +135,7 @@ func TestProjectsLimitsSet_RejectsNegativeValues(t *testing.T) { } func TestProjectsLimitsSet_Success(t *testing.T) { - buf := captureProjectsOutput(t) + buf := capturePtermOutput(t) fakeProjects := &FakeProjectsService{} fakeLimits := &FakeProjectLimitsService{ UpdateFunc: func(ctx context.Context, id string, body kernel.ProjectLimitUpdateParams, opts ...option.RequestOption) (*kernel.ProjectLimits, error) { diff --git a/cmd/proxies/check.go b/cmd/proxies/check.go index a820ff4..2fb2650 100644 --- a/cmd/proxies/check.go +++ b/cmd/proxies/check.go @@ -12,8 +12,8 @@ import ( ) func (p ProxyCmd) Check(ctx context.Context, in ProxyCheckInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if in.Output != "json" { diff --git a/cmd/proxies/create.go b/cmd/proxies/create.go index 746e1d5..44de57f 100644 --- a/cmd/proxies/create.go +++ b/cmd/proxies/create.go @@ -13,8 +13,8 @@ import ( ) func (p ProxyCmd) Create(ctx context.Context, in ProxyCreateInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } // Validate proxy type diff --git a/cmd/proxies/get.go b/cmd/proxies/get.go index 1c5a54b..83c1f65 100644 --- a/cmd/proxies/get.go +++ b/cmd/proxies/get.go @@ -12,8 +12,8 @@ import ( ) func (p ProxyCmd) Get(ctx context.Context, in ProxyGetInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } item, err := p.proxies.Get(ctx, in.ID) diff --git a/cmd/proxies/list.go b/cmd/proxies/list.go index 7300611..762503e 100644 --- a/cmd/proxies/list.go +++ b/cmd/proxies/list.go @@ -13,8 +13,8 @@ import ( ) func (p ProxyCmd) List(ctx context.Context, in ProxyListInput) error { - if in.Output != "" && in.Output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(in.Output); err != nil { + return err } if in.Output != "json" { diff --git a/cmd/proxies/output.go b/cmd/proxies/output.go new file mode 100644 index 0000000..e036121 --- /dev/null +++ b/cmd/proxies/output.go @@ -0,0 +1,14 @@ +package proxies + +import ( + "github.com/kernel/cli/pkg/util" + "github.com/spf13/cobra" +) + +func validateJSONOutput(output string) error { + return util.ValidateJSONOutput(output) +} + +func addJSONOutputFlag(cmd *cobra.Command) { + util.AddJSONOutputFlag(cmd) +} diff --git a/cmd/proxies/proxies.go b/cmd/proxies/proxies.go index 2440d3a..ee93faf 100644 --- a/cmd/proxies/proxies.go +++ b/cmd/proxies/proxies.go @@ -80,9 +80,9 @@ func init() { ProxiesCmd.AddCommand(proxiesCheckCmd) // Add output flags - proxiesListCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") - proxiesGetCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") - proxiesCreateCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(proxiesListCmd) + addJSONOutputFlag(proxiesGetCmd) + addJSONOutputFlag(proxiesCreateCmd) // Add flags for create command proxiesCreateCmd.Flags().String("name", "", "Proxy configuration name") @@ -114,5 +114,5 @@ func init() { proxiesDeleteCmd.Flags().BoolP("yes", "y", false, "Skip confirmation prompt") // Check flags - proxiesCheckCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") + addJSONOutputFlag(proxiesCheckCmd) } diff --git a/cmd/ssh.go b/cmd/ssh.go index 307149c..5de7096 100644 --- a/cmd/ssh.go +++ b/cmd/ssh.go @@ -73,8 +73,8 @@ func runSSH(cmd *cobra.Command, args []string) error { setupOnly, _ := cmd.Flags().GetBool("setup-only") output, _ := cmd.Flags().GetString("output") - if output != "" && output != "json" { - return fmt.Errorf("unsupported --output value: use 'json'") + if err := validateJSONOutput(output); err != nil { + return err } if output == "json" && !setupOnly { return fmt.Errorf("--output json is only supported with --setup-only") diff --git a/cmd/status.go b/cmd/status.go index 4d1ca70..f225e22 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -35,11 +35,14 @@ var statusCmd = &cobra.Command{ } func init() { - statusCmd.Flags().StringP("output", "o", "", "Output format (json)") + addJSONOutputFlag(statusCmd) } func runStatus(cmd *cobra.Command, args []string) error { output, _ := cmd.Flags().GetString("output") + if err := validateJSONOutput(output); err != nil { + return err + } client := &http.Client{Timeout: 10 * time.Second} resp, err := client.Get(util.GetBaseURL() + "/status") diff --git a/cmd/test_helpers_test.go b/cmd/test_helpers_test.go new file mode 100644 index 0000000..779e58b --- /dev/null +++ b/cmd/test_helpers_test.go @@ -0,0 +1,32 @@ +package cmd + +import ( + "bytes" + "os" + "testing" + + "github.com/pterm/pterm" +) + +func capturePtermOutput(t *testing.T) *bytes.Buffer { + var buf bytes.Buffer + pterm.SetDefaultOutput(&buf) + pterm.Info.Writer = &buf + pterm.Error.Writer = &buf + pterm.Success.Writer = &buf + pterm.Warning.Writer = &buf + pterm.Debug.Writer = &buf + pterm.Fatal.Writer = &buf + pterm.DefaultTable = *pterm.DefaultTable.WithWriter(&buf) + t.Cleanup(func() { + pterm.SetDefaultOutput(os.Stdout) + pterm.Info.Writer = os.Stdout + pterm.Error.Writer = os.Stdout + pterm.Success.Writer = os.Stdout + pterm.Warning.Writer = os.Stdout + pterm.Debug.Writer = os.Stdout + pterm.Fatal.Writer = os.Stdout + pterm.DefaultTable = *pterm.DefaultTable.WithWriter(os.Stdout) + }) + return &buf +} diff --git a/go.mod b/go.mod index 62cf135..5ed2523 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1 github.com/golang-jwt/jwt/v5 v5.2.2 github.com/joho/godotenv v1.5.1 - github.com/kernel/kernel-go-sdk v0.53.0 + github.com/kernel/kernel-go-sdk v0.58.0 github.com/klauspost/compress v1.18.5 github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/pterm/pterm v0.12.80 diff --git a/go.sum b/go.sum index b29cf9b..96b230c 100644 --- a/go.sum +++ b/go.sum @@ -64,8 +64,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= -github.com/kernel/kernel-go-sdk v0.53.0 h1:XgcuJv3G4a6nr9LYBZ21gLUWvsIDLSG4YhZAngNrqE0= -github.com/kernel/kernel-go-sdk v0.53.0/go.mod h1:EeZzSuHZVeHKxKCPUzxou2bovNGhXaz0RXrSqKNf1AQ= +github.com/kernel/kernel-go-sdk v0.58.0 h1:FcvqZXgK5D3IbHJarvRJVPJpKE3+Pd7i4z4kBgElpIk= +github.com/kernel/kernel-go-sdk v0.58.0/go.mod h1:EeZzSuHZVeHKxKCPUzxou2bovNGhXaz0RXrSqKNf1AQ= github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= diff --git a/pkg/util/output.go b/pkg/util/output.go new file mode 100644 index 0000000..68ecd58 --- /dev/null +++ b/pkg/util/output.go @@ -0,0 +1,20 @@ +package util + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +const JSONOutputFlagDescription = "Output format: json for raw API response" + +func ValidateJSONOutput(output string) error { + if output == "" || output == "json" { + return nil + } + return fmt.Errorf("unsupported --output value %q; use \"json\" or omit --output for human-readable output", output) +} + +func AddJSONOutputFlag(cmd *cobra.Command) { + cmd.Flags().StringP("output", "o", "", JSONOutputFlagDescription) +}