diff --git a/cmd/ax/exec.go b/cmd/ax/exec.go index 0dfba85..bd95b41 100644 --- a/cmd/ax/exec.go +++ b/cmd/ax/exec.go @@ -124,7 +124,7 @@ func runExec(cmd *cobra.Command, args []string) error { } func execLoop(ctx context.Context, id string, agentID string, input string, lastSeq int32) error { - d := internal.NewDisplay(id) + d := internal.NewDisplay(id, os.Stdout) d.DisplayHeader() var inputs []*proto.Message @@ -341,38 +341,8 @@ func runExecServer(ctx context.Context, d *internal.Display, req *proto.ExecRequ func displayContents(d *internal.Display, contents []*proto.Message) { for _, output := range contents { - content := output.GetContent() - if content == nil { - continue - } - switch o := content.Type.(type) { - case *proto.Content_Text: - d.DisplayOutput(o.Text.Text) - case *proto.Content_Confirmation: - // Let the confirmation prompt handle displaying the question. - case *proto.Content_ToolCall: - // No-op for cleaner CLI logs - case *proto.Content_ToolResult: - // Only print if the tool returned an error, otherwise skip - tr := o.ToolResult - if fr := tr.GetFunctionResult(); fr != nil { - if fr.GetResponse() != nil { - respMap := fr.GetResponse().AsMap() - if errStr, ok := respMap["error"]; ok { - d.DisplayOutput(fmt.Sprintf("\n[TOOL ERROR for %s]\n%v\n", fr.Name, errStr)) - } - } - } - case *proto.Content_Thought: - for _, summary := range o.Thought.GetSummary() { - if textContent := summary.GetText(); textContent != nil { - d.DisplayOutput(fmt.Sprintf("Thinking: %s", textContent.Text)) - } - } - case *proto.Content_Image, *proto.Content_Audio, *proto.Content_Video, *proto.Content_Document: - d.DisplayOutput(fmt.Sprintf("unsupported output type for display: %T", o)) - default: - d.DisplayOutput(fmt.Sprintf("unknown output type: %v", o)) + if content := output.GetContent(); content != nil { + d.Display(content) } } } diff --git a/cmd/ax/internal/display.go b/cmd/ax/internal/display.go index 22434e3..5c413ff 100644 --- a/cmd/ax/internal/display.go +++ b/cmd/ax/internal/display.go @@ -16,11 +16,12 @@ package internal import ( "fmt" + "io" "os" - "sync/atomic" "charm.land/huh/v2" "charm.land/lipgloss/v2" + "github.com/google/ax/proto" ) const ( @@ -38,55 +39,131 @@ var ( // ErrUserAborted is returned when the user aborts a prompt. var ErrUserAborted = huh.ErrUserAborted +type displayState int + +const ( + stateNone displayState = iota + stateText + stateThought +) + type Display struct { id string + w io.Writer // Target output writer, e.g., os.Stdout or a test buffer userStyle lipgloss.Style checkpointStyle lipgloss.Style idStyle lipgloss.Style resumeStyle lipgloss.Style - loadingVisible atomic.Bool - loadingStopCh chan bool + state displayState // Tracks the last printed chunk type to correctly format transition newlines } -func NewDisplay(id string) *Display { +func NewDisplay(id string, w io.Writer) *Display { + if w == nil { + w = os.Stdout + } return &Display{ id: id, + w: w, userStyle: lipgloss.NewStyle().Foreground(purple), checkpointStyle: lipgloss.NewStyle().Foreground(comment), idStyle: lipgloss.NewStyle().Foreground(comment), resumeStyle: lipgloss.NewStyle().Foreground(comment), - loadingStopCh: make(chan bool), + state: stateNone, } } // DisplayInput displays the user input. func (d *Display) DisplayInput(text string) { - fmt.Printf("%s %s\n", + if d.state != stateNone { + fmt.Fprintln(d.w) + } + d.state = stateNone + fmt.Fprintf(d.w, "%s %s\n", d.userStyle.Render("⏺"), text, ) - fmt.Println() + fmt.Fprintln(d.w) } -// DisplayOutput displays an output fragment. -func (d *Display) DisplayOutput(text string) { - fmt.Println(text) - fmt.Println() +// Display prints a content block according to its type. +func (d *Display) Display(content *proto.Content) { + if content == nil { + return + } + switch o := content.Type.(type) { + case *proto.Content_Text: + if d.state == stateThought { + fmt.Fprintln(d.w) // end the thinking line + } + d.state = stateText + fmt.Fprint(d.w, o.Text.Text) + + case *proto.Content_Confirmation: + // Let the confirmation prompt handle displaying the question. + + case *proto.Content_ToolCall: + // No-op for cleaner CLI logs + + case *proto.Content_ToolResult: + // Only print if the tool returned an error, otherwise skip + tr := o.ToolResult + if fr := tr.GetFunctionResult(); fr != nil { + if fr.GetResponse() != nil { + respMap := fr.GetResponse().AsMap() + if errStr, ok := respMap["error"]; ok { + d.displaySystem(fmt.Sprintf("[TOOL ERROR for %s]\n%v", fr.Name, errStr)) + } + } + } + + case *proto.Content_Thought: + for _, summary := range o.Thought.GetSummary() { + if textContent := summary.GetText(); textContent != nil { + if d.state != stateThought { + if d.state == stateText { + fmt.Fprintln(d.w) + } + fmt.Fprint(d.w, "Thinking: ") + } + d.state = stateThought + fmt.Fprint(d.w, textContent.Text) + } + } + + case *proto.Content_Image, *proto.Content_Audio, *proto.Content_Video, *proto.Content_Document: + d.displaySystem(fmt.Sprintf("unsupported output type for display: %T", o)) + + default: + d.displaySystem(fmt.Sprintf("unknown output type: %v", o)) + } +} + +// displaySystem prints a system/error message on a new line. +func (d *Display) displaySystem(text string) { + if d.state != stateNone { + fmt.Fprintln(d.w) + } + d.state = stateNone + fmt.Fprintln(d.w, text) } // FinishOutput completes the streaming output and shows info if provided func (d *Display) FinishOutput(info string) { + if d.state != stateNone { + fmt.Fprintln(d.w) + } + d.state = stateNone if info != "" { - fmt.Println(d.checkpointStyle.Render(info)) + fmt.Fprintln(d.w, d.checkpointStyle.Render(info)) } - fmt.Println() + fmt.Fprintln(d.w) } func (d *Display) DisplayHeader() { - fmt.Println(d.idStyle.Render("Conversation: " + d.id)) - fmt.Println() + fmt.Fprintln(d.w, d.idStyle.Render("Conversation: " + d.id)) + fmt.Fprintln(d.w) } // PromptForApproval shows an accept/reject dialog @@ -128,10 +205,10 @@ func (d *Display) PromptForInput() (string, error) { } func (d *Display) ShowResumption(id string, server string) { - fmt.Println(d.resumeStyle.Render("To resume the conversation,")) + fmt.Fprintln(d.w, d.resumeStyle.Render("To resume the conversation,")) if server != "" { - fmt.Println(d.resumeStyle.Render(fmt.Sprintf("ax exec --conversation %s --server %s", id, server))) + fmt.Fprintln(d.w, d.resumeStyle.Render(fmt.Sprintf("ax exec --conversation %s --server %s", id, server))) } else { - fmt.Println(d.resumeStyle.Render(fmt.Sprintf("ax exec --conversation %s", id))) + fmt.Fprintln(d.w, d.resumeStyle.Render(fmt.Sprintf("ax exec --conversation %s", id))) } } diff --git a/cmd/ax/internal/display_test.go b/cmd/ax/internal/display_test.go new file mode 100644 index 0000000..3f84d60 --- /dev/null +++ b/cmd/ax/internal/display_test.go @@ -0,0 +1,160 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "bytes" + "testing" + + "github.com/google/ax/proto" +) + +func TestDisplay_Streaming(t *testing.T) { + textContent := func(txt string) *proto.Content { + return &proto.Content{Type: &proto.Content_Text{Text: &proto.TextContent{Text: txt}}} + } + thoughtContent := func(txt string) *proto.Content { + return &proto.Content{Type: &proto.Content_Thought{Thought: &proto.ThoughtContent{ + Summary: []*proto.ThoughtSummaryContent{ + {Type: &proto.ThoughtSummaryContent_Text{Text: &proto.TextContent{Text: txt}}}, + }, + }}} + } + + t.Run("consecutive text chunks are concatenated", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.Display(textContent("Hello ")) + d.Display(textContent("world")) + d.Display(textContent("!")) + + got := buf.String() + want := "Hello world!" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("consecutive thought chunks are concatenated with prefix", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.Display(thoughtContent("thinking ")) + d.Display(thoughtContent("deeply")) + + got := buf.String() + want := "Thinking: thinking deeply" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("transition from thought to text adds newline", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.Display(thoughtContent("thinking")) + d.Display(textContent("Hello")) + + got := buf.String() + want := "Thinking: thinking\nHello" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("transition from text to thought adds newline", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.Display(textContent("Hello")) + d.Display(thoughtContent("thinking")) + + got := buf.String() + want := "Hello\nThinking: thinking" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("FinishOutput empty resets state and adds newlines", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.Display(textContent("Hello")) + d.FinishOutput("") + + got := buf.String() + want := "Hello\n\n" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("FinishOutput with info prints info and resets state", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.Display(textContent("Hello")) + d.FinishOutput("seq=1") + + got := buf.String() + if !bytes.HasPrefix([]byte(got), []byte("Hello\n")) { + t.Errorf("expected Hello to end with newline, got %q", got) + } + if !bytes.Contains([]byte(got), []byte("seq=1")) { + t.Errorf("expected output to contain seq=1, got %q", got) + } + }) + + t.Run("displaySystem resets state and prints newline", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.Display(textContent("Hello")) + d.displaySystem("system message") + + got := buf.String() + want := "Hello\nsystem message\n" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("DisplayInput resets state and adds separation newlines", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.Display(textContent("Hello")) + d.DisplayInput("prompt") + + got := buf.String() + if !bytes.HasPrefix([]byte(got), []byte("Hello\n")) { + t.Errorf("expected Hello to end with newline, got %q", got) + } + if !bytes.Contains([]byte(got), []byte("prompt")) { + t.Errorf("expected output to contain prompt, got %q", got) + } + }) +}