Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 2 additions & 32 deletions cmd/ax/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,38 +341,8 @@ func runExecServer(ctx context.Context, d *internal.Display, req *proto.ExecRequ

func displayContents(d *internal.Display, contents []*proto.Message) {
for _, output := range contents {
content := output.GetContent()
if content == nil {
continue
}
switch o := content.Type.(type) {
case *proto.Content_Text:
d.DisplayText(o.Text.Text)
case *proto.Content_Confirmation:
// Let the confirmation prompt handle displaying the question.
case *proto.Content_ToolCall:
// No-op for cleaner CLI logs
case *proto.Content_ToolResult:
// Only print if the tool returned an error, otherwise skip
tr := o.ToolResult
if fr := tr.GetFunctionResult(); fr != nil {
if fr.GetResponse() != nil {
respMap := fr.GetResponse().AsMap()
if errStr, ok := respMap["error"]; ok {
d.DisplaySystem(fmt.Sprintf("[TOOL ERROR for %s]\n%v", fr.Name, errStr))
}
}
}
case *proto.Content_Thought:
for _, summary := range o.Thought.GetSummary() {
if textContent := summary.GetText(); textContent != nil {
d.DisplayThought(textContent.Text)
}
}
case *proto.Content_Image, *proto.Content_Audio, *proto.Content_Video, *proto.Content_Document:
d.DisplaySystem(fmt.Sprintf("unsupported output type for display: %T", o))
default:
d.DisplaySystem(fmt.Sprintf("unknown output type: %v", o))
if content := output.GetContent(); content != nil {
d.Display(content)
}
}
}
Expand Down
67 changes: 50 additions & 17 deletions cmd/ax/internal/display.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"charm.land/huh/v2"
"charm.land/lipgloss/v2"
"github.com/google/ax/proto"
)

const (
Expand Down Expand Up @@ -86,29 +87,61 @@ func (d *Display) DisplayInput(text string) {
fmt.Fprintln(d.w)
}

// DisplayText prints a chunk of model text response.
func (d *Display) DisplayText(text string) {
if d.state == stateThought {
fmt.Fprintln(d.w) // end the thinking line
// Display prints a content block according to its type.
func (d *Display) Display(content *proto.Content) {
if content == nil {
return
}
d.state = stateText
fmt.Fprint(d.w, text)
}
switch o := content.Type.(type) {
case *proto.Content_Text:
if d.state == stateThought {
fmt.Fprintln(d.w) // end the thinking line
}
d.state = stateText
fmt.Fprint(d.w, o.Text.Text)

case *proto.Content_Confirmation:
// Let the confirmation prompt handle displaying the question.

case *proto.Content_ToolCall:
// No-op for cleaner CLI logs

case *proto.Content_ToolResult:
// Only print if the tool returned an error, otherwise skip
tr := o.ToolResult
if fr := tr.GetFunctionResult(); fr != nil {
if fr.GetResponse() != nil {
respMap := fr.GetResponse().AsMap()
if errStr, ok := respMap["error"]; ok {
d.displaySystem(fmt.Sprintf("[TOOL ERROR for %s]\n%v", fr.Name, errStr))
}
}
}

// DisplayThought prints a chunk of model thinking process.
func (d *Display) DisplayThought(text string) {
if d.state != stateThought {
if d.state == stateText {
fmt.Fprintln(d.w)
case *proto.Content_Thought:
for _, summary := range o.Thought.GetSummary() {
if textContent := summary.GetText(); textContent != nil {
if d.state != stateThought {
if d.state == stateText {
fmt.Fprintln(d.w)
}
fmt.Fprint(d.w, "Thinking: ")
}
d.state = stateThought
fmt.Fprint(d.w, textContent.Text)
}
}
fmt.Fprint(d.w, "Thinking: ")

case *proto.Content_Image, *proto.Content_Audio, *proto.Content_Video, *proto.Content_Document:
d.displaySystem(fmt.Sprintf("unsupported output type for display: %T", o))

default:
d.displaySystem(fmt.Sprintf("unknown output type: %v", o))
}
d.state = stateThought
fmt.Fprint(d.w, text)
}

// DisplaySystem prints a system/error message on a new line.
func (d *Display) DisplaySystem(text string) {
// displaySystem prints a system/error message on a new line.
func (d *Display) displaySystem(text string) {
if d.state != stateNone {
fmt.Fprintln(d.w)
}
Expand Down
43 changes: 28 additions & 15 deletions cmd/ax/internal/display_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,30 @@ package internal
import (
"bytes"
"testing"

"github.com/google/ax/proto"
)

func TestDisplay_Streaming(t *testing.T) {
textContent := func(txt string) *proto.Content {
return &proto.Content{Type: &proto.Content_Text{Text: &proto.TextContent{Text: txt}}}
}
thoughtContent := func(txt string) *proto.Content {
return &proto.Content{Type: &proto.Content_Thought{Thought: &proto.ThoughtContent{
Summary: []*proto.ThoughtSummaryContent{
{Type: &proto.ThoughtSummaryContent_Text{Text: &proto.TextContent{Text: txt}}},
},
}}}
}

t.Run("consecutive text chunks are concatenated", func(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
d := NewDisplay("test-id", &buf)

d.DisplayText("Hello ")
d.DisplayText("world")
d.DisplayText("!")
d.Display(textContent("Hello "))
d.Display(textContent("world"))
d.Display(textContent("!"))

got := buf.String()
want := "Hello world!"
Expand All @@ -41,8 +54,8 @@ func TestDisplay_Streaming(t *testing.T) {
var buf bytes.Buffer
d := NewDisplay("test-id", &buf)

d.DisplayThought("thinking ")
d.DisplayThought("deeply")
d.Display(thoughtContent("thinking "))
d.Display(thoughtContent("deeply"))

got := buf.String()
want := "Thinking: thinking deeply"
Expand All @@ -56,8 +69,8 @@ func TestDisplay_Streaming(t *testing.T) {
var buf bytes.Buffer
d := NewDisplay("test-id", &buf)

d.DisplayThought("thinking")
d.DisplayText("Hello")
d.Display(thoughtContent("thinking"))
d.Display(textContent("Hello"))

got := buf.String()
want := "Thinking: thinking\nHello"
Expand All @@ -71,8 +84,8 @@ func TestDisplay_Streaming(t *testing.T) {
var buf bytes.Buffer
d := NewDisplay("test-id", &buf)

d.DisplayText("Hello")
d.DisplayThought("thinking")
d.Display(textContent("Hello"))
d.Display(thoughtContent("thinking"))

got := buf.String()
want := "Hello\nThinking: thinking"
Expand All @@ -86,7 +99,7 @@ func TestDisplay_Streaming(t *testing.T) {
var buf bytes.Buffer
d := NewDisplay("test-id", &buf)

d.DisplayText("Hello")
d.Display(textContent("Hello"))
d.FinishOutput("")

got := buf.String()
Expand All @@ -101,7 +114,7 @@ func TestDisplay_Streaming(t *testing.T) {
var buf bytes.Buffer
d := NewDisplay("test-id", &buf)

d.DisplayText("Hello")
d.Display(textContent("Hello"))
d.FinishOutput("seq=1")

got := buf.String()
Expand All @@ -113,13 +126,13 @@ func TestDisplay_Streaming(t *testing.T) {
}
})

t.Run("DisplaySystem resets state and prints newline", func(t *testing.T) {
t.Run("displaySystem resets state and prints newline", func(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
d := NewDisplay("test-id", &buf)

d.DisplayText("Hello")
d.DisplaySystem("system message")
d.Display(textContent("Hello"))
d.displaySystem("system message")

got := buf.String()
want := "Hello\nsystem message\n"
Expand All @@ -133,7 +146,7 @@ func TestDisplay_Streaming(t *testing.T) {
var buf bytes.Buffer
d := NewDisplay("test-id", &buf)

d.DisplayText("Hello")
d.Display(textContent("Hello"))
d.DisplayInput("prompt")

got := buf.String()
Expand Down
Loading