Skip to content
This repository was archived by the owner on Jan 27, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/relu-specific-torch/flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
{
torchVersion = "2.9";
cudaVersion = "12.8";
cxx11Abi = true;
systems = [
"x86_64-linux"
"aarch64-linux"
Expand Down
185 changes: 66 additions & 119 deletions lib/build-sets.nix
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,22 @@ let
in
builtins.map (buildConfig: buildConfig // { backend = backend buildConfig; }) systemBuildConfigs;

cudaVersions =
let
withCuda = builtins.filter (torchVersion: torchVersion ? cudaVersion) torchVersions;
in
lib.unique (builtins.map (torchVersion: torchVersion.cudaVersion) withCuda);

rocmVersions =
let
withRocm = builtins.filter (torchVersion: torchVersion ? rocmVersion) torchVersions;
in
lib.unique (builtins.map (torchVersion: torchVersion.rocmVersion) withRocm);

xpuVersions =
let
withXpu = builtins.filter (torchVersion: torchVersion ? xpuVersion) torchVersions;
in
lib.unique (builtins.map (torchVersion: torchVersion.xpuVersion) withXpu);

flattenVersion = version: lib.replaceStrings [ "." ] [ "_" ] (lib.versions.pad 2 version);

overlayForTorchVersion = torchVersion: sourceBuild: self: super: {
pythonPackagesExtensions = super.pythonPackagesExtensions ++ [
(
python-self: python-super: with python-self; {
torch =
if sourceBuild then
python-self."torch_${flattenVersion torchVersion}"
else
python-self."torch-bin_${flattenVersion torchVersion}";
}
)
];
};

# An overlay that overides CUDA to the given version.
overlayForCudaVersion = cudaVersion: self: super: {
cudaPackages = super."cudaPackages_${flattenVersion cudaVersion}";
Expand All @@ -57,6 +53,38 @@ let
overlayForXpuVersion = xpuVersion: self: super: {
xpuPackages = super."xpuPackages_${flattenVersion xpuVersion}";
};

backendConfig = {
cpu = {
allowUnfree = true;
};

cuda = {
allowUnfree = true;
cudaSupport = true;
};

metal = {
allowUnfree = true;
metalSupport = true;
};

rocm = {
allowUnfree = true;
rocmSupport = true;
};

xpu = {
allowUnfree = true;
xpuSupport = true;
};
};

xpuConfig = {
allowUnfree = true;
xpuSupport = true;
};

# Construct the nixpkgs package set for the given versions.
mkBuildSet =
buildConfig@{
Expand All @@ -67,34 +95,38 @@ let
rocmVersion ? null,
xpuVersion ? null,
torchVersion,
cxx11Abi,
system,
bundleBuild ? false,
sourceBuild ? false,
}:
let
pkgs =
backendOverlay =
if buildConfig.backend == "cpu" then
pkgsForCpu
[ ]
else if buildConfig.backend == "cuda" then
pkgsByCudaVer.${cudaVersion}
[ (overlayForCudaVersion buildConfig.cudaVersion) ]
else if buildConfig.backend == "rocm" then
pkgsByRocmVer.${rocmVersion}
[ (overlayForRocmVersion buildConfig.rocmVersion) ]
else if buildConfig.backend == "metal" then
pkgsForMetal
[ ]
else if buildConfig.backend == "xpu" then
pkgsByXpuVer.${xpuVersion}
[ (overlayForXpuVersion buildConfig.xpuVersion) ]
else
throw "No compute framework set in Torch version";
torch =
if sourceBuild then
pkgs.python3.pkgs."torch_${flattenVersion torchVersion}".override {
inherit cxx11Abi;
}
else
pkgs.python3.pkgs."torch-bin_${flattenVersion torchVersion}".override {
inherit cxx11Abi;
};
config =
backendConfig.${buildConfig.backend} or (throw "No backend config for ${buildConfig.backend}");

pkgs = import nixpkgs {
inherit config system;
overlays = [
overlay
]
++ backendOverlay
++ [ (overlayForTorchVersion torchVersion sourceBuild) ];
};

torch = pkgs.python3.pkgs.torch;

extension = pkgs.callPackage ./torch-extension { inherit torch; };
in
{
Expand All @@ -106,90 +138,5 @@ let
bundleBuild
;
};
pkgsForXpuVersions =
xpuVersions:
builtins.listToAttrs (
map (xpuVersion: {
name = xpuVersion;
value = import nixpkgs {
inherit system;
config = {
allowUnfree = true;
xpuSupport = true;
};
overlays = [
overlay
(overlayForXpuVersion xpuVersion)
];
};
}) xpuVersions
);
pkgsByXpuVer = pkgsForXpuVersions xpuVersions;

pkgsForMetal = import nixpkgs {
inherit system;
config = {
allowUnfree = true;
metalSupport = true;
};
overlays = [
overlay
];
};

pkgsForCpu = import nixpkgs {
inherit system;
config = {
allowUnfree = true;
};
overlays = [
overlay
];
};

# Instantiate nixpkgs for the given CUDA versions. Returns
# an attribute set like `{ "12.4" = <nixpkgs with 12.4>; ... }`.
pkgsForCudaVersions =
cudaVersions:
builtins.listToAttrs (
map (cudaVersion: {
name = cudaVersion;
value = import nixpkgs {
inherit system;
config = {
allowUnfree = true;
cudaSupport = true;
};
overlays = [
overlay
(overlayForCudaVersion cudaVersion)
];
};
}) cudaVersions
);

pkgsByCudaVer = pkgsForCudaVersions cudaVersions;

pkgsForRocmVersions =
rocmVersions:
builtins.listToAttrs (
map (rocmVersion: {
name = rocmVersion;
value = import nixpkgs {
inherit system;
config = {
allowUnfree = true;
rocmSupport = true;
};
overlays = [
overlay
(overlayForRocmVersion rocmVersion)
];
};
}) rocmVersions
);

pkgsByRocmVer = pkgsForRocmVersions rocmVersions;

in
map mkBuildSet (buildConfigs system)
4 changes: 2 additions & 2 deletions lib/build-variants.nix
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ rec {

buildName =
let
inherit (import ./version-utils.nix { inherit lib; }) abiString flattenVersion;
inherit (import ./version-utils.nix { inherit lib; }) flattenVersion;
computeString =
version:
if backend version == "cpu" then
Expand All @@ -29,7 +29,7 @@ rec {
if version.system == "aarch64-darwin" then
"torch${flattenVersion version.torchVersion}-${computeString version}-${version.system}"
else
"torch${flattenVersion version.torchVersion}-${abiString version.cxx11Abi}-${computeString version}-${version.system}";
"torch${flattenVersion version.torchVersion}-cxx11-${computeString version}-${version.system}";

# Build variants included in bundle builds.
buildVariants =
Expand Down
1 change: 0 additions & 1 deletion lib/build.nix
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
}:

let
abi = torch: if torch.passthru.cxx11Abi then "cxx11" else "cxx98";
supportedCudaCapabilities = builtins.fromJSON (
builtins.readFile ../build2cmake/src/cuda_supported_archs.json
);
Expand Down
1 change: 0 additions & 1 deletion lib/version-utils.nix
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,4 @@ in
{
flattenVersion =
version: lib.replaceStrings [ "." ] [ "" ] (versions.majorMinor (versions.pad 2 version));
abiString = cxx11Abi: if cxx11Abi then "cxx11" else "cxx98";
}
4 changes: 0 additions & 4 deletions pkgs/python-modules/torch/binary/generic.nix
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@
url,
hash,
version,
# Remove, needed for compat.
cxx11Abi ? true,

effectiveStdenv ? if cudaSupport then cudaPackages.backendStdenv else stdenv,
}:
Expand Down Expand Up @@ -322,7 +320,6 @@ buildPythonPackage {
inherit
cudaSupport
cudaPackages
cxx11Abi
rocmSupport
rocmPackages
xpuSupport
Expand All @@ -333,7 +330,6 @@ buildPythonPackage {
rocmArchs = if rocmSupport then supportedTorchRocmArchs else [ ];
}
// (callPackage ../variant.nix {
inherit cxx11Abi;
torchVersion = version;
});

Expand Down
16 changes: 0 additions & 16 deletions pkgs/python-modules/torch/binary/torch-versions.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,98 +2,82 @@
{
"torchVersion": "2.8.0",
"cudaVersion": "12.6",
"cxx11Abi": true,
"systems": ["x86_64-linux"]
},
{
"torchVersion": "2.8.0",
"cudaVersion": "12.8",
"cxx11Abi": true,
"systems": ["x86_64-linux"]
},
{
"torchVersion": "2.8.0",
"cudaVersion": "12.9",
"cxx11Abi": true,
"systems": ["x86_64-linux", "aarch64-linux"]
},
{
"torchVersion": "2.8.0",
"rocmVersion": "6.3",
"cxx11Abi": true,
"systems": ["x86_64-linux"]
},
{
"torchVersion": "2.8.0",
"rocmVersion": "6.4",
"cxx11Abi": true,
"systems": ["x86_64-linux"]
},
{
"torchVersion": "2.8.0",
"cxx11Abi": true,
"metal": true,
"systems": ["aarch64-darwin"]
},
{
"torchVersion": "2.8.0",
"cxx11Abi": true,
"cpu": true,
"systems": ["aarch64-linux", "x86_64-linux"]
},
{
"torchVersion": "2.8.0",
"xpuVersion": "2025.1.3",
"cxx11Abi": true,
"systems": ["x86_64-linux"]
},

{
"torchVersion": "2.9.0",
"cudaVersion": "12.6",
"cxx11Abi": true,
"systems": ["x86_64-linux", "aarch64-linux"]
},
{
"torchVersion": "2.9.0",
"cudaVersion": "12.8",
"cxx11Abi": true,
"systems": ["x86_64-linux", "aarch64-linux"]
},
{
"torchVersion": "2.9.0",
"cudaVersion": "13.0",
"cxx11Abi": true,
"systems": ["x86_64-linux", "aarch64-linux"]
},
{
"torchVersion": "2.9.0",
"rocmVersion": "6.3",
"cxx11Abi": true,
"systems": ["x86_64-linux"]
},
{
"torchVersion": "2.9.0",
"rocmVersion": "6.4",
"cxx11Abi": true,
"systems": ["x86_64-linux"]
},
{
"torchVersion": "2.9.0",
"cxx11Abi": true,
"metal": true,
"systems": ["aarch64-darwin"]
},
{
"torchVersion": "2.9.0",
"cxx11Abi": true,
"cpu": true,
"systems": ["aarch64-linux", "x86_64-linux"]
},
{
"torchVersion": "2.9.0",
"xpuVersion": "2025.2.1",
"cxx11Abi": true,
"systems": ["x86_64-linux"]
}
]
Loading
Loading