diff --git a/kafka-java-auth/src/main/java/com/google/cloud/hosted/kafka/auth/GcpLoginCallbackHandler.java b/kafka-java-auth/src/main/java/com/google/cloud/hosted/kafka/auth/GcpLoginCallbackHandler.java index 11b9ad8..17fc48c 100755 --- a/kafka-java-auth/src/main/java/com/google/cloud/hosted/kafka/auth/GcpLoginCallbackHandler.java +++ b/kafka-java-auth/src/main/java/com/google/cloud/hosted/kafka/auth/GcpLoginCallbackHandler.java @@ -70,6 +70,10 @@ abstract static class StubGoogleCredentials extends GoogleCredentials { abstract String getAccount(); } + String getPrincipalFromEnvironmentVariable() { + return System.getenv("GOOGLE_MANAGED_KAFKA_AUTH_PRINCIPAL"); + } + private static final String HEADER = new Gson().toJson(ImmutableMap.of("typ", "JWT", "alg", "GOOG_OAUTH2_TOKEN")); @@ -138,9 +142,20 @@ private void handleTokenCallback(OAuthBearerTokenCallback callback) throws IOExc subject = ((StubGoogleCredentials) credentials).getAccount(); } else if (credentials instanceof IdTokenProvider) { subject = parseGoogleIdToken((IdTokenProvider) credentials).getEmail(); - } else { - throw new IOException("Unknown credentials type: " + credentials.getClass().getName()); } + + // Allow overriding the principal via an environment variable. This is useful for + // credentials that do not support the getAccount() or similar method, such as + // Workforce Identity Federation + String envSubject = getPrincipalFromEnvironmentVariable(); + if (envSubject != null && !envSubject.isEmpty()) { + subject = envSubject; + } + if (subject == null || subject.isEmpty()) { + throw new IOException("Unable to determine principal for credentials type: " + credentials.getClass().getName() + + ". Please set the GOOGLE_MANAGED_KAFKA_AUTH_PRINCIPAL environment variable."); + } + credentials.refreshIfExpired(); AccessToken googleAccessToken = credentials.getAccessToken(); String kafkaToken = getKafkaAccessToken(googleAccessToken, subject); diff --git a/kafka-java-auth/src/test/java/com/google/cloud/hosted/kafka/auth/GcpLoginCallbackHandlerTest.java b/kafka-java-auth/src/test/java/com/google/cloud/hosted/kafka/auth/GcpLoginCallbackHandlerTest.java index 7da15e5..10a6907 100755 --- a/kafka-java-auth/src/test/java/com/google/cloud/hosted/kafka/auth/GcpLoginCallbackHandlerTest.java +++ b/kafka-java-auth/src/test/java/com/google/cloud/hosted/kafka/auth/GcpLoginCallbackHandlerTest.java @@ -56,10 +56,15 @@ public String getAccount() { } } - static class UnsupportedCredentials extends GoogleCredentials {} + static class UnsupportedCredentials extends GoogleCredentials { + @Override + public AccessToken refreshAccessToken() throws IOException { + return new AccessToken("fake-access-token", Date.from(Instant.now().plusSeconds(3600))); + } + } - public GcpLoginCallbackHandler createHandler(GoogleCredentials credentials) throws Exception { - GcpLoginCallbackHandler gcpLoginCallbackHandler = new GcpLoginCallbackHandler(credentials); + public GcpLoginCallbackHandlerWithEnv createHandler(GoogleCredentials credentials, String principal) throws Exception { + GcpLoginCallbackHandlerWithEnv gcpLoginCallbackHandler = new GcpLoginCallbackHandlerWithEnv(credentials, principal); HashMap configs = new HashMap(); ArrayList jaasConfig = new ArrayList(); jaasConfig.add( @@ -72,13 +77,27 @@ public GcpLoginCallbackHandler createHandler(GoogleCredentials credentials) thro return gcpLoginCallbackHandler; } + static class GcpLoginCallbackHandlerWithEnv extends GcpLoginCallbackHandler { + private final String principal; + + GcpLoginCallbackHandlerWithEnv(GoogleCredentials credentials, String principal) { + super(credentials); + this.principal = principal; + } + + @Override + String getPrincipalFromEnvironmentVariable() { + return principal; + } + } + @Test - public void success() throws Exception { + public void success_withSupportedCredentials() throws Exception { Instant now = Instant.now(); OAuthBearerTokenCallback oauthBearerTokenCallback = new OAuthBearerTokenCallback(); Callback[] callbacks = {oauthBearerTokenCallback}; - GcpLoginCallbackHandler gcpOAuthBearerLoginCallbackHandler = createHandler(new FakeGoogleCredentials()); + GcpLoginCallbackHandler gcpOAuthBearerLoginCallbackHandler = createHandler(new FakeGoogleCredentials(), null); gcpOAuthBearerLoginCallbackHandler.handle(callbacks); OAuthBearerToken oauthBearerToken = oauthBearerTokenCallback.token(); @@ -102,11 +121,70 @@ public void success() throws Exception { } @Test - public void fail_unsupportedCredentialType() throws Exception { + // Test that we can override the principal when using a supported credentials type. + public void success_withSupportedCredentialsEnvPrincipalOverride() throws Exception { + Instant now = Instant.now(); + OAuthBearerTokenCallback oauthBearerTokenCallback = new OAuthBearerTokenCallback(); + Callback[] callbacks = {oauthBearerTokenCallback}; + + GcpLoginCallbackHandler gcpOAuthBearerLoginCallbackHandler = createHandler(new FakeGoogleCredentials(), "fake-environment-account@google.com"); + gcpOAuthBearerLoginCallbackHandler.handle(callbacks); + + OAuthBearerToken oauthBearerToken = oauthBearerTokenCallback.token(); + SerializedJwt jwtToken = new SerializedJwt(oauthBearerToken.value()); + + Map header = OAuthBearerUnsecuredJws.toMap(jwtToken.getHeader()); + Map payload = OAuthBearerUnsecuredJws.toMap(jwtToken.getPayload()); + + // Validate the JWT token is as expected by our server. + assertThat(header.get("typ")).isEqualTo("JWT"); + assertThat(header.get("alg")).isEqualTo("GOOG_OAUTH2_TOKEN"); + + assertThat(payload.get("exp")).isInstanceOf(Integer.class); + assertThat(((int) payload.get("exp"))).isGreaterThan((int) now.getEpochSecond()); + + // The signature is the base64 encoded Google OAuth token. + assertThat(new String(Base64.getUrlDecoder().decode(jwtToken.getSignature()), UTF_8)) + .isEqualTo("fake-access-token"); + assertThat(oauthBearerToken.scope()).isEqualTo(ImmutableSet.of("kafka")); + assertThat(oauthBearerToken.principalName()).isEqualTo("fake-environment-account@google.com"); + } + + @Test + public void fail_withUnsupportedCredentialsNoEnvPrincipal() throws Exception { OAuthBearerTokenCallback oauthBearerTokenCallback = new OAuthBearerTokenCallback(); Callback[] callbacks = {oauthBearerTokenCallback}; - GcpLoginCallbackHandler gcpOAuthBearerLoginCallbackHandler = createHandler(new UnsupportedCredentials()); + GcpLoginCallbackHandler gcpOAuthBearerLoginCallbackHandler = createHandler(new UnsupportedCredentials(), null); assertThrows(IOException.class, () -> gcpOAuthBearerLoginCallbackHandler.handle(callbacks)); } + + @Test + public void success_withUnsupportedCredentialsEnvPrincipal() throws Exception { + Instant now = Instant.now(); + OAuthBearerTokenCallback oauthBearerTokenCallback = new OAuthBearerTokenCallback(); + Callback[] callbacks = {oauthBearerTokenCallback}; + + GcpLoginCallbackHandler gcpOAuthBearerLoginCallbackHandler = createHandler(new UnsupportedCredentials(), "fake-environment-account@google.com"); + gcpOAuthBearerLoginCallbackHandler.handle(callbacks); + + OAuthBearerToken oauthBearerToken = oauthBearerTokenCallback.token(); + SerializedJwt jwtToken = new SerializedJwt(oauthBearerToken.value()); + + Map header = OAuthBearerUnsecuredJws.toMap(jwtToken.getHeader()); + Map payload = OAuthBearerUnsecuredJws.toMap(jwtToken.getPayload()); + + // Validate the JWT token is as expected by our server. + assertThat(header.get("typ")).isEqualTo("JWT"); + assertThat(header.get("alg")).isEqualTo("GOOG_OAUTH2_TOKEN"); + + assertThat(payload.get("exp")).isInstanceOf(Integer.class); + assertThat(((int) payload.get("exp"))).isGreaterThan((int) now.getEpochSecond()); + + // The signature is the base64 encoded Google OAuth token. + assertThat(new String(Base64.getUrlDecoder().decode(jwtToken.getSignature()), UTF_8)) + .isEqualTo("fake-access-token"); + assertThat(oauthBearerToken.scope()).isEqualTo(ImmutableSet.of("kafka")); + assertThat(oauthBearerToken.principalName()).isEqualTo("fake-environment-account@google.com"); + } }