Skip to content

Commit 8aac606

Browse files
committed
Support CDI devices in --device flag
Signed-off-by: Evan Lezar <[email protected]>
1 parent 88924b1 commit 8aac606

File tree

2 files changed

+104
-34
lines changed

2 files changed

+104
-34
lines changed

cli/command/container/opts.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"strings"
1414
"time"
1515

16+
cdi "github.com/container-orchestrated-devices/container-device-interface/pkg/parser"
1617
"github.com/docker/cli/cli/compose/loader"
1718
"github.com/docker/cli/opts"
1819
"github.com/docker/docker/api/types/container"
@@ -443,12 +444,17 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
443444
// parsing flags, we haven't yet sent a _ping to the daemon to determine
444445
// what operating system it is.
445446
deviceMappings := []container.DeviceMapping{}
447+
var cdiDeviceNames []string
446448
for _, device := range copts.devices.GetAll() {
447449
var (
448450
validated string
449451
deviceMapping container.DeviceMapping
450452
err error
451453
)
454+
if cdi.IsQualifiedName(device) {
455+
cdiDeviceNames = append(cdiDeviceNames, device)
456+
continue
457+
}
452458
validated, err = validateDevice(device, serverOS)
453459
if err != nil {
454460
return nil, err
@@ -553,6 +559,16 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
553559
}
554560
}
555561

562+
deviceRequests := copts.gpus.Value()
563+
if len(cdiDeviceNames) > 0 {
564+
cdiDeviceRequest := container.DeviceRequest{
565+
Driver: "cdi",
566+
Capabilities: [][]string{{"cdi"}},
567+
DeviceIDs: cdiDeviceNames,
568+
}
569+
deviceRequests = append(deviceRequests, cdiDeviceRequest)
570+
}
571+
556572
resources := container.Resources{
557573
CgroupParent: copts.cgroupParent,
558574
Memory: copts.memory.Value(),
@@ -583,7 +599,7 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
583599
Ulimits: copts.ulimits.GetList(),
584600
DeviceCgroupRules: copts.deviceCgroupRules.GetAll(),
585601
Devices: deviceMappings,
586-
DeviceRequests: copts.gpus.Value(),
602+
DeviceRequests: deviceRequests,
587603
}
588604

589605
config := &container.Config{

cli/command/container/opts_test.go

Lines changed: 87 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -417,39 +417,93 @@ func TestParseWithExpose(t *testing.T) {
417417

418418
func TestParseDevice(t *testing.T) {
419419
skip.If(t, runtime.GOOS != "linux") // Windows and macOS validate server-side
420-
valids := map[string]container.DeviceMapping{
421-
"/dev/snd": {
422-
PathOnHost: "/dev/snd",
423-
PathInContainer: "/dev/snd",
424-
CgroupPermissions: "rwm",
425-
},
426-
"/dev/snd:rw": {
427-
PathOnHost: "/dev/snd",
428-
PathInContainer: "/dev/snd",
429-
CgroupPermissions: "rw",
430-
},
431-
"/dev/snd:/something": {
432-
PathOnHost: "/dev/snd",
433-
PathInContainer: "/something",
434-
CgroupPermissions: "rwm",
435-
},
436-
"/dev/snd:/something:rw": {
437-
PathOnHost: "/dev/snd",
438-
PathInContainer: "/something",
439-
CgroupPermissions: "rw",
440-
},
441-
}
442-
for device, deviceMapping := range valids {
443-
_, hostconfig, _, err := parseRun([]string{fmt.Sprintf("--device=%v", device), "img", "cmd"})
444-
if err != nil {
445-
t.Fatal(err)
446-
}
447-
if len(hostconfig.Devices) != 1 {
448-
t.Fatalf("Expected 1 devices, got %v", hostconfig.Devices)
449-
}
450-
if hostconfig.Devices[0] != deviceMapping {
451-
t.Fatalf("Expected %v, got %v", deviceMapping, hostconfig.Devices)
452-
}
420+
testCases := []struct {
421+
devices []string
422+
deviceMapping *container.DeviceMapping
423+
deviceRequests []container.DeviceRequest
424+
}{
425+
{
426+
devices: []string{"/dev/snd"},
427+
deviceMapping: &container.DeviceMapping{
428+
PathOnHost: "/dev/snd",
429+
PathInContainer: "/dev/snd",
430+
CgroupPermissions: "rwm",
431+
},
432+
},
433+
{
434+
devices: []string{"/dev/snd:rw"},
435+
deviceMapping: &container.DeviceMapping{
436+
PathOnHost: "/dev/snd",
437+
PathInContainer: "/dev/snd",
438+
CgroupPermissions: "rw",
439+
},
440+
},
441+
{
442+
devices: []string{"/dev/snd:/something"},
443+
deviceMapping: &container.DeviceMapping{
444+
PathOnHost: "/dev/snd",
445+
PathInContainer: "/something",
446+
CgroupPermissions: "rwm",
447+
},
448+
},
449+
{
450+
devices: []string{"/dev/snd:/something:rw"},
451+
deviceMapping: &container.DeviceMapping{
452+
PathOnHost: "/dev/snd",
453+
PathInContainer: "/something",
454+
CgroupPermissions: "rw",
455+
},
456+
},
457+
{
458+
devices: []string{"vendor.com/class=name"},
459+
deviceMapping: nil,
460+
deviceRequests: []container.DeviceRequest{
461+
{
462+
Driver: "cdi",
463+
Capabilities: [][]string{{"cdi"}},
464+
DeviceIDs: []string{"vendor.com/class=name"},
465+
},
466+
},
467+
},
468+
{
469+
devices: []string{"vendor.com/class=name", "/dev/snd:/something:rw"},
470+
deviceMapping: &container.DeviceMapping{
471+
PathOnHost: "/dev/snd",
472+
PathInContainer: "/something",
473+
CgroupPermissions: "rw",
474+
},
475+
deviceRequests: []container.DeviceRequest{
476+
{
477+
Driver: "cdi",
478+
Capabilities: [][]string{{"cdi"}},
479+
DeviceIDs: []string{"vendor.com/class=name"},
480+
},
481+
},
482+
},
483+
}
484+
485+
for _, tc := range testCases {
486+
t.Run(fmt.Sprintf("%s", tc.devices), func(t *testing.T) {
487+
var args []string
488+
for _, d := range tc.devices {
489+
args = append(args, fmt.Sprintf("--device=%v", d))
490+
}
491+
args = append(args, "img", "cmd")
492+
493+
_, hostconfig, _, err := parseRun(args)
494+
495+
assert.NilError(t, err)
496+
497+
if tc.deviceMapping != nil {
498+
if assert.Check(t, is.Len(hostconfig.Devices, 1)) {
499+
assert.Check(t, is.DeepEqual(*tc.deviceMapping, hostconfig.Devices[0]))
500+
}
501+
} else {
502+
assert.Check(t, is.Len(hostconfig.Devices, 0))
503+
}
504+
505+
assert.Check(t, is.DeepEqual(tc.deviceRequests, hostconfig.DeviceRequests))
506+
})
453507
}
454508
}
455509

0 commit comments

Comments
 (0)