Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions internal/modifier/cdi.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ func (f *Factory) newCDIModifier(isJitCDI bool) (oci.SpecModifier, error) {
defaultKind,
)
devices := deviceRequestor.DeviceRequests()

// Run before the empty-device return so NVIDIA_REQUIRE_* is still enforced when
// len(devices)==0 (e.g. CRI CDI injection without matching spec signals). When
// there are no requirements, checkRequirements returns immediately.
if err := checkRequirements(f.logger, f.image, f.driver); err != nil {
return nil, fmt.Errorf("requirements not met: %w", err)
}

if len(devices) == 0 {
f.logger.Debugf("No devices requested; no modification required.")
return nil, nil
Expand Down
37 changes: 1 addition & 36 deletions internal/modifier/csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ import (
"fmt"

"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/cuda"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/NVIDIA/nvidia-container-toolkit/internal/requirements"
)

// newCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper.
Expand All @@ -36,45 +33,13 @@ func (f *Factory) newCSVModifier() (oci.SpecModifier, error) {
}
f.logger.Infof("Constructing modifier from config: %+v", *f.cfg)

if err := checkRequirements(f.logger, f.image); err != nil {
if err := checkRequirements(f.logger, f.image, f.driver); err != nil {
return nil, fmt.Errorf("requirements not met: %v", err)
}

return f.newAutomaticCDISpecModifier(devices)
}

func checkRequirements(logger logger.Interface, image *image.CUDA) error {
if image == nil || image.HasDisableRequire() {
// TODO: We could print the real value here instead
logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", true)
return nil
}

imageRequirements, err := image.GetRequirements()
if err != nil {
// TODO: Should we treat this as a failure, or just issue a warning?
return fmt.Errorf("failed to get image requirements: %v", err)
}

r := requirements.New(logger, imageRequirements)

cudaVersion, err := cuda.Version()
if err != nil {
logger.Warningf("Failed to get CUDA version: %v", err)
} else {
r.AddVersionProperty(requirements.CUDA, cudaVersion)
}

compteCapability, err := cuda.ComputeCapability(0)
if err != nil {
logger.Warningf("Failed to get CUDA Compute Capability: %v", err)
} else {
r.AddVersionProperty(requirements.ARCH, compteCapability)
}

return r.Assert()
}

type csvDevices image.CUDA

func (d csvDevices) DeviceRequests() []string {
Expand Down
200 changes: 200 additions & 0 deletions internal/modifier/image_requirements.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package modifier

import (
"fmt"
"strconv"
"strings"

"github.com/NVIDIA/go-nvml/pkg/nvml"
"golang.org/x/mod/semver"

"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/cuda"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/requirements"
)

// checkRequirements evaluates NVIDIA_REQUIRE_* constraints using the host
// CUDA driver API version from libcuda, the NVIDIA display driver version from
// the driver root (libcuda / libnvidia-ml soname), the compute capability of
// CUDA device 0, and (when requirements reference brand) the GPU product brand
// from NVML. It is used for CSV and CDI / JIT-CDI modes.
Comment on lines +34 to +38
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that there are cases where libcuda.so is not applicable (if we're not injecting actuall GPU devices, for example).

func checkRequirements(logger logger.Interface, image *image.CUDA, driver *root.Driver) error {
if image == nil || image.HasDisableRequire() {
logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", true)
return nil
}

imageRequirements, err := image.GetRequirements()
if err != nil {
return fmt.Errorf("failed to get image requirements: %v", err)
}
if len(imageRequirements) == 0 {
return nil
}

r := requirements.New(logger, imageRequirements)

cudaVersion, err := cuda.Version()
if err != nil {
logger.Warningf("Failed to get CUDA version: %v", err)
} else {
r.AddVersionProperty(requirements.CUDA, cudaVersion)
}

compteCapability, err := cuda.ComputeCapability(0)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we're always using the first device (which was fine for older Tegra-based systems), but this does not map to multi-device systems especially if they're heterogeneous.

if err != nil {
logger.Warningf("Failed to get CUDA Compute Capability: %v", err)
} else {
r.AddVersionProperty(requirements.ARCH, compteCapability)
}

driverVersion, err := driver.Version()
if err != nil {
logger.Warningf("Failed to get NVIDIA driver version: %v", err)
} else {
normalized, normErr := normalizeDriverVersionForSemver(driverVersion)
if normErr != nil {
logger.Warningf("NVIDIA driver version %q is not semver-normalizable: %v", driverVersion, normErr)
} else {
r.AddVersionProperty(requirements.DRIVER, normalized)
}
}

brand, err := getBrandFromNVML(driver)
if err != nil {
logger.Warningf("Failed to get GPU brand from NVML: %v", err)
} else {
r.AddStringProperty(requirements.BRAND, brand)
}

return r.Assert()
}

// normalizeDriverVersionForSemver converts a driver version taken from a
// libcuda / libnvidia-ml soname suffix into a form accepted by
// golang.org/x/mod/semver (no leading zeros in numeric segments)
func normalizeDriverVersionForSemver(raw string) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", fmt.Errorf("empty driver version")
}
parts := strings.Split(raw, ".")
out := make([]string, 0, len(parts))
for _, p := range parts {
if p == "" {
return "", fmt.Errorf("empty version segment in %q", raw)
}
if strings.TrimLeft(p, "0123456789") != "" {
return "", fmt.Errorf("non-numeric version segment %q in %q", p, raw)
}
n, err := strconv.ParseUint(p, 10, 64)
if err != nil {
return "", fmt.Errorf("invalid version segment %q in %q: %w", p, raw, err)
}
out = append(out, strconv.FormatUint(n, 10))
}
normalized := strings.Join(out, ".")
if !semver.IsValid("v" + normalized) {
return "", fmt.Errorf("normalized driver version %q is not valid semver", normalized)
}
return normalized, nil
}

// getBrandFromNVML returns a lowercase brand token for the first visible GPU
// (index 0), using NVML. When driver is non-nil, NVML is loaded from the
// versioned libnvidia-ml under the driver root when possible.
func getBrandFromNVML(driver *root.Driver) (string, error) {
var lib nvml.Interface
var opts []nvml.LibraryOption
v, err := driver.Version()
if err == nil && v != "" && v != "*.*" {
paths, err := driver.Libraries().Locate("libnvidia-ml.so." + v)
if err == nil && len(paths) > 0 {
opts = append(opts, nvml.WithLibraryPath(paths[0]))
}
}

lib = nvml.New(opts...)
if ret := lib.Init(); ret != nvml.SUCCESS {
return "", fmt.Errorf("nvml.Init: %s", lib.ErrorString(ret))
}
defer func() {
_ = lib.Shutdown()
}()

device, ret := lib.DeviceGetHandleByIndex(0)
if ret != nvml.SUCCESS {
return "", fmt.Errorf("nvml.DeviceGetHandleByIndex(0): %s", lib.ErrorString(ret))
}

brandType, ret := lib.DeviceGetBrand(device)
if ret != nvml.SUCCESS {
return "", fmt.Errorf("nvml.DeviceGetBrand: %s", lib.ErrorString(ret))
}
brand, ok := brandTypeToRequirementString(brandType)
if !ok {
return "", fmt.Errorf("unknown NVML brand type %v", brandType)
}
return brand, nil
}

// brandTypeToRequirementString maps NVML brand enums to lowercase tokens
// consistent with typical NVIDIA_REQUIRE_* image constraints.
func brandTypeToRequirementString(b nvml.BrandType) (string, bool) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: is this something that we already have access to in go-nvlib?

switch b {
case nvml.BRAND_UNKNOWN:
return "", false
case nvml.BRAND_QUADRO:
return "quadro", true
case nvml.BRAND_TESLA:
return "tesla", true
case nvml.BRAND_NVS:
return "nvs", true
case nvml.BRAND_GRID:
return "grid", true
case nvml.BRAND_GEFORCE:
return "geforce", true
case nvml.BRAND_TITAN:
return "titan", true
case nvml.BRAND_NVIDIA_VAPPS:
return "nvidiavapps", true
case nvml.BRAND_NVIDIA_VPC:
return "nvidiavpc", true
case nvml.BRAND_NVIDIA_VCS:
return "nvidiavcs", true
case nvml.BRAND_NVIDIA_VWS:
return "nvidiavws", true
case nvml.BRAND_NVIDIA_CLOUD_GAMING:
return "nvidiacloudgaming", true
case nvml.BRAND_QUADRO_RTX:
return "quadrortx", true
case nvml.BRAND_NVIDIA_RTX:
return "nvidiartx", true
case nvml.BRAND_NVIDIA:
return "nvidia", true
case nvml.BRAND_GEFORCE_RTX:
return "geforcertx", true
case nvml.BRAND_TITAN_RTX:
return "titanrtx", true
default:
return "", false
}
}
Loading