diff --git a/global/config_test.go b/global/config_test.go index 31314e5bf3..77421274bb 100644 --- a/global/config_test.go +++ b/global/config_test.go @@ -18,6 +18,7 @@ package global import ( + "encoding/json" "reflect" "testing" ) @@ -1160,13 +1161,22 @@ func TestDefaultTripleConfig(t *testing.T) { func TestHttp3ConfigClone(t *testing.T) { t.Run("clone_http3_config", func(t *testing.T) { http3 := &Http3Config{ - Enable: true, - Negotiation: false, + Enable: true, + Negotiation: false, + KeepAlivePeriod: "15s", + MaxIdleTimeout: "30s", + MaxIncomingStreams: 128, + MaxIncomingUniStreams: 64, } cloned := http3.Clone() assert.NotNil(t, cloned) + assert.NotSame(t, http3, cloned) assert.Equal(t, http3.Enable, cloned.Enable) assert.Equal(t, http3.Negotiation, cloned.Negotiation) + assert.Equal(t, http3.KeepAlivePeriod, cloned.KeepAlivePeriod) + assert.Equal(t, http3.MaxIdleTimeout, cloned.MaxIdleTimeout) + assert.Equal(t, http3.MaxIncomingStreams, cloned.MaxIncomingStreams) + assert.Equal(t, http3.MaxIncomingUniStreams, cloned.MaxIncomingUniStreams) }) t.Run("clone_nil_http3_config", func(t *testing.T) { @@ -1179,8 +1189,13 @@ func TestHttp3ConfigClone(t *testing.T) { http3 := DefaultHttp3Config() cloned := http3.Clone() assert.NotNil(t, cloned) + assert.NotSame(t, http3, cloned) assert.Equal(t, http3.Enable, cloned.Enable) assert.Equal(t, http3.Negotiation, cloned.Negotiation) + assert.Equal(t, http3.KeepAlivePeriod, cloned.KeepAlivePeriod) + assert.Equal(t, http3.MaxIdleTimeout, cloned.MaxIdleTimeout) + assert.Equal(t, http3.MaxIncomingStreams, cloned.MaxIncomingStreams) + assert.Equal(t, http3.MaxIncomingUniStreams, cloned.MaxIncomingUniStreams) }) } @@ -1189,8 +1204,91 @@ func TestDefaultHttp3Config(t *testing.T) { t.Run("default_http3_config", func(t *testing.T) { http3 := DefaultHttp3Config() assert.NotNil(t, http3) + assert.False(t, http3.Enable) + assert.True(t, http3.Negotiation) + assert.Empty(t, http3.KeepAlivePeriod) + assert.Empty(t, http3.MaxIdleTimeout) + assert.Zero(t, http3.MaxIncomingStreams) + assert.Zero(t, http3.MaxIncomingUniStreams) }) } + +func TestHttp3ConfigJSONTags(t *testing.T) { + http3 := &Http3Config{ + Enable: true, + Negotiation: true, + KeepAlivePeriod: "15s", + MaxIdleTimeout: "30s", + MaxIncomingStreams: 128, + MaxIncomingUniStreams: 64, + } + + data, err := json.Marshal(http3) + require.NoError(t, err) + assert.Contains(t, string(data), "\"keep-alive-period\":\"15s\"") + assert.Contains(t, string(data), "\"max-idle-timeout\":\"30s\"") + assert.Contains(t, string(data), "\"max-incoming-streams\":128") + assert.Contains(t, string(data), "\"max-incoming-uni-streams\":64") + + var decoded Http3Config + err = json.Unmarshal([]byte(`{ + "enable": true, + "negotiation": true, + "keep-alive-period": "15s", + "max-idle-timeout": "30s", + "max-incoming-streams": 128, + "max-incoming-uni-streams": 64 + }`), &decoded) + require.NoError(t, err) + assert.Equal(t, http3, &decoded) + + var compatDecoded Http3Config + err = json.Unmarshal([]byte(`{ + "enable": true, + "negotiation": true, + "keepAlivePeriod": "15s", + "maxIdleTimeout": "30s", + "maxIncomingStreams": 128, + "maxIncomingUniStreams": 64 + }`), &compatDecoded) + require.NoError(t, err) + assert.Equal(t, http3, &compatDecoded) + + var preferCanonical Http3Config + err = json.Unmarshal([]byte(`{ + "keep-alive-period": "15s", + "keepAlivePeriod": "99s", + "max-idle-timeout": "30s", + "maxIdleTimeout": "99s", + "max-incoming-streams": 128, + "maxIncomingStreams": 999, + "max-incoming-uni-streams": 64, + "maxIncomingUniStreams": 999 + }`), &preferCanonical) + require.NoError(t, err) + assert.Equal(t, "15s", preferCanonical.KeepAlivePeriod) + assert.Equal(t, "30s", preferCanonical.MaxIdleTimeout) + assert.Equal(t, int64(128), preferCanonical.MaxIncomingStreams) + assert.Equal(t, int64(64), preferCanonical.MaxIncomingUniStreams) + + var nullCanonical Http3Config + err = json.Unmarshal([]byte(`{ + "keep-alive-period": null, + "keepAlivePeriod": "99s", + "max-idle-timeout": null, + "maxIdleTimeout": "99s", + "max-incoming-streams": null, + "maxIncomingStreams": 999, + "max-incoming-uni-streams": null, + "maxIncomingUniStreams": 999 + }`), &nullCanonical) + require.NoError(t, err) + assert.Empty(t, nullCanonical.KeepAlivePeriod) + assert.Empty(t, nullCanonical.MaxIdleTimeout) + assert.Zero(t, nullCanonical.MaxIncomingStreams) + assert.Zero(t, nullCanonical.MaxIncomingUniStreams) +} + func TestConsumerConfigClone(t *testing.T) { t.Run("clone_full_consumer_config", func(t *testing.T) { consumer := &ConsumerConfig{ diff --git a/global/http3_config.go b/global/http3_config.go index c079afa526..a5b04b661c 100644 --- a/global/http3_config.go +++ b/global/http3_config.go @@ -17,6 +17,10 @@ package global +import ( + "encoding/json" +) + // Http3Config represents the config of http3 type Http3Config struct { // Whether to enable HTTP/3 support. @@ -34,14 +38,97 @@ type Http3Config struct { // ref: https://quic-go.net/docs/http3/server/#advertising-http3-via-alt-svc Negotiation bool `yaml:"negotiation" json:"negotiation,omitempty"` - // TODO: add more params about http3 + // KeepAlivePeriod defines how often to send keep-alive packets. + KeepAlivePeriod string `yaml:"keep-alive-period" json:"keep-alive-period,omitempty"` + + // MaxIdleTimeout defines the maximum idle timeout for QUIC connections. + MaxIdleTimeout string `yaml:"max-idle-timeout" json:"max-idle-timeout,omitempty"` + + // MaxIncomingStreams defines the maximum number of concurrent bidirectional streams. + MaxIncomingStreams int64 `yaml:"max-incoming-streams" json:"max-incoming-streams,omitempty"` + + // MaxIncomingUniStreams defines the maximum number of concurrent unidirectional streams. + MaxIncomingUniStreams int64 `yaml:"max-incoming-uni-streams" json:"max-incoming-uni-streams,omitempty"` +} + +func (t *Http3Config) UnmarshalJSON(data []byte) error { + type canonicalJSON struct { + Enable *bool `json:"enable"` + Negotiation *bool `json:"negotiation"` + KeepAlivePeriod *string `json:"keep-alive-period"` + MaxIdleTimeout *string `json:"max-idle-timeout"` + MaxIncomingStreams *int64 `json:"max-incoming-streams"` + MaxIncomingUniStreams *int64 `json:"max-incoming-uni-streams"` + } + type compatJSON struct { + KeepAlivePeriod *string `json:"keepAlivePeriod"` + MaxIdleTimeout *string `json:"maxIdleTimeout"` + MaxIncomingStreams *int64 `json:"maxIncomingStreams"` + MaxIncomingUniStreams *int64 `json:"maxIncomingUniStreams"` + } + + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + var canonical canonicalJSON + if err := json.Unmarshal(data, &canonical); err != nil { + return err + } + + var compat compatJSON + if err := json.Unmarshal(data, &compat); err != nil { + return err + } + + if canonical.Enable != nil { + t.Enable = *canonical.Enable + } + if canonical.Negotiation != nil { + t.Negotiation = *canonical.Negotiation + } + if _, ok := raw["keep-alive-period"]; ok { + if canonical.KeepAlivePeriod != nil { + t.KeepAlivePeriod = *canonical.KeepAlivePeriod + } + } else if compat.KeepAlivePeriod != nil { + t.KeepAlivePeriod = *compat.KeepAlivePeriod + } + if _, ok := raw["max-idle-timeout"]; ok { + if canonical.MaxIdleTimeout != nil { + t.MaxIdleTimeout = *canonical.MaxIdleTimeout + } + } else if compat.MaxIdleTimeout != nil { + t.MaxIdleTimeout = *compat.MaxIdleTimeout + } + if _, ok := raw["max-incoming-streams"]; ok { + if canonical.MaxIncomingStreams != nil { + t.MaxIncomingStreams = *canonical.MaxIncomingStreams + } + } else if compat.MaxIncomingStreams != nil { + t.MaxIncomingStreams = *compat.MaxIncomingStreams + } + if _, ok := raw["max-incoming-uni-streams"]; ok { + if canonical.MaxIncomingUniStreams != nil { + t.MaxIncomingUniStreams = *canonical.MaxIncomingUniStreams + } + } else if compat.MaxIncomingUniStreams != nil { + t.MaxIncomingUniStreams = *compat.MaxIncomingUniStreams + } + + return nil } // DefaultHttp3Config returns a default Http3Config instance. func DefaultHttp3Config() *Http3Config { return &Http3Config{ - Enable: false, - Negotiation: true, + Enable: false, + Negotiation: true, + KeepAlivePeriod: "", + MaxIdleTimeout: "", + MaxIncomingStreams: 0, + MaxIncomingUniStreams: 0, } } @@ -52,7 +139,11 @@ func (t *Http3Config) Clone() *Http3Config { } return &Http3Config{ - Enable: t.Enable, - Negotiation: t.Negotiation, + Enable: t.Enable, + Negotiation: t.Negotiation, + KeepAlivePeriod: t.KeepAlivePeriod, + MaxIdleTimeout: t.MaxIdleTimeout, + MaxIncomingStreams: t.MaxIncomingStreams, + MaxIncomingUniStreams: t.MaxIncomingUniStreams, } } diff --git a/protocol/triple/triple_protocol/http3_config.go b/protocol/triple/triple_protocol/http3_config.go new file mode 100644 index 0000000000..93a520ad89 --- /dev/null +++ b/protocol/triple/triple_protocol/http3_config.go @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package triple_protocol + +import ( + "fmt" + "time" +) + +import ( + "github.com/quic-go/quic-go" +) + +import ( + "dubbo.apache.org/dubbo-go/v3/global" +) + +func newQUICConfig(http3Config *global.Http3Config) (*quic.Config, error) { + quicConfig := &quic.Config{} + if http3Config == nil { + return quicConfig, nil + } + + if http3Config.KeepAlivePeriod != "" { + keepAlivePeriod, err := time.ParseDuration(http3Config.KeepAlivePeriod) + if err != nil { + return nil, fmt.Errorf("invalid http3 keep-alive-period %q: %w", http3Config.KeepAlivePeriod, err) + } + quicConfig.KeepAlivePeriod = keepAlivePeriod + } + + if http3Config.MaxIdleTimeout != "" { + maxIdleTimeout, err := time.ParseDuration(http3Config.MaxIdleTimeout) + if err != nil { + return nil, fmt.Errorf("invalid http3 max-idle-timeout %q: %w", http3Config.MaxIdleTimeout, err) + } + quicConfig.MaxIdleTimeout = maxIdleTimeout + } + + // Preserve quic-go defaults when these fields are left unset in config. + if http3Config.MaxIncomingStreams != 0 { + quicConfig.MaxIncomingStreams = http3Config.MaxIncomingStreams + } + if http3Config.MaxIncomingUniStreams != 0 { + quicConfig.MaxIncomingUniStreams = http3Config.MaxIncomingUniStreams + } + + return quicConfig, nil +} diff --git a/protocol/triple/triple_protocol/server.go b/protocol/triple/triple_protocol/server.go index 9859ed38bd..28240c6291 100644 --- a/protocol/triple/triple_protocol/server.go +++ b/protocol/triple/triple_protocol/server.go @@ -29,7 +29,6 @@ import ( "github.com/dubbogo/grpc-go" - "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" "golang.org/x/net/http2" @@ -214,15 +213,24 @@ func (s *Server) startHttp3(tlsConf *tls.Config) error { return fmt.Errorf("TRIPLE HTTP/3 Server must have TLS config, but TLS config is nil") } + var http3Config *global.Http3Config + if s.tripleConfig != nil { + http3Config = s.tripleConfig.Http3 + } + + quicConfig, err := newQUICConfig(http3Config) + if err != nil { + return err + } + s.http3Srv = &http3.Server{ Addr: s.addr, Handler: s.mux, // Adapt and enhance a generic tls.Config object into a configuration // specifically for HTTP/3 services. // ref: https://quic-go.net/docs/http3/server/#setting-up-a-http3server - TLSConfig: http3.ConfigureTLSConfig(tlsConf), - // TODO: Detailed QUIC configuration. - QUICConfig: &quic.Config{}, + TLSConfig: http3.ConfigureTLSConfig(tlsConf), + QUICConfig: quicConfig, } logger.Debugf("TRIPLE HTTP/3 Server starting on %v", s.addr) @@ -236,12 +244,22 @@ func (s *Server) startHttp2AndHttp3(tlsConf *tls.Config) error { return fmt.Errorf("TRIPLE HTTP/2 and HTTP/3 Server must have TLS config, but TLS config is nil") } + var http3Config *global.Http3Config + if s.tripleConfig != nil { + http3Config = s.tripleConfig.Http3 + } + + quicConfig, err := newQUICConfig(http3Config) + if err != nil { + return err + } + // Start HTTP/3 server first to get its configuration s.http3Srv = &http3.Server{ Addr: s.addr, Handler: s.mux, TLSConfig: http3.ConfigureTLSConfig(tlsConf), - QUICConfig: &quic.Config{}, + QUICConfig: quicConfig, } // Create Alt-Svc handler wrapper for HTTP/2 server diff --git a/protocol/triple/triple_protocol/server_test.go b/protocol/triple/triple_protocol/server_test.go index 33e0242368..a09733c96a 100644 --- a/protocol/triple/triple_protocol/server_test.go +++ b/protocol/triple/triple_protocol/server_test.go @@ -18,9 +18,11 @@ package triple_protocol import ( + "crypto/tls" "net/http" "net/url" "testing" + "time" ) import ( @@ -89,7 +91,7 @@ func TestServer_RegisterMuxHandle(t *testing.T) { }, }) for _, test := range tests { - err := srv.RegisterUnaryHandler(test.path, nil, nil) + err := test.registerFunc(srv, test.path) require.NoError(t, err) _, pattern := srv.mux.Handler(&http.Request{ URL: &url.URL{ @@ -99,3 +101,77 @@ func TestServer_RegisterMuxHandle(t *testing.T) { assert.Equal(t, test.path, pattern) } } + +func TestNewQUICConfig(t *testing.T) { + t.Run("defaults_preserved_when_unset", func(t *testing.T) { + quicConfig, err := newQUICConfig(&global.Http3Config{}) + require.NoError(t, err) + require.NotNil(t, quicConfig) + assert.Zero(t, quicConfig.KeepAlivePeriod) + assert.Zero(t, quicConfig.MaxIdleTimeout) + assert.Zero(t, quicConfig.MaxIncomingStreams) + assert.Zero(t, quicConfig.MaxIncomingUniStreams) + }) + + t.Run("explicit_fields_are_mapped", func(t *testing.T) { + quicConfig, err := newQUICConfig(&global.Http3Config{ + KeepAlivePeriod: "15s", + MaxIdleTimeout: "30s", + MaxIncomingStreams: 128, + MaxIncomingUniStreams: 64, + }) + require.NoError(t, err) + require.NotNil(t, quicConfig) + assert.Equal(t, 15*time.Second, quicConfig.KeepAlivePeriod) + assert.Equal(t, 30*time.Second, quicConfig.MaxIdleTimeout) + assert.Equal(t, int64(128), quicConfig.MaxIncomingStreams) + assert.Equal(t, int64(64), quicConfig.MaxIncomingUniStreams) + }) + + t.Run("invalid_keep_alive_period_returns_error", func(t *testing.T) { + quicConfig, err := newQUICConfig(&global.Http3Config{ + KeepAlivePeriod: "invalid", + }) + require.Error(t, err) + assert.Nil(t, quicConfig) + assert.ErrorContains(t, err, "keep-alive-period") + }) + + t.Run("invalid_max_idle_timeout_returns_error", func(t *testing.T) { + quicConfig, err := newQUICConfig(&global.Http3Config{ + MaxIdleTimeout: "invalid", + }) + require.Error(t, err) + assert.Nil(t, quicConfig) + assert.ErrorContains(t, err, "max-idle-timeout") + }) +} + +func TestServer_HTTP3PathsUseQUICConfigHelper(t *testing.T) { + t.Run("start_http3_returns_parse_error", func(t *testing.T) { + srv := NewServer("127.0.0.1:0", &global.TripleConfig{ + Http3: &global.Http3Config{ + KeepAlivePeriod: "invalid", + }, + }) + + err := srv.startHttp3(&tls.Config{}) + require.Error(t, err) + require.ErrorContains(t, err, "keep-alive-period") + assert.Nil(t, srv.http3Srv) + }) + + t.Run("start_http2_and_http3_returns_parse_error", func(t *testing.T) { + srv := NewServer("127.0.0.1:0", &global.TripleConfig{ + Http3: &global.Http3Config{ + MaxIdleTimeout: "invalid", + }, + }) + + err := srv.startHttp2AndHttp3(&tls.Config{}) + require.Error(t, err) + require.ErrorContains(t, err, "max-idle-timeout") + assert.Nil(t, srv.http3Srv) + assert.Nil(t, srv.httpSrv) + }) +}