diff --git a/src/DurableTask.SqlServer/Scripts/logic.sql b/src/DurableTask.SqlServer/Scripts/logic.sql index ca7135d..708d74b 100644 --- a/src/DurableTask.SqlServer/Scripts/logic.sql +++ b/src/DurableTask.SqlServer/Scripts/logic.sql @@ -56,7 +56,8 @@ IF TYPE_ID(N'__SchemaNamePlaceholder__.OrchestrationEvents') IS NULL [PayloadID] uniqueidentifier NULL, [ParentInstanceID] varchar(100) NULL, [Version] varchar(100) NULL, - [TraceContext] varchar(800) NULL + [TraceContext] varchar(800) NULL, + [Tags] varchar(8000) NULL ) GO @@ -233,7 +234,8 @@ CREATE OR ALTER PROCEDURE __SchemaNamePlaceholder__.CreateInstance @InputText varchar(MAX) = NULL, @StartTime datetime2 = NULL, @DedupeStatuses varchar(MAX) = 'Pending,Running', - @TraceContext varchar(800) = NULL + @TraceContext varchar(800) = NULL, + @Tags varchar(8000) = NULL AS BEGIN DECLARE @TaskHub varchar(50) = __SchemaNamePlaceholder__.CurrentTaskHub() @@ -302,7 +304,8 @@ BEGIN [ExecutionID], [RuntimeStatus], [InputPayloadID], - [TraceContext]) + [TraceContext], + [Tags]) VALUES ( @Name, @Version, @@ -311,7 +314,8 @@ BEGIN @ExecutionID, @RuntimeStatus, @InputPayloadID, - @TraceContext + @TraceContext, + @Tags ) INSERT INTO NewEvents ( @@ -348,10 +352,12 @@ BEGIN DECLARE @TaskHub varchar(50) = __SchemaNamePlaceholder__.CurrentTaskHub() DECLARE @ParentInstanceID varchar(100) DECLARE @Version varchar(100) + DECLARE @Tags varchar(8000) SELECT @ParentInstanceID = [ParentInstanceID], - @Version = [Version] + @Version = [Version], + @Tags = [Tags] FROM Instances WHERE [InstanceID] = @InstanceID SELECT @@ -370,7 +376,8 @@ BEGIN [PayloadID], @ParentInstanceID as [ParentInstanceID], @Version as [Version], - H.[TraceContext] + H.[TraceContext], + @Tags as [Tags] FROM History H WITH (INDEX (PK_History)) LEFT OUTER JOIN Payloads P ON P.[TaskHub] = @TaskHub AND @@ -635,6 +642,7 @@ BEGIN DECLARE @parentInstanceID varchar(100) DECLARE @version varchar(100) DECLARE @runtimeStatus varchar(30) + DECLARE @tags varchar(8000) DECLARE @TaskHub varchar(50) = __SchemaNamePlaceholder__.CurrentTaskHub() BEGIN TRANSACTION @@ -654,7 +662,8 @@ BEGIN @instanceID = I.[InstanceID], @parentInstanceID = I.[ParentInstanceID], @runtimeStatus = I.[RuntimeStatus], - @version = I.[Version] + @version = I.[Version], + @tags = I.[Tags] FROM Instances I WITH (READPAST) INNER JOIN NewEvents E WITH (READPAST) ON E.[TaskHub] = @TaskHub AND @@ -684,7 +693,8 @@ BEGIN DATEDIFF(SECOND, [Timestamp], @now) AS [WaitTime], @parentInstanceID as [ParentInstanceID], @version as [Version], - N.[TraceContext] + N.[TraceContext], + @tags as [Tags] FROM NewEvents N LEFT OUTER JOIN __SchemaNamePlaceholder__.[Payloads] P ON P.[TaskHub] = @TaskHub AND @@ -724,7 +734,8 @@ BEGIN [PayloadID], @parentInstanceID as [ParentInstanceID], @version as [Version], - H.[TraceContext] + H.[TraceContext], + @tags as [Tags] FROM History H WITH (INDEX (PK_History)) LEFT OUTER JOIN Payloads P ON P.[TaskHub] = @TaskHub AND @@ -907,7 +918,8 @@ BEGIN [Version], [ParentInstanceID], [RuntimeStatus], - [TraceContext]) + [TraceContext], + [Tags]) SELECT DISTINCT @TaskHub, E.[InstanceID], @@ -916,7 +928,8 @@ BEGIN E.[Version], E.[ParentInstanceID], 'Pending', - E.[TraceContext] + E.[TraceContext], + E.[Tags] FROM @NewOrchestrationEvents E WHERE E.[EventType] IN ('ExecutionStarted') AND NOT EXISTS ( @@ -1185,7 +1198,8 @@ BEGIN P.[TaskHub] = @TaskHub AND P.[InstanceID] = I.[InstanceID] AND P.[PayloadID] = I.[OutputPayloadID]) ELSE NULL END AS [OutputText], - I.[TraceContext] + I.[TraceContext], + I.[Tags] FROM Instances I WHERE I.[TaskHub] = @TaskHub AND @@ -1231,7 +1245,8 @@ BEGIN P.[TaskHub] = @TaskHub AND P.[InstanceID] = I.[InstanceID] AND P.[PayloadID] = I.[OutputPayloadID]) ELSE NULL END AS [OutputText], - I.[TraceContext] + I.[TraceContext], + I.[Tags] FROM Instances I WHERE diff --git a/src/DurableTask.SqlServer/Scripts/schema-1.3.0.sql b/src/DurableTask.SqlServer/Scripts/schema-1.3.0.sql new file mode 100644 index 0000000..d5596dc --- /dev/null +++ b/src/DurableTask.SqlServer/Scripts/schema-1.3.0.sql @@ -0,0 +1,24 @@ +-- Copyright (c) Microsoft Corporation. +-- Licensed under the MIT License. + +-- PERSISTENT SCHEMA OBJECTS (tables, indexes, types, etc.) +-- +-- The contents of this file must never be changed after +-- being published. Any schema changes must be done in +-- new schema-{major}.{minor}.{patch}.sql scripts. + +-- Add a new Tags column to the Instances table (JSON blob of string key-value pairs). +IF NOT EXISTS (SELECT 1 FROM sys.columns WHERE object_id = OBJECT_ID('__SchemaNamePlaceholder__.Instances') AND name = 'Tags') + ALTER TABLE __SchemaNamePlaceholder__.Instances ADD [Tags] varchar(8000) NULL + +-- Add a Tags column to the OrchestrationEvents table type so that merged tags +-- flow through sub-orchestration creation events. To change a type we must first +-- drop all stored procedures that reference it, then drop the type itself. +-- The type and sprocs will be recreated by logic.sql which executes afterwards. +IF OBJECT_ID('__SchemaNamePlaceholder__._AddOrchestrationEvents') IS NOT NULL + DROP PROCEDURE __SchemaNamePlaceholder__._AddOrchestrationEvents +IF OBJECT_ID('__SchemaNamePlaceholder__._CheckpointOrchestration') IS NOT NULL + DROP PROCEDURE __SchemaNamePlaceholder__._CheckpointOrchestration + +IF TYPE_ID('__SchemaNamePlaceholder__.OrchestrationEvents') IS NOT NULL + DROP TYPE __SchemaNamePlaceholder__.OrchestrationEvents diff --git a/src/DurableTask.SqlServer/SqlOrchestrationService.cs b/src/DurableTask.SqlServer/SqlOrchestrationService.cs index 3dfe161..04286fb 100644 --- a/src/DurableTask.SqlServer/SqlOrchestrationService.cs +++ b/src/DurableTask.SqlServer/SqlOrchestrationService.cs @@ -372,7 +372,8 @@ public override async Task CompleteTaskOrchestrationWorkItemAsync( timerMessages, continuedAsNewMessage, currentWorkItem.EventPayloadMappings, - this.settings.SchemaName); + this.settings.SchemaName, + this.traceHelper); command.Parameters.AddTaskEventsParameter( "@NewTaskEvents", @@ -388,6 +389,7 @@ public override async Task CompleteTaskOrchestrationWorkItemAsync( currentWorkItem.EventPayloadMappings, this.settings.SchemaName); + try { await SqlUtils.ExecuteNonQueryAsync(command, this.traceHelper, instance.InstanceId); @@ -522,6 +524,8 @@ public override async Task CreateTaskOrchestrationAsync(TaskMessage creationMess command.Parameters.Add("@StartTime", SqlDbType.DateTime2).Value = startEvent.ScheduledStartTime; command.Parameters.Add("@TraceContext", SqlDbType.VarChar, size: 800).Value = SqlUtils.GetTraceContext(startEvent); + command.Parameters.AddTagsParameter(startEvent.Tags); + if (dedupeStatuses?.Length > 0) { command.Parameters.Add("@DedupeStatuses", SqlDbType.VarChar).Value = string.Join(",", dedupeStatuses); @@ -543,7 +547,7 @@ public override async Task SendTaskOrchestrationMessageAsync(TaskMessage message using SqlConnection connection = await this.GetAndOpenConnectionAsync(); using SqlCommand command = this.GetSprocCommand(connection, $"{this.settings.SchemaName}._AddOrchestrationEvents"); - command.Parameters.AddOrchestrationEventsParameter("@NewOrchestrationEvents", message, this.settings.SchemaName); + command.Parameters.AddOrchestrationEventsParameter("@NewOrchestrationEvents", message, this.settings.SchemaName, this.traceHelper); string instanceId = message.OrchestrationInstance.InstanceId; await SqlUtils.ExecuteNonQueryAsync(command, this.traceHelper, instanceId); diff --git a/src/DurableTask.SqlServer/SqlTypes/OrchestrationEventSqlType.cs b/src/DurableTask.SqlServer/SqlTypes/OrchestrationEventSqlType.cs index 9c3e98c..1694aba 100644 --- a/src/DurableTask.SqlServer/SqlTypes/OrchestrationEventSqlType.cs +++ b/src/DurableTask.SqlServer/SqlTypes/OrchestrationEventSqlType.cs @@ -31,6 +31,7 @@ static class OrchestrationEventSqlType new SqlMetaData("ParentInstanceID", SqlDbType.VarChar, 100), new SqlMetaData("Version", SqlDbType.VarChar, 100), new SqlMetaData("TraceContext", SqlDbType.VarChar, 800), + new SqlMetaData("Tags", SqlDbType.VarChar, 8000), }; static class ColumnOrdinals @@ -50,6 +51,7 @@ static class ColumnOrdinals public const int ParentInstanceID = 11; public const int Version = 12; public const int TraceContext = 13; + public const int Tags = 14; } public static SqlParameter AddOrchestrationEventsParameter( @@ -59,7 +61,8 @@ public static SqlParameter AddOrchestrationEventsParameter( IList timerMessages, TaskMessage continuedAsNewMessage, EventPayloadMap eventPayloadMap, - string schemaName) + string schemaName, + LogHelper logHelper) { SqlParameter param = commandParameters.Add(paramName, SqlDbType.Structured); param.TypeName = $"{schemaName}.OrchestrationEvents"; @@ -70,7 +73,7 @@ public static SqlParameter AddOrchestrationEventsParameter( messages = messages.Append(continuedAsNewMessage); } - param.Value = ToOrchestrationMessageParameter(messages, eventPayloadMap); + param.Value = ToOrchestrationMessageParameter(messages, eventPayloadMap, logHelper); return param; } @@ -78,17 +81,19 @@ public static SqlParameter AddOrchestrationEventsParameter( this SqlParameterCollection commandParameters, string paramName, TaskMessage message, - string schemaName) + string schemaName, + LogHelper logHelper) { SqlParameter param = commandParameters.Add(paramName, SqlDbType.Structured); param.TypeName = $"{schemaName}.OrchestrationEvents"; - param.Value = ToOrchestrationMessageParameter(message); + param.Value = ToOrchestrationMessageParameter(message, logHelper); return param; } static IEnumerable? ToOrchestrationMessageParameter( this IEnumerable messages, - EventPayloadMap eventPayloadMap) + EventPayloadMap eventPayloadMap, + LogHelper logHelper) { if (!messages.Any()) { @@ -105,18 +110,18 @@ IEnumerable GetOrchestrationMessageRecords() var record = new SqlDataRecord(OrchestrationEventSchema); foreach (TaskMessage msg in messages) { - yield return PopulateOrchestrationMessage(msg, record, eventPayloadMap); + yield return PopulateOrchestrationMessage(msg, record, eventPayloadMap, logHelper); } } } - static IEnumerable ToOrchestrationMessageParameter(TaskMessage msg) + static IEnumerable ToOrchestrationMessageParameter(TaskMessage msg, LogHelper logHelper) { var record = new SqlDataRecord(OrchestrationEventSchema); - yield return PopulateOrchestrationMessage(msg, record, eventPayloadMap: null); + yield return PopulateOrchestrationMessage(msg, record, eventPayloadMap: null, logHelper); } - static SqlDataRecord PopulateOrchestrationMessage(TaskMessage msg, SqlDataRecord record, EventPayloadMap? eventPayloadMap) + static SqlDataRecord PopulateOrchestrationMessage(TaskMessage msg, SqlDataRecord record, EventPayloadMap? eventPayloadMap, LogHelper logHelper) { string instanceId = msg.OrchestrationInstance.InstanceId; @@ -152,6 +157,7 @@ static SqlDataRecord PopulateOrchestrationMessage(TaskMessage msg, SqlDataRecord record.SetSqlString(ColumnOrdinals.ParentInstanceID, SqlUtils.GetParentInstanceId(msg.Event)); record.SetSqlString(ColumnOrdinals.Version, SqlUtils.GetVersion(msg.Event)); record.SetSqlString(ColumnOrdinals.TraceContext, SqlUtils.GetTraceContext(msg.Event)); + record.SetSqlString(ColumnOrdinals.Tags, SqlUtils.GetTagsJson(msg.Event, logHelper)); return record; } diff --git a/src/DurableTask.SqlServer/SqlUtils.cs b/src/DurableTask.SqlServer/SqlUtils.cs index 7fc58ee..8d6de4d 100644 --- a/src/DurableTask.SqlServer/SqlUtils.cs +++ b/src/DurableTask.SqlServer/SqlUtils.cs @@ -24,6 +24,7 @@ static class SqlUtils { static readonly Random random = new Random(); static readonly char[] TraceContextSeparators = new char[] { '\n' }; + const int MaxTagsPayloadSize = 8000; public static string? GetStringOrNull(this DbDataReader reader, int columnIndex) { @@ -75,17 +76,17 @@ public static HistoryEvent GetHistoryEvent(this DbDataReader reader, bool isOrch InstanceId = GetInstanceId(reader), }; break; - case EventType.ExecutionCompleted: - FailureDetails? executionFailedDetails = null; - OrchestrationStatus orchestrationStatus = GetRuntimeStatus(reader); - if (orchestrationStatus == OrchestrationStatus.Failed) - { - TryGetFailureDetails(reader, out executionFailedDetails); + case EventType.ExecutionCompleted: + FailureDetails? executionFailedDetails = null; + OrchestrationStatus orchestrationStatus = GetRuntimeStatus(reader); + if (orchestrationStatus == OrchestrationStatus.Failed) + { + TryGetFailureDetails(reader, out executionFailedDetails); } historyEvent = new ExecutionCompletedEvent( eventId, result: GetPayloadText(reader), - orchestrationStatus: orchestrationStatus, + orchestrationStatus: orchestrationStatus, failureDetails: executionFailedDetails); break; case EventType.ExecutionStarted: @@ -97,7 +98,7 @@ public static HistoryEvent GetHistoryEvent(this DbDataReader reader, bool isOrch InstanceId = GetInstanceId(reader), ExecutionId = GetExecutionId(reader), }, - Tags = null, // TODO + Tags = GetTags(reader), Version = GetVersion(reader), ParentTraceContext = GetTraceContext(reader), }; @@ -259,7 +260,8 @@ public static OrchestrationState GetOrchestrationState(this DbDataReader reader) }, OrchestrationStatus = GetRuntimeStatus(reader), Status = GetStringOrNull(reader, reader.GetOrdinal("CustomStatusText")), - ParentInstance = parentInstance + ParentInstance = parentInstance, + Tags = GetTags(reader), }; // The OutputText column is overloaded to contain either orchestration output or failure details @@ -483,6 +485,74 @@ internal static SqlString GetTraceContext(HistoryEvent e) return traceContext; } + internal static IDictionary? GetTags(DbDataReader reader) + { + int ordinal = reader.GetOrdinal("Tags"); + + if (reader.IsDBNull(ordinal)) + { + return null; + } + + string json = reader.GetString(ordinal); + if (string.IsNullOrEmpty(json)) + { + return null; + } + + try + { + return DTUtils.DeserializeFromJson>(json); + } + catch (Exception ex) + { + Debug.WriteLine($"Failed to deserialize Tags JSON payload. Treating as null. Error: {ex}"); + return null; + } + } + + internal static SqlString GetTagsJson(HistoryEvent e, LogHelper logHelper) + { + if (e is ExecutionStartedEvent startedEvent && startedEvent.Tags != null && startedEvent.Tags.Count > 0) + { + string json = DTUtils.SerializeToJson(startedEvent.Tags); + int utf8Bytes = Encoding.UTF8.GetByteCount(json); + if (utf8Bytes > MaxTagsPayloadSize) + { + logHelper.GenericWarning( + $"Dropping oversized tags ({utf8Bytes} bytes, max {MaxTagsPayloadSize}) for sub-orchestration. " + + $"The merged parent+child tags exceed the allowed limit and will not be persisted.", + instanceId: (e as ExecutionStartedEvent)?.ParentInstance?.OrchestrationInstance?.InstanceId); + return SqlString.Null; + } + + return json; + } + + return SqlString.Null; + } + + internal static void AddTagsParameter( + this SqlParameterCollection parameters, + IDictionary? tags) + { + string? json = tags != null && tags.Count > 0 + ? DTUtils.SerializeToJson(tags) + : null; + + if (json != null) + { + int utf8Bytes = Encoding.UTF8.GetByteCount(json); + if (utf8Bytes > MaxTagsPayloadSize) + { + throw new ArgumentException( + $"The serialized tags payload is {utf8Bytes} bytes, which exceeds the maximum allowed size of {MaxTagsPayloadSize} bytes."); + } + } + + parameters.Add("@Tags", SqlDbType.VarChar, MaxTagsPayloadSize).Value = (object?)json ?? DBNull.Value; + } + public static SqlParameter AddInstanceIDsParameter( this SqlParameterCollection commandParameters, string paramName, diff --git a/test/DurableTask.SqlServer.Tests/Integration/DatabaseManagement.cs b/test/DurableTask.SqlServer.Tests/Integration/DatabaseManagement.cs index 11c0aec..d4c1d02 100644 --- a/test/DurableTask.SqlServer.Tests/Integration/DatabaseManagement.cs +++ b/test/DurableTask.SqlServer.Tests/Integration/DatabaseManagement.cs @@ -46,6 +46,7 @@ public void CanEnumerateEmbeddedSqlScripts() "drop-schema.sql", "schema-1.0.0.sql", "schema-1.2.0.sql", + "schema-1.3.0.sql", "logic.sql", "permissions.sql", }; @@ -98,6 +99,7 @@ public async Task CanCreateAndDropSchema(bool isDatabaseMissing) LogAssert.ExecutedSqlScript("drop-schema.sql"), LogAssert.ExecutedSqlScript("schema-1.0.0.sql"), LogAssert.ExecutedSqlScript("schema-1.2.0.sql"), + LogAssert.ExecutedSqlScript("schema-1.3.0.sql"), LogAssert.ExecutedSqlScript("logic.sql"), LogAssert.ExecutedSqlScript("permissions.sql"), LogAssert.SprocCompleted("dt._UpdateVersion")) @@ -158,6 +160,7 @@ public async Task CanCreateAndDropSchemaWithCustomSchemaName(bool isDatabaseMiss LogAssert.ExecutedSqlScript("drop-schema.sql"), LogAssert.ExecutedSqlScript("schema-1.0.0.sql"), LogAssert.ExecutedSqlScript("schema-1.2.0.sql"), + LogAssert.ExecutedSqlScript("schema-1.3.0.sql"), LogAssert.ExecutedSqlScript("logic.sql"), LogAssert.ExecutedSqlScript("permissions.sql"), LogAssert.SprocCompleted($"{schemaName}._UpdateVersion")) @@ -220,6 +223,7 @@ public async Task CanCreateAndDropMultipleSchemas(bool isDatabaseMissing) LogAssert.ExecutedSqlScript("drop-schema.sql"), LogAssert.ExecutedSqlScript("schema-1.0.0.sql"), LogAssert.ExecutedSqlScript("schema-1.2.0.sql"), + LogAssert.ExecutedSqlScript("schema-1.3.0.sql"), LogAssert.ExecutedSqlScript("logic.sql"), LogAssert.ExecutedSqlScript("permissions.sql"), LogAssert.SprocCompleted($"{firstTestSchemaName}._UpdateVersion")) @@ -230,6 +234,7 @@ public async Task CanCreateAndDropMultipleSchemas(bool isDatabaseMissing) LogAssert.ExecutedSqlScript("drop-schema.sql"), LogAssert.ExecutedSqlScript("schema-1.0.0.sql"), LogAssert.ExecutedSqlScript("schema-1.2.0.sql"), + LogAssert.ExecutedSqlScript("schema-1.3.0.sql"), LogAssert.ExecutedSqlScript("logic.sql"), LogAssert.ExecutedSqlScript("permissions.sql"), LogAssert.SprocCompleted($"{secondTestSchemaName}._UpdateVersion")) @@ -313,6 +318,7 @@ public async Task CanCreateIfNotExists(bool isDatabaseMissing) LogAssert.SprocCompleted("dt._GetVersions"), LogAssert.ExecutedSqlScript("schema-1.0.0.sql"), LogAssert.ExecutedSqlScript("schema-1.2.0.sql"), + LogAssert.ExecutedSqlScript("schema-1.3.0.sql"), LogAssert.ExecutedSqlScript("logic.sql"), LogAssert.ExecutedSqlScript("permissions.sql"), LogAssert.SprocCompleted("dt._UpdateVersion")) @@ -366,6 +372,7 @@ public async Task SchemaCreationIsSerializedAndIdempotent(bool isDatabaseMissing LogAssert.SprocCompleted("dt._GetVersions"), LogAssert.ExecutedSqlScript("schema-1.0.0.sql"), LogAssert.ExecutedSqlScript("schema-1.2.0.sql"), + LogAssert.ExecutedSqlScript("schema-1.3.0.sql"), LogAssert.ExecutedSqlScript("logic.sql"), LogAssert.ExecutedSqlScript("permissions.sql"), LogAssert.SprocCompleted("dt._UpdateVersion"), diff --git a/test/DurableTask.SqlServer.Tests/Integration/Orchestrations.cs b/test/DurableTask.SqlServer.Tests/Integration/Orchestrations.cs index fa7a9c7..fc3ff44 100644 --- a/test/DurableTask.SqlServer.Tests/Integration/Orchestrations.cs +++ b/test/DurableTask.SqlServer.Tests/Integration/Orchestrations.cs @@ -50,9 +50,12 @@ public async Task EmptyOrchestration() orchestrationName, implementation: (ctx, input) => Task.FromResult(input)); - await instance.WaitForCompletion( + OrchestrationState state = await instance.WaitForCompletion( expectedOutput: input); + // Verify backward compatibility: tags should be null when none are specified + Assert.Null(state.Tags); + // Validate logs LogAssert.NoWarningsOrErrors(this.testService.LogProvider); LogAssert.Sequence( @@ -877,22 +880,22 @@ public async Task TerminateScheduledOrchestration() instanceId: null, scheduledStartTime: DateTime.UtcNow.AddSeconds(30), implementation: (ctx, input) => Task.FromResult("done")); - - // Confirm that the orchestration is pending - OrchestrationState state = await instance.GetStateAsync(); - Assert.Equal(OrchestrationStatus.Pending, state.OrchestrationStatus); + + // Confirm that the orchestration is pending + OrchestrationState state = await instance.GetStateAsync(); + Assert.Equal(OrchestrationStatus.Pending, state.OrchestrationStatus); // Terminate the orchestration before it starts await instance.TerminateAsync("Bye!"); - + // Confirm the orchestration was terminated await instance.WaitForCompletion( expectedStatus: OrchestrationStatus.Terminated, expectedOutput: "Bye!"); LogAssert.NoWarningsOrErrors(this.testService.LogProvider); - } - + } + [Fact] public async Task TerminateSuspendedOrchestration() { @@ -910,18 +913,18 @@ public async Task TerminateSuspendedOrchestration() await instance.WaitForStart(); // Suspend the orchestration so that it won't process any new events - await instance.SuspendAsync(); - - // Wait for the orchestration to become suspended - OrchestrationState state = await instance.GetStateAsync(); - TimeSpan waitForSuspendTimeout = TimeSpan.FromSeconds(5); - using CancellationTokenSource cts = new(waitForSuspendTimeout); - while (!cts.IsCancellationRequested && state.OrchestrationStatus != OrchestrationStatus.Suspended) - { - state = await instance.GetStateAsync(); + await instance.SuspendAsync(); + + // Wait for the orchestration to become suspended + OrchestrationState state = await instance.GetStateAsync(); + TimeSpan waitForSuspendTimeout = TimeSpan.FromSeconds(5); + using CancellationTokenSource cts = new(waitForSuspendTimeout); + while (!cts.IsCancellationRequested && state.OrchestrationStatus != OrchestrationStatus.Suspended) + { + state = await instance.GetStateAsync(); } - Assert.Equal(OrchestrationStatus.Suspended, state.OrchestrationStatus); - + Assert.Equal(OrchestrationStatus.Suspended, state.OrchestrationStatus); + // Now terminate the orchestration await instance.TerminateAsync("Bye!"); @@ -937,11 +940,340 @@ public async Task TerminateSuspendedOrchestration() this.testService.LogProvider, LogAssert.AcquiredAppLock(), LogAssert.CheckpointStarting(orchestrationName), - LogAssert.CheckpointCompleted(orchestrationName), + LogAssert.CheckpointCompleted(orchestrationName), LogAssert.CheckpointStarting(orchestrationName), LogAssert.CheckpointCompleted(orchestrationName), LogAssert.CheckpointStarting(orchestrationName), LogAssert.CheckpointCompleted(orchestrationName)); } + + [Fact] + public async Task OrchestrationWithTags() + { + string input = $"Hello {DateTime.UtcNow:o}"; + var tags = new Dictionary + { + { "key1", "value1" }, + { "key2", "value2" }, + }; + + TestInstance instance = await this.testService.RunOrchestrationWithTags( + input, + orchestrationName: "OrchestrationWithTags", + tags: tags, + implementation: (ctx, input) => Task.FromResult(input)); + + OrchestrationState state = await instance.WaitForCompletion(expectedOutput: input); + + Assert.NotNull(state.Tags); + Assert.Equal(2, state.Tags.Count); + Assert.Equal("value1", state.Tags["key1"]); + Assert.Equal("value2", state.Tags["key2"]); + + LogAssert.NoWarningsOrErrors(this.testService.LogProvider); + } + + [Fact] + public async Task OrchestrationWithEmptyTags() + { + string input = $"Hello {DateTime.UtcNow:o}"; + + TestInstance instance = await this.testService.RunOrchestrationWithTags( + input, + orchestrationName: "OrchestrationWithEmptyTags", + tags: new Dictionary(), + implementation: (ctx, input) => Task.FromResult(input)); + + OrchestrationState state = await instance.WaitForCompletion(expectedOutput: input); + + Assert.Null(state.Tags); + + LogAssert.NoWarningsOrErrors(this.testService.LogProvider); + } + + [Fact] + public async Task TagsSurviveContinueAsNew() + { + var tags = new Dictionary + { + { "key1", "value1" }, + { "key2", "value2" }, + }; + + TestInstance instance = await this.testService.RunOrchestrationWithTags( + input: 0, + orchestrationName: "TagsContinueAsNewTest", + tags: tags, + implementation: async (ctx, input) => + { + if (input < 3) + { + await ctx.CreateTimer(DateTime.MinValue, null); + ctx.ContinueAsNew(input + 1); + } + + return input; + }); + + OrchestrationState state = await instance.WaitForCompletion( + expectedOutput: 3, + continuedAsNew: true); + + Assert.NotNull(state.Tags); + Assert.Equal(2, state.Tags.Count); + Assert.Equal("value1", state.Tags["key1"]); + Assert.Equal("value2", state.Tags["key2"]); + } + + [Fact] + public async Task SubOrchestrationInheritsTags() + { + var tags = new Dictionary + { + { "key1", "value1" }, + { "key2", "value2" }, + }; + + string subOrchestrationName = "SubOrchestrationForTagTest"; + string subInstanceId = $"sub-{Guid.NewGuid():N}"; + + this.testService.RegisterInlineOrchestration( + subOrchestrationName, + implementation: (ctx, input) => Task.FromResult("done")); + + TestInstance instance = await this.testService.RunOrchestrationWithTags( + input: (string)null, + orchestrationName: "ParentOrchestrationForTagTest", + tags: tags, + implementation: async (ctx, input) => + { + return await ctx.CreateSubOrchestrationInstance( + subOrchestrationName, string.Empty, subInstanceId, null); + }); + + OrchestrationState state = await instance.WaitForCompletion( + timeout: TimeSpan.FromSeconds(15), + expectedOutput: "done"); + + // Verify parent orchestration tags + Assert.NotNull(state.Tags); + Assert.Equal("value1", state.Tags["key1"]); + Assert.Equal("value2", state.Tags["key2"]); + + // Verify sub-orchestration inherited the tags + OrchestrationState subState = await this.testService.GetOrchestrationStateAsync(subInstanceId); + Assert.NotNull(subState); + Assert.NotNull(subState.Tags); + Assert.Equal("value1", subState.Tags["key1"]); + Assert.Equal("value2", subState.Tags["key2"]); + } + + [Fact] + public async Task TagsWithSpecialCharacters() + { + string input = $"Hello {DateTime.UtcNow:o}"; + var tags = new Dictionary + { + { "key with spaces", "value with spaces" }, + { "unicode-key-日本語", "unicode-value-中文" }, + { "special\"chars", "value'with\"quotes" }, + }; + + TestInstance instance = await this.testService.RunOrchestrationWithTags( + input, + orchestrationName: "TagsSpecialCharsTest", + tags: tags, + implementation: (ctx, input) => Task.FromResult(input)); + + OrchestrationState state = await instance.WaitForCompletion(expectedOutput: input); + + Assert.NotNull(state.Tags); + Assert.Equal(3, state.Tags.Count); + Assert.Equal("value with spaces", state.Tags["key with spaces"]); + Assert.Equal("unicode-value-中文", state.Tags["unicode-key-日本語"]); + Assert.Equal("value'with\"quotes", state.Tags["special\"chars"]); + } + + [Fact] + public async Task SubOrchestrationMergesTags() + { + // Parent tags + var parentTags = new Dictionary + { + { "env", "prod" }, + { "shared", "parent-value" }, + }; + + // Sub-orchestration-specific tags (will be merged with parent tags by Core) + var subTags = new Dictionary + { + { "team", "backend" }, + { "shared", "child-override" }, // should override parent's value + }; + + string subOrchName = "SubOrchForMergeTest"; + string subInstanceId = $"sub-merge-{Guid.NewGuid():N}"; + + this.testService.RegisterInlineOrchestration( + subOrchName, + implementation: (ctx, input) => Task.FromResult("done")); + + TestInstance instance = await this.testService.RunOrchestrationWithTags( + input: (string)null, + orchestrationName: "ParentOrchForMergeTest", + tags: parentTags, + implementation: async (ctx, input) => + { + // Use the 5-arg overload that passes sub-orch-specific tags + return await ctx.CreateSubOrchestrationInstance( + subOrchName, string.Empty, subInstanceId, null, subTags); + }); + + await instance.WaitForCompletion( + timeout: TimeSpan.FromSeconds(15), + expectedOutput: "done"); + + // Verify sub-orchestration has MERGED tags (parent + child, child overrides) + OrchestrationState subState = await this.testService.GetOrchestrationStateAsync(subInstanceId); + Assert.NotNull(subState); + Assert.NotNull(subState.Tags); + Assert.Equal("prod", subState.Tags["env"]); // inherited from parent + Assert.Equal("backend", subState.Tags["team"]); // from sub-orch + Assert.Equal("child-override", subState.Tags["shared"]); // child overrides parent + } + + [Fact] + public async Task MultipleSubOrchestrationsMergeDifferentTags() + { + var parentTags = new Dictionary + { + { "env", "prod" }, + }; + + string subOrchName = "SubOrchForFanOutTest"; + string subId1 = $"sub-fanout1-{Guid.NewGuid():N}"; + string subId2 = $"sub-fanout2-{Guid.NewGuid():N}"; + + this.testService.RegisterInlineOrchestration( + subOrchName, + implementation: (ctx, input) => Task.FromResult("done")); + + TestInstance instance = await this.testService.RunOrchestrationWithTags( + input: (string)null, + orchestrationName: "ParentOrchForFanOutTest", + tags: parentTags, + implementation: async (ctx, input) => + { + // Fan-out: create two sub-orchestrations with different tags + var tags1 = new Dictionary { { "region", "us" } }; + var tags2 = new Dictionary { { "region", "eu" } }; + + Task t1 = ctx.CreateSubOrchestrationInstance( + subOrchName, string.Empty, subId1, null, tags1); + Task t2 = ctx.CreateSubOrchestrationInstance( + subOrchName, string.Empty, subId2, null, tags2); + + await Task.WhenAll(t1, t2); + return "done"; + }); + + await instance.WaitForCompletion( + timeout: TimeSpan.FromSeconds(15), + expectedOutput: "done"); + + // Verify each sub-orchestration got its own correctly-merged tags + OrchestrationState sub1 = await this.testService.GetOrchestrationStateAsync(subId1); + Assert.NotNull(sub1?.Tags); + Assert.Equal("prod", sub1.Tags["env"]); // inherited from parent + Assert.Equal("us", sub1.Tags["region"]); // specific to sub-orch 1 + + OrchestrationState sub2 = await this.testService.GetOrchestrationStateAsync(subId2); + Assert.NotNull(sub2?.Tags); + Assert.Equal("prod", sub2.Tags["env"]); // inherited from parent + Assert.Equal("eu", sub2.Tags["region"]); // specific to sub-orch 2 + } + + [Fact] + public async Task MergedTagsExceedMaxSize_OversizedTagsDropped() + { + // Parent and child tags are each within the 8000-char limit, + // but exceed it after Core's MergeTags() combines them. + // Expected behavior: oversized merged tags are silently dropped + // (with a trace warning), the sub-orchestration is created with + // null tags, and the parent completes normally. + + var parentTags = new Dictionary + { + { "parentKey", new string('p', 4500) }, + }; + + var childTags = new Dictionary + { + { "childKey", new string('c', 4500) }, + }; + + string subOrchName = "SubOrchForOverflowTest"; + string subInstanceId = $"sub-overflow-{Guid.NewGuid():N}"; + + this.testService.RegisterInlineOrchestration( + subOrchName, + implementation: (ctx, input) => Task.FromResult("done")); + + TestInstance instance = await this.testService.RunOrchestrationWithTags( + input: (string)null, + orchestrationName: "ParentOrchForOverflowTest", + tags: parentTags, + implementation: async (ctx, input) => + { + return await ctx.CreateSubOrchestrationInstance( + subOrchName, string.Empty, subInstanceId, null, childTags); + }); + + // Parent should complete normally (sub-orch returns "done") + await instance.WaitForCompletion(expectedOutput: "done"); + + // Sub-orchestration should have been created, but with null tags + // because the merged tags exceeded the maximum size. + OrchestrationState subState = await this.testService.GetOrchestrationStateAsync(subInstanceId); + Assert.NotNull(subState); + Assert.Null(subState.Tags); + } + + [Fact] + public async Task TagsOnManyOrchestrations() + { + string input = $"Hello {DateTime.UtcNow:o}"; + var tags = new Dictionary + { + { "key1", "value1" }, + }; + + TestInstance instance = await this.testService.RunOrchestrationWithTags( + input, + orchestrationName: "TagsManyQueryTest", + tags: tags, + implementation: (ctx, input) => Task.FromResult(input)); + + await instance.WaitForCompletion(expectedOutput: input); + + var filter = new SqlOrchestrationQuery(); + IReadOnlyCollection results = + await this.testService.OrchestrationServiceMock.Object.GetManyOrchestrationsAsync( + filter, CancellationToken.None); + + Assert.NotEmpty(results); + bool foundTaggedInstance = false; + foreach (OrchestrationState result in results) + { + if (result.OrchestrationInstance.InstanceId == instance.InstanceId) + { + Assert.NotNull(result.Tags); + Assert.Equal("value1", result.Tags["key1"]); + foundTaggedInstance = true; + } + } + + Assert.True(foundTaggedInstance, "Did not find the tagged orchestration instance in query results."); + } } } diff --git a/test/DurableTask.SqlServer.Tests/Unit/SqlUtilsTagTests.cs b/test/DurableTask.SqlServer.Tests/Unit/SqlUtilsTagTests.cs new file mode 100644 index 0000000..df61341 --- /dev/null +++ b/test/DurableTask.SqlServer.Tests/Unit/SqlUtilsTagTests.cs @@ -0,0 +1,163 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace DurableTask.SqlServer.Tests.Unit +{ + using System; + using System.Collections.Generic; + using System.Data; + using System.Data.SqlTypes; + using DurableTask.Core.History; + using Microsoft.Data.SqlClient; + using Microsoft.Extensions.Logging.Abstractions; + using Newtonsoft.Json; + using Xunit; + + public class SqlUtilsTagTests + { + static readonly LogHelper TestLogHelper = new LogHelper(NullLogger.Instance); + [Fact] + public void AddTagsParameter_WithTags_SetsJsonValue() + { + // Arrange + var tags = new Dictionary + { + { "key1", "value1" }, + { "key2", "value2" }, + }; + + using var command = new SqlCommand(); + + // Act + command.Parameters.AddTagsParameter(tags); + + // Assert + SqlParameter param = command.Parameters["@Tags"]; + Assert.NotNull(param); + Assert.Equal(SqlDbType.VarChar, param.SqlDbType); + + string json = (string)param.Value; + var deserialized = JsonConvert.DeserializeObject>(json); + Assert.Equal(2, deserialized.Count); + Assert.Equal("value1", deserialized["key1"]); + Assert.Equal("value2", deserialized["key2"]); + } + + [Fact] + public void AddTagsParameter_NullTags_SetsDBNull() + { + using var command = new SqlCommand(); + + // Act + command.Parameters.AddTagsParameter(null); + + // Assert + Assert.Equal(DBNull.Value, command.Parameters["@Tags"].Value); + } + + [Fact] + public void AddTagsParameter_EmptyTags_SetsDBNull() + { + using var command = new SqlCommand(); + + // Act + command.Parameters.AddTagsParameter(new Dictionary()); + + // Assert + Assert.Equal(DBNull.Value, command.Parameters["@Tags"].Value); + } + + [Fact] + public void AddTagsParameter_SpecialCharacters_RoundTrips() + { + // Arrange + var tags = new Dictionary + { + { "special\"key", "value'with\"quotes" }, + { "unicode-日本語", "中文" }, + { "key with spaces", "value with spaces" }, + }; + + using var command = new SqlCommand(); + + // Act + command.Parameters.AddTagsParameter(tags); + + // Assert + string json = (string)command.Parameters["@Tags"].Value; + var deserialized = JsonConvert.DeserializeObject>(json); + Assert.Equal(3, deserialized.Count); + Assert.Equal("value'with\"quotes", deserialized["special\"key"]); + Assert.Equal("中文", deserialized["unicode-日本語"]); + Assert.Equal("value with spaces", deserialized["key with spaces"]); + } + + [Fact] + public void AddTagsParameter_TagsExceedMaxSize_ThrowsArgumentException() + { + // Arrange: create tags whose JSON serialization exceeds 8000 chars + var tags = new Dictionary + { + { "key", new string('x', 8000) }, + }; + + using var command = new SqlCommand(); + + // Act & Assert + var ex = Assert.Throws(() => command.Parameters.AddTagsParameter(tags)); + Assert.Contains("exceeds the maximum allowed size of 8000 bytes", ex.Message); + } + + [Fact] + public void GetTagsJson_TagsExceedMaxSize_ReturnsNullAndDropsTags() + { + // Arrange: simulate merged tags that exceed 8000 chars. + // This covers the sub-orchestration merge path where individually-valid + // parent + child tags combine to exceed the limit. + // Expected: returns SqlString.Null (tags silently dropped with a warning). + var tags = new Dictionary + { + { "key", new string('x', 8000) }, + }; + + var startedEvent = new ExecutionStartedEvent(-1, null) { Tags = tags }; + + // Act + SqlString result = SqlUtils.GetTagsJson(startedEvent, TestLogHelper); + + // Assert: oversized tags are dropped, not thrown + Assert.True(result.IsNull); + } + + [Fact] + public void GetTagsJson_NonExecutionStartedEvent_ReturnsNull() + { + // Non-ExecutionStartedEvent should return SqlString.Null + var timerEvent = new TimerFiredEvent(-1); + + SqlString result = SqlUtils.GetTagsJson(timerEvent, TestLogHelper); + + Assert.True(result.IsNull); + } + + [Fact] + public void GetTagsJson_ExecutionStartedWithTags_ReturnsJson() + { + var tags = new Dictionary + { + { "env", "prod" }, + { "team", "backend" }, + }; + + var startedEvent = new ExecutionStartedEvent(-1, null) { Tags = tags }; + + SqlString result = SqlUtils.GetTagsJson(startedEvent, TestLogHelper); + + Assert.False(result.IsNull); + var deserialized = JsonConvert.DeserializeObject>(result.Value); + Assert.Equal(2, deserialized.Count); + Assert.Equal("prod", deserialized["env"]); + Assert.Equal("backend", deserialized["team"]); + } + } +} diff --git a/test/DurableTask.SqlServer.Tests/Utils/TestService.cs b/test/DurableTask.SqlServer.Tests/Utils/TestService.cs index 6545e6a..b74907b 100644 --- a/test/DurableTask.SqlServer.Tests/Utils/TestService.cs +++ b/test/DurableTask.SqlServer.Tests/Utils/TestService.cs @@ -87,6 +87,11 @@ public async Task InitializeAsync(bool startWorker = true, bool legacyErrorPropa this.client = new TaskHubClient(this.OrchestrationServiceMock.Object, loggerFactory: this.loggerFactory); } + public Task GetOrchestrationStateAsync(string instanceId) + { + return this.client.GetOrchestrationStateAsync(new OrchestrationInstance { InstanceId = instanceId }); + } + public Task StartWorkerAsync() => this.worker?.StartAsync() ?? Task.CompletedTask; public Task PurgeAsync(DateTime maximumThreshold, OrchestrationStateTimeRangeFilterType filterType) @@ -263,6 +268,40 @@ public async Task>> RunOrchestrations> RunOrchestrationWithTags( + TInput input, + string orchestrationName, + IDictionary tags, + Func> implementation, + params (string name, TaskActivity activity)[] activities) + { + // Register the inline orchestration + this.RegisterInlineOrchestration(orchestrationName, string.Empty, implementation); + + foreach ((string name, TaskActivity activity) in activities) + { + this.RegisterInlineActivity(name, string.Empty, activity); + } + + string instanceId = Guid.NewGuid().ToString("N"); + DateTime utcNow = DateTime.UtcNow; + + OrchestrationInstance instance = await this.client.CreateOrchestrationInstanceAsync( + orchestrationName, + string.Empty, + instanceId, + input, + tags); + + return new TestInstance( + this.client, + instance, + orchestrationName, + string.Empty, + utcNow, + input); + } + public void RegisterInlineActivity(string name, string version, TaskActivity activity) { this.worker.AddTaskActivities(new TestObjectCreator(name, version, activity));