Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions lib/ocrypto/algorithm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package ocrypto

// Algorithm identifiers for BaseTDF v4.4.0 Key Access Objects.
// These are the string values used in the "alg" field of KAOs.
const (
AlgRSAOAEP = "RSA-OAEP"
AlgRSAOAEP256 = "RSA-OAEP-256"
AlgECDHHKDF = "ECDH-HKDF"
AlgMLKEM768 = "ML-KEM-768"
AlgMLKEM1024 = "ML-KEM-1024"
AlgHybridECDH = "X-ECDH-ML-KEM-768"
)

// AlgForKeyType returns the BaseTDF algorithm identifier for the given KeyType.
// This maps internal key types to the explicit algorithm strings used in v4.4.0 KAOs.
func AlgForKeyType(kt KeyType) string {
switch kt {
case RSA2048Key, RSA4096Key:
return AlgRSAOAEP
case EC256Key, EC384Key, EC521Key:
return AlgECDHHKDF
default:
return string(kt)
}
}

// KeyTypeForAlg returns the KeyType for a given BaseTDF algorithm identifier.
// Returns the algorithm string as a KeyType if no specific mapping exists.
func KeyTypeForAlg(alg string) KeyType {
switch alg {
case AlgRSAOAEP, AlgRSAOAEP256:
return RSA2048Key
case AlgECDHHKDF:
return EC256Key
default:
return KeyType(alg)
}
Comment on lines +29 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The mapping in KeyTypeForAlg is lossy. For example, AlgForKeyType maps both RSA2048Key and RSA4096Key to AlgRSAOAEP, but this function maps AlgRSAOAEP back to only RSA2048Key. Similarly, EC256Key, EC384Key, and EC521Key are all mapped to AlgECDHHKDF, which is then mapped back only to EC256Key.

This loss of information about the key size could lead to unexpected behavior or bugs if the caller of this function relies on getting a precise key type.

If this lossy conversion is intentional and safe within the current design, it would be beneficial to add a comment explaining why it's acceptable. Otherwise, consider making the mapping more specific or returning an error for ambiguous cases.

}

// AlgForLegacyType maps the v4.3.0 "type" field values to algorithm identifiers.
func AlgForLegacyType(legacyType string) string {
switch legacyType {
case "wrapped":
return AlgRSAOAEP
case "ec-wrapped":
return AlgECDHHKDF
default:
return ""
}
}
58 changes: 58 additions & 0 deletions lib/ocrypto/algorithm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package ocrypto

import "testing"

func TestAlgForKeyType(t *testing.T) {
tests := []struct {
keyType KeyType
want string
}{
{RSA2048Key, AlgRSAOAEP},
{RSA4096Key, AlgRSAOAEP},
{EC256Key, AlgECDHHKDF},
{EC384Key, AlgECDHHKDF},
{EC521Key, AlgECDHHKDF},
{"", ""},
{KeyType("unknown"), "unknown"},
}
for _, tt := range tests {
if got := AlgForKeyType(tt.keyType); got != tt.want {
t.Errorf("AlgForKeyType(%q) = %q, want %q", tt.keyType, got, tt.want)
}
}
}
Comment on lines +5 to +23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better test readability and output, consider using table-driven subtests with t.Run. This makes it easier to identify which specific test case fails.

Example:

func TestAlgForKeyType(t *testing.T) {
	tests := []struct {
		name    string
		keyType KeyType
		want    string
	}{
		{"RSA2048Key", RSA2048Key, AlgRSAOAEP},
		{"RSA4096Key", RSA4096Key, AlgRSAOAEP},
		// ... other test cases
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if got := AlgForKeyType(tt.keyType); got != tt.want {
				t.Errorf("AlgForKeyType(%q) = %q, want %q", tt.keyType, got, tt.want)
			}
		})
	}
}

This can be applied to the other test functions in this file as well.


func TestKeyTypeForAlg(t *testing.T) {
tests := []struct {
alg string
want KeyType
}{
{AlgRSAOAEP, RSA2048Key},
{AlgRSAOAEP256, RSA2048Key},
{AlgECDHHKDF, EC256Key},
{"", KeyType("")},
{"unknown", KeyType("unknown")},
}
for _, tt := range tests {
if got := KeyTypeForAlg(tt.alg); got != tt.want {
t.Errorf("KeyTypeForAlg(%q) = %q, want %q", tt.alg, got, tt.want)
}
}
}

