Skip to content

Commit 0101689

Browse files
klucketensorflower-gardener
authored andcommitted
Reverts 30b2ecd
PiperOrigin-RevId: 689151499
1 parent 25e1991 commit 0101689

File tree

5 files changed

+49
-8
lines changed

5 files changed

+49
-8
lines changed

third_party/xla/xla/stream_executor/cuda/cuda_driver.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,13 @@ int GpuDriver::GetDeviceCount() {
632632
return device_count;
633633
}
634634

635+
absl::StatusOr<int32_t> GpuDriver::GetDriverVersion() {
636+
int32_t version;
637+
TF_RETURN_IF_ERROR(cuda::ToStatus(cuDriverGetVersion(&version),
638+
"Could not get driver version"));
639+
return version;
640+
}
641+
635642
absl::StatusOr<size_t> GpuDriver::GraphGetNodeCount(GpuGraphHandle graph) {
636643
size_t num_nodes;
637644
TF_RETURN_IF_ERROR(

third_party/xla/xla/stream_executor/cuda/cuda_executor.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,11 +1166,9 @@ CudaExecutor::CreateDeviceDescription(int device_ordinal) {
11661166

11671167
DeviceDescription desc;
11681168

1169-
int32_t driver_version = 0;
1170-
TF_RETURN_IF_ERROR(cuda::ToStatus(cuDriverGetVersion(&driver_version),
1171-
"Could not get driver version"));
11721169
desc.set_driver_version(
1173-
ParseCudaVersion(driver_version).value_or(SemanticVersion{0, 0, 0}));
1170+
ParseCudaVersion(GpuDriver::GetDriverVersion().value_or(0))
1171+
.value_or(SemanticVersion{0, 0, 0}));
11741172
desc.set_runtime_version(
11751173
ParseCudaVersion(CudaRuntime::GetRuntimeVersion().value_or(0))
11761174
.value_or(SemanticVersion{0, 0, 0}));

third_party/xla/xla/stream_executor/gpu/gpu_driver.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,43 @@ class GpuDriver {
282282
GpuGraphNodeHandle node,
283283
GpuGraphHandle child);
284284

285+
// The CUDA stream callback type signature.
286+
// The data passed to AddStreamCallback is subsequently passed to this
287+
// callback when it fires.
288+
//
289+
// Some notable things:
290+
// * Callbacks must not make any CUDA API calls.
291+
// * Callbacks from independent streams execute in an undefined order and may
292+
// be serialized.
293+
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gab95a78143bae7f21eebb978f91e7f3f
294+
typedef void (*StreamCallback)(void* data);
295+
296+
// Blocks the calling thread until the operations enqueued onto stream have
297+
// been completed, via cuStreamSynchronize.
298+
//
299+
// TODO(leary) if a pathological thread enqueues operations onto the stream
300+
// while another thread blocks like this, can you wind up waiting an unbounded
301+
// amount of time?
302+
//
303+
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g15e49dd91ec15991eb7c0a741beb7dad
304+
static absl::Status SynchronizeStream(Context* context,
305+
GpuStreamHandle stream);
306+
285307
// -- Context- and device-independent calls.
286308

287309
// Returns the number of visible CUDA device via cuDeviceGetCount.
288310
// This should correspond to the set of device ordinals available.
289311
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g52b5ce05cb8c5fb6831b2c0ff2887c74
290312
static int GetDeviceCount();
313+
314+
// Returns the driver version number via cuDriverGetVersion.
315+
// This is, surprisingly, NOT the actual driver version (e.g. 331.79) but,
316+
// instead, the CUDA toolkit release number that this driver is compatible
317+
// with; e.g. 6000 (for a CUDA 6.0 compatible driver) or 6050 (for a CUDA 6.5
318+
// compatible driver).
319+
//
320+
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VERSION.html#group__CUDA__VERSION_1g8b7a10395392e049006e61bcdc8ebe71
321+
static absl::StatusOr<int32_t> GetDriverVersion();
291322
};
292323

293324
} // namespace gpu

third_party/xla/xla/stream_executor/rocm/rocm_driver.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,4 +494,11 @@ int GpuDriver::GetDeviceCount() {
494494
return device_count;
495495
}
496496

497+
absl::StatusOr<int32_t> GpuDriver::GetDriverVersion() {
498+
int32_t version;
499+
TF_RETURN_IF_ERROR(ToStatus(wrap::hipDriverGetVersion(&version),
500+
"Could not get driver version"));
501+
return version;
502+
}
503+
497504
} // namespace stream_executor::gpu

third_party/xla/xla/stream_executor/rocm/rocm_executor.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,11 +1054,9 @@ RocmExecutor::CreateDeviceDescription(int device_ordinal) {
10541054
desc.set_runtime_version(
10551055
ParseRocmVersion(RocmRuntime::GetRuntimeVersion().value_or(0))
10561056
.value_or(SemanticVersion{0, 0, 0}));
1057-
int32_t driver_version = 0;
1058-
TF_RETURN_IF_ERROR(ToStatus(wrap::hipDriverGetVersion(&driver_version),
1059-
"Could not get driver version"));
10601057
desc.set_driver_version(
1061-
ParseRocmVersion(driver_version).value_or(SemanticVersion{0, 0, 0}));
1058+
ParseRocmVersion(GpuDriver::GetDriverVersion().value_or(0))
1059+
.value_or(SemanticVersion{0, 0, 0}));
10621060

10631061
// It would be better to use the PCI device ID or some other truly unique
10641062
// identifier for the GPU model. But getting this requires using NVML or

0 commit comments

Comments
 (0)