Skip to content

Revert "Add option in OAuthCred to load authUrlV2. (#3777)" #3779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Runner.Listener/BrokerMessageListener.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public async Task<CreateSessionResult> CreateSessionAsync(CancellationToken toke

// Create connection.
Trace.Info("Loading Credentials");
_creds = _credMgr.LoadCredentials(allowAuthUrlV2: false);
_creds = _credMgr.LoadCredentials();

var agent = new TaskAgentReference
{
Expand Down Expand Up @@ -434,7 +434,7 @@ ex is AccessDeniedException ||
private async Task RefreshBrokerConnectionAsync()
{
Trace.Info("Reload credentials.");
_creds = _credMgr.LoadCredentials(allowAuthUrlV2: false); // TODO: change to `true` in the next PR.
_creds = _credMgr.LoadCredentials();
await _brokerServer.ConnectAsync(new Uri(_settings.ServerUrlV2), _creds);
Trace.Info("Connection to Broker Server recreated.");
}
Expand Down
6 changes: 3 additions & 3 deletions src/Runner.Listener/Configuration/ConfigurationManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public async Task ConfigureAsync(CommandSettings command)
runnerSettings.ServerUrl = inputUrl;
// Get the credentials
credProvider = GetCredentialProvider(command, runnerSettings.ServerUrl);
creds = credProvider.GetVssCredentials(HostContext, allowAuthUrlV2: false);
creds = credProvider.GetVssCredentials(HostContext);
Trace.Info("legacy vss cred retrieved");
}
else
Expand Down Expand Up @@ -384,7 +384,7 @@ public async Task ConfigureAsync(CommandSettings command)
if (!runnerSettings.UseV2Flow)
{
var credMgr = HostContext.GetService<ICredentialManager>();
VssCredentials credential = credMgr.LoadCredentials(allowAuthUrlV2: false);
VssCredentials credential = credMgr.LoadCredentials();
try
{
await _runnerServer.ConnectAsync(new Uri(runnerSettings.ServerUrl), credential);
Expand Down Expand Up @@ -519,7 +519,7 @@ public async Task UnconfigureAsync(CommandSettings command)
if (string.IsNullOrEmpty(settings.GitHubUrl))
{
var credProvider = GetCredentialProvider(command, settings.ServerUrl);
creds = credProvider.GetVssCredentials(HostContext, allowAuthUrlV2: false);
creds = credProvider.GetVssCredentials(HostContext);
Trace.Info("legacy vss cred retrieved");
}
else
Expand Down
15 changes: 10 additions & 5 deletions src/Runner.Listener/Configuration/CredentialManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace GitHub.Runner.Listener.Configuration
public interface ICredentialManager : IRunnerService
{
ICredentialProvider GetCredentialProvider(string credType);
VssCredentials LoadCredentials(bool allowAuthUrlV2);
VssCredentials LoadCredentials();
}

public class CredentialManager : RunnerService, ICredentialManager
Expand All @@ -40,7 +40,7 @@ public ICredentialProvider GetCredentialProvider(string credType)
return creds;
}

public VssCredentials LoadCredentials(bool allowAuthUrlV2)
public VssCredentials LoadCredentials()
{
IConfigurationStore store = HostContext.GetService<IConfigurationStore>();

Expand All @@ -51,16 +51,21 @@ public VssCredentials LoadCredentials(bool allowAuthUrlV2)

CredentialData credData = store.GetCredentials();
var migratedCred = store.GetMigratedCredentials();
if (migratedCred != null &&
migratedCred.Scheme == Constants.Configuration.OAuth)
if (migratedCred != null)
{
credData = migratedCred;

// Re-write .credentials with Token URL
store.SaveCredential(credData);

// Delete .credentials_migrated
store.DeleteMigratedCredential();
}

ICredentialProvider credProv = GetCredentialProvider(credData.Scheme);
credProv.CredentialData = credData;

VssCredentials creds = credProv.GetVssCredentials(HostContext, allowAuthUrlV2);
VssCredentials creds = credProv.GetVssCredentials(HostContext);

return creds;
}
Expand Down
8 changes: 4 additions & 4 deletions src/Runner.Listener/Configuration/CredentialProvider.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using System;
using GitHub.Services.Common;
using GitHub.Runner.Common;
using GitHub.Runner.Sdk;
using GitHub.Services.Common;
using GitHub.Services.OAuth;

namespace GitHub.Runner.Listener.Configuration
Expand All @@ -10,7 +10,7 @@ public interface ICredentialProvider
{
Boolean RequireInteractive { get; }
CredentialData CredentialData { get; set; }
VssCredentials GetVssCredentials(IHostContext context, bool allowAuthUrlV2);
VssCredentials GetVssCredentials(IHostContext context);
void EnsureCredential(IHostContext context, CommandSettings command, string serverUrl);
}

Expand All @@ -25,15 +25,15 @@ public CredentialProvider(string scheme)
public virtual Boolean RequireInteractive => false;
public CredentialData CredentialData { get; set; }

public abstract VssCredentials GetVssCredentials(IHostContext context, bool allowAuthUrlV2);
public abstract VssCredentials GetVssCredentials(IHostContext context);
public abstract void EnsureCredential(IHostContext context, CommandSettings command, string serverUrl);
}

