Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@ package scanner

import (
"context"
"errors"
"fmt"
"net/http"
"net/url"

"github.com/google/go-github/v72/github"
)

// Sentinel errors for per-repo scan failures.
var (
ErrEmptyRepo = errors.New("repository is empty")
ErrTruncatedTree = errors.New("tree truncated by GitHub API")
)

// FileEntry represents a file or directory in a repo.
type FileEntry struct {
Path string // full path relative to repo root (e.g., ".github/workflows/ci.yml")
Expand Down Expand Up @@ -61,6 +68,15 @@ func newTestGitHubClient(baseURL string) GitHubClient {
return &realGitHubClient{client: client}
}

// isRateLimitError checks whether an error is a GitHub rate limit error
// (primary or secondary). Rate limit errors must never be swallowed -
// they indicate a global problem that affects all subsequent API calls.
func isRateLimitError(err error) bool {
var rateLimitErr *github.RateLimitError
var abuseErr *github.AbuseRateLimitError
return errors.As(err, &rateLimitErr) || errors.As(err, &abuseErr)
}

func (c *realGitHubClient) ListRepos(ctx context.Context, org string) ([]Repo, error) {
var allRepos []Repo
opts := &github.RepositoryListByOrgOptions{
Expand All @@ -70,6 +86,9 @@ func (c *realGitHubClient) ListRepos(ctx context.Context, org string) ([]Repo, e
for {
ghRepos, resp, err := c.client.Repositories.ListByOrg(ctx, org, opts)
if err != nil {
if isRateLimitError(err) {
return nil, err
}
return nil, fmt.Errorf("list repos for org %s: %w", org, err)
}

Expand All @@ -92,11 +111,21 @@ func (c *realGitHubClient) ListRepos(ctx context.Context, org string) ([]Repo, e
}

func (c *realGitHubClient) GetTree(ctx context.Context, owner, repo, branch string) ([]FileEntry, error) {
tree, _, err := c.client.Git.GetTree(ctx, owner, repo, branch, true)
tree, resp, err := c.client.Git.GetTree(ctx, owner, repo, branch, true)
if err != nil {
if isRateLimitError(err) {
return nil, err
}
if resp != nil && resp.StatusCode == http.StatusConflict {
return nil, ErrEmptyRepo
}
return nil, fmt.Errorf("get tree for %s/%s: %w", owner, repo, err)
}

if tree.GetTruncated() {
return nil, ErrTruncatedTree
}

files := make([]FileEntry, len(tree.Entries))
for i, e := range tree.Entries {
files[i] = FileEntry{
Expand All @@ -111,6 +140,9 @@ func (c *realGitHubClient) GetTree(ctx context.Context, owner, repo, branch stri
func (c *realGitHubClient) GetBranchProtection(ctx context.Context, owner, repo, branch string) (*BranchProtection, error) {
prot, resp, err := c.client.Repositories.GetBranchProtection(ctx, owner, repo, branch)
if err != nil {
if isRateLimitError(err) {
return nil, err
}
if resp != nil && (resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusForbidden) {
return nil, nil
}
Expand All @@ -130,6 +162,9 @@ func (c *realGitHubClient) GetBranchProtection(ctx context.Context, owner, repo,
func (c *realGitHubClient) GetRulesets(ctx context.Context, owner, repo, branch string) (*BranchProtection, error) {
rules, resp, err := c.client.Repositories.GetRulesForBranch(ctx, owner, repo, branch, nil)
if err != nil {
if isRateLimitError(err) {
return nil, err
}
if resp != nil && (resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusForbidden) {
return nil, nil
}
Expand Down Expand Up @@ -167,6 +202,9 @@ func (c *realGitHubClient) CreateIssue(ctx context.Context, owner, repo, title,

_, _, err := c.client.Issues.Create(ctx, owner, repo, req)
if err != nil {
if isRateLimitError(err) {
return err
}
return fmt.Errorf("create issue in %s/%s: %w", owner, repo, err)
}
return nil
Expand Down
8 changes: 7 additions & 1 deletion client_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ type MockGitHubClient struct {
Repos []Repo
Err error
Tree map[string][]FileEntry // repo name -> file entries
TreeErr error
TreeErr error // global tree error (used if TreeErrs is nil)
TreeErrs map[string]error // repo name -> per-repo tree error
Protection map[string]*BranchProtection // repo name -> classic branch protection
ProtectionErr error
Rulesets map[string]*BranchProtection // repo name -> rulesets protection
Expand All @@ -24,6 +25,11 @@ func (m *MockGitHubClient) ListRepos(ctx context.Context, org string) ([]Repo, e
}

func (m *MockGitHubClient) GetTree(ctx context.Context, owner, repo, branch string) ([]FileEntry, error) {
if m.TreeErrs != nil {
if err, ok := m.TreeErrs[repo]; ok {
return nil, err
}
}
if m.TreeErr != nil {
return nil, m.TreeErr
}
Expand Down
119 changes: 119 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package scanner

import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -422,3 +423,121 @@ func TestCreateIssue_APIError(t *testing.T) {
t.Fatal("expected error, got nil")
}
}

// --- GetTree: empty repo and truncated ---

func TestGetTree_EmptyRepo_409(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/repos/org/repo/git/trees/main", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusConflict)
fmt.Fprint(w, `{"message": "Git Repository is empty."}`)
})
client := setupTestServer(t, mux)

_, err := client.GetTree(context.Background(), "org", "repo", "main")
if !errors.Is(err, ErrEmptyRepo) {
t.Fatalf("expected ErrEmptyRepo, got %v", err)
}
}

func TestGetTree_Truncated(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/repos/org/repo/git/trees/main", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{
"sha": "abc123",
"truncated": true,
"tree": [
{"path": "README.md", "type": "blob", "size": 100}
]
}`)
})
client := setupTestServer(t, mux)

_, err := client.GetTree(context.Background(), "org", "repo", "main")
if !errors.Is(err, ErrTruncatedTree) {
t.Fatalf("expected ErrTruncatedTree, got %v", err)
}
}

// --- Rate limit tests ---

// rateLimitHandler returns a handler that simulates a GitHub rate limit response.
func rateLimitHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-RateLimit-Remaining", "0")
w.Header().Set("X-RateLimit-Limit", "5000")
w.Header().Set("X-RateLimit-Reset", "1924905600")
w.WriteHeader(http.StatusForbidden)
fmt.Fprint(w, `{"message": "API rate limit exceeded for user.", "documentation_url": "https://docs.github.com/rest/overview/resources-in-the-rest-api#rate-limiting"}`)
}
}

func TestGetBranchProtection_RateLimit(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/repos/org/repo/branches/main/protection", rateLimitHandler())
client := setupTestServer(t, mux)

_, err := client.GetBranchProtection(context.Background(), "org", "repo", "main")
if err == nil {
t.Fatal("expected rate limit error, got nil")
}
if !isRateLimitError(err) {
t.Errorf("expected rate limit error type, got: %v", err)
}
}

func TestGetRulesets_RateLimit(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/repos/org/repo/rules/branches/main", rateLimitHandler())
client := setupTestServer(t, mux)

_, err := client.GetRulesets(context.Background(), "org", "repo", "main")
if err == nil {
t.Fatal("expected rate limit error, got nil")
}
if !isRateLimitError(err) {
t.Errorf("expected rate limit error type, got: %v", err)
}
}

func TestListRepos_RateLimit(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/orgs/test-org/repos", rateLimitHandler())
client := setupTestServer(t, mux)

_, err := client.ListRepos(context.Background(), "test-org")
if err == nil {
t.Fatal("expected rate limit error, got nil")
}
if !isRateLimitError(err) {
t.Errorf("expected rate limit error type, got: %v", err)
}
}

func TestGetTree_RateLimit(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/repos/org/repo/git/trees/main", rateLimitHandler())
client := setupTestServer(t, mux)

_, err := client.GetTree(context.Background(), "org", "repo", "main")
if err == nil {
t.Fatal("expected rate limit error, got nil")
}
if !isRateLimitError(err) {
t.Errorf("expected rate limit error type, got: %v", err)
}
}

func TestCreateIssue_RateLimit(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/repos/org/repo/issues", rateLimitHandler())
client := setupTestServer(t, mux)

err := client.CreateIssue(context.Background(), "org", "repo", "Test", "Body")
if err == nil {
t.Fatal("expected rate limit error, got nil")
}
if !isRateLimitError(err) {
t.Errorf("expected rate limit error type, got: %v", err)
}
}
45 changes: 38 additions & 7 deletions report.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,29 @@ func GenerateReport(org string, results []RepoResult) string {
func generateReport(org string, results []RepoResult, now time.Time) string {
var b strings.Builder

compliant, nonCompliant := splitByCompliance(results)
scanned, skipped := splitScanned(results)
compliant, nonCompliant := splitByCompliance(scanned)

b.WriteString("# Codatus - Org Compliance Report\n\n")
fmt.Fprintf(&b, "**Org:** %s\n", org)
fmt.Fprintf(&b, "**Scanned:** %s\n", now.UTC().Format("2006-01-02 15:04 UTC"))
fmt.Fprintf(&b, "**Repos scanned:** %d\n", len(results))
if len(results) > 0 {
fmt.Fprintf(&b, "**Compliant:** %d/%d (%d%%)\n", len(compliant), len(results), len(compliant)*100/len(results))
fmt.Fprintf(&b, "**Repos scanned:** %d\n", len(scanned))
if len(scanned) > 0 {
fmt.Fprintf(&b, "**Compliant:** %d/%d (%d%%)\n", len(compliant), len(scanned), len(compliant)*100/len(scanned))
}
if len(skipped) > 0 {
fmt.Fprintf(&b, "**Skipped:** %d\n", len(skipped))
}

if len(results) == 0 {
if len(scanned) == 0 && len(skipped) == 0 {
b.WriteString("\nNo repos found.\n")
return b.String()
}

b.WriteString("\n## Summary\n\n")
writeSummaryTable(&b, results)
if len(scanned) > 0 {
b.WriteString("\n## Summary\n\n")
writeSummaryTable(&b, scanned)
}

if len(compliant) > 0 {
writeCompliantSection(&b, org, compliant)
Expand All @@ -41,9 +47,24 @@ func generateReport(org string, results []RepoResult, now time.Time) string {
writeNonCompliantSection(&b, org, nonCompliant)
}

if len(skipped) > 0 {
writeSkippedSection(&b, org, skipped)
}

return b.String()
}

func splitScanned(results []RepoResult) (scanned, skipped []RepoResult) {
for _, rr := range results {
if rr.Skipped {
skipped = append(skipped, rr)
} else {
scanned = append(scanned, rr)
}
}
return
}

func splitByCompliance(results []RepoResult) (compliant, nonCompliant []RepoResult) {
for _, rr := range results {
if isFullyCompliant(rr) {
Expand Down Expand Up @@ -152,3 +173,13 @@ func writeNonCompliantSection(b *strings.Builder, org string, nonCompliant []Rep
b.WriteString("\n</details>\n\n")
}
}

func writeSkippedSection(b *strings.Builder, org string, skipped []RepoResult) {
fmt.Fprintf(b, "\n## ⚠️ Skipped (%s)\n\n", pluralRepos(len(skipped)))
for _, rr := range skipped {
fmt.Fprintf(b, "<details>\n<summary><a href=\"https://github.com/%s/%s\">%s</a> - %s</summary>\n\n",
org, rr.RepoName, rr.RepoName, rr.SkipReason)
b.WriteString("This repository was excluded from compliance results.\n")
b.WriteString("\n</details>\n\n")
}
}
Loading
Loading