Fix autotune agent from passing incorrect configurations and enforce test agent to use input from reference code#47
Conversation
73d4f4a to
10f7c2b
Compare
…t into AutotuneSummaryAgent and change write_autotune_specs_tool to force JSON format
…l tolerance into session state and test prompts
10f7c2b to
1044c1a
Compare
…ructions and error handling logic
9dbcaf7 to
0a77e2e
Compare
…ver and client pipelines
0a77e2e to
34928cd
Compare
| jitter = random.uniform(0.1, 2.0) | ||
| logger.info(f"Sleeping for {jitter:.2f}s (jitter) to avoid DB lock.") | ||
| time.sleep(jitter) | ||
|
|
There was a problem hiding this comment.
Yeah, I did rebase and it showed here. I stack these two pr together. I think once the first one is merge and then rabasing, these overlap will disapper?
There was a problem hiding this comment.
You don't need to rebase the code from other PRs. In fact, this will make merging more complex. If you are to revert the rebase, let's do this.
| ), | ||
| ) | ||
|
|
||
| def _save_iteration_files( |
There was a problem hiding this comment.
This code change is in the previous PR too.
There was a problem hiding this comment.
Same as before.
There was a problem hiding this comment.
Let's remove the changes from the other PR.
| "code": mock_code_content, | ||
| "timeout": MOCK_EXECUTION_TIMEOUT, | ||
| "backend_type": "cpu", | ||
| "backend_type": "tpu", |
There was a problem hiding this comment.
If the original code has pallas kernel, mock test can fail.
| COMPILE_VALIDATION_TIMEOUT = 60 * 1 | ||
| MOCK_EXECUTION_TIMEOUT = 60 * 3 | ||
| TEST_EXECUTION_TIMEOUT = 60 * 5 | ||
| TEST_EXECUTION_POLL_INTERVAL = 20 |
There was a problem hiding this comment.
Also in previous PR.
Where is this variable used?
There was a problem hiding this comment.
Same as above.
| To help you build this `code_template`, you must read the reference kernel code located at {base_kernel_path?} and the testing script located at {test_file_path?} to understand the reference computation inputs, outputs, and validation logic. | ||
| All correctness check and timing logic must be defined inside this `code_template`: | ||
| - **Reference Computation**: The reference kernel will be automatically written to a file named `base_kernel.py` in the execution directory. Import functions/implementations directly from it (e.g. `from base_kernel import computation as reference_computation, get_inputs`). | ||
| - **Correctness Check**: In the main block of the template, perform a correctness check comparing the tuned kernel's output against the reference implementation's output (using `jnp.allclose` or `np.testing.assert_allclose` with appropriate tolerances, e.g., atol={atol?}, rtol={rtol?}). |
There was a problem hiding this comment.
What is the default values for atol and rtol? I think we should use the same values throughout testing, autotuning, and benchmarking.
There was a problem hiding this comment.
In my experiment it is set as 1e-2. autotune will look at that value as well as the test file to use the same tol.
| 1. Use `read_file` tool to read the optimized kernel code located at {optimized_kernel_path?}. | ||
| 2. Identify the parameters that can be tuned in the kernel (e.g., BLOCK_M, BLOCK_N). | ||
| 3. Create a code_template from the optimized kernel code, replacing the specific parameter values with placeholders enclosed in curly braces (for example, if the parameter is BLOCK_M, use it enclosed in curly braces as the placeholder). | ||
| To help you build this `code_template`, you must read the reference kernel code located at {base_kernel_path?} and the testing script located at {test_file_path?} to understand the reference computation inputs, outputs, and validation logic. |
There was a problem hiding this comment.
This assumes we have the test file ready. It won't apply to hitl agent. We need to be careful when we do ablation studies(when running without test agent).
There was a problem hiding this comment.
Without test file, it will rely only on the reference code.
There was a problem hiding this comment.
How does it generate correctness test without referring to test file and only on the reference code?
There was a problem hiding this comment.
I put a reference to the test file here is just to help autotune to generate test. Without this prompt, I think it will come up with the test just like test agent.
When you adapt this to hitl, you may try if the agent can generate good test by removing this from prompt.
| 3. **Keep all test classes and test methods** - just fix syntax/import/structure issues | ||
| 4. Focus ONLY on making the test file valid Python code with correct imports and pytest structure | ||
| 5. **Numerical Tolerance**: Use the specified tolerances: atol={atol?}, rtol={rtol?}. If they are not specified, default to atol=1e-3, rtol=1e-3. | ||
| 6. **Input Generation**: Ensure that if the base kernel file (`{base_kernel_path?}`) defines an input generation function (e.g. `create_inputs`), it is reused/imported and used directly. |
There was a problem hiding this comment.
atol and rtol should be the same values across different agents and benchmark.
There was a problem hiding this comment.
atol and rtol is set as 1e-2 for all agent as they are the state variable hardcoded in pipeline agent. The idea of using same tol value is perhaps not correct because some task may inherently be difficult to achieve the same level of accuracy. For example in minibench, each task is paired with its own specified atol and rtol.
There was a problem hiding this comment.
For jaxbench we are using 1e-1, and for autotune and test we are using 1e-2. This is a more strict rule for agent than for benchmark, right? Meaning, we filter out many candidates from autotune but they are actually valid during benchmark?
There was a problem hiding this comment.
I observed that agent sometimes realize the tol is too strict for that precision and then set a looser one. For example, in the gqa attention case, I see the agent change tol from 1e-2 to 5e-2.
I would suggest using a 1e-2 tol for the agent and ideally we should use 1e-2 for benchmark as well (I will actually try this). But give the benchmark sometimes is not that reasonable, I set a loose 1e-1 tol in benchmark now.
…ation for correctness checks and improve write instructions
No description provided.