Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 11 additions & 73 deletions pkg/fake/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,88 +338,26 @@ func SimulatedStreamCopy(c echo.Context, resp *http.Response, chunkDelay time.Du
ctx := c.Request().Context()
writer := c.Response().Writer

reader := bufio.NewReaderSize(resp.Body, 64*1024)
w := c.Response().Writer

// Reuse timer to avoid allocations per chunk
timer := time.NewTimer(chunkDelay)
defer timer.Stop()

dataPrefix := []byte("data:")
rf, ok := w.(io.ReaderFrom)
if !ok {
// fallback seguro
_, err := io.Copy(w, resp.Body)
return err
}

for {
select {
case <-ctx.Done():
slog.WarnContext(ctx, "client disconnected, stop streaming")
return nil
default:
}

line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
// Write any remaining data without newline
if len(line) > 0 {
_, _ = writer.Write(line)
c.Response().Flush()
n, err := rf.ReadFrom(io.LimitReader(resp.Body, 256))
if n > 0 {
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
return nil
}
return err
}

// Write the line (already includes newline from ReadBytes)
if _, err := writer.Write(line); err != nil {
return err
}

// Add delay after data lines (SSE events start with "data:")
if bytes.HasPrefix(line, dataPrefix) {
c.Response().Flush()
timer.Reset(chunkDelay)
select {
case <-ctx.Done():
return nil
case <-timer.C:
}
}
}
}

// streamReadResult holds the result of a streaming read operation.
type streamReadResult struct {
n int64
err error
}

