Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,11 @@ message ModelOptions {
bool Reranking = 71;

repeated string Overrides = 72;

// EngineArgs carries a JSON-encoded map of backend-native engine arguments
// applied verbatim to the backend's engine constructor (e.g. vLLM AsyncEngineArgs).
// Unknown keys produce an error at LoadModel time.
string EngineArgs = 73;
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while I'm ok with it in general, we already carry a repeated string Options = 62; that is used already for this purpose (which calls already for refactoring, as would make much more sense to have a map<string, string> instead)

}

message Result {
Expand Down
41 changes: 41 additions & 0 deletions backend/python/vllm/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3
import asyncio
import dataclasses
import difflib
from concurrent import futures
import argparse
import signal
Expand Down Expand Up @@ -101,6 +103,36 @@ def _parse_options(self, options_list):
opts[key.strip()] = value.strip()
return opts

def _apply_engine_args(self, engine_args, engine_args_json):
"""Apply user-supplied engine_args (JSON object) onto an AsyncEngineArgs.

Returns a new AsyncEngineArgs with the typed fields preserved and the
user's overrides layered on top. Uses ``dataclasses.replace`` so vLLM's
``__post_init__`` re-runs and auto-converts dict-valued fields like
``compilation_config`` / ``attention_config`` into their dataclass form.
``speculative_config`` and ``kv_transfer_config`` are accepted as dicts
directly (vLLM converts them at engine init).

Unknown keys raise ValueError with the closest valid field as a hint.
"""
if not engine_args_json:
return engine_args
try:
extra = json.loads(engine_args_json)
except json.JSONDecodeError as e:
raise ValueError(f"engine_args is not valid JSON: {e}") from e
if not isinstance(extra, dict):
raise ValueError(
f"engine_args must be a JSON object, got {type(extra).__name__}"
)
valid = {f.name for f in dataclasses.fields(type(engine_args))}
for key in extra:
if key not in valid:
suggestion = difflib.get_close_matches(key, valid, n=1)
hint = f" did you mean {suggestion[0]!r}?" if suggestion else ""
raise ValueError(f"unknown engine_args key {key!r}.{hint}")
return dataclasses.replace(engine_args, **extra)

def _messages_to_dicts(self, messages):
"""Convert proto Messages to list of dicts suitable for apply_chat_template()."""
result = []
Expand Down Expand Up @@ -176,6 +208,15 @@ async def LoadModel(self, request, context):
"audio": max(request.LimitAudioPerPrompt, 1)
}

# engine_args from YAML overrides typed fields above so operators can
# tune anything the AsyncEngineArgs dataclass exposes without waiting
# on protobuf changes.
try:
engine_args = self._apply_engine_args(engine_args, request.EngineArgs)
except ValueError as err:
print(f"engine_args error: {err}", file=sys.stderr)
return backend_pb2.Result(success=False, message=str(err))

try:
self.llm = AsyncLLMEngine.from_engine_args(engine_args)
except Exception as err:
Expand Down
9 changes: 8 additions & 1 deletion backend/python/vllm/requirements-cublas12-after.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
# flash-attn wheels are ABI-tied to a specific torch version. vllm forces
# torch==2.10.0 as a hard dep, but flash-attn 2.8.3 (latest) only ships
# prebuilt wheels up to torch 2.8 — any wheel we pin here gets silently
# broken when vllm upgrades torch during install, producing an undefined
# libc10_cuda symbol at import time. FlashInfer (required by vllm) covers
# attention, and rotary_embedding/common.py guards the flash_attn import
# with find_spec(), so skipping flash-attn is safe and the only stable
# choice until upstream ships a torch-2.10 wheel.
vllm
2 changes: 1 addition & 1 deletion backend/python/vllm/requirements-cublas12.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
accelerate
torch==2.7.0
torch
transformers
bitsandbytes
52 changes: 52 additions & 0 deletions backend/python/vllm/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,58 @@ def test_parse_options(self):
self.assertEqual(opts["key_with_colons"], "a:b:c")
self.assertNotIn("invalid_no_colon", opts)

def test_apply_engine_args_known_keys(self):
"""
Tests _apply_engine_args overlays user-supplied JSON onto AsyncEngineArgs.
"""
import sys, os, json as _json
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from backend import BackendServicer
from vllm.engine.arg_utils import AsyncEngineArgs

