Skip to content

Commit 0443041

Browse files
committed
Refactor nvml CDI spec generation for consistency
Signed-off-by: Evan Lezar <[email protected]>
1 parent a1d48b6 commit 0443041

File tree

14 files changed

+486
-229
lines changed

14 files changed

+486
-229
lines changed

cmd/nvidia-ctk/cdi/generate/generate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
312312
return nil, fmt.Errorf("failed to create CDI library: %v", err)
313313
}
314314

315-
deviceSpecs, err := cdilib.GetAllDeviceSpecs()
315+
deviceSpecs, err := cdilib.GetDeviceSpecsByID("all")
316316
if err != nil {
317317
return nil, fmt.Errorf("failed to create device CDI specs: %v", err)
318318
}

pkg/nvcdi/api.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ type SpecGenerator interface {
3939
GetSpec(...string) (spec.Interface, error)
4040
}
4141

42+
// A DeviceSpecGenerator is used to generate the specs for one or more devices.
43+
type DeviceSpecGenerator interface {
44+
GetDeviceSpecs() ([]specs.Device, error)
45+
}
46+
4247
// A HookName represents one of the predefined NVIDIA CDI hooks.
4348
type HookName = discover.HookName
4449

pkg/nvcdi/full-gpu-nvml.go

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,41 +19,74 @@ package nvcdi
1919
import (
2020
"fmt"
2121

22-
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
2322
"tags.cncf.io/container-device-interface/pkg/cdi"
2423
"tags.cncf.io/container-device-interface/specs-go"
2524

25+
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
26+
"github.com/NVIDIA/go-nvml/pkg/nvml"
27+
2628
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2729
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
2830
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/dgpu"
2931
)
3032

