Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 59 additions & 30 deletions p2p/kademlia/dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ const (
defaultDeleteDataInterval = 11 * time.Hour
delKeysCountThreshold = 10
lowSpaceThreshold = 50 // GB
batchStoreSize = 2500
batchRetrieveSize = 1000
storeSameSymbolsBatchConcurrency = 3
storeSymbolsBatchConcurrency = 3.0
fetchSymbolsBatchConcurrency = 6
minimumDataStoreSuccessRate = 75.0

maxIterations = 4
Expand Down Expand Up @@ -734,10 +734,10 @@ func (s *DHT) BatchRetrieve(ctx context.Context, keys []string, required int32,
return result, nil
}

batchSize := batchStoreSize
batchSize := batchRetrieveSize
var networkFound int32
totalBatches := int(math.Ceil(float64(required) / float64(batchSize)))
parallelBatches := int(math.Min(float64(totalBatches), storeSymbolsBatchConcurrency))
parallelBatches := int(math.Min(float64(totalBatches), fetchSymbolsBatchConcurrency))

semaphore := make(chan struct{}, parallelBatches)
var wg sync.WaitGroup
Expand Down Expand Up @@ -775,7 +775,13 @@ func (s *DHT) BatchRetrieve(ctx context.Context, keys []string, required int32,
wg.Wait()

netFound := int(atomic.LoadInt32(&networkFound))
s.metrics.RecordBatchRetrieve(len(keys), int(required), int(foundLocalCount), netFound, time.Duration(time.Since(start).Milliseconds())) // NEW
totalFound := int(foundLocalCount) + netFound

s.metrics.RecordBatchRetrieve(len(keys), int(required), int(foundLocalCount), netFound, time.Since(start))

if totalFound < int(required) {
return result, errors.Errorf("insufficient symbols: required=%d, found=%d", required, totalFound)
}

return result, nil
}
Expand All @@ -800,7 +806,7 @@ func (s *DHT) processBatch(
defer wg.Done()
defer func() { <-semaphore }()

for i := 0; i < maxIterations; i++ {
for i := 0; i < 1; i++ {
select {
case <-ctx.Done():
return
Expand All @@ -822,9 +828,18 @@ func (s *DHT) processBatch(
}
}

knownMu.Lock()
nodesSnap := make(map[string]*Node, len(knownNodes))
for id, n := range knownNodes {
nodesSnap[id] = n
}
knownMu.Unlock()

foundCount, newClosestContacts, batchErr := s.iterateBatchGetValues(
ctx, knownNodes, batchKeys, batchHexKeys, fetchMap, resMap, required, foundLocalCount+atomic.LoadInt32(networkFound),
ctx, nodesSnap, batchKeys, batchHexKeys, fetchMap, resMap,
required, foundLocalCount+atomic.LoadInt32(networkFound),
)

if batchErr != nil {
logtrace.Error(ctx, "Iterate batch get values failed", logtrace.Fields{
logtrace.FieldModule: "dht", "txid": txID, logtrace.FieldError: batchErr.Error(),
Expand Down Expand Up @@ -872,19 +887,36 @@ func (s *DHT) processBatch(
}
}

func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node, keys []string, hexKeys []string, fetchMap map[string][]int,
resMap *sync.Map, req, alreadyFound int32) (int, map[string]*NodeList, error) {
semaphore := make(chan struct{}, storeSameSymbolsBatchConcurrency) // Limit concurrency to 1
func (s *DHT) iterateBatchGetValues(
ctx context.Context,
nodes map[string]*Node,
keys []string,
hexKeys []string,
fetchMap map[string][]int,
resMap *sync.Map,
req, alreadyFound int32,
) (int, map[string]*NodeList, error) {

semaphore := make(chan struct{}, storeSameSymbolsBatchConcurrency)
closestContacts := make(map[string]*NodeList)
var wg sync.WaitGroup
contactsMap := make(map[string]map[string][]*Node)
var firstErr error
var mu sync.Mutex // To protect the firstErr
var mu sync.Mutex
foundCount := int32(0)

gctx, cancel := context.WithCancel(ctx) // Create a cancellable context
gctx, cancel := context.WithCancel(ctx)
defer cancel()
for nodeID, node := range nodes {

// ✅ Iterate ONLY nodes that actually have work according to fetchMap
for nodeID, idxs := range fetchMap {
if len(idxs) == 0 {
continue
}
node, ok := nodes[nodeID]
if !ok {
continue
}
if s.ignorelist.Banned(node) {
logtrace.Info(ctx, "Ignore banned node in iterate batch get values", logtrace.Fields{
logtrace.FieldModule: "dht",
Expand All @@ -894,8 +926,9 @@ func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node,
}

contactsMap[nodeID] = make(map[string][]*Node)

wg.Add(1)
go func(node *Node, nodeID string) {
go func(node *Node, nodeID string, indices []int) {
defer wg.Done()

select {
Expand All @@ -907,17 +940,15 @@ func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node,
defer func() { <-semaphore }()
}

indices := fetchMap[nodeID]
requestKeys := make(map[string]KeyValWithClosest)
// Build requestKeys from the provided indices only
requestKeys := make(map[string]KeyValWithClosest, len(indices))
for _, idx := range indices {
if idx < len(hexKeys) {
_, loaded := resMap.Load(hexKeys[idx]) // check if key is already there in resMap
if !loaded {
if idx >= 0 && idx < len(hexKeys) {
if _, loaded := resMap.Load(hexKeys[idx]); !loaded {
requestKeys[hexKeys[idx]] = KeyValWithClosest{}
}
}
}

if len(requestKeys) == 0 {
return
}
Expand All @@ -932,21 +963,20 @@ func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node,
return
}

// Merge values or closest contacts
for k, v := range decompressedData {
if len(v.Value) > 0 {
_, loaded := resMap.LoadOrStore(k, v.Value)
if !loaded {
atomic.AddInt32(&foundCount, 1)
if atomic.LoadInt32(&foundCount) >= int32(req-alreadyFound) {
cancel() // Cancel context to stop other goroutines
if _, loaded := resMap.LoadOrStore(k, v.Value); !loaded {
if atomic.AddInt32(&foundCount, 1) >= int32(req-alreadyFound) {
cancel()
return
}
}
} else {
contactsMap[nodeID][k] = v.Closest
}
}
}(node, nodeID)
}(node, nodeID, idxs)
}

wg.Wait()
Expand All @@ -964,27 +994,26 @@ func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node,
})
}

// Build closestContacts from contactsMap (same as before)
for _, closestNodes := range contactsMap {
for key, nodes := range closestNodes {
comparator, err := hex.DecodeString(key)
if err != nil {
logtrace.Error(ctx, "Failed to decode hex key in closestNodes.Range", logtrace.Fields{
logtrace.Error(ctx, "Failed to decode hex key in closestNodes", logtrace.Fields{
logtrace.FieldModule: "dht",
"key": key,
logtrace.FieldError: err.Error(),
})
return 0, nil, err
return int(foundCount), nil, err
}
bkey := base58.Encode(comparator)

if _, ok := closestContacts[bkey]; !ok {
closestContacts[bkey] = &NodeList{Nodes: nodes, Comparator: comparator}
} else {
closestContacts[bkey].AddNodes(nodes)
}
}
}

for key, nodes := range closestContacts {
nodes.Sort()
nodes.TopN(Alpha)
Expand Down
54 changes: 27 additions & 27 deletions supernode/services/cascade/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
)

const (
requiredSymbolPercent = 9
requiredSymbolPercent = 17
)

type DownloadRequest struct {
Expand All @@ -36,8 +36,8 @@ func (task *CascadeRegistrationTask) Download(
req *DownloadRequest,
send func(resp *DownloadResponse) error,
) (err error) {
fields := logtrace.Fields{logtrace.FieldMethod: "Download", logtrace.FieldRequest: req}
logtrace.Info(ctx, "Cascade download request received", fields)
fields := logtrace.Fields{logtrace.FieldMethod: "Download", logtrace.FieldRequest: req}
logtrace.Info(ctx, "Cascade download request received", fields)

// Ensure task status is finalized regardless of outcome
defer func() {
Expand All @@ -54,36 +54,36 @@ func (task *CascadeRegistrationTask) Download(
fields[logtrace.FieldError] = err
return task.wrapErr(ctx, "failed to get action", err, fields)
}
logtrace.Info(ctx, "Action retrieved", fields)
task.streamDownloadEvent(SupernodeEventTypeActionRetrieved, "Action retrieved", "", "", send)
logtrace.Info(ctx, "Action retrieved", fields)
task.streamDownloadEvent(SupernodeEventTypeActionRetrieved, "Action retrieved", "", "", send)

if actionDetails.GetAction().State != actiontypes.ActionStateDone {
err = errors.New("action is not in a valid state")
fields[logtrace.FieldError] = "action state is not done yet"
fields[logtrace.FieldActionState] = actionDetails.GetAction().State
return task.wrapErr(ctx, "action not found", err, fields)
}
logtrace.Info(ctx, "Action state validated", fields)
logtrace.Info(ctx, "Action state validated", fields)

metadata, err := task.decodeCascadeMetadata(ctx, actionDetails.GetAction().Metadata, fields)
if err != nil {
fields[logtrace.FieldError] = err.Error()
return task.wrapErr(ctx, "error decoding cascade metadata", err, fields)
}
logtrace.Info(ctx, "Cascade metadata decoded", fields)
task.streamDownloadEvent(SupernodeEventTypeMetadataDecoded, "Cascade metadata decoded", "", "", send)
logtrace.Info(ctx, "Cascade metadata decoded", fields)
task.streamDownloadEvent(SupernodeEventTypeMetadataDecoded, "Cascade metadata decoded", "", "", send)

// Notify: network retrieval phase begins
task.streamDownloadEvent(SupernodeEventTypeNetworkRetrieveStarted, "Network retrieval started", "", "", send)
// Notify: network retrieval phase begins
task.streamDownloadEvent(SupernodeEventTypeNetworkRetrieveStarted, "Network retrieval started", "", "", send)

filePath, tmpDir, err := task.downloadArtifacts(ctx, actionDetails.GetAction().ActionID, metadata, fields)
if err != nil {
fields[logtrace.FieldError] = err.Error()
return task.wrapErr(ctx, "failed to download artifacts", err, fields)
}
logtrace.Info(ctx, "File reconstructed and hash verified", fields)
// Notify: decode completed, file ready on disk
task.streamDownloadEvent(SupernodeEventTypeDecodeCompleted, "Decode completed", filePath, tmpDir, send)
filePath, tmpDir, err := task.downloadArtifacts(ctx, actionDetails.GetAction().ActionID, metadata, fields)
if err != nil {
fields[logtrace.FieldError] = err.Error()
return task.wrapErr(ctx, "failed to download artifacts", err, fields)
}
logtrace.Info(ctx, "File reconstructed and hash verified", fields)
// Notify: decode completed, file ready on disk
task.streamDownloadEvent(SupernodeEventTypeDecodeCompleted, "Decode completed", filePath, tmpDir, send)

return nil
}
Expand Down Expand Up @@ -147,15 +147,15 @@ func (task *CascadeRegistrationTask) restoreFileFromLayout(

fields["totalSymbols"] = totalSymbols
fields["requiredSymbols"] = requiredSymbols
logtrace.Info(ctx, "Symbols to be retrieved", fields)
logtrace.Info(ctx, "Symbols to be retrieved", fields)

// Progressive retrieval moved to helper for readability/testing
decodeInfo, err := task.retrieveAndDecodeProgressively(ctx, layout, actionID, fields)
if err != nil {
fields[logtrace.FieldError] = err.Error()
logtrace.Error(ctx, "failed to decode symbols progressively", fields)
return "", "", fmt.Errorf("decode symbols using RaptorQ: %w", err)
}
// Progressive retrieval moved to helper for readability/testing
decodeInfo, err := task.retrieveAndDecodeProgressively(ctx, layout, actionID, fields)
if err != nil {
fields[logtrace.FieldError] = err.Error()
logtrace.Error(ctx, "failed to decode symbols progressively", fields)
return "", "", fmt.Errorf("decode symbols using RaptorQ: %w", err)
}

fileHash, err := crypto.HashFileIncrementally(decodeInfo.FilePath, 0)
if err != nil {
Expand All @@ -175,7 +175,7 @@ func (task *CascadeRegistrationTask) restoreFileFromLayout(
fields[logtrace.FieldError] = err.Error()
return "", decodeInfo.DecodeTmpDir, err
}
logtrace.Info(ctx, "File successfully restored and hash verified", fields)
logtrace.Info(ctx, "File successfully restored and hash verified", fields)

return decodeInfo.FilePath, decodeInfo.DecodeTmpDir, nil
}
Expand Down
36 changes: 16 additions & 20 deletions supernode/services/cascade/progressive_decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,10 @@ import (
)

// retrieveAndDecodeProgressively performs a minimal two-step retrieval for a single-block layout:
// 1) Fetch approximately requiredSymbolPercent of symbols and try decoding.
// 2) If that fails, fetch all available symbols from the block and try again.
// This replaces earlier multi-block balancing and multi-threshold escalation.
func (task *CascadeRegistrationTask) retrieveAndDecodeProgressively(
ctx context.Context,
layout codec.Layout,
actionID string,
fields logtrace.Fields,
) (adaptors.DecodeResponse, error) {
// Ensure base context fields are present for logs
// 1) Send ALL keys with a minimum required count (requiredSymbolPercent).
// 2) If decode fails, escalate by asking for ALL symbols (required = total).
func (task *CascadeRegistrationTask) retrieveAndDecodeProgressively(ctx context.Context, layout codec.Layout, actionID string,
fields logtrace.Fields) (adaptors.DecodeResponse, error) {
if fields == nil {
fields = logtrace.Fields{}
}
Expand All @@ -29,28 +23,27 @@ func (task *CascadeRegistrationTask) retrieveAndDecodeProgressively(
return adaptors.DecodeResponse{}, fmt.Errorf("empty layout: no blocks")
}

// Single-block fast path
// Single-block path
if len(layout.Blocks) == 1 {
blk := layout.Blocks[0]
total := len(blk.Symbols)
if total == 0 {
return adaptors.DecodeResponse{}, fmt.Errorf("empty layout: no symbols")
}

// Step 1: try with requiredSymbolPercent of symbols
// Step 1: send ALL keys, require only reqCount
reqCount := (total*requiredSymbolPercent + 99) / 100
if reqCount < 1 {
reqCount = 1
}
if reqCount > total {
} else if reqCount > total {
reqCount = total
}
fields["targetPercent"] = requiredSymbolPercent
fields["targetCount"] = reqCount
fields["total"] = total
logtrace.Info(ctx, "retrieving initial symbols (single block)", fields)

keys := blk.Symbols[:reqCount]
symbols, err := task.P2PClient.BatchRetrieve(ctx, keys, reqCount, actionID)
symbols, err := task.P2PClient.BatchRetrieve(ctx, blk.Symbols, reqCount, actionID)
if err != nil {
fields[logtrace.FieldError] = err.Error()
logtrace.Error(ctx, "failed to retrieve symbols", fields)
Expand All @@ -66,21 +59,24 @@ func (task *CascadeRegistrationTask) retrieveAndDecodeProgressively(
return decodeInfo, nil
}

// Step 2: escalate to all symbols
logtrace.Info(ctx, "initial decode failed; retrieving all symbols (single block)", nil)
symbols, err = task.P2PClient.BatchRetrieve(ctx, blk.Symbols, total, actionID)
// Step 2: escalate to require ALL symbols
fields["escalating"] = true
fields["requiredCount"] = total
logtrace.Info(ctx, "initial decode failed; retrieving all symbols (single block)", fields)

symbols, err = task.P2PClient.BatchRetrieve(ctx, blk.Symbols, reqCount*2, actionID)
if err != nil {
fields[logtrace.FieldError] = err.Error()
logtrace.Error(ctx, "failed to retrieve all symbols", fields)
return adaptors.DecodeResponse{}, fmt.Errorf("failed to retrieve symbols: %w", err)
}

return task.RQ.Decode(ctx, adaptors.DecodeRequest{
ActionID: actionID,
Symbols: symbols,
Layout: layout,
})
}

// Multi-block layouts are not supported by current policy
return adaptors.DecodeResponse{}, fmt.Errorf("unsupported layout: expected 1 block, found %d", len(layout.Blocks))
}