Skip to content

Commit d470139

Browse files
authored
Add option in OAuthCred to load authUrlV2. (#3777)
1 parent cdeec01 commit d470139

12 files changed

+187
-38
lines changed

src/Runner.Listener/BrokerMessageListener.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public async Task<CreateSessionResult> CreateSessionAsync(CancellationToken toke
6565

6666
// Create connection.
6767
Trace.Info("Loading Credentials");
68-
_creds = _credMgr.LoadCredentials();
68+
_creds = _credMgr.LoadCredentials(allowAuthUrlV2: false);
6969

7070
var agent = new TaskAgentReference
7171
{
@@ -434,7 +434,7 @@ ex is AccessDeniedException ||
434434
private async Task RefreshBrokerConnectionAsync()
435435
{
436436
Trace.Info("Reload credentials.");
437-
_creds = _credMgr.LoadCredentials();
437+
_creds = _credMgr.LoadCredentials(allowAuthUrlV2: false); // TODO: change to `true` in the next PR.
438438
await _brokerServer.ConnectAsync(new Uri(_settings.ServerUrlV2), _creds);
439439
Trace.Info("Connection to Broker Server recreated.");
440440
}

src/Runner.Listener/Configuration/ConfigurationManager.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ public async Task ConfigureAsync(CommandSettings command)
127127
runnerSettings.ServerUrl = inputUrl;
128128
// Get the credentials
129129
credProvider = GetCredentialProvider(command, runnerSettings.ServerUrl);
130-
creds = credProvider.GetVssCredentials(HostContext);
130+
creds = credProvider.GetVssCredentials(HostContext, allowAuthUrlV2: false);
131131
Trace.Info("legacy vss cred retrieved");
132132
}
133133
else
@@ -384,7 +384,7 @@ public async Task ConfigureAsync(CommandSettings command)
384384
if (!runnerSettings.UseV2Flow)
385385
{
386386
var credMgr = HostContext.GetService<ICredentialManager>();
387-
VssCredentials credential = credMgr.LoadCredentials();
387+
VssCredentials credential = credMgr.LoadCredentials(allowAuthUrlV2: false);
388388
try
389389
{
390390
await _runnerServer.ConnectAsync(new Uri(runnerSettings.ServerUrl), credential);
@@ -519,7 +519,7 @@ public async Task UnconfigureAsync(CommandSettings command)
519519
if (string.IsNullOrEmpty(settings.GitHubUrl))
520520
{
521521
var credProvider = GetCredentialProvider(command, settings.ServerUrl);
522-
creds = credProvider.GetVssCredentials(HostContext);
522+
creds = credProvider.GetVssCredentials(HostContext, allowAuthUrlV2: false);
523523
Trace.Info("legacy vss cred retrieved");
524524
}
525525
else

src/Runner.Listener/Configuration/CredentialManager.cs

+5-10
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace GitHub.Runner.Listener.Configuration
1313
public interface ICredentialManager : IRunnerService
1414
{
1515
ICredentialProvider GetCredentialProvider(string credType);
16-
VssCredentials LoadCredentials();
16+
VssCredentials LoadCredentials(bool allowAuthUrlV2);
1717
}
1818

1919
public class CredentialManager : RunnerService, ICredentialManager
@@ -40,7 +40,7 @@ public ICredentialProvider GetCredentialProvider(string credType)
4040
return creds;
4141
}
4242

43-
public VssCredentials LoadCredentials()
43+
public VssCredentials LoadCredentials(bool allowAuthUrlV2)
4444
{
4545
IConfigurationStore store = HostContext.GetService<IConfigurationStore>();
4646

@@ -51,21 +51,16 @@ public VssCredentials LoadCredentials()
5151

5252
CredentialData credData = store.GetCredentials();
5353
var migratedCred = store.GetMigratedCredentials();
54-
if (migratedCred != null)
54+
if (migratedCred != null &&
55+
migratedCred.Scheme == Constants.Configuration.OAuth)
5556
{
5657
credData = migratedCred;
57-
58-
// Re-write .credentials with Token URL
59-
store.SaveCredential(credData);
60-
61-
// Delete .credentials_migrated
62-
store.DeleteMigratedCredential();
6358
}
6459

6560
ICredentialProvider credProv = GetCredentialProvider(credData.Scheme);
6661
credProv.CredentialData = credData;
6762

68-
VssCredentials creds = credProv.GetVssCredentials(HostContext);
63+
VssCredentials creds = credProv.GetVssCredentials(HostContext, allowAuthUrlV2);
6964

7065
return creds;
7166
}

src/Runner.Listener/Configuration/CredentialProvider.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using System;
2-
using GitHub.Services.Common;
32
using GitHub.Runner.Common;
43
using GitHub.Runner.Sdk;
4+
using GitHub.Services.Common;
55
using GitHub.Services.OAuth;
66

77
namespace GitHub.Runner.Listener.Configuration
@@ -10,7 +10,7 @@ public interface ICredentialProvider
1010
{
1111
Boolean RequireInteractive { get; }
1212
CredentialData CredentialData { get; set; }
13-
VssCredentials GetVssCredentials(IHostContext context);
13+
VssCredentials GetVssCredentials(IHostContext context, bool allowAuthUrlV2);
1414
void EnsureCredential(IHostContext context, CommandSettings command, string serverUrl);
1515
}
1616

@@ -25,15 +25,15 @@ public CredentialProvider(string scheme)
2525
public virtual Boolean RequireInteractive => false;
2626
public CredentialData CredentialData { get; set; }
2727

28-
public abstract VssCredentials GetVssCredentials(IHostContext context);
28+
public abstract VssCredentials GetVssCredentials(IHostContext context, bool allowAuthUrlV2);
2929
public abstract void EnsureCredential(IHostContext context, CommandSettings command, string serverUrl);
3030
}
3131

3232
public sealed class OAuthAccessTokenCredential : CredentialProvider
3333
{
3434
public OAuthAccessTokenCredential() : base(Constants.Configuration.OAuthAccessToken) { }
3535

36-
public override VssCredentials GetVssCredentials(IHostContext context)
36+
public override VssCredentials GetVssCredentials(IHostContext context, bool allowAuthUrlV2)
3737
{
3838
ArgUtil.NotNull(context, nameof(context));
3939
Tracing trace = context.GetTrace(nameof(OAuthAccessTokenCredential));

src/Runner.Listener/Configuration/OAuthCredential.cs

+9-1
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,18 @@ public override void EnsureCredential(
2222
// Nothing to verify here
2323
}
2424

25-
public override VssCredentials GetVssCredentials(IHostContext context)
25+
public override VssCredentials GetVssCredentials(IHostContext context, bool allowAuthUrlV2)
2626
{
2727
var clientId = this.CredentialData.Data.GetValueOrDefault("clientId", null);
2828
var authorizationUrl = this.CredentialData.Data.GetValueOrDefault("authorizationUrl", null);
29+
var authorizationUrlV2 = this.CredentialData.Data.GetValueOrDefault("authorizationUrlV2", null);
30+
31+
if (allowAuthUrlV2 &&
32+
!string.IsNullOrEmpty(authorizationUrlV2) &&
33+
context.AllowAuthMigration)
34+
{
35+
authorizationUrl = authorizationUrlV2;
36+
}
2937

3038
// For back compat with .credential file that doesn't has 'oauthEndpointUrl' section
3139
var oauthEndpointUrl = this.CredentialData.Data.GetValueOrDefault("oauthEndpointUrl", authorizationUrl);

src/Runner.Listener/MessageListener.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public async Task<CreateSessionResult> CreateSessionAsync(CancellationToken toke
8080

8181
// Create connection.
8282
Trace.Info("Loading Credentials");
83-
_creds = _credMgr.LoadCredentials();
83+
_creds = _credMgr.LoadCredentials(allowAuthUrlV2: false);
8484

8585
var agent = new TaskAgentReference
8686
{
@@ -415,6 +415,7 @@ public async Task DeleteMessageAsync(TaskAgentMessage message)
415415
public async Task RefreshListenerTokenAsync()
416416
{
417417
await _runnerServer.RefreshConnectionAsync(RunnerConnectionType.MessageQueue, TimeSpan.FromSeconds(60));
418+
_creds = _credMgr.LoadCredentials(allowAuthUrlV2: false); // TODO: change to `true` in next PR
418419
await _brokerServer.ForceRefreshConnection(_creds);
419420
}
420421

src/Runner.Listener/Runner.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ private async Task<int> RunAsync(RunnerSettings settings, bool runOnce = false)
570570

571571
// Create connection
572572
var credMgr = HostContext.GetService<ICredentialManager>();
573-
var creds = credMgr.LoadCredentials();
573+
var creds = credMgr.LoadCredentials(allowAuthUrlV2: false);
574574

575575
if (string.IsNullOrEmpty(messageRef.RunServiceUrl))
576576
{

src/Runner.Listener/RunnerConfigUpdater.cs

+10
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,16 @@ private async Task UpdateRunnerCredentialsAsync(string serviceType, string confi
197197
await ReportTelemetryAsync($"Credential clientId in refreshed config '{refreshedClientId ?? "Empty"}' does not match the current credential clientId '{clientId}'.");
198198
return;
199199
}
200+
201+
// make sure the credential authorizationUrl in the refreshed config match the current credential authorizationUrl for OAuth auth scheme
202+
var authorizationUrl = _credData.Data.GetValueOrDefault("authorizationUrl", null);
203+
var refreshedAuthorizationUrl = refreshedCredConfig.Data.GetValueOrDefault("authorizationUrl", null);
204+
if (authorizationUrl != refreshedAuthorizationUrl)
205+
{
206+
Trace.Error($"Credential authorizationUrl in refreshed config '{refreshedAuthorizationUrl ?? "Empty"}' does not match the current credential authorizationUrl '{authorizationUrl}'.");
207+
await ReportTelemetryAsync($"Credential authorizationUrl in refreshed config '{refreshedAuthorizationUrl ?? "Empty"}' does not match the current credential authorizationUrl '{authorizationUrl}'.");
208+
return;
209+
}
200210
}
201211

202212
// save the refreshed runner credentials as a separate file

src/Test/L0/Listener/BrokerMessageListenerL0.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public async void CreatesSession()
5050
tokenSource.Token))
5151
.Returns(Task.FromResult(expectedSession));
5252

53-
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
53+
_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
5454
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
5555
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));
5656

src/Test/L0/Listener/Configuration/RunnerCredentialL0.cs

+88-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1-
using GitHub.Runner.Listener;
1+
using System.Collections.Generic;
2+
using System.Security.Cryptography;
3+
using GitHub.Runner.Listener;
24
using GitHub.Runner.Listener.Configuration;
35
using GitHub.Services.Common;
46
using GitHub.Services.OAuth;
7+
using Moq;
8+
using Xunit;
59

610
namespace GitHub.Runner.Common.Tests.Listener.Configuration
711
{
812
public class TestRunnerCredential : CredentialProvider
913
{
1014
public TestRunnerCredential() : base("TEST") { }
11-
public override VssCredentials GetVssCredentials(IHostContext context)
15+
public override VssCredentials GetVssCredentials(IHostContext context, bool allowAuthUrlV2)
1216
{
1317
Tracing trace = context.GetTrace("OuthAccessToken");
1418
trace.Info("GetVssCredentials()");
@@ -23,4 +27,85 @@ public override void EnsureCredential(IHostContext context, CommandSettings comm
2327
{
2428
}
2529
}
26-
}
30+
31+
public class OAuthCredentialTestsL0
32+
{
33+
private Mock<IRSAKeyManager> _rsaKeyManager = new Mock<IRSAKeyManager>();
34+
35+
[Fact]
36+
[Trait("Level", "L0")]
37+
[Trait("Category", "OAuthCredential")]
38+
public void NotUseAuthV2Url()
39+
{
40+
using (TestHostContext hc = new(this))
41+
{
42+
// Arrange.
43+
var oauth = new OAuthCredential();
44+
oauth.CredentialData = new CredentialData()
45+
{
46+
Scheme = Constants.Configuration.OAuth
47+
};
48+
oauth.CredentialData.Data.Add("clientId", "someClientId");
49+
oauth.CredentialData.Data.Add("authorizationUrl", "http://myserver/");
50+
oauth.CredentialData.Data.Add("authorizationUrlV2", "http://myserverv2/");
51+
52+
_rsaKeyManager.Setup(x => x.GetKey()).Returns(RSA.Create(2048));
53+
hc.SetSingleton<IRSAKeyManager>(_rsaKeyManager.Object);
54+
55+
// Act.
56+
var cred = oauth.GetVssCredentials(hc, false); // not allow auth v2
57+
58+
var cred2 = oauth.GetVssCredentials(hc, true); // use auth v2 but hostcontext doesn't
59+
60+
hc.EnableAuthMigration("L0Test");
61+
var cred3 = oauth.GetVssCredentials(hc, false); // not use auth v2 but hostcontext does
62+
63+
oauth.CredentialData.Data.Remove("authorizationUrlV2");
64+
var cred4 = oauth.GetVssCredentials(hc, true); // v2 url is not there
65+
66+
// Assert.
67+
Assert.Equal("http://myserver/", (cred.Federated as VssOAuthCredential).AuthorizationUrl.AbsoluteUri);
68+
Assert.Equal("someClientId", (cred.Federated as VssOAuthCredential).ClientCredential.ClientId);
69+
70+
Assert.Equal("http://myserver/", (cred2.Federated as VssOAuthCredential).AuthorizationUrl.AbsoluteUri);
71+
Assert.Equal("someClientId", (cred2.Federated as VssOAuthCredential).ClientCredential.ClientId);
72+
73+
Assert.Equal("http://myserver/", (cred3.Federated as VssOAuthCredential).AuthorizationUrl.AbsoluteUri);
74+
Assert.Equal("someClientId", (cred3.Federated as VssOAuthCredential).ClientCredential.ClientId);
75+
76+
Assert.Equal("http://myserver/", (cred4.Federated as VssOAuthCredential).AuthorizationUrl.AbsoluteUri);
77+
Assert.Equal("someClientId", (cred4.Federated as VssOAuthCredential).ClientCredential.ClientId);
78+
}
79+
}
80+
81+
[Fact]
82+
[Trait("Level", "L0")]
83+
[Trait("Category", "OAuthCredential")]
84+
public void UseAuthV2Url()
85+
{
86+
using (TestHostContext hc = new(this))
87+
{
88+
// Arrange.
89+
var oauth = new OAuthCredential();
90+
oauth.CredentialData = new CredentialData()
91+
{
92+
Scheme = Constants.Configuration.OAuth
93+
};
94+
oauth.CredentialData.Data.Add("clientId", "someClientId");
95+
oauth.CredentialData.Data.Add("authorizationUrl", "http://myserver/");
96+
oauth.CredentialData.Data.Add("authorizationUrlV2", "http://myserverv2/");
97+
98+
_rsaKeyManager.Setup(x => x.GetKey()).Returns(RSA.Create(2048));
99+
hc.SetSingleton<IRSAKeyManager>(_rsaKeyManager.Object);
100+
101+
// Act.
102+
hc.EnableAuthMigration("L0Test");
103+
var cred = oauth.GetVssCredentials(hc, true);
104+
105+
// Assert.
106+
Assert.Equal("http://myserverv2/", (cred.Federated as VssOAuthCredential).AuthorizationUrl.AbsoluteUri);
107+
Assert.Equal("someClientId", (cred.Federated as VssOAuthCredential).ClientCredential.ClientId);
108+
}
109+
}
110+
}
111+
}

src/Test/L0/Listener/MessageListenerL0.cs

+8-8
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public async void CreatesSession()
6767
tokenSource.Token))
6868
.Returns(Task.FromResult(expectedSession));
6969

70-
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
70+
_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
7171
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
7272
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));
7373

