From cd96aa900e7fbdf9b94f29b80e953f384182defd Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Fri, 12 Jun 2026 23:24:52 -0700 Subject: [PATCH 1/3] refactor: implement stateful Display and inject io.Writer Closes #64. Adds stateful Display printing for streaming outputs. Removes os.Stdout.Sync() overhead and cleans up dead code. Replaces global stdout redirecting with injected io.Writer dependency, enabling safe parallel unit tests. --- cmd/ax/exec.go | 12 +-- cmd/ax/internal/display.go | 80 +++++++++++++---- cmd/ax/internal/display_test.go | 147 ++++++++++++++++++++++++++++++++ 3 files changed, 215 insertions(+), 24 deletions(-) create mode 100644 cmd/ax/internal/display_test.go diff --git a/cmd/ax/exec.go b/cmd/ax/exec.go index 0dfba85..4efc578 100644 --- a/cmd/ax/exec.go +++ b/cmd/ax/exec.go @@ -124,7 +124,7 @@ func runExec(cmd *cobra.Command, args []string) error { } func execLoop(ctx context.Context, id string, agentID string, input string, lastSeq int32) error { - d := internal.NewDisplay(id) + d := internal.NewDisplay(id, os.Stdout) d.DisplayHeader() var inputs []*proto.Message @@ -347,7 +347,7 @@ func displayContents(d *internal.Display, contents []*proto.Message) { } switch o := content.Type.(type) { case *proto.Content_Text: - d.DisplayOutput(o.Text.Text) + d.DisplayText(o.Text.Text) case *proto.Content_Confirmation: // Let the confirmation prompt handle displaying the question. case *proto.Content_ToolCall: @@ -359,20 +359,20 @@ func displayContents(d *internal.Display, contents []*proto.Message) { if fr.GetResponse() != nil { respMap := fr.GetResponse().AsMap() if errStr, ok := respMap["error"]; ok { - d.DisplayOutput(fmt.Sprintf("\n[TOOL ERROR for %s]\n%v\n", fr.Name, errStr)) + d.DisplaySystem(fmt.Sprintf("[TOOL ERROR for %s]\n%v", fr.Name, errStr)) } } } case *proto.Content_Thought: for _, summary := range o.Thought.GetSummary() { if textContent := summary.GetText(); textContent != nil { - d.DisplayOutput(fmt.Sprintf("Thinking: %s", textContent.Text)) + d.DisplayThought(textContent.Text) } } case *proto.Content_Image, *proto.Content_Audio, *proto.Content_Video, *proto.Content_Document: - d.DisplayOutput(fmt.Sprintf("unsupported output type for display: %T", o)) + d.DisplaySystem(fmt.Sprintf("unsupported output type for display: %T", o)) default: - d.DisplayOutput(fmt.Sprintf("unknown output type: %v", o)) + d.DisplaySystem(fmt.Sprintf("unknown output type: %v", o)) } } } diff --git a/cmd/ax/internal/display.go b/cmd/ax/internal/display.go index 22434e3..46efbd5 100644 --- a/cmd/ax/internal/display.go +++ b/cmd/ax/internal/display.go @@ -16,8 +16,8 @@ package internal import ( "fmt" + "io" "os" - "sync/atomic" "charm.land/huh/v2" "charm.land/lipgloss/v2" @@ -38,55 +38,99 @@ var ( // ErrUserAborted is returned when the user aborts a prompt. var ErrUserAborted = huh.ErrUserAborted +type displayState int + +const ( + stateNone displayState = iota + stateText + stateThought +) + type Display struct { id string + w io.Writer // Target output writer, e.g., os.Stdout or a test buffer userStyle lipgloss.Style checkpointStyle lipgloss.Style idStyle lipgloss.Style resumeStyle lipgloss.Style - loadingVisible atomic.Bool - loadingStopCh chan bool + state displayState // Tracks the last printed chunk type to correctly format transition newlines } -func NewDisplay(id string) *Display { +func NewDisplay(id string, w io.Writer) *Display { + if w == nil { + w = os.Stdout + } return &Display{ id: id, + w: w, userStyle: lipgloss.NewStyle().Foreground(purple), checkpointStyle: lipgloss.NewStyle().Foreground(comment), idStyle: lipgloss.NewStyle().Foreground(comment), resumeStyle: lipgloss.NewStyle().Foreground(comment), - loadingStopCh: make(chan bool), + state: stateNone, } } // DisplayInput displays the user input. func (d *Display) DisplayInput(text string) { - fmt.Printf("%s %s\n", + if d.state != stateNone { + fmt.Fprintln(d.w) + } + d.state = stateNone + fmt.Fprintf(d.w, "%s %s\n", d.userStyle.Render("⏺"), text, ) - fmt.Println() + fmt.Fprintln(d.w) +} + +// DisplayText prints a chunk of model text response. +func (d *Display) DisplayText(text string) { + if d.state == stateThought { + fmt.Fprintln(d.w) // end the thinking line + } + d.state = stateText + fmt.Fprint(d.w, text) } -// DisplayOutput displays an output fragment. -func (d *Display) DisplayOutput(text string) { - fmt.Println(text) - fmt.Println() +// DisplayThought prints a chunk of model thinking process. +func (d *Display) DisplayThought(text string) { + if d.state != stateThought { + if d.state == stateText { + fmt.Fprintln(d.w) + } + fmt.Fprint(d.w, "Thinking: ") + } + d.state = stateThought + fmt.Fprint(d.w, text) +} + +// DisplaySystem prints a system/error message on a new line. +func (d *Display) DisplaySystem(text string) { + if d.state != stateNone { + fmt.Fprintln(d.w) + } + d.state = stateNone + fmt.Fprintln(d.w, text) } // FinishOutput completes the streaming output and shows info if provided func (d *Display) FinishOutput(info string) { + if d.state != stateNone { + fmt.Fprintln(d.w) + } + d.state = stateNone if info != "" { - fmt.Println(d.checkpointStyle.Render(info)) + fmt.Fprintln(d.w, d.checkpointStyle.Render(info)) } - fmt.Println() + fmt.Fprintln(d.w) } func (d *Display) DisplayHeader() { - fmt.Println(d.idStyle.Render("Conversation: " + d.id)) - fmt.Println() + fmt.Fprintln(d.w, d.idStyle.Render("Conversation: " + d.id)) + fmt.Fprintln(d.w) } // PromptForApproval shows an accept/reject dialog @@ -128,10 +172,10 @@ func (d *Display) PromptForInput() (string, error) { } func (d *Display) ShowResumption(id string, server string) { - fmt.Println(d.resumeStyle.Render("To resume the conversation,")) + fmt.Fprintln(d.w, d.resumeStyle.Render("To resume the conversation,")) if server != "" { - fmt.Println(d.resumeStyle.Render(fmt.Sprintf("ax exec --conversation %s --server %s", id, server))) + fmt.Fprintln(d.w, d.resumeStyle.Render(fmt.Sprintf("ax exec --conversation %s --server %s", id, server))) } else { - fmt.Println(d.resumeStyle.Render(fmt.Sprintf("ax exec --conversation %s", id))) + fmt.Fprintln(d.w, d.resumeStyle.Render(fmt.Sprintf("ax exec --conversation %s", id))) } } diff --git a/cmd/ax/internal/display_test.go b/cmd/ax/internal/display_test.go new file mode 100644 index 0000000..2551170 --- /dev/null +++ b/cmd/ax/internal/display_test.go @@ -0,0 +1,147 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "bytes" + "testing" +) + +func TestDisplay_Streaming(t *testing.T) { + t.Run("consecutive text chunks are concatenated", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayText("Hello ") + d.DisplayText("world") + d.DisplayText("!") + + got := buf.String() + want := "Hello world!" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("consecutive thought chunks are concatenated with prefix", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayThought("thinking ") + d.DisplayThought("deeply") + + got := buf.String() + want := "Thinking: thinking deeply" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("transition from thought to text adds newline", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayThought("thinking") + d.DisplayText("Hello") + + got := buf.String() + want := "Thinking: thinking\nHello" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("transition from text to thought adds newline", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayText("Hello") + d.DisplayThought("thinking") + + got := buf.String() + want := "Hello\nThinking: thinking" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("FinishOutput empty resets state and adds newlines", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayText("Hello") + d.FinishOutput("") + + got := buf.String() + want := "Hello\n\n" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("FinishOutput with info prints info and resets state", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayText("Hello") + d.FinishOutput("seq=1") + + got := buf.String() + if !bytes.HasPrefix([]byte(got), []byte("Hello\n")) { + t.Errorf("expected Hello to end with newline, got %q", got) + } + if !bytes.Contains([]byte(got), []byte("seq=1")) { + t.Errorf("expected output to contain seq=1, got %q", got) + } + }) + + t.Run("DisplaySystem resets state and prints newline", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayText("Hello") + d.DisplaySystem("system message") + + got := buf.String() + want := "Hello\nsystem message\n" + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("DisplayInput resets state and adds separation newlines", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + d := NewDisplay("test-id", &buf) + + d.DisplayText("Hello") + d.DisplayInput("prompt") + + got := buf.String() + if !bytes.HasPrefix([]byte(got), []byte("Hello\n")) { + t.Errorf("expected Hello to end with newline, got %q", got) + } + if !bytes.Contains([]byte(got), []byte("prompt")) { + t.Errorf("expected output to contain prompt, got %q", got) + } + }) +} From 29b23f2b9db0277e1df7e79c151508578b32b2e0 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Wed, 17 Jun 2026 14:33:14 -0700 Subject: [PATCH 2/3] refactor: simplify Display interface by accepting proto.Content directly --- cmd/ax/exec.go | 34 ++--------------------- cmd/ax/internal/display.go | 49 +++++++++++++++++++++++++++++---- cmd/ax/internal/display_test.go | 43 +++++++++++++++++++---------- 3 files changed, 73 insertions(+), 53 deletions(-) diff --git a/cmd/ax/exec.go b/cmd/ax/exec.go index 4efc578..bd95b41 100644 --- a/cmd/ax/exec.go +++ b/cmd/ax/exec.go @@ -341,38 +341,8 @@ func runExecServer(ctx context.Context, d *internal.Display, req *proto.ExecRequ func displayContents(d *internal.Display, contents []*proto.Message) { for _, output := range contents { - content := output.GetContent() - if content == nil { - continue - } - switch o := content.Type.(type) { - case *proto.Content_Text: - d.DisplayText(o.Text.Text) - case *proto.Content_Confirmation: - // Let the confirmation prompt handle displaying the question. - case *proto.Content_ToolCall: - // No-op for cleaner CLI logs - case *proto.Content_ToolResult: - // Only print if the tool returned an error, otherwise skip - tr := o.ToolResult - if fr := tr.GetFunctionResult(); fr != nil { - if fr.GetResponse() != nil { - respMap := fr.GetResponse().AsMap() - if errStr, ok := respMap["error"]; ok { - d.DisplaySystem(fmt.Sprintf("[TOOL ERROR for %s]\n%v", fr.Name, errStr)) - } - } - } - case *proto.Content_Thought: - for _, summary := range o.Thought.GetSummary() { - if textContent := summary.GetText(); textContent != nil { - d.DisplayThought(textContent.Text) - } - } - case *proto.Content_Image, *proto.Content_Audio, *proto.Content_Video, *proto.Content_Document: - d.DisplaySystem(fmt.Sprintf("unsupported output type for display: %T", o)) - default: - d.DisplaySystem(fmt.Sprintf("unknown output type: %v", o)) + if content := output.GetContent(); content != nil { + d.Display(content) } } } diff --git a/cmd/ax/internal/display.go b/cmd/ax/internal/display.go index 46efbd5..a0cbcf2 100644 --- a/cmd/ax/internal/display.go +++ b/cmd/ax/internal/display.go @@ -21,6 +21,7 @@ import ( "charm.land/huh/v2" "charm.land/lipgloss/v2" + "github.com/google/ax/proto" ) const ( @@ -86,8 +87,8 @@ func (d *Display) DisplayInput(text string) { fmt.Fprintln(d.w) } -// DisplayText prints a chunk of model text response. -func (d *Display) DisplayText(text string) { +// displayText prints a chunk of model text response. +func (d *Display) displayText(text string) { if d.state == stateThought { fmt.Fprintln(d.w) // end the thinking line } @@ -95,8 +96,8 @@ func (d *Display) DisplayText(text string) { fmt.Fprint(d.w, text) } -// DisplayThought prints a chunk of model thinking process. -func (d *Display) DisplayThought(text string) { +// displayThought prints a chunk of model thinking process. +func (d *Display) displayThought(text string) { if d.state != stateThought { if d.state == stateText { fmt.Fprintln(d.w) @@ -107,8 +108,8 @@ func (d *Display) DisplayThought(text string) { fmt.Fprint(d.w, text) } -// DisplaySystem prints a system/error message on a new line. -func (d *Display) DisplaySystem(text string) { +// displaySystem prints a system/error message on a new line. +func (d *Display) displaySystem(text string) { if d.state != stateNone { fmt.Fprintln(d.w) } @@ -116,6 +117,42 @@ func (d *Display) DisplaySystem(text string) { fmt.Fprintln(d.w, text) } +// Display prints a content block according to its type. +func (d *Display) Display(content *proto.Content) { + if content == nil { + return + } + switch o := content.Type.(type) { + case *proto.Content_Text: + d.displayText(o.Text.Text) + case *proto.Content_Confirmation: + // Let the confirmation prompt handle displaying the question. + case *proto.Content_ToolCall: + // No-op for cleaner CLI logs + case *proto.Content_ToolResult: + // Only print if the tool returned an error, otherwise skip + tr := o.ToolResult + if fr := tr.GetFunctionResult(); fr != nil { + if fr.GetResponse() != nil { + respMap := fr.GetResponse().AsMap() + if errStr, ok := respMap["error"]; ok { + d.displaySystem(fmt.Sprintf("[TOOL ERROR for %s]\n%v", fr.Name, errStr)) + } + } + } + case *proto.Content_Thought: + for _, summary := range o.Thought.GetSummary() { + if textContent := summary.GetText(); textContent != nil { + d.displayThought(textContent.Text) + } + } + case *proto.Content_Image, *proto.Content_Audio, *proto.Content_Video, *proto.Content_Document: + d.displaySystem(fmt.Sprintf("unsupported output type for display: %T", o)) + default: + d.displaySystem(fmt.Sprintf("unknown output type: %v", o)) + } +} + // FinishOutput completes the streaming output and shows info if provided func (d *Display) FinishOutput(info string) { if d.state != stateNone { diff --git a/cmd/ax/internal/display_test.go b/cmd/ax/internal/display_test.go index 2551170..3f84d60 100644 --- a/cmd/ax/internal/display_test.go +++ b/cmd/ax/internal/display_test.go @@ -17,17 +17,30 @@ package internal import ( "bytes" "testing" + + "github.com/google/ax/proto" ) func TestDisplay_Streaming(t *testing.T) { + textContent := func(txt string) *proto.Content { + return &proto.Content{Type: &proto.Content_Text{Text: &proto.TextContent{Text: txt}}} + } + thoughtContent := func(txt string) *proto.Content { + return &proto.Content{Type: &proto.Content_Thought{Thought: &proto.ThoughtContent{ + Summary: []*proto.ThoughtSummaryContent{ + {Type: &proto.ThoughtSummaryContent_Text{Text: &proto.TextContent{Text: txt}}}, + }, + }}} + } + t.Run("consecutive text chunks are concatenated", func(t *testing.T) { t.Parallel() var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayText("Hello ") - d.DisplayText("world") - d.DisplayText("!") + d.Display(textContent("Hello ")) + d.Display(textContent("world")) + d.Display(textContent("!")) got := buf.String() want := "Hello world!" @@ -41,8 +54,8 @@ func TestDisplay_Streaming(t *testing.T) { var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayThought("thinking ") - d.DisplayThought("deeply") + d.Display(thoughtContent("thinking ")) + d.Display(thoughtContent("deeply")) got := buf.String() want := "Thinking: thinking deeply" @@ -56,8 +69,8 @@ func TestDisplay_Streaming(t *testing.T) { var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayThought("thinking") - d.DisplayText("Hello") + d.Display(thoughtContent("thinking")) + d.Display(textContent("Hello")) got := buf.String() want := "Thinking: thinking\nHello" @@ -71,8 +84,8 @@ func TestDisplay_Streaming(t *testing.T) { var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayText("Hello") - d.DisplayThought("thinking") + d.Display(textContent("Hello")) + d.Display(thoughtContent("thinking")) got := buf.String() want := "Hello\nThinking: thinking" @@ -86,7 +99,7 @@ func TestDisplay_Streaming(t *testing.T) { var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayText("Hello") + d.Display(textContent("Hello")) d.FinishOutput("") got := buf.String() @@ -101,7 +114,7 @@ func TestDisplay_Streaming(t *testing.T) { var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayText("Hello") + d.Display(textContent("Hello")) d.FinishOutput("seq=1") got := buf.String() @@ -113,13 +126,13 @@ func TestDisplay_Streaming(t *testing.T) { } }) - t.Run("DisplaySystem resets state and prints newline", func(t *testing.T) { + t.Run("displaySystem resets state and prints newline", func(t *testing.T) { t.Parallel() var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayText("Hello") - d.DisplaySystem("system message") + d.Display(textContent("Hello")) + d.displaySystem("system message") got := buf.String() want := "Hello\nsystem message\n" @@ -133,7 +146,7 @@ func TestDisplay_Streaming(t *testing.T) { var buf bytes.Buffer d := NewDisplay("test-id", &buf) - d.DisplayText("Hello") + d.Display(textContent("Hello")) d.DisplayInput("prompt") got := buf.String() From 9fd2e64e4eebc8b266f4e2966495b951feee7497 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Wed, 17 Jun 2026 14:55:51 -0700 Subject: [PATCH 3/3] refactor: inline displayText and displayThought helper methods --- cmd/ax/internal/display.go | 60 ++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/cmd/ax/internal/display.go b/cmd/ax/internal/display.go index a0cbcf2..5c413ff 100644 --- a/cmd/ax/internal/display.go +++ b/cmd/ax/internal/display.go @@ -87,36 +87,6 @@ func (d *Display) DisplayInput(text string) { fmt.Fprintln(d.w) } -// displayText prints a chunk of model text response. -func (d *Display) displayText(text string) { - if d.state == stateThought { - fmt.Fprintln(d.w) // end the thinking line - } - d.state = stateText - fmt.Fprint(d.w, text) -} - -// displayThought prints a chunk of model thinking process. -func (d *Display) displayThought(text string) { - if d.state != stateThought { - if d.state == stateText { - fmt.Fprintln(d.w) - } - fmt.Fprint(d.w, "Thinking: ") - } - d.state = stateThought - fmt.Fprint(d.w, text) -} - -// displaySystem prints a system/error message on a new line. -func (d *Display) displaySystem(text string) { - if d.state != stateNone { - fmt.Fprintln(d.w) - } - d.state = stateNone - fmt.Fprintln(d.w, text) -} - // Display prints a content block according to its type. func (d *Display) Display(content *proto.Content) { if content == nil { @@ -124,11 +94,18 @@ func (d *Display) Display(content *proto.Content) { } switch o := content.Type.(type) { case *proto.Content_Text: - d.displayText(o.Text.Text) + if d.state == stateThought { + fmt.Fprintln(d.w) // end the thinking line + } + d.state = stateText + fmt.Fprint(d.w, o.Text.Text) + case *proto.Content_Confirmation: // Let the confirmation prompt handle displaying the question. + case *proto.Content_ToolCall: // No-op for cleaner CLI logs + case *proto.Content_ToolResult: // Only print if the tool returned an error, otherwise skip tr := o.ToolResult @@ -140,19 +117,38 @@ func (d *Display) Display(content *proto.Content) { } } } + case *proto.Content_Thought: for _, summary := range o.Thought.GetSummary() { if textContent := summary.GetText(); textContent != nil { - d.displayThought(textContent.Text) + if d.state != stateThought { + if d.state == stateText { + fmt.Fprintln(d.w) + } + fmt.Fprint(d.w, "Thinking: ") + } + d.state = stateThought + fmt.Fprint(d.w, textContent.Text) } } + case *proto.Content_Image, *proto.Content_Audio, *proto.Content_Video, *proto.Content_Document: d.displaySystem(fmt.Sprintf("unsupported output type for display: %T", o)) + default: d.displaySystem(fmt.Sprintf("unknown output type: %v", o)) } } +// displaySystem prints a system/error message on a new line. +func (d *Display) displaySystem(text string) { + if d.state != stateNone { + fmt.Fprintln(d.w) + } + d.state = stateNone + fmt.Fprintln(d.w, text) +} + // FinishOutput completes the streaming output and shows info if provided func (d *Display) FinishOutput(info string) { if d.state != stateNone {