diff --git a/tls/common.go b/tls/common.go index 0e397d8a..2e41a6df 100644 --- a/tls/common.go +++ b/tls/common.go @@ -123,8 +123,8 @@ const ( // CurveID is the type of a TLS identifier for an elliptic curve. See // https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8. // -// In TLS 1.3, this type is called NamedGroup, but at this time this library -// only supports Elliptic Curve based groups. See RFC 8446, Section 4.2.7. +// In TLS 1.3, this type is called NamedGroup. This library historically used it +// for elliptic curves, but it can represent any TLS 1.3 (EC / hybrid / PQ) group. type CurveID uint16 const ( @@ -132,6 +132,10 @@ const ( CurveP384 CurveID = 24 CurveP521 CurveID = 25 X25519 CurveID = 29 + + // Hybrid PQ key exchange groups (TLS 1.3 NamedGroup) + SecP256r1MLKEM768 CurveID = 4587 + X25519MLKEM768 CurveID = 4588 ) func (curveID *CurveID) MarshalJSON() ([]byte, error) { diff --git a/tls/handshake_client.go b/tls/handshake_client.go index e4bffb52..8d864a7b 100644 --- a/tls/handshake_client.go +++ b/tls/handshake_client.go @@ -222,7 +222,7 @@ func (c *ClientFingerprintConfiguration) marshal(config *Config) ([]byte, error) return hello, nil } -func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { +func (c *Conn) makeClientHello() (*clientHelloMsg, map[CurveID]tls13KeyShare, error) { config := c.config if len(config.ServerName) == 0 && !config.InsecureSkipVerify { return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") @@ -306,22 +306,54 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms } - var params ecdheParameters + var keySharesByGroup map[CurveID]tls13KeyShare if hello.supportedVersions[0] == VersionTLS13 { hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13()...) - curveID := config.curvePreferences()[0] - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { - return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + prefs := config.curvePreferences() + if len(prefs) == 0 { + return nil, nil, errors.New("tls: no supported key exchange mechanisms (no curve preferences)") } - params, err = generateECDHEParameters(config.rand(), curveID) - if err != nil { - return nil, nil, err + + // By default, send a single key_share. + // If ML-KEM hybrid is explicitly enabled as the top preference, also send X25519 as fallback. + shareGroups := []CurveID{prefs[0]} + if prefs[0] == X25519MLKEM768 { + // Ensure compatibility with servers that don't support the hybrid group. + if prefs[0] != X25519 { + shareGroups = append(shareGroups, X25519) + } + } + + hello.keyShares = make([]keyShare, 0, len(shareGroups)) + keySharesByGroup = make(map[CurveID]tls13KeyShare, len(shareGroups)) + + seen := make(map[CurveID]struct{}, len(shareGroups)) + for _, group := range shareGroups { + if _, ok := seen[group]; ok { + continue + } + seen[group] = struct{}{} + + ks, genErr := generateTLS13KeyShare(config.rand(), group) + if genErr != nil { + // If a group is not supported/implemented, skip it. + continue + } + + hello.keyShares = append(hello.keyShares, keyShare{ + group: group, + data: ks.PublicKey(), + }) + keySharesByGroup[group] = ks + } + + if len(hello.keyShares) == 0 { + return nil, nil, errors.New("tls: no supported key exchange mechanisms (no key shares)") } - hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} } - return hello, params, nil + return hello, keySharesByGroup, nil } func (c *Conn) clientHandshake() (err error) { @@ -333,7 +365,7 @@ func (c *Conn) clientHandshake() (err error) { var session *ClientSessionState var sessionCache ClientSessionCache var cacheKey string - var ecdheParams ecdheParameters + var keySharesByGroup map[CurveID]tls13KeyShare // This may be a renegotiation handshake, in which case some fields // need to be reset. @@ -422,7 +454,7 @@ func (c *Conn) clientHandshake() (err error) { sessionCache = nil } else { - hello, ecdheParams, err = c.makeClientHello() + hello, keySharesByGroup, err = c.makeClientHello() if err != nil { return err } @@ -489,13 +521,13 @@ func (c *Conn) clientHandshake() (err error) { if c.vers == VersionTLS13 { hs := &clientHandshakeStateTLS13{ - c: c, - serverHello: serverHello, - hello: hello, - ecdheParams: ecdheParams, - session: session, - earlySecret: earlySecret, - binderKey: binderKey, + c: c, + serverHello: serverHello, + hello: hello, + keySharesByGroup: keySharesByGroup, + session: session, + earlySecret: earlySecret, + binderKey: binderKey, } // In TLS 1.3, session tickets are delivered after the handshake. diff --git a/tls/handshake_client_tls13.go b/tls/handshake_client_tls13.go index db8c1912..077edb19 100644 --- a/tls/handshake_client_tls13.go +++ b/tls/handshake_client_tls13.go @@ -16,10 +16,10 @@ import ( ) type clientHandshakeStateTLS13 struct { - c *Conn - serverHello *serverHelloMsg - hello *clientHelloMsg - ecdheParams ecdheParameters + c *Conn + serverHello *serverHelloMsg + hello *clientHelloMsg + keySharesByGroup map[CurveID]tls13KeyShare session *ClientSessionState earlySecret []byte @@ -34,7 +34,7 @@ type clientHandshakeStateTLS13 struct { trafficSecret []byte // client_application_traffic_secret_0 } -// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and, +// handshake requires hs.c, hs.hello, hs.serverHello, hs.keySharesByGroup, and, // optionally, hs.session, hs.earlySecret and hs.binderKey to be set. func (hs *clientHandshakeStateTLS13) handshake() error { // The server must not select TLS 1.3 in a renegotiation. See RFC 8446, @@ -45,7 +45,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } // Consistency check on the presence of a keyShare and its parameters. - if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 { + if len(hs.hello.keyShares) == 0 || hs.keySharesByGroup == nil { return hs.c.sendAlert(AlertInternalError) } @@ -219,21 +219,20 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { c.sendAlert(AlertIllegalParameter) return errors.New("tls: server selected unsupported group") } - if hs.ecdheParams.CurveID() == curveID { + if _, ok := hs.keySharesByGroup[curveID]; ok { c.sendAlert(AlertIllegalParameter) return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") } - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { - c.sendAlert(AlertInternalError) - return errors.New("tls: CurvePreferences includes unsupported curve") - } - params, err := generateECDHEParameters(c.config.rand(), curveID) + ks, err := generateTLS13KeyShare(c.config.rand(), curveID) if err != nil { c.sendAlert(AlertInternalError) return err } - hs.ecdheParams = params - hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + if hs.keySharesByGroup == nil { + hs.keySharesByGroup = make(map[CurveID]tls13KeyShare) + } + hs.keySharesByGroup[curveID] = ks + hs.hello.keyShares = []keyShare{{group: curveID, data: ks.PublicKey()}} } hs.hello.raw = nil @@ -307,7 +306,9 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error { c.sendAlert(AlertIllegalParameter) return errors.New("tls: server did not send a key share") } - if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() { + + ks, ok := hs.keySharesByGroup[hs.serverHello.serverShare.group] + if !ok || ks == nil { c.sendAlert(AlertIllegalParameter) return errors.New("tls: server selected unsupported group") } @@ -345,10 +346,16 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error { func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { c := hs.c - sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data) - if sharedKey == nil { + ks, ok := hs.keySharesByGroup[hs.serverHello.serverShare.group] + if !ok || ks == nil { + c.sendAlert(AlertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + + sharedKey, err := ks.SharedKey(hs.serverHello.serverShare.data) + if err != nil { c.sendAlert(AlertIllegalParameter) - return errors.New("tls: invalid server key share") + return err } earlySecret := hs.earlySecret @@ -365,7 +372,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { serverHandshakeTrafficLabel, hs.transcript) c.in.setTrafficSecret(hs.suite, serverSecret) - err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) + err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) if err != nil { c.sendAlert(AlertInternalError) return err diff --git a/tls/handshake_server_tls13.go b/tls/handshake_server_tls13.go index 2cd9edf9..f1d766e3 100644 --- a/tls/handshake_server_tls13.go +++ b/tls/handshake_server_tls13.go @@ -216,21 +216,13 @@ GroupSelection: clientKeyShare = &hs.clientHello.keyShares[0] } - if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok { - c.sendAlert(AlertInternalError) - return errors.New("tls: CurvePreferences includes unsupported curve") - } - params, err := generateECDHEParameters(c.config.rand(), selectedGroup) + serverShareData, sharedKey, err := generateTLS13ServerShareAndSharedKey(c.config.rand(), selectedGroup, clientKeyShare.data) if err != nil { - c.sendAlert(AlertInternalError) - return err - } - hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()} - hs.sharedKey = params.SharedKey(clientKeyShare.data) - if hs.sharedKey == nil { c.sendAlert(AlertIllegalParameter) - return errors.New("tls: invalid client key share") + return err } + hs.hello.serverShare = keyShare{group: selectedGroup, data: serverShareData} + hs.sharedKey = sharedKey c.serverName = hs.clientHello.serverName return nil diff --git a/tls/key_agreement.go b/tls/key_agreement.go index 9bde1a39..b09529a5 100644 --- a/tls/key_agreement.go +++ b/tls/key_agreement.go @@ -388,6 +388,9 @@ type ecdheKeyAgreement struct { func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { var curveID CurveID for _, c := range clientHello.supportedCurves { + if c == X25519MLKEM768 { + continue // ML-KEM hybrid group is TLS 1.3 (key_share) only. + } if config.supportsCurve(c) { curveID = c break diff --git a/tls/key_schedule.go b/tls/key_schedule.go index b9d7a824..bf811985 100644 --- a/tls/key_schedule.go +++ b/tls/key_schedule.go @@ -7,6 +7,7 @@ package tls import ( "crypto/elliptic" "crypto/hmac" + "crypto/mlkem" "errors" "hash" "io" @@ -32,6 +33,13 @@ const ( trafficUpdateLabel = "traffic upd" ) +const ( + x25519ShareSize = 32 + mlkem768EKSize = mlkem.EncapsulationKeySize768 // 1184 + mlkem768CTSize = mlkem.CiphertextSize768 // 1088 + mlkemSSSize = 32 // ML-KEM shared secret size +) + // expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte { var hkdfLabel cryptobyte.Builder @@ -113,6 +121,12 @@ type ecdheParameters interface { MakeLog() (*jsonKeys.ECPoint, *jsonKeys.ECDHPrivateParams) } +type tls13KeyShare interface { + Group() CurveID + PublicKey() []byte + SharedKey(serverShare []byte) ([]byte, error) +} + func generateECDHEParameters(rand io.Reader, curveID CurveID) (ecdheParameters, error) { if curveID == X25519 { privateKey := make([]byte, curve25519.ScalarSize) @@ -279,3 +293,149 @@ func (p *x25519Parameters) MakeLog() (*jsonKeys.ECPoint, *jsonKeys.ECDHPrivatePa return public, private } + +type tls13ECDHEKeyShare struct { + group CurveID + params ecdheParameters +} + +func (k *tls13ECDHEKeyShare) Group() CurveID { return k.group } +func (k *tls13ECDHEKeyShare) PublicKey() []byte { return k.params.PublicKey() } + +func (k *tls13ECDHEKeyShare) SharedKey(serverShare []byte) ([]byte, error) { + sk := k.params.SharedKey(serverShare) + if sk == nil { + return nil, errors.New("tls: invalid server key share") + } + return sk, nil +} + +type tls13X25519MLKEM768KeyShare struct { + dk *mlkem.DecapsulationKey768 + xparams ecdheParameters +} + +func (k *tls13X25519MLKEM768KeyShare) Group() CurveID { return X25519MLKEM768 } + +// ClientHello.key_share.data = EK(1184) || X25519(32) +func (k *tls13X25519MLKEM768KeyShare) PublicKey() []byte { + ek := k.dk.EncapsulationKey().Bytes() + x := k.xparams.PublicKey() + out := make([]byte, 0, len(ek)+len(x)) + out = append(out, ek...) + out = append(out, x...) + return out +} + +// ServerHello.key_share.data = CT(1088) || X25519(32) +// SharedKey = KEM_ss || ECDHE_ss +func (k *tls13X25519MLKEM768KeyShare) SharedKey(serverShare []byte) ([]byte, error) { + if len(serverShare) != mlkem768CTSize+x25519ShareSize { + return nil, errors.New("tls: invalid server share length for X25519MLKEM768") + } + ct := serverShare[:mlkem768CTSize] + sx := serverShare[mlkem768CTSize:] + + kemSS, err := k.dk.Decapsulate(ct) + if err != nil { + return nil, err + } + if len(kemSS) != mlkemSSSize { + return nil, errors.New("tls: invalid ML-KEM shared secret size") + } + + ecdheSS := k.xparams.SharedKey(sx) + if ecdheSS == nil { + return nil, errors.New("tls: invalid server x25519 share") + } + + shared := make([]byte, 0, len(kemSS)+len(ecdheSS)) + shared = append(shared, kemSS...) + shared = append(shared, ecdheSS...) + return shared, nil +} + +func generateTLS13KeyShare(rand io.Reader, group CurveID) (tls13KeyShare, error) { + switch group { + case X25519MLKEM768: + dk, err := mlkem.GenerateKey768() + if err != nil { + return nil, err + } + xp, err := generateECDHEParameters(rand, X25519) + if err != nil { + return nil, err + } + return &tls13X25519MLKEM768KeyShare{dk: dk, xparams: xp}, nil + + default: + if _, ok := curveForCurveID(group); group != X25519 && !ok { + return nil, errors.New("tls: unsupported group") + } + p, err := generateECDHEParameters(rand, group) + if err != nil { + return nil, err + } + return &tls13ECDHEKeyShare{group: group, params: p}, nil + } +} + +func generateTLS13ServerShareAndSharedKey(rand io.Reader, group CurveID, clientShare []byte) ([]byte, []byte, error) { + switch group { + case X25519MLKEM768: + // ClientHello.share = EK(1184) || X25519(32) + if len(clientShare) != mlkem768EKSize+x25519ShareSize { + return nil, nil, errors.New("tls: invalid client share length for X25519MLKEM768") + } + ekBytes := clientShare[:mlkem768EKSize] + cx := clientShare[mlkem768EKSize:] + + ek, err := mlkem.NewEncapsulationKey768(ekBytes) + if err != nil { + return nil, nil, err + } + + kemSS, ct := ek.Encapsulate() + if len(ct) != mlkem768CTSize || len(kemSS) != mlkemSSSize { + return nil, nil, errors.New("tls: invalid ML-KEM encapsulation output size") + } + + sp, err := generateECDHEParameters(rand, X25519) + if err != nil { + return nil, nil, err + } + ecdheSS := sp.SharedKey(cx) + if ecdheSS == nil { + return nil, nil, errors.New("tls: invalid client x25519 share") + } + + // ServerHello.share = CT(1088) || X25519(32) + serverShare := make([]byte, 0, len(ct)+len(sp.PublicKey())) + serverShare = append(serverShare, ct...) + serverShare = append(serverShare, sp.PublicKey()...) + + // shared = KEM_ss || ECDHE_ss + shared := make([]byte, 0, len(kemSS)+len(ecdheSS)) + shared = append(shared, kemSS...) + shared = append(shared, ecdheSS...) + return serverShare, shared, nil + + default: + // Classical TLS 1.3 ECDHE (X25519, P-256, P-384, P-521, etc.) + if _, ok := curveForCurveID(group); group != X25519 && !ok { + return nil, nil, errors.New("tls: unsupported selected group") + } + + params, err := generateECDHEParameters(rand, group) + if err != nil { + return nil, nil, err + } + + sharedKey := params.SharedKey(clientShare) + if sharedKey == nil { + return nil, nil, errors.New("tls: invalid client key share") + } + + return params.PublicKey(), sharedKey, nil + } +} diff --git a/tls/tls_names.go b/tls/tls_names.go index 110cbc21..eebe0626 100644 --- a/tls/tls_names.go +++ b/tls/tls_names.go @@ -446,8 +446,8 @@ func init() { curveNames[258] = "ffdhe4096" curveNames[259] = "ffdhe6144" curveNames[260] = "ffdhe8192" - curveNames[4587] = "secp256r1mlkem768" // draft-kwiatkowski-tls-ecdhe-mlkem - curveNames[4588] = "x25519mlkem768" // draft-kwiatkowski-tls-ecdhe-mlkem + curveNames[4587] = "secp256r1mlkem768" + curveNames[4588] = "x25519mlkem768" curveNames[65281] = "arbitrary_explicit_prime_curves" curveNames[65282] = "arbitrary_explicit_char2_curves"