diff --git a/cmd/test-streammanager/bridge.go b/cmd/test-streammanager/bridge.go new file mode 100644 index 0000000000..501adc3d32 --- /dev/null +++ b/cmd/test-streammanager/bridge.go @@ -0,0 +1,40 @@ +// Copyright 2026, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "fmt" + + "github.com/wavetermdev/waveterm/pkg/wshrpc" +) + +// WriterBridge - used by the writer broker +// Sends data to the pipe, receives acks from the pipe +type WriterBridge struct { + pipe *DeliveryPipe +} + +func (b *WriterBridge) StreamDataCommand(data wshrpc.CommandStreamData, opts *wshrpc.RpcOpts) error { + b.pipe.EnqueueData(data) + return nil +} + +func (b *WriterBridge) StreamDataAckCommand(ack wshrpc.CommandStreamAckData, opts *wshrpc.RpcOpts) error { + return fmt.Errorf("writer bridge should not send acks") +} + +// ReaderBridge - used by the reader broker +// Sends acks to the pipe, receives data from the pipe +type ReaderBridge struct { + pipe *DeliveryPipe +} + +func (b *ReaderBridge) StreamDataCommand(data wshrpc.CommandStreamData, opts *wshrpc.RpcOpts) error { + return fmt.Errorf("reader bridge should not send data") +} + +func (b *ReaderBridge) StreamDataAckCommand(ack wshrpc.CommandStreamAckData, opts *wshrpc.RpcOpts) error { + b.pipe.EnqueueAck(ack) + return nil +} diff --git a/cmd/test-streammanager/deliverypipe.go b/cmd/test-streammanager/deliverypipe.go new file mode 100644 index 0000000000..8f8451f45a --- /dev/null +++ b/cmd/test-streammanager/deliverypipe.go @@ -0,0 +1,249 @@ +// Copyright 2026, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "encoding/base64" + "math/rand" + "sort" + "sync" + "time" + + "github.com/wavetermdev/waveterm/pkg/wshrpc" +) + +type DeliveryConfig struct { + Delay time.Duration + Skew time.Duration +} + +type taggedPacket struct { + seq uint64 + deliveryTime time.Time + isData bool + dataPk wshrpc.CommandStreamData + ackPk wshrpc.CommandStreamAckData + dataSize int +} + +type DeliveryPipe struct { + lock sync.Mutex + config DeliveryConfig + + // Sequence counters (separate for data and ack) + dataSeq uint64 + ackSeq uint64 + + // Pending packets sorted by (deliveryTime, seq) + dataPending []taggedPacket + ackPending []taggedPacket + + // Delivery targets + dataTarget func(wshrpc.CommandStreamData) + ackTarget func(wshrpc.CommandStreamAckData) + + // Control + closed bool + wg sync.WaitGroup + + // Metrics + metrics *Metrics + lastDataSeqNum int64 + lastAckSeqNum int64 + + // Byte tracking for high water mark + currentBytes int64 +} + +func NewDeliveryPipe(config DeliveryConfig, metrics *Metrics) *DeliveryPipe { + return &DeliveryPipe{ + config: config, + metrics: metrics, + lastDataSeqNum: -1, + lastAckSeqNum: -1, + } +} + +func (dp *DeliveryPipe) SetDataTarget(fn func(wshrpc.CommandStreamData)) { + dp.lock.Lock() + defer dp.lock.Unlock() + dp.dataTarget = fn +} + +func (dp *DeliveryPipe) SetAckTarget(fn func(wshrpc.CommandStreamAckData)) { + dp.lock.Lock() + defer dp.lock.Unlock() + dp.ackTarget = fn +} + +func (dp *DeliveryPipe) EnqueueData(pkt wshrpc.CommandStreamData) { + dp.lock.Lock() + defer dp.lock.Unlock() + + if dp.closed { + return + } + + dataSize := base64.StdEncoding.DecodedLen(len(pkt.Data64)) + dp.dataSeq++ + tagged := taggedPacket{ + seq: dp.dataSeq, + deliveryTime: dp.computeDeliveryTime(), + isData: true, + dataPk: pkt, + dataSize: dataSize, + } + + dp.dataPending = append(dp.dataPending, tagged) + dp.sortPending(&dp.dataPending) + + dp.currentBytes += int64(dataSize) + if dp.metrics != nil { + dp.metrics.AddDataPacket() + dp.metrics.UpdatePipeHighWaterMark(dp.currentBytes) + } +} + +func (dp *DeliveryPipe) EnqueueAck(pkt wshrpc.CommandStreamAckData) { + dp.lock.Lock() + defer dp.lock.Unlock() + + if dp.closed { + return + } + + dp.ackSeq++ + tagged := taggedPacket{ + seq: dp.ackSeq, + deliveryTime: dp.computeDeliveryTime(), + isData: false, + ackPk: pkt, + } + + dp.ackPending = append(dp.ackPending, tagged) + dp.sortPending(&dp.ackPending) + + if dp.metrics != nil { + dp.metrics.AddAckPacket() + } +} + +func (dp *DeliveryPipe) computeDeliveryTime() time.Time { + base := time.Now().Add(dp.config.Delay) + + if dp.config.Skew == 0 { + return base + } + + // Random skew: -skew to +skew + skewNs := dp.config.Skew.Nanoseconds() + randomSkew := time.Duration(rand.Int63n(2*skewNs+1) - skewNs) + return base.Add(randomSkew) +} + +func (dp *DeliveryPipe) sortPending(pending *[]taggedPacket) { + sort.Slice(*pending, func(i, j int) bool { + pi, pj := (*pending)[i], (*pending)[j] + if pi.deliveryTime.Equal(pj.deliveryTime) { + return pi.seq < pj.seq + } + return pi.deliveryTime.Before(pj.deliveryTime) + }) +} + +func (dp *DeliveryPipe) Start() { + dp.wg.Add(2) + go dp.dataDeliveryLoop() + go dp.ackDeliveryLoop() +} + +func (dp *DeliveryPipe) dataDeliveryLoop() { + defer dp.wg.Done() + dp.deliveryLoop( + func() *[]taggedPacket { return &dp.dataPending }, + func(pkt taggedPacket) { + if dp.dataTarget != nil { + // Track out-of-order packets + if dp.metrics != nil && dp.lastDataSeqNum != -1 { + if pkt.dataPk.Seq < dp.lastDataSeqNum { + dp.metrics.AddOOOPacket() + } + } + dp.lastDataSeqNum = pkt.dataPk.Seq + dp.dataTarget(pkt.dataPk) + + dp.lock.Lock() + dp.currentBytes -= int64(pkt.dataSize) + dp.lock.Unlock() + } + }, + ) +} + +func (dp *DeliveryPipe) ackDeliveryLoop() { + defer dp.wg.Done() + dp.deliveryLoop( + func() *[]taggedPacket { return &dp.ackPending }, + func(pkt taggedPacket) { + if dp.ackTarget != nil { + // Track out-of-order acks + if dp.metrics != nil && dp.lastAckSeqNum != -1 { + if pkt.ackPk.Seq < dp.lastAckSeqNum { + dp.metrics.AddOOOPacket() + } + } + dp.lastAckSeqNum = pkt.ackPk.Seq + dp.ackTarget(pkt.ackPk) + } + }, + ) +} + +func (dp *DeliveryPipe) deliveryLoop( + getPending func() *[]taggedPacket, + deliver func(taggedPacket), +) { + for { + dp.lock.Lock() + if dp.closed { + dp.lock.Unlock() + return + } + + pending := getPending() + now := time.Now() + + // Find all packets ready for delivery (deliveryTime <= now) + readyCount := 0 + for _, pkt := range *pending { + if pkt.deliveryTime.After(now) { + break + } + readyCount++ + } + + // Extract ready packets + ready := make([]taggedPacket, readyCount) + copy(ready, (*pending)[:readyCount]) + *pending = (*pending)[readyCount:] + + dp.lock.Unlock() + + // Deliver all ready packets (outside lock) + for _, pkt := range ready { + deliver(pkt) + } + + // Always sleep 1ms - simple busy loop + time.Sleep(1 * time.Millisecond) + } +} + +func (dp *DeliveryPipe) Close() { + dp.lock.Lock() + dp.closed = true + dp.lock.Unlock() + + dp.wg.Wait() +} diff --git a/cmd/test-streammanager/generator.go b/cmd/test-streammanager/generator.go new file mode 100644 index 0000000000..5cfc92b4b3 --- /dev/null +++ b/cmd/test-streammanager/generator.go @@ -0,0 +1,40 @@ +// Copyright 2026, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "io" +) + +// Base64 charset: all printable, easy to inspect manually +const Base64Chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + +type TestDataGenerator struct { + totalBytes int64 + generated int64 +} + +func NewTestDataGenerator(totalBytes int64) *TestDataGenerator { + return &TestDataGenerator{totalBytes: totalBytes} +} + +func (g *TestDataGenerator) Read(p []byte) (n int, err error) { + if g.generated >= g.totalBytes { + return 0, io.EOF + } + + remaining := g.totalBytes - g.generated + toRead := int64(len(p)) + if toRead > remaining { + toRead = remaining + } + + // Sequential pattern using base64 chars (0-63 cycling) + for i := int64(0); i < toRead; i++ { + p[i] = Base64Chars[(g.generated+i)%64] + } + + g.generated += toRead + return int(toRead), nil +} diff --git a/cmd/test-streammanager/main-test-streammanager.go b/cmd/test-streammanager/main-test-streammanager.go new file mode 100644 index 0000000000..4e6702e790 --- /dev/null +++ b/cmd/test-streammanager/main-test-streammanager.go @@ -0,0 +1,254 @@ +// Copyright 2026, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "fmt" + "io" + "log" + "os" + "time" + + "github.com/spf13/cobra" + "github.com/wavetermdev/waveterm/pkg/jobmanager" + "github.com/wavetermdev/waveterm/pkg/streamclient" + "github.com/wavetermdev/waveterm/pkg/wshrpc" +) + +type TestConfig struct { + Mode string + DataSize int64 + Delay time.Duration + Skew time.Duration + WindowSize int + SlowReader int + Verbose bool +} + +var config TestConfig + +var rootCmd = &cobra.Command{ + Use: "test-streammanager", + Short: "Integration test for StreamManager streaming system", + RunE: func(cmd *cobra.Command, args []string) error { + return runTest(config) + }, +} + +func init() { + rootCmd.Flags().StringVar(&config.Mode, "mode", "streammanager", "Writer mode: 'streammanager' or 'writer'") + rootCmd.Flags().Int64Var(&config.DataSize, "size", 10*1024*1024, "Total data to transfer (bytes)") + rootCmd.Flags().DurationVar(&config.Delay, "delay", 0, "Base delivery delay (e.g., 10ms)") + rootCmd.Flags().DurationVar(&config.Skew, "skew", 0, "Delivery skew +/- (e.g., 5ms)") + rootCmd.Flags().IntVar(&config.WindowSize, "windowsize", 64*1024, "Window size for both sender and receiver") + rootCmd.Flags().IntVar(&config.SlowReader, "slowreader", 0, "Slow reader mode: bytes per second (0=disabled, e.g., 1024)") + rootCmd.Flags().BoolVar(&config.Verbose, "verbose", false, "Enable verbose logging") +} + +func main() { + if err := rootCmd.Execute(); err != nil { + os.Exit(1) + } +} + +func runTest(config TestConfig) error { + if config.Mode != "streammanager" && config.Mode != "writer" { + return fmt.Errorf("invalid mode: %s (must be 'streammanager' or 'writer')", config.Mode) + } + + fmt.Printf("Starting Streaming Integration Test\n") + fmt.Printf(" Mode: %s\n", config.Mode) + fmt.Printf(" Data Size: %d bytes\n", config.DataSize) + fmt.Printf(" Delay: %v, Skew: %v\n", config.Delay, config.Skew) + fmt.Printf(" Window Size: %d\n", config.WindowSize) + if config.SlowReader > 0 { + fmt.Printf(" Slow Reader: %d bytes/sec\n", config.SlowReader) + } + + // 1. Create metrics + metrics := NewMetrics() + + // 2. Create the delivery pipe + pipe := NewDeliveryPipe(DeliveryConfig{ + Delay: config.Delay, + Skew: config.Skew, + }, metrics) + + // 3. Create brokers with bridges + writerBridge := &WriterBridge{pipe: pipe} + readerBridge := &ReaderBridge{pipe: pipe} + + writerBroker := streamclient.NewBroker(writerBridge) + readerBroker := streamclient.NewBroker(readerBridge) + + // 4. Wire up delivery targets + pipe.SetDataTarget(readerBroker.RecvData) + pipe.SetAckTarget(writerBroker.RecvAck) + + // 5. Start the delivery pipe + pipe.Start() + + // 6. Create the reader side + reader, streamMeta := readerBroker.CreateStreamReader("reader-route", "writer-route", int64(config.WindowSize)) + + // 7. Set up writer side based on mode + var writerDone chan error + if config.Mode == "streammanager" { + writerDone = runStreamManagerMode(config, writerBroker, streamMeta) + } else { + writerDone = runWriterMode(config, writerBroker, streamMeta) + } + + // 8. Create verifier + verifier := NewVerifier(config.DataSize) + + // 9. Create metrics writer wrapper + metricsWriter := &MetricsWriter{ + writer: verifier, + metrics: metrics, + } + + // 10. Wrap reader with slow reader if configured + var actualReader io.Reader = reader + if config.SlowReader > 0 { + actualReader = NewSlowReader(reader, config.SlowReader) + } + + // 11. Start reading from stream reader and writing to verifier + metrics.Start() + + readerDone := make(chan error) + go func() { + _, err := io.Copy(metricsWriter, actualReader) + readerDone <- err + }() + + // 12. Wait for completion + var writerErr, readerErr error + if writerDone != nil { + writerErr = <-writerDone + } + readerErr = <-readerDone + metrics.End() + + // 13. Cleanup + pipe.Close() + writerBroker.Close() + readerBroker.Close() + + // 14. Report results + fmt.Println(metrics.Report()) + fmt.Printf("Verification: received=%d, mismatches=%d\n", + verifier.TotalReceived(), verifier.Mismatches()) + + if writerErr != nil && writerErr != io.EOF { + return fmt.Errorf("writer error: %w", writerErr) + } + + if readerErr != nil && readerErr != io.EOF { + return fmt.Errorf("reader error: %w", readerErr) + } + + if verifier.Mismatches() > 0 { + return fmt.Errorf("data corruption: %d mismatches, first at byte %d", + verifier.Mismatches(), verifier.FirstMismatch()) + } + + fmt.Println("TEST PASSED") + return nil +} + +func runStreamManagerMode(config TestConfig, writerBroker *streamclient.Broker, streamMeta *wshrpc.StreamMeta) chan error { + streamManager := jobmanager.MakeStreamManagerWithSizes(config.WindowSize, 2*1024*1024) + writerBroker.AttachStreamWriter(streamMeta, streamManager) + + dataSender := &BrokerDataSender{broker: writerBroker} + startSeq, err := streamManager.ClientConnected(streamMeta.Id, dataSender, config.WindowSize, 0) + if err != nil { + fmt.Printf("failed to connect stream manager: %v\n", err) + return nil + } + fmt.Printf(" Stream connected, startSeq: %d\n", startSeq) + + generator := NewTestDataGenerator(config.DataSize) + if err := streamManager.AttachReader(generator); err != nil { + fmt.Printf("failed to attach reader: %v\n", err) + return nil + } + + return nil +} + +func runWriterMode(config TestConfig, writerBroker *streamclient.Broker, streamMeta *wshrpc.StreamMeta) chan error { + writer, err := writerBroker.CreateStreamWriter(streamMeta) + if err != nil { + fmt.Printf("failed to create stream writer: %v\n", err) + return nil + } + fmt.Printf(" Stream writer created\n") + + generator := NewTestDataGenerator(config.DataSize) + + done := make(chan error, 1) + go func() { + _, copyErr := io.Copy(writer, generator) + closeErr := writer.Close() + if copyErr != nil && copyErr != io.EOF { + done <- copyErr + } else { + done <- closeErr + } + }() + + return done +} + +// BrokerDataSender implements DataSender interface +type BrokerDataSender struct { + broker *streamclient.Broker +} + +func (s *BrokerDataSender) SendData(dataPk wshrpc.CommandStreamData) { + s.broker.SendData(dataPk) +} + +// MetricsWriter wraps an io.Writer and records bytes written to metrics +type MetricsWriter struct { + writer io.Writer + metrics *Metrics +} + +func (mw *MetricsWriter) Write(p []byte) (n int, err error) { + n, err = mw.writer.Write(p) + if n > 0 { + mw.metrics.AddBytes(int64(n)) + } + return n, err +} + +// SlowReader wraps an io.Reader and rate-limits reads to a specified bytes/sec +type SlowReader struct { + reader io.Reader + bytesPerSec int +} + +func NewSlowReader(reader io.Reader, bytesPerSec int) *SlowReader { + return &SlowReader{ + reader: reader, + bytesPerSec: bytesPerSec, + } +} + +func (sr *SlowReader) Read(p []byte) (n int, err error) { + time.Sleep(1 * time.Second) + + readSize := sr.bytesPerSec + if readSize > len(p) { + readSize = len(p) + } + + n, err = sr.reader.Read(p[:readSize]) + log.Printf("SlowReader: read %d bytes, err=%v", n, err) + return n, err +} diff --git a/cmd/test-streammanager/metrics.go b/cmd/test-streammanager/metrics.go new file mode 100644 index 0000000000..94b4f4169b --- /dev/null +++ b/cmd/test-streammanager/metrics.go @@ -0,0 +1,110 @@ +// Copyright 2026, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "fmt" + "sync" + "time" +) + +type Metrics struct { + lock sync.Mutex + + // Timing + startTime time.Time + endTime time.Time + + // Data transfer + totalBytes int64 + + // Packet counts + dataPackets int64 + ackPackets int64 + + // Out of order tracking + oooPackets int64 + + // High water mark for pipe bytes + pipeHighWaterMark int64 +} + +func NewMetrics() *Metrics { + return &Metrics{} +} + +func (m *Metrics) Start() { + m.lock.Lock() + defer m.lock.Unlock() + m.startTime = time.Now() +} + +func (m *Metrics) End() { + m.lock.Lock() + defer m.lock.Unlock() + m.endTime = time.Now() +} + +func (m *Metrics) AddDataPacket() { + m.lock.Lock() + defer m.lock.Unlock() + m.dataPackets++ +} + +func (m *Metrics) AddAckPacket() { + m.lock.Lock() + defer m.lock.Unlock() + m.ackPackets++ +} + +func (m *Metrics) AddOOOPacket() { + m.lock.Lock() + defer m.lock.Unlock() + m.oooPackets++ +} + +func (m *Metrics) AddBytes(n int64) { + m.lock.Lock() + defer m.lock.Unlock() + m.totalBytes += n +} + +func (m *Metrics) UpdatePipeHighWaterMark(currentBytes int64) { + m.lock.Lock() + defer m.lock.Unlock() + if currentBytes > m.pipeHighWaterMark { + m.pipeHighWaterMark = currentBytes + } +} + +func (m *Metrics) GetPipeHighWaterMark() int64 { + m.lock.Lock() + defer m.lock.Unlock() + return m.pipeHighWaterMark +} + +func (m *Metrics) Report() string { + m.lock.Lock() + defer m.lock.Unlock() + + duration := m.endTime.Sub(m.startTime) + durationSecs := duration.Seconds() + if durationSecs == 0 { + durationSecs = 1.0 + } + throughput := float64(m.totalBytes) / durationSecs / 1024 / 1024 + + return fmt.Sprintf(` +StreamManager Integration Test Results +====================================== +Duration: %v +Total Bytes: %d +Throughput: %.2f MB/s +Data Packets: %d +Ack Packets: %d +OOO Packets: %d +Pipe High Water: %d bytes (%.2f KB) +`, duration, m.totalBytes, throughput, m.dataPackets, m.ackPackets, m.oooPackets, + m.pipeHighWaterMark, float64(m.pipeHighWaterMark)/1024) +} diff --git a/cmd/test-streammanager/verifier.go b/cmd/test-streammanager/verifier.go new file mode 100644 index 0000000000..e6abe518a5 --- /dev/null +++ b/cmd/test-streammanager/verifier.go @@ -0,0 +1,63 @@ +// Copyright 2026, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "sync" +) + +type Verifier struct { + lock sync.Mutex + expectedGen *TestDataGenerator + totalReceived int64 + mismatches int + firstMismatch int64 +} + +func NewVerifier(totalBytes int64) *Verifier { + return &Verifier{ + expectedGen: NewTestDataGenerator(totalBytes), + firstMismatch: -1, + } +} + +func (v *Verifier) Write(p []byte) (n int, err error) { + v.lock.Lock() + defer v.lock.Unlock() + + expected := make([]byte, len(p)) + // expectedGen.Read() error ignored: TestDataGenerator is deterministic and won't fail, + // and any data length mismatch will be caught by byte comparison below + v.expectedGen.Read(expected) + + for i := 0; i < len(p); i++ { + if p[i] != expected[i] { + v.mismatches++ + if v.firstMismatch == -1 { + v.firstMismatch = v.totalReceived + int64(i) + } + } + } + + v.totalReceived += int64(len(p)) + return len(p), nil +} + +func (v *Verifier) TotalReceived() int64 { + v.lock.Lock() + defer v.lock.Unlock() + return v.totalReceived +} + +func (v *Verifier) Mismatches() int { + v.lock.Lock() + defer v.lock.Unlock() + return v.mismatches +} + +func (v *Verifier) FirstMismatch() int64 { + v.lock.Lock() + defer v.lock.Unlock() + return v.firstMismatch +} diff --git a/pkg/jobmanager/cirbuf.go b/pkg/jobmanager/cirbuf.go index fae4063b85..e3a5e415a9 100644 --- a/pkg/jobmanager/cirbuf.go +++ b/pkg/jobmanager/cirbuf.go @@ -4,7 +4,6 @@ package jobmanager import ( - "context" "fmt" "sync" ) @@ -63,46 +62,12 @@ func (cb *CirBuf) SetEffectiveWindow(syncMode bool, windowSize int) { } } -// Write will never block if syncMode is false -// If syncMode is true, write will block until enough data is consumed to allow the write to finish -// to cancel a write in progress use WriteCtx -func (cb *CirBuf) Write(data []byte) (int, error) { - return cb.WriteCtx(context.Background(), data) -} - -// WriteCtx writes data to the circular buffer with context support for cancellation. -// In sync mode, blocks when buffer is full until space is available or context is cancelled. -// Returns partial byte count and context error if cancelled mid-write. +// WriteAvailable attempts to write as much data as possible without blocking. +// Returns the number of bytes written and a channel to wait on if buffer is full (nil if not blocking). +// In sync mode when buffer is full, returns 0 written and a channel that will be closed when space is available. +// The caller should wait on the channel and retry the write. // NOTE: Only one concurrent blocked write is allowed. Multiple blocked writes will panic. -func (cb *CirBuf) WriteCtx(ctx context.Context, data []byte) (int, error) { - if len(data) == 0 { - return 0, nil - } - - bytesWritten := 0 - for bytesWritten < len(data) { - if err := ctx.Err(); err != nil { - return bytesWritten, err - } - - n, spaceAvailable := cb.writeAvailable(data[bytesWritten:]) - bytesWritten += n - - if spaceAvailable != nil { - select { - case <-spaceAvailable: - continue - case <-ctx.Done(): - tryReadCh(cb.waiterChan) - return bytesWritten, ctx.Err() - } - } - } - - return bytesWritten, nil -} - -func (cb *CirBuf) writeAvailable(data []byte) (int, chan struct{}) { +func (cb *CirBuf) WriteAvailable(data []byte) (int, <-chan struct{}) { cb.lock.Lock() defer cb.lock.Unlock() @@ -111,11 +76,14 @@ func (cb *CirBuf) writeAvailable(data []byte) (int, chan struct{}) { for i := 0; i < len(data); i++ { if cb.syncMode && cb.count >= cb.windowSize { + if written > 0 { + return written, nil + } spaceAvailable := make(chan struct{}) if !tryWriteCh(cb.waiterChan, spaceAvailable) { panic("CirBuf: multiple concurrent blocked writes not allowed") } - return written, spaceAvailable + return 0, spaceAvailable } cb.buf[cb.writePos] = data[i] diff --git a/pkg/jobmanager/streammanager.go b/pkg/jobmanager/streammanager.go index 8af2d64d2d..4d77ed5acc 100644 --- a/pkg/jobmanager/streammanager.go +++ b/pkg/jobmanager/streammanager.go @@ -53,6 +53,10 @@ type StreamManager struct { sentNotAcked int64 terminalEventSent bool + // track max acked to handle out-of-order ACKs (reset on disconnect) + maxAckedSeq int64 + maxAckedRwnd int64 + // terminal state - once true, stream is complete terminalEventAcked bool closed bool @@ -174,6 +178,8 @@ func (sm *StreamManager) ClientDisconnected() { sm.connected = false sm.dataSender = nil sm.sentNotAcked = 0 + sm.maxAckedSeq = 0 + sm.maxAckedRwnd = 0 if !sm.terminalEventAcked { sm.terminalEventSent = false } @@ -198,6 +204,19 @@ func (sm *StreamManager) RecvAck(ackPk wshrpc.CommandStreamAckData) { } seq := ackPk.Seq + rwnd := ackPk.RWnd + + // Ignore stale ACKs using tuple comparison (seq, rwnd) + if seq < sm.maxAckedSeq || (seq == sm.maxAckedSeq && rwnd <= sm.maxAckedRwnd) { + // log.Printf("streammanager ignoring stale ACK: seq=%d rwnd=%d (max: seq=%d rwnd=%d)", + // seq, rwnd, sm.maxAckedSeq, sm.maxAckedRwnd) + return + } + + // Update max acked tuple + sm.maxAckedSeq = seq + sm.maxAckedRwnd = rwnd + headPos := sm.buf.HeadPos() if seq < headPos { return @@ -287,11 +306,20 @@ func (sm *StreamManager) readLoop() { } func (sm *StreamManager) handleReadData(data []byte) { - sm.buf.Write(data) - sm.lock.Lock() - defer sm.lock.Unlock() - if sm.connected { - sm.drainCond.Signal() + offset := 0 + for offset < len(data) { + n, waitCh := sm.buf.WriteAvailable(data[offset:]) + offset += n + + if n > 0 { + sm.lock.Lock() + sm.drainCond.Signal() + sm.lock.Unlock() + } + + if waitCh != nil { + <-waitCh + } } } diff --git a/pkg/streamclient/streamwriter.go b/pkg/streamclient/streamwriter.go index 862f0c9cfb..730bbf3e16 100644 --- a/pkg/streamclient/streamwriter.go +++ b/pkg/streamclient/streamwriter.go @@ -22,7 +22,8 @@ type Writer struct { nextSeq int64 buffer []byte sentNotAcked int64 - lastAckedSeq int64 + maxAckedSeq int64 + maxAckedRwnd int64 finAcked bool canceled bool canceledChan chan struct{} @@ -38,7 +39,7 @@ func NewWriter(id string, readWindow int64, dataSender DataSender) *Writer { dataSender: dataSender, nextSeq: 0, sentNotAcked: 0, - lastAckedSeq: 0, + maxAckedSeq: 0, canceledChan: make(chan struct{}), } w.cond = sync.NewCond(&w.lock) @@ -54,12 +55,12 @@ func (w *Writer) RecvAck(ackPk wshrpc.CommandStreamAckData) { } ackedSeq := ackPk.Seq - if ackedSeq > w.lastAckedSeq { - w.lastAckedSeq = ackedSeq - } + rwnd := ackPk.RWnd if ackPk.Fin { w.finAcked = true + w.maxAckedSeq = ackedSeq + return } if ackPk.Cancel && !w.canceled { @@ -72,6 +73,15 @@ func (w *Writer) RecvAck(ackPk wshrpc.CommandStreamAckData) { return } + // Ignore stale ACKs using tuple comparison (seq, rwnd) + if ackedSeq < w.maxAckedSeq || (ackedSeq == w.maxAckedSeq && rwnd <= w.maxAckedRwnd) { + return + } + + // Update max acked tuple + w.maxAckedSeq = ackedSeq + w.maxAckedRwnd = rwnd + if !w.closed { if ackedSeq > (w.nextSeq - w.sentNotAcked) { ackedBytes := ackedSeq - (w.nextSeq - w.sentNotAcked) @@ -81,16 +91,16 @@ func (w *Writer) RecvAck(ackPk wshrpc.CommandStreamAckData) { } } - w.readWindow = ackPk.RWnd + w.readWindow = rwnd w.cond.Broadcast() } } -func (w *Writer) GetAckState() (lastAckedSeq int64, finAcked bool, canceled bool) { +func (w *Writer) GetAckState() (maxAckedSeq int64, finAcked bool, canceled bool) { w.lock.Lock() defer w.lock.Unlock() - return w.lastAckedSeq, w.finAcked, w.canceled + return w.maxAckedSeq, w.finAcked, w.canceled } func (w *Writer) GetCanceledChan() <-chan struct{} {