Skip to content

[BUG] Allocating tensor makes calling cute.jit a lot slower even it doesn't use the allocated tensor #3275

@kainzhong

Description

@kainzhong

Which component has the problem?

CuTe DSL

Bug Report

Describe the bug
If I allocate a tensor by torch.empty, it makes calling a compiled cute.jit function slower even if that function doesn't do anything and doesn't use the allocated tensor

Steps/Code to reproduce bug

import time
import torch
from torch.cuda import nvtx

import cutlass
import cutlass.cute as cute

x_fake0 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (256, 4096), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake1 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (128, 2048), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake2 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (2048, 256), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake3 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (256, 128), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake4 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (128, 256), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake5 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (256, 1024), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake6 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (512, 2048), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake7 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (128, 4096), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x_fake8 = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (2048, 128), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16)
x0 = torch.empty((256, 4096), dtype=torch.bfloat16, device='cuda')
x1 = torch.empty((128, 2048), dtype=torch.bfloat16, device='cuda')
x2 = torch.empty((2048, 256), dtype=torch.bfloat16, device='cuda')
x3 = torch.empty((256, 128), dtype=torch.bfloat16, device='cuda')
x4 = torch.empty((128, 256), dtype=torch.bfloat16, device='cuda')
x5 = torch.empty((256, 1024), dtype=torch.bfloat16, device='cuda')
x6 = torch.empty((512, 2048), dtype=torch.bfloat16, device='cuda')
x7 = torch.empty((128, 4096), dtype=torch.bfloat16, device='cuda')
x8 = torch.empty((2048, 128), dtype=torch.bfloat16, device='cuda')

class NoLaunchKernel:

    @cute.jit
    def __call__(self, x0, x1, x2, x3, x4, x5, x6, x7, x8):
        return

class LaunchKernel:

    @cute.jit
    def __call__(self, x0, x1, x2, x3, x4, x5, x6, x7, x8):
        self.kernel(x0, x1, x2, x3, x4, x5, x6, x7, x8).launch(
            grid=(1,),
            block=(1,),
        )

    @cute.kernel
    def kernel(self, x0, x1, x2, x3, x4, x5, x6, x7, x8):
        return

WARMUP = 100
TRIALS = 10
ITERS = 400

def bench(kernel_cls, with_alloc):
    if kernel_cls is not None:
        compiled = cute.compile(kernel_cls(), x_fake0, x_fake1, x_fake2, x_fake3, x_fake4, x_fake5, x_fake6, x_fake7, x_fake8, options="--enable-tvm-ffi")
    else:
        compiled = lambda *args : None

    # Warmup (include the alloc path so caching allocator is hot)
    for i in range(WARMUP):
        if with_alloc:
            torch.empty(1, dtype=torch.uint8, device='cuda')
        compiled(x0, x1, x2, x3, x4, x5, x6, x7, x8)
    torch.cuda.synchronize()

    trial_us = []
    torch.cuda.cudart().cudaProfilerStart()
    for trial in range(TRIALS):
        t0 = time.perf_counter_ns()
        nvtx.range_push(f"Trial-{trial}")
        for i in range(ITERS):
            if with_alloc:
                torch.empty(1, dtype=torch.uint8, device='cuda')
            compiled(x0, x1, x2, x3, x4, x5, x6, x7, x8)
        torch.cuda.synchronize()
        nvtx.range_pop()
        t1 = time.perf_counter_ns()
        trial_us.append((t1 - t0) / ITERS / 1000)
    torch.cuda.cudart().cudaProfilerStop()

    mean = sum(trial_us) / len(trial_us)
    var = sum((x - mean) ** 2 for x in trial_us) / (len(trial_us) - 1)
    std = var ** 0.5
    suffix = '+Alloc' if with_alloc else ''
    name = "Nothing" if kernel_cls is None else kernel_cls.__name__
    print(f'RESULT {name}{suffix} {mean} {std}', flush=True)

bench(None, True)
bench(NoLaunchKernel, False)
bench(LaunchKernel,   False)
bench(NoLaunchKernel, True)
bench(LaunchKernel,   True)

Expected behavior
On GB200 this script outputs:

RESULT Nothing+Alloc 1.6487247500000002 0.11324282288056196
RESULT NoLaunchKernel 1.6442525 0.01118622689838812
RESULT LaunchKernel 3.37709725 0.048031188586312216
RESULT NoLaunchKernel+Alloc 5.747288 0.38881368030559027
RESULT LaunchKernel+Alloc 8.0602305 0.2851624457455979

"Nothing+Alloc" means it doesn't call cute.jit and only does allocation, so I assume allocation along takes 1.64us
Then I would suppose if I add this time to "NoLaunchKernel" and "LaunchKernel", it would be about the same number from "NoLaunchKernel+Alloc" and "LaunchKernel+Alloc" and I don't even pass the allocated empty tensor to my compiled function. However, the actual duration is a lot longer which doesn't seem to make sense to me. They should be about 1.64+1.64=3.28us, and 1.64+3.37=5.01us or even less since it doesn't use the allocated tensor so they should be able to overlap instead of run sequentially/

Worth to mention that I also tried this on an AGX Thor on dlcluster and that yields a reasonable result:

RESULT Nothing+Alloc 1.8182492499999998 0.029221814650331548
RESULT NoLaunchKernel 1.95853925 0.01587520056959913
RESULT LaunchKernel 4.24348425 0.0610531419708591
RESULT NoLaunchKernel+Alloc 3.7950477499999997 0.015231044660768708
RESULT LaunchKernel+Alloc 6.26364875 0.0763004097780063

Environment details (please complete the following information):
On NVIDIA's ptyche cluster
Using Docker container with image gitlab-master.nvidia.com:5005/dl/transformerengine/transformerengine:main-pytorch-py3-devel-arm64

Additional context
Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions