Skip to content

[BEEEP] Lazy load the current user in the CurrentContext #5605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weโ€™ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/Api/Billing/Controllers/AccountsBillingController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Bit.Api.Billing.Models.Responses;
using Bit.Core.Billing.Models.Api.Requests.Accounts;
using Bit.Core.Billing.Services;
using Bit.Core.Context;
using Bit.Core.Services;
using Bit.Core.Utilities;
using Microsoft.AspNetCore.Authorization;
Expand All @@ -12,15 +13,15 @@
[Route("accounts/billing")]
[Authorize("Application")]
public class AccountsBillingController(
ICurrentContext currentContext,

Check warning on line 16 in src/Api/Billing/Controllers/AccountsBillingController.cs

View check run for this annotation

Codecov / codecov/patch

src/Api/Billing/Controllers/AccountsBillingController.cs#L16

Added line #L16 was not covered by tests
IPaymentService paymentService,
IUserService userService,
IPaymentHistoryService paymentHistoryService) : Controller
{
[HttpGet("history")]
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<BillingHistoryResponseModel> GetBillingHistoryAsync()
{
var user = await userService.GetUserByPrincipalAsync(User);
var user = await currentContext.UserAsync.Value;

Check warning on line 24 in src/Api/Billing/Controllers/AccountsBillingController.cs

View check run for this annotation

Codecov / codecov/patch

src/Api/Billing/Controllers/AccountsBillingController.cs#L24

Added line #L24 was not covered by tests
if (user == null)
{
throw new UnauthorizedAccessException();
Expand All @@ -34,7 +35,7 @@
[SelfHosted(NotSelfHostedOnly = true)]
public async Task<BillingPaymentResponseModel> GetPaymentMethodAsync()
{
var user = await userService.GetUserByPrincipalAsync(User);
var user = await currentContext.UserAsync.Value;

Check warning on line 38 in src/Api/Billing/Controllers/AccountsBillingController.cs

View check run for this annotation

Codecov / codecov/patch

src/Api/Billing/Controllers/AccountsBillingController.cs#L38

Added line #L38 was not covered by tests
if (user == null)
{
throw new UnauthorizedAccessException();
Expand All @@ -47,7 +48,7 @@
[HttpGet("invoices")]
public async Task<IResult> GetInvoicesAsync([FromQuery] string? status = null, [FromQuery] string? startAfter = null)
{
var user = await userService.GetUserByPrincipalAsync(User);
var user = await currentContext.UserAsync.Value;

Check warning on line 51 in src/Api/Billing/Controllers/AccountsBillingController.cs

View check run for this annotation

Codecov / codecov/patch

src/Api/Billing/Controllers/AccountsBillingController.cs#L51

Added line #L51 was not covered by tests
if (user == null)
{
throw new UnauthorizedAccessException();
Expand All @@ -65,7 +66,7 @@
[HttpGet("transactions")]
public async Task<IResult> GetTransactionsAsync([FromQuery] DateTime? startAfter = null)
{
var user = await userService.GetUserByPrincipalAsync(User);
var user = await currentContext.UserAsync.Value;

Check warning on line 69 in src/Api/Billing/Controllers/AccountsBillingController.cs

View check run for this annotation

Codecov / codecov/patch

src/Api/Billing/Controllers/AccountsBillingController.cs#L69

Added line #L69 was not covered by tests
if (user == null)
{
throw new UnauthorizedAccessException();
Expand All @@ -82,7 +83,7 @@
[HttpPost("preview-invoice")]
public async Task<IResult> PreviewInvoiceAsync([FromBody] PreviewIndividualInvoiceRequestBody model)
{
var user = await userService.GetUserByPrincipalAsync(User);
var user = await currentContext.UserAsync.Value;

Check warning on line 86 in src/Api/Billing/Controllers/AccountsBillingController.cs

View check run for this annotation

Codecov / codecov/patch

src/Api/Billing/Controllers/AccountsBillingController.cs#L86

Added line #L86 was not covered by tests
if (user == null)
{
throw new UnauthorizedAccessException();
Expand Down
21 changes: 10 additions & 11 deletions src/Core/Auth/Identity/UserStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,31 +44,30 @@

public async Task<User> FindByEmailAsync(string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken))
{
if (_currentContext?.User != null && _currentContext.User.Email == normalizedEmail)
var currentUser = await _currentContext.UserAsync.Value;
if (currentUser != null && currentUser.Email == normalizedEmail)
{
return _currentContext.User;
return currentUser;

Check warning on line 50 in src/Core/Auth/Identity/UserStore.cs

View check run for this annotation

Codecov / codecov/patch

src/Core/Auth/Identity/UserStore.cs#L50

Added line #L50 was not covered by tests
}

_currentContext.User = await _userRepository.GetByEmailAsync(normalizedEmail);
return _currentContext.User;
return await _userRepository.GetByEmailAsync(normalizedEmail);
}

public async Task<User> FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken))
{
if (_currentContext?.User != null &&
string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
var currentUser = await _currentContext.UserAsync.Value;
if (currentUser != null &&
string.Equals(currentUser.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
{
return _currentContext.User;
return currentUser;

Check warning on line 62 in src/Core/Auth/Identity/UserStore.cs

View check run for this annotation

Codecov / codecov/patch

src/Core/Auth/Identity/UserStore.cs#L62

Added line #L62 was not covered by tests
}

Guid userIdGuid;
if (!Guid.TryParse(userId, out userIdGuid))
if (!Guid.TryParse(userId, out var userIdGuid))
{
return null;
}

_currentContext.User = await _userRepository.GetByIdAsync(userIdGuid);
return _currentContext.User;
return await _userRepository.GetByIdAsync(userIdGuid);
}

public async Task<User> FindByNameAsync(string normalizedUserName, CancellationToken cancellationToken = default(CancellationToken))
Expand Down
9 changes: 7 additions & 2 deletions src/Core/Context/CurrentContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ public class CurrentContext : ICurrentContext
{
private readonly IProviderOrganizationRepository _providerOrganizationRepository;
private readonly IProviderUserRepository _providerUserRepository;
private readonly IUserRepository _userRepository;

private bool _builtHttpContext;
private bool _builtClaimsPrincipal;
private IEnumerable<ProviderOrganizationProviderDetails> _providerOrganizationProviderDetails;
private IEnumerable<ProviderUserOrganizationDetails> _providerUserOrganizations;

public virtual HttpContext HttpContext { get; set; }
public virtual Guid? UserId { get; set; }
public virtual User User { get; set; }
public virtual Lazy<Task<User>> UserAsync { get; private set; } = new(() => Task.FromResult<User>(null));
public virtual string DeviceIdentifier { get; set; }
public virtual DeviceType? DeviceType { get; set; }
public virtual string IpAddress { get; set; }
Expand All @@ -47,10 +49,12 @@ public class CurrentContext : ICurrentContext

public CurrentContext(
IProviderOrganizationRepository providerOrganizationRepository,
IProviderUserRepository providerUserRepository)
IProviderUserRepository providerUserRepository,
IUserRepository userRepository)
{
_providerOrganizationRepository = providerOrganizationRepository;
_providerUserRepository = providerUserRepository;
_userRepository = userRepository;
}

public async virtual Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings)
Expand Down Expand Up @@ -138,6 +142,7 @@ public virtual Task SetContextAsync(ClaimsPrincipal user)
if (Guid.TryParse(subject, out var subIdGuid))
{
UserId = subIdGuid;
UserAsync = new Lazy<Task<User>>(() => _userRepository.GetByIdAsync(UserId.Value));
}

ClientId = GetClaimValue(claimsDict, "client_id");
Expand Down
2 changes: 1 addition & 1 deletion src/Core/Context/ICurrentContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public interface ICurrentContext
{
HttpContext HttpContext { get; set; }
Guid? UserId { get; set; }
User User { get; set; }
Lazy<Task<User>> UserAsync { get; }
string DeviceIdentifier { get; set; }
DeviceType? DeviceType { get; set; }
string IpAddress { get; set; }
Expand Down
18 changes: 9 additions & 9 deletions src/Core/Services/Implementations/UserService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,30 +173,30 @@

public async Task<User> GetUserByIdAsync(string userId)
{
if (_currentContext?.User != null &&
string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))
var currentUser = await _currentContext.UserAsync.Value;

Check warning on line 176 in src/Core/Services/Implementations/UserService.cs

View check run for this annotation

Codecov / codecov/patch

src/Core/Services/Implementations/UserService.cs#L176

Added line #L176 was not covered by tests
if (currentUser != null &&
string.Equals(currentUser.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase))

Check warning on line 178 in src/Core/Services/Implementations/UserService.cs

View check run for this annotation

Codecov / codecov/patch

src/Core/Services/Implementations/UserService.cs#L178

Added line #L178 was not covered by tests
{
return _currentContext.User;
return currentUser;

Check warning on line 180 in src/Core/Services/Implementations/UserService.cs

View check run for this annotation

Codecov / codecov/patch

src/Core/Services/Implementations/UserService.cs#L180

Added line #L180 was not covered by tests
}

if (!Guid.TryParse(userId, out var userIdGuid))
{
return null;
}

_currentContext.User = await _userRepository.GetByIdAsync(userIdGuid);
return _currentContext.User;
return await _userRepository.GetByIdAsync(userIdGuid);

Check warning on line 188 in src/Core/Services/Implementations/UserService.cs

View check run for this annotation

Codecov / codecov/patch

src/Core/Services/Implementations/UserService.cs#L188

Added line #L188 was not covered by tests
}

public async Task<User> GetUserByIdAsync(Guid userId)
{
if (_currentContext?.User != null && _currentContext.User.Id == userId)
var currentUser = await _currentContext.UserAsync.Value;
if (currentUser != null && currentUser.Id == userId)
{
return _currentContext.User;
return currentUser;
}

_currentContext.User = await _userRepository.GetByIdAsync(userId);
return _currentContext.User;
return await _userRepository.GetByIdAsync(userId);
}

public async Task<User> GetUserByPrincipalAsync(ClaimsPrincipal principal)
Expand Down Expand Up @@ -1045,7 +1045,7 @@
}
catch when (!_globalSettings.SelfHosted)
{
await paymentService.CancelAndRecoverChargesAsync(user);

Check warning on line 1048 in src/Core/Services/Implementations/UserService.cs

View workflow job for this annotation

GitHub Actions / Quality scan

'paymentService' is null on at least one execution path. (https://rules.sonarsource.com/csharp/RSPEC-2259)
throw;
}

Expand Down
4 changes: 2 additions & 2 deletions src/Notifications/NotificationsHub.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

public override async Task OnConnectedAsync()
{
var currentContext = new CurrentContext(null, null);
var currentContext = new CurrentContext(null, null, null);

Check warning on line 23 in src/Notifications/NotificationsHub.cs

View check run for this annotation

Codecov / codecov/patch

src/Notifications/NotificationsHub.cs#L23

Added line #L23 was not covered by tests
await currentContext.BuildAsync(Context.User, _globalSettings);

var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
Expand Down Expand Up @@ -57,7 +57,7 @@

public override async Task OnDisconnectedAsync(Exception exception)
{
var currentContext = new CurrentContext(null, null);
var currentContext = new CurrentContext(null, null, null);

Check warning on line 60 in src/Notifications/NotificationsHub.cs

View check run for this annotation

Codecov / codecov/patch

src/Notifications/NotificationsHub.cs#L60

Added line #L60 was not covered by tests
await currentContext.BuildAsync(Context.User, _globalSettings);

var clientType = DeviceTypes.ToClientType(currentContext.DeviceType);
Expand Down
Loading