feat(vllm): expose AsyncEngineArgs via generic engine_args YAML map#9563
Draft
feat(vllm): expose AsyncEngineArgs via generic engine_args YAML map#9563
Conversation
The pinned flash-attn 2.8.3+cu12torch2.7 wheel breaks at import time
once vllm 0.19.1 upgrades torch to its hard-pinned 2.10.0:
ImportError: .../flash_attn_2_cuda...so: undefined symbol:
_ZN3c104cuda29c10_cuda_check_implementationEiPKcS2_ib
That C10 CUDA symbol is libtorch-version-specific. Dao-AILab has not yet
published flash-attn wheels for torch 2.10 -- the latest release (2.8.3)
tops out at torch 2.8 -- so any wheel pinned here is silently ABI-broken
the moment vllm completes its install.
vllm 0.19.1 lists flashinfer-python==0.6.6 as a hard dep, which already
covers the attention path. The only other use of flash-attn in vllm is
the rotary apply_rotary import in
vllm/model_executor/layers/rotary_embedding/common.py, which is guarded
by find_spec("flash_attn") and falls back cleanly when absent.
Also unpin torch in requirements-cublas12.txt: the 2.7.0 pin only
existed to give the flash-attn wheel a matching torch to link against.
With flash-attn gone, vllm's own torch==2.10.0 dep is the binding
constraint regardless of what we put here.
Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>
LocalAI's vLLM backend wraps a small typed subset of vLLM's
AsyncEngineArgs (quantization, tensor_parallel_size, dtype, etc.).
Anything outside that subset -- pipeline/data/expert parallelism,
speculative_config, kv_transfer_config, all2all_backend, prefix
caching, chunked prefill, etc. -- requires a new protobuf field, a
Go struct field, an options.go line, and a backend.py mapping per
feature. That cadence is the bottleneck on shipping vLLM's
production feature set.
Add a generic `engine_args:` map on the model YAML that is
JSON-serialised into a new ModelOptions.EngineArgs proto field and
applied verbatim to AsyncEngineArgs at LoadModel time. Validation
is done by the Python backend via dataclasses.fields(); unknown
keys fail with the closest valid name as a hint.
dataclasses.replace() is used so vLLM's __post_init__ re-runs and
auto-converts dict values into nested config dataclasses
(CompilationConfig, AttentionConfig, ...). speculative_config and
kv_transfer_config flow through as dicts; vLLM converts them at
engine init.
Operators can now write:
engine_args:
data_parallel_size: 8
enable_expert_parallel: true
all2all_backend: deepep_low_latency
speculative_config:
method: deepseek_mtp
num_speculative_tokens: 3
kv_cache_dtype: fp8
without further proto/Go/Python plumbing per field.
Production defaults seeded by hooks_vllm.go: enable_prefix_caching
and enable_chunked_prefill default to true unless explicitly set.
Existing typed YAML fields (gpu_memory_utilization,
tensor_parallel_size, etc.) remain for back-compat; engine_args
overrides them when both are set.
Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>
mudler
reviewed
Apr 25, 2026
| // EngineArgs carries a JSON-encoded map of backend-native engine arguments | ||
| // applied verbatim to the backend's engine constructor (e.g. vLLM AsyncEngineArgs). | ||
| // Unknown keys produce an error at LoadModel time. | ||
| string EngineArgs = 73; |
Owner
There was a problem hiding this comment.
while I'm ok with it in general, we already carry a repeated string Options = 62; that is used already for this purpose (which calls already for refactoring, as would make much more sense to have a map<string, string> instead)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Allow arbitrary vLLM options to be set exposing far more features
Notes for Reviewers
Signed commits