diff --git a/Directory.Packages.props b/Directory.Packages.props index ee50dd7..4c447fc 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -9,7 +9,7 @@ - + diff --git a/src/DurableTask.SqlServer/Scripts/logic.sql b/src/DurableTask.SqlServer/Scripts/logic.sql index 708d74b..6fbc422 100644 --- a/src/DurableTask.SqlServer/Scripts/logic.sql +++ b/src/DurableTask.SqlServer/Scripts/logic.sql @@ -76,7 +76,8 @@ IF TYPE_ID(N'__SchemaNamePlaceholder__.TaskEvents') IS NULL [PayloadText] varchar(max) NULL, [PayloadID] uniqueidentifier NULL, [Version] varchar(100) NULL, - [TraceContext] varchar(800) NULL + [TraceContext] varchar(800) NULL, + [Tags] varchar(8000) NULL ) GO @@ -1031,7 +1032,8 @@ BEGIN [LockExpiration], [PayloadID], [Version], - [TraceContext] + [TraceContext], + [Tags] ) OUTPUT INSERTED.[SequenceNumber], @@ -1047,7 +1049,8 @@ BEGIN [LockExpiration], [PayloadID], [Version], - [TraceContext] + [TraceContext], + [Tags] FROM @NewTaskEvents COMMIT TRANSACTION @@ -1310,7 +1313,8 @@ BEGIN P.[InstanceID] = N.[InstanceID] AND P.[PayloadID] = N.[PayloadID]) AS [PayloadText], DATEDIFF(SECOND, [Timestamp], @now) AS [WaitTime], - [TraceContext] + [TraceContext], + [Tags] FROM NewTasks N WHERE [TaskHub] = @TaskHub AND [SequenceNumber] = @SequenceNumber diff --git a/src/DurableTask.SqlServer/Scripts/schema-1.3.0.sql b/src/DurableTask.SqlServer/Scripts/schema-1.6.0.sql similarity index 54% rename from src/DurableTask.SqlServer/Scripts/schema-1.3.0.sql rename to src/DurableTask.SqlServer/Scripts/schema-1.6.0.sql index d5596dc..2d8f7d2 100644 --- a/src/DurableTask.SqlServer/Scripts/schema-1.3.0.sql +++ b/src/DurableTask.SqlServer/Scripts/schema-1.6.0.sql @@ -11,14 +11,23 @@ IF NOT EXISTS (SELECT 1 FROM sys.columns WHERE object_id = OBJECT_ID('__SchemaNamePlaceholder__.Instances') AND name = 'Tags') ALTER TABLE __SchemaNamePlaceholder__.Instances ADD [Tags] varchar(8000) NULL --- Add a Tags column to the OrchestrationEvents table type so that merged tags --- flow through sub-orchestration creation events. To change a type we must first --- drop all stored procedures that reference it, then drop the type itself. --- The type and sprocs will be recreated by logic.sql which executes afterwards. +-- Add a new Tags column to the NewTasks table so that orchestration tags +-- propagate to activity task workers via OrchestrationExecutionContext. +IF NOT EXISTS (SELECT 1 FROM sys.columns WHERE object_id = OBJECT_ID('__SchemaNamePlaceholder__.NewTasks') AND name = 'Tags') + ALTER TABLE __SchemaNamePlaceholder__.NewTasks ADD [Tags] varchar(8000) NULL + +-- Add Tags columns to the OrchestrationEvents and TaskEvents table types. +-- To change a type we must first drop all stored procedures that reference it, +-- then drop the type itself. The types and sprocs will be recreated by logic.sql +-- which executes afterwards. IF OBJECT_ID('__SchemaNamePlaceholder__._AddOrchestrationEvents') IS NOT NULL DROP PROCEDURE __SchemaNamePlaceholder__._AddOrchestrationEvents IF OBJECT_ID('__SchemaNamePlaceholder__._CheckpointOrchestration') IS NOT NULL DROP PROCEDURE __SchemaNamePlaceholder__._CheckpointOrchestration +IF OBJECT_ID('__SchemaNamePlaceholder__._CompleteTasks') IS NOT NULL + DROP PROCEDURE __SchemaNamePlaceholder__._CompleteTasks IF TYPE_ID('__SchemaNamePlaceholder__.OrchestrationEvents') IS NOT NULL DROP TYPE __SchemaNamePlaceholder__.OrchestrationEvents +IF TYPE_ID('__SchemaNamePlaceholder__.TaskEvents') IS NOT NULL + DROP TYPE __SchemaNamePlaceholder__.TaskEvents diff --git a/src/DurableTask.SqlServer/SqlOrchestrationService.cs b/src/DurableTask.SqlServer/SqlOrchestrationService.cs index 04286fb..75f4ebd 100644 --- a/src/DurableTask.SqlServer/SqlOrchestrationService.cs +++ b/src/DurableTask.SqlServer/SqlOrchestrationService.cs @@ -379,7 +379,8 @@ public override async Task CompleteTaskOrchestrationWorkItemAsync( "@NewTaskEvents", outboundMessages, currentWorkItem.EventPayloadMappings, - this.settings.SchemaName); + this.settings.SchemaName, + this.traceHelper); command.Parameters.AddHistoryEventsParameter( "@NewHistoryEvents", @@ -485,7 +486,7 @@ public override async Task CompleteTaskActivityWorkItemAsync(TaskActivityWorkIte using SqlCommand command = this.GetSprocCommand(connection, $"{this.settings.SchemaName}._CompleteTasks"); command.Parameters.AddMessageIdParameter("@CompletedTasks", workItem.TaskMessage, this.settings.SchemaName); - command.Parameters.AddTaskEventsParameter("@Results", responseMessage, this.settings.SchemaName); + command.Parameters.AddTaskEventsParameter("@Results", responseMessage, this.settings.SchemaName, this.traceHelper); OrchestrationInstance instance = workItem.TaskMessage.OrchestrationInstance; try diff --git a/src/DurableTask.SqlServer/SqlTypes/TaskEventSqlType.cs b/src/DurableTask.SqlServer/SqlTypes/TaskEventSqlType.cs index 37b8858..275b41a 100644 --- a/src/DurableTask.SqlServer/SqlTypes/TaskEventSqlType.cs +++ b/src/DurableTask.SqlServer/SqlTypes/TaskEventSqlType.cs @@ -9,6 +9,7 @@ namespace DurableTask.SqlServer.SqlTypes using System.Data.SqlTypes; using System.Linq; using DurableTask.Core; + using DurableTask.SqlServer.Logging; using Microsoft.Data.SqlClient; using Microsoft.Data.SqlClient.Server; @@ -30,6 +31,7 @@ static class TaskEventSqlType new SqlMetaData("PayloadID", SqlDbType.UniqueIdentifier), new SqlMetaData("Version", SqlDbType.VarChar, 100), new SqlMetaData("TraceContext", SqlDbType.VarChar, 800), + new SqlMetaData("Tags", SqlDbType.VarChar, 8000), }; static class ColumnOrdinals @@ -48,6 +50,7 @@ static class ColumnOrdinals public const int PayloadId = 10; public const int Version = 11; public const int TraceContext = 12; + public const int Tags = 13; } public static SqlParameter AddTaskEventsParameter( @@ -55,11 +58,12 @@ public static SqlParameter AddTaskEventsParameter( string paramName, IList outboundMessages, EventPayloadMap eventPayloadMap, - string schemaName) + string schemaName, + LogHelper logHelper) { SqlParameter param = commandParameters.Add(paramName, SqlDbType.Structured); param.TypeName = $"{schemaName}.TaskEvents"; - param.Value = ToTaskMessagesParameter(outboundMessages, eventPayloadMap); + param.Value = ToTaskMessagesParameter(outboundMessages, eventPayloadMap, logHelper); return param; } @@ -67,17 +71,19 @@ public static SqlParameter AddTaskEventsParameter( this SqlParameterCollection commandParameters, string paramName, TaskMessage message, - string schemaName) + string schemaName, + LogHelper logHelper) { SqlParameter param = commandParameters.Add(paramName, SqlDbType.Structured); param.TypeName = $"{schemaName}.TaskEvents"; - param.Value = ToTaskMessageParameter(message); + param.Value = ToTaskMessageParameter(message, logHelper); return param; } static IEnumerable? ToTaskMessagesParameter( this IEnumerable messages, - EventPayloadMap? eventPayloadMap) + EventPayloadMap? eventPayloadMap, + LogHelper logHelper) { if (!messages.Any()) { @@ -92,21 +98,22 @@ IEnumerable GetTaskEventRecords() var record = new SqlDataRecord(TaskEventSchema); foreach (TaskMessage msg in messages) { - yield return PopulateTaskMessageRecord(msg, record, eventPayloadMap); + yield return PopulateTaskMessageRecord(msg, record, eventPayloadMap, logHelper); } } } - static IEnumerable ToTaskMessageParameter(TaskMessage msg) + static IEnumerable ToTaskMessageParameter(TaskMessage msg, LogHelper logHelper) { var record = new SqlDataRecord(TaskEventSchema); - yield return PopulateTaskMessageRecord(msg, record, eventPayloadMap: null); + yield return PopulateTaskMessageRecord(msg, record, eventPayloadMap: null, logHelper); } static SqlDataRecord PopulateTaskMessageRecord( TaskMessage msg, SqlDataRecord record, - EventPayloadMap? eventPayloadMap) + EventPayloadMap? eventPayloadMap, + LogHelper logHelper) { record.SetSqlString(ColumnOrdinals.InstanceID, msg.OrchestrationInstance.InstanceId); record.SetSqlString(ColumnOrdinals.ExecutionID, msg.OrchestrationInstance.ExecutionId); @@ -140,6 +147,7 @@ static SqlDataRecord PopulateTaskMessageRecord( record.SetSqlString(ColumnOrdinals.Version, SqlUtils.GetVersion(msg.Event)); record.SetSqlString(ColumnOrdinals.TraceContext, SqlUtils.GetTraceContext(msg.Event)); + record.SetSqlString(ColumnOrdinals.Tags, SqlUtils.GetMergedTaskTagsJson(msg, logHelper)); return record; } } diff --git a/src/DurableTask.SqlServer/SqlUtils.cs b/src/DurableTask.SqlServer/SqlUtils.cs index 8d6de4d..84983f4 100644 --- a/src/DurableTask.SqlServer/SqlUtils.cs +++ b/src/DurableTask.SqlServer/SqlUtils.cs @@ -185,6 +185,7 @@ public static HistoryEvent GetHistoryEvent(this DbDataReader reader, bool isOrch Input = GetPayloadText(reader), Name = GetName(reader), Version = GetVersion(reader), + Tags = GetTags(reader), ParentTraceContext = GetTraceContext(reader), }; break; @@ -515,21 +516,56 @@ internal static SqlString GetTagsJson(HistoryEvent e, LogHelper logHelper) { if (e is ExecutionStartedEvent startedEvent && startedEvent.Tags != null && startedEvent.Tags.Count > 0) { - string json = DTUtils.SerializeToJson(startedEvent.Tags); - int utf8Bytes = Encoding.UTF8.GetByteCount(json); - if (utf8Bytes > MaxTagsPayloadSize) + return SerializeTagsJson(startedEvent.Tags, logHelper, (e as ExecutionStartedEvent)?.ParentInstance?.OrchestrationInstance?.InstanceId); + } + + return SqlString.Null; + } + + internal static SqlString GetMergedTaskTagsJson(TaskMessage msg, LogHelper logHelper) + { + IDictionary? orchestrationTags = msg.OrchestrationExecutionContext?.OrchestrationTags; + IDictionary? activityTags = (msg.Event as TaskScheduledEvent)?.Tags; + + bool hasOrchTags = orchestrationTags != null && orchestrationTags.Count > 0; + bool hasActTags = activityTags != null && activityTags.Count > 0; + + if (!hasOrchTags && !hasActTags) + { + return SqlString.Null; + } + + // Merge flat: orchestration tags as base, activity tags override on key collision. + if (hasOrchTags && hasActTags) + { + var merged = new Dictionary(orchestrationTags!); + foreach (var kvp in activityTags!) { - logHelper.GenericWarning( - $"Dropping oversized tags ({utf8Bytes} bytes, max {MaxTagsPayloadSize}) for sub-orchestration. " + - $"The merged parent+child tags exceed the allowed limit and will not be persisted.", - instanceId: (e as ExecutionStartedEvent)?.ParentInstance?.OrchestrationInstance?.InstanceId); - return SqlString.Null; + merged[kvp.Key] = kvp.Value; } + return SerializeTagsJson(merged, logHelper, msg.OrchestrationInstance?.InstanceId); + } - return json; + return SerializeTagsJson( + hasOrchTags ? orchestrationTags! : activityTags!, + logHelper, + msg.OrchestrationInstance?.InstanceId); + } + + static SqlString SerializeTagsJson(IDictionary tags, LogHelper logHelper, string? instanceId) + { + string json = DTUtils.SerializeToJson(tags); + int utf8Bytes = Encoding.UTF8.GetByteCount(json); + if (utf8Bytes > MaxTagsPayloadSize) + { + logHelper.GenericWarning( + $"Dropping oversized tags ({utf8Bytes} bytes, max {MaxTagsPayloadSize}). " + + $"The tags exceed the allowed limit and will not be persisted.", + instanceId: instanceId); + return SqlString.Null; } - return SqlString.Null; + return json; } internal static void AddTagsParameter( diff --git a/src/common.props b/src/common.props index a952a5a..7f63f5c 100644 --- a/src/common.props +++ b/src/common.props @@ -16,8 +16,8 @@ 1 - 5 - 4 + 6 + 0 $(MajorVersion).$(MinorVersion).$(PatchVersion) $(MajorVersion).$(MinorVersion).0.0 diff --git a/test/DurableTask.SqlServer.Tests/Integration/DatabaseManagement.cs b/test/DurableTask.SqlServer.Tests/Integration/DatabaseManagement.cs index d4c1d02..75f95b3 100644 --- a/test/DurableTask.SqlServer.Tests/Integration/DatabaseManagement.cs +++ b/test/DurableTask.SqlServer.Tests/Integration/DatabaseManagement.cs @@ -46,7 +46,7 @@ public void CanEnumerateEmbeddedSqlScripts() "drop-schema.sql", "schema-1.0.0.sql", "schema-1.2.0.sql", - "schema-1.3.0.sql", + "schema-1.6.0.sql", "logic.sql", "permissions.sql", }; @@ -99,7 +99,7 @@ public async Task CanCreateAndDropSchema(bool isDatabaseMissing) LogAssert.ExecutedSqlScript("drop-schema.sql"), LogAssert.ExecutedSqlScript("schema-1.0.0.sql"), LogAssert.ExecutedSqlScript("schema-1.2.0.sql"), - LogAssert.ExecutedSqlScript("schema-1.3.0.sql"), + LogAssert.ExecutedSqlScript("schema-1.6.0.sql"), LogAssert.ExecutedSqlScript("logic.sql"), LogAssert.ExecutedSqlScript("permissions.sql"), LogAssert.SprocCompleted("dt._UpdateVersion")) @@ -160,7 +160,7 @@ public async Task CanCreateAndDropSchemaWithCustomSchemaName(bool isDatabaseMiss LogAssert.ExecutedSqlScript("drop-schema.sql"), LogAssert.ExecutedSqlScript("schema-1.0.0.sql"), LogAssert.ExecutedSqlScript("schema-1.2.0.sql"), - LogAssert.ExecutedSqlScript("schema-1.3.0.sql"), + LogAssert.ExecutedSqlScript("schema-1.6.0.sql"), LogAssert.ExecutedSqlScript("logic.sql"), LogAssert.ExecutedSqlScript("permissions.sql"), LogAssert.SprocCompleted($"{schemaName}._UpdateVersion")) @@ -223,7 +223,7 @@ public async Task CanCreateAndDropMultipleSchemas(bool isDatabaseMissing) LogAssert.ExecutedSqlScript("drop-schema.sql"), LogAssert.ExecutedSqlScript("schema-1.0.0.sql"), LogAssert.ExecutedSqlScript("schema-1.2.0.sql"), - LogAssert.ExecutedSqlScript("schema-1.3.0.sql"), + LogAssert.ExecutedSqlScript("schema-1.6.0.sql"), LogAssert.ExecutedSqlScript("logic.sql"), LogAssert.ExecutedSqlScript("permissions.sql"), LogAssert.SprocCompleted($"{firstTestSchemaName}._UpdateVersion")) @@ -234,7 +234,7 @@ public async Task CanCreateAndDropMultipleSchemas(bool isDatabaseMissing) LogAssert.ExecutedSqlScript("drop-schema.sql"), LogAssert.ExecutedSqlScript("schema-1.0.0.sql"), LogAssert.ExecutedSqlScript("schema-1.2.0.sql"), - LogAssert.ExecutedSqlScript("schema-1.3.0.sql"), + LogAssert.ExecutedSqlScript("schema-1.6.0.sql"), LogAssert.ExecutedSqlScript("logic.sql"), LogAssert.ExecutedSqlScript("permissions.sql"), LogAssert.SprocCompleted($"{secondTestSchemaName}._UpdateVersion")) @@ -318,7 +318,7 @@ public async Task CanCreateIfNotExists(bool isDatabaseMissing) LogAssert.SprocCompleted("dt._GetVersions"), LogAssert.ExecutedSqlScript("schema-1.0.0.sql"), LogAssert.ExecutedSqlScript("schema-1.2.0.sql"), - LogAssert.ExecutedSqlScript("schema-1.3.0.sql"), + LogAssert.ExecutedSqlScript("schema-1.6.0.sql"), LogAssert.ExecutedSqlScript("logic.sql"), LogAssert.ExecutedSqlScript("permissions.sql"), LogAssert.SprocCompleted("dt._UpdateVersion")) @@ -372,7 +372,7 @@ public async Task SchemaCreationIsSerializedAndIdempotent(bool isDatabaseMissing LogAssert.SprocCompleted("dt._GetVersions"), LogAssert.ExecutedSqlScript("schema-1.0.0.sql"), LogAssert.ExecutedSqlScript("schema-1.2.0.sql"), - LogAssert.ExecutedSqlScript("schema-1.3.0.sql"), + LogAssert.ExecutedSqlScript("schema-1.6.0.sql"), LogAssert.ExecutedSqlScript("logic.sql"), LogAssert.ExecutedSqlScript("permissions.sql"), LogAssert.SprocCompleted("dt._UpdateVersion"), @@ -388,6 +388,29 @@ public async Task SchemaCreationIsSerializedAndIdempotent(bool isDatabaseMissing .EndOfLog(); } + /// + /// Verifies that the schema-1.6.0 migration correctly adds the Tags column + /// to the Instances and NewTasks tables. Without the correctly named schema file, + /// existing databases would not be upgraded and the @Tags parameter would be unrecognized. + /// + [Fact] + public async Task SchemaUpgradeAddsTagsColumn() + { + using TestDatabase testDb = this.CreateTestDb(); + IOrchestrationService service = this.CreateServiceWithTestDb(testDb); + + // Create the full schema from scratch + await service.CreateAsync(recreateInstanceStore: true); + + // Verify the Tags column exists on the Instances table + IEnumerable instanceColumns = testDb.GetColumns("Instances"); + Assert.Contains("Tags", instanceColumns); + + // Verify the Tags column exists on the NewTasks table + IEnumerable taskColumns = testDb.GetColumns("NewTasks"); + Assert.Contains("Tags", taskColumns); + } + TestDatabase CreateTestDb(bool initializeDatabase = true) { var testDb = new TestDatabase(this.output); @@ -511,8 +534,8 @@ async Task ValidateDatabaseSchemaAsync(TestDatabase database, string schemaName database.ConnectionString, schemaName); Assert.Equal(1, currentSchemaVersion.Major); - Assert.Equal(5, currentSchemaVersion.Minor); - Assert.Equal(4, currentSchemaVersion.Patch); + Assert.Equal(6, currentSchemaVersion.Minor); + Assert.Equal(0, currentSchemaVersion.Patch); } sealed class TestDatabase : IDisposable @@ -574,6 +597,19 @@ public IEnumerable GetTables(string schemaName = "dt") } } + public IEnumerable GetColumns(string tableName, string schemaName = "dt") + { + this.testDb.Tables.Refresh(); + Table? table = this.testDb.Tables[tableName, schemaName]; + if (table != null) + { + foreach (Column column in table.Columns) + { + yield return column.Name; + } + } + } + public IEnumerable GetSprocs(string schemaName = "dt") { this.testDb.StoredProcedures.Refresh(); diff --git a/test/DurableTask.SqlServer.Tests/Integration/Orchestrations.cs b/test/DurableTask.SqlServer.Tests/Integration/Orchestrations.cs index fc3ff44..e80f8b5 100644 --- a/test/DurableTask.SqlServer.Tests/Integration/Orchestrations.cs +++ b/test/DurableTask.SqlServer.Tests/Integration/Orchestrations.cs @@ -11,6 +11,8 @@ namespace DurableTask.SqlServer.Tests.Integration using System.Threading.Tasks; using DurableTask.Core; using DurableTask.Core.Exceptions; + using DurableTask.Core.History; + using DurableTask.Core.Middleware; using DurableTask.SqlServer.Tests.Logging; using DurableTask.SqlServer.Tests.Utils; using Moq; @@ -769,26 +771,27 @@ public async Task TraceContextFlowCorrectly() foreach (Activity span in exportedItems) { this.outputHelper.WriteLine( - $"{span.Id}: Name={span.DisplayName}, Start={span.StartTimeUtc:o}, Duration={span.Duration}"); + $"{span.Id}: Name={span.DisplayName}, Kind={span.Kind}, Start={span.StartTimeUtc:o}, Duration={span.Duration}, TraceState={span.TraceStateString ?? "(null)"}"); } Assert.True(exportedItems.Count >= 4); // Validate the orchestration trace activity/span. Specifically, the IDs and time range. // We need to verify these because we use custom logic to store and retrieve this data (not serialization). + // Filter by Server kind to get the actual execution span, not the client-side scheduling span. Activity orchestratorSpan = exportedItems.LastOrDefault( - span => span.OperationName == $"orchestration:{orchestrationName}"); + span => span.OperationName == $"orchestration:{orchestrationName}" && span.Kind == ActivityKind.Server); Assert.NotNull(orchestratorSpan); Assert.Equal(clientSpan.TraceId, orchestratorSpan.TraceId); Assert.NotEqual(clientSpan.SpanId, orchestratorSpan.SpanId); // new span ID - Assert.Equal("TestTraceState", orchestratorSpan.TraceStateString); + Assert.Equal("TestTraceState (modified!)", orchestratorSpan.TraceStateString); Assert.True(orchestratorSpan.StartTimeUtc >= clientSpan.StartTimeUtc); Assert.True(orchestratorSpan.Duration > delay * 2); // two sleeps Assert.True(orchestratorSpan.StartTimeUtc + orchestratorSpan.Duration <= clientSpan.StartTimeUtc + clientSpan.Duration); // Validate the sub-orchestrator span, which should be a sub-set of the parent orchestration span. Activity subOrchestratorSpan = exportedItems.LastOrDefault( - span => span.OperationName == $"orchestration:{subOrchestrationName}"); + span => span.OperationName == $"orchestration:{subOrchestrationName}" && span.Kind == ActivityKind.Server); Assert.NotNull(subOrchestratorSpan); Assert.Equal(clientSpan.TraceId, subOrchestratorSpan.TraceId); Assert.NotEqual(orchestratorSpan.SpanId, subOrchestratorSpan.SpanId); // new span ID @@ -799,7 +802,7 @@ public async Task TraceContextFlowCorrectly() // Validate the activity span, which should be a subset of the sub-orchestration span Activity activitySpan = exportedItems.LastOrDefault( - span => span.OperationName == $"activity:{activityName}"); + span => span.OperationName == $"activity:{activityName}" && span.Kind == ActivityKind.Server); Assert.NotNull(activitySpan); Assert.Equal(clientSpan.TraceId, activitySpan.TraceId); Assert.NotEqual(subOrchestratorSpan.SpanId, activitySpan.SpanId); // new span ID @@ -1275,5 +1278,99 @@ await this.testService.OrchestrationServiceMock.Object.GetManyOrchestrationsAsyn Assert.True(foundTaggedInstance, "Did not find the tagged orchestration instance in query results."); } + + /// + /// Verifies that orchestration tags propagate to the activity middleware context + /// via OrchestrationExecutionContext, surviving the SQL persistence round-trip + /// through the NewTasks table. This test exposes the gap where tags were + /// serialized by TaskOrchestrationDispatcher but never persisted/restored + /// by the MSSQL backend. + /// + [Fact] + public async Task ActivityReceivesOrchestrationTags() + { + var tags = new Dictionary + { + { "env", "test" }, + { "team", "platform" }, + }; + + // Capture tags seen by activity middleware + IDictionary capturedTags = null; + + this.testService.AddActivityDispatcherMiddleware(async (context, next) => + { + var scheduledEvent = context.GetProperty(); + capturedTags = scheduledEvent?.Tags; + await next(); + }); + + TestInstance instance = await this.testService.RunOrchestrationWithTags( + input: "hello", + orchestrationName: "ActivityTagsPropagation", + tags: tags, + implementation: (ctx, input) => ctx.ScheduleTask("Echo", "", input), + activities: new[] + { + ("Echo", TestService.MakeActivity((TaskContext ctx, string input) => input)), + }); + + await instance.WaitForCompletion(expectedOutput: "hello"); + + // Verify the activity middleware received the orchestration's tags + Assert.NotNull(capturedTags); + Assert.Equal(2, capturedTags.Count); + Assert.Equal("test", capturedTags["env"]); + Assert.Equal("platform", capturedTags["team"]); + } + + /// + /// Verifies that per-activity tags (via ScheduleTaskOptions.Tags) are merged flat + /// with orchestration-level tags, and that activity tags override on key collision. + /// Both OrchestrationExecutionContext and TaskScheduledEvent carry the merged result. + /// + [Fact] + public async Task ActivityTagsMergedWithOrchestrationTags() + { + var orchestrationTags = new Dictionary + { + { "env", "prod" }, + { "team", "platform" }, + }; + + // Capture tags seen by activity middleware via TaskScheduledEvent.Tags + IDictionary capturedTags = null; + + this.testService.AddActivityDispatcherMiddleware(async (context, next) => + { + var scheduledEvent = context.GetProperty(); + capturedTags = scheduledEvent?.Tags; + await next(); + }); + + var activityOptions = ScheduleTaskOptions.CreateBuilder() + .AddTag("priority", "high") + .AddTag("env", "staging") // overrides orchestration-level "env" + .Build(); + + TestInstance instance = await this.testService.RunOrchestrationWithTags( + input: "hello", + orchestrationName: "MergedActivityTags", + tags: orchestrationTags, + implementation: (ctx, input) => ctx.ScheduleTask("Echo", "", activityOptions, input), + activities: new[] + { + ("Echo", TestService.MakeActivity((TaskContext ctx, string input) => input)), + }); + + await instance.WaitForCompletion(expectedOutput: "hello"); + + // Verify merged tags: activity "env=staging" overrides orchestration "env=prod" + Assert.NotNull(capturedTags); + Assert.Equal(3, capturedTags.Count); + Assert.Equal("staging", capturedTags["env"]); + Assert.Equal("platform", capturedTags["team"]); + Assert.Equal("high", capturedTags["priority"]); + } } } diff --git a/test/DurableTask.SqlServer.Tests/Utils/TestService.cs b/test/DurableTask.SqlServer.Tests/Utils/TestService.cs index b74907b..51fa8ae 100644 --- a/test/DurableTask.SqlServer.Tests/Utils/TestService.cs +++ b/test/DurableTask.SqlServer.Tests/Utils/TestService.cs @@ -9,6 +9,7 @@ namespace DurableTask.SqlServer.Tests.Utils using System.Linq; using System.Threading.Tasks; using DurableTask.Core; + using DurableTask.Core.Middleware; using DurableTask.SqlServer.Tests.Logging; using Microsoft.Extensions.Logging; using Moq; @@ -94,6 +95,11 @@ public Task GetOrchestrationStateAsync(string instanceId) public Task StartWorkerAsync() => this.worker?.StartAsync() ?? Task.CompletedTask; + public void AddActivityDispatcherMiddleware(Func, Task> middleware) + { + this.worker.AddActivityDispatcherMiddleware(middleware); + } + public Task PurgeAsync(DateTime maximumThreshold, OrchestrationStateTimeRangeFilterType filterType) { return this.client.PurgeOrchestrationInstanceHistoryAsync(