Skip to content

Commit 8bdbe5c

Browse files
telpirionsofisl
authored andcommitted
feat: adds enhancements to library (#22)
* feat: adds enhancements to library * chore: changes to synth.py * fix: broken pack n' play test * fix: add enhanced types to ts compiler * fix: project enabled * fix: adds docstrings to toValue(), fromValue() functions * fix: removing any * fix: edits to synth.py per reviewer * fix: add more test coverage * chore: added comment about conversion interface Co-authored-by: Sofia Leon <[email protected]>
1 parent 09c6769 commit 8bdbe5c

8 files changed

+413
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Copyright 2020 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
function main(
20+
datasetId,
21+
modelDisplayName,
22+
trainingPipelineDisplayName,
23+
project,
24+
location = 'us-central1'
25+
) {
26+
// [START aiplatform_create_training_pipeline_image_classification]
27+
/**
28+
* TODO(developer): Uncomment these variables before running the sample.
29+
* (Not necessary if passing values as arguments)
30+
*/
31+
/*
32+
const datasetId = 'YOUR DATASET';
33+
const modelDisplayName = 'NEW MODEL NAME;
34+
const trainingPipelineDisplayName = 'NAME FOR TRAINING PIPELINE';
35+
const project = 'YOUR PROJECT ID';
36+
const location = 'us-central1';
37+
*/
38+
// Imports the Google Cloud Pipeline Service Client library
39+
const aiplatform = require('@google-cloud/aiplatform');
40+
41+
const {
42+
definition,
43+
} = aiplatform.protos.google.cloud.aiplatform.v1beta1.schema.trainingjob;
44+
const ModelType = definition.AutoMlImageClassificationInputs.ModelType;
45+
46+
// Specifies the location of the api endpoint
47+
const clientOptions = {
48+
apiEndpoint: 'us-central1-aiplatform.googleapis.com',
49+
};
50+
51+
// Instantiates a client
52+
const pipelineServiceClient = new aiplatform.PipelineServiceClient(
53+
clientOptions
54+
);
55+
56+
async function createTrainingPipelineImageClassification() {
57+
// Configure the parent resource
58+
const parent = `projects/${project}/locations/${location}`;
59+
60+
// Values should match the input expected by your model.
61+
const trainingTaskInputsMessage = new definition.AutoMlImageClassificationInputs(
62+
{
63+
multiLabel: true,
64+
modelType: ModelType.CLOUD,
65+
budgetMilliNodeHours: 8000,
66+
disableEarlyStopping: false,
67+
}
68+
);
69+
70+
const trainingTaskInputs = trainingTaskInputsMessage.toValue();
71+
72+
const trainingTaskDefinition =
73+
'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml';
74+
75+
const modelToUpload = {displayName: modelDisplayName};
76+
const inputDataConfig = {datasetId: datasetId};
77+
const trainingPipeline = {
78+
displayName: trainingPipelineDisplayName,
79+
trainingTaskDefinition,
80+
trainingTaskInputs,
81+
inputDataConfig: inputDataConfig,
82+
modelToUpload: modelToUpload,
83+
};
84+
const request = {
85+
parent,
86+
trainingPipeline,
87+
};
88+
89+
// Create training pipeline request
90+
const [response] = await pipelineServiceClient.createTrainingPipeline(
91+
request
92+
);
93+
94+
console.log('Create training pipeline image classification response');
95+
console.log(`\tName : ${response.name}`);
96+
console.log(`\tDisplay Name : ${response.displayName}`);
97+
console.log(
98+
`\tTraining task definition : ${response.trainingTaskDefinition}`
99+
);
100+
console.log(
101+
`\tTraining task inputs : \
102+
${JSON.stringify(response.trainingTaskInputs)}`
103+
);
104+
console.log(
105+
`\tTraining task metadata : \
106+
${JSON.stringify(response.trainingTaskMetadata)}`
107+
);
108+
console.log(`\tState ; ${response.state}`);
109+
console.log(`\tCreate time : ${JSON.stringify(response.createTime)}`);
110+
console.log(`\tStart time : ${JSON.stringify(response.startTime)}`);
111+
console.log(`\tEnd time : ${JSON.stringify(response.endTime)}`);
112+
console.log(`\tUpdate time : ${JSON.stringify(response.updateTime)}`);
113+
console.log(`\tLabels : ${JSON.stringify(response.labels)}`);
114+
115+
const error = response.error;
116+
console.log('\tError');
117+
if (error === null) {
118+
console.log('\t\tCode : {}');
119+
console.log('\t\tMessage : {}');
120+
} else {
121+
console.log(`\t\tCode : ${error.code}`);
122+
console.log(`\t\tMessage : ${error.message}`);
123+
}
124+
}
125+
126+
createTrainingPipelineImageClassification();
127+
// [END aiplatform_create_training_pipeline_image_classification]
128+
}
129+
130+
process.on('unhandledRejection', err => {
131+
console.error(err.message);
132+
process.exitCode = 1;
133+
});
134+
135+
main(...process.argv.slice(2));
+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/**
2+
* Copyright 2020, Google, LLC.
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
'use strict';
17+
18+
function main(projectId, location = 'us-central1') {
19+
// [START aiplatform_list_endpoints]
20+
/**
21+
* TODO(developer): Uncomment these variables before running the sample.
22+
*/
23+
// const projectId = 'YOUR_PROJECT_ID';
24+
// const location = 'YOUR_PROJECT_LOCATION';
25+
26+
const {EndpointServiceClient} = require('@google-cloud/aiplatform');
27+
28+
// Specifies the location of the api endpoint
29+
const clientOptions = {
30+
apiEndpoint: 'us-central1-aiplatform.googleapis.com',
31+
};
32+
const client = new EndpointServiceClient(clientOptions);
33+
34+
async function listEndpoints() {
35+
// Configure the parent resource
36+
const parent = `projects/${projectId}/locations/${location}`;
37+
const request = {
38+
parent,
39+
};
40+
41+
// Get and print out a list of all the endpoints for this resource
42+
const [result] = await client.listEndpoints(request);
43+
for (const endpoint of result) {
44+
console.log(`\nEndpoint name: ${endpoint.name}`);
45+
console.log(`Display name: ${endpoint.displayName}`);
46+
if (endpoint.deployedModels[0]) {
47+
console.log(
48+
`First deployed model: ${endpoint.deployedModels[0].model}`
49+
);
50+
}
51+
}
52+
}
53+
54+
listEndpoints();
55+
// [END aiplatform_list_endpoints]
56+
}
57+
58+
main(...process.argv.slice(2)).catch(err => {
59+
console.error(err);
60+
process.exitCode = 1;
61+
});

