Skip to content

Commit f4a4639

Browse files
beckerhetensorflower-gardener
authored andcommitted
Split GpuTimer into CUDA and ROCm specific implementations
This requires the following changes: - Move GpuEvent::Record as `RecordEvent` into `CudaStream` and `RocmStream` - Move `GpuStream::WaitFor` into `CudaExecutor` and `RocmExecutor` - `CudaStream` and `RocmStream` get a factory function instead of having an init function. - The corresponding GpuDriver functions move into the .cc files where they get called. PiperOrigin-RevId: 683837514
1 parent 7de0fbe commit f4a4639

27 files changed

+702
-410
lines changed

third_party/xla/xla/service/gpu/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2892,7 +2892,6 @@ xla_test(
28922892
"//xla/stream_executor:device_description",
28932893
"//xla/stream_executor:platform",
28942894
"//xla/stream_executor:stream_executor_h",
2895-
"//xla/stream_executor/gpu:gpu_timer",
28962895
"//xla/stream_executor/gpu:mock_gpu_executor",
28972896
"//xla/tests:filecheck",
28982897
"//xla/tests:hlo_test_base",

third_party/xla/xla/service/gpu/determinism_test.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ limitations under the License.
3232
#include "xla/service/gpu/tests/gpu_codegen_test.h"
3333
#include "xla/service/platform_util.h"
3434
#include "xla/stream_executor/device_description.h"
35-
#include "xla/stream_executor/gpu/gpu_timer.h"
3635
#include "xla/stream_executor/gpu/mock_gpu_executor.h"
3736
#include "xla/stream_executor/platform.h"
3837
#include "xla/stream_executor/stream_executor.h"

third_party/xla/xla/stream_executor/cuda/BUILD

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -647,10 +647,15 @@ cc_library(
647647
"gpu",
648648
],
649649
deps = [
650-
":cuda_driver",
650+
":cuda_status",
651651
"//xla/stream_executor:event",
652+
"//xla/stream_executor/gpu:context",
653+
"//xla/stream_executor/gpu:gpu_driver_header",
652654
"//xla/stream_executor/gpu:gpu_event",
655+
"//xla/stream_executor/gpu:gpu_types_header",
653656
"//xla/stream_executor/gpu:scoped_activate_context",
657+
"@com_google_absl//absl/base",
658+
"@com_google_absl//absl/status",
654659
"@local_config_cuda//cuda:cuda_headers",
655660
],
656661
)
@@ -964,12 +969,14 @@ cc_library(
964969
],
965970
deps = [
966971
":cuda_collectives",
972+
":cuda_driver",
967973
":cuda_event", # buildcleaner: keep
968974
":cuda_kernel", # buildcleaner: keep
969975
":cuda_platform_id",
970976
":cuda_runtime",
971977
":cuda_status",
972978
":cuda_stream",
979+
":cuda_timer",
973980
":cuda_version_parser",
974981
":delay_kernel_cuda",
975982
"//xla/stream_executor",
@@ -994,7 +1001,6 @@ cc_library(
9941001
"//xla/stream_executor/gpu:gpu_kernel_header",
9951002
"//xla/stream_executor/gpu:gpu_semaphore",
9961003
"//xla/stream_executor/gpu:gpu_stream_header",
997-
"//xla/stream_executor/gpu:gpu_timer",
9981004
"//xla/stream_executor/gpu:gpu_types_header",
9991005
"//xla/stream_executor/gpu:read_numa_node",
10001006
"//xla/stream_executor/gpu:scoped_activate_context",
@@ -1146,5 +1152,48 @@ cc_library(
11461152
"cuda-only",
11471153
"gpu",
11481154
],
1149-
deps = ["//xla/stream_executor/gpu:gpu_stream"],
1155+
deps = [
1156+
":cuda_event",
1157+
":cuda_status",
1158+
"//xla/stream_executor:event",
1159+
"//xla/stream_executor:platform",
1160+
"//xla/stream_executor:stream",
1161+
"//xla/stream_executor/gpu:context",
1162+
"//xla/stream_executor/gpu:gpu_event",
1163+
"//xla/stream_executor/gpu:gpu_executor_header",
1164+
"//xla/stream_executor/gpu:gpu_stream",
1165+
"//xla/stream_executor/gpu:scoped_activate_context",
1166+
"@com_google_absl//absl/log",
1167+
"@com_google_absl//absl/status",
1168+
"@com_google_absl//absl/status:statusor",
1169+
"@local_config_cuda//cuda:cuda_headers",
1170+
"@local_tsl//tsl/platform:errors",
1171+
"@local_tsl//tsl/platform:statusor",
1172+
],
1173+
)
1174+
1175+
cc_library(
1176+
name = "cuda_timer",
1177+
srcs = ["cuda_timer.cc"],
1178+
hdrs = ["cuda_timer.h"],
1179+
tags = [
1180+
"cuda-only",
1181+
"gpu",
1182+
],
1183+
deps = [
1184+
":cuda_status",
1185+
"//xla/stream_executor:event_based_timer",
1186+
"//xla/stream_executor/gpu:context",
1187+
"//xla/stream_executor/gpu:gpu_event",
1188+
"//xla/stream_executor/gpu:gpu_semaphore",
1189+
"//xla/stream_executor/gpu:gpu_stream",
1190+
"//xla/stream_executor/gpu:scoped_activate_context",
1191+
"@com_google_absl//absl/log",
1192+
"@com_google_absl//absl/status",
1193+
"@com_google_absl//absl/status:statusor",
1194+
"@com_google_absl//absl/time",
1195+
"@local_config_cuda//cuda:cuda_headers",
1196+
"@local_tsl//tsl/platform:errors",
1197+
"@local_tsl//tsl/platform:statusor",
1198+
],
11501199
)

third_party/xla/xla/stream_executor/cuda/cuda_driver.cc

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,25 +1005,6 @@ absl::Status GpuDriver::AddStreamCallback(Context* context, CUstream stream,
10051005
return cuda::ToStatus(cuLaunchHostFunc(stream, callback, data));
10061006
}
10071007

1008-
absl::StatusOr<GpuStreamHandle> GpuDriver::CreateStream(Context* context,
1009-
int priority) {
1010-
ScopedActivateContext activated(context);
1011-
GpuStreamHandle stream;
1012-
// If the priority is 0, then use the previous api to create the stream with
1013-
// the default priority for backward compatibility. Probably there is no
1014-
// difference in using the new api call but leaving it as is for now.
1015-
if (priority == 0) {
1016-
TF_RETURN_IF_ERROR(
1017-
cuda::ToStatus(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)));
1018-
} else {
1019-
TF_RETURN_IF_ERROR(cuda::ToStatus(
1020-
cuStreamCreateWithPriority(&stream, CU_STREAM_NON_BLOCKING, priority)));
1021-
}
1022-
1023-
VLOG(2) << "successfully created stream " << stream << " for context "
1024-
<< context << " on thread";
1025-
return stream;
1026-
}
10271008

