diff --git a/examples/cmd/benchmark_experimental.go b/examples/cmd/benchmark_experimental.go index f50c5cc348..f18ec0a4a0 100644 --- a/examples/cmd/benchmark_experimental.go +++ b/examples/cmd/benchmark_experimental.go @@ -2,9 +2,14 @@ package cmd import ( + "bytes" "context" "crypto/rand" + "crypto/tls" "fmt" + "io" + "net/http" + "os" "sync" "time" @@ -13,7 +18,6 @@ import ( kasp "github.com/opentdf/platform/protocol/go/kas" "github.com/opentdf/platform/protocol/go/kas/kasconnect" "github.com/opentdf/platform/protocol/go/policy" - "github.com/opentdf/platform/sdk/experimental/tdf" "github.com/opentdf/platform/sdk/httputil" "github.com/spf13/cobra" @@ -35,7 +39,7 @@ func init() { //nolint: mnd // no magic number, this is just default value for payload size benchmarkCmd.Flags().IntVar(&payloadSize, "payload-size", 1024*1024, "Payload size in bytes") // Default 1MB //nolint: mnd // same as above - benchmarkCmd.Flags().IntVar(&segmentChunk, "segment-chunks", 16*1024, "segment chunks ize") // Default 16 segments + benchmarkCmd.Flags().IntVar(&segmentChunk, "segment-chunks", 16*1024, "segment chunk size") // Default 16KB ExamplesCmd.AddCommand(benchmarkCmd) } @@ -46,16 +50,21 @@ func runExperimentalWriterBenchmark(_ *cobra.Command, _ []string) error { return fmt.Errorf("failed to generate random payload: %w", err) } - http := httputil.SafeHTTPClient() + var httpClient *http.Client + if insecureSkipVerify { + httpClient = httputil.SafeHTTPClientWithTLSConfig(&tls.Config{InsecureSkipVerify: true}) //nolint:gosec // user-requested flag + } else { + httpClient = httputil.SafeHTTPClient() + } fmt.Println("endpoint:", platformEndpoint) - serviceClient := kasconnect.NewAccessServiceClient(http, platformEndpoint) + serviceClient := kasconnect.NewAccessServiceClient(httpClient, platformEndpoint) resp, err := serviceClient.PublicKey(context.Background(), connect.NewRequest(&kasp.PublicKeyRequest{Algorithm: string(ocrypto.RSA2048Key)})) if err != nil { return fmt.Errorf("failed to get public key from KAS: %w", err) } var attrs []*policy.Value - simpleyKey := &policy.SimpleKasKey{ + simpleKey := &policy.SimpleKasKey{ KasUri: platformEndpoint, KasId: "id", PublicKey: &policy.SimpleKasPublicKey{ @@ -65,29 +74,31 @@ func runExperimentalWriterBenchmark(_ *cobra.Command, _ []string) error { }, } - attrs = append(attrs, &policy.Value{Fqn: testAttr, KasKeys: []*policy.SimpleKasKey{simpleyKey}, Attribute: &policy.Attribute{Namespace: &policy.Namespace{Name: "example.com"}, Fqn: testAttr}}) - writer, err := tdf.NewWriter(context.Background(), tdf.WithDefaultKASForWriter(simpleyKey), tdf.WithInitialAttributes(attrs), tdf.WithSegmentIntegrityAlgorithm(tdf.HS256)) + attrs = append(attrs, &policy.Value{Fqn: testAttr, KasKeys: []*policy.SimpleKasKey{simpleKey}, Attribute: &policy.Attribute{Namespace: &policy.Namespace{Name: "example.com"}, Fqn: testAttr}}) + writer, err := tdf.NewWriter(context.Background(), tdf.WithDefaultKASForWriter(simpleKey), tdf.WithInitialAttributes(attrs), tdf.WithSegmentIntegrityAlgorithm(tdf.HS256)) if err != nil { return fmt.Errorf("failed to create writer: %w", err) } - i := 0 + segs := (len(payload) + segmentChunk - 1) / segmentChunk + segResults := make([]*tdf.SegmentResult, segs) wg := sync.WaitGroup{} - segs := len(payload) / segmentChunk wg.Add(segs) start := time.Now() - for i < segs { - segment := i - go func() { - start := i * segmentChunk - end := min(start+segmentChunk, len(payload)) - _, err = writer.WriteSegment(context.Background(), segment, payload[start:end]) - if err != nil { - fmt.Println(err) - panic(err) + for i := 0; i < segs; i++ { + segStart := i * segmentChunk + segEnd := min(segStart+segmentChunk, len(payload)) + // Copy the chunk: EncryptInPlace overwrites the input buffer and + // appends a 16-byte auth tag, which would corrupt adjacent segments. + chunk := make([]byte, segEnd-segStart) + copy(chunk, payload[segStart:segEnd]) + go func(index int, data []byte) { + defer wg.Done() + sr, serr := writer.WriteSegment(context.Background(), index, data) + if serr != nil { + panic(serr) } - wg.Done() - }() - i++ + segResults[index] = sr + }(i, chunk) } wg.Wait() @@ -98,12 +109,48 @@ func runExperimentalWriterBenchmark(_ *cobra.Command, _ []string) error { } totalTime := end.Sub(start) + // Assemble the complete TDF: segment data (in order) + finalize data + var tdfBuf bytes.Buffer + for i, sr := range segResults { + if _, err := io.Copy(&tdfBuf, sr.TDFData); err != nil { + return fmt.Errorf("failed to read segment %d TDF data: %w", i, err) + } + } + tdfBuf.Write(result.Data) + + outPath := "/tmp/benchmark-experimental.tdf" + if err := os.WriteFile(outPath, tdfBuf.Bytes(), 0o600); err != nil { + return fmt.Errorf("failed to write TDF: %w", err) + } + fmt.Printf("# Benchmark Experimental TDF Writer Results:\n") fmt.Printf("| Metric | Value |\n") fmt.Printf("|--------------------|--------------|\n") fmt.Printf("| Payload Size (B) | %d |\n", payloadSize) - fmt.Printf("| Output Size (B) | %d |\n", len(result.Data)) + fmt.Printf("| Output Size (B) | %d |\n", tdfBuf.Len()) fmt.Printf("| Total Time | %s |\n", totalTime) + fmt.Printf("| TDF saved to | %s |\n", outPath) + + // Decrypt with production SDK to verify interoperability + s, err := newSDK() + if err != nil { + return fmt.Errorf("failed to create SDK: %w", err) + } + defer s.Close() + tdfReader, err := s.LoadTDF(bytes.NewReader(tdfBuf.Bytes())) + if err != nil { + return fmt.Errorf("failed to load TDF with production SDK: %w", err) + } + var decrypted bytes.Buffer + if _, err = io.Copy(&decrypted, tdfReader); err != nil { + return fmt.Errorf("failed to decrypt TDF with production SDK: %w", err) + } + + if bytes.Equal(payload, decrypted.Bytes()) { + fmt.Println("| Decrypt Verify | PASS - roundtrip matches |") + } else { + fmt.Printf("| Decrypt Verify | FAIL - payload %d bytes, decrypted %d bytes |\n", len(payload), decrypted.Len()) + } return nil } diff --git a/sdk/experimental/tdf/keysplit/xor_splitter.go b/sdk/experimental/tdf/keysplit/xor_splitter.go index 150a2c2028..accc1875f6 100644 --- a/sdk/experimental/tdf/keysplit/xor_splitter.go +++ b/sdk/experimental/tdf/keysplit/xor_splitter.go @@ -124,6 +124,22 @@ func (x *XORSplitter) GenerateSplits(_ context.Context, attrs []*policy.Value, d // 4. Collect all public keys from assignments allKeys := collectAllPublicKeys(assignments) + // 5. Merge the default KAS public key if not already present. + // Attribute grants may reference the default KAS URL without including the public key + // (e.g., legacy grants with only a URI). The default KAS key fills this gap. + if x.config.defaultKAS != nil && x.config.defaultKAS.GetPublicKey() != nil { + kasURL := x.config.defaultKAS.GetKasUri() + if _, exists := allKeys[kasURL]; !exists { + pubKey := x.config.defaultKAS.GetPublicKey() + allKeys[kasURL] = KASPublicKey{ + URL: kasURL, + KID: pubKey.GetKid(), + PEM: pubKey.GetPem(), + Algorithm: formatAlgorithm(pubKey.GetAlgorithm()), + } + } + } + slog.Debug("completed key split generation", slog.Int("num_splits", len(splits)), slog.Int("num_kas_keys", len(allKeys))) diff --git a/sdk/experimental/tdf/keysplit/xor_splitter_test.go b/sdk/experimental/tdf/keysplit/xor_splitter_test.go index e5bd229abf..e620c5cd9a 100644 --- a/sdk/experimental/tdf/keysplit/xor_splitter_test.go +++ b/sdk/experimental/tdf/keysplit/xor_splitter_test.go @@ -526,3 +526,73 @@ func TestXORSplitter_ComplexScenarios(t *testing.T) { assert.True(t, found, "Should find split with multiple KAS URLs") }) } + +// TestXORSplitter_DefaultKASMergedForURIOnlyGrant is a regression test +// ensuring that when an attribute grant references a KAS URL without +// embedding the public key (URI-only legacy grant), the default KAS's +// full public key info is merged into the result. Without the merge fix +// in GenerateSplits, collectAllPublicKeys returns an incomplete map and +// key wrapping fails. +func TestXORSplitter_DefaultKASMergedForURIOnlyGrant(t *testing.T) { + defaultKAS := &policy.SimpleKasKey{ + KasUri: kasUs, + PublicKey: &policy.SimpleKasPublicKey{ + Algorithm: policy.Algorithm_ALGORITHM_RSA_2048, + Kid: "default-kid", + Pem: mockRSAPublicKey1, + }, + } + splitter := NewXORSplitter(WithDefaultKAS(defaultKAS)) + + dek := make([]byte, 32) + _, err := rand.Read(dek) + require.NoError(t, err) + + // Create an attribute whose grant references kasUs by URI only (no KasKeys). + attr := createMockValue("https://test.com/attr/level/value/secret", "", "", policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF) + attr.Grants = []*policy.KeyAccessServer{ + {Uri: kasUs}, // URI-only, no embedded public key + } + + result, err := splitter.GenerateSplits(t.Context(), []*policy.Value{attr}, dek) + require.NoError(t, err) + require.NotNil(t, result) + + // The default KAS public key must be merged into the result. + require.Contains(t, result.KASPublicKeys, kasUs, "default KAS key should be merged for URI-only grant") + pubKey := result.KASPublicKeys[kasUs] + assert.Equal(t, "default-kid", pubKey.KID) + assert.Equal(t, mockRSAPublicKey1, pubKey.PEM) + assert.Equal(t, "rsa:2048", pubKey.Algorithm) +} + +// TestXORSplitter_DefaultKASDoesNotOverwriteExistingKey verifies that when +// an attribute grant already embeds a full public key for the same KAS URL +// as the default, the grant's key is preserved and not overwritten. +func TestXORSplitter_DefaultKASDoesNotOverwriteExistingKey(t *testing.T) { + defaultKAS := &policy.SimpleKasKey{ + KasUri: kasUs, + PublicKey: &policy.SimpleKasPublicKey{ + Algorithm: policy.Algorithm_ALGORITHM_RSA_2048, + Kid: "default-kid", + Pem: mockRSAPublicKey1, + }, + } + splitter := NewXORSplitter(WithDefaultKAS(defaultKAS)) + + dek := make([]byte, 32) + _, err := rand.Read(dek) + require.NoError(t, err) + + // Create an attribute with a fully-embedded grant for the same KAS URL + // but with a different KID. + attr := createMockValue("https://test.com/attr/level/value/secret", kasUs, "grant-kid", policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF) + + result, err := splitter.GenerateSplits(t.Context(), []*policy.Value{attr}, dek) + require.NoError(t, err) + require.NotNil(t, result) + + require.Contains(t, result.KASPublicKeys, kasUs) + pubKey := result.KASPublicKeys[kasUs] + assert.Equal(t, "grant-kid", pubKey.KID, "grant's key should not be overwritten by default KAS") +} diff --git a/sdk/experimental/tdf/options.go b/sdk/experimental/tdf/options.go index e2c7e70d31..4ea2301129 100644 --- a/sdk/experimental/tdf/options.go +++ b/sdk/experimental/tdf/options.go @@ -42,14 +42,14 @@ type BaseConfig struct{} // WriterConfig contains configuration options for TDF Writer creation. // // The configuration controls cryptographic algorithms and processing behavior: -// - integrityAlgorithm: Algorithm for root integrity signature calculation +// - rootIntegrityAlgorithm: Algorithm for root integrity signature calculation // - segmentIntegrityAlgorithm: Algorithm for individual segment hash calculation // // These can be set independently to optimize for different security/performance requirements. type WriterConfig struct { BaseConfig - // integrityAlgorithm specifies the algorithm for root integrity verification - integrityAlgorithm IntegrityAlgorithm + // rootIntegrityAlgorithm specifies the algorithm for root integrity verification + rootIntegrityAlgorithm IntegrityAlgorithm // segmentIntegrityAlgorithm specifies the algorithm for segment-level integrity segmentIntegrityAlgorithm IntegrityAlgorithm @@ -95,7 +95,7 @@ type Option[T any] func(T) // writer, err := NewWriter(ctx, WithIntegrityAlgorithm(GMAC)) func WithIntegrityAlgorithm(algo IntegrityAlgorithm) Option[*WriterConfig] { return func(c *WriterConfig) { - c.integrityAlgorithm = algo + c.rootIntegrityAlgorithm = algo } } diff --git a/sdk/experimental/tdf/writer.go b/sdk/experimental/tdf/writer.go index 02d2f8af53..d5c072e7cd 100644 --- a/sdk/experimental/tdf/writer.go +++ b/sdk/experimental/tdf/writer.go @@ -156,8 +156,8 @@ type Writer struct { func NewWriter(_ context.Context, opts ...Option[*WriterConfig]) (*Writer, error) { // Initialize Config config := &WriterConfig{ - integrityAlgorithm: HS256, - segmentIntegrityAlgorithm: HS256, + rootIntegrityAlgorithm: HS256, + segmentIntegrityAlgorithm: GMAC, } for _, opt := range opts { @@ -264,7 +264,11 @@ func (w *Writer) WriteSegment(ctx context.Context, index int, data []byte) (*Seg if err != nil { return nil, err } - segmentSig, err := calculateSignature(segmentCipher, w.dek, w.segmentIntegrityAlgorithm, false) // Don't ever hex encode new tdf's + // Hash must cover nonce + cipher to match the standard SDK reader's verification. + // The standard SDK's Encrypt() returns nonce prepended to cipher and hashes that; + // EncryptInPlace() returns them separately, so we must concatenate for hashing. + segmentData := append(nonce, segmentCipher...) //nolint:gocritic // nonce cap == len, so always allocates + segmentSig, err := calculateSignature(segmentData, w.dek, w.segmentIntegrityAlgorithm, false) if err != nil { return nil, err } @@ -562,13 +566,18 @@ func (w *Writer) getManifest(ctx context.Context, cfg *WriterFinalizeConfig) (*M return nil, 0, 0, errors.New("empty segment hash") } - rootSignature, err := calculateSignature(aggregateHash.Bytes(), w.dek, w.integrityAlgorithm, false) - if err != nil { - return nil, 0, 0, err - } - encryptInfo.RootSignature = RootSignature{ - Algorithm: w.integrityAlgorithm.String(), - Signature: string(ocrypto.Base64Encode([]byte(rootSignature))), + // Only compute root signature when segments have been written; stub + // manifests returned before any WriteSegment call leave the root + // signature empty. + if aggregateHash.Len() > 0 { + rootSignature, err := calculateSignature(aggregateHash.Bytes(), w.dek, w.rootIntegrityAlgorithm, false) + if err != nil { + return nil, 0, 0, err + } + encryptInfo.RootSignature = RootSignature{ + Algorithm: w.rootIntegrityAlgorithm.String(), + Signature: string(ocrypto.Base64Encode([]byte(rootSignature))), + } } keyAccessList, err := buildKeyAccessObjects(result, policyBytes, cfg.encryptedMetadata) diff --git a/sdk/experimental/tdf/writer_test.go b/sdk/experimental/tdf/writer_test.go index 296c582041..ff4b145b0c 100644 --- a/sdk/experimental/tdf/writer_test.go +++ b/sdk/experimental/tdf/writer_test.go @@ -4,8 +4,12 @@ package tdf import ( "bytes" + "crypto/hmac" "crypto/rand" + "crypto/sha256" + "encoding/base64" "encoding/json" + "io" "os" "path/filepath" "runtime" @@ -14,6 +18,7 @@ import ( "github.com/opentdf/platform/lib/ocrypto" "github.com/opentdf/platform/protocol/go/policy" + "github.com/opentdf/platform/sdk/internal/zipstream" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xeipuuv/gojsonschema" @@ -58,6 +63,8 @@ func TestWriterEndToEnd(t *testing.T) { {"GetManifestIncludesInitialPolicy", testGetManifestIncludesInitialPolicy}, {"SparseIndicesInOrder", testSparseIndicesInOrder}, {"SparseIndicesOutOfOrder", testSparseIndicesOutOfOrder}, + {"SegmentHashCoversNonceAndCipher", testSegmentHashCoversNonceAndCipher}, + {"FinalizeWithURIOnlyGrant", testFinalizeWithURIOnlyGrant}, } for _, tc := range testCases { @@ -104,9 +111,9 @@ func testGetManifestIncludesInitialPolicy(t *testing.T) { } assert.True(t, found, "provisional policy should include initial attribute FQN") - // Pre-finalize manifest should include kaos based on initial attributes, and estimated root signature + // Pre-finalize manifest should include kaos based on initial attributes. + // Root signature is empty when no segments have been written (GMAC requires data). assert.Len(t, m.KeyAccessObjs, 1) - assert.NotEmpty(t, m.Signature) } // Sparse indices end-to-end: write 0,1,2,5000,5001,5002 and verify manifest and totals. @@ -188,6 +195,95 @@ func testSparseIndicesOutOfOrder(t *testing.T) { assert.Equal(t, int64(expectedPlain), fin.TotalSize) } +// testSegmentHashCoversNonceAndCipher is a regression test ensuring that the +// HS256 segment hash covers nonce+ciphertext, not ciphertext alone. +// +// The standard SDK's Encrypt() returns nonce prepended to ciphertext and +// hashes that combined blob; the experimental SDK's EncryptInPlace() returns +// them separately, so the writer must concatenate before hashing. +// +// Only HS256 is tested because GMAC extracts the last 16 bytes of data as +// the tag — stripping the nonce prefix doesn't change the tail, so GMAC is +// structurally unable to detect a nonce-exclusion regression. +func testSegmentHashCoversNonceAndCipher(t *testing.T) { + ctx := t.Context() + + writer, err := NewWriter(ctx, WithSegmentIntegrityAlgorithm(HS256)) + require.NoError(t, err) + + testData := []byte("segment hash regression test payload") + result, err := writer.WriteSegment(ctx, 0, testData) + require.NoError(t, err) + + // Read all bytes from the TDFData reader to get the full segment output. + allBytes, err := io.ReadAll(result.TDFData) + require.NoError(t, err) + + // The last EncryptedSize bytes are the encrypted segment (nonce + cipher). + // Everything before that is the ZIP local file header. + encryptedData := allBytes[len(allBytes)-int(result.EncryptedSize):] + + // Positive assertion: independently compute HMAC-SHA256 over nonce+cipher + // using crypto/hmac directly (not the production calculateSignature path) + // and verify it matches the stored hash. + mac := hmac.New(sha256.New, writer.dek) + mac.Write(encryptedData) + expectedHash := base64.StdEncoding.EncodeToString(mac.Sum(nil)) + assert.Equal(t, expectedHash, result.Hash, "hash should equal independent HMAC-SHA256 over nonce+ciphertext") + + // Negative / regression assertion: independently compute HMAC-SHA256 over + // cipher-only (stripping the 12-byte GCM nonce). If someone reverts the + // fix so only cipher is hashed, the stored hash would match this value. + cipherOnly := encryptedData[ocrypto.GcmStandardNonceSize:] + wrongMac := hmac.New(sha256.New, writer.dek) + wrongMac.Write(cipherOnly) + wrongHash := base64.StdEncoding.EncodeToString(wrongMac.Sum(nil)) + assert.NotEqual(t, wrongHash, result.Hash, "hash must NOT match cipher-only (nonce must be included)") +} + +// testFinalizeWithURIOnlyGrant is an end-to-end regression test ensuring +// that Finalize succeeds when attribute grants reference a KAS URL without +// embedding the public key (URI-only legacy grants). The default KAS must +// supply the missing key information. Without the merge fix in +// GenerateSplits, key wrapping fails with "no valid key access objects". +func testFinalizeWithURIOnlyGrant(t *testing.T) { + ctx := t.Context() + + defaultKAS := &policy.SimpleKasKey{ + KasUri: testKAS1, + PublicKey: &policy.SimpleKasPublicKey{ + Algorithm: policy.Algorithm_ALGORITHM_RSA_2048, + Kid: "default-kid", + Pem: mockRSAPublicKey1, + }, + } + + writer, err := NewWriter(ctx, WithDefaultKASForWriter(defaultKAS)) + require.NoError(t, err) + + _, err = writer.WriteSegment(ctx, 0, []byte("uri-only grant test")) + require.NoError(t, err) + + // Create attribute with a URI-only grant (no KasKeys / no embedded public key). + uriOnlyAttr := createTestAttributeWithRule( + "https://example.com/attr/Level/value/Secret", + "", "", // no KAS URL → no grants added by helper + policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ALL_OF, + ) + uriOnlyAttr.Grants = []*policy.KeyAccessServer{ + {Uri: testKAS1}, // URI-only, no KasKeys + } + + fin, err := writer.Finalize(ctx, WithAttributeValues([]*policy.Value{uriOnlyAttr})) + require.NoError(t, err, "Finalize must succeed when default KAS fills in missing key for URI-only grant") + require.NotNil(t, fin.Manifest) + + // Verify the key access object references the right KAS + require.GreaterOrEqual(t, len(fin.Manifest.KeyAccessObjs), 1) + assert.Equal(t, testKAS1, fin.Manifest.KeyAccessObjs[0].KasURL) + assert.NotEmpty(t, fin.Manifest.KeyAccessObjs[0].WrappedKey) +} + // testInitialAttributesOnWriter verifies that attributes/KAS supplied at // NewWriter are used by Finalize when not overridden, and that Finalize // overrides take precedence. @@ -996,6 +1092,196 @@ func BenchmarkTDFCreation(b *testing.B) { }) } +// TestCrossDecryptWithSharedDEK verifies that the experimental writer's +// encryption format is compatible with the production SDK by injecting a +// shared DEK into the experimental writer and cross-validating with the +// same crypto primitives the production reader uses: +// +// - ocrypto.AesGcm.Encrypt() (production encrypt: returns nonce||ciphertext) +// - ocrypto.AesGcm.Decrypt() (production decrypt: expects nonce||ciphertext) +// - HMAC-SHA256(dek, nonce||ciphertext) (production segment hash verification) +// +// The test also assembles a complete TDF ZIP from the experimental writer +// and parses it with zipstream.TDFReader (the same reader the production +// SDK uses internally) to verify structural compatibility. +func TestCrossDecryptWithSharedDEK(t *testing.T) { + ctx := t.Context() + + sharedDEK, err := ocrypto.RandomBytes(kKeySize) + require.NoError(t, err) + + t.Run("SingleSegment", func(t *testing.T) { + original := []byte("Cross-SDK format compatibility: single segment") + + sharedCipher, err := ocrypto.NewAESGcm(sharedDEK) + require.NoError(t, err) + + // --- Experimental writer with injected DEK --- + writer, err := NewWriter(ctx, WithSegmentIntegrityAlgorithm(HS256)) + require.NoError(t, err) + writer.dek = sharedDEK + writer.block, err = ocrypto.NewAESGcm(sharedDEK) + require.NoError(t, err) + + expInput := append([]byte(nil), original...) + expResult, err := writer.WriteSegment(ctx, 0, expInput) + require.NoError(t, err) + + allBytes, err := io.ReadAll(expResult.TDFData) + require.NoError(t, err) + expEncrypted := allBytes[len(allBytes)-int(expResult.EncryptedSize):] + + // --- Production-style encrypt with the same DEK --- + prodEncrypted, err := sharedCipher.Encrypt(original) + require.NoError(t, err) + + // --- Cross-decrypt: production Decrypt() on experimental output --- + decryptedFromExp, err := sharedCipher.Decrypt(expEncrypted) + require.NoError(t, err, "production Decrypt must handle experimental output") + assert.Equal(t, decryptedFromExp, original) + + // --- Cross-decrypt: Decrypt() on production output --- + decryptedFromProd, err := sharedCipher.Decrypt(prodEncrypted) + require.NoError(t, err) + assert.Equal(t, original, decryptedFromProd) + + // --- Hash cross-verification --- + // The production reader computes HMAC-SHA256(payloadKey, encryptedSegment) + // and compares it against the manifest segment hash. Verify the + // experimental writer's stored hash matches this computation. + mac := hmac.New(sha256.New, sharedDEK) + mac.Write(expEncrypted) + independentHash := base64.StdEncoding.EncodeToString(mac.Sum(nil)) + assert.Equal(t, expResult.Hash, independentHash, + "experimental hash must equal production-style HMAC-SHA256") + + // Verify production-encrypted data also hashes correctly + prodMac := hmac.New(sha256.New, sharedDEK) + prodMac.Write(prodEncrypted) + prodHash := base64.StdEncoding.EncodeToString(prodMac.Sum(nil)) + assert.NotEmpty(t, prodHash) + // Both hashes are valid HMACs but differ because nonces are random + assert.NotEqual(t, independentHash, prodHash) + }) + + t.Run("MultiSegment", func(t *testing.T) { + sharedCipher, err := ocrypto.NewAESGcm(sharedDEK) + require.NoError(t, err) + + writer, err := NewWriter(ctx, WithSegmentIntegrityAlgorithm(HS256)) + require.NoError(t, err) + writer.dek = sharedDEK + writer.block, err = ocrypto.NewAESGcm(sharedDEK) + require.NoError(t, err) + + segments := [][]byte{ + []byte("segment zero"), + []byte("segment one with longer content for variety"), + []byte("s2"), + } + + for i, original := range segments { + input := append([]byte(nil), original...) + result, err := writer.WriteSegment(ctx, i, input) + require.NoError(t, err) + + raw, err := io.ReadAll(result.TDFData) + require.NoError(t, err) + encrypted := raw[len(raw)-int(result.EncryptedSize):] + + // Cross-decrypt each segment with production-style Decrypt + decrypted, err := sharedCipher.Decrypt(encrypted) + require.NoError(t, err, "segment %d cross-decrypt", i) + assert.Equal(t, original, decrypted, "segment %d plaintext", i) + + // Verify hash matches independent HMAC + mac := hmac.New(sha256.New, sharedDEK) + mac.Write(encrypted) + assert.Equal(t, + base64.StdEncoding.EncodeToString(mac.Sum(nil)), + result.Hash, "segment %d hash", i) + } + }) + + t.Run("FullTDFAssembly", func(t *testing.T) { + // Assemble a complete TDF ZIP from the experimental writer and + // parse it with the same zipstream.TDFReader the production SDK uses. + writer, err := NewWriter(ctx, WithSegmentIntegrityAlgorithm(HS256)) + require.NoError(t, err) + writer.dek = sharedDEK + writer.block, err = ocrypto.NewAESGcm(sharedDEK) + require.NoError(t, err) + + plainSegments := [][]byte{ + []byte("first segment payload"), + []byte("second segment payload - a bit longer"), + } + sharedCipher, err := ocrypto.NewAESGcm(sharedDEK) + require.NoError(t, err) + + // Collect segment TDFData (ZIP local headers + encrypted data) + var tdfBuf bytes.Buffer + for i, original := range plainSegments { + input := append([]byte(nil), original...) + result, err := writer.WriteSegment(ctx, i, input) + require.NoError(t, err) + _, err = io.Copy(&tdfBuf, result.TDFData) + require.NoError(t, err) + } + + // Finalize (adds central directory + manifest entry) + attrs := []*policy.Value{ + createTestAttribute("https://example.com/attr/Cross/value/Test", testKAS1, "kid1"), + } + fin, err := writer.Finalize(ctx, WithAttributeValues(attrs)) + require.NoError(t, err) + tdfBuf.Write(fin.Data) + + // Parse with zipstream.TDFReader — the production SDK's ZIP parser + tdfReader, err := zipstream.NewTDFReader(bytes.NewReader(tdfBuf.Bytes())) + require.NoError(t, err, "production TDFReader must parse experimental TDF ZIP") + + // Verify manifest is valid JSON with expected fields + manifestJSON, err := tdfReader.Manifest() + require.NoError(t, err) + assert.Contains(t, manifestJSON, `"algorithm":"AES-256-GCM"`) + assert.Contains(t, manifestJSON, `"isStreamable":true`) + + var manifest Manifest + require.NoError(t, json.Unmarshal([]byte(manifestJSON), &manifest)) + require.Len(t, manifest.Segments, len(plainSegments)) + assert.Equal(t, "HS256", manifest.SegmentHashAlgorithm) + assert.NotEmpty(t, manifest.Signature, "root signature must be present") + + // Verify payload is readable and each segment decrypts correctly + payloadSize, err := tdfReader.PayloadSize() + require.NoError(t, err) + + var offset int64 + for i, seg := range manifest.Segments { + require.LessOrEqual(t, offset+seg.EncryptedSize, payloadSize, + "segment %d exceeds payload bounds", i) + + readBuf, err := tdfReader.ReadPayload(offset, seg.EncryptedSize) + require.NoError(t, err, "segment %d ReadPayload", i) + + // This is exactly what the production reader does: + // 1. Verify segment hash + mac := hmac.New(sha256.New, sharedDEK) + mac.Write(readBuf) + computedHash := base64.StdEncoding.EncodeToString(mac.Sum(nil)) + assert.Equal(t, seg.Hash, computedHash, "segment %d hash verification", i) + + // 2. Decrypt + decrypted, err := sharedCipher.Decrypt(readBuf) + require.NoError(t, err, "segment %d decrypt", i) + assert.Equal(t, plainSegments[i], decrypted, "segment %d plaintext", i) + + offset += seg.EncryptedSize + } + }) +} + // testGetManifestBeforeAndAfterFinalize verifies GetManifest returns a stub // before finalization and the final manifest after finalization. func testGetManifestBeforeAndAfterFinalize(t *testing.T) {