Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tls/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,19 @@ const (
// CurveID is the type of a TLS identifier for an elliptic curve. See
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8.
//
// In TLS 1.3, this type is called NamedGroup, but at this time this library
// only supports Elliptic Curve based groups. See RFC 8446, Section 4.2.7.
// In TLS 1.3, this type is called NamedGroup. This library historically used it
// for elliptic curves, but it can represent any TLS 1.3 (EC / hybrid / PQ) group.
type CurveID uint16

const (
CurveP256 CurveID = 23
CurveP384 CurveID = 24
CurveP521 CurveID = 25
X25519 CurveID = 29

// Hybrid PQ key exchange groups (TLS 1.3 NamedGroup)
SecP256r1MLKEM768 CurveID = 4587
X25519MLKEM768 CurveID = 4588
)

func (curveID *CurveID) MarshalJSON() ([]byte, error) {
Expand Down
70 changes: 51 additions & 19 deletions tls/handshake_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func (c *ClientFingerprintConfiguration) marshal(config *Config) ([]byte, error)
return hello, nil
}

func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) {
func (c *Conn) makeClientHello() (*clientHelloMsg, map[CurveID]tls13KeyShare, error) {
config := c.config
if len(config.ServerName) == 0 && !config.InsecureSkipVerify {
return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
Expand Down Expand Up @@ -306,22 +306,54 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) {
hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms
}

var params ecdheParameters
var keySharesByGroup map[CurveID]tls13KeyShare
if hello.supportedVersions[0] == VersionTLS13 {
hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13()...)

curveID := config.curvePreferences()[0]
if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
prefs := config.curvePreferences()
if len(prefs) == 0 {
return nil, nil, errors.New("tls: no supported key exchange mechanisms (no curve preferences)")
}
params, err = generateECDHEParameters(config.rand(), curveID)
if err != nil {
return nil, nil, err

// By default, send a single key_share.
// If ML-KEM hybrid is explicitly enabled as the top preference, also send X25519 as fallback.
shareGroups := []CurveID{prefs[0]}
if prefs[0] == X25519MLKEM768 {
// Ensure compatibility with servers that don't support the hybrid group.
if prefs[0] != X25519 {
shareGroups = append(shareGroups, X25519)
}
}

hello.keyShares = make([]keyShare, 0, len(shareGroups))
keySharesByGroup = make(map[CurveID]tls13KeyShare, len(shareGroups))

seen := make(map[CurveID]struct{}, len(shareGroups))
for _, group := range shareGroups {
if _, ok := seen[group]; ok {
continue
}
seen[group] = struct{}{}

ks, genErr := generateTLS13KeyShare(config.rand(), group)
if genErr != nil {
// If a group is not supported/implemented, skip it.
continue
}

hello.keyShares = append(hello.keyShares, keyShare{
group: group,
data: ks.PublicKey(),
})
keySharesByGroup[group] = ks
}

if len(hello.keyShares) == 0 {
return nil, nil, errors.New("tls: no supported key exchange mechanisms (no key shares)")
}
hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}}
}

return hello, params, nil
return hello, keySharesByGroup, nil
}

