Skip to content

Commit 23e0798

Browse files
Fix missing ServiceCallSite.Key causing an unkeyed cache entry to be overwritten by a keyed instance (#113343)
1 parent deee462 commit 23e0798

File tree

8 files changed

+87
-15
lines changed

8 files changed

+87
-15
lines changed

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs

+5-6
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ void AddCallSite(ServiceCallSite callSite, int index)
362362
ResultCache resultCache = (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root)
363363
? new ResultCache(cacheLocation, callSiteKey)
364364
: new ResultCache(CallSiteResultCacheLocation.None, callSiteKey);
365-
return _callSiteCache[callSiteKey] = new IEnumerableCallSite(resultCache, itemType, callSites);
365+
return _callSiteCache[callSiteKey] = new IEnumerableCallSite(resultCache, itemType, callSites, serviceIdentifier.ServiceKey);
366366
}
367367
finally
368368
{
@@ -415,7 +415,7 @@ private static CallSiteResultCacheLocation GetCommonCacheLocation(CallSiteResult
415415
var lifetime = new ResultCache(descriptor.Lifetime, serviceIdentifier, slot);
416416
if (descriptor.HasImplementationInstance())
417417
{
418-
callSite = new ConstantCallSite(descriptor.ServiceType, descriptor.GetImplementationInstance());
418+
callSite = new ConstantCallSite(descriptor.ServiceType, descriptor.GetImplementationInstance(), descriptor.ServiceKey);
419419
}
420420
else if (!descriptor.IsKeyedService && descriptor.ImplementationFactory != null)
421421
{
@@ -433,7 +433,6 @@ private static CallSiteResultCacheLocation GetCommonCacheLocation(CallSiteResult
433433
{
434434
throw new InvalidOperationException(SR.InvalidServiceDescriptor);
435435
}
436-
callSite.Key = descriptor.ServiceKey;
437436

438437
return _callSiteCache[callSiteKey] = callSite;
439438
}
@@ -512,7 +511,7 @@ private ConstructorCallSite CreateConstructorCallSite(
512511
ParameterInfo[] parameters = constructor.GetParameters();
513512
if (parameters.Length == 0)
514513
{
515-
return new ConstructorCallSite(lifetime, serviceIdentifier.ServiceType, constructor);
514+
return new ConstructorCallSite(lifetime, serviceIdentifier.ServiceType, constructor, serviceIdentifier.ServiceKey);
516515
}
517516

518517
parameterCallSites = CreateArgumentCallSites(
@@ -522,7 +521,7 @@ private ConstructorCallSite CreateConstructorCallSite(
522521
parameters,
523522
throwIfCallSiteNotFound: true)!;
524523

525-
return new ConstructorCallSite(lifetime, serviceIdentifier.ServiceType, constructor, parameterCallSites);
524+
return new ConstructorCallSite(lifetime, serviceIdentifier.ServiceType, constructor, parameterCallSites, serviceIdentifier.ServiceKey);
526525
}
527526

528527
Array.Sort(constructors,
@@ -586,7 +585,7 @@ private ConstructorCallSite CreateConstructorCallSite(
586585
else
587586
{
588587
Debug.Assert(parameterCallSites != null);
589-
return new ConstructorCallSite(lifetime, serviceIdentifier.ServiceType, bestConstructor, parameterCallSites);
588+
return new ConstructorCallSite(lifetime, serviceIdentifier.ServiceType, bestConstructor, parameterCallSites, serviceIdentifier.ServiceKey);
590589
}
591590
}
592591
finally

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ internal sealed class ConstantCallSite : ServiceCallSite
1010
private readonly Type _serviceType;
1111
internal object? DefaultValue => Value;
1212

13-
public ConstantCallSite(Type serviceType, object? defaultValue) : base(ResultCache.None(serviceType))
13+
public ConstantCallSite(Type serviceType, object? defaultValue, object? serviceKey = null) : base(ResultCache.None(serviceType), serviceKey)
1414
{
1515
_serviceType = serviceType ?? throw new ArgumentNullException(nameof(serviceType));
1616
if (defaultValue != null && !serviceType.IsInstanceOfType(defaultValue))

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstructorCallSite.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ internal sealed class ConstructorCallSite : ServiceCallSite
1111
internal ConstructorInfo ConstructorInfo { get; }
1212
internal ServiceCallSite[] ParameterCallSites { get; }
1313

14-
public ConstructorCallSite(ResultCache cache, Type serviceType, ConstructorInfo constructorInfo) : this(cache, serviceType, constructorInfo, Array.Empty<ServiceCallSite>())
14+
public ConstructorCallSite(ResultCache cache, Type serviceType, ConstructorInfo constructorInfo, object? serviceKey) : this(cache, serviceType, constructorInfo, Array.Empty<ServiceCallSite>(), serviceKey)
1515
{
1616
}
1717

18-
public ConstructorCallSite(ResultCache cache, Type serviceType, ConstructorInfo constructorInfo, ServiceCallSite[] parameterCallSites) : base(cache)
18+
public ConstructorCallSite(ResultCache cache, Type serviceType, ConstructorInfo constructorInfo, ServiceCallSite[] parameterCallSites, object? serviceKey) : base(cache, serviceKey)
1919
{
2020
if (!serviceType.IsAssignableFrom(constructorInfo.DeclaringType))
2121
{

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/FactoryCallSite.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ internal sealed class FactoryCallSite : ServiceCallSite
99
{
1010
public Func<IServiceProvider, object> Factory { get; }
1111

12-
public FactoryCallSite(ResultCache cache, Type serviceType, Func<IServiceProvider, object> factory) : base(cache)
12+
public FactoryCallSite(ResultCache cache, Type serviceType, Func<IServiceProvider, object> factory) : base(cache, null)
1313
{
1414
Factory = factory;
1515
ServiceType = serviceType;
1616
}
1717

18-
public FactoryCallSite(ResultCache cache, Type serviceType, object serviceKey, Func<IServiceProvider, object, object> factory) : base(cache)
18+
public FactoryCallSite(ResultCache cache, Type serviceType, object serviceKey, Func<IServiceProvider, object, object> factory) : base(cache, serviceKey)
1919
{
2020
Factory = sp => factory(sp, serviceKey);
2121
ServiceType = serviceType;

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/IEnumerableCallSite.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ internal sealed class IEnumerableCallSite : ServiceCallSite
1313
internal Type ItemType { get; }
1414
internal ServiceCallSite[] ServiceCallSites { get; }
1515

16-
public IEnumerableCallSite(ResultCache cache, Type itemType, ServiceCallSite[] serviceCallSites) : base(cache)
16+
public IEnumerableCallSite(ResultCache cache, Type itemType, ServiceCallSite[] serviceCallSites, object? serviceKey = null) : base(cache, serviceKey)
1717
{
1818
Debug.Assert(!ServiceProvider.VerifyAotCompatibility || !itemType.IsValueType, "If VerifyAotCompatibility=true, an IEnumerableCallSite should not be created with a ValueType.");
1919

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCallSite.cs

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,18 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
1010
/// </summary>
1111
internal abstract class ServiceCallSite
1212
{
13-
protected ServiceCallSite(ResultCache cache)
13+
protected ServiceCallSite(ResultCache cache, object? key)
1414
{
1515
Cache = cache;
16+
Key = key;
1617
}
1718

1819
public abstract Type ServiceType { get; }
1920
public abstract Type? ImplementationType { get; }
2021
public abstract CallSiteKind Kind { get; }
2122
public ResultCache Cache { get; }
2223
public object? Value { get; set; }
23-
public object? Key { get; set; }
24+
public object? Key { get; }
2425

2526
public bool CaptureDisposable =>
2627
ImplementationType == null ||

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderCallSite.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
77
{
88
internal sealed class ServiceProviderCallSite : ServiceCallSite
99
{
10-
public ServiceProviderCallSite() : base(ResultCache.None(typeof(IServiceProvider)))
10+
public ServiceProviderCallSite() : base(ResultCache.None(typeof(IServiceProvider)), null)
1111
{
1212
}
1313

src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/CallSiteTests.cs

+72
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ internal static ServiceCallSite GetCallSite(this CallSiteFactory callSiteFactory
1717
{
1818
return callSiteFactory.GetCallSite(ServiceIdentifier.FromServiceType(type), callSiteChain);
1919
}
20+
21+
internal static ServiceCallSite GetKeyedCallSite(this CallSiteFactory callSiteFactory, Type type, object? serviceKey, CallSiteChain callSiteChain)
22+
{
23+
return callSiteFactory.GetCallSite(new ServiceIdentifier(serviceKey, type), callSiteChain);
24+
}
2025
}
2126

2227
public class CallSiteTests
@@ -300,6 +305,61 @@ public void CallSiteFactoryResolvesIEnumerableOfOpenGenericServiceAfterResolving
300305
Assert.Equal(typeof(FakeOpenGenericService<int>), implementationTypes[1]);
301306
}
302307

308+
[Fact]
309+
public void ServiceCallSite_ShouldHaveKey_WhenResolvingKeyedService()
310+
{
311+
// Arrange
312+
IServiceCollection services = new ServiceCollection();
313+
314+
services.Add(ServiceDescriptor.Transient(typeof(SomeService), typeof(SomeService)));
315+
services.Add(ServiceDescriptor.KeyedTransient(typeof(SomeService), "someKey", typeof(SomeOtherService)));
316+
317+
using var serviceProvider = services.BuildServiceProvider();
318+
319+
// Act
320+
var callSite = serviceProvider.CallSiteFactory.GetKeyedCallSite(typeof(SomeService), "someKey", new CallSiteChain());
321+
322+
// Assert
323+
Assert.NotNull(callSite.Key);
324+
}
325+
326+
[Fact]
327+
public void ServiceCallSite_ShouldHaveKey_WhenResolvingKeyedClosedImplementationOfOpenGenericService()
328+
{
329+
// Arrange
330+
IServiceCollection services = new ServiceCollection();
331+
332+
services.Add(ServiceDescriptor.Transient(typeof(IGenericService<>), typeof(UnkeyedGenericService<>)));
333+
services.Add(ServiceDescriptor.KeyedTransient(typeof(IGenericService<>), "someKey", typeof(PrimaryKeyedGenericService<>)));
334+
335+
using var serviceProvider = services.BuildServiceProvider();
336+
337+
// Act
338+
var callSite = serviceProvider.CallSiteFactory.GetKeyedCallSite(typeof(IGenericService<object>), "someKey", new CallSiteChain());
339+
340+
// Assert
341+
Assert.NotNull(callSite.Key);
342+
}
343+
344+
[Fact]
345+
public void ServiceCallSite_ShouldHaveKey_WhenResolvingKeyedIEnumerableOfClosedImplementationOfOpenGenericService()
346+
{
347+
// Arrange
348+
IServiceCollection services = new ServiceCollection();
349+
350+
services.Add(ServiceDescriptor.Transient(typeof(IGenericService<>), typeof(UnkeyedGenericService<>)));
351+
services.Add(ServiceDescriptor.KeyedTransient(typeof(IGenericService<>), "someKey", typeof(PrimaryKeyedGenericService<>)));
352+
services.Add(ServiceDescriptor.KeyedTransient(typeof(IGenericService<>), "someKey", typeof(SecondaryKeyedGenericService<>)));
353+
354+
using var serviceProvider = services.BuildServiceProvider();
355+
356+
// Act
357+
var callSite = serviceProvider.CallSiteFactory.GetKeyedCallSite(typeof(IEnumerable<IGenericService<object>>), "someKey", new CallSiteChain());
358+
359+
// Assert
360+
Assert.NotNull(callSite.Key);
361+
}
362+
303363
private class FakeIntService : IFakeOpenGenericService<int>
304364
{
305365
public int Value => 0;
@@ -395,6 +455,18 @@ public void Dispose()
395455
}
396456
}
397457

458+
private interface IGenericService<T>;
459+
460+
private class UnkeyedGenericService<T> : IGenericService<T>;
461+
462+
private class PrimaryKeyedGenericService<T> : IGenericService<T>;
463+
464+
private class SecondaryKeyedGenericService<T> : IGenericService<T>;
465+
466+
private class SomeService;
467+
468+
private class SomeOtherService : SomeService;
469+
398470
private static object Invoke(ServiceCallSite callSite, ServiceProviderEngineScope scope)
399471
{
400472
return CallSiteRuntimeResolver.Instance.Resolve(callSite, scope);

0 commit comments

Comments
 (0)