From a5b80f2e3dcdd5356c043b72e1706b6b869068a5 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 6 Nov 2025 15:22:16 -0800 Subject: [PATCH 1/2] Initial draft --- srtp/srtp.go | 51 +++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/srtp/srtp.go b/srtp/srtp.go index 09de8e6..f5d271f 100644 --- a/srtp/srtp.go +++ b/srtp/srtp.go @@ -16,8 +16,10 @@ package srtp import ( "crypto/rand" + "errors" "fmt" "net" + "sync" prtp "github.com/pion/rtp" "github.com/pion/srtp/v3" @@ -99,6 +101,7 @@ type Profile struct { type Config = srtp.Config type ContextOption = srtp.ContextOption +type SessionOption = func(session *session) error type SessionKeys = srtp.SessionKeys // Expects a byte slice containing MKI value encoded in big-endian. @@ -107,17 +110,38 @@ func MasterKeyIndicator(mki []byte) ContextOption { return srtp.MasterKeyIndicator(mki) } -func NewSession(log logger.Logger, conn net.Conn, conf *Config) (rtp.Session, error) { +const ( + SessionExpiredError error = errors.New("SRTP profile key lifetime expired") +) + +// Lifetime, in packets encrypted, before session keys are invalidated +// Once limit is reached, WriteRTP will return SessionExpiredError +// Zero is the same as no limit +func WithLifetime(lifetime uint64) SessionOption { + return func(session *session) error { + session.lifetime = lifetime + return nil + } +} + +func NewSession(log logger.Logger, conn net.Conn, conf *Config, opts ...SessionOption) (rtp.Session, error) { s, err := srtp.NewSessionSRTP(conn, conf) if err != nil { return nil, err } - return &session{log: log, s: s}, nil + ses := &session{log: log, s: s} + for i, opt := range opts { + if err := opt(ses); err != nil { + return nil, fmt.Errorf("option %d: %w", i, err) + } + } + return ses, nil } type session struct { - log logger.Logger - s *srtp.SessionSRTP + log logger.Logger + s *srtp.SessionSRTP + lifetime uint64 } func (s *session) OpenWriteStream() (rtp.WriteStream, error) { @@ -125,7 +149,7 @@ func (s *session) OpenWriteStream() (rtp.WriteStream, error) { if err != nil { return nil, err } - return writeStream{w: w}, nil + return writeStream{w: w, lifetimeSet: s.lifetime > 0, lifetimeRemaining: s.lifetime}, nil } func (s *session) AcceptStream() (rtp.ReadStream, uint32, error) { @@ -141,7 +165,11 @@ func (s *session) Close() error { } type writeStream struct { - w *srtp.WriteStreamSRTP + w *srtp.WriteStreamSRTP + lifetimeSet bool + + mu sync.RWMutex + lifetimeRemaining uint64 } func (w writeStream) String() string { @@ -149,6 +177,17 @@ func (w writeStream) String() string { } func (w writeStream) WriteRTP(h *prtp.Header, payload []byte) (int, error) { + if w.lifetimeSet { + w.mu.Lock() + remaining := w.lifetimeRemaining + if w.lifetimeRemaining > 0 { + w.lifetimeRemaining-- + } + w.mu.Unlock() + if remaining == 0 { + return 0, SessionExpiredError + } + } return w.w.WriteRTP(h, payload) } From e21a83dc9e3c00d1975a8ee75496275d8817cea3 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 6 Nov 2025 15:35:13 -0800 Subject: [PATCH 2/2] Respecting lifetime limits --- srtp/srtp.go | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/srtp/srtp.go b/srtp/srtp.go index f5d271f..636180d 100644 --- a/srtp/srtp.go +++ b/srtp/srtp.go @@ -110,9 +110,7 @@ func MasterKeyIndicator(mki []byte) ContextOption { return srtp.MasterKeyIndicator(mki) } -const ( - SessionExpiredError error = errors.New("SRTP profile key lifetime expired") -) +type SessionExpiredError error // Lifetime, in packets encrypted, before session keys are invalidated // Once limit is reached, WriteRTP will return SessionExpiredError @@ -149,7 +147,13 @@ func (s *session) OpenWriteStream() (rtp.WriteStream, error) { if err != nil { return nil, err } - return writeStream{w: w, lifetimeSet: s.lifetime > 0, lifetimeRemaining: s.lifetime}, nil + newLifetime := uint64(s.lifetime) + return writeStream{ + w: w, + lifetimeSet: s.lifetime > 0, + lifetimeRemaining: &newLifetime, + mu: &sync.RWMutex{}, + }, nil } func (s *session) AcceptStream() (rtp.ReadStream, uint32, error) { @@ -168,8 +172,8 @@ type writeStream struct { w *srtp.WriteStreamSRTP lifetimeSet bool - mu sync.RWMutex - lifetimeRemaining uint64 + mu *sync.RWMutex + lifetimeRemaining *uint64 } func (w writeStream) String() string { @@ -179,13 +183,13 @@ func (w writeStream) String() string { func (w writeStream) WriteRTP(h *prtp.Header, payload []byte) (int, error) { if w.lifetimeSet { w.mu.Lock() - remaining := w.lifetimeRemaining - if w.lifetimeRemaining > 0 { - w.lifetimeRemaining-- + remaining := *w.lifetimeRemaining + if remaining > 0 { + *w.lifetimeRemaining = remaining - 1 } w.mu.Unlock() if remaining == 0 { - return 0, SessionExpiredError + return 0, SessionExpiredError(errors.New("SRTP profile key lifetime expired")) } } return w.w.WriteRTP(h, payload)