Skip to content

Commit ca55888

Browse files
SNOW-2032706 Implement AWS SDK strategy for GCS
1 parent 4086086 commit ca55888

File tree

4 files changed

+359
-3
lines changed

4 files changed

+359
-3
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package net.snowflake.client.jdbc.cloud.storage;
2+
3+
import com.amazonaws.SignableRequest;
4+
import com.amazonaws.auth.AWS4Signer;
5+
import com.amazonaws.auth.AWSCredentials;
6+
import com.amazonaws.http.HttpMethodName;
7+
import java.util.HashMap;
8+
import java.util.Map;
9+
import java.util.stream.Collectors;
10+
11+
public class AwsSdkGCPSigner extends AWS4Signer {
12+
private static final Map<String, String> headerMap =
13+
new HashMap<String, String>() {
14+
{
15+
put("x-amz-storage-class", "x-goog-storage-class");
16+
put("x-amz-acl", "x-goog-acl");
17+
put("x-amz-date", "x-goog-date");
18+
put("x-amz-copy-source", "x-goog-copy-source");
19+
put("x-amz-metadata-directive", "x-goog-metadata-directive");
20+
put("x-amz-copy-source-if-match", "x-goog-copy-source-if-match");
21+
put("x-amz-copy-source-if-none-match", "x-goog-copy-source-if-none-match");
22+
put("x-amz-copy-source-if-unmodified-since", "x-goog-copy-source-if-unmodified-since");
23+
put("x-amz-copy-source-if-modified-since", "x-goog-copy-source-if-modified-since");
24+
}
25+
};
26+
27+
@Override
28+
public void sign(SignableRequest<?> request, AWSCredentials credentials) {
29+
if (credentials.getAWSAccessKeyId() != null && !"".equals(credentials.getAWSAccessKeyId())) {
30+
request.addHeader("Authorization", "Bearer " + credentials.getAWSAccessKeyId());
31+
}
32+
33+
if (request.getHttpMethod() == HttpMethodName.GET) {
34+
request.addHeader("Accept-Encoding", "gzip,deflate");
35+
}
36+
37+
Map<String, String> headerCopy =
38+
request.getHeaders().entrySet().stream()
39+
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
40+
41+
for (Map.Entry<String, String> entry : headerCopy.entrySet()) {
42+
String entryKey = entry.getKey().toLowerCase();
43+
if (headerMap.containsKey(entryKey)) {
44+
request.addHeader(headerMap.get(entryKey), entry.getValue());
45+
} else if (entryKey.startsWith("x-amz-meta-")) {
46+
request.addHeader(entryKey.replace("x-amz-meta-", "x-goog-meta-"), entry.getValue());
47+
}
48+
}
49+
}
50+
}
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
package net.snowflake.client.jdbc.cloud.storage;
2+
3+
import static net.snowflake.client.core.Constants.CLOUD_STORAGE_CREDENTIALS_EXPIRED;
4+
import static net.snowflake.client.jdbc.SnowflakeUtil.createDefaultExecutorService;
5+
import static net.snowflake.client.jdbc.cloud.storage.SnowflakeS3Client.EXPIRED_AWS_TOKEN_ERROR_CODE;
6+
7+
import com.amazonaws.AmazonClientException;
8+
import com.amazonaws.AmazonServiceException;
9+
import com.amazonaws.ClientConfiguration;
10+
import com.amazonaws.auth.AWSStaticCredentialsProvider;
11+
import com.amazonaws.auth.BasicAWSCredentials;
12+
import com.amazonaws.auth.SignerFactory;
13+
import com.amazonaws.client.builder.AwsClientBuilder;
14+
import com.amazonaws.services.s3.AmazonS3;
15+
import com.amazonaws.services.s3.AmazonS3Client;
16+
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
17+
import com.amazonaws.services.s3.model.AmazonS3Exception;
18+
import com.amazonaws.services.s3.model.ObjectListing;
19+
import com.amazonaws.services.s3.model.ObjectMetadata;
20+
import com.amazonaws.services.s3.model.S3Object;
21+
import com.amazonaws.services.s3.transfer.Download;
22+
import com.amazonaws.services.s3.transfer.TransferManager;
23+
import com.amazonaws.services.s3.transfer.TransferManagerBuilder;
24+
import com.amazonaws.services.s3.transfer.Upload;
25+
import java.io.File;
26+
import java.io.InputStream;
27+
import java.util.Map;
28+
import java.util.Optional;
29+
import java.util.stream.Collectors;
30+
import net.snowflake.client.core.SFSession;
31+
import net.snowflake.client.jdbc.ErrorCode;
32+
import net.snowflake.client.jdbc.SnowflakeFileTransferAgent;
33+
import net.snowflake.client.jdbc.SnowflakeSQLException;
34+
import net.snowflake.client.jdbc.SnowflakeSQLLoggedException;
35+
import net.snowflake.client.jdbc.SnowflakeUtil;
36+
import net.snowflake.client.log.SFLogger;
37+
import net.snowflake.client.log.SFLoggerFactory;
38+
import net.snowflake.client.util.SFPair;
39+
import net.snowflake.common.core.SqlState;
40+
import org.apache.http.HttpStatus;
41+
42+
public class GCSAccessStrategyAwsSdk implements GCSAccessStrategy {
43+
private static final SFLogger logger = SFLoggerFactory.getLogger(GCSAccessStrategyAwsSdk.class);
44+
private final AmazonS3 amazonClient;
45+
46+
GCSAccessStrategyAwsSdk(StageInfo stage) {
47+
String accessToken = (String) stage.getCredentials().get("GCS_ACCESS_TOKEN");
48+
49+
// todo: fix client configuration needed
50+
/*
51+
clientConfig.getApacheHttpClientConfig().setSslSocketFactory(getSSLConnectionSocketFactory());
52+
if (session != null) {
53+
S3HttpUtil.setProxyForS3(session.getHttpClientKey(), clientConfig);
54+
} else {
55+
S3HttpUtil.setSessionlessProxyForS3(proxyProperties, clientConfig);
56+
}*/
57+
58+
Optional<String> oEndpoint = stage.gcsCustomEndpoint();
59+
String endpoint = "storage.googleapis.com";
60+
if (oEndpoint.isPresent()) {
61+
endpoint = oEndpoint.get();
62+
}
63+
if (endpoint.startsWith("https://")) {
64+
endpoint = endpoint.replaceFirst("https://", "");
65+
}
66+
if (endpoint.startsWith(stage.getStorageAccount())) {
67+
endpoint = endpoint.replaceFirst(stage.getStorageAccount() + ".", "");
68+
}
69+
70+
AmazonS3ClientBuilder amazonS3Builder =
71+
AmazonS3Client.builder()
72+
.withPathStyleAccessEnabled(false)
73+
.withEndpointConfiguration(
74+
new AwsClientBuilder.EndpointConfiguration(endpoint, "auto"));
75+
76+
ClientConfiguration clientConfig = new ClientConfiguration();
77+
78+
SignerFactory.registerSigner(
79+
"net.snowflake.client.jdbc.cloud.storage.AwsSdkGCPSigner",
80+
net.snowflake.client.jdbc.cloud.storage.AwsSdkGCPSigner.class);
81+
clientConfig.setSignerOverride("net.snowflake.client.jdbc.cloud.storage.AwsSdkGCPSigner");
82+
83+
if (accessToken != null) {
84+
amazonS3Builder.withCredentials(
85+
new AWSStaticCredentialsProvider(new BasicAWSCredentials(accessToken, "")));
86+
} else {
87+
logger.debug("no credentials provided, configuring bucket client without credentials");
88+
amazonS3Builder.withCredentials(
89+
new AWSStaticCredentialsProvider(new BasicAWSCredentials("", "")));
90+
}
91+
92+
amazonClient = amazonS3Builder.withClientConfiguration(clientConfig).build();
93+
}
94+
95+
@Override
96+
public StorageObjectSummaryCollection listObjects(String remoteStorageLocation, String prefix) {
97+
ObjectListing objListing = amazonClient.listObjects(remoteStorageLocation, prefix);
98+
99+
return new StorageObjectSummaryCollection(objListing.getObjectSummaries());
100+
}
101+
102+
@Override
103+
public StorageObjectMetadata getObjectMetadata(String remoteStorageLocation, String prefix) {
104+
ObjectMetadata meta = amazonClient.getObjectMetadata(remoteStorageLocation, prefix);
105+
106+
Map<String, String> userMetadata =
107+
meta.getRawMetadata().entrySet().stream()
108+
.filter(entry -> entry.getKey().startsWith("x-goog-meta-"))
109+
.collect(
110+
Collectors.toMap(
111+
e -> e.getKey().replaceFirst("x-goog-meta-", ""),
112+
e -> e.getValue().toString()));
113+
114+
meta.setUserMetadata(userMetadata);
115+
return new S3ObjectMetadata(meta);
116+
}
117+
118+
@Override
119+
public Map<String, String> download(
120+
int parallelism, String remoteStorageLocation, String stageFilePath, File localFile)
121+
throws InterruptedException {
122+
123+
logger.debug(
124+
"Staring download of file from S3 stage path: {} to {}",
125+
stageFilePath,
126+
localFile.getAbsolutePath());
127+
TransferManager tx;
128+
129+
logger.debug("Creating executor service for transfer manager with {} threads", parallelism);
130+
131+
// download files from s3
132+
tx =
133+
TransferManagerBuilder.standard()
134+
.withS3Client(amazonClient)
135+
.withDisableParallelDownloads(true)
136+
.withExecutorFactory(
137+
() -> createDefaultExecutorService("s3-transfer-manager-downloader-", parallelism))
138+
.build();
139+
140+
Download myDownload = tx.download(remoteStorageLocation, stageFilePath, localFile);
141+
142+
// Pull object metadata from S3
143+
StorageObjectMetadata meta = this.getObjectMetadata(remoteStorageLocation, stageFilePath);
144+
145+
Map<String, String> metaMap = SnowflakeUtil.createCaseInsensitiveMap(meta.getUserMetadata());
146+
myDownload.waitForCompletion();
147+
148+
return metaMap;
149+
}
150+
151+
@Override
152+
public SFPair<InputStream, Map<String, String>> downloadToStream(
153+
String remoteStorageLocation, String stageFilePath, boolean isEncrypting) {
154+
S3Object file = amazonClient.getObject(remoteStorageLocation, stageFilePath);
155+
ObjectMetadata meta = amazonClient.getObjectMetadata(remoteStorageLocation, stageFilePath);
156+
InputStream stream = file.getObjectContent();
157+
158+
Map<String, String> metaMap = SnowflakeUtil.createCaseInsensitiveMap(meta.getUserMetadata());
159+
160+
return SFPair.of(stream, metaMap);
161+
}
162+
163+
@Override
164+
public void uploadWithDownScopedToken(
165+
int parallelism,
166+
String remoteStorageLocation,
167+
String destFileName,
168+
String contentEncoding,
169+
Map<String, String> metadata,
170+
long contentLength,
171+
InputStream content,
172+
String queryId)
173+
throws InterruptedException {
174+
// we need to assemble an ObjectMetadata object here, as we are not using S3ObjectMatadata for
175+
// GCS
176+
ObjectMetadata s3Meta = new ObjectMetadata();
177+
if (contentEncoding != null) {
178+
s3Meta.setContentEncoding(contentEncoding);
179+
}
180+
s3Meta.setContentLength(contentLength);
181+
s3Meta.setUserMetadata(metadata);
182+
183+
TransferManager tx;
184+
logger.debug("Creating executor service for transfer" + "manager with {} threads", parallelism);
185+
186+
// upload files to s3
187+
tx =
188+
TransferManagerBuilder.standard()
189+
.withS3Client(amazonClient)
190+
.withExecutorFactory(
191+
() -> createDefaultExecutorService("s3-transfer-manager-uploader-", parallelism))
192+
.build();
193+
194+
final Upload myUpload;
195+
196+
myUpload = tx.upload(remoteStorageLocation, destFileName, content, s3Meta);
197+
myUpload.waitForCompletion();
198+
199+
logger.info("Uploaded data from input stream to S3 location: {}.", destFileName);
200+
}
201+
202+
@Override
203+
public boolean handleStorageException(
204+
Exception ex,
205+
int retryCount,
206+
String operation,
207+
SFSession session,
208+
String command,
209+
String queryId,
210+
SnowflakeGCSClient gcsClient)
211+
throws SnowflakeSQLException {
212+
if (ex instanceof AmazonClientException) {
213+
logger.debug("GCSAccessStrategyAwsSdk: " + ex.getMessage());
214+
if (retryCount > gcsClient.getMaxRetries()
215+
|| SnowflakeS3Client.isClientException400Or404(ex)) {
216+
String extendedRequestId = "none";
217+
218+
if (ex instanceof AmazonS3Exception) {
219+
AmazonS3Exception ex1 = (AmazonS3Exception) ex;
220+
extendedRequestId = ex1.getExtendedRequestId();
221+
}
222+
223+
if (ex instanceof AmazonServiceException) {
224+
AmazonServiceException ex1 = (AmazonServiceException) ex;
225+
226+
// The AWS credentials might have expired when server returns error 400 and
227+
// does not return the ExpiredToken error code.
228+
// If session is null we cannot renew the token so throw the exception
229+
if (ex1.getStatusCode() == HttpStatus.SC_BAD_REQUEST && session != null) {
230+
SnowflakeFileTransferAgent.renewExpiredToken(session, command, gcsClient);
231+
} else {
232+
throw new SnowflakeSQLLoggedException(
233+
queryId,
234+
session,
235+
SqlState.SYSTEM_ERROR,
236+
ErrorCode.S3_OPERATION_ERROR.getMessageCode(),
237+
ex1,
238+
operation,
239+
ex1.getErrorType().toString(),
240+
ex1.getErrorCode(),
241+
ex1.getMessage(),
242+
ex1.getRequestId(),
243+
extendedRequestId);
244+
}
245+
246+
} else {
247+
throw new SnowflakeSQLLoggedException(
248+
queryId,
249+
session,
250+
SqlState.SYSTEM_ERROR,
251+
ErrorCode.AWS_CLIENT_ERROR.getMessageCode(),
252+
ex,
253+
operation,
254+
ex.getMessage());
255+
}
256+
} else {
257+
logger.debug(
258+
"Encountered exception ({}) during {}, retry count: {}",
259+
ex.getMessage(),
260+
operation,
261+
retryCount);
262+
logger.debug("Stack trace: ", ex);
263+
264+
// exponential backoff up to a limit
265+
int backoffInMillis = gcsClient.getRetryBackoffMin();
266+
267+
if (retryCount > 1) {
268+
backoffInMillis <<= (Math.min(retryCount - 1, gcsClient.getRetryBackoffMaxExponent()));
269+
}
270+
271+
try {
272+
logger.debug("Sleep for {} milliseconds before retry", backoffInMillis);
273+
274+
Thread.sleep(backoffInMillis);
275+
} catch (InterruptedException ex1) {
276+
// ignore
277+
}
278+
279+
// If the exception indicates that the AWS token has expired,
280+
// we need to refresh our S3 client with the new token
281+
if (ex instanceof AmazonS3Exception) {
282+
AmazonS3Exception s3ex = (AmazonS3Exception) ex;
283+
if (s3ex.getErrorCode().equalsIgnoreCase(EXPIRED_AWS_TOKEN_ERROR_CODE)) {
284+
// If session is null we cannot renew the token so throw the ExpiredToken exception
285+
if (session != null) {
286+
SnowflakeFileTransferAgent.renewExpiredToken(session, command, gcsClient);
287+
} else {
288+
throw new SnowflakeSQLException(
289+
queryId,
290+
s3ex.getErrorCode(),
291+
CLOUD_STORAGE_CREDENTIALS_EXPIRED,
292+
"S3 credentials have expired");
293+
}
294+
}
295+
}
296+
}
297+
return true;
298+
} else {
299+
return false;
300+
}
301+
}
302+
}

