diff --git a/pkg/net/credentials/alts/common/utils.go b/pkg/net/credentials/alts/common/utils.go index e5860e91..262066bf 100644 --- a/pkg/net/credentials/alts/common/utils.go +++ b/pkg/net/credentials/alts/common/utils.go @@ -21,6 +21,14 @@ type ExpandKeyFunc func(sharedSecret []byte, protocol string, info []byte) ([]by // defaultSendHandshakeMessage sends serialized handshake bytes with its signature over the connection // Format: [handshake length][handshake bytes][signature length][signature] func defaultSendHandshakeMessage(conn net.Conn, handshakeBytes, signature []byte) error { + if handshakeBytes == nil { + return fmt.Errorf("empty handshake") + } + // Normalise nil → empty slice so len() works and we still send the frame. + if signature == nil { + signature = []byte{} + } + // Calculate total message size and allocate a single buffer totalSize := MsgLenFieldSize + len(handshakeBytes) + MsgLenFieldSize + len(signature) buf := make([]byte, totalSize) @@ -55,35 +63,40 @@ func defaultSendHandshakeMessage(conn net.Conn, handshakeBytes, signature []byte // Format: [handshake length][handshake bytes][signature length][signature] var SendHandshakeMessage SendHandshakeMessageFunc = defaultSendHandshakeMessage +const maxFrameSize = 64 * 1024 // 64 KiB is plenty for a protobuf handshake + // defaultReceiveHandshakeMessage receives handshake bytes and its signature from the connection. // Format: [handshake length][handshake bytes][signature length][signature] func defaultReceiveHandshakeMessage(conn net.Conn) ([]byte, []byte, error) { - // Read handshake length - lenBuf := make([]byte, MsgLenFieldSize) - if _, err := io.ReadFull(conn, lenBuf); err != nil { - return nil, nil, fmt.Errorf("failed to read handshake length: %w", err) + var lenBuf [MsgLenFieldSize]byte + + if _, err := io.ReadFull(conn, lenBuf[:]); err != nil { + return nil, nil, fmt.Errorf("read handshake len: %w", err) + } + handshakeLen := binary.BigEndian.Uint32(lenBuf[:]) + if handshakeLen == 0 || handshakeLen > maxFrameSize { + return nil, nil, fmt.Errorf("invalid handshake length %d", handshakeLen) } - handshakeLen := binary.BigEndian.Uint32(lenBuf) - // Read handshake bytes - handshakeBytes := make([]byte, handshakeLen) - if _, err := io.ReadFull(conn, handshakeBytes); err != nil { - return nil, nil, fmt.Errorf("failed to read handshake bytes: %w", err) + handshake := make([]byte, handshakeLen) + if _, err := io.ReadFull(conn, handshake); err != nil { + return nil, nil, fmt.Errorf("read handshake: %w", err) } - // Read signature length - if _, err := io.ReadFull(conn, lenBuf); err != nil { - return nil, nil, fmt.Errorf("failed to read signature length: %w", err) + if _, err := io.ReadFull(conn, lenBuf[:]); err != nil { + return nil, nil, fmt.Errorf("read signature len: %w", err) + } + sigLen := binary.BigEndian.Uint32(lenBuf[:]) + if sigLen == 0 || sigLen > maxFrameSize { + return nil, nil, fmt.Errorf("invalid signature length %d", sigLen) } - sigLen := binary.BigEndian.Uint32(lenBuf) - // Read signature signature := make([]byte, sigLen) if _, err := io.ReadFull(conn, signature); err != nil { - return nil, nil, fmt.Errorf("failed to read signature: %w", err) + return nil, nil, fmt.Errorf("read signature: %w", err) } - return handshakeBytes, signature, nil + return handshake, signature, nil } // receiveHandshakeMessage receives handshake bytes and its signature from the connection @@ -119,23 +132,21 @@ func GetALTSKeySize(protocol string) (int, error) { } } +// defaultExpandKey always runs HKDF–SHA-256 so the key material is +// uniformly random even when the raw ECDH secret is longer than needed. func defaultExpandKey(sharedSecret []byte, protocol string, info []byte) ([]byte, error) { keySize, err := GetALTSKeySize(protocol) if err != nil { - return nil, fmt.Errorf("failed to get key size: %w", err) - } - if keySize <= len(sharedSecret) { - return sharedSecret[:keySize], nil + return nil, err } - // Use HKDF with SHA-256 - hkdf := hkdf.New(sha256.New, sharedSecret, nil, info) + // HKDF with empty salt (ok for ECDH) & caller-provided info. + h := hkdf.New(sha256.New, sharedSecret, nil, info) key := make([]byte, keySize) - if _, err := io.ReadFull(hkdf, key); err != nil { - return nil, fmt.Errorf("failed to expand key: %w", err) + if _, err := io.ReadFull(h, key); err != nil { + return nil, fmt.Errorf("expand key: %w", err) } - return key, nil } diff --git a/pkg/net/credentials/alts/conn/record.go b/pkg/net/credentials/alts/conn/record.go index 85bebaa1..d513f5e6 100644 --- a/pkg/net/credentials/alts/conn/record.go +++ b/pkg/net/credentials/alts/conn/record.go @@ -3,11 +3,14 @@ package conn import ( + "crypto/aes" + "crypto/cipher" "encoding/binary" "fmt" "math" "net" + "github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/common" . "github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/common" ) @@ -61,29 +64,93 @@ type Conn struct { overhead int } -// NewConn creates a new secure channel instance given the other party role and -// handshaking result. -var NewConn = func (c net.Conn, side Side, recordProtocol string, key, protected []byte) (net.Conn, error) { - newCrypto := protocols[recordProtocol] - if newCrypto == nil { +func init() { + registerFactory(common.RecordProtocolAESGCM, newAESGCM) + registerFactory(common.RecordProtocolAESGCMReKey, newAESGCMRekey) +} + +type aesGCMRecord struct { + aead cipher.AEAD + overhead int +} + +func newAESGCM(s common.Side, key []byte) (common.ALTSRecordCrypto, error) { + if len(key) != common.KeySizeAESGCM { + return nil, fmt.Errorf("aesgcm: need %d-byte key, got %d", common.KeySizeAESGCM, len(key)) + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + return &aesGCMRecord{aead: aead, overhead: aead.Overhead()}, nil +} + +// newAESGCMRekey expects the 44-byte bundle: 32-byte key || 12-byte counter mask. +// We ignore the counter-mask here (your existing frame layer may XOR it with the +// sequence number; if you haven’t implemented re-keying yet just drop it). +func newAESGCMRekey(s common.Side, keyData []byte) (common.ALTSRecordCrypto, error) { + if len(keyData) != common.KeySizeAESGCMReKey { + return nil, fmt.Errorf("aesgcm-rekey: need %d-byte keyData, got %d", + common.KeySizeAESGCMReKey, len(keyData)) + } + return newAESGCM(s, keyData[:common.KeySizeAESGCM]) +} + +func (r *aesGCMRecord) EncryptionOverhead() int { return r.overhead } + +func (r *aesGCMRecord) Encrypt(dst, plain []byte) ([]byte, error) { + nonce := make([]byte, r.aead.NonceSize()) + return r.aead.Seal(dst, nonce, plain, nil), nil +} + +func (r *aesGCMRecord) Decrypt(dst, cipherTxt []byte) ([]byte, error) { + nonce := make([]byte, r.aead.NonceSize()) + plain, err := r.aead.Open(dst, nonce, cipherTxt, nil) + if err != nil { + return nil, fmt.Errorf("aesgcm decrypt: %w", err) + } + return plain, nil +} + +// NewConn creates a new ALTS secure channel after the handshake. +// +// Params +// +// c – the raw TCP connection returned by net.Dial / Accept +// side – common.ClientSide or common.ServerSide +// recordProtocol – string negotiated during the handshake +// key – expanded key material for the record protocol +// protected – any bytes already read from the socket that belong to +// the first encrypted frame (nil in the usual path) +// +// It returns a *Conn that transparently Encrypts/Decrypts every frame. +var NewConn = func( + c net.Conn, + side Side, + recordProtocol string, + key, protected []byte, +) (net.Conn, error) { + + factory, ok := lookupFactory(recordProtocol) + if !ok { return nil, fmt.Errorf("negotiated unknown next_protocol %q", recordProtocol) } - crypto, err := newCrypto(side, key) + crypto, err := factory(side, key) if err != nil { return nil, fmt.Errorf("protocol %q: %v", recordProtocol, err) } + overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead() payloadLengthLimit := altsRecordDefaultLength - overhead + + // If the caller already peeked bytes from the socket (rare) copy them var protectedBuf []byte if protected == nil { - // We pre-allocate protected to be of size - // 2*altsRecordDefaultLength-1 during initialization. We only - // read from the network into protected when protected does not - // contain a complete frame, which is at most - // altsRecordDefaultLength-1 (bytes). And we read at most - // altsRecordDefaultLength (bytes) data into protected at one - // time. Therefore, 2*altsRecordDefaultLength-1 is large enough - // to buffer data read from the network. + // 2*altsRecordDefaultLength-1 is enough to hold one full frame protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1) } else { protectedBuf = make([]byte, len(protected)) @@ -95,8 +162,8 @@ var NewConn = func (c net.Conn, side Side, recordProtocol string, key, protected crypto: crypto, payloadLengthLimit: payloadLengthLimit, protected: protectedBuf, - writeBuf: make([]byte, altsWriteBufferInitialSize), nextFrame: protectedBuf, + writeBuf: make([]byte, altsWriteBufferInitialSize), overhead: overhead, } return altsConn, nil diff --git a/pkg/net/credentials/alts/conn/register.go b/pkg/net/credentials/alts/conn/register.go index 24bae62e..824396a0 100644 --- a/pkg/net/credentials/alts/conn/register.go +++ b/pkg/net/credentials/alts/conn/register.go @@ -1,6 +1,9 @@ package conn import ( + "sync" + + "github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/common" . "github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/common" ) @@ -32,4 +35,24 @@ func RegisterALTSRecordProtocols() { func UnregisterALTSRecordProtocols() { ALTSRecordProtocols = make([]string, 0) -} \ No newline at end of file +} + +var ( + recMu sync.RWMutex + facts = make(map[string]common.ALTSRecordFunc) +) + +// registerFactory is called from init() blocks. +func registerFactory(proto string, f common.ALTSRecordFunc) { + recMu.Lock() + facts[proto] = f + recMu.Unlock() +} + +// lookupFactory is used by NewConn. +func lookupFactory(proto string) (common.ALTSRecordFunc, bool) { + recMu.RLock() + f, ok := facts[proto] + recMu.RUnlock() + return f, ok +} diff --git a/pkg/net/credentials/alts/handshake/handshake.go b/pkg/net/credentials/alts/handshake/handshake.go index 89f3f379..e513c03c 100644 --- a/pkg/net/credentials/alts/handshake/handshake.go +++ b/pkg/net/credentials/alts/handshake/handshake.go @@ -88,12 +88,12 @@ func newHandshaker(keyExchange securekeyx.KeyExchanger, conn net.Conn, remoteAdd side Side, timeout time.Duration, opts interface{}) *secureHandshaker { hs := &secureHandshaker{ - conn: conn, + conn: conn, keyExchanger: keyExchange, - remoteAddr: remoteAddr, - side: side, - protocol: RecordProtocolXChaCha20Poly1305ReKey, // Default to XChaCha20-Poly1305 - timeout: timeout, + remoteAddr: remoteAddr, + side: side, + protocol: RecordProtocolAESGCMReKey, + timeout: timeout, } if side == ClientSide { @@ -295,7 +295,7 @@ func (h *secureHandshaker) ServerHandshake(ctx context.Context) (net.Conn, crede // Create ALTS connection altsConn, err := NewConn(h.conn, h.side, h.protocol, expandedKey, nil) if err != nil { - return nil, nil, fmt.Errorf("failed to create ALTS connection: %w", err) + return nil, nil, fmt.Errorf("failed to create ALTS connection on Server: %w", err) } // Clear expanded key @@ -314,54 +314,61 @@ func (h *secureHandshaker) ServerHandshake(ctx context.Context) (net.Conn, crede return altsConn, clientAuthInfo, nil } -func (h *secureHandshaker) defaultReadRequestWithTimeout(ctx context.Context) ([]byte, []byte, error) { - readChan := make(chan handshakeData, 1) - - go func() { - reqBytes, reqSig, err := ReceiveHandshakeMessage(h.conn) - readChan <- handshakeData{reqBytes, reqSig, err} - }() +// defaultReadRequestWithTimeout blocks in-place (no goroutine) and relies +// on conn deadlines to honour ctx cancellation. +func (h *secureHandshaker) defaultReadRequestWithTimeout( + ctx context.Context, +) ([]byte, []byte, error) { - select { - case <-ctx.Done(): - return nil, nil, ctx.Err() - case result := <-readChan: - if result.err != nil { - return nil, nil, fmt.Errorf("failed to receive handshake request: %w", result.err) - } - if len(result.bytes) == 0 { + if deadline, ok := ctx.Deadline(); ok { + _ = h.conn.SetReadDeadline(deadline) + } + bytes, sig, err := ReceiveHandshakeMessage(h.conn) + if err != nil { + // Map timeout to our public errors + if ne, ok := err.(net.Error); ok && ne.Timeout() { return nil, nil, ErrPeerNotResponding } - return result.bytes, result.signature, nil + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, nil, ctx.Err() + } + return nil, nil, fmt.Errorf("receive handshake request: %w", err) + } + if len(bytes) == 0 { + return nil, nil, ErrPeerNotResponding } + // Clear deadline for future I/O + _ = h.conn.SetReadDeadline(time.Time{}) + return bytes, sig, nil } -func (h *secureHandshaker) defaultReadResponseWithTimeout(ctx context.Context, lastWrite time.Time) ([]byte, []byte, error) { - readChan := make(chan handshakeData, 1) - readStartTime := time.Now() +// defaultReadResponseWithTimeout behaves like the request version but also +// checks the elapsed time since last write to decide “peer not responding”. +func (h *secureHandshaker) defaultReadResponseWithTimeout( + ctx context.Context, + lastWrite time.Time, +) ([]byte, []byte, error) { - go func() { - respBytes, respSig, err := ReceiveHandshakeMessage(h.conn) - readChan <- handshakeData{respBytes, respSig, err} - }() - - select { - case <-ctx.Done(): - // If we've been reading for long enough and got no response, treat as unresponsive - if time.Since(readStartTime) >= h.timeout { + if deadline, ok := ctx.Deadline(); ok { + _ = h.conn.SetReadDeadline(deadline) + } + bytes, sig, err := ReceiveHandshakeMessage(h.conn) + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Timeout() { return nil, nil, ErrPeerNotResponding } - return nil, nil, ctx.Err() - case result := <-readChan: - if result.err != nil { - return nil, nil, fmt.Errorf("failed to receive handshake response: %w", result.err) + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, nil, ctx.Err() } - // If nothing was written and nothing was read, peer is not responding - if len(result.bytes) == 0 && time.Since(lastWrite) > h.timeout { - return nil, nil, ErrPeerNotResponding - } - return result.bytes, result.signature, nil + return nil, nil, fmt.Errorf("receive handshake response: %w", err) + } + + // Detect silent peer + if len(bytes) == 0 && time.Since(lastWrite) >= h.timeout { + return nil, nil, ErrPeerNotResponding } + _ = h.conn.SetReadDeadline(time.Time{}) + return bytes, sig, nil } func (h *secureHandshaker) readResponseWithTimeout(ctx context.Context, lastWrite time.Time) ([]byte, []byte, error) {