Skip to content

Commit 5aefd2b

Browse files
authored
[EBPF] gpu: handle runtime changes of CUDA_VISIBLE_DEVICES (#38312)
1 parent 9679219 commit 5aefd2b

22 files changed

+213
-49
lines changed

pkg/ebpf/cgo/genpost.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ func processFile(rdr io.Reader, out io.Writer) error {
5656
"Topic_name",
5757
"Trigger_comm",
5858
"Victim_comm",
59+
"Devices",
5960
}
6061

6162
// Convert []int8 to []byte in multiple generated fields from the kernel, to simplify

pkg/ebpf/uprobes/attacher_test.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,7 @@ func (s *SharedLibrarySuite) TestMultipleLibsets() {
10521052
// Create test files for different libsets
10531053
cryptoLibPath, _ := createTempTestFile(t, "foo-libssl.so")
10541054
gpuLibPath, _ := createTempTestFile(t, "foo-libcudart.so")
1055+
libcLibPath, _ := createTempTestFile(t, "foo-libc.so")
10551056

10561057
attachCfg := AttacherConfig{
10571058
Rules: []*AttachRule{
@@ -1063,9 +1064,13 @@ func (s *SharedLibrarySuite) TestMultipleLibsets() {
10631064
LibraryNameRegex: regexp.MustCompile(`foo-libcudart\.so`),
10641065
Targets: AttachToSharedLibraries,
10651066
},
1067+
{
1068+
LibraryNameRegex: regexp.MustCompile(`foo-libc\.so`),
1069+
Targets: AttachToSharedLibraries,
1070+
},
10661071
},
10671072
EbpfConfig: ebpfCfg,
1068-
SharedLibsLibsets: []sharedlibraries.Libset{sharedlibraries.LibsetCrypto, sharedlibraries.LibsetGPU},
1073+
SharedLibsLibsets: []sharedlibraries.Libset{sharedlibraries.LibsetCrypto, sharedlibraries.LibsetGPU, sharedlibraries.LibsetLibc},
10691074
EnablePeriodicScanNewProcesses: false,
10701075
}
10711076

@@ -1094,6 +1099,7 @@ func (s *SharedLibrarySuite) TestMultipleLibsets() {
10941099
testCases := []testCase{
10951100
{cryptoLibPath, "foo-libssl.so", "crypto library"},
10961101
{gpuLibPath, "foo-libcudart.so", "GPU library"},
1102+
{libcLibPath, "foo-libc.so", "libc library"},
10971103
}
10981104

10991105
var commands []*exec.Cmd

pkg/ebpf/uprobes/testutil.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,12 @@ func waitAndRetryIfFail(t *testing.T, setupFunc func(), testFunc func() bool, re
250250
}
251251
}
252252

253-
require.Fail(t, "condition not met after %d retries", maxRetries, msgAndArgs)
253+
extraFmt := ""
254+
if len(msgAndArgs) > 0 {
255+
extraFmt = fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...) + ": "
256+
}
257+
258+
require.Fail(t, "condition not met", "%scondition not met after %d retries", extraFmt, maxRetries)
254259
}
255260

256261
// processMonitorProxy is a wrapper around a ProcessMonitor that stores the

pkg/gpu/consumer.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
"time"
1515
"unsafe"
1616

17+
"golang.org/x/sys/unix"
18+
1719
"github.com/DataDog/datadog-agent/comp/core/telemetry"
1820
ddebpf "github.com/DataDog/datadog-agent/pkg/ebpf"
1921
"github.com/DataDog/datadog-agent/pkg/gpu/config"
@@ -163,7 +165,7 @@ func (c *cudaEventConsumer) Start() {
163165
}
164166

165167
func isStreamSpecificEvent(eventType gpuebpf.CudaEventType) bool {
166-
return eventType != gpuebpf.CudaEventTypeSetDevice
168+
return eventType != gpuebpf.CudaEventTypeSetDevice && eventType != gpuebpf.CudaEventTypeVisibleDevicesSet
167169
}
168170