src/main/java/net/snowflake/client/jdbc/cloud/storage/SnowflakeGCSClient.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,11 @@ private void setupGCSClient(
12211221
logger.debug("Setting up the GCS client ", false);
12221222

12231223
try {
1224-
this.gcsAccessStrategy = new GCSDefaultAccessStrategy(stage, session);
1224+
if (stage.getUseVirtualUrl()) {
1225+
this.gcsAccessStrategy = new GCSAccessStrategyAwsSdk(stage);
1226+
} else {
1227+
this.gcsAccessStrategy = new GCSDefaultAccessStrategy(stage, session);
1228+
}
12251229

12261230
if (encMat != null) {
12271231
byte[] decodedKey = Base64.getDecoder().decode(encMat.getQueryStageMasterKey());

src/main/java/net/snowflake/client/jdbc/cloud/storage/SnowflakeS3Client.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public class SnowflakeS3Client implements SnowflakeStorageClient {
8686
private static final String S3_STREAMING_INGEST_CLIENT_KEY = "ingestclientkey";
8787

8888
// expired AWS token error code
89-
private static final String EXPIRED_AWS_TOKEN_ERROR_CODE = "ExpiredToken";
89+
protected static final String EXPIRED_AWS_TOKEN_ERROR_CODE = "ExpiredToken";
9090

9191
private int encryptionKeySize = 0; // used for PUTs
9292
private AmazonS3 amazonClient = null;
@@ -953,7 +953,7 @@ private static void handleS3Exception(
953953
* @param ex exception
954954
* @return true if it's a 400 or 404 status code
955955
*/
956-
public boolean isClientException400Or404(Exception ex) {
956+
public static boolean isClientException400Or404(Exception ex) {
957957
if (ex instanceof AmazonServiceException) {
958958
AmazonServiceException asEx = (AmazonServiceException) (ex);
959959
return asEx.getStatusCode() == HttpStatus.SC_NOT_FOUND

0 commit comments

Comments
 (0)