Skip to content

Commit

Permalink
fix merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
elineii committed Feb 19, 2024
2 parents 6435586 + f5a6a38 commit dcbdb15
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tsururu/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ def __init__(
elif column_type == "datetime":
data[column_name] = pd.to_datetime(data[column_name])

self.seq_data = data
self.columns_and_features_params = columns_and_features_params
self.history = history
self.step = step
Expand All @@ -406,6 +405,7 @@ def __init__(
self.date_column = columns_and_features_params["date"]["column"][0]
self.delta = delta
self.print_freq_period_info()
self.seq_data = data.sort_values(["id", "date"])

def make_padded_test(
self,
Expand Down
8 changes: 8 additions & 0 deletions tsururu/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
self.get_num_iterations = get_num_iterations
if self.get_num_iterations:
self.num_iterations = []
self.columns = None

def initialize_validator(self):
"""Initialization of the sample generator for training the model
Expand Down Expand Up @@ -102,6 +103,10 @@ def __init__(
super().__init__(get_num_iterations, validation_params, model_params)

def fit(self, X: pd.DataFrame, y: NDArray[np.floating]) -> None:
# Initialize columns' order and reorder columns
self.columns = sorted(X.columns)
X = X[self.columns]

# Initialize cv object
cv = self.initialize_validator()

Expand Down Expand Up @@ -152,6 +157,9 @@ def fit(self, X: pd.DataFrame, y: NDArray[np.floating]) -> None:
print(f"Std: {np.std(self.scores).round(4)}")

def predict(self, X: pd.DataFrame) -> NDArray[np.floating]:
# Reorder columns
X = X[self.columns]

models_preds = [model.predict(X) for model in self.models]
y_pred = np.mean(models_preds, axis=0)
return y_pred
Expand Down
5 changes: 4 additions & 1 deletion tsururu/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,12 @@ def _generate_X_y(
)
if role == "id":
self.id_feature_column = current_X
final_X = pd.concat((final_X, current_X), axis=1)
if not column_params["drop_raw_feature"]:
final_X = pd.concat((final_X, current_X), axis=1)
else:
for transformer_name, transformer_params in column_params["features"].items():
assert not (role != "target" and transformer_params.get("transform_target")), "It is not possible to use transform_target=True with transformers for exogenous variables"

transformer_init_params = {
"transformer_name": transformer_name,
"transformer_params": {
Expand Down
12 changes: 10 additions & 2 deletions tsururu/transformers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations
from typing import List, Union, Tuple, Optional
from numpy.typing import NDArray
import re

import numpy as np
import pandas as pd
import holidays

from .dataset import IndexSlicer

LAG_TRANSFORMER_MASK = r"lag_\d+__"
SEASON_TRANSFORMER_MASK = r"season_\w+__"

date_attrs = {
"y": "year",
"m": "month",
Expand Down Expand Up @@ -81,10 +85,14 @@ def fit(
Fitted transformer.
"""
self.columns = raw_ts_X.columns[
np.hstack(
[raw_ts_X.columns.str.contains(raw_column_name) for raw_column_name in columns]
np.any(
[raw_ts_X.columns.str.contains(
fr"{LAG_TRANSFORMER_MASK}{re.escape(raw_column_name)}$|{SEASON_TRANSFORMER_MASK}{re.escape(raw_column_name)}$|^{re.escape(raw_column_name)}$"
) for raw_column_name in columns],
axis=0
)
]

self.id_column = id_column
self.transform_train = transform_train
self.transform_target = transform_target
Expand Down

0 comments on commit dcbdb15

Please sign in to comment.