10281009
void GpuDriver::DestroyStream(Context* context, GpuStreamHandle stream) {
10291010
if (stream == nullptr) {
@@ -1155,23 +1136,6 @@ bool GpuDriver::HostUnregister(Context* context, void* location) {
11551136
return true;
11561137
}
11571138

1158-
int GpuDriver::GetGpuStreamPriority(
1159-
Context* context, stream_executor::StreamPriority stream_priority) {
1160-
ScopedActivateContext activation(context);
1161-
if (stream_priority == stream_executor::StreamPriority::Default) {
1162-
return 0;
1163-
}
1164-
int lowest, highest;
1165-
auto status = cuda::ToStatus(cuCtxGetStreamPriorityRange(&lowest, &highest));
1166-
if (!status.ok()) {
1167-
LOG(ERROR)
1168-
<< "Could not query stream priority range. Returning default priority.";
1169-
return 0;
1170-
}
1171-
return stream_priority == stream_executor::StreamPriority::Highest ? highest
1172-
: lowest;
1173-
}
1174-
11751139
absl::Status GpuDriver::DestroyEvent(Context* context, CUevent* event) {
11761140
if (*event == nullptr) {
11771141
return absl::InvalidArgumentError("input event cannot be null");
@@ -1181,39 +1145,6 @@ absl::Status GpuDriver::DestroyEvent(Context* context, CUevent* event) {
11811145
return cuda::ToStatus(cuEventDestroy(*event), "Error destroying CUDA event");
11821146
}
11831147

1184-
absl::Status GpuDriver::RecordEvent(Context* context, CUevent event,
1185-
CUstream stream) {
1186-
ScopedActivateContext activated{context};
1187-
return cuda::ToStatus(cuEventRecord(event, stream),
1188-
"Error recording CUDA event");
1189-
}
1190-
1191-
absl::StatusOr<float> GpuDriver::GetEventElapsedTime(Context* context,
1192-
CUevent start,
1193-
CUevent stop) {
1194-
ScopedActivateContext activated{context};
1195-
// The stop event must have completed in order for cuEventElapsedTime to
1196-
// work.
1197-
auto status = cuda::ToStatus(cuEventSynchronize(stop));
1198-
if (!status.ok()) {
1199-
LOG(ERROR) << "failed to synchronize the stop event: " << status;
1200-
return false;
1201-
}
1202-
1203-
float elapsed_milliseconds;
1204-
1205-
TF_RETURN_IF_ERROR(
1206-
cuda::ToStatus(cuEventElapsedTime(&elapsed_milliseconds, start, stop)));
1207-
1208-
return elapsed_milliseconds;
1209-
}
1210-
1211-
absl::Status GpuDriver::WaitStreamOnEvent(Context* context, CUstream stream,
1212-
CUevent event) {
1213-
ScopedActivateContext activation(context);
1214-
return cuda::ToStatus(cuStreamWaitEvent(stream, event, 0 /* = flags */));
1215-
}
1216-
12171148
absl::Status GpuDriver::SynchronizeStream(Context* context, CUstream stream) {
12181149
ScopedActivateContext activated{context};
12191150
CHECK(stream != nullptr);

third_party/xla/xla/stream_executor/cuda/cuda_event.cc

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,25 @@ limitations under the License.
1515

1616
#include "xla/stream_executor/cuda/cuda_event.h"
1717

18+
#include <cstdint>
19+
20+
#include "absl/base/casts.h"
21+
#include "absl/status/status.h"
1822
#include "third_party/gpus/cuda/include/cuda.h"
19-
#include "xla/stream_executor/cuda/cuda_driver.h"
23+
#include "xla/stream_executor/cuda/cuda_status.h"
2024
#include "xla/stream_executor/event.h"
25+
#include "xla/stream_executor/gpu/context.h"
2126
#include "xla/stream_executor/gpu/scoped_activate_context.h"
2227

2328
namespace stream_executor {
2429
namespace gpu {
30+
namespace {
31+
absl::Status WaitStreamOnEvent(Context* context, CUstream stream,
32+
CUevent event) {
33+
ScopedActivateContext activation(context);
34+
return cuda::ToStatus(cuStreamWaitEvent(stream, event, 0 /* = flags */));
35+
}
36+
} // namespace
2537

2638
Event::Status CudaEvent::PollForStatus() {
2739
ScopedActivateContext activated(context());
@@ -34,5 +46,10 @@ Event::Status CudaEvent::PollForStatus() {
3446
return Event::Status::kError;
3547
}
3648

49+
absl::Status CudaEvent::WaitForEventOnExternalStream(std::intptr_t stream) {
50+
return WaitStreamOnEvent(context(), absl::bit_cast<CUstream>(stream),
51+
gpu_event());
52+
}
53+
3754
} // namespace gpu
3855
} // namespace stream_executor

third_party/xla/xla/stream_executor/cuda/cuda_event.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@ limitations under the License.
1616
#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_
1717
#define XLA_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_
1818

19+
#include <cstdint>
20+
21+
#include "absl/status/status.h"
1922
#include "xla/stream_executor/event.h"
23+
#include "xla/stream_executor/gpu/context.h"
2024
#include "xla/stream_executor/gpu/gpu_event.h"
2125

2226
namespace stream_executor::gpu {
@@ -29,6 +33,8 @@ class CudaEvent : public GpuEvent {
2933
explicit CudaEvent(Context *context) : GpuEvent(context) {}
3034

3135
Event::Status PollForStatus() override;
36+
37+
absl::Status WaitForEventOnExternalStream(std::intptr_t stream) override;
3238
};
3339

3440
} // namespace stream_executor::gpu

third_party/xla/xla/stream_executor/cuda/cuda_executor.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ limitations under the License.
5252
#include "xla/stream_executor/cuda/cuda_runtime.h"
5353
#include "xla/stream_executor/cuda/cuda_status.h"
5454
#include "xla/stream_executor/cuda/cuda_stream.h"
55+
#include "xla/stream_executor/cuda/cuda_timer.h"
5556
#include "xla/stream_executor/cuda/cuda_version_parser.h"
5657
#include "xla/stream_executor/cuda/delay_kernel.h"
5758
#include "xla/stream_executor/device_description.h"
@@ -67,7 +68,6 @@ limitations under the License.
6768
#include "xla/stream_executor/gpu/gpu_kernel.h"
6869
#include "xla/stream_executor/gpu/gpu_semaphore.h"
6970
#include "xla/stream_executor/gpu/gpu_stream.h"
70-
#include "xla/stream_executor/gpu/gpu_timer.h"
7171
#include "xla/stream_executor/gpu/gpu_types.h"
7272
#include "xla/stream_executor/gpu/read_numa_node.h"
7373
#include "xla/stream_executor/gpu/scoped_activate_context.h"
@@ -427,10 +427,10 @@ CudaExecutor::CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) {
427427
}
428428
TF_ASSIGN_OR_RETURN(auto start_event, CreateGpuEvent(/*allow_timing=*/true));
429429
TF_ASSIGN_OR_RETURN(auto stop_event, CreateGpuEvent(/*allow_timing=*/true));
430-
TF_RETURN_IF_ERROR(start_event->Record(stream->gpu_stream()));
431-
return std::make_unique<GpuTimer>(gpu_context(), std::move(start_event),
432-
std::move(stop_event), stream,
433-
std::move(semaphore));
430+
TF_RETURN_IF_ERROR(stream->RecordEvent(start_event.get()));
431+
return std::make_unique<CudaTimer>(gpu_context(), std::move(start_event),
432+
std::move(stop_event), stream,
433+
std::move(semaphore));
434434
}
435435

436436
bool CudaExecutor::UnloadGpuBinary(const void* gpu_binary) {
@@ -811,9 +811,9 @@ absl::StatusOr<std::unique_ptr<Event>> CudaExecutor::CreateEvent() {
811811
absl::StatusOr<std::unique_ptr<Stream>> CudaExecutor::CreateStream(
812812
std::optional<std::variant<StreamPriority, int>> priority) {
813813
TF_ASSIGN_OR_RETURN(auto event, CreateGpuEvent(/*allow_timing=*/false));
814-
auto stream = std::make_unique<CudaStream>(this, std::move(event), priority);
814+
TF_ASSIGN_OR_RETURN(auto stream,
815+
CudaStream::Create(this, std::move(event), priority));
815816
absl::MutexLock l(&alive_gpu_streams_mu_);
816-
TF_RETURN_IF_ERROR(stream->Init());
817817
auto gpu_stream = stream->gpu_stream();
818818
alive_gpu_streams_[gpu_stream] = stream.get();
819819
return std::move(stream);

0 commit comments

Comments
 (0)