Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pkg/crypto/hash.go
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
package crypto

1 change: 0 additions & 1 deletion pkg/errors/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,3 @@ func TestRecover(t *testing.T) {
}

}

2 changes: 1 addition & 1 deletion pkg/lumera/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"context"
"fmt"

authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
sntypes "github.com/LumeraProtocol/lumera/x/supernode/v1/types"
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
)

type SecureKeyExchangeValidator struct {
Expand Down
79 changes: 39 additions & 40 deletions pkg/net/credentials/address_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,99 +2,98 @@ package credentials

import (
"fmt"
"strings"
"net"
"strconv"
"strings"
)

// LumeraAddress represents the components of a Lumera address
type LumeraAddress struct {
Identity string
Host string
Port uint16
Identity string
Host string
Port uint16
}

type LumeraAddresses []LumeraAddress

// String returns the address in the format "identity@host:port"
func (a LumeraAddress) String() string {
return fmt.Sprintf("%s@%s:%d", a.Identity, a.Host, a.Port)
return fmt.Sprintf("%s@%s:%d", a.Identity, a.Host, a.Port)
}

// HostPort returns just the "host:port" portion
func (a LumeraAddress) HostPort() string {
return fmt.Sprintf("%s:%d", a.Host, a.Port)
return fmt.Sprintf("%s:%d", a.Host, a.Port)
}

// ExtractIdentity extracts the identity part from an address in the format "identity@address"
// Returns the identity and the standard address
// If requireIdentity is true, an error is returned when identity is not found
func ExtractIdentity(address string, requireIdentity ...bool) (string, string, error) {
parts := strings.SplitN(address, "@", 2)

// Check if identity is required
identityRequired := false
if len(requireIdentity) > 0 {
identityRequired = requireIdentity[0]
}

if len(parts) != 2 {
// Not in Lumera format, return empty identity and original address
if identityRequired {
return "", "", fmt.Errorf("identity required but not found in address: %s", address)
}
return "", address, nil
}

identity := parts[0]
standardAddress := parts[1]

if identity == "" {
return "", "", fmt.Errorf("empty identity found in address: %s", address)
}

if standardAddress == "" {
return "", "", fmt.Errorf("missing address in: %s", address)
}

return identity, standardAddress, nil
}

// ParseLumeraAddress parses a Lumera address in the format "identity@host:port"
// and returns the components as a LumeraAddress struct.
// Returns an error if any component, including the port, is missing.
func ParseLumeraAddress(address string) (LumeraAddress, error) {
var result LumeraAddress
// Extract identity and the remainder (host:port)
identity, remainder, err := ExtractIdentity(address, true)
if err != nil {
return result, fmt.Errorf("failed to extract identity: %w", err)
}
result.Identity = identity
// Split the remainder into host and port
host, portStr, err := net.SplitHostPort(remainder)
if err != nil {
// If port is missing or any other format error, return an error
return result, fmt.Errorf("invalid host:port format: %w", err)
}
var result LumeraAddress

// Extract identity and the remainder (host:port)
identity, remainder, err := ExtractIdentity(address, true)
if err != nil {
return result, fmt.Errorf("failed to extract identity: %w", err)
}
result.Identity = identity

// Split the remainder into host and port
host, portStr, err := net.SplitHostPort(remainder)
if err != nil {
// If port is missing or any other format error, return an error
return result, fmt.Errorf("invalid host:port format: %w", err)
}
if host == "" {
return result, fmt.Errorf("missing host in address: %s", address)
}

result.Host = host

// Parse the port string to uint16
portInt, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return result, fmt.Errorf("invalid port number: %w", err)
}
result.Port = uint16(portInt)

return result, nil
}

result.Host = host

// Parse the port string to uint16
portInt, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return result, fmt.Errorf("invalid port number: %w", err)
}
result.Port = uint16(portInt)

return result, nil
}

