Skip to content
25 changes: 10 additions & 15 deletions lib/ruby_llm/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ def initialize(model: nil, provider: nil, assume_model_exists: false, context: n
@callbacks = Hash.new { |callbacks, name| callbacks[name] = [] }
end

def ask(message = nil, with: nil, &)
add_message role: :user, content: build_content(message, with)
def ask(message = nil, with: nil, cache_point: false, &)
add_message role: :user, content: build_content(message, with), cache_point: cache_point
complete(&)
end

alias say ask

def with_instructions(instructions, append: false, replace: nil)
def with_instructions(instructions, append: false, replace: nil, cache_point: false)
append ||= (replace == false) unless replace.nil?

if append
append_system_instruction(instructions)
append_system_instruction(instructions, cache_point: cache_point)
else
replace_system_instruction(instructions)
replace_system_instruction(instructions, cache_point: cache_point)
end

self
Expand Down Expand Up @@ -370,21 +370,16 @@ def content_like?(object)
object.is_a?(Content) || object.is_a?(Content::Raw)
end

def append_system_instruction(instructions)
def append_system_instruction(instructions, cache_point: false)
system_messages, non_system_messages = @messages.partition { |msg| msg.role == :system }
system_messages << Message.new(role: :system, content: instructions)
system_messages << Message.new(role: :system, content: instructions, cache_point: cache_point)
@messages = system_messages + non_system_messages
end

def replace_system_instruction(instructions)
system_messages, non_system_messages = @messages.partition { |msg| msg.role == :system }
def replace_system_instruction(instructions, cache_point: false)
_, non_system_messages = @messages.partition { |msg| msg.role == :system }

if system_messages.empty?
system_messages = [Message.new(role: :system, content: instructions)]
else
system_messages.first.content = instructions
system_messages = [system_messages.first]
end
system_messages = [Message.new(role: :system, content: instructions, cache_point: cache_point)]

@messages = system_messages + non_system_messages
end
Expand Down
7 changes: 5 additions & 2 deletions lib/ruby_llm/message.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ module RubyLLM
class Message
ROLES = %i[system user assistant tool].freeze

attr_reader :role, :model_id, :tool_calls, :tool_call_id, :raw, :thinking, :tokens
attr_reader :role, :model_id, :tool_calls, :tool_call_id, :raw, :thinking, :tokens, :cache_point
alias cache_point? cache_point
attr_writer :content

def initialize(options = {})
Expand All @@ -24,6 +25,7 @@ def initialize(options = {})
)
@raw = options[:raw]
@thinking = options[:thinking]
@cache_point = options.fetch(:cache_point, false)

ensure_valid_role
end
Expand Down Expand Up @@ -92,7 +94,8 @@ def to_h
tool_calls: tool_calls,
tool_call_id: tool_call_id,
thinking: thinking&.text,
thinking_signature: thinking&.signature
thinking_signature: thinking&.signature,
cache_point: @cache_point || nil
}.merge(tokens ? tokens.to_h : {}).compact
end

Expand Down
2 changes: 1 addition & 1 deletion lib/ruby_llm/provider.rb
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def maybe_normalize_temperature(temperature, _model)

def sync_response(connection, payload, additional_headers = {})
response = connection.post completion_url, payload do |req|
req.headers = additional_headers.merge(req.headers) unless additional_headers.empty?
req.headers.merge!(additional_headers) unless additional_headers.empty?
end
parse_completion_response response
end
Expand Down
7 changes: 7 additions & 0 deletions lib/ruby_llm/providers/anthropic.rb
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ def headers
}
end

def complete(messages, headers: {}, **kwargs, &block)
headers = headers.merge('anthropic-beta' => 'prompt-caching-2024-07-31') if messages.any?(&:cache_point?)

super(messages, headers: headers, **kwargs, &block) # rubocop:disable Style/SuperArguments
# Ignoring as we're modifying headers before calling super. We need to call super with modified headers.
end

class << self
def capabilities
Anthropic::Capabilities
Expand Down
24 changes: 19 additions & 5 deletions lib/ruby_llm/providers/anthropic/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,13 @@ def build_system_content(system_messages)
system_messages.flat_map do |msg|
content = msg.content

if content.is_a?(RubyLLM::Content::Raw)
content.value
else
Media.format_content(content)
end
blocks = if content.is_a?(RubyLLM::Content::Raw)
Array(content.value)
else
Array(Media.format_content(content))
end

msg.cache_point? ? inject_cache_control(blocks) : blocks
end
end

Expand Down Expand Up @@ -155,6 +157,7 @@ def format_basic_message_with_thinking(msg, thinking_enabled)
end

append_formatted_content(content_blocks, msg.content)
inject_cache_control(content_blocks) if msg.cache_point?