ai-platform/snippets/package.json

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
"@google-cloud/aiplatform": "^1.0.0"
1717
},
1818
"devDependencies": {
19-
"mocha": "^8.0.0"
19+
"chai": "^4.2.0",
20+
"mocha": "^8.0.0",
21+
"uuid": "^8.3.1"
2022
}
2123
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Copyright 2020 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
function main(filename, endpointId, project, location = 'us-central1') {
20+
// [START aiplatform_predict_image_classification]
21+
/**
22+
* TODO(developer): Uncomment these variables before running the sample.\
23+
* (Not necessary if passing values as arguments)
24+
*/
25+
26+
// const filename = "YOUR_PREDICTION_FILE_NAME";
27+
// const endpointId = "YOUR_ENDPOINT_ID";
28+
// const project = 'YOUR_PROJECT_ID';
29+
// const location = 'YOUR_PROJECT_LOCATION';
30+
const aiplatform = require('@google-cloud/aiplatform');
31+
const {
32+
instance,
33+
params,
34+
prediction,
35+
} = aiplatform.protos.google.cloud.aiplatform.v1beta1.schema.predict;
36+
37+
// Imports the Google Cloud Prediction Service Client library
38+
const {PredictionServiceClient} = aiplatform;
39+
40+
// Specifies the location of the api endpoint
41+
const clientOptions = {
42+
apiEndpoint: 'us-central1-prediction-aiplatform.googleapis.com',
43+
};
44+
45+
// Instantiates a client
46+
const predictionServiceClient = new PredictionServiceClient(clientOptions);
47+
48+
async function predictImageClassification() {
49+
// Configure the endpoint resource
50+
const endpoint = `projects/${project}/locations/${location}/endpoints/${endpointId}`;
51+
52+
const parametersObj = new params.ImageClassificationPredictionParams({
53+
confidenceThreshold: 0.5,
54+
maxPredictions: 5,
55+
});
56+
const parameters = parametersObj.toValue();
57+
58+
const fs = require('fs');
59+
const image = fs.readFileSync(filename, 'base64');
60+
const instanceObj = new instance.ImageClassificationPredictionInstance({
61+
content: image,
62+
});
63+
const instanceValue = instanceObj.toValue();
64+
65+
const instances = [instanceValue];
66+
const request = {
67+
endpoint,
68+
instances,
69+
parameters,
70+
};
71+
72+
// Predict request
73+
const [response] = await predictionServiceClient.predict(request);
74+
75+
console.log('Predict image classification response');
76+
console.log(`\tDeployed model id : ${response.deployedModelId}`);
77+
const predictions = response.predictions;
78+
console.log('\tPredictions :');
79+
for (const predictionValue of predictions) {
80+
const predictionResultObj = prediction.ClassificationPredictionResult.fromValue(
81+
predictionValue
82+
);
83+
for (const [i, label] of predictionResultObj.displayNames.entries()) {
84+
console.log(`\tDisplay name: ${label}`);
85+
console.log(`\tConfidences: ${predictionResultObj.confidences[i]}`);
86+
console.log(`\tIDs: ${predictionResultObj.ids[i]}\n\n`);
87+
}
88+
}
89+
}
90+
predictImageClassification();
91+
// [END aiplatform_predict_image_classification]
92+
}
93+
94+
process.on('unhandledRejection', err => {
95+
console.error(err.message);
96+
process.exitCode = 1;
97+
});
98+
99+
main(...process.argv.slice(2));
74.3 KB
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright 2020 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
const {assert} = require('chai');
20+
const {after, describe, it} = require('mocha');
21+
22+
const uuid = require('uuid').v4;
23+
const cp = require('child_process');
24+
const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});
25+
26+
const aiplatform = require('@google-cloud/aiplatform');
27+
const clientOptions = {
28+
apiEndpoint: 'us-central1-aiplatform.googleapis.com',
29+
};
30+
31+
const pipelineServiceClient = new aiplatform.PipelineServiceClient(
32+
clientOptions
33+
);
34+
35+
const datasetId = process.env.TRAINING_PIPELINE_IMAGE_CLASS_DATASET_ID;
36+
const modelDisplayName = `temp_create_training_pipeline_image_classification_model_test${uuid()}`;
37+
const trainingPipelineDisplayName = `temp_create_training_pipeline_image_classification_test_${uuid()}`;
38+
const project = process.env.CAIP_PROJECT_ID;
39+
const location = process.env.LOCATION;
40+
41+
let trainingPipelineId;
42+
43+
describe('AI platform create training pipeline image classification', () => {
44+
it('should create a new image classification training pipeline', async () => {
45+
const stdout = execSync(
46+
`node ./create-training-pipeline-image-classification.js ${datasetId} ${modelDisplayName} ${trainingPipelineDisplayName} ${project} ${location}`
47+
);
48+
assert.match(stdout, /\/locations\/us-central1\/trainingPipelines\//);
49+
trainingPipelineId = stdout
50+
.split('/locations/us-central1/trainingPipelines/')[1]
51+
.split('\n')[0];
52+
});
53+
54+
after('should cancel the training pipeline and delete it', async () => {
55+
const name = pipelineServiceClient.trainingPipelinePath(
56+
project,
57+
location,
58+
trainingPipelineId
59+
);
60+
61+
const cancelRequest = {
62+
name,
63+
};
64+
65+
pipelineServiceClient.cancelTrainingPipeline(cancelRequest).then(() => {
66+
const deleteRequest = {
67+
name,
68+
};
69+
70+
return pipelineServiceClient.deleteTrainingPipeline(deleteRequest);
71+
});
72+
});
73+
});

0 commit comments

Comments
 (0)