From edd79cb2a69281ffac34edc07173af774e432e44 Mon Sep 17 00:00:00 2001 From: nightshift Date: Tue, 14 Apr 2026 10:08:03 +0000 Subject: [PATCH] feat(deps): add dependency risk scanner Add 'nightshift deps' subcommand that scans Go module dependencies for security vulnerabilities (OSV.dev API), maintenance risks (GitHub API), and license concerns (GitHub raw content + SPDX matching). Key implementation details: - CVSS 3.1 vector parsing computes base scores from vector strings (not raw numeric scores) per the CVSS specification - API errors are surfaced to callers rather than silently swallowed - Partial results returned alongside errors for resilience - Concurrent scanning with semaphores (10 for OSV, 5 for GitHub) - Results persisted to dep_scans/dep_findings SQLite tables - Colored CLI output with lipgloss, JSON output via --json flag - Exit code 1 when critical findings detected Nightshift-Task: dependency-risk Nightshift-Ref: https://github.com/marcus/nightshift --- cmd/nightshift/commands/deps.go | 163 ++++++++++++++++ go.mod | 3 +- go.sum | 8 +- internal/db/migrations.go | 33 ++++ internal/deps/gomod.go | 54 ++++++ internal/deps/gomod_test.go | 85 +++++++++ internal/deps/license.go | 153 +++++++++++++++ internal/deps/maintenance.go | 154 +++++++++++++++ internal/deps/models.go | 66 +++++++ internal/deps/risk.go | 45 +++++ internal/deps/risk_test.go | 98 ++++++++++ internal/deps/scanner.go | 129 +++++++++++++ internal/deps/scanner_test.go | 150 +++++++++++++++ internal/deps/store.go | 180 ++++++++++++++++++ internal/deps/vulns.go | 322 ++++++++++++++++++++++++++++++++ internal/deps/vulns_test.go | 191 +++++++++++++++++++ 16 files changed, 1829 insertions(+), 5 deletions(-) create mode 100644 cmd/nightshift/commands/deps.go create mode 100644 internal/deps/gomod.go create mode 100644 internal/deps/gomod_test.go create mode 100644 internal/deps/license.go create mode 100644 internal/deps/maintenance.go create mode 100644 internal/deps/models.go create mode 100644 internal/deps/risk.go create mode 100644 internal/deps/risk_test.go create mode 100644 internal/deps/scanner.go create mode 100644 internal/deps/scanner_test.go create mode 100644 internal/deps/store.go create mode 100644 internal/deps/vulns.go create mode 100644 internal/deps/vulns_test.go diff --git a/cmd/nightshift/commands/deps.go b/cmd/nightshift/commands/deps.go new file mode 100644 index 0000000..80a2b3f --- /dev/null +++ b/cmd/nightshift/commands/deps.go @@ -0,0 +1,163 @@ +package commands + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/charmbracelet/lipgloss" + "github.com/spf13/cobra" + + "github.com/marcus/nightshift/internal/config" + "github.com/marcus/nightshift/internal/db" + "github.com/marcus/nightshift/internal/deps" + "github.com/marcus/nightshift/internal/logging" +) + +var depsCmd = &cobra.Command{ + Use: "deps", + Short: "Scan dependencies for security, maintenance, and license risks", + Long: `Scan Go module dependencies across managed projects for: + - Security vulnerabilities (via OSV.dev) + - Maintenance risks (archived repos, stale projects) + - License concerns (copyleft, unknown licenses) + +Results are scored Critical/High/Medium/Low and stored in the database.`, + RunE: func(cmd *cobra.Command, args []string) error { + project, _ := cmd.Flags().GetString("project") + jsonOutput, _ := cmd.Flags().GetBool("json") + dbPath, _ := cmd.Flags().GetString("db") + + return runDeps(project, jsonOutput, dbPath) + }, +} + +func init() { + depsCmd.Flags().StringP("project", "p", "", "Scan a specific project path") + depsCmd.Flags().Bool("json", false, "Output results as JSON") + depsCmd.Flags().String("db", "", "Database path (uses config if not set)") + rootCmd.AddCommand(depsCmd) +} + +func runDeps(project string, jsonOutput bool, dbPath string) error { + logger := logging.Component("deps") + + if dbPath == "" { + cfg, err := config.Load() + if err != nil { + logger.Warnf("could not load config for db path, using default: %v", err) + } else { + dbPath = cfg.ExpandedDBPath() + } + } + + database, err := db.Open(dbPath) + if err != nil { + return fmt.Errorf("opening database: %w", err) + } + defer func() { _ = database.Close() }() + + store := deps.NewStore(database.SQL()) + if err := store.InitTables(); err != nil { + return fmt.Errorf("initializing tables: %w", err) + } + + githubToken := os.Getenv("GITHUB_TOKEN") + scanner := deps.NewScanner(store, githubToken, logger) + + ctx := context.Background() + + var paths []string + if project != "" { + paths = []string{project} + } else { + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + for _, p := range cfg.Projects { + paths = append(paths, p.Path) + } + if len(paths) == 0 { + return fmt.Errorf("no projects configured; use --project to specify a path") + } + } + + results, scanErr := scanner.ScanAll(ctx, paths) + + if jsonOutput { + enc := json.NewEncoder(os.Stdout) + enc.SetIndent("", " ") + if err := enc.Encode(results); err != nil { + return fmt.Errorf("encoding JSON: %w", err) + } + } else { + renderResults(results, scanErr) + } + + for _, r := range results { + for _, f := range r.Findings { + if f.RiskLevel == deps.RiskCritical { + if scanErr != nil { + fmt.Fprintf(os.Stderr, "\nWarning: scan completed with errors: %v\n", scanErr) + } + os.Exit(1) + } + } + } + + if scanErr != nil { + fmt.Fprintf(os.Stderr, "\nWarning: scan completed with errors: %v\n", scanErr) + } + + return nil +} + +func renderResults(results []*deps.ScanResult, scanErr error) { + criticalStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("196")).Bold(true) + highStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("208")).Bold(true) + mediumStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("226")) + lowStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("246")) + headerStyle := lipgloss.NewStyle().Bold(true).Underline(true) + + for _, result := range results { + critical, high, medium, low := deps.SummarizeResults(result) + fmt.Printf("\n%s (%d deps)\n", headerStyle.Render(result.Project), len(result.Deps)) + fmt.Printf(" Critical: %s High: %s Medium: %s Low: %s\n", + criticalStyle.Render(fmt.Sprintf("%d", critical)), + highStyle.Render(fmt.Sprintf("%d", high)), + mediumStyle.Render(fmt.Sprintf("%d", medium)), + lowStyle.Render(fmt.Sprintf("%d", low)), + ) + + if len(result.Findings) == 0 { + fmt.Println(" No issues found.") + continue + } + + fmt.Println() + for _, f := range result.Findings { + var styled string + badge := strings.ToUpper(string(f.RiskLevel)) + switch f.RiskLevel { + case deps.RiskCritical: + styled = criticalStyle.Render(fmt.Sprintf(" [%s]", badge)) + case deps.RiskHigh: + styled = highStyle.Render(fmt.Sprintf(" [%s]", badge)) + case deps.RiskMedium: + styled = mediumStyle.Render(fmt.Sprintf(" [%s]", badge)) + default: + styled = lowStyle.Render(fmt.Sprintf(" [%s]", badge)) + } + fmt.Printf("%s %s\n", styled, f.Detail) + } + } + + if scanErr != nil { + fmt.Fprintf(os.Stderr, "\nWarning: %v\n", scanErr) + } + + fmt.Println() +} diff --git a/go.mod b/go.mod index f739510..05de43e 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/marcus/nightshift -go 1.24.0 +go 1.25.0 require ( github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 @@ -13,6 +13,7 @@ require ( github.com/rs/zerolog v1.34.0 github.com/spf13/cobra v1.10.2 github.com/spf13/viper v1.21.0 + golang.org/x/mod v0.35.0 modernc.org/sqlite v1.35.0 ) diff --git a/go.sum b/go.sum index 8e43c0e..5f7dd2b 100644 --- a/go.sum +++ b/go.sum @@ -106,8 +106,8 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtDIs4eL21OuM9nyAADmo= golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= -golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= -golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= +golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= +golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -118,8 +118,8 @@ golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= -golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= -golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= +golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/db/migrations.go b/internal/db/migrations.go index 3b7d11e..071742e 100644 --- a/internal/db/migrations.go +++ b/internal/db/migrations.go @@ -40,6 +40,11 @@ var migrations = []Migration{ Description: "add branch column to run_history", SQL: migration005SQL, }, + { + Version: 6, + Description: "add dep_scans and dep_findings tables for dependency risk scanning", + SQL: migration006SQL, + }, } const migration002SQL = ` @@ -121,6 +126,34 @@ const migration005SQL = ` ALTER TABLE run_history ADD COLUMN branch TEXT NOT NULL DEFAULT ''; ` +const migration006SQL = ` +CREATE TABLE IF NOT EXISTS dep_scans ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project TEXT NOT NULL, + scanned_at DATETIME NOT NULL, + total_deps INTEGER, + critical_count INTEGER, + high_count INTEGER, + medium_count INTEGER, + low_count INTEGER +); + +CREATE TABLE IF NOT EXISTS dep_findings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + scan_id INTEGER NOT NULL REFERENCES dep_scans(id), + module TEXT NOT NULL, + version TEXT NOT NULL, + risk_level TEXT NOT NULL, + category TEXT NOT NULL, + detail TEXT NOT NULL, + cve_id TEXT, + cvss_score REAL +); + +CREATE INDEX IF NOT EXISTS idx_dep_scans_project ON dep_scans(project, scanned_at DESC); +CREATE INDEX IF NOT EXISTS idx_dep_findings_scan ON dep_findings(scan_id); +` + // Migrate runs all pending migrations inside transactions. func Migrate(db *sql.DB) error { if db == nil { diff --git a/internal/deps/gomod.go b/internal/deps/gomod.go new file mode 100644 index 0000000..61cd56c --- /dev/null +++ b/internal/deps/gomod.go @@ -0,0 +1,54 @@ +package deps + +import ( + "fmt" + "os" + "path/filepath" + + "golang.org/x/mod/modfile" +) + +// ParseGoMod reads and parses a go.mod file from the given project path, +// returning all required dependencies. +func ParseGoMod(projectPath string) ([]Dependency, error) { + gomodPath := filepath.Join(projectPath, "go.mod") + data, err := os.ReadFile(gomodPath) + if err != nil { + return nil, fmt.Errorf("reading go.mod: %w", err) + } + + f, err := modfile.Parse(gomodPath, data, nil) + if err != nil { + return nil, fmt.Errorf("parsing go.mod: %w", err) + } + + replacements := make(map[string]struct { + mod string + version string + }) + for _, rep := range f.Replace { + if rep.New.Path != "" { + replacements[rep.Old.Path] = struct { + mod string + version string + }{mod: rep.New.Path, version: rep.New.Version} + } + } + + var deps []Dependency + for _, req := range f.Require { + mod := req.Mod.Path + ver := req.Mod.Version + if rep, ok := replacements[mod]; ok { + mod = rep.mod + ver = rep.version + } + deps = append(deps, Dependency{ + Module: mod, + Version: ver, + Indirect: req.Indirect, + }) + } + + return deps, nil +} diff --git a/internal/deps/gomod_test.go b/internal/deps/gomod_test.go new file mode 100644 index 0000000..ab87e99 --- /dev/null +++ b/internal/deps/gomod_test.go @@ -0,0 +1,85 @@ +package deps + +import ( + "os" + "path/filepath" + "testing" +) + +func TestParseGoMod(t *testing.T) { + tests := []struct { + name string + content string + wantDeps []Dependency + wantErr bool + }{ + { + name: "direct and indirect deps", + content: `module example.com/mymod + +go 1.21 + +require ( + github.com/foo/bar v1.2.3 + github.com/baz/qux v0.1.0 // indirect +) +`, + wantDeps: []Dependency{ + {Module: "github.com/foo/bar", Version: "v1.2.3", Indirect: false}, + {Module: "github.com/baz/qux", Version: "v0.1.0", Indirect: true}, + }, + }, + { + name: "replace directive", + content: `module example.com/mymod + +go 1.21 + +require github.com/old/mod v1.0.0 + +replace github.com/old/mod => github.com/new/mod v2.0.0 +`, + wantDeps: []Dependency{ + {Module: "github.com/new/mod", Version: "v2.0.0", Indirect: false}, + }, + }, + { + name: "empty go.mod", + content: "module example.com/mymod\n\ngo 1.21\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte(tt.content), 0644); err != nil { + t.Fatal(err) + } + + deps, err := ParseGoMod(dir) + if (err != nil) != tt.wantErr { + t.Fatalf("ParseGoMod() error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil { + return + } + + if len(deps) != len(tt.wantDeps) { + t.Fatalf("got %d deps, want %d", len(deps), len(tt.wantDeps)) + } + for i, got := range deps { + want := tt.wantDeps[i] + if got.Module != want.Module || got.Version != want.Version || got.Indirect != want.Indirect { + t.Errorf("dep[%d] = %+v, want %+v", i, got, want) + } + } + }) + } +} + +func TestParseGoModMissingFile(t *testing.T) { + _, err := ParseGoMod(t.TempDir()) + if err == nil { + t.Fatal("expected error for missing go.mod") + } +} diff --git a/internal/deps/license.go b/internal/deps/license.go new file mode 100644 index 0000000..3b0435c --- /dev/null +++ b/internal/deps/license.go @@ -0,0 +1,153 @@ +package deps + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" +) + +// LicenseChecker detects license types for dependencies. +type LicenseChecker struct { + client *http.Client + sem chan struct{} +} + +// NewLicenseChecker creates a LicenseChecker with sensible defaults. +func NewLicenseChecker() *LicenseChecker { + return &LicenseChecker{ + client: &http.Client{Timeout: 10 * time.Second}, + sem: make(chan struct{}, 10), + } +} + +// CheckLicenses checks the license of each dependency. +// Non-GitHub modules are marked as unknown (Medium risk). +// Returns an error if any license fetch fails. +func (lc *LicenseChecker) CheckLicenses(ctx context.Context, deps []Dependency) ([]Finding, error) { + var ( + mu sync.Mutex + findings []Finding + errs []string + ) + + var wg sync.WaitGroup + for _, dep := range deps { + owner, repo := parseGitHubModule(dep.Module) + if owner == "" { + mu.Lock() + findings = append(findings, Finding{ + Module: dep.Module, + Version: dep.Version, + RiskLevel: RiskMedium, + Category: CategoryLicense, + Detail: fmt.Sprintf("%s: unable to determine license (non-GitHub module)", dep.Module), + }) + mu.Unlock() + continue + } + + wg.Add(1) + go func(d Dependency, owner, repo string) { + defer wg.Done() + lc.sem <- struct{}{} + defer func() { <-lc.sem }() + + f, err := lc.checkLicense(ctx, d, owner, repo) + mu.Lock() + defer mu.Unlock() + if err != nil { + errs = append(errs, fmt.Sprintf("%s: %v", d.Module, err)) + return + } + if f != nil { + findings = append(findings, *f) + } + }(dep, owner, repo) + } + wg.Wait() + + if len(errs) > 0 { + return findings, fmt.Errorf("license check errors: %s", strings.Join(errs, "; ")) + } + return findings, nil +} + +func (lc *LicenseChecker) checkLicense(ctx context.Context, dep Dependency, owner, repo string) (*Finding, error) { + content, err := lc.fetchLicenseContent(ctx, owner, repo) + if err != nil { + return nil, err + } + + license, risk := classifyLicense(content) + if risk == RiskNone { + return nil, nil + } + + return &Finding{ + Module: dep.Module, + Version: dep.Version, + RiskLevel: risk, + Category: CategoryLicense, + Detail: fmt.Sprintf("%s: %s license", dep.Module, license), + }, nil +} + +func (lc *LicenseChecker) fetchLicenseContent(ctx context.Context, owner, repo string) (string, error) { + for _, filename := range []string{"LICENSE", "LICENSE.md", "LICENSE.txt", "LICENCE", "COPYING"} { + url := fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/HEAD/%s", owner, repo, filename) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + continue + } + + resp, err := lc.client.Do(req) + if err != nil { + continue + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) + if err != nil { + continue + } + return string(body), nil + } + } + return "", fmt.Errorf("no license file found for %s/%s", owner, repo) +} + +// classifyLicense identifies the license type from file content using keyword matching. +func classifyLicense(content string) (string, RiskLevel) { + upper := strings.ToUpper(content) + + switch { + case strings.Contains(upper, "GNU AFFERO GENERAL PUBLIC LICENSE"): + return "AGPL", RiskHigh + case strings.Contains(upper, "GNU GENERAL PUBLIC LICENSE"): + if strings.Contains(upper, "LESSER") { + return "LGPL", RiskMedium + } + return "GPL", RiskHigh + case strings.Contains(upper, "MOZILLA PUBLIC LICENSE"): + return "MPL", RiskMedium + case strings.Contains(upper, "MIT LICENSE") || strings.Contains(upper, "PERMISSION IS HEREBY GRANTED, FREE OF CHARGE"): + return "MIT", RiskNone + case strings.Contains(upper, "APACHE LICENSE"): + return "Apache-2.0", RiskNone + case strings.Contains(upper, "BSD ") && (strings.Contains(upper, "REDISTRIBUTION AND USE") || strings.Contains(upper, "REDISTRIBUTIONS")): + return "BSD", RiskNone + case strings.Contains(upper, "ISC LICENSE"): + return "ISC", RiskNone + case strings.Contains(upper, "THE UNLICENSE") || strings.Contains(upper, "UNLICENSE"): + return "Unlicense", RiskNone + case strings.Contains(upper, "CREATIVE COMMONS"): + return "CC", RiskMedium + default: + return "Unknown", RiskMedium + } +} diff --git a/internal/deps/maintenance.go b/internal/deps/maintenance.go new file mode 100644 index 0000000..e96ea85 --- /dev/null +++ b/internal/deps/maintenance.go @@ -0,0 +1,154 @@ +package deps + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" +) + +// MaintenanceChecker assesses dependency maintenance health via the GitHub API. +type MaintenanceChecker struct { + client *http.Client + githubToken string + baseURL string + sem chan struct{} + cache sync.Map +} + +type repoInfo struct { + Archived bool `json:"archived"` + PushedAt time.Time `json:"pushed_at"` +} + +// NewMaintenanceChecker creates a MaintenanceChecker. +func NewMaintenanceChecker(githubToken string) *MaintenanceChecker { + return &MaintenanceChecker{ + client: &http.Client{Timeout: 10 * time.Second}, + githubToken: githubToken, + baseURL: "https://api.github.com", + sem: make(chan struct{}, 5), + } +} + +// CheckMaintenance checks each dependency's maintenance health. +// Non-GitHub modules are skipped. Returns an error if any GitHub API calls fail. +func (mc *MaintenanceChecker) CheckMaintenance(ctx context.Context, deps []Dependency) ([]Finding, error) { + var ( + mu sync.Mutex + findings []Finding + errs []string + ) + + var wg sync.WaitGroup + for _, dep := range deps { + owner, repo := parseGitHubModule(dep.Module) + if owner == "" { + continue + } + + wg.Add(1) + go func(d Dependency, owner, repo string) { + defer wg.Done() + mc.sem <- struct{}{} + defer func() { <-mc.sem }() + + info, err := mc.getRepoInfo(ctx, owner, repo) + mu.Lock() + defer mu.Unlock() + if err != nil { + errs = append(errs, fmt.Sprintf("%s/%s: %v", owner, repo, err)) + return + } + + if info.Archived { + findings = append(findings, Finding{ + Module: d.Module, + Version: d.Version, + RiskLevel: RiskCritical, + Category: CategoryMaintenance, + Detail: fmt.Sprintf("%s/%s is archived", owner, repo), + }) + } + + age := time.Since(info.PushedAt) + switch { + case age > 3*365*24*time.Hour: + findings = append(findings, Finding{ + Module: d.Module, + Version: d.Version, + RiskLevel: RiskHigh, + Category: CategoryMaintenance, + Detail: fmt.Sprintf("%s/%s: last commit %.0f days ago", owner, repo, age.Hours()/24), + }) + case age > 365*24*time.Hour: + findings = append(findings, Finding{ + Module: d.Module, + Version: d.Version, + RiskLevel: RiskMedium, + Category: CategoryMaintenance, + Detail: fmt.Sprintf("%s/%s: last commit %.0f days ago", owner, repo, age.Hours()/24), + }) + } + }(dep, owner, repo) + } + wg.Wait() + + if len(errs) > 0 { + return findings, fmt.Errorf("maintenance check errors: %s", strings.Join(errs, "; ")) + } + return findings, nil +} + +func (mc *MaintenanceChecker) getRepoInfo(ctx context.Context, owner, repo string) (*repoInfo, error) { + cacheKey := owner + "/" + repo + if cached, ok := mc.cache.Load(cacheKey); ok { + return cached.(*repoInfo), nil + } + + url := fmt.Sprintf("%s/repos/%s/%s", mc.baseURL, owner, repo) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + + if mc.githubToken != "" { + req.Header.Set("Authorization", "Bearer "+mc.githubToken) + } + + resp, err := mc.client.Do(req) + if err != nil { + return nil, fmt.Errorf("querying GitHub: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("GitHub returned %d: %s", resp.StatusCode, string(respBody)) + } + + var info repoInfo + if err := json.NewDecoder(resp.Body).Decode(&info); err != nil { + return nil, fmt.Errorf("decoding response: %w", err) + } + + mc.cache.Store(cacheKey, &info) + return &info, nil +} + +// parseGitHubModule extracts owner/repo from a Go module path. +// Returns empty strings for non-GitHub modules. +func parseGitHubModule(module string) (string, string) { + if !strings.HasPrefix(module, "github.com/") { + return "", "" + } + parts := strings.SplitN(module, "/", 4) + if len(parts) < 3 { + return "", "" + } + return parts[1], parts[2] +} diff --git a/internal/deps/models.go b/internal/deps/models.go new file mode 100644 index 0000000..da025d5 --- /dev/null +++ b/internal/deps/models.go @@ -0,0 +1,66 @@ +package deps + +import "time" + +// RiskLevel represents the severity of a dependency risk finding. +type RiskLevel string + +const ( + RiskCritical RiskLevel = "critical" + RiskHigh RiskLevel = "high" + RiskMedium RiskLevel = "medium" + RiskLow RiskLevel = "low" + RiskNone RiskLevel = "none" +) + +// Severity returns a numeric severity for sorting (higher = more severe). +func (r RiskLevel) Severity() int { + switch r { + case RiskCritical: + return 4 + case RiskHigh: + return 3 + case RiskMedium: + return 2 + case RiskLow: + return 1 + default: + return 0 + } +} + +// Category classifies what type of risk a finding represents. +type Category string + +const ( + CategoryVulnerability Category = "vulnerability" + CategoryMaintenance Category = "maintenance" + CategoryLicense Category = "license" +) + +// Dependency represents a single Go module dependency. +type Dependency struct { + Module string `json:"module"` + Version string `json:"version"` + Indirect bool `json:"indirect"` +} + +// Finding represents a single risk finding for a dependency. +type Finding struct { + Module string `json:"module"` + Version string `json:"version"` + RiskLevel RiskLevel `json:"risk_level"` + Category Category `json:"category"` + Detail string `json:"detail"` + CVEID string `json:"cve_id,omitempty"` + CVSSScore float64 `json:"cvss_score,omitempty"` +} + +// ScanResult holds the complete results of scanning a project's dependencies. +type ScanResult struct { + ID int64 `json:"id,omitempty"` + Project string `json:"project"` + ScannedAt time.Time `json:"scanned_at"` + Deps []Dependency `json:"deps"` + Findings []Finding `json:"findings"` +} diff --git a/internal/deps/risk.go b/internal/deps/risk.go new file mode 100644 index 0000000..5989777 --- /dev/null +++ b/internal/deps/risk.go @@ -0,0 +1,45 @@ +package deps + +import ( + "sort" +) + +// ScoreDependency returns the highest risk level across all findings for a module. +func ScoreDependency(findings []Finding) RiskLevel { + max := RiskNone + for _, f := range findings { + if f.RiskLevel.Severity() > max.Severity() { + max = f.RiskLevel + } + } + return max +} + +// SummarizeResults counts findings by risk level. +func SummarizeResults(result *ScanResult) (critical, high, medium, low int) { + for _, f := range result.Findings { + switch f.RiskLevel { + case RiskCritical: + critical++ + case RiskHigh: + high++ + case RiskMedium: + medium++ + case RiskLow: + low++ + } + } + return +} + +// SortFindings sorts findings by severity descending, then by module name ascending. +func SortFindings(findings []Finding) { + sort.Slice(findings, func(i, j int) bool { + si := findings[i].RiskLevel.Severity() + sj := findings[j].RiskLevel.Severity() + if si != sj { + return si > sj + } + return findings[i].Module < findings[j].Module + }) +} diff --git a/internal/deps/risk_test.go b/internal/deps/risk_test.go new file mode 100644 index 0000000..8aeacea --- /dev/null +++ b/internal/deps/risk_test.go @@ -0,0 +1,98 @@ +package deps + +import ( + "testing" +) + +func TestScoreDependency(t *testing.T) { + tests := []struct { + name string + findings []Finding + want RiskLevel + }{ + { + name: "no findings", + findings: nil, + want: RiskNone, + }, + { + name: "single critical", + findings: []Finding{ + {RiskLevel: RiskCritical}, + }, + want: RiskCritical, + }, + { + name: "mixed levels returns highest", + findings: []Finding{ + {RiskLevel: RiskLow}, + {RiskLevel: RiskHigh}, + {RiskLevel: RiskMedium}, + }, + want: RiskHigh, + }, + { + name: "all low", + findings: []Finding{ + {RiskLevel: RiskLow}, + {RiskLevel: RiskLow}, + }, + want: RiskLow, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ScoreDependency(tt.findings) + if got != tt.want { + t.Errorf("ScoreDependency() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSummarizeResults(t *testing.T) { + result := &ScanResult{ + Findings: []Finding{ + {RiskLevel: RiskCritical}, + {RiskLevel: RiskHigh}, + {RiskLevel: RiskHigh}, + {RiskLevel: RiskMedium}, + {RiskLevel: RiskMedium}, + {RiskLevel: RiskMedium}, + {RiskLevel: RiskLow}, + }, + } + + c, h, m, l := SummarizeResults(result) + if c != 1 || h != 2 || m != 3 || l != 1 { + t.Errorf("SummarizeResults() = (%d, %d, %d, %d), want (1, 2, 3, 1)", c, h, m, l) + } +} + +func TestSortFindings(t *testing.T) { + findings := []Finding{ + {Module: "b-mod", RiskLevel: RiskLow}, + {Module: "a-mod", RiskLevel: RiskCritical}, + {Module: "c-mod", RiskLevel: RiskHigh}, + {Module: "a-mod", RiskLevel: RiskHigh}, + } + + SortFindings(findings) + + expected := []struct { + module string + level RiskLevel + }{ + {"a-mod", RiskCritical}, + {"a-mod", RiskHigh}, + {"c-mod", RiskHigh}, + {"b-mod", RiskLow}, + } + + for i, f := range findings { + if f.Module != expected[i].module || f.RiskLevel != expected[i].level { + t.Errorf("findings[%d] = {%s, %s}, want {%s, %s}", + i, f.Module, f.RiskLevel, expected[i].module, expected[i].level) + } + } +} diff --git a/internal/deps/scanner.go b/internal/deps/scanner.go new file mode 100644 index 0000000..7b91c4d --- /dev/null +++ b/internal/deps/scanner.go @@ -0,0 +1,129 @@ +package deps + +import ( + "context" + "fmt" + "path/filepath" + "strings" + "time" +) + +// Logger defines the logging interface used by the scanner. +type Logger interface { + Infof(format string, args ...any) + Warnf(format string, args ...any) + Errorf(format string, args ...any) +} + +// Scanner orchestrates dependency scanning across projects. +type Scanner struct { + store *Store + githubToken string + logger Logger +} + +// NewScanner creates a Scanner with the given dependencies. +func NewScanner(store *Store, githubToken string, logger Logger) *Scanner { + return &Scanner{ + store: store, + githubToken: githubToken, + logger: logger, + } +} + +// ScanProject scans a single project's dependencies for risks. +// It runs vulnerability, maintenance, and license checks concurrently +// and aggregates the results. Partial results are returned alongside +// any errors encountered during scanning. +func (s *Scanner) ScanProject(ctx context.Context, projectPath string) (*ScanResult, error) { + project := filepath.Base(projectPath) + s.logger.Infof("scanning dependencies for %s", project) + + deps, err := ParseGoMod(projectPath) + if err != nil { + return nil, fmt.Errorf("parsing go.mod for %s: %w", project, err) + } + + s.logger.Infof("parsed %d dependencies for %s", len(deps), project) + + type checkResult struct { + findings []Finding + err error + name string + } + + results := make(chan checkResult, 3) + + go func() { + vc := NewVulnChecker() + findings, err := vc.CheckVulnerabilities(ctx, deps) + results <- checkResult{findings: findings, err: err, name: "vulnerability"} + }() + + go func() { + mc := NewMaintenanceChecker(s.githubToken) + findings, err := mc.CheckMaintenance(ctx, deps) + results <- checkResult{findings: findings, err: err, name: "maintenance"} + }() + + go func() { + lc := NewLicenseChecker() + findings, err := lc.CheckLicenses(ctx, deps) + results <- checkResult{findings: findings, err: err, name: "license"} + }() + + var allFindings []Finding + var errs []string + + for i := 0; i < 3; i++ { + r := <-results + if r.err != nil { + s.logger.Warnf("%s check completed with errors: %v", r.name, r.err) + errs = append(errs, r.err.Error()) + } + allFindings = append(allFindings, r.findings...) + } + + SortFindings(allFindings) + + scanResult := &ScanResult{ + Project: project, + ScannedAt: time.Now(), + Deps: deps, + Findings: allFindings, + } + + if s.store != nil { + if _, err := s.store.SaveScan(scanResult); err != nil { + s.logger.Errorf("failed to save scan results: %v", err) + } + } + + if len(errs) > 0 { + return scanResult, fmt.Errorf("scan completed with errors: %s", strings.Join(errs, "; ")) + } + + return scanResult, nil +} + +// ScanAll scans multiple projects sequentially. +func (s *Scanner) ScanAll(ctx context.Context, projectPaths []string) ([]*ScanResult, error) { + var results []*ScanResult + var errs []string + + for _, path := range projectPaths { + result, err := s.ScanProject(ctx, path) + if err != nil { + s.logger.Warnf("scan for %s completed with errors: %v", path, err) + errs = append(errs, err.Error()) + } + if result != nil { + results = append(results, result) + } + } + + if len(errs) > 0 { + return results, fmt.Errorf("some scans had errors: %s", strings.Join(errs, "; ")) + } + return results, nil +} diff --git a/internal/deps/scanner_test.go b/internal/deps/scanner_test.go new file mode 100644 index 0000000..50fa3ba --- /dev/null +++ b/internal/deps/scanner_test.go @@ -0,0 +1,150 @@ +package deps + +import ( + "context" + "database/sql" + "os" + "path/filepath" + "testing" + + _ "modernc.org/sqlite" +) + +type testLogger struct{ t *testing.T } + +func (l *testLogger) Infof(format string, args ...any) { l.t.Logf("INFO: "+format, args...) } +func (l *testLogger) Warnf(format string, args ...any) { l.t.Logf("WARN: "+format, args...) } +func (l *testLogger) Errorf(format string, args ...any) { l.t.Logf("ERROR: "+format, args...) } + +func TestScanProjectWithMinimalGoMod(t *testing.T) { + dir := t.TempDir() + gomod := `module example.com/testmod + +go 1.21 + +require github.com/rs/zerolog v1.34.0 +` + if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte(gomod), 0644); err != nil { + t.Fatal(err) + } + + db := setupTestDB(t) + store := NewStore(db) + if err := store.InitTables(); err != nil { + t.Fatal(err) + } + + scanner := NewScanner(store, "", &testLogger{t}) + + result, _ := scanner.ScanProject(context.Background(), dir) + if result == nil { + t.Fatal("expected non-nil result") + } + if result.Project != filepath.Base(dir) { + t.Errorf("project = %q, want %q", result.Project, filepath.Base(dir)) + } + if len(result.Deps) != 1 { + t.Errorf("got %d deps, want 1", len(result.Deps)) + } + + saved, err := store.LatestScan(result.Project) + if err != nil { + t.Fatalf("LatestScan error: %v", err) + } + if saved == nil { + t.Fatal("expected saved scan") + } + if saved.Project != result.Project { + t.Errorf("saved project = %q, want %q", saved.Project, result.Project) + } +} + +func TestStoreRoundTrip(t *testing.T) { + db := setupTestDB(t) + store := NewStore(db) + if err := store.InitTables(); err != nil { + t.Fatal(err) + } + + result := &ScanResult{ + Project: "test-project", + Deps: []Dependency{ + {Module: "github.com/foo/bar", Version: "v1.0.0"}, + }, + Findings: []Finding{ + { + Module: "github.com/foo/bar", + Version: "v1.0.0", + RiskLevel: RiskHigh, + Category: CategoryVulnerability, + Detail: "CVE-2024-1234", + CVEID: "CVE-2024-1234", + CVSSScore: 7.5, + }, + }, + } + + id, err := store.SaveScan(result) + if err != nil { + t.Fatalf("SaveScan error: %v", err) + } + if id <= 0 { + t.Fatalf("expected positive scan id, got %d", id) + } + + loaded, err := store.LatestScan("test-project") + if err != nil { + t.Fatalf("LatestScan error: %v", err) + } + if loaded == nil { + t.Fatal("expected loaded scan") + } + if len(loaded.Findings) != 1 { + t.Fatalf("got %d findings, want 1", len(loaded.Findings)) + } + if loaded.Findings[0].CVEID != "CVE-2024-1234" { + t.Errorf("CVEID = %q, want CVE-2024-1234", loaded.Findings[0].CVEID) + } +} + +func TestAllLatestScans(t *testing.T) { + db := setupTestDB(t) + store := NewStore(db) + if err := store.InitTables(); err != nil { + t.Fatal(err) + } + + for _, project := range []string{"proj-a", "proj-b"} { + result := &ScanResult{ + Project: project, + Deps: []Dependency{{Module: "github.com/x/y", Version: "v1.0.0"}}, + Findings: []Finding{{Module: "github.com/x/y", Version: "v1.0.0", RiskLevel: RiskLow, Category: CategoryLicense, Detail: "MIT"}}, + } + if _, err := store.SaveScan(result); err != nil { + t.Fatal(err) + } + } + + scans, err := store.AllLatestScans() + if err != nil { + t.Fatalf("AllLatestScans error: %v", err) + } + if len(scans) != 2 { + t.Fatalf("got %d scans, want 2", len(scans)) + } +} + +func setupTestDB(t *testing.T) *sql.DB { + t.Helper() + dbPath := filepath.Join(t.TempDir(), "test.db") + db, err := sql.Open("sqlite", dbPath) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = db.Close() }) + + if _, err := db.Exec("PRAGMA journal_mode=WAL; PRAGMA foreign_keys=ON;"); err != nil { + t.Fatal(err) + } + return db +} diff --git a/internal/deps/store.go b/internal/deps/store.go new file mode 100644 index 0000000..63ab9c2 --- /dev/null +++ b/internal/deps/store.go @@ -0,0 +1,180 @@ +package deps + +import ( + "database/sql" + "fmt" + "time" +) + +// Store handles SQLite persistence for dependency scan results. +type Store struct { + db *sql.DB +} + +// NewStore creates a new Store backed by the given database. +func NewStore(db *sql.DB) *Store { + return &Store{db: db} +} + +// InitTables creates the dep_scans and dep_findings tables if they don't exist. +func (s *Store) InitTables() error { + _, err := s.db.Exec(` + CREATE TABLE IF NOT EXISTS dep_scans ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project TEXT NOT NULL, + scanned_at DATETIME NOT NULL, + total_deps INTEGER, + critical_count INTEGER, + high_count INTEGER, + medium_count INTEGER, + low_count INTEGER + ); + + CREATE TABLE IF NOT EXISTS dep_findings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + scan_id INTEGER NOT NULL REFERENCES dep_scans(id), + module TEXT NOT NULL, + version TEXT NOT NULL, + risk_level TEXT NOT NULL, + category TEXT NOT NULL, + detail TEXT NOT NULL, + cve_id TEXT, + cvss_score REAL + ); + + CREATE INDEX IF NOT EXISTS idx_dep_scans_project ON dep_scans(project, scanned_at DESC); + CREATE INDEX IF NOT EXISTS idx_dep_findings_scan ON dep_findings(scan_id); + `) + if err != nil { + return fmt.Errorf("creating dep tables: %w", err) + } + return nil +} + +// SaveScan persists a scan result and its findings in a transaction. +func (s *Store) SaveScan(result *ScanResult) (int64, error) { + tx, err := s.db.Begin() + if err != nil { + return 0, fmt.Errorf("begin transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + critical, high, medium, low := SummarizeResults(result) + + res, err := tx.Exec( + `INSERT INTO dep_scans (project, scanned_at, total_deps, critical_count, high_count, medium_count, low_count) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + result.Project, result.ScannedAt, len(result.Deps), critical, high, medium, low, + ) + if err != nil { + return 0, fmt.Errorf("inserting scan: %w", err) + } + + scanID, err := res.LastInsertId() + if err != nil { + return 0, fmt.Errorf("getting scan id: %w", err) + } + + for _, f := range result.Findings { + _, err := tx.Exec( + `INSERT INTO dep_findings (scan_id, module, version, risk_level, category, detail, cve_id, cvss_score) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + scanID, f.Module, f.Version, f.RiskLevel, f.Category, f.Detail, f.CVEID, f.CVSSScore, + ) + if err != nil { + return 0, fmt.Errorf("inserting finding: %w", err) + } + } + + if err := tx.Commit(); err != nil { + return 0, fmt.Errorf("committing scan: %w", err) + } + + result.ID = scanID + return scanID, nil +} + +// LatestScan retrieves the most recent scan for a project. +func (s *Store) LatestScan(project string) (*ScanResult, error) { + row := s.db.QueryRow( + `SELECT id, project, scanned_at FROM dep_scans WHERE project = ? ORDER BY scanned_at DESC LIMIT 1`, + project, + ) + + var result ScanResult + var scannedAt string + if err := row.Scan(&result.ID, &result.Project, &scannedAt); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("querying scan: %w", err) + } + result.ScannedAt, _ = time.Parse(time.RFC3339, scannedAt) + + findings, err := s.loadFindings(result.ID) + if err != nil { + return nil, err + } + result.Findings = findings + + return &result, nil +} + +// AllLatestScans retrieves the most recent scan for each project. +func (s *Store) AllLatestScans() ([]*ScanResult, error) { + rows, err := s.db.Query( + `SELECT id, project, scanned_at FROM dep_scans ds + WHERE scanned_at = (SELECT MAX(scanned_at) FROM dep_scans WHERE project = ds.project) + ORDER BY project`, + ) + if err != nil { + return nil, fmt.Errorf("querying scans: %w", err) + } + defer rows.Close() + + var results []*ScanResult + for rows.Next() { + var r ScanResult + var scannedAt string + if err := rows.Scan(&r.ID, &r.Project, &scannedAt); err != nil { + return nil, fmt.Errorf("scanning row: %w", err) + } + r.ScannedAt, _ = time.Parse(time.RFC3339, scannedAt) + + findings, err := s.loadFindings(r.ID) + if err != nil { + return nil, err + } + r.Findings = findings + results = append(results, &r) + } + + return results, rows.Err() +} + +func (s *Store) loadFindings(scanID int64) ([]Finding, error) { + rows, err := s.db.Query( + `SELECT module, version, risk_level, category, detail, cve_id, cvss_score + FROM dep_findings WHERE scan_id = ? ORDER BY id`, + scanID, + ) + if err != nil { + return nil, fmt.Errorf("querying findings: %w", err) + } + defer rows.Close() + + var findings []Finding + for rows.Next() { + var f Finding + var cveID sql.NullString + var cvss sql.NullFloat64 + if err := rows.Scan(&f.Module, &f.Version, &f.RiskLevel, &f.Category, &f.Detail, &cveID, &cvss); err != nil { + return nil, fmt.Errorf("scanning finding: %w", err) + } + f.CVEID = cveID.String + f.CVSSScore = cvss.Float64 + findings = append(findings, f) + } + + return findings, rows.Err() +} diff --git a/internal/deps/vulns.go b/internal/deps/vulns.go new file mode 100644 index 0000000..fc3412a --- /dev/null +++ b/internal/deps/vulns.go @@ -0,0 +1,322 @@ +package deps + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "math" + "net/http" + "strings" + "sync" + "time" +) + +// VulnChecker queries the OSV.dev API for known vulnerabilities. +type VulnChecker struct { + client *http.Client + baseURL string + sem chan struct{} +} + +// NewVulnChecker creates a VulnChecker with sensible defaults. +func NewVulnChecker() *VulnChecker { + return &VulnChecker{ + client: &http.Client{Timeout: 10 * time.Second}, + baseURL: "https://api.osv.dev", + sem: make(chan struct{}, 10), + } +} + +type osvQuery struct { + Package osvPackage `json:"package"` + Version string `json:"version"` +} + +type osvPackage struct { + Name string `json:"name"` + Ecosystem string `json:"ecosystem"` +} + +type osvResponse struct { + Vulns []osvVuln `json:"vulns"` +} + +type osvVuln struct { + ID string `json:"id"` + Aliases []string `json:"aliases"` + Summary string `json:"summary"` + Severity []osvSeverity `json:"severity"` +} + +type osvSeverity struct { + Type string `json:"type"` + Score string `json:"score"` +} + +// CheckVulnerabilities queries OSV.dev for each dependency and returns findings. +// Returns a non-nil error if any API calls fail, along with whatever findings +// were successfully retrieved. +func (vc *VulnChecker) CheckVulnerabilities(ctx context.Context, deps []Dependency) ([]Finding, error) { + var ( + mu sync.Mutex + findings []Finding + errs []string + ) + + var wg sync.WaitGroup + for _, dep := range deps { + wg.Add(1) + go func(d Dependency) { + defer wg.Done() + vc.sem <- struct{}{} + defer func() { <-vc.sem }() + + fs, err := vc.queryModule(ctx, d) + mu.Lock() + defer mu.Unlock() + if err != nil { + errs = append(errs, fmt.Sprintf("%s@%s: %v", d.Module, d.Version, err)) + return + } + findings = append(findings, fs...) + }(dep) + } + wg.Wait() + + if len(errs) > 0 { + return findings, fmt.Errorf("vulnerability check errors: %s", strings.Join(errs, "; ")) + } + return findings, nil +} + +func (vc *VulnChecker) queryModule(ctx context.Context, dep Dependency) ([]Finding, error) { + query := osvQuery{ + Package: osvPackage{Name: dep.Module, Ecosystem: "Go"}, + Version: dep.Version, + } + + body, err := json.Marshal(query) + if err != nil { + return nil, fmt.Errorf("marshaling query: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, vc.baseURL+"/v1/query", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := vc.client.Do(req) + if err != nil { + return nil, fmt.Errorf("querying OSV.dev: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("OSV.dev returned %d: %s", resp.StatusCode, string(respBody)) + } + + var result osvResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("decoding response: %w", err) + } + + var findings []Finding + for _, vuln := range result.Vulns { + cveID := extractCVEID(vuln) + cvss := extractCVSSScore(vuln) + findings = append(findings, Finding{ + Module: dep.Module, + Version: dep.Version, + RiskLevel: cvssToRiskLevel(cvss), + Category: CategoryVulnerability, + Detail: formatVulnDetail(vuln, cveID, cvss), + CVEID: cveID, + CVSSScore: cvss, + }) + } + + return findings, nil +} + +func extractCVEID(vuln osvVuln) string { + for _, alias := range vuln.Aliases { + if strings.HasPrefix(alias, "CVE-") { + return alias + } + } + return vuln.ID +} + +// extractCVSSScore extracts the CVSS score from a vulnerability's severity data. +// OSV.dev returns CVSS vector strings, not raw numeric scores. We parse the +// vector and compute the base score per the CVSS 3.1 specification. +// Returns -1 if no severity data is available. +func extractCVSSScore(vuln osvVuln) float64 { + for _, sev := range vuln.Severity { + if sev.Type == "CVSS_V3" { + score := parseCVSS3Vector(sev.Score) + if score >= 0 { + return score + } + } + } + return -1 +} + +// parseCVSS3Vector computes a CVSS 3.1 base score from a vector string. +// Example input: "CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H" +func parseCVSS3Vector(vector string) float64 { + if vector == "" { + return -1 + } + + metrics := make(map[string]string) + for _, segment := range strings.Split(vector, "/") { + kv := strings.SplitN(segment, ":", 2) + if len(kv) == 2 { + metrics[kv[0]] = kv[1] + } + } + + av, acv, prv, uiv := metrics["AV"], metrics["AC"], metrics["PR"], metrics["UI"] + cv, iv, avail := metrics["C"], metrics["I"], metrics["A"] + sv := metrics["S"] + + if av == "" || acv == "" || prv == "" || uiv == "" || cv == "" || iv == "" || avail == "" || sv == "" { + return -1 + } + + scopeChanged := sv == "C" + + iss := 1 - (1-impactWeight(cv))*(1-impactWeight(iv))*(1-impactWeight(avail)) + + var impact float64 + if scopeChanged { + impact = 7.52*(iss-0.029) - 3.25*math.Pow(iss*0.9731-0.02, 13) + } else { + impact = 6.42 * iss + } + + if impact <= 0 { + return 0.0 + } + + exploitability := 8.22 * avWeight(av) * acWeight(acv) * prWeight(prv, scopeChanged) * uiWeight(uiv) + + var base float64 + if scopeChanged { + base = math.Min(1.08*(impact+exploitability), 10.0) + } else { + base = math.Min(impact+exploitability, 10.0) + } + + return cvssRoundUp(base) +} + +func avWeight(v string) float64 { + switch v { + case "N": + return 0.85 + case "A": + return 0.62 + case "L": + return 0.55 + case "P": + return 0.20 + default: + return 0.85 + } +} + +func acWeight(v string) float64 { + switch v { + case "L": + return 0.77 + case "H": + return 0.44 + default: + return 0.77 + } +} + +func prWeight(v string, scopeChanged bool) float64 { + switch v { + case "N": + return 0.85 + case "L": + if scopeChanged { + return 0.68 + } + return 0.62 + case "H": + if scopeChanged { + return 0.50 + } + return 0.27 + default: + return 0.85 + } +} + +func uiWeight(v string) float64 { + switch v { + case "N": + return 0.85 + case "R": + return 0.62 + default: + return 0.85 + } +} + +func impactWeight(v string) float64 { + switch v { + case "H": + return 0.56 + case "L": + return 0.22 + case "N": + return 0.0 + default: + return 0.56 + } +} + +// cvssRoundUp rounds up to one decimal place per the CVSS spec. +func cvssRoundUp(x float64) float64 { + i := math.Floor(x * 10) + if (x*10 - i) > 0.0 { + return (i + 1) / 10.0 + } + return i / 10.0 +} + +func cvssToRiskLevel(score float64) RiskLevel { + switch { + case score < 0: + return RiskMedium + case score >= 9.0: + return RiskCritical + case score >= 7.0: + return RiskHigh + case score >= 4.0: + return RiskMedium + default: + return RiskLow + } +} + +func formatVulnDetail(vuln osvVuln, cveID string, cvss float64) string { + summary := vuln.Summary + if summary == "" { + summary = "No description available" + } + if cvss >= 0 { + return fmt.Sprintf("%s (CVSS %.1f): %s", cveID, cvss, summary) + } + return fmt.Sprintf("%s: %s", cveID, summary) +} diff --git a/internal/deps/vulns_test.go b/internal/deps/vulns_test.go new file mode 100644 index 0000000..0f1b9e1 --- /dev/null +++ b/internal/deps/vulns_test.go @@ -0,0 +1,191 @@ +package deps + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestCVSSToRiskLevel(t *testing.T) { + tests := []struct { + score float64 + want RiskLevel + }{ + {9.8, RiskCritical}, + {9.0, RiskCritical}, + {8.5, RiskHigh}, + {7.0, RiskHigh}, + {5.0, RiskMedium}, + {4.0, RiskMedium}, + {3.9, RiskLow}, + {0.0, RiskLow}, + {-1, RiskMedium}, + } + for _, tt := range tests { + got := cvssToRiskLevel(tt.score) + if got != tt.want { + t.Errorf("cvssToRiskLevel(%v) = %v, want %v", tt.score, got, tt.want) + } + } +} + +func TestParseCVSS3Vector(t *testing.T) { + tests := []struct { + name string + vector string + want float64 + }{ + { + name: "critical severity", + vector: "CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H", + want: 9.8, + }, + { + name: "high severity", + vector: "CVSS:3.1/AV:N/AC:L/PR:N/UI:R/S:U/C:H/I:H/A:H", + want: 8.8, + }, + { + name: "medium severity", + vector: "CVSS:3.1/AV:N/AC:H/PR:N/UI:R/S:U/C:L/I:L/A:N", + want: 4.2, + }, + { + name: "low severity", + vector: "CVSS:3.1/AV:L/AC:H/PR:H/UI:R/S:U/C:L/I:N/A:N", + want: 1.8, + }, + { + name: "empty", + vector: "", + want: -1, + }, + { + name: "incomplete vector", + vector: "CVSS:3.1/AV:N", + want: -1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseCVSS3Vector(tt.vector) + if got != tt.want { + t.Errorf("parseCVSS3Vector(%q) = %v, want %v", tt.vector, got, tt.want) + } + }) + } +} + +func TestCheckVulnerabilities(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var query osvQuery + if err := json.NewDecoder(r.Body).Decode(&query); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + if query.Package.Name == "github.com/vuln/pkg" { + resp := osvResponse{ + Vulns: []osvVuln{ + { + ID: "GHSA-1234", + Aliases: []string{"CVE-2024-1234"}, + Summary: "Remote code execution", + Severity: []osvSeverity{ + {Type: "CVSS_V3", Score: "CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:H/I:H/A:H"}, + }, + }, + }, + } + json.NewEncoder(w).Encode(resp) + return + } + + json.NewEncoder(w).Encode(osvResponse{}) + })) + defer srv.Close() + + vc := NewVulnChecker() + vc.baseURL = srv.URL + + deps := []Dependency{ + {Module: "github.com/vuln/pkg", Version: "v1.0.0"}, + {Module: "github.com/safe/pkg", Version: "v2.0.0"}, + } + + findings, err := vc.CheckVulnerabilities(context.Background(), deps) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(findings) != 1 { + t.Fatalf("got %d findings, want 1", len(findings)) + } + + f := findings[0] + if f.CVEID != "CVE-2024-1234" { + t.Errorf("CVEID = %q, want CVE-2024-1234", f.CVEID) + } + if f.CVSSScore != 9.8 { + t.Errorf("CVSSScore = %v, want 9.8", f.CVSSScore) + } + if f.RiskLevel != RiskCritical { + t.Errorf("RiskLevel = %v, want critical", f.RiskLevel) + } +} + +func TestCheckVulnerabilitiesAPIError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "service unavailable", http.StatusServiceUnavailable) + })) + defer srv.Close() + + vc := NewVulnChecker() + vc.baseURL = srv.URL + + deps := []Dependency{ + {Module: "github.com/some/pkg", Version: "v1.0.0"}, + } + + findings, err := vc.CheckVulnerabilities(context.Background(), deps) + if err == nil { + t.Fatal("expected error for API failure") + } + if len(findings) != 0 { + t.Errorf("expected no findings on error, got %d", len(findings)) + } +} + +func TestExtractCVEID(t *testing.T) { + tests := []struct { + name string + vuln osvVuln + wantCVE string + }{ + { + name: "has CVE alias", + vuln: osvVuln{ID: "GHSA-1234", Aliases: []string{"CVE-2024-5678"}}, + wantCVE: "CVE-2024-5678", + }, + { + name: "no CVE alias", + vuln: osvVuln{ID: "GHSA-1234", Aliases: []string{"GHSA-abcd"}}, + wantCVE: "GHSA-1234", + }, + { + name: "no aliases", + vuln: osvVuln{ID: "GO-2024-001"}, + wantCVE: "GO-2024-001", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractCVEID(tt.vuln) + if got != tt.wantCVE { + t.Errorf("extractCVEID() = %q, want %q", got, tt.wantCVE) + } + }) + } +}