@@ -444,7 +444,7 @@ driver_get_device_1_svc(ptr_t ctxptr, u_int idx, driver_get_device_res *res, may
444
444
{
445
445
struct driver * ctx = (struct driver * )ctxptr ;
446
446
int domainid , deviceid , busid ;
447
- char buf [NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE ];
447
+ char buf [NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE + 1 ];
448
448
449
449
memset (res , 0 , sizeof (* res ));
450
450
if (idx >= MAX_DEVICES ) {
@@ -459,8 +459,8 @@ driver_get_device_1_svc(ptr_t ctxptr, u_int idx, driver_get_device_res *res, may
459
459
goto fail ;
460
460
if (call_cuda (ctx , cuDeviceGetAttribute , & deviceid , CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID , device_handles [idx ].cuda ) < 0 )
461
461
goto fail ;
462
- snprintf (buf , sizeof (buf ), "%04x :%02x:%02x.0" , domainid , busid , deviceid );
463
- if (call_nvml (ctx , nvmlDeviceGetHandleByPciBusId , buf , & device_handles [idx ].nvml ) < 0 )
462
+ snprintf (buf , sizeof (buf ), "%08x :%02x:%02x.0" , domainid , busid , deviceid );
463
+ if (call_nvml (ctx , nvmlDeviceGetHandleByPciBusId_v2 , buf , & device_handles [idx ].nvml ) < 0 )
464
464
goto fail ;
465
465
466
466
res -> driver_get_device_res_u .dev = (ptr_t )& device_handles [idx ];
@@ -527,12 +527,16 @@ driver_get_device_busid_1_svc(ptr_t ctxptr, ptr_t dev, driver_get_device_busid_r
527
527
{
528
528
struct driver * ctx = (struct driver * )ctxptr ;
529
529
struct driver_device * handle = (struct driver_device * )dev ;
530
- nvmlPciInfo_t pci ;
530
+ int domainid , deviceid , busid ;
531
531
532
532
memset (res , 0 , sizeof (* res ));
533
- if (call_nvml (ctx , nvmlDeviceGetPciInfo_v2 , handle -> nvml , & pci ) < 0 )
533
+ if (call_cuda (ctx , cuDeviceGetAttribute , & domainid , CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID , handle -> cuda ) < 0 )
534
+ goto fail ;
535
+ if (call_cuda (ctx , cuDeviceGetAttribute , & busid , CU_DEVICE_ATTRIBUTE_PCI_BUS_ID , handle -> cuda ) < 0 )
536
+ goto fail ;
537
+ if (call_cuda (ctx , cuDeviceGetAttribute , & deviceid , CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID , handle -> cuda ) < 0 )
534
538
goto fail ;
535
- if (( res -> driver_get_device_busid_res_u .busid = xstrdup ( ctx -> err , pci . busId )) == NULL )
539
+ if (xasprintf ( ctx -> err , & res -> driver_get_device_busid_res_u .busid , "%08x:%02x:%02x.0" , domainid , busid , deviceid ) < 0 )
536
540
goto fail ;
537
541
return (true);
538
542
0 commit comments