diff --git a/jitter/buffer.go b/jitter/buffer.go index 3defff9..888e4e3 100644 --- a/jitter/buffer.go +++ b/jitter/buffer.go @@ -35,6 +35,7 @@ type Buffer struct { initialized bool prevSN uint16 + prevTS uint32 head *packet tail *packet @@ -326,6 +327,7 @@ func (b *Buffer) popSample() []*rtp.Packet { func (b *Buffer) popHead() *packet { c := b.head b.prevSN = c.packet.SequenceNumber + b.prevTS = c.packet.Timestamp b.head = c.next if b.head == nil { b.tail = nil @@ -335,6 +337,14 @@ func (b *Buffer) popHead() *packet { return c } +func (b *Buffer) LastSequenceNumber() uint16 { + return b.prevSN +} + +func (b *Buffer) LastTimestamp() uint32 { + return b.prevTS +} + func before(a, b uint16) bool { return (b-a)&0x8000 == 0 } diff --git a/opus/opus.go b/opus/opus.go index 4b2af7c..4e1e593 100644 --- a/opus/opus.go +++ b/opus/opus.go @@ -25,7 +25,6 @@ import ( "github.com/livekit/protocol/logger" "github.com/livekit/media-sdk" - "github.com/livekit/media-sdk/rtp" "github.com/livekit/media-sdk/webm" ) @@ -72,7 +71,7 @@ func Encode(w Writer, channels int, logger logger.Logger) (media.PCM16Writer, er return &encoder{ w: w, enc: enc, - buf: make([]byte, w.SampleRate()/rtp.DefFramesPerSec*channels), + buf: make([]byte, w.SampleRate()/media.DefFramesPerSec*channels), logger: logger, }, nil } @@ -132,10 +131,26 @@ func (d *decoder) WriteSample(in Sample) error { media.StereoToMono(d.buf2, returnData) returnData = d.buf2[:n2] } - return d.w.WriteSample(returnData) } +// If FEC data is not available, it falls back to PLC automatically +func (d *decoder) DecodeFEC(data []byte, pcm []int16) error { + if d.dec == nil { + return fmt.Errorf("decoder not initialized") + } + + return d.dec.DecodeFEC(data, pcm) +} + +func (d *decoder) DecodePLC(pcm []int16) error { + if d.dec == nil { + return fmt.Errorf("decoder not initialized") + } + + return d.dec.DecodePLC(pcm) +} + func (d *decoder) resetForSample(in Sample) (int, error) { channels := int(C.opus_packet_get_nb_channels((*C.uchar)(&in[0]))) @@ -147,7 +162,7 @@ func (d *decoder) resetForSample(in Sample) (int, error) { } d.dec = dec - d.buf = make([]int16, d.w.SampleRate()/rtp.DefFramesPerSec*channels) + d.buf = make([]int16, d.w.SampleRate()/media.DefFramesPerSec*channels) d.lastChannels = channels } diff --git a/opus/opus_jitter.go b/opus/opus_jitter.go new file mode 100644 index 0000000..39c66d2 --- /dev/null +++ b/opus/opus_jitter.go @@ -0,0 +1,163 @@ +package opus + +import ( + "time" + + "github.com/livekit/media-sdk" + "github.com/livekit/media-sdk/jitter" + "github.com/livekit/media-sdk/rtp" + "github.com/livekit/protocol/logger" +) + +const ( + opusJitterMaxLatency = 60 * time.Millisecond + opusDTXFrameLength = 1 +) + +func HandleOpusJitter(h rtp.Handler, pcmWriter media.PCM16Writer, targetChannels int) rtp.Handler { + handler := &opusJitterHandler{ + h: h, + err: make(chan error, 1), + logger: logger.GetLogger(), + } + + dec, err := Decode(pcmWriter, targetChannels, handler.logger) + if err != nil { + handler.err <- err + return handler + } + handler.decoder = dec.(*decoder) + + handler.buf = jitter.NewBuffer( + rtp.AudioDepacketizer{}, + opusJitterMaxLatency, + func(packets []*rtp.Packet) { + for _, p := range packets { + handler.handleRTP(p) + } + }, + jitter.WithPacketLossHandler(func() { + handler.pendingLoss = true + }), + ) + + return handler +} + +type opusJitterHandler struct { + h rtp.Handler + buf *jitter.Buffer + decoder *decoder + err chan error + logger logger.Logger + nextPacket *rtp.Packet + lastPacket *rtp.Packet + pendingLoss bool +} + +func (r *opusJitterHandler) String() string { + return "OpusJitter -> " + r.h.String() +} + +func (r *opusJitterHandler) HandleRTP(h *rtp.Header, payload []byte) error { + r.buf.Push(&rtp.Packet{Header: *h, Payload: payload}) + select { + case err := <-r.err: + return err + default: + return nil + } +} + +func (r *opusJitterHandler) handleRTP(p *rtp.Packet) { + isDtx := len(p.Payload) == opusDTXFrameLength + + // Not sure what to do if we have a pending loss and the packet is DTX. + if r.pendingLoss && !isDtx { + // Store the next packet for FEC + r.nextPacket = p + r.handlePacketLoss() + r.pendingLoss = false + } + + if r.lastPacket != nil && (isDtx || len(r.lastPacket.Payload) == opusDTXFrameLength) { + silenceSamples := int(p.Timestamp - r.lastPacket.Timestamp) + if silenceSamples > 0 { + silenceBuf := make([]int16, silenceSamples*r.decoder.targetChannels) + if err := r.decoder.w.WriteSample(silenceBuf); err != nil { + r.logger.Warnw("failed to write silence", err) + } + } + + if isDtx { + r.lastPacket = p + return + } + } + + if err := r.decoder.WriteSample(p.Payload); err != nil { + r.logger.Warnw("failed to decode packet", err) + } + + r.lastPacket = p +} + +func (r *opusJitterHandler) handlePacketLoss() { + if r.decoder == nil || r.nextPacket == nil { + return + } + + lostPackets := int(r.nextPacket.SequenceNumber - r.buf.LastSequenceNumber() - 1) + if lostPackets <= 0 { + return + } + + lastTs := r.buf.LastTimestamp() + nextTs := r.nextPacket.Timestamp + + totalSamples := int(nextTs - lastTs) + if totalSamples <= 0 { + return + } + + samplesPerPacket := totalSamples / lostPackets + + if lostPackets > 1 { + // For mono audio, if we call DecodePLC right after a + // SFU generated mono silence, the concealment might not be proper. + // But, we need to pass the buffer for the exact duration of the lost audio. + plcSamples := samplesPerPacket * (lostPackets - 1) * r.decoder.lastChannels + buf := make([]int16, plcSamples) + + err := r.decoder.DecodePLC(buf) + if err != nil { + r.logger.Warnw("failed to recover lost packets with PLC", err) + return + } + _ = r.decoder.w.WriteSample(buf) + } + + // Should we reset for the next packet before calling DecodeFEC? + // This will update the decoder's state for the next packet so it might help. + // But, it might also cause some issues if the next packet is SFU generated silence. + channels, err := r.decoder.resetForSample(r.nextPacket.Payload) + if err != nil { + r.logger.Warnw("failed to reset decoder for FEC", err) + return + } + + buf := make([]int16, samplesPerPacket*channels) + err = r.decoder.DecodeFEC(r.nextPacket.Payload, buf) + if err != nil { + r.logger.Warnw("failed to recover last lost packet with FEC", err) + return + } + _ = r.decoder.w.WriteSample(buf) +} + +func (r *opusJitterHandler) Close() error { + if r.decoder != nil { + return r.decoder.Close() + } + return nil +} diff --git a/rtp/jitter.go b/rtp/jitter.go index 804a81d..e161e3c 100644 --- a/rtp/jitter.go +++ b/rtp/jitter.go @@ -33,7 +33,7 @@ func HandleJitter(h Handler) Handler { } // Jitter buffer expects to be closed (to stop the timer), but handler interface doesn't allow it. // This should be fine, because GC can now collect timers and goroutines blocked on them if they are not referenced. - handler.buf = jitter.NewBuffer(audioDepacketizer{}, jitterMaxLatency, func(packets []*rtp.Packet) { + handler.buf = jitter.NewBuffer(AudioDepacketizer{}, jitterMaxLatency, func(packets []*rtp.Packet) { for _, p := range packets { handler.handleRTP(p) } @@ -73,16 +73,16 @@ func (r *jitterHandler) HandleRTP(h *rtp.Header, payload []byte) error { } } -type audioDepacketizer struct{} +type AudioDepacketizer struct{} -func (d audioDepacketizer) Unmarshal(packet []byte) ([]byte, error) { +func (d AudioDepacketizer) Unmarshal(packet []byte) ([]byte, error) { return packet, nil } -func (d audioDepacketizer) IsPartitionHead(payload []byte) bool { +func (d AudioDepacketizer) IsPartitionHead(payload []byte) bool { return true } -func (d audioDepacketizer) IsPartitionTail(marker bool, payload []byte) bool { +func (d AudioDepacketizer) IsPartitionTail(marker bool, payload []byte) bool { return true }