Skip to content

Commit e66d834

Browse files
Sergio0694jkotas
andauthored
Implement 'ConditionalWeakTable<TKey,TValue>.GetOrAdd' APIs (#111204)
* Add '[EditorBrowsable(Never)]' to APIs * Add 'GetOrAdd' API * Add 'GetOrAdd' API * Add 'GetOrAdd' API * Update ref assembly * Add unit tests * Add XML docs for new APIs * Remove 'Atomically' to clarify docs * Convert uses to new APIs * Remove leftover unused method * Switch 'GetOrCreateComInterfaceForObject' to local type * Apply suggestions from code review * Lower threshold time for new tests --------- Co-authored-by: Jan Kotas <[email protected]>
1 parent d9b7515 commit e66d834

File tree

10 files changed

+344
-49
lines changed

10 files changed

+344
-49
lines changed

src/coreclr/nativeaot/System.Private.CoreLib/src/System/Reflection/Runtime/Dispensers/DispenserThatReusesAsLongAsKeyIsAlive.cs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,16 @@ internal sealed class DispenserThatReusesAsLongAsKeyIsAlive<K, [DynamicallyAcces
1414
{
1515
public DispenserThatReusesAsLongAsKeyIsAlive(Func<K, V> factory)
1616
{
17-
_createValueCallback = CreateValue;
1817
_conditionalWeakTable = new ConditionalWeakTable<K, V>();
1918
_factory = factory;
2019
}
2120

2221
public sealed override V GetOrAdd(K key)
2322
{
24-
return _conditionalWeakTable.GetValue(key, _createValueCallback);
25-
}
26-
27-
private V CreateValue(K key)
28-
{
29-
return _factory(key);
23+
return _conditionalWeakTable.GetOrAdd(key, _factory);
3024
}
3125

3226
private readonly Func<K, V> _factory;
3327
private readonly ConditionalWeakTable<K, V> _conditionalWeakTable;
34-
private readonly ConditionalWeakTable<K, V>.CreateValueCallback _createValueCallback;
3528
}
3629
}

src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,13 @@ public void DisconnectTracker()
701701
}
702702
}
703703