func TestAlgForLegacyType(t *testing.T) {
tests := []struct {
legacy string
want string
}{
{"wrapped", AlgRSAOAEP},
{"ec-wrapped", AlgECDHHKDF},
{"", ""},
{"unknown", ""},
}
for _, tt := range tests {
if got := AlgForLegacyType(tt.legacy); got != tt.want {
t.Errorf("AlgForLegacyType(%q) = %q, want %q", tt.legacy, got, tt.want)
}
}
}
8 changes: 4 additions & 4 deletions sdk/codegen/runner/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ func New%s%s%sConnectWrapper(httpClient connect.HTTPClient, baseURL string, opts
func generateInterfaceType(interfaceName string, methods []string, packageName, prefix, suffix string) string {
// Generate the interface type definition
var builder strings.Builder
builder.WriteString(fmt.Sprintf(`
fmt.Fprintf(&builder, `
type %s%s%s interface {
`, prefix, interfaceName, suffix))
`, prefix, interfaceName, suffix)
for _, method := range methods {
builder.WriteString(fmt.Sprintf(` %s(ctx context.Context, req *%s.%sRequest) (*%s.%sResponse, error)
`, method, packageName, method, packageName, method))
fmt.Fprintf(&builder, ` %s(ctx context.Context, req *%s.%sRequest) (*%s.%sResponse, error)
`, method, packageName, method, packageName, method)
}
builder.WriteString("}\n")
return builder.String()
Expand Down
4 changes: 2 additions & 2 deletions sdk/kas_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ type KASKeyFetcher interface {

func (s SDK) getPublicKey(ctx context.Context, kasurl, algorithm, kidToFind string) (*KASInfo, error) {
if s.kasKeyCache != nil {
if cachedValue := s.kasKeyCache.get(kasurl, algorithm, kidToFind); nil != cachedValue {
if cachedValue := s.get(kasurl, algorithm, kidToFind); nil != cachedValue {
return cachedValue, nil
}
}
Expand Down Expand Up @@ -459,7 +459,7 @@ func (s SDK) getPublicKey(ctx context.Context, kasurl, algorithm, kidToFind stri
PublicKey: resp.Msg.GetPublicKey(),
}
if s.kasKeyCache != nil {
s.kasKeyCache.store(ki)
s.store(ki)
}
return &ki, nil
}
12 changes: 6 additions & 6 deletions sdk/kas_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,19 @@ func Test_StoreKASKeys(t *testing.T) {
)
require.NoError(t, err)

assert.Nil(t, s.kasKeyCache.get("https://localhost:8080", "ec:secp256r1", "e1"))
assert.Nil(t, s.kasKeyCache.get("https://localhost:8080", "rsa:2048", "r1"))
assert.Nil(t, s.get("https://localhost:8080", "ec:secp256r1", "e1"))
assert.Nil(t, s.get("https://localhost:8080", "rsa:2048", "r1"))

require.NoError(t, s.StoreKASKeys("https://localhost:8080", &policy.KasPublicKeySet{
Keys: []*policy.KasPublicKey{
{Pem: "sample", Kid: "e1", Alg: policy.KasPublicKeyAlgEnum_KAS_PUBLIC_KEY_ALG_ENUM_EC_SECP256R1},
{Pem: "sample", Kid: "r1", Alg: policy.KasPublicKeyAlgEnum_KAS_PUBLIC_KEY_ALG_ENUM_RSA_2048},
},
}))
assert.Nil(t, s.kasKeyCache.get("https://nowhere", "alg:unknown", ""))
assert.Nil(t, s.kasKeyCache.get("https://localhost:8080", "alg:unknown", ""))
ecKey := s.kasKeyCache.get("https://localhost:8080", "ec:secp256r1", "e1")
rsaKey := s.kasKeyCache.get("https://localhost:8080", "rsa:2048", "r1")
assert.Nil(t, s.get("https://nowhere", "alg:unknown", ""))
assert.Nil(t, s.get("https://localhost:8080", "alg:unknown", ""))
ecKey := s.get("https://localhost:8080", "ec:secp256r1", "e1")
rsaKey := s.get("https://localhost:8080", "rsa:2048", "r1")
require.NotNil(t, ecKey)
require.Equal(t, "e1", ecKey.KID)
require.NotNil(t, rsaKey)
Expand Down
43 changes: 33 additions & 10 deletions sdk/manifest.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package sdk

import (
"encoding/json"

"github.com/opentdf/platform/lib/ocrypto"
)

type Segment struct {
Hash string `json:"hash"`
Size int64 `json:"segmentSize"`
Expand All @@ -20,16 +26,33 @@ type IntegrityInformation struct {
}

type KeyAccess struct {
KeyType string `json:"type"`
KasURL string `json:"url"`
Protocol string `json:"protocol"`
WrappedKey string `json:"wrappedKey"`
PolicyBinding interface{} `json:"policyBinding"`
EncryptedMetadata string `json:"encryptedMetadata,omitempty"`
KID string `json:"kid,omitempty"`
SplitID string `json:"sid,omitempty"`
SchemaVersion string `json:"schemaVersion,omitempty"`
EphemeralPublicKey string `json:"ephemeralPublicKey,omitempty"`
KeyType string `json:"type"`
Algorithm string `json:"alg,omitempty"`
KasURL string `json:"url"`
Protocol string `json:"protocol"`
WrappedKey string `json:"wrappedKey"`
PolicyBinding any `json:"policyBinding"`
EncryptedMetadata string `json:"encryptedMetadata,omitempty"`
KID string `json:"kid,omitempty"`
SplitID string `json:"sid,omitempty"`
SchemaVersion string `json:"schemaVersion,omitempty"`
EphemeralPublicKey string `json:"ephemeralPublicKey,omitempty"`
}

func (ka *KeyAccess) UnmarshalJSON(data []byte) error {
// Use an alias to avoid infinite recursion
type keyAccessAlias KeyAccess
var raw keyAccessAlias
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
*ka = KeyAccess(raw)

// If Algorithm not set but KeyType is, infer algorithm from legacy type
if ka.Algorithm == "" && ka.KeyType != "" {
ka.Algorithm = ocrypto.AlgForLegacyType(ka.KeyType)
}
return nil
}

type PolicyBinding struct {
Expand Down
116 changes: 116 additions & 0 deletions sdk/manifest_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package sdk

import (
"encoding/json"
"testing"
)

func TestKeyAccessUnmarshalJSON_V43Compat(t *testing.T) {
tests := []struct {
name string
input string
wantKeyType string
wantAlg string
}{
{
name: "v4.3 wrapped infers RSA-OAEP",
input: `{"type":"wrapped","url":"https://kas.example.com","protocol":"kas","wrappedKey":"abc"}`,
wantKeyType: "wrapped",
wantAlg: "RSA-OAEP",
},
{
name: "v4.3 ec-wrapped infers ECDH-HKDF",
input: `{"type":"ec-wrapped","url":"https://kas.example.com","protocol":"kas","wrappedKey":"abc"}`,
wantKeyType: "ec-wrapped",
wantAlg: "ECDH-HKDF",
},
{
name: "v4.4 explicit alg preserved",
input: `{"type":"wrapped","alg":"RSA-OAEP-256","url":"https://kas.example.com","protocol":"kas","wrappedKey":"abc"}`,
wantKeyType: "wrapped",
wantAlg: "RSA-OAEP-256",
},
{
name: "v4.4 ML-KEM-768 preserved",
input: `{"type":"","alg":"ML-KEM-768","url":"https://kas.example.com","wrappedKey":"abc"}`,
wantKeyType: "",
wantAlg: "ML-KEM-768",
},
{
name: "unknown type does not infer alg",
input: `{"type":"remote","url":"https://kas.example.com","wrappedKey":"abc"}`,
wantKeyType: "remote",
wantAlg: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var ka KeyAccess
if err := json.Unmarshal([]byte(tt.input), &ka); err != nil {
t.Fatalf("UnmarshalJSON() error = %v", err)
}
if ka.KeyType != tt.wantKeyType {
t.Errorf("KeyType = %q, want %q", ka.KeyType, tt.wantKeyType)
}
if ka.Algorithm != tt.wantAlg {
t.Errorf("Algorithm = %q, want %q", ka.Algorithm, tt.wantAlg)
}
})
}
}

func TestKeyAccessMarshalJSON_V44Fields(t *testing.T) {
ka := KeyAccess{
KeyType: "wrapped",
Algorithm: "RSA-OAEP",
KasURL: "https://kas.example.com",
Protocol: "kas",
WrappedKey: "abc",
PolicyBinding: PolicyBinding{Alg: "HS256", Hash: "def"},
KID: "k1",
SplitID: "s1",
}

data, err := json.Marshal(ka)
if err != nil {
t.Fatalf("Marshal() error = %v", err)
}

var roundTrip KeyAccess
if err := json.Unmarshal(data, &roundTrip); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if roundTrip.Algorithm != "RSA-OAEP" {
t.Errorf("Algorithm = %q, want %q", roundTrip.Algorithm, "RSA-OAEP")
}
if roundTrip.KID != "k1" {
t.Errorf("KID = %q, want %q", roundTrip.KID, "k1")
}
if roundTrip.SplitID != "s1" {
t.Errorf("SplitID = %q, want %q", roundTrip.SplitID, "s1")
}
}

func TestKeyAccessMarshalJSON_OmitEmptyAlg(t *testing.T) {
ka := KeyAccess{
KeyType: "wrapped",
KasURL: "https://kas.example.com",
Protocol: "kas",
WrappedKey: "abc",
PolicyBinding: "hash",
}

data, err := json.Marshal(ka)
if err != nil {
t.Fatalf("Marshal() error = %v", err)
}

// Verify "alg" is not present when empty
var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if _, ok := raw["alg"]; ok {
t.Error("expected alg field to be omitted when empty")
}
}
2 changes: 1 addition & 1 deletion sdk/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ func getTokenEndpoint(c config) (string, error) {
// so only store the most recent known key per url & algorithm pair.
func (s *SDK) StoreKASKeys(url string, keys *policy.KasPublicKeySet) error {
for _, key := range keys.GetKeys() {
s.kasKeyCache.store(KASInfo{
s.store(KASInfo{
URL: url,
PublicKey: key.GetPem(),
KID: key.GetKid(),
Expand Down
14 changes: 11 additions & 3 deletions sdk/tdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,15 @@ func (s SDK) prepareManifest(ctx context.Context, t *TDFObject, tdfConfig TDFCon
symKeys = append(symKeys, symKey)

// policy binding
policyBindingHash := hex.EncodeToString(ocrypto.CalculateSHA256Hmac(symKey, base64PolicyObject))
pbstring := string(ocrypto.Base64Encode([]byte(policyBindingHash)))
hmacBytes := ocrypto.CalculateSHA256Hmac(symKey, base64PolicyObject)
var pbstring string
if tdfConfig.useHex {
// Legacy format: hex encode then base64
pbstring = string(ocrypto.Base64Encode([]byte(hex.EncodeToString(hmacBytes))))
} else {
// v4.4.0 format: direct base64 of HMAC bytes
pbstring = string(ocrypto.Base64Encode(hmacBytes))
}
policyBinding := PolicyBinding{
Alg: "HS256",
Hash: pbstring,
Expand Down Expand Up @@ -662,8 +669,10 @@ func encryptMetadata(symKey []byte, metaData string) (string, error) {
}

func createKeyAccess(kasInfo KASInfo, symKey []byte, policyBinding PolicyBinding, encryptedMetadata, splitID string) (KeyAccess, error) {
ktype := ocrypto.KeyType(kasInfo.Algorithm)
keyAccess := KeyAccess{
KeyType: kWrapped,
Algorithm: ocrypto.AlgForKeyType(ktype),
KasURL: kasInfo.URL,
KID: kasInfo.KID,
Protocol: kKasProtocol,
Expand All @@ -673,7 +682,6 @@ func createKeyAccess(kasInfo KASInfo, symKey []byte, policyBinding PolicyBinding
SchemaVersion: keyAccessSchemaVersion,
}

ktype := ocrypto.KeyType(kasInfo.Algorithm)
if ocrypto.IsECKeyType(ktype) {
mode, err := ocrypto.ECKeyTypeToMode(ktype)
if err != nil {
Expand Down
Loading
Loading