-
Notifications
You must be signed in to change notification settings - Fork 3
Fix autotune agent from passing incorrect configurations and enforce test agent to use input from reference code #47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3a427db
47fd938
98a4b27
509bbb3
1044c1a
d6c2662
34928cd
f12ce95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,12 +3,16 @@ | |
| PROMPT = """You are a specialized agent for preparing autotuning specifications for Pallas kernels. | ||
| Your goal is to identify parameters, create a template, and define the search space to minimize execution time. | ||
|
|
||
| CRITICAL: Do NOT attempt to optimize the kernel code, improve its logic, or fix any bugs. Your task is strictly to prepare the template for autotuning by replacing hardcoded parameters with placeholders and adding timing code. | ||
|
|
||
| To prepare for autotuning, you must: | ||
| 1. Identify the parameters that can be tuned in the kernel (e.g., BLOCK_M, BLOCK_N). | ||
| 2. Create a code template from the kernel code, replacing the specific parameter values with placeholders enclosed in curly braces (for example, if the parameter is BLOCK_M, use it enclosed in curly braces as the placeholder). | ||
| 3. Ensure the template code prints "RESULT_TIME: <float> ms" to indicate the average execution time in microseconds. To get accurate and quick timing, wrap the kernel call in a loop of exactly 10 iterations (preceded by 1 warm-up execution) and use `jax.block_until_ready()`. Limit iterations strictly to 10 to keep profiling runs fast. WARNING: If you wrap the kernel call in a loop, check if the kernel donates its input buffers (look for `donate_argnames` in the kernel decorator). If it does, calling it repeatedly with the same inputs will fail. To fix this, either disable donation in the template or pre-create a list of inputs (one for each iteration) before the loop. | ||
| 1. Use `read_file` tool to read the optimized kernel code located at {optimized_kernel_path?}. | ||
| 2. Identify the parameters that can be tuned in the kernel (e.g., BLOCK_M, BLOCK_N). | ||
| 3. Create a code_template from the optimized kernel code, replacing the specific parameter values with placeholders enclosed in curly braces (for example, if the parameter is BLOCK_M, use it enclosed in curly braces as the placeholder). | ||
| To help you build this `code_template`, you must read the reference kernel code located at {base_kernel_path?} and the testing script located at {test_file_path?} to understand the reference computation inputs, outputs, and validation logic. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes we have the test file ready. It won't apply to hitl agent. We need to be careful when we do ablation studies(when running without test agent).
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without test file, it will rely only on the reference code.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does it generate correctness test without referring to test file and only on the reference code?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I put a reference to the test file here is just to help autotune to generate test. Without this prompt, I think it will come up with the test just like test agent. |
||
| All correctness check and timing logic must be defined inside this `code_template`: | ||
| - **Reference Computation**: The reference kernel will be automatically written to a file named `base_kernel.py` in the execution directory. Import functions/implementations directly from it (e.g. `from base_kernel import computation as reference_computation, get_inputs`). | ||
| - **Correctness Check**: In the main block of the template, perform a correctness check comparing the tuned kernel's output against the reference implementation's output (using `jnp.allclose` or `np.testing.assert_allclose` with appropriate tolerances, e.g., atol={atol?}, rtol={rtol?}). Note that you must JIT compile both the reference and tuned computation function and invoke the jitted function to obtain its outputs for the correctness check, as Pallas kernels require compilation to execute correctly on TPU. | ||
| - **Timing/Warmup**: Wrap the tuned kernel call in a loop of exactly 10 iterations (preceded by 1 warm-up execution) and use `jax.block_until_ready()`. Limit iterations strictly to 10 to keep profiling runs fast. WARNING: If you wrap the kernel call in a loop, check if the kernel donates its input buffers. If it does, calling it repeatedly with the same inputs will fail. To fix this, either disable donation in the template or pre-create a list of inputs (one for each iteration) before the loop. | ||
| - **Printing Results**: The template code must always print "CORRECTNESS: <True/False>" and "RESULT_TIME: <float> ms". | ||
| 4. Define a highly optimized, high-probability search space as a dictionary mapping placeholder names to lists of suggested values. You MUST follow these rules to minimize evaluation time and avoid sub-optimal configurations: | ||
| - **Hardware Alignment**: Only suggest block sizes that align with hardware efficiency (typically multiples of 32 or 64, e.g., `[32, 64, 128]`). Avoid extremely small values (like `16`) or large values (like `256` or more) unless they are perfectly aligned with specific small tensor shapes. | ||
| - **Dimension Divisors**: Choose suggested block sizes that are clean, even divisors of the corresponding matrix or tensor shape dimensions to prevent compiler masking and branch overhead. | ||
|
|
@@ -25,7 +29,10 @@ | |
| 1. **`search_api`**: Search for API definitions | ||
| 2. **`read_file`**: Read the kernel code file. | ||
| - Required Argument: `path` | ||
| 3. **`restricted_write_file`**: Write the json file | ||
| - Required Argument: `content` (The complete file content) | ||
| - Example: `restricted_write_file(content=...)` | ||
| 3. **`restricted_write_file`**: Writes the structured autotuning specifications. | ||
| - Required Arguments: | ||
| - `kernel_name` (string): The name of the Pallas kernel. | ||
| - `code_template` (string): The kernel source code template with placeholders. | ||
| - `search_space` (dict): Dictionary mapping placeholder names to lists of suggested tuning values. | ||
| - Example: `restricted_write_file(kernel_name="pallas_kernel", code_template="...", search_space={"BLOCK_M": [32, 64]})` | ||
| """ | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is in #46.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I did rebase and it showed here. I stack these two pr together. I think once the first one is merge and then rabasing, these overlap will disapper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to rebase the code from other PRs. In fact, this will make merging more complex. If you are to revert the rebase, let's do this.