diff --git a/build2cmake/src/torch/cuda.rs b/build2cmake/src/torch/cuda.rs index e96432fe..5b821b9b 100644 --- a/build2cmake/src/torch/cuda.rs +++ b/build2cmake/src/torch/cuda.rs @@ -173,7 +173,7 @@ fn write_cmake( cmake_writer, )?; - render_deps(env, build, cmake_writer)?; + render_deps(env, backend, build, cmake_writer)?; render_binding(env, torch, name, cmake_writer)?; @@ -213,10 +213,19 @@ pub fn render_binding( Ok(()) } -fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Result<()> { +fn render_deps( + env: &Environment, + backend: Backend, + build: &Build, + write: &mut impl Write, +) -> Result<()> { let mut deps = HashSet::new(); - for kernel in build.kernels.values() { + for kernel in build + .kernels + .values() + .filter(|kernel| kernel.backend() == backend) + { deps.extend(kernel.depends()); } diff --git a/build2cmake/src/torch/xpu.rs b/build2cmake/src/torch/xpu.rs index 3ef285b1..9790e4d1 100644 --- a/build2cmake/src/torch/xpu.rs +++ b/build2cmake/src/torch/xpu.rs @@ -8,7 +8,7 @@ use minijinja::{context, Environment}; use super::common::write_pyproject_toml; use super::kernel_ops_identifier; -use crate::config::{Build, Dependency, Kernel, Torch}; +use crate::config::{Backend, Build, Dependency, Kernel, Torch}; use crate::version::Version; use crate::FileSet; @@ -144,7 +144,7 @@ fn write_cmake( cmake_writer, )?; - render_deps(env, build, cmake_writer)?; + render_deps(env, Backend::Xpu, build, cmake_writer)?; render_binding(env, torch, name, cmake_writer)?; @@ -184,10 +184,19 @@ fn render_binding( Ok(()) } -fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Result<()> { +fn render_deps( + env: &Environment, + backend: Backend, + build: &Build, + write: &mut impl Write, +) -> Result<()> { let mut deps = HashSet::new(); - for kernel in build.kernels.values() { + for kernel in build + .kernels + .values() + .filter(|kernel| kernel.backend() == backend) + { deps.extend(kernel.depends()); }