From 39a83548a02d3d5862a92fa2c118090131384ea8 Mon Sep 17 00:00:00 2001 From: Tieg Zaharia Date: Tue, 11 Mar 2025 21:56:31 -0400 Subject: [PATCH] Adds archive type detection and zip support to ExtractArchiveFile(). Signed-off-by: Tieg Zaharia --- internal/utils/archive_extract.go | 99 ++++++++++++-- internal/utils/archive_extract_test.go | 175 ++++++++++++++++++++++--- 2 files changed, 250 insertions(+), 24 deletions(-) diff --git a/internal/utils/archive_extract.go b/internal/utils/archive_extract.go index 18d9dfc2..27e95889 100644 --- a/internal/utils/archive_extract.go +++ b/internal/utils/archive_extract.go @@ -2,27 +2,99 @@ package utils import ( "archive/tar" + "archive/zip" "compress/gzip" "fmt" "io" + "log" "log/slog" + "net/http" "os" "path/filepath" "strings" ) -// ExtractArchiveFile extracts a .tar.gz / .tgz file located at archivePath, -// using outputDir as the root of the extracted files. +// ExtractArchiveFile extracts a .tar.gz / .tgz file or .zip file located +// at archivePath, using outputDir as the root of the extracted files. func ExtractArchiveFile(archivePath string, outputDir string) error { + if outputDir == "" { + return fmt.Errorf("outputDir is empty") + } + f, err := os.Open(archivePath) if err != nil { return err } defer f.Close() - return processGzipFile(f, func(reader io.Reader) error { - return extractTar(reader, outputDir) - }) + fileType, err := detectFileType(f) + if err != nil { + return err + } + + if fileType == "application/zip" { + return processZipFile(archivePath, outputDir) + } else if fileType == "application/x-gzip" { + return processGzipFile(f, func(reader io.Reader) error { + return extractTar(reader, outputDir) + }) + } else { + return fmt.Errorf("%s is not a supported archive file type: %s", archivePath, fileType) + } +} + +func processZipFile(filePath, outputDir string) error { + if err := os.MkdirAll(outputDir, os.ModePerm); err != nil { + return fmt.Errorf("create dir for %s failed: %w", outputDir, err) + } + + zipReader, err := zip.OpenReader(filePath) + if err != nil { + log.Fatal(err) + } + defer zipReader.Close() + + for _, file := range zipReader.File { + outputPath := filepath.Join(outputDir, file.Name) + + // check for ZipSlip (https://snyk.io/research/zip-slip-vulnerability) by ensuring + // outputPath (cleaned) actually is inside output directory that was specified + if !strings.HasPrefix(outputPath, filepath.Join(outputDir)+string(os.PathSeparator)) { + // Note: this error string is used in a test + return fmt.Errorf("archive path escapes output dir: %s", file.Name) + } + + if file.FileInfo().IsDir() { + if err := os.MkdirAll(outputPath, 0o755); err != nil { + return err + } + continue + } + + // Ensure parent directories exist before creating the file + if err := os.MkdirAll(filepath.Dir(outputPath), 0o755); err != nil { + return err + } + + srcFile, err := file.Open() + if err != nil { + return err + } + defer srcFile.Close() + + destFile, err := os.Create(outputPath) + if err != nil { + return err + } + defer destFile.Close() + + _, err = io.Copy(destFile, srcFile) + if err != nil { + return err + } + } + + return nil } func processGzipFile(gzFile *os.File, process func(io.Reader) error) error { @@ -48,10 +120,6 @@ extractTar extracts the contents of the given stream of bytes of a tar archive, outputDir as the root of the extracted files. */ func extractTar(tarStream io.Reader, outputDir string) error { - if outputDir == "" { - return fmt.Errorf("outputDir is empty") - } - tarReader := tar.NewReader(tarStream) var header *tar.Header @@ -107,3 +175,16 @@ func extractTar(tarStream io.Reader, outputDir string) error { return nil } + +func detectFileType(archiveFile *os.File) (string, error) { + // DetectContentType never uses more than the first 512 bytes. + buffer := make([]byte, 512) + _, err := archiveFile.Read(buffer) + if err != nil { + return "", err + } + + mimeType := http.DetectContentType(buffer) + + return mimeType, nil +} diff --git a/internal/utils/archive_extract_test.go b/internal/utils/archive_extract_test.go index dba1c2f5..7bdbbf61 100644 --- a/internal/utils/archive_extract_test.go +++ b/internal/utils/archive_extract_test.go @@ -2,6 +2,7 @@ package utils import ( "archive/tar" + "archive/zip" "compress/gzip" "fmt" "os" @@ -87,10 +88,48 @@ func createTgzFile(path string, headers []*tar.Header) (err error) { return tarWriter.Close() } -func makePaths(t *testing.T, testName string) (workDir, archivePath, extractPath string, err error) { +func createZipFile(path string, files map[string]string) (err error) { + zipFile, err := os.Create(path) + if err != nil { + return fmt.Errorf("failed to create temp archive file: %w", err) + } + defer func() { + closeErr := zipFile.Close() + if closeErr != nil && err == nil { + err = fmt.Errorf("failed to close temp archive file: %w", closeErr) + } + }() + + zipWriter := zip.NewWriter(zipFile) + + for filePath, fileBody := range files { + writer, err := zipWriter.Create(filePath) + if err != nil { + return err + } + + // _, err = io.Copy(writer, strings.NewReader(fileBody)) + _, err = writer.Write([]byte(fileBody)) + if err != nil { + return err + } + } + + return zipWriter.Close() +} + +// extension: e.g. ".tar.gz", ".zip" +func makePaths(t *testing.T, testName, extension string) (workDir, archivePath, extractPath string, err error) { t.Helper() workDir = t.TempDir() - archivePath = filepath.Join(workDir, testName+".tar.gz") + // On MacOS, the temp dir is a symlink from /var to /private/var + workDir, err = filepath.EvalSymlinks(workDir) + if err != nil { + fmt.Println("Error reading symlink:", err) + return + } + + archivePath = filepath.Join(workDir, testName+extension) extractPath = filepath.Join(workDir, "extracted") if err = os.Mkdir(extractPath, 0o700); err != nil { @@ -100,7 +139,7 @@ func makePaths(t *testing.T, testName string) (workDir, archivePath, extractPath return } -func doExtractionTest(archivePath, extractPath string, archiveHeaders []*tar.Header, runChecks func() error) (err error) { +func doTgzExtractionTest(archivePath, extractPath string, archiveHeaders []*tar.Header, runChecks func() error) (err error) { if err = createTgzFile(archivePath, archiveHeaders); err != nil { return fmt.Errorf("failed to create test tgz file: %w", err) } @@ -114,10 +153,24 @@ func doExtractionTest(archivePath, extractPath string, archiveHeaders []*tar.Hea return runChecks() } +func doZipExtractionTest(archivePath, extractPath string, testFiles map[string]string, runChecks func() error) (err error) { + if err = createZipFile(archivePath, testFiles); err != nil { + return fmt.Errorf("failed to create test tgz file: %w", err) + } + + log.Initialize("") + + if err = ExtractArchiveFile(archivePath, extractPath); err != nil { + return fmt.Errorf("extract failed: %w", err) + } + + return runChecks() +} + func TestExtractSimpleTarGzFile(t *testing.T) { testName := "simple" - _, archivePath, extractPath, err := makePaths(t, testName) + _, archivePath, extractPath, err := makePaths(t, testName, ".tar.gz") if err != nil { t.Errorf("%v", err) return @@ -128,7 +181,53 @@ func TestExtractSimpleTarGzFile(t *testing.T) { makeFileHeader("test/1.txt", 10), } - err = doExtractionTest(archivePath, extractPath, testHeaders, func() error { + err = doTgzExtractionTest(archivePath, extractPath, testHeaders, func() error { + dirInfo, err := os.Stat(filepath.Join(extractPath, "test")) + if err != nil { + return fmt.Errorf("stat extracted dir: %w", err) + } + if dirInfo.Name() != "test" { + return fmt.Errorf("expected extracted directory name 'test', got %s", dirInfo.Name()) + } + if !dirInfo.IsDir() { + return fmt.Errorf("expected to extract directory but it was not a directory") + } + + fileInfo, err := os.Stat(filepath.Join(extractPath, "test", "1.txt")) + if err != nil { + return fmt.Errorf("stat extracted file: %w", err) + } + if fileInfo.Name() != "1.txt" { + return fmt.Errorf("expected to extract file with name '1.txt' but it has name %s", fileInfo.Name()) + } + if fileInfo.Size() != 10 { + return fmt.Errorf("expected to extract file with size 10 but it has size %d", fileInfo.Size()) + } + if fileInfo.IsDir() { + return fmt.Errorf("expected to extract file but it was a directory") + } + return nil + }) + + if err != nil { + t.Errorf("Error: %v", err) + } +} + +func TestExtractSimpleZipFile(t *testing.T) { + testName := "simple" + + _, archivePath, extractPath, err := makePaths(t, testName, ".zip") + if err != nil { + t.Errorf("%v", err) + return + } + + testFiles := map[string]string{ + "test/1.txt": "Some text.", + } + + err = doZipExtractionTest(archivePath, extractPath, testFiles, func() error { dirInfo, err := os.Stat(filepath.Join(extractPath, "test")) if err != nil { return fmt.Errorf("stat extracted dir: %w", err) @@ -164,7 +263,7 @@ func TestExtractSimpleTarGzFile(t *testing.T) { func TestExtractMissingParentDir(t *testing.T) { testName := "simple" - _, archivePath, extractPath, err := makePaths(t, testName) + _, archivePath, extractPath, err := makePaths(t, testName, ".tar.gz") if err != nil { t.Errorf("%v", err) return @@ -174,7 +273,7 @@ func TestExtractMissingParentDir(t *testing.T) { makeFileHeader("test/1.txt", 10), } - err = doExtractionTest(archivePath, extractPath, testHeaders, func() error { + err = doTgzExtractionTest(archivePath, extractPath, testHeaders, func() error { dirInfo, err := os.Stat(filepath.Join(extractPath, "test")) if err != nil { return fmt.Errorf("stat extracted dir: %w", err) @@ -210,7 +309,7 @@ func TestExtractMissingParentDir(t *testing.T) { func TestExtractAbsolutePathTarGzFile(t *testing.T) { testName := "abs-path" - _, archivePath, extractPath, err := makePaths(t, testName) + _, archivePath, extractPath, err := makePaths(t, testName, ".tar.gz") if err != nil { t.Errorf("%v", err) return @@ -221,7 +320,53 @@ func TestExtractAbsolutePathTarGzFile(t *testing.T) { makeFileHeader("/2.txt", 0), } - err = doExtractionTest(archivePath, extractPath, testHeaders, func() error { + err = doTgzExtractionTest(archivePath, extractPath, testHeaders, func() error { + dirInfo, err := os.Stat(filepath.Join(extractPath, "test")) + if err != nil { + return fmt.Errorf("stat extracted dir: %w", err) + } + if dirInfo.Name() != "test" { + return fmt.Errorf("expected extracted directory name 'test', got %s", dirInfo.Name()) + } + if !dirInfo.IsDir() { + return fmt.Errorf("expected to extract directory but it was not a directory") + } + + fileInfo, err := os.Stat(filepath.Join(extractPath, "2.txt")) + if err != nil { + return fmt.Errorf("stat extracted file: %w", err) + } + if fileInfo.Name() != "2.txt" { + return fmt.Errorf("expected to extract file with name '1.txt' but it has name %s", fileInfo.Name()) + } + if fileInfo.Size() != 0 { + return fmt.Errorf("expected to extract file with size 0 but it has size %d", fileInfo.Size()) + } + if fileInfo.IsDir() { + return fmt.Errorf("expected to extract file but it was a directory") + } + return nil + }) + + if err != nil { + t.Errorf("Error: %v", err) + } +} + +func TestExtractAbsolutePathZipFile(t *testing.T) { + testName := "abs-path" + + _, archivePath, extractPath, err := makePaths(t, testName, ".zip") + if err != nil { + t.Errorf("%v", err) + return + } + + testFiles := map[string]string{ + "/test/2.txt": "", + } + + err = doZipExtractionTest(archivePath, extractPath, testFiles, func() error { dirInfo, err := os.Stat(filepath.Join(extractPath, "test")) if err != nil { return fmt.Errorf("stat extracted dir: %w", err) @@ -257,7 +402,7 @@ func TestExtractAbsolutePathTarGzFile(t *testing.T) { func TestExtractZipSlip(t *testing.T) { testName := "zipslip" - _, archivePath, extractPath, err := makePaths(t, testName) + _, archivePath, extractPath, err := makePaths(t, testName, ".tar.gz") if err != nil { t.Errorf("%v", err) return @@ -268,7 +413,7 @@ func TestExtractZipSlip(t *testing.T) { makeFileHeader("test/../../bad.txt", 1), } - err = doExtractionTest(archivePath, extractPath, testHeaders, func() error { + err = doTgzExtractionTest(archivePath, extractPath, testHeaders, func() error { t.Fatal("Extraction should have returned an error") return nil }) @@ -281,7 +426,7 @@ func TestExtractZipSlip(t *testing.T) { func TestExtractZipSlip2(t *testing.T) { testName := "zipslip2" - _, archivePath, extractPath, err := makePaths(t, testName) + _, archivePath, extractPath, err := makePaths(t, testName, ".tar.gz") if err != nil { t.Errorf("%v", err) return @@ -295,7 +440,7 @@ func TestExtractZipSlip2(t *testing.T) { makeFileHeader(filepath.Join("..", filepath.Base(similarlyNamedDir), "bad2.txt"), 1), } - err = doExtractionTest(archivePath, extractPath, testHeaders, func() error { + err = doTgzExtractionTest(archivePath, extractPath, testHeaders, func() error { bad2Info, err := os.Stat(filepath.Join(similarlyNamedDir, "bad2.txt")) if err == nil && bad2Info.Size() == 1 { t.Errorf("Found file in similarly named directory") @@ -312,7 +457,7 @@ func TestExtractZipSlip2(t *testing.T) { func TestExtractZipSlip3(t *testing.T) { testName := "zipslip3" - workDir, archivePath, extractPath, err := makePaths(t, testName) + workDir, archivePath, extractPath, err := makePaths(t, testName, ".tar.gz") if err != nil { t.Errorf("%v", err) return @@ -322,7 +467,7 @@ func TestExtractZipSlip3(t *testing.T) { makeFileHeader("../bad3.txt", 1), } - err = doExtractionTest(archivePath, extractPath, testHeaders, func() error { + err = doTgzExtractionTest(archivePath, extractPath, testHeaders, func() error { bad3Info, err := os.Stat(filepath.Join(workDir, "bad3.txt")) if err == nil && bad3Info.Size() == 1 { t.Errorf("Found file in parent directory")