Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cmd/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
ChatBasePath: viper.GetString(FlagChatBasePath),
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
InitialPrompt: viper.GetString(FlagInitialPrompt),
})
if err != nil {
return xerrors.Errorf("failed to create server: %w", err)
Expand Down Expand Up @@ -174,6 +175,7 @@ const (
FlagAllowedHosts = "allowed-hosts"
FlagAllowedOrigins = "allowed-origins"
FlagExit = "exit"
FlagInitialPrompt = "initial-prompt"
)

func CreateServerCmd() *cobra.Command {
Expand Down Expand Up @@ -211,6 +213,7 @@ func CreateServerCmd() *cobra.Command {
{FlagAllowedHosts, "a", []string{"localhost", "127.0.0.1", "[::1]"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
{FlagInitialPrompt, "I", "", "Initial prompt for the agent (recommended only if the agent doesn't support initial prompt in interaction mode)", "string"},
}

for _, spec := range flagSpecs {
Expand Down
17 changes: 15 additions & 2 deletions lib/httpapi/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ type ServerConfig struct {
ChatBasePath string
AllowedHosts []string
AllowedOrigins []string
InitialPrompt string
}

// Validate allowed hosts don't contain whitespace, commas, schemes, or ports.
Expand Down Expand Up @@ -230,7 +231,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
SnapshotInterval: snapshotInterval,
ScreenStabilityLength: 2 * time.Second,
FormatMessage: formatMessage,
})
}, config.InitialPrompt)
emitter := NewEventEmitter(1024)
s := &Server{
router: router,
Expand Down Expand Up @@ -306,7 +307,19 @@ func (s *Server) StartSnapshotLoop(ctx context.Context) {
s.conversation.StartSnapshotLoop(ctx)
go func() {
for {
s.emitter.UpdateStatusAndEmitChanges(s.conversation.Status())
currentStatus := s.conversation.Status()

// Send initial prompt when agent becomes stable for the first time
if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable {
if err := s.conversation.SendMessage(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil {
s.logger.Error("Failed to send initial prompt", "error", err)
} else {
s.conversation.InitialPromptSent = true
currentStatus = st.ConversationStatusChanging
s.logger.Info("Initial prompt sent successfully")
}
}
Comment on lines 312 to 321
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an inherent race condition here: we will report "stable" status for a short time period before "changing". Is it possible to prevent this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think line 322 would prevent that:

currentStatus = st.ConversationStatusChanging

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see now, I assumed that the status updates were done in a separate goroutine but it's actually done in UpdateStatusAndEmitChanges() below in line 326.

I think it's worth adding a test for this behaviour so that we can validate that we don't send an extraneous status update.

I also think this logic might be better encapsulated inside the Conversation.

s.emitter.UpdateStatusAndEmitChanges(currentStatus)
s.emitter.UpdateMessagesAndEmitChanges(s.conversation.Messages())
s.emitter.UpdateScreenAndEmitChanges(s.conversation.Screen())
time.Sleep(snapshotInterval)
Expand Down
8 changes: 7 additions & 1 deletion lib/screentracker/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ type Conversation struct {
messages []ConversationMessage
screenBeforeLastUserMessage string
lock sync.Mutex
// InitialPrompt is the initial prompt passed to the agent
InitialPrompt string
// InitialPromptSent keeps track if the InitialPrompt has been successfully sent to the agents
InitialPromptSent bool
}

type ConversationStatus string
Expand All @@ -94,7 +98,7 @@ func getStableSnapshotsThreshold(cfg ConversationConfig) int {
return threshold + 1
}

func NewConversation(ctx context.Context, cfg ConversationConfig) *Conversation {
func NewConversation(ctx context.Context, cfg ConversationConfig, initialPrompt string) *Conversation {
threshold := getStableSnapshotsThreshold(cfg)
c := &Conversation{
cfg: cfg,
Expand All @@ -107,6 +111,8 @@ func NewConversation(ctx context.Context, cfg ConversationConfig) *Conversation
Time: cfg.GetTime(),
},
},
InitialPrompt: initialPrompt,
InitialPromptSent: len(initialPrompt) == 0,
}
return c
}
Expand Down
4 changes: 2 additions & 2 deletions lib/screentracker/conversation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func statusTest(t *testing.T, params statusTestParams) {
if params.cfg.GetTime == nil {
params.cfg.GetTime = func() time.Time { return time.Now() }
}
c := st.NewConversation(ctx, params.cfg)
c := st.NewConversation(ctx, params.cfg, "")
assert.Equal(t, st.ConversationStatusInitializing, c.Status())

for i, step := range params.steps {
Expand Down Expand Up @@ -147,7 +147,7 @@ func TestMessages(t *testing.T) {
for _, opt := range opts {
opt(&cfg)
}
return st.NewConversation(context.Background(), cfg)
return st.NewConversation(context.Background(), cfg, "")
}

t.Run("messages are copied", func(t *testing.T) {
Expand Down