diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index 41b46f7529..dec009bc14 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -763,6 +763,195 @@ "default": 4 } } + }, + "embeddings": { + "type": "object", + "description": "Configuration for text embedding/vectorization service. Supports OpenAI and Azure OpenAI providers.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Whether the embedding service is enabled. Defaults to true.", + "default": true + }, + "provider": { + "type": "string", + "description": "The embedding provider type.", + "enum": ["azure-openai", "openai"] + }, + "base-url": { + "type": "string", + "description": "The provider base URL. For Azure OpenAI, use the Azure resource endpoint. For OpenAI, use https://api.openai.com." + }, + "api-key": { + "type": "string", + "description": "The API key for authentication. Supports environment variable substitution with @env('VAR_NAME')." + }, + "model": { + "type": "string", + "description": "The model or deployment name. Required for Azure OpenAI (deployment name). For OpenAI, defaults to 'text-embedding-3-small' if not specified." + }, + "api-version": { + "type": "string", + "description": "Azure API version. Only used for Azure OpenAI provider.", + "default": "2023-05-15" + }, + "dimensions": { + "type": "integer", + "description": "Output vector dimensions. Defaults to 1536 if not specified. Useful for Redis schema alignment.", + "default": 1536, + "minimum": 1 + }, + "timeout-ms": { + "type": "integer", + "description": "Request timeout in milliseconds.", + "default": 30000, + "minimum": 1, + "maximum": 300000 + }, + "endpoint": { + "type": "object", + "description": "REST endpoint configuration for the embedding service.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Whether the /embed REST endpoint is enabled. Defaults to false.", + "default": false + }, + "path": { + "type": "string", + "description": "The URL path for the embedding endpoint. Defaults to '/embed'.", + "default": "/embed" + }, + "roles": { + "type": "array", + "description": "The roles allowed to access the embedding endpoint. Defaults to ['authenticated'].", + "default": ["authenticated"], + "items": { + "type": "string" + } + } + } + }, + "health": { + "type": "object", + "description": "Health check configuration for the embedding service.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Whether health checks are enabled for embeddings. Defaults to false.", + "default": false + }, + "threshold-ms": { + "type": "integer", + "description": "The maximum response time in milliseconds to be considered healthy.", + "default": 1000, + "minimum": 1, + "maximum": 300000 + }, + "test-text": { + "type": "string", + "description": "The text to use for health check validation.", + "default": "health check" + }, + "expected-dimensions": { + "type": "integer", + "description": "The expected number of dimensions in the embedding result. If specified, dimension validation is performed.", + "minimum": 1 + } + } + }, + "cache": { + "type": "object", + "description": "Cache configuration for embedding results.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Whether caching is enabled for embeddings. Defaults to true.", + "default": true + }, + "level": { + "type": "string", + "description": "Cache level (L1 for in-memory only, L1L2 for in-memory + distributed). Defaults to L1.", + "enum": ["L1", "L1L2"], + "default": "L1" + }, + "ttl-seconds": { + "type": "integer", + "description": "Time-to-live for cached embeddings in seconds. Defaults to 86400 (24 hours).", + "default": 86400, + "minimum": 1 + } + } + }, + "chunking": { + "type": "object", + "description": "Chunking configuration for text processing before embedding. Used to split large text inputs into smaller chunks.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Whether chunking is enabled. Defaults to true.", + "default": true + }, + "size-chars": { + "type": "integer", + "description": "The size of each chunk in characters.", + "default": 800, + "minimum": 1 + }, + "overlap-chars": { + "type": "integer", + "description": "The number of characters to overlap between consecutive chunks. Overlap helps maintain context across chunk boundaries.", + "default": 100, + "minimum": 0 + } + } + } + }, + "required": ["provider", "base-url", "api-key"], + "allOf": [ + { + "$comment": "Azure OpenAI requires the model (deployment name) to be specified.", + "if": { + "properties": { + "provider": { + "const": "azure-openai" + } + }, + "required": ["provider"] + }, + "then": { + "required": ["model"], + "properties": { + "api-version": { + "type": "string", + "description": "Azure API version. Required for Azure OpenAI provider.", + "default": "2023-05-15" + } + } + } + }, + { + "$comment": "OpenAI does not require model (defaults to text-embedding-3-small) and does not use api-version.", + "if": { + "properties": { + "provider": { + "const": "openai" + } + }, + "required": ["provider"] + }, + "then": { + "properties": { + "api-version": false + } + } + } + ] } } }, diff --git a/src/Cli.Tests/ConfigureOptionsTests.cs b/src/Cli.Tests/ConfigureOptionsTests.cs index 8c5ece5c3b..b885a3e62d 100644 --- a/src/Cli.Tests/ConfigureOptionsTests.cs +++ b/src/Cli.Tests/ConfigureOptionsTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Serilog; namespace Cli.Tests @@ -16,6 +17,12 @@ public class ConfigureOptionsTests : VerifyBase private const string TEST_RUNTIME_CONFIG_FILE = "test-update-runtime-setting.json"; private const string TEST_DATASOURCE_HEALTH_NAME = "My Data Source"; + // Embeddings test constants + private const string TEST_AZURE_OPENAI_BASE_URL = "https://myservice.openai.azure.com"; + private const string TEST_OPENAI_BASE_URL = "https://api.openai.com"; + private const string TEST_EMBEDDINGS_API_KEY = "test-api-key"; + private const string TEST_EMBEDDINGS_MODEL = "text-embedding-ada-002"; + [TestInitialize] public void TestInitialize() { @@ -1438,6 +1445,95 @@ private void SetupFileSystemWithInitialConfig(string jsonConfig) Assert.IsNotNull(config.Runtime); } + /// + /// Helper method to create a RuntimeConfig with embeddings configuration for testing. + /// + private static RuntimeConfig CreateConfigWithEmbeddings( + EmbeddingProviderType provider, + string baseUrl, + string apiKey, + string? model = null, + EmbeddingsEndpointOptions? endpoint = null, + EmbeddingsHealthCheckConfig? health = null) + { + RuntimeConfigLoader.TryParseConfig(INITIAL_CONFIG, out RuntimeConfig? config); + Assert.IsNotNull(config); + + return config with + { + Runtime = config.Runtime! with + { + Embeddings = new EmbeddingsOptions( + Provider: provider, + BaseUrl: baseUrl, + ApiKey: apiKey, + Model: model, + Endpoint: endpoint, + Health: health) + } + }; + } + + /// + /// Helper method to assert common embeddings configuration after an update. + /// + private RuntimeConfig AssertEmbeddingsConfigUpdate(bool isSuccess) + { + Assert.IsTrue(isSuccess); + string updatedConfig = _fileSystem!.File.ReadAllText(TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(updatedConfig, out RuntimeConfig? config)); + Assert.IsNotNull(config.Runtime?.Embeddings); + return config; + } + + /// + /// Helper method to assert embeddings endpoint settings. + /// + private static void AssertEmbeddingsEndpoint( + RuntimeConfig config, + bool expectedEnabled, + string[] expectedRoles) + { + Assert.IsNotNull(config.Runtime?.Embeddings); + Assert.IsNotNull(config.Runtime.Embeddings.Endpoint); + Assert.AreEqual(expectedEnabled, config.Runtime.Embeddings.Endpoint.Enabled); + Assert.IsNotNull(config.Runtime.Embeddings.Endpoint.Roles); + CollectionAssert.AreEqual(expectedRoles, config.Runtime.Embeddings.Endpoint.Roles); + } + + /// + /// Helper method to assert embeddings health settings. + /// + private static void AssertEmbeddingsHealth( + RuntimeConfig config, + bool expectedEnabled, + int expectedThresholdMs, + string expectedTestText, + int expectedDimensions) + { + Assert.IsNotNull(config.Runtime?.Embeddings); + Assert.IsNotNull(config.Runtime.Embeddings.Health); + Assert.AreEqual(expectedEnabled, config.Runtime.Embeddings.Health.Enabled); + Assert.AreEqual(expectedThresholdMs, config.Runtime.Embeddings.Health.ThresholdMs); + Assert.AreEqual(expectedTestText, config.Runtime.Embeddings.Health.TestText); + Assert.AreEqual(expectedDimensions, config.Runtime.Embeddings.Health.ExpectedDimensions); + } + + /// + /// Helper method to assert base embeddings provider settings are preserved. + /// + private static void AssertBaseEmbeddingsSettings( + RuntimeConfig config, + EmbeddingProviderType expectedProvider, + string expectedBaseUrl, + string expectedApiKey) + { + Assert.IsNotNull(config.Runtime?.Embeddings); + Assert.AreEqual(expectedProvider, config.Runtime.Embeddings.Provider); + Assert.AreEqual(expectedBaseUrl, config.Runtime.Embeddings.BaseUrl); + Assert.AreEqual(expectedApiKey, config.Runtime.Embeddings.ApiKey); + } + /// /// A simple ILogger implementation that records all log messages to a list, /// enabling tests to assert on log output without redirecting console streams. @@ -1466,7 +1562,7 @@ public void Log( } } - /// + /// /// Tests adding user-delegated-auth configuration options individually or together. /// Verifies that enabled and database-audience properties can be set independently or combined. /// Also verifies default values for properties not explicitly set. @@ -1591,6 +1687,178 @@ public void TestUpdateUserDelegatedAuthDatabaseAudience() Assert.AreEqual("EntraId", (string?)userDelegatedAuthSection["provider"]); } + /// + /// Tests that running "dab configure" with embeddings endpoint options on a config with existing embeddings + /// results in the endpoint options being added to the embeddings configuration. + /// + [TestMethod] + public void TestAddEmbeddingsEndpointOptions() + { + // Arrange: Create a config with embeddings but no endpoint/health + RuntimeConfig config = CreateConfigWithEmbeddings( + EmbeddingProviderType.AzureOpenAI, + TEST_AZURE_OPENAI_BASE_URL, + TEST_EMBEDDINGS_API_KEY, + model: TEST_EMBEDDINGS_MODEL); + _fileSystem!.AddFile(TEST_RUNTIME_CONFIG_FILE, new MockFileData(config.ToJson())); + + // Act: Configure embeddings endpoint options + ConfigureOptions options = new( + runtimeEmbeddingsEndpointEnabled: CliBool.True, + runtimeEmbeddingsEndpointRoles: new List { "admin", "reader" }, + config: TEST_RUNTIME_CONFIG_FILE); + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert + config = AssertEmbeddingsConfigUpdate(isSuccess); + AssertEmbeddingsEndpoint(config, expectedEnabled: true, expectedRoles: new[] { "admin", "reader" }); + AssertBaseEmbeddingsSettings(config, EmbeddingProviderType.AzureOpenAI, + TEST_AZURE_OPENAI_BASE_URL, TEST_EMBEDDINGS_API_KEY); + } + + /// + /// Tests that running "dab configure" with embeddings health options on a config with existing embeddings + /// results in the health options being added to the embeddings configuration. + /// + [TestMethod] + public void TestAddEmbeddingsHealthOptions() + { + // Arrange: Create a config with embeddings but no health config + RuntimeConfig config = CreateConfigWithEmbeddings( + EmbeddingProviderType.OpenAI, + TEST_OPENAI_BASE_URL, + TEST_EMBEDDINGS_API_KEY); + _fileSystem!.AddFile(TEST_RUNTIME_CONFIG_FILE, new MockFileData(config.ToJson())); + + // Act: Configure embeddings health options + ConfigureOptions options = new( + runtimeEmbeddingsHealthEnabled: CliBool.True, + runtimeEmbeddingsHealthThresholdMs: 3000, + runtimeEmbeddingsHealthTestText: "hello world", + runtimeEmbeddingsHealthExpectedDimensions: 1536, + config: TEST_RUNTIME_CONFIG_FILE); + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert + config = AssertEmbeddingsConfigUpdate(isSuccess); + AssertEmbeddingsHealth(config, expectedEnabled: true, expectedThresholdMs: 3000, + expectedTestText: "hello world", expectedDimensions: 1536); + Assert.IsNotNull(config.Runtime?.Embeddings); + Assert.AreEqual(EmbeddingProviderType.OpenAI, config.Runtime.Embeddings.Provider); + } + + /// + /// Tests that running "dab configure" with both embeddings endpoint and health options + /// on a config with existing embeddings results in both being added. + /// + [TestMethod] + public void TestAddEmbeddingsEndpointAndHealthOptionsTogether() + { + // Arrange + RuntimeConfig config = CreateConfigWithEmbeddings( + EmbeddingProviderType.AzureOpenAI, + TEST_AZURE_OPENAI_BASE_URL, + TEST_EMBEDDINGS_API_KEY, + model: TEST_EMBEDDINGS_MODEL); + _fileSystem!.AddFile(TEST_RUNTIME_CONFIG_FILE, new MockFileData(config.ToJson())); + + // Act: Configure both endpoint and health options at once + ConfigureOptions options = new( + runtimeEmbeddingsEndpointEnabled: CliBool.True, + runtimeEmbeddingsEndpointRoles: new List { "authenticated" }, + runtimeEmbeddingsHealthEnabled: CliBool.True, + runtimeEmbeddingsHealthThresholdMs: 5000, + runtimeEmbeddingsHealthTestText: "test embedding", + runtimeEmbeddingsHealthExpectedDimensions: 768, + config: TEST_RUNTIME_CONFIG_FILE); + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert + config = AssertEmbeddingsConfigUpdate(isSuccess); + AssertEmbeddingsEndpoint(config, expectedEnabled: true, expectedRoles: new[] { "authenticated" }); + AssertEmbeddingsHealth(config, expectedEnabled: true, expectedThresholdMs: 5000, + expectedTestText: "test embedding", expectedDimensions: 768); + } + + /// + /// Tests that updating endpoint roles on a config that already has endpoint and health settings + /// preserves the existing health settings. + /// + [TestMethod] + public void TestUpdateExistingEmbeddingsEndpointRolesPreservesHealth() + { + // Arrange: Create a config with embeddings that already has endpoint and health + RuntimeConfig config = CreateConfigWithEmbeddings( + EmbeddingProviderType.AzureOpenAI, + TEST_AZURE_OPENAI_BASE_URL, + TEST_EMBEDDINGS_API_KEY, + model: TEST_EMBEDDINGS_MODEL, + endpoint: new EmbeddingsEndpointOptions(enabled: true, roles: new[] { "old-role" }), + health: new EmbeddingsHealthCheckConfig(enabled: true, thresholdMs: 2000, + testText: "existing text", expectedDimensions: 512)); + _fileSystem!.AddFile(TEST_RUNTIME_CONFIG_FILE, new MockFileData(config.ToJson())); + + // Act: Update only endpoint roles + ConfigureOptions options = new( + runtimeEmbeddingsEndpointRoles: new List { "new-role" }, + config: TEST_RUNTIME_CONFIG_FILE); + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert: Endpoint roles updated, health preserved + config = AssertEmbeddingsConfigUpdate(isSuccess); + AssertEmbeddingsEndpoint(config, expectedEnabled: true, expectedRoles: new[] { "new-role" }); + AssertEmbeddingsHealth(config, expectedEnabled: true, expectedThresholdMs: 2000, + expectedTestText: "existing text", expectedDimensions: 512); + } + + /// + /// Tests that configuring embeddings health with an invalid (negative) threshold fails. + /// + [TestMethod] + public void TestConfigureEmbeddingsHealthWithInvalidThresholdFails() + { + // Arrange + RuntimeConfig config = CreateConfigWithEmbeddings( + EmbeddingProviderType.OpenAI, + TEST_OPENAI_BASE_URL, + TEST_EMBEDDINGS_API_KEY); + _fileSystem!.AddFile(TEST_RUNTIME_CONFIG_FILE, new MockFileData(config.ToJson())); + + // Act: Configure with invalid threshold + ConfigureOptions options = new( + runtimeEmbeddingsHealthEnabled: CliBool.True, + runtimeEmbeddingsHealthThresholdMs: -1, + config: TEST_RUNTIME_CONFIG_FILE); + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert: Should fail + Assert.IsFalse(isSuccess); + } + + /// + /// Tests that configuring embeddings health with an invalid (negative) expected-dimensions fails. + /// + [TestMethod] + public void TestConfigureEmbeddingsHealthWithInvalidExpectedDimensionsFails() + { + // Arrange + RuntimeConfig config = CreateConfigWithEmbeddings( + EmbeddingProviderType.OpenAI, + TEST_OPENAI_BASE_URL, + TEST_EMBEDDINGS_API_KEY); + _fileSystem!.AddFile(TEST_RUNTIME_CONFIG_FILE, new MockFileData(config.ToJson())); + + // Act: Configure with invalid expected dimensions + ConfigureOptions options = new( + runtimeEmbeddingsHealthEnabled: CliBool.True, + runtimeEmbeddingsHealthExpectedDimensions: 0, + config: TEST_RUNTIME_CONFIG_FILE); + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert: Should fail + Assert.IsFalse(isSuccess); + } + /// /// Tests adding pagination options to a config that doesn't have a pagination section. /// Command: dab configure --runtime.pagination.max-page-size 500 --runtime.pagination.default-page-size 50 --runtime.pagination.next-link-relative true diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index 66c560405b..458194dd92 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -4,6 +4,7 @@ using System.IO.Abstractions; using Azure.DataApiBuilder.Config; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Azure.DataApiBuilder.Product; using Cli.Constants; using CommandLine; @@ -91,6 +92,20 @@ public ConfigureOptions( long? fileSinkFileSizeLimitBytes = null, IEnumerable? runtimeTelemetryLogLevel = null, bool showEffectivePermissions = false, + CliBool? runtimeEmbeddingsEnabled = null, + EmbeddingProviderType? runtimeEmbeddingsProvider = null, + string? runtimeEmbeddingsBaseUrl = null, + string? runtimeEmbeddingsApiKey = null, + string? runtimeEmbeddingsModel = null, + string? runtimeEmbeddingsApiVersion = null, + int? runtimeEmbeddingsDimensions = null, + int? runtimeEmbeddingsTimeoutMs = null, + CliBool? runtimeEmbeddingsEndpointEnabled = null, + IEnumerable? runtimeEmbeddingsEndpointRoles = null, + CliBool? runtimeEmbeddingsHealthEnabled = null, + int? runtimeEmbeddingsHealthThresholdMs = null, + string? runtimeEmbeddingsHealthTestText = null, + int? runtimeEmbeddingsHealthExpectedDimensions = null, string? config = null) : base(config) { @@ -176,6 +191,23 @@ public ConfigureOptions( // Telemetry Log Level RuntimeTelemetryLogLevel = runtimeTelemetryLogLevel; ShowEffectivePermissions = showEffectivePermissions; + // Embeddings + RuntimeEmbeddingsEnabled = runtimeEmbeddingsEnabled; + RuntimeEmbeddingsProvider = runtimeEmbeddingsProvider; + RuntimeEmbeddingsBaseUrl = runtimeEmbeddingsBaseUrl; + RuntimeEmbeddingsApiKey = runtimeEmbeddingsApiKey; + RuntimeEmbeddingsModel = runtimeEmbeddingsModel; + RuntimeEmbeddingsApiVersion = runtimeEmbeddingsApiVersion; + RuntimeEmbeddingsDimensions = runtimeEmbeddingsDimensions; + RuntimeEmbeddingsTimeoutMs = runtimeEmbeddingsTimeoutMs; + // Embeddings Endpoint + RuntimeEmbeddingsEndpointEnabled = runtimeEmbeddingsEndpointEnabled; + RuntimeEmbeddingsEndpointRoles = runtimeEmbeddingsEndpointRoles; + // Embeddings Health + RuntimeEmbeddingsHealthEnabled = runtimeEmbeddingsHealthEnabled; + RuntimeEmbeddingsHealthThresholdMs = runtimeEmbeddingsHealthThresholdMs; + RuntimeEmbeddingsHealthTestText = runtimeEmbeddingsHealthTestText; + RuntimeEmbeddingsHealthExpectedDimensions = runtimeEmbeddingsHealthExpectedDimensions; } [Option("data-source.database-type", Required = false, HelpText = "Database type. Allowed values: MSSQL, PostgreSQL, CosmosDB_NoSQL, MySQL.")] @@ -384,6 +416,47 @@ public ConfigureOptions( [Option("show-effective-permissions", Required = false, HelpText = "Display effective permissions for all entities, including inherited permissions. Entities are listed in alphabetical order.")] public bool ShowEffectivePermissions { get; } + [Option("runtime.embeddings.enabled", Required = false, HelpText = "Enable/disable the embedding service. Default: true")] + public CliBool? RuntimeEmbeddingsEnabled { get; } + + [Option("runtime.embeddings.provider", Required = false, HelpText = "Configure embedding provider type. Allowed values: azure-openai, openai.")] + public EmbeddingProviderType? RuntimeEmbeddingsProvider { get; } + + [Option("runtime.embeddings.base-url", Required = false, HelpText = "Configure the embedding provider base URL.")] + public string? RuntimeEmbeddingsBaseUrl { get; } + + [Option("runtime.embeddings.api-key", Required = false, HelpText = "Configure the embedding API key for authentication.")] + public string? RuntimeEmbeddingsApiKey { get; } + + [Option("runtime.embeddings.model", Required = false, HelpText = "Configure the model/deployment name. Required for Azure OpenAI, defaults to text-embedding-3-small for OpenAI.")] + public string? RuntimeEmbeddingsModel { get; } + + [Option("runtime.embeddings.api-version", Required = false, HelpText = "Configure the Azure API version. Only used for Azure OpenAI provider. Default: 2024-02-01")] + public string? RuntimeEmbeddingsApiVersion { get; } + + [Option("runtime.embeddings.dimensions", Required = false, HelpText = "Configure the output vector dimensions. Optional, uses model default if not specified.")] + public int? RuntimeEmbeddingsDimensions { get; } + + [Option("runtime.embeddings.timeout-ms", Required = false, HelpText = "Configure the request timeout in milliseconds. Default: 30000")] + public int? RuntimeEmbeddingsTimeoutMs { get; } + + [Option("runtime.embeddings.endpoint.enabled", Required = false, HelpText = "Enable/disable the endpoint for embeddings. Default: false")] + public CliBool? RuntimeEmbeddingsEndpointEnabled { get; } + + [Option("runtime.embeddings.endpoint.roles", Required = false, Separator = ',', HelpText = "Configure the roles allowed to access the embedding endpoint. Comma-separated list. In development mode defaults to 'anonymous'.")] + public IEnumerable? RuntimeEmbeddingsEndpointRoles { get; } + + [Option("runtime.embeddings.health.enabled", Required = false, HelpText = "Enable/disable health checks for the embedding service. Default: true")] + public CliBool? RuntimeEmbeddingsHealthEnabled { get; } + + [Option("runtime.embeddings.health.threshold-ms", Required = false, HelpText = "Configure the health check threshold in milliseconds. Default: 5000")] + public int? RuntimeEmbeddingsHealthThresholdMs { get; } + + [Option("runtime.embeddings.health.test-text", Required = false, HelpText = "Configure the test text for health check validation. Default: 'health check'")] + public string? RuntimeEmbeddingsHealthTestText { get; } + + [Option("runtime.embeddings.health.expected-dimensions", Required = false, HelpText = "Configure the expected dimensions for health check validation. Optional.")] + public int? RuntimeEmbeddingsHealthExpectedDimensions { get; } public int Handler(ILogger logger, FileSystemRuntimeConfigLoader loader, IFileSystem fileSystem) { diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index 88591f2e86..1a43812c8c 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -10,6 +10,7 @@ using Azure.DataApiBuilder.Config.Converters; using Azure.DataApiBuilder.Config.NamingPolicies; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Azure.DataApiBuilder.Core; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Core.Resolvers; @@ -1182,6 +1183,33 @@ options.FileSinkRetainedFileCountLimit is not null || }; } + // Embeddings: Provider, Endpoint, ApiKey, Model, ApiVersion, Dimensions, TimeoutMs, Enabled, Endpoint.Enabled/Roles, Health.* + if (options.RuntimeEmbeddingsProvider is not null || + options.RuntimeEmbeddingsBaseUrl is not null || + options.RuntimeEmbeddingsApiKey is not null || + options.RuntimeEmbeddingsModel is not null || + options.RuntimeEmbeddingsApiVersion is not null || + options.RuntimeEmbeddingsDimensions is not null || + options.RuntimeEmbeddingsTimeoutMs is not null || + options.RuntimeEmbeddingsEnabled is not null || + options.RuntimeEmbeddingsEndpointEnabled is not null || + options.RuntimeEmbeddingsEndpointRoles is not null || + options.RuntimeEmbeddingsHealthEnabled is not null || + options.RuntimeEmbeddingsHealthThresholdMs is not null || + options.RuntimeEmbeddingsHealthTestText is not null || + options.RuntimeEmbeddingsHealthExpectedDimensions is not null) + { + bool status = TryUpdateConfiguredEmbeddingsValues(options, runtimeConfig?.Runtime?.Embeddings, out EmbeddingsOptions? updatedEmbeddingsOptions); + if (status && updatedEmbeddingsOptions is not null) + { + runtimeConfig = runtimeConfig! with { Runtime = runtimeConfig.Runtime! with { Embeddings = updatedEmbeddingsOptions } }; + } + else + { + return false; + } + } + return runtimeConfig != null; } @@ -1850,6 +1878,163 @@ private static bool TryUpdateConfiguredFileOptions( } } + /// + /// Attempts to update the embeddings configuration based on the provided options. + /// Creates a new EmbeddingsOptions object if the configuration is valid. + /// Provider, endpoint, and API key are required when configuring embeddings. + /// + /// The configuration options provided by the user. + /// The existing embeddings options from the runtime configuration. + /// The resulting embeddings options if successful. + /// True if the embeddings options were successfully configured; otherwise, false. + private static bool TryUpdateConfiguredEmbeddingsValues( + ConfigureOptions options, + EmbeddingsOptions? existingEmbeddingsOptions, + out EmbeddingsOptions? updatedEmbeddingsOptions) + { + updatedEmbeddingsOptions = null; + + try + { + // Get values from options or fall back to existing configuration + EmbeddingProviderType? provider = options.RuntimeEmbeddingsProvider ?? existingEmbeddingsOptions?.Provider; + string? baseUrl = options.RuntimeEmbeddingsBaseUrl ?? existingEmbeddingsOptions?.BaseUrl; + string? apiKey = options.RuntimeEmbeddingsApiKey ?? existingEmbeddingsOptions?.ApiKey; + string? model = options.RuntimeEmbeddingsModel ?? existingEmbeddingsOptions?.Model; + string? apiVersion = options.RuntimeEmbeddingsApiVersion ?? existingEmbeddingsOptions?.ApiVersion; + int? dimensions = options.RuntimeEmbeddingsDimensions ?? existingEmbeddingsOptions?.Dimensions; + int? timeoutMs = options.RuntimeEmbeddingsTimeoutMs ?? existingEmbeddingsOptions?.TimeoutMs; + bool? enabled = options.RuntimeEmbeddingsEnabled.HasValue + ? options.RuntimeEmbeddingsEnabled.Value == CliBool.True + : existingEmbeddingsOptions?.Enabled; + + // Validate required fields + if (provider is null) + { + _logger.LogError("Failed to configure embeddings: provider is required. Use --runtime.embeddings.provider to specify the provider (azure-openai or openai)."); + return false; + } + + if (string.IsNullOrEmpty(baseUrl)) + { + _logger.LogError("Failed to configure embeddings: base-url is required. Use --runtime.embeddings.base-url to specify the provider base URL."); + return false; + } + + if (string.IsNullOrEmpty(apiKey)) + { + _logger.LogError("Failed to configure embeddings: api-key is required. Use --runtime.embeddings.api-key to specify the authentication key."); + return false; + } + + // Validate Azure OpenAI requires model/deployment name + if (provider == EmbeddingProviderType.AzureOpenAI && string.IsNullOrEmpty(model)) + { + _logger.LogError("Failed to configure embeddings: model/deployment name is required for Azure OpenAI provider. Use --runtime.embeddings.model to specify the deployment name."); + return false; + } + + // Validate dimensions if provided + if (dimensions is not null && dimensions <= 0) + { + _logger.LogError("Failed to configure embeddings: dimensions must be a positive integer."); + return false; + } + + // Validate timeout if provided + if (timeoutMs is not null && timeoutMs <= 0) + { + _logger.LogError("Failed to configure embeddings: timeout-ms must be a positive integer."); + return false; + } + + // Build EmbeddingsEndpointOptions from CLI flags or existing config + EmbeddingsEndpointOptions? existingEndpoint = existingEmbeddingsOptions?.Endpoint; + EmbeddingsEndpointOptions? endpointOptions = null; + + if (options.RuntimeEmbeddingsEndpointEnabled is not null || + options.RuntimeEmbeddingsEndpointRoles is not null || + existingEndpoint is not null) + { + bool? endpointEnabled = options.RuntimeEmbeddingsEndpointEnabled.HasValue + ? options.RuntimeEmbeddingsEndpointEnabled.Value == CliBool.True + : existingEndpoint?.Enabled; + + string[]? endpointRoles = options.RuntimeEmbeddingsEndpointRoles is not null && options.RuntimeEmbeddingsEndpointRoles.Any() + ? options.RuntimeEmbeddingsEndpointRoles.ToArray() + : existingEndpoint?.Roles; + + endpointOptions = new EmbeddingsEndpointOptions( + enabled: endpointEnabled, + roles: endpointRoles); + + _logger.LogInformation("Updated RuntimeConfig with Runtime.Embeddings.Endpoint configuration."); + } + + // Build EmbeddingsHealthCheckConfig from CLI flags or existing config + EmbeddingsHealthCheckConfig? existingHealth = existingEmbeddingsOptions?.Health; + EmbeddingsHealthCheckConfig? healthOptions = null; + + if (options.RuntimeEmbeddingsHealthEnabled is not null || + options.RuntimeEmbeddingsHealthThresholdMs is not null || + options.RuntimeEmbeddingsHealthTestText is not null || + options.RuntimeEmbeddingsHealthExpectedDimensions is not null || + existingHealth is not null) + { + bool? healthEnabled = options.RuntimeEmbeddingsHealthEnabled.HasValue + ? options.RuntimeEmbeddingsHealthEnabled.Value == CliBool.True + : existingHealth?.Enabled; + + int? healthThresholdMs = options.RuntimeEmbeddingsHealthThresholdMs ?? existingHealth?.ThresholdMs; + string? healthTestText = options.RuntimeEmbeddingsHealthTestText ?? existingHealth?.TestText; + int? healthExpectedDimensions = options.RuntimeEmbeddingsHealthExpectedDimensions ?? existingHealth?.ExpectedDimensions; + + // Validate threshold if provided + if (healthThresholdMs is not null && healthThresholdMs <= 0) + { + _logger.LogError("Failed to configure embeddings health: threshold-ms must be a positive integer."); + return false; + } + + // Validate expected dimensions if provided + if (healthExpectedDimensions is not null && healthExpectedDimensions <= 0) + { + _logger.LogError("Failed to configure embeddings health: expected-dimensions must be a positive integer."); + return false; + } + + healthOptions = new EmbeddingsHealthCheckConfig( + enabled: healthEnabled, + thresholdMs: healthThresholdMs, + testText: healthTestText, + expectedDimensions: healthExpectedDimensions); + + _logger.LogInformation("Updated RuntimeConfig with Runtime.Embeddings.Health configuration."); + } + + // Create the embeddings options + updatedEmbeddingsOptions = new EmbeddingsOptions( + Provider: (EmbeddingProviderType)provider, + BaseUrl: baseUrl, + ApiKey: apiKey, + Enabled: enabled, + Model: model, + ApiVersion: apiVersion, + Dimensions: dimensions, + TimeoutMs: timeoutMs, + Endpoint: endpointOptions, + Health: healthOptions); + + _logger.LogInformation("Updated RuntimeConfig with Runtime.Embeddings configuration."); + return true; + } + catch (Exception ex) + { + _logger.LogError("Failed to update RuntimeConfig.Embeddings with exception message: {exceptionMessage}.", ex.Message); + return false; + } + } + /// /// Parse permission string to create PermissionSetting array. /// diff --git a/src/Config/Converters/EmbeddingsOptionsConverterFactory.cs b/src/Config/Converters/EmbeddingsOptionsConverterFactory.cs new file mode 100644 index 0000000000..3d48b7325a --- /dev/null +++ b/src/Config/Converters/EmbeddingsOptionsConverterFactory.cs @@ -0,0 +1,361 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json; +using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; + +namespace Azure.DataApiBuilder.Config.Converters; + +/// +/// Custom JSON converter for EmbeddingsOptions that handles proper deserialization +/// of the configuration properties including environment variable replacement. +/// +internal class EmbeddingsOptionsConverterFactory : JsonConverterFactory +{ + public EmbeddingsOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) + { + // Note: replacementSettings is not used in this converter because the environment variable + // replacement is handled by the string deserializers registered in the JsonSerializerOptions. + } + + /// + public override bool CanConvert(Type typeToConvert) + { + return typeToConvert.IsAssignableTo(typeof(EmbeddingsOptions)); + } + + /// + public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) + { + return new EmbeddingsOptionsConverter(); + } + + private class EmbeddingsOptionsConverter : JsonConverter + { + public override EmbeddingsOptions? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException("Expected start of object."); + } + + bool? enabled = null; + EmbeddingProviderType? provider = null; + string? baseUrl = null; + string? apiKey = null; + string? model = null; + string? apiVersion = null; + int? dimensions = null; + int? timeoutMs = null; + EmbeddingsEndpointOptions? endpoint = null; + EmbeddingsHealthCheckConfig? health = null; + EmbeddingsChunkingOptions? chunking = null; + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + + if (reader.TokenType != JsonTokenType.PropertyName) + { + throw new JsonException("Expected property name."); + } + + string? propertyName = reader.GetString()?.ToLowerInvariant(); + reader.Read(); + + switch (propertyName) + { + case "enabled": + enabled = JsonSerializer.Deserialize(ref reader, options); + break; + case "provider": + string? providerStr = reader.GetString(); + if (providerStr is not null) + { + provider = providerStr.ToLowerInvariant() switch + { + "azure-openai" => EmbeddingProviderType.AzureOpenAI, + "openai" => EmbeddingProviderType.OpenAI, + _ => throw new JsonException($"Unknown provider: {providerStr}") + }; + } + break; + case "base-url": + baseUrl = JsonSerializer.Deserialize(ref reader, options); + break; + case "api-key": + apiKey = JsonSerializer.Deserialize(ref reader, options); + break; + case "model": + model = JsonSerializer.Deserialize(ref reader, options); + break; + case "api-version": + apiVersion = JsonSerializer.Deserialize(ref reader, options); + break; + case "dimensions": + dimensions = JsonSerializer.Deserialize(ref reader, options); + break; + case "timeout-ms": + timeoutMs = JsonSerializer.Deserialize(ref reader, options); + break; + case "endpoint": + endpoint = ReadEndpointOptions(ref reader, options); + break; + case "health": + health = ReadHealthCheckConfig(ref reader, options); + break; + case "chunking": + chunking = ReadChunkingOptions(ref reader, options); + break; + default: + reader.Skip(); + break; + } + } + + if (provider is null) + { + throw new JsonException("Missing required property: provider"); + } + + if (baseUrl is null) + { + throw new JsonException("Missing required property: base-url"); + } + + if (apiKey is null) + { + throw new JsonException("Missing required property: api-key"); + } + + return new EmbeddingsOptions( + Provider: provider.Value, + BaseUrl: baseUrl, + ApiKey: apiKey, + Enabled: enabled, + Model: model, + ApiVersion: apiVersion, + Dimensions: dimensions, + TimeoutMs: timeoutMs, + Endpoint: endpoint, + Health: health, + Chunking: chunking); + } + + /// + /// Manually deserializes EmbeddingsEndpointOptions to handle the type mismatch + /// between nullable constructor parameters and non-nullable properties. + /// Follows the same pattern as FileSinkConverter. + /// + private static EmbeddingsEndpointOptions ReadEndpointOptions(ref Utf8JsonReader reader, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException("Expected start of object for endpoint."); + } + + bool? enabled = null; + string[]? roles = null; + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + return new EmbeddingsEndpointOptions(enabled: enabled, roles: roles); + } + + if (reader.TokenType != JsonTokenType.PropertyName) + { + throw new JsonException("Expected property name in endpoint."); + } + + string? propName = reader.GetString()?.ToLowerInvariant(); + reader.Read(); + + switch (propName) + { + case "enabled": + enabled = JsonSerializer.Deserialize(ref reader, options); + break; + case "roles": + roles = JsonSerializer.Deserialize(ref reader, options); + break; + default: + reader.Skip(); + break; + } + } + + throw new JsonException("Failed to read the EmbeddingsEndpointOptions."); + } + + /// + /// Manually deserializes EmbeddingsChunkingOptions to handle the type mismatch + /// between nullable constructor parameters and non-nullable properties. + /// Follows the same pattern as FileSinkConverter. + /// + private static EmbeddingsChunkingOptions ReadChunkingOptions(ref Utf8JsonReader reader, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException("Expected start of object for chunking."); + } + + bool? enabled = null; + int? sizeChars = null; + int? overlapChars = null; + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + return new EmbeddingsChunkingOptions(Enabled: enabled, SizeChars: sizeChars, OverlapChars: overlapChars); + } + + if (reader.TokenType != JsonTokenType.PropertyName) + { + throw new JsonException("Expected property name in chunking."); + } + + string? propName = reader.GetString()?.ToLowerInvariant(); + reader.Read(); + + switch (propName) + { + case "enabled": + enabled = reader.TokenType == JsonTokenType.Null ? null : reader.GetBoolean(); + break; + case "size-chars": + sizeChars = reader.TokenType == JsonTokenType.Null ? null : reader.GetInt32(); + break; + case "overlap-chars": + overlapChars = reader.TokenType == JsonTokenType.Null ? null : reader.GetInt32(); + break; + default: + reader.Skip(); + break; + } + } + + throw new JsonException("Failed to read the EmbeddingsChunkingOptions."); + } + + /// + /// Manually deserializes EmbeddingsHealthCheckConfig to handle the type mismatch + /// between nullable constructor parameters and non-nullable properties. + /// Follows the same pattern as FileSinkConverter. + /// + private static EmbeddingsHealthCheckConfig ReadHealthCheckConfig(ref Utf8JsonReader reader, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException("Expected start of object for health."); + } + + bool? enabled = null; + int? thresholdMs = null; + string? testText = null; + int? expectedDimensions = null; + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + return new EmbeddingsHealthCheckConfig(enabled: enabled, thresholdMs: thresholdMs, testText: testText, expectedDimensions: expectedDimensions); + } + + if (reader.TokenType != JsonTokenType.PropertyName) + { + throw new JsonException("Expected property name in health."); + } + + string? propName = reader.GetString()?.ToLowerInvariant(); + reader.Read(); + + switch (propName) + { + case "enabled": + enabled = JsonSerializer.Deserialize(ref reader, options); + break; + case "threshold-ms": + thresholdMs = JsonSerializer.Deserialize(ref reader, options); + break; + case "test-text": + testText = JsonSerializer.Deserialize(ref reader, options); + break; + case "expected-dimensions": + expectedDimensions = JsonSerializer.Deserialize(ref reader, options); + break; + default: + reader.Skip(); + break; + } + } + + throw new JsonException("Failed to read the EmbeddingsHealthCheckConfig."); + } + + public override void Write(Utf8JsonWriter writer, EmbeddingsOptions value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + writer.WriteBoolean("enabled", value.Enabled); + + // Write provider + string providerStr = value.Provider switch + { + EmbeddingProviderType.AzureOpenAI => "azure-openai", + EmbeddingProviderType.OpenAI => "openai", + _ => throw new JsonException($"Unknown provider: {value.Provider}") + }; + writer.WriteString("provider", providerStr); + + writer.WriteString("base-url", value.BaseUrl); + writer.WriteString("api-key", value.ApiKey); + + if (value.Model is not null) + { + writer.WriteString("model", value.Model); + } + + if (value.ApiVersion is not null) + { + writer.WriteString("api-version", value.ApiVersion); + } + + if (value.Dimensions is not null) + { + writer.WriteNumber("dimensions", value.Dimensions.Value); + } + + if (value.TimeoutMs is not null) + { + writer.WriteNumber("timeout-ms", value.TimeoutMs.Value); + } + + if (value.Endpoint is not null) + { + writer.WritePropertyName("endpoint"); + JsonSerializer.Serialize(writer, value.Endpoint, options); + } + + if (value.Health is not null) + { + writer.WritePropertyName("health"); + JsonSerializer.Serialize(writer, value.Health, options); + } + + if (value.Chunking is not null) + { + writer.WritePropertyName("chunking"); + JsonSerializer.Serialize(writer, value.Chunking, options); + } + + writer.WriteEndObject(); + } + } +} diff --git a/src/Config/HealthCheck/HealthCheckConstants.cs b/src/Config/HealthCheck/HealthCheckConstants.cs index fd5901575c..b57526fb75 100644 --- a/src/Config/HealthCheck/HealthCheckConstants.cs +++ b/src/Config/HealthCheck/HealthCheckConstants.cs @@ -12,6 +12,7 @@ public static class HealthCheckConstants public const string DATASOURCE = "data-source"; public const string REST = "rest"; public const string GRAPHQL = "graphql"; + public const string EMBEDDING = "embedding"; public const int ERROR_RESPONSE_TIME_MS = -1; public const int DEFAULT_THRESHOLD_RESPONSE_TIME_MS = 1000; public const int DEFAULT_FIRST_VALUE = 100; diff --git a/src/Config/ObjectModel/Embeddings/EmbeddingProviderType.cs b/src/Config/ObjectModel/Embeddings/EmbeddingProviderType.cs new file mode 100644 index 0000000000..9b2efc994b --- /dev/null +++ b/src/Config/ObjectModel/Embeddings/EmbeddingProviderType.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Runtime.Serialization; +using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config.Converters; + +namespace Azure.DataApiBuilder.Config.ObjectModel.Embeddings; + +/// +/// Represents the supported embedding provider types. +/// +[JsonConverter(typeof(EnumMemberJsonEnumConverterFactory))] +public enum EmbeddingProviderType +{ + /// + /// Azure OpenAI embedding provider. + /// + [EnumMember(Value = "azure-openai")] + AzureOpenAI, + + /// + /// OpenAI embedding provider. + /// + [EnumMember(Value = "openai")] + OpenAI +} diff --git a/src/Config/ObjectModel/Embeddings/EmbeddingsChunkingOptions.cs b/src/Config/ObjectModel/Embeddings/EmbeddingsChunkingOptions.cs new file mode 100644 index 0000000000..3bf406007f --- /dev/null +++ b/src/Config/ObjectModel/Embeddings/EmbeddingsChunkingOptions.cs @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel.Embeddings; + +/// +/// Represents the chunking options for text processing before embedding. +/// Used to split large text inputs into smaller chunks for embedding. +/// +public record EmbeddingsChunkingOptions +{ + /// + /// Default chunk size in characters. + /// + public const int DEFAULT_SIZE_CHARS = 800; + + /// + /// Default overlap size in characters between consecutive chunks. + /// + public const int DEFAULT_OVERLAP_CHARS = 100; + + /// + /// Whether chunking is enabled. Defaults to true. + /// When enabled, text inputs will be split into smaller chunks before embedding. + /// + [JsonPropertyName("enabled")] + public bool Enabled { get; init; } = true; + + /// + /// The size of each chunk in characters. + /// Defaults to 800 characters. + /// + [JsonPropertyName("size-chars")] + public int SizeChars { get; init; } + + /// + /// The number of characters to overlap between consecutive chunks. + /// Defaults to 100 characters. + /// Overlap helps maintain context across chunk boundaries. + /// + [JsonPropertyName("overlap-chars")] + public int OverlapChars { get; init; } + + [JsonConstructor] + public EmbeddingsChunkingOptions( + bool? Enabled = null, + int? SizeChars = null, + int? OverlapChars = null) + { + this.Enabled = Enabled ?? true; + this.SizeChars = SizeChars ?? DEFAULT_SIZE_CHARS; + this.OverlapChars = Math.Max(0, OverlapChars ?? DEFAULT_OVERLAP_CHARS); + } + + /// + /// Gets the effective chunk size, ensuring it's at least as large as the overlap. + /// + [JsonIgnore] + public int EffectiveSizeChars => Math.Max(SizeChars, OverlapChars + 1); +} diff --git a/src/Config/ObjectModel/Embeddings/EmbeddingsEndpointOptions.cs b/src/Config/ObjectModel/Embeddings/EmbeddingsEndpointOptions.cs new file mode 100644 index 0000000000..e620cdb619 --- /dev/null +++ b/src/Config/ObjectModel/Embeddings/EmbeddingsEndpointOptions.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel.Embeddings; + +/// +/// Endpoint configuration for the embedding service. +/// +public record EmbeddingsEndpointOptions +{ + /// + /// Default path for the embedding endpoint. + /// + public const string DEFAULT_PATH = "/embed"; + + /// + /// Default roles for the embedding endpoint. + /// + public static readonly string[] DEFAULT_ROLES = new[] { "authenticated" }; + + /// + /// Anonymous role constant. + /// + public const string ANONYMOUS_ROLE = "anonymous"; + + /// + /// Whether the endpoint is enabled. Defaults to false. + /// + [JsonPropertyName("enabled")] + public bool Enabled { get; init; } + + /// + /// Flag indicating whether the user provided the enabled setting. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedEnabled { get; init; } + + /// + /// The roles allowed to access the embedding endpoint. + /// When null, GetEffectiveRoles returns ["authenticated"] by default. + /// In production mode, must be explicitly configured (cannot be null). + /// + [JsonPropertyName("roles")] + public string[]? Roles { get; init; } + + /// + /// Gets the effective roles. + /// Returns configured roles if specified, otherwise defaults to ["authenticated"]. + /// + /// Whether the host is in development mode (kept for API compatibility). + /// Array of allowed roles. + public string[] GetEffectiveRoles(bool isDevelopmentMode) + { + if (Roles is not null && Roles.Length > 0) + { + return Roles; + } + + return DEFAULT_ROLES; + } + + /// + /// Checks if the given role is allowed to access the embedding endpoint. + /// + /// The role to check. + /// Whether the host is in development mode. + /// True if the role is allowed; otherwise, false. + public bool IsRoleAllowed(string role, bool isDevelopmentMode) + { + string[] effectiveRoles = GetEffectiveRoles(isDevelopmentMode); + return effectiveRoles.Contains(role, StringComparer.OrdinalIgnoreCase); + } + + /// + /// Default constructor. + /// + public EmbeddingsEndpointOptions() + { + Enabled = false; + } + + /// + /// Constructor with optional parameters. + /// + [JsonConstructor] + public EmbeddingsEndpointOptions( + bool? enabled = null, + string[]? roles = null) + { + if (enabled.HasValue) + { + Enabled = enabled.Value; + UserProvidedEnabled = true; + } + else + { + Enabled = false; + } + + // Keep roles as-is (null if not provided) so validation can check it + // GetEffectiveRoles() will provide the default when needed + Roles = roles; + } +} diff --git a/src/Config/ObjectModel/Embeddings/EmbeddingsHealthCheckConfig.cs b/src/Config/ObjectModel/Embeddings/EmbeddingsHealthCheckConfig.cs new file mode 100644 index 0000000000..31ef87415b --- /dev/null +++ b/src/Config/ObjectModel/Embeddings/EmbeddingsHealthCheckConfig.cs @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel.Embeddings; + +/// +/// Health check configuration for embeddings. +/// Validates that the embedding service is responding within threshold and returning expected results. +/// +public record EmbeddingsHealthCheckConfig : HealthCheckConfig +{ + /// + /// Default threshold for embedding health check in milliseconds. + /// + public const int DEFAULT_THRESHOLD_MS = 1000; + + /// + /// Default test text used for health check validation. + /// + public const string DEFAULT_TEST_TEXT = "health check"; + + /// + /// The expected milliseconds the embedding request should complete within to be considered healthy. + /// If the request takes longer than this value, the health check will be considered unhealthy. + /// Requests completing at exactly the threshold are considered healthy. + /// Default: 1000ms (1 second) + /// + [JsonPropertyName("threshold-ms")] + public int ThresholdMs { get; init; } + + /// + /// Flag indicating whether the user provided a custom threshold. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedThresholdMs { get; init; } + + /// + /// The test text to use for health check validation. + /// This text will be embedded and the result validated. + /// Default: "health check" + /// + [JsonPropertyName("test-text")] + public string TestText { get; init; } + + /// + /// Flag indicating whether the user provided custom test text. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedTestText { get; init; } + + /// + /// The expected number of dimensions in the embedding result. + /// If specified, the health check will verify the embedding has this many dimensions. + /// If not specified, dimension validation is skipped. + /// + [JsonPropertyName("expected-dimensions")] + public int? ExpectedDimensions { get; init; } + + /// + /// Flag indicating whether the user provided expected dimensions. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedExpectedDimensions { get; init; } + + /// + /// Default constructor with default values. + /// + public EmbeddingsHealthCheckConfig() : base() + { + ThresholdMs = DEFAULT_THRESHOLD_MS; + TestText = DEFAULT_TEST_TEXT; + } + + /// + /// Constructor with optional parameters. + /// + [JsonConstructor] + public EmbeddingsHealthCheckConfig( + bool? enabled = null, + int? thresholdMs = null, + string? testText = null, + int? expectedDimensions = null) : base(enabled) + { + if (thresholdMs is not null) + { + ThresholdMs = (int)thresholdMs; + UserProvidedThresholdMs = true; + } + else + { + ThresholdMs = DEFAULT_THRESHOLD_MS; + } + + if (testText is not null) + { + TestText = testText; + UserProvidedTestText = true; + } + else + { + TestText = DEFAULT_TEST_TEXT; + } + + if (expectedDimensions is not null) + { + ExpectedDimensions = expectedDimensions; + UserProvidedExpectedDimensions = true; + } + } +} diff --git a/src/Config/ObjectModel/Embeddings/EmbeddingsOptions.cs b/src/Config/ObjectModel/Embeddings/EmbeddingsOptions.cs new file mode 100644 index 0000000000..3a44422dbc --- /dev/null +++ b/src/Config/ObjectModel/Embeddings/EmbeddingsOptions.cs @@ -0,0 +1,238 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel.Embeddings; + +/// +/// Represents the options for configuring the embedding service. +/// Used for text embedding/vectorization with OpenAI or Azure OpenAI providers. +/// +public record EmbeddingsOptions +{ + /// + /// Default timeout in milliseconds for embedding requests. + /// + public const int DEFAULT_TIMEOUT_MS = 30000; + + /// + /// Default dimensions for embedding vectors. + /// + public const int DEFAULT_DIMENSIONS = 1536; + + /// + /// Default API version for Azure OpenAI. + /// + public const string DEFAULT_AZURE_API_VERSION = "2023-05-15"; + + /// + /// Default model for OpenAI embeddings. + /// + public const string DEFAULT_OPENAI_MODEL = "text-embedding-3-small"; + + /// + /// Whether the embedding service is enabled. Defaults to true. + /// When false, the embedding service will not be used. + /// + [JsonPropertyName("enabled")] + public bool Enabled { get; init; } = true; + + /// + /// Flag indicating whether the user provided the enabled setting. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedEnabled { get; init; } + + /// + /// The embedding provider type (azure-openai or openai). + /// Required. + /// + [JsonPropertyName("provider")] + public EmbeddingProviderType Provider { get; init; } + + /// + /// The provider base URL. + /// Required. + /// + [JsonPropertyName("base-url")] + public string BaseUrl { get; init; } + + /// + /// The API key for authentication. + /// Required. + /// + [JsonPropertyName("api-key")] + public string ApiKey { get; init; } + + /// + /// The model or deployment name. + /// For Azure OpenAI, this is the deployment name. + /// For OpenAI, this is the model name (defaults to text-embedding-3-small if not specified). + /// + [JsonPropertyName("model")] + public string? Model { get; init; } + + /// + /// Azure API version. Only used for Azure OpenAI provider. + /// Defaults to 2024-02-01. + /// + [JsonPropertyName("api-version")] + public string? ApiVersion { get; init; } + + /// + /// Output vector dimensions. Optional, uses model default if not specified. + /// + [JsonPropertyName("dimensions")] + public int? Dimensions { get; init; } + + /// + /// Request timeout in milliseconds. Defaults to 30000 (30 seconds). + /// + [JsonPropertyName("timeout-ms")] + public int? TimeoutMs { get; init; } + + /// + /// Endpoint configuration for the embedding service. + /// + [JsonPropertyName("endpoint")] + public EmbeddingsEndpointOptions? Endpoint { get; init; } + + /// + /// Health check configuration for the embedding service. + /// + [JsonPropertyName("health")] + public EmbeddingsHealthCheckConfig? Health { get; init; } + + /// + /// Chunking configuration for text processing before embedding. + /// + [JsonPropertyName("chunking")] + public EmbeddingsChunkingOptions? Chunking { get; init; } + + /// + /// Flag which informs whether the user provided a custom timeout value. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(TimeoutMs))] + public bool UserProvidedTimeoutMs { get; init; } + + /// + /// Flag which informs whether the user provided a custom API version. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(ApiVersion))] + public bool UserProvidedApiVersion { get; init; } + + /// + /// Flag which informs whether the user provided custom dimensions. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(Dimensions))] + public bool UserProvidedDimensions { get; init; } + + /// + /// Flag which informs whether the user provided a custom model. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(Model))] + public bool UserProvidedModel { get; init; } + + /// + /// Gets the effective timeout in milliseconds, using default if not specified. + /// + [JsonIgnore] + public int EffectiveTimeoutMs => TimeoutMs ?? DEFAULT_TIMEOUT_MS; + + /// + /// Gets the effective API version for Azure OpenAI, using default if not specified. + /// + [JsonIgnore] + public string EffectiveApiVersion => ApiVersion ?? DEFAULT_AZURE_API_VERSION; + + /// + /// Gets the effective model name, using default for OpenAI if not specified. + /// For Azure OpenAI, model is required (no default). + /// + [JsonIgnore] + public string? EffectiveModel => Model ?? (Provider == EmbeddingProviderType.OpenAI ? DEFAULT_OPENAI_MODEL : null); + + /// + /// Returns true if embedding health check is enabled. + /// + [JsonIgnore] + public bool IsHealthCheckEnabled => Health?.Enabled ?? false; + + /// + /// Returns true if embedding endpoint is enabled. + /// + [JsonIgnore] + public bool IsEndpointEnabled => Endpoint?.Enabled ?? false; + + /// + /// Returns true if chunking is enabled. + /// + [JsonIgnore] + public bool IsChunkingEnabled => Chunking?.Enabled ?? false; + + [JsonConstructor] + public EmbeddingsOptions( + EmbeddingProviderType Provider, + string BaseUrl, + string ApiKey, + bool? Enabled = null, + string? Model = null, + string? ApiVersion = null, + int? Dimensions = null, + int? TimeoutMs = null, + EmbeddingsEndpointOptions? Endpoint = null, + EmbeddingsHealthCheckConfig? Health = null, + EmbeddingsChunkingOptions? Chunking = null) + { + this.Provider = Provider; + this.BaseUrl = BaseUrl; + this.ApiKey = ApiKey; + this.Endpoint = Endpoint; + this.Health = Health; + this.Chunking = Chunking; + + if (Enabled.HasValue) + { + this.Enabled = Enabled.Value; + UserProvidedEnabled = true; + } + else + { + this.Enabled = true; // Default to enabled + } + + if (Model is not null) + { + this.Model = Model; + UserProvidedModel = true; + } + + if (ApiVersion is not null) + { + this.ApiVersion = ApiVersion; + UserProvidedApiVersion = true; + } + + if (Dimensions.HasValue) + { + this.Dimensions = Dimensions.Value; + UserProvidedDimensions = true; + } + else + { + this.Dimensions = DEFAULT_DIMENSIONS; + } + + if (TimeoutMs is not null) + { + this.TimeoutMs = TimeoutMs; + UserProvidedTimeoutMs = true; + } + } +} diff --git a/src/Config/ObjectModel/RuntimeOptions.cs b/src/Config/ObjectModel/RuntimeOptions.cs index 525ea8d089..1d5ad86db0 100644 --- a/src/Config/ObjectModel/RuntimeOptions.cs +++ b/src/Config/ObjectModel/RuntimeOptions.cs @@ -3,6 +3,7 @@ using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; namespace Azure.DataApiBuilder.Config.ObjectModel; @@ -17,6 +18,7 @@ public record RuntimeOptions public RuntimeCacheOptions? Cache { get; init; } public PaginationOptions? Pagination { get; init; } public RuntimeHealthCheckConfig? Health { get; init; } + public EmbeddingsOptions? Embeddings { get; init; } public CompressionOptions? Compression { get; init; } [JsonConstructor] @@ -30,6 +32,7 @@ public RuntimeOptions( RuntimeCacheOptions? Cache = null, PaginationOptions? Pagination = null, RuntimeHealthCheckConfig? Health = null, + EmbeddingsOptions? Embeddings = null, CompressionOptions? Compression = null) { this.Rest = Rest; @@ -41,6 +44,7 @@ public RuntimeOptions( this.Cache = Cache; this.Pagination = Pagination; this.Health = Health; + this.Embeddings = Embeddings; this.Compression = Compression; } @@ -77,4 +81,12 @@ Mcp is null || Health is null || Health?.Enabled is null || Health?.Enabled is true; + + /// + /// Indicates whether embeddings are configured. + /// Embeddings are considered configured when the Embeddings property is not null. + /// + [JsonIgnore] + [MemberNotNullWhen(true, nameof(Embeddings))] + public bool IsEmbeddingsConfigured => Embeddings is not null; } diff --git a/src/Config/RuntimeConfigLoader.cs b/src/Config/RuntimeConfigLoader.cs index 83b7b3969e..1d417a0924 100644 --- a/src/Config/RuntimeConfigLoader.cs +++ b/src/Config/RuntimeConfigLoader.cs @@ -342,6 +342,9 @@ public static JsonSerializerOptions GetSerializationOptions( // Add AzureKeyVaultOptionsConverterFactory to ensure AKV config is deserialized properly options.Converters.Add(new AzureKeyVaultOptionsConverterFactory(replacementSettings)); + // Add EmbeddingsOptionsConverterFactory to handle embeddings configuration + options.Converters.Add(new EmbeddingsOptionsConverterFactory(replacementSettings)); + // Only add the extensible string converter if we have replacement settings if (replacementSettings is not null) { diff --git a/src/Core/Configurations/RuntimeConfigValidator.cs b/src/Core/Configurations/RuntimeConfigValidator.cs index c8d86e8e11..2eba39337e 100644 --- a/src/Core/Configurations/RuntimeConfigValidator.cs +++ b/src/Core/Configurations/RuntimeConfigValidator.cs @@ -6,6 +6,7 @@ using System.Text.RegularExpressions; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Azure.DataApiBuilder.Core.AuthenticationHelpers; using Azure.DataApiBuilder.Core.Authorization; using Azure.DataApiBuilder.Core.Models; @@ -96,6 +97,19 @@ public void ValidateConfigProperties() ValidateLoggerFilters(runtimeConfig); ValidateAzureLogAnalyticsAuth(runtimeConfig); ValidateFileSinkPath(runtimeConfig); + ValidateEmbeddingsOptions(runtimeConfig); + + // Running these graphQL validations only in development mode to ensure + // fast startup of engine in production mode. + if (runtimeConfig.IsDevelopmentMode()) + { + ValidateEntityConfiguration(runtimeConfig); + + if (runtimeConfig.IsGraphQLEnabled) + { + ValidateEntitiesDoNotGenerateDuplicateQueriesOrMutation(runtimeConfig.DataSource.DatabaseType, runtimeConfig.Entities); + } + } } /// @@ -296,6 +310,132 @@ public void ValidateFileSinkPath(RuntimeConfig runtimeConfig) } } + /// + /// Validates the embeddings configuration options when embeddings are configured. + /// Checks required fields, URL format, numeric constraints, and endpoint constraints. + /// + public void ValidateEmbeddingsOptions(RuntimeConfig runtimeConfig) + { + // Skip validation if embeddings are not configured. + if (runtimeConfig.Runtime?.Embeddings is null) + { + return; + } + + EmbeddingsOptions embeddingsOptions = runtimeConfig.Runtime.Embeddings; + + // Skip further validation if embeddings are explicitly disabled. + if (!embeddingsOptions.Enabled) + { + return; + } + + // base-url is required and must be a valid URL. + if (string.IsNullOrWhiteSpace(embeddingsOptions.BaseUrl)) + { + HandleOrRecordException(new DataApiBuilderException( + message: "Embeddings 'base-url' cannot be null or empty when embeddings are enabled.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + else if (!Uri.TryCreate(embeddingsOptions.BaseUrl, UriKind.Absolute, out Uri? baseUri) || + (baseUri.Scheme != Uri.UriSchemeHttps && baseUri.Scheme != Uri.UriSchemeHttp)) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"Embeddings 'base-url' must be a valid HTTP or HTTPS URL. Got: {embeddingsOptions.BaseUrl}", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + // api-key is required. + if (string.IsNullOrWhiteSpace(embeddingsOptions.ApiKey)) + { + HandleOrRecordException(new DataApiBuilderException( + message: "Embeddings 'api-key' cannot be null or empty when embeddings are enabled.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + // For Azure OpenAI provider, model (deployment name) is required. + if (embeddingsOptions.Provider == EmbeddingProviderType.AzureOpenAI && string.IsNullOrWhiteSpace(embeddingsOptions.Model)) + { + HandleOrRecordException(new DataApiBuilderException( + message: "Embeddings 'model' (deployment name) is required when using the Azure OpenAI provider.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + // timeout-ms must be positive if provided. + if (embeddingsOptions.TimeoutMs is not null && embeddingsOptions.TimeoutMs <= 0) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"Embeddings 'timeout-ms' must be a positive integer. Got: {embeddingsOptions.TimeoutMs}", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + // dimensions must be positive if provided. + if (embeddingsOptions.Dimensions is not null && embeddingsOptions.Dimensions <= 0) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"Embeddings 'dimensions' must be a positive integer. Got: {embeddingsOptions.Dimensions}", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + // Validate endpoint configuration. + if (embeddingsOptions.Endpoint is not null && embeddingsOptions.Endpoint.Enabled) + { + // In production mode, roles must be explicitly configured (cannot be null). + if (!runtimeConfig.IsDevelopmentMode() && + embeddingsOptions.Endpoint.Roles is null) + { + HandleOrRecordException(new DataApiBuilderException( + message: "Embeddings endpoint 'roles' must be explicitly configured in production mode.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + // Empty roles array is not allowed (checked after production null check) + if (embeddingsOptions.Endpoint.Roles is not null && embeddingsOptions.Endpoint.Roles.Length == 0) + { + HandleOrRecordException(new DataApiBuilderException( + message: "Embeddings endpoint 'roles' cannot be empty when endpoint is enabled.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + } + + // Validate health check configuration. + if (embeddingsOptions.Health is not null && embeddingsOptions.Health.Enabled) + { + if (embeddingsOptions.Health.ThresholdMs <= 0) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"Embeddings health check 'threshold-ms' must be a positive integer. Got: {embeddingsOptions.Health.ThresholdMs}", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (string.IsNullOrWhiteSpace(embeddingsOptions.Health.TestText)) + { + HandleOrRecordException(new DataApiBuilderException( + message: "Embeddings health check 'test-text' cannot be null or empty when health check is enabled.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (embeddingsOptions.Health.ExpectedDimensions is not null && embeddingsOptions.Health.ExpectedDimensions <= 0) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"Embeddings health check 'expected-dimensions' must be a positive integer. Got: {embeddingsOptions.Health.ExpectedDimensions}", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + } + + } + /// /// This method runs several validations against the config file such as schema validation, /// validation of entities metadata, validation of permissions, validation of entity configuration. diff --git a/src/Core/Services/Embeddings/EmbeddingService.cs b/src/Core/Services/Embeddings/EmbeddingService.cs new file mode 100644 index 0000000000..468c27622f --- /dev/null +++ b/src/Core/Services/Embeddings/EmbeddingService.cs @@ -0,0 +1,594 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Net.Http.Headers; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; +using Microsoft.Extensions.Logging; +using ZiggyCreatures.Caching.Fusion; + +namespace Azure.DataApiBuilder.Core.Services.Embeddings; + +/// +/// Service implementation for text embedding/vectorization. +/// Supports both OpenAI and Azure OpenAI providers. +/// Caches embeddings using FusionCache L1 memory cache. +/// L2/distributed cache is optional globally and is used by this service when configured. +/// +public class EmbeddingService : IEmbeddingService +{ + private readonly HttpClient _httpClient; + private readonly EmbeddingsOptions _options; + private readonly ILogger _logger; + private readonly IFusionCache _cache; + private readonly string _providerName; + + // Constants + private const char KEY_DELIMITER = ':'; + private const string CACHE_KEY_PREFIX = "embedding"; + + /// + /// Maximum number of text chunks accepted in one batch embedding request. + /// This protects the system from accidentally submitting extremely large arrays. + /// + public const int MAX_BATCH_TEXT_COUNT = 2048; + + /// + /// Default cache TTL in hours. Set high since embeddings are deterministic and don't get outdated. + /// + private const int DEFAULT_CACHE_TTL_HOURS = 24; + + /// + /// JSON serializer options for request/response handling. + /// + private static readonly JsonSerializerOptions _jsonSerializerOptions = new() + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }; + + /// + /// Initializes a new instance of the EmbeddingService. + /// + /// The HTTP client for making API requests. + /// The embedding configuration options. + /// The logger instance. + /// The FusionCache instance used for caching embedding vectors. + public EmbeddingService( + HttpClient httpClient, + EmbeddingsOptions options, + ILogger logger, + IFusionCache cache) + { + _httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); + _options = options ?? throw new ArgumentNullException(nameof(options)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _cache = cache ?? throw new ArgumentNullException(nameof(cache)); + + // Cache provider name for telemetry to avoid repeated string allocations + _providerName = _options.Provider.ToString().ToLowerInvariant(); + + // Validate required options + if (string.IsNullOrEmpty(_options.BaseUrl)) + { + throw new ArgumentException("BaseUrl is required in EmbeddingsOptions.", nameof(options)); + } + + if (string.IsNullOrEmpty(_options.ApiKey)) + { + throw new ArgumentException("ApiKey is required in EmbeddingsOptions.", nameof(options)); + } + + // Azure OpenAI requires model/deployment name + if (_options.Provider == EmbeddingProviderType.AzureOpenAI && string.IsNullOrEmpty(_options.EffectiveModel)) + { + throw new InvalidOperationException("Model/deployment name is required for Azure OpenAI provider."); + } + + ConfigureHttpClient(); + } + + /// + /// Configures the HTTP client with timeout and authentication headers. + /// + private void ConfigureHttpClient() + { + _httpClient.Timeout = TimeSpan.FromMilliseconds(_options.EffectiveTimeoutMs); + + if (_options.Provider == EmbeddingProviderType.AzureOpenAI) + { + _httpClient.DefaultRequestHeaders.Add("api-key", _options.ApiKey); + } + else + { + _httpClient.DefaultRequestHeaders.Authorization = + new AuthenticationHeaderValue("Bearer", _options.ApiKey); + } + + _httpClient.DefaultRequestHeaders.Accept.Clear(); + _httpClient.DefaultRequestHeaders.Accept.Add( + new MediaTypeWithQualityHeaderValue("application/json")); + } + + /// + public bool IsEnabled => _options.Enabled; + + /// + public async Task TryEmbedAsync(string text, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + _logger.LogDebug("Embedding service is disabled, skipping embed request"); + return new EmbeddingResult(false, null, "Embedding service is disabled."); + } + + if (string.IsNullOrEmpty(text)) + { + _logger.LogWarning("TryEmbedAsync called with null or empty text"); + return new EmbeddingResult(false, null, "Text cannot be null or empty."); + } + + Stopwatch stopwatch = Stopwatch.StartNew(); + using Activity? activity = EmbeddingTelemetryHelper.StartEmbeddingActivity("TryEmbedAsync"); + activity?.SetEmbeddingActivityTags(_providerName, _options.EffectiveModel, textCount: 1); + + try + { + EmbeddingTelemetryHelper.TrackEmbeddingRequest(_providerName, textCount: 1); + + (float[] embedding, bool fromCache) = await EmbedWithCacheInfoAsync(text, cancellationToken); + + stopwatch.Stop(); + activity?.SetEmbeddingActivitySuccess(stopwatch.Elapsed.TotalMilliseconds, embedding.Length); + EmbeddingTelemetryHelper.TrackTotalDuration(_providerName, stopwatch.Elapsed, fromCache: fromCache); + EmbeddingTelemetryHelper.TrackDimensions(_providerName, embedding.Length); + + if (fromCache) + { + EmbeddingTelemetryHelper.TrackCacheHit(_providerName); + } + else + { + EmbeddingTelemetryHelper.TrackCacheMiss(_providerName); + } + + return new EmbeddingResult(true, embedding); + } + catch (Exception ex) + { + stopwatch.Stop(); + _logger.LogError(ex, "Failed to generate embedding for text"); + activity?.SetEmbeddingActivityError(ex); + EmbeddingTelemetryHelper.TrackError(_providerName, ex.GetType().Name); + + return new EmbeddingResult(false, null, "Failed to generate embedding."); + } + } + + /// + public async Task TryEmbedBatchAsync(string[] texts, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + _logger.LogDebug("Embedding service is disabled, skipping batch embed request"); + return new EmbeddingBatchResult(false, null, "Embedding service is disabled."); + } + + if (texts is null || texts.Length == 0) + { + _logger.LogWarning("TryEmbedBatchAsync called with null or empty texts array"); + return new EmbeddingBatchResult(false, null, "Texts array cannot be null or empty."); + } + + if (texts.Any(string.IsNullOrEmpty)) + { + _logger.LogWarning("TryEmbedBatchAsync called with one or more null or empty texts"); + return new EmbeddingBatchResult(false, null, "Texts array must not contain null or empty entries."); + } + + if (texts.Length > MAX_BATCH_TEXT_COUNT) + { + _logger.LogWarning( + "TryEmbedBatchAsync called with {Count} texts, which exceeds max supported batch size {MaxBatchSize}", + texts.Length, + MAX_BATCH_TEXT_COUNT); + return new EmbeddingBatchResult( + false, + null, + $"Texts array exceeds max supported batch size of {MAX_BATCH_TEXT_COUNT}."); + } + + Stopwatch stopwatch = Stopwatch.StartNew(); + using Activity? activity = EmbeddingTelemetryHelper.StartEmbeddingActivity("TryEmbedBatchAsync"); + activity?.SetEmbeddingActivityTags(_providerName, _options.EffectiveModel, texts.Length); + + try + { + EmbeddingTelemetryHelper.TrackEmbeddingRequest(_providerName, texts.Length); + + float[][] embeddings = await EmbedBatchAsync(texts, cancellationToken); + + stopwatch.Stop(); + int dimensions = embeddings.Length > 0 ? embeddings[0].Length : 0; + activity?.SetEmbeddingActivitySuccess(stopwatch.Elapsed.TotalMilliseconds, dimensions); + EmbeddingTelemetryHelper.TrackTotalDuration(_providerName, stopwatch.Elapsed, fromCache: false); + if (dimensions > 0) + { + EmbeddingTelemetryHelper.TrackDimensions(_providerName, dimensions); + } + + return new EmbeddingBatchResult(true, embeddings); + } + catch (Exception ex) + { + stopwatch.Stop(); + _logger.LogError(ex, "Failed to generate embeddings for batch of {Count} texts", texts.Length); + activity?.SetEmbeddingActivityError(ex); + EmbeddingTelemetryHelper.TrackError(_providerName, ex.GetType().Name); + + return new EmbeddingBatchResult(false, null, "Failed to generate embeddings."); + } + } + + /// + public async Task EmbedAsync(string text, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + throw new InvalidOperationException("Embedding service is disabled."); + } + + if (string.IsNullOrEmpty(text)) + { + throw new ArgumentException("Text cannot be null or empty.", nameof(text)); + } + + (float[] embedding, _) = await EmbedWithCacheInfoAsync(text, cancellationToken); + return embedding; + } + + /// + public async Task EmbedBatchAsync(string[] texts, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + throw new InvalidOperationException("Embedding service is disabled."); + } + + if (texts is null || texts.Length == 0) + { + throw new ArgumentException("Texts cannot be null or empty.", nameof(texts)); + } + + if (texts.Any(string.IsNullOrEmpty)) + { + throw new ArgumentException("Texts array must not contain null or empty entries.", nameof(texts)); + } + + if (texts.Length > MAX_BATCH_TEXT_COUNT) + { + throw new ArgumentException( + $"Texts array exceeds max supported batch size of {MAX_BATCH_TEXT_COUNT}.", + nameof(texts)); + } + + // For batch, check cache for each text individually + string[] cacheKeys = texts.Select(CreateCacheKey).ToArray(); + float[]?[] results = new float[texts.Length][]; + List uncachedIndices = new(); + int cacheHits = 0; + + // Check cache for each text + for (int i = 0; i < texts.Length; i++) + { + MaybeValue cached = _cache.TryGet(key: cacheKeys[i]); + + if (cached.HasValue) + { + _logger.LogDebug("Embedding cache hit for text hash {TextHash}", cacheKeys[i]); + results[i] = cached.Value; + cacheHits++; + EmbeddingTelemetryHelper.TrackCacheHit(_providerName); + } + else + { + uncachedIndices.Add(i); + EmbeddingTelemetryHelper.TrackCacheMiss(_providerName); + } + } + + // If all texts were cached, return immediately + if (uncachedIndices.Count == 0) + { + _logger.LogDebug("All {Count} texts were cache hits, returning cached embeddings", texts.Length); + return results!; + } + + _logger.LogDebug("Embedding cache miss for {Count} text(s), calling API", uncachedIndices.Count); + + // Call API for uncached texts only + string[] uncachedTexts = uncachedIndices.Select(i => texts[i]).ToArray(); + + Stopwatch apiStopwatch = Stopwatch.StartNew(); + float[][] apiResults = await EmbedFromApiAsync(uncachedTexts, cancellationToken); + apiStopwatch.Stop(); + + // Track API call telemetry + EmbeddingTelemetryHelper.TrackApiCall(_providerName, uncachedTexts.Length); + EmbeddingTelemetryHelper.TrackApiDuration(_providerName, apiStopwatch.Elapsed, uncachedTexts.Length); + + // Cache new results and merge with cached results + for (int i = 0; i < uncachedIndices.Count; i++) + { + int originalIndex = uncachedIndices[i]; + results[originalIndex] = apiResults[i]; + + // Store embeddings using the configured FusionCache stack. + _cache.Set( + key: cacheKeys[originalIndex], + value: apiResults[i], + options => + { + options.SetDuration(TimeSpan.FromHours(DEFAULT_CACHE_TTL_HOURS)); + }); + } + + return results!; + } + + /// + /// Internal helper that embeds text using cache and returns whether the result came from cache. + /// + private async Task<(float[] Embedding, bool FromCache)> EmbedWithCacheInfoAsync(string text, CancellationToken cancellationToken) + { + string cacheKey = CreateCacheKey(text); + bool fromCache = true; + + float[]? embedding = await _cache.GetOrSetAsync( + key: cacheKey, + async (FusionCacheFactoryExecutionContext ctx, CancellationToken ct) => + { + fromCache = false; + _logger.LogDebug("Embedding cache miss, calling API for text hash {TextHash}", cacheKey); + + float[][] results = await EmbedFromApiAsync(new[] { text }, ct); + float[] result = results[0]; + + // Validate the embedding result is not empty + if (result.Length == 0) + { + throw new InvalidOperationException("API returned empty embedding array."); + } + + // Respect configured cache layers (L1 and optional L2). + ctx.Options.SetDuration(TimeSpan.FromHours(DEFAULT_CACHE_TTL_HOURS)); + + return result; + }, + token: cancellationToken); + + if (embedding is null || embedding.Length == 0) + { + throw new InvalidOperationException("Failed to get embedding from cache or API."); + } + + return (embedding, fromCache); + } + + /// + /// Creates a cache key from the text using SHA256 hash. + /// Format: embedding:{provider}:{model}:{SHA256_hash} + /// Includes provider and model to prevent cross-configuration collisions. + /// Uses hash to keep cache keys small and deterministic. + /// + /// The text to create a cache key for. + /// Cache key string. + private string CreateCacheKey(string text) + { + // Include provider and model in hash to avoid cross-provider/model collisions + string keyInput = $"{_options.Provider}:{_options.EffectiveModel}:{text}"; + byte[] textBytes = Encoding.UTF8.GetBytes(keyInput); + byte[] hashBytes = SHA256.HashData(textBytes); + string hashHex = Convert.ToHexString(hashBytes); + string model = _options.EffectiveModel ?? "unknown"; + + StringBuilder cacheKeyBuilder = new(); + cacheKeyBuilder.Append(CACHE_KEY_PREFIX); + cacheKeyBuilder.Append(KEY_DELIMITER); + cacheKeyBuilder.Append(_providerName); + cacheKeyBuilder.Append(KEY_DELIMITER); + cacheKeyBuilder.Append(model); + cacheKeyBuilder.Append(KEY_DELIMITER); + cacheKeyBuilder.Append(hashHex); + + return cacheKeyBuilder.ToString(); + } + + /// + /// Calls the embedding API to get embeddings for the provided texts. + /// + private async Task EmbedFromApiAsync(string[] texts, CancellationToken cancellationToken) + { + string requestUrl = BuildRequestUrl(); + object requestBody = BuildRequestBody(texts); + + string requestJson = JsonSerializer.Serialize(requestBody, _jsonSerializerOptions); + using HttpContent content = new StringContent(requestJson, Encoding.UTF8, "application/json"); + + _logger.LogDebug("Sending embedding request to {Url} with {Count} text(s)", requestUrl, texts.Length); + + using HttpResponseMessage response = await _httpClient.PostAsync(requestUrl, content, cancellationToken); + + if (!response.IsSuccessStatusCode) + { + string errorContent = await response.Content.ReadAsStringAsync(cancellationToken); + _logger.LogError("Embedding request failed with status {StatusCode}: {ErrorContent}", + response.StatusCode, errorContent); + throw new HttpRequestException( + $"Embedding request failed with status code {(int)response.StatusCode}."); + } + + string responseJson = await response.Content.ReadAsStringAsync(cancellationToken); + EmbeddingResponse? embeddingResponse = JsonSerializer.Deserialize(responseJson, _jsonSerializerOptions); + + if (embeddingResponse?.Data is null || embeddingResponse.Data.Count == 0) + { + throw new InvalidOperationException("No embedding data received from the provider."); + } + + List data = embeddingResponse.Data; + int expectedCount = texts.Length; + + // Validate that we received exactly one embedding per input text. + if (data.Count != expectedCount) + { + _logger.LogError( + "Embedding provider returned {ActualCount} embeddings for {ExpectedCount} input text(s).", + data.Count, + expectedCount); + throw new InvalidOperationException( + $"Embedding provider returned {data.Count} embeddings for {expectedCount} input text(s)."); + } + + // Validate indices are within range and unique. + int minIndex = data.Min(d => d.Index); + int maxIndex = data.Max(d => d.Index); + if (minIndex < 0 || maxIndex >= expectedCount) + { + _logger.LogError( + "Embedding provider returned out-of-range indices. MinIndex: {MinIndex}, MaxIndex: {MaxIndex}, ExpectedCount: {ExpectedCount}.", + minIndex, + maxIndex, + expectedCount); + throw new InvalidOperationException( + $"Embedding provider returned out-of-range indices. MinIndex: {minIndex}, MaxIndex: {maxIndex}, ExpectedCount: {expectedCount}."); + } + + int distinctIndexCount = data.Select(d => d.Index).Distinct().Count(); + if (distinctIndexCount != expectedCount) + { + _logger.LogError( + "Embedding provider returned duplicate or missing indices. DistinctIndexCount: {DistinctIndexCount}, ExpectedCount: {ExpectedCount}.", + distinctIndexCount, + expectedCount); + throw new InvalidOperationException( + $"Embedding provider returned duplicate or missing indices. DistinctIndexCount: {distinctIndexCount}, ExpectedCount: {expectedCount}."); + } + + // Sort by index to ensure correct order and extract embeddings + List sortedData = data.OrderBy(d => d.Index).ToList(); + return sortedData.Select(d => d.Embedding).ToArray(); + } + + /// + /// Builds the request URL based on the provider type. + /// + private string BuildRequestUrl() + { + string baseUrl = _options.BaseUrl.TrimEnd('/'); + + if (_options.Provider == EmbeddingProviderType.AzureOpenAI) + { + // Azure OpenAI: {baseUrl}/openai/deployments/{deployment}/embeddings?api-version={version} + string model = _options.EffectiveModel + ?? throw new InvalidOperationException("Model/deployment name is required for Azure OpenAI."); + + string encodedModel = global::System.Uri.EscapeDataString(model); + + return $"{baseUrl}/openai/deployments/{encodedModel}/embeddings?api-version={_options.EffectiveApiVersion}"; + } + else + { + // OpenAI: {baseUrl}/v1/embeddings + return $"{baseUrl}/v1/embeddings"; + } + } + + /// + /// Builds the request body based on the provider type. + /// + private object BuildRequestBody(string[] texts) + { + // Use single string for single text, array for batch + object input = texts.Length == 1 ? texts[0] : texts; + + if (_options.Provider == EmbeddingProviderType.AzureOpenAI) + { + // Azure OpenAI request body + if (_options.UserProvidedDimensions) + { + return new + { + input, + dimensions = _options.Dimensions + }; + } + + return new { input }; + } + else + { + // OpenAI request body - includes model in body + string model = _options.EffectiveModel ?? EmbeddingsOptions.DEFAULT_OPENAI_MODEL; + + if (_options.UserProvidedDimensions) + { + return new + { + model, + input, + dimensions = _options.Dimensions + }; + } + + return new + { + model, + input + }; + } + } + + /// + /// Response model for embedding API responses. + /// + private sealed class EmbeddingResponse + { + [JsonPropertyName("data")] + public List? Data { get; set; } + + [JsonPropertyName("model")] + public string? Model { get; set; } + + [JsonPropertyName("usage")] + public EmbeddingUsage? Usage { get; set; } + } + + /// + /// Individual embedding data in the response. + /// + private sealed class EmbeddingData + { + [JsonPropertyName("index")] + public int Index { get; set; } + + [JsonPropertyName("embedding")] + public float[] Embedding { get; set; } = Array.Empty(); + } + + /// + /// Token usage information in the response. + /// + private sealed class EmbeddingUsage + { + [JsonPropertyName("prompt_tokens")] + public int PromptTokens { get; set; } + + [JsonPropertyName("total_tokens")] + public int TotalTokens { get; set; } + } +} diff --git a/src/Core/Services/Embeddings/EmbeddingTelemetryHelper.cs b/src/Core/Services/Embeddings/EmbeddingTelemetryHelper.cs new file mode 100644 index 0000000000..5c3f425af1 --- /dev/null +++ b/src/Core/Services/Embeddings/EmbeddingTelemetryHelper.cs @@ -0,0 +1,264 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Diagnostics.Metrics; +using Azure.DataApiBuilder.Core.Telemetry; +using OpenTelemetry.Trace; + +namespace Azure.DataApiBuilder.Core.Services.Embeddings; + +/// +/// Helper class for tracking embedding-related telemetry metrics and traces. +/// +public static class EmbeddingTelemetryHelper +{ + /// + /// Meter name for embedding metrics. + /// + public static readonly string MeterName = "DataApiBuilder.Embeddings"; + + // Metrics + private static readonly Meter _meter = new(MeterName); + + // Counters + private static readonly Counter _embeddingRequests = _meter.CreateCounter( + "embedding_requests_total", + description: "Total number of embedding requests"); + + private static readonly Counter _embeddingApiCalls = _meter.CreateCounter( + "embedding_api_calls_total", + description: "Total number of embedding API calls (excludes cache hits)"); + + private static readonly Counter _embeddingCacheHits = _meter.CreateCounter( + "embedding_cache_hits_total", + description: "Total number of embedding cache hits"); + + private static readonly Counter _embeddingCacheMisses = _meter.CreateCounter( + "embedding_cache_misses_total", + description: "Total number of embedding cache misses"); + + private static readonly Counter _embeddingErrors = _meter.CreateCounter( + "embedding_errors_total", + description: "Total number of embedding errors"); + + private static readonly Counter _embeddingTextsProcessed = _meter.CreateCounter( + "embedding_texts_processed_total", + description: "Total number of texts processed for embedding"); + + // Histograms for timing and sizing + private static readonly Histogram _embeddingApiDuration = _meter.CreateHistogram( + "embedding_api_duration_ms", + unit: "ms", + description: "Duration of embedding API calls in milliseconds"); + + private static readonly Histogram _embeddingTotalDuration = _meter.CreateHistogram( + "embedding_total_duration_ms", + unit: "ms", + description: "Total duration of embedding operations including cache lookup"); + + private static readonly Histogram _embeddingTokens = _meter.CreateHistogram( + "embedding_tokens_total", + description: "Total tokens used in embedding requests"); + + private static readonly Histogram _embeddingDimensions = _meter.CreateHistogram( + "embedding_dimensions", + description: "Number of dimensions in embedding vectors"); + + /// + /// Tracks an embedding request (entry point, includes cache hits). + /// + /// The embedding provider (e.g., azure-openai, openai). + /// Number of texts being embedded. + public static void TrackEmbeddingRequest(string provider, int textCount) + { + _embeddingRequests.Add(1, + new KeyValuePair("provider", provider)); + _embeddingTextsProcessed.Add(textCount, + new KeyValuePair("provider", provider)); + } + + /// + /// Tracks an embedding API call (cache miss, actual API call made). + /// + /// The embedding provider. + /// Number of texts sent to API. + public static void TrackApiCall(string provider, int textCount) + { + _embeddingApiCalls.Add(1, + new KeyValuePair("provider", provider), + new KeyValuePair("text_count", textCount)); + } + + /// + /// Tracks an embedding cache hit. + /// + /// The embedding provider. + public static void TrackCacheHit(string provider) + { + _embeddingCacheHits.Add(1, new KeyValuePair("provider", provider)); + } + + /// + /// Tracks an embedding cache miss. + /// + /// The embedding provider. + public static void TrackCacheMiss(string provider) + { + _embeddingCacheMisses.Add(1, new KeyValuePair("provider", provider)); + } + + /// + /// Tracks an embedding error. + /// + /// The embedding provider. + /// The type of error that occurred. + public static void TrackError(string provider, string errorType) + { + _embeddingErrors.Add(1, + new KeyValuePair("provider", provider), + new KeyValuePair("error_type", errorType)); + } + + /// + /// Tracks the duration of an embedding API call. + /// + /// The embedding provider. + /// The duration of the API call. + /// Number of texts embedded. + public static void TrackApiDuration(string provider, TimeSpan duration, int textCount) + { + _embeddingApiDuration.Record(duration.TotalMilliseconds, + new KeyValuePair("provider", provider), + new KeyValuePair("text_count", textCount)); + } + + /// + /// Tracks the total duration of an embedding operation (including cache lookup). + /// + /// The embedding provider. + /// The total duration. + /// Whether result was from cache. + public static void TrackTotalDuration(string provider, TimeSpan duration, bool fromCache) + { + _embeddingTotalDuration.Record(duration.TotalMilliseconds, + new KeyValuePair("provider", provider), + new KeyValuePair("from_cache", fromCache)); + } + + /// + /// Tracks token usage from an embedding request. + /// + /// The embedding provider. + /// Total tokens used. + public static void TrackTokenUsage(string provider, long totalTokens) + { + _embeddingTokens.Record(totalTokens, new KeyValuePair("provider", provider)); + } + + /// + /// Tracks embedding vector dimensions. + /// + /// The embedding provider. + /// Number of dimensions in the vector. + public static void TrackDimensions(string provider, int dimensions) + { + _embeddingDimensions.Record(dimensions, new KeyValuePair("provider", provider)); + } + + /// + /// Starts an activity for embedding operations. + /// + /// Name of the operation (e.g., "EmbedAsync", "EmbedBatchAsync"). + /// The started activity, or null if tracing is not enabled. + public static Activity? StartEmbeddingActivity(string operationName) + { + return TelemetryTracesHelper.DABActivitySource.StartActivity( + name: $"Embedding.{operationName}", + kind: ActivityKind.Client); + } + + /// + /// Sets embedding-specific tags on an activity. + /// + /// The activity to tag. + /// The embedding provider. + /// The model being used. + /// Number of texts being embedded. + public static void SetEmbeddingActivityTags( + this Activity activity, + string provider, + string? model, + int textCount) + { + if (activity.IsAllDataRequested) + { + activity.SetTag("embedding.provider", provider); + if (!string.IsNullOrEmpty(model)) + { + activity.SetTag("embedding.model", model); + } + + activity.SetTag("embedding.text_count", textCount); + } + } + + /// + /// Records cache status on an activity. + /// + /// The activity to tag. + /// Number of cache hits. + /// Number of cache misses. + public static void SetCacheActivityTags( + this Activity activity, + int cacheHits, + int cacheMisses) + { + if (activity.IsAllDataRequested) + { + activity.SetTag("embedding.cache_hits", cacheHits); + activity.SetTag("embedding.cache_misses", cacheMisses); + } + } + + /// + /// Records successful completion of an embedding activity. + /// + /// The activity to complete. + /// Duration in milliseconds. + /// Number of dimensions in the result. + public static void SetEmbeddingActivitySuccess( + this Activity activity, + double durationMs, + int? dimensions = null) + { + if (activity.IsAllDataRequested) + { + activity.SetTag("embedding.duration_ms", durationMs); + if (dimensions.HasValue) + { + activity.SetTag("embedding.dimensions", dimensions.Value); + } + + activity.SetStatus(ActivityStatusCode.Ok); + } + } + + /// + /// Records an error on an embedding activity. + /// + /// The activity to record error on. + /// The exception that occurred. + public static void SetEmbeddingActivityError( + this Activity activity, + Exception ex) + { + if (activity.IsAllDataRequested) + { + activity.SetStatus(ActivityStatusCode.Error, ex.Message); + activity.AddException(ex); + activity.SetTag("error.type", ex.GetType().Name); + activity.SetTag("error.message", ex.Message); + } + } +} diff --git a/src/Core/Services/Embeddings/IEmbeddingService.cs b/src/Core/Services/Embeddings/IEmbeddingService.cs new file mode 100644 index 0000000000..ef5a9e490c --- /dev/null +++ b/src/Core/Services/Embeddings/IEmbeddingService.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.DataApiBuilder.Core.Services.Embeddings; + +/// +/// Result of a TryEmbed operation. +/// +/// Whether the embedding was generated successfully. +/// The embedding vector, or null if unsuccessful. +/// Error message if unsuccessful, or null if successful. +public record EmbeddingResult(bool Success, float[]? Embedding, string? ErrorMessage = null); + +/// +/// Result of a TryEmbedBatch operation. +/// +/// Whether the embeddings were generated successfully. +/// The embedding vectors, or null if unsuccessful. +/// Error message if unsuccessful, or null if successful. +public record EmbeddingBatchResult(bool Success, float[][]? Embeddings, string? ErrorMessage = null); + +/// +/// Service interface for text embedding/vectorization. +/// Supports both single text and batch embedding operations. +/// +public interface IEmbeddingService +{ + /// + /// Gets whether the embedding service is enabled. + /// + bool IsEnabled { get; } + + /// + /// Attempts to generate an embedding vector for a single text input. + /// Returns a result indicating success or failure without throwing exceptions. + /// + /// The text to embed. + /// Cancellation token for the operation. + /// Result containing the embedding if successful, or error information if not. + Task TryEmbedAsync(string text, CancellationToken cancellationToken = default); + + /// + /// Attempts to generate embedding vectors for multiple text inputs in a batch. + /// Returns a result indicating success or failure without throwing exceptions. + /// + /// The texts to embed. + /// Cancellation token for the operation. + /// Result containing the embeddings if successful, or error information if not. + Task TryEmbedBatchAsync(string[] texts, CancellationToken cancellationToken = default); + + /// + /// Generates an embedding vector for a single text input. + /// Throws if the service is disabled or an error occurs. + /// + /// The text to embed. + /// Cancellation token for the operation. + /// The embedding vector as an array of floats. + /// Thrown when the service is disabled. + Task EmbedAsync(string text, CancellationToken cancellationToken = default); + + /// + /// Generates embedding vectors for multiple text inputs in a batch. + /// Throws if the service is disabled or an error occurs. + /// + /// The texts to embed. + /// Cancellation token for the operation. + /// The embedding vectors as an array of float arrays, matching input order. + /// Thrown when the service is disabled. + Task EmbedBatchAsync(string[] texts, CancellationToken cancellationToken = default); +} diff --git a/src/Service.Tests/Configuration/RuntimeConfigLoaderTests.cs b/src/Service.Tests/Configuration/RuntimeConfigLoaderTests.cs index 3ff9e58531..737e7bb618 100644 --- a/src/Service.Tests/Configuration/RuntimeConfigLoaderTests.cs +++ b/src/Service.Tests/Configuration/RuntimeConfigLoaderTests.cs @@ -398,9 +398,9 @@ public async Task ChildConfigWithMissingEnvVarsLoadsSuccessfully() }"; // Save original env var values and clear them to ensure they don't exist. - string? origEndpoint = Environment.GetEnvironmentVariable("NONEXISTENT_OTEL_ENDPOINT"); - string? origHeaders = Environment.GetEnvironmentVariable("NONEXISTENT_OTEL_HEADERS"); - string? origServiceName = Environment.GetEnvironmentVariable("NONEXISTENT_OTEL_SERVICE_NAME"); + string origEndpoint = Environment.GetEnvironmentVariable("NONEXISTENT_OTEL_ENDPOINT"); + string origHeaders = Environment.GetEnvironmentVariable("NONEXISTENT_OTEL_HEADERS"); + string origServiceName = Environment.GetEnvironmentVariable("NONEXISTENT_OTEL_SERVICE_NAME"); Environment.SetEnvironmentVariable("NONEXISTENT_OTEL_ENDPOINT", null); Environment.SetEnvironmentVariable("NONEXISTENT_OTEL_HEADERS", null); Environment.SetEnvironmentVariable("NONEXISTENT_OTEL_SERVICE_NAME", null); diff --git a/src/Service.Tests/UnitTests/ChunkTextTests.cs b/src/Service.Tests/UnitTests/ChunkTextTests.cs new file mode 100644 index 0000000000..9769e84b99 --- /dev/null +++ b/src/Service.Tests/UnitTests/ChunkTextTests.cs @@ -0,0 +1,325 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; +using Azure.DataApiBuilder.Service.Helpers; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests; + +/// +/// Unit tests for the ChunkText functionality in EmbeddingController. +/// +[TestClass] +public class ChunkTextTests +{ + + /// + /// Tests that ChunkText returns single chunk for text smaller than chunk size. + /// + [TestMethod] + public void ChunkText_ReturnsSingleChunk_ForSmallText() + { + // Arrange + string text = "Short text"; + int chunkSize = 100; + int overlap = 10; + + // Act + List chunks = ChunkText(text, chunkSize, overlap); + + // Assert + Assert.AreEqual(1, chunks.Count); + Assert.AreEqual(text, chunks[0]); + } + + /// + /// Tests that ChunkText splits text into multiple chunks. + /// + [TestMethod] + public void ChunkText_SplitsIntoMultipleChunks() + { + // Arrange + string text = new string('A', 250); // 250 characters + int chunkSize = 100; + int overlap = 0; + + // Act + List chunks = ChunkText(text, chunkSize, overlap); + + // Assert + Assert.AreEqual(3, chunks.Count); + Assert.AreEqual(100, chunks[0].Length); + Assert.AreEqual(100, chunks[1].Length); + Assert.AreEqual(50, chunks[2].Length); + } + + /// + /// Tests that ChunkText creates overlapping chunks. + /// + [TestMethod] + public void ChunkText_CreatesOverlappingChunks() + { + // Arrange + string text = "0123456789ABCDEFGHIJ"; // 20 characters + int chunkSize = 10; + int overlap = 3; + + // Act + List chunks = ChunkText(text, chunkSize, overlap); + + // Assert + Assert.IsTrue(chunks.Count >= 2, "Should have multiple chunks"); + + // First chunk: chars 0-9 + Assert.AreEqual("0123456789", chunks[0]); + + // Second chunk should start at position 7 (10 - 3 overlap) + // and include chars 7-16 + if (chunks.Count >= 2) + { + Assert.IsTrue(chunks[1].StartsWith("789"), "Second chunk should start with overlap from first chunk"); + } + } + + /// + /// Tests that ChunkText with zero overlap creates adjacent chunks. + /// + [TestMethod] + public void ChunkText_WithZeroOverlap_CreatesAdjacentChunks() + { + // Arrange + string text = "AAAABBBBCCCCDDDD"; // 16 characters + int chunkSize = 4; + int overlap = 0; + + // Act + List chunks = ChunkText(text, chunkSize, overlap); + + // Assert + Assert.AreEqual(4, chunks.Count); + Assert.AreEqual("AAAA", chunks[0]); + Assert.AreEqual("BBBB", chunks[1]); + Assert.AreEqual("CCCC", chunks[2]); + Assert.AreEqual("DDDD", chunks[3]); + } + + /// + /// Tests that ChunkText handles overlap equal to chunk size. + /// + [TestMethod] + public void ChunkText_HandlesOverlapEqualToChunkSize() + { + // Arrange + string text = "0123456789ABCDEF"; // 16 characters + int chunkSize = 5; + int overlap = 5; + + // Act + List chunks = ChunkText(text, chunkSize, overlap); + + // Assert - each chunk should start at the same position as previous (overlap = size) + // This should still terminate and not create infinite chunks + Assert.IsTrue(chunks.Count > 0); + Assert.IsTrue(chunks.Count < 100, "Should not create excessive chunks"); + } + + /// + /// Tests that ChunkText handles overlap larger than chunk size. + /// + [TestMethod] + public void ChunkText_HandlesOverlapLargerThanChunkSize() + { + // Arrange + string text = "0123456789ABCDEF"; // 16 characters + int chunkSize = 5; + int overlap = 10; + + // Act + List chunks = ChunkText(text, chunkSize, overlap); + + // Assert - should handle gracefully without infinite loop + Assert.IsTrue(chunks.Count > 0); + Assert.IsTrue(chunks.Count < 100, "Should not create excessive chunks"); + } + + /// + /// Tests that ChunkText handles empty string. + /// + [TestMethod] + public void ChunkText_HandlesEmptyString() + { + // Arrange + string text = ""; + int chunkSize = 100; + int overlap = 10; + + // Act + List chunks = ChunkText(text, chunkSize, overlap); + + // Assert + Assert.AreEqual(0, chunks.Count, "Empty text should produce no chunks"); + } + + /// + /// Tests that ChunkText handles single character. + /// + [TestMethod] + public void ChunkText_HandlesSingleCharacter() + { + // Arrange + string text = "A"; + int chunkSize = 100; + int overlap = 10; + + // Act + List chunks = ChunkText(text, chunkSize, overlap); + + // Assert + Assert.AreEqual(1, chunks.Count); + Assert.AreEqual("A", chunks[0]); + } + + /// + /// Tests that ChunkText with chunk size of 1 creates individual character chunks. + /// + [TestMethod] + public void ChunkText_WithChunkSizeOne_CreatesCharacterChunks() + { + // Arrange + string text = "ABCDE"; + int chunkSize = 1; + int overlap = 0; + + // Act + List chunks = ChunkText(text, chunkSize, overlap); + + // Assert + Assert.AreEqual(5, chunks.Count); + Assert.AreEqual("A", chunks[0]); + Assert.AreEqual("B", chunks[1]); + Assert.AreEqual("C", chunks[2]); + Assert.AreEqual("D", chunks[3]); + Assert.AreEqual("E", chunks[4]); + } + + /// + /// Tests that ChunkText preserves whitespace and special characters. + /// + [TestMethod] + public void ChunkText_PreservesWhitespaceAndSpecialCharacters() + { + // Arrange + string text = "Hello World!\nNew Line\tTab"; + int chunkSize = 15; + int overlap = 0; + + // Act + List chunks = ChunkText(text, chunkSize, overlap); + + // Assert + string reconstructed = string.Concat(chunks); + Assert.AreEqual(text, reconstructed, "Reconstructed text should match original"); + } + + /// + /// Tests that ChunkText handles Unicode characters correctly. + /// + [TestMethod] + public void ChunkText_HandlesUnicodeCharacters() + { + // Arrange + string text = "Hello 世界 🌍 Émoji"; + int chunkSize = 10; + int overlap = 2; + + // Act + List chunks = ChunkText(text, chunkSize, overlap); + + // Assert + Assert.IsTrue(chunks.Count > 0); + string reconstructedStart = chunks[0]; + Assert.IsTrue(reconstructedStart.Contains("Hello") || reconstructedStart.Contains("世") || reconstructedStart.Contains("🌍"), + "Should preserve Unicode characters"); + } + + /// + /// Tests that overlapping chunks share common text. + /// + [TestMethod] + public void ChunkText_OverlappingChunksShareCommonText() + { + // Arrange + string text = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + int chunkSize = 10; + int overlap = 3; + + // Act + List chunks = ChunkText(text, chunkSize, overlap); + + // Assert + for (int i = 0; i < chunks.Count - 1; i++) + { + string currentChunk = chunks[i]; + string nextChunk = chunks[i + 1]; + + // Last 'overlap' characters of current chunk should match first 'overlap' of next chunk + string currentEnd = currentChunk.Substring(Math.Max(0, currentChunk.Length - overlap)); + string nextStart = nextChunk.Substring(0, Math.Min(overlap, nextChunk.Length)); + + Assert.AreEqual(currentEnd, nextStart, + $"Chunks {i} and {i + 1} should have overlapping content"); + } + } + + /// + /// Tests that text can be reconstructed from non-overlapping chunks. + /// + [TestMethod] + public void ChunkText_NonOverlappingChunks_CanReconstructText() + { + // Arrange + string text = "The quick brown fox jumps over the lazy dog"; + int chunkSize = 10; + int overlap = 0; + + // Act + List chunks = ChunkText(text, chunkSize, overlap); + + // Assert + string reconstructed = string.Concat(chunks); + Assert.AreEqual(text, reconstructed); + } + + /// + /// Tests ChunkText with very large text. + /// + [TestMethod] + public void ChunkText_HandlesLargeText() + { + // Arrange + string text = new string('X', 10000); + int chunkSize = 1000; + int overlap = 100; + + // Act + List chunks = ChunkText(text, chunkSize, overlap); + + // Assert + Assert.IsTrue(chunks.Count >= 10, "Large text should be split into multiple chunks"); + Assert.AreEqual(1000, chunks[0].Length); + Assert.IsTrue(chunks[chunks.Count - 1].Length <= 1000); + } + + /// + /// Helper method that delegates to the production + /// implementation so tests exercise real controller logic rather than a local re-implementation. + /// + private static List ChunkText(string text, int chunkSize, int overlap) + { + return TextChunker.ChunkText(text, chunkSize, overlap).ToList(); + } +} diff --git a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs index 05561e4cf9..bbb4874d1a 100644 --- a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs +++ b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs @@ -14,6 +14,7 @@ using Azure.DataApiBuilder.Config.Converters; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Core.Services.MetadataProviders; @@ -1623,114 +1624,6 @@ private static void ValidateExceptionForDuplicateQueriesDueToEntityDefinitions(S Assert.AreEqual(expected: DataApiBuilderException.SubStatusCodes.ConfigValidationError, actual: dabException.SubStatusCode); } - /// - /// Method to create a sample entity with GraphQL enabled, - /// with given source and relationship Info. - /// Rest is disabled by default, unless specified otherwise. - /// - /// Database name of entity. - /// Dictionary containing {relationshipName, Relationship} - private static Entity GetSampleEntityUsingSourceAndRelationshipMap( - string source, - Dictionary relationshipMap, - EntityGraphQLOptions graphQLDetails, - EntityRestOptions restDetails = null - ) - { - EntityAction actionForRole = new( - Action: EntityActionOperation.Create, - Fields: null, - Policy: null); - - EntityPermission permissionForEntity = new( - Role: "anonymous", - Actions: new[] { actionForRole }); - - Entity sampleEntity = new( - Source: new(source, EntitySourceType.Table, null, null), - Fields: null, - Rest: restDetails ?? new(Enabled: false), - GraphQL: graphQLDetails, - Permissions: new[] { permissionForEntity }, - Relationships: relationshipMap, - Mappings: null - ); - - return sampleEntity; - } - - /// - /// Returns Dictionary containing pair of string and entity. - /// It creates two sample entities and forms relationship between them. - /// - /// Name of the source entity. - /// Name of the target entity. - /// List of strings representing the source field names. - /// List of strings representing the target field names. - /// Name of the linking object. - /// List of strings representing the linking source field names. - /// List of strings representing the linking target field names. - private static Dictionary GetSampleEntityMap( - string sourceEntity, - string targetEntity, - string[] sourceFields, - string[] targetFields, - string linkingObject, - string[] linkingSourceFields, - string[] linkingTargetFields - ) - { - Dictionary relationshipMap = new(); - - // Creating relationship between source and target entity. - EntityRelationship sampleRelationship = new( - Cardinality: Cardinality.One, - TargetEntity: targetEntity, - SourceFields: sourceFields, - TargetFields: targetFields, - LinkingObject: linkingObject, - LinkingSourceFields: linkingSourceFields, - LinkingTargetFields: linkingTargetFields - ); - - relationshipMap.Add("rname1", sampleRelationship); - - Entity sampleEntity1 = GetSampleEntityUsingSourceAndRelationshipMap( - source: "TEST_SOURCE1", - relationshipMap: relationshipMap, - graphQLDetails: new("rname1", "rname1s", true) - ); - - sampleRelationship = new( - Cardinality: Cardinality.One, - TargetEntity: sourceEntity, - SourceFields: targetFields, - TargetFields: sourceFields, - LinkingObject: linkingObject, - LinkingSourceFields: linkingTargetFields, - LinkingTargetFields: linkingSourceFields - ); - - relationshipMap = new() - { - { "rname2", sampleRelationship } - }; - - Entity sampleEntity2 = GetSampleEntityUsingSourceAndRelationshipMap( - source: "TEST_SOURCE2", - relationshipMap: relationshipMap, - graphQLDetails: new("rname2", "rname2s", true) - ); - - Dictionary entityMap = new() - { - { sourceEntity, sampleEntity1 }, - { targetEntity, sampleEntity2 } - }; - - return entityMap; - } - /// /// Tests whether the API path prefix is well formed or not. /// @@ -3040,12 +2933,748 @@ public void ValidateMaxResponseSizeInConfig( } } - private static RuntimeConfigValidator InitializeRuntimeConfigValidator() + /// + /// Validates that embeddings validation is skipped when embeddings are null or disabled. + /// No exception should be thrown. + /// + [DataTestMethod] + [DataRow(true, DisplayName = "Embeddings is null - validation skipped.")] + [DataRow(false, DisplayName = "Embeddings is disabled - validation skipped.")] + public void ValidateEmbeddingsOptions_SkipsValidation_WhenNullOrDisabled(bool isNull) { - MockFileSystem fileSystem = new(); - FileSystemRuntimeConfigLoader loader = new(fileSystem); - RuntimeConfigProvider provider = new(loader); - return new(provider, fileSystem, new Mock>().Object); + EmbeddingsOptions embeddingsOptions = isNull + ? null + : new EmbeddingsOptions( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "", + ApiKey: "", + Enabled: false); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(null, null), + Embeddings: embeddingsOptions + ), + Entities: new(new Dictionary()) + ); + + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + // Should not throw any exception. + configValidator.ValidateEmbeddingsOptions(runtimeConfig); + } + + /// + /// Validates that embeddings base-url is required and must be a valid HTTP or HTTPS URL. + /// + [DataTestMethod] + [DataRow(null, true, "Embeddings 'base-url' cannot be null or empty when embeddings are enabled.", + DisplayName = "Embeddings base-url is null.")] + [DataRow("", true, "Embeddings 'base-url' cannot be null or empty when embeddings are enabled.", + DisplayName = "Embeddings base-url is empty.")] + [DataRow(" ", true, "Embeddings 'base-url' cannot be null or empty when embeddings are enabled.", + DisplayName = "Embeddings base-url is whitespace.")] + [DataRow("not-a-url", true, "Embeddings 'base-url' must be a valid HTTP or HTTPS URL. Got: not-a-url", + DisplayName = "Embeddings base-url is not a valid URL.")] + [DataRow("ftp://example.com", true, "Embeddings 'base-url' must be a valid HTTP or HTTPS URL. Got: ftp://example.com", + DisplayName = "Embeddings base-url is FTP, not HTTP/HTTPS.")] + [DataRow("https://api.openai.com", false, null, + DisplayName = "Embeddings base-url is valid HTTPS URL.")] + [DataRow("http://localhost:8080", false, null, + DisplayName = "Embeddings base-url is valid HTTP URL.")] + public void ValidateEmbeddingsOptions_BaseUrl(string baseUrl, bool exceptionExpected, string expectedErrorMessage) + { + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: baseUrl, + ApiKey: "test-api-key", + Enabled: true); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(null, null), + Embeddings: embeddingsOptions + ), + Entities: new(new Dictionary()) + ); + + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + if (exceptionExpected) + { + DataApiBuilderException ex = Assert.ThrowsException( + () => configValidator.ValidateEmbeddingsOptions(runtimeConfig)); + Assert.AreEqual(expectedErrorMessage, ex.Message); + Assert.AreEqual(HttpStatusCode.ServiceUnavailable, ex.StatusCode); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ConfigValidationError, ex.SubStatusCode); + } + else + { + configValidator.ValidateEmbeddingsOptions(runtimeConfig); + } + } + + /// + /// Validates that embeddings api-key is required when embeddings are enabled. + /// + [DataTestMethod] + [DataRow(null, true, DisplayName = "Embeddings api-key is null.")] + [DataRow("", true, DisplayName = "Embeddings api-key is empty.")] + [DataRow(" ", true, DisplayName = "Embeddings api-key is whitespace.")] + [DataRow("sk-valid-key", false, DisplayName = "Embeddings api-key is valid.")] + public void ValidateEmbeddingsOptions_ApiKey(string apiKey, bool exceptionExpected) + { + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: apiKey, + Enabled: true); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(null, null), + Embeddings: embeddingsOptions + ), + Entities: new(new Dictionary()) + ); + + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + if (exceptionExpected) + { + DataApiBuilderException ex = Assert.ThrowsException( + () => configValidator.ValidateEmbeddingsOptions(runtimeConfig)); + Assert.AreEqual("Embeddings 'api-key' cannot be null or empty when embeddings are enabled.", ex.Message); + Assert.AreEqual(HttpStatusCode.ServiceUnavailable, ex.StatusCode); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ConfigValidationError, ex.SubStatusCode); + } + else + { + configValidator.ValidateEmbeddingsOptions(runtimeConfig); + } + } + + /// + /// Validates that for Azure OpenAI provider, model (deployment name) is required. + /// For OpenAI provider, model is not required. + /// + [DataTestMethod] + [DataRow(EmbeddingProviderType.AzureOpenAI, null, true, + DisplayName = "AzureOpenAI with null model fails.")] + [DataRow(EmbeddingProviderType.AzureOpenAI, "", true, + DisplayName = "AzureOpenAI with empty model fails.")] + [DataRow(EmbeddingProviderType.AzureOpenAI, " ", true, + DisplayName = "AzureOpenAI with whitespace model fails.")] + [DataRow(EmbeddingProviderType.AzureOpenAI, "my-deployment", false, + DisplayName = "AzureOpenAI with valid model passes.")] + [DataRow(EmbeddingProviderType.OpenAI, null, false, + DisplayName = "OpenAI with null model passes.")] + [DataRow(EmbeddingProviderType.OpenAI, "", false, + DisplayName = "OpenAI with empty model passes.")] + public void ValidateEmbeddingsOptions_ModelRequiredForAzureOpenAI( + EmbeddingProviderType provider, string model, bool exceptionExpected) + { + EmbeddingsOptions embeddingsOptions = new( + Provider: provider, + BaseUrl: "https://myinstance.openai.azure.com", + ApiKey: "test-api-key", + Enabled: true, + Model: model); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(null, null), + Embeddings: embeddingsOptions + ), + Entities: new(new Dictionary()) + ); + + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + if (exceptionExpected) + { + DataApiBuilderException ex = Assert.ThrowsException( + () => configValidator.ValidateEmbeddingsOptions(runtimeConfig)); + Assert.AreEqual("Embeddings 'model' (deployment name) is required when using the Azure OpenAI provider.", ex.Message); + Assert.AreEqual(HttpStatusCode.ServiceUnavailable, ex.StatusCode); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ConfigValidationError, ex.SubStatusCode); + } + else + { + configValidator.ValidateEmbeddingsOptions(runtimeConfig); + } + } + + /// + /// Validates that timeout-ms must be positive if provided. + /// + [DataTestMethod] + [DataRow(0, true, DisplayName = "Embeddings timeout-ms is zero.")] + [DataRow(-1, true, DisplayName = "Embeddings timeout-ms is negative.")] + [DataRow(-100, true, DisplayName = "Embeddings timeout-ms is large negative.")] + [DataRow(1, false, DisplayName = "Embeddings timeout-ms is 1 (valid).")] + [DataRow(30000, false, DisplayName = "Embeddings timeout-ms is 30000 (valid).")] + [DataRow(null, false, DisplayName = "Embeddings timeout-ms is null (valid, uses default).")] + public void ValidateEmbeddingsOptions_TimeoutMs(int? timeoutMs, bool exceptionExpected) + { + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-api-key", + Enabled: true, + TimeoutMs: timeoutMs); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(null, null), + Embeddings: embeddingsOptions + ), + Entities: new(new Dictionary()) + ); + + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + if (exceptionExpected) + { + DataApiBuilderException ex = Assert.ThrowsException( + () => configValidator.ValidateEmbeddingsOptions(runtimeConfig)); + Assert.AreEqual($"Embeddings 'timeout-ms' must be a positive integer. Got: {timeoutMs}", ex.Message); + Assert.AreEqual(HttpStatusCode.ServiceUnavailable, ex.StatusCode); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ConfigValidationError, ex.SubStatusCode); + } + else + { + configValidator.ValidateEmbeddingsOptions(runtimeConfig); + } + } + + /// + /// Validates that dimensions must be positive if provided. + /// + [DataTestMethod] + [DataRow(0, true, DisplayName = "Embeddings dimensions is zero.")] + [DataRow(-1, true, DisplayName = "Embeddings dimensions is negative.")] + [DataRow(-512, true, DisplayName = "Embeddings dimensions is large negative.")] + [DataRow(1, false, DisplayName = "Embeddings dimensions is 1 (valid).")] + [DataRow(1536, false, DisplayName = "Embeddings dimensions is 1536 (valid).")] + [DataRow(null, false, DisplayName = "Embeddings dimensions is null (valid, uses model default).")] + public void ValidateEmbeddingsOptions_Dimensions(int? dimensions, bool exceptionExpected) + { + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-api-key", + Enabled: true, + Dimensions: dimensions); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(null, null), + Embeddings: embeddingsOptions + ), + Entities: new(new Dictionary()) + ); + + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + if (exceptionExpected) + { + DataApiBuilderException ex = Assert.ThrowsException( + () => configValidator.ValidateEmbeddingsOptions(runtimeConfig)); + Assert.AreEqual($"Embeddings 'dimensions' must be a positive integer. Got: {dimensions}", ex.Message); + Assert.AreEqual(HttpStatusCode.ServiceUnavailable, ex.StatusCode); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ConfigValidationError, ex.SubStatusCode); + } + else + { + configValidator.ValidateEmbeddingsOptions(runtimeConfig); + } + } + + /// + /// Validates endpoint roles behavior: + /// - Production mode requires explicitly configured roles (even though null defaults to ['authenticated']) + /// - Development mode allows default roles + /// - Empty roles array is not allowed in either mode + /// + [DataTestMethod] + [DataRow(HostMode.Production, null, true, + DisplayName = "Production mode with null roles fails (requires explicit config).")] + [DataRow(HostMode.Production, new string[0], true, + DisplayName = "Production mode with empty roles fails.")] + [DataRow(HostMode.Production, new string[] { "authenticated" }, false, + DisplayName = "Production mode with explicit roles passes.")] + [DataRow(HostMode.Development, null, false, + DisplayName = "Development mode with null roles uses default ['authenticated'].")] + [DataRow(HostMode.Development, new string[0], true, + DisplayName = "Development mode with empty roles fails.")] + public void ValidateEmbeddingsOptions_EndpointRolesInProductionMode( + HostMode hostMode, + string[] roles, + bool exceptionExpected) + { + EmbeddingsEndpointOptions endpointOptions = new( + enabled: true, + roles: roles); + + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-api-key", + Enabled: true, + Endpoint: endpointOptions); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null, Mode: hostMode), + Embeddings: embeddingsOptions + ), + Entities: new(new Dictionary()) + ); + + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + if (exceptionExpected) + { + DataApiBuilderException ex = Assert.ThrowsException( + () => configValidator.ValidateEmbeddingsOptions(runtimeConfig)); + + // Production with null gets caught first, empty array gets caught second + string expectedMessage = (hostMode == HostMode.Production && roles is null) + ? "Embeddings endpoint 'roles' must be explicitly configured in production mode." + : "Embeddings endpoint 'roles' cannot be empty when endpoint is enabled."; + + Assert.AreEqual(expectedMessage, ex.Message); + Assert.AreEqual(HttpStatusCode.ServiceUnavailable, ex.StatusCode); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ConfigValidationError, ex.SubStatusCode); + } + else + { + configValidator.ValidateEmbeddingsOptions(runtimeConfig); + } + } + + /// + /// Validates that health check threshold-ms must be positive when health check is enabled. + /// + [DataTestMethod] + [DataRow(0, true, DisplayName = "Health check threshold-ms is zero.")] + [DataRow(-1, true, DisplayName = "Health check threshold-ms is negative.")] + [DataRow(-500, true, DisplayName = "Health check threshold-ms is large negative.")] + [DataRow(1, false, DisplayName = "Health check threshold-ms is 1 (valid).")] + [DataRow(5000, false, DisplayName = "Health check threshold-ms is 5000 (valid).")] + public void ValidateEmbeddingsOptions_HealthCheckThresholdMs(int thresholdMs, bool exceptionExpected) + { + EmbeddingsHealthCheckConfig healthConfig = new( + enabled: true, + thresholdMs: thresholdMs, + testText: "health check"); + + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-api-key", + Enabled: true, + Health: healthConfig); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(null, null), + Embeddings: embeddingsOptions + ), + Entities: new(new Dictionary()) + ); + + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + if (exceptionExpected) + { + DataApiBuilderException ex = Assert.ThrowsException( + () => configValidator.ValidateEmbeddingsOptions(runtimeConfig)); + Assert.AreEqual($"Embeddings health check 'threshold-ms' must be a positive integer. Got: {thresholdMs}", ex.Message); + Assert.AreEqual(HttpStatusCode.ServiceUnavailable, ex.StatusCode); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ConfigValidationError, ex.SubStatusCode); + } + else + { + configValidator.ValidateEmbeddingsOptions(runtimeConfig); + } + } + + /// + /// Validates that health check test-text cannot be null or empty when health check is enabled. + /// + [DataTestMethod] + [DataRow(null, false, DisplayName = "Health check test-text is null (uses default).")] + [DataRow("", true, DisplayName = "Health check test-text is empty.")] + [DataRow(" ", true, DisplayName = "Health check test-text is whitespace.")] + [DataRow("health check", false, DisplayName = "Health check test-text is valid.")] + public void ValidateEmbeddingsOptions_HealthCheckTestText(string testText, bool exceptionExpected) + { + EmbeddingsHealthCheckConfig healthConfig = new( + enabled: true, + thresholdMs: 5000, + testText: testText); + + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-api-key", + Enabled: true, + Health: healthConfig); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(null, null), + Embeddings: embeddingsOptions + ), + Entities: new(new Dictionary()) + ); + + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + if (exceptionExpected) + { + DataApiBuilderException ex = Assert.ThrowsException( + () => configValidator.ValidateEmbeddingsOptions(runtimeConfig)); + Assert.AreEqual("Embeddings health check 'test-text' cannot be null or empty when health check is enabled.", ex.Message); + Assert.AreEqual(HttpStatusCode.ServiceUnavailable, ex.StatusCode); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ConfigValidationError, ex.SubStatusCode); + } + else + { + configValidator.ValidateEmbeddingsOptions(runtimeConfig); + } + } + + /// + /// Validates that health check expected-dimensions must be positive if provided. + /// + [DataTestMethod] + [DataRow(0, true, DisplayName = "Health check expected-dimensions is zero.")] + [DataRow(-1, true, DisplayName = "Health check expected-dimensions is negative.")] + [DataRow(-256, true, DisplayName = "Health check expected-dimensions is large negative.")] + [DataRow(1, false, DisplayName = "Health check expected-dimensions is 1 (valid).")] + [DataRow(1536, false, DisplayName = "Health check expected-dimensions is 1536 (valid).")] + [DataRow(null, false, DisplayName = "Health check expected-dimensions is null (valid, skips validation).")] + public void ValidateEmbeddingsOptions_HealthCheckExpectedDimensions(int? expectedDimensions, bool exceptionExpected) + { + EmbeddingsHealthCheckConfig healthConfig = new( + enabled: true, + thresholdMs: 5000, + testText: "health check", + expectedDimensions: expectedDimensions); + + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-api-key", + Enabled: true, + Health: healthConfig); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(null, null), + Embeddings: embeddingsOptions + ), + Entities: new(new Dictionary()) + ); + + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + if (exceptionExpected) + { + DataApiBuilderException ex = Assert.ThrowsException( + () => configValidator.ValidateEmbeddingsOptions(runtimeConfig)); + Assert.AreEqual($"Embeddings health check 'expected-dimensions' must be a positive integer. Got: {expectedDimensions}", ex.Message); + Assert.AreEqual(HttpStatusCode.ServiceUnavailable, ex.StatusCode); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.ConfigValidationError, ex.SubStatusCode); + } + else + { + configValidator.ValidateEmbeddingsOptions(runtimeConfig); + } + } + + /// + /// Validates that a fully valid embeddings configuration passes all validation checks. + /// + [TestMethod] + public void ValidateEmbeddingsOptions_FullyValidConfig_Passes() + { + EmbeddingsEndpointOptions endpointOptions = new( + enabled: true, + roles: new[] { "authenticated" }); + + EmbeddingsHealthCheckConfig healthConfig = new( + enabled: true, + thresholdMs: 5000, + testText: "test embedding", + expectedDimensions: 1536); + + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://myinstance.openai.azure.com", + ApiKey: "my-api-key", + Enabled: true, + Model: "text-embedding-ada-002", + TimeoutMs: 15000, + Dimensions: 1536, + Endpoint: endpointOptions, + Health: healthConfig); + + RuntimeCacheLevel2Options level2Options = new( + Enabled: true, + Provider: "redis", + ConnectionString: "localhost:6379"); + + RuntimeCacheOptions cacheOptions = new(Enabled: true, TtlSeconds: 5) + { + Level2 = level2Options + }; + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null, Mode: HostMode.Production), + Cache: cacheOptions, + Embeddings: embeddingsOptions + ), + Entities: new(new Dictionary()) + ); + + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + // Should not throw any exception. + configValidator.ValidateEmbeddingsOptions(runtimeConfig); + } + + /// + /// Validates that health check validation is skipped when health check is disabled. + /// Even invalid values should not cause an exception. + /// + [TestMethod] + public void ValidateEmbeddingsOptions_HealthCheckDisabled_SkipsValidation() + { + EmbeddingsHealthCheckConfig healthConfig = new( + enabled: false, + thresholdMs: -100, + testText: "", + expectedDimensions: -50); + + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-api-key", + Enabled: true, + Health: healthConfig); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(null, null), + Embeddings: embeddingsOptions + ), + Entities: new(new Dictionary()) + ); + + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + // Should not throw any exception since health check is disabled. + configValidator.ValidateEmbeddingsOptions(runtimeConfig); + } + + /// + /// Validates that endpoint validation is skipped when endpoint is disabled. + /// Even invalid values should not cause an exception. + /// + [TestMethod] + public void ValidateEmbeddingsOptions_EndpointDisabled_SkipsValidation() + { + EmbeddingsEndpointOptions endpointOptions = new( + enabled: false, + roles: null); + + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-api-key", + Enabled: true, + Endpoint: endpointOptions); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "", Options: null), + Runtime: new( + Rest: new(Path: "/api"), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null, Mode: HostMode.Production), + Embeddings: embeddingsOptions + ), + Entities: new(new Dictionary()) + ); + + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + // Should not throw even though the path conflicts with REST and roles are null in production mode, + // because the endpoint is disabled. + configValidator.ValidateEmbeddingsOptions(runtimeConfig); + } + + private static RuntimeConfigValidator InitializeRuntimeConfigValidator() + { + MockFileSystem fileSystem = new(); + FileSystemRuntimeConfigLoader loader = new(fileSystem); + RuntimeConfigProvider provider = new(loader); + return new(provider, fileSystem, new Mock>().Object); + } + + private static Entity GetSampleEntityUsingSourceAndRelationshipMap( + string source, + Dictionary relationshipMap, + EntityGraphQLOptions graphQLDetails, + EntityRestOptions restDetails = null + ) + { + EntityAction actionForRole = new( + Action: EntityActionOperation.Create, + Fields: null, + Policy: null); + EntityPermission permissionForEntity = new( + Role: "anonymous", + Actions: new[] { actionForRole }); + Entity sampleEntity = new( + Source: new(source, EntitySourceType.Table, null, null), + Fields: null, + Rest: restDetails ?? new(Enabled: false), + GraphQL: graphQLDetails, + Permissions: new[] { permissionForEntity }, + Relationships: relationshipMap, + Mappings: null + ); + return sampleEntity; + } + + /// + /// Returns Dictionary containing pair of string and entity. + /// It creates two sample entities and forms relationship between them. + /// + /// Name of the source entity. + /// Name of the target entity. + /// List of strings representing the source field names. + /// List of strings representing the target field names. + /// Name of the linking object. + /// List of strings representing the linking source field names. + /// List of strings representing the linking target field names. + private static Dictionary GetSampleEntityMap( + string sourceEntity, + string targetEntity, + string[] sourceFields, + string[] targetFields, + string linkingObject, + string[] linkingSourceFields, + string[] linkingTargetFields + ) + { + Dictionary relationshipMap = new(); + // Creating relationship between source and target entity. + EntityRelationship sampleRelationship = new( + Cardinality: Cardinality.One, + TargetEntity: targetEntity, + SourceFields: sourceFields, + TargetFields: targetFields, + LinkingObject: linkingObject, + LinkingSourceFields: linkingSourceFields, + LinkingTargetFields: linkingTargetFields + ); + relationshipMap.Add("rname1", sampleRelationship); + Entity sampleEntity1 = GetSampleEntityUsingSourceAndRelationshipMap( + source: "TEST_SOURCE1", + relationshipMap: relationshipMap, + graphQLDetails: new("rname1", "rname1s", true) + ); + sampleRelationship = new( + Cardinality: Cardinality.One, + TargetEntity: sourceEntity, + SourceFields: targetFields, + TargetFields: sourceFields, + LinkingObject: linkingObject, + LinkingSourceFields: linkingTargetFields, + LinkingTargetFields: linkingSourceFields + ); + relationshipMap = new() + { + { "rname2", sampleRelationship } + }; + Entity sampleEntity2 = GetSampleEntityUsingSourceAndRelationshipMap( + source: "TEST_SOURCE2", + relationshipMap: relationshipMap, + graphQLDetails: new("rname2", "rname2s", true) + ); + Dictionary entityMap = new() + { + { sourceEntity, sampleEntity1 }, + { targetEntity, sampleEntity2 } + }; + return entityMap; } } } diff --git a/src/Service.Tests/UnitTests/EmbeddingControllerTests.cs b/src/Service.Tests/UnitTests/EmbeddingControllerTests.cs new file mode 100644 index 0000000000..8b3a627a85 --- /dev/null +++ b/src/Service.Tests/UnitTests/EmbeddingControllerTests.cs @@ -0,0 +1,1979 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; +using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Services.Embeddings; +using Azure.DataApiBuilder.Service.Controllers; +using Azure.DataApiBuilder.Service.Models; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.Logging; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests; + +/// +/// Unit tests for EmbeddingController. +/// Covers fixed route metadata, authorization, request body parsing, +/// service availability, error handling, and integration with IEmbeddingService. +/// +[TestClass] +public class EmbeddingControllerTests +{ + private Mock> _mockLogger = null!; + private Mock _mockEmbeddingService = null!; + + [TestInitialize] + public void Setup() + { + _mockLogger = new Mock>(); + _mockEmbeddingService = new Mock(); + _mockEmbeddingService.Setup(s => s.IsEnabled).Returns(true); + } + + #region Fixed Endpoint Route Tests + + /// + /// Tests that the controller action is bound to the fixed "embed" route. + /// + [TestMethod] + public void PostAsync_UsesFixedEmbedRoute() + { + RouteAttribute? routeAttribute = typeof(EmbeddingController) + .GetMethod(nameof(EmbeddingController.PostAsync))? + .GetCustomAttributes(typeof(RouteAttribute), inherit: false) + .Cast() + .SingleOrDefault(); + + Assert.IsNotNull(routeAttribute); + Assert.AreEqual("embed", routeAttribute.Template); + } + + /// + /// Tests that embedding requests succeed at the fixed endpoint route. + /// + [TestMethod] + public async Task PostAsync_SucceedsAtFixedEndpointRoute() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f, 0.3f }; + SetupSuccessfulEmbedding(embedding); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + } + + #endregion + + #region Embeddings and Endpoint Enabled/Disabled Tests + + /// + /// Tests that the controller returns NotFound when embeddings config is null. + /// + [TestMethod] + public async Task PostAsync_ReturnsNotFound_WhenEmbeddingsIsNull() + { + // Arrange + Mock mockProvider = CreateMockConfigProvider(embeddingsOptions: null); + EmbeddingController controller = new(mockProvider.Object, _mockLogger.Object, _mockEmbeddingService.Object); + controller.ControllerContext = CreateControllerContext("/embed"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(NotFoundResult)); + } + + /// + /// Tests that the controller returns NotFound when embeddings is disabled. + /// + [TestMethod] + public async Task PostAsync_ReturnsNotFound_WhenEmbeddingsIsDisabled() + { + // Arrange + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "key", + Enabled: false, + Endpoint: new EmbeddingsEndpointOptions(enabled: true)); + + Mock mockProvider = CreateMockConfigProvider( + embeddingsOptions: embeddingsOptions, hostMode: HostMode.Development); + EmbeddingController controller = new(mockProvider.Object, _mockLogger.Object, _mockEmbeddingService.Object); + controller.ControllerContext = CreateControllerContext("/embed"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(NotFoundResult)); + } + + /// + /// Tests that the controller returns NotFound when endpoint config is null. + /// + [TestMethod] + public async Task PostAsync_ReturnsNotFound_WhenEndpointIsNull() + { + // Arrange + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "key", + Endpoint: null); + + Mock mockProvider = CreateMockConfigProvider( + embeddingsOptions: embeddingsOptions, hostMode: HostMode.Development); + EmbeddingController controller = new(mockProvider.Object, _mockLogger.Object, _mockEmbeddingService.Object); + controller.ControllerContext = CreateControllerContext("/embed"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(NotFoundResult)); + } + + /// + /// Tests that the controller returns NotFound when endpoint is disabled. + /// + [TestMethod] + public async Task PostAsync_ReturnsNotFound_WhenEndpointIsDisabled() + { + // Arrange + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "key", + Endpoint: new EmbeddingsEndpointOptions(enabled: false)); + + Mock mockProvider = CreateMockConfigProvider( + embeddingsOptions: embeddingsOptions, hostMode: HostMode.Development); + EmbeddingController controller = new(mockProvider.Object, _mockLogger.Object, _mockEmbeddingService.Object); + controller.ControllerContext = CreateControllerContext("/embed"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(NotFoundResult)); + } + + #endregion + + #region Service Availability Tests + + /// + /// Tests that the controller returns ServiceUnavailable when embedding service is null. + /// + [TestMethod] + public async Task PostAsync_ReturnsServiceUnavailable_WhenServiceIsNull() + { + // Arrange + EmbeddingController controller = CreateController( + requestPath: "/embed", + hostMode: HostMode.Development, + embeddingService: null, + useClassMockService: false); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.ServiceUnavailable, (int)value!.error.status); + } + + /// + /// Tests that the controller returns ServiceUnavailable when embedding service is disabled. + /// + [TestMethod] + public async Task PostAsync_ReturnsServiceUnavailable_WhenServiceIsDisabled() + { + // Arrange + Mock disabledService = new(); + disabledService.Setup(s => s.IsEnabled).Returns(false); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + hostMode: HostMode.Development, + embeddingService: disabledService.Object, + useClassMockService: false); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.ServiceUnavailable, (int)value!.error.status); + } + + #endregion + + #region Authorization Tests + + /// + /// Tests that anonymous access is allowed in development mode when no roles are configured + /// (development mode defaults to allowing anonymous). + /// + [TestMethod] + public async Task PostAsync_AllowsAnonymous_InDevelopmentMode_WithNoRolesConfigured() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f }; + SetupSuccessfulEmbedding(embedding); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Development, + endpointRoles: null, // no roles configured — dev mode defaults to anonymous + clientRole: null); // no role header — defaults to anonymous + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + } + + /// + /// Tests that anonymous access is denied when default authenticated role is used. + /// + [TestMethod] + public async Task PostAsync_ReturnsForbidden_InProductionMode_WithNoRolesConfigured() + { + // Arrange + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Production, + endpointRoles: UseConfigDefault, // use config default ["authenticated"] + clientRole: null); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.Forbidden, (int)value!.error.status); + } + + /// + /// Tests that a request with an unauthorized role is denied. + /// + [TestMethod] + public async Task PostAsync_ReturnsForbidden_WhenRoleIsNotAuthorized() + { + // Arrange + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Production, + endpointRoles: new[] { "admin" }, + clientRole: "reader"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.Forbidden, (int)value!.error.status); + } + + /// + /// Tests that a request with an authorized role is accepted. + /// + [TestMethod] + public async Task PostAsync_AllowsAccess_WhenRoleIsAuthorized() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f }; + SetupSuccessfulEmbedding(embedding); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Production, + endpointRoles: new[] { "admin", "reader" }, + clientRole: "admin"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + } + + /// + /// Tests that role matching is case-insensitive. + /// + [TestMethod] + public async Task PostAsync_RoleMatchingIsCaseInsensitive() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f }; + SetupSuccessfulEmbedding(embedding); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Production, + endpointRoles: new[] { "Admin" }, + clientRole: "ADMIN"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + } + + /// + /// Tests that when no X-MS-API-ROLE header is provided, the anonymous role is used. + /// + [TestMethod] + public async Task PostAsync_UsesAnonymousRole_WhenNoRoleHeaderProvided() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f }; + SetupSuccessfulEmbedding(embedding); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Production, + endpointRoles: new[] { "anonymous" }, + clientRole: null); // no role header + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + } + + #endregion + + #region Request Body Parsing Tests + + /// + /// Tests successful embedding with a plain text request body. + /// + [TestMethod] + public async Task PostAsync_ReturnsEmbedding_ForPlainTextBody() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f, 0.3f }; + SetupSuccessfulEmbedding(embedding); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "Hello, world!", + contentType: "text/plain", + hostMode: HostMode.Development, + acceptHeader: "text/plain"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(ContentResult)); + ContentResult contentResult = (ContentResult)result; + Assert.AreEqual("0.1,0.2,0.3", contentResult.Content); + Assert.AreEqual("text/plain", contentResult.ContentType); + } + + /// + /// Tests successful embedding with a JSON-wrapped string request body. + /// + [TestMethod] + public async Task PostAsync_ReturnsEmbedding_ForJsonWrappedStringBody() + { + // Arrange + float[] embedding = new[] { 0.4f, 0.5f }; + string expectedText = "Hello, world!"; + _mockEmbeddingService + .Setup(s => s.TryEmbedAsync(expectedText, It.IsAny())) + .ReturnsAsync(new EmbeddingResult(true, embedding)); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "\"Hello, world!\"", // JSON-wrapped string + contentType: "application/json", + hostMode: HostMode.Development, + acceptHeader: "text/plain"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(ContentResult)); + ContentResult contentResult = (ContentResult)result; + Assert.AreEqual("0.4,0.5", contentResult.Content); + + // Verify the service was called with the unwrapped string + _mockEmbeddingService.Verify( + s => s.TryEmbedAsync(expectedText, It.IsAny()), + Times.Once()); + } + + /// + /// Tests that an application/json body that is neither a string nor a document array returns BadRequest. + /// + [TestMethod] + public async Task PostAsync_ReturnsBadRequest_ForInvalidJsonBody() + { + // Arrange — a JSON object is not a valid string or document array + string rawBody = "{\"foo\":\"bar\"}"; + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: rawBody, + contentType: "application/json", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert — controller must reject the body with a descriptive message + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.BadRequest, (int)value!.error.status); + Assert.IsTrue( + jsonResult.Value?.ToString()?.Contains("application/json") == true, + "Error message should mention 'application/json'."); + + // Embedding service must NOT be called + _mockEmbeddingService.Verify( + s => s.TryEmbedAsync(It.IsAny(), It.IsAny()), + Times.Never()); + } + + #endregion + + #region Empty Request Body Validation Tests + + /// + /// Tests that an empty request body returns BadRequest. + /// + [TestMethod] + public async Task PostAsync_ReturnsBadRequest_ForEmptyBody() + { + // Arrange + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.BadRequest, (int)value!.error.status); + } + + /// + /// Tests that a whitespace-only request body returns BadRequest. + /// + [TestMethod] + public async Task PostAsync_ReturnsBadRequest_ForWhitespaceOnlyBody() + { + // Arrange + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: " \n\t ", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.BadRequest, (int)value!.error.status); + } + + #endregion + + #region Error Response Handling Tests + + /// + /// Tests that InternalServerError is returned when TryEmbedAsync fails. + /// + [TestMethod] + public async Task PostAsync_ReturnsInternalServerError_WhenEmbeddingFails() + { + // Arrange + _mockEmbeddingService + .Setup(s => s.TryEmbedAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new EmbeddingResult(false, null, "Provider returned an error.")); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.InternalServerError, (int)value!.error.status); + // Error message must NOT expose internal provider details + Assert.IsFalse( + jsonResult.Value?.ToString()?.Contains("Provider returned an error.") == true, + "Internal error details must not be exposed to the client."); + } + + /// + /// Tests that InternalServerError is returned when embedding result is null. + /// + [TestMethod] + public async Task PostAsync_ReturnsInternalServerError_WhenEmbeddingIsNull() + { + // Arrange + _mockEmbeddingService + .Setup(s => s.TryEmbedAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new EmbeddingResult(true, null)); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.InternalServerError, (int)value!.error.status); + } + + /// + /// Tests that InternalServerError is returned when embedding result is empty array. + /// + [TestMethod] + public async Task PostAsync_ReturnsInternalServerError_WhenEmbeddingIsEmpty() + { + // Arrange + _mockEmbeddingService + .Setup(s => s.TryEmbedAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new EmbeddingResult(true, Array.Empty())); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.InternalServerError, (int)value!.error.status); + } + + /// + /// Tests that when TryEmbedAsync fails with no error message, a default error message is returned. + /// + [TestMethod] + public async Task PostAsync_ReturnsDefaultErrorMessage_WhenNoErrorMessageProvided() + { + // Arrange + _mockEmbeddingService + .Setup(s => s.TryEmbedAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new EmbeddingResult(false, null, null)); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.InternalServerError, (int)value!.error.status); + // The generic error message should be returned, not internal details + Assert.IsTrue(jsonResult.Value?.ToString()?.Contains("Failed to generate embedding.") == true); + } + + #endregion + + #region Integration with IEmbeddingService Tests + + /// + /// Tests that the embedding service is called with the correct text from the request body. + /// + [TestMethod] + public async Task PostAsync_CallsEmbeddingService_WithCorrectText() + { + // Arrange + string inputText = "This is the text to embed"; + float[] embedding = new[] { 0.1f, 0.2f }; + _mockEmbeddingService + .Setup(s => s.TryEmbedAsync(inputText, It.IsAny())) + .ReturnsAsync(new EmbeddingResult(true, embedding)); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: inputText, + hostMode: HostMode.Development); + + // Act + await controller.PostAsync(); + + // Assert + _mockEmbeddingService.Verify( + s => s.TryEmbedAsync(inputText, It.IsAny()), + Times.Once()); + } + + /// + /// Tests that the embedding vector is returned as comma-separated floats in plain text + /// when Accept: text/plain is requested. + /// + [TestMethod] + public async Task PostAsync_ReturnsCommaSeparatedFloats() + { + // Arrange + float[] embedding = new[] { 1.5f, -0.25f, 3.14159f, 0f }; + SetupSuccessfulEmbedding(embedding); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test", + hostMode: HostMode.Development, + acceptHeader: "text/plain"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(ContentResult)); + ContentResult contentResult = (ContentResult)result; + Assert.AreEqual("1.5,-0.25,3.14159,0", contentResult.Content); + } + + /// + /// Tests that the embedding service is not called when the service is unavailable. + /// + [TestMethod] + public async Task PostAsync_DoesNotCallService_WhenServiceIsUnavailable() + { + // Arrange + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Development, + embeddingService: null, + useClassMockService: false); + + // Act + await controller.PostAsync(); + + // Assert + _mockEmbeddingService.Verify( + s => s.TryEmbedAsync(It.IsAny(), It.IsAny()), + Times.Never()); + } + + /// + /// Tests that the embedding service is not called when the request body is empty. + /// + [TestMethod] + public async Task PostAsync_DoesNotCallService_WhenBodyIsEmpty() + { + // Arrange + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "", + hostMode: HostMode.Development); + + // Act + await controller.PostAsync(); + + // Assert + _mockEmbeddingService.Verify( + s => s.TryEmbedAsync(It.IsAny(), It.IsAny()), + Times.Never()); + } + + /// + /// Tests that the embedding service is not called when authorization fails. + /// + [TestMethod] + public async Task PostAsync_DoesNotCallService_WhenAuthorizationFails() + { + // Arrange + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Production, + endpointRoles: new[] { "admin" }, + clientRole: "unauthorized-role"); + + // Act + await controller.PostAsync(); + + // Assert + _mockEmbeddingService.Verify( + s => s.TryEmbedAsync(It.IsAny(), It.IsAny()), + Times.Never()); + } + + #endregion + + #region Development vs Production Mode Tests + + /// + /// Tests that development mode allows anonymous access by default even without explicit roles. + /// + [TestMethod] + public async Task PostAsync_DevelopmentMode_DefaultsToAnonymousAccess() + { + // Arrange + float[] embedding = new[] { 0.1f }; + SetupSuccessfulEmbedding(embedding); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test", + hostMode: HostMode.Development, + endpointRoles: new[] { "anonymous" }, // explicitly allow anonymous + clientRole: null); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert - should succeed because anonymous is explicitly allowed + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + } + + /// + /// Tests that production mode requires authenticated role when using default. + /// + [TestMethod] + public async Task PostAsync_ProductionMode_DeniesAccessByDefault() + { + // Arrange + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test", + hostMode: HostMode.Production, + endpointRoles: UseConfigDefault, // use config default ["authenticated"] + clientRole: null); // anonymous - not allowed + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.Forbidden, (int)value!.error.status); + } + + /// + /// Tests that production mode allows access when the client role is in the configured roles. + /// + [TestMethod] + public async Task PostAsync_ProductionMode_AllowsConfiguredRole() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f }; + SetupSuccessfulEmbedding(embedding); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test", + hostMode: HostMode.Production, + endpointRoles: new[] { "authenticated", "admin" }, + clientRole: "authenticated"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + } + + #endregion + + #region Content Negotiation Tests + + /// + /// Tests that the default response (no Accept header) is JSON with EmbeddingResponse. + /// + [TestMethod] + public async Task PostAsync_ReturnsJson_WhenNoAcceptHeader() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f, 0.3f }; + SetupSuccessfulEmbedding(embedding); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Development, + acceptHeader: null); // no Accept header + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + OkObjectResult okResult = (OkObjectResult)result; + Assert.IsInstanceOfType(okResult.Value, typeof(EmbeddingResponse)); + EmbeddingResponse response = (EmbeddingResponse)okResult.Value!; + CollectionAssert.AreEqual(embedding, response.Embedding); + Assert.AreEqual(3, response.Dimensions); + } + + /// + /// Tests that Accept: application/json returns JSON with EmbeddingResponse. + /// + [TestMethod] + public async Task PostAsync_ReturnsJson_WhenAcceptIsApplicationJson() + { + // Arrange + float[] embedding = new[] { 0.5f, 0.6f }; + SetupSuccessfulEmbedding(embedding); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Development, + acceptHeader: "application/json"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + OkObjectResult okResult = (OkObjectResult)result; + Assert.IsInstanceOfType(okResult.Value, typeof(EmbeddingResponse)); + EmbeddingResponse response = (EmbeddingResponse)okResult.Value!; + CollectionAssert.AreEqual(embedding, response.Embedding); + Assert.AreEqual(2, response.Dimensions); + } + + /// + /// Tests that Accept: text/plain returns plain text with comma-separated floats. + /// + [TestMethod] + public async Task PostAsync_ReturnsTextPlain_WhenAcceptIsTextPlain() + { + // Arrange + float[] embedding = new[] { 0.7f, 0.8f, 0.9f }; + SetupSuccessfulEmbedding(embedding); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Development, + acceptHeader: "text/plain"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(ContentResult)); + ContentResult contentResult = (ContentResult)result; + Assert.AreEqual("0.7,0.8,0.9", contentResult.Content); + Assert.AreEqual("text/plain", contentResult.ContentType); + } + + /// + /// Tests that when Accept includes both application/json and text/plain, JSON wins. + /// + [TestMethod] + public async Task PostAsync_ReturnsJson_WhenAcceptIncludesBothJsonAndTextPlain() + { + // Arrange + float[] embedding = new[] { 1.0f, 2.0f }; + SetupSuccessfulEmbedding(embedding); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Development, + acceptHeader: "text/plain, application/json"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert - JSON wins when both are present + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + OkObjectResult okResult = (OkObjectResult)result; + Assert.IsInstanceOfType(okResult.Value, typeof(EmbeddingResponse)); + } + + /// + /// Tests that Accept: */* returns JSON (default format). + /// + [TestMethod] + public async Task PostAsync_ReturnsJson_WhenAcceptIsWildcard() + { + // Arrange + float[] embedding = new[] { 0.1f }; + SetupSuccessfulEmbedding(embedding); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "test text", + hostMode: HostMode.Development, + acceptHeader: "*/*"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert - wildcard does not trigger text/plain + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + } + + #endregion + + #region Document Array with Chunking Tests + + /// + /// Tests that document array requests are properly processed. + /// + [TestMethod] + public async Task PostAsync_ReturnsEmbeddings_ForDocumentArray() + { + // Arrange — controller uses TryEmbedBatchAsync per document + float[] embedding1 = new[] { 0.1f, 0.2f }; + float[] embedding2 = new[] { 0.3f, 0.4f }; + _mockEmbeddingService + .Setup(s => s.TryEmbedBatchAsync( + It.Is(texts => texts.Length == 1 && texts[0] == "First document"), + It.IsAny())) + .ReturnsAsync(new EmbeddingBatchResult(true, new[] { embedding1 })); + _mockEmbeddingService + .Setup(s => s.TryEmbedBatchAsync( + It.Is(texts => texts.Length == 1 && texts[0] == "Second document"), + It.IsAny())) + .ReturnsAsync(new EmbeddingBatchResult(true, new[] { embedding2 })); + + string requestBody = """ + [ + {"key": "doc-1", "text": "First document"}, + {"key": "doc-2", "text": "Second document"} + ] + """; + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: requestBody, + contentType: "application/json", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + OkObjectResult okResult = (OkObjectResult)result; + Assert.IsNotNull(okResult.Value); + + EmbedDocumentResponse[]? responses = okResult.Value as EmbedDocumentResponse[]; + Assert.IsNotNull(responses); + Assert.AreEqual(2, responses.Length); + Assert.AreEqual("doc-1", responses[0].Key); + Assert.AreEqual("doc-2", responses[1].Key); + Assert.AreEqual(1, responses[0].Data.Length); // no chunking by default + Assert.AreEqual(1, responses[1].Data.Length); + } + + /// + /// Tests that document array with chunking enabled splits text into multiple embeddings. + /// + [TestMethod] + public async Task PostAsync_ChunksDocuments_WhenChunkingEnabled() + { + // Arrange + float[] embedding1 = new[] { 0.1f, 0.2f }; + float[] embedding2 = new[] { 0.3f, 0.4f }; + _mockEmbeddingService + .Setup(s => s.TryEmbedBatchAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((string[] texts, CancellationToken _) => + new EmbeddingBatchResult(true, texts.Select(_ => embedding1).ToArray())); + + // Create a long text that will be chunked (default chunk size is 1000) + string longText = new string('A', 1500); + + string requestBody = $$""" + [ + {"key": "doc-1", "text": "{{longText}}"} + ] + """; + + EmbeddingsEndpointOptions endpointOptions = new(enabled: true, roles: new[] { "anonymous" }); + EmbeddingsChunkingOptions chunkingOptions = new(Enabled: true, SizeChars: 1000, OverlapChars: 250); + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key", + Enabled: true, + Endpoint: endpointOptions, + Chunking: chunkingOptions); + + Mock mockProvider = CreateMockConfigProvider( + embeddingsOptions: embeddingsOptions, + hostMode: HostMode.Development); + + EmbeddingController controller = new( + mockProvider.Object, + _mockLogger.Object, + _mockEmbeddingService.Object); + + controller.ControllerContext = CreateControllerContext( + "/embed", + requestBody, + "application/json"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + OkObjectResult okResult = (OkObjectResult)result; + EmbedDocumentResponse[]? responses = okResult.Value as EmbedDocumentResponse[]; + Assert.IsNotNull(responses); + Assert.AreEqual(1, responses.Length); + Assert.AreEqual("doc-1", responses[0].Key); + Assert.IsTrue(responses[0].Data.Length > 1, "Text should be chunked into multiple embeddings"); + } + + /// + /// Tests that query parameter $chunking.enabled=true overrides config. + /// + [TestMethod] + public async Task PostAsync_ChunkingQueryParameter_EnablesChunking() + { + // Arrange — controller calls TryEmbedBatchAsync (not TryEmbedAsync) for document arrays + float[] embedding = new[] { 0.1f, 0.2f }; + _mockEmbeddingService + .Setup(s => s.TryEmbedBatchAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((string[] texts, CancellationToken _) => + new EmbeddingBatchResult(true, texts.Select(_ => embedding).ToArray())); + + string longText = new string('A', 1500); + string requestBody = $$""" + [ + {"key": "doc-1", "text": "{{longText}}"} + ] + """; + + EmbeddingController controller = CreateController( + requestPath: "/embed?$chunking.enabled=true&$chunking.size-chars=500", + requestBody: requestBody, + contentType: "application/json", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + OkObjectResult okResult = (OkObjectResult)result; + EmbedDocumentResponse[]? responses = okResult.Value as EmbedDocumentResponse[]; + Assert.IsNotNull(responses); + Assert.AreEqual("doc-1", responses[0].Key); + Assert.IsTrue(responses[0].Data.Length >= 3, "Text should be chunked into at least 3 embeddings with 500 char chunks"); + } + + /// + /// Tests that query parameter $chunking.size-chars overrides config. + /// + [TestMethod] + public async Task PostAsync_ChunkingQueryParameter_OverridesChunkSize() + { + // Arrange — controller sends all chunks as a single batch per document + float[] embedding = new[] { 0.1f, 0.2f }; + _mockEmbeddingService + .Setup(s => s.TryEmbedBatchAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((string[] texts, CancellationToken _) => + new EmbeddingBatchResult(true, texts.Select(_ => embedding).ToArray())); + + string text = new string('A', 1000); + string requestBody = $$""" + [ + {"key": "doc-1", "text": "{{text}}"} + ] + """; + + EmbeddingController controller = CreateController( + requestPath: "/embed?$chunking.enabled=true&$chunking.size-chars=300&$chunking.overlap-chars=0", + requestBody: requestBody, + contentType: "application/json", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + OkObjectResult okResult = (OkObjectResult)result; + EmbedDocumentResponse[]? responses = okResult.Value as EmbedDocumentResponse[]; + Assert.IsNotNull(responses); + // 1000 chars with 300-char chunks and no overlap = 4 chunks (300, 300, 300, 100) + Assert.IsTrue(responses[0].Data.Length >= 4, $"Expected at least 4 chunks, but got {responses[0].Data.Length}"); + } + + /// + /// Tests that query parameter $chunking.overlap-chars is respected. + /// + [TestMethod] + public async Task PostAsync_ChunkingQueryParameter_OverridesOverlapChars() + { + // Arrange — capture the chunks batch to verify overlap + float[] embedding = new[] { 0.1f, 0.2f }; + List capturedBatches = new(); + _mockEmbeddingService + .Setup(s => s.TryEmbedBatchAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((string[] texts, CancellationToken _) => + new EmbeddingBatchResult(true, texts.Select(_ => embedding).ToArray())) + .Callback((texts, _) => capturedBatches.Add(texts)); + + string text = "0123456789" + "ABCDEFGHIJ" + "abcdefghij"; // 30 chars + string requestBody = $$""" + [ + {"key": "doc-1", "text": "{{text}}"} + ] + """; + + EmbeddingController controller = CreateController( + requestPath: "/embed?$chunking.enabled=true&$chunking.size-chars=15&$chunking.overlap-chars=5", + requestBody: requestBody, + contentType: "application/json", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + Assert.IsTrue(capturedBatches.Count > 0, "TryEmbedBatchAsync should be called"); + string[] chunks = capturedBatches[0]; + Assert.IsTrue(chunks.Length >= 2, "Should have multiple chunks"); + + // Check overlap: last 5 chars of chunk[i] should match first 5 chars of chunk[i+1] + if (chunks.Length >= 2) + { + string chunk1End = chunks[0].Substring(Math.Max(0, chunks[0].Length - 5)); + string chunk2Start = chunks[1].Substring(0, Math.Min(5, chunks[1].Length)); + Assert.AreEqual(chunk1End, chunk2Start, "Chunks should have overlapping content"); + } + } + + /// + /// Tests that $chunking.enabled=false disables chunking even if config enables it. + /// + [TestMethod] + public async Task PostAsync_ChunkingQueryParameter_DisablesChunking() + { + // Arrange — controller calls TryEmbedBatchAsync; with chunking disabled the batch has 1 element + float[] embedding = new[] { 0.1f, 0.2f }; + _mockEmbeddingService + .Setup(s => s.TryEmbedBatchAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((string[] texts, CancellationToken _) => + new EmbeddingBatchResult(true, texts.Select(_ => embedding).ToArray())); + + string longText = new string('A', 2000); + string requestBody = $$""" + [ + {"key": "doc-1", "text": "{{longText}}"} + ] + """; + + EmbeddingsEndpointOptions endpointOptions = new(enabled: true, roles: new[] { "anonymous" }); + EmbeddingsChunkingOptions chunkingOptions = new(Enabled: true, SizeChars: 500, OverlapChars: 100); + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key", + Enabled: true, + Endpoint: endpointOptions, + Chunking: chunkingOptions); + + Mock mockProvider = CreateMockConfigProvider( + embeddingsOptions: embeddingsOptions, + hostMode: HostMode.Development); + + EmbeddingController controller = new( + mockProvider.Object, + _mockLogger.Object, + _mockEmbeddingService.Object); + + controller.ControllerContext = CreateControllerContext( + "/embed?$chunking.enabled=false", + requestBody, + "application/json"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + OkObjectResult okResult = (OkObjectResult)result; + EmbedDocumentResponse[]? responses = okResult.Value as EmbedDocumentResponse[]; + Assert.IsNotNull(responses); + Assert.AreEqual(1, responses[0].Data.Length, "Should not chunk when disabled via query parameter"); + } + + /// + /// Tests that empty document array returns BadRequest. + /// + [TestMethod] + public async Task PostAsync_ReturnsBadRequest_ForEmptyDocumentArray() + { + // Arrange + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "[]", + contentType: "application/json", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.BadRequest, (int)value!.error.status); + } + + /// + /// Tests that document with missing key returns InternalServerError. + /// + [TestMethod] + public async Task PostAsync_HandlesDocumentWithMissingKey() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f }; + _mockEmbeddingService + .Setup(s => s.TryEmbedAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new EmbeddingResult(true, embedding)); + + string requestBody = """ + [ + {"text": "Document without key"} + ] + """; + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: requestBody, + contentType: "application/json", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert - document without key should be rejected with 400 + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.BadRequest, (int)value!.error.status); + } + + /// + /// Tests that document with empty text is skipped or returns error. + /// + [TestMethod] + public async Task PostAsync_HandlesDocumentWithEmptyText() + { + // Arrange + string requestBody = """ + [ + {"key": "doc-1", "text": ""} + ] + """; + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: requestBody, + contentType: "application/json", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert - empty text should result in a 400 error + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.BadRequest, (int)value!.error.status); + } + + /// + /// Tests that chunking respects minimum chunk size. + /// + [TestMethod] + public async Task PostAsync_ChunkingHandlesVerySmallChunkSize() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f }; + _mockEmbeddingService + .Setup(s => s.TryEmbedBatchAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((string[] texts, CancellationToken _) => + new EmbeddingBatchResult(true, texts.Select(_ => embedding).ToArray())); + + string requestBody = """ + [ + {"key": "doc-1", "text": "Short"} + ] + """; + + EmbeddingController controller = CreateController( + requestPath: "/embed?$chunking.enabled=true&$chunking.size-chars=1", + requestBody: requestBody, + contentType: "application/json", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert — size=1 produces one chunk per character; must not crash + Assert.IsNotNull(result, "Result should not be null"); + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + } + + /// + /// Tests chunking with overlap larger than chunk size. + /// + [TestMethod] + public async Task PostAsync_ChunkingHandlesOverlapLargerThanChunkSize() + { + // Arrange — EffectiveSizeChars clamps to overlap+1, so chunking terminates safely + float[] embedding = new[] { 0.1f, 0.2f }; + _mockEmbeddingService + .Setup(s => s.TryEmbedBatchAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((string[] texts, CancellationToken _) => + new EmbeddingBatchResult(true, texts.Select(_ => embedding).ToArray())); + + string text = new string('A', 100); + string requestBody = $$""" + [ + {"key": "doc-1", "text": "{{text}}"} + ] + """; + + EmbeddingController controller = CreateController( + requestPath: "/embed?$chunking.enabled=true&$chunking.size-chars=50&$chunking.overlap-chars=60", + requestBody: requestBody, + contentType: "application/json", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert — overlap clamped via EffectiveSizeChars; result must be Ok + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + } + + /// + /// Tests that failed embeddings in document array process are handled. + /// + [TestMethod] + public async Task PostAsync_HandlesEmbeddingFailure_InDocumentArray() + { + // Arrange — first doc succeeds, second fails; controller uses TryEmbedBatchAsync per doc + float[] embedding = new[] { 0.1f, 0.2f }; + _mockEmbeddingService + .SetupSequence(s => s.TryEmbedBatchAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new EmbeddingBatchResult(true, new[] { embedding })) + .ReturnsAsync(new EmbeddingBatchResult(false, null, "Provider error")); + + string requestBody = """ + [ + {"key": "doc-1", "text": "First document"}, + {"key": "doc-2", "text": "Second document"} + ] + """; + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: requestBody, + contentType: "application/json", + hostMode: HostMode.Development); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert - should return error when any embedding fails + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult jsonResult = (JsonResult)result; + dynamic? value = jsonResult.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.InternalServerError, (int)value!.error.status); + } + + #endregion + + #region Invalid Query Parameter Tests + + /// + /// Tests that an invalid $chunking.enabled value returns BadRequest. + /// + [TestMethod] + public async Task PostAsync_ReturnsBadRequest_ForInvalidChunkingEnabled() + { + EmbeddingController controller = CreateController( + requestPath: "/embed?$chunking.enabled=notabool", + requestBody: "test", + hostMode: HostMode.Development); + + IActionResult result = await controller.PostAsync(); + + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult bad = (JsonResult)result; + dynamic? value = bad.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.BadRequest, (int)value!.error.status); + Assert.IsTrue(bad.Value?.ToString()?.Contains("$chunking.enabled") == true); + } + + /// + /// Tests that a non-positive $chunking.size-chars returns BadRequest. + /// + [TestMethod] + public async Task PostAsync_ReturnsBadRequest_ForNonPositiveChunkSize() + { + EmbeddingController controller = CreateController( + requestPath: "/embed?$chunking.size-chars=0", + requestBody: "test", + hostMode: HostMode.Development); + + IActionResult result = await controller.PostAsync(); + + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult bad = (JsonResult)result; + dynamic? value = bad.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.BadRequest, (int)value!.error.status); + Assert.IsTrue(bad.Value?.ToString()?.Contains("$chunking.size-chars") == true); + } + + /// + /// Tests that a negative $chunking.overlap-chars returns BadRequest. + /// + [TestMethod] + public async Task PostAsync_ReturnsBadRequest_ForNegativeOverlapChars() + { + EmbeddingController controller = CreateController( + requestPath: "/embed?$chunking.overlap-chars=-1", + requestBody: "test", + hostMode: HostMode.Development); + + IActionResult result = await controller.PostAsync(); + + Assert.IsInstanceOfType(result, typeof(JsonResult)); + JsonResult bad = (JsonResult)result; + dynamic? value = bad.Value; + Assert.IsNotNull(value); + Assert.AreEqual((int)HttpStatusCode.BadRequest, (int)value!.error.status); + Assert.IsTrue(bad.Value?.ToString()?.Contains("$chunking.overlap-chars") == true); + } + + #endregion + + #region Single Text with Chunking Tests + + /// + /// Tests that a plain-text body with chunking enabled is routed through the + /// document-array path and returns multiple embeddings. + /// + [TestMethod] + public async Task PostAsync_SingleText_WithChunkingEnabled_ReturnsDocumentResponse() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f }; + _mockEmbeddingService + .Setup(s => s.TryEmbedBatchAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((string[] texts, CancellationToken _) => + new EmbeddingBatchResult(true, texts.Select(_ => embedding).ToArray())); + + string longText = new string('X', 1500); + + EmbeddingsEndpointOptions endpointOptions = new(enabled: true, roles: new[] { "anonymous" }); + EmbeddingsChunkingOptions chunkingOptions = new(Enabled: true, SizeChars: 1000, OverlapChars: 250); + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key", + Enabled: true, + Endpoint: endpointOptions, + Chunking: chunkingOptions); + + Mock mockProvider = CreateMockConfigProvider( + embeddingsOptions: embeddingsOptions, + hostMode: HostMode.Development); + + EmbeddingController controller = new(mockProvider.Object, _mockLogger.Object, _mockEmbeddingService.Object); + controller.ControllerContext = CreateControllerContext("/embed", longText, "text/plain"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert — chunking routes through document-array path; returns EmbedDocumentResponse[] + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + OkObjectResult okResult = (OkObjectResult)result; + EmbedDocumentResponse[]? responses = okResult.Value as EmbedDocumentResponse[]; + Assert.IsNotNull(responses, "Chunked single-text should return EmbedDocumentResponse[]"); + Assert.AreEqual("input", responses[0].Key); + Assert.IsTrue(responses[0].Data.Length > 1, "Text should be split into multiple chunks"); + } + + /// + /// Tests that a plain-text body with chunking disabled returns the legacy EmbeddingResponse. + /// + [TestMethod] + public async Task PostAsync_SingleText_WithChunkingDisabled_ReturnsEmbeddingResponse() + { + float[] embedding = new[] { 0.1f, 0.2f }; + SetupSuccessfulEmbedding(embedding); + + EmbeddingController controller = CreateController( + requestPath: "/embed", + requestBody: "hello world", + contentType: "text/plain", + hostMode: HostMode.Development); + + IActionResult result = await controller.PostAsync(); + + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + OkObjectResult okResult = (OkObjectResult)result; + Assert.IsInstanceOfType(okResult.Value, typeof(EmbeddingResponse)); + } + + #endregion + + #region Accept: text/plain Consistency with Chunking Tests + + /// + /// Single text + chunking enabled + Accept: text/plain must return ContentResult (not JSON), + /// with one line per chunk where each line is comma-separated floats. + /// This validates that the Accept header is honoured consistently regardless of whether + /// chunking routes through the document-array path. + /// + [TestMethod] + public async Task PostAsync_SingleText_ChunkingEnabled_AcceptTextPlain_ReturnsPlainTextLines() + { + // Arrange — a 1500-char text with 1000-char chunks and no overlap produces exactly 2 chunks + float[] embedding = new[] { 0.1f, 0.2f }; + _mockEmbeddingService + .Setup(s => s.TryEmbedBatchAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((string[] texts, CancellationToken _) => + new EmbeddingBatchResult(true, texts.Select(_ => embedding).ToArray())); + + EmbeddingController controller = CreateControllerWithChunking( + requestBody: new string('X', 1500), + acceptHeader: "text/plain"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert — ContentResult, not OkObjectResult + Assert.IsInstanceOfType(result, typeof(ContentResult)); + ContentResult contentResult = (ContentResult)result; + Assert.AreEqual("text/plain", contentResult.ContentType); + Assert.IsNotNull(contentResult.Content); + + // Two chunks → two newline-separated lines + string[] lines = contentResult.Content!.Split('\n'); + Assert.AreEqual(2, lines.Length, "Each chunk produces one line."); + foreach (string line in lines) + { + Assert.IsTrue(line.Contains(','), "Each line must contain comma-separated floats."); + } + } + + /// + /// Validates the exact text/plain format for a chunked single-text request: + /// line N contains the comma-separated floats of chunk N's embedding vector. + /// + [TestMethod] + public async Task PostAsync_SingleText_ChunkingEnabled_AcceptTextPlain_ExactLineFormat() + { + // Arrange — deterministic embeddings: chunk 0 → [0.1, 0.2, 0.3], chunk 1 → [0.4, 0.5, 0.6] + float[] chunkEmbedding1 = new[] { 0.1f, 0.2f, 0.3f }; + float[] chunkEmbedding2 = new[] { 0.4f, 0.5f, 0.6f }; + _mockEmbeddingService + .Setup(s => s.TryEmbedBatchAsync( + It.Is(t => t.Length == 2), + It.IsAny())) + .ReturnsAsync(new EmbeddingBatchResult(true, new[] { chunkEmbedding1, chunkEmbedding2 })); + + // 1500 chars, 1000-char chunk size, 0 overlap → exactly 2 chunks sent as one batch + EmbeddingController controller = CreateControllerWithChunking( + requestBody: new string('X', 1500), + acceptHeader: "text/plain"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(ContentResult)); + ContentResult contentResult = (ContentResult)result; + string[] lines = contentResult.Content!.Split('\n'); + Assert.AreEqual(2, lines.Length); + Assert.AreEqual("0.1,0.2,0.3", lines[0]); + Assert.AreEqual("0.4,0.5,0.6", lines[1]); + } + + /// + /// Single text + chunking enabled + no Accept header must return JSON (OkObjectResult), + /// preserving the default JSON behaviour even when chunking is active. + /// + [TestMethod] + public async Task PostAsync_SingleText_ChunkingEnabled_NoAcceptHeader_ReturnsJson() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f }; + _mockEmbeddingService + .Setup(s => s.TryEmbedBatchAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((string[] texts, CancellationToken _) => + new EmbeddingBatchResult(true, texts.Select(_ => embedding).ToArray())); + + EmbeddingController controller = CreateControllerWithChunking( + requestBody: new string('X', 1500), + acceptHeader: null); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert — no Accept header → JSON (EmbedDocumentResponse[]) + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + OkObjectResult okResult = (OkObjectResult)result; + Assert.IsInstanceOfType(okResult.Value, typeof(EmbedDocumentResponse[])); + } + + /// + /// Single text + chunking enabled + Accept: application/json must return JSON, + /// consistent with the non-chunked path where JSON wins over text/plain. + /// + [TestMethod] + public async Task PostAsync_SingleText_ChunkingEnabled_AcceptJson_ReturnsJson() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f }; + _mockEmbeddingService + .Setup(s => s.TryEmbedBatchAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((string[] texts, CancellationToken _) => + new EmbeddingBatchResult(true, texts.Select(_ => embedding).ToArray())); + + EmbeddingController controller = CreateControllerWithChunking( + requestBody: new string('X', 1500), + acceptHeader: "application/json"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + OkObjectResult okResult = (OkObjectResult)result; + Assert.IsInstanceOfType(okResult.Value, typeof(EmbedDocumentResponse[])); + } + + /// + /// Single text + chunking enabled + Accept: text/plain, application/json → JSON wins, + /// matching the same precedence rule applied in the non-chunked single-text path. + /// + [TestMethod] + public async Task PostAsync_SingleText_ChunkingEnabled_AcceptBothJsonAndTextPlain_JsonWins() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f }; + _mockEmbeddingService + .Setup(s => s.TryEmbedBatchAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((string[] texts, CancellationToken _) => + new EmbeddingBatchResult(true, texts.Select(_ => embedding).ToArray())); + + EmbeddingController controller = CreateControllerWithChunking( + requestBody: new string('X', 1500), + acceptHeader: "text/plain, application/json"); + + // Act + IActionResult result = await controller.PostAsync(); + + // Assert — JSON takes precedence + Assert.IsInstanceOfType(result, typeof(OkObjectResult)); + OkObjectResult okResult = (OkObjectResult)result; + Assert.IsInstanceOfType(okResult.Value, typeof(EmbedDocumentResponse[])); + } + + /// + /// Helper: creates a controller wired with chunking enabled (1000-char chunks, no overlap) + /// and the class-level mock embedding service. + /// + private EmbeddingController CreateControllerWithChunking( + string requestBody, + string? acceptHeader, + int sizeChars = 1000, + int overlapChars = 0) + { + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key", + Enabled: true, + Endpoint: new EmbeddingsEndpointOptions(enabled: true, roles: new[] { "anonymous" }), + Chunking: new EmbeddingsChunkingOptions(Enabled: true, SizeChars: sizeChars, OverlapChars: overlapChars)); + + Mock mockProvider = CreateMockConfigProvider( + embeddingsOptions: embeddingsOptions, + hostMode: HostMode.Development); + + EmbeddingController controller = new(mockProvider.Object, _mockLogger.Object, _mockEmbeddingService.Object); + controller.ControllerContext = CreateControllerContext( + "/embed", + requestBody, + contentType: "text/plain", + acceptHeader: acceptHeader); + return controller; + } + + #endregion + + #region Helper Methods + + /// + /// Sets up the mock embedding service to return a successful result with the given embedding. + /// + private void SetupSuccessfulEmbedding(float[] embedding) + { + _mockEmbeddingService + .Setup(s => s.TryEmbedAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new EmbeddingResult(true, embedding)); + } + + /// + /// Sentinel array to indicate the test wants to use config defaults (not test defaults). + /// Use this in tests that explicitly want to test the default role behavior. + /// + private static readonly string[] UseConfigDefault = Array.Empty(); + + /// + /// Creates an EmbeddingController with all the necessary mocks wired up. + /// + private EmbeddingController CreateController( + string requestPath, + string? requestBody = null, + string? contentType = "text/plain", + HostMode hostMode = HostMode.Development, + string[]? endpointRoles = null, + string? clientRole = null, + IEmbeddingService? embeddingService = null, + bool useClassMockService = true, + string? acceptHeader = null) + { + // Determine roles to use: + // - If UseConfigDefault sentinel: pass null to use actual config defaults + // - If null: default to anonymous for test convenience + // - Otherwise: use provided roles + string[]? rolesToUse; + if (ReferenceEquals(endpointRoles, UseConfigDefault)) + { + rolesToUse = null; // Will use config default ["authenticated"] + } + else + { + rolesToUse = endpointRoles ?? new[] { "anonymous" }; // Test default for convenience + } + + EmbeddingsEndpointOptions endpointOptions = new( + enabled: true, + roles: rolesToUse); + + EmbeddingsOptions embeddingsOptions = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key", + Enabled: true, + Endpoint: endpointOptions); + + Mock mockProvider = CreateMockConfigProvider( + embeddingsOptions: embeddingsOptions, + hostMode: hostMode); + + // If useClassMockService is true and no explicit service provided, use the class-level mock + IEmbeddingService? serviceToUse = useClassMockService && embeddingService is null + ? _mockEmbeddingService.Object + : embeddingService; + + EmbeddingController controller = new( + mockProvider.Object, + _mockLogger.Object, + serviceToUse); + + controller.ControllerContext = CreateControllerContext( + requestPath, + requestBody, + contentType, + clientRole, + acceptHeader); + + return controller; + } + + /// + /// Creates a mock RuntimeConfigProvider that returns a config with the specified embeddings and host options. + /// + private static Mock CreateMockConfigProvider( + EmbeddingsOptions? embeddingsOptions, + HostMode hostMode = HostMode.Development) + { + HostOptions hostOptions = new( + Cors: null, + Authentication: null, + Mode: hostMode); + + RuntimeOptions runtimeOptions = new( + Rest: null, + GraphQL: null, + Mcp: null, + Host: hostOptions, + Embeddings: embeddingsOptions); + + DataSource dataSource = new(DatabaseType.MSSQL, string.Empty); + RuntimeEntities entities = new(new System.Collections.Generic.Dictionary()); + + RuntimeConfig config = new( + Schema: null, + DataSource: dataSource, + Entities: entities, + Runtime: runtimeOptions); + + Mock mockLoader = new(null, null); + Mock mockProvider = new(mockLoader.Object); + mockProvider + .Setup(p => p.GetConfig()) + .Returns(config); + + return mockProvider; + } + + /// + /// Creates a ControllerContext with a configured HttpContext for testing. + /// + private static ControllerContext CreateControllerContext( + string requestPath, + string? requestBody = null, + string? contentType = "text/plain", + string? clientRole = null, + string? acceptHeader = null) + { + DefaultHttpContext httpContext = new(); + + // Parse path and query string + int queryIndex = requestPath.IndexOf('?'); + if (queryIndex >= 0) + { + httpContext.Request.Path = requestPath.Substring(0, queryIndex); + httpContext.Request.QueryString = new QueryString(requestPath.Substring(queryIndex)); + } + else + { + httpContext.Request.Path = requestPath; + } + + httpContext.Request.Method = "POST"; + httpContext.Request.ContentType = contentType; + + if (requestBody is not null) + { + byte[] bodyBytes = Encoding.UTF8.GetBytes(requestBody); + httpContext.Request.Body = new MemoryStream(bodyBytes); + httpContext.Request.ContentLength = bodyBytes.Length; + } + else + { + httpContext.Request.Body = new MemoryStream(); + } + + if (!string.IsNullOrEmpty(clientRole)) + { + httpContext.Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER] = clientRole; + } + + if (!string.IsNullOrEmpty(acceptHeader)) + { + httpContext.Request.Headers.Accept = acceptHeader; + } + + return new ControllerContext + { + HttpContext = httpContext + }; + } + + #endregion +} + diff --git a/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs b/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs new file mode 100644 index 0000000000..b780aa608f --- /dev/null +++ b/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs @@ -0,0 +1,1486 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; +using Azure.DataApiBuilder.Core.Services.Embeddings; +using Microsoft.Extensions.Logging; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using Moq.Protected; +using ZiggyCreatures.Caching.Fusion; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests; + +/// +/// Unit tests for EmbeddingService. +/// +[TestClass] +public class EmbeddingServiceTests +{ + private Mock> _mockLogger = null!; + private Mock _mockCache = null!; + + [TestInitialize] + public void Setup() + { + _mockLogger = new Mock>(); + _mockCache = new Mock(); + } + + /// + /// Tests that IsEnabled returns true when embeddings are enabled. + /// + [TestMethod] + public void IsEnabled_ReturnsTrue_WhenEnabled() + { + // Arrange + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + + // Assert + Assert.IsTrue(service.IsEnabled); + } + + /// + /// Tests that IsEnabled returns false when embeddings are disabled. + /// + [TestMethod] + public void IsEnabled_ReturnsFalse_WhenDisabled() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://test.openai.azure.com", + ApiKey: "test-api-key", + Enabled: false, + Model: "text-embedding-ada-002"); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + + // Assert + Assert.IsFalse(service.IsEnabled); + } + + /// + /// Tests that TryEmbedAsync returns failure when service is disabled. + /// + [TestMethod] + public async Task TryEmbedAsync_ReturnsFailure_WhenDisabled() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://test.openai.azure.com", + ApiKey: "test-api-key", + Enabled: false, + Model: "text-embedding-ada-002"); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + + // Act + EmbeddingResult result = await service.TryEmbedAsync("test"); + + // Assert + Assert.IsFalse(result.Success); + Assert.IsNull(result.Embedding); + Assert.IsNotNull(result.ErrorMessage); + } + + /// + /// Tests that TryEmbedAsync returns failure for null or empty text. + /// + [DataTestMethod] + [DataRow(null, DisplayName = "Null text returns failure")] + [DataRow("", DisplayName = "Empty text returns failure")] + public async Task TryEmbedAsync_ReturnsFailure_ForNullOrEmptyText(string text) + { + // Arrange + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + + // Act + EmbeddingResult result = await service.TryEmbedAsync(text!); + + // Assert + Assert.IsFalse(result.Success); + } + + /// + /// Tests that EffectiveModel returns the default model for OpenAI when not specified. + /// + [TestMethod] + public void EmbeddingsOptions_OpenAI_DefaultModel() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key"); + + // Assert + Assert.IsNull(options.Model); + Assert.AreEqual(EmbeddingsOptions.DEFAULT_OPENAI_MODEL, options.EffectiveModel); + } + + /// + /// Tests that EffectiveModel returns null for Azure OpenAI when model not specified. + /// + [TestMethod] + public void EmbeddingsOptions_AzureOpenAI_NoDefaultModel() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://my.openai.azure.com", + ApiKey: "test-key"); + + // Assert + Assert.IsNull(options.Model); + Assert.IsNull(options.EffectiveModel); + } + + /// + /// Tests that EffectiveTimeoutMs returns the default timeout when not specified. + /// + [TestMethod] + public void EmbeddingsOptions_DefaultTimeout() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key"); + + // Assert + Assert.IsNull(options.TimeoutMs); + Assert.AreEqual(EmbeddingsOptions.DEFAULT_TIMEOUT_MS, options.EffectiveTimeoutMs); + } + + /// + /// Tests that custom timeout is used when specified. + /// + [TestMethod] + public void EmbeddingsOptions_CustomTimeout() + { + // Arrange + int customTimeout = 60000; + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key", + TimeoutMs: customTimeout); + + // Assert + Assert.AreEqual(customTimeout, options.TimeoutMs); + Assert.AreEqual(customTimeout, options.EffectiveTimeoutMs); + Assert.IsTrue(options.UserProvidedTimeoutMs); + } + + #region Successful API Call Tests + + /// + /// Tests that TryEmbedAsync returns a successful result with correct embedding values + /// when the Azure OpenAI API returns a valid response. + /// + [TestMethod] + public async Task TryEmbedAsync_ReturnsSuccess_WithValidAzureOpenAIResponse() + { + // Arrange + float[] expectedEmbedding = new[] { 0.1f, 0.2f, 0.3f }; + string responseJson = CreateEmbeddingResponseJson(expectedEmbedding); + + Mock mockHandler = CreateMockHttpMessageHandler(HttpStatusCode.OK, responseJson); + HttpClient httpClient = new(mockHandler.Object); + + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + EmbeddingResult result = await service.TryEmbedAsync("test text"); + + // Assert + Assert.IsTrue(result.Success); + Assert.IsNotNull(result.Embedding); + Assert.IsNull(result.ErrorMessage); + CollectionAssert.AreEqual(expectedEmbedding, result.Embedding); + + // Verify HTTP call was made + mockHandler.Protected().Verify( + "SendAsync", + Times.Once(), + ItExpr.IsAny(), + ItExpr.IsAny()); + } + + /// + /// Tests that TryEmbedAsync returns a successful result with correct embedding values + /// when the OpenAI API returns a valid response. + /// + [TestMethod] + public async Task TryEmbedAsync_ReturnsSuccess_WithValidOpenAIResponse() + { + // Arrange + float[] expectedEmbedding = new[] { 0.4f, 0.5f, 0.6f, 0.7f }; + string responseJson = CreateEmbeddingResponseJson(expectedEmbedding); + + Mock mockHandler = CreateMockHttpMessageHandler(HttpStatusCode.OK, responseJson); + HttpClient httpClient = new(mockHandler.Object); + + EmbeddingsOptions options = CreateOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + EmbeddingResult result = await service.TryEmbedAsync("test text"); + + // Assert + Assert.IsTrue(result.Success); + Assert.IsNotNull(result.Embedding); + CollectionAssert.AreEqual(expectedEmbedding, result.Embedding); + } + + /// + /// Tests that EmbedAsync returns the expected embedding array on a successful API call. + /// + [TestMethod] + public async Task EmbedAsync_ReturnsEmbedding_OnSuccessfulApiCall() + { + // Arrange + float[] expectedEmbedding = new[] { 1.0f, 2.0f, 3.0f }; + string responseJson = CreateEmbeddingResponseJson(expectedEmbedding); + + Mock mockHandler = CreateMockHttpMessageHandler(HttpStatusCode.OK, responseJson); + HttpClient httpClient = new(mockHandler.Object); + + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + float[] result = await service.EmbedAsync("test text"); + + // Assert + CollectionAssert.AreEqual(expectedEmbedding, result); + } + + #endregion + + #region HTTP Error Handling Tests + + /// + /// Tests that TryEmbedAsync returns failure with error message when the API returns an HTTP error. + /// + [DataTestMethod] + [DataRow(HttpStatusCode.BadRequest, "Bad Request", DisplayName = "400 Bad Request")] + [DataRow(HttpStatusCode.Unauthorized, "Invalid API key", DisplayName = "401 Unauthorized")] + [DataRow(HttpStatusCode.TooManyRequests, "Rate limit exceeded", DisplayName = "429 Too Many Requests")] + [DataRow(HttpStatusCode.InternalServerError, "Internal server error", DisplayName = "500 Internal Server Error")] + public async Task TryEmbedAsync_ReturnsFailure_OnHttpError(HttpStatusCode statusCode, string errorBody) + { + // Arrange + Mock mockHandler = CreateMockHttpMessageHandler(statusCode, errorBody); + HttpClient httpClient = new(mockHandler.Object); + + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + EmbeddingResult result = await service.TryEmbedAsync("test text"); + + // Assert + Assert.IsFalse(result.Success); + Assert.IsNull(result.Embedding); + Assert.IsNotNull(result.ErrorMessage); + // The error message is a generic message when the service encounters any error + Assert.AreEqual("Failed to generate embedding.", result.ErrorMessage, + $"Error message should be the generic failure message. Actual: {result.ErrorMessage}"); + } + + /// + /// Tests that EmbedAsync throws an exception when the API returns an HTTP error. + /// + [TestMethod] + public async Task EmbedAsync_ThrowsException_OnHttpError() + { + // Arrange + Mock mockHandler = CreateMockHttpMessageHandler( + HttpStatusCode.InternalServerError, "Server error"); + HttpClient httpClient = new(mockHandler.Object); + + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act & Assert + await Assert.ThrowsExceptionAsync( + () => service.EmbedAsync("test text")); + } + + #endregion + + #region Response Parsing and Validation Tests + + /// + /// Tests that TryEmbedAsync returns failure when the API returns an empty data array. + /// + [TestMethod] + public async Task TryEmbedAsync_ReturnsFailure_WhenApiReturnsEmptyData() + { + // Arrange + string responseJson = JsonSerializer.Serialize(new { data = Array.Empty(), model = "test" }); + + Mock mockHandler = CreateMockHttpMessageHandler(HttpStatusCode.OK, responseJson); + HttpClient httpClient = new(mockHandler.Object); + + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + EmbeddingResult result = await service.TryEmbedAsync("test text"); + + // Assert + Assert.IsFalse(result.Success); + Assert.IsNotNull(result.ErrorMessage); + } + + /// + /// Tests that TryEmbedAsync returns failure when the API returns null data. + /// + [TestMethod] + public async Task TryEmbedAsync_ReturnsFailure_WhenApiReturnsNullData() + { + // Arrange + string responseJson = JsonSerializer.Serialize(new { model = "test" }); + + Mock mockHandler = CreateMockHttpMessageHandler(HttpStatusCode.OK, responseJson); + HttpClient httpClient = new(mockHandler.Object); + + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + EmbeddingResult result = await service.TryEmbedAsync("test text"); + + // Assert + Assert.IsFalse(result.Success); + } + + /// + /// Tests that TryEmbedBatchAsync returns failure when the API returns a mismatched number + /// of embeddings compared to the input count. + /// + [TestMethod] + public async Task TryEmbedBatchAsync_ReturnsFailure_WhenEmbeddingCountMismatches() + { + // Arrange - send 2 texts but API returns 1 embedding + string responseJson = CreateEmbeddingResponseJson(new[] { 0.1f, 0.2f }); // single embedding + + Mock mockHandler = CreateMockHttpMessageHandler(HttpStatusCode.OK, responseJson); + HttpClient httpClient = new(mockHandler.Object); + + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + EmbeddingBatchResult result = await service.TryEmbedBatchAsync(new[] { "text1", "text2" }); + + // Assert + Assert.IsFalse(result.Success); + Assert.IsNotNull(result.ErrorMessage); + } + + /// + /// Tests that TryEmbedBatchAsync returns failure when the API returns out-of-range indices. + /// + [TestMethod] + public async Task TryEmbedBatchAsync_ReturnsFailure_WhenIndicesOutOfRange() + { + // Arrange - 1 text but embedding has index 5 + string responseJson = CreateEmbeddingResponseJsonWithIndices( + new[] { (5, new[] { 0.1f, 0.2f }) }); + + Mock mockHandler = CreateMockHttpMessageHandler(HttpStatusCode.OK, responseJson); + HttpClient httpClient = new(mockHandler.Object); + + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + EmbeddingBatchResult result = await service.TryEmbedBatchAsync(new[] { "text1" }); + + // Assert + Assert.IsFalse(result.Success); + Assert.IsNotNull(result.ErrorMessage); + } + + /// + /// Tests that TryEmbedBatchAsync returns failure when the API returns duplicate indices. + /// + [TestMethod] + public async Task TryEmbedBatchAsync_ReturnsFailure_WhenDuplicateIndices() + { + // Arrange - 2 texts but both embeddings have index 0 + string responseJson = CreateEmbeddingResponseJsonWithIndices( + new[] { (0, new[] { 0.1f, 0.2f }), (0, new[] { 0.3f, 0.4f }) }); + + Mock mockHandler = CreateMockHttpMessageHandler(HttpStatusCode.OK, responseJson); + HttpClient httpClient = new(mockHandler.Object); + + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + EmbeddingBatchResult result = await service.TryEmbedBatchAsync(new[] { "text1", "text2" }); + + // Assert + Assert.IsFalse(result.Success); + Assert.IsNotNull(result.ErrorMessage); + } + + /// + /// Tests that batch embeddings are returned in the correct order even when the API + /// returns them out of order (by index). + /// + [TestMethod] + public async Task TryEmbedBatchAsync_ReturnsCorrectOrder_WhenApiReturnsOutOfOrder() + { + // Arrange - API returns index 1 before index 0 + float[] embedding0 = new[] { 0.1f, 0.2f }; + float[] embedding1 = new[] { 0.3f, 0.4f }; + string responseJson = CreateEmbeddingResponseJsonWithIndices( + new[] { (1, embedding1), (0, embedding0) }); + + Mock mockHandler = CreateMockHttpMessageHandler(HttpStatusCode.OK, responseJson); + HttpClient httpClient = new(mockHandler.Object); + + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + EmbeddingBatchResult result = await service.TryEmbedBatchAsync(new[] { "text0", "text1" }); + + // Assert + Assert.IsTrue(result.Success); + Assert.IsNotNull(result.Embeddings); + Assert.AreEqual(2, result.Embeddings.Length); + CollectionAssert.AreEqual(embedding0, result.Embeddings[0]); + CollectionAssert.AreEqual(embedding1, result.Embeddings[1]); + } + + #endregion + + #region Cache Hit/Miss Tests + + /// + /// Tests that the second call to TryEmbedAsync with the same text returns the cached result + /// and does not make a second API call. + /// + [TestMethod] + public async Task TryEmbedAsync_ReturnsCachedResult_OnSecondCallWithSameText() + { + // Arrange + float[] expectedEmbedding = new[] { 0.1f, 0.2f, 0.3f }; + string responseJson = CreateEmbeddingResponseJson(expectedEmbedding); + + int callCount = 0; + Mock mockHandler = new(MockBehavior.Strict); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(() => + { + callCount++; + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(responseJson, Encoding.UTF8, "application/json") + }; + }); + + HttpClient httpClient = new(mockHandler.Object); + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act - first call triggers API + EmbeddingResult result1 = await service.TryEmbedAsync("same text"); + // Act - second call should use cache + EmbeddingResult result2 = await service.TryEmbedAsync("same text"); + + // Assert + Assert.IsTrue(result1.Success); + Assert.IsTrue(result2.Success); + CollectionAssert.AreEqual(expectedEmbedding, result1.Embedding); + CollectionAssert.AreEqual(expectedEmbedding, result2.Embedding); + Assert.AreEqual(1, callCount, "HTTP API should only be called once; second call should use cache."); + } + + /// + /// Tests that different texts result in separate API calls (cache misses). + /// + [TestMethod] + public async Task TryEmbedAsync_MakesSeparateApiCalls_ForDifferentTexts() + { + // Arrange + float[] embedding1 = new[] { 0.1f, 0.2f }; + float[] embedding2 = new[] { 0.3f, 0.4f }; + + int callCount = 0; + Mock mockHandler = new(MockBehavior.Strict); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(() => + { + callCount++; + float[] embedding = callCount == 1 ? embedding1 : embedding2; + string json = CreateEmbeddingResponseJson(embedding); + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(json, Encoding.UTF8, "application/json") + }; + }); + + HttpClient httpClient = new(mockHandler.Object); + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + EmbeddingResult result1 = await service.TryEmbedAsync("text one"); + EmbeddingResult result2 = await service.TryEmbedAsync("text two"); + + // Assert + Assert.IsTrue(result1.Success); + Assert.IsTrue(result2.Success); + Assert.AreEqual(2, callCount, "Each unique text should trigger a separate API call."); + } + + #endregion + + #region Batch Embedding Tests + + /// + /// Tests that TryEmbedBatchAsync returns success with correct embeddings for multiple texts. + /// + [TestMethod] + public async Task TryEmbedBatchAsync_ReturnsSuccess_ForMultipleTexts() + { + // Arrange + float[] embedding0 = new[] { 0.1f, 0.2f }; + float[] embedding1 = new[] { 0.3f, 0.4f }; + float[] embedding2 = new[] { 0.5f, 0.6f }; + + string responseJson = CreateEmbeddingResponseJsonWithIndices(new[] + { + (0, embedding0), + (1, embedding1), + (2, embedding2) + }); + + Mock mockHandler = CreateMockHttpMessageHandler(HttpStatusCode.OK, responseJson); + HttpClient httpClient = new(mockHandler.Object); + + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + EmbeddingBatchResult result = await service.TryEmbedBatchAsync(new[] { "text0", "text1", "text2" }); + + // Assert + Assert.IsTrue(result.Success); + Assert.IsNotNull(result.Embeddings); + Assert.AreEqual(3, result.Embeddings.Length); + CollectionAssert.AreEqual(embedding0, result.Embeddings[0]); + CollectionAssert.AreEqual(embedding1, result.Embeddings[1]); + CollectionAssert.AreEqual(embedding2, result.Embeddings[2]); + } + + /// + /// Tests that TryEmbedBatchAsync returns failure when the service is disabled. + /// + [TestMethod] + public async Task TryEmbedBatchAsync_ReturnsFailure_WhenDisabled() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://test.openai.azure.com", + ApiKey: "test-api-key", + Enabled: false, + Model: "text-embedding-ada-002"); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + + // Act + EmbeddingBatchResult result = await service.TryEmbedBatchAsync(new[] { "text1" }); + + // Assert + Assert.IsFalse(result.Success); + Assert.IsNull(result.Embeddings); + Assert.IsNotNull(result.ErrorMessage); + } + + /// + /// Tests that TryEmbedBatchAsync returns failure for null texts array. + /// + [TestMethod] + public async Task TryEmbedBatchAsync_ReturnsFailure_ForNullTexts() + { + // Arrange + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + + // Act + EmbeddingBatchResult result = await service.TryEmbedBatchAsync(null!); + + // Assert + Assert.IsFalse(result.Success); + Assert.IsNull(result.Embeddings); + } + + /// + /// Tests that TryEmbedBatchAsync returns failure for empty texts array. + /// + [TestMethod] + public async Task TryEmbedBatchAsync_ReturnsFailure_ForEmptyTexts() + { + // Arrange + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + + // Act + EmbeddingBatchResult result = await service.TryEmbedBatchAsync(Array.Empty()); + + // Assert + Assert.IsFalse(result.Success); + Assert.IsNull(result.Embeddings); + } + + /// + /// Tests that TryEmbedBatchAsync returns failure when texts array exceeds max batch size. + /// + [TestMethod] + public async Task TryEmbedBatchAsync_ReturnsFailure_WhenTextsExceedMaxBatchSize() + { + // Arrange + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + string[] oversizedTexts = Enumerable.Repeat("chunk", EmbeddingService.MAX_BATCH_TEXT_COUNT + 1).ToArray(); + + // Act + EmbeddingBatchResult result = await service.TryEmbedBatchAsync(oversizedTexts); + + // Assert + Assert.IsFalse(result.Success); + Assert.IsNull(result.Embeddings); + Assert.IsNotNull(result.ErrorMessage); + StringAssert.Contains(result.ErrorMessage, EmbeddingService.MAX_BATCH_TEXT_COUNT.ToString()); + } + + /// + /// Tests that EmbedBatchAsync throws when the service is disabled. + /// + [TestMethod] + public async Task EmbedBatchAsync_Throws_WhenDisabled() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://test.openai.azure.com", + ApiKey: "test-api-key", + Enabled: false, + Model: "text-embedding-ada-002"); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + + // Act & Assert + await Assert.ThrowsExceptionAsync( + () => service.EmbedBatchAsync(new[] { "text1" })); + } + + /// + /// Tests that EmbedBatchAsync throws when texts array exceeds max batch size. + /// + [TestMethod] + public async Task EmbedBatchAsync_Throws_WhenTextsExceedMaxBatchSize() + { + // Arrange + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + string[] oversizedTexts = Enumerable.Repeat("chunk", EmbeddingService.MAX_BATCH_TEXT_COUNT + 1).ToArray(); + + // Act & Assert + ArgumentException exception = await Assert.ThrowsExceptionAsync( + () => service.EmbedBatchAsync(oversizedTexts)); + StringAssert.Contains(exception.Message, EmbeddingService.MAX_BATCH_TEXT_COUNT.ToString()); + } + + /// + /// Tests that EmbedBatchAsync uses cached results for previously embedded texts + /// and only calls the API for uncached texts. + /// + [TestMethod] + public async Task EmbedBatchAsync_OnlyCallsApiForUncachedTexts() + { + // Arrange + float[] embeddingA = new[] { 0.1f, 0.2f }; + float[] embeddingB = new[] { 0.3f, 0.4f }; + float[] embeddingC = new[] { 0.5f, 0.6f }; + + int apiCallCount = 0; + Mock mockHandler = new(MockBehavior.Strict); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync((HttpRequestMessage request, CancellationToken _) => + { + apiCallCount++; + string body = request.Content!.ReadAsStringAsync().Result; + + string json; + if (apiCallCount == 1) + { + // First call embeds "textA" via TryEmbedAsync + json = CreateEmbeddingResponseJson(embeddingA); + } + else + { + // Second call should only embed "textB" and "textC" (textA is cached) + json = CreateEmbeddingResponseJsonWithIndices(new[] + { + (0, embeddingB), + (1, embeddingC) + }); + } + + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(json, Encoding.UTF8, "application/json") + }; + }); + + HttpClient httpClient = new(mockHandler.Object); + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // First: embed "textA" so it's cached + EmbeddingResult preResult = await service.TryEmbedAsync("textA"); + Assert.IsTrue(preResult.Success); + Assert.AreEqual(1, apiCallCount); + + // Act: batch embed ["textA", "textB", "textC"] - textA should come from cache + float[][] batchResults = await service.EmbedBatchAsync(new[] { "textA", "textB", "textC" }); + + // Assert + Assert.AreEqual(2, apiCallCount, "Only 1 additional API call should be made for the 2 uncached texts."); + Assert.AreEqual(3, batchResults.Length); + CollectionAssert.AreEqual(embeddingA, batchResults[0], "textA should come from cache."); + CollectionAssert.AreEqual(embeddingB, batchResults[1]); + CollectionAssert.AreEqual(embeddingC, batchResults[2]); + } + + #endregion + + #region Provider-Specific URL Construction Tests + + /// + /// Tests that the Azure OpenAI provider constructs the correct URL format: + /// {baseUrl}/openai/deployments/{deployment}/embeddings?api-version={version} + /// + [TestMethod] + public async Task AzureOpenAI_BuildsCorrectRequestUrl() + { + // Arrange + float[] expectedEmbedding = new[] { 0.1f, 0.2f }; + string responseJson = CreateEmbeddingResponseJson(expectedEmbedding); + + Uri capturedUri = null!; + Mock mockHandler = new(MockBehavior.Strict); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync((HttpRequestMessage request, CancellationToken _) => + { + capturedUri = request.RequestUri!; + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(responseJson, Encoding.UTF8, "application/json") + }; + }); + + HttpClient httpClient = new(mockHandler.Object); + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://myservice.openai.azure.com", + ApiKey: "test-key", + Enabled: true, + Model: "my-deployment", + ApiVersion: "2024-06-01"); + + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + await service.TryEmbedAsync("test"); + + // Assert + Assert.IsNotNull(capturedUri); + Assert.AreEqual( + "https://myservice.openai.azure.com/openai/deployments/my-deployment/embeddings?api-version=2024-06-01", + capturedUri.ToString()); + } + + /// + /// Tests that the OpenAI provider constructs the correct URL format: + /// {baseUrl}/v1/embeddings + /// + [TestMethod] + public async Task OpenAI_BuildsCorrectRequestUrl() + { + // Arrange + float[] expectedEmbedding = new[] { 0.1f, 0.2f }; + string responseJson = CreateEmbeddingResponseJson(expectedEmbedding); + + Uri capturedUri = null!; + Mock mockHandler = new(MockBehavior.Strict); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync((HttpRequestMessage request, CancellationToken _) => + { + capturedUri = request.RequestUri!; + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(responseJson, Encoding.UTF8, "application/json") + }; + }); + + HttpClient httpClient = new(mockHandler.Object); + EmbeddingsOptions options = CreateOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + await service.TryEmbedAsync("test"); + + // Assert + Assert.IsNotNull(capturedUri); + Assert.AreEqual("https://api.openai.com/v1/embeddings", capturedUri.ToString()); + } + + /// + /// Tests that Azure OpenAI uses the default API version when none is specified. + /// + [TestMethod] + public async Task AzureOpenAI_UsesDefaultApiVersion_WhenNotSpecified() + { + // Arrange + float[] expectedEmbedding = new[] { 0.1f, 0.2f }; + string responseJson = CreateEmbeddingResponseJson(expectedEmbedding); + + Uri capturedUri = null!; + Mock mockHandler = new(MockBehavior.Strict); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync((HttpRequestMessage request, CancellationToken _) => + { + capturedUri = request.RequestUri!; + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(responseJson, Encoding.UTF8, "application/json") + }; + }); + + HttpClient httpClient = new(mockHandler.Object); + EmbeddingsOptions options = CreateAzureOpenAIOptions(); // no explicit api-version + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + await service.TryEmbedAsync("test"); + + // Assert + Assert.IsNotNull(capturedUri); + Assert.IsTrue(capturedUri.ToString().Contains($"api-version={EmbeddingsOptions.DEFAULT_AZURE_API_VERSION}")); + } + + #endregion + + #region Request Body Building Tests + + /// + /// Tests that the OpenAI request body includes the model name. + /// + [TestMethod] + public async Task OpenAI_RequestBody_IncludesModel() + { + // Arrange + float[] expectedEmbedding = new[] { 0.1f, 0.2f }; + string responseJson = CreateEmbeddingResponseJson(expectedEmbedding); + + string capturedRequestBody = null!; + Mock mockHandler = new(MockBehavior.Strict); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync((HttpRequestMessage request, CancellationToken _) => + { + capturedRequestBody = request.Content!.ReadAsStringAsync().Result; + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(responseJson, Encoding.UTF8, "application/json") + }; + }); + + HttpClient httpClient = new(mockHandler.Object); + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key", + Enabled: true, + Model: "text-embedding-3-large"); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + await service.TryEmbedAsync("test"); + + // Assert + Assert.IsNotNull(capturedRequestBody); + using JsonDocument doc = JsonDocument.Parse(capturedRequestBody); + Assert.IsTrue(doc.RootElement.TryGetProperty("model", out JsonElement modelElement)); + Assert.AreEqual("text-embedding-3-large", modelElement.GetString()); + } + + /// + /// Tests that the Azure OpenAI request body does NOT include the model name + /// (it's in the URL as the deployment name instead). + /// + [TestMethod] + public async Task AzureOpenAI_RequestBody_DoesNotIncludeModel() + { + // Arrange + float[] expectedEmbedding = new[] { 0.1f, 0.2f }; + string responseJson = CreateEmbeddingResponseJson(expectedEmbedding); + + string capturedRequestBody = null!; + Mock mockHandler = new(MockBehavior.Strict); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync((HttpRequestMessage request, CancellationToken _) => + { + capturedRequestBody = request.Content!.ReadAsStringAsync().Result; + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(responseJson, Encoding.UTF8, "application/json") + }; + }); + + HttpClient httpClient = new(mockHandler.Object); + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + await service.TryEmbedAsync("test"); + + // Assert + Assert.IsNotNull(capturedRequestBody); + using JsonDocument doc = JsonDocument.Parse(capturedRequestBody); + Assert.IsFalse(doc.RootElement.TryGetProperty("model", out _), + "Azure OpenAI request body should not contain 'model' property."); + } + + /// + /// Tests that dimensions are included in the request body when specified. + /// + [TestMethod] + public async Task RequestBody_IncludesDimensions_WhenSpecified() + { + // Arrange + float[] expectedEmbedding = new[] { 0.1f, 0.2f }; + string responseJson = CreateEmbeddingResponseJson(expectedEmbedding); + + string capturedRequestBody = null!; + Mock mockHandler = new(MockBehavior.Strict); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync((HttpRequestMessage request, CancellationToken _) => + { + capturedRequestBody = request.Content!.ReadAsStringAsync().Result; + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(responseJson, Encoding.UTF8, "application/json") + }; + }); + + HttpClient httpClient = new(mockHandler.Object); + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key", + Enabled: true, + Dimensions: 256); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + await service.TryEmbedAsync("test"); + + // Assert + Assert.IsNotNull(capturedRequestBody); + using JsonDocument doc = JsonDocument.Parse(capturedRequestBody); + Assert.IsTrue(doc.RootElement.TryGetProperty("dimensions", out JsonElement dimElement)); + Assert.AreEqual(256, dimElement.GetInt32()); + } + + /// + /// Tests that dimensions are NOT included in the request body when not specified. + /// + [TestMethod] + public async Task RequestBody_ExcludesDimensions_WhenNotSpecified() + { + // Arrange + float[] expectedEmbedding = new[] { 0.1f, 0.2f }; + string responseJson = CreateEmbeddingResponseJson(expectedEmbedding); + + string capturedRequestBody = null!; + Mock mockHandler = new(MockBehavior.Strict); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync((HttpRequestMessage request, CancellationToken _) => + { + capturedRequestBody = request.Content!.ReadAsStringAsync().Result; + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(responseJson, Encoding.UTF8, "application/json") + }; + }); + + HttpClient httpClient = new(mockHandler.Object); + EmbeddingsOptions options = CreateAzureOpenAIOptions(); // no dimensions + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + await service.TryEmbedAsync("test"); + + // Assert + Assert.IsNotNull(capturedRequestBody); + using JsonDocument doc = JsonDocument.Parse(capturedRequestBody); + Assert.IsFalse(doc.RootElement.TryGetProperty("dimensions", out _), + "Request body should not contain 'dimensions' when not specified."); + } + + /// + /// Tests that a single text is sent as a string (not an array) in the request body. + /// + [TestMethod] + public async Task RequestBody_SendsSingleTextAsString() + { + // Arrange + float[] expectedEmbedding = new[] { 0.1f, 0.2f }; + string responseJson = CreateEmbeddingResponseJson(expectedEmbedding); + + string capturedRequestBody = null!; + Mock mockHandler = new(MockBehavior.Strict); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync((HttpRequestMessage request, CancellationToken _) => + { + capturedRequestBody = request.Content!.ReadAsStringAsync().Result; + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(responseJson, Encoding.UTF8, "application/json") + }; + }); + + HttpClient httpClient = new(mockHandler.Object); + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + await service.TryEmbedAsync("single text input"); + + // Assert + Assert.IsNotNull(capturedRequestBody); + using JsonDocument doc = JsonDocument.Parse(capturedRequestBody); + Assert.IsTrue(doc.RootElement.TryGetProperty("input", out JsonElement inputElement)); + Assert.AreEqual(JsonValueKind.String, inputElement.ValueKind, + "Single text should be sent as a string, not an array."); + Assert.AreEqual("single text input", inputElement.GetString()); + } + + /// + /// Tests that multiple texts in a batch are sent as an array in the request body. + /// + [TestMethod] + public async Task RequestBody_SendsBatchTextsAsArray() + { + // Arrange + string responseJson = CreateEmbeddingResponseJsonWithIndices(new[] + { + (0, new[] { 0.1f, 0.2f }), + (1, new[] { 0.3f, 0.4f }) + }); + + string capturedRequestBody = null!; + Mock mockHandler = new(MockBehavior.Strict); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync((HttpRequestMessage request, CancellationToken _) => + { + capturedRequestBody = request.Content!.ReadAsStringAsync().Result; + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(responseJson, Encoding.UTF8, "application/json") + }; + }); + + HttpClient httpClient = new(mockHandler.Object); + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + await service.TryEmbedBatchAsync(new[] { "text one", "text two" }); + + // Assert + Assert.IsNotNull(capturedRequestBody); + using JsonDocument doc = JsonDocument.Parse(capturedRequestBody); + Assert.IsTrue(doc.RootElement.TryGetProperty("input", out JsonElement inputElement)); + Assert.AreEqual(JsonValueKind.Array, inputElement.ValueKind, + "Batch texts should be sent as an array."); + Assert.AreEqual(2, inputElement.GetArrayLength()); + } + + #endregion + + #region Authentication Header Tests + + /// + /// Tests that Azure OpenAI uses the api-key header for authentication. + /// + [TestMethod] + public async Task AzureOpenAI_UsesApiKeyHeader() + { + // Arrange + float[] expectedEmbedding = new[] { 0.1f, 0.2f }; + string responseJson = CreateEmbeddingResponseJson(expectedEmbedding); + + HttpRequestMessage capturedRequest = null!; + Mock mockHandler = new(MockBehavior.Strict); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync((HttpRequestMessage request, CancellationToken _) => + { + capturedRequest = request; + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(responseJson, Encoding.UTF8, "application/json") + }; + }); + + HttpClient httpClient = new(mockHandler.Object); + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://test.openai.azure.com", + ApiKey: "my-azure-key", + Model: "text-embedding-ada-002"); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + await service.TryEmbedAsync("test"); + + // Assert + Assert.IsTrue(httpClient.DefaultRequestHeaders.Contains("api-key"), + "Azure OpenAI should use api-key header."); + IEnumerable values = httpClient.DefaultRequestHeaders.GetValues("api-key"); + Assert.AreEqual("my-azure-key", values.First()); + } + + /// + /// Tests that OpenAI uses the Bearer token Authorization header. + /// + [TestMethod] + public async Task OpenAI_UsesBearerAuthorizationHeader() + { + // Arrange + float[] expectedEmbedding = new[] { 0.1f, 0.2f }; + string responseJson = CreateEmbeddingResponseJson(expectedEmbedding); + + Mock mockHandler = CreateMockHttpMessageHandler(HttpStatusCode.OK, responseJson); + HttpClient httpClient = new(mockHandler.Object); + + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "my-openai-key"); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + await service.TryEmbedAsync("test"); + + // Assert + Assert.IsNotNull(httpClient.DefaultRequestHeaders.Authorization); + Assert.AreEqual("Bearer", httpClient.DefaultRequestHeaders.Authorization.Scheme); + Assert.AreEqual("my-openai-key", httpClient.DefaultRequestHeaders.Authorization.Parameter); + } + + #endregion + + #region Timeout Tests + + /// + /// Tests that TryEmbedAsync returns failure when the HTTP request times out. + /// + [TestMethod] + public async Task TryEmbedAsync_ReturnsFailure_OnTimeout() + { + // Arrange + Mock mockHandler = new(MockBehavior.Strict); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ThrowsAsync(new TaskCanceledException("The request was canceled due to the configured HttpClient.Timeout.")); + + HttpClient httpClient = new(mockHandler.Object); + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + using IFusionCache cache = new FusionCache(new FusionCacheOptions()); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, cache); + + // Act + EmbeddingResult result = await service.TryEmbedAsync("test text"); + + // Assert + Assert.IsFalse(result.Success); + Assert.IsNull(result.Embedding); + Assert.IsNotNull(result.ErrorMessage); + } + + /// + /// Tests that the HttpClient timeout is set from the EmbeddingsOptions configuration. + /// + [TestMethod] + public void Constructor_SetsHttpClientTimeout_FromOptions() + { + // Arrange + int customTimeoutMs = 15000; + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://test.openai.azure.com", + ApiKey: "test-key", + Model: "text-embedding-ada-002", + TimeoutMs: customTimeoutMs); + HttpClient httpClient = new(); + + // Act + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + + // Assert + Assert.AreEqual(TimeSpan.FromMilliseconds(customTimeoutMs), httpClient.Timeout); + } + + #endregion + + #region Constructor Validation Tests + + /// + /// Tests that constructor throws when BaseUrl is empty. + /// + [TestMethod] + public void Constructor_Throws_WhenBaseUrlIsEmpty() + { + // Arrange & Act & Assert + Assert.ThrowsException(() => + new EmbeddingService( + new HttpClient(), + new EmbeddingsOptions( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "", + ApiKey: "key"), + _mockLogger.Object, + _mockCache.Object)); + } + + /// + /// Tests that constructor throws when ApiKey is empty. + /// + [TestMethod] + public void Constructor_Throws_WhenApiKeyIsEmpty() + { + // Arrange & Act & Assert + Assert.ThrowsException(() => + new EmbeddingService( + new HttpClient(), + new EmbeddingsOptions( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: ""), + _mockLogger.Object, + _mockCache.Object)); + } + + /// + /// Tests that constructor throws when Azure OpenAI provider is used without a model. + /// + [TestMethod] + public void Constructor_Throws_WhenAzureOpenAIHasNoModel() + { + // Arrange & Act & Assert + Assert.ThrowsException(() => + new EmbeddingService( + new HttpClient(), + new EmbeddingsOptions( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://test.openai.azure.com", + ApiKey: "key"), + _mockLogger.Object, + _mockCache.Object)); + } + + #endregion + + #region Helper Methods + + private static EmbeddingsOptions CreateAzureOpenAIOptions() + { + return new EmbeddingsOptions( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://test.openai.azure.com", + ApiKey: "test-api-key", + Enabled: true, + Model: "text-embedding-ada-002"); + } + + private static EmbeddingsOptions CreateOpenAIOptions() + { + return new EmbeddingsOptions( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-api-key", + Enabled: true); + } + + /// + /// Creates a mock HttpMessageHandler that returns the specified status code and response body. + /// + private static Mock CreateMockHttpMessageHandler(HttpStatusCode statusCode, string responseBody) + { + Mock mockHandler = new(MockBehavior.Strict); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage(statusCode) + { + Content = new StringContent(responseBody, Encoding.UTF8, "application/json") + }); + + return mockHandler; + } + + /// + /// Creates an embedding API response JSON with a single embedding at index 0. + /// + private static string CreateEmbeddingResponseJson(float[] embedding) + { + return CreateEmbeddingResponseJsonWithIndices(new[] { (0, embedding) }); + } + + /// + /// Creates an embedding API response JSON with multiple embeddings at specified indices. + /// + private static string CreateEmbeddingResponseJsonWithIndices((int Index, float[] Embedding)[] embeddings) + { + var data = embeddings.Select(e => new + { + index = e.Index, + embedding = e.Embedding, + @object = "embedding" + }).ToArray(); + + var response = new + { + data, + model = "text-embedding-ada-002", + @object = "list", + usage = new + { + prompt_tokens = 5, + total_tokens = 5 + } + }; + + return JsonSerializer.Serialize(response, new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase + }); + } + + #endregion +} diff --git a/src/Service.Tests/UnitTests/EmbeddingsChunkingOptionsTests.cs b/src/Service.Tests/UnitTests/EmbeddingsChunkingOptionsTests.cs new file mode 100644 index 0000000000..c647258fe2 --- /dev/null +++ b/src/Service.Tests/UnitTests/EmbeddingsChunkingOptionsTests.cs @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests; + +/// +/// Unit tests for EmbeddingsChunkingOptions configuration class. +/// +[TestClass] +public class EmbeddingsChunkingOptionsTests +{ + /// + /// Tests that default values are correctly set. + /// + [TestMethod] + public void Constructor_SetsDefaultValues() + { + // Arrange & Act + EmbeddingsChunkingOptions options = new(Enabled: true); + + // Assert + Assert.IsTrue(options.Enabled); + Assert.AreEqual(EmbeddingsChunkingOptions.DEFAULT_SIZE_CHARS, options.SizeChars); + Assert.AreEqual(EmbeddingsChunkingOptions.DEFAULT_OVERLAP_CHARS, options.OverlapChars); + } + + /// + /// Tests that custom values override defaults. + /// + [TestMethod] + public void Constructor_SetsCustomValues() + { + // Arrange & Act + EmbeddingsChunkingOptions options = new( + Enabled: true, + SizeChars: 500, + OverlapChars: 100); + + // Assert + Assert.IsTrue(options.Enabled); + Assert.AreEqual(500, options.SizeChars); + Assert.AreEqual(100, options.OverlapChars); + } + + /// + /// Tests that EffectiveSizeChars returns configured value when valid. + /// + [TestMethod] + public void EffectiveSizeChars_ReturnsConfiguredValue_WhenValid() + { + // Arrange + EmbeddingsChunkingOptions options = new( + Enabled: true, + SizeChars: 750, + OverlapChars: 50); + + // Act + int effectiveSize = options.EffectiveSizeChars; + + // Assert + Assert.AreEqual(750, effectiveSize); + } + + /// + /// Tests that EffectiveSizeChars ensures size is at least overlap+1 when value is too small. + /// + [TestMethod] + public void EffectiveSizeChars_ReturnsMinimumValid_WhenValueTooSmall() + { + // Arrange + EmbeddingsChunkingOptions options = new( + Enabled: true, + SizeChars: 0, + OverlapChars: 50); + + // Act + int effectiveSize = options.EffectiveSizeChars; + + // Assert - should be at least overlap + 1 + Assert.AreEqual(51, effectiveSize); + } + + /// + /// Tests that EffectiveSizeChars ensures size is at least overlap+1 when value is negative. + /// + [TestMethod] + public void EffectiveSizeChars_ReturnsMinimumValid_WhenValueNegative() + { + // Arrange + EmbeddingsChunkingOptions options = new( + Enabled: true, + SizeChars: -100, + OverlapChars: 50); + + // Act + int effectiveSize = options.EffectiveSizeChars; + + // Assert - should be at least overlap + 1 + Assert.AreEqual(51, effectiveSize); + } + + /// + /// Tests that disabled chunking still has valid configuration. + /// + [TestMethod] + public void Constructor_AllowsDisabledChunking() + { + // Arrange & Act + EmbeddingsChunkingOptions options = new( + Enabled: false, + SizeChars: 500, + OverlapChars: 100); + + // Assert + Assert.IsFalse(options.Enabled); + Assert.AreEqual(500, options.SizeChars); + Assert.AreEqual(100, options.OverlapChars); + } + + /// + /// Tests that zero overlap is valid. + /// + [TestMethod] + public void Constructor_AllowsZeroOverlap() + { + // Arrange & Act + EmbeddingsChunkingOptions options = new( + Enabled: true, + SizeChars: 1000, + OverlapChars: 0); + + // Assert + Assert.AreEqual(0, options.OverlapChars); + } + + /// + /// Tests that negative overlap is clamped to zero. + /// + [TestMethod] + public void Constructor_NegativeOverlapClampedToZero() + { + // Arrange & Act + EmbeddingsChunkingOptions options = new( + Enabled: true, + SizeChars: 1000, + OverlapChars: -50); + + // Assert: negative overlap must be clamped to 0 + Assert.AreEqual(0, options.OverlapChars); + } + + /// + /// Tests that very large chunk sizes are accepted. + /// + [TestMethod] + public void Constructor_AllowsLargeChunkSize() + { + // Arrange & Act + EmbeddingsChunkingOptions options = new( + Enabled: true, + SizeChars: 100000, + OverlapChars: 1000); + + // Assert + Assert.AreEqual(100000, options.SizeChars); + Assert.AreEqual(100000, options.EffectiveSizeChars); + } + + /// + /// Tests that overlap can be larger than chunk size (edge case). + /// + [TestMethod] + public void Constructor_AllowsOverlapLargerThanChunkSize() + { + // Arrange & Act + EmbeddingsChunkingOptions options = new( + Enabled: true, + SizeChars: 100, + OverlapChars: 200); + + // Assert + Assert.AreEqual(100, options.SizeChars); + Assert.AreEqual(200, options.OverlapChars); + } +} diff --git a/src/Service.Tests/UnitTests/EmbeddingsHealthCheckTests.cs b/src/Service.Tests/UnitTests/EmbeddingsHealthCheckTests.cs new file mode 100644 index 0000000000..8a7991287e --- /dev/null +++ b/src/Service.Tests/UnitTests/EmbeddingsHealthCheckTests.cs @@ -0,0 +1,661 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config; +using Azure.DataApiBuilder.Config.HealthCheck; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Services.Embeddings; +using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Service.HealthCheck; +using Microsoft.Extensions.Logging; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests; + +/// +/// Unit tests for the embeddings health check logic in . +/// The private method UpdateEmbeddingsHealthCheckResultsAsync is tested indirectly +/// through the public method. +/// Data source and entity health checks are disabled to isolate embeddings health check behavior. +/// +[TestClass] +public class EmbeddingsHealthCheckTests +{ + private Mock> _mockLogger = null!; + private Mock _mockEmbeddingService = null!; + private HttpUtilities _httpUtilities = null!; + + private const string TIME_EXCEEDED_ERROR_MESSAGE = "The threshold for executing the request has exceeded."; + private const string DIMENSIONS_MISMATCH_ERROR_MESSAGE = "The embedding dimensions do not match the expected dimensions."; + + [TestInitialize] + public void Setup() + { + _mockLogger = new Mock>(); + _mockEmbeddingService = new Mock(); + + // Create HttpUtilities with mocked dependencies. + // HttpUtilities won't be called since data source and entity health checks are disabled. + Mock> httpLogger = new(); + Mock metadataProviderFactory = new(); + Mock mockLoader = new(null, null); + Mock mockConfigProvider = new(mockLoader.Object); + Mock mockHttpClientFactory = new(); + mockHttpClientFactory + .Setup(f => f.CreateClient(It.IsAny())) + .Returns(new HttpClient { BaseAddress = new Uri("http://localhost:5000") }); + + _httpUtilities = new HttpUtilities( + httpLogger.Object, + metadataProviderFactory.Object, + mockConfigProvider.Object, + mockHttpClientFactory.Object); + } + + #region Healthy Scenarios + + /// + /// Validates that when embedding succeeds within threshold and no dimension check is configured, + /// the health check entry reports Healthy status. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_ReportsHealthy_WhenEmbeddingSucceedsWithinThreshold() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f, 0.3f }; + SetupSuccessfulEmbedding(embedding); + + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig(enabled: true, thresholdMs: 60000)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + HealthCheckResultEntry embeddingCheck = GetEmbeddingCheck(report); + Assert.AreEqual(HealthStatus.Healthy, embeddingCheck.Status); + Assert.AreEqual("embeddings", embeddingCheck.Name); + Assert.IsNull(embeddingCheck.Exception); + Assert.IsNotNull(embeddingCheck.ResponseTimeData); + Assert.IsTrue(embeddingCheck.ResponseTimeData!.ResponseTimeMs >= 0); + Assert.AreEqual(60000, embeddingCheck.ResponseTimeData.ThresholdMs); + CollectionAssert.Contains(embeddingCheck.Tags!, HealthCheckConstants.EMBEDDING); + } + + /// + /// Validates that when embedding succeeds and dimensions match the expected value, + /// the health check entry reports Healthy status. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_ReportsHealthy_WhenDimensionsMatch() + { + // Arrange + float[] embedding = new[] { 0.1f, 0.2f, 0.3f }; + SetupSuccessfulEmbedding(embedding); + + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig( + enabled: true, + thresholdMs: 60000, + expectedDimensions: 3)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + HealthCheckResultEntry embeddingCheck = GetEmbeddingCheck(report); + Assert.AreEqual(HealthStatus.Healthy, embeddingCheck.Status); + Assert.IsNull(embeddingCheck.Exception); + } + + /// + /// Validates that the overall report status is Healthy when the only check is a healthy embedding check. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_OverallStatusHealthy_WhenEmbeddingCheckIsHealthy() + { + // Arrange + SetupSuccessfulEmbedding(new[] { 0.1f }); + + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig(enabled: true, thresholdMs: 60000)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + Assert.AreEqual(HealthStatus.Healthy, report.Status); + } + + #endregion + + #region Unhealthy - Time Exceeded + + /// + /// Validates that when the response time exceeds the threshold, + /// the health check entry reports Unhealthy status with the time exceeded error message. + /// Uses a threshold of -1 to guarantee the threshold is always exceeded regardless of execution time. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_ReportsUnhealthy_WhenResponseTimeExceedsThreshold() + { + // Arrange + SetupSuccessfulEmbedding(new[] { 0.1f }); + + // Threshold of -1 guarantees any response time (>=0) will exceed the threshold. + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig(enabled: true, thresholdMs: -1)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + HealthCheckResultEntry embeddingCheck = GetEmbeddingCheck(report); + Assert.AreEqual(HealthStatus.Unhealthy, embeddingCheck.Status); + Assert.IsNotNull(embeddingCheck.Exception); + Assert.IsTrue(embeddingCheck.Exception!.Contains(TIME_EXCEEDED_ERROR_MESSAGE)); + } + + /// + /// Validates that the overall report status is Unhealthy when the embedding check is unhealthy. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_OverallStatusUnhealthy_WhenEmbeddingCheckIsUnhealthy() + { + // Arrange + SetupSuccessfulEmbedding(new[] { 0.1f }); + + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig(enabled: true, thresholdMs: -1)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + Assert.AreEqual(HealthStatus.Unhealthy, report.Status); + } + + #endregion + + #region Unhealthy - Dimensions Mismatch + + /// + /// Validates that when the embedding dimensions don't match the expected value, + /// the health check entry reports Unhealthy status with the dimensions mismatch error message. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_ReportsUnhealthy_WhenDimensionsMismatch() + { + // Arrange: Embedding returns 3 dimensions but config expects 5 + float[] embedding = new[] { 0.1f, 0.2f, 0.3f }; + SetupSuccessfulEmbedding(embedding); + + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig( + enabled: true, + thresholdMs: 60000, + expectedDimensions: 5)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + HealthCheckResultEntry embeddingCheck = GetEmbeddingCheck(report); + Assert.AreEqual(HealthStatus.Unhealthy, embeddingCheck.Status); + Assert.IsNotNull(embeddingCheck.Exception); + Assert.IsTrue(embeddingCheck.Exception!.Contains(DIMENSIONS_MISMATCH_ERROR_MESSAGE)); + Assert.IsTrue(embeddingCheck.Exception.Contains("Expected: 5")); + Assert.IsTrue(embeddingCheck.Exception.Contains("Actual: 3")); + // Response time should still be recorded (not ERROR_RESPONSE_TIME_MS) + Assert.IsTrue(embeddingCheck.ResponseTimeData!.ResponseTimeMs >= 0); + } + + #endregion + + #region Unhealthy - Combined Failures + + /// + /// Validates that when both dimensions mismatch and response time exceeds the threshold, + /// the health check entry reports Unhealthy with both error messages combined. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_ReportsUnhealthy_WhenBothDimensionsMismatchAndTimeExceeded() + { + // Arrange: 3 dimensions, but expect 10; threshold of -1 guarantees time exceeded + float[] embedding = new[] { 0.1f, 0.2f, 0.3f }; + SetupSuccessfulEmbedding(embedding); + + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig( + enabled: true, + thresholdMs: -1, + expectedDimensions: 10)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + HealthCheckResultEntry embeddingCheck = GetEmbeddingCheck(report); + Assert.AreEqual(HealthStatus.Unhealthy, embeddingCheck.Status); + Assert.IsNotNull(embeddingCheck.Exception); + Assert.IsTrue(embeddingCheck.Exception!.Contains(DIMENSIONS_MISMATCH_ERROR_MESSAGE)); + Assert.IsTrue(embeddingCheck.Exception.Contains(TIME_EXCEEDED_ERROR_MESSAGE)); + } + + #endregion + + #region Unhealthy - Embedding Failure + + /// + /// Validates that when the embedding service returns a failure with an error message, + /// the health check entry reports Unhealthy with the error message and ERROR_RESPONSE_TIME_MS. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_ReportsUnhealthy_WhenEmbeddingFails() + { + // Arrange + _mockEmbeddingService + .Setup(s => s.TryEmbedAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new EmbeddingResult(false, null, "Provider API returned 401 Unauthorized.")); + + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig(enabled: true, thresholdMs: 60000)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + HealthCheckResultEntry embeddingCheck = GetEmbeddingCheck(report); + Assert.AreEqual(HealthStatus.Unhealthy, embeddingCheck.Status); + Assert.AreEqual("Provider API returned 401 Unauthorized.", embeddingCheck.Exception); + Assert.AreEqual(HealthCheckConstants.ERROR_RESPONSE_TIME_MS, embeddingCheck.ResponseTimeData!.ResponseTimeMs); + } + + /// + /// Validates that when the embedding service returns a failure with no error message, + /// the health check entry reports Unhealthy with the default "Embedding request failed." message. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_ReportsUnhealthy_WithDefaultErrorMessage_WhenNoErrorMessageProvided() + { + // Arrange + _mockEmbeddingService + .Setup(s => s.TryEmbedAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new EmbeddingResult(false, null, null)); + + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig(enabled: true, thresholdMs: 60000)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + HealthCheckResultEntry embeddingCheck = GetEmbeddingCheck(report); + Assert.AreEqual(HealthStatus.Unhealthy, embeddingCheck.Status); + Assert.AreEqual("Embedding request failed.", embeddingCheck.Exception); + Assert.AreEqual(HealthCheckConstants.ERROR_RESPONSE_TIME_MS, embeddingCheck.ResponseTimeData!.ResponseTimeMs); + } + + #endregion + + #region Unhealthy - Exception Handling + + /// + /// Validates that when the embedding service throws an exception, + /// the health check entry reports Unhealthy with the exception message and ERROR_RESPONSE_TIME_MS. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_ReportsUnhealthy_WhenExceptionThrown() + { + // Arrange + _mockEmbeddingService + .Setup(s => s.TryEmbedAsync(It.IsAny(), It.IsAny())) + .ThrowsAsync(new InvalidOperationException("Connection timed out.")); + + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig(enabled: true, thresholdMs: 60000)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + HealthCheckResultEntry embeddingCheck = GetEmbeddingCheck(report); + Assert.AreEqual(HealthStatus.Unhealthy, embeddingCheck.Status); + Assert.AreEqual("Connection timed out.", embeddingCheck.Exception); + Assert.AreEqual(HealthCheckConstants.ERROR_RESPONSE_TIME_MS, embeddingCheck.ResponseTimeData!.ResponseTimeMs); + CollectionAssert.Contains(embeddingCheck.Tags!, HealthCheckConstants.EMBEDDING); + } + + #endregion + + #region Skip Scenarios + + /// + /// Validates that when embeddings options are null, + /// no embedding health check entry is added to the report. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_Skipped_WhenEmbeddingsOptionsNull() + { + // Arrange + RuntimeConfig config = CreateRuntimeConfig(embeddingsOptions: null, embeddingsHealth: null); + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + Assert.IsFalse(HasEmbeddingCheck(report)); + } + + /// + /// Validates that when embeddings are disabled, + /// no embedding health check entry is added to the report. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_Skipped_WhenEmbeddingsDisabled() + { + // Arrange + RuntimeConfig config = CreateRuntimeConfig( + embeddingsEnabled: false, + embeddingsHealth: new EmbeddingsHealthCheckConfig(enabled: true)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + Assert.IsFalse(HasEmbeddingCheck(report)); + } + + /// + /// Validates that when the embeddings health check config is null, + /// no embedding health check entry is added to the report. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_Skipped_WhenHealthConfigNull() + { + // Arrange + RuntimeConfig config = CreateRuntimeConfig(embeddingsHealth: null); + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + Assert.IsFalse(HasEmbeddingCheck(report)); + } + + /// + /// Validates that when the embeddings health check is explicitly disabled, + /// no embedding health check entry is added to the report. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_Skipped_WhenHealthCheckDisabled() + { + // Arrange + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig(enabled: false)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + Assert.IsFalse(HasEmbeddingCheck(report)); + } + + /// + /// Validates that when the embedding service is null, + /// no embedding health check entry is added to the report. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_Skipped_WhenEmbeddingServiceNull() + { + // Arrange + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig(enabled: true)); + + HealthCheckHelper helper = new(_mockLogger.Object, _httpUtilities, embeddingService: null); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + Assert.IsFalse(HasEmbeddingCheck(report)); + } + + #endregion + + #region Test Text Validation + + /// + /// Validates that the configured test text is passed to the embedding service. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_UsesConfiguredTestText() + { + // Arrange + string customTestText = "custom health check text"; + _mockEmbeddingService + .Setup(s => s.TryEmbedAsync(customTestText, It.IsAny())) + .ReturnsAsync(new EmbeddingResult(true, new[] { 0.1f })); + + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig( + enabled: true, + thresholdMs: 60000, + testText: customTestText)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + _mockEmbeddingService.Verify( + s => s.TryEmbedAsync(customTestText, It.IsAny()), + Times.Once()); + } + + /// + /// Validates that the default test text is used when no custom test text is configured. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_UsesDefaultTestText_WhenNotConfigured() + { + // Arrange + _mockEmbeddingService + .Setup(s => s.TryEmbedAsync(EmbeddingsHealthCheckConfig.DEFAULT_TEST_TEXT, It.IsAny())) + .ReturnsAsync(new EmbeddingResult(true, new[] { 0.1f })); + + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig(enabled: true, thresholdMs: 60000)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + _mockEmbeddingService.Verify( + s => s.TryEmbedAsync(EmbeddingsHealthCheckConfig.DEFAULT_TEST_TEXT, It.IsAny()), + Times.Once()); + } + + #endregion + + #region Tags Validation + + /// + /// Validates that the embedding health check entry always includes the "embedding" tag. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_AlwaysIncludesEmbeddingTag() + { + // Arrange + SetupSuccessfulEmbedding(new[] { 0.1f }); + + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig(enabled: true, thresholdMs: 60000)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + HealthCheckResultEntry embeddingCheck = GetEmbeddingCheck(report); + Assert.IsNotNull(embeddingCheck.Tags); + Assert.AreEqual(1, embeddingCheck.Tags!.Count); + Assert.AreEqual(HealthCheckConstants.EMBEDDING, embeddingCheck.Tags[0]); + } + + /// + /// Validates that even on failure, the embedding health check entry includes the "embedding" tag. + /// + [TestMethod] + public async Task EmbeddingsHealthCheck_IncludesEmbeddingTag_OnFailure() + { + // Arrange + _mockEmbeddingService + .Setup(s => s.TryEmbedAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new EmbeddingResult(false, null, "Error")); + + RuntimeConfig config = CreateRuntimeConfig( + embeddingsHealth: new EmbeddingsHealthCheckConfig(enabled: true, thresholdMs: 60000)); + + HealthCheckHelper helper = CreateHealthCheckHelper(); + + // Act + ComprehensiveHealthCheckReport report = await helper.GetHealthCheckResponseAsync(config, "", ""); + + // Assert + HealthCheckResultEntry embeddingCheck = GetEmbeddingCheck(report); + CollectionAssert.Contains(embeddingCheck.Tags!, HealthCheckConstants.EMBEDDING); + } + + #endregion + + #region Helper Methods + + /// + /// Sets up the mock embedding service to return a successful result with the given embedding. + /// + private void SetupSuccessfulEmbedding(float[] embedding) + { + _mockEmbeddingService + .Setup(s => s.TryEmbedAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new EmbeddingResult(true, embedding)); + } + + /// + /// Creates a using the class-level mocked dependencies. + /// + private HealthCheckHelper CreateHealthCheckHelper() + { + return new HealthCheckHelper(_mockLogger.Object, _httpUtilities, _mockEmbeddingService.Object); + } + + /// + /// Creates a with data source and entity health checks disabled + /// to isolate embeddings health check behavior. + /// + /// The embeddings health check config. Pass null to omit. + /// Override the entire EmbeddingsOptions. When provided, embeddingsHealth and embeddingsEnabled are ignored. + /// Whether embeddings are enabled. Defaults to true. + private static RuntimeConfig CreateRuntimeConfig( + EmbeddingsHealthCheckConfig? embeddingsHealth = null, + EmbeddingsOptions? embeddingsOptions = null, + bool embeddingsEnabled = true) + { + // If embeddingsOptions is not explicitly provided, build one from parameters + if (embeddingsOptions is null && (embeddingsHealth is not null || embeddingsEnabled)) + { + embeddingsOptions = new EmbeddingsOptions( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key", + Enabled: embeddingsEnabled, + Health: embeddingsHealth); + } + + DataSource dataSource = new( + DatabaseType.MSSQL, + "Server=localhost;Database=test;", + Options: null, + Health: new DatasourceHealthCheckConfig(enabled: false)); + + RuntimeOptions runtimeOptions = new( + Rest: new RestRuntimeOptions(Enabled: true), + GraphQL: new GraphQLRuntimeOptions(Enabled: true), + Mcp: new McpRuntimeOptions(Enabled: true), + Host: new HostOptions(Cors: null, Authentication: null, Mode: HostMode.Development), + Health: new RuntimeHealthCheckConfig(enabled: true), + Embeddings: embeddingsOptions); + + RuntimeEntities entities = new(new Dictionary()); + + return new RuntimeConfig( + Schema: null, + DataSource: dataSource, + Entities: entities, + Runtime: runtimeOptions); + } + + /// + /// Gets the embedding health check entry from the report, asserting it exists. + /// + private static HealthCheckResultEntry GetEmbeddingCheck(ComprehensiveHealthCheckReport report) + { + Assert.IsNotNull(report.Checks, "Checks should not be null."); + HealthCheckResultEntry? embeddingCheck = report.Checks! + .FirstOrDefault(c => c.Tags != null && c.Tags.Contains(HealthCheckConstants.EMBEDDING)); + Assert.IsNotNull(embeddingCheck, "Expected an embedding health check entry in the report."); + return embeddingCheck!; + } + + /// + /// Checks if the report contains an embedding health check entry. + /// + private static bool HasEmbeddingCheck(ComprehensiveHealthCheckReport report) + { + return report.Checks != null && + report.Checks.Any(c => c.Tags != null && c.Tags.Contains(HealthCheckConstants.EMBEDDING)); + } + + #endregion +} diff --git a/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs b/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs new file mode 100644 index 0000000000..663aa8b547 --- /dev/null +++ b/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs @@ -0,0 +1,385 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Text.Json; +using Azure.DataApiBuilder.Config; +using Azure.DataApiBuilder.Config.Converters; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests; + +/// +/// Unit tests for EmbeddingsOptions deserialization and EmbeddingProviderType enum. +/// +[TestClass] +public class EmbeddingsOptionsTests +{ + private const string BASIC_CONFIG_WITH_EMBEDDINGS = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""runtime"": { + ""embeddings"": { + ""provider"": ""azure-openai"", + ""base-url"": ""https://my-openai.openai.azure.com"", + ""api-key"": ""test-api-key"", + ""model"": ""text-embedding-ada-002"", + ""api-version"": ""2024-02-01"", + ""dimensions"": 1536, + ""timeout-ms"": 30000 + } + }, + ""entities"": {} + }"; + + private const string OPENAI_CONFIG = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""runtime"": { + ""embeddings"": { + ""provider"": ""openai"", + ""base-url"": ""https://api.openai.com"", + ""api-key"": ""sk-test-key"" + } + }, + ""entities"": {} + }"; + + private const string MINIMAL_AZURE_CONFIG = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""runtime"": { + ""embeddings"": { + ""provider"": ""azure-openai"", + ""base-url"": ""https://my-openai.openai.azure.com"", + ""api-key"": ""test-api-key"", + ""model"": ""my-deployment"" + } + }, + ""entities"": {} + }"; + + private const string CONFIG_WITHOUT_EMBEDDINGS = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""entities"": {} + }"; + + /// + /// Tests that Azure OpenAI embeddings configuration deserializes correctly. + /// + [TestMethod] + public void TestAzureOpenAIEmbeddingsConfigDeserialization() + { + // Act + bool success = RuntimeConfigLoader.TryParseConfig(BASIC_CONFIG_WITH_EMBEDDINGS, out RuntimeConfig? runtimeConfig); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(runtimeConfig); + Assert.IsNotNull(runtimeConfig.Runtime); + Assert.IsNotNull(runtimeConfig.Runtime.Embeddings); + + EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; + Assert.AreEqual(EmbeddingProviderType.AzureOpenAI, embeddings.Provider); + Assert.AreEqual("https://my-openai.openai.azure.com", embeddings.BaseUrl); + Assert.AreEqual("test-api-key", embeddings.ApiKey); + Assert.AreEqual("text-embedding-ada-002", embeddings.Model); + Assert.AreEqual("2024-02-01", embeddings.ApiVersion); + Assert.AreEqual(1536, embeddings.Dimensions); + Assert.AreEqual(30000, embeddings.TimeoutMs); + } + + /// + /// Tests that OpenAI embeddings configuration deserializes correctly with defaults. + /// + [TestMethod] + public void TestOpenAIEmbeddingsConfigWithDefaults() + { + // Act + bool success = RuntimeConfigLoader.TryParseConfig(OPENAI_CONFIG, out RuntimeConfig? runtimeConfig); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); + + EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; + Assert.AreEqual(EmbeddingProviderType.OpenAI, embeddings.Provider); + Assert.AreEqual("https://api.openai.com", embeddings.BaseUrl); + Assert.AreEqual("sk-test-key", embeddings.ApiKey); + Assert.IsNull(embeddings.Model); + Assert.AreEqual(EmbeddingsOptions.DEFAULT_OPENAI_MODEL, embeddings.EffectiveModel); + Assert.IsNull(embeddings.ApiVersion); + Assert.AreEqual(EmbeddingsOptions.DEFAULT_DIMENSIONS, embeddings.Dimensions); + Assert.IsNull(embeddings.TimeoutMs); + Assert.AreEqual(EmbeddingsOptions.DEFAULT_TIMEOUT_MS, embeddings.EffectiveTimeoutMs); + } + + /// + /// Tests that minimal Azure OpenAI config deserializes correctly. + /// + [TestMethod] + public void TestMinimalAzureOpenAIConfig() + { + // Act + bool success = RuntimeConfigLoader.TryParseConfig(MINIMAL_AZURE_CONFIG, out RuntimeConfig? runtimeConfig); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); + + EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; + Assert.AreEqual(EmbeddingProviderType.AzureOpenAI, embeddings.Provider); + Assert.AreEqual("my-deployment", embeddings.Model); + Assert.AreEqual("my-deployment", embeddings.EffectiveModel); + } + + /// + /// Tests that configuration without embeddings section deserializes correctly. + /// + [TestMethod] + public void TestConfigWithoutEmbeddings() + { + // Act + bool success = RuntimeConfigLoader.TryParseConfig(CONFIG_WITHOUT_EMBEDDINGS, out RuntimeConfig? runtimeConfig); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(runtimeConfig); + Assert.IsNull(runtimeConfig.Runtime?.Embeddings); + } + + /// + /// Tests that EmbeddingProviderType enum deserializes correctly from JSON. + /// + [DataTestMethod] + [DataRow("azure-openai", EmbeddingProviderType.AzureOpenAI)] + [DataRow("openai", EmbeddingProviderType.OpenAI)] + public void TestEmbeddingProviderTypeDeserialization(string jsonValue, EmbeddingProviderType expected) + { + // Arrange + string config = $@" + {{ + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": {{ + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }}, + ""runtime"": {{ + ""embeddings"": {{ + ""provider"": ""{jsonValue}"", + ""base-url"": ""https://example.com"", + ""api-key"": ""test-key"", + ""model"": ""test-model"" + }} + }}, + ""entities"": {{}} + }}"; + + // Act + bool success = RuntimeConfigLoader.TryParseConfig(config, out RuntimeConfig? runtimeConfig); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); + Assert.AreEqual(expected, runtimeConfig.Runtime.Embeddings.Provider); + } + + /// + /// Tests EmbeddingsOptions serialization to JSON. + /// + [TestMethod] + public void TestEmbeddingsOptionsSerialization() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://my-endpoint.openai.azure.com", + ApiKey: "my-api-key", + Model: "my-model", + ApiVersion: "2024-02-01", + Dimensions: 1536, + TimeoutMs: 60000); + + // Act + JsonSerializerOptions serializerOptions = RuntimeConfigLoader.GetSerializationOptions(replacementSettings: null); + string json = JsonSerializer.Serialize(options, serializerOptions); + + // Normalize json for comparison (remove whitespace) + string normalizedJson = json.Replace(" ", "").Replace("\n", "").Replace("\r", ""); + + // Assert + Assert.IsTrue(normalizedJson.Contains("\"provider\":\"azure-openai\""), $"Expected provider in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"base-url\":\"https://my-endpoint.openai.azure.com\""), $"Expected base-url in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"api-key\":\"my-api-key\""), $"Expected api-key in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"model\":\"my-model\""), $"Expected model in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"api-version\":\"2024-02-01\""), $"Expected api-version in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"dimensions\":1536"), $"Expected dimensions in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"timeout-ms\":60000"), $"Expected timeout-ms in JSON: {json}"); + } + + /// + /// Tests that environment variable replacement works for embeddings configuration. + /// + [TestMethod] + public void TestEmbeddingsConfigWithEnvVarReplacement() + { + // Arrange + string config = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""runtime"": { + ""embeddings"": { + ""provider"": ""azure-openai"", + ""base-url"": ""@env('EMBEDDINGS_ENDPOINT')"", + ""api-key"": ""@env('EMBEDDINGS_API_KEY')"", + ""model"": ""@env('EMBEDDINGS_MODEL')"" + } + }, + ""entities"": {} + }"; + + // Set environment variables + Environment.SetEnvironmentVariable("EMBEDDINGS_ENDPOINT", "https://test.openai.azure.com"); + Environment.SetEnvironmentVariable("EMBEDDINGS_API_KEY", "test-key-from-env"); + Environment.SetEnvironmentVariable("EMBEDDINGS_MODEL", "test-model-from-env"); + + // Create replacement settings to enable env var replacement + DeserializationVariableReplacementSettings replacementSettings = new( + doReplaceEnvVar: true, + doReplaceAkvVar: false, + envFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + + try + { + // Act + bool success = RuntimeConfigLoader.TryParseConfig(config, out RuntimeConfig? runtimeConfig, replacementSettings); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); + + EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; + Assert.AreEqual("https://test.openai.azure.com", embeddings.BaseUrl); + Assert.AreEqual("test-key-from-env", embeddings.ApiKey); + Assert.AreEqual("test-model-from-env", embeddings.Model); + } + finally + { + // Cleanup + Environment.SetEnvironmentVariable("EMBEDDINGS_ENDPOINT", null); + Environment.SetEnvironmentVariable("EMBEDDINGS_API_KEY", null); + Environment.SetEnvironmentVariable("EMBEDDINGS_MODEL", null); + } + } + + /// + /// Tests that Enabled defaults to true when not present in config JSON. + /// + [TestMethod] + public void TestEmbeddingsEnabled_DefaultsToTrue_WhenNotSpecified() + { + // Act + bool success = RuntimeConfigLoader.TryParseConfig(OPENAI_CONFIG, out RuntimeConfig? runtimeConfig); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); + Assert.IsTrue(runtimeConfig.Runtime.Embeddings.Enabled, + "Enabled should default to true when not specified in config."); + } + + /// + /// Tests that Enabled: true deserializes correctly. + /// + [TestMethod] + public void TestEmbeddingsEnabled_TrueDeserializesCorrectly() + { + // Arrange + string config = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""runtime"": { + ""embeddings"": { + ""provider"": ""openai"", + ""base-url"": ""https://api.openai.com"", + ""api-key"": ""sk-test"", + ""enabled"": true + } + }, + ""entities"": {} + }"; + + // Act + bool success = RuntimeConfigLoader.TryParseConfig(config, out RuntimeConfig? runtimeConfig); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); + Assert.IsTrue(runtimeConfig.Runtime.Embeddings.Enabled, + "Enabled should be true when explicitly set to true in config."); + } + + /// + /// Tests that Enabled: false deserializes correctly and results in false. + /// + [TestMethod] + public void TestEmbeddingsEnabled_FalseDeserializesCorrectly() + { + // Arrange + string config = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""runtime"": { + ""embeddings"": { + ""provider"": ""openai"", + ""base-url"": ""https://api.openai.com"", + ""api-key"": ""sk-test"", + ""enabled"": false + } + }, + ""entities"": {} + }"; + + // Act + bool success = RuntimeConfigLoader.TryParseConfig(config, out RuntimeConfig? runtimeConfig); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); + Assert.IsFalse(runtimeConfig.Runtime.Embeddings.Enabled, + "Enabled should be false when explicitly set to false in config."); + } +} diff --git a/src/Service.Tests/UnitTests/RequestParserUnitTests.cs b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs index 4da3266271..0aeb309c61 100644 --- a/src/Service.Tests/UnitTests/RequestParserUnitTests.cs +++ b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs @@ -35,7 +35,7 @@ public class RequestParserUnitTests public void ExtractRawQueryParameter_PreservesEncoding(string queryString, string parameterName, string expectedValue) { // Call the internal method directly (no reflection needed) - string? result = RequestParser.ExtractRawQueryParameter(queryString, parameterName); + string result = RequestParser.ExtractRawQueryParameter(queryString, parameterName); Assert.AreEqual(expectedValue, result, $"Expected '{expectedValue}' but got '{result}' for parameter '{parameterName}' in query '{queryString}'"); @@ -49,10 +49,10 @@ public void ExtractRawQueryParameter_PreservesEncoding(string queryString, strin [DataRow("", "$filter", DisplayName = "Empty query string")] [DataRow(null, "$filter", DisplayName = "Null query string")] [DataRow("?otherParam=value", "$filter", DisplayName = "Different parameter")] - public void ExtractRawQueryParameter_ReturnsNull_WhenParameterNotFound(string? queryString, string parameterName) + public void ExtractRawQueryParameter_ReturnsNull_WhenParameterNotFound(string queryString, string parameterName) { // Call the internal method directly (no reflection needed) - string? result = RequestParser.ExtractRawQueryParameter(queryString, parameterName); + string result = RequestParser.ExtractRawQueryParameter(queryString, parameterName); Assert.IsNull(result, $"Expected null but got '{result}' for parameter '{parameterName}' in query '{queryString}'"); @@ -71,7 +71,7 @@ public void ExtractRawQueryParameter_ReturnsNull_WhenParameterNotFound(string? q public void ExtractRawQueryParameter_HandlesEdgeCases(string queryString, string parameterName, string expectedValue) { // Call the internal method directly (no reflection needed) - string? result = RequestParser.ExtractRawQueryParameter(queryString, parameterName); + string result = RequestParser.ExtractRawQueryParameter(queryString, parameterName); Assert.AreEqual(expectedValue, result, $"Expected '{expectedValue}' but got '{result}' for parameter '{parameterName}' in query '{queryString}'"); diff --git a/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs b/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs index fa2e0e33f6..b640c79cd9 100644 --- a/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs +++ b/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs @@ -951,8 +951,8 @@ public void TestOboNoUserContext_UsesBaseConnectionString() [DataRow(null, null, "iss and oid/sub", DisplayName = "Authenticated user with no claims throws OboAuthenticationFailure")] public void TestOboEnabled_AuthenticatedUserMissingClaims_ThrowsException( - string? issuer, - string? objectId, + string issuer, + string objectId, string missingClaimDescription) { // Arrange - Create an authenticated HttpContext with incomplete claims @@ -987,8 +987,8 @@ public void TestOboEnabled_AuthenticatedUserMissingClaims_ThrowsException( /// The oid claim value, or null to omit. /// A configured HttpContextAccessor mock with authenticated user. private static Mock CreateHttpContextAccessorWithAuthenticatedUserMissingClaims( - string? issuer, - string? objectId) + string issuer, + string objectId) { Mock httpContextAccessor = new(); DefaultHttpContext context = new(); diff --git a/src/Service/Controllers/EmbeddingController.cs b/src/Service/Controllers/EmbeddingController.cs new file mode 100644 index 0000000000..084cb70a74 --- /dev/null +++ b/src/Service/Controllers/EmbeddingController.cs @@ -0,0 +1,432 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Mime; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; +using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Services.Embeddings; +using Azure.DataApiBuilder.Service.Helpers; +using Azure.DataApiBuilder.Service.Models; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Primitives; + +namespace Azure.DataApiBuilder.Service.Controllers; + +/// +/// Controller to serve embedding requests at the fixed endpoint path: /embed. +/// Accepts plain text or JSON input and returns embedding vector as JSON by default, +/// or as plain text (comma-separated floats) when the client sends Accept: text/plain. +/// Uses a dedicated "embed" route to avoid conflicts with other API routes. +/// +[ApiController] +public class EmbeddingController : ControllerBase +{ + private readonly IEmbeddingService? _embeddingService; + private readonly RuntimeConfigProvider _runtimeConfigProvider; + private readonly ILogger _logger; + + /// + /// Constructor. + /// + public EmbeddingController( + RuntimeConfigProvider runtimeConfigProvider, + ILogger logger, + IEmbeddingService? embeddingService = null) + { + _runtimeConfigProvider = runtimeConfigProvider; + _logger = logger; + _embeddingService = embeddingService; + } + + /// + /// POST endpoint for generating embeddings. + /// Accepts plain text, JSON string, or array of documents with key/text pairs. + /// Supports query parameters to override chunking settings. + /// Default response is JSON: { "embedding": [...], "dimensions": N } for single text, + /// or [{ "key": "...", "data": [[...], [...]] }] for document arrays. + /// + /// Embedding vector(s) as JSON, or an error response. + [HttpPost] + [Route("embed")] + [Consumes("text/plain", "application/json")] + [Produces("application/json", "text/plain")] + public async Task PostAsync() + { + // Get embeddings configuration + EmbeddingsOptions? embeddingsOptions = _runtimeConfigProvider.GetConfig()?.Runtime?.Embeddings; + EmbeddingsEndpointOptions? endpointOptions = embeddingsOptions?.Endpoint; + + // Check if embeddings and endpoint are enabled + if (embeddingsOptions is null || !embeddingsOptions.Enabled) + { + return NotFound(); + } + + if (endpointOptions is null || !endpointOptions.Enabled) + { + return NotFound(); + } + + // Check if embedding service is available + if (_embeddingService is null || !_embeddingService.IsEnabled) + { + _logger.LogWarning("Embedding endpoint called but embedding service is not available or disabled."); + Response.StatusCode = (int)HttpStatusCode.ServiceUnavailable; + return RestController.ErrorResponse( + "UnexpectedError", + "Embedding service is not available.", + HttpStatusCode.ServiceUnavailable); + } + + // Check authorization + bool isDevelopmentMode = _runtimeConfigProvider.GetConfig()?.Runtime?.Host?.Mode == HostMode.Development; + string clientRole = GetClientRole(); + + if (!endpointOptions.IsRoleAllowed(clientRole, isDevelopmentMode)) + { + _logger.LogWarning("Embedding endpoint access denied for role: {Role}", clientRole); + Response.StatusCode = (int)HttpStatusCode.Forbidden; + return RestController.ErrorResponse( + "AuthorizationCheckFailed", + "Access denied.", + HttpStatusCode.Forbidden); + } + + // Parse query parameters for chunking options + EmbeddingsChunkingOptions? queryChunkingOptions = ParseChunkingOptionsFromQuery(out string? paramValidationError); + if (paramValidationError is not null) + { + Response.StatusCode = (int)HttpStatusCode.BadRequest; + return RestController.ErrorResponse("BadRequest", paramValidationError, HttpStatusCode.BadRequest); + } + + // Read request body + string requestBody; + try + { + using StreamReader reader = new(Request.Body); + requestBody = await reader.ReadToEndAsync(); + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to read request body for embedding."); + Response.StatusCode = (int)HttpStatusCode.BadRequest; + return RestController.ErrorResponse("BadRequest", "Failed to read request body.", HttpStatusCode.BadRequest); + } + + if (string.IsNullOrWhiteSpace(requestBody)) + { + Response.StatusCode = (int)HttpStatusCode.BadRequest; + return RestController.ErrorResponse("BadRequest", "Request body cannot be empty.", HttpStatusCode.BadRequest); + } + + CancellationToken cancellationToken = HttpContext.RequestAborted; + + // Try to parse as document array first (if JSON content type) + if (Request.ContentType?.Contains("application/json", StringComparison.OrdinalIgnoreCase) == true) + { + try + { + EmbedDocumentRequest[]? documents = JsonSerializer.Deserialize(requestBody); + + if (documents is not null && documents.Length > 0) + { + // Handle as document array + return await ProcessDocumentArrayAsync(documents, embeddingsOptions, queryChunkingOptions, cancellationToken); + } + else if (documents is not null && documents.Length == 0) + { + // Empty document array + Response.StatusCode = (int)HttpStatusCode.BadRequest; + return RestController.ErrorResponse("BadRequest", "Document array cannot be empty.", HttpStatusCode.BadRequest); + } + } + catch (JsonException) + { + // Not a document array, try as single text + _logger.LogDebug("Request body is not a document array, trying as single text."); + } + + // Try to parse as single JSON string + try + { + string? jsonString = JsonSerializer.Deserialize(requestBody); + if (jsonString is not null) + { + requestBody = jsonString; + } + else + { + Response.StatusCode = (int)HttpStatusCode.BadRequest; + return RestController.ErrorResponse("BadRequest", "JSON request body must be a non-null string or a document array.", HttpStatusCode.BadRequest); + } + } + catch (JsonException) + { + // Body is application/json but neither an array nor a string (e.g. {"foo":"bar"}) + Response.StatusCode = (int)HttpStatusCode.BadRequest; + return RestController.ErrorResponse("BadRequest", "Request body with content type 'application/json' must be a JSON string or a document array.", HttpStatusCode.BadRequest); + } + } + + // Handle as single text, applying chunking when enabled + return await ProcessSingleTextAsync(requestBody, embeddingsOptions, queryChunkingOptions, cancellationToken); + } + + /// + /// Processes a document array request and returns embeddings for each document. + /// Uses batch embedding (TryEmbedBatchAsync) per document to reduce round-trips. + /// + private async Task ProcessDocumentArrayAsync( + EmbedDocumentRequest[] documents, + EmbeddingsOptions embeddingsOptions, + EmbeddingsChunkingOptions? queryChunkingOptions, + CancellationToken cancellationToken) + { + List responses = new(); + + foreach (EmbedDocumentRequest doc in documents) + { + if (string.IsNullOrEmpty(doc.Key)) + { + Response.StatusCode = (int)HttpStatusCode.BadRequest; + return RestController.ErrorResponse("BadRequest", "Each document must have a non-empty key.", HttpStatusCode.BadRequest); + } + + if (string.IsNullOrEmpty(doc.Text)) + { + Response.StatusCode = (int)HttpStatusCode.BadRequest; + return RestController.ErrorResponse("BadRequest", "Document with key has empty text.", HttpStatusCode.BadRequest); + } + + try + { + // Use query params if provided, otherwise fall back to config + EmbeddingsChunkingOptions? effectiveChunking = queryChunkingOptions ?? embeddingsOptions.Chunking; + + // Chunk the text if chunking is enabled + string[] chunks = TextChunker.ChunkText(doc.Text, effectiveChunking); + + // Batch-embed all chunks for this document in a single request + EmbeddingBatchResult batchResult = await _embeddingService!.TryEmbedBatchAsync(chunks, cancellationToken); + + if (!batchResult.Success || batchResult.Embeddings is null) + { + _logger.LogError("Failed to embed document chunks: {Error}", batchResult.ErrorMessage); + Response.StatusCode = (int)HttpStatusCode.InternalServerError; + return RestController.ErrorResponse( + "UnexpectedError", + "Failed to generate embeddings.", + HttpStatusCode.InternalServerError); + } + + responses.Add(new EmbedDocumentResponse(doc.Key, batchResult.Embeddings)); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error processing document."); + Response.StatusCode = (int)HttpStatusCode.InternalServerError; + return RestController.ErrorResponse( + "UnexpectedError", + "Failed to generate embeddings.", + HttpStatusCode.InternalServerError); + } + } + + return Ok(responses.ToArray()); + } + + /// + /// Routes a single-text request through chunking when enabled, falling back to the + /// legacy single-embedding response for backward compatibility when not chunked. + /// + private async Task ProcessSingleTextAsync( + string text, + EmbeddingsOptions embeddingsOptions, + EmbeddingsChunkingOptions? queryChunkingOptions, + CancellationToken cancellationToken) + { + EmbeddingsChunkingOptions? effectiveChunking = queryChunkingOptions ?? embeddingsOptions.Chunking; + + if (effectiveChunking is not null && effectiveChunking.Enabled) + { + // Route through document-array path to produce a multi-chunk response + EmbedDocumentRequest[] documents = + [ + new EmbedDocumentRequest { Key = "input", Text = text } + ]; + IActionResult result = await ProcessDocumentArrayAsync(documents, embeddingsOptions, effectiveChunking, cancellationToken); + + // Apply text/plain format when requested, consistent with the non-chunked path. + // Each chunk's embedding is output as one line of comma-separated floats. + if (ClientAcceptsTextPlain() && result is OkObjectResult okResult && okResult.Value is EmbedDocumentResponse[] docResponses) + { + IEnumerable lines = docResponses + .SelectMany(d => d.Data) + .Select(embedding => string.Join(",", embedding.Select(f => f.ToString("G", CultureInfo.InvariantCulture)))); + return Content(string.Join("\n", lines), MediaTypeNames.Text.Plain); + } + + return result; + } + + return await ProcessSingleTextAsync(text, cancellationToken); + } + + /// + /// Processes a single text request and returns embedding (backward compatible, no chunking). + /// + private async Task ProcessSingleTextAsync(string text, CancellationToken cancellationToken) + { + if (string.IsNullOrWhiteSpace(text)) + { + Response.StatusCode = (int)HttpStatusCode.BadRequest; + return RestController.ErrorResponse("BadRequest", "Request body cannot be empty.", HttpStatusCode.BadRequest); + } + + // Generate embedding + EmbeddingResult result = await _embeddingService!.TryEmbedAsync(text, cancellationToken); + + if (!result.Success) + { + _logger.LogError("Embedding request failed: {Error}", result.ErrorMessage); + Response.StatusCode = (int)HttpStatusCode.InternalServerError; + return RestController.ErrorResponse( + "UnexpectedError", + "Failed to generate embedding.", + HttpStatusCode.InternalServerError); + } + + if (result.Embedding is null || result.Embedding.Length == 0) + { + _logger.LogError("Embedding request returned empty result."); + Response.StatusCode = (int)HttpStatusCode.InternalServerError; + return RestController.ErrorResponse( + "UnexpectedError", + "Failed to generate embedding.", + HttpStatusCode.InternalServerError); + } + + // Return embedding as plain text (comma-separated floats) when explicitly requested via Accept header. + if (ClientAcceptsTextPlain()) + { + string embeddingText = string.Join(",", result.Embedding.Select(f => f.ToString("G", CultureInfo.InvariantCulture))); + return Content(embeddingText, MediaTypeNames.Text.Plain); + } + + // Default: return structured JSON response. + return Ok(new EmbeddingResponse(result.Embedding)); + } + + /// + /// Parses query parameters and creates EmbeddingsChunkingOptions. + /// Returns null if no query parameters are provided (use config defaults). + /// Sets to a non-null message if any provided param is invalid. + /// + private EmbeddingsChunkingOptions? ParseChunkingOptionsFromQuery(out string? validationError) + { + validationError = null; + bool? enabled = null; + int? sizeChars = null; + int? overlapChars = null; + + if (Request.Query.TryGetValue("$chunking.enabled", out StringValues enabledValue)) + { + if (bool.TryParse(enabledValue, out bool parsedEnabled)) + { + enabled = parsedEnabled; + } + else + { + validationError = $"Invalid value for '$chunking.enabled': must be 'true' or 'false'."; + return null; + } + } + + if (Request.Query.TryGetValue("$chunking.size-chars", out StringValues sizeValue)) + { + if (int.TryParse(sizeValue, out int size) && size > 0) + { + sizeChars = size; + } + else + { + validationError = $"Invalid value for '$chunking.size-chars': must be a positive integer."; + return null; + } + } + + if (Request.Query.TryGetValue("$chunking.overlap-chars", out StringValues overlapValue)) + { + if (int.TryParse(overlapValue, out int overlap) && overlap >= 0) + { + overlapChars = overlap; + } + else + { + validationError = $"Invalid value for '$chunking.overlap-chars': must be a non-negative integer."; + return null; + } + } + + // If no query parameters provided, return null to use config defaults + if (!enabled.HasValue && !sizeChars.HasValue && !overlapChars.HasValue) + { + return null; + } + + // Create new options with query parameters (using defaults for unspecified values) + return new EmbeddingsChunkingOptions(enabled, sizeChars, overlapChars); + } + + /// + /// Gets the client role from request headers. + /// + private string GetClientRole() + { + StringValues roleHeader = Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER]; + string? firstRole = roleHeader.Count == 1 ? roleHeader[0] : null; + + if (!string.IsNullOrEmpty(firstRole)) + { + return firstRole.ToLowerInvariant(); + } + + return EmbeddingsEndpointOptions.ANONYMOUS_ROLE; + } + + /// + /// Checks whether the client explicitly requests text/plain via the Accept header. + /// Returns true only when text/plain is present and application/json is not, + /// so that the default response format remains JSON. + /// + private bool ClientAcceptsTextPlain() + { + StringValues acceptHeader = Request.Headers.Accept; + if (acceptHeader.Count == 0) + { + return false; + } + + string accept = acceptHeader.ToString(); + bool wantsText = accept.Contains(MediaTypeNames.Text.Plain, StringComparison.OrdinalIgnoreCase); + bool wantsJson = accept.Contains(MediaTypeNames.Application.Json, StringComparison.OrdinalIgnoreCase); + + // Only return text/plain when the client explicitly asks for it + // and does NOT also ask for JSON (in which case JSON wins). + return wantsText && !wantsJson; + } +} + diff --git a/src/Service/HealthCheck/HealthCheckHelper.cs b/src/Service/HealthCheck/HealthCheckHelper.cs index 2a5f6f5ddf..3263be0ed7 100644 --- a/src/Service/HealthCheck/HealthCheckHelper.cs +++ b/src/Service/HealthCheck/HealthCheckHelper.cs @@ -10,7 +10,9 @@ using System.Threading.Tasks; using Azure.DataApiBuilder.Config.HealthCheck; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Core.Services.Embeddings; using Azure.DataApiBuilder.Product; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; @@ -27,18 +29,24 @@ public class HealthCheckHelper // Dependencies private ILogger _logger; private HttpUtilities _httpUtility; + private IEmbeddingService? _embeddingService; + private string _incomingRoleHeader = string.Empty; + private string _incomingRoleToken = string.Empty; private const string TIME_EXCEEDED_ERROR_MESSAGE = "The threshold for executing the request has exceeded."; + private const string DIMENSIONS_MISMATCH_ERROR_MESSAGE = "The embedding dimensions do not match the expected dimensions."; /// /// Constructor to inject the logger and HttpUtility class. /// /// Logger to track the log statements. /// HttpUtility to call methods from the internal class. - public HealthCheckHelper(ILogger logger, HttpUtilities httpUtility) + /// Optional embedding service for embedding health checks. + public HealthCheckHelper(ILogger logger, HttpUtilities httpUtility, IEmbeddingService? embeddingService = null) { _logger = logger; _httpUtility = httpUtility; + _embeddingService = embeddingService; } /// @@ -49,7 +57,7 @@ public HealthCheckHelper(ILogger logger, HttpUtilities httpUt /// The effective role header for the current request. /// The bearer token for the current request. /// This function returns the comprehensive health report after calculating the response time of each datasource, rest and graphql health queries. - public async Task GetHealthCheckResponseAsync(RuntimeConfig runtimeConfig, string roleHeader, string roleToken) + public async Task GetHealthCheckResponseAsync(RuntimeConfig runtimeConfig, string roleHeader = "", string roleToken = "") { // Create a JSON response for the comprehensive health check endpoint using the provided basic health report. // If the response has already been created, it will be reused. @@ -141,14 +149,19 @@ private static void UpdateTimestampOfResponse(ref ComprehensiveHealthCheckReport // Updates the DAB configuration details coming from RuntimeConfig for the Health report. private static void UpdateDabConfigurationDetails(ref ComprehensiveHealthCheckReport comprehensiveHealthCheckReport, RuntimeConfig runtimeConfig) { + bool embeddingsEnabled = runtimeConfig?.Runtime?.Embeddings?.Enabled ?? false; + bool embeddingsEndpointEnabled = embeddingsEnabled && (runtimeConfig?.Runtime?.Embeddings?.IsEndpointEnabled ?? false); + comprehensiveHealthCheckReport.ConfigurationDetails = new ConfigurationDetails { - Rest = runtimeConfig.IsRestEnabled, - GraphQL = runtimeConfig.IsGraphQLEnabled, - Mcp = runtimeConfig.IsMcpEnabled, - Caching = runtimeConfig.IsCachingEnabled, + Rest = runtimeConfig?.IsRestEnabled ?? false, + GraphQL = runtimeConfig?.IsGraphQLEnabled ?? false, + Mcp = runtimeConfig?.IsMcpEnabled ?? false, + Caching = runtimeConfig?.IsCachingEnabled ?? false, Telemetry = runtimeConfig?.Runtime?.Telemetry != null, - Mode = runtimeConfig?.Runtime?.Host?.Mode ?? HostMode.Production, // Modify to runtimeConfig.HostMode in Roles PR + Mode = runtimeConfig?.Runtime?.Host?.Mode ?? HostMode.Production, + Embeddings = embeddingsEnabled, + EmbeddingsEndpoint = embeddingsEndpointEnabled }; } @@ -158,6 +171,7 @@ private async Task UpdateHealthCheckDetailsAsync(ComprehensiveHealthCheckReport comprehensiveHealthCheckReport.Checks = new List(); await UpdateDataSourceHealthCheckResultsAsync(comprehensiveHealthCheckReport, runtimeConfig); await UpdateEntityHealthCheckResultsAsync(comprehensiveHealthCheckReport, runtimeConfig, roleHeader, roleToken); + await UpdateEmbeddingsHealthCheckResultsAsync(comprehensiveHealthCheckReport, runtimeConfig); } // Updates the DataSource Health Check Results in the response. @@ -350,5 +364,108 @@ private async Task PopulateEntityHealthAsync(ComprehensiveHealthCheckReport comp return (HealthCheckConstants.ERROR_RESPONSE_TIME_MS, errorMessage); } + + /// + /// Updates the Embeddings Health Check Results in the response. + /// Executes a test embedding and validates response time and optionally dimensions. + /// + private async Task UpdateEmbeddingsHealthCheckResultsAsync(ComprehensiveHealthCheckReport comprehensiveHealthCheckReport, RuntimeConfig runtimeConfig) + { + EmbeddingsOptions? embeddingsOptions = runtimeConfig?.Runtime?.Embeddings; + EmbeddingsHealthCheckConfig? healthConfig = embeddingsOptions?.Health; + + // Only run health check if embeddings is enabled, health check is enabled, and embedding service is available + if (embeddingsOptions is null || !embeddingsOptions.Enabled || healthConfig is null || !healthConfig.Enabled || _embeddingService is null) + { + return; + } + + if (comprehensiveHealthCheckReport.Checks is null) + { + comprehensiveHealthCheckReport.Checks = new List(); + } + + string testText = healthConfig.TestText; + int thresholdMs = healthConfig.ThresholdMs; + int? expectedDimensions = healthConfig.ExpectedDimensions; + + try + { + Stopwatch stopwatch = new(); + stopwatch.Start(); + EmbeddingResult result = await _embeddingService.TryEmbedAsync(testText); + stopwatch.Stop(); + + int responseTimeMs = (int)stopwatch.ElapsedMilliseconds; + bool isResponseTimeWithinThreshold = responseTimeMs <= thresholdMs; + bool isDimensionsValid = true; + string? errorMessage = null; + + if (!result.Success) + { + errorMessage = result.ErrorMessage ?? "Embedding request failed."; + comprehensiveHealthCheckReport.Checks.Add(new HealthCheckResultEntry + { + Name = "embeddings", + ResponseTimeData = new ResponseTimeData + { + ResponseTimeMs = HealthCheckConstants.ERROR_RESPONSE_TIME_MS, + ThresholdMs = thresholdMs + }, + Exception = errorMessage, + Tags = new List { HealthCheckConstants.EMBEDDING }, + Status = HealthStatus.Unhealthy + }); + return; + } + + // Validate dimensions if expected dimensions is specified + if (expectedDimensions.HasValue && result.Embedding is not null) + { + isDimensionsValid = result.Embedding.Length == expectedDimensions.Value; + if (!isDimensionsValid) + { + errorMessage = $"{DIMENSIONS_MISMATCH_ERROR_MESSAGE} Expected: {expectedDimensions.Value}, Actual: {result.Embedding.Length}"; + } + } + + // Check response time threshold + if (!isResponseTimeWithinThreshold) + { + errorMessage = errorMessage is null ? TIME_EXCEEDED_ERROR_MESSAGE : $"{errorMessage} {TIME_EXCEEDED_ERROR_MESSAGE}"; + } + + bool isHealthy = isResponseTimeWithinThreshold && isDimensionsValid; + + comprehensiveHealthCheckReport.Checks.Add(new HealthCheckResultEntry + { + Name = "embeddings", + ResponseTimeData = new ResponseTimeData + { + ResponseTimeMs = responseTimeMs, + ThresholdMs = thresholdMs + }, + Exception = errorMessage, + Tags = new List { HealthCheckConstants.EMBEDDING }, + Status = isHealthy ? HealthStatus.Healthy : HealthStatus.Unhealthy + }); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error executing embeddings health check."); + comprehensiveHealthCheckReport.Checks.Add(new HealthCheckResultEntry + { + Name = "embeddings", + ResponseTimeData = new ResponseTimeData + { + ResponseTimeMs = HealthCheckConstants.ERROR_RESPONSE_TIME_MS, + ThresholdMs = thresholdMs + }, + Exception = ex.Message, + Tags = new List { HealthCheckConstants.EMBEDDING }, + Status = HealthStatus.Unhealthy + }); + } + } } } diff --git a/src/Service/HealthCheck/Model/ConfigurationDetails.cs b/src/Service/HealthCheck/Model/ConfigurationDetails.cs index 9ff007754e..e73497e3e0 100644 --- a/src/Service/HealthCheck/Model/ConfigurationDetails.cs +++ b/src/Service/HealthCheck/Model/ConfigurationDetails.cs @@ -29,5 +29,11 @@ public record ConfigurationDetails [JsonPropertyName("mode")] public HostMode Mode { get; init; } + + [JsonPropertyName("embeddings")] + public bool Embeddings { get; init; } + + [JsonPropertyName("embeddings-endpoint")] + public bool EmbeddingsEndpoint { get; init; } } } diff --git a/src/Service/Helpers/TextChunker.cs b/src/Service/Helpers/TextChunker.cs new file mode 100644 index 0000000000..f1c1522fdd --- /dev/null +++ b/src/Service/Helpers/TextChunker.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; + +namespace Azure.DataApiBuilder.Service.Helpers; + +/// +/// Static helper for splitting text into overlapping chunks before embedding. +/// Encapsulates the chunking algorithm so it can be tested and reused independently of the controller. +/// +public static class TextChunker +{ + /// + /// Splits into chunks of at most characters, + /// with each consecutive chunk overlapping by characters. + /// Returns an empty array for null or empty input. + /// The step size is always at least 1 (Math.Max(1, chunkSize - overlap)), + /// so this method always terminates regardless of the overlap value. + /// + public static string[] ChunkText(string text, int chunkSize, int overlap) + { + if (string.IsNullOrEmpty(text)) + { + return Array.Empty(); + } + + // Guarantee at least one character of forward progress per iteration. + int step = Math.Max(1, chunkSize - overlap); + + if (text.Length <= chunkSize) + { + return new[] { text }; + } + + List chunks = new(); + int position = 0; + + while (position < text.Length) + { + int remaining = text.Length - position; + chunks.Add(text.Substring(position, Math.Min(chunkSize, remaining))); + position += step; + } + + return chunks.ToArray(); + } + + /// + /// Splits text into chunks based on the provided . + /// When chunking is disabled or options are null, returns the text as a single-element array. + /// Uses to guarantee step >= 1. + /// + public static string[] ChunkText(string text, EmbeddingsChunkingOptions? chunkingOptions) + { + if (chunkingOptions is null || !chunkingOptions.Enabled) + { + return new[] { text }; + } + + // EffectiveSizeChars = Math.Max(SizeChars, OverlapChars + 1), guaranteeing step >= 1. + return ChunkText(text, chunkingOptions.EffectiveSizeChars, chunkingOptions.OverlapChars); + } +} diff --git a/src/Service/Models/EmbedDocumentRequest.cs b/src/Service/Models/EmbedDocumentRequest.cs new file mode 100644 index 0000000000..06985e0334 --- /dev/null +++ b/src/Service/Models/EmbedDocumentRequest.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Service.Models; + +/// +/// Request model for a single document in a batch embedding request. +/// +public record EmbedDocumentRequest +{ + /// + /// Unique key/identifier for this document. + /// + [JsonPropertyName("key")] + public string Key { get; init; } = string.Empty; + + /// + /// The text content to embed. + /// + [JsonPropertyName("text")] + public string Text { get; init; } = string.Empty; +} diff --git a/src/Service/Models/EmbedDocumentResponse.cs b/src/Service/Models/EmbedDocumentResponse.cs new file mode 100644 index 0000000000..5d1b79217f --- /dev/null +++ b/src/Service/Models/EmbedDocumentResponse.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Service.Models; + +/// +/// Response model for a single document in a batch embedding response. +/// +public record EmbedDocumentResponse +{ + /// + /// The unique key/identifier for this document (matches request key). + /// + [JsonPropertyName("key")] + public string Key { get; init; } + + /// + /// The embedding vectors for this document. + /// If chunking is disabled or text fits in one chunk, this will contain one vector. + /// If chunking is enabled and text is split, this will contain multiple vectors (one per chunk). + /// + [JsonPropertyName("data")] + public float[][] Data { get; init; } + + public EmbedDocumentResponse(string key, float[][] data) + { + Key = key; + Data = data; + } +} diff --git a/src/Service/Models/EmbeddingResponse.cs b/src/Service/Models/EmbeddingResponse.cs new file mode 100644 index 0000000000..512881a454 --- /dev/null +++ b/src/Service/Models/EmbeddingResponse.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Service.Models; + +/// +/// JSON response model for the embedding endpoint. +/// Provides a structured, extensible format instead of raw comma-separated text. +/// +public record EmbeddingResponse +{ + /// + /// The embedding vector as an array of floating-point values. + /// + [JsonPropertyName("embedding")] + public float[] Embedding { get; init; } + + /// + /// The number of dimensions in the embedding vector. + /// + [JsonPropertyName("dimensions")] + public int Dimensions { get; init; } + + public EmbeddingResponse(float[] embedding) + { + Embedding = embedding; + Dimensions = embedding.Length; + } +} diff --git a/src/Service/Startup.cs b/src/Service/Startup.cs index 70c162a078..3c9a6f4e17 100644 --- a/src/Service/Startup.cs +++ b/src/Service/Startup.cs @@ -12,6 +12,7 @@ using Azure.DataApiBuilder.Config; using Azure.DataApiBuilder.Config.Converters; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Azure.DataApiBuilder.Config.Utilities; using Azure.DataApiBuilder.Core.AuthenticationHelpers; using Azure.DataApiBuilder.Core.AuthenticationHelpers.AuthenticationSimulator; @@ -24,6 +25,7 @@ using Azure.DataApiBuilder.Core.Resolvers.Factories; using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Core.Services.Cache; +using Azure.DataApiBuilder.Core.Services.Embeddings; using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Core.Services.OpenAPI; using Azure.DataApiBuilder.Core.Telemetry; @@ -169,7 +171,8 @@ public void ConfigureServices(IServiceCollection services) configure.Headers = runtimeConfig.Runtime.Telemetry.OpenTelemetry.Headers; configure.Protocol = OtlpExportProtocol.Grpc; }) - .AddMeter(TelemetryMetricsHelper.MeterName); + .AddMeter(TelemetryMetricsHelper.MeterName) + .AddMeter(EmbeddingTelemetryHelper.MeterName); }) .WithTracing(tracing => { @@ -435,6 +438,50 @@ public void ConfigureServices(IServiceCollection services) services.AddSingleton(); services.AddSingleton(); + // Register embedding service if configured and enabled. + // NOTE: IEmbeddingService is only registered when enabled to avoid constructor + // failures when config has empty/placeholder values for disabled embeddings. + // TODO: To support hot-reload for embeddings (toggling enabled on/off at runtime), + // EmbeddingService would need to read config dynamically from RuntimeConfigProvider + // and defer constructor validation. Track as a separate work item. + if (runtimeConfigAvailable + && runtimeConfig?.Runtime?.IsEmbeddingsConfigured == true) + { + EmbeddingsOptions embeddingsOptions = runtimeConfig.Runtime.Embeddings; + services.AddSingleton(embeddingsOptions); + + string providerName = embeddingsOptions.Provider.ToString().ToLowerInvariant(); + + if (embeddingsOptions.Enabled) + { + services.AddHttpClient(); + _logger.LogInformation( + "Embeddings service enabled with provider: {Provider}, model: {Model}, base-url: {BaseUrl}", + providerName, + embeddingsOptions.EffectiveModel ?? "(default)", + embeddingsOptions.BaseUrl); + + // Endpoint is only available if both embeddings and endpoint are enabled + if (embeddingsOptions.IsEndpointEnabled) + { + _logger.LogInformation( + "Embeddings endpoint enabled at path: {Path}", + EmbeddingsEndpointOptions.DEFAULT_PATH); + } + + if (embeddingsOptions.IsHealthCheckEnabled) + { + _logger.LogInformation( + "Embeddings health check enabled with threshold: {ThresholdMs}ms", + embeddingsOptions.Health!.ThresholdMs); + } + } + else + { + _logger.LogInformation("Embeddings service is configured but disabled."); + } + } + AddGraphQLService(services, runtimeConfig?.Runtime?.GraphQL); // Subscribe the GraphQL schema refresh method to the specific hot-reload event