Skip to content

Commit e053cdd

Browse files
experimental/ssh: clarify GPU compute provisioning during ssh connect startup
GPU_8xH100 serverless capacity takes ~10 minutes at P50 and ~30 minutes at P90 to acquire, but `ssh connect` gave up after a hard 10-minute startup timeout with an opaque error: Error: failed to ensure that ssh server is running: failed to submit and start ssh server job: timed out: waiting for task to start (current state: PENDING) Users read this as a service outage rather than compute still being provisioned (see the Zillow report in #remote-development-help). - Raise the startup timeout to 40 minutes when --accelerator is set, keeping 10 minutes otherwise. - Print an upfront notice that GPU provisioning can take 10-30 minutes, and reflect provisioning in the spinner text. - On startup timeout, append guidance to the error: the run ID and run page URL, that compute is likely still provisioning, and that the run was left in place so re-running the command connects once it starts. Co-authored-by: Isaac
1 parent afe30fe commit e053cdd

5 files changed

Lines changed: 97 additions & 16 deletions

File tree

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
### CLI
88
* Show a once-per-day notice after a command when a newer CLI release is available, with a link to the release and the upgrade command for the detected install method. Suppressed for non-interactive/CI runs, JSON output, the Databricks Runtime, and development builds, and can be disabled with `DATABRICKS_CLI_DISABLE_UPDATE_CHECK` ([#5470](https://github.com/databricks/cli/pull/5470)).
9+
* `ssh connect`: Increase the SSH server startup timeout from 10 to 40 minutes for GPU accelerators, show "Waiting for compute to start" while compute spins up (with a notice for GPU accelerators that provisioning can take upwards of 10 minutes), and explain on timeout that the job run was left in place so re-running the command connects once compute is available.
910

1011
### Bundles
1112
* Remove API enum values and types that are still in development from the `databricks-bundles` Python package; these were never accepted by the backend ([#5484](https://github.com/databricks/cli/pull/5484)).

experimental/ssh/cmd/connect.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ Connect to a dedicated cluster:
9090
if connectionName == "" && clusterID == "" && !proxyMode {
9191
connectionName = client.GenerateDefaultConnectionName(wsClient.Config.Host, accelerator)
9292
}
93+
startupTimeout := taskStartupTimeout
94+
if accelerator != "" {
95+
startupTimeout = gpuTaskStartupTimeout
96+
}
9397
opts := client.ClientOptions{
9498
Profile: wsClient.Config.Profile,
9599
ClusterID: clusterID,
@@ -103,7 +107,7 @@ Connect to a dedicated cluster:
103107
HandoverTimeout: handoverTimeout,
104108
ReleasesDir: releasesDir,
105109
ServerTimeout: max(serverTimeout, shutdownDelay),
106-
TaskStartupTimeout: taskStartupTimeout,
110+
TaskStartupTimeout: startupTimeout,
107111
AutoStartCluster: autoStartCluster,
108112
ClientPublicKeyName: clientPublicKeyName,
109113
ClientPrivateKeyName: clientPrivateKeyName,

experimental/ssh/cmd/constants.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@ const (
99
defaultHandoverTimeout = 30 * time.Minute
1010
defaultEnvironmentVersion = 4
1111

12-
serverTimeout = 24 * time.Hour
13-
taskStartupTimeout = 10 * time.Minute
14-
serverPortRange = 100
15-
serverConfigDir = ".ssh-tunnel"
16-
serverPrivateKeyName = "server-private-key"
17-
serverPublicKeyName = "server-public-key"
18-
clientPrivateKeyName = "client-private-key"
19-
clientPublicKeyName = "client-public-key"
12+
serverTimeout = 24 * time.Hour
13+
taskStartupTimeout = 10 * time.Minute
14+
// Serverless GPU capacity is acquired on demand: launch latency for GPU_8xH100 is
15+
// ~10 minutes at P50 and ~30 minutes at P90, so GPU sessions need a much longer
16+
// startup timeout than the default to avoid giving up on runs that would succeed.
17+
gpuTaskStartupTimeout = 40 * time.Minute
18+
serverPortRange = 100
19+
serverConfigDir = ".ssh-tunnel"
20+
serverPrivateKeyName = "server-private-key"
21+
serverPublicKeyName = "server-public-key"
22+
clientPrivateKeyName = "client-private-key"
23+
clientPublicKeyName = "client-public-key"
2024
)

experimental/ssh/internal/client/client.go

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
578578
cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", waiter.RunId))
579579

580580
// Return the run ID even on error so callers can fetch the run's failure details.
581-
return waiter.RunId, waitForJobToStart(ctx, client, waiter.RunId, opts.TaskStartupTimeout)
581+
return waiter.RunId, waitForJobToStart(ctx, client, waiter.RunId, opts)
582582
}
583583

584584
func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, clusterID string, opts ClientOptions) error {
@@ -642,7 +642,7 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient,
642642
sp := cmdio.NewSpinner(ctx, cmdio.WithElapsedTime())
643643
defer sp.Close()
644644
if autoStart {
645-
sp.Update("Ensuring the cluster is running...")
645+
sp.Update("Waiting for compute to start...")
646646
err := client.Clusters.EnsureClusterIsRunning(ctx, clusterID)
647647
if err != nil {
648648
return fmt.Errorf("failed to ensure that the cluster is running: %w", err)
@@ -662,13 +662,21 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient,
662662

663663
// waitForJobToStart polls the task status until the SSH server task is in RUNNING state or terminates.
664664
// Returns an error if the task fails to start or if polling times out.
665-
func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, runID int64, taskStartupTimeout time.Duration) error {
665+
func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, runID int64, opts ClientOptions) error {
666+
waitingMessage := "Waiting for compute to start..."
667+
if opts.Accelerator != "" {
668+
// GPU capacity is acquired on demand and routinely takes 10+ minutes; without
669+
// this notice users assume a long PENDING wait means the service is down.
670+
cmdio.LogString(ctx, fmt.Sprintf("Waiting for %s compute to be provisioned. This can take upwards of 10 minutes depending on capacity...", opts.Accelerator))
671+
waitingMessage = fmt.Sprintf("Waiting for %s compute to be provisioned...", opts.Accelerator)
672+
}
673+
666674
sp := cmdio.NewSpinner(ctx, cmdio.WithElapsedTime())
667675
defer sp.Close()
668-
sp.Update("Starting SSH server...")
676+
sp.Update(waitingMessage)
669677
var prevState jobs.RunLifecycleStateV2State
670678

671-
_, err := retries.Poll(ctx, taskStartupTimeout, func() (*jobs.RunTask, *retries.Err) {
679+
_, err := retries.Poll(ctx, opts.TaskStartupTimeout, func() (*jobs.RunTask, *retries.Err) {
672680
run, err := client.Jobs.GetRun(ctx, jobs.GetRunRequest{
673681
RunId: runID,
674682
})
@@ -697,7 +705,7 @@ func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient,
697705

698706
// Update spinner if state changed
699707
if currentState != prevState {
700-
sp.Update(fmt.Sprintf("Starting SSH server... (task: %s)", currentState))
708+
sp.Update(fmt.Sprintf("%s (task: %s)", waitingMessage, currentState))
701709
prevState = currentState
702710
}
703711

@@ -716,9 +724,33 @@ func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient,
716724
return nil, retries.Continues(fmt.Sprintf("waiting for task to start (current state: %s)", currentState))
717725
})
718726

727+
// A startup timeout almost always means compute is still being provisioned (the task
728+
// never left PENDING), not an outage. The run is intentionally not cancelled: if
729+
// capacity arrives later the server starts, and re-running the command connects to it.
730+
if _, ok := errors.AsType[*retries.ErrTimedOut](err); ok {
731+
return fmt.Errorf("%w\n%s", err, describeStartupTimeout(ctx, client, runID, opts))
732+
}
719733
return err
720734
}
721735

736+
// describeStartupTimeout formats guidance for when the SSH server task did not reach RUNNING
737+
// within the startup timeout. It is best-effort: failures to fetch the run page URL are
738+
// silently ignored so the guidance can always be embedded in the returned error.
739+
func describeStartupTimeout(ctx context.Context, client *databricks.WorkspaceClient, runID int64, opts ClientOptions) string {
740+
var b strings.Builder
741+
fmt.Fprintf(&b, " The SSH server job (run ID: %d) did not start within %s; its compute is most likely still being provisioned.\n", runID, opts.TaskStartupTimeout)
742+
if opts.Accelerator != "" {
743+
fmt.Fprintf(&b, " %s capacity can take longer than this to acquire when demand is high.\n", opts.Accelerator)
744+
}
745+
runLocation := "in the workspace UI (Jobs & Pipelines > Job Runs)"
746+
if run, err := client.Jobs.GetRun(ctx, jobs.GetRunRequest{RunId: runID}); err == nil && run.RunPageUrl != "" {
747+
runLocation = "at " + run.RunPageUrl
748+
}
749+
fmt.Fprintf(&b, " The run was left in place and may still start: track it %s,\n", runLocation)
750+
fmt.Fprintf(&b, " then re-run this command to connect once the run is running, or cancel the run to give up.")
751+
return b.String()
752+
}
753+
722754
// maxRunFailureTraceBytes bounds how much of a failed run's error trace we print to the
723755
// terminal; the full output is always available via the run page URL.
724756
const maxRunFailureTraceBytes = 2000

experimental/ssh/internal/client/client_internal_test.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,48 @@ func TestWaitForJobToStartSurfacesFailure(t *testing.T) {
110110
api.EXPECT().GetRunOutput(mock.Anything, jobs.GetRunOutputRequest{RunId: 99}).Return(
111111
&jobs.RunOutput{}, nil)
112112

113-
err := waitForJobToStart(ctx, m.WorkspaceClient, 1, 30*time.Second)
113+
err := waitForJobToStart(ctx, m.WorkspaceClient, 1, ClientOptions{TaskStartupTimeout: 30 * time.Second})
114114
require.Error(t, err)
115115
assert.Contains(t, err.Error(), "ssh server bootstrap job failed")
116116
assert.Contains(t, err.Error(), "Could not reach driver of cluster 0605-x.")
117117
}
118+
119+
func TestWaitForJobToStartTimeoutExplainsPendingCompute(t *testing.T) {
120+
ctx := cmdio.MockDiscard(t.Context())
121+
m := mocks.NewMockWorkspaceClient(t)
122+
api := m.GetMockJobsAPI()
123+
// The run stays PENDING for the whole (tiny) startup timeout; the same response also
124+
// serves the post-timeout lookup of the run page URL.
125+
api.EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 1}).Return(&jobs.Run{
126+
RunId: 1,
127+
RunPageUrl: "https://example.test/run/1",
128+
Tasks: []jobs.RunTask{{
129+
TaskKey: sshServerTaskKey,
130+
Status: &jobs.RunStatus{State: jobs.RunLifecycleStateV2StatePending},
131+
}},
132+
}, nil)
133+
134+
err := waitForJobToStart(ctx, m.WorkspaceClient, 1, ClientOptions{
135+
TaskStartupTimeout: 10 * time.Millisecond,
136+
Accelerator: "GPU_8xH100",
137+
})
138+
require.Error(t, err)
139+
assert.Contains(t, err.Error(), "current state: PENDING")
140+
assert.Contains(t, err.Error(), "did not start within 10ms")
141+
assert.Contains(t, err.Error(), "still being provisioned")
142+
assert.Contains(t, err.Error(), "GPU_8xH100 capacity can take longer")
143+
assert.Contains(t, err.Error(), "https://example.test/run/1")
144+
assert.Contains(t, err.Error(), "re-run this command")
145+
}
146+
147+
func TestDescribeStartupTimeoutWithoutRunPageURL(t *testing.T) {
148+
ctx := cmdio.MockDiscard(t.Context())
149+
m := mocks.NewMockWorkspaceClient(t)
150+
api := m.GetMockJobsAPI()
151+
api.EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 1}).Return(nil, assert.AnError)
152+
153+
out := describeStartupTimeout(ctx, m.WorkspaceClient, 1, ClientOptions{TaskStartupTimeout: 10 * time.Minute})
154+
assert.Contains(t, out, "run ID: 1")
155+
assert.Contains(t, out, "did not start within 10m0s")
156+
assert.Contains(t, out, "in the workspace UI")
157+
}

0 commit comments

Comments
 (0)