// StreamCopy copies a streaming response to the client.
// It properly handles context cancellation during blocking reads.
func StreamCopy(c echo.Context, resp *http.Response) error {
ctx := c.Request().Context()
writer := c.Response().Writer.(io.ReaderFrom)

// Use a channel to receive read results from a goroutine.
// This allows us to properly select on context cancellation
// even when the read is blocking.
resultCh := make(chan streamReadResult, 1)

for {
// Start a goroutine to perform the blocking read
go func() {
n, err := writer.ReadFrom(io.LimitReader(resp.Body, 256))
resultCh <- streamReadResult{n: n, err: err}
}()

// Wait for either context cancellation or read completion
select {
case <-ctx.Done():
slog.WarnContext(ctx, "client disconnected, stop streaming")
// Close the response body to unblock the read goroutine
resp.Body.Close()
<-resultCh
return nil
case result := <-resultCh:
if result.n > 0 {
c.Response().Flush() // keep flushing to client
}
if result.err != nil {
// io.EOF or context canceled means normal completion
Expand Down
11 changes: 7 additions & 4 deletions pkg/runtime/connectrpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,14 @@ func (c *ConnectRPCClient) convertProtoEventToRuntimeEvent(e *cagentv1.Event) Ev
LastMessage: convertProtoMessageUsage(ev.TokenUsage.Usage.LastMessage),
}
}

return &TokenUsageEvent{
Type: "token_usage",
SessionID: ev.TokenUsage.SessionId,
Usage: usage,
AgentContext: AgentContext{AgentName: ev.TokenUsage.AgentName},
Type: "token_usage",
SessionID: ev.TokenUsage.SessionId,
Usage: usage,
AgentContext: AgentContext{
AgentName: ev.TokenUsage.AgentName,
},
}

case *cagentv1.Event_SessionTitle:
Expand Down
93 changes: 50 additions & 43 deletions pkg/runtime/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package runtime

import (
"cmp"
"time"

"github.com/docker/cagent/pkg/chat"
"github.com/docker/cagent/pkg/config/types"
Expand Down Expand Up @@ -195,9 +196,14 @@ type TokenUsageEvent struct {
Type string `json:"type"`
SessionID string `json:"session_id"`
Usage *Usage `json:"usage"`

AgentContext
}

func (*TokenUsageEvent) GetType() string {
return "token_usage"
}

type Usage struct {
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
Expand All @@ -216,11 +222,29 @@ type MessageUsage struct {
Model string
}

func TokenUsage(sessionID, agentName string, inputTokens, outputTokens, contextLength, contextLimit int64, cost float64) Event {
return TokenUsageWithMessage(sessionID, agentName, inputTokens, outputTokens, contextLength, contextLimit, cost, nil)
}

func TokenUsageWithMessage(sessionID, agentName string, inputTokens, outputTokens, contextLength, contextLimit int64, cost float64, msgUsage *MessageUsage) Event {
func TokenUsage(
sessionID, agentName string,
inputTokens, outputTokens, contextLength, contextLimit int64,
cost float64,
) Event {
return TokenUsageWithMessage(
sessionID,
agentName,
inputTokens,
outputTokens,
contextLength,
contextLimit,
cost,
nil,
)
}

func TokenUsageWithMessage(
sessionID, agentName string,
inputTokens, outputTokens, contextLength, contextLimit int64,
cost float64,
msgUsage *MessageUsage,
) Event {
return &TokenUsageEvent{
Type: "token_usage",
SessionID: sessionID,
Expand All @@ -236,13 +260,6 @@ func TokenUsageWithMessage(sessionID, agentName string, inputTokens, outputToken
}
}

type SessionTitleEvent struct {
Type string `json:"type"`
SessionID string `json:"session_id"`
Title string `json:"title"`
AgentContext
}

func SessionTitle(sessionID, title string) Event {
return &SessionTitleEvent{
Type: "session_title",
Expand Down Expand Up @@ -530,42 +547,32 @@ func HookBlocked(toolCall tools.ToolCall, toolDefinition tools.Tool, message, ag
}
}

// MessageAddedEvent is emitted when a message is added to the session.
// This event is used by the PersistentRuntime wrapper to persist messages.
type MessageAddedEvent struct {
Type string `json:"type"`
SessionID string `json:"session_id"`
Message *session.Message `json:"-"`
AgentContext
}
type SessionMetricsEvent struct {
Type string `json:"type"` // "session_metrics"

func (e *MessageAddedEvent) GetAgentName() string { return e.AgentName }
SessionID string `json:"session_id"`

func MessageAdded(sessionID string, msg *session.Message, agentName string) Event {
return &MessageAddedEvent{
Type: "message_added",
SessionID: sessionID,
Message: msg,
AgentContext: AgentContext{AgentName: agentName},
}
UserMessages int `json:"user_messages"`
AssistantMessages int `json:"assistant_messages"`
ToolCalls int `json:"tool_calls"`
ToolErrors int `json:"tool_errors"`

StartedAt time.Time `json:"started_at"`
EndedAt time.Time `json:"ended_at"`
}

// SubSessionCompletedEvent is emitted when a sub-session completes and is added to parent.
// This event is used by the PersistentRuntime wrapper to persist sub-sessions.
type SubSessionCompletedEvent struct {
Type string `json:"type"`
ParentSessionID string `json:"parent_session_id"`
SubSession any `json:"sub_session"` // *session.Session
AgentContext
func (*SessionMetricsEvent) GetType() string {
return "session_metrics"
}

func (e *SubSessionCompletedEvent) GetAgentName() string { return e.AgentName }
type SessionTitleEvent struct {
Type string `json:"type"`
SessionID string `json:"session_id"`
Title string `json:"title"`

AgentContext
}

func SubSessionCompleted(parentSessionID string, subSession any, agentName string) Event {
return &SubSessionCompletedEvent{
Type: "sub_session_completed",
ParentSessionID: parentSessionID,
SubSession: subSession,
AgentContext: AgentContext{AgentName: agentName},
}
func (*SessionTitleEvent) GetType() string {
return "session_title"
}
69 changes: 52 additions & 17 deletions pkg/server/session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,37 +134,48 @@ func (sm *SessionManager) DeleteSession(ctx context.Context, sessionID string) e
}

// RunSession runs a session with the given messages.
func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilename, currentAgent string, messages []api.Message) (<-chan runtime.Event, error) {
func (sm *SessionManager) RunSession(
ctx context.Context,
sessionID, agentFilename, currentAgent string,
messages []api.Message,
) (<-chan runtime.Event, error) {
sm.mux.Lock()
defer sm.mux.Unlock()

// Load persisted session
sess, err := sm.sessionStore.GetSession(ctx, sessionID)
if err != nil {
sm.mux.Unlock()
return nil, err
}

// Mark execution start (observability only)
sess.Metrics = session.Metrics{}
sess.Metrics.StartedAt = time.Now()

// Clone runtime config and inherit working dir
rc := sm.runConfig.Clone()
rc.WorkingDir = sess.WorkingDir

// Collect user messages for potential title generation
var userMessages []string
// Append user messages and count them
for _, msg := range messages {
sess.AddMessage(session.UserMessage(msg.Content, msg.MultiContent...))
if msg.Content != "" {
userMessages = append(userMessages, msg.Content)
}
sess.Metrics.UserMessages++
}

if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil {
sm.mux.Unlock()
return nil, err
}

// Get or create runtime for this session
runtimeSession, exists := sm.runtimeSessions.Load(sessionID)
streamCtx, cancel := context.WithCancel(ctx)
var model provider.Provider

if !exists {
var rt runtime.Runtime
rt, model, err = sm.runtimeForSession(ctx, sess, agentFilename, currentAgent, rc)
if err != nil {
sm.mux.Unlock()
cancel()
return nil, err
}
Expand All @@ -181,30 +192,54 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
model = runtimeSession.model
}

sm.mux.Unlock()

streamChan := make(chan runtime.Event)

// Check if we need to generate a title
needsTitle := sess.Title == "" && len(userMessages) > 0 && model != nil

go func() {
// Start title generation in parallel if needed
if needsTitle {
go sm.generateTitle(ctx, sess, model, userMessages, streamChan)
}

stream := runtimeSession.runtime.RunStream(streamCtx, sess)
defer cancel()
defer close(streamChan)

stream := runtimeSession.runtime.RunStream(streamCtx, sess)

for event := range stream {
if streamCtx.Err() != nil {
return
}

// Collect session-level observability metrics
if e, ok := event.(interface{ GetType() string }); ok {
switch e.GetType() {
case "assistant_message":
sess.Metrics.AssistantMessages++
case "tool_call":
sess.Metrics.ToolCalls++
case "tool_error":
sess.Metrics.ToolErrors++
}
}

streamChan <- event
}

if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil {
return
}
// Mark execution end
sess.Metrics.EndedAt = time.Now()

streamChan <- runtime.TokenUsage(
sess.ID,
currentAgent,
sess.InputTokens,
sess.OutputTokens,
0,
0,
sess.Cost,
)

// Persist updated session state (metrics are ephemeral)
_ = sm.sessionStore.UpdateSession(ctx, sess)
}()

return streamChan, nil
Expand Down
Loading