704+
// Custom type instead of a value tuple to avoid rooting 'ITuple' and other value tuple stuff
705+
private struct GetOrCreateComInterfaceForObjectParameters
706+
{
707+
public ComWrappers? This;
708+
public CreateComInterfaceFlags Flags;
709+
}
710+
704711
/// <summary>
705712
/// Create a COM representation of the supplied object that can be passed to a non-managed environment.
706713
/// </summary>
@@ -716,18 +723,12 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom
716723
{
717724
ArgumentNullException.ThrowIfNull(instance);
718725

719-
ManagedObjectWrapperHolder? managedObjectWrapper;
720-
if (_managedObjectWrapperTable.TryGetValue(instance, out managedObjectWrapper))
726+
ManagedObjectWrapperHolder managedObjectWrapper = _managedObjectWrapperTable.GetOrAdd(instance, static (c, items) =>
721727
{
722-
managedObjectWrapper.AddRef();
723-
return managedObjectWrapper.ComIp;
724-
}
725-
726-
managedObjectWrapper = _managedObjectWrapperTable.GetValue(instance, (c) =>
727-
{
728-
ManagedObjectWrapper* value = CreateManagedObjectWrapper(c, flags);
728+
ManagedObjectWrapper* value = items.This!.CreateManagedObjectWrapper(c, items.Flags);
729729
return new ManagedObjectWrapperHolder(value, c);
730-
});
730+
}, new GetOrCreateComInterfaceForObjectParameters { This = this, Flags = flags });
731+
731732
managedObjectWrapper.AddRef();
732733
return managedObjectWrapper.ComIp;
733734
}
@@ -1069,15 +1070,11 @@ private void RegisterWrapperForObject(NativeObjectWrapper wrapper, object comPro
10691070
Debug.Assert(wrapper.ProxyHandle.Target == comProxy);
10701071
Debug.Assert(wrapper.IsUniqueInstance || _rcwCache.FindProxyForComInstance(wrapper.ExternalComObject) == comProxy);
10711072

1072-
if (s_nativeObjectWrapperTable.TryGetValue(comProxy, out NativeObjectWrapper? registeredWrapper)
1073-
&& registeredWrapper != wrapper)
1074-
{
1075-
Debug.Assert(registeredWrapper.ExternalComObject != wrapper.ExternalComObject);
1076-
wrapper.Release();
1077-
throw new NotSupportedException();
1078-
}
1073+
// Add the input wrapper bound to the COM proxy, if there isn't one already. If another thread raced
1074+
// against this one and this lost, we'd get the wrapper added from that thread instead.
1075+
NativeObjectWrapper registeredWrapper = s_nativeObjectWrapperTable.GetOrAdd(comProxy, wrapper);
10791076

1080-
registeredWrapper = GetValueFromRcwTable(comProxy, wrapper);
1077+
// We lost the race, so we cannot register the incoming wrapper with the target object
10811078
if (registeredWrapper != wrapper)
10821079
{
10831080
Debug.Assert(registeredWrapper.ExternalComObject != wrapper.ExternalComObject);
@@ -1091,9 +1088,6 @@ private void RegisterWrapperForObject(NativeObjectWrapper wrapper, object comPro
10911088
// TrackerObjectManager and we could end up missing a section of the object graph.
10921089
// This cache deduplicates, so it is okay that the wrapper will be registered multiple times.
10931090
AddWrapperToReferenceTrackerHandleCache(registeredWrapper);
1094-
1095-
// Separate out into a local function to avoid the closure and delegate allocation unless we need it.
1096-
static NativeObjectWrapper GetValueFromRcwTable(object userObject, NativeObjectWrapper newWrapper) => s_nativeObjectWrapperTable.GetValue(userObject, _ => newWrapper);
10971091
}
10981092

10991093
private static void AddWrapperToReferenceTrackerHandleCache(NativeObjectWrapper wrapper)

src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ObjectiveCMarshal.NativeAot.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ private static IntPtr CreateReferenceTrackingHandleInternal(
136136
throw new InvalidOperationException(SR.InvalidOperation_ObjectiveCTypeNoFinalizer);
137137
}
138138

139-
var trackerInfo = s_objects.GetValue(obj, static o => new ObjcTrackingInformation());
139+
var trackerInfo = s_objects.GetOrAdd(obj, static o => new ObjcTrackingInformation());
140140
trackerInfo.EnsureInitialized(obj);
141141
trackerInfo.GetTaggedMemory(out memInSizeT, out mem);
142142
return RuntimeImports.RhHandleAllocRefCounted(obj);

src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/PInvokeMarshal.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ public static unsafe IntPtr GetFunctionPointerForDelegate(Delegate del)
7070
//
7171
// Marshalling a managed delegate created from managed code into a native function pointer
7272
//
73-
return GetPInvokeDelegates().GetValue(del, s_AllocateThunk ??= AllocateThunk).Thunk;
73+
return GetPInvokeDelegates().GetOrAdd(del, s_AllocateThunk ??= AllocateThunk).Thunk;
7474
}
7575
}
7676

7777
/// <summary>
7878
/// Used to lookup whether a delegate already has thunk allocated for it
7979
/// </summary>
8080
private static ConditionalWeakTable<Delegate, PInvokeDelegateThunk> s_pInvokeDelegates;
81-
private static ConditionalWeakTable<Delegate, PInvokeDelegateThunk>.CreateValueCallback s_AllocateThunk;
81+
private static Func<Delegate, PInvokeDelegateThunk> s_AllocateThunk;
8282