31-
// GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'.
32-
func (l *nvmllib) GetGPUDeviceSpecs(i int, d device.Device) ([]specs.Device, error) {
33-
edits, err := l.GetGPUDeviceEdits(d)
33+
// A fullGPUDeviceSpecGenerator generates the CDI device specifications for a
34+
// single full GPU.
35+
type fullGPUDeviceSpecGenerator struct {
36+
*nvmllib
37+
id string
38+
index int
39+
device device.Device
40+
}
41+
42+
var _ DeviceSpecGenerator = (*fullGPUDeviceSpecGenerator)(nil)
43+
44+
func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromNVMLDevice(id string, nvmlDevice nvml.Device) (DeviceSpecGenerator, error) {
45+
device, err := l.devicelib.NewDevice(nvmlDevice)
3446
if err != nil {
35-
return nil, fmt.Errorf("failed to get edits for device: %v", err)
47+
return nil, err
3648
}
3749

38-
var deviceSpecs []specs.Device
39-
names, err := l.deviceNamers.GetDeviceNames(i, convert{d})
50+
index, ret := nvmlDevice.GetIndex()
51+
if ret != nvml.SUCCESS {
52+
return nil, fmt.Errorf("failed to get device index: %w", ret)
53+
}
54+
55+
e := &fullGPUDeviceSpecGenerator{
56+
nvmllib: l,
57+
id: id,
58+
index: index,
59+
device: device,
60+
}
61+
return e, nil
62+
}
63+
64+
func (l *fullGPUDeviceSpecGenerator) GetDeviceSpecs() ([]specs.Device, error) {
65+
deviceEdits, err := l.getDeviceEdits()
66+
if err != nil {
67+
return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", l.id, err)
68+
}
69+
70+
names, err := l.getNames()
4071
if err != nil {
41-
return nil, fmt.Errorf("failed to get device name: %v", err)
72+
return nil, fmt.Errorf("failed to get device names: %w", err)
4273
}
74+
75+
var deviceSpecs []specs.Device
4376
for _, name := range names {
44-
spec := specs.Device{
77+
deviceSpec := specs.Device{
4578
Name: name,
46-
ContainerEdits: *edits.ContainerEdits,
79+
ContainerEdits: *deviceEdits.ContainerEdits,
4780
}
48-
deviceSpecs = append(deviceSpecs, spec)
81+
deviceSpecs = append(deviceSpecs, deviceSpec)
4982
}
5083

5184
return deviceSpecs, nil
5285
}
5386

5487
// GetGPUDeviceEdits returns the CDI edits for the full GPU represented by 'device'.
55-
func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error) {
56-
device, err := l.newFullGPUDiscoverer(d)
88+
func (l *fullGPUDeviceSpecGenerator) getDeviceEdits() (*cdi.ContainerEdits, error) {
89+
device, err := l.newFullGPUDiscoverer(l.device)
5790
if err != nil {
5891
return nil, fmt.Errorf("failed to create device discoverer: %v", err)
5992
}
@@ -66,8 +99,12 @@ func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error
6699
return editsForDevice, nil
67100
}
68101

102+
func (l *fullGPUDeviceSpecGenerator) getNames() ([]string, error) {
103+
return l.deviceNamers.GetDeviceNames(l.index, convert{l.device})
104+
}
105+
69106
// newFullGPUDiscoverer creates a discoverer for the full GPU defined by the specified device.
70-
func (l *nvmllib) newFullGPUDiscoverer(d device.Device) (discover.Discover, error) {
107+
func (l *fullGPUDeviceSpecGenerator) newFullGPUDiscoverer(d device.Device) (discover.Discover, error) {
71108
deviceNodes, err := dgpu.NewForDevice(d,
72109
dgpu.WithDevRoot(l.devRoot),
73110
dgpu.WithLogger(l.logger),

pkg/nvcdi/gds.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,14 @@ import (
2828

2929
type gdslib nvcdilib
3030

31-
var _ wrapped = (*gdslib)(nil)
31+
var _ deviceSpecGeneratorFactory = (*gdslib)(nil)
3232

33-
// GetDeviceSpecsByID returns the device specs for the specified devices.
34-
func (l *gdslib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
33+
func (l *gdslib) DeviceSpecGenerators(...string) (DeviceSpecGenerator, error) {
34+
return l, nil
35+
}
36+
37+
// GetDeviceSpecs returns the CDI device specs for a single all device.
38+
func (l *gdslib) GetDeviceSpecs() ([]specs.Device, error) {
3539
discoverer, err := discover.NewGDSDiscoverer(l.logger, l.driverRoot, l.devRoot)
3640
if err != nil {
3741
return nil, fmt.Errorf("failed to create GPUDirect Storage discoverer: %v", err)

pkg/nvcdi/lib-csv.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,9 @@ import (
2929

3030
type csvlib nvcdilib
3131

32-
var _ wrapped = (*csvlib)(nil)
32+
var _ deviceSpecGeneratorFactory = (*csvlib)(nil)
3333

34-
// GetDeviceSpecsByID returns the device specs for the specified devices.
35-
func (l *csvlib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
34+
func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) {
3635
for _, id := range ids {
3736
switch id {
3837
case "all":
@@ -42,6 +41,11 @@ func (l *csvlib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
4241
}
4342
}
4443

44+
return l, nil
45+
}
46+
47+
// GetDeviceSpecs returns the CDI device specs for a single device.
48+
func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) {
4549
d, err := tegra.New(
4650
tegra.WithLogger(l.logger),
4751
tegra.WithDriverRoot(l.driverRoot),

pkg/nvcdi/lib-imex.go

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ import (
3131

3232
type imexlib nvcdilib
3333

34-
var _ wrapped = (*imexlib)(nil)
34+
type imexChannel struct {
35+
id string
36+
devRoot string
37+
}
38+
39+
var _ deviceSpecGeneratorFactory = (*imexlib)(nil)
3540

3641
const (
3742
classImexChannel = "imex-channel"
@@ -42,29 +47,24 @@ func (l *imexlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
4247
return edits.FromDiscoverer(discover.None{})
4348
}
4449

45-
// GetDeviceSpecsByID returns the CDI device specs for the IMEX channels specified.
46-
func (l *imexlib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
50+
// DeviceSpecGenerators returns the CDI device spec generators for the specified
51+
// imex channel IDs.
52+
// Valid IDs are:
53+
// * numeric channel IDs
54+
// * channel<numericChannelID>
55+
// * the special ID 'all'
56+
func (l *imexlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) {
4757
channelsIDs, err := l.getChannelIDs(ids...)
4858
if err != nil {
4959
return nil, err
5060
}
51-
var deviceSpecs []specs.Device
61+
62+
var deviceSpecGenerators DeviceSpecGenerators
5263
for _, id := range channelsIDs {
53-
path := "/dev/nvidia-caps-imex-channels/channel" + id
54-
deviceSpec := specs.Device{
55-
Name: id,
56-
ContainerEdits: specs.ContainerEdits{
57-
DeviceNodes: []*specs.DeviceNode{
58-
{
59-
Path: path,
60-
HostPath: filepath.Join(l.devRoot, path),
61-
},
62-
},
63-
},
64-
}
65-
deviceSpecs = append(deviceSpecs, deviceSpec)
64+
deviceSpecGenerators = append(deviceSpecGenerators, &imexChannel{id: id, devRoot: l.devRoot})
6665
}
67-
return deviceSpecs, nil
66+
67+
return deviceSpecGenerators, nil
6868
}
6969

7070
func (l *imexlib) getChannelIDs(ids ...string) ([]string, error) {
@@ -104,3 +104,20 @@ func (l *imexlib) getAllChannelIDs() ([]string, error) {
104104

105105
return channelIDs, nil
106106
}
107+
108+
// GetDeviceSpecs returns the CDI device specs the specified IMEX channel.
109+
func (l *imexChannel) GetDeviceSpecs() ([]specs.Device, error) {
110+
path := "/dev/nvidia-caps-imex-channels/channel" + l.id
111+
deviceSpec := specs.Device{
112+
Name: l.id,
113+
ContainerEdits: specs.ContainerEdits{
114+
DeviceNodes: []*specs.DeviceNode{
115+
{
116+
Path: path,
117+
HostPath: filepath.Join(l.devRoot, path),
118+
},
119+
},
120+
},
121+
}
122+
return []specs.Device{deviceSpec}, nil
123+
}

0 commit comments

Comments
 (0)