Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions pkg/crypto/hash.go
Original file line number Diff line number Diff line change
@@ -1 +1,40 @@
package crypto

import (
"fmt"
"io"
"lukechampine.com/blake3"
"os"
)

const defaultHashBufferSize = 1024 * 1024 // 1 MB

func HashFileIncrementally(filePath string, bufferSize int) ([]byte, error) {
f, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("open decoded file: %w", err)
}
defer f.Close()

if bufferSize == 0 {
bufferSize = defaultHashBufferSize
}

hasher := blake3.New(32, nil)
buf := make([]byte, bufferSize) // 4MB buffer to balance memory vs I/O

for {
n, readErr := f.Read(buf)
if n > 0 {
hasher.Write(buf[:n])
}
if readErr == io.EOF {
break
}
if readErr != nil {
return nil, fmt.Errorf("streaming file read failed: %w", readErr)
}
}

return hasher.Sum(nil), nil
}
97 changes: 97 additions & 0 deletions pkg/crypto/hash_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package crypto

import (
"encoding/hex"
"os"
"path/filepath"
"testing"

"lukechampine.com/blake3"
)

func TestHashFileIncrementally(t *testing.T) {
expectedBlake3 := func(data []byte) string {
h := blake3.New(32, nil)
h.Write(data)
return hex.EncodeToString(h.Sum(nil))
}

testData := []byte("hello world")
emptyData := []byte("")
largeData := make([]byte, 5*1024*1024)

// Temp dir for test files
tmpDir := t.TempDir()

// Create helper function for file creation
createTempFile := func(name string, content []byte) string {
filePath := filepath.Join(tmpDir, name)
if err := os.WriteFile(filePath, content, 0644); err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
return filePath
}

// Create test files
smallFile := createTempFile("small.txt", testData)
emptyFile := createTempFile("empty.txt", emptyData)
largeFile := createTempFile("large.bin", largeData)

tests := []struct {
name string
filePath string
bufferSize int
wantHash string
wantErr bool
}{
{
name: "small file",
filePath: smallFile,
bufferSize: 4 * 1024, // 4KB buffer
wantHash: expectedBlake3(testData),
wantErr: false,
},
{
name: "empty file",
filePath: emptyFile,
bufferSize: 1024, // small buffer
wantHash: expectedBlake3(emptyData),
wantErr: false,
},
{
name: "large file",
filePath: largeFile,
bufferSize: 1024 * 1024, // 1MB buffer
wantHash: expectedBlake3(largeData),
wantErr: false,
},
{
name: "file does not exist",
filePath: filepath.Join(tmpDir, "doesnotexist.txt"),
bufferSize: 4096,
wantHash: "",
wantErr: true,
},
{
name: "zero buffer size (should use default)",
filePath: smallFile,
bufferSize: 0,
wantHash: expectedBlake3(testData),
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotHash, err := HashFileIncrementally(tt.filePath, tt.bufferSize)

if (err != nil) != tt.wantErr {
t.Fatalf("expected error=%v, got err=%v", tt.wantErr, err)
}

if !tt.wantErr && hex.EncodeToString(gotHash) != tt.wantHash {
t.Errorf("hash mismatch!\n got: %s\n want: %s", gotHash, tt.wantHash)
}
})
}
}
45 changes: 40 additions & 5 deletions supernode/node/action/server/cascade/cascade_action_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func (server *ActionServer) Download(req *pb.DownloadRequest, stream pb.CascadeS
}
}

var restoredFile []byte
var restoredFilePath string
var tmpDir string

