Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chronos : add regression functions for forecast metrics #5402

Merged
merged 8 commits into from
Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/chronos/src/bigdl/chronos/autots/tspipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,14 @@ def quantize(self,
calib_data = preprocess_quantize_data(self, calib_data)

# map metric str to function
from bigdl.chronos.metric.forecast_metrics import TORCHMETRICS_REGRESSION_MAP
from bigdl.chronos.metric.forecast_metrics import REGRESSION_MAP
if isinstance(metric, str):
metric = TORCHMETRICS_REGRESSION_MAP[metric]
metric_func = REGRESSION_MAP[metric]

def metric(y_label, y_predict):
y_label = y_label.numpy()
y_predict = y_predict.numpy()
return metric_func(y_label, y_predict)

# init acc criterion
accuracy_criterion = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,13 @@ def save(self, checkpoint_file):
def _str2metric(metric):
# map metric str to function
if isinstance(metric, str):
from bigdl.chronos.metric.forecast_metrics import TORCHMETRICS_REGRESSION_MAP
metric = TORCHMETRICS_REGRESSION_MAP[metric]
metric_name = metric
from bigdl.chronos.metric.forecast_metrics import REGRESSION_MAP
metric_func = REGRESSION_MAP[metric_name]

def metric(y_label, y_predict):
y_label = y_label.numpy()
y_predict = y_predict.numpy()
return metric_func(y_label, y_predict)
metric.__name__ = metric_name
return metric
11 changes: 9 additions & 2 deletions python/chronos/src/bigdl/chronos/forecaster/base_forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,13 @@ def check_time_steps(tsdataset, past_seq_len, future_seq_len):
def _str2metric(metric):
# map metric str to function
if isinstance(metric, str):
from bigdl.chronos.metric.forecast_metrics import TORCHMETRICS_REGRESSION_MAP
metric = TORCHMETRICS_REGRESSION_MAP[metric]
metric_name = metric
from bigdl.chronos.metric.forecast_metrics import REGRESSION_MAP
metric_func = REGRESSION_MAP[metric_name]

def metric(y_label, y_predict):
y_label = y_label.numpy()
y_predict = y_predict.numpy()
return metric_func(y_label, y_predict)
metric.__name__ = metric_name
return metric
4 changes: 2 additions & 2 deletions python/chronos/src/bigdl/chronos/forecaster/utils_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ def _format_metric_str(prefix, metric):
metrics.append(_format_metric_str(prefix, target_metric))
return metrics
if isinstance(metric, str):
from bigdl.chronos.metric.forecast_metrics import TORCHMETRICS_REGRESSION_MAP
metric_func = TORCHMETRICS_REGRESSION_MAP.get(metric, None)
from bigdl.chronos.metric.forecast_metrics import REGRESSION_MAP
metric_func = REGRESSION_MAP.get(metric, None)
invalidInputError(metric_func is not None,
"{} is not found in available metrics.".format(metric))
return _format_metric(prefix, metric_func)
Expand Down
147 changes: 109 additions & 38 deletions python/chronos/src/bigdl/chronos/metric/forecast_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,117 @@
# limitations under the License.
#

import torch
import numpy as np
from torch import Tensor
from numpy import ndarray
from functools import partial
from torchmetrics.functional import mean_squared_error, mean_absolute_error,\
mean_absolute_percentage_error, r2_score
import numpy as np
from bigdl.nano.utils.log4Error import invalidInputError
from timeit import repeat


EPSILON = 1e-10


# implemented this metric to keep up with orca.automl
def symmetric_mean_absolute_percentage_error(preds: Tensor, target: Tensor) -> Tensor:
abs_diff = torch.abs(preds - target)
abs_per_error = abs_diff / (torch.abs(preds) + torch.abs(target) + EPSILON)
sum_abs_per_error = 100 * torch.sum(abs_per_error)
num_obs = target.numel()
return sum_abs_per_error / num_obs
def mae(y_label, y_predict):
"""
Calculate the mean absolute error (MAE).
.. math::
\\text{MAE} = \\frac{1}{n}\\sum_{t=1}^n |y_t-\\hat{y_t}|
:param y_label: Array-like of shape = (n_samples, \*).
Ground truth (correct) target values.
:param y_predict: Array-like of shape = (n_samples, \*).
Estimated target values.
:return: Ndarray of floats.
An array of non-negative floating point values (the best value is 0.0).
"""
result = np.mean(np.abs(y_label - y_predict))
return result


def mse(y_label, y_predict):
"""
Calculate the mean squared error (MSE).
.. math::
\\text{MSE} = \\frac{1}{n}\\sum_{t=1}^n (y_t-\\hat{y_t})^2
:param y_label: Array-like of shape = (n_samples, \*).
Ground truth (correct) target values.
:param y_predict: Array-like of shape = (n_samples, \*).
Estimated target values.
:return: Ndarray of floats.
An array of non-negative floating point values (the best value is 0.0).
"""
result = np.mean((y_label - y_predict) ** 2)
return result


def rmse(y_label, y_predict):
"""
Calculate square root of the mean squared error (RMSE).
.. math::
\\text{RMSE} = \\sqrt{(\\frac{1}{n}\\sum_{t=1}^n (y_t-\\hat{y_t})^2)}
:param y_label: Array-like of shape = (n_samples, \*).
Ground truth (correct) target values.
:param y_predict: Array-like of shape = (n_samples, \*).
Estimated target values.
:return: Ndarray of floats.
An array of non-negative floating point values (the best value is 0.0).
"""
return np.sqrt(mse(y_label, y_predict))


def mape(y_label, y_predict):
"""
Calculate mean absolute percentage error (MAPE).
.. math::
\\text{MAPE} = \\frac{100\%}{n}\\sum_{t=1}^n |\\frac{y_t-\\hat{y_t}}{y_t}|
:param y_label: Array-like of shape = (n_samples, \*).
Ground truth (correct) target values.
:param y_predict: Array-like of shape = (n_samples, \*).
Estimated target values.
:return: Ndarray of floats.
An array of non-negative floating point values (the best value is 0.0).
"""
return np.mean(np.abs((y_label - y_predict) / (y_label + EPSILON)))


def smape(y_label, y_predict):
"""
Calculate Symmetric mean absolute percentage error (sMAPE).
.. math::
\\text{sMAPE} = \\frac{100\%}{n} \\sum_{t=1}^n \\frac{|y_t-\\hat{y_t}|}{|y_t|+|\\hat{y_t}|}
:param y_label: Array-like of shape = (n_samples, \*).
Ground truth (correct) target values.
:param y_predict: Array-like of shape = (n_samples, \*).
Estimated target values.
:return: Ndarray of floats.
An array of non-negative floating point values (the best value is 0.0).
"""
abs_diff = np.abs(y_predict - y_label)
abs_per_error = abs_diff / (np.abs(y_predict) + np.abs(y_label) + EPSILON)
sum_abs_per_error = np.mean(abs_per_error)
return sum_abs_per_error * 100


TORCHMETRICS_REGRESSION_MAP = {
'mae': mean_absolute_error,
'mse': mean_squared_error,
'rmse': partial(mean_squared_error, squared=False),
'mape': mean_absolute_percentage_error,
'smape': symmetric_mean_absolute_percentage_error,
'r2': r2_score,
def r2(y_label, y_predict):
"""
Calculate the r2 score.
.. math::
R^2 = 1-\\frac{\\sum_{t=1}^n (y_t-\\hat{y_t})^2}{\\sum_{t=1}^n (y_t-\\bar{y})^2}
:param y_label: Array-like of shape = (n_samples, \*).
Ground truth (correct) target values.
:param y_predict: Array-like of shape = (n_samples, \*).
Estimated target values.
:return: Ndarray of floats.
An array of non-negative floating point values (the best value is 1.0).
"""
return 1 - np.sum((y_label - y_predict)**2) / np.sum((y_label - np.mean(y_label))**2)


REGRESSION_MAP = {
'mae': mae,
'mse': mse,
'rmse': rmse,
'mape': mape,
'smape': smape,
'r2': r2,
}


Expand All @@ -57,13 +138,12 @@ def _standard_input(metrics, y_true, y_pred):
metrics = [metrics]
if isinstance(metrics[0], str):
metrics = list(map(lambda x: x.lower(), metrics))
invalidInputError(all(metric in TORCHMETRICS_REGRESSION_MAP.keys() for metric in metrics),
f"metric should be one of {TORCHMETRICS_REGRESSION_MAP.keys()},"
invalidInputError(all(metric in REGRESSION_MAP.keys() for metric in metrics),
f"metric should be one of {REGRESSION_MAP.keys()},"
f" but get {metrics}.")
invalidInputError(type(y_true) is type(y_pred) and isinstance(y_pred, ndarray),
"y_pred and y_true type must be numpy.ndarray,"
f" but found {type(y_pred)} and {type(y_true)}.")
y_true, y_pred = torch.from_numpy(y_true), torch.from_numpy(y_pred)

invalidInputError(y_true.shape == y_pred.shape,
"y_true and y_pred should have the same shape, "
Expand Down Expand Up @@ -91,7 +171,6 @@ class Evaluator(object):
def evaluate(metrics, y_true, y_pred, aggregate='mean'):
"""
Evaluate a specific metrics for y_true and y_pred.

:param metrics: String or list in ['mae', 'mse', 'rmse', 'r2', 'mape', 'smape'] for built-in
metrics. If callable function, it signature should be func(y_true, y_pred), where
y_true and y_pred are numpy ndarray.
Expand All @@ -100,7 +179,6 @@ def evaluate(metrics, y_true, y_pred, aggregate='mean'):
:param aggregate: aggregation method. Currently, "mean" and None are supported,
'mean' represents aggregating by mean, while None will return the element-wise
result. The value defaults to 'mean'.

:return: Float or ndarray of floats.
A floating point value, or an
array of floating point values, one for each individual target.
Expand All @@ -112,23 +190,16 @@ def evaluate(metrics, y_true, y_pred, aggregate='mean'):
if callable(metric):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove these codes? Evaluator supports input of custom callable metric functions.

metric_func = metric
else:
metric_func = TORCHMETRICS_REGRESSION_MAP[metric]
metric_func = REGRESSION_MAP[metric]
if len(original_shape) in [2, 3] and aggregate is None:
res = torch.zeros(y_true.shape[-1])
res = np.zeros(y_true.shape[-1])
for i in range(y_true.shape[-1]):
if callable(metric):
res[i] = torch.from_numpy(metric_func(y_true[..., i], y_pred[..., i]))
else:
res[i] = metric_func(y_pred[..., i], y_true[..., i])
res[i] = metric_func(y_true[..., i], y_pred[..., i])
res = res.reshape(original_shape[1:])
res_list.append(res.numpy())
res_list.append(res)
else:
if callable(metric):
res = metric_func(y_true, y_pred)
res_list.append(res)
else:
res = metric_func(y_pred, y_true)
res_list.append(res.numpy())
res = metric_func(y_true, y_pred)
res_list.append(res)
return res_list

def get_latency(func, *args, num_running=100, **kwargs):
Expand Down