Skip to content

Commit c06b030

Browse files
added tests
1 parent a64035d commit c06b030

File tree

4 files changed

+102
-15
lines changed

4 files changed

+102
-15
lines changed

src/main/java/net/snowflake/client/jdbc/cloud/storage/GCSAccessStrategyAwsSdk.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class GCSAccessStrategyAwsSdk implements GCSAccessStrategy {
5858
if (endpoint.startsWith("https://")) {
5959
endpoint = endpoint.replaceFirst("https://", "");
6060
}
61-
if (endpoint.startsWith(stage.getStorageAccount())) {
61+
if (stage.getStorageAccount() != null && endpoint.startsWith(stage.getStorageAccount())) {
6262
endpoint = endpoint.replaceFirst(stage.getStorageAccount() + ".", "");
6363
}
6464

src/main/java/net/snowflake/client/jdbc/cloud/storage/SnowflakeGCSClient.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ public void renew(Map<?, ?> stageCredentials) throws SnowflakeSQLException {
158158

159159
@Override
160160
public void shutdown() {
161-
if(this.gcsAccessStrategy != null) {
161+
if (this.gcsAccessStrategy != null) {
162162
this.gcsAccessStrategy.shutdown();
163163
}
164164
}
@@ -1220,7 +1220,9 @@ private void setupGCSClient(
12201220
logger.debug("Setting up the GCS client ", false);
12211221

12221222
try {
1223-
if (stage.getUseVirtualUrl()) {
1223+
boolean overrideAwsAccessStrategy =
1224+
Boolean.valueOf(System.getenv("SNOWFLAKE_GCS_FORCE_VIRTUAL_STYLE_DOMAINS"));
1225+
if (stage.getUseVirtualUrl() || overrideAwsAccessStrategy) {
12241226
this.gcsAccessStrategy = new GCSAccessStrategyAwsSdk(stage, session);
12251227
} else {
12261228
this.gcsAccessStrategy = new GCSDefaultAccessStrategy(stage, session);

src/test/java/net/snowflake/client/jdbc/SnowflakeDriverIT.java

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@
5858
import org.junit.jupiter.api.Tag;
5959
import org.junit.jupiter.api.Test;
6060
import org.junit.jupiter.api.io.TempDir;
61+
import org.junit.jupiter.params.ParameterizedTest;
62+
import org.junit.jupiter.params.provider.ValueSource;
6163

6264
/** General integration tests */
6365
@Tag(TestTags.OTHERS)
@@ -730,9 +732,14 @@ public void testDBMetadata() throws Throwable {
730732
}
731733
}
732734

733-
@Test
735+
@ParameterizedTest
736+
@ValueSource(booleans = {true, false})
734737
@DontRunOnGithubActions
735-
public void testPutWithWildcardGCP() throws Throwable {
738+
public void testPutWithWildcardGCP(boolean useAwsSDKStrategy) throws Throwable {
739+
if (useAwsSDKStrategy) {
740+
SnowflakeUtil.systemSetEnv("SNOWFLAKE_GCS_FORCE_VIRTUAL_STYLE_DOMAINS", "true");
741+
}
742+
736743
Properties _connectionProperties = new Properties();
737744
_connectionProperties.put("inject_wait_in_put", 5);
738745
_connectionProperties.put("ssl", "off");
@@ -784,6 +791,8 @@ public void testPutWithWildcardGCP() throws Throwable {
784791
} finally {
785792
statement.execute("DROP STAGE IF EXISTS wildcard_stage");
786793
}
794+
} finally {
795+
SnowflakeUtil.systemSetEnv("SNOWFLAKE_GCS_FORCE_VIRTUAL_STYLE_DOMAINS", "false");
787796
}
788797
}
789798

@@ -805,9 +814,13 @@ private void copyContentFrom(File file1, File file2) throws Exception {
805814
}
806815
}
807816

808-
@Test
817+
@ParameterizedTest
818+
@ValueSource(booleans = {true, false})
809819
@DontRunOnGithubActions
810-
public void testPutGetLargeFileGCP() throws Throwable {
820+
public void testPutGetLargeFileGCP(boolean useAwsSDKStrategy) throws Throwable {
821+
if (useAwsSDKStrategy) {
822+
SnowflakeUtil.systemSetEnv("SNOWFLAKE_GCS_FORCE_VIRTUAL_STYLE_DOMAINS", "true");
823+
}
811824
try (Connection connection = getConnection("gcpaccount");
812825
Statement statement = connection.createStatement()) {
813826
try {
@@ -882,6 +895,8 @@ public void testPutGetLargeFileGCP() throws Throwable {
882895
statement.execute("DROP STAGE IF EXISTS extra_stage");
883896
statement.execute("DROP TABLE IF EXISTS large_table");
884897
}
898+
} finally {
899+
SnowflakeUtil.systemSetEnv("SNOWFLAKE_GCS_FORCE_VIRTUAL_STYLE_DOMAINS", "false");
885900
}
886901
}
887902

@@ -909,9 +924,15 @@ public void testPutOverwrite() throws Throwable {
909924
String destFolderCanonicalPath = destFolder.getCanonicalPath();
910925
String destFolderCanonicalPathWithSeparator = destFolderCanonicalPath + File.separator;
911926

912-
List<String> accounts = Arrays.asList(null, "s3testaccount", "azureaccount", "gcpaccount");
927+
List<String> accounts =
928+
Arrays.asList(null, "s3testaccount", "azureaccount", "gcpaccount", "gcpaccount_awssdk");
913929
for (int i = 0; i < accounts.size(); i++) {
914-
try (Connection connection = getConnection(accounts.get(i));
930+
String accountName = accounts.get(i);
931+
if (accounts.get(i) != null && accounts.get(i).equals("gcpaccount_awssdk")) {
932+
accountName = "gcpaccount";
933+
SnowflakeUtil.systemSetEnv("SNOWFLAKE_GCS_FORCE_VIRTUAL_STYLE_DOMAINS", "true");
934+
}
935+
try (Connection connection = getConnection(accountName);
915936
Statement statement = connection.createStatement()) {
916937
try {
917938
statement.execute("alter session set ENABLE_GCP_PUT_EXCEPTION_FOR_OLD_DRIVERS=false");
@@ -954,6 +975,8 @@ public void testPutOverwrite() throws Throwable {
954975
} finally {
955976
statement.execute("DROP TABLE IF EXISTS testLoadToLocalFS");
956977
}
978+
} finally {
979+
SnowflakeUtil.systemSetEnv("SNOWFLAKE_GCS_FORCE_VIRTUAL_STYLE_DOMAINS", "false");
957980
}
958981
}
959982
}
@@ -962,9 +985,18 @@ public void testPutOverwrite() throws Throwable {
962985
@DontRunOnGithubActions
963986
public void testPut() throws Throwable {
964987

965-
List<String> accounts = Arrays.asList(null, "s3testaccount", "azureaccount", "gcpaccount");
988+
List<String> accounts =
989+
Arrays.asList(null, "s3testaccount", "azureaccount", "gcpaccount", "gcpaccount_awssdk");
966990
for (int i = 0; i < accounts.size(); i++) {
967-
try (Connection connection = getConnection(accounts.get(i));
991+
String accountName = accounts.get(i);
992+
if (accounts.get(i) != null && accounts.get(i).equals("gcpaccount_awssdk")) {
993+
accountName = "gcpaccount";
994+
SnowflakeUtil.systemSetEnv("SNOWFLAKE_GCS_FORCE_VIRTUAL_STYLE_DOMAINS", "true");
995+
}
996+
if (accounts.get(i) == null || !accounts.get(i).startsWith(("gcp"))) {
997+
continue;
998+
}
999+
try (Connection connection = getConnection(accountName);
9681000
Statement statement = connection.createStatement()) {
9691001
try {
9701002
// load file test
@@ -1023,6 +1055,8 @@ public void testPut() throws Throwable {
10231055
} finally {
10241056
statement.execute("DROP TABLE IF EXISTS testLoadToLocalFS");
10251057
}
1058+
} finally {
1059+
SnowflakeUtil.systemSetEnv("SNOWFLAKE_GCS_FORCE_VIRTUAL_STYLE_DOMAINS", "false");
10261060
}
10271061
}
10281062
}
@@ -2613,9 +2647,15 @@ public void testSnow31104() throws Throwable {
26132647
@DontRunOnGithubActions
26142648
public void testPutGet() throws Throwable {
26152649

2616-
List<String> accounts = Arrays.asList(null, "s3testaccount", "azureaccount", "gcpaccount");
2650+
List<String> accounts =
2651+
Arrays.asList(null, "s3testaccount", "azureaccount", "gcpaccount", "gcpaccount_awssdk");
26172652
for (int i = 0; i < accounts.size(); i++) {
2618-
try (Connection connection = getConnection(accounts.get(i));
2653+
String accountName = accounts.get(i);
2654+
if (accounts.get(i) != null && accounts.get(i).equals("gcpaccount_awssdk")) {
2655+
accountName = "gcpaccount";
2656+
SnowflakeUtil.systemSetEnv("SNOWFLAKE_GCS_FORCE_VIRTUAL_STYLE_DOMAINS", "true");
2657+
}
2658+
try (Connection connection = getConnection(accountName);
26192659
Statement statement = connection.createStatement()) {
26202660
try {
26212661
String sourceFilePath = getFullPathFileInResource(TEST_DATA_FILE);
@@ -2655,6 +2695,8 @@ public void testPutGet() throws Throwable {
26552695
} finally {
26562696
statement.execute("DROP STAGE IF EXISTS testGetPut_stage");
26572697
}
2698+
} finally {
2699+
SnowflakeUtil.systemSetEnv("SNOWFLAKE_GCS_FORCE_VIRTUAL_STYLE_DOMAINS", "false");
26582700
}
26592701
}
26602702
}
@@ -2669,9 +2711,15 @@ public void testPutGet() throws Throwable {
26692711
@DontRunOnGithubActions
26702712
public void testPutGetToUnencryptedStage() throws Throwable {
26712713

2672-
List<String> accounts = Arrays.asList(null, "s3testaccount", "azureaccount", "gcpaccount");
2714+
List<String> accounts =
2715+
Arrays.asList(null, "s3testaccount", "azureaccount", "gcpaccount", "gcpaccount_awssdk");
26732716
for (int i = 0; i < accounts.size(); i++) {
2674-
try (Connection connection = getConnection(accounts.get(i));
2717+
String accountName = accounts.get(i);
2718+
if (accounts.get(i) != null && accounts.get(i).equals("gcpaccount_awssdk")) {
2719+
accountName = "gcpaccount";
2720+
SnowflakeUtil.systemSetEnv("SNOWFLAKE_GCS_FORCE_VIRTUAL_STYLE_DOMAINS", "true");
2721+
}
2722+
try (Connection connection = getConnection(accountName);
26752723
Statement statement = connection.createStatement()) {
26762724
try {
26772725
String sourceFilePath = getFullPathFileInResource(TEST_DATA_FILE);
@@ -2712,6 +2760,8 @@ public void testPutGetToUnencryptedStage() throws Throwable {
27122760
} finally {
27132761
statement.execute("DROP STAGE IF EXISTS testPutGet_unencstage");
27142762
}
2763+
} finally {
2764+
SnowflakeUtil.systemSetEnv("SNOWFLAKE_GCS_FORCE_VIRTUAL_STYLE_DOMAINS", "false");
27152765
}
27162766
}
27172767
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package net.snowflake.client.jdbc.cloud.storage;
2+
3+
import static org.junit.Assert.assertTrue;
4+
5+
import com.amazonaws.DefaultRequest;
6+
import com.amazonaws.auth.AWSCredentials;
7+
import com.amazonaws.auth.BasicAWSCredentials;
8+
import com.amazonaws.http.HttpMethodName;
9+
import java.util.HashMap;
10+
import org.junit.jupiter.api.Test;
11+
12+
class AwsSdkGCPSignerTest {
13+
14+
@Test
15+
void testSign() {
16+
AWSCredentials creds = new BasicAWSCredentials("access_key", "");
17+
AwsSdkGCPSigner signer = new AwsSdkGCPSigner();
18+
19+
DefaultRequest request = new DefaultRequest("S3");
20+
21+
HashMap<String, String> headers = new HashMap<>();
22+
headers.put("x-amz-storage-class", "storage_class");
23+
headers.put("x-amz-meta-custom", "custom_meta");
24+
25+
request.setHttpMethod(HttpMethodName.GET);
26+
request.setHeaders(headers);
27+
28+
signer.sign(request, creds);
29+
30+
assertTrue(request.getHeaders().get("Authorization").equals("Bearer access_key"));
31+
assertTrue(request.getHeaders().get("Accept-Encoding").equals("gzip,deflate"));
32+
assertTrue(request.getHeaders().get("x-goog-storage-class").equals("storage_class"));
33+
assertTrue(request.getHeaders().get("x-goog-meta-custom").equals("custom_meta"));
34+
}
35+
}

0 commit comments

Comments
 (0)