diff --git a/credentials/credentials.ts b/credentials/credentials.ts index b350178f..cb425f65 100644 --- a/credentials/credentials.ts +++ b/credentials/credentials.ts @@ -45,6 +45,7 @@ export const DEFAULT_TOKEN_ENDPOINT_PATH = "oauth/token"; export class Credentials { private accessToken?: string; private accessTokenExpiryDate?: Date; + private refreshAccessTokenPromise?: Promise; public static init(configuration: { credentials: AuthCredentialsConfig, telemetry: TelemetryConfiguration, baseOptions?: any }, axios: AxiosInstance = globalAxios): Credentials { return new Credentials(configuration.credentials, axios, configuration.telemetry, configuration.baseOptions); @@ -141,7 +142,13 @@ export class Credentials { return this.accessToken; } - return this.refreshAccessToken(); + if (!this.refreshAccessTokenPromise) { + this.refreshAccessTokenPromise = this.refreshAccessToken().finally(() => { + this.refreshAccessTokenPromise = undefined; + }); + } + + return this.refreshAccessTokenPromise; } } diff --git a/tests/credentials.test.ts b/tests/credentials.test.ts index a465c794..c5396927 100644 --- a/tests/credentials.test.ts +++ b/tests/credentials.test.ts @@ -537,5 +537,93 @@ describe("Credentials", () => { expect(scope.isDone()).toBe(true); }); + + test("should send a single token request for concurrent access token reads", async () => { + const apiTokenIssuer = "issuer.fga.example"; + const expectedBaseUrl = "https://issuer.fga.example"; + const expectedPath = `/${DEFAULT_TOKEN_ENDPOINT_PATH}`; + + const scope = nock(expectedBaseUrl) + .post(expectedPath) + .once() + .delay(20) + .reply(200, { + access_token: "shared-token", + expires_in: 300, + }); + + const credentials = new Credentials( + { + method: CredentialsMethod.ClientCredentials, + config: { + apiTokenIssuer, + apiAudience: OPENFGA_API_AUDIENCE, + clientId: OPENFGA_CLIENT_ID, + clientSecret: OPENFGA_CLIENT_SECRET, + }, + } as AuthCredentialsConfig, + undefined, + mockTelemetryConfig, + ); + + const headers = await Promise.all( + Array.from({ length: 5 }, () => credentials.getAccessTokenHeader()) + ); + + headers.forEach(header => { + expect(header?.value).toBe("Bearer shared-token"); + }); + expect(scope.isDone()).toBe(true); + }); + + test("should clear shared refresh promise after failure and retry on the next call", async () => { + const apiTokenIssuer = "issuer.fga.example"; + const expectedBaseUrl = "https://issuer.fga.example"; + const expectedPath = `/${DEFAULT_TOKEN_ENDPOINT_PATH}`; + + const scope = nock(expectedBaseUrl) + .post(expectedPath) + .once() + .reply(404, { + code: "not_found", + message: "token exchange failed", + }) + .post(expectedPath) + .once() + .reply(200, { + access_token: "recovered-token", + expires_in: 300, + }); + + const credentials = new Credentials( + { + method: CredentialsMethod.ClientCredentials, + config: { + apiTokenIssuer, + apiAudience: OPENFGA_API_AUDIENCE, + clientId: OPENFGA_CLIENT_ID, + clientSecret: OPENFGA_CLIENT_SECRET, + }, + } as AuthCredentialsConfig, + undefined, + mockTelemetryConfig, + ); + + const results = await Promise.allSettled( + Array.from({ length: 5 }, () => credentials.getAccessTokenHeader()) + ); + const rejected = results.filter((result): result is PromiseRejectedResult => result.status === "rejected"); + + expect(rejected).toHaveLength(5); + expect(rejected[0].reason).toBe(rejected[1].reason); + expect(rejected[1].reason).toBe(rejected[2].reason); + expect(rejected[2].reason).toBe(rejected[3].reason); + expect(rejected[3].reason).toBe(rejected[4].reason); + + const header = await credentials.getAccessTokenHeader(); + + expect(header?.value).toBe("Bearer recovered-token"); + expect(scope.isDone()).toBe(true); + }); }); });