Skip to content

Commit e5c2d22

Browse files
klucketensorflower-gardener
authored andcommitted
Move GpuDriver::GetPointerMemorySpace to the appropriate Executor classes.
PiperOrigin-RevId: 685292532
1 parent 15bdfa2 commit e5c2d22

File tree

7 files changed

+43
-48
lines changed

7 files changed

+43
-48
lines changed

third_party/xla/xla/stream_executor/cuda/cuda_driver.cc

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -944,22 +944,6 @@ int GpuDriver::GetDeviceCount() {
944944
return device_count;
945945
}
946946

947-
absl::StatusOr<MemoryType> GpuDriver::GetPointerMemorySpace(
948-
CUdeviceptr pointer) {
949-
unsigned int value;
950-
TF_RETURN_IF_ERROR(cuda::ToStatus(cuPointerGetAttribute(
951-
&value, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer)));
952-
switch (value) {
953-
case CU_MEMORYTYPE_DEVICE:
954-
return MemoryType::kDevice;
955-
case CU_MEMORYTYPE_HOST:
956-
return MemoryType::kHost;
957-
default:
958-
return absl::InternalError(
959-
absl::StrCat("unknown memory space provided by CUDA API: ", value));
960-
}
961-
}
962-
963947
absl::Status GpuDriver::GetPointerAddressRange(CUdeviceptr dptr,
964948
CUdeviceptr* base,
965949
size_t* size) {

third_party/xla/xla/stream_executor/cuda/cuda_executor.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,5 +1171,22 @@ CudaExecutor::CreateDeviceDescription(int device_ordinal) {
11711171
return std::make_unique<DeviceDescription>(std::move(desc));
11721172
}
11731173

1174+
absl::StatusOr<MemoryType> CudaExecutor::GetPointerMemorySpace(
1175+
const void* ptr) {
1176+
CUdeviceptr pointer = reinterpret_cast<CUdeviceptr>(const_cast<void*>(ptr));
1177+
unsigned int value;
1178+
TF_RETURN_IF_ERROR(cuda::ToStatus(cuPointerGetAttribute(
1179+
&value, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer)));
1180+
switch (value) {
1181+
case CU_MEMORYTYPE_DEVICE:
1182+
return MemoryType::kDevice;
1183+
case CU_MEMORYTYPE_HOST:
1184+
return MemoryType::kHost;
1185+
default:
1186+
return absl::InternalError(
1187+
absl::StrCat("unknown memory space provided by CUDA API: ", value));
1188+
}
1189+
}
1190+
11741191
} // namespace gpu
11751192
} // namespace stream_executor

