18
18
19
19
// [START aiplatform_create_training_pipeline_tabular_regression_sample]
20
20
21
+ import com .google .cloud .aiplatform .util .ValueConverter ;
21
22
import com .google .cloud .aiplatform .v1beta1 .DeployedModelRef ;
22
23
import com .google .cloud .aiplatform .v1beta1 .EnvVar ;
23
24
import com .google .cloud .aiplatform .v1beta1 .ExplanationMetadata ;
37
38
import com .google .cloud .aiplatform .v1beta1 .SampledShapleyAttribution ;
38
39
import com .google .cloud .aiplatform .v1beta1 .TimestampSplit ;
39
40
import com .google .cloud .aiplatform .v1beta1 .TrainingPipeline ;
41
+ import com .google .cloud .aiplatform .v1beta1 .schema .trainingjob .definition .AutoMlTables ;
42
+ import com .google .cloud .aiplatform .v1beta1 .schema .trainingjob .definition .AutoMlTablesInputs ;
43
+ import com .google .cloud .aiplatform .v1beta1 .schema .trainingjob .definition .AutoMlTablesInputs .Transformation ;
44
+ import com .google .cloud .aiplatform .v1beta1 .schema .trainingjob .definition .AutoMlTablesInputs .Transformation .AutoTransformation ;
45
+ import com .google .cloud .aiplatform .v1beta1 .schema .trainingjob .definition .AutoMlTablesInputs .Transformation .TimestampTransformation ;
40
46
import com .google .protobuf .Value ;
41
47
import com .google .protobuf .util .JsonFormat ;
42
48
import com .google .rpc .Status ;
43
49
import java .io .IOException ;
50
+ import java .util .ArrayList ;
44
51
45
52
public class CreateTrainingPipelineTabularRegressionSample {
46
53
@@ -50,18 +57,15 @@ public static void main(String[] args) throws IOException {
50
57
String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME" ;
51
58
String datasetId = "YOUR_DATASET_ID" ;
52
59
String targetColumn = "TARGET_COLUMN" ;
53
- String transformation =
54
- "[{TRANSFORMATION_TYPE: {columnName : COLUMN_NAME, invalidValuesAllowed : TRUE/FALSE }}]" ;
55
60
createTrainingPipelineTableRegression (
56
- project , modelDisplayName , datasetId , targetColumn , transformation );
61
+ project , modelDisplayName , datasetId , targetColumn );
57
62
}
58
63
59
64
static void createTrainingPipelineTableRegression (
60
65
String project ,
61
66
String modelDisplayName ,
62
67
String datasetId ,
63
- String targetColumn ,
64
- String transformation )
68
+ String targetColumn )
65
69
throws IOException {
66
70
PipelineServiceSettings pipelineServiceSettings =
67
71
PipelineServiceSettings .newBuilder ()
@@ -77,14 +81,82 @@ static void createTrainingPipelineTableRegression(
77
81
LocationName locationName = LocationName .of (project , location );
78
82
String trainingTaskDefinition =
79
83
"gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml" ;
80
- String jsonString =
81
- "{\" targetColumn\" : \" "
82
- + targetColumn
83
- + "\" ,\" predictionType\" : \" regression\" ,\" transformations\" : "
84
- + transformation
85
- + ",\" trainBudgetMilliNodeHours\" : 8000}" ;
86
- Value .Builder trainingTaskInputs = Value .newBuilder ();
87
- JsonFormat .parser ().merge (jsonString , trainingTaskInputs );
84
+
85
+ // Set the columns used for training and their data types
86
+ ArrayList <Transformation > tranformations = new ArrayList <>();
87
+ tranformations .add (Transformation .newBuilder ()
88
+ .setAuto (AutoTransformation .newBuilder ().setColumnName ("STRING_5000unique_NULLABLE" ))
89
+ .build ());
90
+ tranformations .add (Transformation .newBuilder ()
91
+ .setAuto (AutoTransformation .newBuilder ().setColumnName ("INTEGER_5000unique_NULLABLE" ))
92
+ .build ());
93
+ tranformations .add (Transformation .newBuilder ()
94
+ .setAuto (AutoTransformation .newBuilder ().setColumnName ("FLOAT_5000unique_NULLABLE" ))
95
+ .build ());
96
+ tranformations .add (Transformation .newBuilder ()
97
+ .setAuto (AutoTransformation .newBuilder ().setColumnName ("FLOAT_5000unique_REPEATED" ))
98
+ .build ());
99
+ tranformations .add (Transformation .newBuilder ()
100
+ .setAuto (AutoTransformation .newBuilder ().setColumnName ("NUMERIC_5000unique_NULLABLE" ))
101
+ .build ());
102
+ tranformations .add (Transformation .newBuilder ()
103
+ .setAuto (AutoTransformation .newBuilder ().setColumnName ("BOOLEAN_2unique_NULLABLE" ))
104
+ .build ());
105
+ tranformations .add (Transformation .newBuilder ()
106
+ .setTimestamp (TimestampTransformation .newBuilder ()
107
+ .setColumnName ("TIMESTAMP_1unique_NULLABLE" )
108
+ .setInvalidValuesAllowed (true ))
109
+ .build ());
110
+ tranformations .add (Transformation .newBuilder ()
111
+ .setAuto (AutoTransformation .newBuilder ().setColumnName ("DATE_1unique_NULLABLE" ))
112
+ .build ());
113
+ tranformations .add (Transformation .newBuilder ()
114
+ .setAuto (AutoTransformation .newBuilder ().setColumnName ("TIME_1unique_NULLABLE" ))
115
+ .build ());
116
+ tranformations .add (Transformation .newBuilder ()
117
+ .setTimestamp (TimestampTransformation .newBuilder ()
118
+ .setColumnName ("DATETIME_1unique_NULLABLE" )
119
+ .setInvalidValuesAllowed (true ))
120
+ .build ());
121
+ tranformations .add (Transformation .newBuilder ()
122
+ .setAuto (AutoTransformation .newBuilder ()
123
+ .setColumnName ("STRUCT_NULLABLE.STRING_5000unique_NULLABLE" ))
124
+ .build ());
125
+ tranformations .add (Transformation .newBuilder ()
126
+ .setAuto (AutoTransformation .newBuilder ()
127
+ .setColumnName ("STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE" ))
128
+ .build ());
129
+ tranformations .add (Transformation .newBuilder ()
130
+ .setAuto (AutoTransformation .newBuilder ()
131
+ .setColumnName ("STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE" ))
132
+ .build ());
133
+ tranformations .add (Transformation .newBuilder ()
134
+ .setAuto (AutoTransformation .newBuilder ()
135
+ .setColumnName ("STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED" ))
136
+ .build ());
137
+ tranformations .add (Transformation .newBuilder ()
138
+ .setAuto (AutoTransformation .newBuilder ()
139
+ .setColumnName ("STRUCT_NULLABLE.FLOAT_5000unique_REPEATED" ))
140
+ .build ());
141
+ tranformations .add (Transformation .newBuilder ()
142
+ .setAuto (AutoTransformation .newBuilder ()
143
+ .setColumnName ("STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE" ))
144
+ .build ());
145
+ tranformations .add (Transformation .newBuilder ()
146
+ .setAuto (AutoTransformation .newBuilder ()
147
+ .setColumnName ("STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE" ))
148
+ .build ());
149
+
150
+ AutoMlTablesInputs trainingTaskInputs = AutoMlTablesInputs .newBuilder ()
151
+ .addAllTransformations (tranformations )
152
+ .setTargetColumn (targetColumn )
153
+ .setPredictionType ("regression" )
154
+ .setTrainBudgetMilliNodeHours (8000 )
155
+ .setDisableEarlyStopping (false )
156
+ // supported regression optimisation objectives: minimize-rmse,
157
+ // minimize-mae, minimize-rmsle
158
+ .setOptimizationObjective ("minimize-rmse" )
159
+ .build ();
88
160
89
161
FractionSplit fractionSplit =
90
162
FractionSplit .newBuilder ()
@@ -104,7 +176,7 @@ static void createTrainingPipelineTableRegression(
104
176
TrainingPipeline .newBuilder ()
105
177
.setDisplayName (modelDisplayName )
106
178
.setTrainingTaskDefinition (trainingTaskDefinition )
107
- .setTrainingTaskInputs (trainingTaskInputs )
179
+ .setTrainingTaskInputs (ValueConverter . toValue ( trainingTaskInputs ) )
108
180
.setInputDataConfig (inputDataConfig )
109
181
.setModelToUpload (modelToUpload )
110
182
.build ();
0 commit comments