Skip to content

Commit e6a4c3f

Browse files
feat(tpu): add tpu vm create spot sample. (#9610)
* Changed package, added information to CODEOWNERS * Added information to CODEOWNERS * Added timeout * Fixed parameters for test * Fixed DeleteTpuVm and naming * Added comment, created Util class * Fixed naming * Fixed whitespace * Split PR into smaller, deleted redundant code * Implemented tpu_vm_create_spot sample, created test * changed zone * Changed zone * Fixed empty lines and tests, deleted cleanup method * Changed zone * Deleted redundant test class * Increased timeout * Fixed test
1 parent d8d6253 commit e6a4c3f

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed
+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
//[START tpu_vm_create_spot]
20+
import com.google.cloud.tpu.v2.CreateNodeRequest;
21+
import com.google.cloud.tpu.v2.Node;
22+
import com.google.cloud.tpu.v2.SchedulingConfig;
23+
import com.google.cloud.tpu.v2.TpuClient;
24+
import java.io.IOException;
25+
import java.util.concurrent.ExecutionException;
26+
27+
public class CreateSpotTpuVm {
28+
public static void main(String[] args)
29+
throws IOException, ExecutionException, InterruptedException {
30+
// TODO(developer): Replace these variables before running the sample.
31+
// Project ID or project number of the Google Cloud project you want to create a node.
32+
String projectId = "YOUR_PROJECT_ID";
33+
// The zone in which to create the TPU.
34+
// For more information about supported TPU types for specific zones,
35+
// see https://cloud.google.com/tpu/docs/regions-zones
36+
String zone = "us-central1-f";
37+
// The name for your TPU.
38+
String nodeName = "YOUR_TPY_NAME";
39+
// The accelerator type that specifies the version and size of the Cloud TPU you want to create.
40+
// For more information about supported accelerator types for each TPU version,
41+
// see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions.
42+
String tpuType = "v2-8";
43+
// Software version that specifies the version of the TPU runtime to install.
44+
// For more information see https://cloud.google.com/tpu/docs/runtimes
45+
String tpuSoftwareVersion = "tpu-vm-tf-2.14.1";
46+
47+
createSpotTpuVm(projectId, zone, nodeName, tpuType, tpuSoftwareVersion);
48+
}
49+
50+
// Creates a preemptible TPU VM with the specified name, zone, accelerator type, and version.
51+
public static Node createSpotTpuVm(
52+
String projectId, String zone, String nodeName, String tpuType, String tpuSoftwareVersion)
53+
throws IOException, ExecutionException, InterruptedException {
54+
// Initialize client that will be used to send requests. This client only needs to be created
55+
// once, and can be reused for multiple requests.
56+
try (TpuClient tpuClient = TpuClient.create()) {
57+
String parent = String.format("projects/%s/locations/%s", projectId, zone);
58+
// TODO: Wait for update of library to change preemptible to spot=True
59+
SchedulingConfig schedulingConfig = SchedulingConfig.newBuilder()
60+
.setPreemptible(true)
61+
.build();
62+
63+
Node tpuVm = Node.newBuilder()
64+
.setName(nodeName)
65+
.setAcceleratorType(tpuType)
66+
.setRuntimeVersion(tpuSoftwareVersion)
67+
.setSchedulingConfig(schedulingConfig)
68+
.build();
69+
70+
CreateNodeRequest request = CreateNodeRequest.newBuilder()
71+
.setParent(parent)
72+
.setNodeId(nodeName)
73+
.setNode(tpuVm)
74+
.build();
75+
76+
return tpuClient.createNodeAsync(request).get();
77+
}
78+
}
79+
}
80+
//[END tpu_vm_create_spot]

tpu/src/test/java/tpu/TpuVmIT.java

+23
Original file line numberDiff line numberDiff line change
@@ -210,4 +210,27 @@ public void testStopTpuVm() throws IOException, ExecutionException, InterruptedE
210210
assertEquals(returnedNode, mockNode);
211211
}
212212
}
213+
214+
@Test
215+
public void testCreateSpotTpuVm() throws Exception {
216+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
217+
Node mockNode = mock(Node.class);
218+
TpuClient mockTpuClient = mock(TpuClient.class);
219+
OperationFuture mockFuture = mock(OperationFuture.class);
220+
221+
mockedTpuClient.when(TpuClient::create).thenReturn(mockTpuClient);
222+
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
223+
.thenReturn(mockFuture);
224+
when(mockFuture.get()).thenReturn(mockNode);
225+
226+
Node returnedNode = CreateSpotTpuVm.createSpotTpuVm(
227+
PROJECT_ID, ZONE, NODE_NAME,
228+
TPU_TYPE, TPU_SOFTWARE_VERSION);
229+
230+
verify(mockTpuClient, times(1))
231+
.createNodeAsync(any(CreateNodeRequest.class));
232+
verify(mockFuture, times(1)).get();
233+
assertEquals(returnedNode, mockNode);
234+
}
235+
}
213236
}

0 commit comments

Comments
 (0)