Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/deploy-wheel-to-pypi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
if: github.repository == 'NVIDIA/tilus'
runs-on: linux-amd64-gpu-l4-latest-1
container:
image: nvidia/cuda:12.6.2-devel-ubuntu22.04
image: nvidia/cuda:13.0.0-devel-ubuntu22.04
options: --gpus all
outputs:
wheel-path: ${{ steps.setup-and-install.outputs.wheel-path }}
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
if: github.repository == 'NVIDIA/tilus' && needs.check-changes.outputs.should_run_tests == 'true'
runs-on: linux-amd64-gpu-l4-latest-1
container:
image: nvidia/cuda:12.6.2-devel-ubuntu22.04
image: nvidia/cuda:13.0.0-devel-ubuntu22.04
options: --gpus all
steps:
- name: Checkout Repository
Expand Down Expand Up @@ -106,7 +106,7 @@ jobs:
- linux-amd64-gpu-l4-latest-1
runs-on: ${{ matrix.runner }}
container:
image: nvidia/cuda:12.6.2-devel-ubuntu22.04
image: nvidia/cuda:13.0.0-devel-ubuntu22.04
options: --gpus all
steps:
- name: Checkout Repository
Expand All @@ -128,7 +128,7 @@ jobs:
if: github.repository == 'NVIDIA/tilus' && needs.check-changes.outputs.should_run_examples == 'true'
runs-on: linux-amd64-gpu-l4-latest-1
container:
image: nvidia/cuda:12.6.2-devel-ubuntu22.04
image: nvidia/cuda:13.0.0-devel-ubuntu22.04
options: --gpus all
steps:
- name: Checkout Repository
Expand Down
32 changes: 28 additions & 4 deletions python/tilus/lang/instantiated_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,33 @@ def __call__(self, *args, **kwargs):

return ret

def compile(self, *args: Any, **kwargs: Any) -> JitInstance:
"""Compile the script for the given arguments without executing it.

This transpiles every schedule in the autotune space into a Program and builds each Program
to a shared library, but does not run the kernel and does not benchmark/persist a dispatch
choice. Useful in CI to validate that a kernel compiles for a target architecture (e.g.,
sm100a) on a machine that does not support running it. Combine with
:func:`tilus.target.scope` to override the build target.

Parameters
----------
args:
The positional arguments to ``__call__``.

kwargs:
The keyword arguments to ``__call__``.

Returns
-------
jit_instance: JitInstance
The JIT instance for the script with the given arguments. The compiled programs are
available as ``jit_instance.valid_programs`` and ``jit_instance.compiled_programs``.
"""
jit_instance = self._jit_instance_for(*args, **kwargs)
jit_instance.programs()
return jit_instance

def _jit_instance_for(self, *args: Any, **kwargs: Any) -> JitInstance:
if kwargs or self.with_default:
# we allow the user to pass the keyword arguments to the script instance, or use the default values
Expand All @@ -794,10 +821,7 @@ def _jit_instance_for(self, *args: Any, **kwargs: Any) -> JitInstance:
"The number of arguments should be {}, but got {}.".format(len(self.params.param_names), len(args))
)

# extract the JIT key and the tuning key
keys = extract_keys(args, self.const_params, self.tuning_params)

jit_key, tuning_key = keys
jit_key, _ = extract_keys(args, self.const_params, self.tuning_params)
jit_instance: Optional[JitInstance] = self.jit_instances.get(jit_key, None)
if jit_instance is None:
jit_instance = JitInstance(self.script_cls, self.params, self.build_options, self.schedules, jit_key)
Expand Down
8 changes: 4 additions & 4 deletions python/tilus/lang/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tilus.lang.modules.cuda import cuda

if TYPE_CHECKING:
from tilus.lang.instantiated_script import InstantiatedScript, JitInstance
from tilus.lang.instantiated_script import InstantiatedScript, JitInstance # noqa: F401

Int: TypeAlias = int | Expr

Expand Down Expand Up @@ -70,9 +70,9 @@ def __init__(self) -> None:
def __call__(self, *args, **kwargs):
raise RuntimeError("This method should never be called.")

def jit_instance_for(self, *args: object, **kwargs: object) -> JitInstance:
def compile(self, *args: object, **kwargs: object) -> JitInstance:
"""
Instantiate the script program with the specified arguments and keyword arguments.
Transpile and build the script for the given arguments without executing it.

Parameters
----------
Expand All @@ -86,7 +86,7 @@ def jit_instance_for(self, *args: object, **kwargs: object) -> JitInstance:
ret: JitInstance
The JIT instance for the script with given arguments.
"""
raise RuntimeError("This method should never be called. See InstantiatedScript.jit_instance instead.")
raise RuntimeError("This method should never be called. See InstantiatedScript.compile instead.")