// IsLumeraAddressFormat checks if the address is in Lumera format (contains @)
func IsLumeraAddressFormat(address string) bool {
Expand All @@ -104,4 +103,4 @@ func IsLumeraAddressFormat(address string) bool {
// FormatAddressWithIdentity creates a properly formatted address with identity@address
func FormatAddressWithIdentity(identity, address string) string {
return fmt.Sprintf("%s@%s", identity, address)
}
}
4 changes: 2 additions & 2 deletions pkg/net/credentials/address_helper_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package credentials

import (
"testing"
"reflect"
"testing"
)

func TestExtractIdentity(t *testing.T) {
Expand Down Expand Up @@ -378,4 +378,4 @@ func TestFormatAddressWithIdentity(t *testing.T) {
}
})
}
}
}
8 changes: 4 additions & 4 deletions pkg/net/credentials/alts/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ const (
// Record protocol using XChaCha20-Poly1305 with rekeying
RecordProtocolXChaCha20Poly1305ReKey = "ALTSRP_XCHACHA20_POLY1305_REKEY"

// Key sizes for different protocols
KeySizeAESGCM = 16
KeySizeAESGCMReKey = 44 // 32 bytes key + 12 bytes counter mask
KeySizeXChaCha20Poly1305ReKey = 56 // 32 bytes key + 24 bytes nonce
// Key sizes for different protocols
KeySizeAESGCM = 16
KeySizeAESGCMReKey = 44 // 32 bytes key + 12 bytes counter mask
KeySizeXChaCha20Poly1305ReKey = 56 // 32 bytes key + 24 bytes nonce
)

