Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions cmd/nvidia-cdi-hook/check-requirements/check-requirements.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/**
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package checkrequirements

import (
"context"
"encoding/json"
"fmt"
"os"

"github.com/opencontainers/runtime-spec/specs-go"
"github.com/urfave/cli/v3"

"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/NVIDIA/nvidia-container-toolkit/internal/requirements"
)

type command struct {
logger logger.Interface
}

type options struct {
containerSpec string
driverRoot string
}

// NewCommand constructs a check-requirements command with the specified logger.
func NewCommand(logger logger.Interface) *cli.Command {
c := command{
logger: logger,
}
return c.build()
}

func (m command) build() *cli.Command {
cfg := options{}

return &cli.Command{
Name: "check-requirements",
Usage: "Check NVIDIA_REQUIRE_* constraints from the container image",
Action: func(_ context.Context, _ *cli.Command) error {
return m.run(&cfg)
},
Flags: []cli.Flag{
&cli.StringFlag{
Name: "driver-root",
Usage: "Specify the NVIDIA GPU driver root to use when detecting host properties",
Destination: &cfg.driverRoot,
},
&cli.StringFlag{
Name: "container-spec",
Hidden: true,
Category: "testing-only",
Usage: "Specify the path to the OCI container state. If empty or '-' the state will be read from STDIN",
Destination: &cfg.containerSpec,
},
},
}
}

func (m command) run(cfg *options) error {
cudaImage, err := loadCUDAImageFromState(cfg.containerSpec, m.logger)
if err != nil {
return fmt.Errorf("failed to load CUDA image from container state: %w", err)
}

driver := root.New(
root.WithLogger(m.logger),
root.WithDriverRoot(cfg.driverRoot),
)
if err := requirements.CheckImage(m.logger, cudaImage, driver); err != nil {
return fmt.Errorf("requirements not met: %w", err)
}
return nil
}

func loadCUDAImageFromState(containerStatePath string, logger logger.Interface) (*image.CUDA, error) {
state, err := oci.LoadContainerState(containerStatePath)
if err != nil {
return nil, fmt.Errorf("failed to load container state: %w", err)
}

specFilePath := oci.GetSpecFilePath(state.Bundle)
specFile, err := os.Open(specFilePath)
if err != nil {
return nil, fmt.Errorf("failed to open OCI spec file: %w", err)
}
defer specFile.Close()

var spec specs.Spec
if err := json.NewDecoder(specFile).Decode(&spec); err != nil {
return nil, fmt.Errorf("failed to decode OCI spec: %w", err)
}

cudaImage, err := image.NewCUDAImageFromSpec(
&spec,
image.WithLogger(logger),
)
if err != nil {
return nil, err
}
return &cudaImage, nil
}
2 changes: 2 additions & 0 deletions cmd/nvidia-cdi-hook/commands/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

"github.com/urfave/cli/v3"

"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/chmod"

Check failure on line 25 in cmd/nvidia-cdi-hook/commands/commands.go

View workflow job for this annotation

GitHub Actions / check

File is not properly formatted (gofmt)
checkrequirements "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/check-requirements"
symlinks "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/create-symlinks"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/cudacompat"
disabledevicenodemodification "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/disable-device-node-modification"
Expand Down Expand Up @@ -85,6 +86,7 @@

// Define the supported hooks.
base.Commands = []*cli.Command{
checkrequirements.NewCommand(logger),
ldcache.NewCommand(logger),
symlinks.NewCommand(logger),
chmod.NewCommand(logger),
Expand Down
3 changes: 3 additions & 0 deletions internal/discover/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ const (
//
// Deprecated: The chmod hook is deprecated and will be removed in a future release.
ChmodHook = HookName("chmod")
// A CheckRequirementsHook is used to enforce NVIDIA_REQUIRE_* constraints
// from the container image.
CheckRequirementsHook = HookName("check-requirements")
// A CreateSymlinksHook is used to create symlinks in the container.
CreateSymlinksHook = HookName("create-symlinks")
// DisableDeviceNodeModificationHook refers to the hook used to ensure that
Expand Down
43 changes: 43 additions & 0 deletions internal/discover/requirements.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/**
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package discover

// CheckRequirementsHookOptions defines the options that can be specified when
// creating the check-requirements hook.
type CheckRequirementsHookOptions struct {
DriverRoot string
}

// NewCheckRequirementsHookDiscoverer creates a discoverer for a
// check-requirements hook.
func NewCheckRequirementsHookDiscoverer(hookCreator HookCreator, o *CheckRequirementsHookOptions) Discover {
hook := hookCreator.Create(CheckRequirementsHook, o.args()...)
if hook == nil {
return None{}
}
return hook
}

func (o *CheckRequirementsHookOptions) args() []string {
if o == nil {
return nil
}
if o.DriverRoot == "" || o.DriverRoot == "/" {
return nil
}
return []string{"--driver-root=" + o.DriverRoot}
}
9 changes: 9 additions & 0 deletions internal/modifier/cdi.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/NVIDIA/nvidia-container-toolkit/internal/requirements"

Check failure on line 29 in internal/modifier/cdi.go

View workflow job for this annotation

GitHub Actions / check

File is not properly formatted (gofmt)
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
)
Expand All @@ -51,6 +52,14 @@
defaultKind,
)
devices := deviceRequestor.DeviceRequests()

// Run before the empty-device return so NVIDIA_REQUIRE_* is still enforced when
// len(devices)==0 (e.g. CRI CDI injection without matching spec signals). When
// there are no requirements, checkRequirements returns immediately.
if err := requirements.CheckImage(f.logger, f.image, f.driver); err != nil {
return nil, fmt.Errorf("requirements not met: %w", err)
}

if len(devices) == 0 {
f.logger.Debugf("No devices requested; no modification required.")
return nil, nil
Expand Down
36 changes: 1 addition & 35 deletions internal/modifier/csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import (
"fmt"

"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/cuda"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/NVIDIA/nvidia-container-toolkit/internal/requirements"
)
Expand All @@ -36,45 +34,13 @@ func (f *Factory) newCSVModifier() (oci.SpecModifier, error) {
}
f.logger.Infof("Constructing modifier from config: %+v", *f.cfg)

if err := checkRequirements(f.logger, f.image); err != nil {
if err := requirements.CheckImage(f.logger, f.image, f.driver); err != nil {
return nil, fmt.Errorf("requirements not met: %v", err)
}

return f.newAutomaticCDISpecModifier(devices)
}

func checkRequirements(logger logger.Interface, image *image.CUDA) error {
if image == nil || image.HasDisableRequire() {
// TODO: We could print the real value here instead
logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", true)
return nil
}

imageRequirements, err := image.GetRequirements()
if err != nil {
// TODO: Should we treat this as a failure, or just issue a warning?
return fmt.Errorf("failed to get image requirements: %v", err)
}

r := requirements.New(logger, imageRequirements)

cudaVersion, err := cuda.Version()
if err != nil {
logger.Warningf("Failed to get CUDA version: %v", err)
} else {
r.AddVersionProperty(requirements.CUDA, cudaVersion)
}

compteCapability, err := cuda.ComputeCapability(0)
if err != nil {
logger.Warningf("Failed to get CUDA Compute Capability: %v", err)
} else {
r.AddVersionProperty(requirements.ARCH, compteCapability)
}

return r.Assert()
}

type csvDevices image.CUDA

func (d csvDevices) DeviceRequests() []string {
Expand Down
Loading
Loading