Skip to content

Commit 6c67a19

Browse files
committed
Add support for device architecture requirement
1 parent c95f17c commit 6c67a19

File tree

8 files changed

+173
-62
lines changed

8 files changed

+173
-62
lines changed

src/driver.c

Lines changed: 77 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,19 @@
3232
#define SONAME_LIBCUDA "libcuda.so.1"
3333
#define SONAME_LIBNVML "libnvidia-ml.so.1"
3434

35+
#define MAX_DEVICES 64
3536
#define REAP_TIMEOUT_MS 10
3637

3738
static int reset_cuda_environment(struct error *);
3839
static int setup_rpc_client(struct driver *);
3940
static noreturn void setup_rpc_service(struct driver *, uid_t, gid_t, pid_t);
4041
static int reap_process(struct error *, pid_t, int, bool);
4142

43+
static struct driver_device {
44+
nvmlDevice_t nvml;
45+
CUdevice cuda;
46+
} device_handles[MAX_DEVICES];
47+
4248
#define call_nvml(ctx, sym, ...) __extension__ ({ \
4349
union {void *ptr; __typeof__(&sym) fn;} u_; \
4450
nvmlReturn_t r_; \
@@ -83,7 +89,7 @@ reset_cuda_environment(struct error *err)
8389
const struct { const char *name, *value; } env[] = {
8490
{"CUDA_DISABLE_UNIFIED_MEMORY", "1"},
8591
{"CUDA_CACHE_DISABLE", "1"},
86-
{"CUDA_DEVICE_ORDER", "FASTEST_FIRST"},
92+
{"CUDA_DEVICE_ORDER", "PCI_BUS_ID"},
8793
{"CUDA_VISIBLE_DEVICES", NULL},
8894
{"CUDA_MPS_PIPE_DIRECTORY", "/dev/null"},
8995
};
@@ -418,49 +424,46 @@ driver_get_device_count_1_svc(ptr_t ctxptr, driver_get_device_count_res *res, ma
418424
}
419425

420426
int
421-
driver_get_device_handle(struct driver *ctx, unsigned int idx, driver_device_handle *dev, bool pci_order)
427+
driver_get_device(struct driver *ctx, unsigned int idx, struct driver_device **dev)
422428
{
423-
struct driver_get_device_handle_res res = {0};
429+
struct driver_get_device_res res = {0};
424430
int rv = -1;
425431

426-
if (call_rpc(ctx, &res, driver_get_device_handle_1, idx, pci_order) < 0)
432+
if (call_rpc(ctx, &res, driver_get_device_1, idx) < 0)
427433
goto fail;
428-
*dev = (driver_device_handle)res.driver_get_device_handle_res_u.handle;
434+
*dev = (struct driver_device *)res.driver_get_device_res_u.dev;
429435
rv = 0;
430436

431437
fail:
432-
xdr_free((xdrproc_t)xdr_driver_get_device_handle_res, (caddr_t)&res);
438+
xdr_free((xdrproc_t)xdr_driver_get_device_res, (caddr_t)&res);
433439
return (rv);
434440
}
435441

