Skip to content

Commit b06f65a

Browse files
SNOW-2032706 Implement AWS SDK strategy for GCS
1 parent 9acf849 commit b06f65a

File tree

4 files changed

+381
-3
lines changed

4 files changed

+381
-3
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package net.snowflake.client.jdbc.cloud.storage;
2+
3+
import com.amazonaws.SignableRequest;
4+
import com.amazonaws.auth.AWS4Signer;
5+
import com.amazonaws.auth.AWSCredentials;
6+
import com.amazonaws.http.HttpMethodName;
7+
import java.util.HashMap;
8+
import java.util.Map;
9+
import java.util.stream.Collectors;
10+
11+
public class AwsSdkGCPSigner extends AWS4Signer {
12+
private static final Map<String, String> headerMap =
13+
new HashMap<String, String>() {
14+
{
15+
put("x-amz-storage-class", "x-goog-storage-class");
16+
put("x-amz-acl", "x-goog-acl");
17+
put("x-amz-date", "x-goog-date");
18+
put("x-amz-copy-source", "x-goog-copy-source");
19+
put("x-amz-metadata-directive", "x-goog-metadata-directive");
20+
put("x-amz-copy-source-if-match", "x-goog-copy-source-if-match");
21+
put("x-amz-copy-source-if-none-match", "x-goog-copy-source-if-none-match");
22+
put("x-amz-copy-source-if-unmodified-since", "x-goog-copy-source-if-unmodified-since");
23+
put("x-amz-copy-source-if-modified-since", "x-goog-copy-source-if-modified-since");
24+
}
25+
};
26+
27+
@Override
28+
public void sign(SignableRequest<?> request, AWSCredentials credentials) {
29+
if (credentials.getAWSAccessKeyId() != null && !"".equals(credentials.getAWSAccessKeyId())) {
30+
request.addHeader("Authorization", "Bearer " + credentials.getAWSAccessKeyId());
31+
}
32+
33+
if (request.getHttpMethod() == HttpMethodName.GET) {
34+
request.addHeader("Accept-Encoding", "gzip,deflate");
35+
}
36+
37+
Map<String, String> headerCopy =
38+
request.getHeaders().entrySet().stream()
39+
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
40+
41+
for (Map.Entry<String, String> entry : headerCopy.entrySet()) {
42+
String entryKey = entry.getKey().toLowerCase();
43+
if (headerMap.containsKey(entryKey)) {
44+
request.addHeader(headerMap.get(entryKey), entry.getValue());
45+
} else if (entryKey.startsWith("x-amz-meta-")) {
46+
request.addHeader(entryKey.replace("x-amz-meta-", "x-goog-meta-"), entry.getValue());
47+
}
48+
}
49+
}
50+
}
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
package net.snowflake.client.jdbc.cloud.storage;
2+
3+
import static net.snowflake.client.core.Constants.CLOUD_STORAGE_CREDENTIALS_EXPIRED;
4+
import static net.snowflake.client.jdbc.SnowflakeUtil.createDefaultExecutorService;
5+
import static net.snowflake.client.jdbc.cloud.storage.SnowflakeS3Client.EXPIRED_AWS_TOKEN_ERROR_CODE;
6+
7+
import com.amazonaws.AmazonClientException;
8+
import com.amazonaws.AmazonServiceException;
9+
import com.amazonaws.ClientConfiguration;
10+
import com.amazonaws.auth.AWSStaticCredentialsProvider;
11+
import com.amazonaws.auth.BasicAWSCredentials;
12+
import com.amazonaws.auth.SignerFactory;
13+
import com.amazonaws.client.builder.AwsClientBuilder;
14+
import com.amazonaws.services.s3.AmazonS3;
15+
import com.amazonaws.services.s3.AmazonS3Client;
16+
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
17+
import com.amazonaws.services.s3.model.AmazonS3Exception;
18+
import com.amazonaws.services.s3.model.ObjectListing;
19+
import com.amazonaws.services.s3.model.ObjectMetadata;
20+
import com.amazonaws.services.s3.model.S3Object;
21+
import com.amazonaws.services.s3.transfer.Download;
22+
import com.amazonaws.services.s3.transfer.TransferManager;
23+
import com.amazonaws.services.s3.transfer.TransferManagerBuilder;
24+
import com.amazonaws.services.s3.transfer.Upload;
25+
import java.io.File;
26+
import java.io.InputStream;
27+
import java.util.List;
28+
import java.util.Map;
29+
import java.util.Optional;
30+
import java.util.stream.Collectors;
31+
import net.snowflake.client.core.HeaderCustomizerHttpRequestInterceptor;
32+
import net.snowflake.client.core.SFBaseSession;
33+
import net.snowflake.client.core.SFSession;
34+
import net.snowflake.client.jdbc.ErrorCode;
35+
import net.snowflake.client.jdbc.HttpHeadersCustomizer;
36+
import net.snowflake.client.jdbc.SnowflakeFileTransferAgent;
37+
import net.snowflake.client.jdbc.SnowflakeSQLException;
38+
import net.snowflake.client.jdbc.SnowflakeSQLLoggedException;
39+
import net.snowflake.client.jdbc.SnowflakeUtil;
40+
import net.snowflake.client.log.SFLogger;
41+
import net.snowflake.client.log.SFLoggerFactory;
42+
import net.snowflake.client.util.SFPair;
43+
import net.snowflake.common.core.SqlState;
44+
import org.apache.http.HttpStatus;
45+
46+
public class GCSAccessStrategyAwsSdk implements GCSAccessStrategy {
47+
private static final SFLogger logger = SFLoggerFactory.getLogger(GCSAccessStrategyAwsSdk.class);
48+
private final AmazonS3 amazonClient;
49+
50+
GCSAccessStrategyAwsSdk(StageInfo stage, SFBaseSession session) throws SnowflakeSQLException {
51+
String accessToken = (String) stage.getCredentials().get("GCS_ACCESS_TOKEN");
52+
53+
Optional<String> oEndpoint = stage.gcsCustomEndpoint();
54+
String endpoint = "storage.googleapis.com";
55+
if (oEndpoint.isPresent()) {
56+
endpoint = oEndpoint.get();
57+
}
58+
if (endpoint.startsWith("https://")) {
59+
endpoint = endpoint.replaceFirst("https://", "");
60+
}
61+
if (endpoint.startsWith(stage.getStorageAccount())) {
62+
endpoint = endpoint.replaceFirst(stage.getStorageAccount() + ".", "");
63+
}
64+
65+
AmazonS3ClientBuilder amazonS3Builder =
66+
AmazonS3Client.builder()
67+
.withPathStyleAccessEnabled(false)
68+
.withEndpointConfiguration(
69+
new AwsClientBuilder.EndpointConfiguration(endpoint, "auto"));
70+
71+
ClientConfiguration clientConfig = new ClientConfiguration();
72+
73+
SignerFactory.registerSigner(
74+
"net.snowflake.client.jdbc.cloud.storage.AwsSdkGCPSigner",
75+
net.snowflake.client.jdbc.cloud.storage.AwsSdkGCPSigner.class);
76+
clientConfig.setSignerOverride("net.snowflake.client.jdbc.cloud.storage.AwsSdkGCPSigner");
77+
78+
clientConfig
79+
.getApacheHttpClientConfig()
80+
.setSslSocketFactory(SnowflakeS3Client.getSSLConnectionSocketFactory());
81+
if (session != null) {
82+
S3HttpUtil.setProxyForS3(session.getHttpClientKey(), clientConfig);
83+
} else {
84+
S3HttpUtil.setSessionlessProxyForS3(stage.getProxyProperties(), clientConfig);
85+
}
86+
87+
if (session instanceof SFSession) {
88+
List<HttpHeadersCustomizer> headersCustomizers =
89+
((SFSession) session).getHttpHeadersCustomizers();
90+
if (headersCustomizers != null && !headersCustomizers.isEmpty()) {
91+
amazonS3Builder.withRequestHandlers(
92+
new HeaderCustomizerHttpRequestInterceptor(headersCustomizers));
93+
}
94+
}
95+
96+
if (accessToken != null) {
97+
amazonS3Builder.withCredentials(
98+
new AWSStaticCredentialsProvider(new BasicAWSCredentials(accessToken, "")));
99+
} else {
100+
logger.debug("no credentials provided, configuring bucket client without credentials");
101+
amazonS3Builder.withCredentials(
102+
new AWSStaticCredentialsProvider(new BasicAWSCredentials("", "")));
103+
}
104+
105+
amazonClient = amazonS3Builder.withClientConfiguration(clientConfig).build();
106+
}
107+
108+
@Override
109+
public StorageObjectSummaryCollection listObjects(String remoteStorageLocation, String prefix) {
110+
ObjectListing objListing = amazonClient.listObjects(remoteStorageLocation, prefix);
111+
112+
return new StorageObjectSummaryCollection(objListing.getObjectSummaries());
113+
}
114+
115+
@Override
116+
public StorageObjectMetadata getObjectMetadata(String remoteStorageLocation, String prefix) {
117+
ObjectMetadata meta = amazonClient.getObjectMetadata(remoteStorageLocation, prefix);
118+
119+
Map<String, String> userMetadata =
120+
meta.getRawMetadata().entrySet().stream()
121+
.filter(entry -> entry.getKey().startsWith("x-goog-meta-"))
122+
.collect(
123+
Collectors.toMap(
124+
e -> e.getKey().replaceFirst("x-goog-meta-", ""),
125+
e -> e.getValue().toString()));
126+
127+
meta.setUserMetadata(userMetadata);
128+
return new S3ObjectMetadata(meta);
129+
}
130+
131+
@Override
132+
public Map<String, String> download(
133+
int parallelism, String remoteStorageLocation, String stageFilePath, File localFile)
134+
throws InterruptedException {
135+
136+
logger.debug(
137+
"Staring download of file from S3 stage path: {} to {}",
138+
stageFilePath,
139+
localFile.getAbsolutePath());
140+
TransferManager tx;
141+
142+
logger.debug("Creating executor service for transfer manager with {} threads", parallelism);
143+
144+
// download files from s3
145+
tx =
146+
TransferManagerBuilder.standard()
147+
.withS3Client(amazonClient)
148+
.withDisableParallelDownloads(true)
149+
.withExecutorFactory(
150+
() -> createDefaultExecutorService("s3-transfer-manager-downloader-", parallelism))
151+
.build();
152+
153+
Download myDownload = tx.download(remoteStorageLocation, stageFilePath, localFile);
154+
155+
// Pull object metadata from S3
156+
StorageObjectMetadata meta = this.getObjectMetadata(remoteStorageLocation, stageFilePath);
157+
158+
Map<String, String> metaMap = SnowflakeUtil.createCaseInsensitiveMap(meta.getUserMetadata());
159+
myDownload.waitForCompletion();
160+
161+
return metaMap;
162+
}
163+
164+
@Override
165+
public SFPair<InputStream, Map<String, String>> downloadToStream(
166+
String remoteStorageLocation, String stageFilePath, boolean isEncrypting) {
167+
S3Object file = amazonClient.getObject(remoteStorageLocation, stageFilePath);
168+
ObjectMetadata meta = amazonClient.getObjectMetadata(remoteStorageLocation, stageFilePath);
169+
InputStream stream = file.getObjectContent();
170+
171+
Map<String, String> metaMap = SnowflakeUtil.createCaseInsensitiveMap(meta.getUserMetadata());
172+
173+
return SFPair.of(stream, metaMap);
174+
}
175+
176+
@Override
177+
public void uploadWithDownScopedToken(
178+
int parallelism,
179+
String remoteStorageLocation,
180+
String destFileName,
181+
String contentEncoding,
182+
Map<String, String> metadata,
183+
long contentLength,
184+
InputStream content,
185+
String queryId)
186+
throws InterruptedException {
187+
// we need to assemble an ObjectMetadata object here, as we are not using S3ObjectMatadata for
188+
// GCS
189+
ObjectMetadata s3Meta = new ObjectMetadata();
190+
if (contentEncoding != null) {
191+
s3Meta.setContentEncoding(contentEncoding);
192+
}
193+
s3Meta.setContentLength(contentLength);
194+
s3Meta.setUserMetadata(metadata);
195+
196+
TransferManager tx;
197+
logger.debug("Creating executor service for transfer" + "manager with {} threads", parallelism);
198+
199+
// upload files to s3
200+
tx =
201+
TransferManagerBuilder.standard()
202+
.withS3Client(amazonClient)
203+
.withExecutorFactory(
204+
() -> createDefaultExecutorService("s3-transfer-manager-uploader-", parallelism))
205+
.build();
206+
207+
final Upload myUpload;
208+
209+
myUpload = tx.upload(remoteStorageLocation, destFileName, content, s3Meta);
210+
myUpload.waitForCompletion();
211+
212+
logger.info("Uploaded data from input stream to S3 location: {}.", destFileName);
213+
}
214+
215+
private static boolean isClientException400Or404(Exception ex) {
216+
if (ex instanceof AmazonServiceException) {
217+
AmazonServiceException asEx = (AmazonServiceException) (ex);
218+
return asEx.getStatusCode() == HttpStatus.SC_NOT_FOUND
219+
|| asEx.getStatusCode() == HttpStatus.SC_BAD_REQUEST;
220+
}
221+
return false;
222+
}
223+
224+
@Override
225+
public boolean handleStorageException(
226+
Exception ex,
227+
int retryCount,
228+
String operation,
229+
SFSession session,
230+
String command,
231+
String queryId,
232+
SnowflakeGCSClient gcsClient)
233+
throws SnowflakeSQLException {
234+
if (ex instanceof AmazonClientException) {
235+
logger.debug("GCSAccessStrategyAwsSdk: " + ex.getMessage());
236+
237+
if (retryCount > gcsClient.getMaxRetries() || isClientException400Or404(ex)) {
238+
String extendedRequestId = "none";
239+
240+
if (ex instanceof AmazonS3Exception) {
241+
AmazonS3Exception ex1 = (AmazonS3Exception) ex;
242+
extendedRequestId = ex1.getExtendedRequestId();
243+
}
244+
245+
if (ex instanceof AmazonServiceException) {
246+
AmazonServiceException ex1 = (AmazonServiceException) ex;
247+
248+
// The AWS credentials might have expired when server returns error 400 and
249+
// does not return the ExpiredToken error code.
250+
// If session is null we cannot renew the token so throw the exception
251+
if (ex1.getStatusCode() == HttpStatus.SC_BAD_REQUEST && session != null) {
252+
SnowflakeFileTransferAgent.renewExpiredToken(session, command, gcsClient);
253+
} else {
254+
throw new SnowflakeSQLLoggedException(
255+
queryId,
256+
session,
257+
SqlState.SYSTEM_ERROR,
258+
ErrorCode.S3_OPERATION_ERROR.getMessageCode(),
259+
ex1,
260+
operation,
261+
ex1.getErrorType().toString(),
262+
ex1.getErrorCode(),
263+
ex1.getMessage(),
264+
ex1.getRequestId(),
265+
extendedRequestId);
266+
}
267+
268+
} else {
269+
throw new SnowflakeSQLLoggedException(
270+
queryId,
271+
session,
272+
SqlState.SYSTEM_ERROR,
273+
ErrorCode.AWS_CLIENT_ERROR.getMessageCode(),
274+
ex,
275+
operation,
276+
ex.getMessage());
277+
}
278+
} else {
279+
logger.debug(
280+
"Encountered exception ({}) during {}, retry count: {}",
281+
ex.getMessage(),
282+
operation,
283+
retryCount);
284+
logger.debug("Stack trace: ", ex);
285+
286+
// exponential backoff up to a limit
287+
int backoffInMillis = gcsClient.getRetryBackoffMin();
288+
289+
if (retryCount > 1) {
290+
backoffInMillis <<= (Math.min(retryCount - 1, gcsClient.getRetryBackoffMaxExponent()));
291+
}
292+
293+
try {
294+
logger.debug("Sleep for {} milliseconds before retry", backoffInMillis);
295+
296+
Thread.sleep(backoffInMillis);
297+
} catch (InterruptedException ex1) {
298+
// ignore
299+
}
300+
301+
// If the exception indicates that the AWS token has expired,
302+
// we need to refresh our S3 client with the new token
303+
if (ex instanceof AmazonS3Exception) {
304+
AmazonS3Exception s3ex = (AmazonS3Exception) ex;
305+
if (s3ex.getErrorCode().equalsIgnoreCase(EXPIRED_AWS_TOKEN_ERROR_CODE)) {
306+
// If session is null we cannot renew the token so throw the ExpiredToken exception
307+
if (session != null) {
308+
SnowflakeFileTransferAgent.renewExpiredToken(session, command, gcsClient);
309+
} else {
310+
throw new SnowflakeSQLException(
311+
queryId,
312+
s3ex.getErrorCode(),
313+
CLOUD_STORAGE_CREDENTIALS_EXPIRED,
314+
"S3 credentials have expired");
315+
}
316+
}
317+
}
318+
}
319+
return true;
320+
} else {
321+
return false;
322+
}
323+
}
324+
}

src/main/java/net/snowflake/client/jdbc/cloud/storage/SnowflakeGCSClient.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1218,7 +1218,11 @@ private void setupGCSClient(
12181218
logger.debug("Setting up the GCS client ", false);
12191219

12201220
try {
1221-
this.gcsAccessStrategy = new GCSDefaultAccessStrategy(stage, session);
1221+
if (stage.getUseVirtualUrl()) {
1222+
this.gcsAccessStrategy = new GCSAccessStrategyAwsSdk(stage, session);
1223+
} else {
1224+
this.gcsAccessStrategy = new GCSDefaultAccessStrategy(stage, session);
1225+
}
12221226

12231227
if (encMat != null) {
12241228
byte[] decodedKey = Base64.getDecoder().decode(encMat.getQueryStageMasterKey());

0 commit comments

Comments
 (0)