Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions grypedb/grypedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ package grypedb
import (
"compress/gzip"
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -119,10 +121,14 @@ func New(dbPath string, opts ...Option) (*Source, error) {
// Download downloads the latest Grype database to the specified directory.
// Returns the path to the downloaded database file.
func Download(ctx context.Context, destDir string) (string, error) {
return downloadFrom(ctx, LatestDBURL, destDir)
}

func downloadFrom(ctx context.Context, listingURL, destDir string) (string, error) {
client := &http.Client{Timeout: DefaultTimeout}

// Fetch listing to get latest database URL
req, err := http.NewRequestWithContext(ctx, "GET", LatestDBURL, nil)
req, err := http.NewRequestWithContext(ctx, "GET", listingURL, nil)
if err != nil {
return "", fmt.Errorf("creating listing request: %w", err)
}
Expand Down Expand Up @@ -178,21 +184,34 @@ func Download(ctx context.Context, destDir string) (string, error) {
}
defer func() { _ = outFile.Close() }()

// Decompress if gzipped
var reader io.Reader = resp.Body
// Hash the compressed download to verify against the listing checksum
hasher := sha256.New()
body := io.TeeReader(resp.Body, hasher)

// Decompress if gzipped, with a 2 GB cap on decompressed output
const maxDecompressedSize = 2 << 30
reader := body
if strings.HasSuffix(latest.URL, ".gz") {
gzReader, err := gzip.NewReader(resp.Body)
gzReader, err := gzip.NewReader(body)
if err != nil {
return "", fmt.Errorf("creating gzip reader: %w", err)
}
defer func() { _ = gzReader.Close() }()
reader = gzReader
reader = io.LimitReader(gzReader, maxDecompressedSize)
}

if _, err := io.Copy(outFile, reader); err != nil {
return "", fmt.Errorf("writing database: %w", err)
}

if latest.Checksum != "" {
got := "sha256:" + hex.EncodeToString(hasher.Sum(nil))
if got != latest.Checksum {
_ = os.Remove(dbPath)
return "", fmt.Errorf("checksum mismatch: got %s, want %s", got, latest.Checksum)
}
}

return dbPath, nil
}

Expand Down
100 changes: 100 additions & 0 deletions grypedb/grypedb_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package grypedb

import (
"bytes"
"compress/gzip"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
)

func TestDownloadVerifiesChecksum(t *testing.T) {
dbContent := []byte("fake database content for checksum test")

var gzBuf bytes.Buffer
gw := gzip.NewWriter(&gzBuf)
if _, err := gw.Write(dbContent); err != nil {
t.Fatal(err)
}
if err := gw.Close(); err != nil {
t.Fatal(err)
}
gzData := gzBuf.Bytes()

h := sha256.Sum256(gzData)
goodChecksum := "sha256:" + hex.EncodeToString(h[:])

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/listing.json" {
listing := dbListing{
Available: []dbEntry{{
Built: time.Now(),
Version: 5,
URL: "http://" + r.Host + "/db.tar.gz",
Checksum: goodChecksum,
}},
}
_ = json.NewEncoder(w).Encode(listing)
return
}
_, _ = w.Write(gzData)
}))
defer ts.Close()

destDir := t.TempDir()
path, err := downloadFrom(context.Background(), ts.URL+"/listing.json", destDir)
if err != nil {
t.Fatalf("download with good checksum failed: %v", err)
}

content, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(content, dbContent) {
t.Error("downloaded content does not match expected")
}
}

func TestDownloadRejectsChecksumMismatch(t *testing.T) {
dbContent := []byte("database content")

var gzBuf bytes.Buffer
gw := gzip.NewWriter(&gzBuf)
_, _ = gw.Write(dbContent)
_ = gw.Close()
gzData := gzBuf.Bytes()

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/listing.json" {
listing := dbListing{
Available: []dbEntry{{
Built: time.Now(),
Version: 5,
URL: "http://" + r.Host + "/db.tar.gz",
Checksum: "sha256:0000000000000000000000000000000000000000000000000000000000000000",
}},
}
_ = json.NewEncoder(w).Encode(listing)
return
}
_, _ = w.Write(gzData)
}))
defer ts.Close()

destDir := t.TempDir()
_, err := downloadFrom(context.Background(), ts.URL+"/listing.json", destDir)
if err == nil {
t.Fatal("expected checksum mismatch error, got nil")
}
if !strings.Contains(err.Error(), "checksum mismatch") {
t.Fatalf("expected checksum mismatch error, got: %v", err)
}
}