diff --git a/sdk/adapters/supernodeservice/adapter.go b/sdk/adapters/supernodeservice/adapter.go index 0c2d12f3..ea3d97ed 100644 --- a/sdk/adapters/supernodeservice/adapter.go +++ b/sdk/adapters/supernodeservice/adapter.go @@ -75,6 +75,8 @@ func calculateOptimalChunkSize(fileSize int64) int { return chunkSize } +const maxFileSize = 1 * 1024 * 1024 * 1024 // 1GB limit + func (a *cascadeAdapter) CascadeSupernodeRegister(ctx context.Context, in *CascadeSupernodeRegisterRequest, opts ...grpc.CallOption) (*CascadeSupernodeRegisterResponse, error) { // Create the client stream ctx = net.AddCorrelationID(ctx) @@ -102,6 +104,15 @@ func (a *cascadeAdapter) CascadeSupernodeRegister(ctx context.Context, in *Casca } totalBytes := fileInfo.Size() + // Validate file size before starting upload + if totalBytes > maxFileSize { + a.logger.Error(ctx, "File exceeds maximum size limit", + "filePath", in.FilePath, + "fileSize", totalBytes, + "maxSize", maxFileSize) + return nil, fmt.Errorf("file size %d bytes exceeds maximum allowed size of 1GB", totalBytes) + } + // Define adaptive chunk size based on file size chunkSize := calculateOptimalChunkSize(totalBytes) diff --git a/sdk/net/factory.go b/sdk/net/factory.go index 3456c999..7df449bf 100644 --- a/sdk/net/factory.go +++ b/sdk/net/factory.go @@ -36,10 +36,17 @@ func NewClientFactory(ctx context.Context, logger log.Logger, keyring keyring.Ke logger.Debug(ctx, "Creating supernode client factory", "localAddress", config.LocalCosmosAddress) + // Optimized for streaming 1GB files with 4MB chunks (10 concurrent streams) + opts := client.DefaultClientOptions() + opts.MaxRecvMsgSize = 16 * 1024 * 1024 // 16MB to match server + opts.MaxSendMsgSize = 16 * 1024 * 1024 // 16MB to match server + opts.InitialWindowSize = 16 * 1024 * 1024 // 16MB per stream (4x chunk size) + opts.InitialConnWindowSize = 160 * 1024 * 1024 // 160MB (16MB x 10 streams) + return &ClientFactory{ logger: logger, keyring: keyring, - clientOptions: client.DefaultClientOptions(), + clientOptions: opts, config: config, lumeraClient: lumeraClient, } diff --git a/sdk/task/cache.go b/sdk/task/cache.go index 4395efad..f541e548 100644 --- a/sdk/task/cache.go +++ b/sdk/task/cache.go @@ -23,6 +23,7 @@ type TaskEntry struct { Events []event.Event CreatedAt time.Time LastUpdatedAt time.Time + Cancel context.CancelFunc // For cancelling long-running tasks } type TaskCache struct { @@ -64,8 +65,8 @@ func (tc *TaskCache) getOrCreateMutex(taskID string) *sync.Mutex { return mu.(*sync.Mutex) } -// Set stores a task in the cache with initial metadata -func (tc *TaskCache) Set(ctx context.Context, taskID string, task Task, taskType TaskType, actionID string) bool { +// Set stores a task in the cache with initial metadata and optional cancel function +func (tc *TaskCache) Set(ctx context.Context, taskID string, task Task, taskType TaskType, actionID string, cancel context.CancelFunc) bool { mu := tc.getOrCreateMutex(taskID) mu.Lock() defer mu.Unlock() @@ -82,6 +83,7 @@ func (tc *TaskCache) Set(ctx context.Context, taskID string, task Task, taskType Events: make([]event.Event, 0), CreatedAt: now, LastUpdatedAt: now, + Cancel: cancel, } success := tc.cache.Set(taskID, entry, 1) diff --git a/sdk/task/cascade.go b/sdk/task/cascade.go index ba15fdbe..8dd96fdf 100644 --- a/sdk/task/cascade.go +++ b/sdk/task/cascade.go @@ -33,9 +33,14 @@ func NewCascadeTask(base BaseTask, filePath string, actionId string) *CascadeTas // Run executes the full cascade‐task lifecycle. func (t *CascadeTask) Run(ctx context.Context) error { - t.LogEvent(ctx, event.SDKTaskStarted, "Running cascade task", nil) + // Validate file size before proceeding + if err := ValidateFileSize(t.filePath); err != nil { + t.LogEvent(ctx, event.SDKTaskFailed, "File validation failed", event.EventData{event.KeyError: err.Error()}) + return err + } + // 1 - Fetch the supernodes supernodes, err := t.fetchSupernodes(ctx, t.Action.Height) diff --git a/sdk/task/helpers.go b/sdk/task/helpers.go index 44c9a662..301539f8 100644 --- a/sdk/task/helpers.go +++ b/sdk/task/helpers.go @@ -4,12 +4,29 @@ import ( "context" "encoding/base64" "fmt" + "os" "path/filepath" "strings" "github.com/LumeraProtocol/supernode/sdk/adapters/lumera" ) +const maxFileSize = 1 * 1024 * 1024 * 1024 // 1GB limit + +// ValidateFileSize checks if a file size is within the allowed 1GB limit +func ValidateFileSize(filePath string) error { + fileInfo, err := os.Stat(filePath) + if err != nil { + return fmt.Errorf("failed to check file: %w", err) + } + + if fileInfo.Size() > maxFileSize { + return fmt.Errorf("file size %d bytes exceeds maximum allowed size of 1GB", fileInfo.Size()) + } + + return nil +} + func (m *ManagerImpl) validateAction(ctx context.Context, actionID string) (lumera.Action, error) { action, err := m.lumeraClient.GetAction(ctx, actionID) if err != nil { diff --git a/sdk/task/manager.go b/sdk/task/manager.go index 10170b30..f46478f9 100644 --- a/sdk/task/manager.go +++ b/sdk/task/manager.go @@ -92,20 +92,25 @@ func NewManagerWithLumeraClient(ctx context.Context, config config.Config, logge // CreateCascadeTask creates and starts a Cascade task using the new pattern func (m *ManagerImpl) CreateCascadeTask(ctx context.Context, filePath string, actionID, signature string) (string, error) { + // Create a detached context immediately to prevent HTTP request cancellation + taskCtx, cancel := context.WithCancel(context.Background()) + // First validate the action before creating the task - action, err := m.validateAction(ctx, actionID) + action, err := m.validateAction(taskCtx, actionID) if err != nil { + cancel() // Clean up if validation fails return "", err } // verify signature - if err := m.validateSignature(ctx, action, signature); err != nil { + if err := m.validateSignature(taskCtx, action, signature); err != nil { + cancel() // Clean up if signature validation fails return "", err } taskID := uuid.New().String()[:8] - m.logger.Debug(ctx, "Generated task ID", "taskID", taskID) + m.logger.Debug(taskCtx, "Generated task ID", "taskID", taskID) baseTask := BaseTask{ TaskID: taskID, @@ -122,23 +127,23 @@ func (m *ManagerImpl) CreateCascadeTask(ctx context.Context, filePath string, ac // Create cascade-specific task task := NewCascadeTask(baseTask, filePath, actionID) - // Store task in cache - m.taskCache.Set(ctx, taskID, task, TaskTypeCascade, actionID) + // Store task in cache with cancel function + m.taskCache.Set(taskCtx, taskID, task, TaskTypeCascade, actionID, cancel) // Ensure task is stored before returning m.taskCache.Wait() go func() { - m.logger.Debug(ctx, "Starting cascade task asynchronously", "taskID", taskID) - err := task.Run(ctx) + m.logger.Debug(taskCtx, "Starting cascade task asynchronously", "taskID", taskID) + err := task.Run(taskCtx) if err != nil { // Error handling is done via events in the task.Run method // This is just a failsafe in case something goes wrong - m.logger.Error(ctx, "Cascade task failed with error", "taskID", taskID, "error", err) + m.logger.Error(taskCtx, "Cascade task failed with error", "taskID", taskID, "error", err) } }() - m.logger.Info(ctx, "Cascade task created successfully", "taskID", taskID) + m.logger.Info(taskCtx, "Cascade task created successfully", "taskID", taskID) return taskID, nil } @@ -152,13 +157,19 @@ func (m *ManagerImpl) GetTask(ctx context.Context, taskID string) (*TaskEntry, b func (m *ManagerImpl) DeleteTask(ctx context.Context, taskID string) error { m.logger.Info(ctx, "Deleting task", "taskID", taskID) - // First check if the task exists - _, exists := m.taskCache.Get(ctx, taskID) + // First check if the task exists and get its entry + taskEntry, exists := m.taskCache.Get(ctx, taskID) if !exists { m.logger.Warn(ctx, "Task not found for deletion", "taskID", taskID) return fmt.Errorf("task not found: %s", taskID) } + // Cancel the task if it has a cancel function + if taskEntry.Cancel != nil { + m.logger.Info(ctx, "Cancelling task before deletion", "taskID", taskID) + taskEntry.Cancel() + } + // Delete the task from the cache m.taskCache.Del(ctx, taskID) @@ -242,19 +253,25 @@ func (m *ManagerImpl) Close(ctx context.Context) { } func (m *ManagerImpl) CreateDownloadTask(ctx context.Context, actionID string, outputDir string, signature string) (string, error) { + // Create a detached context immediately to prevent HTTP request cancellation + taskCtx, cancel := context.WithCancel(context.Background()) + // First validate the action before creating the task - action, err := m.validateDownloadAction(ctx, actionID) + action, err := m.validateDownloadAction(taskCtx, actionID) if err != nil { + cancel() // Clean up if validation fails return "", err } // Decode metadata to get the filename - metadata, err := m.lumeraClient.DecodeCascadeMetadata(ctx, action) + metadata, err := m.lumeraClient.DecodeCascadeMetadata(taskCtx, action) if err != nil { + cancel() // Clean up if metadata decode fails return "", fmt.Errorf("failed to decode cascade metadata: %w", err) } // Ensure we have a filename from metadata if metadata.FileName == "" { + cancel() // Clean up if no filename return "", fmt.Errorf("no filename found in cascade metadata") } @@ -263,7 +280,7 @@ func (m *ManagerImpl) CreateDownloadTask(ctx context.Context, actionID string, o taskID := uuid.New().String()[:8] - m.logger.Debug(ctx, "Generated download task ID", "task_id", taskID, "final_output_path", finalOutputPath) + m.logger.Debug(taskCtx, "Generated download task ID", "task_id", taskID, "final_output_path", finalOutputPath) baseTask := BaseTask{ TaskID: taskID, @@ -280,22 +297,22 @@ func (m *ManagerImpl) CreateDownloadTask(ctx context.Context, actionID string, o // Use the final output path with the correct filename task := NewCascadeDownloadTask(baseTask, actionID, finalOutputPath, signature) - // Store task in cache - m.taskCache.Set(ctx, taskID, task, TaskTypeCascade, actionID) + // Store task in cache with cancel function + m.taskCache.Set(taskCtx, taskID, task, TaskTypeCascade, actionID, cancel) // Ensure task is stored before returning m.taskCache.Wait() go func() { - m.logger.Debug(ctx, "Starting download cascade task asynchronously", "taskID", taskID) - err := task.Run(ctx) + m.logger.Debug(taskCtx, "Starting download cascade task asynchronously", "taskID", taskID) + err := task.Run(taskCtx) if err != nil { // Error handling is done via events in the task.Run method // This is just a failsafe in case something goes wrong - m.logger.Error(ctx, "Download Cascade task failed with error", "taskID", taskID, "error", err) + m.logger.Error(taskCtx, "Download Cascade task failed with error", "taskID", taskID, "error", err) } }() - m.logger.Info(ctx, "Download Cascade task created successfully", "taskID", taskID, "outputPath", finalOutputPath) + m.logger.Info(taskCtx, "Download Cascade task created successfully", "taskID", taskID, "outputPath", finalOutputPath) return taskID, nil } diff --git a/supernode/node/action/server/cascade/cascade_action_server.go b/supernode/node/action/server/cascade/cascade_action_server.go index 5fe889ba..54b7cd3b 100644 --- a/supernode/node/action/server/cascade/cascade_action_server.go +++ b/supernode/node/action/server/cascade/cascade_action_server.go @@ -76,6 +76,8 @@ func (server *ActionServer) Register(stream pb.CascadeService_RegisterServer) er ctx := stream.Context() logtrace.Info(ctx, "client streaming request to upload cascade input data received", fields) + const maxFileSize = 1 * 1024 * 1024 * 1024 // 1GB limit + var ( metadata *pb.Metadata totalSize int @@ -130,6 +132,14 @@ func (server *ActionServer) Register(stream pb.CascadeService_RegisterServer) er } totalSize += len(x.Chunk.Data) + // Validate total size doesn't exceed limit + if totalSize > maxFileSize { + fields[logtrace.FieldError] = "file size exceeds 1GB limit" + fields["total_size"] = totalSize + logtrace.Error(ctx, "upload rejected: file too large", fields) + return fmt.Errorf("file size %d exceeds maximum allowed size of 1GB", totalSize) + } + logtrace.Info(ctx, "received data chunk", logtrace.Fields{ "chunk_size": len(x.Chunk.Data), "total_size_so_far": totalSize, diff --git a/supernode/node/supernode/server/server.go b/supernode/node/supernode/server/server.go index 1693a2cd..1672cf01 100644 --- a/supernode/node/supernode/server/server.go +++ b/supernode/node/supernode/server/server.go @@ -58,15 +58,16 @@ func (server *Server) Run(ctx context.Context) error { logtrace.Fatal(ctx, "Failed to setup gRPC server", logtrace.Fields{logtrace.FieldModule: "server", logtrace.FieldError: err.Error()}) } - // Custom server options + // Optimized for streaming 1GB files with 4MB chunks (10 concurrent streams) opts := grpcserver.DefaultServerOptions() - opts.MaxRecvMsgSize = 2 * 1024 * 1024 * 1024 // 2 GB - opts.MaxSendMsgSize = 2 * 1024 * 1024 * 1024 // 2 GB - opts.InitialWindowSize = 32 * 1024 * 1024 // 32 MB - opts.InitialConnWindowSize = 32 * 1024 * 1024 // 32 MB - opts.WriteBufferSize = 1024 * 1024 // 1 MB - opts.ReadBufferSize = 1024 * 1024 // 1 MB + opts.MaxRecvMsgSize = (16 * 1024 * 1024) // 16MB (supports 4MB chunks + overhead) + opts.MaxSendMsgSize = (16 * 1024 * 1024) // 16MB for download streaming + opts.InitialWindowSize = (16 * 1024 * 1024) // 16MB per stream (4x chunk size) + opts.InitialConnWindowSize = (160 * 1024 * 1024) // 160MB (16MB x 10 streams) + opts.MaxConcurrentStreams = 20 // Limit to prevent resource exhaustion + opts.ReadBufferSize = (8 * 1024 * 1024) // 8MB TCP buffer + opts.WriteBufferSize = (8 * 1024 * 1024) // 8MB TCP buffer for _, address := range addresses { addr := net.JoinHostPort(strings.TrimSpace(address), strconv.Itoa(server.config.Port)) @@ -110,20 +111,20 @@ func (server *Server) setupGRPCServer() error { for _, service := range server.services { server.grpcServer.RegisterService(service.Desc(), service) server.healthServer.SetServingStatus(service.Desc().ServiceName, healthpb.HealthCheckResponse_SERVING) - + // Keep reference to SupernodeServer if ss, ok := service.(*SupernodeServer); ok { supernodeServer = ss } } - + // After all services are registered, update SupernodeServer with the list if supernodeServer != nil { // Register all custom services for _, svc := range server.services { supernodeServer.RegisterService(svc.Desc().ServiceName, svc.Desc()) } - + // Also register the health service healthDesc := healthpb.Health_ServiceDesc supernodeServer.RegisterService(healthDesc.ServiceName, &healthDesc)