diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f777af..bd4dc63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.6.2] + +### Fixed +- Fixes a bug in processing SSE messages in `data: ` strings in the payload of the message + ## [0.6.1] ### Fixed diff --git a/Project.toml b/Project.toml index 7dfc0b7..6d131ef 100644 --- a/Project.toml +++ b/Project.toml @@ -1,17 +1,19 @@ name = "StreamCallbacks" uuid = "c1b9e933-98a0-46fc-8ea7-3b58b195fb0a" authors = ["J S <49557684+svilupp@users.noreply.github.com> and contributors"] -version = "0.6.1" +version = "0.6.2" [deps] HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +LibCURL = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" [compat] Aqua = "0.8" HTTP = "1.10" JSON3 = "1.14" +LibCURL = "0.6.4" PrecompileTools = "1.2" Test = "1" julia = "1.9" diff --git a/examples/error_handling_test.jl b/examples/error_handling_test.jl new file mode 100644 index 0000000..2e10e25 --- /dev/null +++ b/examples/error_handling_test.jl @@ -0,0 +1,46 @@ +# Test error handling with custom IO that fails on "5" +using HTTP, JSON3 +using StreamCallbacks +using StreamCallbacks: OpenAIStream, libcurl_streamed_request!, streamed_request_http!, streamed_request_libcurl! + +# Prepare target and auth +url = "https://api.openai.com/v1/chat/completions" +headers = [ + "Content-Type" => "application/json", + "Authorization" => "Bearer $(get(ENV, "OPENAI_API_KEY", ""))" +] + +# Custom IO type that throws when it sees "5" +struct ErrorOnFiveIO <: IO + buffer::Vector{String} +end +ErrorOnFiveIO() = ErrorOnFiveIO(String[]) + +function StreamCallbacks.print_content(out::ErrorOnFiveIO, text::AbstractString; kwargs...) + push!(out.buffer, text) + if occursin("5", text) + error("Custom IO error: Found forbidden number '5' in: $(text)") + end +end + +messages = [Dict("role" => "user", "content" => "Count from 1 to 10.")] +payload = IOBuffer() +JSON3.write(payload, (; stream = true, messages, model = "gpt-4o-mini", stream_options = (; include_usage = true))) +payload_str = String(take!(payload)) + +println("=== Testing Error Handling ===") + +# Test 1: HTTP.jl with error handling +println("\n1. Testing HTTP.jl error handling...") +cb_http = StreamCallback(; out = ErrorOnFiveIO(), flavor = OpenAIStream(), throw_on_error = true) +# resp_http = streamed_request_http!(cb_http, url, headers, IOBuffer(payload_str)) +# println("HTTP: No error occurred (unexpected)") + +# Test 2: LibCURL with error handling +println("\n2. Testing LibCURL error handling...") +cb_curl = StreamCallback(; out = ErrorOnFiveIO(), flavor = OpenAIStream(), throw_on_error = true) + +resp_curl = streamed_request_libcurl!(cb_curl, url, headers, payload_str) +println("LibCURL: No error occurred (unexpected)") + +println("\n=== Error Handling Test Complete ===") \ No newline at end of file diff --git a/examples/google_openai_streaming_example.jl b/examples/google_openai_streaming_example.jl new file mode 100644 index 0000000..6ac16de --- /dev/null +++ b/examples/google_openai_streaming_example.jl @@ -0,0 +1,23 @@ +# Calling Google AI with OpenAI Schema using StreamCallbacks +using HTTP, JSON3 +using StreamCallbacks + +# Prepare target and auth for Google AI Studio +url = "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions" +headers = [ + "Content-Type" => "application/json", + "Authorization" => "Bearer $(get(ENV, "GOOGLE_API_KEY", ""))" +] + +# Send the request with OpenAI-compatible format +cb = StreamCallback(; out = stdout, flavor = OpenAIStream()) # Use OpenAIStream for Google's OpenAI schema +messages = [Dict("role" => "user", + "content" => "Count from 1 to 10. Start with numbers only.")] +payload = IOBuffer() +JSON3.write(payload, + (; stream = true, messages, model = "gemini-2.5-flash", stream_options = (; include_usage = true))) + +resp = streamed_request!(cb, url, headers, payload); + +println("Response status: ", resp.status) +println("Collected chunks: ", length(cb.chunks)) \ No newline at end of file diff --git a/examples/long_context_test.jl b/examples/long_context_test.jl new file mode 100644 index 0000000..8f203fd --- /dev/null +++ b/examples/long_context_test.jl @@ -0,0 +1,38 @@ +# Test long context with HTTP vs LibCURL performance comparison +using HTTP, JSON3 +using StreamCallbacks +using StreamCallbacks: OpenAIStream, streamed_request_libcurl! + +# Prepare target and auth +url = "https://api.openai.com/v1/chat/completions" +headers = [ + "Content-Type" => "application/json", + "Authorization" => "Bearer $(get(ENV, "OPENAI_API_KEY", ""))" +] + +# Create very long context +very_long_text = ["(Random text chunk $i.) " for i in 1:100_000] |> join +very_long_text = ["(Random text chunk $i.) " for i in 1:1] |> join +messages = [Dict("role" => "user", "content" => very_long_text * "Count from 1 to 10.")] + +payload = IOBuffer() +JSON3.write(payload, (; stream = true, messages, model = "gpt-4o-mini", stream_options = (; include_usage = true))) +payload_str = String(take!(payload)) + +println("=== Testing Long Context Performance ===") +println("Context length: $(length(very_long_text)) characters") + +# Test 1: HTTP.jl based streaming +println("\n1. Testing HTTP.jl streaming...") +# cb_http = StreamCallback(; out = stdout, flavor = OpenAIStream(), throw_on_error = true) +# resp_http = @time streamed_request!(cb_http, url, headers, IOBuffer(payload_str)) +# @show resp_http +# println("HTTP chunks received: $(length(cb_http.chunks))") + +# Test 2: LibCURL based streaming +println("\n2. Testing LibCURL streaming...") +cb_curl = StreamCallback(; out = stdout, flavor = OpenAIStream(), throw_on_error = true) +resp_curl = @time streamed_request_libcurl!(cb_curl, url, headers, payload_str) +println("LibCURL chunks received: $(length(cb_curl.chunks))") + +println("\n=== Performance Comparison Complete ===") \ No newline at end of file diff --git a/examples/openai_example.jl b/examples/openai_example.jl index 7f233bb..afe8b5d 100644 --- a/examples/openai_example.jl +++ b/examples/openai_example.jl @@ -1,24 +1,58 @@ # Calling OpenAI with StreamCallbacks using HTTP, JSON3 using StreamCallbacks +using StreamCallbacks: OpenAIStream +using StreamCallbacks: streamed_request_libcurl! -## Prepare target and auth +# Prepare target and auth url = "https://api.openai.com/v1/chat/completions" headers = [ "Content-Type" => "application/json", "Authorization" => "Bearer $(get(ENV, "OPENAI_API_KEY", ""))" ]; +# Custom IO type that throws when it sees "5" +struct ErrorOnFiveIO <: IO + buffer::Vector{String} +end +ErrorOnFiveIO() = ErrorOnFiveIO(String[]) -## Send the request -cb = StreamCallback(; out = stdout, flavor = OpenAIStream()) -messages = [Dict("role" => "user", - "content" => "Count from 1 to 100.")] +function StreamCallbacks.print_content(out::ErrorOnFiveIO, text::AbstractString; kwargs...) + push!(out.buffer, text) + if occursin("5", text) + error("Custom IO error: Found forbidden number '5' in: $(text)") + end +end + +# Send the request +cb = StreamCallback(; out = stdout, flavor = OpenAIStream(), throw_on_error = true) +# cb = StreamCallback(; out = ErrorOnFiveIO(), flavor = OpenAIStream(), throw_on_error = true) +very_long_text = ["(Just some random text $i.) " for i in 1:1] |> join +# very_long_text = "" +messages = [Dict("role" => "user", "content" => very_long_text * "Count from 1 to 100.")] +using LLMRateLimiters +# @show LLMRateLimiters.estimate_tokens(messages[1]["content"]) + +# payload = IOBuffer() JSON3.write(payload, - (; stream = true, messages, model = "gpt-4o-mini", + (; stream = true, messages, model = "gpt-5-mini", stream_options = (; include_usage = true))) -resp = streamed_request!(cb, url, headers, payload); +# Test different streaming methods: +# 1. HTTP.jl based (default) +# payload_str = String(take!(payload)) +# resp = @time streamed_request!(cb, url, headers, IOBuffer(payload_str)); +# @show resp + +# 2. Socket-based streaming +# resp = socket_streamed_request!(cb, url, headers, String(take!(payload))); + +# 3. LibCURL-based streaming (recommended) +# Clear chunks from previous test to avoid accumulation +empty!(cb.chunks) +resp = @time streamed_request_libcurl!(cb, url, headers, payload); +@show resp +; ## Check the response resp # should be a `HTTP.Response` object with a message body like if we wouldn't use streaming diff --git a/examples/promptingtools_aigenerate_example.jl b/examples/promptingtools_aigenerate_example.jl new file mode 100644 index 0000000..b973c82 --- /dev/null +++ b/examples/promptingtools_aigenerate_example.jl @@ -0,0 +1,24 @@ +using PromptingTools +using StreamCallbacks +using GoogleGenAI +using PromptingTools: GoogleSchema + +# Stream to stdout with callback collecting chunks +cb = StreamCallback(out = stdout, verbose = false) + +# Use a real model id; adjust as needed +msg = @time aigenerate(GoogleSchema(), "Tell me a short story of humanity:"; + model = "gemini-2.5-pro-preview-06-05", + streamcallback = cb) + +println("\n\nFinal content:\n", msg.content) +#%% +using GoogleGenAI + +models = list_models() +for m in models + if "createCachedContent" in m[:supported_generation_methods] + println(m[:name]) + end +end + diff --git a/examples/responses_stream_example.jl b/examples/responses_stream_example.jl new file mode 100644 index 0000000..9542665 --- /dev/null +++ b/examples/responses_stream_example.jl @@ -0,0 +1,21 @@ +using HTTP +using JSON3 +using OpenAI +using PromptingTools +const PT = PromptingTools +using PromptingTools: OpenAIResponseSchema, AbstractResponseSchema, airespond +using StreamCallbacks: StreamCallback + +# This example demonstrates the use of the OpenAI Responses API +# with proper schema support and streaming capabilities + +# Make sure your OpenAI API key is set in the environment variable OPENAI_API_KEY + +# Basic usage with the new schema +schema = OpenAIResponseSchema() +cb = StreamCallback(out=stdout) + +response = airespond(schema, "What is the 6th largest city in the Czech Republic? you can think, but in the answer I only want to see the city."; +model = "gpt-5.1-codex", streamcallback=cb) +@show response.tokens +@show response.extras[:usage] \ No newline at end of file diff --git a/examples/stream_comparison_example.jl b/examples/stream_comparison_example.jl new file mode 100644 index 0000000..7d2cb1a --- /dev/null +++ b/examples/stream_comparison_example.jl @@ -0,0 +1,72 @@ +# Calling OpenAI with StreamCallbacks +using HTTP, JSON3 +using StreamCallbacks +using StreamCallbacks: OpenAIStream +using StreamCallbacks: libcurl_streamed_request! + +# Prepare target and auth +url = "https://api.openai.com/v1/chat/completions" +headers = [ + "Content-Type" => "application/json", + "Authorization" => "Bearer $(get(ENV, "OPENAI_API_KEY", ""))" +]; +# Custom IO type that throws when it sees "5" +struct ErrorOnFiveIO <: IO + buffer::Vector{String} +end +ErrorOnFiveIO() = ErrorOnFiveIO(String[]) + +function StreamCallbacks.print_content(out::ErrorOnFiveIO, text::AbstractString; kwargs...) + push!(out.buffer, text) + print(text) + if occursin("5", text) + # error("Custom IO error: Found forbidden number '5' in: $(text)") + end +end + +# Send the request +# cb = StreamCallback(; out = stdout, flavor = OpenAIStream(), throw_on_error = false) +cb = StreamCallback(; out = ErrorOnFiveIO(), flavor = OpenAIStream(), throw_on_error = true) +# cb = StreamCallback(; out = ErrorOnFiveIO(), flavor = AnthropicStream(), throw_on_error = true) +using JLD2 +# @load "call_error.jld2" url headers body kwargs +# @show typeof(headers) +very_long_text = ["(Just some random text $i.) " for i in 1:1] |> join +# very_long_text = "" +messages = [Dict("role" => "user", "content" => very_long_text * "Count from 1 to 10.")] +using LLMRateLimiters +# @show LLMRateLimiters.estimate_tokens(messages[1]["content"]) + +# +payload = IOBuffer() +JSON3.write(payload, + (; stream = true, messages, model = "gpt-4o-mini", + stream_options = (; include_usage = true))) + +# Test different streaming methods: +# 1. HTTP.jl based (default) +# payload_str = String(take!(payload)) +# resp = @time streamed_request!(cb, url, headers, IOBuffer(payload_str)); +# @show resp + +# 2. Socket-based streaming +# resp = socket_streamed_request!(cb, url, headers, String(take!(payload))); + +# 3. LibCURL-based streaming (recommended) +# Clear chunks from previous test to avoid accumulation +body_dict = JSON3.read(body, Dict) +# @show body_dict +# body_dict["system"][1]["text"] = "Count from 1 to 10." +body_str = String(JSON3.write(body_dict)) +@show body_str +# empty!(cb.chunks) +resp = @time streamed_request_libcurl!(cb, url, headers, body_str; kwargs...); +@show resp +; +## Check the response +resp # should be a `HTTP.Response` object with a message body like if we wouldn't use streaming + +## Check the callback +cb.chunks # should be a vector of `StreamChunk` objects, each with a `json` field with received data from the API + +# TIP: For debugging, use `cb.verbose = true` in the `StreamCallback` constructor to get more details on each chunk and enable DEBUG loglevel. diff --git a/src/StreamCallbacks.jl b/src/StreamCallbacks.jl index 8230bb7..c53e287 100644 --- a/src/StreamCallbacks.jl +++ b/src/StreamCallbacks.jl @@ -2,12 +2,15 @@ module StreamCallbacks using HTTP, JSON3 using PrecompileTools +using LibCURL + +export StreamCallback, StreamChunk, OpenAIStream, AnthropicStream, OllamaStream, ResponseStream, + streamed_request!, libcurl_streamed_request! -export StreamCallback, StreamChunk, OpenAIStream, AnthropicStream, OllamaStream, - streamed_request! include("interface.jl") include("shared_methods.jl") +include("shared_methods_libcurl.jl") include("stream_openai.jl") @@ -15,6 +18,8 @@ include("stream_anthropic.jl") include("stream_ollama.jl") +include("stream_response.jl") + @compile_workload begin include("precompilation.jl") end diff --git a/src/interface.jl b/src/interface.jl index d150772..2b4ca26 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -66,11 +66,14 @@ Abstract type for the stream flavor, ie, the API provider. Available flavors: - `OpenAIStream` for OpenAI API - `AnthropicStream` for Anthropic API +- `OllamaStream` for Ollama API +- `ResponseStream` for OpenAI Response API """ abstract type AbstractStreamFlavor end struct OpenAIStream <: AbstractStreamFlavor end struct AnthropicStream <: AbstractStreamFlavor end struct OllamaStream <: AbstractStreamFlavor end +struct ResponseStream <: AbstractStreamFlavor end ## Default implementations """ diff --git a/src/precompilation.jl b/src/precompilation.jl index 5e860ea..bb74472 100644 --- a/src/precompilation.jl +++ b/src/precompilation.jl @@ -30,4 +30,4 @@ build_response_body(flavor, cb) flavor = OllamaStream() is_done(flavor, example_chunk) extract_content(flavor, example_chunk) -build_response_body(flavor, cb) \ No newline at end of file +build_response_body(flavor, cb) diff --git a/src/shared_methods.jl b/src/shared_methods.jl index 7c6120b..8cb71c5 100644 --- a/src/shared_methods.jl +++ b/src/shared_methods.jl @@ -5,87 +5,108 @@ extract_chunks(flavor::AbstractStreamFlavor, blob::AbstractString; spillover::AbstractString = "", verbose::Bool = false, kwargs...) -Extract the chunks from the received SSE blob. Shared by all streaming flavors currently. - -Returns a list of `StreamChunk` and the next spillover (if message was incomplete). +Extract the chunks from the received SSE blob. Correctly implements SSE spec field parsing. """ @inline function extract_chunks(flavor::AbstractStreamFlavor, blob::AbstractString; spillover::AbstractString = "", verbose::Bool = false, kwargs...) - chunks = StreamChunk[] + + # Handle any spillover from previous incomplete message + full_blob = spillover * blob + + # Split on double newlines (SSE message boundaries) + messages = split(full_blob, r"\n\n") + + # Check if last message is incomplete (no trailing \n\n) next_spillover = "" - ## SSE come separated by double-newlines - blob_split = split(blob, "\n\n") - for (bi, chunk) in enumerate(blob_split) - isempty(chunk) && continue - event_split = split(chunk, "event: ") - has_event = length(event_split) > 1 - # if length>1, we know it was there! - for event_blob in event_split - isempty(event_blob) && continue - event_name = nothing - data_buf = IOBuffer() - data_splits = split(event_blob, "data: ") - for i in eachindex(data_splits) - isempty(data_splits[i]) && continue - if i == 1 & has_event && !isempty(data_splits[i]) - ## we have an event name - event_name = strip(data_splits[i]) |> Symbol - elseif bi == 1 && i == 1 && !isempty(data_splits[i]) - ## in the first part of the first blob, it must be a spillover - spillover = string(spillover, rstrip(data_splits[i], '\n')) - verbose && @info "Buffer spillover detected: $(spillover)" - elseif i > 1 - ## any subsequent data blobs are accummulated into the data buffer - ## there can be multiline data that must be concatenated - data_chunk = rstrip(data_splits[i], '\n') - write(data_buf, data_chunk) + if !endswith(full_blob, "\n\n") && !isempty(messages) + # Last message might be incomplete, save it for next time + next_spillover = pop!(messages) + verbose && @info "Incomplete message detected, spillover: $(repr(next_spillover))" + end + + chunks = StreamChunk[] + + for message in messages + isempty(strip(message)) && continue + + # Parse line starts + event_name = nothing + data_parts = String[] + + for line in split(message, '\n') + try + line = rstrip(line, '\r') # Handle \r\n + + # Handle comments (lines starting with ":") + if startswith(line, ":") + continue # Ignore comment lines per SSE spec + end + + # Parse field:value lines + colon_pos = findfirst(':', line) + if isnothing(colon_pos) + continue # Skip lines without colons + end + + field_name = line[1:(colon_pos - 1)] + field_value = line[(colon_pos + 1):end] + + # Strip UTF-8 BOM from first field name if present (SSE spec compliance) + if !isempty(field_name) && field_name[1] == '\ufeff' + field_name = field_name[nextind(field_name, 1):end] end - end - ## Parse the spillover - if bi == 1 && !isempty(spillover) - data = spillover - json = if startswith(data, '{') && endswith(data, '}') - try - JSON3.read(data) - catch e - verbose && @warn "Cannot parse JSON: $data" - nothing + # Remove leading space from field value if present (SSE spec) + if startswith(field_value, " ") + field_value = field_value[2:end] + end + + # Handle data fields: only add non-empty field values to avoid artifacts + if field_name == "data" + # SSE spec: empty data fields should contribute empty string, not be skipped + push!(data_parts, field_value) + elseif field_name == "event" + # Event field should not be empty + if !isempty(field_value) + event_name = Symbol(field_value) end - else - nothing end - ## ignore event name - push!(chunks, StreamChunk(; data = spillover, json = json)) - # reset the spillover - spillover = "" + # Note: id and retry fields are ignored but could be parsed if needed + catch e + # Handle malformed lines gracefully + verbose && @warn "Malformed SSE line ignored: $(repr(line)). Error: $e" + continue end - ## On the last iteration of the blob, check if we spilled over - if bi == length(blob_split) && length(data_splits) > 1 && - !isempty(strip(data_splits[end])) - verbose && @info "Incomplete message detected: $(data_splits[end])" - next_spillover = String(take!(data_buf)) - ## Do not save this chunk - else - ## Try to parse the data as JSON - data = String(take!(data_buf)) - isempty(data) && continue - ## try to build a JSON object if it's a well-formed JSON string - json = if startswith(data, '{') && endswith(data, '}') - try - JSON3.read(data) - catch e - verbose && @warn "Cannot parse JSON: $data" - nothing - end - else + end + + isempty(data_parts) && continue + + # Join multiple data lines with newlines (SSE spec) + # Keep raw_data exactly as received from LLM for debugging and testing + raw_data = join(data_parts, '\n') + + # More robust JSON detection - handle both objects and arrays + parsed_json = if !isempty(strip(raw_data)) + stripped = strip(raw_data) + is_json = (startswith(stripped, '{') && endswith(stripped, '}')) || + (startswith(stripped, '[') && endswith(stripped, ']')) + if is_json + try + JSON3.read(raw_data) + catch e + verbose && @warn "Cannot parse JSON: $(repr(raw_data))" nothing end - ## Create a new chunk - push!(chunks, StreamChunk(event_name, data, json)) + else + nothing end + else + nothing end + + push!(chunks, StreamChunk(event_name, raw_data, parsed_json)) end + return chunks, next_spillover end @@ -168,9 +189,9 @@ Handles error messages from the streaming response. ## Define whether to throw an error error_msg = "Error detected in the streaming response: $(error_str)" if throw_on_error - throw(Exception(error_msg)) + throw(ErrorException(error_msg)) else - @warn error_msg + throw(ErrorException(error_msg)) end end return nothing @@ -192,10 +213,24 @@ Returns the response object. - `input`: A buffer with the request body. - `kwargs`: Additional keyword arguments. """ -function streamed_request!(cb::AbstractStreamCallback, url, headers, input; kwargs...) + +function streamed_request!(cb::AbstractStreamCallback, url, headers, input::IOBuffer; kwargs...) + streamed_request!(cb, url, headers, String(take!(input)); kwargs...) +end +function streamed_request!(cb::AbstractStreamCallback, url, headers, input::IO; kwargs...) + streamed_request!(cb, url, headers, read(input); kwargs...) +end +function streamed_request!(cb::AbstractStreamCallback, url, headers, input::Dict; kwargs...) + streamed_request!(cb, url, headers, String(JSON3.write(input)); kwargs...) +end +function streamed_request!(cb::AbstractStreamCallback, url, headers, input::String; kwargs...) + streamed_request_http!(cb, url, headers, input; kwargs...) + # streamed_request_libcurl!(cb, url, headers, input; kwargs...) +end +function streamed_request_http!(cb::AbstractStreamCallback, url, headers, input::String; kwargs...) verbose = get(kwargs, :verbose, false) || cb.verbose resp = HTTP.open("POST", url, headers; kwargs...) do stream - write(stream, String(take!(input))) + write(stream, input) HTTP.closewrite(stream) response = HTTP.startread(stream) @@ -226,7 +261,8 @@ function streamed_request!(cb::AbstractStreamCallback, url, headers, input; kwar spillover = "" while !eof(stream) || !isdone masterchunk = String(readavailable(stream)) - chunks, spillover = extract_chunks( + chunks, + spillover = extract_chunks( cb.flavor, masterchunk; verbose, spillover, cb.kwargs...) for chunk in chunks diff --git a/src/shared_methods_libcurl.jl b/src/shared_methods_libcurl.jl new file mode 100644 index 0000000..f142314 --- /dev/null +++ b/src/shared_methods_libcurl.jl @@ -0,0 +1,190 @@ +# LibCURL-based streaming implementation for StreamCallbacks.jl +using LibCURL + +""" + stream_write_callback(ptr::Ptr{UInt8}, size::Csize_t, nmemb::Csize_t, userdata::Ptr{Cvoid}) + +Callback function for processing streaming response data from libcurl. +""" +function stream_write_callback(ptr::Ptr{UInt8}, size::Csize_t, nmemb::Csize_t, userdata::Ptr{Cvoid})::Csize_t + callback_data = unsafe_pointer_to_objref(userdata) + cb, spillover_ref, isdone_ref, verbose, error_body = callback_data[] + + # Read the data + data_size = size * nmemb + chunk_data = unsafe_string(ptr, data_size) + + # Always capture raw data for potential error responses + write(error_body, chunk_data) + + # Extract chunks using existing logic + chunks, new_spillover = extract_chunks( + cb.flavor, chunk_data; verbose, spillover=spillover_ref, cb.kwargs...) + + # Update spillover + callback_data[] = (cb, new_spillover, isdone_ref, verbose, error_body) + + # Process chunks + for chunk in chunks + verbose && @debug "Chunk Data: $(chunk.data)" + handle_error_message(chunk; throw_on_error=cb.throw_on_error, verbose, cb.kwargs...) + if is_done(cb.flavor, chunk; verbose, cb.kwargs...) + callback_data[] = (cb, new_spillover, true, verbose, error_body) + end + callback(cb, chunk) + push!(cb, chunk) + end + + return data_size +end + +""" + stream_header_callback(ptr::Ptr{UInt8}, size::Csize_t, nmemb::Csize_t, userdata::Ptr{Cvoid}) + +Callback function for processing response headers from libcurl. +""" +function stream_header_callback(ptr::Ptr{UInt8}, size::Csize_t, nmemb::Csize_t, userdata::Ptr{Cvoid})::Csize_t + header_data = unsafe_pointer_to_objref(userdata) + response_headers, status_code = header_data[] + + # Read header line + header_size = size * nmemb + header_line = unsafe_string(ptr, header_size) + header_line = strip(header_line) + + # Parse status line + if startswith(header_line, "HTTP/") + parts = split(header_line) + length(parts) >= 2 && (status_code[] = parse(Int, parts[2])) + elseif occursin(":", header_line) + # Parse header + colon_pos = findfirst(':', header_line) + if !isnothing(colon_pos) + key = strip(header_line[1:colon_pos-1]) + value = strip(header_line[colon_pos+1:end]) + response_headers[lowercase(key)] = value + end + end + + return header_size +end + +""" + streamed_request_libcurl!(cb::AbstractStreamCallback, url::String, headers::Vector, body::String; kwargs...) + +LibCURL-based implementation of streamed_request! with better performance and reliability. +""" +function streamed_request_libcurl!(cb::AbstractStreamCallback, url::String, headers::Vector, body::String; kwargs...) + verbose = get(kwargs, :verbose, false) || cb.verbose + + # Initialize curl handle + curl = LibCURL.curl_easy_init() + curl == C_NULL && error("Failed to initialize curl") + + # Response data collection + response_headers = Dict{String,String}() + status_code = Ref{Int}(0) + spillover = "" + isdone = false + error_body = IOBuffer() + header_list = C_NULL + + try + # Set basic options + LibCURL.curl_easy_setopt(curl, LibCURL.CURLOPT_URL, url) + LibCURL.curl_easy_setopt(curl, LibCURL.CURLOPT_POST, 1) + LibCURL.curl_easy_setopt(curl, LibCURL.CURLOPT_POSTFIELDS, body) + LibCURL.curl_easy_setopt(curl, LibCURL.CURLOPT_POSTFIELDSIZE, sizeof(body)) + + # Set headers + for (key, value) in headers + header_str = "$key: $value" + header_list = LibCURL.curl_slist_append(header_list, header_str) + end + header_list != C_NULL && LibCURL.curl_easy_setopt(curl, LibCURL.CURLOPT_HTTPHEADER, header_list) + + # Write callback for streaming response data + write_callback = @cfunction(stream_write_callback, Csize_t, (Ptr{UInt8}, Csize_t, Csize_t, Ptr{Cvoid})) + callback_data = Ref((cb, spillover, isdone, verbose, error_body)) + LibCURL.curl_easy_setopt(curl, LibCURL.CURLOPT_WRITEFUNCTION, write_callback) + LibCURL.curl_easy_setopt(curl, LibCURL.CURLOPT_WRITEDATA, pointer_from_objref(callback_data)) + + # Header callback for response headers + header_callback = @cfunction(stream_header_callback, Csize_t, (Ptr{UInt8}, Csize_t, Csize_t, Ptr{Cvoid})) + header_data = Ref((response_headers, status_code)) + LibCURL.curl_easy_setopt(curl, LibCURL.CURLOPT_HEADERFUNCTION, header_callback) + LibCURL.curl_easy_setopt(curl, LibCURL.CURLOPT_HEADERDATA, pointer_from_objref(header_data)) + + # SSL options + LibCURL.curl_easy_setopt(curl, LibCURL.CURLOPT_CAINFO, LibCURL.cacert) + LibCURL.curl_easy_setopt(curl, LibCURL.CURLOPT_SSL_VERIFYPEER, 1) + LibCURL.curl_easy_setopt(curl, LibCURL.CURLOPT_SSL_VERIFYHOST, 2) + + # Perform the request + res = LibCURL.curl_easy_perform(curl) + + res != LibCURL.CURLE_OK && error("curl_easy_perform failed: $(LibCURL.curl_easy_strerror(res))") + + # Get final status code + status_ref = Ref{Clong}(0) + LibCURL.curl_easy_getinfo(curl, LibCURL.CURLINFO_RESPONSE_CODE, status_ref) + final_status = Int(status_ref[]) + + # Check for HTTP error status codes first + if final_status >= 400 + content_type = get(response_headers, "content-type", "") + error_body_str = String(take!(error_body)) + + error_msg = """ + HTTP Error $(final_status): Request failed""" + + if occursin("application/json", lowercase(content_type)) && !isempty(error_body_str) + error_msg *= "\nError response body: $(error_body_str)" + end + + error(error_msg * "\nPlease check your request parameters, API key, and model availability.") + end + + # Verify content type for successful responses + content_type = get(response_headers, "content-type", "") + expected_type = cb.flavor isa OllamaStream ? "application/x-ndjson" : "text/event-stream" + + if !occursin(expected_type, lowercase(content_type)) + flavor_name = cb.flavor isa OllamaStream ? "OllamaStream" : "streaming" + error(""" + For $(flavor_name) flavor, Content-Type must be $(expected_type). + Received type: $(content_type) + Status code: $(final_status) + Please check the model you are using and that you set `stream=true`. + """) + end + + # Aesthetic newline for stdout + cb.out == stdout && (println(); flush(stdout)) + + # Build response body + body_content = build_response_body(cb.flavor, cb; verbose, cb.kwargs...) + + # Create response object + resp = ( + status = Int16(final_status), + headers = collect(response_headers), + body = JSON3.write(body_content) + ) + + return resp + + finally + # Cleanup + header_list != C_NULL && LibCURL.curl_slist_free_all(header_list) + LibCURL.curl_easy_cleanup(curl) + end +end + +""" + streamed_request_libcurl!(cb::AbstractStreamCallback, url::String, headers::Vector, body::IOBuffer; kwargs...) + +LibCURL-based implementation that accepts IOBuffer input. +""" +streamed_request_libcurl!(cb::AbstractStreamCallback, url::String, headers::Vector, body::IOBuffer; kwargs...) = + streamed_request_libcurl!(cb, url, headers, String(take!(body)); kwargs...) \ No newline at end of file diff --git a/src/stream_ollama.jl b/src/stream_ollama.jl index c1ba18f..433b6cb 100644 --- a/src/stream_ollama.jl +++ b/src/stream_ollama.jl @@ -71,4 +71,4 @@ function build_response_body( !isnothing(usage) && merge!(response, usage) end return response -end \ No newline at end of file +end diff --git a/src/stream_response.jl b/src/stream_response.jl new file mode 100644 index 0000000..278498d --- /dev/null +++ b/src/stream_response.jl @@ -0,0 +1,132 @@ +# Custom methods for OpenAI Response API streaming -- flavor=ResponseStream() + +""" + is_done(flavor::ResponseStream, chunk::AbstractStreamChunk; kwargs...) + +Check if the streaming is done for Response API. +Response API sends "response.completed" event when done. +""" +@inline function is_done(flavor::ResponseStream, chunk::AbstractStreamChunk; kwargs...) + !isnothing(chunk.json) && get(chunk.json, :type, "") == "response.completed" +end + +""" + extract_content(flavor::ResponseStream, chunk::AbstractStreamChunk; kwargs...) + +Extract the content from Response API streaming chunks. +Response API uses event-based streaming with `response.output_text.delta` events. +""" +@inline function extract_content( + flavor::ResponseStream, chunk::AbstractStreamChunk; kwargs...) + if !isnothing(chunk.json) + # Response API uses different structure: {"type":"response.output_text.delta", "delta":"text", ...} + chunk_type = get(chunk.json, :type, "") + + # Handle regular output text deltas + if chunk_type == "response.output_text.delta" + return get(chunk.json, :delta, nothing) + + # Handle reasoning summary text deltas (for reasoning traces) + elseif chunk_type == "response.reasoning_summary_text.delta" + delta_text = get(chunk.json, :delta, nothing) + if !isnothing(delta_text) + # Italic reasoning summary segments + return "\e[3m" * delta_text * "\e[23m" + end + return nothing + + # When reasoning summary text is done, emit a newline separator + elseif chunk_type == "response.reasoning_summary_text.done" + return "\n" + end + end + return nothing +end + +""" + build_response_body(flavor::ResponseStream, cb::AbstractStreamCallback; verbose::Bool = false, kwargs...) + +Build the response body from the chunks to mimic receiving a standard response from the API. +Reconstructs the Response API format from streaming chunks. +""" +function build_response_body( + flavor::ResponseStream, cb::AbstractStreamCallback; verbose::Bool = false, kwargs...) + isempty(cb.chunks) && return nothing + + response = nothing + content_parts = String[] + + for chunk in cb.chunks + isnothing(chunk.json) && continue + + chunk_type = get(chunk.json, :type, "") + + # Initialize response from the first response.created event + if chunk_type == "response.created" && isnothing(response) + response = get(chunk.json, :response, Dict()) |> copy + end + + # Update response from response.completed event (final state) + if chunk_type == "response.completed" + final_response = get(chunk.json, :response, Dict()) + if !isnothing(response) + # Merge the final response data + response = merge(response, final_response) + else + response = final_response |> copy + end + end + + # Collect content from delta events + if chunk_type == "response.output_text.delta" + delta_content = get(chunk.json, :delta, "") + if !isempty(delta_content) + push!(content_parts, delta_content) + end + end + end + + # If we have response but need to reconstruct content + if !isnothing(response) && !isempty(content_parts) + full_content = join(content_parts) + + # Ensure we have the output structure + if !haskey(response, :output) || isempty(response[:output]) + # Create a basic message output structure + response[:output] = [ + Dict( + :type => "message", + :status => "completed", + :content => [ + Dict( + :type => "output_text", + :text => full_content + ) + ], + :role => "assistant" + ) + ] + else + # Convert output array to mutable and update existing output with reconstructed content + output_array = [] + for output_item in response[:output] + output_dict = Dict(output_item) # Convert JSON3.Object to Dict + if get(output_dict, :type, "") == "message" + content_array = [] + for content_item in get(output_dict, :content, []) + content_dict = Dict(content_item) # Convert JSON3.Object to Dict + if get(content_dict, :type, "") == "output_text" + content_dict[:text] = full_content + end + push!(content_array, content_dict) + end + output_dict[:content] = content_array + end + push!(output_array, output_dict) + end + response[:output] = output_array + end + end + + return response +end \ No newline at end of file diff --git a/test/shared_methods.jl b/test/shared_methods.jl index 012e570..c6ee9f5 100644 --- a/test/shared_methods.jl +++ b/test/shared_methods.jl @@ -50,21 +50,23 @@ end @test chunks[2].json == JSON3.read("{\"status\": \"complete\"}") @test spillover == "" - # Test with spillover + # Test with spillover - SSE spec compliant blob_with_spillover = "event: start\ndata: {\"key\": \"value\"}\n\nevent: continue\ndata: {\"partial\": \"data" - @test_logs (:info, r"Incomplete message detected") chunks, spillover=extract_chunks( + @test_logs (:info, r"Incomplete message detected") chunks, + spillover=extract_chunks( OpenAIStream(), blob_with_spillover; verbose = true) chunks, spillover = extract_chunks( OpenAIStream(), blob_with_spillover; verbose = true) @test length(chunks) == 1 @test chunks[1].event == :start @test chunks[1].json == JSON3.read("{\"key\": \"value\"}") - @test spillover == "{\"partial\": \"data" + @test spillover == "event: continue\ndata: {\"partial\": \"data" # Test with incoming spillover incoming_spillover = spillover blob_after_spillover = "\"}\n\nevent: end\ndata: {\"status\": \"complete\"}\n\n" - chunks, spillover = extract_chunks( + chunks, + spillover = extract_chunks( OpenAIStream(), blob_after_spillover; spillover = incoming_spillover) @test length(chunks) == 2 @test chunks[1].json == JSON3.read("{\"partial\": \"data\"}") @@ -72,12 +74,12 @@ end @test chunks[2].json == JSON3.read("{\"status\": \"complete\"}") @test spillover == "" - # Test with multiple data fields per event + # Test with multiple data fields per event - SSE spec compliant (joined with newlines) multi_data_blob = "event: multi\ndata: line1\ndata: line2\n\n" chunks, spillover = extract_chunks(OpenAIStream(), multi_data_blob) @test length(chunks) == 1 @test chunks[1].event == :multi - @test chunks[1].data == "line1line2" + @test chunks[1].data == "line1\nline2" # Test with non-JSON data non_json_blob = "event: text\ndata: This is plain text\n\n" @@ -145,7 +147,7 @@ end @test chunks[3].data == "[DONE]" @test spillover == "" - # Test case for s3: Simple data chunks + # Test case for s3: Simple data chunks - SSE spec compliant (joined with newlines) s3 = """data: a data: b data: c @@ -155,7 +157,7 @@ end """ chunks, spillover = extract_chunks(OpenAIStream(), s3) @test length(chunks) == 2 - @test chunks[1].data == "abc" + @test chunks[1].data == "a\nb\nc" @test chunks[2].data == "[DONE]" @test spillover == "" @@ -180,6 +182,139 @@ end @test length(chunks) == 2 @test chunks[2].data == "[DONE]" @test final_spillover == "" + + # Test with real Anthropic LLM response streams captured from test_data_clip_issue.jl + # This tests SSE spec compliance with actual data that includes "data:" patterns in the content + real_anthropic_blob = """event: message_start +data: {"type":"message_start","message":{"id":"msg_01Kf4tf7utCTCiPTBYteMCUS","type":"message","role":"assistant","model":"claude-3-5-haiku-20241022","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":32,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":2,"service_tier":"standard"}}} + +event: content_block_start +data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + +event: ping +data: {"type": "ping"} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Here you"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" go:\\n\\ndata"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":": [1], data"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":": [2], data"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":": [3], data"}} + +event: message_stop +data: {"type":"message_stop"} + +""" + chunks, spillover = extract_chunks(AnthropicStream(), real_anthropic_blob) + @test length(chunks) == 9 + @test spillover == "" + + # Test that SSE spec compliance correctly handles "data:" patterns in content + content_deltas = filter(chunk -> chunk.event == :content_block_delta, chunks) + @test length(content_deltas) == 5 + + # Verify that raw data contains "data" patterns exactly as sent by LLM + # Note: Looking for "data" in JSON content, not the SSE "data:" field + data_containing_chunks = filter( + chunk -> contains(chunk.data, "\"text\":\"") && contains(chunk.data, "data"), + content_deltas) + @test length(data_containing_chunks) == 4 # chunks 5,6,7,8 contain "data" in the text field + + # Test specific content extraction + @test chunks[1].event == :message_start + @test chunks[2].event == :content_block_start + @test chunks[3].event == :ping + @test chunks[9].event == :message_stop + + # Test that content contains the exact "data:" patterns as generated by LLM + text_delta_chunk = chunks[6] # chunk with ": [1], data" + @test text_delta_chunk.event == :content_block_delta + @test contains(text_delta_chunk.data, "data") + parsed_json = text_delta_chunk.json + @test parsed_json.delta.text == ": [1], data" +end + +@testset "SSE spec compliance fixes" begin + # Test 1: BOM handling - UTF-8 BOM should be stripped from field names + bom_blob = "\ufeffdata: message with BOM\nevent: test_event\n\n" + chunks, spillover = extract_chunks(OpenAIStream(), bom_blob) + @test length(chunks) == 1 + @test chunks[1].data == "message with BOM" + @test chunks[1].event == :test_event + @test spillover == "" + + # Test 2: BOM in field value should be preserved + bom_value_blob = "data: \ufeffmessage with BOM in value\n\n" + chunks, spillover = extract_chunks(OpenAIStream(), bom_value_blob) + @test length(chunks) == 1 + @test chunks[1].data == "\ufeffmessage with BOM in value" + + # Test 3: Empty data fields should create proper empty strings (no artifacts) + empty_data_blob = "data:\ndata: \ndata: \nevent: test\n\n" + chunks, spillover = extract_chunks(OpenAIStream(), empty_data_blob) + @test length(chunks) == 1 + @test chunks[1].data == "\n\n " # Three data fields: empty, space, two spaces + @test chunks[1].event == :test + + # Test 4: Multiple empty data fields + multi_empty_blob = "data:\ndata:\ndata:\n\n" + chunks, spillover = extract_chunks(OpenAIStream(), multi_empty_blob) + @test length(chunks) == 1 + @test chunks[1].data == "\n\n" # Three empty data fields joined with newlines + + # Test 5: Empty event field should be ignored + empty_event_blob = "data: test message\nevent:\nevent: \n\n" + chunks, spillover = extract_chunks(OpenAIStream(), empty_event_blob) + @test length(chunks) == 1 + @test chunks[1].data == "test message" + @test chunks[1].event === nothing # Empty event fields should be ignored + + # Test 6: Error handling - malformed lines should be handled gracefully + # The error handling prevents crashes but may not always generate warnings + malformed_blob = "data: valid message\n\x00invalid: line with null\ndata: another valid\n\n" + chunks, spillover = extract_chunks(OpenAIStream(), malformed_blob; verbose = false) + @test length(chunks) == 1 + @test chunks[1].data == "valid message\nanother valid" + + # Test 7: Invalid UTF-8 sequences should be handled gracefully + invalid_utf8_blob = "data: valid\n\xff\xfe: invalid utf8\ndata: also valid\n\n" + chunks, spillover = extract_chunks(OpenAIStream(), invalid_utf8_blob; verbose = false) # No warnings for silent test + @test length(chunks) == 1 + @test chunks[1].data == "valid\nalso valid" + + # Test 8: Unicode field names with BOM + unicode_bom_blob = "\ufeff测试: unicode field\ndata: test data\n\n" + chunks, spillover = extract_chunks(OpenAIStream(), unicode_bom_blob) + @test length(chunks) == 1 + @test chunks[1].data == "test data" # Unicode field name should be ignored (not 'data') + + # Test 9: Mixed valid/invalid lines + mixed_blob = """data: line1 + invalid line without colon + event: test_event + data: line2 + : this is a comment + data: line3 + + """ + chunks, spillover = extract_chunks(OpenAIStream(), mixed_blob) + @test length(chunks) == 1 + @test chunks[1].data == "line1\nline2\nline3" + @test chunks[1].event == :test_event + + # Test 10: Edge case - field name that becomes empty after BOM removal + edge_bom_blob = "\ufeff: value after empty field\ndata: valid data\n\n" + chunks, spillover = extract_chunks(OpenAIStream(), edge_bom_blob) + @test length(chunks) == 1 + @test chunks[1].data == "valid data" # BOM+colon line should be treated as comment end @testset "print_content" begin @@ -253,6 +388,13 @@ end @test_throws Exception handle_error_message(error_chunk, throw_on_error = true) end +@testset "extract_content" begin + # Test unimplemented flavor throws error + struct TestFlavor <: AbstractStreamFlavor end + test_chunk = StreamChunk(nothing, "test data", nothing) + @test_throws ArgumentError extract_content(TestFlavor(), test_chunk) +end + ## Not working yet!! # @testset "streamed_request!" begin # # Setup mock server diff --git a/test/test_data_clip_issue2.jl b/test/test_data_clip_issue2.jl new file mode 100644 index 0000000..74a0e01 --- /dev/null +++ b/test/test_data_clip_issue2.jl @@ -0,0 +1,83 @@ +using Test +using HTTP +using JSON3 +using StreamCallbacks + +@testset "Anthropic Integration Test" begin + # Skip if no API key is available + api_key = get(ENV, "ANTHROPIC_API_KEY", "") + if isempty(api_key) + @test_skip "Skipping Anthropic integration test - no API key found" + return + end + + # Prepare target and auth + url = "https://api.anthropic.com/v1/messages" + headers = [ + "anthropic-version" => "2023-06-01", + "x-api-key" => api_key + ] + + # Test streaming callback + cb = StreamCallback(; out = nothing, flavor = AnthropicStream()) + + messages = [Dict("role" => "user", "content" => "Write me data: [1], data: [2], data: [3] ... do it 10 times. All into separate lines.")] + payload = IOBuffer() + JSON3.write(payload, + (; stream = true, messages, model = "claude-3-5-haiku-latest", max_tokens = 2048)) + + # Send the request + resp = streamed_request!(cb, url, headers, payload) + + # Test response structure - it's a NamedTuple with HTTP.Response fields + @test resp.status == 200 + @test !isempty(cb.chunks) + + # Build response body + response_body = StreamCallbacks.build_response_body(AnthropicStream(), cb) + @test !isnothing(response_body) + @test haskey(response_body, :content) + @test length(response_body[:content]) >= 1 + @test response_body[:content][1][:type] == "text" + + # Extract the generated text + generated_text = response_body[:content][1][:text] + @test !isempty(generated_text) + + # Test that the response contains exactly 10 "data: " patterns + # Count occurrences of "data: [" followed by a digit and "]" + data_pattern_count = length(collect(eachmatch(r"data: \[\d+\]", generated_text))) + @test data_pattern_count == 10 # Should be exactly 10 + + # Alternative: count simple "data: " occurrences + data_prefix_count = length(collect(eachmatch(r"data: ", generated_text))) + @test data_prefix_count >= 10 # Should be at least 10 + + # Test that chunks contain valid JSON + json_chunks = filter(chunk -> !isnothing(chunk.json), cb.chunks) + @test !isempty(json_chunks) + + # Test that we have message_start and message_stop events + events = [chunk.event for chunk in cb.chunks if !isnothing(chunk.event)] + @test :message_start in events + @test :message_stop in events + + # Test content extraction from chunks + content_chunks = String[] + for chunk in cb.chunks + content = extract_content(AnthropicStream(), chunk) + if !isnothing(content) && !isempty(content) + push!(content_chunks, content) + end + end + @test !isempty(content_chunks) + + # Verify that concatenating content chunks gives us the full text + reconstructed_text = join(content_chunks, "") + @test reconstructed_text == generated_text + + # Debug output to see what we actually got + println("Generated text: ", repr(generated_text)) + println("Data pattern count: ", data_pattern_count) + println("Data prefix count: ", data_prefix_count) +end \ No newline at end of file