Skip to content
Merged
6 changes: 6 additions & 0 deletions MaxKernel/auto_agent/agent_client/run_batch_agent_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
import random
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

Expand Down Expand Up @@ -89,6 +90,11 @@ def process_problem(
user_id = "user_0"
session_id = f"session_{problem_id}_attempt_{attempt}_{int(time.time())}"

# Add random jitter to avoid SQLite database lock contention
jitter = random.uniform(0.1, 2.0)
logger.info(f"Sleeping for {jitter:.2f}s (jitter) to avoid DB lock.")
time.sleep(jitter)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is in #46.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I did rebase and it showed here. I stack these two pr together. I think once the first one is merge and then rabasing, these overlap will disapper?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to rebase the code from other PRs. In fact, this will make merging more complex. If you are to revert the rebase, let's do this.

client = AutoAgentClient(
user_id=user_id,
session_id=session_id,
Expand Down
1 change: 1 addition & 0 deletions MaxKernel/auto_agent/server_utils/eval_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ async def _perform_evaluation(request: EvalRequest):
payload["search_space"] = request.search_space
backend_timeout = request.total_timeout
payload["total_timeout"] = request.total_timeout
payload["dependencies"] = request.dependencies
else:
payload["code"] = request.code
payload["dependencies"] = request.dependencies
Expand Down
78 changes: 60 additions & 18 deletions MaxKernel/auto_agent/server_utils/tpu_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import re
import shutil
import subprocess
import sys
import tempfile
Expand Down Expand Up @@ -49,6 +50,7 @@ class AutotuneRequest(BaseModel):
search_space: dict[str, list]
timeout: Optional[int] = 300
total_timeout: Optional[int] = None
dependencies: Optional[dict] = None


class GetTpuVersionResponse(BaseModel):
Expand Down Expand Up @@ -512,7 +514,17 @@ async def profile(request: CodeRequest):
async def autotune(request: AutotuneRequest):
logging.info("Starting autotune")
async with performance_semaphore:
temp_dir = None
try:
# Create unique temporary directory for this autotune request
temp_dir = tempfile.mkdtemp()
if request.dependencies:
for filename, content in request.dependencies.items():
file_path = os.path.join(temp_dir, filename)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as f:
f.write(content)

# Generate all combinations
keys = list(request.search_space.keys())
values = list(request.search_space.values())
Expand Down Expand Up @@ -543,11 +555,9 @@ async def autotune(request: AutotuneRequest):
continue

# Execute the code
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", prefix="hitl_eval_", delete=False
) as temp_file:
temp_file_path = os.path.join(temp_dir, "run_code.py")
with open(temp_file_path, "w") as temp_file:
temp_file.write(code_content)
temp_file_path = temp_file.name

process = None
try:
Expand All @@ -556,7 +566,7 @@ async def autotune(request: AutotuneRequest):
temp_file_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=tempfile.gettempdir(),
cwd=temp_dir,
)

