Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
name = "StreamCallbacks"
uuid = "c1b9e933-98a0-46fc-8ea7-3b58b195fb0a"
version = "0.7.0"
authors = ["J S <49557684+svilupp@users.noreply.github.com> and contributors"]
version = "0.6.2"

[deps]
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
PromptingTools = "670122d1-24a8-4d70-bfce-740807c42192"

[compat]
Aqua = "0.8"
Expand Down
4 changes: 4 additions & 0 deletions src/StreamCallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ module StreamCallbacks

using HTTP, JSON3
using PrecompileTools
import PromptingTools
using PromptingTools: AbstractStreamCallback, AbstractPromptSchema
using PromptingTools: configure_callback!, streamed_request!


export StreamCallback, StreamChunk, OpenAIStream, AnthropicStream, OllamaStream,
streamed_request!
Expand Down
30 changes: 8 additions & 22 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# This file defines the core interface for the StreamCallbacks.jl package
#
# The goal is to enable custom callbacks for streaming LLM APIs,
Expand All @@ -11,8 +12,6 @@
# which is a simple struct that holds the individual "chunks" (StreamChunk)
# and presents the logic necessary for processing
#
# Top-level interface that wraps the HTTP.POST request and handles the streaming
function streamed_request! end
# It composes of the following interface functions
# Extract the chunks from the received SSE blob. Returns a list of `StreamChunk`
# At the moment, it's assumed to be generic enough for ANY API provider (TBU).
Expand Down Expand Up @@ -42,21 +41,7 @@ It must have the following fields:
- `json`: The JSON object or `nothing` if the chunk does not contain JSON.
"""
abstract type AbstractStreamChunk end

"""
AbstractStreamCallback

Abstract type for the stream callback.

It must have the following fields:
- `out`: The output stream, eg, `stdout` or a pipe.
- `flavor`: The stream flavor which might or might not differ between different providers, eg, `OpenAIStream` or `AnthropicStream`.
- `chunks`: The list of received `AbstractStreamChunk` chunks.
- `verbose`: Whether to print verbose information.
- `throw_on_error`: Whether to throw an error if an error message is detected in the streaming response.
- `kwargs`: Any custom keyword arguments required for your use case.
"""
abstract type AbstractStreamCallback end
abstract type AbstractHTTPStreamCallback <: AbstractStreamCallback end

"""
AbstractStreamFlavor
Expand Down Expand Up @@ -156,7 +141,7 @@ msg = aigenerate("Count from 1 to 10."; streamcallback)
Note: If you provide a `StreamCallback` object to `aigenerate`, we will configure it and necessary `api_kwargs` via `configure_callback!` unless you specify the `flavor` field.
If you provide a `StreamCallback` with a specific `flavor`, we leave all configuration to the user (eg, you need to provide the correct `api_kwargs`).
"""
@kwdef mutable struct StreamCallback{T1 <: Any} <: AbstractStreamCallback
@kwdef mutable struct StreamCallback{T1 <: Any} <: AbstractHTTPStreamCallback
out::T1 = stdout
flavor::Union{AbstractStreamFlavor, Nothing} = nothing
chunks::Vector{<:StreamChunk} = StreamChunk[]
Expand All @@ -168,7 +153,8 @@ function Base.show(io::IO, cb::StreamCallback)
print(io,
"StreamCallback(out=$(cb.out), flavor=$(cb.flavor), chunks=$(length(cb.chunks)) items, $(cb.verbose ? "verbose" : "silent"), $(cb.throw_on_error ? "throw_on_error" : "no_throw"))")
end
Base.empty!(cb::AbstractStreamCallback) = empty!(cb.chunks)
Base.push!(cb::AbstractStreamCallback, chunk::StreamChunk) = push!(cb.chunks, chunk)
Base.isempty(cb::AbstractStreamCallback) = isempty(cb.chunks)
Base.length(cb::AbstractStreamCallback) = length(cb.chunks)

Base.empty!(cb::AbstractHTTPStreamCallback) = empty!(cb.chunks)
Base.push!(cb::AbstractHTTPStreamCallback, chunk::StreamChunk) = push!(cb.chunks, chunk)
Base.isempty(cb::AbstractHTTPStreamCallback) = isempty(cb.chunks)
Base.length(cb::AbstractHTTPStreamCallback) = length(cb.chunks)
5 changes: 5 additions & 0 deletions src/precompilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ flavor = OllamaStream()
is_done(flavor, example_chunk)
extract_content(flavor, example_chunk)
build_response_body(flavor, cb)

using PromptingTools: OpenAISchema
## Streaming configuration
cb = StreamCallback()
configure_callback!(cb, OpenAISchema())
47 changes: 43 additions & 4 deletions src/shared_methods.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using PromptingTools: AbstractOpenAISchema, AbstractAnthropicSchema, AbstractOllamaSchema, AbstractOllamaManagedSchema

# ## Default methods

"""
extract_chunks(flavor::AbstractStreamFlavor, blob::AbstractString;
spillover::AbstractString = "", verbose::Bool = false, kwargs...)
Expand Down Expand Up @@ -152,7 +152,7 @@ end
print_content(::Nothing, ::AbstractString; kwargs...) = nothing