# the following properties should only be access in the __call__ function
@property
Expand Down
30 changes: 29 additions & 1 deletion python/tilus/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional, Sequence, Tuple
from typing import Iterator, Optional, Sequence, Tuple


@dataclass(frozen=True)
Expand Down Expand Up @@ -293,6 +294,33 @@ def set_current_target(target: Target) -> None:
_target = target


@contextmanager
def scope(target: Target) -> Iterator[Target]:
"""Temporarily set the current compilation target.

Useful to compile a kernel for a specific architecture (e.g., sm100a) on a machine that does not
support running it. Restores the previous target on exit.

Parameters
----------
target: Target
The target to use within the scope.

Yields
------
target: Target
The target that is now active.
"""
global _target
assert isinstance(target, Target)
prev = _target
_target = target
try:
yield target
finally:
_target = prev


@functools.cache
def get_default_target() -> Target:
import torch
Expand Down
68 changes: 49 additions & 19 deletions python/tilus/testing/_requires.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,70 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
import functools
from typing import Any, Callable

from tilus.target import Target, get_current_target, nvgpu_sm80, nvgpu_sm90, nvgpu_sm100, nvgpu_sm100a
from tilus.target import Target, get_current_target, nvgpu_sm80, nvgpu_sm90, nvgpu_sm100, nvgpu_sm100a, scope


class _CompileOnlyDone(Exception):
"""Raised inside a compile-only test to short-circuit execution after a successful compile."""


def _requires(target: Target) -> Callable[[Callable], Callable]:
"""
Pytest fixture decorator that skips tests if the current GPU doesn't support the required architecture.
Pytest decorator that adapts test behavior to the current GPU.

If the current GPU supports the required target, the test runs unchanged.

Otherwise, the test runs in *compile-only* mode:
- The current compilation target is overridden to ``target`` for the duration of the test.
- The first ``InstantiatedScript.__call__`` invocation is redirected to
:py:meth:`InstantiatedScript.compile <tilus.InstantiatedScript.compile>`, which transpiles +
builds every schedule in the autotune space without running the kernel.
- After the compile succeeds, a sentinel exception is raised to short-circuit the rest of the
test body; the decorator catches the sentinel and treats the test as passed.

Parameters
----------
target : Target
The required target architecture. Examples include 'sm_90a', 'sm_80',
The required target architecture, e.g. ``nvgpu_sm100a``.
"""

def decorator(test_func):
import pytest

def decorator(test_func: Callable) -> Callable:
try:
required_target = target
current_target = get_current_target()
current_capability = current_target.properties.compute_capability
supports_target = current_target.supports(target)
except Exception:
# Could not determine the current target (e.g. no GPU available).
# Fall through to compile-only mode -- compilation does not need a runtime GPU.
supports_target = False

if not current_target.supports(required_target):
return pytest.mark.skip(
f"Test requires architecture {required_target}, but current GPU capability is {current_capability}"
)(test_func)
if supports_target:
return test_func
except ValueError as e:
# If we can't parse the architecture string, skip the test
return pytest.mark.skip(f"Invalid architecture requirement: {e}")(test_func)
except Exception as e:
# If we can't determine current capability, skip the test
return pytest.mark.skip(f"Cannot determine current GPU capability: {e}")(test_func)

@functools.wraps(test_func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# Importing here avoids a top-level circular import: tilus.testing is imported eagerly
# by tests but tilus.lang.instantiated_script depends on the rest of the package.
from tilus.lang.instantiated_script import InstantiatedScript

original_call = InstantiatedScript.__call__

def compile_only_call(self: InstantiatedScript, *call_args: Any, **call_kwargs: Any) -> Any:
self.compile(*call_args, **call_kwargs)
raise _CompileOnlyDone()

InstantiatedScript.__call__ = compile_only_call # type: ignore[method-assign]
try:
with scope(target):
test_func(*args, **kwargs)
except _CompileOnlyDone:
pass
finally:
InstantiatedScript.__call__ = original_call # type: ignore[method-assign]

return wrapper

return decorator

Expand Down
2 changes: 1 addition & 1 deletion tests/instructions/test_cluster_launch_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __call__(self, n: int32, p_out: ~int32) -> None:
break


@tilus.testing.requires.nvgpu_sm100
@tilus.testing.requires.nvgpu_sm100a
@pytest.mark.parametrize("cluster_blocks", [2, 4])
@pytest.mark.parametrize("num_stages", [2, 3, 4])
@pytest.mark.parametrize("warps", [4, 8])
Expand Down
2 changes: 1 addition & 1 deletion tests/instructions/test_copy_async_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __call__(self, m_size: int32, n_size: int, x_ptr: ~float16, y_ptr: ~float16)
self.tma.wait_group(0)


@requires.nvgpu_sm90
@requires.nvgpu_sm100a
def test_copy_async_tensor_cta():
m = 123
n = 64 * 8
Expand Down
Loading