436442
bool_t
437-
driver_get_device_handle_1_svc(ptr_t ctxptr, u_int idx, bool_t pci_order, driver_get_device_handle_res *res, maybe_unused struct svc_req *req)
443+
driver_get_device_1_svc(ptr_t ctxptr, u_int idx, driver_get_device_res *res, maybe_unused struct svc_req *req)
438444
{
439445
struct driver *ctx = (struct driver *)ctxptr;
440-
driver_device_handle handle;
441-
CUdevice cudev;
442446
int domainid, deviceid, busid;
443447
char buf[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE];
444448

445449
memset(res, 0, sizeof(*res));
446-
if (pci_order) {
447-
if (call_nvml(ctx, nvmlDeviceGetHandleByIndex, idx, &handle) < 0)
448-
goto fail;
449-
} else {
450-
if (call_cuda(ctx, cuDeviceGet, &cudev, (int)idx) < 0)
451-
goto fail;
452-
if (call_cuda(ctx, cuDeviceGetAttribute, &domainid, CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID, cudev) < 0)
453-
goto fail;
454-
if (call_cuda(ctx, cuDeviceGetAttribute, &busid, CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, cudev) < 0)
455-
goto fail;
456-
if (call_cuda(ctx, cuDeviceGetAttribute, &deviceid, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, cudev) < 0)
457-
goto fail;
458-
snprintf(buf, sizeof(buf), "%04x:%02x:%02x.0", domainid, busid, deviceid);
459-
460-
if (call_nvml(ctx, nvmlDeviceGetHandleByPciBusId, buf, &handle) < 0)
461-
goto fail;
450+
if (idx >= MAX_DEVICES) {
451+
error_setx(ctx->err, "too many devices");
452+
goto fail;
462453
}
463-
res->driver_get_device_handle_res_u.handle = (ptr_t)handle;
454+
if (call_cuda(ctx, cuDeviceGet, &device_handles[idx].cuda, (int)idx) < 0)
455+
goto fail;
456+
if (call_cuda(ctx, cuDeviceGetAttribute, &domainid, CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID, device_handles[idx].cuda) < 0)
457+
goto fail;
458+
if (call_cuda(ctx, cuDeviceGetAttribute, &busid, CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, device_handles[idx].cuda) < 0)
459+
goto fail;
460+
if (call_cuda(ctx, cuDeviceGetAttribute, &deviceid, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device_handles[idx].cuda) < 0)
461+
goto fail;
462+
snprintf(buf, sizeof(buf), "%04x:%02x:%02x.0", domainid, busid, deviceid);
463+
if (call_nvml(ctx, nvmlDeviceGetHandleByPciBusId, buf, &device_handles[idx].nvml) < 0)
464+
goto fail;
465+
466+
res->driver_get_device_res_u.dev = (ptr_t)&device_handles[idx];
464467
return (true);
465468

466469
fail:
@@ -469,7 +472,7 @@ driver_get_device_handle_1_svc(ptr_t ctxptr, u_int idx, bool_t pci_order, driver
469472
}
470473

471474
int
472-
driver_get_device_minor(struct driver *ctx, driver_device_handle dev, unsigned int *minor)
475+
driver_get_device_minor(struct driver *ctx, struct driver_device *dev, unsigned int *minor)
473476
{
474477
struct driver_get_device_minor_res res = {0};
475478
int rv = -1;
@@ -488,10 +491,11 @@ bool_t
488491
driver_get_device_minor_1_svc(ptr_t ctxptr, ptr_t dev, driver_get_device_minor_res *res, maybe_unused struct svc_req *req)
489492
{
490493
struct driver *ctx = (struct driver *)ctxptr;
494+
struct driver_device *handle = (struct driver_device *)dev;
491495
unsigned int minor;
492496

493497
memset(res, 0, sizeof(*res));
494-
if (call_nvml(ctx, nvmlDeviceGetMinorNumber, (nvmlDevice_t)dev, &minor) < 0)
498+
if (call_nvml(ctx, nvmlDeviceGetMinorNumber, handle->nvml, &minor) < 0)
495499
goto fail;
496500
res->driver_get_device_minor_res_u.minor = minor;
497501
return (true);
@@ -502,7 +506,7 @@ driver_get_device_minor_1_svc(ptr_t ctxptr, ptr_t dev, driver_get_device_minor_r
502506
}
503507

504508
int
505-
driver_get_device_busid(struct driver *ctx, driver_device_handle dev, char **busid)
509+
driver_get_device_busid(struct driver *ctx, struct driver_device *dev, char **busid)
506510
{
507511
struct driver_get_device_busid_res res = {0};
508512
int rv = -1;
@@ -522,10 +526,11 @@ bool_t
522526
driver_get_device_busid_1_svc(ptr_t ctxptr, ptr_t dev, driver_get_device_busid_res *res, maybe_unused struct svc_req *req)
523527
{
524528
struct driver *ctx = (struct driver *)ctxptr;
529+
struct driver_device *handle = (struct driver_device *)dev;
525530
nvmlPciInfo_t pci;
526531

527532
memset(res, 0, sizeof(*res));
528-
if (call_nvml(ctx, nvmlDeviceGetPciInfo_v2, (nvmlDevice_t)dev, &pci) < 0)
533+
if (call_nvml(ctx, nvmlDeviceGetPciInfo_v2, handle->nvml, &pci) < 0)
529534
goto fail;
530535
if ((res->driver_get_device_busid_res_u.busid = xstrdup(ctx->err, pci.busId)) == NULL)
531536
goto fail;
@@ -537,7 +542,7 @@ driver_get_device_busid_1_svc(ptr_t ctxptr, ptr_t dev, driver_get_device_busid_r
537542
}
538543

539544
int
540-
driver_get_device_uuid(struct driver *ctx, driver_device_handle dev, char **uuid)
545+
driver_get_device_uuid(struct driver *ctx, struct driver_device *dev, char **uuid)
541546
{
542547
struct driver_get_device_uuid_res res = {0};
543548
int rv = -1;
@@ -557,10 +562,11 @@ bool_t
557562
driver_get_device_uuid_1_svc(ptr_t ctxptr, ptr_t dev, driver_get_device_uuid_res *res, maybe_unused struct svc_req *req)
558563
{
559564
struct driver *ctx = (struct driver *)ctxptr;
565+
struct driver_device *handle = (struct driver_device *)dev;
560566
char buf[NVML_DEVICE_UUID_BUFFER_SIZE];
561567

562568
memset(res, 0, sizeof(*res));
563-
if (call_nvml(ctx, nvmlDeviceGetUUID, (nvmlDevice_t)dev, buf, sizeof(buf)) < 0)
569+
if (call_nvml(ctx, nvmlDeviceGetUUID, handle->nvml, buf, sizeof(buf)) < 0)
564570
goto fail;
565571
if ((res->driver_get_device_uuid_res_u.uuid = xstrdup(ctx->err, buf)) == NULL)
566572
goto fail;
@@ -570,3 +576,42 @@ driver_get_device_uuid_1_svc(ptr_t ctxptr, ptr_t dev, driver_get_device_uuid_res
570576
error_to_xdr(ctx->err, res);
571577
return (true);
572578
}
579+
580+
int
581+
driver_get_device_arch(struct driver *ctx, struct driver_device *dev, char **arch)
582+
{
583+
struct driver_get_device_arch_res res = {0};
584+
int rv = -1;
585+
586+
if (call_rpc(ctx, &res, driver_get_device_arch_1, (ptr_t)dev) < 0)
587+
goto fail;
588+
if (xasprintf(ctx->err, arch, "%u.%u", res.driver_get_device_arch_res_u.arch.major,
589+
res.driver_get_device_arch_res_u.arch.minor) < 0)
590+
goto fail;
591+
rv = 0;
592+
593+
fail:
594+
xdr_free((xdrproc_t)xdr_driver_get_device_arch_res, (caddr_t)&res);
595+
return (rv);
596+
}
597+
598+
bool_t
599+
driver_get_device_arch_1_svc(ptr_t ctxptr, ptr_t dev, driver_get_device_arch_res *res, maybe_unused struct svc_req *req)
600+
{
601+
struct driver *ctx = (struct driver *)ctxptr;
602+
struct driver_device *handle = (struct driver_device *)dev;
603+
int major, minor;
604+
605+
memset(res, 0, sizeof(*res));
606+
if (call_cuda(ctx, cuDeviceGetAttribute, &major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, handle->cuda) < 0)
607+
goto fail;
608+
if (call_cuda(ctx, cuDeviceGetAttribute, &minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, handle->cuda) < 0)
609+
goto fail;
610+
res->driver_get_device_arch_res_u.arch.major = (unsigned int)major;
611+
res->driver_get_device_arch_res_u.arch.minor = (unsigned int)minor;
612+
return (true);
613+
614+
fail:
615+
error_to_xdr(ctx->err, res);
616+
return (true);
617+
}

src/driver.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ SVCXPRT *svcunixfd_create(int, u_int, u_int);
2121
#define SOCK_CLT 0
2222
#define SOCK_SVC 1
2323

24+
struct driver_device;
25+
2426
struct driver {
2527
struct error *err;
2628
void *cuda_dl;
@@ -31,18 +33,17 @@ struct driver {
3133
CLIENT *rpc_clt;
3234
};
3335

34-
typedef struct nvmlDevice_st *driver_device_handle;
35-
3636
void driver_program_1(struct svc_req *, register SVCXPRT *);
3737

3838
int driver_init(struct driver *, struct error *, uid_t, gid_t);
3939
int driver_shutdown(struct driver *);
4040
int driver_get_rm_version(struct driver *, char **);
4141
int driver_get_cuda_version(struct driver *, char **);
4242
int driver_get_device_count(struct driver *, unsigned int *);
43-
int driver_get_device_handle(struct driver *, unsigned int, driver_device_handle *, bool);
44-
int driver_get_device_minor(struct driver *, driver_device_handle, unsigned int *);
45-
int driver_get_device_busid(struct driver *, driver_device_handle, char **);
46-
int driver_get_device_uuid(struct driver *, driver_device_handle, char **);
43+
int driver_get_device(struct driver *, unsigned int, struct driver_device **);
44+
int driver_get_device_minor(struct driver *, struct driver_device *, unsigned int *);
45+
int driver_get_device_busid(struct driver *, struct driver_device *, char **);
46+
int driver_get_device_uuid(struct driver *, struct driver_device *, char **);
47+
int driver_get_device_arch(struct driver *, struct driver_device *, char **);
4748

4849
#endif /* HEADER_DRIVER_H */

src/driver_rpc.x

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,28 @@ union driver_get_cuda_version_res switch (int errcode) {
4141
string errmsg<>;
4242
};
4343

44+
struct driver_device_arch {
45+
unsigned int major;
46+
unsigned int minor;
47+
};
48+
49+
union driver_get_device_arch_res switch (int errcode) {
50+
case 0:
51+
driver_device_arch arch;
52+
default:
53+
string errmsg<>;
54+
};
55+
4456
union driver_get_device_count_res switch (int errcode) {
4557
case 0:
4658
unsigned int count;
4759
default:
4860
string errmsg<>;
4961
};
5062

51-
union driver_get_device_handle_res switch (int errcode) {
63+
union driver_get_device_res switch (int errcode) {
5264
case 0:
53-
ptr_t handle;
65+
ptr_t dev;
5466
default:
5567
string errmsg<>;
5668
};
@@ -83,9 +95,10 @@ program DRIVER_PROGRAM {
8395
driver_get_rm_version_res DRIVER_GET_RM_VERSION(ptr_t) = 3;
8496
driver_get_cuda_version_res DRIVER_GET_CUDA_VERSION(ptr_t) = 4;
8597
driver_get_device_count_res DRIVER_GET_DEVICE_COUNT(ptr_t) = 5;
86-
driver_get_device_handle_res DRIVER_GET_DEVICE_HANDLE(ptr_t, unsigned int, bool) = 6;
98+
driver_get_device_res DRIVER_GET_DEVICE(ptr_t, unsigned int) = 6;
8799
driver_get_device_minor_res DRIVER_GET_DEVICE_MINOR(ptr_t, ptr_t) = 7;
88100
driver_get_device_busid_res DRIVER_GET_DEVICE_BUSID(ptr_t, ptr_t) = 8;
89101
driver_get_device_uuid_res DRIVER_GET_DEVICE_UUID(ptr_t, ptr_t) = 9;
102+
driver_get_device_arch_res DRIVER_GET_DEVICE_ARCH(ptr_t, ptr_t) = 10;
90103
} = 1;
91104
} = 0x1;

src/dsl.c

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,31 +116,39 @@ evaluate_rule(char *buf, char *expr, void *ctx, const struct dsl_rule rules[], s
116116
}
117117

118118
int
119-
dsl_evaluate(struct error *err, char *expr, void *ctx, const struct dsl_rule rules[], size_t size)
119+
dsl_evaluate(struct error *err, const char *predicate, void *ctx, const struct dsl_rule rules[], size_t size)
120120
{
121+
char *ptr, *expr = NULL;
121122
char *or_expr, *and_expr;
122123
int ret = true;
124+
int rv = -1;
123125
char buf[EXPR_MAX];
124126

125-
while ((or_expr = strsep(&expr, " ")) != NULL) {
127+
if ((expr = ptr = xstrdup(err, predicate)) == NULL)
128+
goto fail;
129+
while ((or_expr = strsep(&ptr, " ")) != NULL) {
126130
if (*or_expr == '\0')
127131
continue;
128132
while ((and_expr = strsep(&or_expr, ",")) != NULL) {
129133
if (*and_expr == '\0')
130134
continue;
131135
if ((ret = evaluate_rule(buf, and_expr, ctx, rules, size)) < 0) {
132136
error_setx(err, "invalid expression");
133-
return (-1);
137+
goto fail;
134138
}
135139
if (!ret)
136140
break;
137141
}
138142
if (and_expr == NULL)
139-
return (0);
143+
break;
140144
}
141145
if (!ret) {
142146
error_setx(err, "unsatisfied condition: %s", buf);
143-
return (-1);
147+
goto fail;
144148
}
145-
return (0);
149+
rv = 0;
150+
151+
fail:
152+
free(expr);
153+
return (rv);
146154
}

src/dsl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ struct dsl_rule {
2424
};
2525

2626
int dsl_compare_version(const char *, enum dsl_comparator, const char *);
27-
int dsl_evaluate(struct error *, char *, void *, const struct dsl_rule [], size_t);
27+
int dsl_evaluate(struct error *, const char *, void *, const struct dsl_rule [], size_t);
2828

2929
#endif /* HEADER_DSL_H */

src/nvc.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ struct nvc_driver_info {
5959
struct nvc_device {
6060
char *uuid;
6161
char *busid;
62+
char *arch;
6263
struct nvc_device_node node;
6364
};
6465

0 commit comments

Comments
 (0)