"""
callback(cb::AbstractStreamCallback, chunk::AbstractStreamChunk; kwargs...)
callback(cb::AbstractHTTPStreamCallback, chunk::AbstractStreamChunk; kwargs...)

Process the chunk to be printed and print it. It's a wrapper for two operations:
- extract the content from the chunk using `extract_content`
Expand Down Expand Up @@ -198,7 +198,7 @@ Handles error messages from the streaming response.
end

"""
streamed_request!(cb::AbstractStreamCallback, url, headers, input; kwargs...)
streamed_request!(cb::AbstractHTTPStreamCallback, url, headers, input; kwargs...)

End-to-end wrapper for POST streaming requests.
In-place modification of the callback object (`cb.chunks`) with the results of the request being returned.
Expand All @@ -213,7 +213,7 @@ Returns the response object.
- `input`: A buffer with the request body.
- `kwargs`: Additional keyword arguments.
"""
function streamed_request!(cb::AbstractStreamCallback, url, headers, input; kwargs...)
function PromptingTools.streamed_request!(cb::AbstractHTTPStreamCallback, url, headers, input; kwargs...)
verbose = get(kwargs, :verbose, false) || cb.verbose
resp = HTTP.open("POST", url, headers; kwargs...) do stream
write(stream, String(take!(input)))
Expand Down Expand Up @@ -274,3 +274,42 @@ function streamed_request!(cb::AbstractStreamCallback, url, headers, input; kwar

return resp
end

"""
configure_callback!(cb::AbstractHTTPStreamCallback, schema::AbstractPromptSchema;
api_kwargs...)

Configures the callback `cb` for streaming with a given prompt schema.

If no `cb.flavor` is provided, adjusts the `flavor` and the provided `api_kwargs` as necessary.
Eg, for most schemas, we add kwargs like `stream = true` to the `api_kwargs`.

If `cb.flavor` is provided, both `callback` and `api_kwargs` are left unchanged! You need to configure them yourself!
"""
function PromptingTools.configure_callback!(cb::AbstractHTTPStreamCallback, schema::AbstractPromptSchema;
api_kwargs...)
## Check if we are in passthrough mode or if we should configure the callback
if isnothing(cb.flavor)
if schema isa AbstractOpenAISchema
## Enable streaming for all OpenAI-compatible APIs
api_kwargs = (;
api_kwargs..., stream = true, stream_options = (; include_usage = true))
flavor = OpenAIStream()
elseif schema isa Union{AbstractAnthropicSchema, AbstractOllamaSchema}
api_kwargs = (; api_kwargs..., stream = true)
flavor = schema isa AbstractOllamaSchema ? OllamaStream() : AnthropicStream()
elseif schema isa AbstractOllamaManagedSchema
throw(ErrorException("OllamaManagedSchema is not supported for streaming. Use OllamaSchema instead."))
else
error("Unsupported schema type: $(typeof(schema)). Currently supported: OpenAISchema and AnthropicSchema.")
end
cb.flavor = flavor
end
return cb, api_kwargs
end
# method to build a callback from IO or Channel
function PromptingTools.configure_callback!(
output_stream::Union{IO, Channel}, schema::AbstractPromptSchema)
cb = StreamCallback(out = output_stream)
return configure_callback!(cb, schema)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ using StreamCallbacks: AbstractStreamFlavor, OpenAIStream, AnthropicStream, Stre
include("stream_openai.jl")
include("stream_anthropic.jl")
include("stream_ollama.jl")
include("streaming.jl")
end
29 changes: 29 additions & 0 deletions test/streaming.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using StreamCallbacks: StreamCallback, StreamChunk, OpenAIStream, AnthropicStream,
configure_callback!, OllamaStream
using PromptingTools: OpenAISchema, AnthropicSchema, GoogleSchema, OllamaSchema

@testset "configure_callback!" begin
# Test configure_callback! method
cb, api_kwargs = configure_callback!(StreamCallback(), OpenAISchema())
@test cb.flavor isa OpenAIStream
@test api_kwargs[:stream] == true
@test api_kwargs[:stream_options] == (include_usage = true,)

cb, api_kwargs = configure_callback!(StreamCallback(), AnthropicSchema())
@test cb.flavor isa AnthropicStream
@test api_kwargs[:stream] == true

cb, api_kwargs = configure_callback!(StreamCallback(), OllamaSchema())
@test cb.flavor isa OllamaStream
@test api_kwargs[:stream] == true

# Test error for unsupported schema
@test_throws ErrorException configure_callback!(StreamCallback(), GoogleSchema())
@test_throws ErrorException configure_callback!(StreamCallback(), OllamaManagedSchema())

# Test configure_callback! with output stream
cb, _ = configure_callback!(IOBuffer(), OpenAISchema())
@test cb isa StreamCallback
@test cb.out isa IOBuffer
@test cb.flavor isa OpenAIStream
end
Loading