diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py index 3e15846..8670c14 100644 --- a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py @@ -65,7 +65,10 @@ def get_gpu_specs(gpu_name: str) -> dict[str, Any] | None: ... print(f"SM Count: {specs['sm_count']}") """ if gpu_name in GPU_SPECS_DATABASE: - return GPU_SPECS_DATABASE[gpu_name].copy() + # GPU_SPECS_DATABASE is a MappingProxyType (read-only), so we return a + # mutable copy to allow callers to modify the result without affecting + # the database. + return dict(GPU_SPECS_DATABASE[gpu_name]) logger.warning( "Unknown GPU: '%s'. Disable Optimization. Available GPUs: %s", diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py index cbc616d..0c984b4 100644 --- a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py @@ -27,7 +27,9 @@ Last Updated: January 2026 """ -GPU_SPECS_DATABASE: dict[str, dict[str, object]] = { +from types import MappingProxyType + +_GPU_SPECS_DATABASE: dict[str, dict[str, object]] = { # NVIDIA A100 SKUs - SXM4 Variants "NVIDIA A100 SXM4 40GB": { "name": "NVIDIA A100 SXM4 40GB", @@ -180,3 +182,6 @@ "tdp_w": 360, }, } + +# Make database read-only to prevent accidental modification +GPU_SPECS_DATABASE = MappingProxyType(_GPU_SPECS_DATABASE) diff --git a/triton_kernel_agent/prompt_manager.py b/triton_kernel_agent/prompt_manager.py index 9534fc9..7252aea 100644 --- a/triton_kernel_agent/prompt_manager.py +++ b/triton_kernel_agent/prompt_manager.py @@ -88,6 +88,7 @@ def _load_templates(self): "test_generation": "test_generation.j2", "kernel_generation": "kernel_generation.j2", "kernel_refinement": "kernel_refinement.j2", + "kernel_optimization": "kernel_optimization.j2", "triton_guidelines": "triton_guidelines.j2", } @@ -194,6 +195,64 @@ def render_kernel_refinement_prompt( no_cusolver=no_cusolver, ) + def render_kernel_optimization_prompt( + self, + problem_description: str, + kernel_code: str, + gpu_specs: dict, + roofline: dict, + category: str, + summary: str, + reasoning: str, + root_cause: dict, + recommended_fix: dict, + pytorch_baseline_ms: float | None = None, + current_best_ms: float | None = None, + error_feedback: str | None = None, + ) -> str: + """ + Render the kernel optimization prompt. + + Args: + problem_description: Description of the problem + kernel_code: Current kernel implementation + gpu_specs: GPU hardware specifications dict + roofline: Roofline analysis result dict with keys: + bottleneck, compute_sol_pct, memory_sol_pct, efficiency_pct, + headroom_pct, at_roofline, uses_tensor_cores, warnings + category: Bottleneck category ("memory", "compute", "underutilized") + summary: One-line bottleneck summary + reasoning: Explanation citing metrics + root_cause: Single root cause dict {"cause": "...", "evidence": [...]} + recommended_fix: Single fix dict {"fix": "...", "rationale": "..."} + pytorch_baseline_ms: PyTorch Eager baseline time in ms + current_best_ms: Current best kernel time in ms (for iterative opt) + error_feedback: Error message from previous failed attempt + + Returns: + Rendered prompt string + """ + template = self.templates["kernel_optimization"] + + bottleneck = { + "category": category, + "summary": summary, + "reasoning": reasoning, + "root_cause": root_cause, + "recommended_fix": recommended_fix, + } + + return template.render( + problem_description=problem_description, + kernel_code=kernel_code, + gpu_specs=gpu_specs, + roofline=roofline, + bottleneck=bottleneck, + pytorch_baseline_ms=pytorch_baseline_ms, + current_best_ms=current_best_ms, + error_feedback=error_feedback, + ) + def render_triton_guidelines(self) -> str: """ Render the Triton guidelines. diff --git a/triton_kernel_agent/templates/kernel_optimization.j2 b/triton_kernel_agent/templates/kernel_optimization.j2 new file mode 100644 index 0000000..92fe699 --- /dev/null +++ b/triton_kernel_agent/templates/kernel_optimization.j2 @@ -0,0 +1,101 @@ +{# +Copyright (c) Meta Platforms, Inc. and affiliates. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +#} + +TASK: Optimize the following Triton kernel based on hardware profiling analysis to achieve better performance. + +{% if gpu_specs %} +## TARGET GPU +- GPU: {{ gpu_specs.name }} +- Architecture: {{ gpu_specs.architecture }} +- Peak Memory Bandwidth: {{ gpu_specs.peak_memory_bw_gbps }} GB/s +- Peak FP32: {{ gpu_specs.peak_fp32_tflops }} TFLOPS +- Peak FP16: {{ gpu_specs.peak_fp16_tflops }} TFLOPS +- Peak BF16: {{ gpu_specs.peak_bf16_tflops }} TFLOPS +- SM Count: {{ gpu_specs.sm_count }} +- Max Threads per SM: {{ gpu_specs.max_threads_per_sm }} +- L1 Cache per SM: {{ gpu_specs.l1_cache_kb }} KB +- L2 Cache: {{ gpu_specs.l2_cache_mb }} MB +- Memory: {{ gpu_specs.memory_gb }} GB {{ gpu_specs.memory_type }} +{% endif %} + +## PROBLEM DESCRIPTION +{{ problem_description }} +{% if pytorch_baseline_ms %} +PyTorch Eager baseline: {{ "%.4f"|format(pytorch_baseline_ms) }} ms +{% endif %} + +## CURRENT KERNEL +```python +{{ kernel_code }} +``` + +{% if roofline %} +## ROOFLINE ANALYSIS +- Primary Bottleneck: {{ roofline.bottleneck | upper }} +- Compute SOL: {{ "%.1f"|format(roofline.compute_sol_pct) }}% +- Memory SOL: {{ "%.1f"|format(roofline.memory_sol_pct) }}% +- Efficiency: {{ "%.1f"|format(roofline.efficiency_pct) }}% (headroom: {{ "%.1f"|format(roofline.headroom_pct) }}%) +- At Roofline: {{ "Yes" if roofline.at_roofline else "No" }} +- Tensor Cores: {{ "Active" if roofline.uses_tensor_cores else "Inactive" }} +{%- if roofline.warnings %} +- Warnings: {{ roofline.warnings | join("; ") }} +{%- endif %} +{% endif %} + +## BOTTLENECK ANALYSIS +### Category: {{ bottleneck.category | upper }} +{{ bottleneck.summary }} + +**Reasoning:** {{ bottleneck.reasoning }} + +**Root Cause:** {{ bottleneck.root_cause.cause }} +{%- if bottleneck.root_cause.evidence %} + Evidence: {% for e in bottleneck.root_cause.evidence %}{{ e.metric }}={{ e.value }}{% if not loop.last %}, {% endif %}{% endfor %} +{%- endif %} + +**Recommended Fix:** {{ bottleneck.recommended_fix.fix }} +{%- if bottleneck.recommended_fix.rationale %} ({{ bottleneck.recommended_fix.rationale }}){% endif %} + +{% if error_feedback %} +## PREVIOUS ATTEMPT FAILED +{{ error_feedback }} +{% endif %} + +## PERFORMANCE TARGET +{% if pytorch_baseline_ms %} +- PyTorch Eager baseline: {{ "%.4f"|format(pytorch_baseline_ms) }} ms +{% endif %} +{% if current_best_ms %} +- Current best kernel: {{ "%.4f"|format(current_best_ms) }} ms +- Target: Improve by at least 10% (< {{ "%.4f"|format(current_best_ms * 0.9) }} ms) +{% else %} +- Target: Improve by at least 10% over Eager (< {{ "%.4f"|format(pytorch_baseline_ms * 0.9) }} ms) +{% endif %} +- Maintain numerical correctness (atol=1e-4 or rtol=1e-4) +- Preserve public API (same inputs/outputs, shapes, dtypes) + +## REQUIREMENTS +1. Apply the recommended fixes above to address the {{ bottleneck.category | upper }} bottleneck +2. The implementation must be a complete, valid Python file +3. Main function must be named 'kernel_function' wrapping the Triton kernel +4. Keep the wrapper free of PyTorch compute primitives + +## OUTPUT FORMAT +Output complete optimized kernel code in ```python blocks. +Include only: imports, Triton kernel (@triton.jit), wrapper function (kernel_function). +No testing code, benchmarks, or explanatory comments. + +Generate the complete optimized kernel implementation: