Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## langchain-cloudflare

### [0.3.5]

#### Added

- **`@cf/zai-org/glm-5.2`**: Adds GLM-5.2 to the tested and example-supported model lists with reasoning content, tool calling, structured output, and model-specific parameter handling for Workers AI's OpenAI-compatible schema.
- **OpenAI-compatible endpoint mode**: Adds `endpoint_format="openai_compatible"` to `ChatCloudflareWorkersAI` for REST chat completions routing through `/ai/v1/chat/completions` or AI Gateway `/workers-ai/v1/chat/completions`, with focused text and structured vision regression coverage.

---

### [0.3.4]

#### Changed
Expand Down
25 changes: 25 additions & 0 deletions libs/langchain-cloudflare/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,31 @@ llm = ChatCloudflareWorkersAI()
llm.invoke("Sing a ballad of LangChain.")
```

### REST endpoint format

By default, `ChatCloudflareWorkersAI` uses the native Workers AI run endpoint:

```python
llm = ChatCloudflareWorkersAI(
model="@cf/moonshotai/kimi-k2.6",
endpoint_format="workers_ai", # default
)
```

For REST calls that need Cloudflare's OpenAI-compatible chat completions API,
set `endpoint_format="openai_compatible"`:

```python
llm = ChatCloudflareWorkersAI(
model="@cf/moonshotai/kimi-k2.6",
endpoint_format="openai_compatible",
)
```

When `ai_gateway` is configured, OpenAI-compatible mode routes through the
Workers AI chat completions path on AI Gateway. This option is REST-only; Worker
bindings use `env.AI.run()` and do not expose a chat completions route.

## Embeddings

`CloudflareWorkersAIEmbeddings` class exposes embeddings from [CloudflareWorkersAI](https://developers.cloudflare.com/workers-ai/).
Expand Down
1 change: 1 addition & 0 deletions libs/langchain-cloudflare/examples/workers/src/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"@cf/mistralai/mistral-small-3.1-24b-instruct",
"@cf/qwen/qwen3-30b-a3b-fp8",
"@cf/zai-org/glm-4.7-flash",
"@cf/zai-org/glm-5.2",
"@cf/openai/gpt-oss-120b",
"@cf/openai/gpt-oss-20b",
"@cf/nvidia/nemotron-3-120b-a12b",
Expand Down
152 changes: 115 additions & 37 deletions libs/langchain-cloudflare/langchain_cloudflare/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ def _normalize_response_format_for_cloudflare(
)

MODEL_BEHAVIORS: Dict[str, ModelBehavior] = {
"glm-5.2": ModelBehavior(
embed_tool_calls_in_content=False,
unsupported_params=("top_k", "repetition_penalty"),
supports_reasoning_content=True,
),
"glm": ModelBehavior(
embed_tool_calls_in_content=False,
unsupported_params=("max_tokens", "top_k", "repetition_penalty", "tool_choice"),
Expand Down Expand Up @@ -369,6 +374,9 @@ async def fetch(self, request):
base_url: Optional[str]
Base URL path for API requests, leave blank if not using a proxy
or service emulator.
endpoint_format: Literal["workers_ai", "openai_compatible"]
REST endpoint format. Defaults to native Workers AI run endpoints.
Set to ``"openai_compatible"`` to use chat completions endpoints.
binding: Optional[Any]
Workers AI binding (env.AI) for use in Python Workers.
When provided, uses the binding instead of REST API calls.
Expand Down Expand Up @@ -414,6 +422,16 @@ async def fetch(self, request):
alias="cloudflare_ai_gateway",
default_factory=from_env("AI_GATEWAY", default=None),
)
endpoint_format: Literal["workers_ai", "openai_compatible"] = "workers_ai"
"""REST endpoint format to use.

``"workers_ai"`` uses Workers AI native run endpoints:
``/ai/run/{model}`` or AI Gateway ``/workers-ai/run/{model}``.
``"openai_compatible"`` uses chat completions endpoints and includes
``model`` in the JSON payload:
``/ai/v1/chat/completions`` or AI Gateway
``/workers-ai/v1/chat/completions``.
"""
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
Expand Down Expand Up @@ -495,6 +513,12 @@ def validate_environment(self) -> Self:

# If binding is provided, skip REST API setup
if self.binding is not None:
if self.endpoint_format != "workers_ai":
raise ValueError(
"endpoint_format='openai_compatible' is only supported for "
"REST API calls. Workers AI bindings use env.AI.run() and "
"do not expose a chat completions endpoint."
)
# When using binding, we don't need api_token or account_id
return self

Expand Down Expand Up @@ -605,7 +629,7 @@ def _get_ls_params(
params = self._get_invocation_params(stop=stop, **kwargs)
ls_params = LangSmithParams(
ls_provider="cloudflare-workers-ai",
ls_model_name=self.model,
ls_model_name=params.get("model", self.model),
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)
Expand Down Expand Up @@ -813,6 +837,54 @@ def _translate_params_for_model(self, params: Dict[str, Any]) -> Dict[str, Any]:

return params

def _get_api_url(self) -> str:
"""Return the REST API path for the configured endpoint format."""
if self.endpoint_format == "openai_compatible":
if self.ai_gateway:
return "workers-ai/v1/chat/completions"
return f"accounts/{self.account_id}/ai/v1/chat/completions"

if self.ai_gateway:
return f"workers-ai/run/{self.model}"
return f"accounts/{self.account_id}/ai/run/{self.model}"

def _create_request_payload(
self,
message_dicts: List[Dict[str, Any]],
params: Dict[str, Any],
) -> Dict[str, Any]:
"""Build the REST request payload for the configured endpoint format."""
if self.endpoint_format == "openai_compatible":
return {
**params,
"model": self.model,
"messages": message_dicts,
}

return {"messages": message_dicts, **params}

def _create_openai_stream_chunk(
self,
chunk: Dict[str, Any],
) -> Optional[ChatGenerationChunk]:
"""Convert an OpenAI-compatible SSE chunk to a LangChain chunk."""
choices = chunk.get("choices") or []
if not choices:
return None

choice = choices[0]
delta = choice.get("delta") or choice.get("message") or {}
response_text = _normalize_message_content(delta.get("content", ""))

generation_info = {}
if "usage" in chunk:
generation_info["usage"] = chunk["usage"]

return ChatGenerationChunk(
message=AIMessageChunk(content=response_text),
generation_info=generation_info or None,
)

# MARK: - Generate
def _generate( # type: ignore
self,
Expand All @@ -830,7 +902,7 @@ def _generate( # type: ignore
params = self._translate_params_for_model(params)

# Create the request payload
payload = {"messages": message_dicts, **params}
payload = self._create_request_payload(message_dicts, params)

# Use binding if available (for Python Workers)
if self.binding is not None:
Expand All @@ -846,13 +918,7 @@ def _generate( # type: ignore
response_data = loop.run_until_complete(self._call_binding(payload))
else:
# Use REST API (httpx client)
# Construct the API URL
if self.ai_gateway:
# If using AI Gateway
api_url = f"workers-ai/run/{self.model}"
else:
# If using direct API
api_url = f"accounts/{self.account_id}/ai/run/{self.model}"
api_url = self._get_api_url()

# Make the API request
response = self.client.post(api_url, json=payload)
Expand All @@ -875,18 +941,14 @@ async def _agenerate( # type: ignore
params = self._translate_params_for_model(params)

# Create the request payload
payload = {"messages": message_dicts, **params}
payload = self._create_request_payload(message_dicts, params)

# Use binding if available (for Python Workers)
if self.binding is not None:
response_data = await self._call_binding(payload)
else:
# Use REST API (httpx async client)
# Construct the Cloudflare Workers AI API URL
if self.ai_gateway:
api_url = f"workers-ai/run/{self.model}"
else:
api_url = f"accounts/{self.account_id}/ai/run/{self.model}"
api_url = self._get_api_url()

# Make the API request
response = await self.async_client.post(api_url, json=payload)
Expand Down Expand Up @@ -916,14 +978,9 @@ def _stream(
params = {**params, **kwargs, "stream": True}
params = self._translate_params_for_model(params)

# Construct the Cloudflare Workers AI API URL
if self.ai_gateway:
api_url = f"workers-ai/run/{self.model}"
else:
api_url = f"accounts/{self.account_id}/ai/run/{self.model}"

# Create the request payload
payload = {"messages": message_dicts, **params}
api_url = self._get_api_url()
payload = self._create_request_payload(message_dicts, params)

# Make the streaming API request
with self.client.stream("POST", api_url, json=payload) as response:
Expand All @@ -950,8 +1007,21 @@ def _stream(
except json.JSONDecodeError:
continue

if "choices" in chunk:
generation_chunk = self._create_openai_stream_chunk(chunk)
if generation_chunk is None:
continue

if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text,
chunk=generation_chunk,
)

yield generation_chunk

# Process streaming response
if "streamed_data" in chunk:
elif "streamed_data" in chunk:
# Handle streamed_data format from Cloudflare
for stream_chunk in chunk["streamed_data"]:
response_text = stream_chunk.get("response", "")
Expand Down Expand Up @@ -1080,14 +1150,9 @@ async def _astream(
params = {**params, **kwargs, "stream": True}
params = self._translate_params_for_model(params)

# Construct the Cloudflare Workers AI API URL
if self.ai_gateway:
api_url = f"workers-ai/run/{self.model}"
else:
api_url = f"accounts/{self.account_id}/ai/run/{self.model}"

# Create the request payload
payload = {"messages": message_dicts, **params}
api_url = self._get_api_url()
payload = self._create_request_payload(message_dicts, params)

# Make the streaming API request
async with self.async_client.stream("POST", api_url, json=payload) as response:
Expand All @@ -1114,8 +1179,21 @@ async def _astream(
except json.JSONDecodeError:
continue

if "choices" in chunk:
generation_chunk = self._create_openai_stream_chunk(chunk)
if generation_chunk is None:
continue

if run_manager:
await run_manager.on_llm_new_token(
token=generation_chunk.text,
chunk=generation_chunk,
)

yield generation_chunk

# Handle the streamed_data format
if "streamed_data" in chunk:
elif "streamed_data" in chunk:
for stream_chunk in chunk["streamed_data"]:
response_text = stream_chunk.get("response", "")
accumulated_content += response_text
Expand Down Expand Up @@ -1826,14 +1904,14 @@ def _inject_schema_message(
return [schema_system_msg] + list(messages)
return messages

llm = self.bind( # type: ignore[assignment]
schema_bound_llm = self.bind(
response_format={"type": "json_object"},
ls_structured_output_format={
"kwargs": {"method": "json_mode"},
"schema": schema,
},
)
pipeline = RunnableLambda(_inject_schema_message) | llm # type: ignore[arg-type]
pipeline = RunnableLambda(_inject_schema_message) | schema_bound_llm # type: ignore[arg-type]

if include_raw:
parser_assign = RunnablePassthrough.assign(
Expand All @@ -1858,7 +1936,7 @@ def _inject_schema_message(

formatted_tool = convert_to_openai_tool(schema)
tool_name = formatted_tool["function"]["name"]
llm = self.bind_tools(
pipeline = self.bind_tools(
[schema],
ls_structured_output_format={
"kwargs": {"method": "function_calling"},
Expand All @@ -1876,7 +1954,7 @@ def _inject_schema_message(
)

elif method == "json_mode":
llm = self.bind( # type: ignore[assignment]
pipeline = self.bind(
response_format={"type": "json_object"},
ls_structured_output_format={
"kwargs": {"method": "json_mode"},
Expand Down Expand Up @@ -1907,9 +1985,9 @@ def _inject_schema_message(
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
return RunnableMap(raw=pipeline) | parser_with_fallback
else:
return llm | output_parser
return pipeline | output_parser


def _is_pydantic_class(obj: Any) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-cloudflare/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "langchain-cloudflare"
version = "0.3.4"
version = "0.3.5"
description = "Langchain Integrations for Cloudflare's WorkersAI and Vectorize"
readme = "README.md"
license = "MIT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"@cf/mistralai/mistral-small-3.1-24b-instruct",
"@cf/qwen/qwen3-30b-a3b-fp8",
"@cf/zai-org/glm-4.7-flash",
"@cf/zai-org/glm-5.2",
"@cf/openai/gpt-oss-120b",
"@cf/openai/gpt-oss-20b",
"@cf/nvidia/nemotron-3-120b-a12b",
Expand Down Expand Up @@ -906,6 +907,7 @@ class TestWorkerReasoningContent:
REASONING_MODELS = [
"@cf/qwen/qwen3-30b-a3b-fp8",
"@cf/zai-org/glm-4.7-flash",
"@cf/zai-org/glm-5.2",
"@cf/openai/gpt-oss-120b",
"@cf/openai/gpt-oss-20b",
"@cf/moonshotai/kimi-k2.5",
Expand Down
Loading
Loading