Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions cmd/nvidia-cdi-hook/cudacompat/cuda-elf-header.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"encoding/json"
"fmt"
"os"
"slices"
"strings"

"golang.org/x/mod/semver"
Expand Down Expand Up @@ -118,14 +119,31 @@ func getCUDAFwdCompatibilitySection(lib *elf.File) *elf.Section {

// UseCompat checks whether the CUDA compat libraries with the specified elf
// header should be used given the specified host versions.
// This is done by comparing the host CUDA version with the CUDA version
// specified in the ELF header.
func (h *compatElfHeader) UseCompat(hostCUDAVersion string) bool {
// If the host driver version is specified, we check if the driver version
// is supported in the ELF header. If no host driver version is provided, we
// fall back to checking the CUDA version specified in the ELF header.
func (h *compatElfHeader) UseCompat(compatDriverVersion string, hostDriverVersion string, hostCUDAVersion string) bool {
if h == nil {
return false
}

return h.CUDAVersion.UseCompat(hostCUDAVersion)
if compatDriverVersion == "" || hostDriverVersion == "" {
if hostCUDAVersion != "" {
return h.CUDAVersion.UseCompat(hostCUDAVersion)
}
return false
}

hostDriverMajor, err := extractMajorVersion(hostDriverVersion)
if err != nil {
return false
}

if !slices.Contains(h.Driver, hostDriverMajor) {
return false
}

return compareVersions(compatDriverVersion, hostDriverVersion) > 0
}

type cudaVersion string
Expand All @@ -137,9 +155,28 @@ func (containerVersion cudaVersion) UseCompat(hostVersion string) bool {
return false
}

return semver.Compare(normalizeVersion(containerVersion), normalizeVersion(hostVersion)) > 0
return compareVersions(containerVersion, hostVersion) > 0
}

func compareVersions[T string | cudaVersion, O string | cudaVersion](this T, other O) int {
return semver.Compare(normalizeVersion(this), normalizeVersion(other))
}

// normalizeVersion converts the given version into a valid semantic version.
// This function will always return a string in the format of vMAJOR.MINOR.PATCH
// It accounts for version strings that have leading zeros, which is common
// in NVIDIA driver version strings. For example, 570.211.01 will be converted to
// v570.22.1
Comment on lines +167 to +169
Copy link
Copy Markdown
Member

@elezar elezar Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we comment on why dropping leading zeros is important here? My assumption is that we need to ensure that 570.211.1 and 570.211.01 compare as equal, meaning that we may want to add a compareVersions function which calls normalizeVersion and has the relevant tests:

func compareVersions[T string | cudaVersion, O string | cudaVersion](this T, other O) int {
    return semver.Compare(normalizeVersion(this), normalizeVersion(other))
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Version strings with leading zeros are not valid semantic versions. And from https://pkg.go.dev/golang.org/x/mod/semver#Compare:

An invalid semantic version string is considered less than a valid one. All invalid semantic version strings compare equal to each other.

As a result, calling semver.Compare("575.57.08", "575.10.10") would incorrectly return -1 because the first argument was not a valid semantic version. I wanted to continue leveraging semver.Compare() so I decided to expand the logic in our normalizeVersion() method.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a compareVersions() method as you have suggested and added some unit tests.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question / comment from my side. Does it make sense to add this fix for incorrect version comparisons as a separate commit / PR independently of the program flow changes?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I have separated the commits in this PR.

func normalizeVersion[T string | cudaVersion](v T) string {
return "v" + strings.TrimPrefix(string(v), "v")
majorMinorPatch := []string{"0", "0", "0"}
versionParts := strings.SplitN(strings.TrimPrefix(string(v), "v"), ".", 3)
for i, versionPart := range versionParts {
trimmed := strings.TrimLeft(versionPart, "0")
if trimmed == "" {
trimmed = "0"
}
majorMinorPatch[i] = trimmed
}

return "v" + strings.Join(majorMinorPatch, ".")
}
211 changes: 207 additions & 4 deletions cmd/nvidia-cdi-hook/cudacompat/cuda-elf-header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ func TestGetCUDACompatElfHeader(t *testing.T) {
expected *compatElfHeader
}{
{
description: "wip",
filename: "libcuda.so.575.57.08",
description: "575.57.08",
filename: "575.57.08/libcuda.so.575.57.08",
expected: &compatElfHeader{
Format: 1,
CUDAVersion: "12.9",
Expand All @@ -48,8 +48,8 @@ func TestGetCUDACompatElfHeader(t *testing.T) {
},
},
{
description: "wip",
filename: "libcuda.so.590.44.01",
description: "590.44.01",
filename: "590.44.01/libcuda.so.590.44.01",
expected: &compatElfHeader{
Format: 1,
CUDAVersion: "13.1",
Expand All @@ -70,3 +70,206 @@ func TestGetCUDACompatElfHeader(t *testing.T) {
})
}
}

func TestUseCompat(t *testing.T) {
testCases := []struct {
description string
elfHeader *compatElfHeader
compatDriverVersion string
hostDriverVersion string
hostCudaVersion string
expected bool
}{
{
description: "container cuda version greater than host cuda version",
elfHeader: &compatElfHeader{
Format: 1,
CUDAVersion: "12.9",
Driver: []int{535, 550, 560, 565, 570, 575},
Device: []int{1, 2, 7, 8, 9, 10, 11, 12, 13, 14},
},
hostCudaVersion: "12.8",
expected: true,
},
{
description: "container cuda version same as host cuda version",
elfHeader: &compatElfHeader{
Format: 1,
CUDAVersion: "12.9",
Driver: []int{535, 550, 560, 565, 570, 575},
Device: []int{1, 2, 7, 8, 9, 10, 11, 12, 13, 14},
},
hostCudaVersion: "12.9",
expected: false,
},
{
description: "container cuda version less than host cuda version",
elfHeader: &compatElfHeader{
Format: 1,
CUDAVersion: "12.9",
Driver: []int{535, 550, 560, 565, 570, 575},
Device: []int{1, 2, 7, 8, 9, 10, 11, 12, 13, 14},
},
hostCudaVersion: "12.10",
expected: false,
},
{
description: "host driver branch not supported in compat elf header",
elfHeader: &compatElfHeader{
Format: 1,
CUDAVersion: "12.9",
Driver: []int{535, 550, 560, 565, 570, 575},
Device: []int{1, 2, 7, 8, 9, 10, 11, 12, 13, 14},
},
compatDriverVersion: "575.57.08",
hostDriverVersion: "590.44.01",
expected: false,
},
{
description: "host driver branch supported in compat elf header, host driver branch < compat driver branch",
elfHeader: &compatElfHeader{
Format: 1,
CUDAVersion: "12.9",
Driver: []int{535, 550, 560, 565, 570, 575},
Device: []int{1, 2, 7, 8, 9, 10, 11, 12, 13, 14},
},
compatDriverVersion: "575.57.08",
hostDriverVersion: "570.211.01",
expected: true,
},
{
description: "host driver branch same as compat driver branch, compat driver > host driver",
elfHeader: &compatElfHeader{
Format: 1,
CUDAVersion: "12.9",
Driver: []int{535, 550, 560, 565, 570, 575},
Device: []int{1, 2, 7, 8, 9, 10, 11, 12, 13, 14},
},
compatDriverVersion: "575.57.08",
hostDriverVersion: "575.10.10",
expected: true,
},
{
description: "host driver branch same as compat driver branch, compat driver = host driver",
elfHeader: &compatElfHeader{
Format: 1,
CUDAVersion: "12.9",
Driver: []int{535, 550, 560, 565, 570, 575},
Device: []int{1, 2, 7, 8, 9, 10, 11, 12, 13, 14},
},
compatDriverVersion: "575.57.08",
hostDriverVersion: "575.57.08",
expected: false,
},
{
description: "host driver branch same as compat driver branch, compat driver < host driver",
elfHeader: &compatElfHeader{
Format: 1,
CUDAVersion: "12.9",
Driver: []int{535, 550, 560, 565, 570, 575},
Device: []int{1, 2, 7, 8, 9, 10, 11, 12, 13, 14},
},
compatDriverVersion: "575.57.08",
hostDriverVersion: "575.99.99",
expected: false,
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
useCompat := tc.elfHeader.UseCompat(tc.compatDriverVersion, tc.hostDriverVersion, tc.hostCudaVersion)

require.EqualValues(t, tc.expected, useCompat)
})
}
}

func TestCompareVersions(t *testing.T) {
testCases := []struct {
description string
a string
b string
expected int
}{
{
description: "empty",
expected: 0,
},
{
description: "less than",
a: "1.2.3",
b: "2.4.5",
expected: -1,
},
{
description: "equal",
a: "1.1.1",
b: "1.1.1",
expected: 0,
},
{
description: "equal with leading zeros in version string",
a: "1.1.1",
b: "1.01.1",
expected: 0,
},
{
description: "greater than",
a: "2.4.5",
b: "2.4.4",
expected: 1,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
require.EqualValues(t, tc.expected, compareVersions(tc.a, tc.b))
})
}

}

func TestNormalizeVersion(t *testing.T) {
testCases := []struct {
description string
input string
expected string
}{
{
description: "empty",
input: "",
expected: "v0.0.0",
},
{
description: "major is 0",
input: "v0.1.2",
expected: "v0.1.2",
},
{
description: "major only",
input: "1",
expected: "v1.0.0",
},
{
description: "major and minor only",
input: "1.1",
expected: "v1.1.0",
},
{
description: "zero-padded version",
input: "01.02.03",
expected: "v1.2.3",
},
{
description: "valid semantic version",
input: "v1.2.3-4+567",
expected: "v1.2.3-4+567",
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
output := normalizeVersion(tc.input)
require.EqualValues(t, tc.expected, output)
})
}
}
21 changes: 13 additions & 8 deletions cmd/nvidia-cdi-hook/cudacompat/cudacompat.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,19 @@ func (m command) getContainerForwardCompatDir(containerRoot containerRoot, o *op
}

