import time
import torch
from torch.cuda import nvtx
import cutlass
import cutlass.cute as cute
x_fake0 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (256, 4096), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake1 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (128, 2048), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake2 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (2048, 256), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake3 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (256, 128), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake4 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (128, 256), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake5 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (256, 1024), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake6 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (512, 2048), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake7 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (128, 4096), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake8 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (2048, 128), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x0 = torch.empty((256, 4096), dtype=torch.bfloat16, device='cuda')
x1 = torch.empty((128, 2048), dtype=torch.bfloat16, device='cuda')
x2 = torch.empty((2048, 256), dtype=torch.bfloat16, device='cuda')
x3 = torch.empty((256, 128), dtype=torch.bfloat16, device='cuda')
x4 = torch.empty((128, 256), dtype=torch.bfloat16, device='cuda')
x5 = torch.empty((256, 1024), dtype=torch.bfloat16, device='cuda')
x6 = torch.empty((512, 2048), dtype=torch.bfloat16, device='cuda')
x7 = torch.empty((128, 4096), dtype=torch.bfloat16, device='cuda')
x8 = torch.empty((2048, 128), dtype=torch.bfloat16, device='cuda')
class NoLaunchKernel:
@cute.jit
def __call__(self, x0, x1, x2, x3, x4, x5, x6, x7, x8):
return
class LaunchKernel:
@cute.jit
def __call__(self, x0, x1, x2, x3, x4, x5, x6, x7, x8):
self.kernel(x0, x1, x2, x3, x4, x5, x6, x7, x8).launch(
grid=(1,),
block=(1,),
)
@cute.kernel
def kernel(self, x0, x1, x2, x3, x4, x5, x6, x7, x8):
return
WARMUP = 100
TRIALS = 10
ITERS = 400
def bench(kernel_cls, with_alloc):
if kernel_cls is not None:
compiled = cute.compile(kernel_cls(), x_fake0, x_fake1, x_fake2, x_fake3, x_fake4, x_fake5, x_fake6, x_fake7, x_fake8, options="--enable-tvm-ffi")
else:
compiled = lambda *args : None
# Warmup (include the alloc path so caching allocator is hot)
for i in range(WARMUP):
if with_alloc:
torch.empty(1, dtype=torch.uint8, device='cuda')
compiled(x0, x1, x2, x3, x4, x5, x6, x7, x8)
torch.cuda.synchronize()
trial_us = []
torch.cuda.cudart().cudaProfilerStart()
for trial in range(TRIALS):
t0 = time.perf_counter_ns()
nvtx.range_push(f"Trial-{trial}")
for i in range(ITERS):
if with_alloc:
torch.empty(1, dtype=torch.uint8, device='cuda')
compiled(x0, x1, x2, x3, x4, x5, x6, x7, x8)
torch.cuda.synchronize()
nvtx.range_pop()
t1 = time.perf_counter_ns()
trial_us.append((t1 - t0) / ITERS / 1000)
torch.cuda.cudart().cudaProfilerStop()
mean = sum(trial_us) / len(trial_us)
var = sum((x - mean) ** 2 for x in trial_us) / (len(trial_us) - 1)
std = var ** 0.5
suffix = '+Alloc' if with_alloc else ''
name = "Nothing" if kernel_cls is None else kernel_cls.__name__
print(f'RESULT {name}{suffix} {mean} {std}', flush=True)
bench(None, True)
bench(NoLaunchKernel, False)
bench(LaunchKernel, False)
bench(NoLaunchKernel, True)
bench(LaunchKernel, True)
"Nothing+Alloc" means it doesn't call cute.jit and only does allocation, so I assume allocation along takes 1.64us
Then I would suppose if I add this time to "NoLaunchKernel" and "LaunchKernel", it would be about the same number from "NoLaunchKernel+Alloc" and "LaunchKernel+Alloc" and I don't even pass the allocated empty tensor to my compiled function. However, the actual duration is a lot longer which doesn't seem to make sense to me. They should be about 1.64+1.64=3.28us, and 1.64+3.37=5.01us or even less since it doesn't use the allocated tensor so they should be able to overlap instead of run sequentially/
Worth to mention that I also tried this on an AGX Thor on dlcluster and that yields a reasonable result:
Which component has the problem?
CuTe DSL
Bug Report
Describe the bug
If I allocate a tensor by
torch.empty, it makes calling a compiled cute.jit function slower even if that function doesn't do anything and doesn't use the allocated tensorSteps/Code to reproduce bug
Expected behavior
On GB200 this script outputs:
"Nothing+Alloc" means it doesn't call cute.jit and only does allocation, so I assume allocation along takes 1.64us
Then I would suppose if I add this time to "NoLaunchKernel" and "LaunchKernel", it would be about the same number from "NoLaunchKernel+Alloc" and "LaunchKernel+Alloc" and I don't even pass the allocated empty tensor to my compiled function. However, the actual duration is a lot longer which doesn't seem to make sense to me. They should be about 1.64+1.64=3.28us, and 1.64+3.37=5.01us or even less since it doesn't use the allocated tensor so they should be able to overlap instead of run sequentially/
Worth to mention that I also tried this on an AGX Thor on dlcluster and that yields a reasonable result:
Environment details (please complete the following information):
On NVIDIA's ptyche cluster
Using Docker container with image
gitlab-master.nvidia.com:5005/dl/transformerengine/transformerengine:main-pytorch-py3-devel-arm64Additional context
Add any other context about the problem here.