Skip to content

Commit 3083a3c

Browse files
authored
DefaultTokenAcquirerFactoryImplementation use ConcurrentDictionary (#2764)
1 parent 0c3b9d5 commit 3083a3c

File tree

2 files changed

+123
-46
lines changed

2 files changed

+123
-46
lines changed

src/Microsoft.Identity.Web.TokenAcquisition/DefaultTokenAcquirerFactoryImplementation.cs

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
// Licensed under the MIT License.
33

44
using System;
5+
using System.Collections.Concurrent;
56
using System.Collections.Generic;
67
using Microsoft.Extensions.DependencyInjection;
7-
using Microsoft.Extensions.Options;
88
using Microsoft.Identity.Abstractions;
99

1010
namespace Microsoft.Identity.Web
@@ -17,7 +17,7 @@ public DefaultTokenAcquirerFactoryImplementation(IServiceProvider serviceProvide
1717
}
1818
private IServiceProvider ServiceProvider { get; set; }
1919

20-
readonly Dictionary<string, ITokenAcquirer> _authSchemes = new Dictionary<string, ITokenAcquirer>();
20+
readonly ConcurrentDictionary<string, ITokenAcquirer> _authSchemes = new();
2121

2222
/// <inheritdoc/>
2323
public ITokenAcquirer GetTokenAcquirer(
@@ -26,40 +26,38 @@ public ITokenAcquirer GetTokenAcquirer(
2626
IEnumerable<CredentialDescription> clientCredentials,
2727
string? region = null)
2828
{
29-
CheckServiceProviderNotNull();
29+
string key = GetKey(authority, clientId, region);
3030

31-
ITokenAcquirer? tokenAcquirer;
32-
// Compute the key
33-
string key = GetKey(authority, clientId);
34-
if (!_authSchemes.TryGetValue(key, out tokenAcquirer))
35-
{
36-
MicrosoftIdentityApplicationOptions MicrosoftIdentityApplicationOptions = new MicrosoftIdentityApplicationOptions
37-
{
38-
ClientId = clientId,
39-
Authority = authority,
40-
ClientCredentials = clientCredentials,
41-
SendX5C = true
42-
};
43-
if (region != null)
31+
// GetOrAdd ONLY synchronizes the outcome. So, the factory might still be invoked multiple times.
32+
// Therefore, all side-effects within this block must remain idempotent.
33+
return _authSchemes.GetOrAdd(key, (key) =>
4434
{
45-
MicrosoftIdentityApplicationOptions.AzureRegion = region;
46-
}
35+
MicrosoftIdentityApplicationOptions MicrosoftIdentityApplicationOptions = new()
36+
{
37+
ClientId = clientId,
38+
Authority = authority,
39+
ClientCredentials = clientCredentials,
40+
SendX5C = true
41+
};
4742

48-
var optionsMonitor = ServiceProvider.GetRequiredService<IMergedOptionsStore>();
49-
var mergedOptions = optionsMonitor.Get(key);
50-
MergedOptions.UpdateMergedOptionsFromMicrosoftIdentityApplicationOptions(MicrosoftIdentityApplicationOptions, mergedOptions);
51-
tokenAcquirer = GetTokenAcquirer(key);
52-
}
53-
return tokenAcquirer;
43+
if (region != null)
44+
{
45+
MicrosoftIdentityApplicationOptions.AzureRegion = region;
46+
}
47+
48+
IMergedOptionsStore optionsMonitor = ServiceProvider.GetRequiredService<IMergedOptionsStore>();
49+
MergedOptions mergedOptions = optionsMonitor.Get(key);
50+
MergedOptions.UpdateMergedOptionsFromMicrosoftIdentityApplicationOptions(MicrosoftIdentityApplicationOptions, mergedOptions);
51+
52+
return MakeTokenAcquirer(key);
53+
});
5454
}
5555

5656
/// <inheritdoc/>
5757
public ITokenAcquirer GetTokenAcquirer(IdentityApplicationOptions IdentityApplicationOptions)
5858
{
5959
_ = Throws.IfNull(IdentityApplicationOptions);
6060

61-
CheckServiceProviderNotNull();
62-
6361
// Compute the Azure region if the option is a MicrosoftIdentityApplicationOptions.
6462
MicrosoftIdentityApplicationOptions? MicrosoftIdentityApplicationOptions = IdentityApplicationOptions as MicrosoftIdentityApplicationOptions;
6563
if (MicrosoftIdentityApplicationOptions == null)
@@ -77,33 +75,36 @@ public ITokenAcquirer GetTokenAcquirer(IdentityApplicationOptions IdentityApplic
7775
};
7876
}
7977

80-
// Compute the key
81-
ITokenAcquirer? tokenAcquirer;
82-
string key = GetKey(IdentityApplicationOptions.Authority, IdentityApplicationOptions.ClientId);
83-
if (!_authSchemes.TryGetValue(key, out tokenAcquirer))
78+
string key = GetKey(IdentityApplicationOptions.Authority, IdentityApplicationOptions.ClientId, MicrosoftIdentityApplicationOptions.AzureRegion);
79+
80+
return _authSchemes.GetOrAdd(key, (key) =>
8481
{
85-
var optionsMonitor = ServiceProvider!.GetRequiredService<IMergedOptionsStore>();
86-
var mergedOptions = optionsMonitor.Get(key);
82+
IMergedOptionsStore optionsMonitor = ServiceProvider!.GetRequiredService<IMergedOptionsStore>();
83+
MergedOptions mergedOptions = optionsMonitor.Get(key);
84+
85+
8786
MergedOptions.UpdateMergedOptionsFromMicrosoftIdentityApplicationOptions(MicrosoftIdentityApplicationOptions, mergedOptions);
88-
tokenAcquirer = GetTokenAcquirer(key);
89-
}
90-
return tokenAcquirer;
87+
return MakeTokenAcquirer(key);
88+
});
9189
}
9290

