diff --git a/examples/realtime_encoder_integration.go b/examples/realtime_encoder_integration.go index 9b74ef5..fcef0bc 100644 --- a/examples/realtime_encoder_integration.go +++ b/examples/realtime_encoder_integration.go @@ -57,7 +57,7 @@ func main() { // Create RTCSender with BWE capabilities rtcSender, err := sender.NewRTCSender( sender.DefaultInterceptors(), - sender.GCC(initialBitrate), // Initial bitrate 500 kbps + sender.GCC(initialBitrate, 0), // Initial bitrate 500 kbps, no max cap sender.SetLoggerFactory(loggerFactory), ) if err != nil { diff --git a/sender/option.go b/sender/option.go index 911beac..6158556 100644 --- a/sender/option.go +++ b/sender/option.go @@ -75,11 +75,24 @@ func CCLogWriter(w io.Writer) Option { } } -// GCC returns an Option that configures Google Congestion Control with the specified initial bitrate. -func GCC(initialBitrate int) Option { +// GCC returns an Option that configures Google Congestion Control with the +// specified initial bitrate and max bitrate (in bps). A maxBitrate of 0 means +// no cap (uses GCC default of 50 Mbps). +func GCC(initialBitrate, maxBitrate int) Option { return func(sender ConfigurableWebRTCSender) error { + if rtcSender, ok := sender.(*RTCSender); ok { + rtcSender.gccConfigured = true + + return rtcSender.setupGCC(initialBitrate, maxBitrate) + } + // Fallback for other ConfigurableWebRTCSender types. controller, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) { - return gcc.NewSendSideBWE(gcc.SendSideBWEInitialBitrate(initialBitrate)) + opts := []gcc.Option{gcc.SendSideBWEInitialBitrate(initialBitrate)} + if maxBitrate > 0 { + opts = append(opts, gcc.SendSideBWEMaxBitrate(maxBitrate)) + } + + return gcc.NewSendSideBWE(opts...) }) if err != nil { return err diff --git a/sender/option_test.go b/sender/option_test.go index c0705b7..989d7cd 100644 --- a/sender/option_test.go +++ b/sender/option_test.go @@ -98,7 +98,7 @@ func TestCCLogWriter(t *testing.T) { func TestGCC(t *testing.T) { initialBitrate := 1000000 - option := GCC(initialBitrate) + option := GCC(initialBitrate, 0) require.NotNil(t, option) // Test that option is a function @@ -107,6 +107,40 @@ func TestGCC(t *testing.T) { // Note: Full testing would require WebRTC setup, so we just test the option creation } +func TestGCC_WithMaxBitrate(t *testing.T) { + option := GCC(500_000, 1_500_000) + require.NotNil(t, option) + assert.IsType(t, Option(nil), option) +} + +func TestGCC_AppliedToMock(t *testing.T) { + // Exercises the fallback path for non-RTCSender types. + mock := &MockConfigurableWebRTCSender{ + mediaEngine: &webrtc.MediaEngine{}, + registry: &interceptor.Registry{}, + } + err := mock.mediaEngine.RegisterDefaultCodecs() + require.NoError(t, err) + + option := GCC(500_000, 0) + err = option(mock) + require.NoError(t, err) +} + +func TestGCC_AppliedToMockWithMaxBitrate(t *testing.T) { + // Exercises the fallback path with maxBitrate > 0. + mock := &MockConfigurableWebRTCSender{ + mediaEngine: &webrtc.MediaEngine{}, + registry: &interceptor.Registry{}, + } + err := mock.mediaEngine.RegisterDefaultCodecs() + require.NoError(t, err) + + option := GCC(500_000, 1_500_000) + err = option(mock) + require.NoError(t, err) +} + func TestSetLoggerFactory(t *testing.T) { loggerFactory := plogging.NewDefaultLoggerFactory() diff --git a/sender/rtc_sender.go b/sender/rtc_sender.go index 3f4e612..73266bc 100644 --- a/sender/rtc_sender.go +++ b/sender/rtc_sender.go @@ -118,6 +118,9 @@ type RTCSender struct { // Logging ccLogWriter io.Writer log logging.LeveledLogger + + // gccConfigured is true when GCC was set up via the GCC option. + gccConfigured bool } // SetOnEncodedFrame registers a callback invoked after each VP8 frame is @@ -148,31 +151,39 @@ func NewRTCSender(opts ...Option) (*RTCSender, error) { return nil, err } - // Set up GCC bandwidth estimation by default - if err := sender.setupGCC(1_000_000); err != nil { // Default initial bitrate: 1Mbps - return nil, err - } - // Register the stats interceptor so GetTrackStats can return RTP/RTCP // counters (PacketsSent, RoundTripTime) per track. if err := sender.setupStats(); err != nil { return nil, err } - // Apply options directly to RTCSender + // Apply options first (may include custom GCC config) for _, opt := range opts { if err := opt(sender); err != nil { return nil, err } } + // Set up default GCC only if no GCC option was provided + if !sender.gccConfigured { + if err := sender.setupGCC(1_000_000, 0); err != nil { // Default initial bitrate: 1Mbps, no max + return nil, err + } + } + return sender, nil } -// setupGCC sets up Google Congestion Control with the specified initial bitrate. -func (s *RTCSender) setupGCC(initialBitrate int) error { +// setupGCC sets up Google Congestion Control with the specified initial and max bitrate. +// A maxBitrate of 0 means no cap (uses GCC default of 50 Mbps). +func (s *RTCSender) setupGCC(initialBitrate, maxBitrate int) error { controller, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) { - return gcc.NewSendSideBWE(gcc.SendSideBWEInitialBitrate(initialBitrate)) + opts := []gcc.Option{gcc.SendSideBWEInitialBitrate(initialBitrate)} + if maxBitrate > 0 { + opts = append(opts, gcc.SendSideBWEMaxBitrate(maxBitrate)) + } + + return gcc.NewSendSideBWE(opts...) }) if err != nil { return err diff --git a/sender/rtc_sender_test.go b/sender/rtc_sender_test.go index cfdc581..680a2e9 100644 --- a/sender/rtc_sender_test.go +++ b/sender/rtc_sender_test.go @@ -66,6 +66,29 @@ func TestNewRTCSender(t *testing.T) { var _ ConfigurableWebRTCSender = sender } +func TestNewRTCSender_WithGCCOption(t *testing.T) { + // When GCC option is provided, the default setupGCC should be skipped. + sender, err := NewRTCSender(GCC(500_000, 0)) + require.NoError(t, err) + require.NotNil(t, sender) + assert.True(t, sender.gccConfigured) +} + +func TestNewRTCSender_WithGCCMaxBitrate(t *testing.T) { + sender, err := NewRTCSender(GCC(500_000, 1_500_000)) + require.NoError(t, err) + require.NotNil(t, sender) + assert.True(t, sender.gccConfigured) +} + +func TestNewRTCSender_DefaultGCC(t *testing.T) { + // Without GCC option, default GCC should be set up. + sender, err := NewRTCSender() + require.NoError(t, err) + require.NotNil(t, sender) + assert.False(t, sender.gccConfigured) +} + func TestVideoTrackInfo_Validation(t *testing.T) { tests := []struct { name string diff --git a/vnet/flow.go b/vnet/flow.go index f95720a..72fbee7 100644 --- a/vnet/flow.go +++ b/vnet/flow.go @@ -219,7 +219,7 @@ func createWebRTCSender( commonOpts := []sender.Option{ sender.SetVnet(leftVnet, []string{publicIPLeft}), sender.PacketLogWriter(loggers.rtpLogger, loggers.rtcpLogger), - sender.GCC(100_000), + sender.GCC(100_000, 0), sender.CCLogWriter(loggers.ccLogger), sender.SetLoggerFactory(loggerFactory), }