servicer = BackendServicer()
base = AsyncEngineArgs(model="facebook/opt-125m")
extras = _json.dumps({
"trust_remote_code": True,
"max_num_seqs": 32,
})
out = servicer._apply_engine_args(base, extras)
self.assertTrue(out.trust_remote_code)
self.assertEqual(out.max_num_seqs, 32)
# untouched fields preserved
self.assertEqual(out.model, "facebook/opt-125m")

def test_apply_engine_args_unknown_key_raises(self):
"""
Tests _apply_engine_args rejects unknown keys with a helpful suggestion.
"""
import sys, os, json as _json
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from backend import BackendServicer
from vllm.engine.arg_utils import AsyncEngineArgs

servicer = BackendServicer()
base = AsyncEngineArgs(model="facebook/opt-125m")
with self.assertRaises(ValueError) as ctx:
servicer._apply_engine_args(base, _json.dumps({"trustremotecode": True}))
self.assertIn("trustremotecode", str(ctx.exception))
# close-match hint for the typo
self.assertIn("trust_remote_code", str(ctx.exception))

def test_apply_engine_args_empty_passthrough(self):
"""
Tests that empty engine_args returns the base unchanged.
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from backend import BackendServicer
from vllm.engine.arg_utils import AsyncEngineArgs

servicer = BackendServicer()
base = AsyncEngineArgs(model="facebook/opt-125m")
self.assertIs(servicer._apply_engine_args(base, ""), base)
self.assertIs(servicer._apply_engine_args(base, None), base)

def test_tokenize_string(self):
"""
Tests the TokenizeString RPC returns valid tokens.
Expand Down
14 changes: 14 additions & 0 deletions core/backend/options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package backend

import (
"encoding/json"
"math/rand/v2"
"os"
"path/filepath"
Expand Down Expand Up @@ -159,6 +160,18 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
})
}

engineArgsJSON := ""
if len(c.EngineArgs) > 0 {
if buf, err := json.Marshal(c.EngineArgs); err != nil {
// EngineArgs failing to marshal is a config bug, not a runtime
// condition — surface it loudly but don't break load (the backend
// will see an empty string and proceed with typed args only).
xlog.Warn("engine_args failed to marshal; ignoring", "model", c.Model, "err", err)
} else {
engineArgsJSON = string(buf)
}
}

opts := &pb.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType,
Expand All @@ -176,6 +189,7 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions {
CLIPSubfolder: c.Diffusers.ClipSubFolder,
Options: c.Options,
Overrides: c.Overrides,
EngineArgs: engineArgsJSON,
CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
ContextSize: int32(ctxSize),
Expand Down
61 changes: 61 additions & 0 deletions core/backend/options_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package backend

import (
"encoding/json"
"testing"

"github.com/mudler/LocalAI/core/config"
)

func TestGrpcModelOpts_EngineArgsSerialization(t *testing.T) {
threads := 1
cfg := config.ModelConfig{
Threads: &threads,
LLMConfig: config.LLMConfig{
EngineArgs: map[string]any{
"data_parallel_size": 8,
"enable_expert_parallel": true,
"speculative_config": map[string]any{
"method": "ngram",
"num_speculative_tokens": 4,
},
},
},
}

opts := grpcModelOpts(cfg, "/tmp/models")

if opts.EngineArgs == "" {
t.Fatal("EngineArgs proto field is empty; expected JSON-marshalled map")
}

var round map[string]any
if err := json.Unmarshal([]byte(opts.EngineArgs), &round); err != nil {
t.Fatalf("EngineArgs is not valid JSON: %v", err)
}

if round["data_parallel_size"].(float64) != 8 {
t.Errorf("data_parallel_size lost in roundtrip: %v", round["data_parallel_size"])
}
if round["enable_expert_parallel"].(bool) != true {
t.Errorf("enable_expert_parallel lost in roundtrip: %v", round["enable_expert_parallel"])
}
spec, ok := round["speculative_config"].(map[string]any)
if !ok {
t.Fatalf("speculative_config not preserved as nested object: %T", round["speculative_config"])
}
if spec["method"].(string) != "ngram" {
t.Errorf("speculative_config.method lost in roundtrip: %v", spec["method"])
}
}

func TestGrpcModelOpts_EngineArgsEmptyWhenUnset(t *testing.T) {
threads := 1
cfg := config.ModelConfig{Threads: &threads}

opts := grpcModelOpts(cfg, "/tmp/models")

if opts.EngineArgs != "" {
t.Errorf("expected empty EngineArgs when unset, got %q", opts.EngineArgs)
}
}
25 changes: 25 additions & 0 deletions core/config/hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,30 @@ var _ = Describe("Backend hooks and parser defaults", func() {
}
Expect(count).To(Equal(1))
})

