diff --git a/.gitignore b/.gitignore index 9412f5be..652a2084 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,4 @@ msbuild.wrn # Test files *.trx +.nuget/ diff --git a/src/ImageBuilder.Tests/CachingTokenCredentialTests.cs b/src/ImageBuilder.Tests/CachingTokenCredentialTests.cs new file mode 100644 index 00000000..c333b20e --- /dev/null +++ b/src/ImageBuilder.Tests/CachingTokenCredentialTests.cs @@ -0,0 +1,216 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; +using Moq; +using Xunit; + +namespace Microsoft.DotNet.ImageBuilder.Tests; + +public class CachingTokenCredentialTests +{ + private static readonly string[] TestScopes = ["https://test.scope/.default"]; + private static readonly TokenRequestContext TestRequestContext = new(TestScopes); + + [Fact] + public void GetToken_CachesTokenOnFirstCall() + { + // Arrange + var expectedToken = CreateToken(expiresInMinutes: 60); + var innerCredential = new Mock(); + innerCredential + .Setup(c => c.GetToken(It.IsAny(), It.IsAny())) + .Returns(expectedToken); + + var cachingCredential = new CachingTokenCredential(innerCredential.Object); + + // Act + var token1 = cachingCredential.GetToken(TestRequestContext, CancellationToken.None); + var token2 = cachingCredential.GetToken(TestRequestContext, CancellationToken.None); + + // Assert + Assert.Equal(expectedToken.Token, token1.Token); + Assert.Equal(expectedToken.Token, token2.Token); + + // Verify the inner credential was only called once (token was cached) + innerCredential.Verify( + c => c.GetToken(It.IsAny(), It.IsAny()), + Times.Once); + } + + [Fact] + public void GetToken_RefreshesExpiredToken() + { + // Arrange + var expiredToken = CreateToken(expiresInMinutes: 4); // Expires in 4 minutes (within 5-minute buffer) + var freshToken = CreateToken(expiresInMinutes: 60); + + var innerCredential = new Mock(); + innerCredential + .SetupSequence(c => c.GetToken(It.IsAny(), It.IsAny())) + .Returns(expiredToken) + .Returns(freshToken); + + var cachingCredential = new CachingTokenCredential(innerCredential.Object); + + // Act - First call returns expired token + var token1 = cachingCredential.GetToken(TestRequestContext, CancellationToken.None); + + // Second call should refresh since token is about to expire + var token2 = cachingCredential.GetToken(TestRequestContext, CancellationToken.None); + + // Assert + Assert.Equal(expiredToken.Token, token1.Token); + Assert.Equal(freshToken.Token, token2.Token); + + // Verify the inner credential was called twice (once for initial, once for refresh) + innerCredential.Verify( + c => c.GetToken(It.IsAny(), It.IsAny()), + Times.Exactly(2)); + } + + [Fact] + public async Task GetTokenAsync_CachesTokenOnFirstCall() + { + // Arrange + var expectedToken = CreateToken(expiresInMinutes: 60); + var innerCredential = new Mock(); + innerCredential + .Setup(c => c.GetTokenAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(expectedToken); + + var cachingCredential = new CachingTokenCredential(innerCredential.Object); + + // Act + var token1 = await cachingCredential.GetTokenAsync(TestRequestContext, CancellationToken.None); + var token2 = await cachingCredential.GetTokenAsync(TestRequestContext, CancellationToken.None); + + // Assert + Assert.Equal(expectedToken.Token, token1.Token); + Assert.Equal(expectedToken.Token, token2.Token); + + // Verify the inner credential was only called once (token was cached) + innerCredential.Verify( + c => c.GetTokenAsync(It.IsAny(), It.IsAny()), + Times.Once); + } + + [Fact] + public async Task GetTokenAsync_RefreshesExpiredToken() + { + // Arrange + var expiredToken = CreateToken(expiresInMinutes: 4); // Expires in 4 minutes (within 5-minute buffer) + var freshToken = CreateToken(expiresInMinutes: 60); + + var innerCredential = new Mock(); + innerCredential + .SetupSequence(c => c.GetTokenAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(expiredToken) + .ReturnsAsync(freshToken); + + var cachingCredential = new CachingTokenCredential(innerCredential.Object); + + // Act - First call returns expired token + var token1 = await cachingCredential.GetTokenAsync(TestRequestContext, CancellationToken.None); + + // Second call should refresh since token is about to expire + var token2 = await cachingCredential.GetTokenAsync(TestRequestContext, CancellationToken.None); + + // Assert + Assert.Equal(expiredToken.Token, token1.Token); + Assert.Equal(freshToken.Token, token2.Token); + + // Verify the inner credential was called twice (once for initial, once for refresh) + innerCredential.Verify( + c => c.GetTokenAsync(It.IsAny(), It.IsAny()), + Times.Exactly(2)); + } + + [Fact] + public void GetToken_ValidTokenNotRefreshed() + { + // Arrange + var validToken = CreateToken(expiresInMinutes: 30); // Expires in 30 minutes (well beyond 5-minute buffer) + + var innerCredential = new Mock(); + innerCredential + .Setup(c => c.GetToken(It.IsAny(), It.IsAny())) + .Returns(validToken); + + var cachingCredential = new CachingTokenCredential(innerCredential.Object); + + // Act - Multiple calls + var token1 = cachingCredential.GetToken(TestRequestContext, CancellationToken.None); + var token2 = cachingCredential.GetToken(TestRequestContext, CancellationToken.None); + var token3 = cachingCredential.GetToken(TestRequestContext, CancellationToken.None); + + // Assert - All tokens should be the same + Assert.Equal(validToken.Token, token1.Token); + Assert.Equal(validToken.Token, token2.Token); + Assert.Equal(validToken.Token, token3.Token); + + // Verify the inner credential was only called once + innerCredential.Verify( + c => c.GetToken(It.IsAny(), It.IsAny()), + Times.Once); + } + + [Fact] + public void Constructor_ThrowsOnNullCredential() + { + // Act & Assert + Assert.Throws(() => new CachingTokenCredential(null!)); + } + + [Fact] + public async Task GetTokenAsync_ConcurrentCalls_OnlyFetchesOnce() + { + // Arrange + var expectedToken = CreateToken(expiresInMinutes: 60); + var callCount = 0; + + var innerCredential = new Mock(); + innerCredential + .Setup(c => c.GetTokenAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(() => + { + Interlocked.Increment(ref callCount); + // Simulate a slow token fetch + Thread.Sleep(100); + return expectedToken; + }); + + var cachingCredential = new CachingTokenCredential(innerCredential.Object); + + // Act - Start multiple concurrent token requests + var tasks = new Task[10]; + for (int i = 0; i < tasks.Length; i++) + { + tasks[i] = cachingCredential.GetTokenAsync(TestRequestContext, CancellationToken.None).AsTask(); + } + + var results = await Task.WhenAll(tasks); + + // Assert - All results should be the same token + foreach (var result in results) + { + Assert.Equal(expectedToken.Token, result.Token); + } + + // Verify the inner credential was only called once despite concurrent requests + Assert.Equal(1, callCount); + } + + private static AccessToken CreateToken(int expiresInMinutes) + { + return new AccessToken( + accessToken: Guid.NewGuid().ToString(), + expiresOn: DateTimeOffset.UtcNow.AddMinutes(expiresInMinutes)); + } +} diff --git a/src/ImageBuilder/AzureTokenCredentialProvider.cs b/src/ImageBuilder/AzureTokenCredentialProvider.cs index c65fe2ea..c5b65007 100644 --- a/src/ImageBuilder/AzureTokenCredentialProvider.cs +++ b/src/ImageBuilder/AzureTokenCredentialProvider.cs @@ -81,7 +81,11 @@ Attempted to get Service Connection credential but SYSTEM_ACCESSTOKEN environmen ); } - return credential; + // Wrap the credential with CachingTokenCredential to ensure tokens are cached. + // AzurePipelinesCredential does not cache tokens internally, so each call to + // GetToken would make a new request to Azure, which is slow. The caching wrapper + // caches the token and refreshes it only when it's close to expiration. + return new CachingTokenCredential(credential); }); } } diff --git a/src/ImageBuilder/CachingTokenCredential.cs b/src/ImageBuilder/CachingTokenCredential.cs new file mode 100644 index 00000000..7beaae41 --- /dev/null +++ b/src/ImageBuilder/CachingTokenCredential.cs @@ -0,0 +1,91 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; + +namespace Microsoft.DotNet.ImageBuilder; + +#nullable enable + +/// +/// A wrapper that caches access tokens and refreshes them +/// when they are close to expiration. This is necessary for credentials like +/// that do not cache tokens internally. +/// +/// +/// This class uses to provide thread-safe access for both synchronous +/// and asynchronous callers, preventing concurrent token fetches that would waste network resources. +/// The is used instead of a regular lock because it supports both +/// and +/// methods, allowing coordination between +/// the sync and async methods. +/// +internal class CachingTokenCredential : TokenCredential +{ + private readonly TokenCredential _innerCredential; + private readonly SemaphoreSlim _semaphore = new(1, 1); + private AccessToken? _cachedToken; + + /// + /// The amount of time before token expiration at which a new token should be fetched. + /// This ensures we don't use a token that's about to expire. + /// + private static readonly TimeSpan TokenRefreshBuffer = TimeSpan.FromMinutes(5); + + public CachingTokenCredential(TokenCredential innerCredential) + { + _innerCredential = innerCredential ?? throw new ArgumentNullException(nameof(innerCredential)); + } + + public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) + { + _semaphore.Wait(cancellationToken); + try + { + if (IsTokenValid(_cachedToken)) + { + return _cachedToken!.Value; + } + + _cachedToken = _innerCredential.GetToken(requestContext, cancellationToken); + return _cachedToken.Value; + } + finally + { + _semaphore.Release(); + } + } + + public override async ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) + { + await _semaphore.WaitAsync(cancellationToken); + try + { + if (IsTokenValid(_cachedToken)) + { + return _cachedToken!.Value; + } + + _cachedToken = await _innerCredential.GetTokenAsync(requestContext, cancellationToken); + return _cachedToken.Value; + } + finally + { + _semaphore.Release(); + } + } + + private static bool IsTokenValid(AccessToken? token) + { + if (!token.HasValue) + { + return false; + } + + // Token is valid if it's not expired and won't expire within the refresh buffer + return token.Value.ExpiresOn > DateTimeOffset.UtcNow.Add(TokenRefreshBuffer); + } +} diff --git a/src/ImageBuilder/Microsoft.DotNet.ImageBuilder.csproj b/src/ImageBuilder/Microsoft.DotNet.ImageBuilder.csproj index bb9fad0d..31a1aa91 100644 --- a/src/ImageBuilder/Microsoft.DotNet.ImageBuilder.csproj +++ b/src/ImageBuilder/Microsoft.DotNet.ImageBuilder.csproj @@ -14,6 +14,10 @@ false + + + +