diff --git a/p2p/kademlia/dht.go b/p2p/kademlia/dht.go index 493f90f1..00c86fb8 100644 --- a/p2p/kademlia/dht.go +++ b/p2p/kademlia/dht.go @@ -40,9 +40,9 @@ const ( defaultDeleteDataInterval = 11 * time.Hour delKeysCountThreshold = 10 lowSpaceThreshold = 50 // GB - batchStoreSize = 2500 + batchRetrieveSize = 1000 storeSameSymbolsBatchConcurrency = 3 - storeSymbolsBatchConcurrency = 3.0 + fetchSymbolsBatchConcurrency = 6 minimumDataStoreSuccessRate = 75.0 maxIterations = 4 @@ -734,10 +734,10 @@ func (s *DHT) BatchRetrieve(ctx context.Context, keys []string, required int32, return result, nil } - batchSize := batchStoreSize + batchSize := batchRetrieveSize var networkFound int32 totalBatches := int(math.Ceil(float64(required) / float64(batchSize))) - parallelBatches := int(math.Min(float64(totalBatches), storeSymbolsBatchConcurrency)) + parallelBatches := int(math.Min(float64(totalBatches), fetchSymbolsBatchConcurrency)) semaphore := make(chan struct{}, parallelBatches) var wg sync.WaitGroup @@ -775,7 +775,13 @@ func (s *DHT) BatchRetrieve(ctx context.Context, keys []string, required int32, wg.Wait() netFound := int(atomic.LoadInt32(&networkFound)) - s.metrics.RecordBatchRetrieve(len(keys), int(required), int(foundLocalCount), netFound, time.Duration(time.Since(start).Milliseconds())) // NEW + totalFound := int(foundLocalCount) + netFound + + s.metrics.RecordBatchRetrieve(len(keys), int(required), int(foundLocalCount), netFound, time.Since(start)) + + if totalFound < int(required) { + return result, errors.Errorf("insufficient symbols: required=%d, found=%d", required, totalFound) + } return result, nil } @@ -800,7 +806,7 @@ func (s *DHT) processBatch( defer wg.Done() defer func() { <-semaphore }() - for i := 0; i < maxIterations; i++ { + for i := 0; i < 1; i++ { select { case <-ctx.Done(): return @@ -822,9 +828,18 @@ func (s *DHT) processBatch( } } + knownMu.Lock() + nodesSnap := make(map[string]*Node, len(knownNodes)) + for id, n := range knownNodes { + nodesSnap[id] = n + } + knownMu.Unlock() + foundCount, newClosestContacts, batchErr := s.iterateBatchGetValues( - ctx, knownNodes, batchKeys, batchHexKeys, fetchMap, resMap, required, foundLocalCount+atomic.LoadInt32(networkFound), + ctx, nodesSnap, batchKeys, batchHexKeys, fetchMap, resMap, + required, foundLocalCount+atomic.LoadInt32(networkFound), ) + if batchErr != nil { logtrace.Error(ctx, "Iterate batch get values failed", logtrace.Fields{ logtrace.FieldModule: "dht", "txid": txID, logtrace.FieldError: batchErr.Error(), @@ -872,19 +887,36 @@ func (s *DHT) processBatch( } } -func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node, keys []string, hexKeys []string, fetchMap map[string][]int, - resMap *sync.Map, req, alreadyFound int32) (int, map[string]*NodeList, error) { - semaphore := make(chan struct{}, storeSameSymbolsBatchConcurrency) // Limit concurrency to 1 +func (s *DHT) iterateBatchGetValues( + ctx context.Context, + nodes map[string]*Node, + keys []string, + hexKeys []string, + fetchMap map[string][]int, + resMap *sync.Map, + req, alreadyFound int32, +) (int, map[string]*NodeList, error) { + + semaphore := make(chan struct{}, storeSameSymbolsBatchConcurrency) closestContacts := make(map[string]*NodeList) var wg sync.WaitGroup contactsMap := make(map[string]map[string][]*Node) var firstErr error - var mu sync.Mutex // To protect the firstErr + var mu sync.Mutex foundCount := int32(0) - gctx, cancel := context.WithCancel(ctx) // Create a cancellable context + gctx, cancel := context.WithCancel(ctx) defer cancel() - for nodeID, node := range nodes { + + // ✅ Iterate ONLY nodes that actually have work according to fetchMap + for nodeID, idxs := range fetchMap { + if len(idxs) == 0 { + continue + } + node, ok := nodes[nodeID] + if !ok { + continue + } if s.ignorelist.Banned(node) { logtrace.Info(ctx, "Ignore banned node in iterate batch get values", logtrace.Fields{ logtrace.FieldModule: "dht", @@ -894,8 +926,9 @@ func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node, } contactsMap[nodeID] = make(map[string][]*Node) + wg.Add(1) - go func(node *Node, nodeID string) { + go func(node *Node, nodeID string, indices []int) { defer wg.Done() select { @@ -907,17 +940,15 @@ func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node, defer func() { <-semaphore }() } - indices := fetchMap[nodeID] - requestKeys := make(map[string]KeyValWithClosest) + // Build requestKeys from the provided indices only + requestKeys := make(map[string]KeyValWithClosest, len(indices)) for _, idx := range indices { - if idx < len(hexKeys) { - _, loaded := resMap.Load(hexKeys[idx]) // check if key is already there in resMap - if !loaded { + if idx >= 0 && idx < len(hexKeys) { + if _, loaded := resMap.Load(hexKeys[idx]); !loaded { requestKeys[hexKeys[idx]] = KeyValWithClosest{} } } } - if len(requestKeys) == 0 { return } @@ -932,13 +963,12 @@ func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node, return } + // Merge values or closest contacts for k, v := range decompressedData { if len(v.Value) > 0 { - _, loaded := resMap.LoadOrStore(k, v.Value) - if !loaded { - atomic.AddInt32(&foundCount, 1) - if atomic.LoadInt32(&foundCount) >= int32(req-alreadyFound) { - cancel() // Cancel context to stop other goroutines + if _, loaded := resMap.LoadOrStore(k, v.Value); !loaded { + if atomic.AddInt32(&foundCount, 1) >= int32(req-alreadyFound) { + cancel() return } } @@ -946,7 +976,7 @@ func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node, contactsMap[nodeID][k] = v.Closest } } - }(node, nodeID) + }(node, nodeID, idxs) } wg.Wait() @@ -964,19 +994,19 @@ func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node, }) } + // Build closestContacts from contactsMap (same as before) for _, closestNodes := range contactsMap { for key, nodes := range closestNodes { comparator, err := hex.DecodeString(key) if err != nil { - logtrace.Error(ctx, "Failed to decode hex key in closestNodes.Range", logtrace.Fields{ + logtrace.Error(ctx, "Failed to decode hex key in closestNodes", logtrace.Fields{ logtrace.FieldModule: "dht", "key": key, logtrace.FieldError: err.Error(), }) - return 0, nil, err + return int(foundCount), nil, err } bkey := base58.Encode(comparator) - if _, ok := closestContacts[bkey]; !ok { closestContacts[bkey] = &NodeList{Nodes: nodes, Comparator: comparator} } else { @@ -984,7 +1014,6 @@ func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node, } } } - for key, nodes := range closestContacts { nodes.Sort() nodes.TopN(Alpha) diff --git a/supernode/services/cascade/download.go b/supernode/services/cascade/download.go index 78db1758..3c4e640e 100644 --- a/supernode/services/cascade/download.go +++ b/supernode/services/cascade/download.go @@ -17,7 +17,7 @@ import ( ) const ( - requiredSymbolPercent = 9 + requiredSymbolPercent = 17 ) type DownloadRequest struct { @@ -36,8 +36,8 @@ func (task *CascadeRegistrationTask) Download( req *DownloadRequest, send func(resp *DownloadResponse) error, ) (err error) { - fields := logtrace.Fields{logtrace.FieldMethod: "Download", logtrace.FieldRequest: req} - logtrace.Info(ctx, "Cascade download request received", fields) + fields := logtrace.Fields{logtrace.FieldMethod: "Download", logtrace.FieldRequest: req} + logtrace.Info(ctx, "Cascade download request received", fields) // Ensure task status is finalized regardless of outcome defer func() { @@ -54,8 +54,8 @@ func (task *CascadeRegistrationTask) Download( fields[logtrace.FieldError] = err return task.wrapErr(ctx, "failed to get action", err, fields) } - logtrace.Info(ctx, "Action retrieved", fields) - task.streamDownloadEvent(SupernodeEventTypeActionRetrieved, "Action retrieved", "", "", send) + logtrace.Info(ctx, "Action retrieved", fields) + task.streamDownloadEvent(SupernodeEventTypeActionRetrieved, "Action retrieved", "", "", send) if actionDetails.GetAction().State != actiontypes.ActionStateDone { err = errors.New("action is not in a valid state") @@ -63,27 +63,27 @@ func (task *CascadeRegistrationTask) Download( fields[logtrace.FieldActionState] = actionDetails.GetAction().State return task.wrapErr(ctx, "action not found", err, fields) } - logtrace.Info(ctx, "Action state validated", fields) + logtrace.Info(ctx, "Action state validated", fields) metadata, err := task.decodeCascadeMetadata(ctx, actionDetails.GetAction().Metadata, fields) if err != nil { fields[logtrace.FieldError] = err.Error() return task.wrapErr(ctx, "error decoding cascade metadata", err, fields) } - logtrace.Info(ctx, "Cascade metadata decoded", fields) - task.streamDownloadEvent(SupernodeEventTypeMetadataDecoded, "Cascade metadata decoded", "", "", send) + logtrace.Info(ctx, "Cascade metadata decoded", fields) + task.streamDownloadEvent(SupernodeEventTypeMetadataDecoded, "Cascade metadata decoded", "", "", send) - // Notify: network retrieval phase begins - task.streamDownloadEvent(SupernodeEventTypeNetworkRetrieveStarted, "Network retrieval started", "", "", send) + // Notify: network retrieval phase begins + task.streamDownloadEvent(SupernodeEventTypeNetworkRetrieveStarted, "Network retrieval started", "", "", send) - filePath, tmpDir, err := task.downloadArtifacts(ctx, actionDetails.GetAction().ActionID, metadata, fields) - if err != nil { - fields[logtrace.FieldError] = err.Error() - return task.wrapErr(ctx, "failed to download artifacts", err, fields) - } - logtrace.Info(ctx, "File reconstructed and hash verified", fields) - // Notify: decode completed, file ready on disk - task.streamDownloadEvent(SupernodeEventTypeDecodeCompleted, "Decode completed", filePath, tmpDir, send) + filePath, tmpDir, err := task.downloadArtifacts(ctx, actionDetails.GetAction().ActionID, metadata, fields) + if err != nil { + fields[logtrace.FieldError] = err.Error() + return task.wrapErr(ctx, "failed to download artifacts", err, fields) + } + logtrace.Info(ctx, "File reconstructed and hash verified", fields) + // Notify: decode completed, file ready on disk + task.streamDownloadEvent(SupernodeEventTypeDecodeCompleted, "Decode completed", filePath, tmpDir, send) return nil } @@ -147,15 +147,15 @@ func (task *CascadeRegistrationTask) restoreFileFromLayout( fields["totalSymbols"] = totalSymbols fields["requiredSymbols"] = requiredSymbols - logtrace.Info(ctx, "Symbols to be retrieved", fields) + logtrace.Info(ctx, "Symbols to be retrieved", fields) - // Progressive retrieval moved to helper for readability/testing - decodeInfo, err := task.retrieveAndDecodeProgressively(ctx, layout, actionID, fields) - if err != nil { - fields[logtrace.FieldError] = err.Error() - logtrace.Error(ctx, "failed to decode symbols progressively", fields) - return "", "", fmt.Errorf("decode symbols using RaptorQ: %w", err) - } + // Progressive retrieval moved to helper for readability/testing + decodeInfo, err := task.retrieveAndDecodeProgressively(ctx, layout, actionID, fields) + if err != nil { + fields[logtrace.FieldError] = err.Error() + logtrace.Error(ctx, "failed to decode symbols progressively", fields) + return "", "", fmt.Errorf("decode symbols using RaptorQ: %w", err) + } fileHash, err := crypto.HashFileIncrementally(decodeInfo.FilePath, 0) if err != nil { @@ -175,7 +175,7 @@ func (task *CascadeRegistrationTask) restoreFileFromLayout( fields[logtrace.FieldError] = err.Error() return "", decodeInfo.DecodeTmpDir, err } - logtrace.Info(ctx, "File successfully restored and hash verified", fields) + logtrace.Info(ctx, "File successfully restored and hash verified", fields) return decodeInfo.FilePath, decodeInfo.DecodeTmpDir, nil } diff --git a/supernode/services/cascade/progressive_decode.go b/supernode/services/cascade/progressive_decode.go index 65cd6820..3db58988 100644 --- a/supernode/services/cascade/progressive_decode.go +++ b/supernode/services/cascade/progressive_decode.go @@ -10,16 +10,10 @@ import ( ) // retrieveAndDecodeProgressively performs a minimal two-step retrieval for a single-block layout: -// 1) Fetch approximately requiredSymbolPercent of symbols and try decoding. -// 2) If that fails, fetch all available symbols from the block and try again. -// This replaces earlier multi-block balancing and multi-threshold escalation. -func (task *CascadeRegistrationTask) retrieveAndDecodeProgressively( - ctx context.Context, - layout codec.Layout, - actionID string, - fields logtrace.Fields, -) (adaptors.DecodeResponse, error) { - // Ensure base context fields are present for logs +// 1) Send ALL keys with a minimum required count (requiredSymbolPercent). +// 2) If decode fails, escalate by asking for ALL symbols (required = total). +func (task *CascadeRegistrationTask) retrieveAndDecodeProgressively(ctx context.Context, layout codec.Layout, actionID string, + fields logtrace.Fields) (adaptors.DecodeResponse, error) { if fields == nil { fields = logtrace.Fields{} } @@ -29,7 +23,7 @@ func (task *CascadeRegistrationTask) retrieveAndDecodeProgressively( return adaptors.DecodeResponse{}, fmt.Errorf("empty layout: no blocks") } - // Single-block fast path + // Single-block path if len(layout.Blocks) == 1 { blk := layout.Blocks[0] total := len(blk.Symbols) @@ -37,20 +31,19 @@ func (task *CascadeRegistrationTask) retrieveAndDecodeProgressively( return adaptors.DecodeResponse{}, fmt.Errorf("empty layout: no symbols") } - // Step 1: try with requiredSymbolPercent of symbols + // Step 1: send ALL keys, require only reqCount reqCount := (total*requiredSymbolPercent + 99) / 100 if reqCount < 1 { reqCount = 1 - } - if reqCount > total { + } else if reqCount > total { reqCount = total } fields["targetPercent"] = requiredSymbolPercent fields["targetCount"] = reqCount + fields["total"] = total logtrace.Info(ctx, "retrieving initial symbols (single block)", fields) - keys := blk.Symbols[:reqCount] - symbols, err := task.P2PClient.BatchRetrieve(ctx, keys, reqCount, actionID) + symbols, err := task.P2PClient.BatchRetrieve(ctx, blk.Symbols, reqCount, actionID) if err != nil { fields[logtrace.FieldError] = err.Error() logtrace.Error(ctx, "failed to retrieve symbols", fields) @@ -66,14 +59,18 @@ func (task *CascadeRegistrationTask) retrieveAndDecodeProgressively( return decodeInfo, nil } - // Step 2: escalate to all symbols - logtrace.Info(ctx, "initial decode failed; retrieving all symbols (single block)", nil) - symbols, err = task.P2PClient.BatchRetrieve(ctx, blk.Symbols, total, actionID) + // Step 2: escalate to require ALL symbols + fields["escalating"] = true + fields["requiredCount"] = total + logtrace.Info(ctx, "initial decode failed; retrieving all symbols (single block)", fields) + + symbols, err = task.P2PClient.BatchRetrieve(ctx, blk.Symbols, reqCount*2, actionID) if err != nil { fields[logtrace.FieldError] = err.Error() logtrace.Error(ctx, "failed to retrieve all symbols", fields) return adaptors.DecodeResponse{}, fmt.Errorf("failed to retrieve symbols: %w", err) } + return task.RQ.Decode(ctx, adaptors.DecodeRequest{ ActionID: actionID, Symbols: symbols, @@ -81,6 +78,5 @@ func (task *CascadeRegistrationTask) retrieveAndDecodeProgressively( }) } - // Multi-block layouts are not supported by current policy return adaptors.DecodeResponse{}, fmt.Errorf("unsupported layout: expected 1 block, found %d", len(layout.Blocks)) }