Skip to content

Commit 0c84563

Browse files
committed
Weak ptr only from rc
1 parent fe4922a commit 0c84563

File tree

11 files changed

+67
-40
lines changed

11 files changed

+67
-40
lines changed

source/adapters/level_zero/v2/common.hpp

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,36 @@ struct ref_counted {
150150
URHandle handle;
151151
};
152152

153+
// This version of ref_counted does not call retain/release functions.
154+
// It is used to avoid circular references, most notably to ur_context_handle_t.
155+
// This is equivalent to just using URHandle but makes it clear that no ref
156+
// counting is expected.
157+
template <typename URHandle> struct weak {
158+
template <ur_result_t (*retain)(URHandle), ur_result_t (*release)(URHandle)>
159+
weak(const ref_counted<URHandle, retain, release> &handle)
160+
: handle(handle.get()) {}
161+
162+
operator URHandle() const { return handle; }
163+
URHandle operator->() const { return handle; }
164+
165+
weak(const weak &) = default;
166+
weak &operator=(const weak &) = default;
167+
168+
weak(weak &&other) = default;
169+
weak &operator=(weak &&other) = default;
170+
171+
URHandle get() const { return handle; }
172+
173+
private:
174+
URHandle handle;
175+
};
176+
153177
template <typename URHandle> struct ref_counted_traits;
154178

