diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 04564f9c..c7253f26 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -14,6 +14,7 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/delete" "github.com/brevdev/brev-cli/pkg/cmd/envvars" "github.com/brevdev/brev-cli/pkg/cmd/fu" + "github.com/brevdev/brev-cli/pkg/cmd/gpucreate" "github.com/brevdev/brev-cli/pkg/cmd/gpusearch" "github.com/brevdev/brev-cli/pkg/cmd/healthcheck" "github.com/brevdev/brev-cli/pkg/cmd/hello" @@ -272,6 +273,7 @@ func createCmdTree(cmd *cobra.Command, t *terminal.Terminal, loginCmdStore *stor cmd.AddCommand(workspacegroups.NewCmdWorkspaceGroups(t, loginCmdStore)) cmd.AddCommand(scale.NewCmdScale(t, noLoginCmdStore)) cmd.AddCommand(gpusearch.NewCmdGPUSearch(t, noLoginCmdStore)) + cmd.AddCommand(gpucreate.NewCmdGPUCreate(t, loginCmdStore)) cmd.AddCommand(configureenvvars.NewCmdConfigureEnvVars(t, loginCmdStore)) cmd.AddCommand(importideconfig.NewCmdImportIDEConfig(t, noLoginCmdStore)) cmd.AddCommand(shell.NewCmdShell(t, loginCmdStore, noLoginCmdStore)) diff --git a/pkg/cmd/gpucreate/gpucreate.go b/pkg/cmd/gpucreate/gpucreate.go new file mode 100644 index 00000000..3a008fd6 --- /dev/null +++ b/pkg/cmd/gpucreate/gpucreate.go @@ -0,0 +1,824 @@ +// Package gpucreate provides a command to create GPU instances with retry logic +package gpucreate + +import ( + "encoding/json" + "fmt" + "io" + "os" + "strconv" + "strings" + "sync" + "time" + "unicode" + + "github.com/brevdev/brev-cli/pkg/cmd/gpusearch" + "github.com/brevdev/brev-cli/pkg/cmd/util" + "github.com/brevdev/brev-cli/pkg/config" + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/featureflag" + "github.com/brevdev/brev-cli/pkg/store" + "github.com/brevdev/brev-cli/pkg/terminal" + "github.com/spf13/cobra" +) + +var ( + long = `Create GPU instances with automatic retry across multiple instance types. + +This command attempts to create GPU instances, trying different instance types +until the desired number of instances are successfully created. Instance types +can be specified directly, piped from 'brev search', or auto-selected using defaults. + +Default Behavior: +If no instance types are specified (no --type flag and no piped input), the command +automatically searches for GPUs matching these criteria: + - Minimum 20GB total VRAM + - Minimum 500GB disk + - Compute capability 8.0+ (Ampere or newer) + - Boot time under 7 minutes +Results are sorted by price (cheapest first). + +Retry and Fallback Logic: +When multiple instance types are provided (via --type or piped input), the command +tries to create ALL instances using the first type before falling back to the next: + + 1. Try first type for all instances (using --parallel workers if specified) + 2. If first type succeeds for all instances, done + 3. If first type fails for some instances, try second type for remaining instances + 4. Continue until all instances are created or all types are exhausted + +Example with --count 2 and types [A, B]: + - Try A for instance-1 → success + - Try A for instance-2 → success + - Done! (both instances use type A) + +If type A fails for instance-2: + - Try A for instance-1 → success + - Try A for instance-2 → fail + - Try B for instance-2 → success + - Done! (instance-1 uses A, instance-2 uses B) + +Startup Scripts: +You can attach a startup script that runs when the instance boots using the +--startup-script flag. The script can be provided as: + - An inline string: --startup-script 'pip install torch' + - A file path (prefix with @): --startup-script @setup.sh + - An absolute file path: --startup-script @/path/to/setup.sh` + + example = ` + # Quick start: create an instance using smart defaults (sorted by price) + brev create my-instance + + # Create with explicit --name flag + brev create --name my-instance + + # Create and immediately open in VS Code + brev create my-instance | brev open + + # Create and SSH into the instance + brev shell $(brev create my-instance) + + # Create and run a command + brev create my-instance | brev shell -c "nvidia-smi" + + # Create with a specific GPU type + brev create my-instance --type g5.xlarge + + # Pipe instance types from brev search (tries first type, falls back if needed) + brev search --min-vram 24 | brev create my-instance + + # Create multiple instances (all use same type, with fallback) + brev create my-cluster --count 3 --type g5.xlarge + # Creates: my-cluster-1, my-cluster-2, my-cluster-3 (all g5.xlarge) + + # Create multiple instances with fallback types + brev search --gpu-name A100 | brev create my-cluster --count 2 + # Tries first A100 type for both instances, falls back to next type if needed + + # Create instances in parallel (faster, but may use more types on partial failures) + brev search --gpu-name A100 | brev create my-cluster --count 3 --parallel 3 + + # Try multiple specific types in order (fallback chain) + brev create my-instance --type g5.xlarge,g5.2xlarge,g4dn.xlarge + + # Attach a startup script from a file + brev create my-instance --type g5.xlarge --startup-script @setup.sh + + # Attach an inline startup script + brev create my-instance --startup-script 'pip install torch' + + # Combine: find cheapest A100, attach setup script + brev search --gpu-name A100 --sort price | brev create ml-box -s @ml-setup.sh +` +) + +// GPUCreateStore defines the interface for GPU create operations +type GPUCreateStore interface { + util.GetWorkspaceByNameOrIDErrStore + gpusearch.GPUSearchStore + GetActiveOrganizationOrDefault() (*entity.Organization, error) + GetCurrentUser() (*entity.User, error) + GetWorkspace(workspaceID string) (*entity.Workspace, error) + CreateWorkspace(organizationID string, options *store.CreateWorkspacesOptions) (*entity.Workspace, error) + DeleteWorkspace(workspaceID string) (*entity.Workspace, error) + GetAllInstanceTypesWithWorkspaceGroups(orgID string) (*gpusearch.AllInstanceTypesResponse, error) +} + +// Default filter values for automatic GPU selection +const ( + defaultMinTotalVRAM = 20.0 // GB + defaultMinDisk = 500.0 // GB + defaultMinCapability = 8.0 + defaultMaxBootTime = 7 // minutes +) + +// CreateResult holds the result of a workspace creation attempt +type CreateResult struct { + Workspace *entity.Workspace + InstanceType string + Error error +} + +// NewCmdGPUCreate creates the gpu-create command +func NewCmdGPUCreate(t *terminal.Terminal, gpuCreateStore GPUCreateStore) *cobra.Command { + var name string + var instanceTypes string + var count int + var parallel int + var detached bool + var timeout int + var startupScript string + + cmd := &cobra.Command{ + Annotations: map[string]string{"workspace": ""}, + Use: "create [name]", + Aliases: []string{"provision", "gpu-create", "gpu-retry", "gcreate"}, + DisableFlagsInUseLine: true, + Short: "Create GPU instances with automatic retry", + Long: long, + Example: example, + RunE: func(cmd *cobra.Command, args []string) error { + // Accept name as positional arg or --name flag + if len(args) > 0 && name == "" { + name = args[0] + } + + // Check if output is being piped (for chaining with brev shell/open) + piped := isStdoutPiped() + + // Parse instance types from flag or stdin + types, err := parseInstanceTypes(instanceTypes) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + // If no types provided, use default filters to find suitable GPUs + if len(types) == 0 { + msg := fmt.Sprintf("No instance types specified, using defaults: min-total-vram=%.0fGB, min-disk=%.0fGB, min-capability=%.1f, max-boot-time=%dm\n\n", + defaultMinTotalVRAM, defaultMinDisk, defaultMinCapability, defaultMaxBootTime) + if piped { + fmt.Fprint(os.Stderr, msg) + } else { + t.Vprint(msg) + } + + types, err = getDefaultInstanceTypes(gpuCreateStore) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if len(types) == 0 { + return breverrors.NewValidationError("no GPU instances match the default filters. Try 'brev search' to see available options") + } + } + + if name == "" { + return breverrors.NewValidationError("name is required (as argument or --name flag)") + } + + if count < 1 { + return breverrors.NewValidationError("--count must be at least 1") + } + + if parallel < 1 { + parallel = 1 + } + + // Parse startup script (can be a string or @filepath) + scriptContent, err := parseStartupScript(startupScript) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + opts := GPUCreateOptions{ + Name: name, + InstanceTypes: types, + Count: count, + Parallel: parallel, + Detached: detached, + Timeout: time.Duration(timeout) * time.Second, + StartupScript: scriptContent, + } + + err = RunGPUCreate(t, gpuCreateStore, opts) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil + }, + } + + cmd.Flags().StringVarP(&name, "name", "n", "", "Base name for the instances (or pass as first argument)") + cmd.Flags().StringVarP(&instanceTypes, "type", "t", "", "Comma-separated list of instance types to try") + cmd.Flags().IntVarP(&count, "count", "c", 1, "Number of instances to create") + cmd.Flags().IntVarP(¶llel, "parallel", "p", 1, "Number of parallel creation attempts") + cmd.Flags().BoolVarP(&detached, "detached", "d", false, "Don't wait for instances to be ready") + cmd.Flags().IntVar(&timeout, "timeout", 300, "Timeout in seconds for each instance to become ready") + cmd.Flags().StringVarP(&startupScript, "startup-script", "s", "", "Startup script to run on instance (string or @filepath)") + + return cmd +} + +// InstanceSpec holds an instance type and its target disk size +type InstanceSpec struct { + Type string + DiskGB float64 // Target disk size in GB, 0 means use default +} + +// GPUCreateOptions holds the options for GPU instance creation +type GPUCreateOptions struct { + Name string + InstanceTypes []InstanceSpec + Count int + Parallel int + Detached bool + Timeout time.Duration + StartupScript string +} + +// parseStartupScript parses the startup script from a string or file path +// If the value starts with @, it's treated as a file path +func parseStartupScript(value string) (string, error) { + if value == "" { + return "", nil + } + + // Check if it's a file path (prefixed with @) + if strings.HasPrefix(value, "@") { + filePath := strings.TrimPrefix(value, "@") + content, err := os.ReadFile(filePath) + if err != nil { + return "", breverrors.WrapAndTrace(err) + } + return string(content), nil + } + + // Otherwise, treat it as the script content directly + return value, nil +} + +// getDefaultInstanceTypes fetches GPU instance types using default filters +func getDefaultInstanceTypes(store GPUCreateStore) ([]InstanceSpec, error) { + response, err := store.GetInstanceTypes() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + if response == nil || len(response.Items) == 0 { + return nil, nil + } + + // Use gpusearch package to process, filter, and sort instances + instances := gpusearch.ProcessInstances(response.Items) + filtered := gpusearch.FilterInstances(instances, "", "", 0, defaultMinTotalVRAM, defaultMinCapability, defaultMinDisk, defaultMaxBootTime, false, false, false) + gpusearch.SortInstances(filtered, "price", false) + + // Convert to InstanceSpec with disk info + var specs []InstanceSpec + for _, inst := range filtered { + // For defaults, use the minimum disk size that meets the filter + diskGB := inst.DiskMin + if inst.DiskMin != inst.DiskMax && defaultMinDisk > inst.DiskMin && defaultMinDisk <= inst.DiskMax { + diskGB = defaultMinDisk + } + specs = append(specs, InstanceSpec{Type: inst.Type, DiskGB: diskGB}) + } + + return specs, nil +} + +// parseInstanceTypes parses instance types from flag value or stdin +// Returns InstanceSpec with type and optional disk size (from JSON input) +func parseInstanceTypes(flagValue string) ([]InstanceSpec, error) { + var specs []InstanceSpec + + // First check if there's a flag value + if flagValue != "" { + parts := strings.Split(flagValue, ",") + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + specs = append(specs, InstanceSpec{Type: p}) + } + } + } + + // Check if there's piped input from stdin + stat, _ := os.Stdin.Stat() + if (stat.Mode() & os.ModeCharDevice) == 0 { + // Data is being piped to stdin - read all input first + input, err := io.ReadAll(os.Stdin) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + inputStr := strings.TrimSpace(string(input)) + if inputStr == "" { + return specs, nil + } + + // Check if input is JSON (starts with '[') + if strings.HasPrefix(inputStr, "[") { + jsonSpecs, err := parseJSONInput(inputStr) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + specs = append(specs, jsonSpecs...) + } else { + // Parse as table format + tableSpecs := parseTableInput(inputStr) + specs = append(specs, tableSpecs...) + } + } + + return specs, nil +} + +// parseJSONInput parses JSON array input from gpu-search --json +func parseJSONInput(input string) ([]InstanceSpec, error) { + var instances []gpusearch.GPUInstanceInfo + if err := json.Unmarshal([]byte(input), &instances); err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + var specs []InstanceSpec + for _, inst := range instances { + spec := InstanceSpec{ + Type: inst.Type, + DiskGB: inst.TargetDisk, + } + specs = append(specs, spec) + } + return specs, nil +} + +// parseTableInput parses table format input from gpu-search +// Table format: TYPE TARGET_DISK PROVIDER GPU COUNT ... +func parseTableInput(input string) []InstanceSpec { + var specs []InstanceSpec + lines := strings.Split(input, "\n") + + for i, line := range lines { + // Skip header line (first line typically contains column names) + if i == 0 && (strings.Contains(line, "TYPE") || strings.Contains(line, "GPU")) { + continue + } + + // Skip empty lines + line = strings.TrimSpace(line) + if line == "" { + continue + } + + // Skip summary lines (e.g., "Found X GPU instance types") + if strings.HasPrefix(line, "Found ") { + continue + } + + // Extract TYPE (column 0) and TARGET_DISK (column 1) from the table output + // The format is: TYPE TARGET_DISK PROVIDER GPU COUNT ... + fields := strings.Fields(line) + if len(fields) > 0 { + instanceType := fields[0] + // Validate it looks like an instance type (contains letters and possibly numbers/dots) + if isValidInstanceType(instanceType) { + spec := InstanceSpec{Type: instanceType} + // Parse TARGET_DISK if present (column 1) + if len(fields) > 1 { + if diskGB, err := strconv.ParseFloat(fields[1], 64); err == nil && diskGB > 0 { + spec.DiskGB = diskGB + } + } + specs = append(specs, spec) + } + } + } + + return specs +} + +// isValidInstanceType checks if a string looks like a valid instance type. +// Instance types typically have formats like: g5.xlarge, p4d.24xlarge, n1-highmem-4:nvidia-tesla-t4:1 +func isValidInstanceType(s string) bool { + if len(s) < 2 { + return false + } + var hasLetter, hasDigit bool + for _, c := range s { + if unicode.IsLetter(c) { + hasLetter = true + } else if unicode.IsDigit(c) { + hasDigit = true + } + if hasLetter && hasDigit { + return true + } + } + return hasLetter && hasDigit +} + +// isStdoutPiped returns true if stdout is being piped (not a terminal) +func isStdoutPiped() bool { + stat, _ := os.Stdout.Stat() + return (stat.Mode() & os.ModeCharDevice) == 0 +} + +// formatInstanceSpecs formats a slice of InstanceSpec for display +func formatInstanceSpecs(specs []InstanceSpec) string { + var parts []string + for _, spec := range specs { + if spec.DiskGB > 0 { + parts = append(parts, fmt.Sprintf("%s (%.0fGB disk)", spec.Type, spec.DiskGB)) + } else { + parts = append(parts, spec.Type) + } + } + return strings.Join(parts, ", ") +} + +// createContext holds shared state for instance creation +type createContext struct { + t *terminal.Terminal + store GPUCreateStore + opts GPUCreateOptions + org *entity.Organization + user *entity.User + allInstanceTypes *gpusearch.AllInstanceTypesResponse + piped bool + logf func(format string, a ...interface{}) +} + +// newCreateContext initializes the context for instance creation +func newCreateContext(t *terminal.Terminal, store GPUCreateStore, opts GPUCreateOptions) (*createContext, error) { + piped := isStdoutPiped() + + ctx := &createContext{ + t: t, + store: store, + opts: opts, + piped: piped, + } + + // Set up logging function + ctx.logf = func(format string, a ...interface{}) { + if piped { + fmt.Fprintf(os.Stderr, format, a...) + } else { + t.Vprintf(format, a...) + } + } + + // Get user + user, err := store.GetCurrentUser() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + ctx.user = user + + // Get organization + org, err := store.GetActiveOrganizationOrDefault() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if org == nil { + return nil, breverrors.NewValidationError("no organization found") + } + ctx.org = org + + // Fetch instance types with workspace groups + allInstanceTypes, err := store.GetAllInstanceTypesWithWorkspaceGroups(org.ID) + if err != nil { + ctx.logf("Warning: could not fetch instance types with workspace groups: %s\n", err.Error()) + ctx.logf("Falling back to default workspace group\n") + } + ctx.allInstanceTypes = allInstanceTypes + + return ctx, nil +} + +// typeCreateResult holds the result of creating instances with a single type +type typeCreateResult struct { + successes []*entity.Workspace + hadFailure bool + fatalError error +} + +// createInstancesWithType attempts to create instances using a specific type +func (c *createContext) createInstancesWithType(spec InstanceSpec, startIdx, count int) typeCreateResult { + result := typeCreateResult{} + + var mu sync.Mutex + var wg sync.WaitGroup + + // Determine worker count + workerCount := c.opts.Parallel + if workerCount > count { + workerCount = count + } + + // Create index channel + indicesToCreate := make(chan int, count) + for i := startIdx; i < startIdx+count; i++ { + indicesToCreate <- i + } + close(indicesToCreate) + + for i := 0; i < workerCount; i++ { + wg.Add(1) + go func(workerID int) { + defer wg.Done() + c.runWorker(workerID, spec, indicesToCreate, &result, &mu) + }(i) + } + + wg.Wait() + return result +} + +// runWorker processes instance creation requests from the channel +func (c *createContext) runWorker(workerID int, spec InstanceSpec, indices <-chan int, result *typeCreateResult, mu *sync.Mutex) { + for idx := range indices { + // Check if we've already created enough + mu.Lock() + if len(result.successes) >= c.opts.Count { + mu.Unlock() + return + } + mu.Unlock() + + // Determine instance name + instanceName := c.opts.Name + if c.opts.Count > 1 { + instanceName = fmt.Sprintf("%s-%d", c.opts.Name, idx+1) + } + + c.logf("[Worker %d] Trying %s for instance '%s'...\n", workerID+1, spec.Type, instanceName) + + // Attempt to create the workspace + workspace, err := createWorkspaceWithType(c.store, c.org.ID, instanceName, spec.Type, spec.DiskGB, c.user, c.allInstanceTypes, c.opts.StartupScript) + + mu.Lock() + if err != nil { + c.handleCreateError(workerID, spec.Type, instanceName, err, result) + } else { + c.handleCreateSuccess(workerID, spec.Type, instanceName, workspace, result) + } + mu.Unlock() + } +} + +// handleCreateError processes a failed instance creation (must be called with lock held) +func (c *createContext) handleCreateError(workerID int, instanceType, instanceName string, err error, result *typeCreateResult) { + errStr := err.Error() + if c.piped { + c.logf("[Worker %d] %s Failed: %s\n", workerID+1, instanceType, errStr) + } else { + c.logf("[Worker %d] %s Failed: %s\n", workerID+1, c.t.Yellow(instanceType), errStr) + } + + result.hadFailure = true + if strings.Contains(errStr, "duplicate workspace") { + result.fatalError = fmt.Errorf("workspace '%s' already exists. Use a different name or delete the existing workspace", instanceName) + } +} + +// handleCreateSuccess processes a successful instance creation (must be called with lock held) +func (c *createContext) handleCreateSuccess(workerID int, instanceType, instanceName string, workspace *entity.Workspace, result *typeCreateResult) { + if c.piped { + c.logf("[Worker %d] %s Success! Created instance '%s'\n", workerID+1, instanceType, instanceName) + } else { + c.logf("[Worker %d] %s Success! Created instance '%s'\n", workerID+1, c.t.Green(instanceType), instanceName) + } + result.successes = append(result.successes, workspace) +} + +// cleanupExtraInstances deletes instances beyond the requested count +func (c *createContext) cleanupExtraInstances(workspaces []*entity.Workspace) []*entity.Workspace { + if len(workspaces) <= c.opts.Count { + return workspaces + } + + extras := workspaces[c.opts.Count:] + c.logf("\nCleaning up %d extra instance(s)...\n", len(extras)) + + for _, ws := range extras { + c.logf(" Deleting %s...", ws.Name) + _, err := c.store.DeleteWorkspace(ws.ID) + if err != nil { + c.logf(" Failed\n") + } else { + c.logf(" Done\n") + } + } + + return workspaces[:c.opts.Count] +} + +// waitForInstances waits for all instances to be ready +func (c *createContext) waitForInstances(workspaces []*entity.Workspace) { + if c.opts.Detached { + return + } + + c.logf("\nWaiting for instance(s) to be ready...\n") + c.logf("You can safely ctrl+c to exit\n") + + for _, ws := range workspaces { + err := pollUntilReady(c.t, ws.ID, c.store, c.opts.Timeout, c.piped, c.logf) + if err != nil { + c.logf(" %s: Timeout waiting for ready state\n", ws.Name) + } + } +} + +// printSummary outputs the final creation summary +func (c *createContext) printSummary(workspaces []*entity.Workspace) { + if c.piped { + for _, ws := range workspaces { + fmt.Println(ws.Name) + } + return + } + + fmt.Print("\n") + c.t.Vprint(c.t.Green(fmt.Sprintf("Successfully created %d instance(s)!\n\n", len(workspaces)))) + + for _, ws := range workspaces { + c.t.Vprintf("Instance: %s\n", c.t.Green(ws.Name)) + c.t.Vprintf(" ID: %s\n", ws.ID) + c.t.Vprintf(" Type: %s\n", ws.InstanceType) + displayConnectBreadCrumb(c.t, ws) + fmt.Print("\n") + } +} + +// RunGPUCreate executes the GPU create with retry logic +func RunGPUCreate(t *terminal.Terminal, gpuCreateStore GPUCreateStore, opts GPUCreateOptions) error { + ctx, err := newCreateContext(t, gpuCreateStore, opts) + if err != nil { + return err + } + + ctx.logf("Attempting to create %d instance(s) with %d parallel attempts\n", opts.Count, opts.Parallel) + ctx.logf("Instance types to try: %s\n\n", formatInstanceSpecs(opts.InstanceTypes)) + + var successfulWorkspaces []*entity.Workspace + + // Try each instance type in order + for _, spec := range opts.InstanceTypes { + if len(successfulWorkspaces) >= opts.Count { + break + } + + remaining := opts.Count - len(successfulWorkspaces) + ctx.logf("Trying %s for %d instance(s)...\n", spec.Type, remaining) + + result := ctx.createInstancesWithType(spec, len(successfulWorkspaces), remaining) + successfulWorkspaces = append(successfulWorkspaces, result.successes...) + + if result.fatalError != nil { + ctx.logf("\nError: %s\n", result.fatalError.Error()) + break + } + + if !result.hadFailure && len(successfulWorkspaces) >= opts.Count { + break + } + + if len(successfulWorkspaces) < opts.Count && result.hadFailure { + ctx.logf("\nType %s had failures, trying next type...\n\n", spec.Type) + } + } + + // Check if we created enough instances + if len(successfulWorkspaces) < opts.Count { + ctx.logf("\nWarning: Only created %d/%d instances\n", len(successfulWorkspaces), opts.Count) + if len(successfulWorkspaces) > 0 { + ctx.logf("Successfully created instances:\n") + for _, ws := range successfulWorkspaces { + ctx.logf(" - %s (ID: %s)\n", ws.Name, ws.ID) + } + } + return breverrors.NewValidationError(fmt.Sprintf("could only create %d/%d instances", len(successfulWorkspaces), opts.Count)) + } + + successfulWorkspaces = ctx.cleanupExtraInstances(successfulWorkspaces) + ctx.waitForInstances(successfulWorkspaces) + ctx.printSummary(successfulWorkspaces) + + return nil +} + +// createWorkspaceWithType creates a workspace with the specified instance type +func createWorkspaceWithType(gpuCreateStore GPUCreateStore, orgID, name, instanceType string, diskGB float64, user *entity.User, allInstanceTypes *gpusearch.AllInstanceTypesResponse, startupScript string) (*entity.Workspace, error) { + clusterID := config.GlobalConfig.GetDefaultClusterID() + cwOptions := store.NewCreateWorkspacesOptions(clusterID, name) + cwOptions.WithInstanceType(instanceType) + cwOptions = resolveWorkspaceUserOptions(cwOptions, user) + + // Set disk size if specified (convert GB to Gi format) + if diskGB > 0 { + cwOptions.DiskStorage = fmt.Sprintf("%.0fGi", diskGB) + } + + // Look up the workspace group ID for this instance type + if allInstanceTypes != nil { + workspaceGroupID := allInstanceTypes.GetWorkspaceGroupID(instanceType) + if workspaceGroupID != "" { + cwOptions.WorkspaceGroupID = workspaceGroupID + } + } + + // Set startup script if provided using VMBuild lifecycle script + if startupScript != "" { + cwOptions.VMBuild = &store.VMBuild{ + ForceJupyterInstall: true, + LifeCycleScriptAttr: &store.LifeCycleScriptAttr{ + Script: startupScript, + }, + } + } + + workspace, err := gpuCreateStore.CreateWorkspace(orgID, cwOptions) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + return workspace, nil +} + +// resolveWorkspaceUserOptions sets workspace template and class based on user type +func resolveWorkspaceUserOptions(options *store.CreateWorkspacesOptions, user *entity.User) *store.CreateWorkspacesOptions { + if options.WorkspaceTemplateID == "" { + if featureflag.IsAdmin(user.GlobalUserType) { + options.WorkspaceTemplateID = store.DevWorkspaceTemplateID + } else { + options.WorkspaceTemplateID = store.UserWorkspaceTemplateID + } + } + if options.WorkspaceClassID == "" { + if featureflag.IsAdmin(user.GlobalUserType) { + options.WorkspaceClassID = store.DevWorkspaceClassID + } else { + options.WorkspaceClassID = store.UserWorkspaceClassID + } + } + return options +} + +// pollUntilReady waits for a workspace to reach the running state +func pollUntilReady(t *terminal.Terminal, wsID string, gpuCreateStore GPUCreateStore, timeout time.Duration, piped bool, logf func(string, ...interface{})) error { + deadline := time.Now().Add(timeout) + + for time.Now().Before(deadline) { + ws, err := gpuCreateStore.GetWorkspace(wsID) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if ws.Status == entity.Running { + if piped { + logf(" %s: Ready\n", ws.Name) + } else { + logf(" %s: %s\n", ws.Name, t.Green("Ready")) + } + return nil + } + + if ws.Status == entity.Failure { + return breverrors.NewValidationError(fmt.Sprintf("instance %s failed", ws.Name)) + } + + time.Sleep(5 * time.Second) + } + + return breverrors.NewValidationError("timeout waiting for instance to be ready") +} + +// displayConnectBreadCrumb shows connection instructions +func displayConnectBreadCrumb(t *terminal.Terminal, workspace *entity.Workspace) { + t.Vprintf(" Connect:\n") + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("brev open %s", workspace.Name))) + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("brev shell %s", workspace.Name))) +} diff --git a/pkg/cmd/gpucreate/gpucreate_test.go b/pkg/cmd/gpucreate/gpucreate_test.go new file mode 100644 index 00000000..1609616b --- /dev/null +++ b/pkg/cmd/gpucreate/gpucreate_test.go @@ -0,0 +1,418 @@ +package gpucreate + +import ( + "strings" + "testing" + + "github.com/brevdev/brev-cli/pkg/cmd/gpusearch" + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/store" + "github.com/stretchr/testify/assert" +) + +// MockGPUCreateStore is a mock implementation of GPUCreateStore for testing +type MockGPUCreateStore struct { + User *entity.User + Org *entity.Organization + Workspaces map[string]*entity.Workspace + CreateError error + CreateErrorTypes map[string]error // Errors for specific instance types + DeleteError error + CreatedWorkspaces []*entity.Workspace + DeletedWorkspaceIDs []string +} + +func NewMockGPUCreateStore() *MockGPUCreateStore { + return &MockGPUCreateStore{ + User: &entity.User{ + ID: "user-123", + GlobalUserType: "Standard", + }, + Org: &entity.Organization{ + ID: "org-123", + Name: "test-org", + }, + Workspaces: make(map[string]*entity.Workspace), + CreateErrorTypes: make(map[string]error), + CreatedWorkspaces: []*entity.Workspace{}, + DeletedWorkspaceIDs: []string{}, + } +} + +func (m *MockGPUCreateStore) GetCurrentUser() (*entity.User, error) { + return m.User, nil +} + +func (m *MockGPUCreateStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { + return m.Org, nil +} + +func (m *MockGPUCreateStore) GetWorkspace(workspaceID string) (*entity.Workspace, error) { + if ws, ok := m.Workspaces[workspaceID]; ok { + return ws, nil + } + return &entity.Workspace{ + ID: workspaceID, + Status: entity.Running, + }, nil +} + +func (m *MockGPUCreateStore) CreateWorkspace(organizationID string, options *store.CreateWorkspacesOptions) (*entity.Workspace, error) { + // Check for type-specific errors first + if err, ok := m.CreateErrorTypes[options.InstanceType]; ok { + return nil, err + } + + if m.CreateError != nil { + return nil, m.CreateError + } + + ws := &entity.Workspace{ + ID: "ws-" + options.Name, + Name: options.Name, + InstanceType: options.InstanceType, + Status: entity.Running, + } + m.Workspaces[ws.ID] = ws + m.CreatedWorkspaces = append(m.CreatedWorkspaces, ws) + return ws, nil +} + +func (m *MockGPUCreateStore) DeleteWorkspace(workspaceID string) (*entity.Workspace, error) { + if m.DeleteError != nil { + return nil, m.DeleteError + } + + m.DeletedWorkspaceIDs = append(m.DeletedWorkspaceIDs, workspaceID) + ws := m.Workspaces[workspaceID] + delete(m.Workspaces, workspaceID) + return ws, nil +} + +func (m *MockGPUCreateStore) GetWorkspaceByNameOrID(orgID string, nameOrID string) ([]entity.Workspace, error) { + return []entity.Workspace{}, nil +} + +func (m *MockGPUCreateStore) GetAllInstanceTypesWithWorkspaceGroups(orgID string) (*gpusearch.AllInstanceTypesResponse, error) { + return nil, nil +} + +func (m *MockGPUCreateStore) GetInstanceTypes() (*gpusearch.InstanceTypesResponse, error) { + // Return a default set of instance types for testing + return &gpusearch.InstanceTypesResponse{ + Items: []gpusearch.InstanceType{ + { + Type: "g5.xlarge", + SupportedGPUs: []gpusearch.GPU{ + {Count: 1, Name: "A10G", Manufacturer: "NVIDIA", Memory: "24GiB"}, + }, + SupportedStorage: []gpusearch.Storage{ + {Size: "500GiB"}, + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: gpusearch.BasePrice{Currency: "USD", Amount: "1.006"}, + EstimatedDeployTime: "5m0s", + }, + }, + }, nil +} + +func TestIsValidInstanceType(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"Valid AWS instance type", "g5.xlarge", true}, + {"Valid AWS large instance", "p4d.24xlarge", true}, + {"Valid GCP instance type", "n1-highmem-4:nvidia-tesla-t4:1", true}, + {"Single letter", "a", false}, + {"No numbers", "xlarge", false}, + {"No letters", "12345", false}, + {"Empty string", "", false}, + {"Single character", "1", false}, + {"Valid with colon", "g5:xlarge", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isValidInstanceType(tt.input) + assert.Equal(t, tt.expected, result, "Validation failed for %s", tt.input) + }) + } +} + +func TestParseInstanceTypesFromFlag(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + {"Single type", "g5.xlarge", []string{"g5.xlarge"}}, + {"Multiple types comma separated", "g5.xlarge,g5.2xlarge,p3.2xlarge", []string{"g5.xlarge", "g5.2xlarge", "p3.2xlarge"}}, + {"With spaces", "g5.xlarge, g5.2xlarge, p3.2xlarge", []string{"g5.xlarge", "g5.2xlarge", "p3.2xlarge"}}, + {"Empty string", "", []string{}}, + {"Only spaces", " ", []string{}}, + {"Trailing comma", "g5.xlarge,", []string{"g5.xlarge"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseInstanceTypes(tt.input) + assert.NoError(t, err) + + // Handle nil vs empty slice + if len(tt.expected) == 0 { + assert.Empty(t, result) + } else { + // Compare just the Type field of each InstanceSpec + var resultTypes []string + for _, spec := range result { + resultTypes = append(resultTypes, spec.Type) + } + assert.Equal(t, tt.expected, resultTypes) + } + }) + } +} + +func TestGPUCreateOptions(t *testing.T) { + opts := GPUCreateOptions{ + Name: "my-instance", + InstanceTypes: []InstanceSpec{ + {Type: "g5.xlarge", DiskGB: 500}, + {Type: "g5.2xlarge"}, + }, + Count: 2, + Parallel: 3, + Detached: true, + } + + assert.Equal(t, "my-instance", opts.Name) + assert.Len(t, opts.InstanceTypes, 2) + assert.Equal(t, "g5.xlarge", opts.InstanceTypes[0].Type) + assert.Equal(t, 500.0, opts.InstanceTypes[0].DiskGB) + assert.Equal(t, "g5.2xlarge", opts.InstanceTypes[1].Type) + assert.Equal(t, 0.0, opts.InstanceTypes[1].DiskGB) + assert.Equal(t, 2, opts.Count) + assert.Equal(t, 3, opts.Parallel) + assert.True(t, opts.Detached) +} + +func TestResolveWorkspaceUserOptionsStandard(t *testing.T) { + user := &entity.User{ + ID: "user-123", + GlobalUserType: "Standard", + } + + options := &store.CreateWorkspacesOptions{} + result := resolveWorkspaceUserOptions(options, user) + + assert.Equal(t, store.UserWorkspaceTemplateID, result.WorkspaceTemplateID) + assert.Equal(t, store.UserWorkspaceClassID, result.WorkspaceClassID) +} + +func TestResolveWorkspaceUserOptionsAdmin(t *testing.T) { + user := &entity.User{ + ID: "user-123", + GlobalUserType: "Admin", + } + + options := &store.CreateWorkspacesOptions{} + result := resolveWorkspaceUserOptions(options, user) + + assert.Equal(t, store.DevWorkspaceTemplateID, result.WorkspaceTemplateID) + assert.Equal(t, store.DevWorkspaceClassID, result.WorkspaceClassID) +} + +func TestResolveWorkspaceUserOptionsPreserveExisting(t *testing.T) { + user := &entity.User{ + ID: "user-123", + GlobalUserType: "Standard", + } + + options := &store.CreateWorkspacesOptions{ + WorkspaceTemplateID: "custom-template", + WorkspaceClassID: "custom-class", + } + result := resolveWorkspaceUserOptions(options, user) + + // Should preserve existing values + assert.Equal(t, "custom-template", result.WorkspaceTemplateID) + assert.Equal(t, "custom-class", result.WorkspaceClassID) +} + +func TestMockGPUCreateStoreBasics(t *testing.T) { + mock := NewMockGPUCreateStore() + + user, err := mock.GetCurrentUser() + assert.NoError(t, err) + assert.Equal(t, "user-123", user.ID) + + org, err := mock.GetActiveOrganizationOrDefault() + assert.NoError(t, err) + assert.Equal(t, "org-123", org.ID) +} + +func TestMockGPUCreateStoreCreateWorkspace(t *testing.T) { + mock := NewMockGPUCreateStore() + + options := store.NewCreateWorkspacesOptions("cluster-1", "test-instance") + options.WithInstanceType("g5.xlarge") + + ws, err := mock.CreateWorkspace("org-123", options) + assert.NoError(t, err) + assert.Equal(t, "test-instance", ws.Name) + assert.Equal(t, "g5.xlarge", ws.InstanceType) + assert.Len(t, mock.CreatedWorkspaces, 1) +} + +func TestMockGPUCreateStoreDeleteWorkspace(t *testing.T) { + mock := NewMockGPUCreateStore() + + // First create a workspace + options := store.NewCreateWorkspacesOptions("cluster-1", "test-instance") + ws, _ := mock.CreateWorkspace("org-123", options) + + // Then delete it + _, err := mock.DeleteWorkspace(ws.ID) + assert.NoError(t, err) + assert.Contains(t, mock.DeletedWorkspaceIDs, ws.ID) +} + +func TestMockGPUCreateStoreTypeSpecificError(t *testing.T) { + mock := NewMockGPUCreateStore() + mock.CreateErrorTypes["g5.xlarge"] = assert.AnError + + options := store.NewCreateWorkspacesOptions("cluster-1", "test-instance") + options.WithInstanceType("g5.xlarge") + + _, err := mock.CreateWorkspace("org-123", options) + assert.Error(t, err) + + // Different type should work + options2 := store.NewCreateWorkspacesOptions("cluster-1", "test-instance-2") + options2.WithInstanceType("g5.2xlarge") + + ws, err := mock.CreateWorkspace("org-123", options2) + assert.NoError(t, err) + assert.NotNil(t, ws) +} + +func TestGetDefaultInstanceTypes(t *testing.T) { + mock := NewMockGPUCreateStore() + + // Get default instance types - the mock returns a g5.xlarge which has: + // - 24GB VRAM (>= 20GB total VRAM requirement) + // - 500GB disk (>= 500GB requirement) + // - A10G GPU = 8.6 capability (>= 8.0 requirement) + // - 5m boot time (< 7m requirement) + specs, err := getDefaultInstanceTypes(mock) + assert.NoError(t, err) + assert.Len(t, specs, 1) + assert.Equal(t, "g5.xlarge", specs[0].Type) + assert.Equal(t, 500.0, specs[0].DiskGB) // Should use the instance's disk size +} + +func TestGetDefaultInstanceTypesFiltersOut(t *testing.T) { + // The mock returns a g5.xlarge which meets all requirements + mock := NewMockGPUCreateStore() + + specs, err := getDefaultInstanceTypes(mock) + assert.NoError(t, err) + // Should return the A10G instance which meets all requirements + assert.Len(t, specs, 1) + assert.Equal(t, "g5.xlarge", specs[0].Type) +} + +func TestParseInstanceTypesFromTableOutput(t *testing.T) { + // Simulated table output from brev gpus command + // Note: This tests the parsing logic, not actual stdin reading + tableLines := []string{ + "TYPE GPU COUNT VRAM/GPU TOTAL VRAM CAPABILITY VCPUs $/HR", + "g5.xlarge A10G 1 24 GB 24 GB 8.6 4 $1.01", + "g5.2xlarge A10G 1 24 GB 24 GB 8.6 8 $1.21", + "p4d.24xlarge A100 8 40 GB 320 GB 8.0 96 $32.77", + "", + "Found 3 GPU instance types", + } + + // Test parsing each line (simulating the scanner behavior) + var types []string + lineNum := 0 + for _, line := range tableLines { + lineNum++ + + // Skip header line + if lineNum == 1 && (strings.Contains(line, "TYPE") || strings.Contains(line, "GPU")) { + continue + } + + // Skip empty lines and summary + if line == "" || strings.HasPrefix(line, "Found ") { + continue + } + + // Extract first column + fields := strings.Fields(line) + if len(fields) > 0 && isValidInstanceType(fields[0]) { + types = append(types, fields[0]) + } + } + + assert.Len(t, types, 3) + assert.Contains(t, types, "g5.xlarge") + assert.Contains(t, types, "g5.2xlarge") + assert.Contains(t, types, "p4d.24xlarge") +} + +func TestParseJSONInput(t *testing.T) { + // Simulated JSON output from gpu-search --json + jsonInput := `[ + { + "type": "g5.xlarge", + "provider": "aws", + "gpu_name": "A10G", + "target_disk_gb": 1000 + }, + { + "type": "p4d.24xlarge", + "provider": "aws", + "gpu_name": "A100", + "target_disk_gb": 500 + }, + { + "type": "g6.xlarge", + "provider": "aws", + "gpu_name": "L4" + } + ]` + + specs, err := parseJSONInput(jsonInput) + assert.NoError(t, err) + assert.Len(t, specs, 3) + + // Check first instance with disk + assert.Equal(t, "g5.xlarge", specs[0].Type) + assert.Equal(t, 1000.0, specs[0].DiskGB) + + // Check second instance with different disk + assert.Equal(t, "p4d.24xlarge", specs[1].Type) + assert.Equal(t, 500.0, specs[1].DiskGB) + + // Check third instance without disk (should be 0) + assert.Equal(t, "g6.xlarge", specs[2].Type) + assert.Equal(t, 0.0, specs[2].DiskGB) +} + +func TestFormatInstanceSpecs(t *testing.T) { + specs := []InstanceSpec{ + {Type: "g5.xlarge", DiskGB: 1000}, + {Type: "p4d.24xlarge", DiskGB: 0}, + {Type: "g6.xlarge", DiskGB: 500}, + } + + result := formatInstanceSpecs(specs) + assert.Equal(t, "g5.xlarge (1000GB disk), p4d.24xlarge, g6.xlarge (500GB disk)", result) +} diff --git a/pkg/store/workspace.go b/pkg/store/workspace.go index 5190d313..ebb5b4a9 100644 --- a/pkg/store/workspace.go +++ b/pkg/store/workspace.go @@ -34,6 +34,17 @@ type ModifyWorkspaceRequest struct { InstanceType string `json:"instanceType,omitempty"` } +// LifeCycleScriptAttr holds the lifecycle script configuration +type LifeCycleScriptAttr struct { + Script string `json:"script,omitempty"` +} + +// VMBuild holds VM-specific build configuration +type VMBuild struct { + ForceJupyterInstall bool `json:"forceJupyterInstall,omitempty"` + LifeCycleScriptAttr *LifeCycleScriptAttr `json:"lifeCycleScriptAttr,omitempty"` +} + type CreateWorkspacesOptions struct { Name string `json:"name"` WorkspaceGroupID string `json:"workspaceGroupId"` @@ -57,6 +68,7 @@ type CreateWorkspacesOptions struct { DiskStorage string `json:"diskStorage"` BaseImage string `json:"baseImage"` VMOnlyMode bool `json:"vmOnlyMode"` + VMBuild *VMBuild `json:"vmBuild,omitempty"` PortMappings map[string]string `json:"portMappings"` Files interface{} `json:"files"` Labels interface{} `json:"labels"` @@ -88,6 +100,7 @@ var ( var DefaultApplicationList = []entity.Application{DefaultApplication} func NewCreateWorkspacesOptions(clusterID, name string) *CreateWorkspacesOptions { + isStoppable := false return &CreateWorkspacesOptions{ BaseImage: "", Description: "", @@ -95,12 +108,12 @@ func NewCreateWorkspacesOptions(clusterID, name string) *CreateWorkspacesOptions ExecsV1: &entity.ExecsV1{}, Files: nil, InstanceType: "", - IsStoppable: nil, + IsStoppable: &isStoppable, Labels: nil, LaunchJupyterOnStart: false, Name: name, - PortMappings: nil, - ReposV1: nil, + PortMappings: map[string]string{}, + ReposV1: &entity.ReposV1{}, VMOnlyMode: true, WorkspaceGroupID: "GCP", WorkspaceTemplateID: DefaultWorkspaceTemplateID,