func (m command) useCompatLibraries(libcudaCompatPath string, hostDriverVersion string, hostCUDAVersion string) (bool, error) {
// First check the ELF header of the libcuda.so included in the compat directory.
// If this is present, we use the ELF header to determine whether the CUDA compat
// libraries in the container should be used over the host driver libraries.
compatDriverVersion := strings.TrimPrefix(filepath.Base(libcudaCompatPath), "libcuda.so.")
cudaCompatHeader, _ := GetCUDACompatElfHeader(libcudaCompatPath)
if cudaCompatHeader != nil {
return cudaCompatHeader.UseCompat(compatDriverVersion, hostDriverVersion, hostCUDAVersion), nil
}

// If the host CUDA version is specified, we need to inspect the ELF header
// of the compat libraries in the container to determine whether these
// should be used.
// should be used. Return early if we cannot read the ELF header.
if hostCUDAVersion != "" {
cudaCompatHeader, _ := GetCUDACompatElfHeader(libcudaCompatPath)
if cudaCompatHeader != nil {
return cudaCompatHeader.UseCompat(hostCUDAVersion), nil
}
// If we were unable to read the CUDA header, we do not use the compat
// libraries.
return false, nil
}

Expand All @@ -196,12 +199,14 @@ func (m command) useCompatLibraries(libcudaCompatPath string, hostDriverVersion
return false, nil
}

// If we reach this point, it means we could not read the ELf header but
// the host driver version is specified. We fall back to comparing the major
// versions of the host driver and compat driver.
driverMajor, err := extractMajorVersion(hostDriverVersion)
if err != nil {
return false, fmt.Errorf("failed to extract major version from %q: %v", hostDriverVersion, err)
}

compatDriverVersion := strings.TrimPrefix(filepath.Base(libcudaCompatPath), "libcuda.so.")
compatMajor, err := extractMajorVersion(compatDriverVersion)
if err != nil {
return false, fmt.Errorf("failed to extract major version from %q: %v", compatDriverVersion, err)
Expand Down
Loading
Loading