diff --git a/build2cmake/Cargo.lock b/build2cmake/Cargo.lock index 20f6a5a5..c4d239b0 100644 --- a/build2cmake/Cargo.lock +++ b/build2cmake/Cargo.lock @@ -84,6 +84,7 @@ dependencies = [ "serde", "serde-value", "serde_json", + "thiserror", "toml", ] @@ -691,6 +692,26 @@ dependencies = [ "syn", ] +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tinystr" version = "0.8.1" diff --git a/build2cmake/Cargo.toml b/build2cmake/Cargo.toml index 9399842d..ce3104a8 100644 --- a/build2cmake/Cargo.toml +++ b/build2cmake/Cargo.toml @@ -20,6 +20,7 @@ rand = "0.8" serde = { version = "1", features = ["derive"] } serde_json = "1" serde-value = "0.7" +thiserror = "1" toml = "0.8" [build-dependencies] diff --git a/build2cmake/src/config/v3.rs b/build2cmake/src/config/v3.rs index 5043bc04..14c25418 100644 --- a/build2cmake/src/config/v3.rs +++ b/build2cmake/src/config/v3.rs @@ -3,15 +3,75 @@ use std::{ fmt::Display, path::PathBuf, str::FromStr, + sync::LazyLock, }; use eyre::Result; use itertools::Itertools; use serde::{Deserialize, Serialize}; +use thiserror::Error; +use super::{common::Dependency, v2}; use crate::version::Version; -use super::{common::Dependency, v2}; +#[derive(Debug, Error)] +enum DependencyError { + #[error("No dependencies are defined for backend: {backend:?}")] + Backend { backend: String }, + #[error("Unknown dependency `{dependency:?}` for backend `{backend:?}`")] + Dependency { backend: String, dependency: String }, + #[error("Unknown dependency: `{dependency:?}`")] + GeneralDependency { dependency: String }, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +struct PythonDependencies { + general: HashMap, + backends: HashMap>, +} + +impl PythonDependencies { + fn get_dependency(&self, dependency: &str) -> Result<&[String], DependencyError> { + match self.general.get(dependency) { + None => Err(DependencyError::GeneralDependency { + dependency: dependency.to_string(), + }), + Some(dep) => Ok(&dep.python), + } + } + + fn get_backend_dependency( + &self, + backend: Backend, + dependency: &str, + ) -> Result<&[String], DependencyError> { + let backend_deps = match self.backends.get(&backend) { + None => { + return Err(DependencyError::Backend { + backend: backend.to_string(), + }) + } + Some(backend_deps) => backend_deps, + }; + match backend_deps.get(dependency) { + None => Err(DependencyError::Dependency { + backend: backend.to_string(), + dependency: dependency.to_string(), + }), + Some(dep) => Ok(&dep.python), + } + } +} + +#[derive(Debug, Deserialize, Serialize)] +struct PythonDependency { + nix: Vec, + python: Vec, +} + +static PYTHON_DEPENDENCIES: LazyLock = + LazyLock::new(|| serde_json::from_str(include_str!("../python_dependencies.json")).unwrap()); #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields)] @@ -44,7 +104,9 @@ pub struct General { pub hub: Option, - pub python_depends: Option>, + pub python_depends: Option>, + + pub xpu: Option, } impl General { @@ -52,6 +114,53 @@ impl General { pub fn python_name(&self) -> String { self.name.replace("-", "_") } + + pub fn python_depends(&self) -> Box> + '_> { + let general_python_deps = match self.python_depends.as_ref() { + Some(deps) => deps, + None => { + return Box::new(std::iter::empty()); + } + }; + + Box::new(general_python_deps.iter().flat_map(move |dep| { + match PYTHON_DEPENDENCIES.get_dependency(dep) { + Ok(deps) => deps.iter().map(|s| Ok(s.clone())).collect::>(), + Err(e) => vec![Err(e.into())], + } + })) + } + + pub fn backend_python_depends( + &self, + backend: Backend, + ) -> Box> + '_> { + let backend_python_deps = match backend { + Backend::Cuda => self + .cuda + .as_ref() + .and_then(|cuda| cuda.python_depends.as_ref()), + Backend::Xpu => self + .xpu + .as_ref() + .and_then(|xpu| xpu.python_depends.as_ref()), + _ => None, + }; + + let backend_python_deps = match backend_python_deps { + Some(deps) => deps, + None => { + return Box::new(std::iter::empty()); + } + }; + + Box::new(backend_python_deps.iter().flat_map(move |dep| { + match PYTHON_DEPENDENCIES.get_backend_dependency(backend, dep) { + Ok(deps) => deps.iter().map(|s| Ok(s.clone())).collect::>(), + Err(e) => vec![Err(e.into())], + } + })) + } } #[derive(Debug, Deserialize, Serialize)] @@ -59,29 +168,20 @@ impl General { pub struct CudaGeneral { pub minver: Option, pub maxver: Option, + pub python_depends: Option>, } #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case")] -pub struct Hub { - pub repo_id: Option, - pub branch: Option, +pub struct XpuGeneral { + pub python_depends: Option>, } -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case")] -pub enum PythonDependency { - Einops, - NvidiaCutlassDsl, -} - -impl Display for PythonDependency { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - PythonDependency::Einops => write!(f, "einops"), - PythonDependency::NvidiaCutlassDsl => write!(f, "nvidia-cutlass-dsl"), - } - } +pub struct Hub { + pub repo_id: Option, + pub branch: Option, } #[derive(Debug, Deserialize, Clone, Serialize)] @@ -215,7 +315,7 @@ impl Kernel { } } -#[derive(Clone, Copy, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize)] +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case")] pub enum Backend { Cpu, @@ -290,6 +390,7 @@ impl General { Some(CudaGeneral { minver: general.cuda_minver, maxver: general.cuda_maxver, + python_depends: None, }) } else { None @@ -300,9 +401,8 @@ impl General { backends, cuda, hub: general.hub.map(Into::into), - python_depends: general - .python_depends - .map(|deps| deps.into_iter().map(Into::into).collect()), + python_depends: None, + xpu: None, } } } @@ -316,15 +416,6 @@ impl From for Hub { } } -impl From for PythonDependency { - fn from(dep: v2::PythonDependency) -> Self { - match dep { - v2::PythonDependency::Einops => PythonDependency::Einops, - v2::PythonDependency::NvidiaCutlassDsl => PythonDependency::NvidiaCutlassDsl, - } - } -} - impl From for Torch { fn from(torch: v2::Torch) -> Self { Self { diff --git a/build2cmake/src/main.rs b/build2cmake/src/main.rs index b4223031..54c0619c 100644 --- a/build2cmake/src/main.rs +++ b/build2cmake/src/main.rs @@ -172,7 +172,7 @@ fn generate_torch( }; let file_set = if build.is_noarch() { - write_torch_ext_noarch(&env, &build, target_dir.clone(), ops_id)? + write_torch_ext_noarch(&env, backend, &build, target_dir.clone(), ops_id)? } else { match backend { Backend::Cpu => write_torch_ext_cpu(&env, &build, target_dir.clone(), ops_id)?, @@ -375,13 +375,11 @@ fn get_generated_files( ) -> Result> { let mut all_set = FileSet::new(); - if build.is_noarch() { - let set = write_torch_ext_noarch(env, build, target_dir.clone(), ops_id.clone())?; - - all_set.extend(set); - } else { - for backend in &build.general.backends { - let set = match backend { + for backend in &build.general.backends { + let set = if build.is_noarch() { + write_torch_ext_noarch(env, *backend, build, target_dir.clone(), ops_id.clone())? + } else { + match backend { Backend::Cpu => { write_torch_ext_cpu(env, build, target_dir.clone(), ops_id.clone())? } @@ -394,10 +392,9 @@ fn get_generated_files( Backend::Xpu => { write_torch_ext_xpu(env, build, target_dir.clone(), ops_id.clone())? } - }; - - all_set.extend(set); - } + } + }; + all_set.extend(set); } Ok(all_set.into_names()) diff --git a/build2cmake/src/python_dependencies.json b/build2cmake/src/python_dependencies.json new file mode 100644 index 00000000..25b5d9e3 --- /dev/null +++ b/build2cmake/src/python_dependencies.json @@ -0,0 +1,25 @@ +{ + "general": { + "einops": { + "nix": ["einops"], + "python": ["einops"] + } + }, + "backends": { + "cpu": {}, + "cuda": { + "nvidia-cutlass-dsl": { + "nix": ["nvidia-cutlass-dsl"], + "python": ["nvidia-cutlass-dsl"] + } + }, + "metal": {}, + "rocm": {}, + "xpu": { + "onednn": { + "nix": [], + "python": ["onednn-devel"] + } + } + } +} diff --git a/build2cmake/src/torch/common.rs b/build2cmake/src/torch/common.rs index ad5634d9..e3eea573 100644 --- a/build2cmake/src/torch/common.rs +++ b/build2cmake/src/torch/common.rs @@ -2,22 +2,23 @@ use eyre::{Context, Result}; use itertools::Itertools; use minijinja::{context, Environment}; -use crate::{config::General, FileSet}; +use crate::config::{Backend, General}; +use crate::FileSet; pub fn write_pyproject_toml( env: &Environment, + backend: Backend, general: &General, file_set: &mut FileSet, ) -> Result<()> { let writer = file_set.entry("pyproject.toml"); - let python_dependencies = general - .python_depends - .as_ref() - .unwrap_or(&vec![]) - .iter() - .map(|d| format!("\"{d}\"")) - .join(", "); + let python_dependencies = itertools::process_results( + general + .python_depends() + .chain(general.backend_python_depends(backend)), + |iter| iter.map(|d| format!("\"{d}\"")).join(", "), + )?; env.get_template("pyproject.toml") .wrap_err("Cannot get pyproject.toml template")? diff --git a/build2cmake/src/torch/cpu.rs b/build2cmake/src/torch/cpu.rs index 5b2829f4..56f5f0b2 100644 --- a/build2cmake/src/torch/cpu.rs +++ b/build2cmake/src/torch/cpu.rs @@ -6,7 +6,7 @@ use minijinja::{context, Environment}; use super::{common::write_pyproject_toml, kernel_ops_identifier}; use crate::{ - config::{Build, Kernel, Torch}, + config::{Backend, Build, Kernel, Torch}, fileset::FileSet, version::Version, }; @@ -48,7 +48,7 @@ pub fn write_torch_ext_cpu( write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?; - write_pyproject_toml(env, &build.general, &mut file_set)?; + write_pyproject_toml(env, Backend::Cpu, &build.general, &mut file_set)?; write_torch_registration_macros(&mut file_set)?; diff --git a/build2cmake/src/torch/cuda.rs b/build2cmake/src/torch/cuda.rs index e96432fe..11aedca1 100644 --- a/build2cmake/src/torch/cuda.rs +++ b/build2cmake/src/torch/cuda.rs @@ -61,7 +61,7 @@ pub fn write_torch_ext_cuda( write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?; - write_pyproject_toml(env, &build.general, &mut file_set)?; + write_pyproject_toml(env, backend, &build.general, &mut file_set)?; write_torch_registration_macros(&mut file_set)?; diff --git a/build2cmake/src/torch/metal.rs b/build2cmake/src/torch/metal.rs index b4317c68..1af2228b 100644 --- a/build2cmake/src/torch/metal.rs +++ b/build2cmake/src/torch/metal.rs @@ -6,7 +6,7 @@ use minijinja::{context, Environment}; use super::{common::write_pyproject_toml, kernel_ops_identifier}; use crate::{ - config::{Build, Kernel, Torch}, + config::{Backend, Build, Kernel, Torch}, fileset::FileSet, version::Version, }; @@ -50,7 +50,7 @@ pub fn write_torch_ext_metal( write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?; - write_pyproject_toml(env, &build.general, &mut file_set)?; + write_pyproject_toml(env, Backend::Metal, &build.general, &mut file_set)?; write_torch_registration_macros(&mut file_set)?; diff --git a/build2cmake/src/torch/noarch.rs b/build2cmake/src/torch/noarch.rs index 3fd441e1..902da62b 100644 --- a/build2cmake/src/torch/noarch.rs +++ b/build2cmake/src/torch/noarch.rs @@ -5,13 +5,14 @@ use itertools::Itertools; use minijinja::{context, Environment}; use crate::{ - config::{Build, General, Torch}, + config::{Backend, Build, General, Torch}, fileset::FileSet, torch::kernel_ops_identifier, }; pub fn write_torch_ext_noarch( env: &Environment, + backend: Backend, build: &Build, target_dir: PathBuf, ops_id: Option, @@ -21,7 +22,13 @@ pub fn write_torch_ext_noarch( let ops_name = kernel_ops_identifier(&target_dir, &build.general.python_name(), ops_id); write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?; - write_pyproject_toml(env, build.torch.as_ref(), &build.general, &mut file_set)?; + write_pyproject_toml( + env, + backend, + build.torch.as_ref(), + &build.general, + &mut file_set, + )?; Ok(file_set) } @@ -53,6 +60,7 @@ fn write_ops_py( fn write_pyproject_toml( env: &Environment, + backend: Backend, torch: Option<&Torch>, general: &General, file_set: &mut FileSet, @@ -61,13 +69,12 @@ fn write_pyproject_toml( let name = &general.name; let data_globs = torch.and_then(|torch| torch.data_globs().map(|globs| globs.join(", "))); - let python_dependencies = general - .python_depends - .as_ref() - .unwrap_or(&vec![]) - .iter() - .map(|d| format!("\"{d}\"")) - .join(", "); + let python_dependencies = itertools::process_results( + general + .python_depends() + .chain(general.backend_python_depends(backend)), + |iter| iter.map(|d| format!("\"{d}\"")).join(", "), + )?; env.get_template("noarch/pyproject.toml") .wrap_err("Cannot get noarch pyproject.toml template")? diff --git a/build2cmake/src/torch/xpu.rs b/build2cmake/src/torch/xpu.rs index 3ef285b1..73a04c50 100644 --- a/build2cmake/src/torch/xpu.rs +++ b/build2cmake/src/torch/xpu.rs @@ -8,7 +8,7 @@ use minijinja::{context, Environment}; use super::common::write_pyproject_toml; use super::kernel_ops_identifier; -use crate::config::{Build, Dependency, Kernel, Torch}; +use crate::config::{Backend, Build, Dependency, Kernel, Torch}; use crate::version::Version; use crate::FileSet; @@ -49,7 +49,7 @@ pub fn write_torch_ext_xpu( write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?; - write_pyproject_toml(env, &build.general, &mut file_set)?; + write_pyproject_toml(env, Backend::Xpu, &build.general, &mut file_set)?; write_torch_registration_macros(&mut file_set)?; diff --git a/lib/build.nix b/lib/build.nix index 8d8a4dbe..32bdabb3 100644 --- a/lib/build.nix +++ b/lib/build.nix @@ -139,6 +139,8 @@ rec { _: kernel: builtins.length (kernel.cuda-capabilities or supportedCudaCapabilities) ) buildToml.kernel ); + pythonDeps = (buildToml.general.python-depends or [ ]); + backendPythonDeps = lib.attrByPath [ buildConfig.backend "python-depends" ] [ ] buildToml.general; in if !kernelBackends'.${buildConfig.backend} then # No compiled kernel files? Treat it as a noarch package. @@ -149,9 +151,10 @@ rec { src rev doGetKernelCheck + pythonDeps + backendPythonDeps ; kernelName = buildToml.general.name; - pythonDeps = buildToml.general.python-depends or [ ]; } else extension.mkExtension { @@ -163,10 +166,11 @@ rec { src stripRPath rev + pythonDeps + backendPythonDeps ; kernelName = buildToml.general.name; - pythonDeps = buildToml.general.python-depends or [ ]; doAbiCheck = true; }; diff --git a/lib/deps.nix b/lib/deps.nix index d99206da..d52cbcad 100644 --- a/lib/deps.nix +++ b/lib/deps.nix @@ -32,14 +32,38 @@ let pkgs.metal-cpp.dev ]; }; - pythonDeps = with pkgs.python3.pkgs; { - "einops" = [ einops ]; - "nvidia-cutlass-dsl" = [ nvidia-cutlass-dsl ]; - }; + + pythonDeps = + let + depsJson = builtins.fromJSON (builtins.readFile ../build2cmake/src/python_dependencies.json); + # Map the Nix package names to actual Nix packages. + updatePackage = _name: dep: dep // { nix = map (pkg: pkgs.python3.pkgs.${pkg}) dep.nix; }; + updateBackend = _backend: backendDeps: lib.mapAttrs updatePackage backendDeps; + in + depsJson + // { + general = lib.mapAttrs updatePackage depsJson.general; + backends = lib.mapAttrs updateBackend depsJson.backends; + }; + getCppDep = dep: cppDeps.${dep} or (throw "Unknown dependency: ${dep}"); - getPythonDep = dep: pythonDeps.${dep} or (throw "Unknown Python dependency: ${dep}"); + getPythonDep = + dep: lib.attrByPath [ "general" dep "nix" ] (throw "Unknown Python dependency: ${dep}") pythonDeps; + getBackendPythonDep = + backend: dep: + let + backendDeps = lib.attrByPath [ + "backends" + backend + ] (throw "Unknown backend: ${backend}") pythonDeps; + in + lib.attrByPath [ + dep + "nix" + ] (throw "Unknown Python dependency for backend `${backend}`: ${dep}") backendDeps; in { resolveCppDeps = deps: lib.flatten (map getCppDep deps); resolvePythonDeps = deps: lib.flatten (map getPythonDep deps); + resolveBackendPythonDeps = backend: deps: lib.flatten (map (getBackendPythonDep backend) deps); } diff --git a/lib/torch-extension/arch.nix b/lib/torch-extension/arch.nix index 301353fb..3d4c58af 100644 --- a/lib/torch-extension/arch.nix +++ b/lib/torch-extension/arch.nix @@ -56,6 +56,8 @@ # the output. pythonDeps, + backendPythonDeps, + # Wheter to strip rpath for non-nix use. stripRPath ? false, @@ -72,9 +74,12 @@ assert (buildConfig ? xpuVersion) -> xpuSupport; assert (buildConfig.metal or false) -> stdenv.hostPlatform.isDarwin; let - inherit (import ../deps.nix { inherit lib pkgs torch; }) resolvePythonDeps; + inherit (import ../deps.nix { inherit lib pkgs torch; }) resolvePythonDeps resolveBackendPythonDeps; - dependencies = resolvePythonDeps pythonDeps ++ [ torch ]; + dependencies = + resolvePythonDeps pythonDeps + ++ resolveBackendPythonDeps buildConfig.backend backendPythonDeps + ++ [ torch ]; moduleName = builtins.replaceStrings [ "-" ] [ "_" ] kernelName; diff --git a/lib/torch-extension/no-arch.nix b/lib/torch-extension/no-arch.nix index 56e291f7..36c5d87c 100644 --- a/lib/torch-extension/no-arch.nix +++ b/lib/torch-extension/no-arch.nix @@ -34,6 +34,8 @@ # list of derivations, but we also need to write the dependencies to # the output. pythonDeps, + + backendPythonDeps, }: # Extra validation - the environment should correspind to the build config. @@ -43,8 +45,11 @@ assert (buildConfig ? xpuVersion) -> xpuSupport; assert (buildConfig.metal or false) -> stdenv.hostPlatform.isDarwin; let - inherit (import ../deps.nix { inherit lib pkgs torch; }) resolvePythonDeps; - dependencies = resolvePythonDeps pythonDeps ++ [ torch ]; + inherit (import ../deps.nix { inherit lib pkgs torch; }) resolvePythonDeps resolveBackendPythonDeps; + dependencies = + resolvePythonDeps pythonDeps + ++ resolveBackendPythonDeps buildConfig.backend backendPythonDeps + ++ [ torch ]; moduleName = builtins.replaceStrings [ "-" ] [ "_" ] kernelName; metadata = builtins.toJSON { python-depends = pythonDeps; @@ -68,7 +73,9 @@ stdenv.mkDerivation (prevAttrs: { remove-bytecode-hook ] ++ lib.optionals doGetKernelCheck [ - (get-kernel-check.override { python3 = python3.withPackages (ps: dependencies); }) + (get-kernel-check.override { + python3 = python3.withPackages (_: dependencies); + }) ]; dontBuild = true; diff --git a/pkgs/build2cmake/default.nix b/pkgs/build2cmake/default.nix index 4a7fa5bf..846a3882 100644 --- a/pkgs/build2cmake/default.nix +++ b/pkgs/build2cmake/default.nix @@ -22,6 +22,7 @@ rustPlatform.buildRustPackage { || file.name == "pyproject.toml" || file.name == "pyproject_universal.toml" || file.name == "cuda_supported_archs.json" + || file.name == "python_dependencies.json" || (builtins.any file.hasExt [ "cmake" "h"