stdout, stderr = await asyncio.wait_for(
Expand All @@ -568,24 +578,45 @@ async def autotune(request: AutotuneRequest):
exit_code = process.returncode

if exit_code == 0:
# Parse RESULT_TIME
match = re.search(r"RESULT_TIME:\s*([0-9.]+)\s*ms", output)
if match:
time_taken = float(match.group(1))
correctness_match = re.search(
r"CORRECTNESS:\s*(true|false)", output, re.IGNORECASE
)
time_match = re.search(
r"RESULT_TIME:\s*([0-9.]+)\s*ms", output, re.IGNORECASE
)

correctness_passed = False
if correctness_match:
correctness_passed = correctness_match.group(1).lower() == "true"

time_taken = float(time_match.group(1)) if time_match else None

if not correctness_passed:
logging.warning(
f"Correctness check failed or unknown for config {cfg}"
)
all_results.append(
{"cfg": cfg, "time": time_taken, "status": "success"}
{
"cfg": cfg,
"status": "correctness_failed_or_unknown",
"output": output,
}
)
if time_taken < best_time:
best_time = time_taken
best_cfg = cfg
best_output = output
else:
elif time_taken is None:
logging.warning(
f"No RESULT_TIME found in output for config {cfg}"
)
all_results.append(
{"cfg": cfg, "status": "no_result_time", "output": output}
)
else:
all_results.append(
{"cfg": cfg, "time": time_taken, "status": "success"}
)
if time_taken < best_time:
best_time = time_taken
best_cfg = cfg
best_output = output
else:
logging.warning(
f"Config {cfg} failed with exit code {exit_code}. Stderr: {error}"
Expand All @@ -602,16 +633,19 @@ async def autotune(request: AutotuneRequest):
except asyncio.TimeoutError:
logging.warning(f"Config {cfg} timed out")
if process:
process.kill()
await process.wait()
try:
process.kill()
await process.wait()
except Exception as e:
logging.error(f"Failed to kill process: {e}")
all_results.append({"cfg": cfg, "status": "timeout"})
except Exception as e:
logging.error(f"Error running config {cfg}: {e}")
all_results.append(
{"cfg": cfg, "status": "exception", "error": str(e)}
)
finally:
if "temp_file_path" in locals():
if "temp_file_path" in locals() and os.path.exists(temp_file_path):
try:
os.unlink(temp_file_path)
except OSError:
Expand All @@ -632,6 +666,14 @@ async def autotune(request: AutotuneRequest):
except Exception as e:
logging.error(f"Autotune failed with error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Autotune error: {str(e)}")
finally:
if temp_dir and os.path.exists(temp_dir):
try:
shutil.rmtree(temp_dir)
except Exception as e:
logging.error(
f"Failed to clean up autotune temp directory {temp_dir}: {e}"
)


@app.post("/get_tpu_version", response_model=GetTpuVersionResponse)
Expand Down
75 changes: 21 additions & 54 deletions MaxKernel/auto_agent/subagents/autotuning/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from typing import AsyncGenerator, Optional

from google.adk.agents import BaseAgent
from google.adk.agents import BaseAgent, SequentialAgent
from google.adk.agents.invocation_context import InvocationContext
from google.adk.events import Event, EventActions

Expand All @@ -14,7 +14,6 @@
from auto_agent.custom_types import CustomLlmAgent
from auto_agent.subagents.autotuning.autotune_tool import autotune_kernel
from auto_agent.subagents.autotuning.prompts import (
apply_best_config_prompt,
autotune_prompt,
summary_prompt,
)
Expand All @@ -25,7 +24,7 @@
)
from auto_agent.tools.search_api_tool import search_api_tool

# 1. Planner Agent (LLM)
# 1. Planner Agent
# This agent identifies parameters, creates the template, and defines the search space.
# It saves them to session state instead of calling the tool directly.
autotune_planner_agent = CustomLlmAgent(
Expand Down Expand Up @@ -106,6 +105,17 @@ async def _run_async_impl(
)
return

dependencies = {}
base_kernel_path = ctx.session.state.get("base_kernel_path", "")
if base_kernel_path and os.path.exists(base_kernel_path):
try:
with open(base_kernel_path, "r") as f:
dependencies[os.path.basename(base_kernel_path)] = f.read()
except Exception as e:
logging.warning(
f"[{self.name}] Failed to read base kernel file {base_kernel_path}: {e}"
)

logging.info(f"[{self.name}] Running autotune for {kernel_name}")

try:
Expand All @@ -114,6 +124,7 @@ async def _run_async_impl(
code_template=code_template,
search_space=search_space,
backend="tpu",
dependencies=dependencies,
)

try:
Expand Down Expand Up @@ -147,67 +158,23 @@ async def _run_async_impl(
output_key="autotune_results",
)

# 3. Apply Best Config Agent
apply_best_config_agent = CustomLlmAgent(
name="ApplyBestConfigAgent",
model=MODEL_NAME,
generate_content_config=model_config,
planner=thinking_planner,
instruction=apply_best_config_prompt.PROMPT,
description="Applies autotuning results to the optimized kernel file.",
tools=[filesystem_tool_r, write_optimized_kernel_tool],
)

# 4. Summarizer Agent
# 3. Summarizer Agent
# This agent reads results from state and talks to the user.
autotune_summary_agent = CustomLlmAgent(
name="AutotuneSummaryAgent",
model=MODEL_NAME,
generate_content_config=model_config,
planner=thinking_planner,
instruction=summary_prompt.PROMPT,
description="Summarizes autotuning results for the user.",
tools=[filesystem_tool_r],
description="Apply and Summarizes autotuning results.",
tools=[filesystem_tool_r, write_optimized_kernel_tool],
output_key="autotuning_summary",
)


class CombinedAutotuneAgent(BaseAgent):
"""Chains autotuning steps and conditionally applies best config."""

def __init__(self, name: str):
super().__init__(name=name)

async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
logging.info(f"[{self.name}] Running AutotunePlannerAgent...")
async for event in autotune_planner_agent.run_async(ctx):
yield event

logging.info(f"[{self.name}] Running AutotuneRunner...")
async for event in autotune_runner.run_async(ctx):
yield event

autotune_results = ctx.session.state.get("autotune_results", {})
if (
autotune_results.get("status") == "success"
and autotune_results.get("best_config") is not None
):
logging.info(f"[{self.name}] Running ApplyBestConfigAgent...")
async for event in apply_best_config_agent.run_async(ctx):
yield event
else:
logging.warning(
f"[{self.name}] Autotune was not successful or no best configuration"
" found. Skipping ApplyBestConfigAgent."
)

logging.info(f"[{self.name}] Running AutotuneSummaryAgent...")
async for event in autotune_summary_agent.run_async(ctx):
yield event


autotune_agent = CombinedAutotuneAgent(name="AutotuneAgent")
autotune_agent = SequentialAgent(
name="AutotuneAgent",
sub_agents=[autotune_planner_agent, autotune_runner, autotune_summary_agent],
)

__all__ = ["autotune_agent"]
4 changes: 4 additions & 0 deletions MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ async def autotune_kernel(
search_space: dict[str, list[Any]],
backend: str = None,
server_addr: str = "http://localhost",
dependencies: dict[str, str] = None,
) -> dict:
"""Runs a grid search to auto-tune a Pallas kernel on a remote server.

Expand All @@ -31,6 +32,8 @@ async def autotune_kernel(
values.
backend: 'tpu' or 'cpu'.
server_addr: Address of the server (default: http://localhost).
dependencies: A dictionary mapping dependency filenames to their contents.
Used to provide reference code for correctness check.

Returns:
A dictionary containing the status, optimal parameters, and a summary of
Expand All @@ -52,6 +55,7 @@ async def autotune_kernel(
"timeout": AUTOTUNE_INDIVIDUAL_TIMEOUT,
"backend_type": backend,
"total_timeout": AUTOTUNE_TOTAL_TIMEOUT,
"dependencies": dependencies,
}
result = await call_eval_server_async(
session,
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
PROMPT = """You are a specialized agent for preparing autotuning specifications for Pallas kernels.
Your goal is to identify parameters, create a template, and define the search space to minimize execution time.

CRITICAL: Do NOT attempt to optimize the kernel code, improve its logic, or fix any bugs. Your task is strictly to prepare the template for autotuning by replacing hardcoded parameters with placeholders and adding timing code.

To prepare for autotuning, you must:
1. Identify the parameters that can be tuned in the kernel (e.g., BLOCK_M, BLOCK_N).
2. Create a code template from the kernel code, replacing the specific parameter values with placeholders enclosed in curly braces (for example, if the parameter is BLOCK_M, use it enclosed in curly braces as the placeholder).
3. Ensure the template code prints "RESULT_TIME: <float> ms" to indicate the average execution time in microseconds. To get accurate and quick timing, wrap the kernel call in a loop of exactly 10 iterations (preceded by 1 warm-up execution) and use `jax.block_until_ready()`. Limit iterations strictly to 10 to keep profiling runs fast. WARNING: If you wrap the kernel call in a loop, check if the kernel donates its input buffers (look for `donate_argnames` in the kernel decorator). If it does, calling it repeatedly with the same inputs will fail. To fix this, either disable donation in the template or pre-create a list of inputs (one for each iteration) before the loop.
1. Use `read_file` tool to read the optimized kernel code located at {optimized_kernel_path?}.
2. Identify the parameters that can be tuned in the kernel (e.g., BLOCK_M, BLOCK_N).
3. Create a code_template from the optimized kernel code, replacing the specific parameter values with placeholders enclosed in curly braces (for example, if the parameter is BLOCK_M, use it enclosed in curly braces as the placeholder).
To help you build this `code_template`, you must read the reference kernel code located at {base_kernel_path?} and the testing script located at {test_file_path?} to understand the reference computation inputs, outputs, and validation logic.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes we have the test file ready. It won't apply to hitl agent. We need to be careful when we do ablation studies(when running without test agent).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without test file, it will rely only on the reference code.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it generate correctness test without referring to test file and only on the reference code?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put a reference to the test file here is just to help autotune to generate test. Without this prompt, I think it will come up with the test just like test agent.
When you adapt this to hitl, you may try if the agent can generate good test by removing this from prompt.

All correctness check and timing logic must be defined inside this `code_template`:
- **Reference Computation**: The reference kernel will be automatically written to a file named `base_kernel.py` in the execution directory. Import functions/implementations directly from it (e.g. `from base_kernel import computation as reference_computation, get_inputs`).
- **Correctness Check**: In the main block of the template, perform a correctness check comparing the tuned kernel's output against the reference implementation's output (using `jnp.allclose` or `np.testing.assert_allclose` with appropriate tolerances, e.g., atol={atol?}, rtol={rtol?}). Note that you must JIT compile both the reference and tuned computation function and invoke the jitted function to obtain its outputs for the correctness check, as Pallas kernels require compilation to execute correctly on TPU.
- **Timing/Warmup**: Wrap the tuned kernel call in a loop of exactly 10 iterations (preceded by 1 warm-up execution) and use `jax.block_until_ready()`. Limit iterations strictly to 10 to keep profiling runs fast. WARNING: If you wrap the kernel call in a loop, check if the kernel donates its input buffers. If it does, calling it repeatedly with the same inputs will fail. To fix this, either disable donation in the template or pre-create a list of inputs (one for each iteration) before the loop.
- **Printing Results**: The template code must always print "CORRECTNESS: <True/False>" and "RESULT_TIME: <float> ms".
4. Define a highly optimized, high-probability search space as a dictionary mapping placeholder names to lists of suggested values. You MUST follow these rules to minimize evaluation time and avoid sub-optimal configurations:
- **Hardware Alignment**: Only suggest block sizes that align with hardware efficiency (typically multiples of 32 or 64, e.g., `[32, 64, 128]`). Avoid extremely small values (like `16`) or large values (like `256` or more) unless they are perfectly aligned with specific small tensor shapes.
- **Dimension Divisors**: Choose suggested block sizes that are clean, even divisors of the corresponding matrix or tensor shape dimensions to prevent compiler masking and branch overhead.
Expand All @@ -25,7 +29,10 @@
1. **`search_api`**: Search for API definitions
2. **`read_file`**: Read the kernel code file.
- Required Argument: `path`
3. **`restricted_write_file`**: Write the json file
- Required Argument: `content` (The complete file content)
- Example: `restricted_write_file(content=...)`
3. **`restricted_write_file`**: Writes the structured autotuning specifications.
- Required Arguments:
- `kernel_name` (string): The name of the Pallas kernel.
- `code_template` (string): The kernel source code template with placeholders.
- `search_space` (dict): Dictionary mapping placeholder names to lists of suggested tuning values.
- Example: `restricted_write_file(kernel_name="pallas_kernel", code_template="...", search_space={"BLOCK_M": [32, 64]})`
"""
Loading