Skip to content

Commit

Permalink
update create_model to allow user to specify included or excluded col… (
Browse files Browse the repository at this point in the history
googleapis#16)

* update create_model to allow user to specify included or excluded columns

* made minor changes stylistically and with added ValueError outputs
  • Loading branch information
jonathan1920 authored and Lars Wander committed Jul 19, 2019
1 parent 162fcbc commit c45eea1
Showing 1 changed file with 42 additions and 9 deletions.
51 changes: 42 additions & 9 deletions automl/google/cloud/automl_v1beta1/helper/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,8 +1043,8 @@ def list_models(self, project=None, region=None):
def create_model(self, model_display_name, dataset=None,
dataset_display_name=None, dataset_name=None,
train_budget_milli_node_hours=None, project=None,
region=None):
"""Create a model. This will train your model on the given dataset.
region=None, input_feature_column_specs_included=None, input_feature_column_specs_excluded=None):
"""Create a model. This will train your model on the given dataset.
Example:
>>> from google.cloud import automl_v1beta1
Expand All @@ -1057,7 +1057,6 @@ def create_model(self, model_display_name, dataset=None,
>>>
>>> m.result() # blocks on result
>>>
Args:
project (Optional[string]):
If you have initialized the client with a value for `project`
Expand Down Expand Up @@ -1085,11 +1084,15 @@ def create_model(self, model_display_name, dataset=None,
The `Dataset` instance you want to train your model on. This
must be supplied if `dataset_display_name` or `dataset_name`
are not supplied.
input_feature_column_specs_included(Optional[string]):
The list of the names of the columns you want to include to train
your model on.
input_feature_column_specs_excluded(Optional[string]):
The list of the names of the columns you want to exclude and
not train your model on.
Returns:
A :class:`~google.cloud.automl_v1beta1.types._OperationFuture`
instance.
Raises:
google.api_core.exceptions.GoogleAPICallError: If the request
failed for any reason.
Expand All @@ -1101,26 +1104,56 @@ def create_model(self, model_display_name, dataset=None,
raise ValueError('\'train_budget_milli_node_hours\' must be a '
'value between 1,000 and 72,000 inclusive')

if input_feature_column_specs_excluded not in [None, []] and input_feature_column_specs_included not in [None, []]:
raise ValueError('\'cannot set both input_feature_column_specs_excluded\' and '
'\'input_feature_column_specs_included\'')


dataset_name = self.__dataset_name_from_args(dataset=dataset,
dataset_name=dataset_name,
dataset_display_name=dataset_display_name,
project=project,
region=region)

tables_model_metadata = {
'train_budget_milli_node_hours': train_budget_milli_node_hours
}
dataset_id = dataset_name.rsplit('/', 1)[-1]
columns = [s for s in self.list_column_specs(dataset=dataset, dataset_name = dataset_name, dataset_display_name=dataset_display_name)]

final_columns = []
if input_feature_column_specs_included:
column_names = [a.display_name for a in columns]
if not (all (name in column_names for name in input_feature_column_specs_included)):
raise ValueError('invalid name in the list' '\'input_feature_column_specs_included\'')
for a in columns:
if a.display_name in input_feature_column_specs_included:
final_columns.append(a)

tables_model_metadata.update(
{'input_feature_column_specs': final_columns}
)
elif input_feature_column_specs_excluded:
for a in columns:
if a.display_name not in input_feature_column_specs_excluded:
final_columns.append(a)

tables_model_metadata.update(
{'input_feature_column_specs': final_columns}
)

request = {
'display_name': model_display_name,
'dataset_id': dataset_id,
'tables_model_metadata': {
'train_budget_milli_node_hours': train_budget_milli_node_hours
}
'tables_model_metadata': tables_model_metadata
}


return self.client.create_model(
self.__location_path(project=project, region=region),
request
)


def delete_model(self, model=None, model_display_name=None,
model_name=None, project=None, region=None):
"""Deletes a model. Note this will not delete any datasets associated
Expand Down

0 comments on commit c45eea1

Please sign in to comment.