Skip to content

Fix autotune agent from passing incorrect configurations and enforce test agent to use input from reference code#47

Merged
shangkunwang01 merged 8 commits into
mainfrom
shangkun-fix-autotune
Jun 5, 2026
Merged

Fix autotune agent from passing incorrect configurations and enforce test agent to use input from reference code#47
shangkunwang01 merged 8 commits into
mainfrom
shangkun-fix-autotune

Conversation

@shangkunwang01

Copy link
Copy Markdown
Collaborator

No description provided.

@shangkunwang01 shangkunwang01 requested a review from NinaCai June 3, 2026 02:07
@shangkunwang01 shangkunwang01 force-pushed the shangkun-fix-autotune branch from 73d4f4a to 10f7c2b Compare June 3, 2026 04:49
…t into AutotuneSummaryAgent and change write_autotune_specs_tool to force JSON format
…l tolerance into session state and test prompts
@shangkunwang01 shangkunwang01 force-pushed the shangkun-fix-autotune branch from 10f7c2b to 1044c1a Compare June 3, 2026 04:53
@shangkunwang01 shangkunwang01 force-pushed the shangkun-fix-autotune branch from 9dbcaf7 to 0a77e2e Compare June 5, 2026 02:29
@shangkunwang01 shangkunwang01 force-pushed the shangkun-fix-autotune branch from 0a77e2e to 34928cd Compare June 5, 2026 03:53
jitter = random.uniform(0.1, 2.0)
logger.info(f"Sleeping for {jitter:.2f}s (jitter) to avoid DB lock.")
time.sleep(jitter)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is in #46.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I did rebase and it showed here. I stack these two pr together. I think once the first one is merge and then rabasing, these overlap will disapper?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to rebase the code from other PRs. In fact, this will make merging more complex. If you are to revert the rebase, let's do this.

),
)

def _save_iteration_files(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code change is in the previous PR too.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the changes from the other PR.

"code": mock_code_content,
"timeout": MOCK_EXECUTION_TIMEOUT,
"backend_type": "cpu",
"backend_type": "tpu",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it changed to tpu?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the original code has pallas kernel, mock test can fail.

COMPILE_VALIDATION_TIMEOUT = 60 * 1
MOCK_EXECUTION_TIMEOUT = 60 * 3
TEST_EXECUTION_TIMEOUT = 60 * 5
TEST_EXECUTION_POLL_INTERVAL = 20

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also in previous PR.

Where is this variable used?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

To help you build this `code_template`, you must read the reference kernel code located at {base_kernel_path?} and the testing script located at {test_file_path?} to understand the reference computation inputs, outputs, and validation logic.
All correctness check and timing logic must be defined inside this `code_template`:
- **Reference Computation**: The reference kernel will be automatically written to a file named `base_kernel.py` in the execution directory. Import functions/implementations directly from it (e.g. `from base_kernel import computation as reference_computation, get_inputs`).
- **Correctness Check**: In the main block of the template, perform a correctness check comparing the tuned kernel's output against the reference implementation's output (using `jnp.allclose` or `np.testing.assert_allclose` with appropriate tolerances, e.g., atol={atol?}, rtol={rtol?}).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the default values for atol and rtol? I think we should use the same values throughout testing, autotuning, and benchmarking.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my experiment it is set as 1e-2. autotune will look at that value as well as the test file to use the same tol.

1. Use `read_file` tool to read the optimized kernel code located at {optimized_kernel_path?}.
2. Identify the parameters that can be tuned in the kernel (e.g., BLOCK_M, BLOCK_N).
3. Create a code_template from the optimized kernel code, replacing the specific parameter values with placeholders enclosed in curly braces (for example, if the parameter is BLOCK_M, use it enclosed in curly braces as the placeholder).
To help you build this `code_template`, you must read the reference kernel code located at {base_kernel_path?} and the testing script located at {test_file_path?} to understand the reference computation inputs, outputs, and validation logic.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes we have the test file ready. It won't apply to hitl agent. We need to be careful when we do ablation studies(when running without test agent).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without test file, it will rely only on the reference code.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it generate correctness test without referring to test file and only on the reference code?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put a reference to the test file here is just to help autotune to generate test. Without this prompt, I think it will come up with the test just like test agent.
When you adapt this to hitl, you may try if the agent can generate good test by removing this from prompt.

3. **Keep all test classes and test methods** - just fix syntax/import/structure issues
4. Focus ONLY on making the test file valid Python code with correct imports and pytest structure
5. **Numerical Tolerance**: Use the specified tolerances: atol={atol?}, rtol={rtol?}. If they are not specified, default to atol=1e-3, rtol=1e-3.
6. **Input Generation**: Ensure that if the base kernel file (`{base_kernel_path?}`) defines an input generation function (e.g. `create_inputs`), it is reused/imported and used directly.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

atol and rtol should be the same values across different agents and benchmark.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

atol and rtol is set as 1e-2 for all agent as they are the state variable hardcoded in pipeline agent. The idea of using same tol value is perhaps not correct because some task may inherently be difficult to achieve the same level of accuracy. For example in minibench, each task is paired with its own specified atol and rtol.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For jaxbench we are using 1e-1, and for autotune and test we are using 1e-2. This is a more strict rule for agent than for benchmark, right? Meaning, we filter out many candidates from autotune but they are actually valid during benchmark?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I observed that agent sometimes realize the tol is too strict for that precision and then set a looser one. For example, in the gqa attention case, I see the agent change tol from 1e-2 to 5e-2.
I would suggest using a 1e-2 tol for the agent and ideally we should use 1e-2 for benchmark as well (I will actually try this). But give the benchmark sometimes is not that reasonable, I set a loose 1e-1 tol in benchmark now.

@shangkunwang01 shangkunwang01 changed the title refactor: simplify autotuning workflow by merging ApplyBestConfigAgent into AutotuneSummaryAgent and change write_autotune_specs_tool to force JSON format Fix autotune agent from passing incorrect configurations and enforce test agent to use input from reference code Jun 5, 2026
…ation for correctness checks and improve write instructions
@NinaCai NinaCai self-requested a review June 5, 2026 18:49

@NinaCai NinaCai left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR rebased #46.

After merging this PR, let's close PR 46. Let's make separate PRs without rebasing each other next time.

@shangkunwang01 shangkunwang01 merged commit e6ec447 into main Jun 5, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants