diff --git a/cagent-schema.json b/cagent-schema.json index 677be8384..e7afbf9ce 100644 --- a/cagent-schema.json +++ b/cagent-schema.json @@ -593,7 +593,8 @@ "api", "a2a", "lsp", - "user_prompt" + "user_prompt", + "switch_model" ] }, "instruction": { @@ -700,6 +701,17 @@ "description": "Timeout in seconds for the fetch tool", "minimum": 1 }, + "models": { + "type": "array", + "description": "List of allowed model references for the switch_model tool. If not specified, all models defined in the config are available.", + "items": { + "type": "string" + }, + "examples": [ + ["fast_model", "powerful_model"], + ["openai/gpt-4o-mini", "anthropic/claude-sonnet-4-0"] + ] + }, "url": { "type": "string", "description": "URL for the a2a tool", @@ -757,7 +769,8 @@ "memory", "script", "fetch", - "user_prompt" + "user_prompt", + "switch_model" ] } } diff --git a/e2e/switch_model_test.go b/e2e/switch_model_test.go new file mode 100644 index 000000000..0d86b4d07 --- /dev/null +++ b/e2e/switch_model_test.go @@ -0,0 +1,110 @@ +package e2e_test + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/agent" + "github.com/docker/cagent/pkg/chat" + "github.com/docker/cagent/pkg/config" + "github.com/docker/cagent/pkg/runtime" + "github.com/docker/cagent/pkg/session" + "github.com/docker/cagent/pkg/teamloader" +) + +// setupSwitchModelTest creates a runtime with model switching support. +func setupSwitchModelTest(t *testing.T) (runtime.Runtime, *agent.Agent) { + t.Helper() + + ctx := t.Context() + agentSource, err := config.Resolve("testdata/switch_model.yaml") + require.NoError(t, err) + + _, runConfig := startRecordingAIProxy(t) + loadResult, err := teamloader.LoadWithConfig(ctx, agentSource, runConfig) + require.NoError(t, err) + + modelSwitcherCfg := &runtime.ModelSwitcherConfig{ + Models: loadResult.Models, + Providers: loadResult.Providers, + ModelsGateway: runConfig.ModelsGateway, + EnvProvider: runConfig.EnvProvider(), + AgentDefaultModels: loadResult.AgentDefaultModels, + } + + rt, err := runtime.New(loadResult.Team, runtime.WithModelSwitcherConfig(modelSwitcherCfg)) + require.NoError(t, err) + + rootAgent, err := loadResult.Team.Agent("root") + require.NoError(t, err) + + return rt, rootAgent +} + +// findSwitchModelCall searches session messages for a switch_model tool call containing the given model name. +func findSwitchModelCall(sess *session.Session, modelName string) bool { + for _, msg := range sess.GetAllMessages() { + if msg.Message.Role != chat.MessageRoleAssistant || msg.Message.ToolCalls == nil { + continue + } + for _, tc := range msg.Message.ToolCalls { + if tc.Function.Name == "switch_model" && strings.Contains(tc.Function.Arguments, modelName) { + return true + } + } + } + return false +} + +// TestSwitchModel_AgentCanSwitchModels verifies that an agent can use the switch_model tool +// to change between models during a conversation. +func TestSwitchModel_AgentCanSwitchModels(t *testing.T) { + t.Parallel() + + ctx := t.Context() + rt, _ := setupSwitchModelTest(t) + + // Switch to smart model + sess := session.New(session.WithUserMessage("Switch to the smart model, then say hi")) + _, err := rt.Run(ctx, sess) + require.NoError(t, err) + + assert.True(t, findSwitchModelCall(sess, "smart"), "Expected switch_model tool call with 'smart' model") + assert.NotEmpty(t, sess.GetLastAssistantMessageContent(), "Expected a response after switching") + + // Switch back to fast model + sess.AddMessage(session.UserMessage("Now switch back to the fast model and say goodbye")) + _, err = rt.Run(ctx, sess) + require.NoError(t, err) + + assert.True(t, findSwitchModelCall(sess, "fast"), "Expected switch_model tool call with 'fast' model") + assert.NotEmpty(t, sess.GetLastAssistantMessageContent(), "Expected a response after switching back") +} + +// TestSwitchModel_ModelActuallyChanges verifies that after calling switch_model, +// the agent's model object is updated to the new model. +func TestSwitchModel_ModelActuallyChanges(t *testing.T) { + t.Parallel() + + ctx := t.Context() + rt, rootAgent := setupSwitchModelTest(t) + + assert.Contains(t, rootAgent.Model().ID(), "gpt-4o-mini", "Should start with gpt-4o-mini") + + // Switch to smart model + sess := session.New(session.WithUserMessage("Use the switch_model tool to switch to smart model, then just say 'done'")) + _, err := rt.Run(ctx, sess) + require.NoError(t, err) + + assert.Contains(t, rootAgent.Model().ID(), "claude", "Model should have changed to claude") + + // Verify the new model works + sess.AddMessage(session.UserMessage("What is 2+2? Answer with just the number.")) + _, err = rt.Run(ctx, sess) + require.NoError(t, err) + + assert.NotEmpty(t, sess.GetLastAssistantMessageContent()) +} diff --git a/e2e/testdata/cassettes/TestSwitchModel_AgentCanSwitchModels.yaml b/e2e/testdata/cassettes/TestSwitchModel_AgentCanSwitchModels.yaml new file mode 100644 index 000000000..1eb30dc59 --- /dev/null +++ b/e2e/testdata/cassettes/TestSwitchModel_AgentCanSwitchModels.yaml @@ -0,0 +1,190 @@ +--- +version: 2 +interactions: + - id: 0 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: api.openai.com + body: '{"messages":[{"content":"You are a helpful assistant that can switch between models.\nWhen asked to switch to a better/smarter model, use the switch_model tool to switch to \"smart\".\nWhen asked to switch to the default/fast model, use the switch_model tool to switch to \"fast\".\nAfter switching, respond briefly confirming the switch and then answer any question.\nKeep responses very short (one sentence max).\n","role":"system"},{"content":"## Model Switching Guidelines\n\nYou have access to multiple AI models and can switch between them strategically.\n\n### When to Consider Switching Models\n\n**Switch to a faster/cheaper model when:**\n- Performing simple, routine tasks (formatting, basic Q&A, short summaries)\n- The current task doesn''t require advanced reasoning\n- Processing straightforward requests that any model can handle well\n- Optimizing for response speed or cost efficiency\n\n**Switch to a more powerful model when:**\n- Facing complex reasoning or multi-step problems\n- Writing or reviewing code that requires careful analysis\n- Handling nuanced or ambiguous requests\n- Generating detailed technical content\n- The current model is struggling with the task quality\n\n**Switch back to the default model when:**\n- A specialized task is complete\n- Returning to general conversation\n- The extra capability is no longer needed\n\n### Best Practices\n\n1. Check the tool description to see available models and which one is currently active\n2. Don''t switch unnecessarily - there''s overhead in changing models\n3. Consider switching proactively before a complex task rather than after struggling\n4. When in doubt about task complexity, prefer the more capable model","role":"system"},{"content":"Switch to the smart model, then say hi","role":"user"}],"model":"gpt-4o-mini","max_tokens":16384,"parallel_tool_calls":true,"stream_options":{"include_usage":true},"tools":[{"function":{"name":"switch_model","description":"Switch the AI model used for subsequent responses.\n\n**Available models:**\n- fast (default) (current)\n- smart\n\nOnly the models listed above can be selected. Any other model will be rejected.","parameters":{"additionalProperties":false,"properties":{"model":{"description":"The model to switch to. Must be one of the allowed models listed in the tool description.","type":"string"}},"required":["model"],"type":"object"}},"type":"function"}],"stream":true}' + url: https://api.openai.com/v1/chat/completions + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_7kJnnLoJFIbRnRHe1CoYTiXn","type":"function","function":{"name":"switch_model","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"47P"} + + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"Brogprarad7mtQ"} + + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"model"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"HyE8RFFAvAHg"} + + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"fVnOy0tVTM1B"} + + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"smart"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"OcEq07Zj7pQR"} + + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"8mxN23j9adGW6p"} + + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null,"obfuscation":"Kk3toyFyDAr59xO"} + + data: {"id":"chatcmpl-D3onWDXckNsqxEavlG94tgXadv48f","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[],"usage":{"prompt_tokens":424,"completion_tokens":14,"total_tokens":438,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}},"obfuscation":"hlCnskHF"} + + data: [DONE] + + headers: {} + status: 200 OK + code: 200 + duration: 925.604834ms + - id: 1 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: api.anthropic.com + body: '{"max_tokens":64000,"messages":[{"content":[{"text":"Switch to the smart model, then say hi","type":"text"}],"role":"user"},{"content":[{"id":"call_7kJnnLoJFIbRnRHe1CoYTiXn","input":{"model":"smart"},"name":"switch_model","cache_control":{"type":"ephemeral"},"type":"tool_use"}],"role":"assistant"},{"content":[{"tool_use_id":"call_7kJnnLoJFIbRnRHe1CoYTiXn","is_error":false,"cache_control":{"type":"ephemeral"},"content":[{"text":"Switched model from \"fast\" to \"smart\".","type":"text"}],"type":"tool_result"}],"role":"user"}],"model":"claude-sonnet-4-20250514","system":[{"text":"You are a helpful assistant that can switch between models.\nWhen asked to switch to a better/smarter model, use the switch_model tool to switch to \"smart\".\nWhen asked to switch to the default/fast model, use the switch_model tool to switch to \"fast\".\nAfter switching, respond briefly confirming the switch and then answer any question.\nKeep responses very short (one sentence max).","type":"text"},{"text":"## Model Switching Guidelines\n\nYou have access to multiple AI models and can switch between them strategically.\n\n### When to Consider Switching Models\n\n**Switch to a faster/cheaper model when:**\n- Performing simple, routine tasks (formatting, basic Q\u0026A, short summaries)\n- The current task doesn''t require advanced reasoning\n- Processing straightforward requests that any model can handle well\n- Optimizing for response speed or cost efficiency\n\n**Switch to a more powerful model when:**\n- Facing complex reasoning or multi-step problems\n- Writing or reviewing code that requires careful analysis\n- Handling nuanced or ambiguous requests\n- Generating detailed technical content\n- The current model is struggling with the task quality\n\n**Switch back to the default model when:**\n- A specialized task is complete\n- Returning to general conversation\n- The extra capability is no longer needed\n\n### Best Practices\n\n1. Check the tool description to see available models and which one is currently active\n2. Don''t switch unnecessarily - there''s overhead in changing models\n3. Consider switching proactively before a complex task rather than after struggling\n4. When in doubt about task complexity, prefer the more capable model","cache_control":{"type":"ephemeral"},"type":"text"}],"tools":[{"input_schema":{"properties":{"model":{"description":"The model to switch to. Must be one of the allowed models listed in the tool description.","type":"string"}},"required":["model"],"type":"object"},"name":"switch_model","description":"Switch the AI model used for subsequent responses.\n\n**Available models:**\n- fast (default)\n- smart (current)\n\nOnly the models listed above can be selected. Any other model will be rejected."}],"stream":true}' + url: https://api.anthropic.com/v1/messages + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + event: message_start + data: {"type":"message_start","message":{"model":"claude-sonnet-4-20250514","id":"msg_01MD9TRKxyqCVZAZ7ensFgbo","type":"message","role":"assistant","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":864,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":2,"service_tier":"standard"}} } + + event: content_block_start + data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} } + + event: ping + data: {"type": "ping"} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Switched"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to smart"} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" model -"} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" hi"} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" there"} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"!"} } + + event: content_block_stop + data: {"type":"content_block_stop","index":0 } + + event: message_delta + data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":864,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":12} } + + event: message_stop + data: {"type":"message_stop" } + + headers: {} + status: 200 OK + code: 200 + duration: 2.564459334s + - id: 2 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: api.anthropic.com + body: '{"max_tokens":64000,"messages":[{"content":[{"text":"Switch to the smart model, then say hi","type":"text"}],"role":"user"},{"content":[{"id":"call_7kJnnLoJFIbRnRHe1CoYTiXn","input":{"model":"smart"},"name":"switch_model","type":"tool_use"}],"role":"assistant"},{"content":[{"tool_use_id":"call_7kJnnLoJFIbRnRHe1CoYTiXn","is_error":false,"content":[{"text":"Switched model from \"fast\" to \"smart\".","type":"text"}],"type":"tool_result"}],"role":"user"},{"content":[{"text":"Switched to smart model - hi there!","cache_control":{"type":"ephemeral"},"type":"text"}],"role":"assistant"},{"content":[{"text":"Now switch back to the fast model and say goodbye","cache_control":{"type":"ephemeral"},"type":"text"}],"role":"user"}],"model":"claude-sonnet-4-20250514","system":[{"text":"You are a helpful assistant that can switch between models.\nWhen asked to switch to a better/smarter model, use the switch_model tool to switch to \"smart\".\nWhen asked to switch to the default/fast model, use the switch_model tool to switch to \"fast\".\nAfter switching, respond briefly confirming the switch and then answer any question.\nKeep responses very short (one sentence max).","type":"text"},{"text":"## Model Switching Guidelines\n\nYou have access to multiple AI models and can switch between them strategically.\n\n### When to Consider Switching Models\n\n**Switch to a faster/cheaper model when:**\n- Performing simple, routine tasks (formatting, basic Q\u0026A, short summaries)\n- The current task doesn''t require advanced reasoning\n- Processing straightforward requests that any model can handle well\n- Optimizing for response speed or cost efficiency\n\n**Switch to a more powerful model when:**\n- Facing complex reasoning or multi-step problems\n- Writing or reviewing code that requires careful analysis\n- Handling nuanced or ambiguous requests\n- Generating detailed technical content\n- The current model is struggling with the task quality\n\n**Switch back to the default model when:**\n- A specialized task is complete\n- Returning to general conversation\n- The extra capability is no longer needed\n\n### Best Practices\n\n1. Check the tool description to see available models and which one is currently active\n2. Don''t switch unnecessarily - there''s overhead in changing models\n3. Consider switching proactively before a complex task rather than after struggling\n4. When in doubt about task complexity, prefer the more capable model","cache_control":{"type":"ephemeral"},"type":"text"}],"tools":[{"input_schema":{"properties":{"model":{"description":"The model to switch to. Must be one of the allowed models listed in the tool description.","type":"string"}},"required":["model"],"type":"object"},"name":"switch_model","description":"Switch the AI model used for subsequent responses.\n\n**Available models:**\n- fast (default)\n- smart (current)\n\nOnly the models listed above can be selected. Any other model will be rejected."}],"stream":true}' + url: https://api.anthropic.com/v1/messages + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + event: message_start + data: {"type":"message_start","message":{"model":"claude-sonnet-4-20250514","id":"msg_01GwhF1NoMHrKCtnok3BqVmV","type":"message","role":"assistant","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":889,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":8,"service_tier":"standard"}} } + + event: content_block_start + data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_019FLZJHzXX2txzzYkeYERu3","name":"switch_model","input":{}} } + + event: ping + data: {"type": "ping"} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\"model\": \"fast"} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"\"}"} } + + event: content_block_stop + data: {"type":"content_block_stop","index":0 } + + event: message_delta + data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":889,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":53} } + + event: message_stop + data: {"type":"message_stop" } + + headers: {} + status: 200 OK + code: 200 + duration: 1.50779325s + - id: 3 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: api.openai.com + body: '{"messages":[{"content":"You are a helpful assistant that can switch between models.\nWhen asked to switch to a better/smarter model, use the switch_model tool to switch to \"smart\".\nWhen asked to switch to the default/fast model, use the switch_model tool to switch to \"fast\".\nAfter switching, respond briefly confirming the switch and then answer any question.\nKeep responses very short (one sentence max).\n","role":"system"},{"content":"## Model Switching Guidelines\n\nYou have access to multiple AI models and can switch between them strategically.\n\n### When to Consider Switching Models\n\n**Switch to a faster/cheaper model when:**\n- Performing simple, routine tasks (formatting, basic Q&A, short summaries)\n- The current task doesn''t require advanced reasoning\n- Processing straightforward requests that any model can handle well\n- Optimizing for response speed or cost efficiency\n\n**Switch to a more powerful model when:**\n- Facing complex reasoning or multi-step problems\n- Writing or reviewing code that requires careful analysis\n- Handling nuanced or ambiguous requests\n- Generating detailed technical content\n- The current model is struggling with the task quality\n\n**Switch back to the default model when:**\n- A specialized task is complete\n- Returning to general conversation\n- The extra capability is no longer needed\n\n### Best Practices\n\n1. Check the tool description to see available models and which one is currently active\n2. Don''t switch unnecessarily - there''s overhead in changing models\n3. Consider switching proactively before a complex task rather than after struggling\n4. When in doubt about task complexity, prefer the more capable model","role":"system"},{"content":"Switch to the smart model, then say hi","role":"user"},{"tool_calls":[{"id":"call_7kJnnLoJFIbRnRHe1CoYTiXn","function":{"arguments":"{\"model\":\"smart\"}","name":"switch_model"},"type":"function"}],"role":"assistant"},{"content":"Switched model from \"fast\" to \"smart\".","tool_call_id":"call_7kJnnLoJFIbRnRHe1CoYTiXn","role":"tool"},{"content":"Switched to smart model - hi there!","role":"assistant"},{"content":"Now switch back to the fast model and say goodbye","role":"user"},{"tool_calls":[{"id":"toolu_019FLZJHzXX2txzzYkeYERu3","function":{"arguments":"{\"model\": \"fast\"}","name":"switch_model"},"type":"function"}],"role":"assistant"},{"content":"Switched model from \"smart\" to \"fast\".","tool_call_id":"toolu_019FLZJHzXX2txzzYkeYERu3","role":"tool"}],"model":"gpt-4o-mini","max_tokens":16384,"parallel_tool_calls":true,"stream_options":{"include_usage":true},"tools":[{"function":{"name":"switch_model","description":"Switch the AI model used for subsequent responses.\n\n**Available models:**\n- fast (default) (current)\n- smart\n\nOnly the models listed above can be selected. Any other model will be rejected.","parameters":{"additionalProperties":false,"properties":{"model":{"description":"The model to switch to. Must be one of the allowed models listed in the tool description.","type":"string"}},"required":["model"],"type":"object"}},"type":"function"}],"stream":true}' + url: https://api.openai.com/v1/chat/completions + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"EBiZHklkg"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":"Sw"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"glrvUYASg"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":"itched"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"xLW9V"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":" back"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"tLdkd3"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"C1BtunrF"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":" fast"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"eqHrjd"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":" model"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"hKBvv"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":" -"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"wUJj0lES4"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":" goodbye"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"Saf"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"ZA7kxYZV20"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null,"obfuscation":"rwnJA"} + + data: {"id":"chatcmpl-D3oncDuEbfU7fwzaXtBpP6KzBraTH","object":"chat.completion.chunk","created":1769802628,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[],"usage":{"prompt_tokens":517,"completion_tokens":10,"total_tokens":527,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}},"obfuscation":"xeWAz0J3"} + + data: [DONE] + + headers: {} + status: 200 OK + code: 200 + duration: 825.417875ms diff --git a/e2e/testdata/cassettes/TestSwitchModel_ModelActuallyChanges.yaml b/e2e/testdata/cassettes/TestSwitchModel_ModelActuallyChanges.yaml new file mode 100644 index 000000000..77d8210a2 --- /dev/null +++ b/e2e/testdata/cassettes/TestSwitchModel_ModelActuallyChanges.yaml @@ -0,0 +1,126 @@ +--- +version: 2 +interactions: + - id: 0 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: api.openai.com + body: '{"messages":[{"content":"You are a helpful assistant that can switch between models.\nWhen asked to switch to a better/smarter model, use the switch_model tool to switch to \"smart\".\nWhen asked to switch to the default/fast model, use the switch_model tool to switch to \"fast\".\nAfter switching, respond briefly confirming the switch and then answer any question.\nKeep responses very short (one sentence max).\n","role":"system"},{"content":"## Model Switching Guidelines\n\nYou have access to multiple AI models and can switch between them strategically.\n\n### When to Consider Switching Models\n\n**Switch to a faster/cheaper model when:**\n- Performing simple, routine tasks (formatting, basic Q&A, short summaries)\n- The current task doesn''t require advanced reasoning\n- Processing straightforward requests that any model can handle well\n- Optimizing for response speed or cost efficiency\n\n**Switch to a more powerful model when:**\n- Facing complex reasoning or multi-step problems\n- Writing or reviewing code that requires careful analysis\n- Handling nuanced or ambiguous requests\n- Generating detailed technical content\n- The current model is struggling with the task quality\n\n**Switch back to the default model when:**\n- A specialized task is complete\n- Returning to general conversation\n- The extra capability is no longer needed\n\n### Best Practices\n\n1. Check the tool description to see available models and which one is currently active\n2. Don''t switch unnecessarily - there''s overhead in changing models\n3. Consider switching proactively before a complex task rather than after struggling\n4. When in doubt about task complexity, prefer the more capable model","role":"system"},{"content":"Use the switch_model tool to switch to smart model, then just say ''done''","role":"user"}],"model":"gpt-4o-mini","max_tokens":16384,"parallel_tool_calls":true,"stream_options":{"include_usage":true},"tools":[{"function":{"name":"switch_model","description":"Switch the AI model used for subsequent responses.\n\n**Available models:**\n- fast (default) (current)\n- smart\n\nOnly the models listed above can be selected. Any other model will be rejected.","parameters":{"additionalProperties":false,"properties":{"model":{"description":"The model to switch to. Must be one of the allowed models listed in the tool description.","type":"string"}},"required":["model"],"type":"object"}},"type":"function"}],"stream":true}' + url: https://api.openai.com/v1/chat/completions + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_f4epGCysqn5sr2YvsxOSbSgd","type":"function","function":{"name":"switch_model","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"PSx"} + + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"fzow185TGFrgOP"} + + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"model"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"DdF9c1FUwkdS"} + + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"9sVOZKCWUxmi"} + + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"smart"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"HqItFmdno1nY"} + + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"ccabCw5RGwhlbI"} + + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null,"obfuscation":"5XcPTFKUh1OVXGN"} + + data: {"id":"chatcmpl-D3onWKk4DaJZDBtyG0jSyKTvAjWCN","object":"chat.completion.chunk","created":1769802622,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_1590f93f9d","choices":[],"usage":{"prompt_tokens":432,"completion_tokens":14,"total_tokens":446,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}},"obfuscation":"Xr45WNpE"} + + data: [DONE] + + headers: {} + status: 200 OK + code: 200 + duration: 977.696458ms + - id: 1 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: api.anthropic.com + body: '{"max_tokens":64000,"messages":[{"content":[{"text":"Use the switch_model tool to switch to smart model, then just say ''done''","type":"text"}],"role":"user"},{"content":[{"id":"call_f4epGCysqn5sr2YvsxOSbSgd","input":{"model":"smart"},"name":"switch_model","cache_control":{"type":"ephemeral"},"type":"tool_use"}],"role":"assistant"},{"content":[{"tool_use_id":"call_f4epGCysqn5sr2YvsxOSbSgd","is_error":false,"cache_control":{"type":"ephemeral"},"content":[{"text":"Switched model from \"fast\" to \"smart\".","type":"text"}],"type":"tool_result"}],"role":"user"}],"model":"claude-sonnet-4-20250514","system":[{"text":"You are a helpful assistant that can switch between models.\nWhen asked to switch to a better/smarter model, use the switch_model tool to switch to \"smart\".\nWhen asked to switch to the default/fast model, use the switch_model tool to switch to \"fast\".\nAfter switching, respond briefly confirming the switch and then answer any question.\nKeep responses very short (one sentence max).","type":"text"},{"text":"## Model Switching Guidelines\n\nYou have access to multiple AI models and can switch between them strategically.\n\n### When to Consider Switching Models\n\n**Switch to a faster/cheaper model when:**\n- Performing simple, routine tasks (formatting, basic Q\u0026A, short summaries)\n- The current task doesn''t require advanced reasoning\n- Processing straightforward requests that any model can handle well\n- Optimizing for response speed or cost efficiency\n\n**Switch to a more powerful model when:**\n- Facing complex reasoning or multi-step problems\n- Writing or reviewing code that requires careful analysis\n- Handling nuanced or ambiguous requests\n- Generating detailed technical content\n- The current model is struggling with the task quality\n\n**Switch back to the default model when:**\n- A specialized task is complete\n- Returning to general conversation\n- The extra capability is no longer needed\n\n### Best Practices\n\n1. Check the tool description to see available models and which one is currently active\n2. Don''t switch unnecessarily - there''s overhead in changing models\n3. Consider switching proactively before a complex task rather than after struggling\n4. When in doubt about task complexity, prefer the more capable model","cache_control":{"type":"ephemeral"},"type":"text"}],"tools":[{"input_schema":{"properties":{"model":{"description":"The model to switch to. Must be one of the allowed models listed in the tool description.","type":"string"}},"required":["model"],"type":"object"},"name":"switch_model","description":"Switch the AI model used for subsequent responses.\n\n**Available models:**\n- fast (default)\n- smart (current)\n\nOnly the models listed above can be selected. Any other model will be rejected."}],"stream":true}' + url: https://api.anthropic.com/v1/messages + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + event: message_start + data: {"type":"message_start","message":{"model":"claude-sonnet-4-20250514","id":"msg_01Vg6kQujtxt1Trp3nhDeLoU","type":"message","role":"assistant","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":874,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":1,"service_tier":"standard"}} } + + event: content_block_start + data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + + event: ping + data: {"type": "ping"} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Done"} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"."} } + + event: content_block_stop + data: {"type":"content_block_stop","index":0 } + + event: message_delta + data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":874,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":5} } + + event: message_stop + data: {"type":"message_stop" } + + headers: {} + status: 200 OK + code: 200 + duration: 1.468132875s + - id: 2 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 0 + host: api.anthropic.com + body: '{"max_tokens":64000,"messages":[{"content":[{"text":"Use the switch_model tool to switch to smart model, then just say ''done''","type":"text"}],"role":"user"},{"content":[{"id":"call_f4epGCysqn5sr2YvsxOSbSgd","input":{"model":"smart"},"name":"switch_model","type":"tool_use"}],"role":"assistant"},{"content":[{"tool_use_id":"call_f4epGCysqn5sr2YvsxOSbSgd","is_error":false,"content":[{"text":"Switched model from \"fast\" to \"smart\".","type":"text"}],"type":"tool_result"}],"role":"user"},{"content":[{"text":"Done.","cache_control":{"type":"ephemeral"},"type":"text"}],"role":"assistant"},{"content":[{"text":"What is 2+2? Answer with just the number.","cache_control":{"type":"ephemeral"},"type":"text"}],"role":"user"}],"model":"claude-sonnet-4-20250514","system":[{"text":"You are a helpful assistant that can switch between models.\nWhen asked to switch to a better/smarter model, use the switch_model tool to switch to \"smart\".\nWhen asked to switch to the default/fast model, use the switch_model tool to switch to \"fast\".\nAfter switching, respond briefly confirming the switch and then answer any question.\nKeep responses very short (one sentence max).","type":"text"},{"text":"## Model Switching Guidelines\n\nYou have access to multiple AI models and can switch between them strategically.\n\n### When to Consider Switching Models\n\n**Switch to a faster/cheaper model when:**\n- Performing simple, routine tasks (formatting, basic Q\u0026A, short summaries)\n- The current task doesn''t require advanced reasoning\n- Processing straightforward requests that any model can handle well\n- Optimizing for response speed or cost efficiency\n\n**Switch to a more powerful model when:**\n- Facing complex reasoning or multi-step problems\n- Writing or reviewing code that requires careful analysis\n- Handling nuanced or ambiguous requests\n- Generating detailed technical content\n- The current model is struggling with the task quality\n\n**Switch back to the default model when:**\n- A specialized task is complete\n- Returning to general conversation\n- The extra capability is no longer needed\n\n### Best Practices\n\n1. Check the tool description to see available models and which one is currently active\n2. Don''t switch unnecessarily - there''s overhead in changing models\n3. Consider switching proactively before a complex task rather than after struggling\n4. When in doubt about task complexity, prefer the more capable model","cache_control":{"type":"ephemeral"},"type":"text"}],"tools":[{"input_schema":{"properties":{"model":{"description":"The model to switch to. Must be one of the allowed models listed in the tool description.","type":"string"}},"required":["model"],"type":"object"},"name":"switch_model","description":"Switch the AI model used for subsequent responses.\n\n**Available models:**\n- fast (default)\n- smart (current)\n\nOnly the models listed above can be selected. Any other model will be rejected."}],"stream":true}' + url: https://api.anthropic.com/v1/messages + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + event: message_start + data: {"type":"message_start","message":{"model":"claude-sonnet-4-20250514","id":"msg_017Ko3yzvcdekdx9jBjSoPqY","type":"message","role":"assistant","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":895,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":1,"service_tier":"standard"}} } + + event: content_block_start + data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} } + + event: ping + data: {"type": "ping"} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"4"} } + + event: content_block_stop + data: {"type":"content_block_stop","index":0 } + + event: message_delta + data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":895,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":5} } + + event: message_stop + data: {"type":"message_stop" } + + headers: {} + status: 200 OK + code: 200 + duration: 1.430636s diff --git a/e2e/testdata/switch_model.yaml b/e2e/testdata/switch_model.yaml new file mode 100644 index 000000000..20dfa515a --- /dev/null +++ b/e2e/testdata/switch_model.yaml @@ -0,0 +1,23 @@ +version: 4 + +models: + fast: + provider: openai + model: gpt-4o-mini + smart: + provider: anthropic + model: claude-sonnet-4-0 + +agents: + root: + model: fast + description: Agent with model switching capability + instruction: | + You are a helpful assistant that can switch between models. + When asked to switch to a better/smarter model, use the switch_model tool to switch to "smart". + When asked to switch to the default/fast model, use the switch_model tool to switch to "fast". + After switching, respond briefly confirming the switch and then answer any question. + Keep responses very short (one sentence max). + toolsets: + - type: switch_model + models: [fast, smart] diff --git a/examples/switch_model.yaml b/examples/switch_model.yaml new file mode 100644 index 000000000..4723f9dd5 --- /dev/null +++ b/examples/switch_model.yaml @@ -0,0 +1,40 @@ +#!/usr/bin/env cagent run + +# This example demonstrates how an agent can dynamically switch between different +# AI models during a conversation. This is useful for: +# - Cost optimization: Use cheaper models for simple tasks +# - Performance: Use more powerful models for complex reasoning +# - Specialization: Use domain-specific models when appropriate + +models: + fast: + provider: anthropic + model: claude-haiku-4-5-20251001 + + powerful: + provider: anthropic + model: claude-sonnet-4-5-20250929 + +agents: + root: + model: fast + description: An adaptive assistant that can switch models based on task complexity + instruction: | + You are a helpful assistant with the ability to switch between different AI models. + + Use switch_model strategically: + - For simple tasks (formatting, basic Q&A, summaries), use the faster/cheaper model + - For complex tasks (code generation, analysis, reasoning), use the more powerful model + - The switch_model tool description shows available models and which one is current + - After completing a specialized task, consider switching back to the default model + + Example workflow: + 1. User asks a simple question -> switch to 'fast' for efficiency + 2. User asks for complex code -> switch to 'powerful' for quality + 3. Task complete -> switch back to 'fast' + + toolsets: + - type: filesystem + - type: shell + - type: switch_model + models: [fast, powerful] diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index aab024847..b281361d1 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -269,6 +269,9 @@ type Toolset struct { // For the `fetch` tool Timeout int `json:"timeout,omitempty"` + + // For the `switch_model` tool - list of allowed models + Models []string `json:"models,omitempty"` } func (t *Toolset) UnmarshalYAML(unmarshal func(any) error) error { diff --git a/pkg/config/latest/validate.go b/pkg/config/latest/validate.go index 752c29987..50a9fb2bb 100644 --- a/pkg/config/latest/validate.go +++ b/pkg/config/latest/validate.go @@ -2,6 +2,7 @@ package latest import ( "errors" + "fmt" "strings" ) @@ -56,6 +57,9 @@ func (t *Toolset) validate() error { if t.Shared && t.Type != "todo" { return errors.New("shared can only be used with type 'todo'") } + if len(t.Models) > 0 && t.Type != "switch_model" { + return errors.New("models can only be used with type 'switch_model'") + } if t.Command != "" && t.Type != "mcp" && t.Type != "lsp" { return errors.New("command can only be used with type 'mcp' or 'lsp'") } @@ -86,6 +90,15 @@ func (t *Toolset) validate() error { if t.Sandbox != nil && len(t.Sandbox.Paths) == 0 { return errors.New("sandbox requires at least one path to be set") } + case "switch_model": + if len(t.Models) == 0 { + return errors.New("switch_model toolset requires at least one model") + } + for i, m := range t.Models { + if strings.TrimSpace(m) == "" { + return fmt.Errorf("switch_model toolset: model at index %d is empty", i) + } + } case "memory": if t.Path == "" { return errors.New("memory toolset requires a path to be set") diff --git a/pkg/config/latest/validate_test.go b/pkg/config/latest/validate_test.go index c68d95cf1..e700bff0e 100644 --- a/pkg/config/latest/validate_test.go +++ b/pkg/config/latest/validate_test.go @@ -189,3 +189,121 @@ agents: }) } } + +func TestToolset_Validate_SwitchModel(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config string + wantErr string + }{ + { + name: "valid switch_model with models", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: switch_model + models: [fast, powerful] +`, + wantErr: "", + }, + { + name: "valid switch_model with single model", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: switch_model + models: + - only_one +`, + wantErr: "", + }, + { + name: "switch_model without models", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: switch_model +`, + wantErr: "switch_model toolset requires at least one model", + }, + { + name: "switch_model with empty models list", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: switch_model + models: [] +`, + wantErr: "switch_model toolset requires at least one model", + }, + { + name: "switch_model with empty string in models", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: switch_model + models: [fast, "", powerful] +`, + wantErr: "switch_model toolset: model at index 1 is empty", + }, + { + name: "switch_model with whitespace-only string in models", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: switch_model + models: [fast, " ", powerful] +`, + wantErr: "switch_model toolset: model at index 1 is empty", + }, + { + name: "models on non-switch_model toolset", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: shell + models: [fast, powerful] +`, + wantErr: "models can only be used with type 'switch_model'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var cfg Config + err := yaml.Unmarshal([]byte(tt.config), &cfg) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index f84ea553f..76110f1f5 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -83,172 +83,101 @@ func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef st return nil } - // Check if modelRef is a named model from config - if modelConfig, exists := r.modelSwitcherCfg.Models[modelRef]; exists { - // Check if this is an alloy model (no provider, comma-separated models) - if isAlloyModelConfig(modelConfig) { - providers, err := r.createProvidersFromAlloyConfig(ctx, modelConfig) - if err != nil { - return fmt.Errorf("failed to create alloy model from config: %w", err) - } - a.SetModelOverride(providers...) - slog.Info("Set agent model override (alloy)", "agent", agentName, "config_name", modelRef, "model_count", len(providers)) - return nil - } + // Parse model references (handles both single and comma-separated refs) + refs := splitModelRefs(modelRef) - prov, err := r.createProviderFromConfig(ctx, &modelConfig) + // Single reference - could be a config name or inline "provider/model" + if len(refs) == 1 { + prov, err := r.resolveModelRef(ctx, refs[0]) if err != nil { - return fmt.Errorf("failed to create model from config: %w", err) + return err } a.SetModelOverride(prov) - slog.Info("Set agent model override", "agent", agentName, "model", prov.ID(), "config_name", modelRef) + slog.Info("Set agent model override", "agent", agentName, "model", prov.ID()) return nil } - // Check if this is an inline alloy spec (comma-separated provider/model specs) - // e.g., "openai/gpt-4o,anthropic/claude-sonnet-4-0" - if isInlineAlloySpec(modelRef) { - providers, err := r.createProvidersFromInlineAlloy(ctx, modelRef) - if err != nil { - return fmt.Errorf("failed to create inline alloy model: %w", err) - } - a.SetModelOverride(providers...) - slog.Info("Set agent model override (inline alloy)", "agent", agentName, "model_count", len(providers)) - return nil - } - - // Try parsing as inline spec (provider/model) - providerName, modelName, ok := strings.Cut(modelRef, "/") - if !ok { - return fmt.Errorf("invalid model reference %q: expected a model name from config or 'provider/model' format", modelRef) - } - - inlineCfg := &latest.ModelConfig{ - Provider: providerName, - Model: modelName, - } - prov, err := r.createProviderFromConfig(ctx, inlineCfg) + // Multiple references - create an alloy (multiple providers) + providers, err := r.resolveModelRefs(ctx, refs) if err != nil { - return fmt.Errorf("failed to create inline model: %w", err) + return fmt.Errorf("failed to create alloy model: %w", err) } - a.SetModelOverride(prov) - slog.Info("Set agent model override (inline)", "agent", agentName, "model", prov.ID()) + a.SetModelOverride(providers...) + slog.Info("Set agent model override (alloy)", "agent", agentName, "model_count", len(providers)) return nil } -// isAlloyModelConfig checks if a model config is an alloy model (multiple models). -func isAlloyModelConfig(cfg latest.ModelConfig) bool { - return cfg.Provider == "" && strings.Contains(cfg.Model, ",") +// splitModelRefs splits a comma-separated model reference string into individual refs. +func splitModelRefs(refs string) []string { + var result []string + for ref := range strings.SplitSeq(refs, ",") { + if ref = strings.TrimSpace(ref); ref != "" { + result = append(result, ref) + } + } + return result } // isInlineAlloySpec checks if a model reference is an inline alloy specification. // An inline alloy is comma-separated provider/model specs like "openai/gpt-4o,anthropic/claude-sonnet-4-0". +// All parts must contain a "/" (i.e., be inline specs) for it to be considered an inline alloy. func isInlineAlloySpec(modelRef string) bool { - if !strings.Contains(modelRef, ",") { + refs := splitModelRefs(modelRef) + if len(refs) < 2 { return false } - // Check that each part looks like a provider/model spec - // and count valid parts (need at least 2 for an alloy) - validParts := 0 - for part := range strings.SplitSeq(modelRef, ",") { - part = strings.TrimSpace(part) - if part == "" { - continue - } - if !strings.Contains(part, "/") { + for _, ref := range refs { + if !strings.Contains(ref, "/") { return false } - validParts++ } - return validParts >= 2 + return true } -// createProvidersFromInlineAlloy creates providers from an inline alloy spec. -// An inline alloy is comma-separated provider/model specs like "openai/gpt-4o,anthropic/claude-sonnet-4-0". -func (r *LocalRuntime) createProvidersFromInlineAlloy(ctx context.Context, modelRef string) ([]provider.Provider, error) { - var providers []provider.Provider - - for part := range strings.SplitSeq(modelRef, ",") { - part = strings.TrimSpace(part) - if part == "" { - continue - } - - // Check if this part exists as a named model in config - if modelCfg, exists := r.modelSwitcherCfg.Models[part]; exists { - prov, err := r.createProviderFromConfig(ctx, &modelCfg) +// resolveModelRef resolves a single model reference to a provider. +// The ref can be either a named model from config or an inline "provider/model" spec. +func (r *LocalRuntime) resolveModelRef(ctx context.Context, ref string) (provider.Provider, error) { + // Check if it's a named model from config + if modelCfg, exists := r.modelSwitcherCfg.Models[ref]; exists { + // If the config itself contains comma-separated models, expand it + if modelCfg.Provider == "" && strings.Contains(modelCfg.Model, ",") { + // This is an alloy config - but we're resolving a single ref, so this is an error + // unless the caller handles alloys explicitly + providers, err := r.resolveModelRefs(ctx, splitModelRefs(modelCfg.Model)) if err != nil { - return nil, fmt.Errorf("failed to create provider for %q: %w", part, err) + return nil, err } - providers = append(providers, prov) - continue - } - - // Parse as provider/model - providerName, modelName, ok := strings.Cut(part, "/") - if !ok { - return nil, fmt.Errorf("invalid model reference %q in inline alloy: expected 'provider/model' format", part) - } - - inlineCfg := &latest.ModelConfig{ - Provider: providerName, - Model: modelName, - } - prov, err := r.createProviderFromConfig(ctx, inlineCfg) - if err != nil { - return nil, fmt.Errorf("failed to create provider for %q: %w", part, err) + if len(providers) == 1 { + return providers[0], nil + } + return nil, fmt.Errorf("model %q resolves to multiple providers; use resolveModelRefs instead", ref) } - providers = append(providers, prov) + return r.createProviderFromConfig(ctx, &modelCfg) } - if len(providers) == 0 { - return nil, fmt.Errorf("inline alloy spec has no valid models") + // Try to parse as inline "provider/model" spec + providerName, modelName, ok := strings.Cut(ref, "/") + if !ok { + return nil, fmt.Errorf("invalid model reference %q: expected a model name from config or 'provider/model' format", ref) } - return providers, nil + cfg := &latest.ModelConfig{ + Provider: providerName, + Model: modelName, + } + return r.createProviderFromConfig(ctx, cfg) } -// createProvidersFromAlloyConfig creates providers for each model in an alloy configuration. -func (r *LocalRuntime) createProvidersFromAlloyConfig(ctx context.Context, alloyCfg latest.ModelConfig) ([]provider.Provider, error) { - var providers []provider.Provider - - for modelRef := range strings.SplitSeq(alloyCfg.Model, ",") { - modelRef = strings.TrimSpace(modelRef) - if modelRef == "" { - continue - } - - // Check if this model reference exists in the config - if modelCfg, exists := r.modelSwitcherCfg.Models[modelRef]; exists { - prov, err := r.createProviderFromConfig(ctx, &modelCfg) - if err != nil { - return nil, fmt.Errorf("failed to create provider for %q: %w", modelRef, err) - } - providers = append(providers, prov) - continue - } - - // Try parsing as inline spec (provider/model) - providerName, modelName, ok := strings.Cut(modelRef, "/") - if !ok { - return nil, fmt.Errorf("invalid model reference %q in alloy config: expected 'provider/model' format", modelRef) - } - - inlineCfg := &latest.ModelConfig{ - Provider: providerName, - Model: modelName, - } - prov, err := r.createProviderFromConfig(ctx, inlineCfg) +// resolveModelRefs resolves multiple model references to providers. +func (r *LocalRuntime) resolveModelRefs(ctx context.Context, refs []string) ([]provider.Provider, error) { + providers := make([]provider.Provider, 0, len(refs)) + for _, ref := range refs { + prov, err := r.resolveModelRef(ctx, ref) if err != nil { - return nil, fmt.Errorf("failed to create provider for %q: %w", modelRef, err) + return nil, fmt.Errorf("model %q: %w", ref, err) } providers = append(providers, prov) } - - if len(providers) == 0 { - return nil, fmt.Errorf("alloy model config has no valid models") - } - return providers, nil } diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index ad951f337..166b453bb 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -162,7 +162,7 @@ type LocalRuntime struct { modelsStore ModelStore sessionCompaction bool managedOAuth bool - startupInfoEmitted bool // Track if startup info has been emitted to avoid unnecessary duplication + startupInfoEmitted atomic.Bool // Track if startup info has been emitted to avoid unnecessary duplication elicitationRequestCh chan ElicitationResult // Channel for receiving elicitation responses elicitationEventsChannel chan Event // Current events channel for sending elicitation requests elicitationEventsChannelMux sync.RWMutex // Protects elicitationEventsChannel @@ -554,16 +554,15 @@ func (r *LocalRuntime) PermissionsInfo() *PermissionsInfo { // This should be called when replacing a session to allow re-emission of // agent, team, and toolset info to the UI. func (r *LocalRuntime) ResetStartupInfo() { - r.startupInfoEmitted = false + r.startupInfoEmitted.Store(false) } // EmitStartupInfo emits initial agent, team, and toolset information for immediate sidebar display func (r *LocalRuntime) EmitStartupInfo(ctx context.Context, events chan Event) { - // Prevent duplicate emissions - if r.startupInfoEmitted { + // Prevent duplicate emissions using atomic compare-and-swap + if !r.startupInfoEmitted.CompareAndSwap(false, true) { return } - r.startupInfoEmitted = true a := r.CurrentAgent() @@ -993,7 +992,7 @@ func (r *LocalRuntime) getTools(ctx context.Context, a *agent.Agent, sessionSpan return agentTools, nil } -// configureToolsetHandlers sets up elicitation and OAuth handlers for all toolsets of an agent. +// configureToolsetHandlers sets up elicitation, OAuth, and model switch handlers for all toolsets of an agent. func (r *LocalRuntime) configureToolsetHandlers(a *agent.Agent, events chan Event) { for _, toolset := range a.ToolSets() { tools.ConfigureHandlers(toolset, @@ -1001,9 +1000,44 @@ func (r *LocalRuntime) configureToolsetHandlers(a *agent.Agent, events chan Even func() { events <- Authorization(tools.ElicitationActionAccept, r.currentAgent) }, r.managedOAuth, ) + + // Configure switch_model callback if this is a SwitchModelToolset + if switchModelToolset := unwrapSwitchModelToolset(toolset); switchModelToolset != nil { + r.configureSwitchModelCallback(a, switchModelToolset, events) + } } } +// unwrapSwitchModelToolset extracts a SwitchModelToolset from a potentially wrapped toolset. +// It uses tools.DeepAs to recursively unwrap any wrapper toolsets (StartableToolSet, +// filterTools, replaceInstruction, etc.) until it finds the SwitchModelToolset. +// Returns the SwitchModelToolset if found, or nil if the toolset is not a SwitchModelToolset. +func unwrapSwitchModelToolset(toolset tools.ToolSet) *builtin.SwitchModelToolset { + switchModelTS, _ := tools.DeepAs[*builtin.SwitchModelToolset](toolset) + return switchModelTS +} + +// configureSwitchModelCallback sets up the callback for the switch_model toolset +// so that when the model is switched, the agent's model override is updated +// and the TUI is notified via events. +func (r *LocalRuntime) configureSwitchModelCallback(a *agent.Agent, switchModelToolset *builtin.SwitchModelToolset, events chan Event) { + switchModelToolset.SetOnSwitchCallback(func(newModel string) error { + ctx := context.Background() + if err := r.SetAgentModel(ctx, a.Name(), newModel); err != nil { + slog.Error("Failed to switch model via switch_model tool", "agent", a.Name(), "model", newModel, "error", err) + return err + } + slog.Debug("Model switched via switch_model tool", "agent", a.Name(), "model", newModel) + + // Emit events to update the TUI sidebar with the new model + if events != nil { + events <- AgentInfo(a.Name(), getAgentModelID(a), a.Description(), a.WelcomeMessage()) + events <- TeamInfo(r.agentDetailsFromTeam(), r.currentAgent) + } + return nil + }) +} + // emitAgentWarningsWithSend emits agent warnings using the provided send function for context-aware sending. func (r *LocalRuntime) emitAgentWarningsWithSend(a *agent.Agent, send func(Event) bool) { warnings := a.DrainWarnings() diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 0d01a82d9..a11f2619c 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" @@ -26,6 +27,7 @@ import ( "github.com/docker/cagent/pkg/session" "github.com/docker/cagent/pkg/team" "github.com/docker/cagent/pkg/tools" + tools_builtin "github.com/docker/cagent/pkg/tools/builtin" ) type stubToolSet struct { @@ -1317,3 +1319,342 @@ func TestToolRejectionWithoutReason(t *testing.T) { require.Equal(t, "The user rejected the tool call.", toolResponse.Response) require.NotContains(t, toolResponse.Response, "Reason:") } + +func TestSwitchModelTool_IntegrationWithRuntime(t *testing.T) { + // This test verifies that the switch_model tool correctly updates the + // agent's model override when used through the runtime. + + // Create a switch_model toolset + switchModelToolset, err := tools_builtin.NewSwitchModelToolset([]string{"fast", "powerful"}) + require.NoError(t, err) + + // Initial model + defaultModel := &mockProvider{id: "default-model", stream: &mockStream{}} + + // Create an agent with the switch_model toolset + root := agent.New("root", "You are a test agent", + agent.WithModel(defaultModel), + agent.WithToolSets(switchModelToolset), + ) + tm := team.New(team.WithAgents(root)) + + // Create runtime with model switcher config + modelSwitcherCfg := &ModelSwitcherConfig{ + Models: map[string]latest.ModelConfig{ + "fast": { + Provider: "openai", + Model: "gpt-4o-mini", + }, + "powerful": { + Provider: "openai", + Model: "gpt-4o", + }, + }, + Providers: nil, + EnvProvider: &mockEnvProvider{vars: map[string]string{"ANTHROPIC_API_KEY": "test-key"}}, + AgentDefaultModels: map[string]string{"root": "fast"}, + } + + rt, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + WithModelSwitcherConfig(modelSwitcherCfg), + ) + require.NoError(t, err) + + // Verify initial model + require.Equal(t, "default-model", root.Model().ID()) + require.False(t, root.HasModelOverride()) + + // Manually configure the toolset handlers (simulating what RunStream does) + events := make(chan Event, 10) + rt.configureToolsetHandlers(root, events) + + // Get the switch_model tool from the agent + agentTools, err := root.Tools(t.Context()) + require.NoError(t, err) + + var switchModelTool tools.Tool + for _, tool := range agentTools { + if tool.Name == "switch_model" { + switchModelTool = tool + break + } + } + require.NotEmpty(t, switchModelTool.Name, "switch_model tool should be available") + + // Call the switch_model tool to switch to "powerful" + toolCall := tools.ToolCall{ + ID: "test-call-1", + Function: tools.FunctionCall{ + Name: "switch_model", + Arguments: `{"model": "powerful"}`, + }, + } + + result, err := switchModelTool.Handler(t.Context(), toolCall) + require.NoError(t, err) + require.False(t, result.IsError, "switch_model should succeed: %s", result.Output) + require.Contains(t, result.Output, `Switched model from "fast" to "powerful"`) + + // Verify the model override was set + require.True(t, root.HasModelOverride(), "agent should have model override after switch") + require.Equal(t, "openai/gpt-4o", root.Model().ID(), "model should be switched to powerful") + + // Verify that AgentInfoEvent and TeamInfoEvent were emitted with the new model + var agentInfoEvent *AgentInfoEvent + var teamInfoEvent *TeamInfoEvent + for { + select { + case evt := <-events: + switch e := evt.(type) { + case *AgentInfoEvent: + agentInfoEvent = e + case *TeamInfoEvent: + teamInfoEvent = e + } + default: + // No more events in the channel + goto done + } + } +done: + require.NotNil(t, agentInfoEvent, "AgentInfoEvent should be emitted after model switch") + assert.Equal(t, "root", agentInfoEvent.AgentName) + assert.Equal(t, "openai/gpt-4o", agentInfoEvent.Model, "AgentInfoEvent should contain the new model ID") + + require.NotNil(t, teamInfoEvent, "TeamInfoEvent should be emitted after model switch") + require.Len(t, teamInfoEvent.AvailableAgents, 1) + assert.Equal(t, "gpt-4o", teamInfoEvent.AvailableAgents[0].Model, "TeamInfoEvent should contain the new model name") +} + +func TestSwitchModelTool_WithInstructionWrapper(t *testing.T) { + // This test verifies that the switch_model tool works correctly even when + // wrapped with a custom wrapper that implements the Unwrapper interface. + // This simulates what happens when the toolset has an 'instruction' field + // in the config (like in gopher.yaml), which causes teamloader to wrap + // the toolset with replaceInstruction. + + // Create a switch_model toolset + switchModelToolset, err := tools_builtin.NewSwitchModelToolset([]string{"haiku", "opus"}) + require.NoError(t, err) + + // Create a wrapper that implements the Unwrapper interface (simulating + // what teamloader's WithInstructions does when instruction is set). + wrappedToolset := &testInstructionWrapper{ToolSet: switchModelToolset} + + // Verify it implements Unwrapper + _, ok := tools.ToolSet(wrappedToolset).(tools.Unwrapper) + require.True(t, ok, "testInstructionWrapper should implement Unwrapper") + + // Initial model + defaultModel := &mockProvider{id: "anthropic/claude-haiku-4-5", stream: &mockStream{}} + + // Create an agent with the wrapped toolset + root := agent.New("root", "You are a test agent", + agent.WithModel(defaultModel), + agent.WithToolSets(wrappedToolset), + ) + tm := team.New(team.WithAgents(root)) + + // Create runtime with model switcher config + modelSwitcherCfg := &ModelSwitcherConfig{ + Models: map[string]latest.ModelConfig{ + "haiku": { + Provider: "anthropic", + Model: "claude-haiku-4-5", + }, + "opus": { + Provider: "anthropic", + Model: "claude-opus-4-5", + }, + }, + EnvProvider: &mockEnvProvider{vars: map[string]string{"ANTHROPIC_API_KEY": "test-key"}}, + AgentDefaultModels: map[string]string{"root": "haiku"}, + } + + rt, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + WithModelSwitcherConfig(modelSwitcherCfg), + ) + require.NoError(t, err) + + // Verify initial model + require.Equal(t, "anthropic/claude-haiku-4-5", root.Model().ID()) + require.False(t, root.HasModelOverride()) + + // Configure toolset handlers - this should find the SwitchModelToolset + // by unwrapping the instructionWrapper using the Unwrapper interface. + events := make(chan Event, 10) + rt.configureToolsetHandlers(root, events) + + // Get the switch_model tool from the agent + agentTools, err := root.Tools(t.Context()) + require.NoError(t, err) + + var switchModelTool tools.Tool + for _, tool := range agentTools { + if tool.Name == "switch_model" { + switchModelTool = tool + break + } + } + require.NotEmpty(t, switchModelTool.Name, "switch_model tool should be available") + + // Call the switch_model tool to switch to "opus" + toolCall := tools.ToolCall{ + ID: "test-call-1", + Function: tools.FunctionCall{ + Name: "switch_model", + Arguments: `{"model": "opus"}`, + }, + } + + result, err := switchModelTool.Handler(t.Context(), toolCall) + require.NoError(t, err) + require.False(t, result.IsError, "switch_model should succeed: %s", result.Output) + require.Contains(t, result.Output, `Switched model from "haiku" to "opus"`) + + // Verify the model override was set - this is the key assertion! + // This works because the unwrapSwitchModelToolset function now uses the + // Unwrapper interface to recursively unwrap any wrapper that implements it. + require.True(t, root.HasModelOverride(), "agent should have model override after switch") + require.Equal(t, "anthropic/claude-opus-4-5", root.Model().ID(), "model should be switched to opus") +} + +// testInstructionWrapper is a test helper that simulates teamloader's replaceInstruction wrapper. +// It implements the tools.Unwrapper interface to allow unwrapping. +type testInstructionWrapper struct { + tools.ToolSet +} + +func (w *testInstructionWrapper) Unwrap() tools.ToolSet { return w.ToolSet } + +func TestSwitchModelTool_MultiAgentScenario(t *testing.T) { + // This test simulates the multi-agent scenario from gopher.yaml + // where root has switch_model but librarian does not + + // Create mock providers + haikuModel := &mockProvider{id: "anthropic/claude-haiku-4-5"} + + // Create switch_model toolset for root + switchModelToolset, err := tools_builtin.NewSwitchModelToolset([]string{"haiku", "opus"}) + require.NoError(t, err) + + // Create agents: root with switch_model, librarian without + root := agent.New("root", "You are the root agent", + agent.WithModel(haikuModel), + agent.WithToolSets(switchModelToolset), + ) + + librarian := agent.New("librarian", "You are the librarian", + agent.WithModel(haikuModel), + ) + + tm := team.New(team.WithAgents(root, librarian)) + + // Create runtime with model switcher config + modelSwitcherCfg := &ModelSwitcherConfig{ + Models: map[string]latest.ModelConfig{ + "haiku": { + Provider: "anthropic", + Model: "claude-haiku-4-5", + }, + "opus": { + Provider: "anthropic", + Model: "claude-opus-4-5", + }, + }, + EnvProvider: &mockEnvProvider{vars: map[string]string{"ANTHROPIC_API_KEY": "test-key"}}, + AgentDefaultModels: map[string]string{"root": "haiku", "librarian": "haiku"}, + } + + rt, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + WithModelSwitcherConfig(modelSwitcherCfg), + ) + require.NoError(t, err) + + // Verify initial state + require.Equal(t, "anthropic/claude-haiku-4-5", root.Model().ID()) + require.False(t, root.HasModelOverride()) + + // Create events channel and configure handlers for root agent + events := make(chan Event, 20) + rt.configureToolsetHandlers(root, events) + + // Get tools for root agent (this creates the handler) + agentTools, err := root.Tools(t.Context()) + require.NoError(t, err) + + var switchModelTool tools.Tool + for _, tool := range agentTools { + if tool.Name == "switch_model" { + switchModelTool = tool + break + } + } + require.NotEmpty(t, switchModelTool.Name, "switch_model tool should be available") + + // Call the switch_model tool to switch to opus + toolCall := tools.ToolCall{ + ID: "test-multi-agent", + Function: tools.FunctionCall{ + Name: "switch_model", + Arguments: `{"model": "opus"}`, + }, + } + + t.Log("Calling switch_model tool in multi-agent scenario...") + result, err := switchModelTool.Handler(t.Context(), toolCall) + require.NoError(t, err) + require.False(t, result.IsError, "switch_model should succeed: %s", result.Output) + t.Logf("Result: %s", result.Output) + + // Verify model was switched + require.True(t, root.HasModelOverride(), "root should have model override") + require.Equal(t, "anthropic/claude-opus-4-5", root.Model().ID(), "root model should be switched to opus") + + // Verify events were emitted + var agentInfoEvent *AgentInfoEvent + var teamInfoEvent *TeamInfoEvent + for { + select { + case evt := <-events: + t.Logf("Event: %T", evt) + switch e := evt.(type) { + case *AgentInfoEvent: + agentInfoEvent = e + case *TeamInfoEvent: + teamInfoEvent = e + } + default: + goto done + } + } +done: + + require.NotNil(t, agentInfoEvent, "AgentInfoEvent should be emitted") + assert.Equal(t, "root", agentInfoEvent.AgentName) + assert.Equal(t, "anthropic/claude-opus-4-5", agentInfoEvent.Model) + + require.NotNil(t, teamInfoEvent, "TeamInfoEvent should be emitted") + // Team info should include both agents + require.Len(t, teamInfoEvent.AvailableAgents, 2) + + // Find root agent in team info and verify its model was updated + var rootAgentInfo *AgentDetails + for i := range teamInfoEvent.AvailableAgents { + if teamInfoEvent.AvailableAgents[i].Name == "root" { + rootAgentInfo = &teamInfoEvent.AvailableAgents[i] + break + } + } + require.NotNil(t, rootAgentInfo, "root agent should be in team info") + assert.Equal(t, "anthropic", rootAgentInfo.Provider) + assert.Equal(t, "claude-opus-4-5", rootAgentInfo.Model) + + t.Log("Multi-agent scenario test passed!") +} diff --git a/pkg/teamloader/filter.go b/pkg/teamloader/filter.go index 3c38092bb..e6ccb2132 100644 --- a/pkg/teamloader/filter.go +++ b/pkg/teamloader/filter.go @@ -43,7 +43,15 @@ type filterTools struct { } // Verify interface compliance -var _ tools.Instructable = (*filterTools)(nil) +var ( + _ tools.Instructable = (*filterTools)(nil) + _ tools.Unwrapper = (*filterTools)(nil) +) + +// Unwrap returns the underlying ToolSet. +func (f *filterTools) Unwrap() tools.ToolSet { + return f.ToolSet +} // Instructions implements tools.Instructable by delegating to the inner toolset. func (f *filterTools) Instructions() string { diff --git a/pkg/teamloader/instructions.go b/pkg/teamloader/instructions.go index 729392787..635b5eb69 100644 --- a/pkg/teamloader/instructions.go +++ b/pkg/teamloader/instructions.go @@ -23,7 +23,15 @@ type replaceInstruction struct { } // Verify interface compliance -var _ tools.Instructable = (*replaceInstruction)(nil) +var ( + _ tools.Instructable = (*replaceInstruction)(nil) + _ tools.Unwrapper = (*replaceInstruction)(nil) +) + +// Unwrap returns the underlying ToolSet. +func (a replaceInstruction) Unwrap() tools.ToolSet { + return a.ToolSet +} func (a replaceInstruction) Instructions() string { original := tools.GetInstructions(a.ToolSet) diff --git a/pkg/teamloader/registry.go b/pkg/teamloader/registry.go index 7374605cc..da70d0bcf 100644 --- a/pkg/teamloader/registry.go +++ b/pkg/teamloader/registry.go @@ -70,6 +70,7 @@ func NewDefaultToolsetRegistry() *ToolsetRegistry { r.Register("a2a", createA2ATool) r.Register("lsp", createLSPTool) r.Register("user_prompt", createUserPromptTool) + r.Register("switch_model", createSwitchModelTool) return r } @@ -284,3 +285,7 @@ func createLSPTool(ctx context.Context, toolset latest.Toolset, _ string, runCon func createUserPromptTool(_ context.Context, _ latest.Toolset, _ string, _ *config.RuntimeConfig) (tools.ToolSet, error) { return builtin.NewUserPromptTool(), nil } + +func createSwitchModelTool(_ context.Context, toolset latest.Toolset, _ string, _ *config.RuntimeConfig) (tools.ToolSet, error) { + return builtin.NewSwitchModelToolset(toolset.Models) +} diff --git a/pkg/teamloader/toon.go b/pkg/teamloader/toon.go index 3cb333f89..c2620e3b4 100644 --- a/pkg/teamloader/toon.go +++ b/pkg/teamloader/toon.go @@ -16,6 +16,14 @@ type toonTools struct { toolRegexps []*regexp.Regexp } +// Verify interface compliance +var _ tools.Unwrapper = (*toonTools)(nil) + +// Unwrap returns the underlying ToolSet. +func (f *toonTools) Unwrap() tools.ToolSet { + return f.ToolSet +} + func (f *toonTools) Tools(ctx context.Context) ([]tools.Tool, error) { allTools, err := f.ToolSet.Tools(ctx) if err != nil { diff --git a/pkg/tools/builtin/switch_model.go b/pkg/tools/builtin/switch_model.go new file mode 100644 index 000000000..683bdb8e4 --- /dev/null +++ b/pkg/tools/builtin/switch_model.go @@ -0,0 +1,193 @@ +package builtin + +import ( + "context" + "fmt" + "slices" + "strings" + "sync" + + "github.com/docker/cagent/pkg/tools" +) + +const ToolNameSwitchModel = "switch_model" + +// ModelSwitchCallback is called when the model is switched. +// It returns an error if the switch failed. +type ModelSwitchCallback func(newModel string) error + +// SwitchModelToolset provides a tool that allows agents to switch between +// a predefined set of models during a conversation. +type SwitchModelToolset struct { + mu sync.RWMutex + models []string + currentModel string // currently selected model + onSwitch ModelSwitchCallback // optional callback when model changes +} + +// Verify interface compliance +var ( + _ tools.ToolSet = (*SwitchModelToolset)(nil) + _ tools.Instructable = (*SwitchModelToolset)(nil) +) + +type SwitchModelArgs struct { + Model string `json:"model" jsonschema:"The model to switch to. Must be one of the allowed models listed in the tool description."` +} + +// NewSwitchModelToolset creates a new switch_model toolset with the given allowed models. +// The first model in the list becomes the default and initially selected model. +// Panics if models is empty or contains empty strings. +func NewSwitchModelToolset(models []string) (*SwitchModelToolset, error) { + if len(models) == 0 { + return nil, fmt.Errorf("switch_model toolset requires at least one model") + } + for i, m := range models { + if strings.TrimSpace(m) == "" { + return nil, fmt.Errorf("switch_model toolset: model at index %d is empty", i) + } + } + + return &SwitchModelToolset{ + models: slices.Clone(models), + currentModel: models[0], + }, nil +} + +// CurrentModel returns the currently selected model. +func (t *SwitchModelToolset) CurrentModel() string { + t.mu.RLock() + defer t.mu.RUnlock() + return t.currentModel +} + +// SetOnSwitchCallback sets a callback that will be invoked whenever the model is switched. +// The callback receives the new model name. This allows the runtime to react to model changes. +func (t *SwitchModelToolset) SetOnSwitchCallback(callback ModelSwitchCallback) { + t.mu.Lock() + defer t.mu.Unlock() + t.onSwitch = callback +} + +// Instructions returns guidance for when to use model switching. +func (t *SwitchModelToolset) Instructions() string { + return `## Model Switching Guidelines + +You have access to multiple AI models and can switch between them strategically. + +### When to Consider Switching Models + +**Switch to a faster/cheaper model when:** +- Performing simple, routine tasks (formatting, basic Q&A, short summaries) +- The current task doesn't require advanced reasoning +- Processing straightforward requests that any model can handle well +- Optimizing for response speed or cost efficiency + +**Switch to a more powerful model when:** +- Facing complex reasoning or multi-step problems +- Writing or reviewing code that requires careful analysis +- Handling nuanced or ambiguous requests +- Generating detailed technical content +- The current model is struggling with the task quality + +**Switch back to the default model when:** +- A specialized task is complete +- Returning to general conversation +- The extra capability is no longer needed + +### Best Practices + +1. Check the tool description to see available models and which one is currently active +2. Don't switch unnecessarily - there's overhead in changing models +3. Consider switching proactively before a complex task rather than after struggling +4. When in doubt about task complexity, prefer the more capable model` +} + +// callTool handles the switch_model tool invocation. +func (t *SwitchModelToolset) callTool(_ context.Context, params SwitchModelArgs) (*tools.ToolCallResult, error) { + requestedModel := strings.TrimSpace(params.Model) + if requestedModel == "" { + return tools.ResultError("model parameter is required and cannot be empty"), nil + } + + // Check if the requested model is in the allowed list + if !slices.Contains(t.models, requestedModel) { + return tools.ResultError(fmt.Sprintf( + "model %q is not allowed. Available models: %s", + requestedModel, + strings.Join(t.models, ", "), + )), nil + } + + // Get current state and callback atomically + t.mu.RLock() + previousModel := t.currentModel + callback := t.onSwitch + t.mu.RUnlock() + + // No-op if already on the requested model + if previousModel == requestedModel { + return tools.ResultSuccess(fmt.Sprintf("Model is already set to %q.", requestedModel)), nil + } + + // Notify the runtime about the model change (before updating internal state) + if callback != nil { + if err := callback(requestedModel); err != nil { + return tools.ResultError(fmt.Sprintf("Failed to switch model: %v", err)), nil + } + } + + // Update internal state after successful callback + t.mu.Lock() + t.currentModel = requestedModel + t.mu.Unlock() + + return tools.ResultSuccess(fmt.Sprintf("Switched model from %q to %q.", previousModel, requestedModel)), nil +} + +// Tools returns the switch_model tool definition. +func (t *SwitchModelToolset) Tools(context.Context) ([]tools.Tool, error) { + t.mu.RLock() + currentModel := t.currentModel + t.mu.RUnlock() + + description := t.buildDescription(currentModel) + + return []tools.Tool{ + { + Name: ToolNameSwitchModel, + Category: "model", + Description: description, + Parameters: tools.MustSchemaFor[SwitchModelArgs](), + OutputSchema: tools.MustSchemaFor[string](), + Handler: tools.NewHandler(t.callTool), + Annotations: tools.ToolAnnotations{ + ReadOnlyHint: true, + Title: "Switch Model", + }, + }, + }, nil +} + +// buildDescription generates the tool description with current state. +func (t *SwitchModelToolset) buildDescription(currentModel string) string { + var sb strings.Builder + + sb.WriteString("Switch the AI model used for subsequent responses.\n\n") + sb.WriteString("**Available models:**\n") + for _, m := range t.models { + fmt.Fprintf(&sb, "- %s", m) + if m == t.models[0] { + sb.WriteString(" (default)") + } + if m == currentModel { + sb.WriteString(" (current)") + } + sb.WriteString("\n") + } + sb.WriteString("\n") + sb.WriteString("Only the models listed above can be selected. ") + sb.WriteString("Any other model will be rejected.") + + return sb.String() +} diff --git a/pkg/tools/builtin/switch_model_test.go b/pkg/tools/builtin/switch_model_test.go new file mode 100644 index 000000000..873af74af --- /dev/null +++ b/pkg/tools/builtin/switch_model_test.go @@ -0,0 +1,238 @@ +package builtin + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewSwitchModelToolset(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + models []string + wantErr bool + }{ + {"valid models", []string{"fast", "powerful"}, false}, + {"empty list", []string{}, true}, + {"nil list", nil, true}, + {"empty model in list", []string{"fast", "", "powerful"}, true}, + {"whitespace-only model", []string{"fast", " ", "powerful"}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + toolset, err := NewSwitchModelToolset(tt.models) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.models[0], toolset.CurrentModel()) + }) + } +} + +func TestSwitchModelToolset_callTool(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + model string + wantError bool + wantOutput string + wantCurrent string + }{ + { + name: "switches to allowed model", + model: "powerful", + wantOutput: "Switched model from \"fast\" to \"powerful\"", + wantCurrent: "powerful", + }, + { + name: "already on requested model", + model: "fast", + wantOutput: "Model is already set to \"fast\"", + wantCurrent: "fast", + }, + { + name: "rejects unknown model", + model: "unknown", + wantError: true, + wantOutput: "model \"unknown\" is not allowed", + wantCurrent: "fast", + }, + { + name: "rejects empty model", + model: "", + wantError: true, + wantOutput: "model parameter is required", + wantCurrent: "fast", + }, + { + name: "rejects whitespace-only model", + model: " ", + wantError: true, + wantOutput: "model parameter is required", + wantCurrent: "fast", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := t.Context() + toolset, err := NewSwitchModelToolset([]string{"fast", "powerful"}) + require.NoError(t, err) + + result, err := toolset.callTool(ctx, SwitchModelArgs{Model: tt.model}) + + require.NoError(t, err) + assert.Equal(t, tt.wantError, result.IsError) + assert.Contains(t, result.Output, tt.wantOutput) + assert.Equal(t, tt.wantCurrent, toolset.CurrentModel()) + }) + } +} + +func TestSwitchModelToolset_Tools(t *testing.T) { + t.Parallel() + + ctx := t.Context() + toolset, err := NewSwitchModelToolset([]string{"fast", "powerful"}) + require.NoError(t, err) + + tools, err := toolset.Tools(ctx) + require.NoError(t, err) + require.Len(t, tools, 1) + + tool := tools[0] + assert.Equal(t, ToolNameSwitchModel, tool.Name) + assert.Equal(t, "model", tool.Category) + assert.True(t, tool.Annotations.ReadOnlyHint) + assert.NotNil(t, tool.Handler) + + // Description includes model info + assert.Contains(t, tool.Description, "fast (default) (current)") + assert.Contains(t, tool.Description, "powerful") + assert.Contains(t, tool.Description, "rejected") + + // After switching, description updates + _, _ = toolset.callTool(ctx, SwitchModelArgs{Model: "powerful"}) + tools, _ = toolset.Tools(ctx) + assert.Contains(t, tools[0].Description, "fast (default)") + assert.Contains(t, tools[0].Description, "powerful (current)") +} + +func TestSwitchModelToolset_Instructions(t *testing.T) { + t.Parallel() + + toolset, err := NewSwitchModelToolset([]string{"fast", "powerful"}) + require.NoError(t, err) + + instructions := toolset.Instructions() + + assert.Contains(t, instructions, "Model Switching Guidelines") + assert.Contains(t, instructions, "Switch to a faster/cheaper model") + assert.Contains(t, instructions, "Switch to a more powerful model") + assert.Contains(t, instructions, "Best Practices") +} + +func TestSwitchModelToolset_OnSwitchCallback(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + initialModel string + targetModel string + wantCallbackCall bool + wantCallbackArg string + }{ + { + name: "callback called on successful switch", + initialModel: "fast", + targetModel: "powerful", + wantCallbackCall: true, + wantCallbackArg: "powerful", + }, + { + name: "callback not called when already on model", + initialModel: "fast", + targetModel: "fast", + wantCallbackCall: false, + }, + { + name: "callback not called for invalid model", + initialModel: "fast", + targetModel: "unknown", + wantCallbackCall: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := t.Context() + + toolset, err := NewSwitchModelToolset([]string{"fast", "powerful"}) + require.NoError(t, err) + + var callbackCalled bool + var callbackArg string + toolset.SetOnSwitchCallback(func(newModel string) error { + callbackCalled = true + callbackArg = newModel + return nil + }) + + _, _ = toolset.callTool(ctx, SwitchModelArgs{Model: tt.targetModel}) + + assert.Equal(t, tt.wantCallbackCall, callbackCalled, "callback called mismatch") + if tt.wantCallbackCall { + assert.Equal(t, tt.wantCallbackArg, callbackArg, "callback argument mismatch") + } + }) + } +} + +func TestSwitchModelToolset_OnSwitchCallback_NilCallback(t *testing.T) { + t.Parallel() + + ctx := t.Context() + toolset, err := NewSwitchModelToolset([]string{"fast", "powerful"}) + require.NoError(t, err) + + // Ensure no panic when callback is nil + result, err := toolset.callTool(ctx, SwitchModelArgs{Model: "powerful"}) + require.NoError(t, err) + assert.False(t, result.IsError) + assert.Equal(t, "powerful", toolset.CurrentModel()) +} + +func TestSwitchModelToolset_OnSwitchCallback_WithError(t *testing.T) { + t.Parallel() + + ctx := t.Context() + toolset, err := NewSwitchModelToolset([]string{"fast", "powerful"}) + require.NoError(t, err) + + // Set callback that returns an error + callbackErr := fmt.Errorf("API key not configured") + toolset.SetOnSwitchCallback(func(newModel string) error { + return callbackErr + }) + + // Call the tool - should fail because callback returns error + result, err := toolset.callTool(ctx, SwitchModelArgs{Model: "powerful"}) + require.NoError(t, err) // No Go error, but tool error + assert.True(t, result.IsError, "should be a tool error") + assert.Contains(t, result.Output, "Failed to switch model") + assert.Contains(t, result.Output, "API key not configured") + + // Verify internal state was rolled back + assert.Equal(t, "fast", toolset.CurrentModel(), "internal state should be rolled back to previous model") +} diff --git a/pkg/tools/startable.go b/pkg/tools/startable.go index 9e737c538..a21985358 100644 --- a/pkg/tools/startable.go +++ b/pkg/tools/startable.go @@ -77,3 +77,43 @@ func As[T any](ts ToolSet) (T, bool) { result, ok := ts.(T) return result, ok } + +// Unwrapper is implemented by toolsets that wrap another toolset. +// This allows recursive unwrapping to find the underlying toolset. +type Unwrapper interface { + Unwrap() ToolSet +} + +// DeepAs performs a type assertion on a ToolSet, recursively unwrapping +// any wrapper toolsets (StartableToolSet and those implementing Unwrapper) +// until it finds a match or runs out of wrappers. +// +// Example: +// +// if switchModel, ok := tools.DeepAs[*builtin.SwitchModelToolset](toolset); ok { +// switchModel.SetCallback(...) +// } +func DeepAs[T any](ts ToolSet) (T, bool) { + for { + // Try to match the current toolset + if result, ok := ts.(T); ok { + return result, true + } + + // Try to unwrap StartableToolSet + if startable, ok := ts.(*StartableToolSet); ok { + ts = startable.ToolSet + continue + } + + // Try to unwrap via Unwrapper interface + if unwrapper, ok := ts.(Unwrapper); ok { + ts = unwrapper.Unwrap() + continue + } + + // No more unwrapping possible + var zero T + return zero, false + } +} diff --git a/pkg/tui/components/sidebar/agent_info_test.go b/pkg/tui/components/sidebar/agent_info_test.go new file mode 100644 index 000000000..474303b9b --- /dev/null +++ b/pkg/tui/components/sidebar/agent_info_test.go @@ -0,0 +1,165 @@ +package sidebar + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/runtime" + "github.com/docker/cagent/pkg/session" + "github.com/docker/cagent/pkg/tui/service" +) + +func TestSidebar_SetAgentInfoUpdatesAvailableAgents(t *testing.T) { + t.Parallel() + + sess := session.New() + sessionState := service.NewSessionState(sess) + sb := New(sessionState) + + m := sb.(*model) + + // Set initial team info with original model + m.SetTeamInfo([]runtime.AgentDetails{ + { + Name: "root", + Description: "Test agent", + Provider: "openai", + Model: "gpt-4o-mini", + }, + }) + + // Verify initial state + require.Len(t, m.availableAgents, 1) + assert.Equal(t, "openai", m.availableAgents[0].Provider) + assert.Equal(t, "gpt-4o-mini", m.availableAgents[0].Model) + + // Now simulate a model switch via SetAgentInfo with new model + m.SetAgentInfo("root", "anthropic/claude-sonnet-4-0", "Test agent") + + // Verify the model was updated in availableAgents + require.Len(t, m.availableAgents, 1) + assert.Equal(t, "anthropic", m.availableAgents[0].Provider, "Provider should be updated") + assert.Equal(t, "claude-sonnet-4-0", m.availableAgents[0].Model, "Model should be updated") +} + +func TestSidebar_SetAgentInfoUpdatesCorrectAgent(t *testing.T) { + t.Parallel() + + sess := session.New() + sessionState := service.NewSessionState(sess) + sb := New(sessionState) + + m := sb.(*model) + + // Set up multiple agents + m.SetTeamInfo([]runtime.AgentDetails{ + { + Name: "root", + Description: "Root agent", + Provider: "openai", + Model: "gpt-4o", + }, + { + Name: "helper", + Description: "Helper agent", + Provider: "anthropic", + Model: "claude-sonnet-4-0", + }, + }) + + // Switch the model for the helper agent + m.SetAgentInfo("helper", "google/gemini-2.0-flash", "Helper agent") + + // Verify only the helper agent's model was updated + require.Len(t, m.availableAgents, 2) + assert.Equal(t, "openai", m.availableAgents[0].Provider, "Root provider should not change") + assert.Equal(t, "gpt-4o", m.availableAgents[0].Model, "Root model should not change") + assert.Equal(t, "google", m.availableAgents[1].Provider, "Helper provider should be updated") + assert.Equal(t, "gemini-2.0-flash", m.availableAgents[1].Model, "Helper model should be updated") +} + +func TestSidebar_SetAgentInfoWithModelIDWithoutProvider(t *testing.T) { + t.Parallel() + + sess := session.New() + sessionState := service.NewSessionState(sess) + sb := New(sessionState) + + m := sb.(*model) + + // Set initial team info + m.SetTeamInfo([]runtime.AgentDetails{ + { + Name: "root", + Description: "Test agent", + Provider: "openai", + Model: "gpt-4o-mini", + }, + }) + + // Switch to a model ID without provider prefix (shouldn't happen but handle gracefully) + m.SetAgentInfo("root", "some-model", "Test agent") + + // Verify the model was set (provider should remain unchanged) + require.Len(t, m.availableAgents, 1) + assert.Equal(t, "openai", m.availableAgents[0].Provider, "Provider should not change for non-prefixed model") + assert.Equal(t, "some-model", m.availableAgents[0].Model, "Model should be updated to the full ID") +} + +func TestSidebar_SetAgentInfoForNonExistentAgent(t *testing.T) { + t.Parallel() + + sess := session.New() + sessionState := service.NewSessionState(sess) + sb := New(sessionState) + + m := sb.(*model) + + // Set initial team info + m.SetTeamInfo([]runtime.AgentDetails{ + { + Name: "root", + Description: "Test agent", + Provider: "openai", + Model: "gpt-4o-mini", + }, + }) + + // Try to set info for a non-existent agent (should not panic or modify existing agents) + m.SetAgentInfo("nonexistent", "anthropic/claude-sonnet-4-0", "Some agent") + + // Verify the existing agent was not modified + require.Len(t, m.availableAgents, 1) + assert.Equal(t, "openai", m.availableAgents[0].Provider) + assert.Equal(t, "gpt-4o-mini", m.availableAgents[0].Model) +} + +func TestSidebar_SetAgentInfoWithEmptyModelID(t *testing.T) { + t.Parallel() + + sess := session.New() + sessionState := service.NewSessionState(sess) + sb := New(sessionState) + + m := sb.(*model) + + // Set initial team info + m.SetTeamInfo([]runtime.AgentDetails{ + { + Name: "root", + Description: "Test agent", + Provider: "openai", + Model: "gpt-4o-mini", + }, + }) + + // Call SetAgentInfo with empty modelID (should not modify availableAgents) + m.SetAgentInfo("root", "", "Test agent") + + // Verify the existing agent's model was not modified + require.Len(t, m.availableAgents, 1) + assert.Equal(t, "openai", m.availableAgents[0].Provider) + assert.Equal(t, "gpt-4o-mini", m.availableAgents[0].Model) +} diff --git a/pkg/tui/components/sidebar/sidebar.go b/pkg/tui/components/sidebar/sidebar.go index 6913094a0..ac8f66e59 100644 --- a/pkg/tui/components/sidebar/sidebar.go +++ b/pkg/tui/components/sidebar/sidebar.go @@ -224,23 +224,27 @@ func (m *model) SetTodos(result *tools.ToolCallResult) error { return m.todoComp.SetTodos(result) } -// SetAgentInfo sets the current agent information and updates the model in availableAgents +// SetAgentInfo sets the current agent information func (m *model) SetAgentInfo(agentName, modelID, description string) { m.currentAgent = agentName m.agentModel = modelID m.agentDescription = description m.reasoningSupported = modelsdev.ModelSupportsReasoning(context.Background(), modelID) - // Update the model in availableAgents for the current agent - // This is important when model routing selects a different model than configured - // Extract just the model name from "provider/model" format to match TeamInfoEvent format + // Update the model in availableAgents to ensure the sidebar displays the correct model + // when the model is switched via the switch_model tool. + if modelID == "" { + return + } for i := range m.availableAgents { - if m.availableAgents[i].Name == agentName && modelID != "" { - modelName := modelID - if idx := strings.LastIndex(modelName, "/"); idx != -1 { - modelName = modelName[idx+1:] + if m.availableAgents[i].Name == agentName { + // Parse the modelID to extract provider and model name + if prov, modelName, found := strings.Cut(modelID, "/"); found { + m.availableAgents[i].Provider = prov + m.availableAgents[i].Model = modelName + } else { + m.availableAgents[i].Model = modelID } - m.availableAgents[i].Model = modelName break } }