From f094164f3c2f126cc4ab2b2ac4fed3d02b722a1d Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Tue, 18 Mar 2025 03:07:40 +0000 Subject: [PATCH 1/5] fix to support multiple sampling --- functionary/sglang_inference.py | 83 ++++++++++--------- .../sglang_monkey_patch/tokenizer_manager.py | 9 +- pyproject.toml | 2 + server_sglang.py | 7 +- 4 files changed, 60 insertions(+), 41 deletions(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index 1d87a0b4..fcb89a0a 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -159,7 +159,6 @@ def v1_chat_generate_request( top_logprobs_num=request.top_logprobs, stream=request.stream, return_text_in_logprobs=True, - rid=f"chatcmpl-{uuid.uuid4().hex}", ) return adapted_request, request @@ -393,7 +392,7 @@ async def completion_stream_generator(params: ChatCompletionParams): async def v1_chat_generate_completion( params: ChatCompletionParams, -) -> Tuple[Union[StreamingResponse, str], Optional[JSONResponse]]: +) -> Tuple[Union[StreamingResponse, str, List[str]], Optional[JSONResponse]]: """ Generate a text completion. @@ -470,11 +469,16 @@ async def v1_chat_generate_completion( return None, create_error_response( status_code=HTTPStatus.BAD_REQUEST, message=str(e), param=None ) - return ret["text"], None + if ( + type(ret) == list + ): # if n > 1 (multiple samples), we return a list of strings + return [item["text"] for item in ret], None + else: + return ret["text"], None def v1_chat_generate_response( - output_text: str, params: ChatCompletionParams + output_text: Union[str, List[str]], params: ChatCompletionParams ) -> ChatCompletionResponse: """ Generate a ChatCompletionResponse from the output text and parameters. @@ -490,44 +494,49 @@ def v1_chat_generate_response( ChatCompletionResponse: An OpenAI-compatible response containing the assistant's message, usage information, and other metadata. """ + output_texts = output_text if type(output_text) == list else [output_text] + choices = [] + prompt_tokens, completion_tokens = 0, 0 # Parse the output text using the specific prompt template - chat_mess = params.prompt_template.parse_assistant_response( - llm_output=output_text, tool_choice=params.tool_func_choice - ) - # Convert tool_calls to function_call if request.functions is provided - chat_mess = convert_tool_calls_to_function_call( - functions=params.request.functions, chat_message=chat_mess - ) + for text in output_texts: + chat_mess = params.prompt_template.parse_assistant_response( + llm_output=text, tool_choice=params.tool_func_choice + ) + # Convert tool_calls to function_call if request.functions is provided + chat_mess = convert_tool_calls_to_function_call( + functions=params.request.functions, chat_message=chat_mess + ) - # Postprocess finish reason - finish_reason = "stop" - if params.tool_func_choice is None or params.tool_func_choice in [ - "auto", - "required", - ]: - if "function_call" in chat_mess and chat_mess["function_call"]: - finish_reason = "function_call" - if "tool_calls" in chat_mess and chat_mess["tool_calls"]: - finish_reason = "tool_calls" - - choices = [ - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(**chat_mess), - finish_reason=finish_reason, + # Postprocess finish reason + finish_reason = "stop" + if params.tool_func_choice is None or params.tool_func_choice in [ + "auto", + "required", + ]: + if "function_call" in chat_mess and chat_mess["function_call"]: + finish_reason = "function_call" + if "tool_calls" in chat_mess and chat_mess["tool_calls"]: + finish_reason = "tool_calls" + + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(**chat_mess), + finish_reason=finish_reason, + ) ) - ] - prompt_tokens = ( - len(params.adapted_request.input_ids) - if params.adapted_request.input_ids - else len(params.tokenizer.encode(params.adapted_request.text)) - ) - completion_tokens = ( - len(params.tokenizer.encode(output_text, add_special_tokens=False)) + 1 - ) # +1 for the eos token + + prompt_tokens += ( + len(params.adapted_request.input_ids) + if params.adapted_request.input_ids + else len(params.tokenizer.encode(params.adapted_request.text)) + ) + completion_tokens += ( + len(params.tokenizer.encode(text, add_special_tokens=False)) + 1 + ) # +1 for the eos token response = ChatCompletionResponse( - id=params.adapted_request.rid, + id=f"chatcmpl-{uuid.uuid4().hex}", model=params.request.model, choices=choices, usage=UsageInfo( diff --git a/functionary/sglang_monkey_patch/tokenizer_manager.py b/functionary/sglang_monkey_patch/tokenizer_manager.py index 25f06a13..c15aa49e 100644 --- a/functionary/sglang_monkey_patch/tokenizer_manager.py +++ b/functionary/sglang_monkey_patch/tokenizer_manager.py @@ -88,7 +88,14 @@ async def _wait_for_response( # if self.server_args.log_requests and state.finished: if state.finished: if obj.text is None and obj.input_ids is not None: - obj.text = self.tokenizer.decode(obj.input_ids) + if ( + type(obj.input_ids) == list + and len(obj.input_ids) > 0 + and type(obj.input_ids[0]) == list + ): # this is for multiple sampling + obj.text = self.tokenizer.decode(obj.input_ids[0]) + else: + obj.text = self.tokenizer.decode(obj.input_ids) obj.input_ids = None logger.info(dict(input=obj.__dict__, output=out)) diff --git a/pyproject.toml b/pyproject.toml index 1361c498..1afb3d45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ sglang = [ "orjson==3.10.10", "sglang[all]==0.3.4.post1", "flashinfer==0.1.6", + "transformers==4.48.3", + "sgl-kernel>=0.0.4.post3" ] [project.urls] diff --git a/server_sglang.py b/server_sglang.py index 8ed3e75e..9481232b 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -59,9 +59,6 @@ ) from functionary.sglang_inference import v1_chat_completions -from functionary.sglang_monkey_patch.tokenizer_manager import ( - MonkeyPatchTokenizerManager, -) logger = logging.getLogger(__name__) @@ -255,6 +252,10 @@ def launch_engine(server_args: ServerArgs): # Launch tokenizer process if args.logfile is not None: + from functionary.sglang_monkey_patch.tokenizer_manager import ( + MonkeyPatchTokenizerManager, + ) + tokenizer_manager = MonkeyPatchTokenizerManager( server_args, port_args, args.logfile ) From 298500dcf732b525a0b4dc1cdbbcd6d2843e9808 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Tue, 18 Mar 2025 10:06:38 +0000 Subject: [PATCH 2/5] migrate sglang --- functionary/sglang_inference.py | 8 +- pyproject.toml | 4 +- server_sglang.py | 229 ++++++-------------------------- 3 files changed, 44 insertions(+), 197 deletions(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index fcb89a0a..89bcc968 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -23,18 +23,16 @@ import uuid from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, Any import sglang as sgl from fastapi import Request from fastapi.responses import JSONResponse, StreamingResponse -from outlines.fsm.json_schema import build_regex_from_schema from sglang.lang.choices import greedy_token_selection from sglang.lang.interpreter import ProgramState from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api.protocol import ErrorResponse -from sglang.srt.server import Runtime from transformers import AutoTokenizer from functionary.inference_stream import generate_openai_format_from_stream_async @@ -76,7 +74,7 @@ class ChatCompletionParams: request: ChatCompletionRequest tokenizer: AutoTokenizer tokenizer_manager: Optional[TokenizerManager] - srt_backend: Optional[Runtime] + srt_backend: Any prompt_template: PromptTemplate tools_or_functions: List[Dict] tool_func_choice: Optional[Union[str, Tool, Function]] @@ -550,7 +548,7 @@ def v1_chat_generate_response( async def v1_chat_completions( tokenizer_manager: Optional[TokenizerManager], - srt_backend: Optional[Runtime], + srt_backend: Any, raw_request: Request, served_model: List[str], ): diff --git a/pyproject.toml b/pyproject.toml index 1afb3d45..ef2e0895 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,10 +23,10 @@ vllm = [ sglang = [ "python-multipart==0.0.12", "orjson==3.10.10", - "sglang[all]==0.3.4.post1", + "sglang[all]==0.4.4.post1", "flashinfer==0.1.6", "transformers==4.48.3", - "sgl-kernel>=0.0.4.post3" + "sgl-kernel>=0.0.5" ] [project.urls] diff --git a/server_sglang.py b/server_sglang.py index 9481232b..fb049ff7 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -25,7 +25,7 @@ import threading import time from http import HTTPStatus -from typing import AsyncIterator, Dict, List, Optional, Union +from typing import AsyncIterator, Dict, List, Optional, Union, Any, Callable import orjson @@ -43,13 +43,19 @@ from sglang.srt.managers.data_parallel_controller import ( run_data_parallel_controller_process, ) +from sglang.srt.entrypoints.http_server import ( + _launch_subprocesses, + set_uvicorn_logging_configs, + _wait_and_warmup, + enable_func_timer, + add_prometheus_middleware, +) from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api from sglang.srt.openai_api.protocol import ModelCard, ModelList -from sglang.srt.server import Runtime, _set_envs_and_config, _wait_and_warmup from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( add_api_key_middleware, @@ -184,133 +190,56 @@ def available_models(): return ModelList(data=model_cards) -def launch_engine(server_args: ServerArgs): - """ - Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess. - """ - - global tokenizer_manager - - # Configure global environment - configure_logger(server_args) - server_args.check_server_args() - _set_envs_and_config(server_args) - - # Allocate ports for inter-process communications - port_args = PortArgs.init_new(server_args) - logger.info(f"{server_args=}") - - # If using model from www.modelscope.cn, first download the model. - server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( - server_args.model_path, server_args.tokenizer_path - ) - - if server_args.dp_size == 1: - # Launch tensor parallel scheduler processes - scheduler_procs = [] - scheduler_pipe_readers = [] - tp_size_per_node = server_args.tp_size // server_args.nnodes - tp_rank_range = range( - tp_size_per_node * server_args.node_rank, - tp_size_per_node * (server_args.node_rank + 1), - ) - for tp_rank in tp_rank_range: - reader, writer = mp.Pipe(duplex=False) - gpu_id = tp_rank % tp_size_per_node - proc = mp.Process( - target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, None, writer), - ) - proc.start() - scheduler_procs.append(proc) - scheduler_pipe_readers.append(reader) - - if server_args.node_rank >= 1: - # For other nodes, they do not need to run tokenizer or detokenizer, - # so they can just wait here. - while True: - pass - else: - # Launch the data parallel controller - reader, writer = mp.Pipe(duplex=False) - scheduler_pipe_readers = [reader] - proc = mp.Process( - target=run_data_parallel_controller_process, - args=(server_args, port_args, writer), - ) - proc.start() - - # Launch detokenizer process - detoken_proc = mp.Process( - target=run_detokenizer_process, - args=( - server_args, - port_args, - ), - ) - detoken_proc.start() - - # Launch tokenizer process - if args.logfile is not None: - from functionary.sglang_monkey_patch.tokenizer_manager import ( - MonkeyPatchTokenizerManager, - ) - - tokenizer_manager = MonkeyPatchTokenizerManager( - server_args, port_args, args.logfile - ) - else: - tokenizer_manager = TokenizerManager(server_args, port_args) - if server_args.chat_template: - load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) - - # Wait for model to finish loading - for i in range(len(scheduler_pipe_readers)): - scheduler_pipe_readers[i].recv() - - def launch_server( server_args: ServerArgs, pipe_finish_writer: Optional[mp.connection.Connection] = None, + launch_callback: Optional[Callable[[], None]] = None, ): """ - Launch SRT (SGLang Runtime) Server + Launch SRT (SGLang Runtime) Server. - The SRT server consists of an HTTP server and the SRT engine. + The SRT server consists of an HTTP server and an SRT engine. - 1. HTTP server: A FastAPI server that routes requests to the engine. - 2. SRT engine: - 1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler. + - HTTP server: A FastAPI server that routes requests to the engine. + - The engine consists of three components: + 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. - 3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. + 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. Note: - 1. The HTTP server and Tokenizer Manager both run in the main process. - 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. + 1. The HTTP server, Engine, and TokenizerManager both run in the main process. + 2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library. """ - - launch_engine(server_args=server_args) + global tokenizer_manager, scheduler_info + tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args) # Add api key authorization if server_args.api_key: add_api_key_middleware(app, server_args.api_key) - # Send a warmup request - t = threading.Thread( - target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid()) + # Add prometheus middleware + if server_args.enable_metrics: + add_prometheus_middleware(app) + enable_func_timer() + + # Send a warmup request - we will create the thread launch it + # in the lifespan after all other warmups have fired. + warmup_thread = threading.Thread( + target=_wait_and_warmup, + args=( + server_args, + pipe_finish_writer, + "", + launch_callback, + ), ) - t.start() + app.warmup_thread = warmup_thread try: + # Update logging configs + set_uvicorn_logging_configs() + app.server_args = server_args # Listen for HTTP requests - LOGGING_CONFIG["formatters"]["default"][ - "fmt" - ] = "[%(asctime)s] %(levelprefix)s %(message)s" - LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S" - LOGGING_CONFIG["formatters"]["access"][ - "fmt" - ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' - LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" uvicorn.run( app, host=server_args.host, @@ -320,62 +249,7 @@ def launch_server( loop="uvloop", ) finally: - t.join() - - -class FunctionaryRuntime(Runtime): - """ - A wrapper for the server. - This is used for launching the server in a python program without - using the commond line interface. - """ - - def __init__( - self, - log_level: str = "error", - *args, - **kwargs, - ): - """See the arguments in server_args.py::ServerArgs""" - self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) - - # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() - atexit.register(self.shutdown) - - # Pre-allocate ports - for port in range(10000, 40000): - if is_port_available(port): - break - port += 1 - self.server_args.port = port - - self.url = self.server_args.url() - self.generate_url = self.url + "/generate" - - # NOTE: We store pid instead of proc to fix some issues during __delete__ - self.pid = None - pipe_reader, pipe_writer = mp.Pipe(duplex=False) - - proc = mp.Process( - target=launch_server, - args=(self.server_args, pipe_writer), - ) - proc.start() - pipe_writer.close() - self.pid = proc.pid - - try: - init_state = pipe_reader.recv() - except EOFError: - init_state = "" - - if init_state != "ready": - self.shutdown() - raise RuntimeError( - "Initialization failed. Please see the error messages above." - ) - - self.endpoint = RuntimeEndpoint(self.url) + warmup_thread.join() if __name__ == "__main__": @@ -386,13 +260,6 @@ def __init__( default=None, help="enable detailed request input/output logging by providing logfile", ) - # parser.add_argument( - # "--enable-grammar-sampling", - # dest="grammar_sampling", - # action="store_true", - # default=False, - # help="enable grammar sampling for function names", - # ) ServerArgs.add_cli_args(parser) args = parser.parse_args() @@ -403,21 +270,3 @@ def __init__( server_args = ServerArgs.from_cli_args(args) launch_server(server_args) - - # if args.grammar_sampling: - # backend = FunctionaryRuntime(**vars(server_args)) - # sgl.set_default_backend( - # sgl.RuntimeEndpoint( - # f"http://{backend.server_args.host}:{backend.server_args.port}" - # ) - # ) - # uvicorn.run( - # app, - # host=server_args.host, - # port=server_args.port, - # log_level=server_args.log_level_http or server_args.log_level, - # timeout_keep_alive=5, - # loop="uvloop", - # ) - # else: - # launch_server(server_args) From 013aa16e80509e88333da75bec1cf4b6c0736980 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Wed, 19 Mar 2025 17:07:19 +0000 Subject: [PATCH 3/5] fix sglang server --- server_sglang.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/server_sglang.py b/server_sglang.py index fb049ff7..291ea8fa 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -49,6 +49,7 @@ _wait_and_warmup, enable_func_timer, add_prometheus_middleware, + lifespan, ) from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.io_struct import GenerateReqInput @@ -63,7 +64,7 @@ is_port_available, prepare_model_and_tokenizer, ) - +from sglang.srt.utils import kill_process_tree from functionary.sglang_inference import v1_chat_completions logger = logging.getLogger(__name__) @@ -71,7 +72,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -app = FastAPI() +app = FastAPI(lifespan=lifespan) tokenizer_manager = None served_model = [] @@ -269,4 +270,7 @@ def launch_server( server_args = ServerArgs.from_cli_args(args) - launch_server(server_args) + try: + launch_server(server_args) + finally: + kill_process_tree(os.getpid(), include_parent=False) From 48e2b07bc7f7e3d1a8f8a4de0aff5f521514cb11 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Thu, 20 Mar 2025 08:07:03 +0000 Subject: [PATCH 4/5] remove functions for grammar sampling --- functionary/sglang_inference.py | 244 +++++--------------------------- 1 file changed, 39 insertions(+), 205 deletions(-) diff --git a/functionary/sglang_inference.py b/functionary/sglang_inference.py index 89bcc968..056f5ffc 100644 --- a/functionary/sglang_inference.py +++ b/functionary/sglang_inference.py @@ -32,7 +32,6 @@ from sglang.lang.interpreter import ProgramState from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager -from sglang.srt.openai_api.protocol import ErrorResponse from transformers import AutoTokenizer from functionary.inference_stream import generate_openai_format_from_stream_async @@ -162,124 +161,6 @@ def v1_chat_generate_request( return adapted_request, request -@sgl.function -def generate_sglang_srt_response( - s: ProgramState, - prompt: str, - prompt_template, - tools_or_functions, - tool_func_choice, - tokenizer, -): - """ - Generate a response using SGLang Frontend Runtime (SRT). - - This function is used when grammar-sampling is enabled. It uses the SRT program - state to update the specific prompt-template Finite State Machine (FSM) generation - state. Constrained generation is performed at specific stages of the FSM. - - Args: - s (ProgramState): The current program state in SGLang. - prompt (str): The input prompt to generate a response for. - prompt_template: The template used to structure the prompt and response. - tools_or_functions (list): Available tools or functions for the model to use. - tool_func_choice (str): The chosen tool or function choice. - tokenizer: The tokenizer used for encoding and decoding text. - - Returns: - ProgramState: The updated program state after generating the response. - """ - completion_tokens = 0 - stop_tokens = prompt_template.get_stop_tokens_for_generation() - function_call_token = prompt_template.get_start_of_function_call_token() - gen_state = prompt_template.initialize_fsm_gen_state( - tool_choice=tool_func_choice, - curr_text="", - curr_tokens=None, - add_code_interpreter=( - True - if any( - [ - "type" in tool_or_func - and tool_or_func["type"] == "code_interpreter" - for tool_or_func in tools_or_functions - ] - ) - else False - ), - ) - # Form the options for the following stages - tools = [] - for tool in tools_or_functions: - if "type" in tool: - if tool["type"] == "function": - tools.append(tool["function"]) - else: - tools.append(tool) - options = prompt_template.get_options_from_gen_state( - gen_state=gen_state, tools_or_functions=tools - ) - - def check_stop_condition(): - stop_match = s.get_meta_info(CONTENT_VAR)["finish_reason"]["matched"] - if not isinstance(stop_match, str): - stop_match = tokenizer.decode(stop_match) - return stop_match in stop_tokens - - s += prompt - while True: - if gen_state["stage"] == "function": - choices = [ - tool["function"]["name"] - for tool in tools_or_functions - if tool["type"] == "function" - ] - if gen_state["add_all_recipient"]: - choices.append("all") - if gen_state["add_code_interpreter"]: - choices.append("python") - s += sgl.select( - name=CONTENT_VAR, - choices=choices, - choices_method=CHOICES_SAMPLING_METHOD, - ) - new_token = s[CONTENT_VAR] - completion_tokens += len( - tokenizer.encode(s[CONTENT_VAR], add_special_tokens=False) - ) - elif gen_state["stage"] == "pre-parameter": - s += prompt_template.fn_param_sep_token - new_token = prompt_template.fn_param_sep_token - elif gen_state["stage"] == "parameter": - tool = next(t for t in tools if t["name"] == gen_state["func_name"]) - regex = build_regex_from_schema(json.dumps(tool["parameters"])) - s += sgl.gen(name=CONTENT_VAR, regex=regex, stop=function_call_token) - new_token = s[CONTENT_VAR] - completion_tokens += s.get_meta_info(CONTENT_VAR)["completion_tokens"] - # Generate new token to determin if there is another tool call - s += sgl.gen(name=CONTENT_VAR, stop=function_call_token) - if check_stop_condition(): - break - elif gen_state["stage"] in ["text-gen", "code-interpreter"]: - s += sgl.gen(name=CONTENT_VAR, stop=function_call_token) - completion_tokens += s.get_meta_info(CONTENT_VAR)["completion_tokens"] - if check_stop_condition(): - break - else: - s += function_call_token - new_token = s[CONTENT_VAR] + function_call_token - elif gen_state["stage"] == "pre-function": - s += function_call_token - new_token = function_call_token - gen_state = prompt_template.update_fsm_gen_state( - gen_state=gen_state, - new_token=new_token, - new_token_id=None, - options=options, - tokenizer=tokenizer, - ) - - async def wrap_sgl_generator(params: ChatCompletionParams): """ This asynchronous generator function yields generated text chunks along @@ -295,36 +176,23 @@ async def wrap_sgl_generator(params: ChatCompletionParams): - str: The generated text chunk. - Optional[str]: The finish reason, if any (e.g., "stop", "length", etc.). """ - if params.grammar_sampling: - prompt = ( - params.adapted_request.text - if params.adapted_request.text - else params.tokenizer.decode(params.adapted_request.input_ids) - ) - # Iterates over the text generated by the SGLang Frontend Runtime - for out in params.frontend_state.text_iter(): - if out.startswith(prompt): - continue - yield out, None - yield "", "stop" - else: - # Iterates over the text generated by the tokenizer manager - stream_buffer = "" - async for content in params.tokenizer_manager.generate_request( - params.adapted_request, params.raw_request - ): - text = content["text"] - delta = text[len(stream_buffer) :] - stream_buffer = stream_buffer + delta - finish_reason = content["meta_info"]["finish_reason"] - - # If finish_reason is not None and delta_text is not empty, - # the delta_text is the eos_token and just remove it - if finish_reason is not None: - finish_reason = finish_reason["type"] - if len(delta) > 0: - delta = "" - yield delta, finish_reason + # Iterates over the text generated by the tokenizer manager + stream_buffer = "" + async for content in params.tokenizer_manager.generate_request( + params.adapted_request, params.raw_request + ): + text = content["text"] + delta = text[len(stream_buffer) :] + stream_buffer = stream_buffer + delta + finish_reason = content["meta_info"]["finish_reason"] + + # If finish_reason is not None and delta_text is not empty, + # the delta_text is the eos_token and just remove it + if finish_reason is not None: + finish_reason = finish_reason["type"] + if len(delta) > 0: + delta = "" + yield delta, finish_reason async def completion_stream_generator(params: ChatCompletionParams): @@ -413,66 +281,32 @@ async def v1_chat_generate_completion( - Streaming responses are handled by the completion_stream_generator function. """ # If streaming, return the StreamingResponse else return the text - if params.grammar_sampling: - # Form the text prompt and run the SGLang Frontend Runtime - prompt = ( - params.adapted_request.text - if params.adapted_request.text - else params.tokenizer.decode(params.adapted_request.input_ids) - ) - state = generate_sglang_srt_response.run( - prompt=prompt, - prompt_template=params.prompt_template, - tools_or_functions=params.tools_or_functions, - tool_func_choice=params.tool_func_choice, - tokenizer=params.tokenizer, - max_new_tokens=params.request.max_tokens, - temperature=params.request.temperature, - top_p=params.request.top_p, - top_k=params.request.top_k, - frequency_penalty=params.request.frequency_penalty, - presence_penalty=params.request.presence_penalty, - stream=params.request.stream, - ) - - if params.adapted_request.stream: - params.frontend_state = state - return ( - StreamingResponse( - completion_stream_generator(params), - media_type="text/event-stream", + if params.adapted_request.stream: + return ( + StreamingResponse( + completion_stream_generator(params), + media_type="text/event-stream", + background=params.tokenizer_manager.create_abort_task( + params.adapted_request ), - None, - ) - else: - return state.text()[len(prompt) :], None + ), + None, + ) else: - if params.adapted_request.stream: - return ( - StreamingResponse( - completion_stream_generator(params), - media_type="text/event-stream", - background=params.tokenizer_manager.create_abort_task( - params.adapted_request - ), - ), - None, + try: + ret = await params.tokenizer_manager.generate_request( + params.adapted_request, params.raw_request + ).__anext__() + except ValueError as e: + return None, create_error_response( + status_code=HTTPStatus.BAD_REQUEST, message=str(e), param=None ) + if ( + type(ret) == list + ): # if n > 1 (multiple samples), we return a list of strings + return [item["text"] for item in ret], None else: - try: - ret = await params.tokenizer_manager.generate_request( - params.adapted_request, params.raw_request - ).__anext__() - except ValueError as e: - return None, create_error_response( - status_code=HTTPStatus.BAD_REQUEST, message=str(e), param=None - ) - if ( - type(ret) == list - ): # if n > 1 (multiple samples), we return a list of strings - return [item["text"] for item in ret], None - else: - return ret["text"], None + return ret["text"], None def v1_chat_generate_response( From b2bd49b7ff874d99e5b56657654934409baaa480 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Thu, 20 Mar 2025 10:06:29 +0000 Subject: [PATCH 5/5] update requirements --- README.md | 2 +- pyproject.toml | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 7c53ca18..06b87b8a 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ pip install -e .[vllm] ``` **SGLang** ```shell -pip install -e .[sglang] --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/ +pip install -e .[sglang] --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python ``` ### Running the server diff --git a/pyproject.toml b/pyproject.toml index ef2e0895..83067ff5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,9 +24,7 @@ sglang = [ "python-multipart==0.0.12", "orjson==3.10.10", "sglang[all]==0.4.4.post1", - "flashinfer==0.1.6", - "transformers==4.48.3", - "sgl-kernel>=0.0.5" + "transformers==4.48.3" ] [project.urls]