Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 134 additions & 7 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,64 @@ var Models = []Model{
// import cycle. Signature matches api.QuotaCheck.
var QuotaCheck func(r *http.Request, op string) (bool, int, error)

// Load initialises the agent package (no-op for now; reserved for future use).
func Load() {}
// AppAgent defines a tailored agent for a specific app or service.
// Each app agent has its own system prompt and a restricted set of tools.
type AppAgent struct {
// ID is the unique identifier used in POST requests (e.g. "blog", "places").
ID string
// Name is the human-readable label shown in the UI.
Name string
// SystemPrompt overrides the default synthesis system prompt for this app.
SystemPrompt string
// Tools restricts which tools the planner may call. An empty slice means
// the full tool catalogue is available.
Tools []string
}

// appAgents holds all registered per-app agents, keyed by AppAgent.ID.
var appAgents = map[string]*AppAgent{}

// RegisterAppAgent registers (or replaces) an app-specific agent configuration.
func RegisterAppAgent(a *AppAgent) {
if a != nil && a.ID != "" {
appAgents[a.ID] = a
}
}

// Load initialises the agent package and registers built-in per-app agents.
func Load() {
RegisterAppAgent(&AppAgent{
ID: "blog",
Name: "Blog Writing Assistant",
SystemPrompt: "You are a helpful writing assistant for blog posts. " +
"Help the user with grammar, spelling, clarity, tone, and structure. " +
"Suggest relevant tags or topics when asked. " +
"When asked to check for moderation concerns (hate speech, spam, off-topic), " +
"flag them constructively. Be encouraging and supportive.",
Tools: []string{"blog_list", "search", "web_search"},
})

RegisterAppAgent(&AppAgent{
ID: "places",
Name: "Places Discovery Assistant",
SystemPrompt: "You are a helpful local guide and place recommendation assistant. " +
"Understand the user's intent from their query — whether they want food, " +
"entertainment, services, or something else — and recommend suitable places. " +
"Consider proximity, category, opening hours, and any preferences mentioned. " +
"Provide practical tips such as what to expect, best time to visit, and alternatives.",
Tools: []string{"places_search", "places_nearby", "weather_forecast"},
})

RegisterAppAgent(&AppAgent{
ID: "weather",
Name: "Weather Advisor",
SystemPrompt: "You are a practical weather advisor. " +
"Given the current forecast, suggest appropriate clothing, gear, and activities. " +
"Mention whether it is suitable for outdoor activities, and recommend indoor " +
"alternatives when the weather is poor. Keep advice concise and actionable.",
Tools: []string{"weather_forecast", "places_nearby"},
})
}

// Handler dispatches GET (page) and POST (query) at /agent.
func Handler(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -222,11 +278,42 @@ const agentToolsDesc = `Available tools (use exact name):
- wallet_balance: Check your wallet credit balance (no args)
- wallet_topup: Get available topup options to add credits to your wallet (no args)`

// allToolsDescMap maps a tool name to its single-line description used by
// buildRestrictedToolsDesc.
var allToolsDescMap = map[string]string{
"news": `news: Get latest news feed (no args)`,
"news_search": `news_search: Search news articles (args: {"query":"search term"})`,
"web_search": `web_search: Search the web for current information (args: {"query":"search term"})`,
"video_search": `video_search: Search for videos (args: {"query":"search term"})`,
"markets": `markets: Get live market prices (args: {"category":"crypto|futures|commodities"})`,
"weather_forecast": `weather_forecast: Get weather forecast (args: {"lat":number,"lon":number})`,
"places_search": `places_search: Search for places (args: {"q":"search name","near":"location"})`,
"places_nearby": `places_nearby: Find places near a location (args: {"address":"location","radius":number})`,
"reminder": `reminder: Get Islamic daily reminder (no args)`,
"search": `search: Search all Mu content (args: {"q":"search term"})`,
"blog_list": `blog_list: Get recent blog posts (no args)`,
"wallet_balance": `wallet_balance: Check your wallet credit balance (no args)`,
"wallet_topup": `wallet_topup: Get available topup options to add credits to your wallet (no args)`,
}

// buildRestrictedToolsDesc returns a tool description string limited to the
// given set of tool names.
func buildRestrictedToolsDesc(tools []string) string {
var lines []string
for _, t := range tools {
if desc, ok := allToolsDescMap[t]; ok {
lines = append(lines, "- "+desc)
}
}
return "Available tools (use exact name):\n" + strings.Join(lines, "\n")
}

// handleQuery processes an agent query request with SSE streaming.
func handleQuery(w http.ResponseWriter, r *http.Request) {
var req struct {
Prompt string `json:"prompt"`
Model string `json:"model"`
App string `json:"app"` // optional: selects an app-specific agent
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || strings.TrimSpace(req.Prompt) == "" {
http.Error(w, `{"error":"prompt required"}`, http.StatusBadRequest)
Expand Down Expand Up @@ -263,6 +350,12 @@ func handleQuery(w http.ResponseWriter, r *http.Request) {
}
}

// Resolve optional app-specific agent configuration.
var appAgent *AppAgent
if req.App != "" {
appAgent = appAgents[req.App]
}

// Start SSE stream
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
Expand All @@ -271,9 +364,16 @@ func handleQuery(w http.ResponseWriter, r *http.Request) {
// --- Step 1: plan tool calls ---
sse(w, map[string]any{"type": "thinking", "message": "Planning your request…"})

// Build the tool catalogue shown to the planner, restricted to the app's
// allowed tools when an app-specific agent is active.
toolsDesc := agentToolsDesc
if appAgent != nil && len(appAgent.Tools) > 0 {
toolsDesc = buildRestrictedToolsDesc(appAgent.Tools)
}

planPrompt := &ai.Prompt{
System: "You are an AI agent. Given a user question, output ONLY a JSON array of tool calls (no other text, no markdown).\n\n" +
agentToolsDesc +
toolsDesc +
"\n\nOutput format: [{\"tool\":\"tool_name\",\"args\":{}}]\nUse at most 5 tool calls. When the question asks for cross-source insights or correlations (e.g. news + markets, news + video), call multiple relevant tools. If no tools are needed output [].",
Question: req.Prompt,
Priority: ai.PriorityHigh,
Expand All @@ -294,7 +394,22 @@ func handleQuery(w http.ResponseWriter, r *http.Request) {
}
planJSON := extractJSONArray(planResult)
var toolCalls []toolCall
json.Unmarshal([]byte(planJSON), &toolCalls) //nolint:errcheck — fallback to empty slice
json.Unmarshal([]byte(planJSON), &toolCalls) //nolint:errcheck - fallback to empty slice

// If an app agent restricts tools, filter out any calls outside that set.
if appAgent != nil && len(appAgent.Tools) > 0 {
allowed := make(map[string]bool, len(appAgent.Tools))
for _, t := range appAgent.Tools {
allowed[t] = true
}
filtered := toolCalls[:0]
for _, tc := range toolCalls {
if allowed[tc.Tool] {
filtered = append(filtered, tc)
}
}
toolCalls = filtered
}

// --- Step 2: execute tool calls ---
type toolResult struct {
Expand Down Expand Up @@ -346,8 +461,16 @@ func handleQuery(w http.ResponseWriter, r *http.Request) {
}

today := time.Now().UTC().Format("Monday, 2 January 2006 (UTC)")
synthPrompt := &ai.Prompt{
System: "You are a helpful assistant. Today's date is " + today + ". " +

// Use the app-specific system prompt when available; otherwise the default.
var synthSystem string
if appAgent != nil && appAgent.SystemPrompt != "" {
synthSystem = appAgent.SystemPrompt + "\n\nToday's date is " + today + "."
if len(ragParts) > 0 {
synthSystem += "\n\nThe tool results below come from live data — treat them as current information."
}
} else {
synthSystem = "You are a helpful assistant. Today's date is " + today + ". " +
"The tool results below come from live data feeds — treat them as current information and use the article publication dates when reasoning about recency.\n\n" +
"Answer the user's question using ONLY the tool results provided below.\n\n" +
"IMPORTANT: For any prices, market values, weather conditions, or other real-time data, you MUST use " +
Expand All @@ -356,7 +479,11 @@ func handleQuery(w http.ResponseWriter, r *http.Request) {
"When results come from multiple sources (news, video, markets, weather, etc.), identify and highlight " +
"connections and correlations between them — for example, how a market move relates to a news story, " +
"or how videos cover the same topic appearing in the news.\n\n" +
"Use markdown formatting. Summarise key information from any news articles, weather data, market prices or other structured data.",
"Use markdown formatting. Summarise key information from any news articles, weather data, market prices or other structured data."
}

synthPrompt := &ai.Prompt{
System: synthSystem,
Rag: ragParts,
Question: req.Prompt,
Priority: ai.PriorityHigh,
Expand Down
149 changes: 149 additions & 0 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,152 @@ func TestStripHTMLTags(t *testing.T) {
t.Errorf("expected text content preserved, got %q", got)
}
}

func TestRegisterAppAgent(t *testing.T) {
// Registering a new agent should make it retrievable.
RegisterAppAgent(&AppAgent{
ID: "test_app",
Name: "Test App Agent",
SystemPrompt: "You are a test agent.",
Tools: []string{"news", "web_search"},
})
got, ok := appAgents["test_app"]
if !ok {
t.Fatal("expected test_app agent to be registered")
}
if got.Name != "Test App Agent" {
t.Errorf("expected name 'Test App Agent', got %q", got.Name)
}
if len(got.Tools) != 2 {
t.Errorf("expected 2 tools, got %d", len(got.Tools))
}
// Clean up
delete(appAgents, "test_app")
}

func TestRegisterAppAgent_NilAndEmpty(t *testing.T) {
// Nil agent must not panic or register anything.
before := len(appAgents)
RegisterAppAgent(nil)
if len(appAgents) != before {
t.Error("expected no change in registry after registering nil agent")
}

// Agent with empty ID must not be registered.
RegisterAppAgent(&AppAgent{ID: "", Name: "No ID"})
if len(appAgents) != before {
t.Error("expected no change in registry after registering agent with empty ID")
}
}

func TestLoadRegistersBuiltInAgents(t *testing.T) {
// Load() should register the blog, places, and weather agents.
Load()
for _, id := range []string{"blog", "places", "weather"} {
if _, ok := appAgents[id]; !ok {
t.Errorf("expected built-in agent %q to be registered after Load()", id)
}
}
}

func TestBuildRestrictedToolsDesc(t *testing.T) {
tools := []string{"news", "blog_list"}
got := buildRestrictedToolsDesc(tools)
if !strings.Contains(got, "news") {
t.Errorf("expected 'news' in restricted tools desc, got %q", got)
}
if !strings.Contains(got, "blog_list") {
t.Errorf("expected 'blog_list' in restricted tools desc, got %q", got)
}
if strings.Contains(got, "markets") {
t.Errorf("expected 'markets' to be excluded from restricted tools desc, got %q", got)
}
}

func TestBuildRestrictedToolsDesc_UnknownTool(t *testing.T) {
// Unknown tool names should be silently ignored.
got := buildRestrictedToolsDesc([]string{"unknown_tool", "news"})
if strings.Contains(got, "unknown_tool") {
t.Errorf("expected unknown_tool to be excluded, got %q", got)
}
if !strings.Contains(got, "news") {
t.Errorf("expected 'news' in result, got %q", got)
}
}

func TestBuiltInAgentSystemPrompts(t *testing.T) {
Load()

blog := appAgents["blog"]
if blog == nil {
t.Fatal("blog agent not registered")
}
if !strings.Contains(blog.SystemPrompt, "grammar") {
t.Errorf("expected blog agent system prompt to mention grammar, got %q", blog.SystemPrompt)
}

places := appAgents["places"]
if places == nil {
t.Fatal("places agent not registered")
}
if !strings.Contains(places.SystemPrompt, "recommend") {
t.Errorf("expected places agent system prompt to mention recommend, got %q", places.SystemPrompt)
}

weather := appAgents["weather"]
if weather == nil {
t.Fatal("weather agent not registered")
}
if !strings.Contains(weather.SystemPrompt, "clothing") {
t.Errorf("expected weather agent system prompt to mention clothing, got %q", weather.SystemPrompt)
}
}

func TestBuiltInAgentToolRestrictions(t *testing.T) {
Load()

// Blog agent should only allow blog/search tools.
blog := appAgents["blog"]
if blog == nil {
t.Fatal("blog agent not registered")
}
hasBlogList := false
for _, toolName := range blog.Tools {
if toolName == "blog_list" {
hasBlogList = true
}
}
if !hasBlogList {
t.Error("expected blog agent to include blog_list tool")
}

// Places agent should include places_search.
places := appAgents["places"]
if places == nil {
t.Fatal("places agent not registered")
}
hasPlacesSearch := false
for _, toolName := range places.Tools {
if toolName == "places_search" {
hasPlacesSearch = true
}
}
if !hasPlacesSearch {
t.Error("expected places agent to include places_search tool")
}

// Weather agent should include weather_forecast.
weather := appAgents["weather"]
if weather == nil {
t.Fatal("weather agent not registered")
}
hasWeatherForecast := false
for _, toolName := range weather.Tools {
if toolName == "weather_forecast" {
hasWeatherForecast = true
}
}
if !hasWeatherForecast {
t.Error("expected weather agent to include weather_forecast tool")
}
}
Loading