|
34 | 34 | import static com.google.common.base.Preconditions.checkArgument;
|
35 | 35 | import static com.google.common.truth.Truth.assertThat;
|
36 | 36 | import static org.junit.jupiter.api.Assertions.assertEquals;
|
| 37 | +import static org.junit.jupiter.api.Assertions.assertNull; |
37 | 38 | import static org.junit.jupiter.api.Assertions.assertThrows;
|
38 | 39 |
|
39 | 40 | import com.google.api.core.ApiFunction;
|
40 | 41 | import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
|
| 42 | +import com.google.api.gax.rpc.FixedHeaderProvider; |
41 | 43 | import com.google.api.gax.rpc.HeaderProvider;
|
42 | 44 | import com.google.api.gax.rpc.TransportChannel;
|
43 | 45 | import com.google.api.gax.rpc.TransportChannelProvider;
|
44 | 46 | import com.google.api.gax.rpc.internal.EnvironmentProvider;
|
45 | 47 | import com.google.api.gax.rpc.mtls.AbstractMtlsTransportChannelTest;
|
46 | 48 | import com.google.api.gax.rpc.mtls.MtlsProvider;
|
| 49 | +import com.google.auth.ApiKeyCredentials; |
47 | 50 | import com.google.auth.Credentials;
|
| 51 | +import com.google.auth.http.AuthHttpConstants; |
48 | 52 | import com.google.auth.oauth2.CloudShellCredentials;
|
49 | 53 | import com.google.auth.oauth2.ComputeEngineCredentials;
|
50 | 54 | import com.google.common.collect.ImmutableList;
|
|
79 | 83 |
|
80 | 84 | class InstantiatingGrpcChannelProviderTest extends AbstractMtlsTransportChannelTest {
|
81 | 85 | private static final String DEFAULT_ENDPOINT = "test.googleapis.com:443";
|
| 86 | + private static final String API_KEY_HEADER_VALUE = "fake_api_key_2"; |
| 87 | + private static final String API_KEY_AUTH_HEADER_KEY = "x-goog-api-key"; |
82 | 88 | private static String originalOSName;
|
83 | 89 | private ComputeEngineCredentials computeEngineCredentials;
|
84 | 90 |
|
@@ -877,6 +883,103 @@ public void canUseDirectPath_nonGDUUniverseDomain() {
|
877 | 883 | Truth.assertThat(provider.canUseDirectPath()).isFalse();
|
878 | 884 | }
|
879 | 885 |
|
| 886 | + @Test |
| 887 | + void providerInitializedWithNonConflictingHeaders_retainsHeaders() { |
| 888 | + InstantiatingGrpcChannelProvider.Builder builder = |
| 889 | + InstantiatingGrpcChannelProvider.newBuilder() |
| 890 | + .setHeaderProvider(getHeaderProviderWithApiKeyHeader()) |
| 891 | + .setEndpoint("test.random.com:443"); |
| 892 | + |
| 893 | + InstantiatingGrpcChannelProvider provider = builder.build(); |
| 894 | + |
| 895 | + assertEquals(1, provider.headersWithDuplicatesRemoved.size()); |
| 896 | + assertEquals( |
| 897 | + API_KEY_HEADER_VALUE, provider.headersWithDuplicatesRemoved.get(API_KEY_AUTH_HEADER_KEY)); |
| 898 | + } |
| 899 | + |
| 900 | + @Test |
| 901 | + void providersInitializedWithConflictingApiKeyCredentialHeaders_removesDuplicates() { |
| 902 | + String correctApiKey = "fake_api_key"; |
| 903 | + ApiKeyCredentials apiKeyCredentials = ApiKeyCredentials.create(correctApiKey); |
| 904 | + InstantiatingGrpcChannelProvider.Builder builder = |
| 905 | + InstantiatingGrpcChannelProvider.newBuilder() |
| 906 | + .setCredentials(apiKeyCredentials) |
| 907 | + .setHeaderProvider(getHeaderProviderWithApiKeyHeader()) |
| 908 | + .setEndpoint("test.random.com:443"); |
| 909 | + |
| 910 | + InstantiatingGrpcChannelProvider provider = builder.build(); |
| 911 | + |
| 912 | + assertEquals(0, provider.headersWithDuplicatesRemoved.size()); |
| 913 | + assertNull(provider.headersWithDuplicatesRemoved.get(API_KEY_AUTH_HEADER_KEY)); |
| 914 | + } |
| 915 | + |
| 916 | + @Test |
| 917 | + void providersInitializedWithConflictingNonApiKeyCredentialHeaders_doesNotRemoveDuplicates() { |
| 918 | + String authProvidedHeader = "Bearer token"; |
| 919 | + Map<String, String> header = new HashMap<>(); |
| 920 | + header.put(AuthHttpConstants.AUTHORIZATION, authProvidedHeader); |
| 921 | + InstantiatingGrpcChannelProvider.Builder builder = |
| 922 | + InstantiatingGrpcChannelProvider.newBuilder() |
| 923 | + .setCredentials(computeEngineCredentials) |
| 924 | + .setHeaderProvider(FixedHeaderProvider.create(header)) |
| 925 | + .setEndpoint("test.random.com:443"); |
| 926 | + |
| 927 | + InstantiatingGrpcChannelProvider provider = builder.build(); |
| 928 | + |
| 929 | + assertEquals(1, provider.headersWithDuplicatesRemoved.size()); |
| 930 | + assertEquals( |
| 931 | + authProvidedHeader, |
| 932 | + provider.headersWithDuplicatesRemoved.get(AuthHttpConstants.AUTHORIZATION)); |
| 933 | + } |
| 934 | + |
| 935 | + @Test |
| 936 | + void buildProvider_handlesNullHeaderProvider() { |
| 937 | + InstantiatingGrpcChannelProvider.Builder builder = |
| 938 | + InstantiatingGrpcChannelProvider.newBuilder().setEndpoint("test.random.com:443"); |
| 939 | + |
| 940 | + InstantiatingGrpcChannelProvider provider = builder.build(); |
| 941 | + |
| 942 | + assertEquals(0, provider.headersWithDuplicatesRemoved.size()); |
| 943 | + } |
| 944 | + |
| 945 | + @Test |
| 946 | + void buildProvider_handlesNullCredentialsMetadataRequest() throws IOException { |
| 947 | + Credentials credentials = Mockito.mock(Credentials.class); |
| 948 | + Mockito.when(credentials.getRequestMetadata()).thenReturn(null); |
| 949 | + InstantiatingGrpcChannelProvider.Builder builder = |
| 950 | + InstantiatingGrpcChannelProvider.newBuilder() |
| 951 | + .setHeaderProvider(getHeaderProviderWithApiKeyHeader()) |
| 952 | + .setEndpoint("test.random.com:443"); |
| 953 | + |
| 954 | + InstantiatingGrpcChannelProvider provider = builder.build(); |
| 955 | + |
| 956 | + assertEquals(1, provider.headersWithDuplicatesRemoved.size()); |
| 957 | + assertEquals( |
| 958 | + API_KEY_HEADER_VALUE, provider.headersWithDuplicatesRemoved.get(API_KEY_AUTH_HEADER_KEY)); |
| 959 | + } |
| 960 | + |
| 961 | + @Test |
| 962 | + void buildProvider_handlesErrorRetrievingCredentialsMetadataRequest() throws IOException { |
| 963 | + Credentials credentials = Mockito.mock(Credentials.class); |
| 964 | + Mockito.when(credentials.getRequestMetadata()) |
| 965 | + .thenThrow(new IOException("Error getting request metadata")); |
| 966 | + InstantiatingGrpcChannelProvider.Builder builder = |
| 967 | + InstantiatingGrpcChannelProvider.newBuilder() |
| 968 | + .setHeaderProvider(getHeaderProviderWithApiKeyHeader()) |
| 969 | + .setEndpoint("test.random.com:443"); |
| 970 | + InstantiatingGrpcChannelProvider provider = builder.build(); |
| 971 | + |
| 972 | + assertEquals(1, provider.headersWithDuplicatesRemoved.size()); |
| 973 | + assertEquals( |
| 974 | + API_KEY_HEADER_VALUE, provider.headersWithDuplicatesRemoved.get(API_KEY_AUTH_HEADER_KEY)); |
| 975 | + } |
| 976 | + |
| 977 | + private FixedHeaderProvider getHeaderProviderWithApiKeyHeader() { |
| 978 | + Map<String, String> header = new HashMap<>(); |
| 979 | + header.put(API_KEY_AUTH_HEADER_KEY, API_KEY_HEADER_VALUE); |
| 980 | + return FixedHeaderProvider.create(header); |
| 981 | + } |
| 982 | + |
880 | 983 | private static class FakeLogHandler extends Handler {
|
881 | 984 |
|
882 | 985 | List<LogRecord> records = new ArrayList<>();
|
|
0 commit comments