{
role: convert_role(msg.role),
Expand Down Expand Up @@ -224,6 +227,17 @@ def append_formatted_content(content_blocks, content)
end
end

def inject_cache_control(blocks)
return blocks if blocks.empty?

last = blocks.last
# Don't duplicate if already present (e.g. Content::Raw with cache_control)
return blocks if last.is_a?(Hash) && last[:cache_control]

blocks[-1] = last.merge(cache_control: { type: 'ephemeral' })
blocks
end

def convert_role(role)
case role
when :tool, :user then 'user'
Expand Down
32 changes: 20 additions & 12 deletions lib/ruby_llm/providers/bedrock/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,26 @@ def render_message_content(msg)
text_and_media_blocks = Media.render_content(msg.content, used_document_names: @used_document_names)
blocks.concat(text_and_media_blocks) if text_and_media_blocks

if msg.tool_call?
msg.tool_calls.each_value do |tool_call|
blocks << {
toolUse: {
toolUseId: tool_call.id,
name: tool_call.name,
input: tool_call.arguments
}
}
end
end
append_tool_use_blocks(blocks, msg)
blocks << { cachePoint: { type: 'default' } } if msg.cache_point?

blocks
end

def append_tool_use_blocks(blocks, msg)
return unless msg.tool_call?

msg.tool_calls.each_value do |tool_call|
blocks << {
toolUse: {
toolUseId: tool_call.id,
name: tool_call.name,
input: tool_call.arguments
}
}
end
end

def render_raw_content(content)
value = content.value
value.is_a?(Array) ? value : [value]
Expand Down Expand Up @@ -211,7 +216,10 @@ def render_role(role)
end

def render_system(messages)
messages.flat_map { |msg| Media.render_content(msg.content, used_document_names: @used_document_names) }
messages.flat_map do |msg|
blocks = Media.render_content(msg.content, used_document_names: @used_document_names)
msg.cache_point? ? blocks + [{ cachePoint: { type: 'default' } }] : blocks
end
end

def render_inference_config(_model, temperature)
Expand Down
18 changes: 17 additions & 1 deletion lib/ruby_llm/providers/openrouter/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,31 @@ def thinking_tokens(usage)

def format_messages(messages)
messages.map do |msg|
content = format_content(msg.content)
content = inject_cache_control(content) if msg.cache_point?

{
role: format_role(msg.role),
content: format_content(msg.content),
content: content,
tool_calls: OpenAI::Tools.format_tool_calls(msg.tool_calls),
tool_call_id: msg.tool_call_id
}.compact.merge(format_thinking(msg))
end
end

def inject_cache_control(content)
# Anthropic cache_control. For other models will be ignored by respective provider.
# Wrap plain strings into a text block first so the marker can be attached.
blocks = content.is_a?(Array) ? content.dup : [{ type: 'text', text: content }]
return blocks if blocks.empty?

last = blocks.last
return blocks if last.is_a?(Hash) && last[:cache_control]

blocks[-1] = last.merge(cache_control: { type: 'ephemeral' })
blocks
end

def format_content(content)
OpenAI::Media.format_content(content)
end
Expand Down
2 changes: 1 addition & 1 deletion lib/ruby_llm/streaming.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def stream_response(connection, payload, additional_headers = {}, &block)
accumulator = StreamAccumulator.new

response = connection.post stream_url, payload do |req|
req.headers = additional_headers.merge(req.headers) unless additional_headers.empty?
req.headers.merge!(additional_headers) unless additional_headers.empty?
if faraday_1?
req.options[:on_data] = handle_stream do |chunk|
accumulator.add chunk
Expand Down
58 changes: 58 additions & 0 deletions spec/ruby_llm/chat_cache_point_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# frozen_string_literal: true

require 'spec_helper'

RSpec.describe RubyLLM::Chat do
include_context 'with configured RubyLLM'

describe 'cache_point forwarding' do
let(:chat) { RubyLLM.chat }

shared_examples 'a method that supports cache_point' do |message_finder|
it 'sets cache_point? true when cache_point: true' do
action.call(cache_point: true)
message = message_finder.call(chat)
expect(message).not_to be_nil
expect(message.cache_point?).to be true
end

it 'sets cache_point? false when cache_point is omitted' do
action.call
message = message_finder.call(chat)
expect(message).not_to be_nil
expect(message.cache_point?).to be false
end
end

describe '#with_instructions' do
let(:action) { ->(opts = {}) { chat.with_instructions('Be helpful', **opts) } }
let(:finder) { ->(c) { c.messages.find { |m| m.role == :system } } }

it_behaves_like 'a method that supports cache_point', ->(c) { c.messages.find { |m| m.role == :system } }

it 'sets cache_point? true on appended message only' do
chat.with_instructions('First instruction')
chat.with_instructions('Second instruction', append: true, cache_point: true)
system_msgs = chat.messages.select { |m| m.role == :system }
expect(system_msgs.last.cache_point?).to be true
expect(system_msgs.first.cache_point?).to be false
end

it 'preserves cache_point: true when replacing' do
chat.with_instructions('Old instruction', cache_point: false)
chat.with_instructions('New instruction', replace: true, cache_point: true)
system_msgs = chat.messages.select { |m| m.role == :system }
expect(system_msgs.size).to eq(1)
expect(system_msgs.first.cache_point?).to be true
end
end

describe '#ask' do
before { allow(chat).to receive(:complete) }

let(:action) { ->(opts = {}) { chat.ask('Hello', **opts) } }

it_behaves_like 'a method that supports cache_point', ->(c) { c.messages.find { |m| m.role == :user } }
end
end
end
24 changes: 24 additions & 0 deletions spec/ruby_llm/message_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,30 @@
)
end

describe '#cache_point?' do
it 'returns false by default' do
message = described_class.new(role: :user, content: 'hello')
expect(message.cache_point?).to be false
end

it 'returns true when constructed with cache_point: true' do
message = described_class.new(role: :user, content: 'hello', cache_point: true)
expect(message.cache_point?).to be true
end
end

describe '#to_h' do
it 'omits cache_point key when false' do
message = described_class.new(role: :user, content: 'hello')
expect(message.to_h).not_to have_key(:cache_point)
end

it 'includes cache_point: true when set' do
message = described_class.new(role: :user, content: 'hello', cache_point: true)
expect(message.to_h[:cache_point]).to be true
end
end

describe '#content' do
it 'normalizes nil content to empty string for assistant tool-call messages' do
tool_call = RubyLLM::ToolCall.new(id: 'call_1', name: 'weather', arguments: {})
Expand Down
Loading
Loading