diff --git a/build2cmake/src/config/mod.rs b/build2cmake/src/config/mod.rs index 58798273..ceb6e823 100644 --- a/build2cmake/src/config/mod.rs +++ b/build2cmake/src/config/mod.rs @@ -5,7 +5,7 @@ pub mod v1; mod v2; use serde_value::Value; -pub use v2::{Backend, Build, Dependencies, Kernel, Torch}; +pub use v2::{Backend, Build, Dependency, General, Kernel, Torch}; #[derive(Debug)] pub enum BuildCompat { diff --git a/build2cmake/src/config/v1.rs b/build2cmake/src/config/v1.rs index 60ea2da3..8b095bb7 100644 --- a/build2cmake/src/config/v1.rs +++ b/build2cmake/src/config/v1.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, fmt::Display, path::PathBuf}; use serde::Deserialize; -use super::v2::Dependencies; +use super::v2::Dependency; #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] @@ -40,7 +40,7 @@ pub struct Kernel { pub rocm_archs: Option>, #[serde(default)] pub language: Language, - pub depends: Vec, + pub depends: Vec, pub include: Option>, pub src: Vec, } diff --git a/build2cmake/src/config/v2.rs b/build2cmake/src/config/v2.rs index 59c16fc1..60b222f8 100644 --- a/build2cmake/src/config/v2.rs +++ b/build2cmake/src/config/v2.rs @@ -54,6 +54,8 @@ pub struct General { pub cuda_minver: Option, pub hub: Option, + + pub python_depends: Option>, } impl General { @@ -70,6 +72,22 @@ pub struct Hub { pub branch: Option, } +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub enum PythonDependency { + Einops, + NvidiaCutlassDsl, +} + +impl Display for PythonDependency { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PythonDependency::Einops => write!(f, "einops"), + PythonDependency::NvidiaCutlassDsl => write!(f, "nvidia-cutlass-dsl"), + } + } +} + #[derive(Debug, Deserialize, Clone, Serialize)] #[serde(deny_unknown_fields)] pub struct Torch { @@ -107,7 +125,7 @@ pub enum Kernel { #[serde(rename_all = "kebab-case")] Cpu { cxx_flags: Option>, - depends: Vec, + depends: Vec, include: Option>, src: Vec, }, @@ -117,21 +135,21 @@ pub enum Kernel { cuda_flags: Option>, cuda_minver: Option, cxx_flags: Option>, - depends: Vec, + depends: Vec, include: Option>, src: Vec, }, #[serde(rename_all = "kebab-case")] Metal { cxx_flags: Option>, - depends: Vec, + depends: Vec, include: Option>, src: Vec, }, #[serde(rename_all = "kebab-case")] Rocm { cxx_flags: Option>, - depends: Vec, + depends: Vec, rocm_archs: Option>, hip_flags: Option>, include: Option>, @@ -140,7 +158,7 @@ pub enum Kernel { #[serde(rename_all = "kebab-case")] Xpu { cxx_flags: Option>, - depends: Vec, + depends: Vec, sycl_flags: Option>, include: Option>, src: Vec, @@ -178,7 +196,7 @@ impl Kernel { } } - pub fn depends(&self) -> &[Dependencies] { + pub fn depends(&self) -> &[Dependency] { match self { Kernel::Cpu { depends, .. } | Kernel::Cuda { depends, .. } @@ -239,7 +257,7 @@ impl FromStr for Backend { #[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] #[non_exhaustive] #[serde(rename_all = "lowercase")] -pub enum Dependencies { +pub enum Dependency { #[serde(rename = "cutlass_2_10")] Cutlass2_10, #[serde(rename = "cutlass_3_5")] @@ -284,6 +302,7 @@ impl General { cuda_maxver: None, cuda_minver: None, hub: None, + python_depends: None, } } } diff --git a/build2cmake/src/templates/pyproject.toml b/build2cmake/src/templates/pyproject.toml index fa457aee..5785e836 100644 --- a/build2cmake/src/templates/pyproject.toml +++ b/build2cmake/src/templates/pyproject.toml @@ -6,5 +6,6 @@ requires = [ "setuptools>=61", "torch", "wheel", + {{python_dependencies}} ] build-backend = "setuptools.build_meta" diff --git a/build2cmake/src/templates/universal/pyproject.toml b/build2cmake/src/templates/universal/pyproject.toml index ce1ca054..e60657b9 100644 --- a/build2cmake/src/templates/universal/pyproject.toml +++ b/build2cmake/src/templates/universal/pyproject.toml @@ -2,7 +2,10 @@ name = "{{ name }}" version = "0.0.1" requires-python = ">= 3.9" -dependencies = ["torch>=2.4"] +dependencies = [ + "torch>=2.8", + {{python_dependencies}} +] [tool.setuptools] package-dir = { "" = "torch-ext" } diff --git a/build2cmake/src/torch/common.rs b/build2cmake/src/torch/common.rs new file mode 100644 index 00000000..ad5634d9 --- /dev/null +++ b/build2cmake/src/torch/common.rs @@ -0,0 +1,33 @@ +use eyre::{Context, Result}; +use itertools::Itertools; +use minijinja::{context, Environment}; + +use crate::{config::General, FileSet}; + +pub fn write_pyproject_toml( + env: &Environment, + general: &General, + file_set: &mut FileSet, +) -> Result<()> { + let writer = file_set.entry("pyproject.toml"); + + let python_dependencies = general + .python_depends + .as_ref() + .unwrap_or(&vec![]) + .iter() + .map(|d| format!("\"{d}\"")) + .join(", "); + + env.get_template("pyproject.toml") + .wrap_err("Cannot get pyproject.toml template")? + .render_to_write( + context! { + python_dependencies => python_dependencies, + }, + writer, + ) + .wrap_err("Cannot render kernel template")?; + + Ok(()) +} diff --git a/build2cmake/src/torch/cpu.rs b/build2cmake/src/torch/cpu.rs index a2ef77d6..9bec4bd3 100644 --- a/build2cmake/src/torch/cpu.rs +++ b/build2cmake/src/torch/cpu.rs @@ -4,7 +4,7 @@ use eyre::{bail, Context, Result}; use itertools::Itertools; use minijinja::{context, Environment}; -use super::kernel_ops_identifier; +use super::{common::write_pyproject_toml, kernel_ops_identifier}; use crate::{ config::{Build, Kernel, Torch}, fileset::FileSet, @@ -47,7 +47,7 @@ pub fn write_torch_ext_cpu( write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?; - write_pyproject_toml(env, &mut file_set)?; + write_pyproject_toml(env, &build.general, &mut file_set)?; write_torch_registration_macros(&mut file_set)?; @@ -209,17 +209,6 @@ fn write_ops_py( Ok(()) } -fn write_pyproject_toml(env: &Environment, file_set: &mut FileSet) -> Result<()> { - let writer = file_set.entry("pyproject.toml"); - - env.get_template("pyproject.toml") - .wrap_err("Cannot get pyproject.toml template")? - .render_to_write(context! {}, writer) - .wrap_err("Cannot render kernel template")?; - - Ok(()) -} - fn write_setup_py( env: &Environment, torch: &Torch, diff --git a/build2cmake/src/torch/cuda.rs b/build2cmake/src/torch/cuda.rs index 0a189104..faf3eb56 100644 --- a/build2cmake/src/torch/cuda.rs +++ b/build2cmake/src/torch/cuda.rs @@ -7,8 +7,9 @@ use eyre::{bail, Context, Result}; use itertools::Itertools; use minijinja::{context, Environment}; +use super::common::write_pyproject_toml; use super::kernel_ops_identifier; -use crate::config::{Backend, Build, Dependencies, Kernel, Torch}; +use crate::config::{Backend, Build, Dependency, Kernel, Torch}; use crate::version::Version; use crate::FileSet; @@ -60,7 +61,7 @@ pub fn write_torch_ext_cuda( write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?; - write_pyproject_toml(env, &mut file_set)?; + write_pyproject_toml(env, &build.general, &mut file_set)?; write_torch_registration_macros(&mut file_set)?; @@ -78,17 +79,6 @@ fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> { Ok(()) } -fn write_pyproject_toml(env: &Environment, file_set: &mut FileSet) -> Result<()> { - let writer = file_set.entry("pyproject.toml"); - - env.get_template("pyproject.toml") - .wrap_err("Cannot get pyproject.toml template")? - .render_to_write(context! {}, writer) - .wrap_err("Cannot render kernel template")?; - - Ok(()) -} - fn write_setup_py( env: &Environment, torch: &Torch, @@ -230,7 +220,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu for dep in deps { match dep { - Dependencies::Cutlass2_10 => { + Dependency::Cutlass2_10 => { env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( @@ -241,7 +231,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu ) .wrap_err("Cannot render CUTLASS dependency template")?; } - Dependencies::Cutlass3_5 => { + Dependency::Cutlass3_5 => { env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( @@ -252,7 +242,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu ) .wrap_err("Cannot render CUTLASS dependency template")?; } - Dependencies::Cutlass3_6 => { + Dependency::Cutlass3_6 => { env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( @@ -263,7 +253,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu ) .wrap_err("Cannot render CUTLASS dependency template")?; } - Dependencies::Cutlass3_8 => { + Dependency::Cutlass3_8 => { env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( @@ -274,7 +264,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu ) .wrap_err("Cannot render CUTLASS dependency template")?; } - Dependencies::Cutlass3_9 => { + Dependency::Cutlass3_9 => { env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( @@ -285,7 +275,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu ) .wrap_err("Cannot render CUTLASS dependency template")?; } - Dependencies::Cutlass4_0 => { + Dependency::Cutlass4_0 => { env.get_template("cuda/dep-cutlass.cmake") .wrap_err("Cannot get CUTLASS dependency template")? .render_to_write( @@ -296,7 +286,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu ) .wrap_err("Cannot render CUTLASS dependency template")?; } - Dependencies::Torch => (), + Dependency::Torch => (), _ => { eprintln!("Warning: CUDA backend doesn't need/support dependency: {dep:?}"); } diff --git a/build2cmake/src/torch/metal.rs b/build2cmake/src/torch/metal.rs index 0d6198b8..ad09ac65 100644 --- a/build2cmake/src/torch/metal.rs +++ b/build2cmake/src/torch/metal.rs @@ -4,7 +4,7 @@ use eyre::{bail, Context, Result}; use itertools::Itertools; use minijinja::{context, Environment}; -use super::kernel_ops_identifier; +use super::{common::write_pyproject_toml, kernel_ops_identifier}; use crate::{ config::{Build, Kernel, Torch}, fileset::FileSet, @@ -49,7 +49,7 @@ pub fn write_torch_ext_metal( write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?; - write_pyproject_toml(env, &mut file_set)?; + write_pyproject_toml(env, &build.general, &mut file_set)?; write_torch_registration_macros(&mut file_set)?; @@ -225,17 +225,6 @@ fn write_ops_py( Ok(()) } -fn write_pyproject_toml(env: &Environment, file_set: &mut FileSet) -> Result<()> { - let writer = file_set.entry("pyproject.toml"); - - env.get_template("pyproject.toml") - .wrap_err("Cannot get pyproject.toml template")? - .render_to_write(context! {}, writer) - .wrap_err("Cannot render kernel template")?; - - Ok(()) -} - fn write_setup_py( env: &Environment, torch: &Torch, diff --git a/build2cmake/src/torch/mod.rs b/build2cmake/src/torch/mod.rs index 3637c53a..bc5ba1e2 100644 --- a/build2cmake/src/torch/mod.rs +++ b/build2cmake/src/torch/mod.rs @@ -4,6 +4,8 @@ pub use cpu::write_torch_ext_cpu; mod cuda; pub use cuda::write_torch_ext_cuda; +pub mod common; + mod metal; pub use metal::write_torch_ext_metal; diff --git a/build2cmake/src/torch/universal.rs b/build2cmake/src/torch/universal.rs index b52525a5..4622a234 100644 --- a/build2cmake/src/torch/universal.rs +++ b/build2cmake/src/torch/universal.rs @@ -1,10 +1,11 @@ use std::path::PathBuf; use eyre::{Context, Result}; +use itertools::Itertools; use minijinja::{context, Environment}; use crate::{ - config::{Build, Torch}, + config::{Build, General, Torch}, fileset::FileSet, torch::kernel_ops_identifier, }; @@ -20,12 +21,7 @@ pub fn write_torch_ext_universal( let ops_name = kernel_ops_identifier(&target_dir, &build.general.python_name(), ops_id); write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?; - write_pyproject_toml( - env, - build.torch.as_ref(), - &build.general.name, - &mut file_set, - )?; + write_pyproject_toml(env, build.torch.as_ref(), &build.general, &mut file_set)?; Ok(file_set) } @@ -58,18 +54,27 @@ fn write_ops_py( fn write_pyproject_toml( env: &Environment, torch: Option<&Torch>, - name: &str, + general: &General, file_set: &mut FileSet, ) -> Result<()> { let writer = file_set.entry("pyproject.toml"); + let name = &general.name; let data_globs = torch.and_then(|torch| torch.data_globs().map(|globs| globs.join(", "))); + let python_dependencies = general + .python_depends + .as_ref() + .unwrap_or(&vec![]) + .iter() + .map(|d| format!("\"{d}\"")) + .join(", "); env.get_template("universal/pyproject.toml") .wrap_err("Cannot get universal pyproject.toml template")? .render_to_write( context! { data_globs => data_globs, + python_dependencies => python_dependencies, name => name, }, writer, diff --git a/build2cmake/src/torch/xpu.rs b/build2cmake/src/torch/xpu.rs index a515180e..5f045b4e 100644 --- a/build2cmake/src/torch/xpu.rs +++ b/build2cmake/src/torch/xpu.rs @@ -6,8 +6,9 @@ use eyre::{bail, Context, Result}; use itertools::Itertools; use minijinja::{context, Environment}; +use super::common::write_pyproject_toml; use super::kernel_ops_identifier; -use crate::config::{Build, Dependencies, Kernel, Torch}; +use crate::config::{Build, Dependency, Kernel, Torch}; use crate::FileSet; static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); @@ -47,7 +48,7 @@ pub fn write_torch_ext_xpu( write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?; - write_pyproject_toml(env, &mut file_set)?; + write_pyproject_toml(env, &build.general, &mut file_set)?; write_torch_registration_macros(&mut file_set)?; @@ -65,17 +66,6 @@ fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> { Ok(()) } -fn write_pyproject_toml(env: &Environment, file_set: &mut FileSet) -> Result<()> { - let writer = file_set.entry("pyproject.toml"); - - env.get_template("pyproject.toml") - .wrap_err("Cannot get pyproject.toml template")? - .render_to_write(context! {}, writer) - .wrap_err("Cannot render pyproject.toml template")?; - - Ok(()) -} - fn write_setup_py( env: &Environment, torch: &Torch, @@ -196,11 +186,11 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu for dep in deps { match dep { - Dependencies::CutlassSycl => { + Dependency::CutlassSycl => { env.get_template("xpu/dep-cutlass-sycl.cmake")? .render_to_write(context! {}, &mut *write)?; } - Dependencies::Torch => (), + Dependency::Torch => (), _ => { // XPU supports CUTLASS-SYCL instead of CUTLASS eprintln!("Warning: XPU backend doesn't need/support dependency: {dep:?}"); diff --git a/docs/writing-kernels.md b/docs/writing-kernels.md index 559fd69b..a1a5942c 100644 --- a/docs/writing-kernels.md +++ b/docs/writing-kernels.md @@ -108,6 +108,9 @@ depends = [ "torch" ] build variants that are [required for compliant kernels](https://github.com/huggingface/kernels/blob/main/docs/kernel-requirements.md). This option is provided for kernels that require functionality only provided by newer CUDA toolkits. +- `python-depends` (**experimental**): a list of additional Python dependencies + that the kernel requires. The only supported dependencies are `einops` + and `nvidia-cutlass-dsl`. ### `torch` diff --git a/flake.lock b/flake.lock index 006f1995..2f7dd440 100644 --- a/flake.lock +++ b/flake.lock @@ -73,11 +73,11 @@ "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1762833384, - "narHash": "sha256-YPLcoqJBlYYGgmTG/J7bDx7lHqPnJ8beLs0+9OLfFhM=", + "lastModified": 1763454682, + "narHash": "sha256-XJuN1/aO8ZTfI959EIcb07nePS78MtDDwe+BTCu4PM4=", "owner": "huggingface", "repo": "hf-nix", - "rev": "752645bcda8793906249809319fa9b8dc11d7af6", + "rev": "dacf34fa85ad80d437d3a51fe95ea3637344208e", "type": "github" }, "original": { @@ -88,11 +88,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1762764791, - "narHash": "sha256-mWl8rYSYDFWD+zCR0VkBjEjD9jYj1/nlkDOfNNu44NA=", + "lastModified": 1763291491, + "narHash": "sha256-eEYvm+45PPmy+Qe+nZDpn1uhoMUjJwx3PwVVQoO9ksA=", "owner": "nixos", "repo": "nixpkgs", - "rev": "b549734f6b3ec54bb9a611a4185d11ee31f52ee1", + "rev": "c543a59edf25ada193719764f3bc0c6ba835f94d", "type": "github" }, "original": { diff --git a/lib/build.nix b/lib/build.nix index ae2e6dbe..5b6c6f3c 100644 --- a/lib/build.nix +++ b/lib/build.nix @@ -27,8 +27,6 @@ let inherit (import ./build-variants.nix { inherit lib; }) computeFramework; in rec { - resolveDeps = import ./deps.nix { inherit lib; }; - readToml = path: builtins.fromTOML (builtins.readFile path); validateBuildConfig = @@ -120,10 +118,12 @@ rec { kernels = lib.filterAttrs (_: kernel: computeFramework buildConfig == kernel.backend) ( buildToml.kernel or { } ); - extraDeps = resolveDeps { - inherit pkgs torch; - deps = lib.unique (lib.flatten (lib.mapAttrsToList (_: kernel: kernel.depends) kernels)); - }; + extraDeps = + let + inherit (import ./deps.nix { inherit lib pkgs torch; }) resolveCppDeps; + kernelDeps = lib.unique (lib.flatten (lib.mapAttrsToList (_: kernel: kernel.depends) kernels)); + in + resolveCppDeps kernelDeps; # Use the mkSourceSet function to get the source src = mkSourceSet path; @@ -146,6 +146,7 @@ rec { doGetKernelCheck ; kernelName = buildToml.general.name; + pythonDeps = buildToml.general.python-depends or [ ]; } else extension.mkExtension { @@ -160,6 +161,7 @@ rec { ; kernelName = buildToml.general.name; + pythonDeps = buildToml.general.python-depends or [ ]; doAbiCheck = true; }; @@ -237,23 +239,28 @@ rec { pkgs = buildSet.pkgs; rocmSupport = pkgs.config.rocmSupport or false; mkShell = pkgs.mkShell.override { inherit (buildSet.extension) stdenv; }; + extension = mkTorchExtension buildSet { inherit path rev doGetKernelCheck; }; in { name = buildName buildSet.buildConfig; value = mkShell { nativeBuildInputs = with pkgs; pythonNativeCheckInputs python3.pkgs; - buildInputs = - with pkgs; - [ - buildSet.torch - python3.pkgs.pytest - ] - ++ (pythonCheckInputs python3.pkgs); + buildInputs = with pkgs; [ + (python3.withPackages ( + ps: + with ps; + extension.dependencies + ++ pythonCheckInputs ps + ++ [ + buildSet.torch + pytest + ] + ++ pythonCheckInputs ps + )) + ]; shellHook = '' - export PYTHONPATH=''${PYTHONPATH}:${ - mkTorchExtension buildSet { inherit path rev doGetKernelCheck; } - } + export PYTHONPATH=''${PYTHONPATH}:${extension} ''; }; }; @@ -277,25 +284,51 @@ rec { rocmSupport = pkgs.config.rocmSupport or false; xpuSupport = pkgs.config.xpuSupport or false; mkShell = pkgs.mkShell.override { inherit (buildSet.extension) stdenv; }; + extension = mkTorchExtension buildSet { inherit path rev doGetKernelCheck; }; + python = ( + pkgs.python3.withPackages ( + ps: + with ps; + extension.dependencies + ++ pythonCheckInputs ps + ++ [ + buildSet.torch + pip + pytest + ] + ) + ); in { name = buildName buildSet.buildConfig; - value = mkShell { + value = mkShell rec { nativeBuildInputs = with pkgs; [ build2cmake kernel-abi-check - python3.pkgs.venvShellHook ] ++ (pythonNativeCheckInputs python3.pkgs); - buildInputs = with pkgs; [ python3.pkgs.pytest ] ++ (pythonCheckInputs python3.pkgs); - inputsFrom = [ (mkTorchExtension buildSet { inherit path rev doGetKernelCheck; }) ]; + buildInputs = [ python ]; + inputsFrom = [ extension ]; env = lib.optionalAttrs rocmSupport { PYTORCH_ROCM_ARCH = lib.concatStringsSep ";" buildSet.torch.rocmArchs; HIP_PATH = pkgs.rocmPackages.clr; }; + venvDir = "./.venv"; + + # We don't use venvShellHook because we want to use our wrapped + # Python interpreter. + shellHook = '' + if [ -d "${venvDir}" ]; then + echo "Skipping venv creation, '${venvDir}' already exists" + else + echo "Creating new venv environment in path: '${venvDir}'" + ${python}/bin/python -m venv --system-site-packages "${venvDir}" + fi + source "${venvDir}/bin/activate" + ''; }; }; in diff --git a/lib/deps.nix b/lib/deps.nix index 0da60528..d99206da 100644 --- a/lib/deps.nix +++ b/lib/deps.nix @@ -1,15 +1,11 @@ { lib, -}: - -{ pkgs, torch, - deps, }: let - knownDeps = with pkgs.cudaPackages; { + cppDeps = { "cutlass_2_10" = [ pkgs.cutlass_2_10 ]; @@ -30,18 +26,20 @@ let ]; "torch" = [ torch - #torch.cxxdev ]; "cutlass_sycl" = [ torch.xpuPackages.cutlass-sycl ]; "metal-cpp" = [ pkgs.metal-cpp.dev ]; }; + pythonDeps = with pkgs.python3.pkgs; { + "einops" = [ einops ]; + "nvidia-cutlass-dsl" = [ nvidia-cutlass-dsl ]; + }; + getCppDep = dep: cppDeps.${dep} or (throw "Unknown dependency: ${dep}"); + getPythonDep = dep: pythonDeps.${dep} or (throw "Unknown Python dependency: ${dep}"); in -let - depToPkg = - dep: - assert lib.assertMsg (builtins.hasAttr dep knownDeps) "Unknown dependency: ${dep}"; - knownDeps.${dep}; -in -lib.flatten (map depToPkg deps) +{ + resolveCppDeps = deps: lib.flatten (map getCppDep deps); + resolvePythonDeps = deps: lib.flatten (map getPythonDep deps); +} diff --git a/lib/torch-extension/arch.nix b/lib/torch-extension/arch.nix index f28f21ee..954e803a 100644 --- a/lib/torch-extension/arch.nix +++ b/lib/torch-extension/arch.nix @@ -4,7 +4,9 @@ xpuSupport ? torch.xpuSupport, lib, + pkgs, stdenv, + writeText, # Native build inputs build2cmake, @@ -49,6 +51,11 @@ nvccThreads, + # A stringly-typed list of Python dependencies. Ideally we'd take a + # list of derivations, but we also need to write the dependencies to + # the output. + pythonDeps, + # Wheter to strip rpath for non-nix use. stripRPath ? false, @@ -65,8 +72,18 @@ assert (buildConfig ? xpuVersion) -> xpuSupport; assert (buildConfig.metal or false) -> stdenv.hostPlatform.isDarwin; let + inherit (import ../deps.nix { inherit lib pkgs torch; }) resolvePythonDeps; + + dependencies = resolvePythonDeps pythonDeps ++ [ torch ]; + moduleName = builtins.replaceStrings [ "-" ] [ "_" ] kernelName; + metadata = builtins.toJSON { + python-depends = pythonDeps; + }; + + metadataFile = writeText "metadata.json" metadata; + # On Darwin, we need the host's xcrun for `xcrun metal` to compile Metal shaders. # It's not supported by the nixpkgs shim. xcrunHost = writeScriptBin "xcrunHost" '' @@ -139,7 +156,7 @@ stdenv.mkDerivation (prevAttrs: { remove-bytecode-hook ] ++ lib.optionals doGetKernelCheck [ - get-kernel-check + (get-kernel-check.override { python3 = python3.withPackages (ps: dependencies); }) ] ++ lib.optionals cudaSupport [ cmakeNvccThreadsHook @@ -241,6 +258,8 @@ stdenv.mkDerivation (prevAttrs: { # the updated kernels has been around for a while. mkdir $out/${moduleName} cp ${./compat.py} $out/${moduleName}/__init__.py + + cp ${metadataFile} $out/metadata.json '' + (lib.optionalString (stripRPath && stdenv.hostPlatform.isLinux)) '' find $out/ -name '*.so' \ @@ -261,6 +280,6 @@ stdenv.mkDerivation (prevAttrs: { __noChroot = metalSupport; passthru = { - inherit torch; + inherit dependencies torch; }; }) diff --git a/lib/torch-extension/no-arch.nix b/lib/torch-extension/no-arch.nix index 09dfa0e4..bae6471c 100644 --- a/lib/torch-extension/no-arch.nix +++ b/lib/torch-extension/no-arch.nix @@ -1,11 +1,15 @@ { lib, + pkgs, stdenv, build2cmake, get-kernel-check, kernel-layout-check, + python3, remove-bytecode-hook, + writeText, + torch, }: @@ -19,10 +23,21 @@ rev, src, + + # A stringly-typed list of Python dependencies. Ideally we'd take a + # list of derivations, but we also need to write the dependencies to + # the output. + pythonDeps, }: let + inherit (import ../deps.nix { inherit lib pkgs torch; }) resolvePythonDeps; + dependencies = resolvePythonDeps pythonDeps ++ [ torch ]; moduleName = builtins.replaceStrings [ "-" ] [ "_" ] kernelName; + metadata = builtins.toJSON { + python-depends = pythonDeps; + }; + metadataFile = writeText "metadata.json" metadata; in stdenv.mkDerivation (prevAttrs: { @@ -40,7 +55,7 @@ stdenv.mkDerivation (prevAttrs: { remove-bytecode-hook ] ++ lib.optionals doGetKernelCheck [ - get-kernel-check + (get-kernel-check.override { python3 = python3.withPackages (ps: dependencies); }) ]; dontBuild = true; @@ -57,7 +72,12 @@ stdenv.mkDerivation (prevAttrs: { cp -r torch-ext/${moduleName}/* $out/ mkdir $out/${moduleName} cp ${./compat.py} $out/${moduleName}/__init__.py + cp ${metadataFile} $out/metadata.json ''; doInstallCheck = true; + + passthru = { + inherit dependencies; + }; }) diff --git a/overlay.nix b/overlay.nix index ccd54c75..f9cb0467 100644 --- a/overlay.nix +++ b/overlay.nix @@ -21,6 +21,21 @@ final: prev: { pythonPackagesExtensions = prev.pythonPackagesExtensions ++ [ ( python-self: python-super: with python-self; { + cuda-bindings = python-self.callPackage ./pkgs/python-modules/cuda-bindings { }; + + cuda-pathfinder = python-self.callPackage ./pkgs/python-modules/cuda-pathfinder { }; + + # Starting with the CUDA 12.8 version, cuda-python is a metapackage + # that pulls in relevant dependencies. For CUDA 12.6 it is just + # cuda-bindings. + cuda-python = + if final.cudaPackages.cudaMajorMinorVersion == "12.6" then + python-self.cuda-bindings + else + python-self.callPackage ./pkgs/python-modules/cuda-python { }; + + nvidia-cutlass-dsl = python-self.callPackage ./pkgs/python-modules/nvidia-cutlass-dsl { }; + kernel-abi-check = callPackage ./pkgs/python-modules/kernel-abi-check { }; kernels = python-super.kernels.overrideAttrs (oldAttrs: { @@ -33,6 +48,8 @@ final: prev: { sha256 = "sha256-6N1W3jLQIS1yEAdNR2X9CuFdMw4Ia0yzBBVQ4Kujv8U="; }; }); + + pyclibrary = python-self.callPackage ./pkgs/python-modules/pyclibrary { }; } ) ]; diff --git a/pkgs/get-kernel-check/default.nix b/pkgs/get-kernel-check/default.nix index 892d88ab..51e49283 100644 --- a/pkgs/get-kernel-check/default.nix +++ b/pkgs/get-kernel-check/default.nix @@ -2,7 +2,8 @@ makeSetupHook { name = "get-kernel-check-hook"; - propagatedBuildInputs = [ - (python3.withPackages (ps: with ps; [ kernels ])) - ]; + substitutions = { + python3 = "${python3}/bin/python"; + kernels = "${with python3.pkgs; makePythonPath [ kernels ]}"; + }; } ./get-kernel-check-hook.sh diff --git a/pkgs/get-kernel-check/get-kernel-check-hook.sh b/pkgs/get-kernel-check/get-kernel-check-hook.sh index f78a3bee..381ed073 100755 --- a/pkgs/get-kernel-check/get-kernel-check-hook.sh +++ b/pkgs/get-kernel-check/get-kernel-check-hook.sh @@ -31,11 +31,13 @@ _getKernelCheckHook() { # Emulate the bundle layout that kernels expects. This even works # for universal kernels, since kernels checks the non-universal # path first. - BUILD_VARIANT=$(python -c "from kernels.utils import build_variant; print(build_variant())") + PYTHONPATH="@kernels@" \ + BUILD_VARIANT=$(@python3@ -c "from kernels.utils import build_variant; print(build_variant())") mkdir -p "${TMPDIR}/build" ln -s "$out" "${TMPDIR}/build/${BUILD_VARIANT}" - python -c "from pathlib import Path; import kernels; kernels.get_local_kernel(Path('${TMPDIR}'), '${moduleName}')" + PYTHONPATH="@kernels@" \ + @python3@ -c "from pathlib import Path; import kernels; kernels.get_local_kernel(Path('${TMPDIR}'), '${moduleName}')" } postInstallCheckHooks+=(_getKernelCheckHook) diff --git a/pkgs/python-modules/cuda-bindings/default.nix b/pkgs/python-modules/cuda-bindings/default.nix new file mode 100644 index 00000000..cb800e3f --- /dev/null +++ b/pkgs/python-modules/cuda-bindings/default.nix @@ -0,0 +1,90 @@ +{ + lib, + buildPythonPackage, + fetchFromGitHub, + symlinkJoin, + + autoAddDriverRunpath, + cython, + pyclibrary, + setuptools, + + cuda-pathfinder, + cudaPackages, + versioneer, +}: +let + outpaths = + with cudaPackages; + [ + cuda_cudart + cuda_nvcc + cuda_nvrtc + cuda_profiler_api + libcufile + ] + ++ lib.optionals (cudaAtLeast "13.0") [ cuda_crt ]; + + cudatoolkit_joined = symlinkJoin { + name = "cudatoolkit-joined-${cudaPackages.cudaMajorMinorVersion}"; + paths = + outpaths ++ lib.concatMap (outpath: lib.map (output: outpath.${output}) outpath.outputs) outpaths; + }; + + versionHashes = { + "12.6" = { + version = "12.6.2.post1"; + hash = "sha256-MG6q+Hyo0H4XKZLbtFQqfen6T2gxWzyk1M9jWryjjj4="; + }; + "12.8" = { + version = "12.8.0"; + hash = "sha256-7e9w70KkC6Pcvyu6Cwt5Asrc3W9TgsjiGvArRTer6Oc="; + }; + "12.9" = { + version = "12.9.4"; + hash = "sha256-eqdBBlcfuVCFNl0osKV4lfH0QjWxdyThTDLhEFZrPKM="; + }; + "13.0" = { + version = "13.0.3"; + hash = "sha256-Uq1oQWtilocQPh6cZ3P/L/L6caCHv17u1y67sm5fhhA="; + }; + }; + + versionHash = + versionHashes.${cudaPackages.cudaMajorMinorVersion} + or (throw "Unsupported CUDA version: ${cudaPackages.cudaMajorMinorVersion}"); + inherit (versionHash) hash version; + +in +buildPythonPackage { + pname = "cuda-bindings"; + inherit version; + pyproject = true; + + src = fetchFromGitHub { + owner = "NVIDIA"; + repo = "cuda-python"; + rev = "v${version}"; + inherit hash; + }; + + sourceRoot = "source/cuda_bindings"; + + build-system = [ + cython + pyclibrary + setuptools + versioneer + ]; + + dependencies = [ cuda-pathfinder ]; + + nativeBuildInputs = [ + autoAddDriverRunpath + cudaPackages.cuda_nvcc + ]; + + env.CUDA_HOME = cudatoolkit_joined; + + pythonImportsCheck = [ "cuda.bindings" ]; +} diff --git a/pkgs/python-modules/cuda-pathfinder/default.nix b/pkgs/python-modules/cuda-pathfinder/default.nix new file mode 100644 index 00000000..01db00c5 --- /dev/null +++ b/pkgs/python-modules/cuda-pathfinder/default.nix @@ -0,0 +1,24 @@ +{ + buildPythonPackage, + fetchFromGitHub, + setuptools, +}: + +buildPythonPackage rec { + pname = "cuda-pathfinder"; + version = "1.3.2"; + pyproject = true; + + src = fetchFromGitHub { + owner = "NVIDIA"; + repo = "cuda-python"; + rev = "${pname}-v${version}"; + hash = "sha256-hm/LoOVpJVKkOuKrBdHnYi1JMCNeB2ozAvz/N6RG0zU="; + }; + + sourceRoot = "source/cuda_pathfinder"; + + build-system = [ setuptools ]; + + pythonImportsCheck = [ "cuda.pathfinder" ]; +} diff --git a/pkgs/python-modules/cuda-python/default.nix b/pkgs/python-modules/cuda-python/default.nix new file mode 100644 index 00000000..786559a2 --- /dev/null +++ b/pkgs/python-modules/cuda-python/default.nix @@ -0,0 +1,38 @@ +{ + buildPythonPackage, + fetchFromGitHub, + + pythonRelaxDepsHook, + setuptools, + + cuda-bindings, + cuda-pathfinder, +}: + +buildPythonPackage rec { + pname = "cuda-python"; + version = "13.0.3"; + pyproject = true; + + src = fetchFromGitHub { + owner = "NVIDIA"; + repo = "cuda-python"; + rev = "v${version}"; + hash = "sha256-Uq1oQWtilocQPh6cZ3P/L/L6caCHv17u1y67sm5fhhA="; + }; + + sourceRoot = "source/cuda_python"; + + nativeBuildInputs = [ + pythonRelaxDepsHook + ]; + + build-system = [ setuptools ]; + + dependencies = [ + cuda-bindings + cuda-pathfinder + ]; + + pythonRelaxDeps = [ "cuda-bindings" ]; +} diff --git a/pkgs/python-modules/nvidia-cutlass-dsl/default.nix b/pkgs/python-modules/nvidia-cutlass-dsl/default.nix new file mode 100644 index 00000000..488b8e25 --- /dev/null +++ b/pkgs/python-modules/nvidia-cutlass-dsl/default.nix @@ -0,0 +1,61 @@ +{ + stdenv, + fetchPypi, + python, + + buildPythonPackage, + autoPatchelfHook, + autoAddDriverRunpath, + pythonWheelDepsCheckHook, + + cudaPackages, + cuda-python, + numpy, + typing-extensions, +}: + +let + format = "wheel"; + pyShortVersion = "cp" + builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion; + hashes = { + cp313-x86_64-linux = "sha256-k1ZgSvyPYqrEZjSzoSuvjLPzpvLkTjmNz+bsmP8ajRs="; + }; + hash = + hashes."${pyShortVersion}-${stdenv.system}" + or (throw "Unsupported Python version: ${pyShortVersion}-${stdenv.system}"); + +in +buildPythonPackage rec { + pname = "nvidia-cutlass-dsl"; + version = "4.2.1"; + inherit format; + + src = fetchPypi { + pname = "nvidia_cutlass_dsl"; + python = pyShortVersion; + abi = pyShortVersion; + dist = pyShortVersion; + platform = "manylinux_2_28_${stdenv.hostPlatform.uname.processor}"; + inherit format hash version; + }; + + nativeBuildInputs = [ + autoAddDriverRunpath + autoPatchelfHook + pythonWheelDepsCheckHook + ]; + + dependencies = [ + cuda-python + numpy + typing-extensions + ]; + + autoPatchelfIgnoreMissingDeps = [ + "libcuda.so.1" + ]; + + meta = { + broken = !(cudaPackages.cudaAtLeast "12.8"); + }; +} diff --git a/pkgs/python-modules/pyclibrary/default.nix b/pkgs/python-modules/pyclibrary/default.nix new file mode 100644 index 00000000..8553d851 --- /dev/null +++ b/pkgs/python-modules/pyclibrary/default.nix @@ -0,0 +1,31 @@ +{ + buildPythonPackage, + fetchFromGitHub, + + setuptools, + setuptools-scm, + + pyparsing, +}: + +buildPythonPackage rec { + pname = "pyclibrary"; + version = "0.3.0"; + pyproject = true; + + src = fetchFromGitHub { + owner = "MatthieuDartiailh"; + repo = "pyclibrary"; + tag = version; + hash = "sha256-RyIbRySRWSZwKP5G6yXYCOnfKOV0165aPyjMf3nSbOM="; + }; + + build-system = [ + setuptools + setuptools-scm + ]; + + dependencies = [ pyparsing ]; + + pythonImportsCheck = [ "pyclibrary" ]; +}