26
26
#include < utility>
27
27
#include < vector>
28
28
29
- #include < cuda.h>
30
-
31
29
#include " arrow/gpu/cuda_internal.h"
32
30
#include " arrow/gpu/cuda_memory.h"
33
31
#include " arrow/util/checked_cast.h"
32
+ #include " arrow/util/logging.h"
34
33
35
34
namespace arrow {
36
35
@@ -273,6 +272,35 @@ bool IsCudaDevice(const Device& device) {
273
272
return device.type_name () == kCudaDeviceTypeName ;
274
273
}
275
274
275
+ Result<std::shared_ptr<Device::Stream>> CudaDevice::MakeStream (unsigned int flags) {
276
+ ARROW_ASSIGN_OR_RAISE (auto context, GetContext ());
277
+ ContextSaver set_temporary (reinterpret_cast <CUcontext>(context.get ()->handle ()));
278
+
279
+ CUstream stream;
280
+ CU_RETURN_NOT_OK (" cuStreamCreate" , cuStreamCreate (&stream, flags));
281
+ return std::shared_ptr<Device::Stream>(
282
+ new CudaDevice::Stream (context, new CUstream (stream), [](void * st) {
283
+ auto typed_stream = reinterpret_cast <CUstream*>(st);
284
+ // DCHECK_OK still evaluates its argument in release mode
285
+ // but in debug mode it'll also throw if it fails
286
+ DCHECK_OK (
287
+ internal::StatusFromCuda (cuStreamDestroy (*typed_stream), " cuStreamDestroy" ));
288
+ delete typed_stream;
289
+ }));
290
+ }
291
+
292
+ Result<std::shared_ptr<Device::Stream>> CudaDevice::WrapStream (
293
+ void * stream, Device::Stream::release_fn_t release_fn) {
294
+ if (!release_fn) {
295
+ release_fn = [](void *) {};
296
+ }
297
+
298
+ auto cu_stream = reinterpret_cast <CUstream*>(stream);
299
+ ARROW_ASSIGN_OR_RAISE (auto context, GetContext ());
300
+ return std::shared_ptr<Device::Stream>(
301
+ new CudaDevice::Stream (context, cu_stream, release_fn));
302
+ }
303
+
276
304
Result<std::shared_ptr<CudaDevice>> AsCudaDevice (const std::shared_ptr<Device>& device) {
277
305
if (IsCudaDevice (*device)) {
278
306
return checked_pointer_cast<CudaDevice>(device);
@@ -281,6 +309,48 @@ Result<std::shared_ptr<CudaDevice>> AsCudaDevice(const std::shared_ptr<Device>&
281
309
}
282
310
}
283
311
312
+ Status CudaDevice::Stream::WaitEvent (const Device::SyncEvent& event) {
313
+ auto cuda_event =
314
+ checked_cast<const CudaDevice::SyncEvent*, const Device::SyncEvent*>(&event);
315
+ if (!cuda_event) {
316
+ return Status::Invalid (" CudaDevice::Stream cannot Wait on non-cuda event" );
317
+ }
318
+
319
+ auto cu_event = cuda_event->value ();
320
+ if (!cu_event) {
321
+ return Status::Invalid (" Cuda Stream cannot wait on null event" );
322
+ }
323
+
324
+ ContextSaver set_temporary (reinterpret_cast <CUcontext>(context_.get ()->handle ()));
325
+ CU_RETURN_NOT_OK (" cuStreamWaitEvent" ,
326
+ cuStreamWaitEvent (value (), cu_event, CU_EVENT_WAIT_DEFAULT));
327
+ return Status::OK ();
328
+ }
329
+
330
+ Status CudaDevice::Stream::Synchronize () const {
331
+ ContextSaver set_temporary (reinterpret_cast <CUcontext>(context_.get ()->handle ()));
332
+ CU_RETURN_NOT_OK (" cuStreamSynchronize" , cuStreamSynchronize (value ()));
333
+ return Status::OK ();
334
+ }
335
+
336
+ Status CudaDevice::SyncEvent::Wait () {
337
+ ContextSaver set_temporary (reinterpret_cast <CUcontext>(context_.get ()->handle ()));
338
+ CU_RETURN_NOT_OK (" cuEventSynchronize" , cuEventSynchronize (value ()));
339
+ return Status::OK ();
340
+ }
341
+
342
+ Status CudaDevice::SyncEvent::Record (const Device::Stream& st, const unsigned int flags) {
343
+ auto cuda_stream = checked_cast<const CudaDevice::Stream*, const Device::Stream*>(&st);
344
+ if (!cuda_stream) {
345
+ return Status::Invalid (" CudaDevice::Event cannot record on non-cuda stream" );
346
+ }
347
+
348
+ ContextSaver set_temporary (reinterpret_cast <CUcontext>(context_.get ()->handle ()));
349
+ CU_RETURN_NOT_OK (" cuEventRecordWithFlags" ,
350
+ cuEventRecordWithFlags (value (), cuda_stream->value (), flags));
351
+ return Status::OK ();
352
+ }
353
+
284
354
// ----------------------------------------------------------------------
285
355
// CudaMemoryManager implementation
286
356
@@ -293,11 +363,35 @@ std::shared_ptr<CudaDevice> CudaMemoryManager::cuda_device() const {
293
363
return checked_pointer_cast<CudaDevice>(device_);
294
364
}
295
365
366
+ Result<std::shared_ptr<Device::SyncEvent>> CudaMemoryManager::MakeDeviceSyncEvent () {
367
+ ARROW_ASSIGN_OR_RAISE (auto context, cuda_device ()->GetContext ());
368
+ ContextSaver set_temporary (reinterpret_cast <CUcontext>(context.get ()->handle ()));
369
+
370
+ // TODO: event creation flags
371
+ CUevent ev;
372
+ CU_RETURN_NOT_OK (" cuEventCreate" , cuEventCreate (&ev, CU_EVENT_DEFAULT));
373
+
374
+ return std::shared_ptr<Device::SyncEvent>(
375
+ new CudaDevice::SyncEvent (context, new CUevent (ev), [](void * ev) {
376
+ auto typed_event = reinterpret_cast <CUevent*>(ev);
377
+ // DCHECK_OK still evaluates its argument in release mode
378
+ // but in debug mode it'll also throw if it fails
379
+ DCHECK_OK (
380
+ internal::StatusFromCuda (cuEventDestroy (*typed_event), " cuEventDestroy" ));
381
+ delete typed_event;
382
+ }));
383
+ }
384
+
296
385
Result<std::shared_ptr<Device::SyncEvent>> CudaMemoryManager::WrapDeviceSyncEvent (
297
386
void * sync_event, Device::SyncEvent::release_fn_t release_sync_event) {
298
- return nullptr ;
299
- // auto ev = reinterpret_cast<CUstream*>(sync_event);
300
- // return std::make_shared<CudaDeviceSync>(ev);
387
+ if (!release_sync_event) {
388
+ release_sync_event = [](void *) {};
389
+ }
390
+
391
+ auto ev = reinterpret_cast <CUevent*>(sync_event);
392
+ ARROW_ASSIGN_OR_RAISE (auto context, cuda_device ()->GetContext ());
393
+ return std::shared_ptr<Device::SyncEvent>(
394
+ new CudaDevice::SyncEvent (context, ev, release_sync_event));
301
395
}
302
396
303
397
Result<std::shared_ptr<io::RandomAccessFile>> CudaMemoryManager::GetBufferReader (
@@ -440,7 +534,7 @@ class CudaDeviceManager::Impl {
440
534
Status AllocateHost (int device_number, int64_t nbytes, uint8_t ** out) {
441
535
RETURN_NOT_OK (CheckDeviceNum (device_number));
442
536
ARROW_ASSIGN_OR_RAISE (auto ctx, GetContext (device_number));
443
- ContextSaver set_temporary (( CUcontext) (ctx.get ()->handle ()));
537
+ ContextSaver set_temporary (reinterpret_cast < CUcontext> (ctx.get ()->handle ()));
444
538
CU_RETURN_NOT_OK (" cuMemHostAlloc" , cuMemHostAlloc (reinterpret_cast <void **>(out),
445
539
static_cast <size_t >(nbytes),
446
540
CU_MEMHOSTALLOC_PORTABLE));
0 commit comments