Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions p2p/kademlia/conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,21 +203,21 @@ func (conn *connWrapper) RemoteAddr() net.Addr {
func (conn *connWrapper) SetDeadline(t time.Time) error {
conn.mtx.Lock()
defer conn.mtx.Unlock()
return conn.rawConn.SetDeadline(t)
return conn.secureConn.SetDeadline(t)
}

// SetReadDeadline implements net.Conn's SetReadDeadline interface
func (conn *connWrapper) SetReadDeadline(t time.Time) error {
conn.mtx.Lock()
defer conn.mtx.Unlock()
return conn.rawConn.SetReadDeadline(t)
return conn.secureConn.SetReadDeadline(t)
}

// SetWriteDeadline implements net.Conn's SetWriteDeadline interface
func (conn *connWrapper) SetWriteDeadline(t time.Time) error {
conn.mtx.Lock()
defer conn.mtx.Unlock()
return conn.rawConn.SetWriteDeadline(t)
return conn.secureConn.SetWriteDeadline(t)
}

// StartConnEviction starts a goroutine that periodically evicts idle connections.
Expand Down
15 changes: 14 additions & 1 deletion p2p/kademlia/dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,19 @@ func NewDHT(ctx context.Context, store Store, metaStore MetaStore, options *Opti
return nil, fmt.Errorf("failed to create client credentials: %w", err)
}

// Initialize server credentials for incoming connections
serverCreds, err := ltc.NewServerCreds(&ltc.ServerOptions{
CommonOptions: ltc.CommonOptions{
Keyring: options.Keyring,
LocalIdentity: string(options.ID),
PeerType: securekeyx.Supernode,
Validator: lumera.NewSecureKeyExchangeValidator(options.LumeraClient),
},
})
if err != nil {
return nil, fmt.Errorf("failed to create server credentials: %w", err)
}

// new a hashtable with options
ht, err := NewHashTable(options)
if err != nil {
Expand All @@ -139,7 +152,7 @@ func NewDHT(ctx context.Context, store Store, metaStore MetaStore, options *Opti
s.skipBadBootstrapAddrs()

// new network service for dht
network, err := NewNetwork(ctx, s, ht.self, clientCreds)
network, err := NewNetwork(ctx, s, ht.self, clientCreds, serverCreds)
if err != nil {
return nil, fmt.Errorf("new network: %v", err)
}
Expand Down
20 changes: 11 additions & 9 deletions p2p/kademlia/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,21 @@ type Network struct {
done chan struct{} // network is stopped

// For secure connection
tc credentials.TransportCredentials
clientTC credentials.TransportCredentials // for outgoing connections
serverTC credentials.TransportCredentials // for incoming connections
connPool *ConnPool
connPoolMtx sync.Mutex
sem *semaphore.Weighted
}

// NewNetwork returns a network service
func NewNetwork(ctx context.Context, dht *DHT, self *Node, tc credentials.TransportCredentials) (*Network, error) {
func NewNetwork(ctx context.Context, dht *DHT, self *Node, clientTC, serverTC credentials.TransportCredentials) (*Network, error) {
s := &Network{
dht: dht,
self: self,
done: make(chan struct{}),
tc: tc,
clientTC: clientTC,
serverTC: serverTC,
connPool: NewConnPool(ctx),
sem: semaphore.NewWeighted(maxConcurrentFindBatchValsRequests),
}
Expand Down Expand Up @@ -103,7 +105,7 @@ func (s *Network) Start(ctx context.Context) error {

// Stop the network
func (s *Network) Stop(ctx context.Context) {
if s.tc != nil {
if s.clientTC != nil || s.serverTC != nil {
s.connPool.Release()
}
// close the socket
Expand Down Expand Up @@ -344,8 +346,8 @@ func (s *Network) handleConn(ctx context.Context, rawConn net.Conn) {
"remote-addr": rawConn.RemoteAddr().String(),
})
// do secure handshaking
if s.tc != nil {
conn, err = NewSecureServerConn(ctx, s.tc, rawConn)
if s.serverTC != nil {
conn, err = NewSecureServerConn(ctx, s.serverTC, rawConn)
if err != nil {
rawConn.Close()
logtrace.Warn(ctx, "Server secure handshake failed", logtrace.Fields{
Expand Down Expand Up @@ -607,15 +609,15 @@ func (s *Network) Call(ctx context.Context, request *Message, isLong bool) (*Mes

remoteAddr := fmt.Sprintf("%s@%s:%d", string(request.Receiver.ID), request.Receiver.IP, request.Receiver.Port)

if s.tc == nil {
if s.clientTC == nil {
return nil, errors.New("secure transport credentials are not set")
}

// do secure handshaking
s.connPoolMtx.Lock()
conn, err := s.connPool.Get(remoteAddr)
if err != nil {
conn, err = NewSecureClientConn(ctx, s.tc, remoteAddr)
conn, err = NewSecureClientConn(ctx, s.clientTC, remoteAddr)
if err != nil {
s.connPoolMtx.Unlock()
return nil, errors.Errorf("client secure establish %q: %w", remoteAddr, err)
Expand All @@ -625,7 +627,7 @@ func (s *Network) Call(ctx context.Context, request *Message, isLong bool) (*Mes
s.connPoolMtx.Unlock()

defer func() {
if err != nil && s.tc != nil {
if err != nil && s.clientTC != nil {
s.connPoolMtx.Lock()
defer s.connPoolMtx.Unlock()

Expand Down
7 changes: 5 additions & 2 deletions pkg/net/credentials/lumeratc.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ func NewTransportCredentials(side Side, opts interface{}) (credentials.Transport
keyExMutex.Lock()
defer keyExMutex.Unlock()

keyExchanger, exists := keyExchangers[optsCommon.LocalIdentity]
// Create unique cache key that includes both identity and side (client/server)
// This ensures client and server credentials get separate KeyExchanger instances
cacheKey := fmt.Sprintf("%s-%d", optsCommon.LocalIdentity, side)
keyExchanger, exists := keyExchangers[cacheKey]
if !exists {
keyExchanger, err = securekeyx.NewSecureKeyExchange(
optsCommon.Keyring,
Expand All @@ -112,7 +115,7 @@ func NewTransportCredentials(side Side, opts interface{}) (credentials.Transport
if err != nil {
return nil, fmt.Errorf("failed to create secure key exchange: %w", err)
}
keyExchangers[optsCommon.LocalIdentity] = keyExchanger
keyExchangers[cacheKey] = keyExchanger
}

return &LumeraTC{
Expand Down