diff --git a/internal/processing/checksum.go b/internal/processing/checksum.go new file mode 100644 index 00000000..5f05d0aa --- /dev/null +++ b/internal/processing/checksum.go @@ -0,0 +1,125 @@ +package processing + +import ( + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "hash" + "io" + "os" + "strings" +) + +// ChecksumResult holds the outcome of a checksum verification. +type ChecksumResult struct { + Algorithm string + Expected string + Actual string + Match bool +} + +// VerifyChecksum computes the hash of a file and compares it to the expected value. +// algorithm should be one of: md5, sha1, sha256. +// expected should be a hex-encoded hash string. +func VerifyChecksum(filePath string, algorithm string, expected string) (*ChecksumResult, error) { + if filePath == "" || algorithm == "" || expected == "" { + return nil, fmt.Errorf("filepath, algorithm, and expected hash are all required") + } + + algorithm = strings.ToLower(algorithm) + expected = strings.ToLower(strings.TrimSpace(expected)) + + var h hash.Hash + switch algorithm { + case "md5": + h = md5.New() + case "sha1", "sha-1": + algorithm = "sha1" + h = sha1.New() + case "sha256", "sha-256": + algorithm = "sha256" + h = sha256.New() + default: + return nil, fmt.Errorf("unsupported checksum algorithm: %s", algorithm) + } + + f, err := os.Open(filePath) + if err != nil { + return nil, fmt.Errorf("failed to open file for checksum: %w", err) + } + defer func() { _ = f.Close() }() + + if _, err := io.Copy(h, f); err != nil { + return nil, fmt.Errorf("failed to read file for checksum: %w", err) + } + + actual := hex.EncodeToString(h.Sum(nil)) + return &ChecksumResult{ + Algorithm: algorithm, + Expected: expected, + Actual: actual, + Match: actual == expected, + }, nil +} + +// ParseDigestHeader parses an HTTP Digest header (RFC 3230) and returns +// the algorithm and hex-encoded hash. +// Example header: "sha-256=base64hash" or "SHA-256=base64hash" +func ParseDigestHeader(header string) (algorithm string, hexHash string, err error) { + parts := strings.SplitN(header, "=", 2) + if len(parts) != 2 { + return "", "", nil + } + + algo := strings.ToLower(strings.TrimSpace(parts[0])) + value := strings.TrimSpace(parts[1]) + + switch algo { + case "sha-256": + algo = "sha256" + case "sha-1": + algo = "sha1" + case "md5": + // no normalization needed + default: + return "", "", nil + } + + expectedBytes := 0 + switch algo { + case "md5": + expectedBytes = md5.Size + case "sha1": + expectedBytes = sha1.Size + case "sha256": + expectedBytes = sha256.Size + } + expectedHexLen := expectedBytes * 2 + if len(value) == expectedHexLen { + if decoded, err := hex.DecodeString(value); err == nil { + if len(decoded) != expectedBytes { + return "", "", fmt.Errorf("digest length mismatch for %s", algo) + } + return algo, strings.ToLower(value), nil + } + } + + for _, enc := range []*base64.Encoding{ + base64.StdEncoding, + base64.URLEncoding, + base64.RawStdEncoding, + base64.RawURLEncoding, + } { + if decoded, err := enc.DecodeString(value); err == nil { + if len(decoded) != expectedBytes { + return "", "", fmt.Errorf("digest length mismatch for %s", algo) + } + return algo, hex.EncodeToString(decoded), nil + } + } + + return "", "", nil +} diff --git a/internal/processing/checksum_test.go b/internal/processing/checksum_test.go new file mode 100644 index 00000000..2e90d188 --- /dev/null +++ b/internal/processing/checksum_test.go @@ -0,0 +1,127 @@ +package processing + +import ( + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "encoding/hex" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestVerifyChecksum_SHA256(t *testing.T) { + // Create a temp file with known content + dir := t.TempDir() + path := filepath.Join(dir, "test.bin") + content := []byte("hello surge") + require.NoError(t, os.WriteFile(path, content, 0o644)) + + // Compute expected hash + h := sha256.Sum256(content) + expected := hex.EncodeToString(h[:]) + + result, err := VerifyChecksum(path, "sha256", expected) + require.NoError(t, err) + assert.True(t, result.Match) + assert.Equal(t, expected, result.Actual) +} + +func TestVerifyChecksum_MD5(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.bin") + content := []byte("hello surge") + require.NoError(t, os.WriteFile(path, content, 0o644)) + + h := md5.Sum(content) + expected := hex.EncodeToString(h[:]) + + result, err := VerifyChecksum(path, "md5", expected) + require.NoError(t, err) + assert.True(t, result.Match) + assert.Equal(t, "md5", result.Algorithm) + assert.Equal(t, expected, result.Actual) +} + +func TestVerifyChecksum_SHA1(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.bin") + content := []byte("hello surge") + require.NoError(t, os.WriteFile(path, content, 0o644)) + + h := sha1.Sum(content) + expected := hex.EncodeToString(h[:]) + + result, err := VerifyChecksum(path, "sha-1", expected) + require.NoError(t, err) + assert.True(t, result.Match) + assert.Equal(t, "sha1", result.Algorithm) + assert.Equal(t, expected, result.Actual) +} + +func TestVerifyChecksum_Mismatch(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.bin") + require.NoError(t, os.WriteFile(path, []byte("hello"), 0o644)) + + result, err := VerifyChecksum(path, "sha256", "0000000000000000000000000000000000000000000000000000000000000000") + require.NoError(t, err) + assert.False(t, result.Match) +} + +func TestVerifyChecksum_UnsupportedAlgorithm(t *testing.T) { + _, err := VerifyChecksum("/tmp/test", "sha512", "abc") + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported") +} + +func TestVerifyChecksum_EmptyArgs(t *testing.T) { + _, err := VerifyChecksum("", "sha256", "abc") + assert.Error(t, err) +} + +func mustParseDigestHeader(t *testing.T, header string) (string, string) { + t.Helper() + algo, hash, err := ParseDigestHeader(header) + require.NoError(t, err) + return algo, hash +} + +func TestParseDigestHeader_SHA256Base64(t *testing.T) { + // sha256 of empty string in base64 + algo, hash := mustParseDigestHeader(t, "sha-256=47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=") + assert.Equal(t, "sha256", algo) + assert.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hash) +} + +func TestParseDigestHeader_MD5Hex(t *testing.T) { + algo, hash := mustParseDigestHeader(t, "md5=d41d8cd98f00b204e9800998ecf8427e") + assert.Equal(t, "md5", algo) + assert.Equal(t, "d41d8cd98f00b204e9800998ecf8427e", hash) +} + +func TestParseDigestHeader_Invalid(t *testing.T) { + algo, hash := mustParseDigestHeader(t, "invalid") + assert.Empty(t, algo) + assert.Empty(t, hash) +} + +func TestParseDigestHeader_UnsupportedAlgo(t *testing.T) { + algo, hash := mustParseDigestHeader(t, "sha-512=abc") + assert.Empty(t, algo) + assert.Empty(t, hash) +} + +func TestParseDigestHeader_SHA256UnpaddedBase64(t *testing.T) { + algo, hash := mustParseDigestHeader(t, "sha-256=47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU") + assert.Equal(t, "sha256", algo) + assert.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hash) +} + +func TestParseDigestHeader_SHA256WrongLengthHex(t *testing.T) { + _, _, err := ParseDigestHeader("sha-256=d41d8cd98f00b204e9800998ecf8427e") + require.Error(t, err) +}