Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 69 additions & 22 deletions examples/cmd/benchmark_experimental.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
package cmd

import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"fmt"
"io"
"net/http"
"os"
"sync"
"time"

Expand All @@ -13,7 +18,6 @@ import (
kasp "github.com/opentdf/platform/protocol/go/kas"
"github.com/opentdf/platform/protocol/go/kas/kasconnect"
"github.com/opentdf/platform/protocol/go/policy"

"github.com/opentdf/platform/sdk/experimental/tdf"
"github.com/opentdf/platform/sdk/httputil"
"github.com/spf13/cobra"
Expand All @@ -35,7 +39,7 @@ func init() {
//nolint: mnd // no magic number, this is just default value for payload size
benchmarkCmd.Flags().IntVar(&payloadSize, "payload-size", 1024*1024, "Payload size in bytes") // Default 1MB
//nolint: mnd // same as above
benchmarkCmd.Flags().IntVar(&segmentChunk, "segment-chunks", 16*1024, "segment chunks ize") // Default 16 segments
benchmarkCmd.Flags().IntVar(&segmentChunk, "segment-chunks", 16*1024, "segment chunk size") // Default 16KB
ExamplesCmd.AddCommand(benchmarkCmd)
}

Expand All @@ -46,16 +50,21 @@ func runExperimentalWriterBenchmark(_ *cobra.Command, _ []string) error {
return fmt.Errorf("failed to generate random payload: %w", err)
}

http := httputil.SafeHTTPClient()
var httpClient *http.Client
if insecureSkipVerify {
httpClient = httputil.SafeHTTPClientWithTLSConfig(&tls.Config{InsecureSkipVerify: true}) //nolint:gosec // user-requested flag
} else {
httpClient = httputil.SafeHTTPClient()
}
fmt.Println("endpoint:", platformEndpoint)
serviceClient := kasconnect.NewAccessServiceClient(http, platformEndpoint)
serviceClient := kasconnect.NewAccessServiceClient(httpClient, platformEndpoint)
resp, err := serviceClient.PublicKey(context.Background(), connect.NewRequest(&kasp.PublicKeyRequest{Algorithm: string(ocrypto.RSA2048Key)}))
if err != nil {
return fmt.Errorf("failed to get public key from KAS: %w", err)
}
var attrs []*policy.Value

simpleyKey := &policy.SimpleKasKey{
simpleKey := &policy.SimpleKasKey{
KasUri: platformEndpoint,
KasId: "id",
PublicKey: &policy.SimpleKasPublicKey{
Expand All @@ -65,29 +74,31 @@ func runExperimentalWriterBenchmark(_ *cobra.Command, _ []string) error {
},
}

attrs = append(attrs, &policy.Value{Fqn: testAttr, KasKeys: []*policy.SimpleKasKey{simpleyKey}, Attribute: &policy.Attribute{Namespace: &policy.Namespace{Name: "example.com"}, Fqn: testAttr}})
writer, err := tdf.NewWriter(context.Background(), tdf.WithDefaultKASForWriter(simpleyKey), tdf.WithInitialAttributes(attrs), tdf.WithSegmentIntegrityAlgorithm(tdf.HS256))
attrs = append(attrs, &policy.Value{Fqn: testAttr, KasKeys: []*policy.SimpleKasKey{simpleKey}, Attribute: &policy.Attribute{Namespace: &policy.Namespace{Name: "example.com"}, Fqn: testAttr}})
writer, err := tdf.NewWriter(context.Background(), tdf.WithDefaultKASForWriter(simpleKey), tdf.WithInitialAttributes(attrs), tdf.WithSegmentIntegrityAlgorithm(tdf.HS256))
if err != nil {
return fmt.Errorf("failed to create writer: %w", err)
}
i := 0
segs := (len(payload) + segmentChunk - 1) / segmentChunk
segResults := make([]*tdf.SegmentResult, segs)
wg := sync.WaitGroup{}
segs := len(payload) / segmentChunk
wg.Add(segs)
start := time.Now()
for i < segs {
segment := i
go func() {
start := i * segmentChunk
end := min(start+segmentChunk, len(payload))
_, err = writer.WriteSegment(context.Background(), segment, payload[start:end])
if err != nil {
fmt.Println(err)
panic(err)
for i := 0; i < segs; i++ {
segStart := i * segmentChunk
segEnd := min(segStart+segmentChunk, len(payload))
// Copy the chunk: EncryptInPlace overwrites the input buffer and
// appends a 16-byte auth tag, which would corrupt adjacent segments.
chunk := make([]byte, segEnd-segStart)
copy(chunk, payload[segStart:segEnd])
go func(index int, data []byte) {
defer wg.Done()
sr, serr := writer.WriteSegment(context.Background(), index, data)
if serr != nil {
panic(serr)
}
wg.Done()
}()
i++
segResults[index] = sr
}(i, chunk)
}
wg.Wait()

Expand All @@ -98,12 +109,48 @@ func runExperimentalWriterBenchmark(_ *cobra.Command, _ []string) error {
}
totalTime := end.Sub(start)

// Assemble the complete TDF: segment data (in order) + finalize data
var tdfBuf bytes.Buffer
for i, sr := range segResults {
if _, err := io.Copy(&tdfBuf, sr.TDFData); err != nil {
return fmt.Errorf("failed to read segment %d TDF data: %w", i, err)
}
}
tdfBuf.Write(result.Data)

outPath := "/tmp/benchmark-experimental.tdf"
if err := os.WriteFile(outPath, tdfBuf.Bytes(), 0o600); err != nil {
return fmt.Errorf("failed to write TDF: %w", err)
}

fmt.Printf("# Benchmark Experimental TDF Writer Results:\n")
fmt.Printf("| Metric | Value |\n")
fmt.Printf("|--------------------|--------------|\n")
fmt.Printf("| Payload Size (B) | %d |\n", payloadSize)
fmt.Printf("| Output Size (B) | %d |\n", len(result.Data))
fmt.Printf("| Output Size (B) | %d |\n", tdfBuf.Len())
fmt.Printf("| Total Time | %s |\n", totalTime)
fmt.Printf("| TDF saved to | %s |\n", outPath)

// Decrypt with production SDK to verify interoperability
s, err := newSDK()
if err != nil {
return fmt.Errorf("failed to create SDK: %w", err)
}
defer s.Close()
tdfReader, err := s.LoadTDF(bytes.NewReader(tdfBuf.Bytes()))
if err != nil {
return fmt.Errorf("failed to load TDF with production SDK: %w", err)
}
var decrypted bytes.Buffer
if _, err = io.Copy(&decrypted, tdfReader); err != nil {
return fmt.Errorf("failed to decrypt TDF with production SDK: %w", err)
}

if bytes.Equal(payload, decrypted.Bytes()) {
fmt.Println("| Decrypt Verify | PASS - roundtrip matches |")
} else {
fmt.Printf("| Decrypt Verify | FAIL - payload %d bytes, decrypted %d bytes |\n", len(payload), decrypted.Len())
}

return nil
}
16 changes: 16 additions & 0 deletions sdk/experimental/tdf/keysplit/xor_splitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ func (x *XORSplitter) GenerateSplits(_ context.Context, attrs []*policy.Value, d
// 4. Collect all public keys from assignments
allKeys := collectAllPublicKeys(assignments)

// 5. Merge the default KAS public key if not already present.
// Attribute grants may reference the default KAS URL without including the public key
// (e.g., legacy grants with only a URI). The default KAS key fills this gap.
if x.config.defaultKAS != nil && x.config.defaultKAS.GetPublicKey() != nil {
kasURL := x.config.defaultKAS.GetKasUri()
if _, exists := allKeys[kasURL]; !exists {
pubKey := x.config.defaultKAS.GetPublicKey()
allKeys[kasURL] = KASPublicKey{
URL: kasURL,
KID: pubKey.GetKid(),
PEM: pubKey.GetPem(),
Algorithm: formatAlgorithm(pubKey.GetAlgorithm()),
}
}
}

slog.Debug("completed key split generation",
slog.Int("num_splits", len(splits)),
slog.Int("num_kas_keys", len(allKeys)))
Expand Down
70 changes: 70 additions & 0 deletions sdk/experimental/tdf/keysplit/xor_splitter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,3 +526,73 @@ func TestXORSplitter_ComplexScenarios(t *testing.T) {
assert.True(t, found, "Should find split with multiple KAS URLs")
})
}

// TestXORSplitter_DefaultKASMergedForURIOnlyGrant is a regression test
// ensuring that when an attribute grant references a KAS URL without
// embedding the public key (URI-only legacy grant), the default KAS's
// full public key info is merged into the result. Without the merge fix
// in GenerateSplits, collectAllPublicKeys returns an incomplete map and
// key wrapping fails.
func TestXORSplitter_DefaultKASMergedForURIOnlyGrant(t *testing.T) {
defaultKAS := &policy.SimpleKasKey{
KasUri: kasUs,
PublicKey: &policy.SimpleKasPublicKey{
Algorithm: policy.Algorithm_ALGORITHM_RSA_2048,
Kid: "default-kid",
Pem: mockRSAPublicKey1,
},
}
splitter := NewXORSplitter(WithDefaultKAS(defaultKAS))

dek := make([]byte, 32)
_, err := rand.Read(dek)
require.NoError(t, err)

// Create an attribute whose grant references kasUs by URI only (no KasKeys).
attr := createMockValue("https://test.com/attr/level/value/secret", "", "", policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF)
attr.Grants = []*policy.KeyAccessServer{
{Uri: kasUs}, // URI-only, no embedded public key
}

result, err := splitter.GenerateSplits(t.Context(), []*policy.Value{attr}, dek)
require.NoError(t, err)
require.NotNil(t, result)

// The default KAS public key must be merged into the result.
require.Contains(t, result.KASPublicKeys, kasUs, "default KAS key should be merged for URI-only grant")
pubKey := result.KASPublicKeys[kasUs]
assert.Equal(t, "default-kid", pubKey.KID)
assert.Equal(t, mockRSAPublicKey1, pubKey.PEM)
assert.Equal(t, "rsa:2048", pubKey.Algorithm)
}

// TestXORSplitter_DefaultKASDoesNotOverwriteExistingKey verifies that when
// an attribute grant already embeds a full public key for the same KAS URL
// as the default, the grant's key is preserved and not overwritten.
func TestXORSplitter_DefaultKASDoesNotOverwriteExistingKey(t *testing.T) {
defaultKAS := &policy.SimpleKasKey{
KasUri: kasUs,
PublicKey: &policy.SimpleKasPublicKey{
Algorithm: policy.Algorithm_ALGORITHM_RSA_2048,
Kid: "default-kid",
Pem: mockRSAPublicKey1,
},
}
splitter := NewXORSplitter(WithDefaultKAS(defaultKAS))

dek := make([]byte, 32)
_, err := rand.Read(dek)
require.NoError(t, err)

// Create an attribute with a fully-embedded grant for the same KAS URL
// but with a different KID.
attr := createMockValue("https://test.com/attr/level/value/secret", kasUs, "grant-kid", policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF)

result, err := splitter.GenerateSplits(t.Context(), []*policy.Value{attr}, dek)
require.NoError(t, err)
require.NotNil(t, result)

require.Contains(t, result.KASPublicKeys, kasUs)
pubKey := result.KASPublicKeys[kasUs]
assert.Equal(t, "grant-kid", pubKey.KID, "grant's key should not be overwritten by default KAS")
}
8 changes: 4 additions & 4 deletions sdk/experimental/tdf/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ type BaseConfig struct{}
// WriterConfig contains configuration options for TDF Writer creation.
//
// The configuration controls cryptographic algorithms and processing behavior:
// - integrityAlgorithm: Algorithm for root integrity signature calculation
// - rootIntegrityAlgorithm: Algorithm for root integrity signature calculation
// - segmentIntegrityAlgorithm: Algorithm for individual segment hash calculation
//
// These can be set independently to optimize for different security/performance requirements.
type WriterConfig struct {
BaseConfig
// integrityAlgorithm specifies the algorithm for root integrity verification
integrityAlgorithm IntegrityAlgorithm
// rootIntegrityAlgorithm specifies the algorithm for root integrity verification
rootIntegrityAlgorithm IntegrityAlgorithm
// segmentIntegrityAlgorithm specifies the algorithm for segment-level integrity
segmentIntegrityAlgorithm IntegrityAlgorithm

Expand Down Expand Up @@ -95,7 +95,7 @@ type Option[T any] func(T)
// writer, err := NewWriter(ctx, WithIntegrityAlgorithm(GMAC))
func WithIntegrityAlgorithm(algo IntegrityAlgorithm) Option[*WriterConfig] {
return func(c *WriterConfig) {
c.integrityAlgorithm = algo
c.rootIntegrityAlgorithm = algo
}
}

Expand Down
29 changes: 19 additions & 10 deletions sdk/experimental/tdf/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ type Writer struct {
func NewWriter(_ context.Context, opts ...Option[*WriterConfig]) (*Writer, error) {
// Initialize Config
config := &WriterConfig{
integrityAlgorithm: HS256,
segmentIntegrityAlgorithm: HS256,
rootIntegrityAlgorithm: HS256,
segmentIntegrityAlgorithm: GMAC,
}

for _, opt := range opts {
Expand Down Expand Up @@ -264,7 +264,11 @@ func (w *Writer) WriteSegment(ctx context.Context, index int, data []byte) (*Seg
if err != nil {
return nil, err
}
segmentSig, err := calculateSignature(segmentCipher, w.dek, w.segmentIntegrityAlgorithm, false) // Don't ever hex encode new tdf's
// Hash must cover nonce + cipher to match the standard SDK reader's verification.
// The standard SDK's Encrypt() returns nonce prepended to cipher and hashes that;
// EncryptInPlace() returns them separately, so we must concatenate for hashing.
segmentData := append(nonce, segmentCipher...) //nolint:gocritic // nonce cap == len, so always allocates
segmentSig, err := calculateSignature(segmentData, w.dek, w.segmentIntegrityAlgorithm, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -562,13 +566,18 @@ func (w *Writer) getManifest(ctx context.Context, cfg *WriterFinalizeConfig) (*M
return nil, 0, 0, errors.New("empty segment hash")
}

rootSignature, err := calculateSignature(aggregateHash.Bytes(), w.dek, w.integrityAlgorithm, false)
if err != nil {
return nil, 0, 0, err
}
encryptInfo.RootSignature = RootSignature{
Algorithm: w.integrityAlgorithm.String(),
Signature: string(ocrypto.Base64Encode([]byte(rootSignature))),
// Only compute root signature when segments have been written; stub
// manifests returned before any WriteSegment call leave the root
// signature empty.
if aggregateHash.Len() > 0 {
rootSignature, err := calculateSignature(aggregateHash.Bytes(), w.dek, w.rootIntegrityAlgorithm, false)
if err != nil {
return nil, 0, 0, err
}
encryptInfo.RootSignature = RootSignature{
Algorithm: w.rootIntegrityAlgorithm.String(),
Signature: string(ocrypto.Base64Encode([]byte(rootSignature))),
}
}

keyAccessList, err := buildKeyAccessObjects(result, policyBytes, cfg.encryptedMetadata)
Expand Down
Loading
Loading