Skip to content

Commit 923fbc1

Browse files
beckerheGoogle-ML-Automation
authored andcommitted
Remove AsGpuStreamValue usage from the CUDA blas_plugin
`AsGpuStreamValue` is deprecated and needs to be inlined. PiperOrigin-RevId: 742618884
1 parent 047108c commit 923fbc1

File tree

4 files changed

+23
-15
lines changed

4 files changed

+23
-15
lines changed

xla/stream_executor/cuda/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ cc_library(
346346
visibility = ["//visibility:public"],
347347
deps = [
348348
":cuda_blas_utils",
349+
":cuda_compute_capability",
349350
":cuda_executor",
350351
":cuda_helpers",
351352
":cuda_platform_id",
@@ -367,13 +368,17 @@ cc_library(
367368
"//xla/stream_executor:stream_executor_h",
368369
"//xla/stream_executor/gpu:gpu_blas_lt",
369370
"//xla/stream_executor/gpu:gpu_helpers_header",
370-
"//xla/stream_executor/gpu:gpu_stream_header",
371371
"//xla/stream_executor/platform:initialize",
372372
"//xla/tsl/cuda:cublas",
373373
"//xla/tsl/cuda:cublas_lt",
374+
"//xla/tsl/platform:errors",
375+
"//xla/tsl/platform:logging",
376+
"//xla/tsl/platform:statusor",
374377
"//xla/tsl/protobuf:dnn_proto_cc",
378+
"@com_google_absl//absl/base",
375379
"@com_google_absl//absl/base:core_headers",
376380
"@com_google_absl//absl/log",
381+
"@com_google_absl//absl/log:check",
377382
"@com_google_absl//absl/status",
378383
"@com_google_absl//absl/status:statusor",
379384
"@com_google_absl//absl/strings",

xla/stream_executor/cuda/cuda_blas.cc

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ limitations under the License.
2222
#include <string>
2323
#include <vector>
2424

25+
#include "absl/base/casts.h"
26+
#include "absl/log/check.h"
27+
#include "absl/log/log.h"
2528
#include "absl/status/status.h"
29+
#include "absl/status/statusor.h"
2630
#include "absl/strings/str_cat.h"
2731
#include "absl/strings/str_format.h"
2832
#include "absl/strings/string_view.h"
@@ -40,22 +44,21 @@ limitations under the License.
4044
#include "xla/stream_executor/activate_context.h"
4145
#include "xla/stream_executor/blas.h"
4246
#include "xla/stream_executor/cuda/cuda_blas_utils.h"
47+
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
4348
#include "xla/stream_executor/cuda/cuda_helpers.h"
4449
#include "xla/stream_executor/cuda/cuda_platform_id.h"
45-
#include "xla/stream_executor/device_description.h"
4650
#include "xla/stream_executor/device_memory.h"
4751
#include "xla/stream_executor/event_based_timer.h"
4852
#include "xla/stream_executor/gpu/gpu_helpers.h"
49-
#include "xla/stream_executor/gpu/gpu_stream.h"
5053
#include "xla/stream_executor/numeric_options.h"
5154
#include "xla/stream_executor/platform/initialize.h"
5255
#include "xla/stream_executor/plugin_registry.h"
5356
#include "xla/stream_executor/scratch_allocator.h"
5457
#include "xla/stream_executor/stream_executor.h"
58+
#include "xla/tsl/platform/errors.h"
59+
#include "xla/tsl/platform/logging.h"
60+
#include "xla/tsl/platform/statusor.h"
5561
#include "xla/tsl/protobuf/dnn.pb.h"
56-
#include "tsl/platform/errors.h"
57-
#include "tsl/platform/logging.h"
58-
#include "tsl/platform/statusor.h"
5962
#include "tsl/platform/tensor_float_32_utils.h"
6063

6164
namespace stream_executor {
@@ -229,7 +232,10 @@ bool CUDABlas::SetStream(Stream *stream) {
229232
CHECK(blas_ != nullptr);
230233
std::unique_ptr<ActivateContext> activation = parent_->Activate();
231234

232-
auto handle = (stream != nullptr) ? gpu::AsGpuStreamValue(stream) : nullptr;
235+
auto handle =
236+
(stream != nullptr)
237+
? absl::bit_cast<CUstream>(stream->platform_specific_handle().stream)
238+
: nullptr;
233239
if (auto ret = cublasSetStream(blas_, handle); ret != CUBLAS_STATUS_SUCCESS) {
234240
LOG(ERROR) << "failed to set stream for cuBLAS calls: " << ToString(ret);
235241
return false;

xla/stream_executor/cuda/cuda_blas_lt.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ limitations under the License.
2727
#include <utility>
2828
#include <vector>
2929

30+
#include "absl/base/casts.h"
3031
#include "absl/log/log.h"
3132
#include "absl/status/status.h"
3233
#include "absl/status/statusor.h"
@@ -40,20 +41,17 @@ limitations under the License.
4041
#include "xla/status_macros.h"
4142
#include "xla/stream_executor/activate_context.h"
4243
#include "xla/stream_executor/blas.h"
43-
#include "xla/stream_executor/cuda/cuda_blas.h"
4444
#include "xla/stream_executor/cuda/cuda_blas_utils.h"
4545
#include "xla/stream_executor/device_memory.h"
4646
#include "xla/stream_executor/event_based_timer.h"
4747
#include "xla/stream_executor/gpu/gpu_blas_lt.h"
4848
#include "xla/stream_executor/gpu/gpu_helpers.h"
49-
#include "xla/stream_executor/gpu/gpu_stream.h"
5049
#include "xla/stream_executor/stream.h"
50+
#include "xla/tsl/platform/errors.h"
51+
#include "xla/tsl/platform/statusor.h"
5152
#include "xla/types.h"
5253
#include "xla/util.h"
5354
#include "xla/xla_data.pb.h"
54-
#include "tsl/platform/errors.h"
55-
#include "tsl/platform/ml_dtypes.h"
56-
#include "tsl/platform/statusor.h"
5755

5856
#define SET_ATTR(setter, handle, attr, value) \
5957
ToStatus(setter(handle, attr, &value, sizeof(decltype(value))), #setter)
@@ -462,7 +460,8 @@ absl::Status BlasLt::MatmulPlan::DoMatmul(
462460
blas_lt->blas_lt_.get(), op_desc_.get(), alpha, a.opaque(),
463461
a_desc_.get(), b.opaque(), b_desc_.get(), beta, args.c.opaque(),
464462
c_desc_.get(), args.d.opaque(), d_desc_.get(), palgo, workspace_addr,
465-
workspace_size, gpu::AsGpuStreamValue(stream)));
463+
workspace_size,
464+
absl::bit_cast<CUstream>(stream->platform_specific_handle().stream)));
466465
} else {
467466
return absl::InternalError("cublaslt: Invalid algorithm type");
468467
}

xla/stream_executor/cuda/cuda_blas_lt.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ limitations under the License.
1818

1919
#include <cstddef>
2020
#include <memory>
21-
#include <optional>
2221
#include <type_traits>
2322
#include <utility>
2423
#include <vector>
@@ -31,7 +30,6 @@ limitations under the License.
3130
#include "third_party/gpus/cuda/include/cublas_v2.h"
3231
#include "third_party/gpus/cuda/include/library_types.h"
3332
#include "xla/stream_executor/blas.h"
34-
#include "xla/stream_executor/device_memory.h"
3533
#include "xla/stream_executor/gpu/gpu_blas_lt.h"
3634
#include "xla/stream_executor/scratch_allocator.h"
3735
#include "xla/stream_executor/stream_executor.h"

0 commit comments

Comments
 (0)