9391
/// <inheritdoc/>
9492
public ITokenAcquirer GetTokenAcquirer(string authenticationScheme = "")
93+
{
94+
return _authSchemes.GetOrAdd(authenticationScheme, (key) =>
95+
{
96+
return MakeTokenAcquirer(authenticationScheme);
97+
});
98+
}
99+
100+
private ITokenAcquirer MakeTokenAcquirer(string authenticationScheme = "")
95101
{
96102
CheckServiceProviderNotNull();
97103

98-
ITokenAcquirer? acquirer;
99-
if (!_authSchemes.TryGetValue(authenticationScheme, out acquirer))
100-
{
101-
var tokenAcquisition = ServiceProvider!.GetRequiredService<ITokenAcquisition>();
102-
acquirer = new TokenAcquirer(tokenAcquisition, authenticationScheme);
103-
_authSchemes.Add(authenticationScheme, acquirer);
104-
}
105-
return acquirer;
104+
ITokenAcquisition tokenAcquisition = ServiceProvider!.GetRequiredService<ITokenAcquisition>();
105+
return new TokenAcquirer(tokenAcquisition, authenticationScheme);
106106
}
107+
107108
private void CheckServiceProviderNotNull()
108109
{
109110
if (ServiceProvider == null)
@@ -112,10 +113,9 @@ private void CheckServiceProviderNotNull()
112113
}
113114
}
114115

115-
116-
private static string GetKey(string? authority, string? clientId)
116+
public static string GetKey(string? authority, string? clientId, string? region)
117117
{
118-
return $"{authority}{clientId}";
118+
return $"{authority}{clientId}{region}";
119119
}
120120
}
121121
}

tests/E2E Tests/TokenAcquirerTests/TokenAcquirer.cs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
// Licensed under the MIT License.
33

