From fd059d1f0371344007504573de27c23a9032ebb8 Mon Sep 17 00:00:00 2001 From: j-rafique Date: Fri, 18 Jul 2025 16:38:24 +0500 Subject: [PATCH] calculate download file hash incrementally --- pkg/crypto/hash.go | 39 ++++++++ pkg/crypto/hash_test.go | 97 +++++++++++++++++++ .../server/cascade/cascade_action_server.go | 45 ++++++++- supernode/services/cascade/download.go | 50 +++++----- 4 files changed, 200 insertions(+), 31 deletions(-) create mode 100644 pkg/crypto/hash_test.go diff --git a/pkg/crypto/hash.go b/pkg/crypto/hash.go index 5871506e..f45fa4ad 100644 --- a/pkg/crypto/hash.go +++ b/pkg/crypto/hash.go @@ -1 +1,40 @@ package crypto + +import ( + "fmt" + "io" + "lukechampine.com/blake3" + "os" +) + +const defaultHashBufferSize = 1024 * 1024 // 1 MB + +func HashFileIncrementally(filePath string, bufferSize int) ([]byte, error) { + f, err := os.Open(filePath) + if err != nil { + return nil, fmt.Errorf("open decoded file: %w", err) + } + defer f.Close() + + if bufferSize == 0 { + bufferSize = defaultHashBufferSize + } + + hasher := blake3.New(32, nil) + buf := make([]byte, bufferSize) // 4MB buffer to balance memory vs I/O + + for { + n, readErr := f.Read(buf) + if n > 0 { + hasher.Write(buf[:n]) + } + if readErr == io.EOF { + break + } + if readErr != nil { + return nil, fmt.Errorf("streaming file read failed: %w", readErr) + } + } + + return hasher.Sum(nil), nil +} diff --git a/pkg/crypto/hash_test.go b/pkg/crypto/hash_test.go new file mode 100644 index 00000000..7814a772 --- /dev/null +++ b/pkg/crypto/hash_test.go @@ -0,0 +1,97 @@ +package crypto + +import ( + "encoding/hex" + "os" + "path/filepath" + "testing" + + "lukechampine.com/blake3" +) + +func TestHashFileIncrementally(t *testing.T) { + expectedBlake3 := func(data []byte) string { + h := blake3.New(32, nil) + h.Write(data) + return hex.EncodeToString(h.Sum(nil)) + } + + testData := []byte("hello world") + emptyData := []byte("") + largeData := make([]byte, 5*1024*1024) + + // Temp dir for test files + tmpDir := t.TempDir() + + // Create helper function for file creation + createTempFile := func(name string, content []byte) string { + filePath := filepath.Join(tmpDir, name) + if err := os.WriteFile(filePath, content, 0644); err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + return filePath + } + + // Create test files + smallFile := createTempFile("small.txt", testData) + emptyFile := createTempFile("empty.txt", emptyData) + largeFile := createTempFile("large.bin", largeData) + + tests := []struct { + name string + filePath string + bufferSize int + wantHash string + wantErr bool + }{ + { + name: "small file", + filePath: smallFile, + bufferSize: 4 * 1024, // 4KB buffer + wantHash: expectedBlake3(testData), + wantErr: false, + }, + { + name: "empty file", + filePath: emptyFile, + bufferSize: 1024, // small buffer + wantHash: expectedBlake3(emptyData), + wantErr: false, + }, + { + name: "large file", + filePath: largeFile, + bufferSize: 1024 * 1024, // 1MB buffer + wantHash: expectedBlake3(largeData), + wantErr: false, + }, + { + name: "file does not exist", + filePath: filepath.Join(tmpDir, "doesnotexist.txt"), + bufferSize: 4096, + wantHash: "", + wantErr: true, + }, + { + name: "zero buffer size (should use default)", + filePath: smallFile, + bufferSize: 0, + wantHash: expectedBlake3(testData), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotHash, err := HashFileIncrementally(tt.filePath, tt.bufferSize) + + if (err != nil) != tt.wantErr { + t.Fatalf("expected error=%v, got err=%v", tt.wantErr, err) + } + + if !tt.wantErr && hex.EncodeToString(gotHash) != tt.wantHash { + t.Errorf("hash mismatch!\n got: %s\n want: %s", gotHash, tt.wantHash) + } + }) + } +} diff --git a/supernode/node/action/server/cascade/cascade_action_server.go b/supernode/node/action/server/cascade/cascade_action_server.go index fb3cc6f6..5fe889ba 100644 --- a/supernode/node/action/server/cascade/cascade_action_server.go +++ b/supernode/node/action/server/cascade/cascade_action_server.go @@ -235,7 +235,7 @@ func (server *ActionServer) Download(req *pb.DownloadRequest, stream pb.CascadeS } } - var restoredFile []byte + var restoredFilePath string var tmpDir string err := task.Download(ctx, &cascadeService.DownloadRequest{ @@ -250,8 +250,8 @@ func (server *ActionServer) Download(req *pb.DownloadRequest, stream pb.CascadeS }, } - if len(resp.Artefacts) > 0 { - restoredFile = resp.Artefacts + if resp.FilePath != "" { + restoredFilePath = resp.FilePath tmpDir = resp.DownloadedDir } @@ -265,15 +265,23 @@ func (server *ActionServer) Download(req *pb.DownloadRequest, stream pb.CascadeS return err } - if len(restoredFile) == 0 { + if restoredFilePath == "" { logtrace.Error(ctx, "no artefact file retrieved", fields) return fmt.Errorf("no artefact to stream") } logtrace.Info(ctx, "streaming artefact file in chunks", fields) + restoredFile, err := readFileContentsInChunks(restoredFilePath) + if err != nil { + logtrace.Error(ctx, "failed to read restored file", logtrace.Fields{ + logtrace.FieldError: err.Error(), + }) + return err + } + logtrace.Info(ctx, "file has been read in chunks", fields) + // Calculate optimal chunk size based on file size chunkSize := calculateOptimalChunkSize(int64(len(restoredFile))) - logtrace.Info(ctx, "calculated optimal chunk size for download", logtrace.Fields{ "file_size": len(restoredFile), "chunk_size": chunkSize, @@ -314,3 +322,30 @@ func (server *ActionServer) Download(req *pb.DownloadRequest, stream pb.CascadeS logtrace.Info(ctx, "completed streaming all chunks", fields) return nil } + +func readFileContentsInChunks(filePath string) ([]byte, error) { + f, err := os.Open(filePath) + if err != nil { + return nil, err + } + defer f.Close() + + buf := make([]byte, 1024*1024) + var fileBytes []byte + + for { + n, readErr := f.Read(buf) + if n > 0 { + // Process chunk + fileBytes = append(fileBytes, buf[:n]...) + } + if readErr == io.EOF { + break + } + if readErr != nil { + return nil, fmt.Errorf("chunked read failed: %w", readErr) + } + } + + return fileBytes, nil +} diff --git a/supernode/services/cascade/download.go b/supernode/services/cascade/download.go index dbc1ba4d..ddd80516 100644 --- a/supernode/services/cascade/download.go +++ b/supernode/services/cascade/download.go @@ -9,6 +9,7 @@ import ( actiontypes "github.com/LumeraProtocol/lumera/x/action/v1/types" "github.com/LumeraProtocol/supernode/pkg/codec" + "github.com/LumeraProtocol/supernode/pkg/crypto" "github.com/LumeraProtocol/supernode/pkg/errors" "github.com/LumeraProtocol/supernode/pkg/logtrace" "github.com/LumeraProtocol/supernode/pkg/utils" @@ -26,7 +27,7 @@ type DownloadRequest struct { type DownloadResponse struct { EventType SupernodeEventType Message string - Artefacts []byte + FilePath string DownloadedDir string } @@ -44,7 +45,7 @@ func (task *CascadeRegistrationTask) Download( return task.wrapErr(ctx, "failed to get action", err, fields) } logtrace.Info(ctx, "action has been retrieved", fields) - task.streamDownloadEvent(SupernodeEventTypeActionRetrieved, "action has been retrieved", nil, "", send) + task.streamDownloadEvent(SupernodeEventTypeActionRetrieved, "action has been retrieved", "", "", send) if actionDetails.GetAction().State != actiontypes.ActionStateDone { err = errors.New("action is not in a valid state") @@ -53,7 +54,7 @@ func (task *CascadeRegistrationTask) Download( return task.wrapErr(ctx, "action not found", err, fields) } logtrace.Info(ctx, "action has been validated", fields) - task.streamDownloadEvent(SupernodeEventTypeActionFinalized, "action state has been validated", nil, "", send) + task.streamDownloadEvent(SupernodeEventTypeActionFinalized, "action state has been validated", "", "", send) metadata, err := task.decodeCascadeMetadata(ctx, actionDetails.GetAction().Metadata, fields) if err != nil { @@ -61,20 +62,20 @@ func (task *CascadeRegistrationTask) Download( return task.wrapErr(ctx, "error decoding cascade metadata", err, fields) } logtrace.Info(ctx, "cascade metadata has been decoded", fields) - task.streamDownloadEvent(SupernodeEventTypeMetadataDecoded, "metadata has been decoded", nil, "", send) + task.streamDownloadEvent(SupernodeEventTypeMetadataDecoded, "metadata has been decoded", "", "", send) - file, tmpDir, err := task.downloadArtifacts(ctx, actionDetails.GetAction().ActionID, metadata, fields) + filePath, tmpDir, err := task.downloadArtifacts(ctx, actionDetails.GetAction().ActionID, metadata, fields) if err != nil { fields[logtrace.FieldError] = err.Error() return task.wrapErr(ctx, "failed to download artifacts", err, fields) } logtrace.Info(ctx, "artifacts have been downloaded", fields) - task.streamDownloadEvent(SupernodeEventTypeArtefactsDownloaded, "artifacts have been downloaded", file, tmpDir, send) + task.streamDownloadEvent(SupernodeEventTypeArtefactsDownloaded, "artifacts have been downloaded", filePath, tmpDir, send) return nil } -func (task *CascadeRegistrationTask) downloadArtifacts(ctx context.Context, actionID string, metadata actiontypes.CascadeMetadata, fields logtrace.Fields) ([]byte, string, error) { +func (task *CascadeRegistrationTask) downloadArtifacts(ctx context.Context, actionID string, metadata actiontypes.CascadeMetadata, fields logtrace.Fields) (string, string, error) { logtrace.Info(ctx, "started downloading the artifacts", fields) var layout codec.Layout @@ -106,7 +107,7 @@ func (task *CascadeRegistrationTask) downloadArtifacts(ctx context.Context, acti } if len(layout.Blocks) == 0 { - return nil, "", errors.New("no symbols found in RQ metadata") + return "", "", errors.New("no symbols found in RQ metadata") } return task.restoreFileFromLayout(ctx, layout, metadata.DataHash, actionID) @@ -117,7 +118,7 @@ func (task *CascadeRegistrationTask) restoreFileFromLayout( layout codec.Layout, dataHash string, actionID string, -) ([]byte, string, error) { +) (string, string, error) { fields := logtrace.Fields{ logtrace.FieldActionID: actionID, @@ -139,7 +140,7 @@ func (task *CascadeRegistrationTask) restoreFileFromLayout( if err != nil { fields[logtrace.FieldError] = err.Error() logtrace.Error(ctx, "failed to retrieve symbols", fields) - return nil, "", fmt.Errorf("failed to retrieve symbols: %w", err) + return "", "", fmt.Errorf("failed to retrieve symbols: %w", err) } fields["retrievedSymbols"] = len(symbols) @@ -154,40 +155,37 @@ func (task *CascadeRegistrationTask) restoreFileFromLayout( if err != nil { fields[logtrace.FieldError] = err.Error() logtrace.Error(ctx, "failed to decode symbols", fields) - return nil, "", fmt.Errorf("decode symbols using RaptorQ: %w", err) + return "", "", fmt.Errorf("decode symbols using RaptorQ: %w", err) } - file, err := os.ReadFile(decodeInfo.FilePath) + fileHash, err := crypto.HashFileIncrementally(decodeInfo.FilePath, 0) if err != nil { fields[logtrace.FieldError] = err.Error() - logtrace.Error(ctx, "failed to read file", fields) - return nil, "", fmt.Errorf("read decoded file: %w", err) + logtrace.Error(ctx, "failed to hash file", fields) + return "", "", fmt.Errorf("hash file: %w", err) } - - // 3. Validate hash (Blake3) - fileHash, err := utils.Blake3Hash(file) - if err != nil { - fields[logtrace.FieldError] = err.Error() - logtrace.Error(ctx, "failed to do hash", fields) - return nil, "", fmt.Errorf("hash file: %w", err) + if fileHash == nil { + fields[logtrace.FieldError] = "file hash is nil" + logtrace.Error(ctx, "failed to hash file", fields) + return "", "", errors.New("file hash is nil") } err = task.verifyDataHash(ctx, fileHash, dataHash, fields) if err != nil { logtrace.Error(ctx, "failed to verify hash", fields) fields[logtrace.FieldError] = err.Error() - return nil, decodeInfo.DecodeTmpDir, err + return "", decodeInfo.DecodeTmpDir, err } - logtrace.Info(ctx, "file successfully restored and hash verified", fields) - return file, decodeInfo.DecodeTmpDir, nil + + return decodeInfo.FilePath, decodeInfo.DecodeTmpDir, nil } -func (task *CascadeRegistrationTask) streamDownloadEvent(eventType SupernodeEventType, msg string, file []byte, tmpDir string, send func(resp *DownloadResponse) error) { +func (task *CascadeRegistrationTask) streamDownloadEvent(eventType SupernodeEventType, msg string, filePath string, tmpDir string, send func(resp *DownloadResponse) error) { _ = send(&DownloadResponse{ EventType: eventType, Message: msg, - Artefacts: file, + FilePath: filePath, DownloadedDir: tmpDir, })