Skip to content

Commit 060d985

Browse files
samples: adds enhanced library samples (#134)
* feat: adds enhanced library samples * feat: update pom.xml * update prot version * removed unused packages * fix: resetting text sentiment analysis due to proto issues Co-authored-by: Mike Ganbold <[email protected]>
1 parent 41cf967 commit 060d985

File tree

34 files changed

+290
-186
lines changed

34 files changed

+290
-186
lines changed

aiplatform/snippets/pom.xml

+5
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@
5757
<version>1.1.2</version>
5858
<scope>test</scope>
5959
</dependency>
60+
<dependency>
61+
<groupId>com.google.api.grpc</groupId>
62+
<artifactId>proto-google-cloud-aiplatform-v1beta1</artifactId>
63+
<version>0.2.0</version>
64+
</dependency>
6065

6166
</dependencies>
6267
</project>

aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java

+11-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
// [START aiplatform_create_training_pipeline_image_object_detection_sample]
2020

21+
import com.google.cloud.aiplatform.util.ValueConverter;
2122
import com.google.cloud.aiplatform.v1beta1.DeployedModelRef;
2223
import com.google.cloud.aiplatform.v1beta1.EnvVar;
2324
import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata;
@@ -38,6 +39,9 @@
3839
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
3940
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
4041
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
42+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassification;
43+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageObjectDetectionInputs;
44+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageObjectDetectionInputs.ModelType;
4145
import com.google.protobuf.Value;
4246
import com.google.protobuf.util.JsonFormat;
4347
import com.google.rpc.Status;
@@ -74,11 +78,12 @@ static void createTrainingPipelineImageObjectDetectionSample(
7478
+ "automl_image_object_detection_1.0.0.yaml";
7579
LocationName locationName = LocationName.of(project, location);
7680

77-
String jsonString =
78-
"{\"modelType\": \"CLOUD_HIGH_ACCURACY_1\", \"budgetMilliNodeHours\": 20000,"
79-
+ " \"disableEarlyStopping\": false}";
80-
Value.Builder trainingTaskInputs = Value.newBuilder();
81-
JsonFormat.parser().merge(jsonString, trainingTaskInputs);
81+
AutoMlImageObjectDetectionInputs autoMlImageObjectDetectionInputs =
82+
AutoMlImageObjectDetectionInputs.newBuilder()
83+
.setModelType(ModelType.CLOUD_HIGH_ACCURACY_1)
84+
.setBudgetMilliNodeHours(20000)
85+
.setDisableEarlyStopping(false)
86+
.build();
8287

8388
InputDataConfig trainingInputDataConfig =
8489
InputDataConfig.newBuilder().setDatasetId(datasetId).build();
@@ -87,7 +92,7 @@ static void createTrainingPipelineImageObjectDetectionSample(
8792
TrainingPipeline.newBuilder()
8893
.setDisplayName(trainingPipelineDisplayName)
8994
.setTrainingTaskDefinition(trainingTaskDefinition)
90-
.setTrainingTaskInputs(trainingTaskInputs)
95+
.setTrainingTaskInputs(ValueConverter.toValue(autoMlImageObjectDetectionInputs))
9196
.setInputDataConfig(trainingInputDataConfig)
9297
.setModelToUpload(model)
9398
.build();

aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java

+35-14
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
// [START aiplatform_create_training_pipeline_tabular_classification_sample]
2020

21+
import com.google.cloud.aiplatform.util.ValueConverter;
2122
import com.google.cloud.aiplatform.v1beta1.DeployedModelRef;
2223
import com.google.cloud.aiplatform.v1beta1.EnvVar;
2324
import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata;
@@ -37,10 +38,14 @@
3738
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
3839
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
3940
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
41+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs;
42+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation;
43+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation;
4044
import com.google.protobuf.Value;
4145
import com.google.protobuf.util.JsonFormat;
4246
import com.google.rpc.Status;
4347
import java.io.IOException;
48+
import java.util.ArrayList;
4449

4550
public class CreateTrainingPipelineTabularClassificationSample {
4651

@@ -50,18 +55,15 @@ public static void main(String[] args) throws IOException {
5055
String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME";
5156
String datasetId = "YOUR_DATASET_ID";
5257
String targetColumn = "TARGET_COLUMN";
53-
String transformation =
54-
"[{TRANSFORMATION_TYPE: {columnName : COLUMN_NAME, invalidValuesAllowed : TRUE/FALSE }}]";
5558
createTrainingPipelineTableClassification(
56-
project, modelDisplayName, datasetId, targetColumn, transformation);
59+
project, modelDisplayName, datasetId, targetColumn);
5760
}
5861

5962
static void createTrainingPipelineTableClassification(
6063
String project,
6164
String modelDisplayName,
6265
String datasetId,
63-
String targetColumn,
64-
String transformation)
66+
String targetColumn)
6567
throws IOException {
6668
PipelineServiceSettings pipelineServiceSettings =
6769
PipelineServiceSettings.newBuilder()
@@ -77,15 +79,34 @@ static void createTrainingPipelineTableClassification(
7779
LocationName locationName = LocationName.of(project, location);
7880
String trainingTaskDefinition =
7981
"gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml";
80-
String jsonString =
81-
"{\"targetColumn\": \""
82-
+ targetColumn
83-
+ "\",\"predictionType\": \"classification\",\"transformations\": "
84-
+ transformation
85-
+ ",\"trainBudgetMilliNodeHours\": 8000}";
8682

87-
Value.Builder trainingTaskInputs = Value.newBuilder();
88-
JsonFormat.parser().merge(jsonString, trainingTaskInputs);
83+
// Set the columns used for training and their data types
84+
Transformation transformation1 = Transformation.newBuilder()
85+
.setAuto(AutoTransformation.newBuilder().setColumnName("sepal_width").build())
86+
.build();
87+
Transformation transformation2 = Transformation.newBuilder()
88+
.setAuto(AutoTransformation.newBuilder().setColumnName("sepal_length").build())
89+
.build();
90+
Transformation transformation3 = Transformation.newBuilder()
91+
.setAuto(AutoTransformation.newBuilder().setColumnName("petal_length").build())
92+
.build();
93+
Transformation transformation4 = Transformation.newBuilder()
94+
.setAuto(AutoTransformation.newBuilder().setColumnName("petal_width").build())
95+
.build();
96+
97+
ArrayList<Transformation> transformationArrayList = new ArrayList<>();
98+
transformationArrayList.add(transformation1);
99+
transformationArrayList.add(transformation2);
100+
transformationArrayList.add(transformation3);
101+
transformationArrayList.add(transformation4);
102+
103+
AutoMlTablesInputs autoMlTablesInputs =
104+
AutoMlTablesInputs.newBuilder()
105+
.setTargetColumn(targetColumn)
106+
.setPredictionType("classification")
107+
.addAllTransformations(transformationArrayList)
108+
.setTrainBudgetMilliNodeHours(8000)
109+
.build();
89110

90111
FractionSplit fractionSplit =
91112
FractionSplit.newBuilder()
@@ -105,7 +126,7 @@ static void createTrainingPipelineTableClassification(
105126
TrainingPipeline.newBuilder()
106127
.setDisplayName(modelDisplayName)
107128
.setTrainingTaskDefinition(trainingTaskDefinition)
108-
.setTrainingTaskInputs(trainingTaskInputs)
129+
.setTrainingTaskInputs(ValueConverter.toValue(autoMlTablesInputs))
109130
.setInputDataConfig(inputDataConfig)
110131
.setModelToUpload(modelToUpload)
111132
.build();

aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java

+86-14
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
// [START aiplatform_create_training_pipeline_tabular_regression_sample]
2020

21+
import com.google.cloud.aiplatform.util.ValueConverter;
2122
import com.google.cloud.aiplatform.v1beta1.DeployedModelRef;
2223
import com.google.cloud.aiplatform.v1beta1.EnvVar;
2324
import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata;
@@ -37,10 +38,16 @@
3738
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
3839
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
3940
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
41+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTables;
42+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs;
43+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation;
44+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation;
45+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.TimestampTransformation;
4046
import com.google.protobuf.Value;
4147
import com.google.protobuf.util.JsonFormat;
4248
import com.google.rpc.Status;
4349
import java.io.IOException;
50+
import java.util.ArrayList;
4451

4552
public class CreateTrainingPipelineTabularRegressionSample {
4653

@@ -50,18 +57,15 @@ public static void main(String[] args) throws IOException {
5057
String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME";
5158
String datasetId = "YOUR_DATASET_ID";
5259
String targetColumn = "TARGET_COLUMN";
53-
String transformation =
54-
"[{TRANSFORMATION_TYPE: {columnName : COLUMN_NAME, invalidValuesAllowed : TRUE/FALSE }}]";
5560
createTrainingPipelineTableRegression(
56-
project, modelDisplayName, datasetId, targetColumn, transformation);
61+
project, modelDisplayName, datasetId, targetColumn);
5762
}
5863

5964
static void createTrainingPipelineTableRegression(
6065
String project,
6166
String modelDisplayName,
6267
String datasetId,
63-
String targetColumn,
64-
String transformation)
68+
String targetColumn)
6569
throws IOException {
6670
PipelineServiceSettings pipelineServiceSettings =
6771
PipelineServiceSettings.newBuilder()
@@ -77,14 +81,82 @@ static void createTrainingPipelineTableRegression(
7781
LocationName locationName = LocationName.of(project, location);
7882
String trainingTaskDefinition =
7983
"gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml";
80-
String jsonString =
81-
"{\"targetColumn\": \""
82-
+ targetColumn
83-
+ "\",\"predictionType\": \"regression\",\"transformations\": "
84-
+ transformation
85-
+ ",\"trainBudgetMilliNodeHours\": 8000}";
86-
Value.Builder trainingTaskInputs = Value.newBuilder();
87-
JsonFormat.parser().merge(jsonString, trainingTaskInputs);
84+
85+
// Set the columns used for training and their data types
86+
ArrayList<Transformation> tranformations = new ArrayList<>();
87+
tranformations.add(Transformation.newBuilder()
88+
.setAuto(AutoTransformation.newBuilder().setColumnName("STRING_5000unique_NULLABLE"))
89+
.build());
90+
tranformations.add(Transformation.newBuilder()
91+
.setAuto(AutoTransformation.newBuilder().setColumnName("INTEGER_5000unique_NULLABLE"))
92+
.build());
93+
tranformations.add(Transformation.newBuilder()
94+
.setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_NULLABLE"))
95+
.build());
96+
tranformations.add(Transformation.newBuilder()
97+
.setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_REPEATED"))
98+
.build());
99+
tranformations.add(Transformation.newBuilder()
100+
.setAuto(AutoTransformation.newBuilder().setColumnName("NUMERIC_5000unique_NULLABLE"))
101+
.build());
102+
tranformations.add(Transformation.newBuilder()
103+
.setAuto(AutoTransformation.newBuilder().setColumnName("BOOLEAN_2unique_NULLABLE"))
104+
.build());
105+
tranformations.add(Transformation.newBuilder()
106+
.setTimestamp(TimestampTransformation.newBuilder()
107+
.setColumnName("TIMESTAMP_1unique_NULLABLE")
108+
.setInvalidValuesAllowed(true))
109+
.build());
110+
tranformations.add(Transformation.newBuilder()
111+
.setAuto(AutoTransformation.newBuilder().setColumnName("DATE_1unique_NULLABLE"))
112+
.build());
113+
tranformations.add(Transformation.newBuilder()
114+
.setAuto(AutoTransformation.newBuilder().setColumnName("TIME_1unique_NULLABLE"))
115+
.build());
116+
tranformations.add(Transformation.newBuilder()
117+
.setTimestamp(TimestampTransformation.newBuilder()
118+
.setColumnName("DATETIME_1unique_NULLABLE")
119+
.setInvalidValuesAllowed(true))
120+
.build());
121+
tranformations.add(Transformation.newBuilder()
122+
.setAuto(AutoTransformation.newBuilder()
123+
.setColumnName("STRUCT_NULLABLE.STRING_5000unique_NULLABLE"))
124+
.build());
125+
tranformations.add(Transformation.newBuilder()
126+
.setAuto(AutoTransformation.newBuilder()
127+
.setColumnName("STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE"))
128+
.build());
129+
tranformations.add(Transformation.newBuilder()
130+
.setAuto(AutoTransformation.newBuilder()
131+
.setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE"))
132+
.build());
133+
tranformations.add(Transformation.newBuilder()
134+
.setAuto(AutoTransformation.newBuilder()
135+
.setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED"))
136+
.build());
137+
tranformations.add(Transformation.newBuilder()
138+
.setAuto(AutoTransformation.newBuilder()
139+
.setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REPEATED"))
140+
.build());
141+
tranformations.add(Transformation.newBuilder()
142+
.setAuto(AutoTransformation.newBuilder()
143+
.setColumnName("STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE"))
144+
.build());
145+
tranformations.add(Transformation.newBuilder()
146+
.setAuto(AutoTransformation.newBuilder()
147+
.setColumnName("STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE"))
148+
.build());
149+
150+
AutoMlTablesInputs trainingTaskInputs = AutoMlTablesInputs.newBuilder()
151+
.addAllTransformations(tranformations)
152+
.setTargetColumn(targetColumn)
153+
.setPredictionType("regression")
154+
.setTrainBudgetMilliNodeHours(8000)
155+
.setDisableEarlyStopping(false)
156+
// supported regression optimisation objectives: minimize-rmse,
157+
// minimize-mae, minimize-rmsle
158+
.setOptimizationObjective("minimize-rmse")
159+
.build();
88160

89161
FractionSplit fractionSplit =
90162
FractionSplit.newBuilder()
@@ -104,7 +176,7 @@ static void createTrainingPipelineTableRegression(
104176
TrainingPipeline.newBuilder()
105177
.setDisplayName(modelDisplayName)
106178
.setTrainingTaskDefinition(trainingTaskDefinition)
107-
.setTrainingTaskInputs(trainingTaskInputs)
179+
.setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs))
108180
.setInputDataConfig(inputDataConfig)
109181
.setModelToUpload(modelToUpload)
110182
.build();

aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java

+7-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
// [START aiplatform_create_training_pipeline_text_classification_sample]
2020

21+
import com.google.cloud.aiplatform.util.ValueConverter;
2122
import com.google.cloud.aiplatform.v1beta1.DeployedModelRef;
2223
import com.google.cloud.aiplatform.v1beta1.EnvVar;
2324
import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata;
@@ -38,8 +39,7 @@
3839
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
3940
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
4041
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
41-
import com.google.protobuf.Value;
42-
import com.google.protobuf.util.JsonFormat;
42+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTextClassificationInputs;
4343
import com.google.rpc.Status;
4444
import java.io.IOException;
4545

@@ -73,12 +73,13 @@ static void createTrainingPipelineTextClassificationSample(
7373
String trainingTaskDefinition =
7474
"gs://google-cloud-aiplatform/schema/trainingjob/definition/"
7575
+ "automl_text_classification_1.0.0.yaml";
76-
String jsonString = "{\"multiLabel\": false}";
7776

7877
LocationName locationName = LocationName.of(project, location);
7978

80-
Value.Builder trainingTaskInputs = Value.newBuilder();
81-
JsonFormat.parser().merge(jsonString, trainingTaskInputs);
79+
AutoMlTextClassificationInputs trainingTaskInputs =
80+
AutoMlTextClassificationInputs.newBuilder()
81+
.setMultiLabel(false)
82+
.build();
8283

8384
InputDataConfig trainingInputDataConfig =
8485
InputDataConfig.newBuilder().setDatasetId(datasetId).build();
@@ -87,7 +88,7 @@ static void createTrainingPipelineTextClassificationSample(
8788
TrainingPipeline.newBuilder()
8889
.setDisplayName(trainingPipelineDisplayName)
8990
.setTrainingTaskDefinition(trainingTaskDefinition)
90-
.setTrainingTaskInputs(trainingTaskInputs)
91+
.setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs))
9192
.setInputDataConfig(trainingInputDataConfig)
9293
.setModelToUpload(model)
9394
.build();

aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java

+2-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
// [START aiplatform_create_training_pipeline_text_entity_extraction_sample]
2020

21+
import com.google.cloud.aiplatform.util.ValueConverter;
2122
import com.google.cloud.aiplatform.v1beta1.DeployedModelRef;
2223
import com.google.cloud.aiplatform.v1beta1.EnvVar;
2324
import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata;
@@ -73,22 +74,17 @@ static void createTrainingPipelineTextEntityExtractionSample(
7374
String trainingTaskDefinition =
7475
"gs://google-cloud-aiplatform/schema/trainingjob/definition/"
7576
+ "automl_text_extraction_1.0.0.yaml";
76-
String jsonString = "{}";
7777

7878
LocationName locationName = LocationName.of(project, location);
7979

80-
// Training task inputs are empty for text entity extraction
81-
Value.Builder trainingTaskInputs = Value.newBuilder();
82-
JsonFormat.parser().merge(jsonString, trainingTaskInputs);
83-
8480
InputDataConfig trainingInputDataConfig =
8581
InputDataConfig.newBuilder().setDatasetId(datasetId).build();
8682
Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
8783
TrainingPipeline trainingPipeline =
8884
TrainingPipeline.newBuilder()
8985
.setDisplayName(trainingPipelineDisplayName)
9086
.setTrainingTaskDefinition(trainingTaskDefinition)
91-
.setTrainingTaskInputs(trainingTaskInputs)
87+
.setTrainingTaskInputs(ValueConverter.EMPTY_VALUE)
9288
.setInputDataConfig(trainingInputDataConfig)
9389
.setModelToUpload(model)
9490
.build();

0 commit comments

Comments
 (0)