Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ msbuild.wrn

# Test files
*.trx
.nuget/
216 changes: 216 additions & 0 deletions src/ImageBuilder.Tests/CachingTokenCredentialTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable enable

using System;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
using Moq;
using Xunit;

namespace Microsoft.DotNet.ImageBuilder.Tests;

public class CachingTokenCredentialTests
{
private static readonly string[] TestScopes = ["https://test.scope/.default"];
private static readonly TokenRequestContext TestRequestContext = new(TestScopes);

[Fact]
public void GetToken_CachesTokenOnFirstCall()
{
// Arrange
var expectedToken = CreateToken(expiresInMinutes: 60);
var innerCredential = new Mock<TokenCredential>();
innerCredential
.Setup(c => c.GetToken(It.IsAny<TokenRequestContext>(), It.IsAny<CancellationToken>()))
.Returns(expectedToken);

var cachingCredential = new CachingTokenCredential(innerCredential.Object);

// Act
var token1 = cachingCredential.GetToken(TestRequestContext, CancellationToken.None);
var token2 = cachingCredential.GetToken(TestRequestContext, CancellationToken.None);

// Assert
Assert.Equal(expectedToken.Token, token1.Token);
Assert.Equal(expectedToken.Token, token2.Token);

// Verify the inner credential was only called once (token was cached)
innerCredential.Verify(
c => c.GetToken(It.IsAny<TokenRequestContext>(), It.IsAny<CancellationToken>()),
Times.Once);
}

[Fact]
public void GetToken_RefreshesExpiredToken()
{
// Arrange
var expiredToken = CreateToken(expiresInMinutes: 4); // Expires in 4 minutes (within 5-minute buffer)
var freshToken = CreateToken(expiresInMinutes: 60);

var innerCredential = new Mock<TokenCredential>();
innerCredential
.SetupSequence(c => c.GetToken(It.IsAny<TokenRequestContext>(), It.IsAny<CancellationToken>()))
.Returns(expiredToken)
.Returns(freshToken);

var cachingCredential = new CachingTokenCredential(innerCredential.Object);

// Act - First call returns expired token
var token1 = cachingCredential.GetToken(TestRequestContext, CancellationToken.None);

// Second call should refresh since token is about to expire
var token2 = cachingCredential.GetToken(TestRequestContext, CancellationToken.None);

// Assert
Assert.Equal(expiredToken.Token, token1.Token);
Assert.Equal(freshToken.Token, token2.Token);

// Verify the inner credential was called twice (once for initial, once for refresh)
innerCredential.Verify(
c => c.GetToken(It.IsAny<TokenRequestContext>(), It.IsAny<CancellationToken>()),
Times.Exactly(2));
}

[Fact]
public async Task GetTokenAsync_CachesTokenOnFirstCall()
{
// Arrange
var expectedToken = CreateToken(expiresInMinutes: 60);
var innerCredential = new Mock<TokenCredential>();
innerCredential
.Setup(c => c.GetTokenAsync(It.IsAny<TokenRequestContext>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(expectedToken);

var cachingCredential = new CachingTokenCredential(innerCredential.Object);

// Act
var token1 = await cachingCredential.GetTokenAsync(TestRequestContext, CancellationToken.None);
var token2 = await cachingCredential.GetTokenAsync(TestRequestContext, CancellationToken.None);

// Assert
Assert.Equal(expectedToken.Token, token1.Token);
Assert.Equal(expectedToken.Token, token2.Token);

// Verify the inner credential was only called once (token was cached)
innerCredential.Verify(
c => c.GetTokenAsync(It.IsAny<TokenRequestContext>(), It.IsAny<CancellationToken>()),
Times.Once);
}

[Fact]
public async Task GetTokenAsync_RefreshesExpiredToken()
{
// Arrange
var expiredToken = CreateToken(expiresInMinutes: 4); // Expires in 4 minutes (within 5-minute buffer)
var freshToken = CreateToken(expiresInMinutes: 60);

var innerCredential = new Mock<TokenCredential>();
innerCredential
.SetupSequence(c => c.GetTokenAsync(It.IsAny<TokenRequestContext>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(expiredToken)
.ReturnsAsync(freshToken);

var cachingCredential = new CachingTokenCredential(innerCredential.Object);

// Act - First call returns expired token
var token1 = await cachingCredential.GetTokenAsync(TestRequestContext, CancellationToken.None);

// Second call should refresh since token is about to expire
var token2 = await cachingCredential.GetTokenAsync(TestRequestContext, CancellationToken.None);

// Assert
Assert.Equal(expiredToken.Token, token1.Token);
Assert.Equal(freshToken.Token, token2.Token);

// Verify the inner credential was called twice (once for initial, once for refresh)
innerCredential.Verify(
c => c.GetTokenAsync(It.IsAny<TokenRequestContext>(), It.IsAny<CancellationToken>()),
Times.Exactly(2));
}

[Fact]
public void GetToken_ValidTokenNotRefreshed()
{
// Arrange
var validToken = CreateToken(expiresInMinutes: 30); // Expires in 30 minutes (well beyond 5-minute buffer)

var innerCredential = new Mock<TokenCredential>();
innerCredential
.Setup(c => c.GetToken(It.IsAny<TokenRequestContext>(), It.IsAny<CancellationToken>()))
.Returns(validToken);

var cachingCredential = new CachingTokenCredential(innerCredential.Object);

// Act - Multiple calls
var token1 = cachingCredential.GetToken(TestRequestContext, CancellationToken.None);
var token2 = cachingCredential.GetToken(TestRequestContext, CancellationToken.None);
var token3 = cachingCredential.GetToken(TestRequestContext, CancellationToken.None);

// Assert - All tokens should be the same
Assert.Equal(validToken.Token, token1.Token);
Assert.Equal(validToken.Token, token2.Token);
Assert.Equal(validToken.Token, token3.Token);

// Verify the inner credential was only called once
innerCredential.Verify(
c => c.GetToken(It.IsAny<TokenRequestContext>(), It.IsAny<CancellationToken>()),
Times.Once);
}

[Fact]
public void Constructor_ThrowsOnNullCredential()
{
// Act & Assert
Assert.Throws<ArgumentNullException>(() => new CachingTokenCredential(null!));
}

[Fact]
public async Task GetTokenAsync_ConcurrentCalls_OnlyFetchesOnce()
{
// Arrange
var expectedToken = CreateToken(expiresInMinutes: 60);
var callCount = 0;

var innerCredential = new Mock<TokenCredential>();
innerCredential
.Setup(c => c.GetTokenAsync(It.IsAny<TokenRequestContext>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(() =>
{
Interlocked.Increment(ref callCount);
// Simulate a slow token fetch
Thread.Sleep(100);
return expectedToken;
});

var cachingCredential = new CachingTokenCredential(innerCredential.Object);

// Act - Start multiple concurrent token requests
var tasks = new Task<AccessToken>[10];
for (int i = 0; i < tasks.Length; i++)
{
tasks[i] = cachingCredential.GetTokenAsync(TestRequestContext, CancellationToken.None).AsTask();
}

var results = await Task.WhenAll(tasks);

// Assert - All results should be the same token
foreach (var result in results)
{
Assert.Equal(expectedToken.Token, result.Token);
}

// Verify the inner credential was only called once despite concurrent requests
Assert.Equal(1, callCount);
}

private static AccessToken CreateToken(int expiresInMinutes)
{
return new AccessToken(
accessToken: Guid.NewGuid().ToString(),
expiresOn: DateTimeOffset.UtcNow.AddMinutes(expiresInMinutes));
}
}
6 changes: 5 additions & 1 deletion src/ImageBuilder/AzureTokenCredentialProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ Attempted to get Service Connection credential but SYSTEM_ACCESSTOKEN environmen
);
}

return credential;
// Wrap the credential with CachingTokenCredential to ensure tokens are cached.
// AzurePipelinesCredential does not cache tokens internally, so each call to
// GetToken would make a new request to Azure, which is slow. The caching wrapper
// caches the token and refreshes it only when it's close to expiration.
return new CachingTokenCredential(credential);
});
}
}
91 changes: 91 additions & 0 deletions src/ImageBuilder/CachingTokenCredential.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;

