diff --git a/cmd/ax/exec.go b/cmd/ax/exec.go index 4efc578..bd95b41 100644 --- a/cmd/ax/exec.go +++ b/cmd/ax/exec.go @@ -341,38 +341,8 @@ func runExecServer(ctx context.Context, d *internal.Display, req *proto.ExecRequ func displayContents(d *internal.Display, contents []*proto.Message) { for _, output := range contents { - content := output.GetContent() - if content == nil { - continue - } - switch o := content.Type.(type) { - case *proto.Content_Text: - d.DisplayText(o.Text.Text) - case *proto.Content_Confirmation: - // Let the confirmation prompt handle displaying the question. - case *proto.Content_ToolCall: - // No-op for cleaner CLI logs - case *proto.Content_ToolResult: - // Only print if the tool returned an error, otherwise skip - tr := o.ToolResult - if fr := tr.GetFunctionResult(); fr != nil { - if fr.GetResponse() != nil { - respMap := fr.GetResponse().AsMap() - if errStr, ok := respMap["error"]; ok { - d.DisplaySystem(fmt.Sprintf("[TOOL ERROR for %s]\n%v", fr.Name, errStr)) - } - } - } - case *proto.Content_Thought: - for _, summary := range o.Thought.GetSummary() { - if textContent := summary.GetText(); textContent != nil { - d.DisplayThought(textContent.Text) - } - } - case *proto.Content_Image, *proto.Content_Audio, *proto.Content_Video, *proto.Content_Document: - d.DisplaySystem(fmt.Sprintf("unsupported output type for display: %T", o)) - default: - d.DisplaySystem(fmt.Sprintf("unknown output type: %v", o)) + if content := output.GetContent(); content != nil { + d.Display(content) } } } diff --git a/cmd/ax/internal/display.go b/cmd/ax/internal/display.go index 46efbd5..5c413ff 100644 --- a/cmd/ax/internal/display.go +++ b/cmd/ax/internal/display.go @@ -21,6 +21,7 @@ import ( "charm.land/huh/v2" "charm.land/lipgloss/v2" + "github.com/google/ax/proto" ) const ( @@ -86,29 +87,61 @@ func (d *Display) DisplayInput(text string) { fmt.Fprintln(d.w) } -// DisplayText prints a chunk of model text response. -func (d *Display) DisplayText(text string) { - if d.state == stateThought { - fmt.Fprintln(d.w) // end the thinking line +// Display prints a content block according to its type. +func (d *Display) Display(content *proto.Content) { + if content == nil { + return } - d.state = stateText - fmt.Fprint(d.w, text) -} + switch o := content.Type.(type) { + case *proto.Content_Text: + if d.state == stateThought { + fmt.Fprintln(d.w) // end the thinking line + } + d.state = stateText + fmt.Fprint(d.w, o.Text.Text) + + case *proto.Content_Confirmation: + // Let the confirmation prompt handle displaying the question. + + case *proto.Content_ToolCall: + // No-op for cleaner CLI logs + + case *proto.Content_ToolResult: + // Only print if the tool returned an error, otherwise skip + tr := o.ToolResult + if fr := tr.GetFunctionResult(); fr != nil { + if fr.GetResponse() != nil { + respMap := fr.GetResponse().AsMap() + if errStr, ok := respMap["error"]; ok { + d.displaySystem(fmt.Sprintf("[TOOL ERROR for %s]\n%v", fr.Name, errStr)) + } + } + } -// DisplayThought prints a chunk of model thinking process. -func (d *Display) DisplayThought(text string) { - if d.state != stateThought { - if d.state == stateText { - fmt.Fprintln(d.w) + case *proto.Content_Thought: + for _, summary := range o.Thought.GetSummary() { + if textContent := summary.GetText(); textContent != nil { + if d.state != stateThought { + if d.state == stateText { + fmt.Fprintln(d.w) + } + fmt.Fprint(d.w, "Thinking: ") + } + d.state = stateThought + fmt.Fprint(d.w, textContent.Text) + } } - fmt.Fprint(d.w, "Thinking: ") + + case *proto.Content_Image, *proto.Content_Audio, *proto.Content_Video, *proto.Content_Document: + d.displaySystem(fmt.Sprintf("unsupported output type for display: %T", o)) + + default: + d.displaySystem(fmt.Sprintf("unknown output type: %v", o)) } - d.state = stateThought - fmt.Fprint(d.w, text) } -// DisplaySystem prints a system/error message on a new line. -func (d *Display) DisplaySystem(text string) { +// displaySystem prints a system/error message on a new line. +func (d *Display) displaySystem(text string) { if d.state != stateNone { fmt.Fprintln(d.w) } diff --git a/cmd/ax/internal/display_test.go b/cmd/ax/internal/display_test.go index 2551170..3f84d60 100644 --- a/cmd/ax/internal/display_test.go +++ b/cmd/ax/internal/display_test.go @@ -17,17 +17,30 @@ package internal import ( "bytes" "testing" + + "github.com/google/ax/proto" ) func TestDisplay_Streaming(t *testing.T) { + textContent := func(txt string) *proto.Content { + return &proto.Content{Type: &proto.Content_Text{Text: &proto.TextContent{Text: txt}}} + } + thoughtContent := func(txt string) *proto.Content { + return &proto.Content{Type: &proto.Content_Thought{Thought: &proto.ThoughtContent{ + Summary: []*proto.ThoughtSummaryContent{ + {Type: &proto.ThoughtSummaryContent_Text{Text: &proto.TextContent{Text: txt}}}, + }, + }}} + } + t.Run("consecutive text chunks are concatenated", func(t *testing.T) { t.Parallel() var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayText("Hello ") - d.DisplayText("world") - d.DisplayText("!") + d.Display(textContent("Hello ")) + d.Display(textContent("world")) + d.Display(textContent("!")) got := buf.String() want := "Hello world!" @@ -41,8 +54,8 @@ func TestDisplay_Streaming(t *testing.T) { var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayThought("thinking ") - d.DisplayThought("deeply") + d.Display(thoughtContent("thinking ")) + d.Display(thoughtContent("deeply")) got := buf.String() want := "Thinking: thinking deeply" @@ -56,8 +69,8 @@ func TestDisplay_Streaming(t *testing.T) { var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayThought("thinking") - d.DisplayText("Hello") + d.Display(thoughtContent("thinking")) + d.Display(textContent("Hello")) got := buf.String() want := "Thinking: thinking\nHello" @@ -71,8 +84,8 @@ func TestDisplay_Streaming(t *testing.T) { var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayText("Hello") - d.DisplayThought("thinking") + d.Display(textContent("Hello")) + d.Display(thoughtContent("thinking")) got := buf.String() want := "Hello\nThinking: thinking" @@ -86,7 +99,7 @@ func TestDisplay_Streaming(t *testing.T) { var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayText("Hello") + d.Display(textContent("Hello")) d.FinishOutput("") got := buf.String() @@ -101,7 +114,7 @@ func TestDisplay_Streaming(t *testing.T) { var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayText("Hello") + d.Display(textContent("Hello")) d.FinishOutput("seq=1") got := buf.String() @@ -113,13 +126,13 @@ func TestDisplay_Streaming(t *testing.T) { } }) - t.Run("DisplaySystem resets state and prints newline", func(t *testing.T) { + t.Run("displaySystem resets state and prints newline", func(t *testing.T) { t.Parallel() var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayText("Hello") - d.DisplaySystem("system message") + d.Display(textContent("Hello")) + d.displaySystem("system message") got := buf.String() want := "Hello\nsystem message\n" @@ -133,7 +146,7 @@ func TestDisplay_Streaming(t *testing.T) { var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayText("Hello") + d.Display(textContent("Hello")) d.DisplayInput("prompt") got := buf.String()