From e2d599791cb670b23a4e09b4a4555437ce4d4818 Mon Sep 17 00:00:00 2001 From: Collier King Date: Sun, 21 Jun 2026 17:35:13 -0500 Subject: [PATCH 1/2] add glm 5.2 and openai endpoint mode --- CHANGELOG.md | 9 + libs/langchain-cloudflare/README.md | 25 +++ .../examples/workers/src/entry.py | 1 + .../langchain_cloudflare/chat_models.py | 138 +++++++++--- libs/langchain-cloudflare/pyproject.toml | 2 +- .../test_worker_integration.py | 2 + .../test_workersai_models.py | 196 +++++++++++++++++- .../tests/unit_tests/test_chat_models.py | 140 +++++++++++++ 8 files changed, 472 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 38bad91..447b37c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## langchain-cloudflare +### [0.3.5] + +#### Added + +- **`@cf/zai-org/glm-5.2`**: Adds GLM-5.2 to the tested and example-supported model lists with reasoning content, tool calling, structured output, and model-specific parameter handling for Workers AI's OpenAI-compatible schema. +- **OpenAI-compatible endpoint mode**: Adds `endpoint_format="openai_compatible"` to `ChatCloudflareWorkersAI` for REST chat completions routing through `/ai/v1/chat/completions` or AI Gateway `/workers-ai/v1/chat/completions`, with focused text and structured vision regression coverage. + +--- + ### [0.3.4] #### Changed diff --git a/libs/langchain-cloudflare/README.md b/libs/langchain-cloudflare/README.md index 4a0c69a..9b11f5e 100644 --- a/libs/langchain-cloudflare/README.md +++ b/libs/langchain-cloudflare/README.md @@ -34,6 +34,31 @@ llm = ChatCloudflareWorkersAI() llm.invoke("Sing a ballad of LangChain.") ``` +### REST endpoint format + +By default, `ChatCloudflareWorkersAI` uses the native Workers AI run endpoint: + +```python +llm = ChatCloudflareWorkersAI( + model="@cf/moonshotai/kimi-k2.6", + endpoint_format="workers_ai", # default +) +``` + +For REST calls that need Cloudflare's OpenAI-compatible chat completions API, +set `endpoint_format="openai_compatible"`: + +```python +llm = ChatCloudflareWorkersAI( + model="@cf/moonshotai/kimi-k2.6", + endpoint_format="openai_compatible", +) +``` + +When `ai_gateway` is configured, OpenAI-compatible mode routes through the +Workers AI chat completions path on AI Gateway. This option is REST-only; Worker +bindings use `env.AI.run()` and do not expose a chat completions route. + ## Embeddings `CloudflareWorkersAIEmbeddings` class exposes embeddings from [CloudflareWorkersAI](https://developers.cloudflare.com/workers-ai/). diff --git a/libs/langchain-cloudflare/examples/workers/src/entry.py b/libs/langchain-cloudflare/examples/workers/src/entry.py index e580fbf..b0b16bb 100644 --- a/libs/langchain-cloudflare/examples/workers/src/entry.py +++ b/libs/langchain-cloudflare/examples/workers/src/entry.py @@ -33,6 +33,7 @@ "@cf/mistralai/mistral-small-3.1-24b-instruct", "@cf/qwen/qwen3-30b-a3b-fp8", "@cf/zai-org/glm-4.7-flash", + "@cf/zai-org/glm-5.2", "@cf/openai/gpt-oss-120b", "@cf/openai/gpt-oss-20b", "@cf/nvidia/nemotron-3-120b-a12b", diff --git a/libs/langchain-cloudflare/langchain_cloudflare/chat_models.py b/libs/langchain-cloudflare/langchain_cloudflare/chat_models.py index beee2b6..304fbc1 100644 --- a/libs/langchain-cloudflare/langchain_cloudflare/chat_models.py +++ b/libs/langchain-cloudflare/langchain_cloudflare/chat_models.py @@ -218,6 +218,11 @@ def _normalize_response_format_for_cloudflare( ) MODEL_BEHAVIORS: Dict[str, ModelBehavior] = { + "glm-5.2": ModelBehavior( + embed_tool_calls_in_content=False, + unsupported_params=("top_k", "repetition_penalty"), + supports_reasoning_content=True, + ), "glm": ModelBehavior( embed_tool_calls_in_content=False, unsupported_params=("max_tokens", "top_k", "repetition_penalty", "tool_choice"), @@ -369,6 +374,9 @@ async def fetch(self, request): base_url: Optional[str] Base URL path for API requests, leave blank if not using a proxy or service emulator. + endpoint_format: Literal["workers_ai", "openai_compatible"] + REST endpoint format. Defaults to native Workers AI run endpoints. + Set to ``"openai_compatible"`` to use chat completions endpoints. binding: Optional[Any] Workers AI binding (env.AI) for use in Python Workers. When provided, uses the binding instead of REST API calls. @@ -414,6 +422,16 @@ async def fetch(self, request): alias="cloudflare_ai_gateway", default_factory=from_env("AI_GATEWAY", default=None), ) + endpoint_format: Literal["workers_ai", "openai_compatible"] = "workers_ai" + """REST endpoint format to use. + + ``"workers_ai"`` uses Workers AI native run endpoints: + ``/ai/run/{model}`` or AI Gateway ``/workers-ai/run/{model}``. + ``"openai_compatible"`` uses chat completions endpoints and includes + ``model`` in the JSON payload: + ``/ai/v1/chat/completions`` or AI Gateway + ``/workers-ai/v1/chat/completions``. + """ request_timeout: Union[float, Tuple[float, float], Any, None] = Field( default=None, alias="timeout" ) @@ -495,6 +513,12 @@ def validate_environment(self) -> Self: # If binding is provided, skip REST API setup if self.binding is not None: + if self.endpoint_format != "workers_ai": + raise ValueError( + "endpoint_format='openai_compatible' is only supported for " + "REST API calls. Workers AI bindings use env.AI.run() and " + "do not expose a chat completions endpoint." + ) # When using binding, we don't need api_token or account_id return self @@ -813,6 +837,54 @@ def _translate_params_for_model(self, params: Dict[str, Any]) -> Dict[str, Any]: return params + def _get_api_url(self) -> str: + """Return the REST API path for the configured endpoint format.""" + if self.endpoint_format == "openai_compatible": + if self.ai_gateway: + return "workers-ai/v1/chat/completions" + return f"accounts/{self.account_id}/ai/v1/chat/completions" + + if self.ai_gateway: + return f"workers-ai/run/{self.model}" + return f"accounts/{self.account_id}/ai/run/{self.model}" + + def _create_request_payload( + self, + message_dicts: List[Dict[str, Any]], + params: Dict[str, Any], + ) -> Dict[str, Any]: + """Build the REST request payload for the configured endpoint format.""" + if self.endpoint_format == "openai_compatible": + return { + **params, + "model": self.model, + "messages": message_dicts, + } + + return {"messages": message_dicts, **params} + + def _create_openai_stream_chunk( + self, + chunk: Dict[str, Any], + ) -> Optional[ChatGenerationChunk]: + """Convert an OpenAI-compatible SSE chunk to a LangChain chunk.""" + choices = chunk.get("choices") or [] + if not choices: + return None + + choice = choices[0] + delta = choice.get("delta") or choice.get("message") or {} + response_text = _normalize_message_content(delta.get("content", "")) + + generation_info = {} + if "usage" in chunk: + generation_info["usage"] = chunk["usage"] + + return ChatGenerationChunk( + message=AIMessageChunk(content=response_text), + generation_info=generation_info or None, + ) + # MARK: - Generate def _generate( # type: ignore self, @@ -830,7 +902,7 @@ def _generate( # type: ignore params = self._translate_params_for_model(params) # Create the request payload - payload = {"messages": message_dicts, **params} + payload = self._create_request_payload(message_dicts, params) # Use binding if available (for Python Workers) if self.binding is not None: @@ -846,13 +918,7 @@ def _generate( # type: ignore response_data = loop.run_until_complete(self._call_binding(payload)) else: # Use REST API (httpx client) - # Construct the API URL - if self.ai_gateway: - # If using AI Gateway - api_url = f"workers-ai/run/{self.model}" - else: - # If using direct API - api_url = f"accounts/{self.account_id}/ai/run/{self.model}" + api_url = self._get_api_url() # Make the API request response = self.client.post(api_url, json=payload) @@ -875,18 +941,14 @@ async def _agenerate( # type: ignore params = self._translate_params_for_model(params) # Create the request payload - payload = {"messages": message_dicts, **params} + payload = self._create_request_payload(message_dicts, params) # Use binding if available (for Python Workers) if self.binding is not None: response_data = await self._call_binding(payload) else: # Use REST API (httpx async client) - # Construct the Cloudflare Workers AI API URL - if self.ai_gateway: - api_url = f"workers-ai/run/{self.model}" - else: - api_url = f"accounts/{self.account_id}/ai/run/{self.model}" + api_url = self._get_api_url() # Make the API request response = await self.async_client.post(api_url, json=payload) @@ -916,14 +978,9 @@ def _stream( params = {**params, **kwargs, "stream": True} params = self._translate_params_for_model(params) - # Construct the Cloudflare Workers AI API URL - if self.ai_gateway: - api_url = f"workers-ai/run/{self.model}" - else: - api_url = f"accounts/{self.account_id}/ai/run/{self.model}" - # Create the request payload - payload = {"messages": message_dicts, **params} + api_url = self._get_api_url() + payload = self._create_request_payload(message_dicts, params) # Make the streaming API request with self.client.stream("POST", api_url, json=payload) as response: @@ -950,8 +1007,21 @@ def _stream( except json.JSONDecodeError: continue + if "choices" in chunk: + generation_chunk = self._create_openai_stream_chunk(chunk) + if generation_chunk is None: + continue + + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, + chunk=generation_chunk, + ) + + yield generation_chunk + # Process streaming response - if "streamed_data" in chunk: + elif "streamed_data" in chunk: # Handle streamed_data format from Cloudflare for stream_chunk in chunk["streamed_data"]: response_text = stream_chunk.get("response", "") @@ -1080,14 +1150,9 @@ async def _astream( params = {**params, **kwargs, "stream": True} params = self._translate_params_for_model(params) - # Construct the Cloudflare Workers AI API URL - if self.ai_gateway: - api_url = f"workers-ai/run/{self.model}" - else: - api_url = f"accounts/{self.account_id}/ai/run/{self.model}" - # Create the request payload - payload = {"messages": message_dicts, **params} + api_url = self._get_api_url() + payload = self._create_request_payload(message_dicts, params) # Make the streaming API request async with self.async_client.stream("POST", api_url, json=payload) as response: @@ -1114,8 +1179,21 @@ async def _astream( except json.JSONDecodeError: continue + if "choices" in chunk: + generation_chunk = self._create_openai_stream_chunk(chunk) + if generation_chunk is None: + continue + + if run_manager: + await run_manager.on_llm_new_token( + token=generation_chunk.text, + chunk=generation_chunk, + ) + + yield generation_chunk + # Handle the streamed_data format - if "streamed_data" in chunk: + elif "streamed_data" in chunk: for stream_chunk in chunk["streamed_data"]: response_text = stream_chunk.get("response", "") accumulated_content += response_text diff --git a/libs/langchain-cloudflare/pyproject.toml b/libs/langchain-cloudflare/pyproject.toml index 59383f5..dd7b829 100644 --- a/libs/langchain-cloudflare/pyproject.toml +++ b/libs/langchain-cloudflare/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "langchain-cloudflare" -version = "0.3.4" +version = "0.3.5" description = "Langchain Integrations for Cloudflare's WorkersAI and Vectorize" readme = "README.md" license = "MIT" diff --git a/libs/langchain-cloudflare/tests/integration_tests/test_worker_integration.py b/libs/langchain-cloudflare/tests/integration_tests/test_worker_integration.py index 5cb8b3d..fff7287 100644 --- a/libs/langchain-cloudflare/tests/integration_tests/test_worker_integration.py +++ b/libs/langchain-cloudflare/tests/integration_tests/test_worker_integration.py @@ -31,6 +31,7 @@ "@cf/mistralai/mistral-small-3.1-24b-instruct", "@cf/qwen/qwen3-30b-a3b-fp8", "@cf/zai-org/glm-4.7-flash", + "@cf/zai-org/glm-5.2", "@cf/openai/gpt-oss-120b", "@cf/openai/gpt-oss-20b", "@cf/nvidia/nemotron-3-120b-a12b", @@ -906,6 +907,7 @@ class TestWorkerReasoningContent: REASONING_MODELS = [ "@cf/qwen/qwen3-30b-a3b-fp8", "@cf/zai-org/glm-4.7-flash", + "@cf/zai-org/glm-5.2", "@cf/openai/gpt-oss-120b", "@cf/openai/gpt-oss-20b", "@cf/moonshotai/kimi-k2.5", diff --git a/libs/langchain-cloudflare/tests/integration_tests/test_workersai_models.py b/libs/langchain-cloudflare/tests/integration_tests/test_workersai_models.py index bb62b68..84df06a 100644 --- a/libs/langchain-cloudflare/tests/integration_tests/test_workersai_models.py +++ b/libs/langchain-cloudflare/tests/integration_tests/test_workersai_models.py @@ -60,6 +60,7 @@ "@cf/mistralai/mistral-small-3.1-24b-instruct", "@cf/qwen/qwen3-30b-a3b-fp8", "@cf/zai-org/glm-4.7-flash", + "@cf/zai-org/glm-5.2", "@cf/openai/gpt-oss-120b", "@cf/openai/gpt-oss-20b", "@cf/nvidia/nemotron-3-120b-a12b", @@ -80,6 +81,12 @@ "@cf/google/gemma-4-26b-a4b-it", ] +# Focused models for endpoint-format parity tests. Keep this small so adding +# OpenAI-compatible coverage does not double the full integration matrix. +OPENAI_COMPAT_MODELS = [ + "@cf/moonshotai/kimi-k2.6", +] + # Pydantic schema for structured output class Entity(BaseModel): @@ -110,6 +117,14 @@ class Data(BaseModel): announcements: List[Announcement] = Field(default_factory=list) +class ImageExtraction(BaseModel): + """Fields extracted from a labeled test image.""" + + ticker: str = Field(description="Ticker symbol shown in the image") + timeframe: str = Field(description="Timeframe shown in the image") + date: str = Field(description="Date shown in the image") + + # Tool for tool calling tests @tool def get_weather(city: str) -> str: @@ -140,7 +155,11 @@ def ai_gateway(): def create_llm( - model: str, account_id: str, api_token: str, ai_gateway: Optional[str] = None + model: str, + account_id: str, + api_token: str, + ai_gateway: Optional[str] = None, + endpoint_format: str = "workers_ai", ): """Create a ChatCloudflareWorkersAI instance.""" return ChatCloudflareWorkersAI( @@ -149,6 +168,7 @@ def create_llm( model=model, temperature=0.0, ai_gateway=ai_gateway, + endpoint_format=endpoint_format, ) @@ -747,6 +767,7 @@ class TestReasoningContent: REASONING_MODELS = [ "@cf/qwen/qwen3-30b-a3b-fp8", "@cf/zai-org/glm-4.7-flash", + "@cf/zai-org/glm-5.2", "@cf/openai/gpt-oss-120b", "@cf/openai/gpt-oss-20b", "@cf/moonshotai/kimi-k2.5", @@ -893,6 +914,19 @@ def test_no_reasoning_content_for_llama(self, account_id, api_token, ai_gateway) # MARK: - Multi-Modal Tests +def _png_chunk(chunk_type: bytes, data: bytes) -> bytes: + """Create a PNG chunk.""" + import struct + import zlib + + chunk = chunk_type + data + return ( + struct.pack(">I", len(data)) + + chunk + + struct.pack(">I", zlib.crc32(chunk) & 0xFFFFFFFF) + ) + + def create_test_image_base64() -> str: """Create a minimal 1x1 red pixel PNG and return as base64. @@ -901,14 +935,6 @@ def create_test_image_base64() -> str: import struct import zlib - def _png_chunk(chunk_type: bytes, data: bytes) -> bytes: - chunk = chunk_type + data - return ( - struct.pack(">I", len(data)) - + chunk - + struct.pack(">I", zlib.crc32(chunk) & 0xFFFFFFFF) - ) - width, height = 1, 1 ihdr_data = struct.pack(">IIBBBBB", width, height, 8, 2, 0, 0, 0) raw_row = b"\x00" + b"\xff\x00\x00" @@ -922,11 +948,83 @@ def _png_chunk(chunk_type: bytes, data: bytes) -> bytes: return base64.standard_b64encode(png).decode("utf-8") +_BLOCK_FONT = { + "A": ["01110", "10001", "10001", "11111", "10001", "10001", "10001"], + "C": ["01111", "10000", "10000", "10000", "10000", "10000", "01111"], + "D": ["11110", "10001", "10001", "10001", "10001", "10001", "11110"], + "E": ["11111", "10000", "10000", "11110", "10000", "10000", "11111"], + "F": ["11111", "10000", "10000", "11110", "10000", "10000", "10000"], + "I": ["11111", "00100", "00100", "00100", "00100", "00100", "11111"], + "K": ["10001", "10010", "10100", "11000", "10100", "10010", "10001"], + "M": ["10001", "11011", "10101", "10101", "10001", "10001", "10001"], + "Q": ["01110", "10001", "10001", "10001", "10101", "10010", "01101"], + "R": ["11110", "10001", "10001", "11110", "10100", "10010", "10001"], + "T": ["11111", "00100", "00100", "00100", "00100", "00100", "00100"], + "0": ["01110", "10001", "10011", "10101", "11001", "10001", "01110"], + "1": ["00100", "01100", "00100", "00100", "00100", "00100", "01110"], + "2": ["01110", "10001", "00001", "00010", "00100", "01000", "11111"], + "4": ["00010", "00110", "01010", "10010", "11111", "00010", "00010"], + "6": ["01110", "10000", "10000", "11110", "10001", "10001", "01110"], + ":": ["00000", "00100", "00100", "00000", "00100", "00100", "00000"], + "-": ["00000", "00000", "00000", "11111", "00000", "00000", "00000"], + " ": ["00000", "00000", "00000", "00000", "00000", "00000", "00000"], +} + + +def _draw_block_text( + pixels: list, + width: int, + x: int, + y: int, + text: str, + scale: int = 7, +) -> None: + """Draw uppercase block-font text into an RGB pixel buffer.""" + height = len(pixels) // width + for char in text: + glyph = _BLOCK_FONT.get(char, _BLOCK_FONT[" "]) + for row, pattern in enumerate(glyph): + for col, bit in enumerate(pattern): + if bit != "1": + continue + for dy in range(scale): + for dx in range(scale): + px = x + col * scale + dx + py = y + row * scale + dy + if 0 <= px < width and 0 <= py < height: + pixels[py * width + px] = (0, 0, 0) + x += 6 * scale + + +def create_labeled_test_image_base64() -> str: + """Create a PNG with large text labels for vision extraction tests.""" + import struct + import zlib + + width, height = 700, 260 + pixels = [(255, 255, 255)] * (width * height) + _draw_block_text(pixels, width, 30, 35, "TICKER: ACME") + _draw_block_text(pixels, width, 30, 115, "TIMEFRAME: Q4") + _draw_block_text(pixels, width, 30, 195, "DATE: 2026-06-21") + + raw = b"".join( + b"\x00" + + b"".join(bytes(pixel) for pixel in pixels[y * width : (y + 1) * width]) + for y in range(height) + ) + png = b"\x89PNG\r\n\x1a\n" + png += _png_chunk(b"IHDR", struct.pack(">IIBBBBB", width, height, 8, 2, 0, 0, 0)) + png += _png_chunk(b"IDAT", zlib.compress(raw)) + png += _png_chunk(b"IEND", b"") + + return base64.standard_b64encode(png).decode("utf-8") + + class TestMultiModal: """Test multi-modal image input across Workers AI models via REST API. Discovery test: Which Workers AI models accept image content blocks - when invoked via the REST API (/v1/chat/completions)? + when invoked via the native Workers AI REST API (/ai/run)? """ @pytest.mark.parametrize("model", MODELS) @@ -1016,6 +1114,84 @@ def test_vision_invoke(self, model, account_id, api_token, ai_gateway): assert len(text) > 0, f"Expected non-empty vision response from {model}" +# MARK: - OpenAI-Compatible Endpoint Tests + + +class TestOpenAICompatibleEndpoint: + """Focused coverage for endpoint_format='openai_compatible'.""" + + @pytest.mark.parametrize("model", OPENAI_COMPAT_MODELS) + @pytest.mark.parametrize("endpoint_format", ["workers_ai", "openai_compatible"]) + def test_basic_invoke_endpoint_formats( + self, model, endpoint_format, account_id, api_token, ai_gateway + ): + """Basic text invoke should work through both REST endpoint formats.""" + if not account_id or not api_token: + pytest.skip("Missing CF_ACCOUNT_ID or CF_AI_API_TOKEN") + + llm = create_llm( + model, + account_id, + api_token, + ai_gateway, + endpoint_format=endpoint_format, + ) + result = llm.invoke("Say hello in exactly one word.") + text = get_text_content(result.content) + + print(f"\n[{model}] endpoint_format={endpoint_format}:") + print(f" Response: {text[:200]}") + + assert len(text) > 0, ( + f"Expected non-empty response from {model} via {endpoint_format}" + ) + + def test_openai_compatible_structured_vision_extracts_labeled_fields( + self, account_id, api_token, ai_gateway + ): + """OpenAI-compatible mode should extract structured fields from image text.""" + if not account_id or not api_token: + pytest.skip("Missing CF_ACCOUNT_ID or CF_AI_API_TOKEN") + + model = "@cf/moonshotai/kimi-k2.6" + llm = create_llm( + model, + account_id, + api_token, + ai_gateway, + endpoint_format="openai_compatible", + ) + structured_llm = llm.with_structured_output( + ImageExtraction, + method="json_schema", + ) + image_b64 = create_labeled_test_image_base64() + message = HumanMessage( + content=[ + { + "type": "text", + "text": ( + "Extract the ticker, timeframe, and date from this image." + ), + }, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_b64}"}, + }, + ] + ) + + result = structured_llm.invoke([message]) + + print("\n[openai_compatible] Structured vision extraction:") + print(f" Result: {result}") + + assert isinstance(result, ImageExtraction) + assert result.ticker == "ACME" + assert result.timeframe == "Q4" + assert result.date == "2026-06-21" + + if __name__ == "__main__": # Run with: python -m pytest test_workersai_models.py -v -s # Or directly: python test_workersai_models.py diff --git a/libs/langchain-cloudflare/tests/unit_tests/test_chat_models.py b/libs/langchain-cloudflare/tests/unit_tests/test_chat_models.py index eca86ac..5e596a4 100644 --- a/libs/langchain-cloudflare/tests/unit_tests/test_chat_models.py +++ b/libs/langchain-cloudflare/tests/unit_tests/test_chat_models.py @@ -463,6 +463,25 @@ def test_glm_unsupported_params_removed(self): assert "tool_choice" not in translated assert translated["temperature"] == 0.7 + def test_glm_5_2_preserves_supported_params(self): + """GLM-5.2 should keep parameters supported by its OpenAI schema.""" + llm = self._create_llm("@cf/zai-org/glm-5.2") + params = { + "max_tokens": 100, + "top_k": 50, + "repetition_penalty": 1.1, + "tool_choice": "required", + "temperature": 0.7, + } + + translated = llm._translate_params_for_model(params) + + assert translated["max_tokens"] == 100 + assert "top_k" not in translated + assert "repetition_penalty" not in translated + assert translated["tool_choice"] == "required" + assert translated["temperature"] == 0.7 + # MARK: - GPT-OSS Model Tests @@ -763,6 +782,127 @@ def test_session_id_with_aig_headers(self): assert llm.client.headers["cf-aig-request-timeout"] == "5000" +# MARK: - Endpoint Format Tests +class TestEndpointFormat: + """Tests for native Workers AI vs OpenAI-compatible endpoint routing.""" + + def test_workers_ai_endpoint_format_uses_native_run_url_and_payload(self): + """Default endpoint format should preserve existing native run behavior.""" + llm = ChatCloudflareWorkersAI( + account_id="test_account", + api_token="test_token", + model="@cf/meta/llama-3.3-70b-instruct-fp8-fast", + ) + + messages, params = llm._create_message_dicts( + [HumanMessage(content="Hello")], + stop=None, + ) + payload = llm._create_request_payload(messages, params) + + assert llm._get_api_url() == ( + "accounts/test_account/ai/run/@cf/meta/llama-3.3-70b-instruct-fp8-fast" + ) + assert "model" not in payload + assert payload["messages"] == [{"role": "user", "content": "Hello"}] + + def test_openai_compatible_endpoint_format_uses_chat_completions_payload(self): + """OpenAI-compatible format should route to chat completions.""" + llm = ChatCloudflareWorkersAI( + account_id="test_account", + api_token="test_token", + model="@cf/moonshotai/kimi-k2.6", + endpoint_format="openai_compatible", + ) + + messages, params = llm._create_message_dicts( + [HumanMessage(content="Hello")], + stop=None, + ) + payload = llm._create_request_payload(messages, params) + + assert llm._get_api_url() == ("accounts/test_account/ai/v1/chat/completions") + assert payload["model"] == "@cf/moonshotai/kimi-k2.6" + assert payload["messages"] == [{"role": "user", "content": "Hello"}] + + def test_openai_compatible_endpoint_format_uses_gateway_chat_completions(self): + """AI Gateway should route OpenAI-compatible requests through Workers AI.""" + llm = ChatCloudflareWorkersAI( + account_id="test_account", + api_token="test_token", + model="@cf/moonshotai/kimi-k2.6", + ai_gateway="my-gateway", + endpoint_format="openai_compatible", + ) + + assert str(llm.client.base_url) == ( + "https://gateway.ai.cloudflare.com/v1/test_account/my-gateway/" + ) + assert llm._get_api_url() == "workers-ai/v1/chat/completions" + + def test_openai_compatible_endpoint_format_rejects_binding(self): + """Bindings use env.AI.run() and cannot select chat completions.""" + with pytest.raises(ValueError, match="openai_compatible"): + ChatCloudflareWorkersAI( + model="@cf/moonshotai/kimi-k2.6", + binding=object(), + endpoint_format="openai_compatible", + ) + + def test_create_chat_result_accepts_top_level_openai_response(self): + """OpenAI-compatible responses can arrive without a result wrapper.""" + llm = ChatCloudflareWorkersAI( + account_id="test_account", + api_token="test_token", + model="@cf/moonshotai/kimi-k2.6", + endpoint_format="openai_compatible", + ) + response = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Done", + "reasoning_content": "Short reasoning", + } + } + ], + "usage": { + "prompt_tokens": 3, + "completion_tokens": 4, + "total_tokens": 7, + }, + } + + result = llm._create_chat_result(response) + message = result.generations[0].message + + assert isinstance(message.content, list) + assert message.content[0]["type"] == "thinking" + assert message.content[1] == {"type": "text", "text": "Done"} + assert message.usage_metadata == { + "input_tokens": 3, + "output_tokens": 4, + "total_tokens": 7, + } + + def test_openai_compatible_stream_chunk_parsing(self): + """OpenAI-compatible streaming deltas should become message chunks.""" + llm = ChatCloudflareWorkersAI( + account_id="test_account", + api_token="test_token", + model="@cf/meta/llama-3.3-70b-instruct-fp8-fast", + endpoint_format="openai_compatible", + ) + + chunk = llm._create_openai_stream_chunk( + {"choices": [{"delta": {"content": "Hello"}}]} + ) + + assert chunk is not None + assert chunk.message.content == "Hello" + + # MARK: - with_structured_output Routing Tests From 2d6a11437187395541e64548fb4fa2f6a7658213 Mon Sep 17 00:00:00 2001 From: Collier King Date: Sun, 21 Jun 2026 17:48:01 -0500 Subject: [PATCH 2/2] fix chat model ci conformance --- .../langchain_cloudflare/chat_models.py | 14 +++++++------- .../tests/unit_tests/test_chat_models.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/libs/langchain-cloudflare/langchain_cloudflare/chat_models.py b/libs/langchain-cloudflare/langchain_cloudflare/chat_models.py index 304fbc1..7e683d9 100644 --- a/libs/langchain-cloudflare/langchain_cloudflare/chat_models.py +++ b/libs/langchain-cloudflare/langchain_cloudflare/chat_models.py @@ -629,7 +629,7 @@ def _get_ls_params( params = self._get_invocation_params(stop=stop, **kwargs) ls_params = LangSmithParams( ls_provider="cloudflare-workers-ai", - ls_model_name=self.model, + ls_model_name=params.get("model", self.model), ls_model_type="chat", ls_temperature=params.get("temperature", self.temperature), ) @@ -1904,14 +1904,14 @@ def _inject_schema_message( return [schema_system_msg] + list(messages) return messages - llm = self.bind( # type: ignore[assignment] + schema_bound_llm = self.bind( response_format={"type": "json_object"}, ls_structured_output_format={ "kwargs": {"method": "json_mode"}, "schema": schema, }, ) - pipeline = RunnableLambda(_inject_schema_message) | llm # type: ignore[arg-type] + pipeline = RunnableLambda(_inject_schema_message) | schema_bound_llm # type: ignore[arg-type] if include_raw: parser_assign = RunnablePassthrough.assign( @@ -1936,7 +1936,7 @@ def _inject_schema_message( formatted_tool = convert_to_openai_tool(schema) tool_name = formatted_tool["function"]["name"] - llm = self.bind_tools( + pipeline = self.bind_tools( [schema], ls_structured_output_format={ "kwargs": {"method": "function_calling"}, @@ -1954,7 +1954,7 @@ def _inject_schema_message( ) elif method == "json_mode": - llm = self.bind( # type: ignore[assignment] + pipeline = self.bind( response_format={"type": "json_object"}, ls_structured_output_format={ "kwargs": {"method": "json_mode"}, @@ -1985,9 +1985,9 @@ def _inject_schema_message( parser_with_fallback = parser_assign.with_fallbacks( [parser_none], exception_key="parsing_error" ) - return RunnableMap(raw=llm) | parser_with_fallback + return RunnableMap(raw=pipeline) | parser_with_fallback else: - return llm | output_parser + return pipeline | output_parser def _is_pydantic_class(obj: Any) -> bool: diff --git a/libs/langchain-cloudflare/tests/unit_tests/test_chat_models.py b/libs/langchain-cloudflare/tests/unit_tests/test_chat_models.py index 5e596a4..f91e53c 100644 --- a/libs/langchain-cloudflare/tests/unit_tests/test_chat_models.py +++ b/libs/langchain-cloudflare/tests/unit_tests/test_chat_models.py @@ -903,6 +903,23 @@ def test_openai_compatible_stream_chunk_parsing(self): assert chunk.message.content == "Hello" +# MARK: - LangSmith Params Tests +class TestLangSmithParams: + """Tests for LangSmith tracing parameters.""" + + def test_get_ls_params_uses_per_call_model_override(self): + """LangSmith params should reflect per-call model overrides.""" + llm = ChatCloudflareWorkersAI( + account_id="test_account", + api_token="test_token", + model="@cf/meta/llama-3.3-70b-instruct-fp8-fast", + ) + + params = llm._get_ls_params(model="test-model-override-sentinel") + + assert params["ls_model_name"] == "test-model-override-sentinel" + + # MARK: - with_structured_output Routing Tests