@@ -127,7 +127,7 @@ public async void CreatesSessionWithBrokerMigration()
127127
tokenSource.Token))
128128
.Returns(Task.FromResult(expectedBrokerSession));
129129

130-
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
130+
_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
131131
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
132132
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));
133133

@@ -177,7 +177,7 @@ public async void DeleteSession()
177177
tokenSource.Token))
178178
.Returns(Task.FromResult(expectedSession));
179179

180-
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
180+
_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
181181
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
182182
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));
183183

@@ -237,7 +237,7 @@ public async void DeleteSessionWithBrokerMigration()
237237
tokenSource.Token))
238238
.Returns(Task.FromResult(expectedBrokerSession));
239239

240-
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
240+
_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
241241
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
242242
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));
243243

@@ -301,7 +301,7 @@ public async void GetNextMessage()
301301
tokenSource.Token))
302302
.Returns(Task.FromResult(expectedSession));
303303

304-
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
304+
_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
305305
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
306306
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));
307307

@@ -382,7 +382,7 @@ public async void GetNextMessageWithBrokerMigration()
382382
tokenSource.Token))
383383
.Returns(Task.FromResult(expectedSession));
384384

385-
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
385+
_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
386386
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
387387
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));
388388

@@ -484,7 +484,7 @@ public async void CreateSessionWithOriginalCredential()
484484
tokenSource.Token))
485485
.Returns(Task.FromResult(expectedSession));
486486

487-
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
487+
_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
488488

489489
var originalCred = new CredentialData() { Scheme = Constants.Configuration.OAuth };
490490
originalCred.Data["authorizationUrl"] = "https://s.server";
@@ -533,7 +533,7 @@ public async void SkipDeleteSession_WhenGetNextMessageGetTaskAgentAccessTokenExp
533533
tokenSource.Token))
534534
.Returns(Task.FromResult(expectedSession));
535535

536-
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
536+
_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
537537
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
538538
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));
539539

0 commit comments

Comments
 (0)