Skip to content

Commit df08956

Browse files
authored
feat(gax): add API key authentication to ClientSettings (#3137)
Allow gax client libraries to authenticate using API key via setApiKey method exposed from ClientSettings. Also added deduping to GRPC calls for api key headers. Tested using LanguageServiceSettings cc @westarle
1 parent e08906c commit df08956

File tree

12 files changed

+717
-119
lines changed

12 files changed

+717
-119
lines changed

gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java

+32-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import com.google.api.gax.rpc.TransportChannelProvider;
4444
import com.google.api.gax.rpc.internal.EnvironmentProvider;
4545
import com.google.api.gax.rpc.mtls.MtlsProvider;
46+
import com.google.auth.ApiKeyCredentials;
4647
import com.google.auth.Credentials;
4748
import com.google.auth.oauth2.ComputeEngineCredentials;
4849
import com.google.common.annotations.VisibleForTesting;
@@ -63,6 +64,8 @@
6364
import java.nio.charset.StandardCharsets;
6465
import java.security.GeneralSecurityException;
6566
import java.security.KeyStore;
67+
import java.util.HashMap;
68+
import java.util.List;
6669
import java.util.Map;
6770
import java.util.concurrent.Executor;
6871
import java.util.concurrent.ScheduledExecutorService;
@@ -123,6 +126,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
123126
@Nullable private final Boolean allowNonDefaultServiceAccount;
124127
@VisibleForTesting final ImmutableMap<String, ?> directPathServiceConfig;
125128
@Nullable private final MtlsProvider mtlsProvider;
129+
@VisibleForTesting final Map<String, String> headersWithDuplicatesRemoved = new HashMap<>();
126130

127131
@Nullable
128132
private final ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator;
@@ -408,7 +412,8 @@ ChannelCredentials createMtlsChannelCredentials() throws IOException, GeneralSec
408412

409413
private ManagedChannel createSingleChannel() throws IOException {
410414
GrpcHeaderInterceptor headerInterceptor =
411-
new GrpcHeaderInterceptor(headerProvider.getHeaders());
415+
new GrpcHeaderInterceptor(headersWithDuplicatesRemoved);
416+
412417
GrpcMetadataHandlerInterceptor metadataHandlerInterceptor =
413418
new GrpcMetadataHandlerInterceptor();
414419

@@ -496,6 +501,28 @@ private ManagedChannel createSingleChannel() throws IOException {
496501
return managedChannel;
497502
}
498503

504+
/* Remove provided headers that will also get set by {@link com.google.auth.ApiKeyCredentials}. They will be added as part of the grpc call when performing auth
505+
* {@link io.grpc.auth.GoogleAuthLibraryCallCredentials#applyRequestMetadata}. GRPC does not dedup headers {@link https://github.com/grpc/grpc-java/blob/a140e1bb0cfa662bcdb7823d73320eb8d49046f1/api/src/main/java/io/grpc/Metadata.java#L504} so we must before initiating the call.
506+
*
507+
* Note: This is specific for ApiKeyCredentials as duplicate API key headers causes a failure on the back end. At this time we are not sure of the behavior for other credentials.
508+
*/
509+
private void removeApiKeyCredentialDuplicateHeaders() {
510+
if (headerProvider != null) {
511+
headersWithDuplicatesRemoved.putAll(headerProvider.getHeaders());
512+
}
513+
if (credentials != null && credentials instanceof ApiKeyCredentials) {
514+
try {
515+
Map<String, List<String>> credentialRequestMetatData = credentials.getRequestMetadata();
516+
if (credentialRequestMetatData != null) {
517+
headersWithDuplicatesRemoved.keySet().removeAll(credentialRequestMetatData.keySet());
518+
}
519+
} catch (IOException e) {
520+
// unreachable, there is no scenario that getRequestMetatData for ApiKeyCredentials will
521+
// throw an IOException
522+
}
523+
}
524+
}
525+
499526
/**
500527
* Marked as Internal Api and intended for internal use. DirectPath must be enabled via the
501528
* settings and a few other configurations/settings must also be valid for the request to go
@@ -883,7 +910,10 @@ public Builder setDirectPathServiceConfig(Map<String, ?> serviceConfig) {
883910
}
884911

885912
public InstantiatingGrpcChannelProvider build() {
886-
return new InstantiatingGrpcChannelProvider(this);
913+
InstantiatingGrpcChannelProvider instantiatingGrpcChannelProvider =
914+
new InstantiatingGrpcChannelProvider(this);
915+
instantiatingGrpcChannelProvider.removeApiKeyCredentialDuplicateHeaders();
916+
return instantiatingGrpcChannelProvider;
887917
}
888918

889919
/**

gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java

+103
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,21 @@
3434
import static com.google.common.base.Preconditions.checkArgument;
3535
import static com.google.common.truth.Truth.assertThat;
3636
import static org.junit.jupiter.api.Assertions.assertEquals;
37+
import static org.junit.jupiter.api.Assertions.assertNull;
3738
import static org.junit.jupiter.api.Assertions.assertThrows;
3839

3940
import com.google.api.core.ApiFunction;
4041
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
42+
import com.google.api.gax.rpc.FixedHeaderProvider;
4143
import com.google.api.gax.rpc.HeaderProvider;
4244
import com.google.api.gax.rpc.TransportChannel;
4345
import com.google.api.gax.rpc.TransportChannelProvider;
4446
import com.google.api.gax.rpc.internal.EnvironmentProvider;
4547
import com.google.api.gax.rpc.mtls.AbstractMtlsTransportChannelTest;
4648
import com.google.api.gax.rpc.mtls.MtlsProvider;
49+
import com.google.auth.ApiKeyCredentials;
4750
import com.google.auth.Credentials;
51+
import com.google.auth.http.AuthHttpConstants;
4852
import com.google.auth.oauth2.CloudShellCredentials;
4953
import com.google.auth.oauth2.ComputeEngineCredentials;
5054
import com.google.common.collect.ImmutableList;
@@ -79,6 +83,8 @@
7983

8084
class InstantiatingGrpcChannelProviderTest extends AbstractMtlsTransportChannelTest {
8185
private static final String DEFAULT_ENDPOINT = "test.googleapis.com:443";
86+
private static final String API_KEY_HEADER_VALUE = "fake_api_key_2";
87+
private static final String API_KEY_AUTH_HEADER_KEY = "x-goog-api-key";
8288
private static String originalOSName;
8389
private ComputeEngineCredentials computeEngineCredentials;
8490

@@ -877,6 +883,103 @@ public void canUseDirectPath_nonGDUUniverseDomain() {
877883
Truth.assertThat(provider.canUseDirectPath()).isFalse();
878884
}
879885

886+
@Test
887+
void providerInitializedWithNonConflictingHeaders_retainsHeaders() {
888+
InstantiatingGrpcChannelProvider.Builder builder =
889+
InstantiatingGrpcChannelProvider.newBuilder()
890+
.setHeaderProvider(getHeaderProviderWithApiKeyHeader())
891+
.setEndpoint("test.random.com:443");
892+
893+
InstantiatingGrpcChannelProvider provider = builder.build();
894+
895+
assertEquals(1, provider.headersWithDuplicatesRemoved.size());
896+
assertEquals(
897+
API_KEY_HEADER_VALUE, provider.headersWithDuplicatesRemoved.get(API_KEY_AUTH_HEADER_KEY));
898+
}
899+
900+
@Test
901+
void providersInitializedWithConflictingApiKeyCredentialHeaders_removesDuplicates() {
902+
String correctApiKey = "fake_api_key";
903+
ApiKeyCredentials apiKeyCredentials = ApiKeyCredentials.create(correctApiKey);
904+
InstantiatingGrpcChannelProvider.Builder builder =
905+
InstantiatingGrpcChannelProvider.newBuilder()
906+
.setCredentials(apiKeyCredentials)
907+
.setHeaderProvider(getHeaderProviderWithApiKeyHeader())
908+
.setEndpoint("test.random.com:443");
909+
910+
InstantiatingGrpcChannelProvider provider = builder.build();
911+
912+
assertEquals(0, provider.headersWithDuplicatesRemoved.size());
913+
assertNull(provider.headersWithDuplicatesRemoved.get(API_KEY_AUTH_HEADER_KEY));
914+
}
915+
916+
@Test
917+
void providersInitializedWithConflictingNonApiKeyCredentialHeaders_doesNotRemoveDuplicates() {
918+
String authProvidedHeader = "Bearer token";
919+
Map<String, String> header = new HashMap<>();
920+
header.put(AuthHttpConstants.AUTHORIZATION, authProvidedHeader);
921+
InstantiatingGrpcChannelProvider.Builder builder =
922+
InstantiatingGrpcChannelProvider.newBuilder()
923+
.setCredentials(computeEngineCredentials)
924+
.setHeaderProvider(FixedHeaderProvider.create(header))
925+
.setEndpoint("test.random.com:443");
926+
927+
InstantiatingGrpcChannelProvider provider = builder.build();
928+
929+
assertEquals(1, provider.headersWithDuplicatesRemoved.size());
930+
assertEquals(
931+
authProvidedHeader,
932+
provider.headersWithDuplicatesRemoved.get(AuthHttpConstants.AUTHORIZATION));
933+
}
934+
935+
@Test
936+
void buildProvider_handlesNullHeaderProvider() {
937+
InstantiatingGrpcChannelProvider.Builder builder =
938+
InstantiatingGrpcChannelProvider.newBuilder().setEndpoint("test.random.com:443");
939+
940+
InstantiatingGrpcChannelProvider provider = builder.build();
941+
942+
assertEquals(0, provider.headersWithDuplicatesRemoved.size());
943+
}
944+
945+
@Test
946+
void buildProvider_handlesNullCredentialsMetadataRequest() throws IOException {
947+
Credentials credentials = Mockito.mock(Credentials.class);
948+
Mockito.when(credentials.getRequestMetadata()).thenReturn(null);
949+
InstantiatingGrpcChannelProvider.Builder builder =
950+
InstantiatingGrpcChannelProvider.newBuilder()
951+
.setHeaderProvider(getHeaderProviderWithApiKeyHeader())
952+
.setEndpoint("test.random.com:443");
953+
954+
InstantiatingGrpcChannelProvider provider = builder.build();
955+
956+
assertEquals(1, provider.headersWithDuplicatesRemoved.size());
957+
assertEquals(
958+
API_KEY_HEADER_VALUE, provider.headersWithDuplicatesRemoved.get(API_KEY_AUTH_HEADER_KEY));
959+
}
960+
961+
@Test
962+
void buildProvider_handlesErrorRetrievingCredentialsMetadataRequest() throws IOException {
963+
Credentials credentials = Mockito.mock(Credentials.class);
964+
Mockito.when(credentials.getRequestMetadata())
965+
.thenThrow(new IOException("Error getting request metadata"));
966+
InstantiatingGrpcChannelProvider.Builder builder =
967+
InstantiatingGrpcChannelProvider.newBuilder()
968+
.setHeaderProvider(getHeaderProviderWithApiKeyHeader())
969+
.setEndpoint("test.random.com:443");
970+
InstantiatingGrpcChannelProvider provider = builder.build();
971+
972+
assertEquals(1, provider.headersWithDuplicatesRemoved.size());
973+
assertEquals(
974+
API_KEY_HEADER_VALUE, provider.headersWithDuplicatesRemoved.get(API_KEY_AUTH_HEADER_KEY));
975+
}
976+
977+
private FixedHeaderProvider getHeaderProviderWithApiKeyHeader() {
978+
Map<String, String> header = new HashMap<>();
979+
header.put(API_KEY_AUTH_HEADER_KEY, API_KEY_HEADER_VALUE);
980+
return FixedHeaderProvider.create(header);
981+
}
982+
880983
private static class FakeLogHandler extends Handler {
881984

882985
List<LogRecord> records = new ArrayList<>();

gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientContext.java

+42-18
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@
4343
import com.google.api.gax.rpc.internal.QuotaProjectIdHidingCredentials;
4444
import com.google.api.gax.tracing.ApiTracerFactory;
4545
import com.google.api.gax.tracing.BaseApiTracerFactory;
46+
import com.google.auth.ApiKeyCredentials;
4647
import com.google.auth.Credentials;
4748
import com.google.auth.oauth2.GdchCredentials;
4849
import com.google.auto.value.AutoValue;
50+
import com.google.common.annotations.VisibleForTesting;
4951
import com.google.common.collect.ImmutableList;
5052
import com.google.common.collect.ImmutableMap;
5153
import com.google.common.collect.Sets;
@@ -175,9 +177,9 @@ public static ClientContext create(StubSettings settings) throws IOException {
175177
// A valid EndpointContext should have been created in the StubSettings
176178
EndpointContext endpointContext = settings.getEndpointContext();
177179
String endpoint = endpointContext.resolvedEndpoint();
178-
180+
Credentials credentials = getCredentials(settings);
181+
// check if need to adjust credentials/endpoint/endpointContext for GDC-H
179182
String settingsGdchApiAudience = settings.getGdchApiAudience();
180-
Credentials credentials = settings.getCredentialsProvider().getCredentials();
181183
boolean usingGDCH = credentials instanceof GdchCredentials;
182184
if (usingGDCH) {
183185
// Can only determine if the GDC-H is being used via the Credentials. The Credentials object
@@ -187,22 +189,7 @@ public static ClientContext create(StubSettings settings) throws IOException {
187189
// Resolve the new endpoint with the GDC-H flow
188190
endpoint = endpointContext.resolvedEndpoint();
189191
// We recompute the GdchCredentials with the audience
190-
String audienceString;
191-
if (!Strings.isNullOrEmpty(settingsGdchApiAudience)) {
192-
audienceString = settingsGdchApiAudience;
193-
} else if (!Strings.isNullOrEmpty(endpoint)) {
194-
audienceString = endpoint;
195-
} else {
196-
throw new IllegalArgumentException("Could not infer GDCH api audience from settings");
197-
}
198-
199-
URI gdchAudienceUri;
200-
try {
201-
gdchAudienceUri = URI.create(audienceString);
202-
} catch (IllegalArgumentException ex) { // thrown when passing a malformed uri string
203-
throw new IllegalArgumentException("The GDC-H API audience string is not a valid URI", ex);
204-
}
205-
credentials = ((GdchCredentials) credentials).createWithGdchAudience(gdchAudienceUri);
192+
credentials = getGdchCredentials(settingsGdchApiAudience, endpoint, credentials);
206193
} else if (!Strings.isNullOrEmpty(settingsGdchApiAudience)) {
207194
throw new IllegalArgumentException(
208195
"GDC-H API audience can only be set when using GdchCredentials");
@@ -291,6 +278,43 @@ public static ClientContext create(StubSettings settings) throws IOException {
291278
.build();
292279
}
293280

281+
/** Determines which credentials to use. API key overrides credentials provided by provider. */
282+
private static Credentials getCredentials(StubSettings settings) throws IOException {
283+
Credentials credentials;
284+
if (settings.getApiKey() != null) {
285+
// if API key exists it becomes the default credential
286+
credentials = ApiKeyCredentials.create(settings.getApiKey());
287+
} else {
288+
credentials = settings.getCredentialsProvider().getCredentials();
289+
}
290+
return credentials;
291+
}
292+
293+
/**
294+
* Constructs a new {@link com.google.auth.Credentials} object based on credentials provided with
295+
* a GDC-H audience
296+
*/
297+
@VisibleForTesting
298+
static GdchCredentials getGdchCredentials(
299+
String settingsGdchApiAudience, String endpoint, Credentials credentials) throws IOException {
300+
String audienceString;
301+
if (!Strings.isNullOrEmpty(settingsGdchApiAudience)) {
302+
audienceString = settingsGdchApiAudience;
303+
} else if (!Strings.isNullOrEmpty(endpoint)) {
304+
audienceString = endpoint;
305+
} else {
306+
throw new IllegalArgumentException("Could not infer GDCH api audience from settings");
307+
}
308+
309+
URI gdchAudienceUri;
310+
try {
311+
gdchAudienceUri = URI.create(audienceString);
312+
} catch (IllegalArgumentException ex) { // thrown when passing a malformed uri string
313+
throw new IllegalArgumentException("The GDC-H API audience string is not a valid URI", ex);
314+
}
315+
return ((GdchCredentials) credentials).createWithGdchAudience(gdchAudienceUri);
316+
}
317+
294318
/**
295319
* Getting a header map from HeaderProvider and InternalHeaderProvider from settings with Quota
296320
* Project Id.

gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientSettings.java

+27
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ public final WatchdogProvider getWatchdogProvider() {
112112
return stubSettings.getStreamWatchdogProvider();
113113
}
114114

115+
/** Gets the API Key that should be used for authentication. */
116+
public final String getApiKey() {
117+
return stubSettings.getApiKey();
118+
}
119+
115120
/** This method is obsolete. Use {@link #getWatchdogCheckIntervalDuration()} instead. */
116121
@Nonnull
117122
@ObsoleteApi("Use getWatchdogCheckIntervalDuration() instead")
@@ -144,6 +149,7 @@ public String toString() {
144149
.add("watchdogProvider", getWatchdogProvider())
145150
.add("watchdogCheckInterval", getWatchdogCheckInterval())
146151
.add("gdchApiAudience", getGdchApiAudience())
152+
.add("apiKey", getApiKey())
147153
.toString();
148154
}
149155

@@ -302,6 +308,21 @@ public B setGdchApiAudience(@Nullable String gdchApiAudience) {
302308
return self();
303309
}
304310

311+
/**
312+
* Sets the API key. The API key will get translated to an {@link
313+
* com.google.auth.ApiKeyCredentials} and stored in {@link ClientContext}.
314+
*
315+
* <p>API Key authorization is not supported for every product. Please check the documentation
316+
* for each product to confirm if it is supported.
317+
*
318+
* <p>Note: If you set an API key and {@link CredentialsProvider} in the same ClientSettings the
319+
* API key will override any credentials provided.
320+
*/
321+
public B setApiKey(String apiKey) {
322+
stubSettings.setApiKey(apiKey);
323+
return self();
324+
}
325+
305326
/**
306327
* Gets the ExecutorProvider that was previously set on this Builder. This ExecutorProvider is
307328
* to use for running asynchronous API call logic (such as retries and long-running operations),
@@ -364,6 +385,11 @@ public WatchdogProvider getWatchdogProvider() {
364385
return stubSettings.getStreamWatchdogProvider();
365386
}
366387

388+
/** Gets the API Key that was previously set on this Builder. */
389+
public String getApiKey() {
390+
return stubSettings.getApiKey();
391+
}
392+
367393
/** This method is obsolete. Use {@link #getWatchdogCheckIntervalDuration()} instead */
368394
@Nullable
369395
@ObsoleteApi("Use getWatchdogCheckIntervalDuration() instead")
@@ -405,6 +431,7 @@ public String toString() {
405431
.add("watchdogProvider", getWatchdogProvider())
406432
.add("watchdogCheckInterval", getWatchdogCheckIntervalDuration())
407433
.add("gdchApiAudience", getGdchApiAudience())
434+
.add("apiKey", getApiKey())
408435
.toString();
409436
}
410437
}

0 commit comments

Comments
 (0)