Skip to content

Commit

Permalink
add updates
Browse files Browse the repository at this point in the history
  • Loading branch information
TheaperDeng committed Sep 15, 2022
1 parent 3f5b76d commit 47c5de1
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
12 changes: 6 additions & 6 deletions python/chronos/src/bigdl/chronos/model/autoformer/Autoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, configs):
self.lr = configs.lr
self.lr_scheduler_milestones = configs.lr_scheduler_milestones
self.loss = loss_creator(configs.loss)
self.configs = configs
self.c_out = configs.c_out

# Decomp
# change kernei_size to odd
Expand Down Expand Up @@ -155,23 +155,23 @@ def training_step(self, batch, batch_idx):
batch_x, batch_y, batch_x_mark, batch_y_mark = map(lambda x: x.float(), batch)
outputs = self(batch_x, batch_x_mark, batch_y, batch_y_mark)

outputs = outputs[:, -self.pred_len:, -self.configs.c_out:]
batch_y = batch_y[:, -self.pred_len:, -self.configs.c_out:]
outputs = outputs[:, -self.pred_len:, -self.c_out:]
batch_y = batch_y[:, -self.pred_len:, -self.c_out:]
return self.loss(outputs, batch_y)

def validation_step(self, batch, batch_idx):
batch_x, batch_y, batch_x_mark, batch_y_mark = map(lambda x: x.float(), batch)
outputs = self(batch_x, batch_x_mark, batch_y, batch_y_mark)

outputs = outputs[:, -self.pred_len:, -self.configs.c_out:]
batch_y = batch_y[:, -self.pred_len:, -self.configs.c_out:]
outputs = outputs[:, -self.pred_len:, -self.c_out:]
batch_y = batch_y[:, -self.pred_len:, -self.c_out:]
self.log("val_loss", self.loss(outputs, batch_y))

def predict_step(self, batch, batch_idx):
batch_x, batch_y, batch_x_mark, batch_y_mark = map(lambda x: x.float(), batch)
outputs = self(batch_x, batch_x_mark, batch_y, batch_y_mark)

outputs = outputs[:, -self.pred_len:, -self.configs.c_out:]
outputs = outputs[:, -self.pred_len:, -self.c_out:]
return outputs

def configure_optimizers(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,16 @@ def get_ts_df():
return train_df


def create_data(loader=False):
def create_data(loader=False, extra_feature=False):
df = get_ts_df()
target = ["value", "extra feature"]
if extra_feature:
target = ["value"]
extra = ["extra feature"]
else:
target = ["value", "extra feature"]
extra = []
tsdata_train, tsdata_val, tsdata_test =\
TSDataset.from_pandas(df, dt_col="datetime", target_col=target,
TSDataset.from_pandas(df, dt_col="datetime", target_col=target, extra_feature_col=extra,
with_split=True, test_ratio=0.1, val_ratio=0.1)
if loader:
train_loader = tsdata_train.to_torch_data_loader(lookback=24, horizon=5,
Expand Down Expand Up @@ -240,3 +245,19 @@ def test_autoformer_forecaster_even_kernel(self):
evaluate = forecaster.evaluate(val_loader)
pred = forecaster.predict(test_loader)
evaluate_list.append(evaluate)

def test_autoformer_forecaster_diff_input_output_dim(self):
train_loader, val_loader, test_loader = create_data(loader=True, extra_feature=True)
evaluate_list = []
forecaster = AutoformerForecaster(past_seq_len=24,
future_seq_len=5,
input_feature_num=2,
output_feature_num=1,
label_len=12,
freq='s',
seed=0,
moving_avg=20) # even
forecaster.fit(train_loader, epochs=3, batch_size=32)
evaluate = forecaster.evaluate(val_loader)
pred = forecaster.predict(test_loader)
evaluate_list.append(evaluate)

0 comments on commit 47c5de1

Please sign in to comment.