diff --git a/srtp/srtp.go b/srtp/srtp.go index 09de8e6..636180d 100644 --- a/srtp/srtp.go +++ b/srtp/srtp.go @@ -16,8 +16,10 @@ package srtp import ( "crypto/rand" + "errors" "fmt" "net" + "sync" prtp "github.com/pion/rtp" "github.com/pion/srtp/v3" @@ -99,6 +101,7 @@ type Profile struct { type Config = srtp.Config type ContextOption = srtp.ContextOption +type SessionOption = func(session *session) error type SessionKeys = srtp.SessionKeys // Expects a byte slice containing MKI value encoded in big-endian. @@ -107,17 +110,36 @@ func MasterKeyIndicator(mki []byte) ContextOption { return srtp.MasterKeyIndicator(mki) } -func NewSession(log logger.Logger, conn net.Conn, conf *Config) (rtp.Session, error) { +type SessionExpiredError error + +// Lifetime, in packets encrypted, before session keys are invalidated +// Once limit is reached, WriteRTP will return SessionExpiredError +// Zero is the same as no limit +func WithLifetime(lifetime uint64) SessionOption { + return func(session *session) error { + session.lifetime = lifetime + return nil + } +} + +func NewSession(log logger.Logger, conn net.Conn, conf *Config, opts ...SessionOption) (rtp.Session, error) { s, err := srtp.NewSessionSRTP(conn, conf) if err != nil { return nil, err } - return &session{log: log, s: s}, nil + ses := &session{log: log, s: s} + for i, opt := range opts { + if err := opt(ses); err != nil { + return nil, fmt.Errorf("option %d: %w", i, err) + } + } + return ses, nil } type session struct { - log logger.Logger - s *srtp.SessionSRTP + log logger.Logger + s *srtp.SessionSRTP + lifetime uint64 } func (s *session) OpenWriteStream() (rtp.WriteStream, error) { @@ -125,7 +147,13 @@ func (s *session) OpenWriteStream() (rtp.WriteStream, error) { if err != nil { return nil, err } - return writeStream{w: w}, nil + newLifetime := uint64(s.lifetime) + return writeStream{ + w: w, + lifetimeSet: s.lifetime > 0, + lifetimeRemaining: &newLifetime, + mu: &sync.RWMutex{}, + }, nil } func (s *session) AcceptStream() (rtp.ReadStream, uint32, error) { @@ -141,7 +169,11 @@ func (s *session) Close() error { } type writeStream struct { - w *srtp.WriteStreamSRTP + w *srtp.WriteStreamSRTP + lifetimeSet bool + + mu *sync.RWMutex + lifetimeRemaining *uint64 } func (w writeStream) String() string { @@ -149,6 +181,17 @@ func (w writeStream) String() string { } func (w writeStream) WriteRTP(h *prtp.Header, payload []byte) (int, error) { + if w.lifetimeSet { + w.mu.Lock() + remaining := *w.lifetimeRemaining + if remaining > 0 { + *w.lifetimeRemaining = remaining - 1 + } + w.mu.Unlock() + if remaining == 0 { + return 0, SessionExpiredError(errors.New("SRTP profile key lifetime expired")) + } + } return w.w.WriteRTP(h, payload) }