44
using System;
5+
using System.Collections.Concurrent;
56
using System.Linq;
67
using System.Net.Http;
78
using System.Security.Cryptography;
89
using System.Security.Cryptography.X509Certificates;
10+
using System.Threading;
911
using System.Threading.Tasks;
1012
using Microsoft.AspNetCore.Http;
1113
using Microsoft.Extensions.DependencyInjection;
@@ -46,6 +48,81 @@ public void TokenAcquirerFactoryDoesNotUseAspNetCoreHost()
4648
Assert.Equal("Microsoft.Identity.Web.Hosts.DefaultTokenAcquisitionHost", service.GetType().FullName);
4749
}
4850

51+
[Fact]
52+
public void DefaultTokenAcquirer_GetKeyHandlesNulls()
53+
{
54+
var res = DefaultTokenAcquirerFactoryImplementation.GetKey("1", "2", "3");
55+
Assert.Equal("123", res);
56+
57+
var no_region = DefaultTokenAcquirerFactoryImplementation.GetKey("1", "2", null);
58+
Assert.Equal("12", no_region);
59+
}
60+
61+
[Fact]
62+
public void AcquireToken_WithMultipleRegions()
63+
{
64+
var tokenAcquirerFactory = TokenAcquirerFactory.GetDefaultInstance();
65+
_ = tokenAcquirerFactory.Build();
66+
67+
ITokenAcquirer tokenAcquirerA = tokenAcquirerFactory.GetTokenAcquirer(
68+
authority: "https://login.microsoftonline.com/msidentitysamplestesting.onmicrosoft.com",
69+
clientId: "6af093f3-b445-4b7a-beae-046864468ad6",
70+
clientCredentials: s_clientCredentials,
71+
"US");
72+
73+
ITokenAcquirer tokenAcquirerB = tokenAcquirerFactory.GetTokenAcquirer(
74+
authority: "https://login.microsoftonline.com/msidentitysamplestesting.onmicrosoft.com",
75+
clientId: "6af093f3-b445-4b7a-beae-046864468ad6",
76+
clientCredentials: s_clientCredentials,
77+
"US");
78+
79+
ITokenAcquirer tokenAcquirerC = tokenAcquirerFactory.GetTokenAcquirer(
80+
authority: "https://login.microsoftonline.com/msidentitysamplestesting.onmicrosoft.com",
81+
clientId: "6af093f3-b445-4b7a-beae-046864468ad6",
82+
clientCredentials: s_clientCredentials,
83+
"EU");
84+
85+
Assert.Equal(tokenAcquirerA, tokenAcquirerB);
86+
Assert.NotEqual(tokenAcquirerA, tokenAcquirerC);
87+
}
88+
89+
[Fact]
90+
public void AcquireToken_SafeFromMultipleThreads()
91+
{
92+
var tokenAcquirerFactory = TokenAcquirerFactory.GetDefaultInstance();
93+
_ = tokenAcquirerFactory.Build();
94+
95+
var count = new ConcurrentDictionary<ITokenAcquirer, bool>();
96+
97+
var action = () =>
98+
{
99+
for (int i = 0; i < 1000; i++)
100+
{
101+
ITokenAcquirer res = tokenAcquirerFactory.GetTokenAcquirer(
102+
authority: "https://login.microsoftonline.com/msidentitysamplestesting.onmicrosoft.com",
103+
clientId: "6af093f3-b445-4b7a-beae-046864468ad6",
104+
clientCredentials: s_clientCredentials,
105+
"" + (i%11));
106+
107+
count.TryAdd(res, true);
108+
}
109+
};
110+
111+
Thread[] threads = new Thread[16];
112+
for (int i = 0; i < 16; i++)
113+
{
114+
threads[i] = new Thread(() => action());
115+
threads[i].Start();
116+
}
117+
118+
foreach (Thread thread in threads)
119+
{
120+
thread.Join();
121+
}
122+
123+
Assert.Equal(11, count.Count);
124+
}
125+
49126
[IgnoreOnAzureDevopsFact]
50127
//[Theory]
51128
//[InlineData(false)]

0 commit comments

Comments
 (0)