Skip to content

Commit 8f016bd

Browse files
jonathan1920Lars Wander
authored and
Lars Wander
committed
update create_model to allow user to specify included or excluded col… (googleapis#16)
* update create_model to allow user to specify included or excluded columns * made minor changes stylistically and with added ValueError outputs
1 parent a8c7a72 commit 8f016bd

File tree

1 file changed

+42
-9
lines changed
  • automl/google/cloud/automl_v1beta1/helper

1 file changed

+42
-9
lines changed

automl/google/cloud/automl_v1beta1/helper/tables.py

+42-9
Original file line numberDiff line numberDiff line change
@@ -1043,8 +1043,8 @@ def list_models(self, project=None, region=None):
10431043
def create_model(self, model_display_name, dataset=None,
10441044
dataset_display_name=None, dataset_name=None,
10451045
train_budget_milli_node_hours=None, project=None,
1046-
region=None):
1047-
"""Create a model. This will train your model on the given dataset.
1046+
region=None, input_feature_column_specs_included=None, input_feature_column_specs_excluded=None):
1047+
"""Create a model. This will train your model on the given dataset.
10481048
10491049
Example:
10501050
>>> from google.cloud import automl_v1beta1
@@ -1057,7 +1057,6 @@ def create_model(self, model_display_name, dataset=None,
10571057
>>>
10581058
>>> m.result() # blocks on result
10591059
>>>
1060-
10611060
Args:
10621061
project (Optional[string]):
10631062
If you have initialized the client with a value for `project`
@@ -1085,11 +1084,15 @@ def create_model(self, model_display_name, dataset=None,
10851084
The `Dataset` instance you want to train your model on. This
10861085
must be supplied if `dataset_display_name` or `dataset_name`
10871086
are not supplied.
1088-
1087+
input_feature_column_specs_included(Optional[string]):
1088+
The list of the names of the columns you want to include to train
1089+
your model on.
1090+
input_feature_column_specs_excluded(Optional[string]):
1091+
The list of the names of the columns you want to exclude and
1092+
not train your model on.
10891093
Returns:
10901094
A :class:`~google.cloud.automl_v1beta1.types._OperationFuture`
10911095
instance.
1092-
10931096
Raises:
10941097
google.api_core.exceptions.GoogleAPICallError: If the request
10951098
failed for any reason.
@@ -1101,26 +1104,56 @@ def create_model(self, model_display_name, dataset=None,
11011104
raise ValueError('\'train_budget_milli_node_hours\' must be a '
11021105
'value between 1,000 and 72,000 inclusive')
11031106

1107+
if input_feature_column_specs_excluded not in [None, []] and input_feature_column_specs_included not in [None, []]:
1108+
raise ValueError('\'cannot set both input_feature_column_specs_excluded\' and '
1109+
'\'input_feature_column_specs_included\'')
1110+
1111+
11041112
dataset_name = self.__dataset_name_from_args(dataset=dataset,
11051113
dataset_name=dataset_name,
11061114
dataset_display_name=dataset_display_name,
11071115
project=project,
11081116
region=region)
1109-
1117+
tables_model_metadata = {
1118+
'train_budget_milli_node_hours': train_budget_milli_node_hours
1119+
}
11101120
dataset_id = dataset_name.rsplit('/', 1)[-1]
1121+
columns = [s for s in self.list_column_specs(dataset=dataset, dataset_name = dataset_name, dataset_display_name=dataset_display_name)]
1122+
1123+
final_columns = []
1124+
if input_feature_column_specs_included:
1125+
column_names = [a.display_name for a in columns]
1126+
if not (all (name in column_names for name in input_feature_column_specs_included)):
1127+
raise ValueError('invalid name in the list' '\'input_feature_column_specs_included\'')
1128+
for a in columns:
1129+
if a.display_name in input_feature_column_specs_included:
1130+
final_columns.append(a)
1131+
1132+
tables_model_metadata.update(
1133+
{'input_feature_column_specs': final_columns}
1134+
)
1135+
elif input_feature_column_specs_excluded:
1136+
for a in columns:
1137+
if a.display_name not in input_feature_column_specs_excluded:
1138+
final_columns.append(a)
1139+
1140+
tables_model_metadata.update(
1141+
{'input_feature_column_specs': final_columns}
1142+
)
1143+
11111144
request = {
11121145
'display_name': model_display_name,
11131146
'dataset_id': dataset_id,
1114-
'tables_model_metadata': {
1115-
'train_budget_milli_node_hours': train_budget_milli_node_hours
1116-
}
1147+
'tables_model_metadata': tables_model_metadata
11171148
}
11181149

1150+
11191151
return self.client.create_model(
11201152
self.__location_path(project=project, region=region),
11211153
request
11221154
)
11231155

1156+
11241157
def delete_model(self, model=None, model_display_name=None,
11251158
model_name=None, project=None, region=None):
11261159
"""Deletes a model. Note this will not delete any datasets associated

0 commit comments

Comments
 (0)