Skip to content

Commit

Permalink
Feat: Add google.ClassificationMetrics, google.RegressionMetrics, and…
Browse files Browse the repository at this point in the history
… google.Forecasting Metrics (#1549)

* Add google.ClassificationMetrics, google.RegressionMetrics, and google.ForecastingMetrics Artifact types to metadata schema with unit tests.

* fix implicit false

* Fix typo

* Running nox -s blacken and nox -s lint

* fix typo in unit test

Co-authored-by: sina chavoshi <[email protected]>
  • Loading branch information
KevinBNaughton and SinaChavoshi authored Aug 2, 2022
1 parent caebb47 commit 3526b3e
Show file tree
Hide file tree
Showing 2 changed files with 393 additions and 0 deletions.
258 changes: 258 additions & 0 deletions google/cloud/aiplatform/metadata/schema/google/artifact_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,261 @@ def __init__(
metadata=extended_metadata,
state=state,
)


class ClassificationMetrics(base_artifact.BaseArtifactSchema):
"""A Google artifact representing evaluation Classification Metrics."""

schema_title = "google.ClassificationMetrics"

def __init__(
self,
*,
au_prc: Optional[float] = None,
au_roc: Optional[float] = None,
log_loss: Optional[float] = None,
artifact_id: Optional[str] = None,
uri: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
):
"""Args:
au_prc (float):
Optional. The Area Under Precision-Recall Curve metric.
Micro-averaged for the overall evaluation.
au_roc (float):
Optional. The Area Under Receiver Operating Characteristic curve metric.
Micro-averaged for the overall evaluation.
log_loss (float):
Optional. The Log Loss metric.
artifact_id (str):
Optional. The <resource_id> portion of the Artifact name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
uri (str):
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
artifact file.
display_name (str):
Optional. The user-defined name of the Artifact.
schema_version (str):
Optional. schema_version specifies the version used by the Artifact.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Artifact to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Artifact.
state (google.cloud.gapic.types.Artifact.State):
Optional. The state of this Artifact. This is a
property of the Artifact, and does not imply or
capture any ongoing process. This property is
managed by clients (such as Vertex AI
Pipelines), and the system does not prescribe or
check the validity of state transitions.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
if au_prc:
extended_metadata["auPrc"] = au_prc
if au_roc:
extended_metadata["auRoc"] = au_roc
if log_loss:
extended_metadata["logLoss"] = log_loss

super(ClassificationMetrics, self).__init__(
uri=uri,
artifact_id=artifact_id,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
state=state,
)


class RegressionMetrics(base_artifact.BaseArtifactSchema):
"""A Google artifact representing evaluation Regression Metrics."""

schema_title = "google.RegressionMetrics"

def __init__(
self,
*,
root_mean_squared_error: Optional[float] = None,
mean_absolute_error: Optional[float] = None,
mean_absolute_percentage_error: Optional[float] = None,
r_squared: Optional[float] = None,
root_mean_squared_log_error: Optional[float] = None,
artifact_id: Optional[str] = None,
uri: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
):
"""Args:
root_mean_squared_error (float):
Optional. Root Mean Squared Error (RMSE).
mean_absolute_error (float):
Optional. Mean Absolute Error (MAE).
mean_absolute_percentage_error (float):
Optional. Mean absolute percentage error.
r_squared (float):
Optional. Coefficient of determination as Pearson correlation coefficient.
root_mean_squared_log_error (float):
Optional. Root mean squared log error.
artifact_id (str):
Optional. The <resource_id> portion of the Artifact name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
uri (str):
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
artifact file.
display_name (str):
Optional. The user-defined name of the Artifact.
schema_version (str):
Optional. schema_version specifies the version used by the Artifact.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Artifact to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Artifact.
state (google.cloud.gapic.types.Artifact.State):
Optional. The state of this Artifact. This is a
property of the Artifact, and does not imply or
capture any ongoing process. This property is
managed by clients (such as Vertex AI
Pipelines), and the system does not prescribe or
check the validity of state transitions.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
if root_mean_squared_error:
extended_metadata["rootMeanSquaredError"] = root_mean_squared_error
if mean_absolute_error:
extended_metadata["meanAbsoluteError"] = mean_absolute_error
if mean_absolute_percentage_error:
extended_metadata[
"meanAbsolutePercentageError"
] = mean_absolute_percentage_error
if r_squared:
extended_metadata["rSquared"] = r_squared
if root_mean_squared_log_error:
extended_metadata["rootMeanSquaredLogError"] = root_mean_squared_log_error

super(RegressionMetrics, self).__init__(
uri=uri,
artifact_id=artifact_id,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
state=state,
)


class ForecastingMetrics(base_artifact.BaseArtifactSchema):
"""A Google artifact representing evaluation Forecasting Metrics."""

schema_title = "google.ForecastingMetrics"

def __init__(
self,
*,
root_mean_squared_error: Optional[float] = None,
mean_absolute_error: Optional[float] = None,
mean_absolute_percentage_error: Optional[float] = None,
r_squared: Optional[float] = None,
root_mean_squared_log_error: Optional[float] = None,
weighted_absolute_percentage_error: Optional[float] = None,
root_mean_squared_percentage_error: Optional[float] = None,
symmetric_mean_absolute_percentage_error: Optional[float] = None,
artifact_id: Optional[str] = None,
uri: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
):
"""Args:
root_mean_squared_error (float):
Optional. Root Mean Squared Error (RMSE).
mean_absolute_error (float):
Optional. Mean Absolute Error (MAE).
mean_absolute_percentage_error (float):
Optional. Mean absolute percentage error.
r_squared (float):
Optional. Coefficient of determination as Pearson correlation coefficient.
root_mean_squared_log_error (float):
Optional. Root mean squared log error.
weighted_absolute_percentage_error (float):
Optional. Weighted Absolute Percentage Error.
Does not use weights, this is just what the metric is called.
Undefined if actual values sum to zero.
Will be very large if actual values sum to a very small number.
root_mean_squared_percentage_error (float):
Optional. Root Mean Square Percentage Error. Square root of MSPE.
Undefined/imaginary when MSPE is negative.
symmetric_mean_absolute_percentage_error (float):
Optional. Symmetric Mean Absolute Percentage Error.
artifact_id (str):
Optional. The <resource_id> portion of the Artifact name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
uri (str):
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
artifact file.
display_name (str):
Optional. The user-defined name of the Artifact.
schema_version (str):
Optional. schema_version specifies the version used by the Artifact.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Artifact to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Artifact.
state (google.cloud.gapic.types.Artifact.State):
Optional. The state of this Artifact. This is a
property of the Artifact, and does not imply or
capture any ongoing process. This property is
managed by clients (such as Vertex AI
Pipelines), and the system does not prescribe or
check the validity of state transitions.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
if root_mean_squared_error:
extended_metadata["rootMeanSquaredError"] = root_mean_squared_error
if mean_absolute_error:
extended_metadata["meanAbsoluteError"] = mean_absolute_error
if mean_absolute_percentage_error:
extended_metadata[
"meanAbsolutePercentageError"
] = mean_absolute_percentage_error
if r_squared:
extended_metadata["rSquared"] = r_squared
if root_mean_squared_log_error:
extended_metadata["rootMeanSquaredLogError"] = root_mean_squared_log_error
if weighted_absolute_percentage_error:
extended_metadata[
"weightedAbsolutePercentageError"
] = weighted_absolute_percentage_error
if root_mean_squared_percentage_error:
extended_metadata[
"rootMeanSquaredPercentageError"
] = root_mean_squared_percentage_error
if symmetric_mean_absolute_percentage_error:
extended_metadata[
"symmetricMeanAbsolutePercentageError"
] = symmetric_mean_absolute_percentage_error

super(ForecastingMetrics, self).__init__(
uri=uri,
artifact_id=artifact_id,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
state=state,
)
Loading

0 comments on commit 3526b3e

Please sign in to comment.