namespace Microsoft.DotNet.ImageBuilder;

#nullable enable

/// <summary>
/// A <see cref="TokenCredential"/> wrapper that caches access tokens and refreshes them
/// when they are close to expiration. This is necessary for credentials like
/// <see cref="Azure.Identity.AzurePipelinesCredential"/> that do not cache tokens internally.
/// </summary>
/// <remarks>
/// This class uses <see cref="SemaphoreSlim"/> to provide thread-safe access for both synchronous
/// and asynchronous callers, preventing concurrent token fetches that would waste network resources.
/// The <see cref="SemaphoreSlim"/> is used instead of a regular lock because it supports both
/// <see cref="SemaphoreSlim.Wait(CancellationToken)"/> and
/// <see cref="SemaphoreSlim.WaitAsync(CancellationToken)"/> methods, allowing coordination between
/// the sync <see cref="GetToken"/> and async <see cref="GetTokenAsync"/> methods.
/// </remarks>
internal class CachingTokenCredential : TokenCredential
{
private readonly TokenCredential _innerCredential;
private readonly SemaphoreSlim _semaphore = new(1, 1);
private AccessToken? _cachedToken;

/// <summary>
/// The amount of time before token expiration at which a new token should be fetched.
/// This ensures we don't use a token that's about to expire.
/// </summary>
private static readonly TimeSpan TokenRefreshBuffer = TimeSpan.FromMinutes(5);

public CachingTokenCredential(TokenCredential innerCredential)
{
_innerCredential = innerCredential ?? throw new ArgumentNullException(nameof(innerCredential));
}

public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken)
{
_semaphore.Wait(cancellationToken);
try
{
if (IsTokenValid(_cachedToken))
{
return _cachedToken!.Value;
}

_cachedToken = _innerCredential.GetToken(requestContext, cancellationToken);
return _cachedToken.Value;
}
finally
{
_semaphore.Release();
}
}

public override async ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken)
{
await _semaphore.WaitAsync(cancellationToken);
try
{
if (IsTokenValid(_cachedToken))
{
return _cachedToken!.Value;
}

_cachedToken = await _innerCredential.GetTokenAsync(requestContext, cancellationToken);
return _cachedToken.Value;
}
finally
{
_semaphore.Release();
}
}

private static bool IsTokenValid(AccessToken? token)
{
if (!token.HasValue)
{
return false;
}

// Token is valid if it's not expired and won't expire within the refresh buffer
return token.Value.ExpiresOn > DateTimeOffset.UtcNow.Add(TokenRefreshBuffer);
}
}
4 changes: 4 additions & 0 deletions src/ImageBuilder/Microsoft.DotNet.ImageBuilder.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
<SignAssembly>false</SignAssembly>
</PropertyGroup>

<ItemGroup>
<InternalsVisibleTo Include="Microsoft.DotNet.ImageBuilder.Tests" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Azure.Containers.ContainerRegistry" Version="1.2.0" />
<PackageReference Include="Azure.Identity" Version="1.13.2" />
Expand Down
Loading