8383
private static ConditionalWeakTable<Delegate, PInvokeDelegateThunk> GetPInvokeDelegates()
8484
{

src/coreclr/nativeaot/System.Private.CoreLib/src/System/Threading/Monitor.NativeAot.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ public static partial class Monitor
2525
#region Object->Lock/Condition mapping
2626

2727
private static readonly ConditionalWeakTable<object, Condition> s_conditionTable = new ConditionalWeakTable<object, Condition>();
28-
private static readonly ConditionalWeakTable<object, Condition>.CreateValueCallback s_createCondition = (o) => new Condition(ObjectHeader.GetLockObject(o));
28+
private static readonly Func<object, Condition> s_createCondition = (o) => new Condition(ObjectHeader.GetLockObject(o));
2929

3030
private static Condition GetCondition(object obj)
3131
{
3232
Debug.Assert(
3333
!(obj is Condition),
3434
"Do not use Monitor.Pulse or Wait on a Condition instance; use the methods on Condition instead.");
35-
return s_conditionTable.GetValue(obj, s_createCondition);
35+
return s_conditionTable.GetOrAdd(obj, s_createCondition);
3636
}
3737
#endregion
3838

src/libraries/System.Net.Sockets/src/System/Net/Sockets/IOControlKeepAlive.Windows.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public static SocketError Set(SafeSocketHandle handle, SocketOptionName optionNa
6161

6262
public static SocketError Set(SafeSocketHandle handle, SocketOptionName optionName, int optionValueSeconds)
6363
{
64-
IOControlKeepAlive ioControlKeepAlive = s_socketKeepAliveTable.GetValue(handle, (SafeSocketHandle handle) => new IOControlKeepAlive());
64+
IOControlKeepAlive ioControlKeepAlive = s_socketKeepAliveTable.GetOrAdd(handle, (SafeSocketHandle handle) => new IOControlKeepAlive());
6565
if (optionName == SocketOptionName.TcpKeepAliveTime)
6666
{
6767
ioControlKeepAlive._timeMs = SecondsToMilliseconds(optionValueSeconds);

src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/ConditionalWeakTable.cs

Lines changed: 108 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System.Collections;
55
using System.Collections.Generic;
6+
using System.ComponentModel;
67
using System.Diagnostics;
78
using System.Diagnostics.CodeAnalysis;
89
using System.Numerics;
@@ -188,36 +189,125 @@ public void Clear()
188189
}
189190

190191
/// <summary>
191-
/// Atomically searches for a specified key in the table and returns the corresponding value.
192-
/// If the key does not exist in the table, the method invokes a callback method to create a
193-
/// value that is bound to the specified key.
192+
/// Searches for a specified key in the table and returns the corresponding value. If the key does
193+
/// not exist in the table, the method adds the given value and binds it to the specified key.
194+
/// </summary>
195+
/// <param name="key">The key of the value to find. It cannot be <see langword="null"/>.</param>
196+
/// <param name="value">The value to add and bind to <typeparamref name="TKey"/>, if one does not exist already.</param>
197+
/// <returns>The value bound to <typeparamref name="TKey"/> in the current <see cref="ConditionalWeakTable{TKey, TValue}"/> instance, after the method completes.</returns>
198+
/// <exception cref="ArgumentNullException"><paramref name="key"/> is <see langword="null"/>.</exception>
199+
public TValue GetOrAdd(TKey key, TValue value)
200+
{
201+
// key is validated by TryGetValue
202+
if (TryGetValue(key, out TValue? existingValue))
203+
{
204+
return existingValue;
205+
}
206+
207+
return GetOrAddLocked(key, value);
208+
}
209+
210+
/// <summary>
211+
/// Searches for a specified key in the table and returns the corresponding value. If the key does not exist
212+
/// in the table, the method invokes the supplied factory to create a value that is bound to the specified key.
213+
/// </summary>
214+
/// <param name="key">The key of the value to find. It cannot be <see langword="null"/>.</param>
215+
/// <param name="valueFactory">The callback that creates a value for key, if one does not exist already. It cannot be <see langword="null"/>.</param>
216+
/// <returns>The value bound to <typeparamref name="TKey"/> in the current <see cref="ConditionalWeakTable{TKey, TValue}"/> instance, after the method completes.</returns>
217+
/// <exception cref="ArgumentNullException"><paramref name="key"/> or <paramref name="valueFactory"/> are <see langword="null"/>.</exception>
218+
/// <remarks>
219+
/// If multiple threads try to initialize the same key, the table may invoke <paramref name="valueFactory"/> multiple times
220+
/// with the same key. Exactly one of these calls will succeed and the returned value of that call will be the one added to
221+
/// the table and returned by all the racing <see cref="GetOrAdd(TKey, Func{TKey, TValue})"/> calls. This rule permits the
222+
/// table to invoke <paramref name="valueFactory"/> outside the internal table lock, to prevent deadlocks.
223+
/// </remarks>
224+
public TValue GetOrAdd(TKey key, Func<TKey, TValue> valueFactory)
225+
{
226+
ArgumentNullException.ThrowIfNull(valueFactory);
227+
228+
// key is validated by TryGetValue
229+
if (TryGetValue(key, out TValue? existingValue))
230+
{
231+
return existingValue;
232+
}
233+
234+
// create the value outside of the lock
235+
TValue value = valueFactory(key);
236+
237+
return GetOrAddLocked(key, value);
238+
}
239+
240+
/// <summary>
241+
/// Searches for a specified key in the table and returns the corresponding value. If the key does not exist
242+
/// in the table, the method invokes the supplied factory to create a value that is bound to the specified key.
243+
/// </summary>
244+
/// <typeparam name="TArg">The type of the additional argument to use with the value factory.</typeparam>
245+
/// <param name="key">The key of the value to find. It cannot be <see langword="null"/>.</param>
246+
/// <param name="valueFactory">The callback that creates a value for key, if one does not exist already. It cannot be <see langword="null"/>.</param>
247+
/// <param name="factoryArgument">The additional argument to supply to <paramref name="valueFactory"/> upon invocation.</param>
248+
/// <returns>The value bound to <typeparamref name="TKey"/> in the current <see cref="ConditionalWeakTable{TKey, TValue}"/> instance, after the method completes.</returns>
249+
/// <exception cref="ArgumentNullException"><paramref name="key"/> or <paramref name="valueFactory"/> are <see langword="null"/>.</exception>
250+
/// <remarks>
251+
/// If multiple threads try to initialize the same key, the table may invoke <paramref name="valueFactory"/> multiple times with the
252+
/// same key. Exactly one of these calls will succeed and the returned value of that call will be the one added to the table and
253+
/// returned by all the racing <see cref="GetOrAdd{TArg}(TKey, Func{TKey, TArg, TValue}, TArg)"/> calls. This rule permits the
254+
/// table to invoke <paramref name="valueFactory"/> outside the internal table lock, to prevent deadlocks.
255+
/// </remarks>
256+
public TValue GetOrAdd<TArg>(TKey key, Func<TKey, TArg, TValue> valueFactory, TArg factoryArgument)
257+
where TArg : allows ref struct
258+
{
259+
ArgumentNullException.ThrowIfNull(valueFactory);
260+
261+
// key is validated by TryGetValue
262+
if (TryGetValue(key, out TValue? existingValue))
263+
{
264+
return existingValue;
265+
}
266+
267+
// create the value outside of the lock
268+
TValue value = valueFactory(key, factoryArgument);
269+
270+
return GetOrAddLocked(key, value);
271+
}
272+
273+
/// <summary>
274+
/// Searches for a specified key in the table and returns the corresponding value. If the key does not exist
275+
/// in the table, the method invokes a callback method to create a value that is bound to the specified key.
194276
/// </summary>
195277
/// <param name="key">key of the value to find. Cannot be null.</param>
196278
/// <param name="createValueCallback">callback that creates value for key. Cannot be null.</param>
197279
/// <returns></returns>
198280
/// <remarks>
281+
/// <para>
199282
/// If multiple threads try to initialize the same key, the table may invoke createValueCallback
200283
/// multiple times with the same key. Exactly one of these calls will succeed and the returned
201284
/// value of that call will be the one added to the table and returned by all the racing GetValue() calls.
202285
/// This rule permits the table to invoke createValueCallback outside the internal table lock
203286
/// to prevent deadlocks.
287+
/// </para>
288+
/// <para>
289+
/// Consider using <see cref="GetOrAdd(TKey, Func{TKey, TValue})"/> (or one of its overloads) instead.
290+
/// </para>
204291
/// </remarks>
292+
[EditorBrowsable(EditorBrowsableState.Never)]
205293
public TValue GetValue(TKey key, CreateValueCallback createValueCallback)
206294
{
207295
ArgumentNullException.ThrowIfNull(createValueCallback);
208296

209297
// key is validated by TryGetValue
210-
return TryGetValue(key, out TValue? existingValue) ?
211-
existingValue :
212-
GetValueLocked(key, createValueCallback);
298+
if (TryGetValue(key, out TValue? existingValue))
299+
{
300+
return existingValue;
301+
}
302+
303+
// create the value outside of the lock
304+
TValue value = createValueCallback(key);
305+
306+
return GetOrAddLocked(key, value);
213307
}
214308

215-
private TValue GetValueLocked(TKey key, CreateValueCallback createValueCallback)
309+
private TValue GetOrAddLocked(TKey key, TValue value)
216310
{
217-
// If we got here, the key was not in the table. Invoke the callback (outside the lock)
218-
// to generate the new value for the key.
219-
TValue newValue = createValueCallback(key);
220-
221311
lock (_lock)
222312
{
223313
// Now that we've taken the lock, must recheck in case we lost a race to add the key.
@@ -228,8 +318,8 @@ private TValue GetValueLocked(TKey key, CreateValueCallback createValueCallback)
228318
else
229319
{
230320
// Verified in-lock that we won the race to add the key. Add it now.
231-
CreateEntry(key, newValue);
232-
return newValue;
321+
CreateEntry(key, value);
322+
return value;
233323
}
234324
}
235325
}
@@ -239,8 +329,13 @@ private TValue GetValueLocked(TKey key, CreateValueCallback createValueCallback)
239329
/// to create new instances as needed. If TValue does not have a default constructor, this will throw.
240330
/// </summary>
241331
/// <param name="key">key of the value to find. Cannot be null.</param>
332+
/// <remarks>
333+
/// Consider using <see cref="GetOrAdd(TKey, Func{TKey, TValue})"/> (or one of its overloads) instead.
334+
/// </remarks>
335+
[EditorBrowsable(EditorBrowsableState.Never)]
242336
public TValue GetOrCreateValue(TKey key) => GetValue(key, _ => Activator.CreateInstance<TValue>());
243337

338+
[EditorBrowsable(EditorBrowsableState.Never)]
244339
public delegate TValue CreateValueCallback(TKey key);
245340

246341
/// <summary>Gets an enumerator for the table.</summary>

src/libraries/System.Reflection.DispatchProxy/src/System/Reflection/DispatchProxyGenerator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ internal static object CreateProxyInstance(
7171
AssemblyLoadContext? alc = AssemblyLoadContext.GetLoadContext(baseType.Assembly);
7272
Debug.Assert(alc != null);
7373

74-
ProxyAssembly proxyAssembly = s_alcProxyAssemblyMap.GetValue(alc, static x => new ProxyAssembly(x));
74+
ProxyAssembly proxyAssembly = s_alcProxyAssemblyMap.GetOrAdd(alc, static x => new ProxyAssembly(x));
7575
GeneratedTypeInfo proxiedType = proxyAssembly.GetProxyType(baseType, interfaceType, interfaceParameter, proxyParameter);
7676
return Activator.CreateInstance(proxiedType.GeneratedType, new object[] { proxiedType.MethodInfos })!;
7777
}

src/libraries/System.Runtime/ref/System.Runtime.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13276,13 +13276,19 @@ public ConditionalWeakTable() { }
1327613276
public void Add(TKey key, TValue value) { }
1327713277
public void AddOrUpdate(TKey key, TValue value) { }
1327813278
public void Clear() { }
13279+
public TValue GetOrAdd(TKey key, TValue value) { throw null; }
13280+
public TValue GetOrAdd(TKey key, System.Func<TKey, TValue> valueFactory) { throw null; }
13281+
public TValue GetOrAdd<TArg>(TKey key, System.Func<TKey, TArg, TValue> valueFactory, TArg factoryArgument) where TArg : allows ref struct { throw null; }
13282+
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
1327913283
public TValue GetOrCreateValue(TKey key) { throw null; }
13284+
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
1328013285
public TValue GetValue(TKey key, System.Runtime.CompilerServices.ConditionalWeakTable<TKey, TValue>.CreateValueCallback createValueCallback) { throw null; }
1328113286
public bool Remove(TKey key) { throw null; }
1328213287
System.Collections.Generic.IEnumerator<System.Collections.Generic.KeyValuePair<TKey, TValue>> System.Collections.Generic.IEnumerable<System.Collections.Generic.KeyValuePair<TKey, TValue>>.GetEnumerator() { throw null; }
1328313288
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw null; }
1328413289
public bool TryAdd(TKey key, TValue value) { throw null; }
1328513290
public bool TryGetValue(TKey key, [System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TValue value) { throw null; }
13291+
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
1328613292
public delegate TValue CreateValueCallback(TKey key);
1328713293
}
1328813294
public readonly partial struct ConfiguredAsyncDisposable

0 commit comments

Comments
 (0)