diff --git a/Project.toml b/Project.toml index dbd9cd9..78d26a4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,13 @@ name = "StreamCallbacks" uuid = "c1b9e933-98a0-46fc-8ea7-3b58b195fb0a" +version = "0.7.0" authors = ["J S <49557684+svilupp@users.noreply.github.com> and contributors"] -version = "0.6.2" [deps] HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +PromptingTools = "670122d1-24a8-4d70-bfce-740807c42192" [compat] Aqua = "0.8" diff --git a/src/StreamCallbacks.jl b/src/StreamCallbacks.jl index 8230bb7..bdd2f87 100644 --- a/src/StreamCallbacks.jl +++ b/src/StreamCallbacks.jl @@ -2,6 +2,10 @@ module StreamCallbacks using HTTP, JSON3 using PrecompileTools +import PromptingTools +using PromptingTools: AbstractStreamCallback, AbstractPromptSchema +using PromptingTools: configure_callback!, streamed_request! + export StreamCallback, StreamChunk, OpenAIStream, AnthropicStream, OllamaStream, streamed_request! diff --git a/src/interface.jl b/src/interface.jl index d150772..33fb111 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,3 +1,4 @@ + # This file defines the core interface for the StreamCallbacks.jl package # # The goal is to enable custom callbacks for streaming LLM APIs, @@ -11,8 +12,6 @@ # which is a simple struct that holds the individual "chunks" (StreamChunk) # and presents the logic necessary for processing # -# Top-level interface that wraps the HTTP.POST request and handles the streaming -function streamed_request! end # It composes of the following interface functions # Extract the chunks from the received SSE blob. Returns a list of `StreamChunk` # At the moment, it's assumed to be generic enough for ANY API provider (TBU). @@ -42,21 +41,7 @@ It must have the following fields: - `json`: The JSON object or `nothing` if the chunk does not contain JSON. """ abstract type AbstractStreamChunk end - -""" - AbstractStreamCallback - -Abstract type for the stream callback. - -It must have the following fields: -- `out`: The output stream, eg, `stdout` or a pipe. -- `flavor`: The stream flavor which might or might not differ between different providers, eg, `OpenAIStream` or `AnthropicStream`. -- `chunks`: The list of received `AbstractStreamChunk` chunks. -- `verbose`: Whether to print verbose information. -- `throw_on_error`: Whether to throw an error if an error message is detected in the streaming response. -- `kwargs`: Any custom keyword arguments required for your use case. -""" -abstract type AbstractStreamCallback end +abstract type AbstractHTTPStreamCallback <: AbstractStreamCallback end """ AbstractStreamFlavor @@ -156,7 +141,7 @@ msg = aigenerate("Count from 1 to 10."; streamcallback) Note: If you provide a `StreamCallback` object to `aigenerate`, we will configure it and necessary `api_kwargs` via `configure_callback!` unless you specify the `flavor` field. If you provide a `StreamCallback` with a specific `flavor`, we leave all configuration to the user (eg, you need to provide the correct `api_kwargs`). """ -@kwdef mutable struct StreamCallback{T1 <: Any} <: AbstractStreamCallback +@kwdef mutable struct StreamCallback{T1 <: Any} <: AbstractHTTPStreamCallback out::T1 = stdout flavor::Union{AbstractStreamFlavor, Nothing} = nothing chunks::Vector{<:StreamChunk} = StreamChunk[] @@ -168,7 +153,8 @@ function Base.show(io::IO, cb::StreamCallback) print(io, "StreamCallback(out=$(cb.out), flavor=$(cb.flavor), chunks=$(length(cb.chunks)) items, $(cb.verbose ? "verbose" : "silent"), $(cb.throw_on_error ? "throw_on_error" : "no_throw"))") end -Base.empty!(cb::AbstractStreamCallback) = empty!(cb.chunks) -Base.push!(cb::AbstractStreamCallback, chunk::StreamChunk) = push!(cb.chunks, chunk) -Base.isempty(cb::AbstractStreamCallback) = isempty(cb.chunks) -Base.length(cb::AbstractStreamCallback) = length(cb.chunks) + +Base.empty!(cb::AbstractHTTPStreamCallback) = empty!(cb.chunks) +Base.push!(cb::AbstractHTTPStreamCallback, chunk::StreamChunk) = push!(cb.chunks, chunk) +Base.isempty(cb::AbstractHTTPStreamCallback) = isempty(cb.chunks) +Base.length(cb::AbstractHTTPStreamCallback) = length(cb.chunks) \ No newline at end of file diff --git a/src/precompilation.jl b/src/precompilation.jl index bb74472..6bd6a3c 100644 --- a/src/precompilation.jl +++ b/src/precompilation.jl @@ -31,3 +31,8 @@ flavor = OllamaStream() is_done(flavor, example_chunk) extract_content(flavor, example_chunk) build_response_body(flavor, cb) + +using PromptingTools: OpenAISchema +## Streaming configuration +cb = StreamCallback() +configure_callback!(cb, OpenAISchema()) \ No newline at end of file diff --git a/src/shared_methods.jl b/src/shared_methods.jl index 3423c41..ec201a3 100644 --- a/src/shared_methods.jl +++ b/src/shared_methods.jl @@ -1,6 +1,6 @@ +using PromptingTools: AbstractOpenAISchema, AbstractAnthropicSchema, AbstractOllamaSchema, AbstractOllamaManagedSchema # ## Default methods - """ extract_chunks(flavor::AbstractStreamFlavor, blob::AbstractString; spillover::AbstractString = "", verbose::Bool = false, kwargs...) @@ -152,7 +152,7 @@ end print_content(::Nothing, ::AbstractString; kwargs...) = nothing """ - callback(cb::AbstractStreamCallback, chunk::AbstractStreamChunk; kwargs...) + callback(cb::AbstractHTTPStreamCallback, chunk::AbstractStreamChunk; kwargs...) Process the chunk to be printed and print it. It's a wrapper for two operations: - extract the content from the chunk using `extract_content` @@ -198,7 +198,7 @@ Handles error messages from the streaming response. end """ - streamed_request!(cb::AbstractStreamCallback, url, headers, input; kwargs...) + streamed_request!(cb::AbstractHTTPStreamCallback, url, headers, input; kwargs...) End-to-end wrapper for POST streaming requests. In-place modification of the callback object (`cb.chunks`) with the results of the request being returned. @@ -213,7 +213,7 @@ Returns the response object. - `input`: A buffer with the request body. - `kwargs`: Additional keyword arguments. """ -function streamed_request!(cb::AbstractStreamCallback, url, headers, input; kwargs...) +function PromptingTools.streamed_request!(cb::AbstractHTTPStreamCallback, url, headers, input; kwargs...) verbose = get(kwargs, :verbose, false) || cb.verbose resp = HTTP.open("POST", url, headers; kwargs...) do stream write(stream, String(take!(input))) @@ -274,3 +274,42 @@ function streamed_request!(cb::AbstractStreamCallback, url, headers, input; kwar return resp end + +""" + configure_callback!(cb::AbstractHTTPStreamCallback, schema::AbstractPromptSchema; + api_kwargs...) + +Configures the callback `cb` for streaming with a given prompt schema. + +If no `cb.flavor` is provided, adjusts the `flavor` and the provided `api_kwargs` as necessary. +Eg, for most schemas, we add kwargs like `stream = true` to the `api_kwargs`. + +If `cb.flavor` is provided, both `callback` and `api_kwargs` are left unchanged! You need to configure them yourself! +""" +function PromptingTools.configure_callback!(cb::AbstractHTTPStreamCallback, schema::AbstractPromptSchema; + api_kwargs...) + ## Check if we are in passthrough mode or if we should configure the callback + if isnothing(cb.flavor) + if schema isa AbstractOpenAISchema + ## Enable streaming for all OpenAI-compatible APIs + api_kwargs = (; + api_kwargs..., stream = true, stream_options = (; include_usage = true)) + flavor = OpenAIStream() + elseif schema isa Union{AbstractAnthropicSchema, AbstractOllamaSchema} + api_kwargs = (; api_kwargs..., stream = true) + flavor = schema isa AbstractOllamaSchema ? OllamaStream() : AnthropicStream() + elseif schema isa AbstractOllamaManagedSchema + throw(ErrorException("OllamaManagedSchema is not supported for streaming. Use OllamaSchema instead.")) + else + error("Unsupported schema type: $(typeof(schema)). Currently supported: OpenAISchema and AnthropicSchema.") + end + cb.flavor = flavor + end + return cb, api_kwargs +end +# method to build a callback from IO or Channel +function PromptingTools.configure_callback!( + output_stream::Union{IO, Channel}, schema::AbstractPromptSchema) + cb = StreamCallback(out = output_stream) + return configure_callback!(cb, schema) +end diff --git a/test/runtests.jl b/test/runtests.jl index 2511a8f..57cdfa6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,4 +16,5 @@ using StreamCallbacks: AbstractStreamFlavor, OpenAIStream, AnthropicStream, Stre include("stream_openai.jl") include("stream_anthropic.jl") include("stream_ollama.jl") + include("streaming.jl") end diff --git a/test/streaming.jl b/test/streaming.jl new file mode 100644 index 0000000..34a3d9e --- /dev/null +++ b/test/streaming.jl @@ -0,0 +1,29 @@ +using StreamCallbacks: StreamCallback, StreamChunk, OpenAIStream, AnthropicStream, + configure_callback!, OllamaStream +using PromptingTools: OpenAISchema, AnthropicSchema, GoogleSchema, OllamaSchema + +@testset "configure_callback!" begin + # Test configure_callback! method + cb, api_kwargs = configure_callback!(StreamCallback(), OpenAISchema()) + @test cb.flavor isa OpenAIStream + @test api_kwargs[:stream] == true + @test api_kwargs[:stream_options] == (include_usage = true,) + + cb, api_kwargs = configure_callback!(StreamCallback(), AnthropicSchema()) + @test cb.flavor isa AnthropicStream + @test api_kwargs[:stream] == true + + cb, api_kwargs = configure_callback!(StreamCallback(), OllamaSchema()) + @test cb.flavor isa OllamaStream + @test api_kwargs[:stream] == true + + # Test error for unsupported schema + @test_throws ErrorException configure_callback!(StreamCallback(), GoogleSchema()) + @test_throws ErrorException configure_callback!(StreamCallback(), OllamaManagedSchema()) + + # Test configure_callback! with output stream + cb, _ = configure_callback!(IOBuffer(), OpenAISchema()) + @test cb isa StreamCallback + @test cb.out isa IOBuffer + @test cb.flavor isa OpenAIStream +end