diff --git a/jitter/buffer.go b/jitter/buffer.go index fa0864e..5c95579 100644 --- a/jitter/buffer.go +++ b/jitter/buffer.go @@ -15,6 +15,7 @@ package jitter import ( + "math" "sync" "time" @@ -43,6 +44,7 @@ type Buffer struct { initialized bool prevSN uint16 + prevTS uint32 head *packet tail *packet @@ -51,6 +53,9 @@ type Buffer struct { pool *packet size int + + maxSequenceJump *uint16 + maxTimestampJump *uint32 } type Option func(*Buffer) @@ -112,6 +117,18 @@ func WithPacketLossHandler(handler func()) Option { } } +func WithMaxSequenceJump(max uint16) Option { + return func(b *Buffer) { + b.maxSequenceJump = &max + } +} + +func WithMaxTimestampJump(max uint32) Option { + return func(b *Buffer) { + b.maxTimestampJump = &max + } +} + func (b *Buffer) WithLogger(logger logger.Logger) *Buffer { b.logger = logger return b @@ -187,6 +204,86 @@ func (b *Buffer) Close() { b.closed.Break() } +func (b *Buffer) isLargeTimestampJump(current, prev uint32) bool { + if b.maxTimestampJump == nil || !b.initialized { + return false + } + + maxJump := uint32(*b.maxTimestampJump) + + cur := int64(current) + prv := int64(prev) + + forwardDiff := cur - prv + if forwardDiff < 0 { + forwardDiff += int64(math.MaxUint32) + 1 + } + + backwardDiff := prv - cur + if backwardDiff < 0 { + backwardDiff += int64(math.MaxUint32) + 1 + } + + return min(backwardDiff, forwardDiff) > int64(maxJump) +} + +func (b *Buffer) isLargeSequenceJump(current, prev uint16) bool { + if b.maxSequenceJump == nil || !b.initialized { + return false + } + + maxJump := int32(*b.maxSequenceJump) + + cur := int32(current) + prv := int32(prev) + + forwardDiff := cur - prv + if forwardDiff < 0 { + forwardDiff += int32(math.MaxUint16) + 1 + } + + backwardDiff := prv - cur + if backwardDiff < 0 { + backwardDiff += int32(math.MaxUint16) + 1 + } + + return min(backwardDiff, forwardDiff) > maxJump +} + +func (b *Buffer) reset() { + b.logger.Infow("resetting jitter buffer due to RTP discontinuity") + + dropped := 0 + for b.head != nil { + next := b.head.next + if !b.head.extPacket.Padding { + dropped++ + } + b.free(b.head) + b.head = next + } + b.tail = nil + + if dropped > 0 { + b.stats.PacketsDropped += uint64(dropped) + if b.onPacketLoss != nil { + b.onPacketLoss() + } + } + + b.initialized = false + b.prevSN = 0 + b.prevTS = 0 + + if !b.timer.Stop() { + select { + case <-b.timer.C: + default: + } + } + b.timer.Reset(b.latency) +} + // push adds a packet to the buffer func (b *Buffer) push(pkt *rtp.Packet, receivedAt time.Time) { b.stats.PacketsPushed++ @@ -197,8 +294,19 @@ func (b *Buffer) push(pkt *rtp.Packet, receivedAt time.Time) { } } + if b.isLargeTimestampJump(pkt.Timestamp, b.prevTS) || + b.isLargeSequenceJump(pkt.SequenceNumber, b.prevSN) { + b.logger.Infow("large RTP discontinuity detected", + "current_ts", pkt.Timestamp, + "prev_ts", b.prevTS, + "current_sn", pkt.SequenceNumber, + "prev_sn", b.prevSN, + ) + b.reset() + } + if b.initialized && before(pkt.SequenceNumber, b.prevSN) { - // packet expired + // packet expired (not after discontinuity reset) if !pkt.Padding { b.stats.PacketsDropped++ if b.onPacketLoss != nil { @@ -354,6 +462,7 @@ func (b *Buffer) popSample() []ExtPacket { func (b *Buffer) popHead() *packet { c := b.head b.prevSN = c.extPacket.SequenceNumber + b.prevTS = c.extPacket.Timestamp b.head = c.next if b.head == nil { b.tail = nil diff --git a/jitter/buffer_test.go b/jitter/buffer_test.go index aa73735..0cdc9c4 100644 --- a/jitter/buffer_test.go +++ b/jitter/buffer_test.go @@ -232,6 +232,80 @@ func TestDroppedPackets(t *testing.T) { }) } +func TestLargeSequenceJump(t *testing.T) { + out := make(chan []ExtPacket, 100) + b := NewBuffer(&testDepacketizer{}, testBufferLatency, chanFunc(t, out), WithMaxSequenceJump(1000)) + s := newTestStream() + + // push some normal packets + for i := 0; i < 10; i++ { + b.Push(s.gen(true, true)) + checkSample(t, out, 1) + } + + // simulate large sequence jump (should trigger reset) + s.largeSeqJump() + + // buffer should reset and accept new packets + for i := 0; i < 10; i++ { + b.Push(s.gen(true, true)) + checkSample(t, out, 1) + } + + stats := b.Stats() + require.Equal(t, uint64(20), stats.PacketsPushed) + require.Equal(t, uint64(20), stats.PacketsPopped) + require.Equal(t, uint64(20), stats.SamplesPopped) +} + +func TestLargeTimestampJump(t *testing.T) { + out := make(chan []ExtPacket, 100) + b := NewBuffer(&testDepacketizer{}, testBufferLatency, chanFunc(t, out), WithMaxTimestampJump(48000*30)) + s := ×tampStream{ + seq: uint16(rand.Uint32()), + ts: uint32(rand.Uint32()), + } + + // push some normal packets + for i := 0; i < 10; i++ { + b.Push(s.gen(true, true)) + checkSample(t, out, 1) + } + + // simulate large timestamp jump (should trigger reset) + s.largeTimestampJump() + + // buffer should reset and accept new packets + for i := 0; i < 10; i++ { + b.Push(s.gen(true, true)) + checkSample(t, out, 1) + } + + stats := b.Stats() + require.Equal(t, uint64(20), stats.PacketsPushed) + require.Equal(t, uint64(20), stats.PacketsPopped) + require.Equal(t, uint64(20), stats.SamplesPopped) +} + +func TestSequenceWraparound(t *testing.T) { + out := make(chan []ExtPacket, 100) + b := NewBuffer(&testDepacketizer{}, testBufferLatency, chanFunc(t, out)) + s := &stream{ + seq: math.MaxUint16 - 5, // start near wrap point + } + + // push packets across wraparound boundary + for i := 0; i < 15; i++ { + b.Push(s.gen(true, true)) + checkSample(t, out, 1) + } + + stats := b.Stats() + require.Equal(t, uint64(15), stats.PacketsPushed) + require.Equal(t, uint64(15), stats.PacketsPopped) + require.Equal(t, uint64(0), stats.PacketsLost) +} + func checkSample(t *testing.T, out chan []ExtPacket, expected int) { select { case sample := <-out: @@ -285,6 +359,36 @@ func (s *stream) discont() { s.seq += math.MaxUint16 / 2 } +func (s *stream) largeSeqJump() { + s.seq += 2000 // more than MAX_SEQUENCE_JUMP (1000) +} + +type timestampStream struct { + seq uint16 + ts uint32 +} + +func (s *timestampStream) gen(head, tail bool) *rtp.Packet { + p := &rtp.Packet{ + Header: rtp.Header{ + Marker: tail, + SequenceNumber: s.seq, + Timestamp: s.ts, + }, + Payload: make([]byte, defaultPacketSize), + } + if head { + copy(p.Payload, headerBytes) + } + s.seq++ + s.ts += 960 // typical increment for 20ms at 48kHz + return p +} + +func (s *timestampStream) largeTimestampJump() { + s.ts += 8000 * 60 // 60 seconds worth of samples (more than MAX_TIMESTAMP_JUMP) +} + const defaultPacketSize = 200 var headerBytes = []byte{0xaa, 0xaa}