Skip to content

Commit 87b2dd7

Browse files
committed
Check for nullptr in raii wrappers
1 parent 0c84563 commit 87b2dd7

File tree

10 files changed

+71
-55
lines changed

10 files changed

+71
-55
lines changed

source/adapters/level_zero/v2/common.hpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -107,17 +107,9 @@ using ze_command_list_handle_t =
107107
template <typename URHandle, ur_result_t (*retain)(URHandle),
108108
ur_result_t (*release)(URHandle)>
109109
struct ref_counted {
110-
ref_counted(URHandle handle) : handle(handle) {
111-
if (handle) {
112-
retain(handle);
113-
}
114-
}
110+
ref_counted(URHandle handle) : handle(handle) { retain(handle); }
115111

116-
~ref_counted() {
117-
if (handle) {
118-
release(handle);
119-
}
120-
}
112+
~ref_counted() { release(handle); }
121113

122114
operator URHandle() const { return handle; }
123115
URHandle operator->() const { return handle; }
@@ -178,9 +170,16 @@ template <typename URHandle> struct ref_counted_traits;
178170

179171
#define DECLARE_REF_COUNTER_TRAITS(URHandle, retainFn, releaseFn) \
180172
template <> struct ref_counted_traits<URHandle> { \
181-
static ur_result_t retain(URHandle handle) { return retainFn(handle); } \
182-
static ur_result_t release(URHandle handle) { return releaseFn(handle); } \
173+
static ur_result_t retain(URHandle handle) { \
174+
assert(handle); \
175+
return retainFn(handle); \
176+
} \
177+
static ur_result_t release(URHandle handle) { \
178+
assert(handle); \
179+
return releaseFn(handle); \
180+
} \
183181
static ur_result_t validate([[maybe_unused]] URHandle handle) { \
182+
assert(handle); \
184183
assert(reinterpret_cast<_ur_object *>(handle)->RefCount.load() != 0); \
185184
return UR_RESULT_SUCCESS; \
186185
} \
@@ -200,13 +199,18 @@ using rc_val_only =
200199
ref_counted<URHandle, ref_counted_traits<URHandle>::validate,
201200
ref_counted_traits<URHandle>::validate>;
202201

203-
DECLARE_REF_COUNTER_TRAITS(::ur_device_handle_t, urDeviceRetain,
204-
urDeviceRelease);
205-
DECLARE_REF_COUNTER_TRAITS(::ur_context_handle_t, urContextRetain,
206-
urContextRelease);
207-
DECLARE_REF_COUNTER_TRAITS(::ur_mem_handle_t, urMemRetain, urMemRelease);
208-
DECLARE_REF_COUNTER_TRAITS(::ur_program_handle_t, urProgramRetain,
209-
urProgramRelease);
202+
DECLARE_REF_COUNTER_TRAITS(::ur_device_handle_t, ur::level_zero::urDeviceRetain,
203+
ur::level_zero::urDeviceRelease);
204+
DECLARE_REF_COUNTER_TRAITS(::ur_context_handle_t,
205+
ur::level_zero::urContextRetain,
206+
ur::level_zero::urContextRelease);
207+
DECLARE_REF_COUNTER_TRAITS(::ur_mem_handle_t, ur::level_zero::urMemRetain,
208+
ur::level_zero::urMemRelease);
209+
DECLARE_REF_COUNTER_TRAITS(::ur_program_handle_t,
210+
ur::level_zero::urProgramRetain,
211+
ur::level_zero::urProgramRelease);
212+
DECLARE_REF_COUNTER_TRAITS(::ur_queue_handle_t, ur::level_zero::urQueueRetain,
213+
ur::level_zero::urQueueRelease);
210214

211215
} // namespace raii
212216

source/adapters/level_zero/v2/event.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,20 @@ ur_event_handle_t_::ur_event_handle_t_(
9393
: hContext(hContext), hZeEvent(hZeEvent), flags(flags),
9494
profilingData(hZeEvent) {}
9595

96-
void ur_event_handle_t_::resetQueueAndCommand(ur_queue_handle_t hQueue,
97-
ur_command_t commandType) {
96+
void ur_event_handle_t_::resetQueueAndCommand(
97+
v2::raii::weak<ur_queue_handle_t> hQueue, ur_command_t commandType) {
9898
this->hQueue = hQueue;
9999
this->commandType = commandType;
100100
profilingData = event_profiling_data_t(hZeEvent);
101101
}
102102

103103
void ur_event_handle_t_::recordStartTimestamp() {
104-
assert(hQueue); // queue must be set before calling this
104+
assert(hQueue.has_value()); // queue must be set before calling this
105105

106106
ur_device_handle_t hDevice;
107-
UR_CALL_THROWS(hQueue->queueGetInfo(UR_QUEUE_INFO_DEVICE, sizeof(hDevice),
108-
reinterpret_cast<void *>(&hDevice),
109-
nullptr));
107+
UR_CALL_THROWS(hQueue.value()->queueGetInfo(
108+
UR_QUEUE_INFO_DEVICE, sizeof(hDevice), reinterpret_cast<void *>(&hDevice),
109+
nullptr));
110110

111111
profilingData.recordStartTimestamp(hDevice);
112112
}
@@ -157,8 +157,8 @@ ur_result_t ur_event_handle_t_::release() {
157157

158158
// If this code is being executed, queue has to be valid (queue cannot
159159
// be released before all operations complete).
160-
assert(hQueue);
161-
hQueue->deferEventFree(this);
160+
assert(hQueue.has_value());
161+
hQueue.value()->deferEventFree(this);
162162
return UR_RESULT_SUCCESS;
163163
}
164164

@@ -182,7 +182,9 @@ ur_event_handle_t_::getEventEndTimestampAndHandle() {
182182
return {profilingData.eventEndTimestampAddr(), hZeEvent};
183183
}
184184

185-
ur_queue_handle_t ur_event_handle_t_::getQueue() const { return hQueue; }
185+
ur_queue_handle_t ur_event_handle_t_::getQueue() const {
186+
return hQueue.has_value() ? hQueue.value().get() : nullptr;
187+
}
186188

187189
ur_context_handle_t ur_event_handle_t_::getContext() const { return hContext; }
188190

source/adapters/level_zero/v2/event.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ struct ur_event_handle_t_ : _ur_object {
5151
ze_event_handle_t hZeEvent, v2::event_flags_t flags);
5252

5353
// Set the queue and command that this event is associated with
54-
void resetQueueAndCommand(ur_queue_handle_t hQueue, ur_command_t commandType);
54+
void resetQueueAndCommand(v2::raii::weak<ur_queue_handle_t> hQueue,
55+
ur_command_t commandType);
5556

5657
// releases event immediately
5758
virtual ur_result_t forceRelease() = 0;
@@ -104,8 +105,8 @@ struct ur_event_handle_t_ : _ur_object {
104105
const ze_event_handle_t hZeEvent;
105106

106107
// queue and commandType that this event is associated with, set by enqueue
107-
// commands. DO NOT ref count hQueue here to avoid circular references.
108-
ur_queue_handle_t hQueue = nullptr;
108+
// commands.
109+
std::optional<v2::raii::weak<ur_queue_handle_t>> hQueue = std::nullopt;
109110
ur_command_t commandType = UR_COMMAND_FORCE_UINT32;
110111

111112
v2::event_flags_t flags;

source/adapters/level_zero/v2/event_pool.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,6 @@ ur_pooled_event_t *event_pool::allocate() {
3434
auto event = freelist.back();
3535
freelist.pop_back();
3636

37-
#ifndef NDEBUG
38-
// Set the command type to an invalid value to catch any misuses in tests
39-
event->resetQueueAndCommand(nullptr, UR_COMMAND_FORCE_UINT32);
40-
#endif
41-
4237
return event;
4338
}
4439

source/adapters/level_zero/v2/event_provider_normal.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
namespace v2 {
2525
static constexpr int EVENTS_BURST = 64;
2626

27-
provider_pool::provider_pool(ur_context_handle_t context, queue_type queue,
28-
event_flags_t flags) {
27+
provider_pool::provider_pool(raii::weak<ur_context_handle_t> context,
28+
queue_type queue, event_flags_t flags) {
2929
ZeStruct<ze_event_pool_desc_t> desc;
3030
desc.count = EVENTS_BURST;
3131
desc.flags = ZE_EVENT_POOL_FLAG_HOST_VISIBLE;

source/adapters/level_zero/v2/event_provider_normal.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ enum queue_type {
3434

3535
class provider_pool {
3636
public:
37-
provider_pool(ur_context_handle_t, queue_type, event_flags_t flags);
37+
provider_pool(raii::weak<ur_context_handle_t>, queue_type,
38+
event_flags_t flags);
3839

3940
raii::cache_borrowed_event allocate();
4041
size_t nfree() const;

source/adapters/level_zero/v2/memory.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(
209209
device_access_mode_t accessMode)
210210
: ur_mem_handle_t_(hContext, size, accessMode),
211211
deviceAllocations(hContext->getPlatform()->getNumDevices()),
212-
activeAllocationDevice(nullptr), hostAllocations() {
212+
activeAllocationDevice(std::nullopt), hostAllocations() {
213213
if (hostPtr) {
214214
auto initialDevice = hContext->getDevices()[0];
215215
UR_CALL_THROWS(migrateBufferTo(initialDevice, hostPtr, size));
@@ -247,9 +247,9 @@ ur_discrete_mem_handle_t::~ur_discrete_mem_handle_t() {
247247
return;
248248

249249
auto srcPtr = ur_cast<char *>(
250-
deviceAllocations[activeAllocationDevice->Id.value()].get());
251-
synchronousZeCopy(hContext, activeAllocationDevice, writeBackPtr, srcPtr,
252-
getSize());
250+
deviceAllocations[activeAllocationDevice.value()->Id.value()].get());
251+
synchronousZeCopy(hContext, activeAllocationDevice.value(), writeBackPtr,
252+
srcPtr, getSize());
253253
}
254254

255255
void *ur_discrete_mem_handle_t::getDevicePtr(
@@ -269,7 +269,7 @@ void *ur_discrete_mem_handle_t::getDevicePtr(
269269
}
270270

271271
if (!hDevice) {
272-
hDevice = activeAllocationDevice;
272+
hDevice = activeAllocationDevice.value().get();
273273
}
274274

275275
char *ptr;
@@ -289,7 +289,8 @@ void *ur_discrete_mem_handle_t::getDevicePtr(
289289

290290
// TODO: see if it's better to migrate the memory to the specified device
291291
return ur_cast<char *>(
292-
deviceAllocations[activeAllocationDevice->Id.value()].get()) +
292+
deviceAllocations[activeAllocationDevice.value()->Id.value()]
293+
.get()) +
293294
offset;
294295
}
295296

@@ -308,7 +309,8 @@ void *ur_discrete_mem_handle_t::mapHostPtr(
308309
if (activeAllocationDevice && (flags & UR_MAP_FLAG_READ)) {
309310
auto srcPtr =
310311
ur_cast<char *>(
311-
deviceAllocations[activeAllocationDevice->Id.value()].get()) +
312+
deviceAllocations[activeAllocationDevice.value()->Id.value()]
313+
.get()) +
312314
offset;
313315
migrate(srcPtr, hostAllocations.back().ptr, size);
314316
}
@@ -327,7 +329,8 @@ void ur_discrete_mem_handle_t::unmapHostPtr(
327329
if (activeAllocationDevice) {
328330
devicePtr =
329331
ur_cast<char *>(
330-
deviceAllocations[activeAllocationDevice->Id.value()].get()) +
332+
deviceAllocations[activeAllocationDevice.value()->Id.value()]
333+
.get()) +
331334
hostAllocation.offset;
332335
} else if (!(hostAllocation.flags &
333336
UR_MAP_FLAG_WRITE_INVALIDATE_REGION)) {

source/adapters/level_zero/v2/memory.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ struct ur_discrete_mem_handle_t : public ur_mem_handle_t_ {
140140
std::vector<usm_unique_ptr_t> deviceAllocations;
141141

142142
// Specifies device on which the latest allocation resides.
143-
// If null, there is no allocation.
144-
v2::raii::rc_val_only<ur_device_handle_t> activeAllocationDevice = nullptr;
143+
std::optional<v2::raii::rc_val_only<ur_device_handle_t>>
144+
activeAllocationDevice = std::nullopt;
145145

146146
// If not null, copy the buffer content back to this memory on release.
147147
void *writeBackPtr = nullptr;

source/adapters/level_zero/v2/queue_immediate_in_order.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ ur_queue_immediate_in_order_t::getSignalEvent(ur_event_handle_t *hUserEvent,
106106
ur_command_t commandType) {
107107
if (hUserEvent) {
108108
*hUserEvent = eventPool->allocate();
109-
(*hUserEvent)->resetQueueAndCommand(this, commandType);
109+
(*hUserEvent)
110+
->resetQueueAndCommand(raii::rc_val_only<ur_queue_handle_t>(this),
111+
commandType);
110112
return *hUserEvent;
111113
} else {
112114
return nullptr;

test/adapters/level_zero/v2/event_pool_test.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ TEST_P(EventPoolTest, Basic) {
165165
auto pool = cache->borrow(device->Id.value(), getParam().flags);
166166

167167
first = pool->allocate();
168-
first->resetQueueAndCommand(queue, UR_COMMAND_KERNEL_LAUNCH);
168+
first->resetQueueAndCommand(
169+
v2::raii::rc_val_only<ur_queue_handle_t>(queue),
170+
UR_COMMAND_KERNEL_LAUNCH);
169171
zeFirst = first->getZeEvent();
170172

171173
urEventRelease(first);
@@ -176,7 +178,9 @@ TEST_P(EventPoolTest, Basic) {
176178
auto pool = cache->borrow(device->Id.value(), getParam().flags);
177179

178180
second = pool->allocate();
179-
first->resetQueueAndCommand(queue, UR_COMMAND_KERNEL_LAUNCH);
181+
first->resetQueueAndCommand(
182+
v2::raii::rc_val_only<ur_queue_handle_t>(queue),
183+
UR_COMMAND_KERNEL_LAUNCH);
180184
zeSecond = second->getZeEvent();
181185

182186
urEventRelease(second);
@@ -197,7 +201,8 @@ TEST_P(EventPoolTest, Threaded) {
197201
for (int i = 0; i < 100; ++i) {
198202
events.push_back(pool->allocate());
199203
events.back()->resetQueueAndCommand(
200-
queue, UR_COMMAND_KERNEL_LAUNCH);
204+
v2::raii::rc_val_only<ur_queue_handle_t>(queue),
205+
UR_COMMAND_KERNEL_LAUNCH);
201206
}
202207
for (int i = 0; i < 100; ++i) {
203208
urEventRelease(events[i]);
@@ -216,7 +221,9 @@ TEST_P(EventPoolTest, ProviderNormalUseMostFreePool) {
216221
std::list<ur_event_handle_t> events;
217222
for (int i = 0; i < 128; ++i) {
218223
auto event = pool->allocate();
219-
event->resetQueueAndCommand(queue, UR_COMMAND_KERNEL_LAUNCH);
224+
event->resetQueueAndCommand(
225+
v2::raii::rc_val_only<ur_queue_handle_t>(queue),
226+
UR_COMMAND_KERNEL_LAUNCH);
220227
events.push_back(event);
221228
}
222229
auto frontZeHandle = events.front()->getZeEvent();
@@ -226,7 +233,8 @@ TEST_P(EventPoolTest, ProviderNormalUseMostFreePool) {
226233
}
227234
for (int i = 0; i < 8; ++i) {
228235
auto e = pool->allocate();
229-
e->resetQueueAndCommand(queue, UR_COMMAND_KERNEL_LAUNCH);
236+
e->resetQueueAndCommand(v2::raii::rc_val_only<ur_queue_handle_t>(queue),
237+
UR_COMMAND_KERNEL_LAUNCH);
230238
events.push_back(e);
231239
}
232240

0 commit comments

Comments
 (0)