diff --git a/backend/backend.proto b/backend/backend.proto index 0c54d7307e33..43b6abe6c69f 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -310,6 +310,11 @@ message ModelOptions { bool Reranking = 71; repeated string Overrides = 72; + + // EngineArgs carries a JSON-encoded map of backend-native engine arguments + // applied verbatim to the backend's engine constructor (e.g. vLLM AsyncEngineArgs). + // Unknown keys produce an error at LoadModel time. + string EngineArgs = 73; } message Result { diff --git a/backend/python/vllm/backend.py b/backend/python/vllm/backend.py index 95ae95a9d4e6..fcdbb96cde62 100644 --- a/backend/python/vllm/backend.py +++ b/backend/python/vllm/backend.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 import asyncio +import dataclasses +import difflib from concurrent import futures import argparse import signal @@ -101,6 +103,36 @@ def _parse_options(self, options_list): opts[key.strip()] = value.strip() return opts + def _apply_engine_args(self, engine_args, engine_args_json): + """Apply user-supplied engine_args (JSON object) onto an AsyncEngineArgs. + + Returns a new AsyncEngineArgs with the typed fields preserved and the + user's overrides layered on top. Uses ``dataclasses.replace`` so vLLM's + ``__post_init__`` re-runs and auto-converts dict-valued fields like + ``compilation_config`` / ``attention_config`` into their dataclass form. + ``speculative_config`` and ``kv_transfer_config`` are accepted as dicts + directly (vLLM converts them at engine init). + + Unknown keys raise ValueError with the closest valid field as a hint. + """ + if not engine_args_json: + return engine_args + try: + extra = json.loads(engine_args_json) + except json.JSONDecodeError as e: + raise ValueError(f"engine_args is not valid JSON: {e}") from e + if not isinstance(extra, dict): + raise ValueError( + f"engine_args must be a JSON object, got {type(extra).__name__}" + ) + valid = {f.name for f in dataclasses.fields(type(engine_args))} + for key in extra: + if key not in valid: + suggestion = difflib.get_close_matches(key, valid, n=1) + hint = f" did you mean {suggestion[0]!r}?" if suggestion else "" + raise ValueError(f"unknown engine_args key {key!r}.{hint}") + return dataclasses.replace(engine_args, **extra) + def _messages_to_dicts(self, messages): """Convert proto Messages to list of dicts suitable for apply_chat_template().""" result = [] @@ -176,6 +208,15 @@ async def LoadModel(self, request, context): "audio": max(request.LimitAudioPerPrompt, 1) } + # engine_args from YAML overrides typed fields above so operators can + # tune anything the AsyncEngineArgs dataclass exposes without waiting + # on protobuf changes. + try: + engine_args = self._apply_engine_args(engine_args, request.EngineArgs) + except ValueError as err: + print(f"engine_args error: {err}", file=sys.stderr) + return backend_pb2.Result(success=False, message=str(err)) + try: self.llm = AsyncLLMEngine.from_engine_args(engine_args) except Exception as err: diff --git a/backend/python/vllm/requirements-cublas12-after.txt b/backend/python/vllm/requirements-cublas12-after.txt index cab27c888e27..e6a61ea11ea2 100644 --- a/backend/python/vllm/requirements-cublas12-after.txt +++ b/backend/python/vllm/requirements-cublas12-after.txt @@ -1,2 +1,9 @@ -https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl +# flash-attn wheels are ABI-tied to a specific torch version. vllm forces +# torch==2.10.0 as a hard dep, but flash-attn 2.8.3 (latest) only ships +# prebuilt wheels up to torch 2.8 — any wheel we pin here gets silently +# broken when vllm upgrades torch during install, producing an undefined +# libc10_cuda symbol at import time. FlashInfer (required by vllm) covers +# attention, and rotary_embedding/common.py guards the flash_attn import +# with find_spec(), so skipping flash-attn is safe and the only stable +# choice until upstream ships a torch-2.10 wheel. vllm diff --git a/backend/python/vllm/requirements-cublas12.txt b/backend/python/vllm/requirements-cublas12.txt index 8bd72ae125fd..e007f0946daa 100644 --- a/backend/python/vllm/requirements-cublas12.txt +++ b/backend/python/vllm/requirements-cublas12.txt @@ -1,4 +1,4 @@ accelerate -torch==2.7.0 +torch transformers bitsandbytes \ No newline at end of file diff --git a/backend/python/vllm/test.py b/backend/python/vllm/test.py index 21aaf4cf785e..25a7f54e6354 100644 --- a/backend/python/vllm/test.py +++ b/backend/python/vllm/test.py @@ -168,6 +168,58 @@ def test_parse_options(self): self.assertEqual(opts["key_with_colons"], "a:b:c") self.assertNotIn("invalid_no_colon", opts) + def test_apply_engine_args_known_keys(self): + """ + Tests _apply_engine_args overlays user-supplied JSON onto AsyncEngineArgs. + """ + import sys, os, json as _json + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from backend import BackendServicer + from vllm.engine.arg_utils import AsyncEngineArgs + + servicer = BackendServicer() + base = AsyncEngineArgs(model="facebook/opt-125m") + extras = _json.dumps({ + "trust_remote_code": True, + "max_num_seqs": 32, + }) + out = servicer._apply_engine_args(base, extras) + self.assertTrue(out.trust_remote_code) + self.assertEqual(out.max_num_seqs, 32) + # untouched fields preserved + self.assertEqual(out.model, "facebook/opt-125m") + + def test_apply_engine_args_unknown_key_raises(self): + """ + Tests _apply_engine_args rejects unknown keys with a helpful suggestion. + """ + import sys, os, json as _json + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from backend import BackendServicer + from vllm.engine.arg_utils import AsyncEngineArgs + + servicer = BackendServicer() + base = AsyncEngineArgs(model="facebook/opt-125m") + with self.assertRaises(ValueError) as ctx: + servicer._apply_engine_args(base, _json.dumps({"trustremotecode": True})) + self.assertIn("trustremotecode", str(ctx.exception)) + # close-match hint for the typo + self.assertIn("trust_remote_code", str(ctx.exception)) + + def test_apply_engine_args_empty_passthrough(self): + """ + Tests that empty engine_args returns the base unchanged. + """ + import sys, os + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from backend import BackendServicer + from vllm.engine.arg_utils import AsyncEngineArgs + + servicer = BackendServicer() + base = AsyncEngineArgs(model="facebook/opt-125m") + self.assertIs(servicer._apply_engine_args(base, ""), base) + self.assertIs(servicer._apply_engine_args(base, None), base) + def test_tokenize_string(self): """ Tests the TokenizeString RPC returns valid tokens. diff --git a/core/backend/options.go b/core/backend/options.go index b09782ce2ca7..5b42a6354749 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -1,6 +1,7 @@ package backend import ( + "encoding/json" "math/rand/v2" "os" "path/filepath" @@ -159,6 +160,18 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions { }) } + engineArgsJSON := "" + if len(c.EngineArgs) > 0 { + if buf, err := json.Marshal(c.EngineArgs); err != nil { + // EngineArgs failing to marshal is a config bug, not a runtime + // condition — surface it loudly but don't break load (the backend + // will see an empty string and proceed with typed args only). + xlog.Warn("engine_args failed to marshal; ignoring", "model", c.Model, "err", err) + } else { + engineArgsJSON = string(buf) + } + } + opts := &pb.ModelOptions{ CUDA: c.CUDA || c.Diffusers.CUDA, SchedulerType: c.Diffusers.SchedulerType, @@ -176,6 +189,7 @@ func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions { CLIPSubfolder: c.Diffusers.ClipSubFolder, Options: c.Options, Overrides: c.Overrides, + EngineArgs: engineArgsJSON, CLIPSkip: int32(c.Diffusers.ClipSkip), ControlNet: c.Diffusers.ControlNet, ContextSize: int32(ctxSize), diff --git a/core/backend/options_internal_test.go b/core/backend/options_internal_test.go new file mode 100644 index 000000000000..f44678c1ad49 --- /dev/null +++ b/core/backend/options_internal_test.go @@ -0,0 +1,61 @@ +package backend + +import ( + "encoding/json" + "testing" + + "github.com/mudler/LocalAI/core/config" +) + +func TestGrpcModelOpts_EngineArgsSerialization(t *testing.T) { + threads := 1 + cfg := config.ModelConfig{ + Threads: &threads, + LLMConfig: config.LLMConfig{ + EngineArgs: map[string]any{ + "data_parallel_size": 8, + "enable_expert_parallel": true, + "speculative_config": map[string]any{ + "method": "ngram", + "num_speculative_tokens": 4, + }, + }, + }, + } + + opts := grpcModelOpts(cfg, "/tmp/models") + + if opts.EngineArgs == "" { + t.Fatal("EngineArgs proto field is empty; expected JSON-marshalled map") + } + + var round map[string]any + if err := json.Unmarshal([]byte(opts.EngineArgs), &round); err != nil { + t.Fatalf("EngineArgs is not valid JSON: %v", err) + } + + if round["data_parallel_size"].(float64) != 8 { + t.Errorf("data_parallel_size lost in roundtrip: %v", round["data_parallel_size"]) + } + if round["enable_expert_parallel"].(bool) != true { + t.Errorf("enable_expert_parallel lost in roundtrip: %v", round["enable_expert_parallel"]) + } + spec, ok := round["speculative_config"].(map[string]any) + if !ok { + t.Fatalf("speculative_config not preserved as nested object: %T", round["speculative_config"]) + } + if spec["method"].(string) != "ngram" { + t.Errorf("speculative_config.method lost in roundtrip: %v", spec["method"]) + } +} + +func TestGrpcModelOpts_EngineArgsEmptyWhenUnset(t *testing.T) { + threads := 1 + cfg := config.ModelConfig{Threads: &threads} + + opts := grpcModelOpts(cfg, "/tmp/models") + + if opts.EngineArgs != "" { + t.Errorf("expected empty EngineArgs when unset, got %q", opts.EngineArgs) + } +} diff --git a/core/config/hooks_test.go b/core/config/hooks_test.go index b97077564e96..12aad2558564 100644 --- a/core/config/hooks_test.go +++ b/core/config/hooks_test.go @@ -110,5 +110,30 @@ var _ = Describe("Backend hooks and parser defaults", func() { } Expect(count).To(Equal(1)) }) + + It("seeds production engine_args defaults", func() { + cfg := &ModelConfig{Backend: "vllm"} + cfg.SetDefaults() + + Expect(cfg.EngineArgs).NotTo(BeNil()) + Expect(cfg.EngineArgs["enable_prefix_caching"]).To(Equal(true)) + Expect(cfg.EngineArgs["enable_chunked_prefill"]).To(Equal(true)) + }) + + It("does not override user-set engine_args", func() { + cfg := &ModelConfig{ + Backend: "vllm", + LLMConfig: LLMConfig{ + EngineArgs: map[string]any{ + "enable_prefix_caching": false, + }, + }, + } + cfg.SetDefaults() + + Expect(cfg.EngineArgs["enable_prefix_caching"]).To(Equal(false)) + // chunked_prefill is still seeded since user didn't set it + Expect(cfg.EngineArgs["enable_chunked_prefill"]).To(Equal(true)) + }) }) }) diff --git a/core/config/hooks_vllm.go b/core/config/hooks_vllm.go index 3f7abd9b393a..ffdd1a52a2aa 100644 --- a/core/config/hooks_vllm.go +++ b/core/config/hooks_vllm.go @@ -45,8 +45,34 @@ func MatchParserDefaults(modelID string) map[string]string { return nil } +// productionEngineArgsDefaults are vLLM ≥ 0.6 features that production deployments +// almost always want. Applied at load time when the user hasn't set the key in +// engine_args. Anything user-supplied wins; we never silently override. +var productionEngineArgsDefaults = map[string]any{ + "enable_prefix_caching": true, + "enable_chunked_prefill": true, +} + func vllmDefaults(cfg *ModelConfig, modelPath string) { - // Check if user already set tool_parser or reasoning_parser in Options + applyEngineArgDefaults(cfg) + applyParserDefaults(cfg) +} + +// applyEngineArgDefaults seeds production-friendly engine_args without overwriting +// anything the user already set. +func applyEngineArgDefaults(cfg *ModelConfig) { + if cfg.EngineArgs == nil { + cfg.EngineArgs = map[string]any{} + } + for k, v := range productionEngineArgsDefaults { + if _, set := cfg.EngineArgs[k]; set { + continue + } + cfg.EngineArgs[k] = v + } +} + +func applyParserDefaults(cfg *ModelConfig) { hasToolParser := false hasReasoningParser := false for _, opt := range cfg.Options { @@ -61,7 +87,6 @@ func vllmDefaults(cfg *ModelConfig, modelPath string) { return } - // Try matching against Model field, then Name parsers := MatchParserDefaults(cfg.Model) if parsers == nil { parsers = MatchParserDefaults(cfg.Name) diff --git a/core/config/model_config.go b/core/config/model_config.go index 1184d8452a71..4229ce5a7e3f 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -241,7 +241,13 @@ type LLMConfig struct { DisableLogStatus bool `yaml:"disable_log_stats,omitempty" json:"disable_log_stats,omitempty"` // vLLM DType string `yaml:"dtype,omitempty" json:"dtype,omitempty"` // vLLM LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt,omitempty" json:"limit_mm_per_prompt,omitempty"` // vLLM - MMProj string `yaml:"mmproj,omitempty" json:"mmproj,omitempty"` + // EngineArgs is a backend-native passthrough applied to the engine constructor + // (e.g. vLLM AsyncEngineArgs). Values may be primitives or nested maps; nested + // maps materialise into the backend's nested config dataclasses (e.g. + // SpeculativeConfig, KVTransferConfig, CompilationConfig). Unknown keys cause + // the backend to fail LoadModel with a list of valid names. + EngineArgs map[string]any `yaml:"engine_args,omitempty" json:"engine_args,omitempty"` + MMProj string `yaml:"mmproj,omitempty" json:"mmproj,omitempty"` FlashAttention *string `yaml:"flash_attention,omitempty" json:"flash_attention,omitempty"` NoKVOffloading bool `yaml:"no_kv_offloading,omitempty" json:"no_kv_offloading,omitempty"`