third_party/xla/xla/stream_executor/cuda/cuda_executor.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,7 @@ class CudaExecutor : public GpuExecutor {
139139
bool HostMemoryRegister(void* location, uint64_t size) override;
140140
bool HostMemoryUnregister(void* location) override;
141141

142-
absl::StatusOr<MemoryType> GetPointerMemorySpace(const void* ptr) override {
143-
return GpuDriver::GetPointerMemorySpace(
144-
reinterpret_cast<GpuDevicePtr>(const_cast<void*>(ptr)));
145-
}
142+
absl::StatusOr<MemoryType> GetPointerMemorySpace(const void* ptr) override;
146143

147144
Stream* FindAllocatedStream(void* gpu_stream) override {
148145
absl::MutexLock lock(&alive_gpu_streams_mu_);

third_party/xla/xla/stream_executor/gpu/gpu_driver.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -421,9 +421,6 @@ class GpuDriver {
421421

422422
// -- Pointer-specific calls.
423423

424-
// Returns the memory space addressed by pointer.
425-
static absl::StatusOr<MemoryType> GetPointerMemorySpace(GpuDevicePtr pointer);
426-
427424
// Returns the base address and size of the device pointer dptr.
428425
static absl::Status GetPointerAddressRange(GpuDevicePtr dptr,
429426
GpuDevicePtr* base, size_t* size);

third_party/xla/xla/stream_executor/rocm/rocm_driver.cc

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -766,27 +766,6 @@ absl::Status GpuDriver::GetPointerAddressRange(hipDeviceptr_t dptr,
766766
reinterpret_cast<void*>(dptr), ToString(result).c_str()));
767767
}
768768

769-
absl::StatusOr<MemoryType> GpuDriver::GetPointerMemorySpace(
770-
hipDeviceptr_t pointer) {
771-
unsigned int value;
772-
hipError_t result = wrap::hipPointerGetAttribute(
773-
&value, HIP_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer);
774-
if (result == hipSuccess) {
775-
switch (value) {
776-
case hipMemoryTypeDevice:
777-
return MemoryType::kDevice;
778-
case hipMemoryTypeHost:
779-
return MemoryType::kHost;
780-
default:
781-
return absl::InternalError(
782-
absl::StrCat("unknown memory space provided by ROCM API: ", value));
783-
}
784-
}
785-
786-
return absl::InternalError(absl::StrCat(
787-
"failed to query device pointer for memory space: ", ToString(result)));
788-
}
789-
790769
absl::StatusOr<int32_t> GpuDriver::GetDriverVersion() {
791770
int32_t version;
792771
TF_RETURN_IF_ERROR(ToStatus(wrap::hipDriverGetVersion(&version),

third_party/xla/xla/stream_executor/rocm/rocm_executor.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ limitations under the License.
3737
#include "absl/synchronization/mutex.h"
3838
#include "absl/synchronization/notification.h"
3939
#include "absl/types/span.h"
40+
#include "rocm/include/hip/driver_types.h"
4041
#include "rocm/include/hip/hip_runtime.h"
4142
#include "rocm/include/hip/hip_version.h"
4243
#include "rocm/rocm_config.h"
@@ -975,6 +976,29 @@ RocmExecutor::CreateDeviceDescription(int device_ordinal) {
975976
return std::make_unique<DeviceDescription>(std::move(desc));
976977
}
977978

979+
absl::StatusOr<MemoryType> RocmExecutor::GetPointerMemorySpace(
980+
const void* ptr) {
981+
hipDeviceptr_t pointer =
982+
reinterpret_cast<hipDeviceptr_t>(const_cast<void*>(ptr));
983+
unsigned int value;
984+
hipError_t result = wrap::hipPointerGetAttribute(
985+
&value, HIP_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer);
986+
if (result == hipSuccess) {
987+
switch (value) {
988+
case hipMemoryTypeDevice:
989+
return MemoryType::kDevice;
990+
case hipMemoryTypeHost:
991+
return MemoryType::kHost;
992+
default:
993+
return absl::InternalError(
994+
absl::StrCat("unknown memory space provided by ROCM API: ", value));
995+
}
996+
}
997+
998+
return absl::InternalError(absl::StrCat(
999+
"failed to query device pointer for memory space: ", ToString(result)));
1000+
}
1001+
9781002
} // namespace gpu
9791003

9801004
} // namespace stream_executor

third_party/xla/xla/stream_executor/rocm/rocm_executor.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,7 @@ class RocmExecutor : public GpuExecutor {
128128
return GpuDriver::HostDeallocate(gpu_context(), location);
129129
}
130130

131-
absl::StatusOr<MemoryType> GetPointerMemorySpace(const void* ptr) override {
132-
return GpuDriver::GetPointerMemorySpace(
133-
reinterpret_cast<GpuDevicePtr>(const_cast<void*>(ptr)));
134-
}
131+
absl::StatusOr<MemoryType> GetPointerMemorySpace(const void* ptr) override;
135132

136133
Stream* FindAllocatedStream(void* gpu_stream) override {
137134
absl::MutexLock lock(&alive_gpu_streams_mu_);

0 commit comments

Comments
 (0)