Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions srtp/srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ package srtp

import (
"crypto/rand"
"errors"
"fmt"
"net"
"sync"

prtp "github.com/pion/rtp"
"github.com/pion/srtp/v3"
Expand Down Expand Up @@ -99,6 +101,7 @@ type Profile struct {

type Config = srtp.Config
type ContextOption = srtp.ContextOption
type SessionOption = func(session *session) error
type SessionKeys = srtp.SessionKeys

// Expects a byte slice containing MKI value encoded in big-endian.
Expand All @@ -107,25 +110,50 @@ func MasterKeyIndicator(mki []byte) ContextOption {
return srtp.MasterKeyIndicator(mki)
}

func NewSession(log logger.Logger, conn net.Conn, conf *Config) (rtp.Session, error) {
type SessionExpiredError error

// Lifetime, in packets encrypted, before session keys are invalidated
// Once limit is reached, WriteRTP will return SessionExpiredError
// Zero is the same as no limit
func WithLifetime(lifetime uint64) SessionOption {
return func(session *session) error {
session.lifetime = lifetime
return nil
}
}

func NewSession(log logger.Logger, conn net.Conn, conf *Config, opts ...SessionOption) (rtp.Session, error) {
s, err := srtp.NewSessionSRTP(conn, conf)
if err != nil {
return nil, err
}
return &session{log: log, s: s}, nil
ses := &session{log: log, s: s}
for i, opt := range opts {
if err := opt(ses); err != nil {
return nil, fmt.Errorf("option %d: %w", i, err)
}
}
return ses, nil
}

type session struct {
log logger.Logger
s *srtp.SessionSRTP
log logger.Logger
s *srtp.SessionSRTP
lifetime uint64
}

func (s *session) OpenWriteStream() (rtp.WriteStream, error) {
w, err := s.s.OpenWriteStream()
if err != nil {
return nil, err
}
return writeStream{w: w}, nil
newLifetime := uint64(s.lifetime)
return writeStream{
w: w,
lifetimeSet: s.lifetime > 0,
lifetimeRemaining: &newLifetime,
mu: &sync.RWMutex{},
}, nil
}

func (s *session) AcceptStream() (rtp.ReadStream, uint32, error) {
Expand All @@ -141,14 +169,29 @@ func (s *session) Close() error {
}

type writeStream struct {
w *srtp.WriteStreamSRTP
w *srtp.WriteStreamSRTP
lifetimeSet bool

mu *sync.RWMutex
lifetimeRemaining *uint64
}

func (w writeStream) String() string {

@alexlivekit alexlivekit Nov 6, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No clue why we're passing w writeStream by value and not by ref (w *writeStream). This is basically churning memory.

But this is why I had to do the ugly thing with the lock and *uint64.

return "SRTPWriteStream"
}

func (w writeStream) WriteRTP(h *prtp.Header, payload []byte) (int, error) {
if w.lifetimeSet {
w.mu.Lock()
remaining := *w.lifetimeRemaining
if remaining > 0 {
*w.lifetimeRemaining = remaining - 1
}
w.mu.Unlock()
if remaining == 0 {
return 0, SessionExpiredError(errors.New("SRTP profile key lifetime expired"))
}
}
return w.w.WriteRTP(h, payload)
}

Expand Down
Loading