Skip to content

Commit 66b449b

Browse files
committed
[L0 v2] introduce raii wrapper for UR handles
Some entities (e.g. devices) do not need to be retained as they are owned by the platform. For such cases, only validate RefCount instead of acutally increasing/decreasing it.
1 parent 0125279 commit 66b449b

File tree

11 files changed

+99
-48
lines changed

11 files changed

+99
-48
lines changed

source/adapters/level_zero/v2/common.hpp

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <ze_api.h>
1515

1616
#include "../common.hpp"
17+
#include "../ur_interface_loader.hpp"
1718
#include "logger/ur_logger.hpp"
1819

1920
namespace v2 {
@@ -54,8 +55,7 @@ struct ze_handle_wrapper {
5455
try {
5556
reset();
5657
} catch (...) {
57-
// TODO: add appropriate logging or pass the error
58-
// to the caller (make the dtor noexcept(false) or use tls?)
58+
// logging already done in reset
5959
}
6060
}
6161

@@ -104,5 +104,83 @@ using ze_context_handle_t =
104104
using ze_command_list_handle_t =
105105
ze_handle_wrapper<::ze_command_list_handle_t, zeCommandListDestroy>;
106106

107+
template <typename URHandle, ur_result_t (*retain)(URHandle),
108+
ur_result_t (*release)(URHandle)>
109+
struct ref_counted {
110+
ref_counted(URHandle handle) : handle(handle) {
111+
if (handle) {
112+
retain(handle);
113+
}
114+
}
115+
116+
~ref_counted() {
117+
if (handle) {
118+
release(handle);
119+
}
120+
}
121+
122+
operator URHandle() const { return handle; }
123+
URHandle operator->() const { return handle; }
124+
125+
ref_counted(const ref_counted &) = delete;
126+
ref_counted &operator=(const ref_counted &) = delete;
127+
128+
ref_counted(ref_counted &&other) {
129+
handle = other.handle;
130+
other.handle = nullptr;
131+
}
132+
133+
ref_counted &operator=(ref_counted &&other) {
134+
if (this == &other) {
135+
return *this;
136+
}
137+
138+
if (handle) {
139+
release(handle);
140+
}
141+
142+
handle = other.handle;
143+
other.handle = nullptr;
144+
return *this;
145+
}
146+
147+
URHandle get() const { return handle; }
148+
149+
private:
150+
URHandle handle;
151+
};
152+
153+
template <typename URHandle>
154+
ur_result_t validateRetain([[maybe_unused]] URHandle handle) {
155+
assert(reinterpret_cast<_ur_object *>(handle)->RefCount.load() != 0);
156+
return UR_RESULT_SUCCESS;
157+
}
158+
159+
template <typename URHandle>
160+
ur_result_t validateRelease([[maybe_unused]] URHandle handle) {
161+
assert(reinterpret_cast<_ur_object *>(handle)->RefCount.load() != 0);
162+
return UR_RESULT_SUCCESS;
163+
}
164+
165+
// Devices are owned by the platform, so we don't need to retain/release them
166+
// as long as the platform is alive.
167+
using ur_device_handle_t =
168+
ref_counted<::ur_device_handle_t, validateRetain<::ur_device_handle_t>,
169+
validateRelease<::ur_device_handle_t>>;
170+
171+
// Spec requires that the context is not destroyed until all entities
172+
// using the context are destroyed.
173+
using ur_context_handle_t =
174+
ref_counted<::ur_context_handle_t, validateRetain<::ur_context_handle_t>,
175+
validateRelease<::ur_context_handle_t>>;
176+
177+
using ur_mem_handle_t =
178+
ref_counted<::ur_mem_handle_t, ur::level_zero::urMemRetain, urMemRelease>;
179+
180+
using ur_program_handle_t =
181+
ref_counted<::ur_program_handle_t, ur::level_zero::urProgramRetain,
182+
urProgramRelease>;
183+
107184
} // namespace raii
185+
108186
} // namespace v2

