diff --git a/lib/gen-flake-outputs.nix b/lib/gen-flake-outputs.nix index 449b13c8..4144a311 100644 --- a/lib/gen-flake-outputs.nix +++ b/lib/gen-flake-outputs.nix @@ -40,6 +40,21 @@ let applicableBuildSets = build.applicableBuildSets { inherit path buildSets; }; + buildConfigBackend = + buildConfig: + if buildConfig.cpu or false then + "cpu" + else if buildConfig ? cudaVersion then + "cuda" + else if buildConfig ? rocmVersion then + "rocm" + else if buildConfig ? xpuVersion then + "xpu" + else if buildConfig.metal or false then + "metal" + else + throw "Cannot determine framework for build set"; + # For picking a default shell, etc. we want to use the following logic: # # - Prefer bundle builds over non-bundle builds. @@ -60,19 +75,7 @@ let buildConfig // { bundleBuild = buildConfig.bundleBuild or false; - framework = - if buildConfig.cpu or false then - "cpu" - else if buildConfig ? cudaVersion then - "cuda" - else if buildConfig ? rocmVersion then - "rocm" - else if buildConfig ? xpuVersion then - "xpu" - else if buildConfig.metal or false then - "metal" - else - throw "Cannot determine framework for build set"; + framework = buildConfigBackend buildConfig; frameworkOrder = if buildConfig ? cudaVersion then 0 else 1; frameworkVersion = buildConfig.cudaVersion or buildConfig.rocmVersion or buildConfig.xpuVersion or "0.0"; @@ -138,6 +141,24 @@ in { inherit bundle; + # Bundles by backend. + backendBundle = + let + backends = lib.unique (map (set: buildConfigBackend set.buildConfig) applicableBuildSets); + in + builtins.listToAttrs ( + builtins.map (backend: { + name = backend; + value = build.mkTorchExtensionBundle { + inherit path doGetKernelCheck; + buildSets = builtins.filter ( + set: buildConfigBackend set.buildConfig == backend + ) applicableBuildSets; + rev = revUnderscored; + }; + }) backends + ); + default = bundle; build-and-copy = writeScriptBin "build-and-copy" ''