Conversation
danieldk
left a comment
There was a problem hiding this comment.
Really cool!
Added some comments. Maybe you can also copy over the test from the relu example?
| #[serde(rename = "cutlass_sycl")] | ||
| CutlassSycl, | ||
| #[serde(rename = "metal-cpp")] | ||
| MetalCpp, |
There was a problem hiding this comment.
Do we need a version like cutlass? My first guess is not, since on Mac we always have everything the latest, but I thought I'd check.
There was a problem hiding this comment.
good question! I agree we probably do not need a version here and we'll always prefer latest
examples/relu-metal-cpp/build.toml
Outdated
| "relu.cpp", | ||
| "metallib_loader.mm", | ||
| "relu_cpp.metal", | ||
| "common.h", |
There was a problem hiding this comment.
Would be nicer to have these in a directory, since it's an example, we want best practices :).
There was a problem hiding this comment.
totally agree, i've updated the folder struct and build.toml in latest
lib/torch-extension/arch.nix
Outdated
| # Build inputs | ||
| apple-sdk_26, | ||
| clr, | ||
| metal-cpp, |
danieldk
left a comment
There was a problem hiding this comment.
One more small comment (file to remove), CI should work again after a rebase on main.
test-kernel.py
Outdated
| @@ -0,0 +1,23 @@ | |||
| # /// script | |||
There was a problem hiding this comment.
I think this is another development file that can be removed? (I think I missed it the first time.)
There was a problem hiding this comment.
ahh I think I actually added it in 914ed28 thanks for catching!
removed
danieldk
left a comment
There was a problem hiding this comment.
Awesome, let's do it! (merging myself because I need to rebase my cutlass/deps PR)
This PR supersedes #291 --- This PR adds support for a new `metal-cpp` kernel dependency. This is a follow up to the metal-cpp support in hf nix: huggingface/hf-nix#128 and enables kernels to use the cpp headers to drive metal kernels. Changes: - adds dep to build2cmake - adds new relu-metal-cpp example - builds example in CI example usage ```bash cd examples/relu-metal-cpp nix build -L . cd ... uv run test_relu_metal_cpp.py ``` `test_relu_metal_cpp.py` ```python # /// script # requires-python = ">=3.10" # dependencies = ["kernels", "torch", "numpy"] # /// from kernels import get_local_kernel import torch from pathlib import Path relu = get_local_kernel(Path("examples/relu-metal-cpp/result"), "relu").relu input = torch.tensor([-1.0, -1.5, 0.0, 2.0, 3.5], device="mps", dtype=torch.float16) out = relu(input) ref = torch.relu(input) assert torch.allclose(out, ref), f"Float16 failed: {out} != {ref}" print(out.cpu().numpy()) print(ref.cpu().numpy()) print("PASS") ``` output ``` [0. 0. 0. 2. 3.5] [0. 0. 0. 2. 3.5] PASS ```
This PR supersedes #291
This PR adds support for a new
metal-cppkernel dependency. This is a follow up to the metal-cpp support in hf nix: huggingface/hf-nix#128 and enables kernels to use the cpp headers to drive metal kernels.Changes:
example usage
test_relu_metal_cpp.pyoutput