diff --git a/MaxKernel/auto_agent/agent_client/run_batch_agent_call.py b/MaxKernel/auto_agent/agent_client/run_batch_agent_call.py index a63f2c9..f6927e3 100644 --- a/MaxKernel/auto_agent/agent_client/run_batch_agent_call.py +++ b/MaxKernel/auto_agent/agent_client/run_batch_agent_call.py @@ -2,6 +2,7 @@ import json import logging import os +import random import time from concurrent.futures import ThreadPoolExecutor, as_completed @@ -89,6 +90,11 @@ def process_problem( user_id = "user_0" session_id = f"session_{problem_id}_attempt_{attempt}_{int(time.time())}" + # Add random jitter to avoid SQLite database lock contention + jitter = random.uniform(0.1, 2.0) + logger.info(f"Sleeping for {jitter:.2f}s (jitter) to avoid DB lock.") + time.sleep(jitter) + client = AutoAgentClient( user_id=user_id, session_id=session_id, diff --git a/MaxKernel/auto_agent/server_utils/eval_server.py b/MaxKernel/auto_agent/server_utils/eval_server.py index 0ba22a5..c6d3385 100644 --- a/MaxKernel/auto_agent/server_utils/eval_server.py +++ b/MaxKernel/auto_agent/server_utils/eval_server.py @@ -228,6 +228,7 @@ async def _perform_evaluation(request: EvalRequest): payload["search_space"] = request.search_space backend_timeout = request.total_timeout payload["total_timeout"] = request.total_timeout + payload["dependencies"] = request.dependencies else: payload["code"] = request.code payload["dependencies"] = request.dependencies diff --git a/MaxKernel/auto_agent/server_utils/tpu_server.py b/MaxKernel/auto_agent/server_utils/tpu_server.py index f23ab31..0ceb994 100644 --- a/MaxKernel/auto_agent/server_utils/tpu_server.py +++ b/MaxKernel/auto_agent/server_utils/tpu_server.py @@ -4,6 +4,7 @@ import logging import os import re +import shutil import subprocess import sys import tempfile @@ -49,6 +50,7 @@ class AutotuneRequest(BaseModel): search_space: dict[str, list] timeout: Optional[int] = 300 total_timeout: Optional[int] = None + dependencies: Optional[dict] = None class GetTpuVersionResponse(BaseModel): @@ -512,7 +514,17 @@ async def profile(request: CodeRequest): async def autotune(request: AutotuneRequest): logging.info("Starting autotune") async with performance_semaphore: + temp_dir = None try: + # Create unique temporary directory for this autotune request + temp_dir = tempfile.mkdtemp() + if request.dependencies: + for filename, content in request.dependencies.items(): + file_path = os.path.join(temp_dir, filename) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "w") as f: + f.write(content) + # Generate all combinations keys = list(request.search_space.keys()) values = list(request.search_space.values()) @@ -543,11 +555,9 @@ async def autotune(request: AutotuneRequest): continue # Execute the code - with tempfile.NamedTemporaryFile( - mode="w", suffix=".py", prefix="hitl_eval_", delete=False - ) as temp_file: + temp_file_path = os.path.join(temp_dir, "run_code.py") + with open(temp_file_path, "w") as temp_file: temp_file.write(code_content) - temp_file_path = temp_file.name process = None try: @@ -556,7 +566,7 @@ async def autotune(request: AutotuneRequest): temp_file_path, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - cwd=tempfile.gettempdir(), + cwd=temp_dir, ) stdout, stderr = await asyncio.wait_for( @@ -568,24 +578,45 @@ async def autotune(request: AutotuneRequest): exit_code = process.returncode if exit_code == 0: - # Parse RESULT_TIME - match = re.search(r"RESULT_TIME:\s*([0-9.]+)\s*ms", output) - if match: - time_taken = float(match.group(1)) + correctness_match = re.search( + r"CORRECTNESS:\s*(true|false)", output, re.IGNORECASE + ) + time_match = re.search( + r"RESULT_TIME:\s*([0-9.]+)\s*ms", output, re.IGNORECASE + ) + + correctness_passed = False + if correctness_match: + correctness_passed = correctness_match.group(1).lower() == "true" + + time_taken = float(time_match.group(1)) if time_match else None + + if not correctness_passed: + logging.warning( + f"Correctness check failed or unknown for config {cfg}" + ) all_results.append( - {"cfg": cfg, "time": time_taken, "status": "success"} + { + "cfg": cfg, + "status": "correctness_failed_or_unknown", + "output": output, + } ) - if time_taken < best_time: - best_time = time_taken - best_cfg = cfg - best_output = output - else: + elif time_taken is None: logging.warning( f"No RESULT_TIME found in output for config {cfg}" ) all_results.append( {"cfg": cfg, "status": "no_result_time", "output": output} ) + else: + all_results.append( + {"cfg": cfg, "time": time_taken, "status": "success"} + ) + if time_taken < best_time: + best_time = time_taken + best_cfg = cfg + best_output = output else: logging.warning( f"Config {cfg} failed with exit code {exit_code}. Stderr: {error}" @@ -602,8 +633,11 @@ async def autotune(request: AutotuneRequest): except asyncio.TimeoutError: logging.warning(f"Config {cfg} timed out") if process: - process.kill() - await process.wait() + try: + process.kill() + await process.wait() + except Exception as e: + logging.error(f"Failed to kill process: {e}") all_results.append({"cfg": cfg, "status": "timeout"}) except Exception as e: logging.error(f"Error running config {cfg}: {e}") @@ -611,7 +645,7 @@ async def autotune(request: AutotuneRequest): {"cfg": cfg, "status": "exception", "error": str(e)} ) finally: - if "temp_file_path" in locals(): + if "temp_file_path" in locals() and os.path.exists(temp_file_path): try: os.unlink(temp_file_path) except OSError: @@ -632,6 +666,14 @@ async def autotune(request: AutotuneRequest): except Exception as e: logging.error(f"Autotune failed with error: {str(e)}") raise HTTPException(status_code=500, detail=f"Autotune error: {str(e)}") + finally: + if temp_dir and os.path.exists(temp_dir): + try: + shutil.rmtree(temp_dir) + except Exception as e: + logging.error( + f"Failed to clean up autotune temp directory {temp_dir}: {e}" + ) @app.post("/get_tpu_version", response_model=GetTpuVersionResponse) diff --git a/MaxKernel/auto_agent/subagents/autotuning/agent.py b/MaxKernel/auto_agent/subagents/autotuning/agent.py index 80fff69..d3a5145 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/agent.py +++ b/MaxKernel/auto_agent/subagents/autotuning/agent.py @@ -5,7 +5,7 @@ import os from typing import AsyncGenerator, Optional -from google.adk.agents import BaseAgent +from google.adk.agents import BaseAgent, SequentialAgent from google.adk.agents.invocation_context import InvocationContext from google.adk.events import Event, EventActions @@ -14,7 +14,6 @@ from auto_agent.custom_types import CustomLlmAgent from auto_agent.subagents.autotuning.autotune_tool import autotune_kernel from auto_agent.subagents.autotuning.prompts import ( - apply_best_config_prompt, autotune_prompt, summary_prompt, ) @@ -25,7 +24,7 @@ ) from auto_agent.tools.search_api_tool import search_api_tool -# 1. Planner Agent (LLM) +# 1. Planner Agent # This agent identifies parameters, creates the template, and defines the search space. # It saves them to session state instead of calling the tool directly. autotune_planner_agent = CustomLlmAgent( @@ -106,6 +105,17 @@ async def _run_async_impl( ) return + dependencies = {} + base_kernel_path = ctx.session.state.get("base_kernel_path", "") + if base_kernel_path and os.path.exists(base_kernel_path): + try: + with open(base_kernel_path, "r") as f: + dependencies[os.path.basename(base_kernel_path)] = f.read() + except Exception as e: + logging.warning( + f"[{self.name}] Failed to read base kernel file {base_kernel_path}: {e}" + ) + logging.info(f"[{self.name}] Running autotune for {kernel_name}") try: @@ -114,6 +124,7 @@ async def _run_async_impl( code_template=code_template, search_space=search_space, backend="tpu", + dependencies=dependencies, ) try: @@ -147,18 +158,8 @@ async def _run_async_impl( output_key="autotune_results", ) -# 3. Apply Best Config Agent -apply_best_config_agent = CustomLlmAgent( - name="ApplyBestConfigAgent", - model=MODEL_NAME, - generate_content_config=model_config, - planner=thinking_planner, - instruction=apply_best_config_prompt.PROMPT, - description="Applies autotuning results to the optimized kernel file.", - tools=[filesystem_tool_r, write_optimized_kernel_tool], -) -# 4. Summarizer Agent +# 3. Summarizer Agent # This agent reads results from state and talks to the user. autotune_summary_agent = CustomLlmAgent( name="AutotuneSummaryAgent", @@ -166,48 +167,14 @@ async def _run_async_impl( generate_content_config=model_config, planner=thinking_planner, instruction=summary_prompt.PROMPT, - description="Summarizes autotuning results for the user.", - tools=[filesystem_tool_r], + description="Apply and Summarizes autotuning results.", + tools=[filesystem_tool_r, write_optimized_kernel_tool], output_key="autotuning_summary", ) - -class CombinedAutotuneAgent(BaseAgent): - """Chains autotuning steps and conditionally applies best config.""" - - def __init__(self, name: str): - super().__init__(name=name) - - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - logging.info(f"[{self.name}] Running AutotunePlannerAgent...") - async for event in autotune_planner_agent.run_async(ctx): - yield event - - logging.info(f"[{self.name}] Running AutotuneRunner...") - async for event in autotune_runner.run_async(ctx): - yield event - - autotune_results = ctx.session.state.get("autotune_results", {}) - if ( - autotune_results.get("status") == "success" - and autotune_results.get("best_config") is not None - ): - logging.info(f"[{self.name}] Running ApplyBestConfigAgent...") - async for event in apply_best_config_agent.run_async(ctx): - yield event - else: - logging.warning( - f"[{self.name}] Autotune was not successful or no best configuration" - " found. Skipping ApplyBestConfigAgent." - ) - - logging.info(f"[{self.name}] Running AutotuneSummaryAgent...") - async for event in autotune_summary_agent.run_async(ctx): - yield event - - -autotune_agent = CombinedAutotuneAgent(name="AutotuneAgent") +autotune_agent = SequentialAgent( + name="AutotuneAgent", + sub_agents=[autotune_planner_agent, autotune_runner, autotune_summary_agent], +) __all__ = ["autotune_agent"] diff --git a/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py b/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py index 6cf3151..524ade7 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py +++ b/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py @@ -19,6 +19,7 @@ async def autotune_kernel( search_space: dict[str, list[Any]], backend: str = None, server_addr: str = "http://localhost", + dependencies: dict[str, str] = None, ) -> dict: """Runs a grid search to auto-tune a Pallas kernel on a remote server. @@ -31,6 +32,8 @@ async def autotune_kernel( values. backend: 'tpu' or 'cpu'. server_addr: Address of the server (default: http://localhost). + dependencies: A dictionary mapping dependency filenames to their contents. + Used to provide reference code for correctness check. Returns: A dictionary containing the status, optimal parameters, and a summary of @@ -52,6 +55,7 @@ async def autotune_kernel( "timeout": AUTOTUNE_INDIVIDUAL_TIMEOUT, "backend_type": backend, "total_timeout": AUTOTUNE_TOTAL_TIMEOUT, + "dependencies": dependencies, } result = await call_eval_server_async( session, diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py deleted file mode 100644 index 363b123..0000000 --- a/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Prompt for ApplyBestConfigAgent.""" - -PROMPT = """You are a specialized agent for applying autotuning results to a Pallas kernel file. -Your goal is to read the best configuration from autotuning results and update the `optimized_kernel.py` file with these values. - -You must: -1. use the `read_file` tool to read the file at {autotune_specs_path?} to get the context of autotuning experiment. -2. Use the `read_file` tool to read the file at {autotune_results_path?} and parse the JSON content of the autotune results to find the best configuration `best_config`. -3. Use the `read_file` tool to read the current kernel file at {optimized_kernel_path?}. -4. Use the `restricted_write_file` tool to save the updated kernel file, replacing the old parameter values with the values from `best_config`. - -Be precise and ensure you only change the specific parameter values identified in the best configuration. -""" diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py index 23a5e2a..34db510 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py +++ b/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py @@ -3,12 +3,16 @@ PROMPT = """You are a specialized agent for preparing autotuning specifications for Pallas kernels. Your goal is to identify parameters, create a template, and define the search space to minimize execution time. -CRITICAL: Do NOT attempt to optimize the kernel code, improve its logic, or fix any bugs. Your task is strictly to prepare the template for autotuning by replacing hardcoded parameters with placeholders and adding timing code. - To prepare for autotuning, you must: -1. Identify the parameters that can be tuned in the kernel (e.g., BLOCK_M, BLOCK_N). -2. Create a code template from the kernel code, replacing the specific parameter values with placeholders enclosed in curly braces (for example, if the parameter is BLOCK_M, use it enclosed in curly braces as the placeholder). -3. Ensure the template code prints "RESULT_TIME: ms" to indicate the average execution time in microseconds. To get accurate and quick timing, wrap the kernel call in a loop of exactly 10 iterations (preceded by 1 warm-up execution) and use `jax.block_until_ready()`. Limit iterations strictly to 10 to keep profiling runs fast. WARNING: If you wrap the kernel call in a loop, check if the kernel donates its input buffers (look for `donate_argnames` in the kernel decorator). If it does, calling it repeatedly with the same inputs will fail. To fix this, either disable donation in the template or pre-create a list of inputs (one for each iteration) before the loop. +1. Use `read_file` tool to read the optimized kernel code located at {optimized_kernel_path?}. +2. Identify the parameters that can be tuned in the kernel (e.g., BLOCK_M, BLOCK_N). +3. Create a code_template from the optimized kernel code, replacing the specific parameter values with placeholders enclosed in curly braces (for example, if the parameter is BLOCK_M, use it enclosed in curly braces as the placeholder). + To help you build this `code_template`, you must read the reference kernel code located at {base_kernel_path?} and the testing script located at {test_file_path?} to understand the reference computation inputs, outputs, and validation logic. + All correctness check and timing logic must be defined inside this `code_template`: + - **Reference Computation**: The reference kernel will be automatically written to a file named `base_kernel.py` in the execution directory. Import functions/implementations directly from it (e.g. `from base_kernel import computation as reference_computation, get_inputs`). + - **Correctness Check**: In the main block of the template, perform a correctness check comparing the tuned kernel's output against the reference implementation's output (using `jnp.allclose` or `np.testing.assert_allclose` with appropriate tolerances, e.g., atol={atol?}, rtol={rtol?}). Note that you must JIT compile both the reference and tuned computation function and invoke the jitted function to obtain its outputs for the correctness check, as Pallas kernels require compilation to execute correctly on TPU. + - **Timing/Warmup**: Wrap the tuned kernel call in a loop of exactly 10 iterations (preceded by 1 warm-up execution) and use `jax.block_until_ready()`. Limit iterations strictly to 10 to keep profiling runs fast. WARNING: If you wrap the kernel call in a loop, check if the kernel donates its input buffers. If it does, calling it repeatedly with the same inputs will fail. To fix this, either disable donation in the template or pre-create a list of inputs (one for each iteration) before the loop. + - **Printing Results**: The template code must always print "CORRECTNESS: " and "RESULT_TIME: ms". 4. Define a highly optimized, high-probability search space as a dictionary mapping placeholder names to lists of suggested values. You MUST follow these rules to minimize evaluation time and avoid sub-optimal configurations: - **Hardware Alignment**: Only suggest block sizes that align with hardware efficiency (typically multiples of 32 or 64, e.g., `[32, 64, 128]`). Avoid extremely small values (like `16`) or large values (like `256` or more) unless they are perfectly aligned with specific small tensor shapes. - **Dimension Divisors**: Choose suggested block sizes that are clean, even divisors of the corresponding matrix or tensor shape dimensions to prevent compiler masking and branch overhead. @@ -25,7 +29,10 @@ 1. **`search_api`**: Search for API definitions 2. **`read_file`**: Read the kernel code file. - Required Argument: `path` -3. **`restricted_write_file`**: Write the json file - - Required Argument: `content` (The complete file content) - - Example: `restricted_write_file(content=...)` +3. **`restricted_write_file`**: Writes the structured autotuning specifications. + - Required Arguments: + - `kernel_name` (string): The name of the Pallas kernel. + - `code_template` (string): The kernel source code template with placeholders. + - `search_space` (dict): Dictionary mapping placeholder names to lists of suggested tuning values. + - Example: `restricted_write_file(kernel_name="pallas_kernel", code_template="...", search_space={"BLOCK_M": [32, 64]})` """ diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py index 70adb3e..e606f92 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py +++ b/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py @@ -1,18 +1,32 @@ """Prompt for AutotuneSummarizerAgent.""" PROMPT = """ -You are an AI assistant summarizing autotuning results for a Pallas kernel optimization task. +You are applying the best configuration and providing a summary of autotuning results. -Your goal is to summarize the autotuning results provided below, report the best configuration and latency to the user, and state whether this configuration was applied. +Your goal is to summarize the autotuning results provided below, report the best configuration and latency, and apply the best configuration if the status is success. Autotuning Results: -{autotune_results} +{autotune_results?} -Instructions: -1. **Extract Metrics**: Find the `"best_cfg"` and `"best_time_ms"` in the results above. -2. **Summarize**: Provide a clear summary in your response. Do NOT list all tested configurations from `all_results`. -3. **Verify Application**: To determine if the best configuration was applied, read the file at {optimized_kernel_path?} and verify that the configuration parameters in the file match the values listed in `"best_config"` from the autotuning results. State whether it was applied. -4. **Handle Errors**: If the status is `"failed"` or `"error"`, report the error message provided in the file. + +Check the status of the autotuning results: + +### Case 1: If the status is "success" +You must: +1. Extract the `"best_cfg"` and `"best_time_ms"` from the results above. +2. Apply `"best_cfg"` to the kernel code located at {optimized_kernel_path?} by: + a. Use the `read_file` tool to read the kernel code {optimized_kernel_path?} + b. Replace their configured values with the values found in `best_config` (e.g., replace `BLOCK_M = 32` with `BLOCK_M = 128` if `best_config` contains `"BLOCK_M": 128`). Ensure the formatting of the script remains valid. + c. Write the updated optimized kernel code back using `restricted_write_file` tool. +3. Verify the best configuration is applied correctly by reading the updated file. +4. Provide a clear summary in your response. Do NOT list all tested configurations from `all_results`. + +### Case 2: If the status is "failed" or "error" +You must: +1. Report the error message and do NOT apply any configuration. + +In all cases, you must: +1. Provide a clear summary in your response. Do NOT list all tested configurations from `all_results`. Please use the following format for your summary: ### Autotuning Results @@ -20,6 +34,4 @@ - **Best Configuration**: `[JSON or description of best config]` - **Latency**: `[Time]` ms - **Applied to File**: [Yes / No] - -[Any additional brief notes or error messages] """ diff --git a/MaxKernel/auto_agent/subagents/pipeline_agent.py b/MaxKernel/auto_agent/subagents/pipeline_agent.py index ad06020..cea9a88 100644 --- a/MaxKernel/auto_agent/subagents/pipeline_agent.py +++ b/MaxKernel/auto_agent/subagents/pipeline_agent.py @@ -3,6 +3,7 @@ import logging import os import re +import shutil from typing import AsyncGenerator from google.adk.agents import BaseAgent @@ -85,6 +86,9 @@ async def _run_async_impl( logging.error( f"[{self.name}] Compilation failed. Looping back to planning." ) + self._save_iteration_files( + ctx, iteration, keys_to_save=["optimized_kernel_path"] + ) iteration += 1 continue @@ -99,6 +103,11 @@ async def _run_async_impl( logging.error( f"[{self.name}] Test generation/validation failed. Looping back to planning." ) + self._save_iteration_files( + ctx, + iteration, + keys_to_save=["optimized_kernel_path", "test_file_path"], + ) iteration += 1 continue @@ -111,6 +120,11 @@ async def _run_async_impl( test_results = ctx.session.state.get("test_results", {}) if not test_results.get("success", False): logging.error(f"[{self.name}] Tests failed. Looping back to planning.") + self._save_iteration_files( + ctx, + iteration, + keys_to_save=["optimized_kernel_path", "test_file_path"], + ) iteration += 1 continue @@ -161,9 +175,11 @@ async def _run_async_impl( ) logging.info(f"[{self.name}] Saved snapshot for iteration {iteration}") - # Step 7: Check if improvement is needed - needs_improvement = ctx.session.state.get("needs_improvement", False) + self._save_iteration_files(ctx, iteration) + # Step 7: Check if improvement is needed + # needs_improvement = ctx.session.state.get("needs_improvement", False) + needs_improvement = True if not needs_improvement: logging.info( f"[{self.name}] No further improvement needed or agent decided to stop. Stopping pipeline." @@ -193,13 +209,42 @@ async def _run_async_impl( ), ) + def _save_iteration_files( + self, + ctx: InvocationContext, + iteration: int, + keys_to_save: list[str] | None = None, + ): + """Saves artifacts with an iteration suffix.""" + if keys_to_save is None: + keys_to_save = [ + "optimized_kernel_path", + "test_file_path", + "autotune_specs_path", + "autotune_results_path", + ] + for path_key in keys_to_save: + path = ctx.session.state.get(path_key) + if path and os.path.exists(path): + directory, filename = os.path.split(path) + name, ext = os.path.splitext(filename) + new_filename = f"{name}_{iteration}{ext}" + new_path = os.path.join(directory, new_filename) + try: + shutil.copy2(path, new_path) + logging.info(f"[{self.name}] Copied {path_key} to {new_path}") + except Exception as e: + logging.error( + f"[{self.name}] Failed to copy {path_key} to {new_path}: {e}" + ) + def _initialize_state(self, ctx: InvocationContext) -> Event: """Initializes session state with standard paths and returns the event.""" # Initialize history if "history" not in ctx.session.state: ctx.session.state["history"] = [] - # Explicitly dictate standard paths in state + # Path related states session_dir = os.path.join(WORKDIR, ctx.session.id) os.makedirs(session_dir, exist_ok=True) @@ -263,6 +308,15 @@ def _initialize_state(self, ctx: InvocationContext) -> Event: f"[{self.name}] Set autotune_results_path: {ctx.session.state['autotune_results_path']}" ) + # Test related states + if "atol" not in ctx.session.state: + ctx.session.state["atol"] = 1e-2 + logging.info(f"[{self.name}] Set atol: {ctx.session.state['atol']}") + + if "rtol" not in ctx.session.state: + ctx.session.state["rtol"] = 1e-2 + logging.info(f"[{self.name}] Set rtol: {ctx.session.state['rtol']}") + logging.info(f"[{self.name}] Published explicit path state update Event.") return Event( author=self.name, @@ -274,6 +328,10 @@ def _initialize_state(self, ctx: InvocationContext) -> Event: "kernel_plan_path": ctx.session.state["kernel_plan_path"], "test_file_path": ctx.session.state["test_file_path"], "profiling_script_path": ctx.session.state["profiling_script_path"], + "autotune_specs_path": ctx.session.state["autotune_specs_path"], + "autotune_results_path": ctx.session.state["autotune_results_path"], + "atol": ctx.session.state["atol"], + "rtol": ctx.session.state["rtol"], } ), ) diff --git a/MaxKernel/auto_agent/subagents/testing/agent.py b/MaxKernel/auto_agent/subagents/testing/agent.py index c120de1..4cfb265 100644 --- a/MaxKernel/auto_agent/subagents/testing/agent.py +++ b/MaxKernel/auto_agent/subagents/testing/agent.py @@ -33,6 +33,7 @@ COMPILE_VALIDATION_TIMEOUT = 60 * 1 MOCK_EXECUTION_TIMEOUT = 60 * 3 TEST_EXECUTION_TIMEOUT = 60 * 5 +TEST_EXECUTION_POLL_INTERVAL = 20 class TestRunner(BaseAgent): @@ -619,7 +620,7 @@ async def _run_async_impl( "eval_type": "unified_test", "code": mock_code_content, "timeout": MOCK_EXECUTION_TIMEOUT, - "backend_type": "cpu", + "backend_type": "tpu", "dependencies": dependencies, } diff --git a/MaxKernel/auto_agent/subagents/testing/prompts/fix_test_script.py b/MaxKernel/auto_agent/subagents/testing/prompts/fix_test_script.py index 5b79012..d2a2b38 100644 --- a/MaxKernel/auto_agent/subagents/testing/prompts/fix_test_script.py +++ b/MaxKernel/auto_agent/subagents/testing/prompts/fix_test_script.py @@ -72,6 +72,8 @@ 2. **DO NOT modify the actual kernel implementations** - only fix the test file 3. **Keep all test classes and test methods** - just fix syntax/import/structure issues 4. Focus ONLY on making the test file valid Python code with correct imports and pytest structure +5. **Numerical Tolerance**: Use the specified tolerances: atol={atol?}, rtol={rtol?}. If they are not specified, default to atol=1e-3, rtol=1e-3. +6. **Input Generation**: Ensure that if the base kernel file (`{base_kernel_path?}`) defines an input generation function (e.g. `create_inputs`), it is reused/imported and used directly. ### Step 4: Write the Fixed Test File diff --git a/MaxKernel/auto_agent/subagents/testing/prompts/gen_test_file.py b/MaxKernel/auto_agent/subagents/testing/prompts/gen_test_file.py index b139157..6e2a141 100644 --- a/MaxKernel/auto_agent/subagents/testing/prompts/gen_test_file.py +++ b/MaxKernel/auto_agent/subagents/testing/prompts/gen_test_file.py @@ -34,7 +34,6 @@ All files are located in: {workdir}" -**Then wait for the user's response before proceeding.** ## Tool Usage @@ -66,6 +65,7 @@ - The function names and signatures - Input/output shapes and types - Any configuration parameters (block_size, tile_size, etc.) + - Check if the base kernel (`{base_kernel_path?}`) contains an input generation function (e.g., `get_inputs` or similar). If it is defined, you should directly copy and reuse it in the test file. 2. **Generate a complete pytest test file** that includes: @@ -74,18 +74,19 @@ - Test that the base kernel compiles (for reference) 2. **TestCorrectness class**: Tests that verify numerical correctness - - Compare outputs between base and optimized kernels - - Test with multiple input sizes using pytest.mark.parametrize - - Test edge cases (zeros, ones, random inputs) - - Use appropriate tolerance (rtol=1e-5, atol=1e-5 or adjust based on kernel) + - Compare outputs between base and optimized kernels with the tolerance specified in the state: rtol={rtol?}, atol={atol?}. If they are not specified, use appropriate tolerance (rtol=1e-3, atol=1e-3 or adjust based on kernel) + - If the base kernel file (`{base_kernel_path?}`) contains an input generation function (e.g., `get_inputs` or similar). If it is defined, you should directly copy and reuse it in the test file. + - If no input generation function is defined, + - Test with multiple input sizes using pytest.mark.parametrize + - Test edge cases (zeros, ones, random inputs) - **Note**: During validation, the optimized kernel import will be temporarily disabled to verify the test structure works with baseline only 3. **TestPerformance class**: Tests that benchmark performance + - If the base kernel file (`{base_kernel_path?}`) contains a function to generate inputs (e.g., `get_inputs` or similar). If it is defined, you should directly copy and reuse it in the test file. - Compare execution time between base and optimized kernels - - Test different tiling/block size configurations if applicable - Include warmup runs before timing - Use .block_until_ready() for accurate JAX timing - - Run 20 iterations for each benchmark to get reliable timing measurements + - Run 10 iterations for each benchmark to get reliable timing measurements ## Requirements @@ -175,10 +176,11 @@ def test_performance_comparison(self): - Look for functions that match the operation (e.g., matmul, flash_attention, conv2d) - Test with appropriate input shapes based on the kernel type -3. **Realistic test data**: - - Use appropriate input sizes based on the kernel - - Generate random data with jax.random for correctness tests - - Use larger sizes for performance tests +3. **Realistic test data & Input Generation**: + - If the base kernel file (`{base_kernel_path?}`) defines a function to generate inputs (such as `get_test_inputs(...)`), you MUST directly reuse/import or copy and use it to construct test inputs. + - If no input generation function is defined, generate random data with `jax.random` suitable for the kernel. + - Use appropriate input sizes based on the kernel configuration. + - Use larger sizes for performance tests if needed. 4. **Error handling**: Include try-catch where compilation might fail diff --git a/MaxKernel/auto_agent/tools/file_tools.py b/MaxKernel/auto_agent/tools/file_tools.py index bc24a5f..8fc47ae 100644 --- a/MaxKernel/auto_agent/tools/file_tools.py +++ b/MaxKernel/auto_agent/tools/file_tools.py @@ -1,7 +1,9 @@ """File-related tools for subagents.""" +import json import os from pathlib import Path +from typing import Any, Dict, List from google.adk.tools import FunctionTool, ToolContext from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams @@ -74,6 +76,45 @@ def _write_file(content: str, tool_context: ToolContext) -> str: return FunctionTool(_write_file) +def write_autotune_specs_tool_fn( + kernel_name: str, + code_template: str, + search_space: Dict[str, List[Any]], + tool_context: ToolContext, +) -> str: + """Writes the structured autotuning specifications to autotune_specs_path in session state. + + Args: + kernel_name: The name of the Pallas kernel. + code_template: The kernel source code template with placeholders like {BLOCK_M}. + search_space: Dictionary mapping placeholder names to lists of suggested tuning values. + """ + target_path = tool_context.state.get("autotune_specs_path") + if not target_path: + return ( + "Error: Path variable 'autotune_specs_path' not found in session state." + ) + + base = Path(WORKDIR).resolve() + target = Path(target_path).resolve() + + try: + if not target.is_relative_to(base): + return f"Error: Access denied. Path is outside {WORKDIR}" + except ValueError: + return "Error: Invalid path or access denied." + + content_dict = { + "kernel_name": kernel_name, + "code_template": code_template, + "search_space": search_space, + } + + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(json.dumps(content_dict, indent=2)) + return f"Successfully wrote structured autotuning specs to {target}" + + write_test_file_tool = restricted_write_file( "test_file_path", "Writes the generated pytest file." ) @@ -86,9 +127,8 @@ def _write_file(content: str, tool_context: ToolContext) -> str: write_profiling_script_tool = restricted_write_file( "profiling_script_path", "Writes the profiling script." ) -write_autotune_specs_tool = restricted_write_file( - "autotune_specs_path", "Writes the autotuning specifications." -) +write_autotune_specs_tool_fn.__name__ = "restricted_write_file" +write_autotune_specs_tool = FunctionTool(write_autotune_specs_tool_fn) __all__ = [ "filesystem_tool_r",