public sealed class OAuthAccessTokenCredential : CredentialProvider
{
public OAuthAccessTokenCredential() : base(Constants.Configuration.OAuthAccessToken) { }

public override VssCredentials GetVssCredentials(IHostContext context, bool allowAuthUrlV2)
public override VssCredentials GetVssCredentials(IHostContext context)
{
ArgUtil.NotNull(context, nameof(context));
Tracing trace = context.GetTrace(nameof(OAuthAccessTokenCredential));
Expand Down
10 changes: 1 addition & 9 deletions src/Runner.Listener/Configuration/OAuthCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,10 @@ public override void EnsureCredential(
// Nothing to verify here
}

public override VssCredentials GetVssCredentials(IHostContext context, bool allowAuthUrlV2)
public override VssCredentials GetVssCredentials(IHostContext context)
{
var clientId = this.CredentialData.Data.GetValueOrDefault("clientId", null);
var authorizationUrl = this.CredentialData.Data.GetValueOrDefault("authorizationUrl", null);
var authorizationUrlV2 = this.CredentialData.Data.GetValueOrDefault("authorizationUrlV2", null);

if (allowAuthUrlV2 &&
!string.IsNullOrEmpty(authorizationUrlV2) &&
context.AllowAuthMigration)
{
authorizationUrl = authorizationUrlV2;
}

// For back compat with .credential file that doesn't has 'oauthEndpointUrl' section
var oauthEndpointUrl = this.CredentialData.Data.GetValueOrDefault("oauthEndpointUrl", authorizationUrl);
Expand Down
3 changes: 1 addition & 2 deletions src/Runner.Listener/MessageListener.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public async Task<CreateSessionResult> CreateSessionAsync(CancellationToken toke

// Create connection.
Trace.Info("Loading Credentials");
_creds = _credMgr.LoadCredentials(allowAuthUrlV2: false);
_creds = _credMgr.LoadCredentials();

var agent = new TaskAgentReference
{
Expand Down Expand Up @@ -415,7 +415,6 @@ public async Task DeleteMessageAsync(TaskAgentMessage message)
public async Task RefreshListenerTokenAsync()
{
await _runnerServer.RefreshConnectionAsync(RunnerConnectionType.MessageQueue, TimeSpan.FromSeconds(60));
_creds = _credMgr.LoadCredentials(allowAuthUrlV2: false); // TODO: change to `true` in next PR
await _brokerServer.ForceRefreshConnection(_creds);
}

Expand Down
2 changes: 1 addition & 1 deletion src/Runner.Listener/Runner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ private async Task<int> RunAsync(RunnerSettings settings, bool runOnce = false)

// Create connection
var credMgr = HostContext.GetService<ICredentialManager>();
var creds = credMgr.LoadCredentials(allowAuthUrlV2: false);
var creds = credMgr.LoadCredentials();

if (string.IsNullOrEmpty(messageRef.RunServiceUrl))
{
Expand Down
10 changes: 0 additions & 10 deletions src/Runner.Listener/RunnerConfigUpdater.cs
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,6 @@ private async Task UpdateRunnerCredentialsAsync(string serviceType, string confi
await ReportTelemetryAsync($"Credential clientId in refreshed config '{refreshedClientId ?? "Empty"}' does not match the current credential clientId '{clientId}'.");
return;
}

// make sure the credential authorizationUrl in the refreshed config match the current credential authorizationUrl for OAuth auth scheme
var authorizationUrl = _credData.Data.GetValueOrDefault("authorizationUrl", null);
var refreshedAuthorizationUrl = refreshedCredConfig.Data.GetValueOrDefault("authorizationUrl", null);
if (authorizationUrl != refreshedAuthorizationUrl)
{
Trace.Error($"Credential authorizationUrl in refreshed config '{refreshedAuthorizationUrl ?? "Empty"}' does not match the current credential authorizationUrl '{authorizationUrl}'.");
await ReportTelemetryAsync($"Credential authorizationUrl in refreshed config '{refreshedAuthorizationUrl ?? "Empty"}' does not match the current credential authorizationUrl '{authorizationUrl}'.");
return;
}
}

// save the refreshed runner credentials as a separate file
Expand Down
2 changes: 1 addition & 1 deletion src/Test/L0/Listener/BrokerMessageListenerL0.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public async void CreatesSession()
tokenSource.Token))
.Returns(Task.FromResult(expectedSession));

_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));

Expand Down
91 changes: 3 additions & 88 deletions src/Test/L0/Listener/Configuration/RunnerCredentialL0.cs
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
using System.Collections.Generic;
using System.Security.Cryptography;
using GitHub.Runner.Listener;
using GitHub.Runner.Listener;
using GitHub.Runner.Listener.Configuration;
using GitHub.Services.Common;
using GitHub.Services.OAuth;
using Moq;
using Xunit;

namespace GitHub.Runner.Common.Tests.Listener.Configuration
{
public class TestRunnerCredential : CredentialProvider
{
public TestRunnerCredential() : base("TEST") { }
public override VssCredentials GetVssCredentials(IHostContext context, bool allowAuthUrlV2)
public override VssCredentials GetVssCredentials(IHostContext context)
{
Tracing trace = context.GetTrace("OuthAccessToken");
trace.Info("GetVssCredentials()");
Expand All @@ -27,85 +23,4 @@ public override void EnsureCredential(IHostContext context, CommandSettings comm
{
}
}

public class OAuthCredentialTestsL0
{
private Mock<IRSAKeyManager> _rsaKeyManager = new Mock<IRSAKeyManager>();

[Fact]
[Trait("Level", "L0")]
[Trait("Category", "OAuthCredential")]
public void NotUseAuthV2Url()
{
using (TestHostContext hc = new(this))
{
// Arrange.
var oauth = new OAuthCredential();
oauth.CredentialData = new CredentialData()
{
Scheme = Constants.Configuration.OAuth
};
oauth.CredentialData.Data.Add("clientId", "someClientId");
oauth.CredentialData.Data.Add("authorizationUrl", "http://myserver/");
oauth.CredentialData.Data.Add("authorizationUrlV2", "http://myserverv2/");

_rsaKeyManager.Setup(x => x.GetKey()).Returns(RSA.Create(2048));
hc.SetSingleton<IRSAKeyManager>(_rsaKeyManager.Object);

// Act.
var cred = oauth.GetVssCredentials(hc, false); // not allow auth v2

var cred2 = oauth.GetVssCredentials(hc, true); // use auth v2 but hostcontext doesn't

hc.EnableAuthMigration("L0Test");
var cred3 = oauth.GetVssCredentials(hc, false); // not use auth v2 but hostcontext does

oauth.CredentialData.Data.Remove("authorizationUrlV2");
var cred4 = oauth.GetVssCredentials(hc, true); // v2 url is not there

// Assert.
Assert.Equal("http://myserver/", (cred.Federated as VssOAuthCredential).AuthorizationUrl.AbsoluteUri);
Assert.Equal("someClientId", (cred.Federated as VssOAuthCredential).ClientCredential.ClientId);

Assert.Equal("http://myserver/", (cred2.Federated as VssOAuthCredential).AuthorizationUrl.AbsoluteUri);
Assert.Equal("someClientId", (cred2.Federated as VssOAuthCredential).ClientCredential.ClientId);

Assert.Equal("http://myserver/", (cred3.Federated as VssOAuthCredential).AuthorizationUrl.AbsoluteUri);
Assert.Equal("someClientId", (cred3.Federated as VssOAuthCredential).ClientCredential.ClientId);

Assert.Equal("http://myserver/", (cred4.Federated as VssOAuthCredential).AuthorizationUrl.AbsoluteUri);
Assert.Equal("someClientId", (cred4.Federated as VssOAuthCredential).ClientCredential.ClientId);
}
}

[Fact]
[Trait("Level", "L0")]
[Trait("Category", "OAuthCredential")]
public void UseAuthV2Url()
{
using (TestHostContext hc = new(this))
{
// Arrange.
var oauth = new OAuthCredential();
oauth.CredentialData = new CredentialData()
{
Scheme = Constants.Configuration.OAuth
};
oauth.CredentialData.Data.Add("clientId", "someClientId");
oauth.CredentialData.Data.Add("authorizationUrl", "http://myserver/");
oauth.CredentialData.Data.Add("authorizationUrlV2", "http://myserverv2/");

_rsaKeyManager.Setup(x => x.GetKey()).Returns(RSA.Create(2048));
hc.SetSingleton<IRSAKeyManager>(_rsaKeyManager.Object);

// Act.
hc.EnableAuthMigration("L0Test");
var cred = oauth.GetVssCredentials(hc, true);

// Assert.
Assert.Equal("http://myserverv2/", (cred.Federated as VssOAuthCredential).AuthorizationUrl.AbsoluteUri);
Assert.Equal("someClientId", (cred.Federated as VssOAuthCredential).ClientCredential.ClientId);
}
}
}
}
}
16 changes: 8 additions & 8 deletions src/Test/L0/Listener/MessageListenerL0.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public async void CreatesSession()
tokenSource.Token))
.Returns(Task.FromResult(expectedSession));