169171
func (c *cudaEventConsumer) handleEvent(header *gpuebpf.CudaEventHeader, dataPtr unsafe.Pointer, dataLen int) error {
@@ -222,11 +224,19 @@ func (c *cudaEventConsumer) handleSetDevice(csde *gpuebpf.CudaSetDeviceEvent) {
222224
c.sysCtx.setDeviceSelection(int(pid), int(tid), csde.Device)
223225
}
224226

227+
func (c *cudaEventConsumer) handleVisibleDevicesSet(vds *gpuebpf.CudaVisibleDevicesSetEvent) {
228+
pid, _ := getPidTidFromHeader(&vds.Header)
229+
230+
c.sysCtx.setUpdatedVisibleDevicesEnvVar(int(pid), unix.ByteSliceToString(vds.Devices[:]))
231+
}
232+
225233
func (c *cudaEventConsumer) handleGlobalEvent(header *gpuebpf.CudaEventHeader, data unsafe.Pointer, dataLen int) error {
226234
eventType := gpuebpf.CudaEventType(header.Type)
227235
switch eventType {
228236
case gpuebpf.CudaEventTypeSetDevice:
229237
return handleTypedEvent(c, c.handleSetDevice, eventType, data, dataLen, gpuebpf.SizeofCudaSetDeviceEvent)
238+
case gpuebpf.CudaEventTypeVisibleDevicesSet:
239+
return handleTypedEvent(c, c.handleVisibleDevicesSet, eventType, data, dataLen, gpuebpf.SizeofCudaVisibleDevicesSetEvent)
230240
default:
231241
c.telemetry.eventErrors.Inc(telemetryEventTypeUnknown, telemetryEventErrorUnknownType)
232242
return fmt.Errorf("unknown event type: %d", header.Type)

pkg/gpu/context.go

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/DataDog/datadog-agent/pkg/gpu/cuda"
1818
ddnvml "github.com/DataDog/datadog-agent/pkg/gpu/safenvml"
1919
gpuutil "github.com/DataDog/datadog-agent/pkg/util/gpu"
20+
"github.com/DataDog/datadog-agent/pkg/util/kernel"
2021
"github.com/DataDog/datadog-agent/pkg/util/ktime"
2122
)
2223

@@ -35,6 +36,12 @@ type systemContext struct {
3536
// be modified by the CUDA_VISIBLE_DEVICES environment variable later
3637
selectedDeviceByPIDAndTID map[int]map[int]int32
3738

39+
// cudaVisibleDevicesPerProcess maps each process ID to the latest visible
40+
// devices environment variable that was set by the process. This is used to
41+
// keep track of updates during process runtime, which aren't visible in
42+
// /proc/pid/environ.
43+
cudaVisibleDevicesPerProcess map[int]string
44+
3845
// deviceCache is a cache of GPU devices on the system
3946
deviceCache ddnvml.DeviceCache
4047

@@ -104,10 +111,11 @@ func getSystemContext(optList ...systemContextOption) (*systemContext, error) {
104111
opts := newSystemContextOptions(optList...)
105112

106113
ctx := &systemContext{
107-
procRoot: opts.procRoot,
108-
selectedDeviceByPIDAndTID: make(map[int]map[int]int32),
109-
visibleDevicesCache: make(map[int][]ddnvml.Device),
110-
workloadmeta: opts.wmeta,
114+
procRoot: opts.procRoot,
115+
selectedDeviceByPIDAndTID: make(map[int]map[int]int32),
116+
visibleDevicesCache: make(map[int][]ddnvml.Device),
117+
cudaVisibleDevicesPerProcess: make(map[int]string),
118+
workloadmeta: opts.wmeta,
111119
}
112120

113121
var err error
@@ -135,6 +143,7 @@ func getSystemContext(optList ...systemContextOption) (*systemContext, error) {
135143
func (ctx *systemContext) removeProcess(pid int) {
136144
delete(ctx.selectedDeviceByPIDAndTID, pid)
137145
delete(ctx.visibleDevicesCache, pid)
146+
delete(ctx.cudaVisibleDevicesPerProcess, pid)
138147

139148
if ctx.cudaKernelCache != nil {
140149
ctx.cudaKernelCache.CleanProcessData(pid)
@@ -251,7 +260,15 @@ func (ctx *systemContext) getCurrentActiveGpuDevice(pid int, tid int, containerI
251260
return nil, fmt.Errorf("error filtering devices for container %s: %w", containerID, err)
252261
}
253262

254-
visibleDevices, err = cuda.GetVisibleDevicesForProcess(visibleDevices, pid, ctx.procRoot)
263+
envVar, ok := ctx.cudaVisibleDevicesPerProcess[pid]
264+
if !ok {
265+
envVar, err = kernel.GetProcessEnvVariable(pid, ctx.procRoot, cuda.CudaVisibleDevicesEnvVar)
266+
if err != nil {
267+
return nil, fmt.Errorf("error getting env var %s for process %d: %w", cuda.CudaVisibleDevicesEnvVar, pid, err)
268+
}
269+
}
270+
271+
visibleDevices, err = cuda.ParseVisibleDevices(visibleDevices, envVar)
255272
if err != nil {
256273
return nil, fmt.Errorf("error getting visible devices for process %d: %w", pid, err)
257274
}
@@ -284,3 +301,10 @@ func (ctx *systemContext) setDeviceSelection(pid int, tid int, deviceIndex int32
284301

285302
ctx.selectedDeviceByPIDAndTID[pid][tid] = deviceIndex
286303
}
304+
305+
func (ctx *systemContext) setUpdatedVisibleDevicesEnvVar(pid int, envVar string) {
306+
ctx.cudaVisibleDevicesPerProcess[pid] = envVar
307+
308+
// Invalidate the visible devices cache to force a re-scan of the devices
309+
delete(ctx.visibleDevicesCache, pid)
310+
}

pkg/gpu/context_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ func TestGetCurrentActiveGpuDevice(t *testing.T) {
176176
containerID string
177177
configuredDeviceIdx []int32
178178
expectedDeviceIdx []int32
179+
updatedEnvVar string
179180
}{
180181
{
181182
name: "NoContainer",
@@ -205,10 +206,29 @@ func TestGetCurrentActiveGpuDevice(t *testing.T) {
205206
configuredDeviceIdx: []int32{1, 2},
206207
expectedDeviceIdx: []int32{containerDeviceIndexes[envVisibleDevices[1]], containerDeviceIndexes[envVisibleDevices[2]]},
207208
},
209+
{
210+
name: "NoContainerAndRuntimeEnvVar",
211+
pid: pidNoContainer,
212+
configuredDeviceIdx: []int32{0},
213+
expectedDeviceIdx: []int32{1},
214+
updatedEnvVar: "1",
215+
},
216+
{
217+
name: "NoContainerAndRuntimeUpdatedEnvVar",
218+
pid: pidNoContainerButEnv,
219+
configuredDeviceIdx: []int32{0},
220+
expectedDeviceIdx: []int32{1},
221+
updatedEnvVar: "1",
222+
},
208223
}
209224

210225
for _, c := range cases {
211226
t.Run(c.name, func(t *testing.T) {
227+
if c.updatedEnvVar != "" {
228+
sysCtx.setUpdatedVisibleDevicesEnvVar(c.pid, c.updatedEnvVar)
229+
require.NotContains(t, sysCtx.visibleDevicesCache, c.pid, "cache not invalidated for process %d", c.pid)
230+
}
231+
212232
for i, idx := range c.configuredDeviceIdx {
213233
sysCtx.setDeviceSelection(c.pid, c.pid+i, idx)
214234
}
@@ -218,6 +238,10 @@ func TestGetCurrentActiveGpuDevice(t *testing.T) {
218238
require.NoError(t, err)
219239
nvmltestutil.RequireDevicesEqual(t, sysCtx.deviceCache.All()[idx], activeDevice, "invalid device at index %d (real index is %d, selected index is %d)", i, idx, c.configuredDeviceIdx[i])
220240
}
241+
242+
// Note: we're explicitly not resetting the caches, as we want to test
243+
// whether the functions correctly invalidate the caches when the
244+
// environment variable is updated.
221245
})
222246
}
223247
}

pkg/gpu/cuda/env.go

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,14 @@ import (
1212
"strconv"
1313
"strings"
1414

15-
"github.com/DataDog/datadog-agent/pkg/util/kernel"
16-
1715
ddnvml "github.com/DataDog/datadog-agent/pkg/gpu/safenvml"
1816
)
1917

20-
const cudaVisibleDevicesEnvVar = "CUDA_VISIBLE_DEVICES"
18+
// CudaVisibleDevicesEnvVar is the name of the environment variable that controls the visible GPUs for CUDA applications
19+
const CudaVisibleDevicesEnvVar = "CUDA_VISIBLE_DEVICES"
2120

22-
// GetVisibleDevicesForProcess modifies the list of GPU devices according to the
23-
// value of the CUDA_VISIBLE_DEVICES environment variable for the specified
24-
// process. Reference:
21+
// ParseVisibleDevices modifies the list of GPU devices according to the
22+
// value of the CUDA_VISIBLE_DEVICES environment variable. Reference:
2523
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars.
2624
//
2725
// As a summary, the CUDA_VISIBLE_DEVICES environment variable should be a comma
@@ -36,18 +34,7 @@ const cudaVisibleDevicesEnvVar = "CUDA_VISIBLE_DEVICES"
3634
// devices whose index precedes the invalid index are visible to CUDA
3735
// applications." If an invalid index is found, an error is returned together
3836
// with the list of valid devices found up until that point.
39-
func GetVisibleDevicesForProcess(devices []ddnvml.Device, pid int, procfs string) ([]ddnvml.Device, error) {
40-
cudaVisibleDevicesForProcess, err := kernel.GetProcessEnvVariable(pid, procfs, cudaVisibleDevicesEnvVar)
41-
if err != nil {
42-
return nil, fmt.Errorf("cannot get env var %s for process %d: %w", cudaVisibleDevicesEnvVar, pid, err)
43-
}
44-
45-
return getVisibleDevices(devices, cudaVisibleDevicesForProcess)
46-
}
47-
48-
// getVisibleDevices processes the list of GPU devices according to the value of
49-
// the CUDA_VISIBLE_DEVICES environment variable
50-
func getVisibleDevices(devices []ddnvml.Device, cudaVisibleDevicesForProcess string) ([]ddnvml.Device, error) {
37+
func ParseVisibleDevices(devices []ddnvml.Device, cudaVisibleDevicesForProcess string) ([]ddnvml.Device, error) {
5138
// First, we adjust the list of devices to take into account how CUDA presents MIG devices in order. This
5239
// list will not be used when searching by prefix because prefix matching is done against *all* devices,
5340
// but index filtering is done against the adjusted list where devices with MIG children are replaced by

pkg/gpu/cuda/env_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func TestGetVisibleDevices(t *testing.T) {
9797

9898
for _, tc := range cases {
9999
t.Run(tc.name, func(t *testing.T) {
100-
devices, err := getVisibleDevices(devList, tc.visibleDevices)
100+
devices, err := ParseVisibleDevices(devList, tc.visibleDevices)
101101
if tc.expectsError {
102102
require.Error(t, err)
103103
} else {
@@ -319,7 +319,7 @@ func TestGetVisibleDevicesWithMIG(t *testing.T) {
319319

320320
for _, tc := range cases {
321321
t.Run(tc.name, func(t *testing.T) {
322-
devices, err := getVisibleDevices(tc.systemDevices, tc.visibleDevices)
322+
devices, err := ParseVisibleDevices(tc.systemDevices, tc.visibleDevices)
323323
if tc.expectsError {
324324
require.Error(t, err)
325325
} else {

pkg/gpu/e2e_events_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/stretchr/testify/require"
1818

1919
"github.com/DataDog/datadog-agent/pkg/gpu/config"
20+
"github.com/DataDog/datadog-agent/pkg/gpu/ebpf"
2021
ddnvml "github.com/DataDog/datadog-agent/pkg/gpu/safenvml"
2122
nvmltestutil "github.com/DataDog/datadog-agent/pkg/gpu/safenvml/testutil"
2223
"github.com/DataDog/datadog-agent/pkg/gpu/testutil"
@@ -71,7 +72,7 @@ func TestPytorchBatchedKernels(t *testing.T) {
7172

7273
telemetryMetrics, err := telemetryMock.GetCountMetric("gpu__consumer", "events")
7374
require.NoError(t, err)
74-
require.Equal(t, 4, len(telemetryMetrics)) // one for each event type
75+
require.Equal(t, int(ebpf.CudaEventTypeCount), len(telemetryMetrics)) // one for each event type
7576
expectedEventsByType := testutil.DataSampleInfos[testutil.DataSamplePytorchBatchedKernels].EventByType
7677
for _, metric := range telemetryMetrics {
7778
eventTypeTag := metric.Tags()["event_type"]

pkg/gpu/ebpf/c/runtime/gpu.c

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,4 +366,43 @@ int BPF_URETPROBE(uretprobe__cudaMemcpy) {
366366
return 0;
367367
}
368368

369+
SEC("uprobe/setenv")
370+
int BPF_UPROBE(uprobe__setenv, const char *name, const char *value, int overwrite) {
371+
// Check if the env var is CUDA_VISIBLE_DEVICES. This is BPF_UPROBE, so we can't use a string
372+
// comparison.
373+
const char cuda_visible_devices[] = "CUDA_VISIBLE_DEVICES";
374+
char name_buf[sizeof(cuda_visible_devices)];
375+
376+
// bpf_probe_read_user_str is available from kernel 5.5, our minimum kernel version is 5.8.0
377+
int res = bpf_probe_read_user_str_with_telemetry(name_buf, sizeof(name_buf), name);
378+
if (res < 0) {
379+
return 0;
380+
}
381+
382+
// return value of bpf_probe_read_user_str_with_telemetry is the length of the string read,
383+
// including the NULL byte. If the string is not the same length, it's not CUDA_VISIBLE_DEVICES.
384+
if (res != sizeof(cuda_visible_devices)) {
385+
return 0;
386+
}
387+
388+
// bpf_strncmp is available in kernel 5.17, our minimum kernel version is 5.8.0
389+
// so we need to do a manual comparison
390+
for (int i = 0; i < sizeof(cuda_visible_devices); i++) {
391+
if (name_buf[i] != cuda_visible_devices[i]) {
392+
return 0;
393+
}
394+
}
395+
396+
cuda_visible_devices_set_t event = { 0 };
397+
398+
if (bpf_probe_read_user_str_with_telemetry(event.visible_devices, sizeof(event.visible_devices), value) < 0) {
399+
return 0;
400+
}
401+
402+
fill_header(&event.header, 0, cuda_visible_devices_set);
403+
404+
bpf_ringbuf_output_with_telemetry(&cuda_events, &event, sizeof(event), 0);
405+
return 0;
406+
}
407+
369408
char __license[] SEC("license") = "GPL";

0 commit comments

Comments
 (0)