Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions internal/github/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@ import (
"fmt"
"os/exec"
"strings"
"time"

gh "github.com/google/go-github/v75/github"
"golang.org/x/oauth2"
)

const apiTimeout = 30 * time.Second

// Client wraps go-github with auth from `gh auth token`.
type Client struct {
gh *gh.Client
}

// NewClient creates a GitHub client using the token from `gh auth token`.
func NewClient(ctx context.Context) (*Client, error) {
token, err := ghAuthToken()
token, err := ghAuthToken(ctx)
if err != nil {
return nil, fmt.Errorf("getting GitHub token: %w", err)
}
Expand All @@ -30,10 +33,15 @@ func NewClient(ctx context.Context) (*Client, error) {
}

// ghAuthToken runs `gh auth token` and returns the token string.
func ghAuthToken() (string, error) {
cmd := exec.Command("gh", "auth", "token")
func ghAuthToken(ctx context.Context) (string, error) {
ctx, cancel := withTimeout(ctx)
defer cancel()
cmd := exec.CommandContext(ctx, "gh", "auth", "token")
out, err := cmd.Output()
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return "", fmt.Errorf("gh auth token timed out after %s", apiTimeout)
}
return "", fmt.Errorf("gh auth token failed: %s (is gh CLI installed and authenticated?)", ghError(err))
}
return strings.TrimSpace(string(out)), nil
Expand Down
29 changes: 29 additions & 0 deletions internal/github/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@ import (
"strings"
)

// withTimeout returns a context with apiTimeout applied, unless the caller
// already set a deadline.
func withTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
if _, ok := ctx.Deadline(); ok {
return ctx, func() {}
}
return context.WithTimeout(ctx, apiTimeout)
}

// ghError extracts stderr from an exec.ExitError for better error messages.
func ghError(err error) string {
if ee, ok := err.(*exec.ExitError); ok && len(ee.Stderr) > 0 {
Expand Down Expand Up @@ -50,9 +59,14 @@ type ApprovedPR struct {

// GetCurrentUser returns the authenticated GitHub user's login.
func GetCurrentUser(ctx context.Context) (string, error) {
ctx, cancel := withTimeout(ctx)
defer cancel()
cmd := exec.CommandContext(ctx, "gh", "api", "user", "--jq", ".login")
out, err := cmd.Output()
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return "", fmt.Errorf("fetching current user timed out after %s", apiTimeout)
}
return "", fmt.Errorf("fetching current user: %s", ghError(err))
}
return strings.TrimSpace(string(out)), nil
Expand All @@ -61,6 +75,8 @@ func GetCurrentUser(ctx context.Context) (string, error) {
// GetReviewRequests fetches PRs where the user is a requested reviewer,
// including re-reviews. Uses GraphQL via `gh api graphql`.
func GetReviewRequests(ctx context.Context, repoFilter string) ([]ReviewRequest, error) {
ctx, cancel := withTimeout(ctx)
defer cancel()
query := `query($q1: String!, $q2: String!) {
requested: search(query: $q1, type: ISSUE, first: 50) {
nodes {
Expand Down Expand Up @@ -103,6 +119,9 @@ func GetReviewRequests(ctx context.Context, repoFilter string) ([]ReviewRequest,
)
out, err := cmd.Output()
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return nil, fmt.Errorf("review requests query timed out after %s", apiTimeout)
}
return nil, fmt.Errorf("GraphQL query failed: %s", ghError(err))
}

Expand Down Expand Up @@ -139,6 +158,8 @@ func GetReviewRequests(ctx context.Context, repoFilter string) ([]ReviewRequest,

// GetApprovedUnmerged fetches the user's own PRs that are approved but not yet merged.
func GetApprovedUnmerged(ctx context.Context, repoFilter string) ([]ApprovedPR, error) {
ctx, cancel := withTimeout(ctx)
defer cancel()
query := `query($q: String!) {
search(query: $q, type: ISSUE, first: 50) {
nodes {
Expand Down Expand Up @@ -168,6 +189,9 @@ func GetApprovedUnmerged(ctx context.Context, repoFilter string) ([]ApprovedPR,
)
out, err := cmd.Output()
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return nil, fmt.Errorf("approved PRs query timed out after %s", apiTimeout)
}
return nil, fmt.Errorf("GraphQL query failed: %s", ghError(err))
}

Expand All @@ -193,6 +217,8 @@ func GetApprovedUnmerged(ctx context.Context, repoFilter string) ([]ApprovedPR,

// ListOpenPRs lists open PRs for a repository using `gh pr list`.
func ListOpenPRs(ctx context.Context, fullRepo string, limit int) ([]ReviewRequest, error) {
ctx, cancel := withTimeout(ctx)
defer cancel()
cmd := exec.CommandContext(ctx, "gh", "pr", "list",
"-R", fullRepo,
"--state", "open",
Expand All @@ -201,6 +227,9 @@ func ListOpenPRs(ctx context.Context, fullRepo string, limit int) ([]ReviewReque
)
out, err := cmd.Output()
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return nil, fmt.Errorf("listing open PRs timed out after %s", apiTimeout)
}
return nil, err
}

Expand Down
91 changes: 91 additions & 0 deletions internal/github/queries_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package github

import (
"context"
"strings"
"testing"
"time"
)

func TestWithTimeout_addsDeadlineWhenNone(t *testing.T) {
ctx, cancel := withTimeout(context.Background())
defer cancel()

deadline, ok := ctx.Deadline()
if !ok {
t.Fatal("expected deadline to be set")
}
remaining := time.Until(deadline)
if remaining <= 0 || remaining > apiTimeout {
t.Fatalf("expected deadline within %s, got %s remaining", apiTimeout, remaining)
}
}

func TestWithTimeout_preservesExistingDeadline(t *testing.T) {
existing := time.Now().Add(5 * time.Second)
parent, parentCancel := context.WithDeadline(context.Background(), existing)
defer parentCancel()

ctx, cancel := withTimeout(parent)
defer cancel()

deadline, ok := ctx.Deadline()
if !ok {
t.Fatal("expected deadline to be set")
}
if !deadline.Equal(existing) {
t.Fatalf("expected existing deadline %v, got %v", existing, deadline)
}
}

func TestGetCurrentUser_timeoutError(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
defer cancel()

_, err := GetCurrentUser(ctx)
if err == nil {
t.Fatal("expected error from expired context")
}
if !strings.Contains(err.Error(), "timed out") {
t.Fatalf("expected timeout error message, got: %s", err)
}
}

func TestGetReviewRequests_timeoutError(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
defer cancel()

_, err := GetReviewRequests(ctx, "")
if err == nil {
t.Fatal("expected error from expired context")
}
if !strings.Contains(err.Error(), "timed out") {
t.Fatalf("expected timeout error message, got: %s", err)
}
}

func TestGetApprovedUnmerged_timeoutError(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
defer cancel()

_, err := GetApprovedUnmerged(ctx, "")
if err == nil {
t.Fatal("expected error from expired context")
}
if !strings.Contains(err.Error(), "timed out") {
t.Fatalf("expected timeout error message, got: %s", err)
}
}

func TestListOpenPRs_timeoutError(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
defer cancel()

_, err := ListOpenPRs(ctx, "owner/repo", 10)
if err == nil {
t.Fatal("expected error from expired context")
}
if !strings.Contains(err.Error(), "timed out") {
t.Fatalf("expected timeout error message, got: %s", err)
}
}
Loading