Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 36 additions & 25 deletions pkg/net/credentials/alts/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ type ExpandKeyFunc func(sharedSecret []byte, protocol string, info []byte) ([]by
// defaultSendHandshakeMessage sends serialized handshake bytes with its signature over the connection
// Format: [handshake length][handshake bytes][signature length][signature]
func defaultSendHandshakeMessage(conn net.Conn, handshakeBytes, signature []byte) error {
if handshakeBytes == nil {
return fmt.Errorf("empty handshake")
}
// Normalise nil → empty slice so len() works and we still send the frame.
if signature == nil {
signature = []byte{}
}

// Calculate total message size and allocate a single buffer
totalSize := MsgLenFieldSize + len(handshakeBytes) + MsgLenFieldSize + len(signature)
buf := make([]byte, totalSize)
Expand Down Expand Up @@ -55,35 +63,40 @@ func defaultSendHandshakeMessage(conn net.Conn, handshakeBytes, signature []byte
// Format: [handshake length][handshake bytes][signature length][signature]
var SendHandshakeMessage SendHandshakeMessageFunc = defaultSendHandshakeMessage

const maxFrameSize = 64 * 1024 // 64 KiB is plenty for a protobuf handshake

// defaultReceiveHandshakeMessage receives handshake bytes and its signature from the connection.
// Format: [handshake length][handshake bytes][signature length][signature]
func defaultReceiveHandshakeMessage(conn net.Conn) ([]byte, []byte, error) {
// Read handshake length
lenBuf := make([]byte, MsgLenFieldSize)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return nil, nil, fmt.Errorf("failed to read handshake length: %w", err)
var lenBuf [MsgLenFieldSize]byte

if _, err := io.ReadFull(conn, lenBuf[:]); err != nil {
return nil, nil, fmt.Errorf("read handshake len: %w", err)
}
handshakeLen := binary.BigEndian.Uint32(lenBuf[:])
if handshakeLen == 0 || handshakeLen > maxFrameSize {
return nil, nil, fmt.Errorf("invalid handshake length %d", handshakeLen)
}
handshakeLen := binary.BigEndian.Uint32(lenBuf)

// Read handshake bytes
handshakeBytes := make([]byte, handshakeLen)
if _, err := io.ReadFull(conn, handshakeBytes); err != nil {
return nil, nil, fmt.Errorf("failed to read handshake bytes: %w", err)
handshake := make([]byte, handshakeLen)
if _, err := io.ReadFull(conn, handshake); err != nil {
return nil, nil, fmt.Errorf("read handshake: %w", err)
}

// Read signature length
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return nil, nil, fmt.Errorf("failed to read signature length: %w", err)
if _, err := io.ReadFull(conn, lenBuf[:]); err != nil {
return nil, nil, fmt.Errorf("read signature len: %w", err)
}
sigLen := binary.BigEndian.Uint32(lenBuf[:])
if sigLen == 0 || sigLen > maxFrameSize {
return nil, nil, fmt.Errorf("invalid signature length %d", sigLen)
}
sigLen := binary.BigEndian.Uint32(lenBuf)

// Read signature
signature := make([]byte, sigLen)
if _, err := io.ReadFull(conn, signature); err != nil {
return nil, nil, fmt.Errorf("failed to read signature: %w", err)
return nil, nil, fmt.Errorf("read signature: %w", err)
}

return handshakeBytes, signature, nil
return handshake, signature, nil
}

// receiveHandshakeMessage receives handshake bytes and its signature from the connection
Expand Down Expand Up @@ -119,23 +132,21 @@ func GetALTSKeySize(protocol string) (int, error) {
}
}

// defaultExpandKey always runs HKDF–SHA-256 so the key material is
// uniformly random even when the raw ECDH secret is longer than needed.
func defaultExpandKey(sharedSecret []byte, protocol string, info []byte) ([]byte, error) {
keySize, err := GetALTSKeySize(protocol)
if err != nil {
return nil, fmt.Errorf("failed to get key size: %w", err)
}
if keySize <= len(sharedSecret) {
return sharedSecret[:keySize], nil
return nil, err
}

// Use HKDF with SHA-256
hkdf := hkdf.New(sha256.New, sharedSecret, nil, info)
// HKDF with empty salt (ok for ECDH) & caller-provided info.
h := hkdf.New(sha256.New, sharedSecret, nil, info)

key := make([]byte, keySize)
if _, err := io.ReadFull(hkdf, key); err != nil {
return nil, fmt.Errorf("failed to expand key: %w", err)
if _, err := io.ReadFull(h, key); err != nil {
return nil, fmt.Errorf("expand key: %w", err)
}

return key, nil
}

Expand Down
97 changes: 82 additions & 15 deletions pkg/net/credentials/alts/conn/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
package conn

import (
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"fmt"
"math"
"net"

"github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/common"
. "github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/common"
)

Expand Down Expand Up @@ -61,29 +64,93 @@ type Conn struct {
overhead int
}

// NewConn creates a new secure channel instance given the other party role and
// handshaking result.
var NewConn = func (c net.Conn, side Side, recordProtocol string, key, protected []byte) (net.Conn, error) {
newCrypto := protocols[recordProtocol]
if newCrypto == nil {
func init() {
registerFactory(common.RecordProtocolAESGCM, newAESGCM)
registerFactory(common.RecordProtocolAESGCMReKey, newAESGCMRekey)
}

type aesGCMRecord struct {
aead cipher.AEAD
overhead int
}

func newAESGCM(s common.Side, key []byte) (common.ALTSRecordCrypto, error) {
if len(key) != common.KeySizeAESGCM {
return nil, fmt.Errorf("aesgcm: need %d-byte key, got %d", common.KeySizeAESGCM, len(key))
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
return &aesGCMRecord{aead: aead, overhead: aead.Overhead()}, nil
}

// newAESGCMRekey expects the 44-byte bundle: 32-byte key || 12-byte counter mask.
// We ignore the counter-mask here (your existing frame layer may XOR it with the
// sequence number; if you haven’t implemented re-keying yet just drop it).
func newAESGCMRekey(s common.Side, keyData []byte) (common.ALTSRecordCrypto, error) {
if len(keyData) != common.KeySizeAESGCMReKey {
return nil, fmt.Errorf("aesgcm-rekey: need %d-byte keyData, got %d",
common.KeySizeAESGCMReKey, len(keyData))
}
return newAESGCM(s, keyData[:common.KeySizeAESGCM])
}

func (r *aesGCMRecord) EncryptionOverhead() int { return r.overhead }

func (r *aesGCMRecord) Encrypt(dst, plain []byte) ([]byte, error) {
nonce := make([]byte, r.aead.NonceSize())
return r.aead.Seal(dst, nonce, plain, nil), nil
}

func (r *aesGCMRecord) Decrypt(dst, cipherTxt []byte) ([]byte, error) {
nonce := make([]byte, r.aead.NonceSize())
plain, err := r.aead.Open(dst, nonce, cipherTxt, nil)
if err != nil {
return nil, fmt.Errorf("aesgcm decrypt: %w", err)
}
return plain, nil
}

// NewConn creates a new ALTS secure channel after the handshake.
//
// Params
//
// c – the raw TCP connection returned by net.Dial / Accept
// side – common.ClientSide or common.ServerSide
// recordProtocol – string negotiated during the handshake
// key – expanded key material for the record protocol
// protected – any bytes already read from the socket that belong to
// the first encrypted frame (nil in the usual path)
//
// It returns a *Conn that transparently Encrypts/Decrypts every frame.
var NewConn = func(
c net.Conn,
side Side,
recordProtocol string,
key, protected []byte,
) (net.Conn, error) {

factory, ok := lookupFactory(recordProtocol)
if !ok {
return nil, fmt.Errorf("negotiated unknown next_protocol %q", recordProtocol)
}
crypto, err := newCrypto(side, key)
crypto, err := factory(side, key)
if err != nil {
return nil, fmt.Errorf("protocol %q: %v", recordProtocol, err)
}

overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
payloadLengthLimit := altsRecordDefaultLength - overhead

// If the caller already peeked bytes from the socket (rare) copy them
var protectedBuf []byte
if protected == nil {
// We pre-allocate protected to be of size
// 2*altsRecordDefaultLength-1 during initialization. We only
// read from the network into protected when protected does not
// contain a complete frame, which is at most
// altsRecordDefaultLength-1 (bytes). And we read at most
// altsRecordDefaultLength (bytes) data into protected at one
// time. Therefore, 2*altsRecordDefaultLength-1 is large enough
// to buffer data read from the network.
// 2*altsRecordDefaultLength-1 is enough to hold one full frame
protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1)
} else {
protectedBuf = make([]byte, len(protected))
Expand All @@ -95,8 +162,8 @@ var NewConn = func (c net.Conn, side Side, recordProtocol string, key, protected
crypto: crypto,
payloadLengthLimit: payloadLengthLimit,
protected: protectedBuf,
writeBuf: make([]byte, altsWriteBufferInitialSize),
nextFrame: protectedBuf,
writeBuf: make([]byte, altsWriteBufferInitialSize),
overhead: overhead,
}
return altsConn, nil
Expand Down
25 changes: 24 additions & 1 deletion pkg/net/credentials/alts/conn/register.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package conn

import (
"sync"

"github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/common"
. "github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/common"
)

Expand Down Expand Up @@ -32,4 +35,24 @@ func RegisterALTSRecordProtocols() {

func UnregisterALTSRecordProtocols() {
ALTSRecordProtocols = make([]string, 0)
}
}

var (
recMu sync.RWMutex
facts = make(map[string]common.ALTSRecordFunc)
)

// registerFactory is called from init() blocks.
func registerFactory(proto string, f common.ALTSRecordFunc) {
recMu.Lock()
facts[proto] = f
recMu.Unlock()
}

// lookupFactory is used by NewConn.
func lookupFactory(proto string) (common.ALTSRecordFunc, bool) {
recMu.RLock()
f, ok := facts[proto]
recMu.RUnlock()
return f, ok
}
93 changes: 50 additions & 43 deletions pkg/net/credentials/alts/handshake/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ func newHandshaker(keyExchange securekeyx.KeyExchanger, conn net.Conn, remoteAdd
side Side, timeout time.Duration, opts interface{}) *secureHandshaker {

hs := &secureHandshaker{
conn: conn,
conn: conn,
keyExchanger: keyExchange,
remoteAddr: remoteAddr,
side: side,
protocol: RecordProtocolXChaCha20Poly1305ReKey, // Default to XChaCha20-Poly1305
timeout: timeout,
remoteAddr: remoteAddr,
side: side,
protocol: RecordProtocolAESGCMReKey,
timeout: timeout,
}

if side == ClientSide {
Expand Down Expand Up @@ -295,7 +295,7 @@ func (h *secureHandshaker) ServerHandshake(ctx context.Context) (net.Conn, crede
// Create ALTS connection
altsConn, err := NewConn(h.conn, h.side, h.protocol, expandedKey, nil)
if err != nil {
return nil, nil, fmt.Errorf("failed to create ALTS connection: %w", err)
return nil, nil, fmt.Errorf("failed to create ALTS connection on Server: %w", err)
}

// Clear expanded key
Expand All @@ -314,54 +314,61 @@ func (h *secureHandshaker) ServerHandshake(ctx context.Context) (net.Conn, crede
return altsConn, clientAuthInfo, nil
}

func (h *secureHandshaker) defaultReadRequestWithTimeout(ctx context.Context) ([]byte, []byte, error) {
readChan := make(chan handshakeData, 1)

go func() {
reqBytes, reqSig, err := ReceiveHandshakeMessage(h.conn)
readChan <- handshakeData{reqBytes, reqSig, err}
}()
// defaultReadRequestWithTimeout blocks in-place (no goroutine) and relies
// on conn deadlines to honour ctx cancellation.
func (h *secureHandshaker) defaultReadRequestWithTimeout(
ctx context.Context,
) ([]byte, []byte, error) {

select {
case <-ctx.Done():
return nil, nil, ctx.Err()
case result := <-readChan:
if result.err != nil {
return nil, nil, fmt.Errorf("failed to receive handshake request: %w", result.err)
}
if len(result.bytes) == 0 {
if deadline, ok := ctx.Deadline(); ok {
_ = h.conn.SetReadDeadline(deadline)
}
bytes, sig, err := ReceiveHandshakeMessage(h.conn)
if err != nil {
// Map timeout to our public errors
if ne, ok := err.(net.Error); ok && ne.Timeout() {
return nil, nil, ErrPeerNotResponding
}
return result.bytes, result.signature, nil
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, nil, ctx.Err()
}
return nil, nil, fmt.Errorf("receive handshake request: %w", err)
}
if len(bytes) == 0 {
return nil, nil, ErrPeerNotResponding
}
// Clear deadline for future I/O
_ = h.conn.SetReadDeadline(time.Time{})
return bytes, sig, nil
}

func (h *secureHandshaker) defaultReadResponseWithTimeout(ctx context.Context, lastWrite time.Time) ([]byte, []byte, error) {
readChan := make(chan handshakeData, 1)
readStartTime := time.Now()
// defaultReadResponseWithTimeout behaves like the request version but also
// checks the elapsed time since last write to decide “peer not responding”.
func (h *secureHandshaker) defaultReadResponseWithTimeout(
ctx context.Context,
lastWrite time.Time,
) ([]byte, []byte, error) {

go func() {
respBytes, respSig, err := ReceiveHandshakeMessage(h.conn)
readChan <- handshakeData{respBytes, respSig, err}
}()

select {
case <-ctx.Done():
// If we've been reading for long enough and got no response, treat as unresponsive
if time.Since(readStartTime) >= h.timeout {
if deadline, ok := ctx.Deadline(); ok {
_ = h.conn.SetReadDeadline(deadline)
}
bytes, sig, err := ReceiveHandshakeMessage(h.conn)
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Timeout() {
return nil, nil, ErrPeerNotResponding
}
return nil, nil, ctx.Err()
case result := <-readChan:
if result.err != nil {
return nil, nil, fmt.Errorf("failed to receive handshake response: %w", result.err)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, nil, ctx.Err()
}
// If nothing was written and nothing was read, peer is not responding
if len(result.bytes) == 0 && time.Since(lastWrite) > h.timeout {
return nil, nil, ErrPeerNotResponding
}
return result.bytes, result.signature, nil
return nil, nil, fmt.Errorf("receive handshake response: %w", err)
}

// Detect silent peer
if len(bytes) == 0 && time.Since(lastWrite) >= h.timeout {
return nil, nil, ErrPeerNotResponding
}
_ = h.conn.SetReadDeadline(time.Time{})
return bytes, sig, nil
}

func (h *secureHandshaker) readResponseWithTimeout(ctx context.Context, lastWrite time.Time) ([]byte, []byte, error) {
Expand Down