-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfix_args2.py
More file actions
43 lines (36 loc) · 2.43 KB
/
fix_args2.py
File metadata and controls
43 lines (36 loc) · 2.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import glob
import re
# 1. Fix OpenCL and HIP kernels to accept uint stride, uint mask, uint iterations
for filepath in glob.glob("kernels/opencl/cachebw*.cl") + glob.glob("kernels/opencl/cache_latency*.cl") + \
glob.glob("hip_kernels/cachebw*.hip") + glob.glob("hip_kernels/cache_latency*.hip") + \
glob.glob("hip_kernels/cache_bw_robust*.hip") + glob.glob("kernels/opencl/cache_bw_robust*.cl"):
with open(filepath, 'r') as f:
content = f.read()
# OpenCL
content = re.sub(r'typedef struct \{.*?\} PushConstants;\s*', '', content, flags=re.DOTALL)
content = content.replace("PushConstants pc", "uint stride, uint mask, uint iterations")
# ROCm HIP
content = re.sub(r'struct PushConstants \{.*?uint32_t padding;\n\};\s*', '', content, flags=re.DOTALL)
content = content.replace("uint32_t stride, uint32_t mask, uint32_t iterations", "uint stride, uint mask, uint iterations")
content = content.replace("pc.stride", "stride")
content = content.replace("pc.mask", "mask")
content = content.replace("pc.iterations", "iterations")
# Force variables usage for DCE
if "cache_latency" in filepath:
if "uint i = 0; i < 1024" in content or "int i = 0; i < 1024" in content or "int i = 0; i < 1000000" in content:
content = re.sub(r'(unroll\n\s*)?for\s*\([^;]+;\s*i\s*<\s*\d+;\s*\+\+i\)', r'\1for (uint i = 0; i < iterations; i++)', content)
if "stride == 0xFFFFFFFF" not in content:
if "if (val > 0)" in content:
content = content.replace("if (val > 0)", "if (stride == 0xFFFFFFFF) { data[1] = mask; }\n if (val > 0)")
elif "data[0] = index;" in content:
content = content.replace("data[0] = index;", "if (stride == 0xFFFFFFFF) { data[1] = mask; }\n data[0] = index;")
elif "buffer[0] = val;" in content:
content = content.replace("buffer[0] = val;", "if (stride == 0xFFFFFFFF) { buffer[1] = mask; }\n buffer[0] = val;")
elif "cache_bw_robust" in filepath:
if "if (final_sum.x < -1e30f)" in content:
content = content.replace("if (final_sum.x < -1e30f)", "if (stride == 0xFFFFFFFF)")
if "if (final_x < -1e30f)" in content:
content = content.replace("if (final_x < -1e30f)", "if (stride == 0xFFFFFFFF)")
with open(filepath, 'w') as f:
f.write(content)
print("Applied CacheBench arg splits and DCE fixes to kernels")