From 8c379c58d052da1091c22c77c0acf1f594572eea Mon Sep 17 00:00:00 2001 From: shangkunwang Date: Wed, 10 Jun 2026 16:57:48 +0000 Subject: [PATCH] feat: enhance autotuning process with ApplyBestConfigAgent and update prompts --- .../agent_client/auto_agent_client.py | 2 +- MaxKernel/auto_agent/custom_types.py | 27 +++++++- .../auto_agent/subagents/autotuning/agent.py | 64 ++++++++++++++++--- .../prompts/apply_best_config_prompt.py | 17 +++++ .../autotuning/prompts/summary_prompt.py | 18 ++---- 5 files changed, 106 insertions(+), 22 deletions(-) create mode 100644 MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py diff --git a/MaxKernel/auto_agent/agent_client/auto_agent_client.py b/MaxKernel/auto_agent/agent_client/auto_agent_client.py index a214f06..e1eb2fb 100644 --- a/MaxKernel/auto_agent/agent_client/auto_agent_client.py +++ b/MaxKernel/auto_agent/agent_client/auto_agent_client.py @@ -4,7 +4,7 @@ import requests -REQUEST_TIMEOUT = 60 * 60 * 3 +REQUEST_TIMEOUT = 60 * 60 * 5 # Configure logging diff --git a/MaxKernel/auto_agent/custom_types.py b/MaxKernel/auto_agent/custom_types.py index b0c4b67..471f973 100644 --- a/MaxKernel/auto_agent/custom_types.py +++ b/MaxKernel/auto_agent/custom_types.py @@ -1,17 +1,40 @@ import logging +from functools import cached_property from typing import AsyncGenerator from google.adk.agents import LlmAgent from google.adk.agents.invocation_context import InvocationContext from google.adk.events import Event, EventActions from google.adk.models.google_llm import Gemini -from google.genai import types +from google.genai import Client, types from auto_agent.constants import ( MODEL_NAME, ) +class TimeoutGemini(Gemini): + @cached_property + def api_client(self) -> Client: + base_url, api_version = self._base_url_and_api_version + kwargs_for_http_options = { + "headers": self._tracking_headers(), + "retry_options": self.retry_options, + "base_url": base_url, + "timeout": 240000, # 240 seconds in milliseconds + } + if api_version: + kwargs_for_http_options["api_version"] = api_version + + kwargs = { + "http_options": types.HttpOptions(**kwargs_for_http_options), + } + if self.model.startswith("projects/"): + kwargs["vertexai"] = True + + return Client(**kwargs) + + class CustomLlmAgent(LlmAgent): """Agent that allows early exit from the loop if a condition is met. @@ -22,7 +45,7 @@ def __init__(self, *args, **kwargs): """Initialize CustomLlmAgent with automatic Gemini model (with retry) wrapping.""" # If model is a string, use the pre-configured gemini_model with retry support if "model" in kwargs and isinstance(kwargs["model"], str): - gemini_model = Gemini( + gemini_model = TimeoutGemini( model=MODEL_NAME, retry_options=types.HttpRetryOptions( initial_delay=1, diff --git a/MaxKernel/auto_agent/subagents/autotuning/agent.py b/MaxKernel/auto_agent/subagents/autotuning/agent.py index d3a5145..df2eb6b 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/agent.py +++ b/MaxKernel/auto_agent/subagents/autotuning/agent.py @@ -5,7 +5,7 @@ import os from typing import AsyncGenerator, Optional -from google.adk.agents import BaseAgent, SequentialAgent +from google.adk.agents import BaseAgent from google.adk.agents.invocation_context import InvocationContext from google.adk.events import Event, EventActions @@ -14,6 +14,7 @@ from auto_agent.custom_types import CustomLlmAgent from auto_agent.subagents.autotuning.autotune_tool import autotune_kernel from auto_agent.subagents.autotuning.prompts import ( + apply_best_config_prompt, autotune_prompt, summary_prompt, ) @@ -159,7 +160,19 @@ async def _run_async_impl( ) -# 3. Summarizer Agent +# 3. Apply Best Config Agent +apply_best_config_agent = CustomLlmAgent( + name="ApplyBestConfigAgent", + model=MODEL_NAME, + generate_content_config=model_config, + planner=thinking_planner, + instruction=apply_best_config_prompt.PROMPT, + description="Applies autotuning results to the optimized kernel file.", + tools=[filesystem_tool_r, write_optimized_kernel_tool], +) + + +# 4. Summarizer Agent # This agent reads results from state and talks to the user. autotune_summary_agent = CustomLlmAgent( name="AutotuneSummaryAgent", @@ -167,14 +180,49 @@ async def _run_async_impl( generate_content_config=model_config, planner=thinking_planner, instruction=summary_prompt.PROMPT, - description="Apply and Summarizes autotuning results.", - tools=[filesystem_tool_r, write_optimized_kernel_tool], + description="Summarizes autotuning results.", + tools=[filesystem_tool_r], output_key="autotuning_summary", ) -autotune_agent = SequentialAgent( - name="AutotuneAgent", - sub_agents=[autotune_planner_agent, autotune_runner, autotune_summary_agent], -) + +class CombinedAutotuneAgent(BaseAgent): + """Chains autotuning steps and conditionally applies best config.""" + + def __init__(self, name: str): + super().__init__(name=name) + + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + logging.info(f"[{self.name}] Running AutotunePlannerAgent...") + async for event in autotune_planner_agent.run_async(ctx): + yield event + + logging.info(f"[{self.name}] Running AutotuneRunner...") + async for event in autotune_runner.run_async(ctx): + yield event + + autotune_results = ctx.session.state.get("autotune_results", {}) + if ( + autotune_results.get("status") == "success" + and autotune_results.get("best_config") is not None + and autotune_results.get("best_time_ms") is not None + ): + logging.info(f"[{self.name}] Running ApplyBestConfigAgent...") + async for event in apply_best_config_agent.run_async(ctx): + yield event + else: + logging.warning( + f"[{self.name}] Autotune was not successful or no best configuration" + " found. Skipping ApplyBestConfigAgent." + ) + + logging.info(f"[{self.name}] Running AutotuneSummaryAgent...") + async for event in autotune_summary_agent.run_async(ctx): + yield event + + +autotune_agent = CombinedAutotuneAgent(name="AutotuneAgent") __all__ = ["autotune_agent"] diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py new file mode 100644 index 0000000..b145038 --- /dev/null +++ b/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py @@ -0,0 +1,17 @@ +"""Prompt for ApplyBestConfigAgent.""" + +PROMPT = """You are a specialized agent for applying autotuning results to a Pallas kernel file. +Your goal is to read the best configuration from autotuning results and update the `optimized_kernel.py` file with these values. + +You must: +1. Use the `read_file` tool to read the autotuning specifications file at {autotune_specs_path?} to understand the code template and the placeholders that were tuned. +2. Use the `read_file` tool to read the autotuning results file at {autotune_results_path?} and extract the `"best_config"` from it. +3. Use the `read_file` tool to read the current optimized kernel code located at {optimized_kernel_path?}. +4. Apply `"best_config"` to the optimized kernel code by: + - Comparing the template structure from the specifications with the actual kernel code. + - Replacing the parameter values in the kernel code with the corresponding best values found in `"best_config"` (e.g., replace `BLOCK_M = 32` with `BLOCK_M = 128` if `best_config` contains `"BLOCK_M": 128`). Ensure the formatting of the script remains valid. + - Write the updated optimized kernel code back using the `restricted_write_file` tool. +5. Verify the best configuration is applied correctly by reading the updated file. + +Be precise and ensure you only change the specific parameter values identified in the best configuration. +""" diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py index e606f92..aa7f8c5 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py +++ b/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py @@ -1,9 +1,9 @@ """Prompt for AutotuneSummarizerAgent.""" PROMPT = """ -You are applying the best configuration and providing a summary of autotuning results. +You are providing a summary of autotuning results. -Your goal is to summarize the autotuning results provided below, report the best configuration and latency, and apply the best configuration if the status is success. +Your goal is to summarize the autotuning results provided below, report the best configuration and latency, and verify if the best configuration was applied if the status was success. Autotuning Results: {autotune_results?} @@ -13,20 +13,16 @@ ### Case 1: If the status is "success" You must: -1. Extract the `"best_cfg"` and `"best_time_ms"` from the results above. -2. Apply `"best_cfg"` to the kernel code located at {optimized_kernel_path?} by: - a. Use the `read_file` tool to read the kernel code {optimized_kernel_path?} - b. Replace their configured values with the values found in `best_config` (e.g., replace `BLOCK_M = 32` with `BLOCK_M = 128` if `best_config` contains `"BLOCK_M": 128`). Ensure the formatting of the script remains valid. - c. Write the updated optimized kernel code back using `restricted_write_file` tool. -3. Verify the best configuration is applied correctly by reading the updated file. -4. Provide a clear summary in your response. Do NOT list all tested configurations from `all_results`. +1. Extract the `"best_config"` and `"best_time_ms"` from the results above. +2. Verify that the best configuration was applied correctly to the kernel code by reading the file located at {optimized_kernel_path?}. +3. Provide a clear summary in your response. Do NOT list all tested configurations from `all_results`. ### Case 2: If the status is "failed" or "error" You must: -1. Report the error message and do NOT apply any configuration. +1. Report the error message. In all cases, you must: -1. Provide a clear summary in your response. Do NOT list all tested configurations from `all_results`. +Provide a clear summary in your response. Do NOT list all tested configurations from `all_results`. Please use the following format for your summary: ### Autotuning Results