// ALTSRecordCrypto is the interface for gRPC ALTS record protocol.
Expand Down
10 changes: 5 additions & 5 deletions pkg/net/credentials/alts/common/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import (
"bytes"
"encoding/binary"
"errors"
"github.com/cosmos/gogoproto/proto"
"github.com/stretchr/testify/require"
"io"
"net"
"strings"
"testing"
"time"
"github.com/cosmos/gogoproto/proto"
"github.com/stretchr/testify/require"

"github.com/LumeraProtocol/lumera/x/lumeraid/securekeyx"
lumeraidtypes "github.com/LumeraProtocol/lumera/x/lumeraid/types"
Expand Down Expand Up @@ -231,10 +231,10 @@ func TestReceiveHandshakeMessageFailures(t *testing.T) {
func TestParseValidHandshakeMessage(t *testing.T) {
// Create a valid handshake message.
validHandshake := &lumeraidtypes.HandshakeInfo{
Address: "127.0.0.1:8080",
PeerType: int32(securekeyx.Supernode),
Address: "127.0.0.1:8080",
PeerType: int32(securekeyx.Supernode),
PublicKey: []byte("public-key"),
Curve: "curve-name",
Curve: "curve-name",
}
handshakeBytes, err := proto.Marshal(validHandshake)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions pkg/net/credentials/alts/conn/aes128gcm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package conn

import (
"bytes"
"testing"
"runtime"
"fmt"
"time"
"runtime"
"strconv"
"testing"
"time"

. "github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/common"
)
Expand Down
1 change: 0 additions & 1 deletion pkg/net/credentials/alts/conn/counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,3 @@ func (c *Counter) Inc() {
c.invalid = true
}
}

2 changes: 1 addition & 1 deletion pkg/net/credentials/alts/conn/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type Conn struct {

// NewConn creates a new secure channel instance given the other party role and
// handshaking result.
var NewConn = func (c net.Conn, side Side, recordProtocol string, key, protected []byte) (net.Conn, error) {
var NewConn = func(c net.Conn, side Side, recordProtocol string, key, protected []byte) (net.Conn, error) {
newCrypto := protocols[recordProtocol]
if newCrypto == nil {
return nil, fmt.Errorf("negotiated unknown next_protocol %q", recordProtocol)
Expand Down
6 changes: 3 additions & 3 deletions pkg/net/credentials/alts/handshake/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package handshake
import (
"testing"

"github.com/stretchr/testify/assert"
. "github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/common"
"github.com/LumeraProtocol/lumera/x/lumeraid/securekeyx"
. "github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/common"
"github.com/stretchr/testify/assert"
)

func TestNewAuthInfo(t *testing.T) {
Expand All @@ -15,4 +15,4 @@ func TestNewAuthInfo(t *testing.T) {
assert.Equal(t, ClientSide, autInfo.(*AuthInfo).Side, "Side should match")
assert.Equal(t, securekeyx.Simplenode, autInfo.(*AuthInfo).RemotePeerType, "RemotePeerType should match")
assert.Equal(t, "cosmos1", autInfo.(*AuthInfo).RemoteIdentity, "RemoteIdentity should match")
}
}
6 changes: 3 additions & 3 deletions pkg/net/credentials/alts/handshake/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ import (
lumeraidmocks "github.com/LumeraProtocol/lumera/x/lumeraid/mocks"
"github.com/LumeraProtocol/lumera/x/lumeraid/securekeyx"
lumeraidtypes "github.com/LumeraProtocol/lumera/x/lumeraid/types"
sntypes "github.com/LumeraProtocol/lumera/x/supernode/v1/types"
. "github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/common"
"github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/conn"
"github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/testutil"
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
sntypes "github.com/LumeraProtocol/lumera/x/supernode/v1/types"
. "github.com/LumeraProtocol/supernode/pkg/testutil"
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
)

const defaultTestTimeout = 100 * time.Second
Expand Down Expand Up @@ -331,7 +331,7 @@ func TestHandshakerConcurrentHandshakes(t *testing.T) {
SupernodeAccount: serverAddr,
}, nil).
Times(1)

serverMockValidator := lumeraidmocks.NewMockKeyExchangerValidator(ctrl)
serverMockValidator.EXPECT().
GetSupernodeBySupernodeAddress(gomock.Any(), serverAddr).
Expand Down
16 changes: 8 additions & 8 deletions pkg/net/credentials/alts/testutil/fake_handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package testutil
import (
"bytes"
"fmt"
"time"
"net"
"time"

"github.com/cosmos/gogoproto/proto"
lumeraidtypes "github.com/LumeraProtocol/lumera/x/lumeraid/types"
. "github.com/LumeraProtocol/supernode/pkg/net/credentials/alts/common"
"github.com/cosmos/gogoproto/proto"
)

// FakeHandshakeConn implements a fake handshake connection
Expand Down Expand Up @@ -52,11 +52,11 @@ func (c *FakeHandshakeConn) Close() error {

// FakeHandshaker simulates a peer in the handshake process
type FakeHandshaker struct {
conn *FakeHandshakeConn
signature []byte
pubKey []byte
peerType int32
curve string
conn *FakeHandshakeConn
signature []byte
pubKey []byte
peerType int32
curve string
}

func NewFakeHandshaker(conn *FakeHandshakeConn) *FakeHandshaker {
Expand Down Expand Up @@ -101,4 +101,4 @@ func (h *FakeHandshaker) SimulateHandshake(isClient bool) error {

func (h *FakeHandshaker) SimulateError(err error) {
h.conn.handshakeErr = err
}
}
2 changes: 1 addition & 1 deletion pkg/net/grpc/internal/leakcheck/leakcheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ import (
"sync/atomic"
"time"

"github.com/LumeraProtocol/supernode/pkg/net/grpc/mem"
"github.com/LumeraProtocol/supernode/pkg/net/grpc/internal"
"github.com/LumeraProtocol/supernode/pkg/net/grpc/mem"
)

// failTestsOnLeakedBuffers is a special flag that will cause tests to fail if
Expand Down
1 change: 1 addition & 0 deletions pkg/net/grpc/internal/leakcheck/leakcheck_enabled.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//go:build checkbuffers

package leakcheck

func init() {
Expand Down
10 changes: 5 additions & 5 deletions pkg/testutil/accounts.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package testutil

import (
"testing"
"crypto/ecdh"
"github.com/stretchr/testify/require"
"testing"

"github.com/cosmos/go-bip39"
"github.com/cosmos/cosmos-sdk/codec"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec"
"github.com/cosmos/cosmos-sdk/crypto/hd"
"github.com/cosmos/cosmos-sdk/crypto/keyring"
cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec"
cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types"
"github.com/cosmos/go-bip39"

"github.com/LumeraProtocol/lumera/x/lumeraid/securekeyx"
)
Expand All @@ -27,9 +27,9 @@ type TestAccount struct {
Address string
PubKey cryptotypes.PubKey
}

// setupTestKeyExchange creates a key exchange instance for testing
func SetupTestKeyExchange(t *testing.T, kb keyring.Keyring, addr string,
func SetupTestKeyExchange(t *testing.T, kb keyring.Keyring, addr string,
peerType securekeyx.PeerType, validator securekeyx.KeyExchangerValidator) *securekeyx.SecureKeyExchange {
ke, err := securekeyx.NewSecureKeyExchange(kb, addr, peerType, ecdh.P256(), validator)
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/testutil/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ func GetFreePortInRange(start, end int) (int, error) {
}
}
return 0, fmt.Errorf("no free port found in range %d-%d", start, end)
}
}
Loading