From 348cf913db21dada660275f51dd9cb08c254ac61 Mon Sep 17 00:00:00 2001 From: Tai An Date: Sat, 25 Apr 2026 03:19:16 -0700 Subject: [PATCH] fix(middleware): parse OpenAI-spec tool_choice in /v1/chat/completions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follows up on #9526 (the 3-site setter fix) by addressing the remaining clause in #9508 — string mode and OpenAI-spec specific-function shape both silently failed in the /v1/chat/completions parsing path. --- core/http/middleware/request.go | 1335 ++++++++++++++++--------------- 1 file changed, 684 insertions(+), 651 deletions(-) diff --git a/core/http/middleware/request.go b/core/http/middleware/request.go index 1c0ee0ec7cf1..b03f6f7e6d7c 100644 --- a/core/http/middleware/request.go +++ b/core/http/middleware/request.go @@ -1,651 +1,684 @@ -package middleware - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "strconv" - "strings" - - "github.com/google/uuid" - "github.com/labstack/echo/v4" - "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/core/services/galleryop" - "github.com/mudler/LocalAI/core/templates" - "github.com/mudler/LocalAI/pkg/functions" - "github.com/mudler/LocalAI/pkg/model" - "github.com/mudler/LocalAI/pkg/utils" - "github.com/mudler/xlog" -) - -type correlationIDKeyType string - -// CorrelationIDKey to track request across process boundary -const CorrelationIDKey correlationIDKeyType = "correlationID" - -type RequestExtractor struct { - modelConfigLoader *config.ModelConfigLoader - modelLoader *model.ModelLoader - applicationConfig *config.ApplicationConfig -} - -func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor { - return &RequestExtractor{ - modelConfigLoader: modelConfigLoader, - modelLoader: modelLoader, - applicationConfig: applicationConfig, - } -} - -const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME" -const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST" -const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG" - -// TODO: Refactor to not return error if unchanged -func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) { - model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if ok && model != "" { - return - } - model = c.Param("model") - - if model == "" { - model = c.QueryParam("model") - } - - // Check FormValue for multipart/form-data requests (e.g., /v1/images/inpainting) - if model == "" { - model = c.FormValue("model") - } - - if model == "" { - // Set model from bearer token, if available - auth := c.Request().Header.Get("Authorization") - bearer := strings.TrimPrefix(auth, "Bearer ") - if bearer != "" && bearer != auth { - exists, err := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, galleryop.ALWAYS_INCLUDE) - if err == nil && exists { - model = bearer - } - } - } - - c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model) -} - -func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - re.setModelNameFromRequest(c) - localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if !ok || localModelName == "" { - c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName) - xlog.Debug("context local model name not found, setting to default", "defaultModelName", defaultModelName) - } - return next(c) - } - } -} - -func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - re.setModelNameFromRequest(c) - localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if localModelName != "" { // Don't overwrite existing values - return next(c) - } - - modelNames, err := galleryop.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, galleryop.SKIP_IF_CONFIGURED) - if err != nil { - xlog.Error("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()", "error", err) - return next(c) - } - - if len(modelNames) == 0 { - xlog.Warn("SetDefaultModelNameToFirstAvailable used with no matching models installed") - // This is non-fatal - making it so was breaking the case of direct installation of raw models - // return errors.New("this endpoint requires at least one model to be installed") - return next(c) - } - - c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0]) - xlog.Debug("context local model name not found, setting to the first model", "first model name", modelNames[0]) - return next(c) - } - } -} - -// TODO: If context and cancel above belong on all methods, move that part of above into here! -// Otherwise, it's in its own method below for now -func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - input := initializer() - if input == nil { - return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body") - } - if err := c.Bind(input); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err)) - } - - // If this request doesn't have an associated model name, fetch it from earlier in the middleware chain - if input.ModelName(nil) == "" { - localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) - if ok && localModelName != "" { - xlog.Debug("overriding empty model name in request body with value found earlier in middleware chain", "context localModelName", localModelName) - input.ModelName(&localModelName) - } - } - - modelName := input.ModelName(nil) - cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(modelName, re.applicationConfig) - - if err != nil { - xlog.Warn("Model Configuration File not found", "model", modelName, "error", err) - } else if cfg.Model == "" && modelName != "" { - xlog.Debug("config does not include model, using input", "input.ModelName", modelName) - cfg.Model = modelName - } - - // If a model name was specified, verify it actually exists before proceeding. - // Check both configured models and loose model files in the model path. - // Skip the check for HuggingFace model IDs (contain "/") since backends - // like diffusers may download these on the fly. - if modelName != "" && !strings.Contains(modelName, "/") { - exists, existsErr := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, modelName, galleryop.ALWAYS_INCLUDE) - if existsErr == nil && !exists { - return c.JSON(http.StatusNotFound, schema.ErrorResponse{ - Error: &schema.APIError{ - Message: fmt.Sprintf("model %q not found. To see available models, call GET /v1/models", modelName), - Code: http.StatusNotFound, - Type: "invalid_request_error", - }, - }) - } - } - - // Check if the model is disabled - if cfg != nil && cfg.IsDisabled() { - return c.JSON(http.StatusForbidden, schema.ErrorResponse{ - Error: &schema.APIError{ - Message: fmt.Sprintf("model %q is disabled and cannot be loaded. Enable it via the System page or API to use it.", modelName), - Code: http.StatusForbidden, - Type: "model_disabled", - }, - }) - } - - c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) - c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) - - return next(c) - } - } -} - -func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error { - input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) - if !ok || input.Model == "" { - return echo.ErrBadRequest - } - - cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) - if !ok || cfg == nil { - return echo.ErrBadRequest - } - - // Extract or generate the correlation ID - correlationID := c.Request().Header.Get("X-Correlation-ID") - if correlationID == "" { - correlationID = uuid.New().String() - } - c.Response().Header().Set("X-Correlation-ID", correlationID) - - // Use the request context directly - Echo properly supports context cancellation! - // No need for workarounds like handleConnectionCancellation - reqCtx := c.Request().Context() - c1, cancel := context.WithCancel(re.applicationConfig.Context) - - // Cancel when request context is cancelled (client disconnects) - go func() { - select { - case <-reqCtx.Done(): - cancel() - case <-c1.Done(): - // Already cancelled - } - }() - - // Add the correlation ID to the new context - ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) - - input.Context = ctxWithCorrelationID - input.Cancel = cancel - - err := mergeOpenAIRequestAndModelConfig(cfg, input) - if err != nil { - return err - } - - if cfg.Model == "" { - xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) - cfg.Model = input.Model - } - - c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) - c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) - - return nil -} - -func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error { - if input.Echo { - config.Echo = input.Echo - } - if input.TopK != nil { - config.TopK = input.TopK - } - if input.TopP != nil { - config.TopP = input.TopP - } - if input.MinP != nil { - config.MinP = input.MinP - } - - if input.Backend != "" { - config.Backend = input.Backend - } - - if input.ClipSkip != 0 { - config.Diffusers.ClipSkip = input.ClipSkip - } - - if input.NegativePromptScale != 0 { - config.NegativePromptScale = input.NegativePromptScale - } - - if input.NegativePrompt != "" { - config.NegativePrompt = input.NegativePrompt - } - - if input.RopeFreqBase != 0 { - config.RopeFreqBase = input.RopeFreqBase - } - - if input.RopeFreqScale != 0 { - config.RopeFreqScale = input.RopeFreqScale - } - - if input.Grammar != "" { - config.Grammar = input.Grammar - } - - if input.Temperature != nil { - config.Temperature = input.Temperature - } - - if input.Maxtokens != nil { - config.Maxtokens = input.Maxtokens - } - - if input.ResponseFormat != nil { - switch responseFormat := input.ResponseFormat.(type) { - case string: - config.ResponseFormat = responseFormat - case map[string]any: - config.ResponseFormatMap = responseFormat - } - } - - switch stop := input.Stop.(type) { - case string: - if stop != "" { - config.StopWords = append(config.StopWords, stop) - } - case []any: - for _, pp := range stop { - if s, ok := pp.(string); ok { - config.StopWords = append(config.StopWords, s) - } - } - } - - if len(input.Tools) > 0 { - for _, tool := range input.Tools { - input.Functions = append(input.Functions, tool.Function) - } - } - - if input.ToolsChoice != nil { - var toolChoice functions.Tool - - switch content := input.ToolsChoice.(type) { - case string: - _ = json.Unmarshal([]byte(content), &toolChoice) - case map[string]any: - dat, _ := json.Marshal(content) - _ = json.Unmarshal(dat, &toolChoice) - } - input.FunctionCall = map[string]any{ - "name": toolChoice.Function.Name, - } - } - - // Decode each request's message content - imgIndex, vidIndex, audioIndex := 0, 0, 0 - for i, m := range input.Messages { - nrOfImgsInMessage := 0 - nrOfVideosInMessage := 0 - nrOfAudiosInMessage := 0 - - switch content := m.Content.(type) { - case string: - input.Messages[i].StringContent = content - case []any: - dat, _ := json.Marshal(content) - c := []schema.Content{} - json.Unmarshal(dat, &c) - - textContent := "" - // we will template this at the end - - CONTENT: - for _, pp := range c { - switch pp.Type { - case "text": - textContent += pp.Text - //input.Messages[i].StringContent = pp.Text - case "video", "video_url": - // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) - if err != nil { - xlog.Error("Failed encoding video", "error", err) - continue CONTENT - } - input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff - vidIndex++ - nrOfVideosInMessage++ - case "audio_url", "audio": - // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) - if err != nil { - xlog.Error("Failed encoding audio", "error", err) - continue CONTENT - } - input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff - audioIndex++ - nrOfAudiosInMessage++ - case "input_audio": - // TODO: make sure that we only return base64 stuff - input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data) - audioIndex++ - nrOfAudiosInMessage++ - case "image_url", "image": - // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) - if err != nil { - xlog.Error("Failed encoding image", "error", err) - continue CONTENT - } - - input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff - - imgIndex++ - nrOfImgsInMessage++ - } - } - - // When the backend handles templating itself (UseTokenizerTemplate), - // it also injects media markers server-side (see - // oaicompat_chat_params_parse in llama.cpp). Emitting our own markers - // here would double-mark them and downstream consumers ignore - // StringContent in that path anyway, so just pass through plain text. - if config.TemplateConfig.UseTokenizerTemplate { - input.Messages[i].StringContent = textContent - } else { - input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ - TotalImages: imgIndex, - TotalVideos: vidIndex, - TotalAudios: audioIndex, - ImagesInMessage: nrOfImgsInMessage, - VideosInMessage: nrOfVideosInMessage, - AudiosInMessage: nrOfAudiosInMessage, - }, textContent) - } - } - } - - if input.RepeatPenalty != 0 { - config.RepeatPenalty = input.RepeatPenalty - } - - if input.FrequencyPenalty != 0 { - config.FrequencyPenalty = input.FrequencyPenalty - } - - if input.PresencePenalty != 0 { - config.PresencePenalty = input.PresencePenalty - } - - if input.Keep != 0 { - config.Keep = input.Keep - } - - if input.Batch != 0 { - config.Batch = input.Batch - } - - if input.IgnoreEOS { - config.IgnoreEOS = input.IgnoreEOS - } - - if input.Seed != nil { - config.Seed = input.Seed - } - - if input.TypicalP != nil { - config.TypicalP = input.TypicalP - } - - xlog.Debug("input.Input", "input", fmt.Sprintf("%+v", input.Input)) - - switch inputs := input.Input.(type) { - case string: - if inputs != "" { - config.InputStrings = append(config.InputStrings, inputs) - } - case []any: - for _, pp := range inputs { - switch i := pp.(type) { - case string: - config.InputStrings = append(config.InputStrings, i) - case []any: - tokens := []int{} - inputStrings := []string{} - for _, ii := range i { - switch ii := ii.(type) { - case int: - tokens = append(tokens, ii) - case float64: - tokens = append(tokens, int(ii)) - case string: - inputStrings = append(inputStrings, ii) - default: - xlog.Error("Unknown input type", "type", fmt.Sprintf("%T", ii)) - } - } - config.InputToken = append(config.InputToken, tokens) - config.InputStrings = append(config.InputStrings, inputStrings...) - } - } - } - - // Can be either a string or an object - switch fnc := input.FunctionCall.(type) { - case string: - if fnc != "" { - config.SetFunctionCallString(fnc) - } - case map[string]any: - var name string - n, exists := fnc["name"] - if exists { - nn, e := n.(string) - if e { - name = nn - } - } - config.SetFunctionCallNameString(name) - } - - switch p := input.Prompt.(type) { - case string: - config.PromptStrings = append(config.PromptStrings, p) - case []any: - for _, pp := range p { - if s, ok := pp.(string); ok { - config.PromptStrings = append(config.PromptStrings, s) - } - } - } - - // If a quality was defined as number, convert it to step - if input.Quality != "" { - q, err := strconv.Atoi(input.Quality) - if err == nil { - config.Step = q - } - } - - if valid, _ := config.Validate(); valid { - return nil - } - return fmt.Errorf("unable to validate configuration after merging") -} - -func (re *RequestExtractor) SetOpenResponsesRequest(c echo.Context) error { - input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenResponsesRequest) - if !ok || input.Model == "" { - return echo.ErrBadRequest - } - - // Convert input items to Messages (this will be done in the endpoint handler) - // We store the input in the request for the endpoint to process - cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) - if !ok || cfg == nil { - return echo.ErrBadRequest - } - - // Extract or generate the correlation ID (Open Responses uses x-request-id) - correlationID := c.Request().Header.Get("x-request-id") - if correlationID == "" { - correlationID = uuid.New().String() - } - c.Response().Header().Set("x-request-id", correlationID) - - // Use the request context directly - Echo properly supports context cancellation! - reqCtx := c.Request().Context() - c1, cancel := context.WithCancel(re.applicationConfig.Context) - - // Cancel when request context is cancelled (client disconnects) - go func() { - select { - case <-reqCtx.Done(): - cancel() - case <-c1.Done(): - // Already cancelled - } - }() - - // Add the correlation ID to the new context - ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) - - input.Context = ctxWithCorrelationID - input.Cancel = cancel - - err := MergeOpenResponsesConfig(cfg, input) - if err != nil { - return err - } - - if cfg.Model == "" { - xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) - cfg.Model = input.Model - } - - c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) - c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) - - return nil -} - -// MergeOpenResponsesConfig merges request parameters into the model configuration. -func MergeOpenResponsesConfig(config *config.ModelConfig, input *schema.OpenResponsesRequest) error { - // Temperature - if input.Temperature != nil { - config.Temperature = input.Temperature - } - - // TopP - if input.TopP != nil { - config.TopP = input.TopP - } - - // MaxOutputTokens -> Maxtokens - if input.MaxOutputTokens != nil { - config.Maxtokens = input.MaxOutputTokens - } - - // Convert tools to functions - this will be handled in the endpoint handler - // We just validate that tools are present if needed - - // Handle tool_choice - if input.ToolChoice != nil { - switch tc := input.ToolChoice.(type) { - case string: - // "auto", "required", or "none" - if tc == "required" { - config.SetFunctionCallString("required") - } else if tc == "none" { - // Don't use tools - handled in endpoint - } - // "auto" is default - let model decide - case map[string]any: - // Specific tool. OpenAI spec nests the function name under "function": - // {"type":"function", "function":{"name":"..."}} - // Legacy/Anthropic-compat form puts it at the top level: - // {"type":"function", "name":"..."} - // The old code only handled the legacy shape AND used the wrong - // setter (SetFunctionCallString writes the mode field; the - // specific-function name lives in a separate field read by - // ShouldCallSpecificFunction / FunctionToCall). Net effect: a - // correctly-formed OpenAI tool_choice never engaged grammar-based - // forcing, the model got the tools but no selection hint, and - // streamed raw JSON as delta.content instead of delta.tool_calls. - if tcType, ok := tc["type"].(string); ok && tcType == "function" { - var name string - if fn, ok := tc["function"].(map[string]any); ok { - if n, ok := fn["name"].(string); ok { - name = n - } - } - if name == "" { - if n, ok := tc["name"].(string); ok { - name = n - } - } - if name != "" { - config.SetFunctionCallNameString(name) - } - } - } - } - - if valid, _ := config.Validate(); valid { - return nil - } - return fmt.Errorf("unable to validate configuration after merging") -} +package middleware + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/galleryop" + "github.com/mudler/LocalAI/core/templates" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/utils" + "github.com/mudler/xlog" +) + +type correlationIDKeyType string + +// CorrelationIDKey to track request across process boundary +const CorrelationIDKey correlationIDKeyType = "correlationID" + +type RequestExtractor struct { + modelConfigLoader *config.ModelConfigLoader + modelLoader *model.ModelLoader + applicationConfig *config.ApplicationConfig +} + +func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor { + return &RequestExtractor{ + modelConfigLoader: modelConfigLoader, + modelLoader: modelLoader, + applicationConfig: applicationConfig, + } +} + +const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME" +const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST" +const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG" + +// TODO: Refactor to not return error if unchanged +func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) { + model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if ok && model != "" { + return + } + model = c.Param("model") + + if model == "" { + model = c.QueryParam("model") + } + + // Check FormValue for multipart/form-data requests (e.g., /v1/images/inpainting) + if model == "" { + model = c.FormValue("model") + } + + if model == "" { + // Set model from bearer token, if available + auth := c.Request().Header.Get("Authorization") + bearer := strings.TrimPrefix(auth, "Bearer ") + if bearer != "" && bearer != auth { + exists, err := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, galleryop.ALWAYS_INCLUDE) + if err == nil && exists { + model = bearer + } + } + } + + c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model) +} + +func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + re.setModelNameFromRequest(c) + localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if !ok || localModelName == "" { + c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName) + xlog.Debug("context local model name not found, setting to default", "defaultModelName", defaultModelName) + } + return next(c) + } + } +} + +func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + re.setModelNameFromRequest(c) + localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if localModelName != "" { // Don't overwrite existing values + return next(c) + } + + modelNames, err := galleryop.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, galleryop.SKIP_IF_CONFIGURED) + if err != nil { + xlog.Error("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()", "error", err) + return next(c) + } + + if len(modelNames) == 0 { + xlog.Warn("SetDefaultModelNameToFirstAvailable used with no matching models installed") + // This is non-fatal - making it so was breaking the case of direct installation of raw models + // return errors.New("this endpoint requires at least one model to be installed") + return next(c) + } + + c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0]) + xlog.Debug("context local model name not found, setting to the first model", "first model name", modelNames[0]) + return next(c) + } + } +} + +// TODO: If context and cancel above belong on all methods, move that part of above into here! +// Otherwise, it's in its own method below for now +func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + input := initializer() + if input == nil { + return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body") + } + if err := c.Bind(input); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err)) + } + + // If this request doesn't have an associated model name, fetch it from earlier in the middleware chain + if input.ModelName(nil) == "" { + localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if ok && localModelName != "" { + xlog.Debug("overriding empty model name in request body with value found earlier in middleware chain", "context localModelName", localModelName) + input.ModelName(&localModelName) + } + } + + modelName := input.ModelName(nil) + cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(modelName, re.applicationConfig) + + if err != nil { + xlog.Warn("Model Configuration File not found", "model", modelName, "error", err) + } else if cfg.Model == "" && modelName != "" { + xlog.Debug("config does not include model, using input", "input.ModelName", modelName) + cfg.Model = modelName + } + + // If a model name was specified, verify it actually exists before proceeding. + // Check both configured models and loose model files in the model path. + // Skip the check for HuggingFace model IDs (contain "/") since backends + // like diffusers may download these on the fly. + if modelName != "" && !strings.Contains(modelName, "/") { + exists, existsErr := galleryop.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, modelName, galleryop.ALWAYS_INCLUDE) + if existsErr == nil && !exists { + return c.JSON(http.StatusNotFound, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: fmt.Sprintf("model %q not found. To see available models, call GET /v1/models", modelName), + Code: http.StatusNotFound, + Type: "invalid_request_error", + }, + }) + } + } + + // Check if the model is disabled + if cfg != nil && cfg.IsDisabled() { + return c.JSON(http.StatusForbidden, schema.ErrorResponse{ + Error: &schema.APIError{ + Message: fmt.Sprintf("model %q is disabled and cannot be loaded. Enable it via the System page or API to use it.", modelName), + Code: http.StatusForbidden, + Type: "model_disabled", + }, + }) + } + + c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) + c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + + return next(c) + } + } +} + +func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error { + input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + if !ok || input.Model == "" { + return echo.ErrBadRequest + } + + cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || cfg == nil { + return echo.ErrBadRequest + } + + // Extract or generate the correlation ID + correlationID := c.Request().Header.Get("X-Correlation-ID") + if correlationID == "" { + correlationID = uuid.New().String() + } + c.Response().Header().Set("X-Correlation-ID", correlationID) + + // Use the request context directly - Echo properly supports context cancellation! + // No need for workarounds like handleConnectionCancellation + reqCtx := c.Request().Context() + c1, cancel := context.WithCancel(re.applicationConfig.Context) + + // Cancel when request context is cancelled (client disconnects) + go func() { + select { + case <-reqCtx.Done(): + cancel() + case <-c1.Done(): + // Already cancelled + } + }() + + // Add the correlation ID to the new context + ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) + + input.Context = ctxWithCorrelationID + input.Cancel = cancel + + err := mergeOpenAIRequestAndModelConfig(cfg, input) + if err != nil { + return err + } + + if cfg.Model == "" { + xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) + cfg.Model = input.Model + } + + c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) + c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + + return nil +} + +func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error { + if input.Echo { + config.Echo = input.Echo + } + if input.TopK != nil { + config.TopK = input.TopK + } + if input.TopP != nil { + config.TopP = input.TopP + } + if input.MinP != nil { + config.MinP = input.MinP + } + + if input.Backend != "" { + config.Backend = input.Backend + } + + if input.ClipSkip != 0 { + config.Diffusers.ClipSkip = input.ClipSkip + } + + if input.NegativePromptScale != 0 { + config.NegativePromptScale = input.NegativePromptScale + } + + if input.NegativePrompt != "" { + config.NegativePrompt = input.NegativePrompt + } + + if input.RopeFreqBase != 0 { + config.RopeFreqBase = input.RopeFreqBase + } + + if input.RopeFreqScale != 0 { + config.RopeFreqScale = input.RopeFreqScale + } + + if input.Grammar != "" { + config.Grammar = input.Grammar + } + + if input.Temperature != nil { + config.Temperature = input.Temperature + } + + if input.Maxtokens != nil { + config.Maxtokens = input.Maxtokens + } + + if input.ResponseFormat != nil { + switch responseFormat := input.ResponseFormat.(type) { + case string: + config.ResponseFormat = responseFormat + case map[string]any: + config.ResponseFormatMap = responseFormat + } + } + + switch stop := input.Stop.(type) { + case string: + if stop != "" { + config.StopWords = append(config.StopWords, stop) + } + case []any: + for _, pp := range stop { + if s, ok := pp.(string); ok { + config.StopWords = append(config.StopWords, s) + } + } + } + + if len(input.Tools) > 0 { + for _, tool := range input.Tools { + input.Functions = append(input.Functions, tool.Function) + } + } + + if input.ToolsChoice != nil { + // OpenAI tool_choice has three valid shapes: + // + // 1. string mode: "auto" | "none" | "required" + // 2. specific tool: {"type":"function", "function":{"name":"..."}} (current spec) + // 3. legacy: {"type":"function", "name":"..."} (older / Anthropic-compat) + // + // The previous code unmarshalled all three into functions.Tool via + // json.Unmarshal and then unconditionally set input.FunctionCall = + // {"name": toolChoice.Function.Name}. That had two consequences: + // + // - For string modes, json.Unmarshal([]byte("required"), &Tool{}) fails; + // the error was silently discarded, name was "", and downstream + // SetFunctionCallNameString("") meant the requested mode never applied. + // - For the OpenAI-spec map shape, the json keys did not match + // functions.Tool's field tags, so name was "" again. + // + // Mirror the parsing pattern from MergeOpenResponsesConfig (#9509) and + // route results through the existing input.FunctionCall string/map + // dispatch downstream (see the switch on input.FunctionCall in this + // same function). Tracked in #9508; sibling fix in #9526. + switch content := input.ToolsChoice.(type) { + case string: + // "auto" is the default and needs no override; "none" is handled + // at the endpoint layer by skipping tool wiring. "required" must + // reach SetFunctionCallString to engage the mode field. + if content != "" && content != "auto" { + input.FunctionCall = content + } + case map[string]any: + if tcType, ok := content["type"].(string); ok && tcType == "function" { + var name string + if fn, ok := content["function"].(map[string]any); ok { + if n, ok := fn["name"].(string); ok { + name = n + } + } + if name == "" { + if n, ok := content["name"].(string); ok { + name = n + } + } + if name != "" { + input.FunctionCall = map[string]any{"name": name} + } + } + } + } + + // Decode each request's message content + imgIndex, vidIndex, audioIndex := 0, 0, 0 + for i, m := range input.Messages { + nrOfImgsInMessage := 0 + nrOfVideosInMessage := 0 + nrOfAudiosInMessage := 0 + + switch content := m.Content.(type) { + case string: + input.Messages[i].StringContent = content + case []any: + dat, _ := json.Marshal(content) + c := []schema.Content{} + json.Unmarshal(dat, &c) + + textContent := "" + // we will template this at the end + + CONTENT: + for _, pp := range c { + switch pp.Type { + case "text": + textContent += pp.Text + //input.Messages[i].StringContent = pp.Text + case "video", "video_url": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) + if err != nil { + xlog.Error("Failed encoding video", "error", err) + continue CONTENT + } + input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff + vidIndex++ + nrOfVideosInMessage++ + case "audio_url", "audio": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) + if err != nil { + xlog.Error("Failed encoding audio", "error", err) + continue CONTENT + } + input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff + audioIndex++ + nrOfAudiosInMessage++ + case "input_audio": + // TODO: make sure that we only return base64 stuff + input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data) + audioIndex++ + nrOfAudiosInMessage++ + case "image_url", "image": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) + if err != nil { + xlog.Error("Failed encoding image", "error", err) + continue CONTENT + } + + input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff + + imgIndex++ + nrOfImgsInMessage++ + } + } + + // When the backend handles templating itself (UseTokenizerTemplate), + // it also injects media markers server-side (see + // oaicompat_chat_params_parse in llama.cpp). Emitting our own markers + // here would double-mark them and downstream consumers ignore + // StringContent in that path anyway, so just pass through plain text. + if config.TemplateConfig.UseTokenizerTemplate { + input.Messages[i].StringContent = textContent + } else { + input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ + TotalImages: imgIndex, + TotalVideos: vidIndex, + TotalAudios: audioIndex, + ImagesInMessage: nrOfImgsInMessage, + VideosInMessage: nrOfVideosInMessage, + AudiosInMessage: nrOfAudiosInMessage, + }, textContent) + } + } + } + + if input.RepeatPenalty != 0 { + config.RepeatPenalty = input.RepeatPenalty + } + + if input.FrequencyPenalty != 0 { + config.FrequencyPenalty = input.FrequencyPenalty + } + + if input.PresencePenalty != 0 { + config.PresencePenalty = input.PresencePenalty + } + + if input.Keep != 0 { + config.Keep = input.Keep + } + + if input.Batch != 0 { + config.Batch = input.Batch + } + + if input.IgnoreEOS { + config.IgnoreEOS = input.IgnoreEOS + } + + if input.Seed != nil { + config.Seed = input.Seed + } + + if input.TypicalP != nil { + config.TypicalP = input.TypicalP + } + + xlog.Debug("input.Input", "input", fmt.Sprintf("%+v", input.Input)) + + switch inputs := input.Input.(type) { + case string: + if inputs != "" { + config.InputStrings = append(config.InputStrings, inputs) + } + case []any: + for _, pp := range inputs { + switch i := pp.(type) { + case string: + config.InputStrings = append(config.InputStrings, i) + case []any: + tokens := []int{} + inputStrings := []string{} + for _, ii := range i { + switch ii := ii.(type) { + case int: + tokens = append(tokens, ii) + case float64: + tokens = append(tokens, int(ii)) + case string: + inputStrings = append(inputStrings, ii) + default: + xlog.Error("Unknown input type", "type", fmt.Sprintf("%T", ii)) + } + } + config.InputToken = append(config.InputToken, tokens) + config.InputStrings = append(config.InputStrings, inputStrings...) + } + } + } + + // Can be either a string or an object + switch fnc := input.FunctionCall.(type) { + case string: + if fnc != "" { + config.SetFunctionCallString(fnc) + } + case map[string]any: + var name string + n, exists := fnc["name"] + if exists { + nn, e := n.(string) + if e { + name = nn + } + } + config.SetFunctionCallNameString(name) + } + + switch p := input.Prompt.(type) { + case string: + config.PromptStrings = append(config.PromptStrings, p) + case []any: + for _, pp := range p { + if s, ok := pp.(string); ok { + config.PromptStrings = append(config.PromptStrings, s) + } + } + } + + // If a quality was defined as number, convert it to step + if input.Quality != "" { + q, err := strconv.Atoi(input.Quality) + if err == nil { + config.Step = q + } + } + + if valid, _ := config.Validate(); valid { + return nil + } + return fmt.Errorf("unable to validate configuration after merging") +} + +func (re *RequestExtractor) SetOpenResponsesRequest(c echo.Context) error { + input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenResponsesRequest) + if !ok || input.Model == "" { + return echo.ErrBadRequest + } + + // Convert input items to Messages (this will be done in the endpoint handler) + // We store the input in the request for the endpoint to process + cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || cfg == nil { + return echo.ErrBadRequest + } + + // Extract or generate the correlation ID (Open Responses uses x-request-id) + correlationID := c.Request().Header.Get("x-request-id") + if correlationID == "" { + correlationID = uuid.New().String() + } + c.Response().Header().Set("x-request-id", correlationID) + + // Use the request context directly - Echo properly supports context cancellation! + reqCtx := c.Request().Context() + c1, cancel := context.WithCancel(re.applicationConfig.Context) + + // Cancel when request context is cancelled (client disconnects) + go func() { + select { + case <-reqCtx.Done(): + cancel() + case <-c1.Done(): + // Already cancelled + } + }() + + // Add the correlation ID to the new context + ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) + + input.Context = ctxWithCorrelationID + input.Cancel = cancel + + err := MergeOpenResponsesConfig(cfg, input) + if err != nil { + return err + } + + if cfg.Model == "" { + xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) + cfg.Model = input.Model + } + + c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) + c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + + return nil +} + +// MergeOpenResponsesConfig merges request parameters into the model configuration. +func MergeOpenResponsesConfig(config *config.ModelConfig, input *schema.OpenResponsesRequest) error { + // Temperature + if input.Temperature != nil { + config.Temperature = input.Temperature + } + + // TopP + if input.TopP != nil { + config.TopP = input.TopP + } + + // MaxOutputTokens -> Maxtokens + if input.MaxOutputTokens != nil { + config.Maxtokens = input.MaxOutputTokens + } + + // Convert tools to functions - this will be handled in the endpoint handler + // We just validate that tools are present if needed + + // Handle tool_choice + if input.ToolChoice != nil { + switch tc := input.ToolChoice.(type) { + case string: + // "auto", "required", or "none" + if tc == "required" { + config.SetFunctionCallString("required") + } else if tc == "none" { + // Don't use tools - handled in endpoint + } + // "auto" is default - let model decide + case map[string]any: + // Specific tool. OpenAI spec nests the function name under "function": + // {"type":"function", "function":{"name":"..."}} + // Legacy/Anthropic-compat form puts it at the top level: + // {"type":"function", "name":"..."} + // The old code only handled the legacy shape AND used the wrong + // setter (SetFunctionCallString writes the mode field; the + // specific-function name lives in a separate field read by + // ShouldCallSpecificFunction / FunctionToCall). Net effect: a + // correctly-formed OpenAI tool_choice never engaged grammar-based + // forcing, the model got the tools but no selection hint, and + // streamed raw JSON as delta.content instead of delta.tool_calls. + if tcType, ok := tc["type"].(string); ok && tcType == "function" { + var name string + if fn, ok := tc["function"].(map[string]any); ok { + if n, ok := fn["name"].(string); ok { + name = n + } + } + if name == "" { + if n, ok := tc["name"].(string); ok { + name = n + } + } + if name != "" { + config.SetFunctionCallNameString(name) + } + } + } + } + + if valid, _ := config.Validate(); valid { + return nil + } + return fmt.Errorf("unable to validate configuration after merging") +}