Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions p2p/kademlia/conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,21 +203,21 @@ func (conn *connWrapper) RemoteAddr() net.Addr {
func (conn *connWrapper) SetDeadline(t time.Time) error {
conn.mtx.Lock()
defer conn.mtx.Unlock()
return conn.rawConn.SetDeadline(t)
return conn.secureConn.SetDeadline(t)
}

// SetReadDeadline implements net.Conn's SetReadDeadline interface
func (conn *connWrapper) SetReadDeadline(t time.Time) error {
conn.mtx.Lock()
defer conn.mtx.Unlock()
return conn.rawConn.SetReadDeadline(t)
return conn.secureConn.SetReadDeadline(t)
}

// SetWriteDeadline implements net.Conn's SetWriteDeadline interface
func (conn *connWrapper) SetWriteDeadline(t time.Time) error {
conn.mtx.Lock()
defer conn.mtx.Unlock()
return conn.rawConn.SetWriteDeadline(t)
return conn.secureConn.SetWriteDeadline(t)
}

// StartConnEviction starts a goroutine that periodically evicts idle connections.
Expand Down
15 changes: 14 additions & 1 deletion p2p/kademlia/dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,19 @@ func NewDHT(ctx context.Context, store Store, metaStore MetaStore, options *Opti
return nil, fmt.Errorf("failed to create client credentials: %w", err)
}

// server creds for incoming
serverCreds, err := ltc.NewServerCreds(&ltc.ServerOptions{
CommonOptions: ltc.CommonOptions{
Keyring: options.Keyring,
LocalIdentity: string(options.ID),
PeerType: securekeyx.Supernode,
Validator: lumera.NewSecureKeyExchangeValidator(options.LumeraClient),
},
})
if err != nil {
return nil, fmt.Errorf("failed to create server credentials: %w", err)
}

// new a hashtable with options
ht, err := NewHashTable(options)
if err != nil {
Expand All @@ -139,7 +152,7 @@ func NewDHT(ctx context.Context, store Store, metaStore MetaStore, options *Opti
s.skipBadBootstrapAddrs()

// new network service for dht
network, err := NewNetwork(ctx, s, ht.self, clientCreds)
network, err := NewNetwork(ctx, s, ht.self, clientCreds, serverCreds)
if err != nil {
return nil, fmt.Errorf("new network: %v", err)
}
Expand Down
26 changes: 17 additions & 9 deletions p2p/kademlia/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,21 @@ type Network struct {
done chan struct{} // network is stopped

// For secure connection
tc credentials.TransportCredentials
clientTC credentials.TransportCredentials // outgoing
serverTC credentials.TransportCredentials // incoming
connPool *ConnPool
connPoolMtx sync.Mutex
sem *semaphore.Weighted
}

// NewNetwork returns a network service
func NewNetwork(ctx context.Context, dht *DHT, self *Node, tc credentials.TransportCredentials) (*Network, error) {
func NewNetwork(ctx context.Context, dht *DHT, self *Node, clientTC, serverTC credentials.TransportCredentials) (*Network, error) {
s := &Network{
dht: dht,
self: self,
done: make(chan struct{}),
tc: tc,
clientTC: clientTC,
serverTC: serverTC,
connPool: NewConnPool(ctx),
sem: semaphore.NewWeighted(maxConcurrentFindBatchValsRequests),
}
Expand Down Expand Up @@ -103,7 +105,7 @@ func (s *Network) Start(ctx context.Context) error {

// Stop the network
func (s *Network) Stop(ctx context.Context) {
if s.tc != nil {
if s.clientTC != nil || s.serverTC != nil {
s.connPool.Release()
}
// close the socket
Expand Down Expand Up @@ -344,8 +346,8 @@ func (s *Network) handleConn(ctx context.Context, rawConn net.Conn) {
"remote-addr": rawConn.RemoteAddr().String(),
})
// do secure handshaking
if s.tc != nil {
conn, err = NewSecureServerConn(ctx, s.tc, rawConn)
if s.serverTC != nil {
conn, err = NewSecureServerConn(ctx, s.serverTC, rawConn)
if err != nil {
rawConn.Close()
logtrace.Warn(ctx, "Server secure handshake failed", logtrace.Fields{
Expand Down Expand Up @@ -607,15 +609,15 @@ func (s *Network) Call(ctx context.Context, request *Message, isLong bool) (*Mes

remoteAddr := fmt.Sprintf("%s@%s:%d", string(request.Receiver.ID), request.Receiver.IP, request.Receiver.Port)

if s.tc == nil {
if s.clientTC == nil {
return nil, errors.New("secure transport credentials are not set")
}

// do secure handshaking
s.connPoolMtx.Lock()
conn, err := s.connPool.Get(remoteAddr)
if err != nil {
conn, err = NewSecureClientConn(ctx, s.tc, remoteAddr)
conn, err = NewSecureClientConn(ctx, s.clientTC, remoteAddr)
if err != nil {
s.connPoolMtx.Unlock()
return nil, errors.Errorf("client secure establish %q: %w", remoteAddr, err)
Expand All @@ -625,7 +627,7 @@ func (s *Network) Call(ctx context.Context, request *Message, isLong bool) (*Mes
s.connPoolMtx.Unlock()

defer func() {
if err != nil && s.tc != nil {
if err != nil && s.clientTC != nil {
s.connPoolMtx.Lock()
defer s.connPoolMtx.Unlock()

Expand All @@ -634,6 +636,12 @@ func (s *Network) Call(ctx context.Context, request *Message, isLong bool) (*Mes
}
}()

// refresh deadline for pooled connections
operationDeadline := time.Now().Add(timeout)
if err := conn.SetDeadline(operationDeadline); err != nil {
return nil, errors.Errorf("failed to set connection deadline: %w", err)
}

// encode and send the request message
data, err := encode(request)
if err != nil {
Expand Down
6 changes: 4 additions & 2 deletions pkg/net/credentials/lumeratc.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ func NewTransportCredentials(side Side, opts interface{}) (credentials.Transport
keyExMutex.Lock()
defer keyExMutex.Unlock()

keyExchanger, exists := keyExchangers[optsCommon.LocalIdentity]
// use side in cache key to separate client/server instances
cacheKey := fmt.Sprintf("%s-%d", optsCommon.LocalIdentity, side)
keyExchanger, exists := keyExchangers[cacheKey]
if !exists {
keyExchanger, err = securekeyx.NewSecureKeyExchange(
optsCommon.Keyring,
Expand All @@ -112,7 +114,7 @@ func NewTransportCredentials(side Side, opts interface{}) (credentials.Transport
if err != nil {
return nil, fmt.Errorf("failed to create secure key exchange: %w", err)
}
keyExchangers[optsCommon.LocalIdentity] = keyExchanger
keyExchangers[cacheKey] = keyExchanger
}

return &LumeraTC{
Expand Down