It("seeds production engine_args defaults", func() {
cfg := &ModelConfig{Backend: "vllm"}
cfg.SetDefaults()

Expect(cfg.EngineArgs).NotTo(BeNil())
Expect(cfg.EngineArgs["enable_prefix_caching"]).To(Equal(true))
Expect(cfg.EngineArgs["enable_chunked_prefill"]).To(Equal(true))
})

It("does not override user-set engine_args", func() {
cfg := &ModelConfig{
Backend: "vllm",
LLMConfig: LLMConfig{
EngineArgs: map[string]any{
"enable_prefix_caching": false,
},
},
}
cfg.SetDefaults()

Expect(cfg.EngineArgs["enable_prefix_caching"]).To(Equal(false))
// chunked_prefill is still seeded since user didn't set it
Expect(cfg.EngineArgs["enable_chunked_prefill"]).To(Equal(true))
})
})
})
29 changes: 27 additions & 2 deletions core/config/hooks_vllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,34 @@ func MatchParserDefaults(modelID string) map[string]string {
return nil
}

// productionEngineArgsDefaults are vLLM ≥ 0.6 features that production deployments
// almost always want. Applied at load time when the user hasn't set the key in
// engine_args. Anything user-supplied wins; we never silently override.
var productionEngineArgsDefaults = map[string]any{
"enable_prefix_caching": true,
"enable_chunked_prefill": true,
}

func vllmDefaults(cfg *ModelConfig, modelPath string) {
// Check if user already set tool_parser or reasoning_parser in Options
applyEngineArgDefaults(cfg)
applyParserDefaults(cfg)
}

// applyEngineArgDefaults seeds production-friendly engine_args without overwriting
// anything the user already set.
func applyEngineArgDefaults(cfg *ModelConfig) {
if cfg.EngineArgs == nil {
cfg.EngineArgs = map[string]any{}
}
for k, v := range productionEngineArgsDefaults {
if _, set := cfg.EngineArgs[k]; set {
continue
}
cfg.EngineArgs[k] = v
}
}

func applyParserDefaults(cfg *ModelConfig) {
hasToolParser := false
hasReasoningParser := false
for _, opt := range cfg.Options {
Expand All @@ -61,7 +87,6 @@ func vllmDefaults(cfg *ModelConfig, modelPath string) {
return
}

// Try matching against Model field, then Name
parsers := MatchParserDefaults(cfg.Model)
if parsers == nil {
parsers = MatchParserDefaults(cfg.Name)
Expand Down
8 changes: 7 additions & 1 deletion core/config/model_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,13 @@ type LLMConfig struct {
DisableLogStatus bool `yaml:"disable_log_stats,omitempty" json:"disable_log_stats,omitempty"` // vLLM
DType string `yaml:"dtype,omitempty" json:"dtype,omitempty"` // vLLM
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt,omitempty" json:"limit_mm_per_prompt,omitempty"` // vLLM
MMProj string `yaml:"mmproj,omitempty" json:"mmproj,omitempty"`
// EngineArgs is a backend-native passthrough applied to the engine constructor
// (e.g. vLLM AsyncEngineArgs). Values may be primitives or nested maps; nested
// maps materialise into the backend's nested config dataclasses (e.g.
// SpeculativeConfig, KVTransferConfig, CompilationConfig). Unknown keys cause
// the backend to fail LoadModel with a list of valid names.
EngineArgs map[string]any `yaml:"engine_args,omitempty" json:"engine_args,omitempty"`
MMProj string `yaml:"mmproj,omitempty" json:"mmproj,omitempty"`

FlashAttention *string `yaml:"flash_attention,omitempty" json:"flash_attention,omitempty"`
NoKVOffloading bool `yaml:"no_kv_offloading,omitempty" json:"no_kv_offloading,omitempty"`
Expand Down
Loading