func (c *Conn) clientHandshake() (err error) {
Expand All @@ -333,7 +365,7 @@ func (c *Conn) clientHandshake() (err error) {
var session *ClientSessionState
var sessionCache ClientSessionCache
var cacheKey string
var ecdheParams ecdheParameters
var keySharesByGroup map[CurveID]tls13KeyShare

// This may be a renegotiation handshake, in which case some fields
// need to be reset.
Expand Down Expand Up @@ -422,7 +454,7 @@ func (c *Conn) clientHandshake() (err error) {
sessionCache = nil
} else {

hello, ecdheParams, err = c.makeClientHello()
hello, keySharesByGroup, err = c.makeClientHello()
if err != nil {
return err
}
Expand Down Expand Up @@ -489,13 +521,13 @@ func (c *Conn) clientHandshake() (err error) {

if c.vers == VersionTLS13 {
hs := &clientHandshakeStateTLS13{
c: c,
serverHello: serverHello,
hello: hello,
ecdheParams: ecdheParams,
session: session,
earlySecret: earlySecret,
binderKey: binderKey,
c: c,
serverHello: serverHello,
hello: hello,
keySharesByGroup: keySharesByGroup,
session: session,
earlySecret: earlySecret,
binderKey: binderKey,
}

// In TLS 1.3, session tickets are delivered after the handshake.
Expand Down
45 changes: 26 additions & 19 deletions tls/handshake_client_tls13.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ import (
)

type clientHandshakeStateTLS13 struct {
c *Conn
serverHello *serverHelloMsg
hello *clientHelloMsg
ecdheParams ecdheParameters
c *Conn
serverHello *serverHelloMsg
hello *clientHelloMsg
keySharesByGroup map[CurveID]tls13KeyShare

session *ClientSessionState
earlySecret []byte
Expand All @@ -34,7 +34,7 @@ type clientHandshakeStateTLS13 struct {
trafficSecret []byte // client_application_traffic_secret_0
}

// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and,
// handshake requires hs.c, hs.hello, hs.serverHello, hs.keySharesByGroup, and,
// optionally, hs.session, hs.earlySecret and hs.binderKey to be set.
func (hs *clientHandshakeStateTLS13) handshake() error {
// The server must not select TLS 1.3 in a renegotiation. See RFC 8446,
Expand All @@ -45,7 +45,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
}

// Consistency check on the presence of a keyShare and its parameters.
if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 {
if len(hs.hello.keyShares) == 0 || hs.keySharesByGroup == nil {
return hs.c.sendAlert(AlertInternalError)
}

Expand Down Expand Up @@ -219,21 +219,20 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
c.sendAlert(AlertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
if hs.ecdheParams.CurveID() == curveID {
if _, ok := hs.keySharesByGroup[curveID]; ok {
c.sendAlert(AlertIllegalParameter)
return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share")
}
if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
c.sendAlert(AlertInternalError)
return errors.New("tls: CurvePreferences includes unsupported curve")
}
params, err := generateECDHEParameters(c.config.rand(), curveID)
ks, err := generateTLS13KeyShare(c.config.rand(), curveID)
if err != nil {
c.sendAlert(AlertInternalError)
return err
}
hs.ecdheParams = params
hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}}
if hs.keySharesByGroup == nil {
hs.keySharesByGroup = make(map[CurveID]tls13KeyShare)
}
hs.keySharesByGroup[curveID] = ks
hs.hello.keyShares = []keyShare{{group: curveID, data: ks.PublicKey()}}
}

hs.hello.raw = nil
Expand Down Expand Up @@ -307,7 +306,9 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error {
c.sendAlert(AlertIllegalParameter)
return errors.New("tls: server did not send a key share")
}
if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() {

ks, ok := hs.keySharesByGroup[hs.serverHello.serverShare.group]
if !ok || ks == nil {
c.sendAlert(AlertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
Expand Down Expand Up @@ -345,10 +346,16 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error {
func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
c := hs.c

sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data)
if sharedKey == nil {
ks, ok := hs.keySharesByGroup[hs.serverHello.serverShare.group]
if !ok || ks == nil {
c.sendAlert(AlertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}

sharedKey, err := ks.SharedKey(hs.serverHello.serverShare.data)
if err != nil {
c.sendAlert(AlertIllegalParameter)
return errors.New("tls: invalid server key share")
return err
}

earlySecret := hs.earlySecret
Expand All @@ -365,7 +372,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
serverHandshakeTrafficLabel, hs.transcript)
c.in.setTrafficSecret(hs.suite, serverSecret)

err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
if err != nil {
c.sendAlert(AlertInternalError)
return err
Expand Down
16 changes: 4 additions & 12 deletions tls/handshake_server_tls13.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,21 +216,13 @@ GroupSelection:
clientKeyShare = &hs.clientHello.keyShares[0]
}

if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok {
c.sendAlert(AlertInternalError)
return errors.New("tls: CurvePreferences includes unsupported curve")
}
params, err := generateECDHEParameters(c.config.rand(), selectedGroup)
serverShareData, sharedKey, err := generateTLS13ServerShareAndSharedKey(c.config.rand(), selectedGroup, clientKeyShare.data)
if err != nil {
c.sendAlert(AlertInternalError)
return err
}
hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()}
hs.sharedKey = params.SharedKey(clientKeyShare.data)
if hs.sharedKey == nil {
c.sendAlert(AlertIllegalParameter)
return errors.New("tls: invalid client key share")
return err
}
hs.hello.serverShare = keyShare{group: selectedGroup, data: serverShareData}
hs.sharedKey = sharedKey

c.serverName = hs.clientHello.serverName
return nil
Expand Down
3 changes: 3 additions & 0 deletions tls/key_agreement.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,9 @@ type ecdheKeyAgreement struct {
func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
var curveID CurveID
for _, c := range clientHello.supportedCurves {
if c == X25519MLKEM768 {
continue // ML-KEM hybrid group is TLS 1.3 (key_share) only.
}
if config.supportsCurve(c) {
curveID = c
break
Expand Down
Loading
Loading