_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));

Expand Down Expand Up @@ -127,7 +127,7 @@ public async void CreatesSessionWithBrokerMigration()
tokenSource.Token))
.Returns(Task.FromResult(expectedBrokerSession));

_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));

Expand Down Expand Up @@ -177,7 +177,7 @@ public async void DeleteSession()
tokenSource.Token))
.Returns(Task.FromResult(expectedSession));

_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));

Expand Down Expand Up @@ -237,7 +237,7 @@ public async void DeleteSessionWithBrokerMigration()
tokenSource.Token))
.Returns(Task.FromResult(expectedBrokerSession));

_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));

Expand Down Expand Up @@ -301,7 +301,7 @@ public async void GetNextMessage()
tokenSource.Token))
.Returns(Task.FromResult(expectedSession));

_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));

Expand Down Expand Up @@ -382,7 +382,7 @@ public async void GetNextMessageWithBrokerMigration()
tokenSource.Token))
.Returns(Task.FromResult(expectedSession));

_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));

Expand Down Expand Up @@ -484,7 +484,7 @@ public async void CreateSessionWithOriginalCredential()
tokenSource.Token))
.Returns(Task.FromResult(expectedSession));

_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());

var originalCred = new CredentialData() { Scheme = Constants.Configuration.OAuth };
originalCred.Data["authorizationUrl"] = "https://s.server";
Expand Down Expand Up @@ -533,7 +533,7 @@ public async void SkipDeleteSession_WhenGetNextMessageGetTaskAgentAccessTokenExp
tokenSource.Token))
.Returns(Task.FromResult(expectedSession));

_credMgr.Setup(x => x.LoadCredentials(It.IsAny<bool>())).Returns(new VssCredentials());
_credMgr.Setup(x => x.LoadCredentials()).Returns(new VssCredentials());
_store.Setup(x => x.GetCredentials()).Returns(new CredentialData() { Scheme = Constants.Configuration.OAuthAccessToken });
_store.Setup(x => x.GetMigratedCredentials()).Returns(default(CredentialData));

Expand Down
Loading