source/adapters/level_zero/v2/event.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ ur_result_t ur_event_handle_t_::release() {
154154
if (isTimestamped() && !getEventEndTimestamp()) {
155155
// L0 will write end timestamp to this event some time in the future,
156156
// so we can't release it yet.
157+
158+
// If this code is being executed, queue has to be valid (queue cannot
159+
// be released before all operations complete).
157160
assert(hQueue);
158161
hQueue->deferEventFree(this);
159162
return UR_RESULT_SUCCESS;

source/adapters/level_zero/v2/event.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,14 @@ struct ur_event_handle_t_ : _ur_object {
9898
uint64_t getEventEndTimestamp();
9999

100100
protected:
101+
// Do not use ref couting on context to avoid circular dependency.
101102
ur_context_handle_t hContext;
102103

103104
// non-owning handle to the L0 event
104105
const ze_event_handle_t hZeEvent;
105106

106107
// queue and commandType that this event is associated with, set by enqueue
107-
// commands
108+
// commands. DO NOT ref count hQueue here to avoid circular references.
108109
ur_queue_handle_t hQueue = nullptr;
109110
ur_command_t commandType = UR_COMMAND_FORCE_UINT32;
110111

source/adapters/level_zero/v2/event_pool.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class event_pool {
5050
event_flags_t getFlags() const;
5151

5252
private:
53+
// Do not use ref couting on context to avoid circular dependency.
5354
ur_context_handle_t hContext;
5455
std::unique_ptr<event_provider> provider;
5556

source/adapters/level_zero/v2/event_pool_cache.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class event_pool_cache {
4141
raii::cache_borrowed_event_pool borrow(DeviceId, event_flags_t flags);
4242

4343
private:
44+
// Do not use ref couting on context to avoid circular dependency.
4445
ur_context_handle_t hContext;
4546
ur_mutex mutex;
4647
ProviderCreateFunc providerCreate;

source/adapters/level_zero/v2/kernel.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ ur_kernel_handle_t_::ur_kernel_handle_t_(ur_program_handle_t hProgram,
4040
const char *kernelName)
4141
: hProgram(hProgram),
4242
deviceKernels(hProgram->Context->getPlatform()->getNumDevices()) {
43-
ur::level_zero::urProgramRetain(hProgram);
44-
4543
for (auto &Dev : hProgram->AssociatedDevices) {
4644
auto zeDevice = Dev->ZeDevice;
4745
// Program may be associated with all devices from the context but built
@@ -75,8 +73,6 @@ ur_kernel_handle_t_::ur_kernel_handle_t_(
7573
const ur_kernel_native_properties_t *pProperties)
7674
: hProgram(hProgram),
7775
deviceKernels(context ? context->getPlatform()->getNumDevices() : 0) {
78-
ur::level_zero::urProgramRetain(hProgram);
79-
8076
auto ownZeHandle = pProperties ? pProperties->isNativeHandleOwned : false;
8177

8278
ze_kernel_handle_t zeKernel = ur_cast<ze_kernel_handle_t>(hNativeKernel);
@@ -94,19 +90,6 @@ ur_kernel_handle_t_::ur_kernel_handle_t_(
9490
completeInitialization();
9591
}
9692

97-
ur_result_t ur_kernel_handle_t_::release() {
98-
// manually release kernels to allow errors to be propagated
99-
for (auto &singleDeviceKernelOpt : deviceKernels) {
100-
if (singleDeviceKernelOpt.has_value()) {
101-
singleDeviceKernelOpt.value().hKernel.reset();
102-
}
103-
}
104-
105-
UR_CALL_THROWS(ur::level_zero::urProgramRelease(hProgram));
106-
107-
return UR_RESULT_SUCCESS;
108-
}
109-
11093
void ur_kernel_handle_t_::completeInitialization() {
11194
// Cache kernel name. Should be the same for all devices
11295
assert(deviceKernels.size() > 0);
@@ -365,7 +348,6 @@ ur_result_t urKernelRelease(
365348
if (!hKernel->RefCount.decrementAndTest())
366349
return UR_RESULT_SUCCESS;
367350

368-
hKernel->release();
369351
delete hKernel;
370352

371353
return UR_RESULT_SUCCESS;

source/adapters/level_zero/v2/kernel.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct ur_single_device_kernel_t {
2020
ze_kernel_handle_t hKernel, bool ownZeHandle);
2121
ur_result_t release();
2222

23-
ur_device_handle_t hDevice;
23+
v2::raii::ur_device_handle_t hDevice;
2424
v2::raii::ze_kernel_handle_t hKernel;
2525
mutable ZeCache<ZeStruct<ze_kernel_properties_t>> zeKernelProperties;
2626
};
@@ -74,9 +74,6 @@ struct ur_kernel_handle_t_ : _ur_object {
7474

7575
std::vector<char> getSourceAttributes() const;
7676

77-
// Perform cleanup.
78-
ur_result_t release();
79-
8077
// Add a pending memory allocation for which device is not yet known.
8178
ur_result_t
8279
addPendingMemoryAllocation(pending_memory_allocation_t allocation);
@@ -92,7 +89,7 @@ struct ur_kernel_handle_t_ : _ur_object {
9289

9390
private:
9491
// Keep the program of the kernel.
95-
const ur_program_handle_t hProgram;
92+
const v2::raii::ur_program_handle_t hProgram;
9693

9794
// Vector of ur_single_device_kernel_t indexed by deviceIndex().
9895
std::vector<std::optional<ur_single_device_kernel_t>> deviceKernels;

source/adapters/level_zero/v2/memory.cpp

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ void *ur_discrete_mem_handle_t::allocateOnDevice(ur_device_handle_t hDevice,
178178
hContext, hDevice, nullptr, UR_USM_TYPE_DEVICE, size, &ptr));
179179

180180
deviceAllocations[id] =
181-
usm_unique_ptr_t(ptr, [hContext = this->hContext](void *ptr) {
181+
usm_unique_ptr_t(ptr, [hContext = this->hContext.get()](void *ptr) {
182182
auto ret = hContext->getDefaultUSMPool()->free(ptr);
183183
if (ret != UR_RESULT_SUCCESS) {
184184
logger::error("Failed to free device memory: {}", ret);
@@ -230,7 +230,7 @@ ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(
230230
devicePtr = allocateOnDevice(hDevice, size);
231231
} else {
232232
deviceAllocations[hDevice->Id.value()] = usm_unique_ptr_t(
233-
devicePtr, [hContext = this->hContext, ownZePtr](void *ptr) {
233+
devicePtr, [hContext = this->hContext.get(), ownZePtr](void *ptr) {
234234
if (!ownZePtr) {
235235
return;
236236
}
@@ -361,22 +361,11 @@ static bool useHostBuffer(ur_context_handle_t hContext) {
361361
ZE_DEVICE_PROPERTY_FLAG_INTEGRATED;
362362
}
363363

364-
namespace ur::level_zero {
365-
ur_result_t urMemRetain(ur_mem_handle_t hMem);
366-
ur_result_t urMemRelease(ur_mem_handle_t hMem);
367-
} // namespace ur::level_zero
368-
369364
ur_mem_sub_buffer_t::ur_mem_sub_buffer_t(ur_mem_handle_t hParent, size_t offset,
370365
size_t size,
371366
device_access_mode_t accessMode)
372367
: ur_mem_handle_t_(hParent->getContext(), size, accessMode),
373-
hParent(hParent), offset(offset), size(size) {
374-
ur::level_zero::urMemRetain(hParent);
375-
}
376-
377-
ur_mem_sub_buffer_t::~ur_mem_sub_buffer_t() {
378-
ur::level_zero::urMemRelease(hParent);
379-
}
368+
hParent(hParent), offset(offset), size(size) {}
380369

381370
void *ur_mem_sub_buffer_t::getDevicePtr(
382371
ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,

source/adapters/level_zero/v2/memory.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct ur_mem_handle_t_ : private _ur_object {
4848

4949
protected:
5050
const device_access_mode_t accessMode;
51-
const ur_context_handle_t hContext;
51+
const v2::raii::ur_context_handle_t hContext;
5252
const size_t size;
5353
};
5454

@@ -141,7 +141,7 @@ struct ur_discrete_mem_handle_t : public ur_mem_handle_t_ {
141141

142142
// Specifies device on which the latest allocation resides.
143143
// If null, there is no allocation.
144-
ur_device_handle_t activeAllocationDevice = nullptr;
144+
v2::raii::ur_device_handle_t activeAllocationDevice = nullptr;
145145

146146
// If not null, copy the buffer content back to this memory on release.
147147
void *writeBackPtr = nullptr;
@@ -157,7 +157,6 @@ struct ur_discrete_mem_handle_t : public ur_mem_handle_t_ {
157157
struct ur_mem_sub_buffer_t : public ur_mem_handle_t_ {
158158
ur_mem_sub_buffer_t(ur_mem_handle_t hParent, size_t offset, size_t size,
159159
device_access_mode_t accesMode);
160-
~ur_mem_sub_buffer_t();
161160

162161
void *
163162
getDevicePtr(ur_device_handle_t, device_access_mode_t, size_t offset,
@@ -172,7 +171,7 @@ struct ur_mem_sub_buffer_t : public ur_mem_handle_t_ {
172171
ur_shared_mutex &getMutex() override;
173172

174173
private:
175-
ur_mem_handle_t hParent;
174+
v2::raii::ur_mem_handle_t hParent;
176175
size_t offset;
177176
size_t size;
178177
};

source/adapters/level_zero/v2/queue_immediate_in_order.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ ur_queue_immediate_in_order_t::queueGetInfo(ur_queue_info_t propName,
121121
// TODO: consider support for queue properties and size
122122
switch ((uint32_t)propName) { // cast to avoid warnings on EXT enum values
123123
case UR_QUEUE_INFO_CONTEXT:
124-
return ReturnValue(hContext);
124+
return ReturnValue(hContext.get());
125125
case UR_QUEUE_INFO_DEVICE:
126-
return ReturnValue(hDevice);
126+
return ReturnValue(hDevice.get());
127127
case UR_QUEUE_INFO_REFERENCE_COUNT:
128128
return ReturnValue(uint32_t{RefCount.load()});
129129
case UR_QUEUE_INFO_FLAGS:

source/adapters/level_zero/v2/queue_immediate_in_order.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ struct ur_command_list_handler_t {
3636

3737
struct ur_queue_immediate_in_order_t : _ur_object, public ur_queue_handle_t_ {
3838
private:
39-
ur_context_handle_t hContext;
40-
ur_device_handle_t hDevice;
39+
v2::raii::ur_context_handle_t hContext;
40+
v2::raii::ur_device_handle_t hDevice;
4141
ur_queue_flags_t flags;
4242

4343
raii::cache_borrowed_event_pool eventPool;

0 commit comments

Comments
 (0)