From 13f01893d355ef16fe12170e4a40725b06c91c29 Mon Sep 17 00:00:00 2001 From: Gourab Singha Date: Tue, 23 Jun 2026 00:47:53 +0530 Subject: [PATCH] fix: parse WaitTimeout and use it as poll timeout in ExecuteAndWait --- service/sql/ext_utilities.go | 10 +++- service/sql/ext_utilities_test.go | 86 +++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 2 deletions(-) create mode 100644 service/sql/ext_utilities_test.go diff --git a/service/sql/ext_utilities.go b/service/sql/ext_utilities.go index aedc83cf0..779358d68 100644 --- a/service/sql/ext_utilities.go +++ b/service/sql/ext_utilities.go @@ -29,8 +29,14 @@ func (a *StatementExecutionAPI) ExecuteAndWait(ctx context.Context, request Exec } return nil, fmt.Errorf("%s", msg) default: - // TODO: parse request.WaitTimeout and use it here - return retries.Poll[StatementResponse](ctx, 20*time.Minute, + timeout := 20 * time.Minute + if request.WaitTimeout != "" { + wTimeout, err := time.ParseDuration(request.WaitTimeout) + if err == nil && wTimeout > 0 { + timeout = wTimeout + } + } + return retries.Poll[StatementResponse](ctx, timeout, func() (*StatementResponse, *retries.Err) { res, err := a.GetStatementByStatementId(ctx, immediateResponse.StatementId) if err != nil { diff --git a/service/sql/ext_utilities_test.go b/service/sql/ext_utilities_test.go new file mode 100644 index 000000000..72fb87f3a --- /dev/null +++ b/service/sql/ext_utilities_test.go @@ -0,0 +1,86 @@ +package sql + +import ( + "context" + "testing" + "time" + + "github.com/databricks/databricks-sdk-go/qa" +) + +func TestExecuteAndWait_DefaultTimeout(t *testing.T) { + client, server := qa.HTTPFixtures{ + { + Method: "POST", + Resource: "/api/2.0/sql/statements", + Response: StatementResponse{ + StatementId: "123", + Status: &StatementStatus{ + State: StatementStateSucceeded, + }, + }, + }, + }.Client(t) + defer server.Close() + + ctx := context.Background() + api := NewStatementExecution(client) + resp, err := api.ExecuteAndWait(ctx, ExecuteStatementRequest{ + WarehouseId: "wh-1", + Statement: "SELECT 1", + }) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatementId != "123" { + t.Errorf("expected statement id 123, got %s", resp.StatementId) + } +} + +func TestExecuteAndWait_CustomTimeoutExceeded(t *testing.T) { + client, server := qa.HTTPFixtures{ + { + Method: "POST", + Resource: "/api/2.0/sql/statements", + Response: StatementResponse{ + StatementId: "123", + Status: &StatementStatus{ + State: StatementStatePending, + }, + }, + }, + { + Method: "GET", + ReuseRequest: true, + Resource: "/api/2.0/sql/statements/123", + Response: StatementResponse{ + StatementId: "123", + Status: &StatementStatus{ + State: StatementStateRunning, + }, + }, + }, + }.Client(t) + defer server.Close() + + ctx := context.Background() + api := NewStatementExecution(client) + + start := time.Now() + _, err := api.ExecuteAndWait(ctx, ExecuteStatementRequest{ + WarehouseId: "wh-1", + Statement: "SELECT 1", + WaitTimeout: "500ms", + }) + elapsed := time.Since(start) + + if err == nil { + t.Fatal("expected timeout error, got nil") + } + + // We expect the polling to timeout around 500ms. + // Allow some buffer (e.g., up to 2s) to prevent transient test failures. + if elapsed > 2*time.Second { + t.Errorf("expected poll to timeout quickly, took %s", elapsed) + } +}