155179
#define DECLARE_REF_COUNTER_TRAITS(URHandle, retainFn, releaseFn) \
156180
template <> struct ref_counted_traits<URHandle> { \
157181
static ur_result_t retain(URHandle handle) { return retainFn(handle); } \
158182
static ur_result_t release(URHandle handle) { return releaseFn(handle); } \
159-
static ur_result_t nop(URHandle) { return UR_RESULT_SUCCESS; } \
160183
static ur_result_t validate([[maybe_unused]] URHandle handle) { \
161184
assert(reinterpret_cast<_ur_object *>(handle)->RefCount.load() != 0); \
162185
return UR_RESULT_SUCCESS; \
@@ -168,14 +191,6 @@ template <typename URHandle>
168191
using rc = ref_counted<URHandle, ref_counted_traits<URHandle>::retain,
169192
ref_counted_traits<URHandle>::release>;
170193

171-
// This version of ref_counted does not call retain/release functions.
172-
// It is used to avoid circular references, most notably to ur_context_handle_t.
173-
// This is equivalent to just using URHandle but makes it clear that no ref
174-
// counting is expected.
175-
template <typename URHandle>
176-
using weak = ref_counted<URHandle, ref_counted_traits<URHandle>::nop,
177-
ref_counted_traits<URHandle>::nop>;
178-
179194
// This version of ref_counted validates that the ref count is not zero on every
180195
// release and retain in debug mode, and does nothing in the release mode.
181196
// Used for types that should always be alibe during the adapter lifetime (e.g.

source/adapters/level_zero/v2/context.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
4949
const ur_device_handle_t *phDevices,
5050
bool ownZeContext)
5151
: commandListCache(hContext),
52-
eventPoolCache(this, phDevices[0]->Platform->getNumDevices(),
52+
eventPoolCache(v2::raii::rc_val_only<ur_context_handle_t>(this),
53+
phDevices[0]->Platform->getNumDevices(),
5354
[context = this, platform = phDevices[0]->Platform](
5455
DeviceId deviceId, v2::event_flags_t flags)
5556
-> std::unique_ptr<v2::event_provider> {
@@ -60,11 +61,14 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
6061

6162
// TODO: just use per-context id?
6263
return std::make_unique<v2::provider_normal>(
63-
context, v2::QUEUE_IMMEDIATE, flags);
64+
v2::raii::rc_val_only<ur_context_handle_t>(context),
65+
v2::QUEUE_IMMEDIATE, flags);
6466
}),
65-
nativeEventsPool(this, std::make_unique<v2::provider_normal>(
66-
this, v2::QUEUE_IMMEDIATE,
67-
v2::EVENT_FLAGS_PROFILING_ENABLED)),
67+
nativeEventsPool(v2::raii::rc_val_only<ur_context_handle_t>(this),
68+
std::make_unique<v2::provider_normal>(
69+
v2::raii::rc_val_only<ur_context_handle_t>(this),
70+
v2::QUEUE_IMMEDIATE,
71+
v2::EVENT_FLAGS_PROFILING_ENABLED)),
6872
hContext(hContext, ownZeContext),
6973
hDevices(phDevices, phDevices + numDevices),
7074
p2pAccessDevices(populateP2PDevices(

source/adapters/level_zero/v2/event.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ uint64_t *event_profiling_data_t::eventEndTimestampAddr() {
8787
return &recordEventEndTimestamp;
8888
}
8989

90-
ur_event_handle_t_::ur_event_handle_t_(ur_context_handle_t hContext,
91-
ze_event_handle_t hZeEvent,
92-
v2::event_flags_t flags)
90+
ur_event_handle_t_::ur_event_handle_t_(
91+
v2::raii::weak<ur_context_handle_t> hContext, ze_event_handle_t hZeEvent,
92+
v2::event_flags_t flags)
9393
: hContext(hContext), hZeEvent(hZeEvent), flags(flags),
9494
profilingData(hZeEvent) {}
9595

@@ -189,7 +189,7 @@ ur_context_handle_t ur_event_handle_t_::getContext() const { return hContext; }
189189
ur_command_t ur_event_handle_t_::getCommandType() const { return commandType; }
190190

191191
ur_pooled_event_t::ur_pooled_event_t(
192-
ur_context_handle_t hContext,
192+
v2::raii::weak<ur_context_handle_t> hContext,
193193
v2::raii::cache_borrowed_event eventAllocation, v2::event_pool *pool)
194194
: ur_event_handle_t_(hContext, eventAllocation.get(), pool->getFlags()),
195195
zeEvent(std::move(eventAllocation)), pool(pool) {}
@@ -200,7 +200,8 @@ ur_result_t ur_pooled_event_t::forceRelease() {
200200
}
201201

202202
ur_native_event_t::ur_native_event_t(
203-
ur_native_handle_t hNativeEvent, ur_context_handle_t hContext,
203+
ur_native_handle_t hNativeEvent,
204+
v2::raii::weak<ur_context_handle_t> hContext,
204205
const ur_event_native_properties_t *pProperties)
205206
: ur_event_handle_t_(
206207
hContext,
@@ -390,7 +391,9 @@ urEventCreateWithNativeHandle(ur_native_handle_t hNativeEvent,
390391
*phEvent = hContext->nativeEventsPool.allocate();
391392
ZE2UR_CALL(zeEventHostSignal, ((*phEvent)->getZeEvent()));
392393
} else {
393-
*phEvent = new ur_native_event_t(hNativeEvent, hContext, pProperties);
394+
*phEvent = new ur_native_event_t(
395+
hNativeEvent, v2::raii::rc_val_only<ur_context_handle_t>(hContext),
396+
pProperties);
394397
}
395398
return UR_RESULT_SUCCESS;
396399
} catch (...) {

source/adapters/level_zero/v2/event.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ struct event_profiling_data_t {
4747

4848
struct ur_event_handle_t_ : _ur_object {
4949
public:
50-
ur_event_handle_t_(ur_context_handle_t hContext, ze_event_handle_t hZeEvent,
51-
v2::event_flags_t flags);
50+
ur_event_handle_t_(v2::raii::weak<ur_context_handle_t> hContext,
51+
ze_event_handle_t hZeEvent, v2::event_flags_t flags);
5252

5353
// Set the queue and command that this event is associated with
5454
void resetQueueAndCommand(ur_queue_handle_t hQueue, ur_command_t commandType);
@@ -113,7 +113,7 @@ struct ur_event_handle_t_ : _ur_object {
113113
};
114114

115115
struct ur_pooled_event_t : ur_event_handle_t_ {
116-
ur_pooled_event_t(ur_context_handle_t hContext,
116+
ur_pooled_event_t(v2::raii::weak<ur_context_handle_t> hContext,
117117
v2::raii::cache_borrowed_event eventAllocation,
118118
v2::event_pool *pool);
119119

@@ -126,7 +126,7 @@ struct ur_pooled_event_t : ur_event_handle_t_ {
126126

127127
struct ur_native_event_t : ur_event_handle_t_ {
128128
ur_native_event_t(ur_native_handle_t hNativeEvent,
129-
ur_context_handle_t hContext,
129+
v2::raii::weak<ur_context_handle_t> hContext,
130130
const ur_event_native_properties_t *pProperties);
131131

132132
ur_result_t forceRelease() override;

source/adapters/level_zero/v2/event_pool.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@ namespace v2 {
2828

2929
class event_pool {
3030
public:
31-
// store weak reference to the queue as event_pool is part of the queue
32-
event_pool(ur_context_handle_t hContext,
31+
event_pool(raii::weak<ur_context_handle_t> hContext,
3332
std::unique_ptr<event_provider> Provider)
34-
: hContext(hContext), provider(std::move(Provider)),
33+
: hContext(std::move(hContext)), provider(std::move(Provider)),
3534
mutex(std::make_unique<std::mutex>()){};
3635

3736
event_pool(event_pool &&other) = default;

source/adapters/level_zero/v2/event_pool_cache.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313

1414
namespace v2 {
1515

16-
event_pool_cache::event_pool_cache(ur_context_handle_t hContext,
16+
event_pool_cache::event_pool_cache(raii::weak<ur_context_handle_t> hContext,
1717
size_t max_devices,
1818
ProviderCreateFunc ProviderCreate)
19-
: hContext(hContext), providerCreate(ProviderCreate) {
19+
: hContext(std::move(hContext)), providerCreate(ProviderCreate) {
2020
pools.resize(max_devices * (1ULL << EVENT_FLAGS_USED_BITS));
2121
}
2222

source/adapters/level_zero/v2/event_pool_cache.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class event_pool_cache {
3535
using ProviderCreateFunc = std::function<std::unique_ptr<event_provider>(
3636
DeviceId, event_flags_t flags)>;
3737

38-
event_pool_cache(ur_context_handle_t hContext, size_t max_devices,
38+
event_pool_cache(raii::weak<ur_context_handle_t> hContext, size_t max_devices,
3939
ProviderCreateFunc);
4040

4141
raii::cache_borrowed_event_pool borrow(DeviceId, event_flags_t flags);

source/adapters/level_zero/v2/event_provider_counter.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
namespace v2 {
2222

2323
provider_counter::provider_counter(ur_platform_handle_t platform,
24-
ur_context_handle_t context,
25-
ur_device_handle_t device) {
24+
raii::weak<ur_context_handle_t> context,
25+
raii::rc_val_only<ur_device_handle_t> device)
26+
: device(std::move(device)) {
2627
ZE2UR_CALL_THROWS(zeDriverGetExtensionFunctionAddress,
2728
(platform->ZeDriver, "zexCounterBasedEventCreate",
2829
(void **)&this->eventCreateFunc));

source/adapters/level_zero/v2/event_provider_counter.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@ typedef ze_result_t (*zexCounterBasedEventCreate)(
3434

3535
class provider_counter : public event_provider {
3636
public:
37-
provider_counter(ur_platform_handle_t platform, ur_context_handle_t,
38-
ur_device_handle_t);
37+
provider_counter(ur_platform_handle_t platform,
38+
raii::weak<ur_context_handle_t> context,
39+
raii::rc_val_only<ur_device_handle_t> device);
3940

4041
raii::cache_borrowed_event allocate() override;
4142
event_flags_t eventFlags() const override;
4243

4344
private:
45+
raii::rc_val_only<ur_device_handle_t> device;
4446
ze_context_handle_t translatedContext;
4547
ze_device_handle_t translatedDevice;
4648

source/adapters/level_zero/v2/event_provider_normal.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ class provider_pool {
4747
// supplies multi-device events for a given context
4848
class provider_normal : public event_provider {
4949
public:
50-
provider_normal(ur_context_handle_t context, queue_type qtype,
50+
provider_normal(raii::weak<ur_context_handle_t> context, queue_type qtype,
5151
event_flags_t flags)
52-
: queueType(qtype), urContext(context), flags(flags) {}
52+
: queueType(qtype), urContext(std::move(context)), flags(flags) {}
5353

5454
raii::cache_borrowed_event allocate() override;
5555
event_flags_t eventFlags() const override;

test/adapters/level_zero/v2/event_pool_test.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,20 @@ struct EventPoolTest : public uur::urQueueTestWithParam<ProviderParams> {
110110
mockVec.push_back(device);
111111

112112
cache = std::unique_ptr<event_pool_cache>(new event_pool_cache(
113-
nullptr, MAX_DEVICES,
113+
v2::raii::rc_val_only<ur_context_handle_t>(context), MAX_DEVICES,
114114
[this, params](DeviceId, event_flags_t flags)
115115
-> std::unique_ptr<event_provider> {
116116
// normally id would be used to find the appropriate device to create the provider
117117
switch (params.provider) {
118118
case TEST_PROVIDER_COUNTER:
119-
return std::make_unique<provider_counter>(platform, context,
120-
device);
119+
return std::make_unique<provider_counter>(
120+
platform,
121+
v2::raii::rc_val_only<ur_context_handle_t>(context),
122+
device);
121123
case TEST_PROVIDER_NORMAL:
122124
return std::make_unique<provider_normal>(
123-
context, params.queue, flags);
125+
v2::raii::rc_val_only<ur_context_handle_t>(context),
126+
params.queue, flags);
124127
}
125128
return nullptr;
126129
}));

0 commit comments

Comments
 (0)