diff --git a/internal/github/client.go b/internal/github/client.go index 18f9840..bb04f72 100644 --- a/internal/github/client.go +++ b/internal/github/client.go @@ -5,11 +5,14 @@ import ( "fmt" "os/exec" "strings" + "time" gh "github.com/google/go-github/v75/github" "golang.org/x/oauth2" ) +const apiTimeout = 30 * time.Second + // Client wraps go-github with auth from `gh auth token`. type Client struct { gh *gh.Client @@ -17,7 +20,7 @@ type Client struct { // NewClient creates a GitHub client using the token from `gh auth token`. func NewClient(ctx context.Context) (*Client, error) { - token, err := ghAuthToken() + token, err := ghAuthToken(ctx) if err != nil { return nil, fmt.Errorf("getting GitHub token: %w", err) } @@ -30,10 +33,15 @@ func NewClient(ctx context.Context) (*Client, error) { } // ghAuthToken runs `gh auth token` and returns the token string. -func ghAuthToken() (string, error) { - cmd := exec.Command("gh", "auth", "token") +func ghAuthToken(ctx context.Context) (string, error) { + ctx, cancel := withTimeout(ctx) + defer cancel() + cmd := exec.CommandContext(ctx, "gh", "auth", "token") out, err := cmd.Output() if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return "", fmt.Errorf("gh auth token timed out after %s", apiTimeout) + } return "", fmt.Errorf("gh auth token failed: %s (is gh CLI installed and authenticated?)", ghError(err)) } return strings.TrimSpace(string(out)), nil diff --git a/internal/github/queries.go b/internal/github/queries.go index e947236..09bd6fe 100644 --- a/internal/github/queries.go +++ b/internal/github/queries.go @@ -8,6 +8,15 @@ import ( "strings" ) +// withTimeout returns a context with apiTimeout applied, unless the caller +// already set a deadline. +func withTimeout(ctx context.Context) (context.Context, context.CancelFunc) { + if _, ok := ctx.Deadline(); ok { + return ctx, func() {} + } + return context.WithTimeout(ctx, apiTimeout) +} + // ghError extracts stderr from an exec.ExitError for better error messages. func ghError(err error) string { if ee, ok := err.(*exec.ExitError); ok && len(ee.Stderr) > 0 { @@ -50,9 +59,14 @@ type ApprovedPR struct { // GetCurrentUser returns the authenticated GitHub user's login. func GetCurrentUser(ctx context.Context) (string, error) { + ctx, cancel := withTimeout(ctx) + defer cancel() cmd := exec.CommandContext(ctx, "gh", "api", "user", "--jq", ".login") out, err := cmd.Output() if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return "", fmt.Errorf("fetching current user timed out after %s", apiTimeout) + } return "", fmt.Errorf("fetching current user: %s", ghError(err)) } return strings.TrimSpace(string(out)), nil @@ -61,6 +75,8 @@ func GetCurrentUser(ctx context.Context) (string, error) { // GetReviewRequests fetches PRs where the user is a requested reviewer, // including re-reviews. Uses GraphQL via `gh api graphql`. func GetReviewRequests(ctx context.Context, repoFilter string) ([]ReviewRequest, error) { + ctx, cancel := withTimeout(ctx) + defer cancel() query := `query($q1: String!, $q2: String!) { requested: search(query: $q1, type: ISSUE, first: 50) { nodes { @@ -103,6 +119,9 @@ func GetReviewRequests(ctx context.Context, repoFilter string) ([]ReviewRequest, ) out, err := cmd.Output() if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return nil, fmt.Errorf("review requests query timed out after %s", apiTimeout) + } return nil, fmt.Errorf("GraphQL query failed: %s", ghError(err)) } @@ -139,6 +158,8 @@ func GetReviewRequests(ctx context.Context, repoFilter string) ([]ReviewRequest, // GetApprovedUnmerged fetches the user's own PRs that are approved but not yet merged. func GetApprovedUnmerged(ctx context.Context, repoFilter string) ([]ApprovedPR, error) { + ctx, cancel := withTimeout(ctx) + defer cancel() query := `query($q: String!) { search(query: $q, type: ISSUE, first: 50) { nodes { @@ -168,6 +189,9 @@ func GetApprovedUnmerged(ctx context.Context, repoFilter string) ([]ApprovedPR, ) out, err := cmd.Output() if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return nil, fmt.Errorf("approved PRs query timed out after %s", apiTimeout) + } return nil, fmt.Errorf("GraphQL query failed: %s", ghError(err)) } @@ -193,6 +217,8 @@ func GetApprovedUnmerged(ctx context.Context, repoFilter string) ([]ApprovedPR, // ListOpenPRs lists open PRs for a repository using `gh pr list`. func ListOpenPRs(ctx context.Context, fullRepo string, limit int) ([]ReviewRequest, error) { + ctx, cancel := withTimeout(ctx) + defer cancel() cmd := exec.CommandContext(ctx, "gh", "pr", "list", "-R", fullRepo, "--state", "open", @@ -201,6 +227,9 @@ func ListOpenPRs(ctx context.Context, fullRepo string, limit int) ([]ReviewReque ) out, err := cmd.Output() if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return nil, fmt.Errorf("listing open PRs timed out after %s", apiTimeout) + } return nil, err } diff --git a/internal/github/queries_test.go b/internal/github/queries_test.go new file mode 100644 index 0000000..913f698 --- /dev/null +++ b/internal/github/queries_test.go @@ -0,0 +1,91 @@ +package github + +import ( + "context" + "strings" + "testing" + "time" +) + +func TestWithTimeout_addsDeadlineWhenNone(t *testing.T) { + ctx, cancel := withTimeout(context.Background()) + defer cancel() + + deadline, ok := ctx.Deadline() + if !ok { + t.Fatal("expected deadline to be set") + } + remaining := time.Until(deadline) + if remaining <= 0 || remaining > apiTimeout { + t.Fatalf("expected deadline within %s, got %s remaining", apiTimeout, remaining) + } +} + +func TestWithTimeout_preservesExistingDeadline(t *testing.T) { + existing := time.Now().Add(5 * time.Second) + parent, parentCancel := context.WithDeadline(context.Background(), existing) + defer parentCancel() + + ctx, cancel := withTimeout(parent) + defer cancel() + + deadline, ok := ctx.Deadline() + if !ok { + t.Fatal("expected deadline to be set") + } + if !deadline.Equal(existing) { + t.Fatalf("expected existing deadline %v, got %v", existing, deadline) + } +} + +func TestGetCurrentUser_timeoutError(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) + defer cancel() + + _, err := GetCurrentUser(ctx) + if err == nil { + t.Fatal("expected error from expired context") + } + if !strings.Contains(err.Error(), "timed out") { + t.Fatalf("expected timeout error message, got: %s", err) + } +} + +func TestGetReviewRequests_timeoutError(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) + defer cancel() + + _, err := GetReviewRequests(ctx, "") + if err == nil { + t.Fatal("expected error from expired context") + } + if !strings.Contains(err.Error(), "timed out") { + t.Fatalf("expected timeout error message, got: %s", err) + } +} + +func TestGetApprovedUnmerged_timeoutError(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) + defer cancel() + + _, err := GetApprovedUnmerged(ctx, "") + if err == nil { + t.Fatal("expected error from expired context") + } + if !strings.Contains(err.Error(), "timed out") { + t.Fatalf("expected timeout error message, got: %s", err) + } +} + +func TestListOpenPRs_timeoutError(t *testing.T) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) + defer cancel() + + _, err := ListOpenPRs(ctx, "owner/repo", 10) + if err == nil { + t.Fatal("expected error from expired context") + } + if !strings.Contains(err.Error(), "timed out") { + t.Fatalf("expected timeout error message, got: %s", err) + } +}