err := task.Download(ctx, &cascadeService.DownloadRequest{
Expand All @@ -250,8 +250,8 @@ func (server *ActionServer) Download(req *pb.DownloadRequest, stream pb.CascadeS
},
}

if len(resp.Artefacts) > 0 {
restoredFile = resp.Artefacts
if resp.FilePath != "" {
restoredFilePath = resp.FilePath
tmpDir = resp.DownloadedDir
}

Expand All @@ -265,15 +265,23 @@ func (server *ActionServer) Download(req *pb.DownloadRequest, stream pb.CascadeS
return err
}

if len(restoredFile) == 0 {
if restoredFilePath == "" {
logtrace.Error(ctx, "no artefact file retrieved", fields)
return fmt.Errorf("no artefact to stream")
}
logtrace.Info(ctx, "streaming artefact file in chunks", fields)

restoredFile, err := readFileContentsInChunks(restoredFilePath)
if err != nil {
logtrace.Error(ctx, "failed to read restored file", logtrace.Fields{
logtrace.FieldError: err.Error(),
})
return err
}
logtrace.Info(ctx, "file has been read in chunks", fields)

// Calculate optimal chunk size based on file size
chunkSize := calculateOptimalChunkSize(int64(len(restoredFile)))

logtrace.Info(ctx, "calculated optimal chunk size for download", logtrace.Fields{
"file_size": len(restoredFile),
"chunk_size": chunkSize,
Expand Down Expand Up @@ -314,3 +322,30 @@ func (server *ActionServer) Download(req *pb.DownloadRequest, stream pb.CascadeS
logtrace.Info(ctx, "completed streaming all chunks", fields)
return nil
}

func readFileContentsInChunks(filePath string) ([]byte, error) {
f, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer f.Close()

buf := make([]byte, 1024*1024)
var fileBytes []byte

for {
n, readErr := f.Read(buf)
if n > 0 {
// Process chunk
fileBytes = append(fileBytes, buf[:n]...)
}
if readErr == io.EOF {
break
}
if readErr != nil {
return nil, fmt.Errorf("chunked read failed: %w", readErr)
}
}

return fileBytes, nil
}
50 changes: 24 additions & 26 deletions supernode/services/cascade/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

actiontypes "github.com/LumeraProtocol/lumera/x/action/v1/types"
"github.com/LumeraProtocol/supernode/pkg/codec"
"github.com/LumeraProtocol/supernode/pkg/crypto"
"github.com/LumeraProtocol/supernode/pkg/errors"
"github.com/LumeraProtocol/supernode/pkg/logtrace"
"github.com/LumeraProtocol/supernode/pkg/utils"
Expand All @@ -26,7 +27,7 @@ type DownloadRequest struct {
type DownloadResponse struct {
EventType SupernodeEventType
Message string
Artefacts []byte
FilePath string
DownloadedDir string
}

Expand All @@ -44,7 +45,7 @@ func (task *CascadeRegistrationTask) Download(
return task.wrapErr(ctx, "failed to get action", err, fields)
}
logtrace.Info(ctx, "action has been retrieved", fields)
task.streamDownloadEvent(SupernodeEventTypeActionRetrieved, "action has been retrieved", nil, "", send)
task.streamDownloadEvent(SupernodeEventTypeActionRetrieved, "action has been retrieved", "", "", send)

if actionDetails.GetAction().State != actiontypes.ActionStateDone {
err = errors.New("action is not in a valid state")
Expand All @@ -53,28 +54,28 @@ func (task *CascadeRegistrationTask) Download(
return task.wrapErr(ctx, "action not found", err, fields)
}
logtrace.Info(ctx, "action has been validated", fields)
task.streamDownloadEvent(SupernodeEventTypeActionFinalized, "action state has been validated", nil, "", send)
task.streamDownloadEvent(SupernodeEventTypeActionFinalized, "action state has been validated", "", "", send)

metadata, err := task.decodeCascadeMetadata(ctx, actionDetails.GetAction().Metadata, fields)
if err != nil {
fields[logtrace.FieldError] = err.Error()
return task.wrapErr(ctx, "error decoding cascade metadata", err, fields)
}
logtrace.Info(ctx, "cascade metadata has been decoded", fields)
task.streamDownloadEvent(SupernodeEventTypeMetadataDecoded, "metadata has been decoded", nil, "", send)
task.streamDownloadEvent(SupernodeEventTypeMetadataDecoded, "metadata has been decoded", "", "", send)

file, tmpDir, err := task.downloadArtifacts(ctx, actionDetails.GetAction().ActionID, metadata, fields)
filePath, tmpDir, err := task.downloadArtifacts(ctx, actionDetails.GetAction().ActionID, metadata, fields)
if err != nil {
fields[logtrace.FieldError] = err.Error()
return task.wrapErr(ctx, "failed to download artifacts", err, fields)
}
logtrace.Info(ctx, "artifacts have been downloaded", fields)
task.streamDownloadEvent(SupernodeEventTypeArtefactsDownloaded, "artifacts have been downloaded", file, tmpDir, send)
task.streamDownloadEvent(SupernodeEventTypeArtefactsDownloaded, "artifacts have been downloaded", filePath, tmpDir, send)

return nil
}

func (task *CascadeRegistrationTask) downloadArtifacts(ctx context.Context, actionID string, metadata actiontypes.CascadeMetadata, fields logtrace.Fields) ([]byte, string, error) {
func (task *CascadeRegistrationTask) downloadArtifacts(ctx context.Context, actionID string, metadata actiontypes.CascadeMetadata, fields logtrace.Fields) (string, string, error) {
logtrace.Info(ctx, "started downloading the artifacts", fields)

var layout codec.Layout
Expand Down Expand Up @@ -106,7 +107,7 @@ func (task *CascadeRegistrationTask) downloadArtifacts(ctx context.Context, acti
}

if len(layout.Blocks) == 0 {
return nil, "", errors.New("no symbols found in RQ metadata")
return "", "", errors.New("no symbols found in RQ metadata")
}

return task.restoreFileFromLayout(ctx, layout, metadata.DataHash, actionID)
Expand All @@ -117,7 +118,7 @@ func (task *CascadeRegistrationTask) restoreFileFromLayout(
layout codec.Layout,
dataHash string,
actionID string,
) ([]byte, string, error) {
) (string, string, error) {

fields := logtrace.Fields{
logtrace.FieldActionID: actionID,
Expand All @@ -139,7 +140,7 @@ func (task *CascadeRegistrationTask) restoreFileFromLayout(
if err != nil {
fields[logtrace.FieldError] = err.Error()
logtrace.Error(ctx, "failed to retrieve symbols", fields)
return nil, "", fmt.Errorf("failed to retrieve symbols: %w", err)
return "", "", fmt.Errorf("failed to retrieve symbols: %w", err)
}

fields["retrievedSymbols"] = len(symbols)
Expand All @@ -154,40 +155,37 @@ func (task *CascadeRegistrationTask) restoreFileFromLayout(
if err != nil {
fields[logtrace.FieldError] = err.Error()
logtrace.Error(ctx, "failed to decode symbols", fields)
return nil, "", fmt.Errorf("decode symbols using RaptorQ: %w", err)
return "", "", fmt.Errorf("decode symbols using RaptorQ: %w", err)
}

file, err := os.ReadFile(decodeInfo.FilePath)
fileHash, err := crypto.HashFileIncrementally(decodeInfo.FilePath, 0)
if err != nil {
fields[logtrace.FieldError] = err.Error()
logtrace.Error(ctx, "failed to read file", fields)
return nil, "", fmt.Errorf("read decoded file: %w", err)
logtrace.Error(ctx, "failed to hash file", fields)
return "", "", fmt.Errorf("hash file: %w", err)
}

// 3. Validate hash (Blake3)
fileHash, err := utils.Blake3Hash(file)
if err != nil {
fields[logtrace.FieldError] = err.Error()
logtrace.Error(ctx, "failed to do hash", fields)
return nil, "", fmt.Errorf("hash file: %w", err)
if fileHash == nil {
fields[logtrace.FieldError] = "file hash is nil"
logtrace.Error(ctx, "failed to hash file", fields)
return "", "", errors.New("file hash is nil")
}

err = task.verifyDataHash(ctx, fileHash, dataHash, fields)
if err != nil {
logtrace.Error(ctx, "failed to verify hash", fields)
fields[logtrace.FieldError] = err.Error()
return nil, decodeInfo.DecodeTmpDir, err
return "", decodeInfo.DecodeTmpDir, err
}

logtrace.Info(ctx, "file successfully restored and hash verified", fields)
return file, decodeInfo.DecodeTmpDir, nil

return decodeInfo.FilePath, decodeInfo.DecodeTmpDir, nil
}

func (task *CascadeRegistrationTask) streamDownloadEvent(eventType SupernodeEventType, msg string, file []byte, tmpDir string, send func(resp *DownloadResponse) error) {
func (task *CascadeRegistrationTask) streamDownloadEvent(eventType SupernodeEventType, msg string, filePath string, tmpDir string, send func(resp *DownloadResponse) error) {
_ = send(&DownloadResponse{
EventType: eventType,
Message: msg,
Artefacts: file,
FilePath: filePath,
DownloadedDir: tmpDir,
})

Expand Down