Skip to content

Commit 3375782

Browse files
committed
driver: request gpu when creating container builder
Signed-off-by: CrazyMax <[email protected]>
1 parent 0d708c0 commit 3375782

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

driver/docker-container/driver.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ type Driver struct {
5656
restartPolicy container.RestartPolicy
5757
env []string
5858
defaultLoad bool
59+
gpus []container.DeviceRequest
5960
}
6061

6162
func (d *Driver) IsMobyDriver() bool {
@@ -158,6 +159,9 @@ func (d *Driver) create(ctx context.Context, l progress.SubLogger) error {
158159
if d.cpusetMems != "" {
159160
hc.Resources.CpusetMems = d.cpusetMems
160161
}
162+
if len(d.gpus) > 0 && d.hasGPUCapability(ctx, cfg.Image, d.gpus) {
163+
hc.Resources.DeviceRequests = d.gpus
164+
}
161165
if info, err := d.DockerAPI.Info(ctx); err == nil {
162166
if info.CgroupDriver == "cgroupfs" {
163167
// Place all buildkit containers inside this cgroup by default so limits can be attached
@@ -429,6 +433,31 @@ func (d *Driver) HostGatewayIP(ctx context.Context) (net.IP, error) {
429433
return nil, errors.New("host-gateway is not supported by the docker-container driver")
430434
}
431435

436+
// hasGPUCapability checks if docker daemon has GPU capability. We need to run
437+
// a dummy container with GPU device to check if the daemon has this capability
438+
// because there is no API to check it yet.
439+
func (d *Driver) hasGPUCapability(ctx context.Context, image string, gpus []container.DeviceRequest) bool {
440+
cfg := &container.Config{
441+
Image: image,
442+
Entrypoint: []string{"/bin/true"},
443+
}
444+
hc := &container.HostConfig{
445+
NetworkMode: container.NetworkMode(container.IPCModeNone),
446+
AutoRemove: true,
447+
Resources: container.Resources{
448+
DeviceRequests: gpus,
449+
},
450+
}
451+
resp, err := d.DockerAPI.ContainerCreate(ctx, cfg, hc, &network.NetworkingConfig{}, nil, "")
452+
if err != nil {
453+
return false
454+
}
455+
if err := d.DockerAPI.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil {
456+
return false
457+
}
458+
return true
459+
}
460+
432461
func demuxConn(c net.Conn) net.Conn {
433462
pr, pw := io.Pipe()
434463
// TODO: rewrite parser with Reader() to avoid goroutine switch

driver/docker-container/factory.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ func (f *factory) New(ctx context.Context, cfg driver.InitConfig) (driver.Driver
5151
InitConfig: cfg,
5252
restartPolicy: rp,
5353
}
54+
var gpus dockeropts.GpuOpts
55+
if err := gpus.Set("all"); err == nil {
56+
if v := gpus.Value(); len(v) > 0 {
57+
d.gpus = v
58+
}
59+
}
5460
for k, v := range cfg.DriverOpts {
5561
switch {
5662
case k == "network":

0 commit comments

Comments
 (0)