Skip to content

Commit 7bbf238

Browse files
zeroshadeloicalleyne
authored andcommitted
apacheGH-37364: [C++][GPU] Add CUDA impl of Device Event/Stream (apache#37365)
### What changes are included in this PR? Adding `CudaDevice::SyncEvent` and `CudaDevice::Stream` implementations which provide more idiomatic handling of Events and Streams. ### Are these changes tested? Simple SyncEvent test added. More stream tests still being added. * Closes: apache#37364 Authored-by: Matt Topol <[email protected]> Signed-off-by: Matt Topol <[email protected]>
1 parent d504b84 commit 7bbf238

File tree

5 files changed

+299
-13
lines changed

5 files changed

+299
-13
lines changed

cpp/src/arrow/c/bridge_test.cc

+3-1
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,9 @@ class MyDevice : public Device {
12221222

12231223
virtual ~MySyncEvent() = default;
12241224
Status Wait() override { return Status::OK(); }
1225-
Status Record(const Device::Stream&) override { return Status::OK(); }
1225+
Status Record(const Device::Stream&, const unsigned int) override {
1226+
return Status::OK();
1227+
}
12261228
};
12271229

12281230
protected:

cpp/src/arrow/device.h

+40-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#pragma once
1919

2020
#include <cstdint>
21+
#include <functional>
2122
#include <memory>
2223
#include <string>
2324

@@ -109,23 +110,54 @@ class ARROW_EXPORT Device : public std::enable_shared_from_this<Device>,
109110
/// should be trivially constructible from it's device-specific counterparts.
110111
class ARROW_EXPORT Stream {
111112
public:
112-
virtual const void* get_raw() const { return NULLPTR; }
113+
using release_fn_t = std::function<void(void*)>;
114+
115+
virtual ~Stream() = default;
116+
117+
virtual const void* get_raw() const { return stream_.get(); }
113118

114119
/// \brief Make the stream wait on the provided event.
115120
///
116121
/// Tells the stream that it should wait until the synchronization
117122
/// event is completed without blocking the CPU.
118123
virtual Status WaitEvent(const SyncEvent&) = 0;
119124

125+
/// \brief Blocks the current thread until a stream's remaining tasks are completed
126+
virtual Status Synchronize() const = 0;
127+
120128
protected:
121-
Stream() = default;
122-
virtual ~Stream() = default;
129+
explicit Stream(void* stream, release_fn_t release_stream)
130+
: stream_{stream, release_stream} {}
131+
132+
std::unique_ptr<void, release_fn_t> stream_;
123133
};
124134

135+
virtual Result<std::shared_ptr<Stream>> MakeStream() { return NULLPTR; }
136+
137+
/// \brief Create a new device stream
138+
///
139+
/// This should create the appropriate stream type for the device,
140+
/// derived from Device::Stream to allow for stream ordered events
141+
/// and memory allocations.
142+
virtual Result<std::shared_ptr<Stream>> MakeStream(unsigned int flags) {
143+
return NULLPTR;
144+
}
145+
146+
/// @brief Wrap an existing device stream alongside a release function
147+
///
148+
/// @param device_stream a pointer to the stream to wrap
149+
/// @param release_fn a function to call during destruction, `nullptr` or
150+
/// a no-op function can be passed to indicate ownership is maintained
151+
/// externally
152+
virtual Result<std::shared_ptr<Stream>> WrapStream(void* device_stream,
153+
Stream::release_fn_t release_fn) {
154+
return NULLPTR;
155+
}
156+
125157
/// \brief EXPERIMENTAL: An object that provides event/stream sync primitives
126158
class ARROW_EXPORT SyncEvent {
127159
public:
128-
using release_fn_t = void (*)(void*);
160+
using release_fn_t = std::function<void(void*)>;
129161

130162
virtual ~SyncEvent() = default;
131163

@@ -134,9 +166,11 @@ class ARROW_EXPORT Device : public std::enable_shared_from_this<Device>,
134166
/// @brief Block until sync event is completed.
135167
virtual Status Wait() = 0;
136168

169+
inline Status Record(const Stream& st) { return Record(st, 0); }
170+
137171
/// @brief Record the wrapped event on the stream so it triggers
138172
/// the event when the stream gets to that point in its queue.
139-
virtual Status Record(const Stream&) = 0;
173+
virtual Status Record(const Stream&, const unsigned int flags) = 0;
140174

141175
protected:
142176
/// If creating this with a passed in event, the caller must ensure
@@ -225,7 +259,7 @@ class ARROW_EXPORT MemoryManager : public std::enable_shared_from_this<MemoryMan
225259

226260
/// \brief Wrap an event into a SyncEvent.
227261
///
228-
/// @param sync_event passed in sync_event from the imported device array.
262+
/// @param sync_event passed in sync_event (should be a pointer to the appropriate type)
229263
/// @param release_sync_event destructor to free sync_event. `nullptr` may be
230264
/// passed to indicate that no destruction/freeing is necessary
231265
virtual Result<std::shared_ptr<Device::SyncEvent>> WrapDeviceSyncEvent(

cpp/src/arrow/gpu/cuda_context.cc

+100-6
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@
2626
#include <utility>
2727
#include <vector>
2828

29-
#include <cuda.h>
30-
3129
#include "arrow/gpu/cuda_internal.h"
3230
#include "arrow/gpu/cuda_memory.h"
3331
#include "arrow/util/checked_cast.h"
32+
#include "arrow/util/logging.h"
3433

3534
namespace arrow {
3635

@@ -273,6 +272,35 @@ bool IsCudaDevice(const Device& device) {
273272
return device.type_name() == kCudaDeviceTypeName;
274273
}
275274

275+
Result<std::shared_ptr<Device::Stream>> CudaDevice::MakeStream(unsigned int flags) {
276+
ARROW_ASSIGN_OR_RAISE(auto context, GetContext());
277+
ContextSaver set_temporary(reinterpret_cast<CUcontext>(context.get()->handle()));
278+
279+
CUstream stream;
280+
CU_RETURN_NOT_OK("cuStreamCreate", cuStreamCreate(&stream, flags));
281+
return std::shared_ptr<Device::Stream>(
282+
new CudaDevice::Stream(context, new CUstream(stream), [](void* st) {
283+
auto typed_stream = reinterpret_cast<CUstream*>(st);
284+
// DCHECK_OK still evaluates its argument in release mode
285+
// but in debug mode it'll also throw if it fails
286+
DCHECK_OK(
287+
internal::StatusFromCuda(cuStreamDestroy(*typed_stream), "cuStreamDestroy"));
288+
delete typed_stream;
289+
}));
290+
}
291+
292+
Result<std::shared_ptr<Device::Stream>> CudaDevice::WrapStream(
293+
void* stream, Device::Stream::release_fn_t release_fn) {
294+
if (!release_fn) {
295+
release_fn = [](void*) {};
296+
}
297+
298+
auto cu_stream = reinterpret_cast<CUstream*>(stream);
299+
ARROW_ASSIGN_OR_RAISE(auto context, GetContext());
300+
return std::shared_ptr<Device::Stream>(
301+
new CudaDevice::Stream(context, cu_stream, release_fn));
302+
}
303+
276304
Result<std::shared_ptr<CudaDevice>> AsCudaDevice(const std::shared_ptr<Device>& device) {
277305
if (IsCudaDevice(*device)) {
278306
return checked_pointer_cast<CudaDevice>(device);
@@ -281,6 +309,48 @@ Result<std::shared_ptr<CudaDevice>> AsCudaDevice(const std::shared_ptr<Device>&
281309
}
282310
}
283311

312+
Status CudaDevice::Stream::WaitEvent(const Device::SyncEvent& event) {
313+
auto cuda_event =
314+
checked_cast<const CudaDevice::SyncEvent*, const Device::SyncEvent*>(&event);
315+
if (!cuda_event) {
316+
return Status::Invalid("CudaDevice::Stream cannot Wait on non-cuda event");
317+
}
318+
319+
auto cu_event = cuda_event->value();
320+
if (!cu_event) {
321+
return Status::Invalid("Cuda Stream cannot wait on null event");
322+
}
323+
324+
ContextSaver set_temporary(reinterpret_cast<CUcontext>(context_.get()->handle()));
325+
CU_RETURN_NOT_OK("cuStreamWaitEvent",
326+
cuStreamWaitEvent(value(), cu_event, CU_EVENT_WAIT_DEFAULT));
327+
return Status::OK();
328+
}
329+
330+
Status CudaDevice::Stream::Synchronize() const {
331+
ContextSaver set_temporary(reinterpret_cast<CUcontext>(context_.get()->handle()));
332+
CU_RETURN_NOT_OK("cuStreamSynchronize", cuStreamSynchronize(value()));
333+
return Status::OK();
334+
}
335+
336+
Status CudaDevice::SyncEvent::Wait() {
337+
ContextSaver set_temporary(reinterpret_cast<CUcontext>(context_.get()->handle()));
338+
CU_RETURN_NOT_OK("cuEventSynchronize", cuEventSynchronize(value()));
339+
return Status::OK();
340+
}
341+
342+
Status CudaDevice::SyncEvent::Record(const Device::Stream& st, const unsigned int flags) {
343+
auto cuda_stream = checked_cast<const CudaDevice::Stream*, const Device::Stream*>(&st);
344+
if (!cuda_stream) {
345+
return Status::Invalid("CudaDevice::Event cannot record on non-cuda stream");
346+
}
347+
348+
ContextSaver set_temporary(reinterpret_cast<CUcontext>(context_.get()->handle()));
349+
CU_RETURN_NOT_OK("cuEventRecordWithFlags",
350+
cuEventRecordWithFlags(value(), cuda_stream->value(), flags));
351+
return Status::OK();
352+
}
353+
284354
// ----------------------------------------------------------------------
285355
// CudaMemoryManager implementation
286356

@@ -293,11 +363,35 @@ std::shared_ptr<CudaDevice> CudaMemoryManager::cuda_device() const {
293363
return checked_pointer_cast<CudaDevice>(device_);
294364
}
295365

366+
Result<std::shared_ptr<Device::SyncEvent>> CudaMemoryManager::MakeDeviceSyncEvent() {
367+
ARROW_ASSIGN_OR_RAISE(auto context, cuda_device()->GetContext());
368+
ContextSaver set_temporary(reinterpret_cast<CUcontext>(context.get()->handle()));
369+
370+
// TODO: event creation flags
371+
CUevent ev;
372+
CU_RETURN_NOT_OK("cuEventCreate", cuEventCreate(&ev, CU_EVENT_DEFAULT));
373+
374+
return std::shared_ptr<Device::SyncEvent>(
375+
new CudaDevice::SyncEvent(context, new CUevent(ev), [](void* ev) {
376+
auto typed_event = reinterpret_cast<CUevent*>(ev);
377+
// DCHECK_OK still evaluates its argument in release mode
378+
// but in debug mode it'll also throw if it fails
379+
DCHECK_OK(
380+
internal::StatusFromCuda(cuEventDestroy(*typed_event), "cuEventDestroy"));
381+
delete typed_event;
382+
}));
383+
}
384+
296385
Result<std::shared_ptr<Device::SyncEvent>> CudaMemoryManager::WrapDeviceSyncEvent(
297386
void* sync_event, Device::SyncEvent::release_fn_t release_sync_event) {
298-
return nullptr;
299-
// auto ev = reinterpret_cast<CUstream*>(sync_event);
300-
// return std::make_shared<CudaDeviceSync>(ev);
387+
if (!release_sync_event) {
388+
release_sync_event = [](void*) {};
389+
}
390+
391+
auto ev = reinterpret_cast<CUevent*>(sync_event);
392+
ARROW_ASSIGN_OR_RAISE(auto context, cuda_device()->GetContext());
393+
return std::shared_ptr<Device::SyncEvent>(
394+
new CudaDevice::SyncEvent(context, ev, release_sync_event));
301395
}
302396

303397
Result<std::shared_ptr<io::RandomAccessFile>> CudaMemoryManager::GetBufferReader(
@@ -440,7 +534,7 @@ class CudaDeviceManager::Impl {
440534
Status AllocateHost(int device_number, int64_t nbytes, uint8_t** out) {
441535
RETURN_NOT_OK(CheckDeviceNum(device_number));
442536
ARROW_ASSIGN_OR_RAISE(auto ctx, GetContext(device_number));
443-
ContextSaver set_temporary((CUcontext)(ctx.get()->handle()));
537+
ContextSaver set_temporary(reinterpret_cast<CUcontext>(ctx.get()->handle()));
444538
CU_RETURN_NOT_OK("cuMemHostAlloc", cuMemHostAlloc(reinterpret_cast<void**>(out),
445539
static_cast<size_t>(nbytes),
446540
CU_MEMHOSTALLOC_PORTABLE));

cpp/src/arrow/gpu/cuda_context.h

+97
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include <memory>
2222
#include <string>
2323

24+
#include <cuda.h>
25+
2426
#include "arrow/device.h"
2527
#include "arrow/result.h"
2628
#include "arrow/util/visibility.h"
@@ -140,6 +142,90 @@ class ARROW_EXPORT CudaDevice : public Device {
140142
/// \param[in] size The buffer size in bytes
141143
Result<std::shared_ptr<CudaHostBuffer>> AllocateHostBuffer(int64_t size);
142144

145+
/// \brief EXPERIMENTAL: Wrapper for CUstreams
146+
///
147+
/// Does not *own* the CUstream object which must be separately constructed
148+
/// and freed using cuStreamCreate and cuStreamDestroy (or equivalent).
149+
/// Default construction will use the cuda default stream, and does not allow
150+
/// construction from literal 0 or nullptr.
151+
class ARROW_EXPORT Stream : public Device::Stream {
152+
public:
153+
~Stream() = default;
154+
155+
[[nodiscard]] inline CUstream value() const noexcept {
156+
if (!stream_) {
157+
return CUstream{};
158+
}
159+
return *reinterpret_cast<CUstream*>(stream_.get());
160+
}
161+
operator CUstream() const noexcept { return value(); }
162+
163+
const void* get_raw() const noexcept override { return stream_.get(); }
164+
Status WaitEvent(const Device::SyncEvent&) override;
165+
Status Synchronize() const override;
166+
167+
protected:
168+
friend class CudaDevice;
169+
170+
explicit Stream(std::shared_ptr<CudaContext> ctx, CUstream* st,
171+
Device::Stream::release_fn_t release_fn)
172+
: Device::Stream(reinterpret_cast<void*>(st), release_fn),
173+
context_{std::move(ctx)} {}
174+
175+
// disable construction from literal 0
176+
explicit Stream(std::shared_ptr<CudaContext>, int,
177+
Device::Stream::release_fn_t) = delete; // Prevent cast from 0
178+
explicit Stream(std::shared_ptr<CudaContext>, std::nullptr_t,
179+
Device::Stream::release_fn_t) = delete; // Prevent cast from nullptr
180+
181+
private:
182+
std::shared_ptr<CudaContext> context_;
183+
};
184+
185+
Result<std::shared_ptr<Device::Stream>> MakeStream() override { return MakeStream(0); }
186+
187+
/// \brief Create a CUstream wrapper in the current context
188+
Result<std::shared_ptr<Device::Stream>> MakeStream(unsigned int flags) override;
189+
190+
/// @brief Wrap a pointer to an existing stream
191+
///
192+
/// @param device_stream passed in stream (should be a CUstream*)
193+
/// @param release_fn destructor to free the stream. `nullptr` may be passed
194+
/// to indicate there is no destruction/freeing necessary.
195+
Result<std::shared_ptr<Device::Stream>> WrapStream(
196+
void* device_stream, Stream::release_fn_t release_fn) override;
197+
198+
class ARROW_EXPORT SyncEvent : public Device::SyncEvent {
199+
public:
200+
[[nodiscard]] CUevent value() const {
201+
if (sync_event_) {
202+
return *static_cast<CUevent*>(sync_event_.get());
203+
}
204+
return CUevent{};
205+
}
206+
operator CUevent() const noexcept { return value(); }
207+
208+
/// @brief Block until the sync event is marked completed
209+
Status Wait() override;
210+
211+
/// @brief Record the wrapped event on the stream
212+
///
213+
/// Once the stream completes the tasks previously added to it,
214+
/// it will trigger the event.
215+
Status Record(const Device::Stream&, const unsigned int) override;
216+
217+
protected:
218+
friend class CudaMemoryManager;
219+
220+
explicit SyncEvent(std::shared_ptr<CudaContext> ctx, CUevent* ev,
221+
Device::SyncEvent::release_fn_t release_ev)
222+
: Device::SyncEvent(reinterpret_cast<void*>(ev), release_ev),
223+
context_{std::move(ctx)} {}
224+
225+
private:
226+
std::shared_ptr<CudaContext> context_;
227+
};
228+
143229
protected:
144230
struct Impl;
145231

@@ -179,6 +265,17 @@ class ARROW_EXPORT CudaMemoryManager : public MemoryManager {
179265
/// having to cast the `device()` result.
180266
std::shared_ptr<CudaDevice> cuda_device() const;
181267

268+
/// \brief Creates a wrapped CUevent.
269+
///
270+
/// Will call cuEventCreate and it will call cuEventDestroy internally
271+
/// when the event is destructed.
272+
Result<std::shared_ptr<Device::SyncEvent>> MakeDeviceSyncEvent() override;
273+
274+
/// \brief Wraps an existing event into a sync event.
275+
///
276+
/// @param sync_event the event to wrap, must be a CUevent*
277+
/// @param release_sync_event a function to call during destruction, `nullptr` or
278+
/// a no-op function can be passed to indicate ownership is maintained externally
182279
Result<std::shared_ptr<Device::SyncEvent>> WrapDeviceSyncEvent(
183280
void* sync_event, Device::SyncEvent::release_fn_t release_sync_event) override;
184281

0 commit comments

Comments
 (0)