diff --git a/docs/zh/api/arch.md b/docs/zh/api/arch.md index aee5410d3..bacbc56ad 100644 --- a/docs/zh/api/arch.md +++ b/docs/zh/api/arch.md @@ -27,5 +27,6 @@ - DGMR - ChipDeepONets - AutoEncoder + - CuboidTransformer show_root_heading: true heading_level: 3 diff --git a/docs/zh/api/data/dataset.md b/docs/zh/api/data/dataset.md index e884ae140..c5873a144 100644 --- a/docs/zh/api/data/dataset.md +++ b/docs/zh/api/data/dataset.md @@ -23,6 +23,8 @@ - MeshAirfoilDataset - MeshCylinderDataset - RadarDataset + - ENSODataset + - SEVIRDataset - build_dataset - DGMRDataset show_root_heading: true diff --git a/examples/earthformer/conf/earthformer_enso_pretrain.yaml b/examples/earthformer/conf/earthformer_enso_pretrain.yaml new file mode 100644 index 000000000..b5ea679b4 --- /dev/null +++ b/examples/earthformer/conf/earthformer_enso_pretrain.yaml @@ -0,0 +1,153 @@ +hydra: + run: + # dynamic output directory according to running time and override name + dir: outputs_earthformer_pretrain + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchaned + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - mode + - output_dir + - log_freq + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: train # running mode: train/eval/export/infer +seed: 0 +output_dir: ${hydra:run.dir} +log_freq: 20 + +# set train and evaluate data path +FILE_PATH: ./datasets/enso/enso_round1_train_20210201 + +# dataset setting +DATASET: + label_keys: ["sst_target","nino_target"] + in_len: 12 + out_len: 14 + nino_window_t: 3 + in_stride: 1 + out_stride: 1 + train_samples_gap: 2 + eval_samples_gap: 1 + normalize_sst: true + +# model settings +MODEL: + input_keys: ["sst_data"] + output_keys: ["sst_target","nino_target"] + input_shape: [12, 24, 48, 1] + target_shape: [14, 24, 48, 1] + base_units: 64 + scale_alpha: 1.0 + + enc_depth: [1, 1] + dec_depth: [1, 1] + enc_use_inter_ffn: true + dec_use_inter_ffn: true + dec_hierarchical_pos_embed: false + + downsample: 2 + downsample_type: "patch_merge" + upsample_type: "upsample" + + num_global_vectors: 0 + use_dec_self_global: false + dec_self_update_global: true + use_dec_cross_global: false + use_global_vector_ffn: false + use_global_self_attn: false + separate_global_qkv: false + global_dim_ratio: 1 + + self_pattern: "axial" + cross_self_pattern: "axial" + cross_pattern: "cross_1x1" + dec_cross_last_n_frames: null + + attn_drop: 0.1 + proj_drop: 0.1 + ffn_drop: 0.1 + num_heads: 4 + + ffn_activation: "gelu" + gated_ffn: false + norm_layer: "layer_norm" + padding_type: "zeros" + pos_embed_type: "t+h+w" + use_relative_pos: true + self_attn_use_final_proj: true + dec_use_first_self_attn: false + + z_init_method: "zeros" + initial_downsample_type: "conv" + initial_downsample_activation: "leaky_relu" + initial_downsample_scale: [1, 1, 2] + initial_downsample_conv_layers: 2 + final_upsample_conv_layers: 1 + checkpoint_level: 2 + + attn_linear_init_mode: "0" + ffn_linear_init_mode: "0" + conv_init_mode: "0" + down_up_linear_init_mode: "0" + norm_init_mode: "0" + + +# training settings +TRAIN: + epochs: 100 + save_freq: 20 + eval_during_train: true + eval_freq: 10 + lr_scheduler: + epochs: ${TRAIN.epochs} + learning_rate: 0.0002 + by_epoch: true + min_lr_ratio: 1.0e-3 + wd: 1.0e-5 + batch_size: 16 + pretrained_model_path: null + checkpoint_path: null + + +# evaluation settings +EVAL: + pretrained_model_path: ./checkpoint/enso/earthformer_enso.pdparams + compute_metric_by_batch: false + eval_with_no_grad: true + batch_size: 1 + +INFER: + pretrained_model_path: ./checkpoint/enso/earthformer_enso.pdparams + export_path: ./inference/earthformer/enso + pdmodel_path: ${INFER.export_path}.pdmodel + pdpiparams_path: ${INFER.export_path}.pdiparams + device: gpu + engine: native + precision: fp32 + onnx_path: ${INFER.export_path}.onnx + ir_optim: true + min_subgraph_size: 10 + gpu_mem: 4000 + gpu_id: 0 + max_batch_size: 16 + num_cpu_threads: 4 + batch_size: 1 + data_path: ./datasets/enso/infer/SODA_train.nc + in_len: 12 + in_stride: 1 + out_len: 14 + out_stride: 1 + samples_gap: 1 diff --git a/examples/earthformer/conf/earthformer_sevir_pretrain.yaml b/examples/earthformer/conf/earthformer_sevir_pretrain.yaml new file mode 100644 index 000000000..73d437dc9 --- /dev/null +++ b/examples/earthformer/conf/earthformer_sevir_pretrain.yaml @@ -0,0 +1,185 @@ +hydra: + run: + # dynamic output directory according to running time and override name + dir: outputs_earthformer_pretrain + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchaned + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - mode + - output_dir + - log_freq + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: train # running mode: train/eval/export/infer +seed: 0 +output_dir: ${hydra:run.dir} +log_freq: 20 + +# set train and evaluate data path +FILE_PATH: ./datasets/sevir/sevir_data + +# SEVIR dataset:raw_seq_len: 49,interval_real_time:5, img_height = 384,img_width = 384 +# SEVIR_lr dataset:raw_seq_len: 25,interval_real_time:10, img_height = 128,img_width = 128 + +# dataset setting +DATASET: + label_keys: ["vil"] + data_types: ["vil"] + seq_len: 25 + raw_seq_len: 49 + sample_mode: "sequent" + stride: 12 + batch_size: 2 + layout: "NTHWC" + in_len: 13 + out_len: 12 + split_mode: "uneven" + + shuffle_seed: 1 + rescale_method: "01" + downsample_dict: null + verbose: false + preprocess: true + +# model settings +MODEL: + input_keys: ["input"] + output_keys: ["vil"] + input_shape: [13, 384, 384, 1] + target_shape: [12, 384, 384, 1] + base_units: 128 + scale_alpha: 1.0 + + enc_depth: [1, 1] + dec_depth: [1, 1] + enc_use_inter_ffn: true + dec_use_inter_ffn: true + dec_hierarchical_pos_embed: false + + downsample: 2 + downsample_type: "patch_merge" + upsample_type: "upsample" + + num_global_vectors: 8 + use_dec_self_global: false + dec_self_update_global: true + use_dec_cross_global: false + use_global_vector_ffn: false + use_global_self_attn: true + separate_global_qkv: true + global_dim_ratio: 1 + + self_pattern: "axial" + cross_self_pattern: "axial" + cross_pattern: "cross_1x1" + dec_cross_last_n_frames: null + + attn_drop: 0.1 + proj_drop: 0.1 + ffn_drop: 0.1 + num_heads: 4 + + ffn_activation: "gelu" + gated_ffn: false + norm_layer: "layer_norm" + padding_type: "zeros" + pos_embed_type: "t+h+w" + use_relative_pos: true + self_attn_use_final_proj: true + dec_use_first_self_attn: false + + z_init_method: "zeros" + initial_downsample_type: "stack_conv" + initial_downsample_activation: "leaky_relu" + initial_downsample_stack_conv_num_layers: 3 + initial_downsample_stack_conv_dim_list: [16, 64, 128] + initial_downsample_stack_conv_downscale_list: [3, 2, 2] + initial_downsample_stack_conv_num_conv_list: [2, 2, 2] + checkpoint_level: 2 + + attn_linear_init_mode: "0" + ffn_linear_init_mode: "0" + conv_init_mode: "0" + down_up_linear_init_mode: "0" + norm_init_mode: "0" + + +# training settings +TRAIN: + epochs: 100 + save_freq: 20 + eval_during_train: true + eval_freq: 10 + lr_scheduler: + epochs: ${TRAIN.epochs} + learning_rate: 0.001 + by_epoch: true + min_lr_ratio: 1.0e-3 + wd: 0.0 + batch_size: 1 + pretrained_model_path: null + checkpoint_path: null + start_date: null + end_date: [2019, 1, 1] + + +# evaluation settings +EVAL: + pretrained_model_path: ./checkpoint/sevir/earthformer_sevir.pdparams + compute_metric_by_batch: false + eval_with_no_grad: true + batch_size: 1 + end_date: [2019, 6, 1] + + metrics_mode: "0" + metrics_list: ["csi", "pod", "sucr", "bias"] + threshold_list: [16, 74, 133, 160, 181, 219] + + +TEST: + pretrained_model_path: ./checkpoint/sevir/earthformer_sevir.pdparams + compute_metric_by_batch: true + eval_with_no_grad: true + batch_size: 1 + start_date: [2019, 6, 1] + end_date: null + +INFER: + pretrained_model_path: ./checkpoint/sevir/earthformer_sevir.pdparams + export_path: ./inference/earthformer/sevir + pdmodel_path: ${INFER.export_path}.pdmodel + pdpiparams_path: ${INFER.export_path}.pdiparams + device: gpu + engine: native + precision: fp32 + onnx_path: ${INFER.export_path}.onnx + ir_optim: true + min_subgraph_size: 10 + gpu_mem: 4000 + gpu_id: 0 + max_batch_size: 16 + num_cpu_threads: 4 + batch_size: 1 + data_path: ./datasets/sevir/vil/2019/SEVIR_VIL_STORMEVENTS_2019_0701_1231.h5 + in_len: 13 + out_len: 12 + sevir_vis_save: ./inference/earthformer/sevir/vis + layout: "NTHWC" + plot_stride: 2 + logging_prefix: "Cuboid_SEVIR" + interval_real_time: 5 + data_type: "vil" + rescale_method: "01" diff --git a/examples/earthformer/earthformer_enso_train.py b/examples/earthformer/earthformer_enso_train.py new file mode 100644 index 000000000..120654c70 --- /dev/null +++ b/examples/earthformer/earthformer_enso_train.py @@ -0,0 +1,282 @@ +from os import path as osp + +import hydra +import numpy as np +import paddle +from omegaconf import DictConfig +from paddle import nn + +import examples.earthformer.enso_metric as enso_metric +import ppsci +from ppsci.data.dataset import enso_dataset +from ppsci.utils import logger + +try: + import xarray as xr +except ModuleNotFoundError: + raise ModuleNotFoundError("Please install xarray with `pip install xarray`.") + + +def get_parameter_names(model, forbidden_layer_types): + result = [] + for name, child in model.named_children(): + result += [ + f"{name}.{n}" + for n in get_parameter_names(child, forbidden_layer_types) + if not isinstance(child, tuple(forbidden_layer_types)) + ] + # Add model specific parameters (defined with nn.Parameter) since they are not in any child. + result += list(model._parameters.keys()) + return result + + +def train(cfg: DictConfig): + # set train dataloader config + train_dataloader_cfg = { + "dataset": { + "name": "ENSODataset", + "data_dir": cfg.FILE_PATH, + "input_keys": cfg.MODEL.input_keys, + "label_keys": cfg.DATASET.label_keys, + "in_len": cfg.DATASET.in_len, + "out_len": cfg.DATASET.out_len, + "in_stride": cfg.DATASET.in_stride, + "out_stride": cfg.DATASET.out_stride, + "train_samples_gap": cfg.DATASET.train_samples_gap, + "eval_samples_gap": cfg.DATASET.eval_samples_gap, + "normalize_sst": cfg.DATASET.normalize_sst, + }, + "sampler": { + "name": "BatchSampler", + "drop_last": True, + "shuffle": True, + }, + "batch_size": cfg.TRAIN.batch_size, + "num_workers": 8, + } + + # set constraint + sup_constraint = ppsci.constraint.SupervisedConstraint( + train_dataloader_cfg, + loss=ppsci.loss.FunctionalLoss(enso_metric.train_mse_func), + name="Sup", + ) + constraint = {sup_constraint.name: sup_constraint} + + # set iters_per_epoch by dataloader length + ITERS_PER_EPOCH = len(sup_constraint.data_loader) + # set eval dataloader config + eval_dataloader_cfg = { + "dataset": { + "name": "ENSODataset", + "data_dir": cfg.FILE_PATH, + "input_keys": cfg.MODEL.input_keys, + "label_keys": cfg.DATASET.label_keys, + "in_len": cfg.DATASET.in_len, + "out_len": cfg.DATASET.out_len, + "in_stride": cfg.DATASET.in_stride, + "out_stride": cfg.DATASET.out_stride, + "train_samples_gap": cfg.DATASET.train_samples_gap, + "eval_samples_gap": cfg.DATASET.eval_samples_gap, + "normalize_sst": cfg.DATASET.normalize_sst, + "training": "eval", + }, + "batch_size": cfg.EVAL.batch_size, + } + + sup_validator = ppsci.validate.SupervisedValidator( + eval_dataloader_cfg, + loss=ppsci.loss.FunctionalLoss(enso_metric.train_mse_func), + metric={ + "rmse": ppsci.metric.FunctionalMetric(enso_metric.eval_rmse_func), + }, + name="Sup_Validator", + ) + validator = {sup_validator.name: sup_validator} + + model = ppsci.arch.CuboidTransformer( + **cfg.MODEL, + ) + + decay_parameters = get_parameter_names(model, [nn.LayerNorm]) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if n in decay_parameters], + "weight_decay": cfg.TRAIN.wd, + }, + { + "params": [ + p for n, p in model.named_parameters() if n not in decay_parameters + ], + "weight_decay": 0.0, + }, + ] + + # # init optimizer and lr scheduler + lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler) + lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine( + **lr_scheduler_cfg, + iters_per_epoch=ITERS_PER_EPOCH, + eta_min=cfg.TRAIN.min_lr_ratio * cfg.TRAIN.lr_scheduler.learning_rate, + warmup_epoch=int(0.2 * cfg.TRAIN.epochs), + )() + optimizer = paddle.optimizer.AdamW( + lr_scheduler, parameters=optimizer_grouped_parameters + ) + + # initialize solver + solver = ppsci.solver.Solver( + model, + constraint, + cfg.output_dir, + optimizer, + lr_scheduler, + cfg.TRAIN.epochs, + ITERS_PER_EPOCH, + eval_during_train=cfg.TRAIN.eval_during_train, + seed=cfg.seed, + validator=validator, + compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + ) + # train model + solver.train() + # evaluate after finished training + solver.eval() + + +def evaluate(cfg: DictConfig): + # set eval dataloader config + eval_dataloader_cfg = { + "dataset": { + "name": "ENSODataset", + "data_dir": cfg.FILE_PATH, + "input_keys": cfg.MODEL.input_keys, + "label_keys": cfg.DATASET.label_keys, + "in_len": cfg.DATASET.in_len, + "out_len": cfg.DATASET.out_len, + "in_stride": cfg.DATASET.in_stride, + "out_stride": cfg.DATASET.out_stride, + "train_samples_gap": cfg.DATASET.train_samples_gap, + "eval_samples_gap": cfg.DATASET.eval_samples_gap, + "normalize_sst": cfg.DATASET.normalize_sst, + "training": "test", + }, + "batch_size": cfg.EVAL.batch_size, + } + + sup_validator = ppsci.validate.SupervisedValidator( + eval_dataloader_cfg, + loss=ppsci.loss.FunctionalLoss(enso_metric.train_mse_func), + metric={ + "rmse": ppsci.metric.FunctionalMetric(enso_metric.eval_rmse_func), + }, + name="Sup_Validator", + ) + validator = {sup_validator.name: sup_validator} + + model = ppsci.arch.CuboidTransformer( + **cfg.MODEL, + ) + + # initialize solver + solver = ppsci.solver.Solver( + model, + output_dir=cfg.output_dir, + log_freq=cfg.log_freq, + seed=cfg.seed, + validator=validator, + pretrained_model_path=cfg.EVAL.pretrained_model_path, + compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + ) + # evaluate + solver.eval() + + +def export(cfg: DictConfig): + # set model + model = ppsci.arch.CuboidTransformer( + **cfg.MODEL, + ) + + # initialize solver + solver = ppsci.solver.Solver( + model, + pretrained_model_path=cfg.INFER.pretrained_model_path, + ) + # export model + from paddle.static import InputSpec + + input_spec = [ + { + key: InputSpec([1, 12, 24, 48, 1], "float32", name=key) + for key in model.input_keys + }, + ] + solver.export(input_spec, cfg.INFER.export_path) + + +def inference(cfg: DictConfig): + import predictor + + predictor = predictor.EarthformerPredictor(cfg) + + train_cmip = xr.open_dataset(cfg.INFER.data_path).transpose( + "year", "month", "lat", "lon" + ) + # select longitudes + lon = train_cmip.lon.values + lon = lon[np.logical_and(lon >= 95, lon <= 330)] + train_cmip = train_cmip.sel(lon=lon) + data = train_cmip.sst.values + data = enso_dataset.fold(data) + + idx_sst = enso_dataset.prepare_inputs_targets( + len_time=data.shape[0], + input_length=cfg.INFER.in_len, + input_gap=cfg.INFER.in_stride, + pred_shift=cfg.INFER.out_len * cfg.INFER.out_stride, + pred_length=cfg.INFER.out_len, + samples_gap=cfg.INFER.samples_gap, + ) + data = data[idx_sst].astype("float32") + + sst_data = data[..., np.newaxis] + idx = np.random.choice(len(data), None, False) + in_seq = sst_data[idx, : cfg.INFER.in_len, ...] # ( in_len, lat, lon, 1) + in_seq = in_seq[np.newaxis, ...] + target_seq = sst_data[idx, cfg.INFER.in_len :, ...] # ( out_len, lat, lon, 1) + target_seq = target_seq[np.newaxis, ...] + + pred_data = predictor.predict(in_seq, cfg.INFER.batch_size) + + # save predict data + save_path = osp.join(cfg.output_dir, "result_enso_pred.npy") + np.save(save_path, pred_data) + logger.info(f"Save output to {save_path}") + + +@hydra.main( + version_base=None, + config_path="./conf", + config_name="earthformer_enso_pretrain.yaml", +) +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "infer": + inference(cfg) + else: + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/earthformer/earthformer_sevir_train.py b/examples/earthformer/earthformer_sevir_train.py new file mode 100644 index 000000000..fdd20128a --- /dev/null +++ b/examples/earthformer/earthformer_sevir_train.py @@ -0,0 +1,354 @@ +import h5py +import hydra +import numpy as np +import paddle +import sevir_metric +import sevir_vis_seq +from omegaconf import DictConfig +from paddle import nn + +import ppsci + + +def get_parameter_names(model, forbidden_layer_types): + result = [] + for name, child in model.named_children(): + result += [ + f"{name}.{n}" + for n in get_parameter_names(child, forbidden_layer_types) + if not isinstance(child, tuple(forbidden_layer_types)) + ] + # Add model specific parameters (defined with nn.Parameter) since they are not in any child. + result += list(model._parameters.keys()) + return result + + +def train(cfg: DictConfig): + # set train dataloader config + train_dataloader_cfg = { + "dataset": { + "name": "SEVIRDataset", + "data_dir": cfg.FILE_PATH, + "input_keys": cfg.MODEL.input_keys, + "label_keys": cfg.DATASET.label_keys, + "data_types": cfg.DATASET.data_types, + "seq_len": cfg.DATASET.seq_len, + "raw_seq_len": cfg.DATASET.raw_seq_len, + "sample_mode": cfg.DATASET.sample_mode, + "stride": cfg.DATASET.stride, + "batch_size": cfg.DATASET.batch_size, + "layout": cfg.DATASET.layout, + "in_len": cfg.DATASET.in_len, + "out_len": cfg.DATASET.out_len, + "split_mode": cfg.DATASET.split_mode, + "start_date": cfg.TRAIN.start_date, + "end_date": cfg.TRAIN.end_date, + "preprocess": cfg.DATASET.preprocess, + "rescale_method": cfg.DATASET.rescale_method, + "shuffle": True, + "verbose": False, + "training": True, + }, + "sampler": { + "name": "BatchSampler", + "drop_last": True, + "shuffle": True, + }, + "batch_size": cfg.TRAIN.batch_size, + "num_workers": 8, + } + + # set constraint + sup_constraint = ppsci.constraint.SupervisedConstraint( + train_dataloader_cfg, + loss=ppsci.loss.FunctionalLoss(sevir_metric.train_mse_func), + name="Sup", + ) + constraint = {sup_constraint.name: sup_constraint} + + # set iters_per_epoch by dataloader length + ITERS_PER_EPOCH = len(sup_constraint.data_loader) + # set eval dataloader config + eval_dataloader_cfg = { + "dataset": { + "name": "SEVIRDataset", + "data_dir": cfg.FILE_PATH, + "input_keys": cfg.MODEL.input_keys, + "label_keys": cfg.DATASET.label_keys, + "data_types": cfg.DATASET.data_types, + "seq_len": cfg.DATASET.seq_len, + "raw_seq_len": cfg.DATASET.raw_seq_len, + "sample_mode": cfg.DATASET.sample_mode, + "stride": cfg.DATASET.stride, + "batch_size": cfg.DATASET.batch_size, + "layout": cfg.DATASET.layout, + "in_len": cfg.DATASET.in_len, + "out_len": cfg.DATASET.out_len, + "split_mode": cfg.DATASET.split_mode, + "start_date": cfg.TRAIN.end_date, + "end_date": cfg.EVAL.end_date, + "preprocess": cfg.DATASET.preprocess, + "rescale_method": cfg.DATASET.rescale_method, + "shuffle": False, + "verbose": False, + "training": False, + }, + "batch_size": cfg.EVAL.batch_size, + } + + sup_validator = ppsci.validate.SupervisedValidator( + eval_dataloader_cfg, + loss=ppsci.loss.MSELoss(), + metric={ + "rmse": ppsci.metric.FunctionalMetric( + sevir_metric.eval_rmse_func( + out_len=cfg.DATASET.seq_len, + layout=cfg.DATASET.layout, + metrics_mode=cfg.EVAL.metrics_mode, + metrics_list=cfg.EVAL.metrics_list, + threshold_list=cfg.EVAL.threshold_list, + ) + ), + }, + name="Sup_Validator", + ) + validator = {sup_validator.name: sup_validator} + + model = ppsci.arch.CuboidTransformer( + **cfg.MODEL, + ) + + decay_parameters = get_parameter_names(model, [nn.LayerNorm]) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if n in decay_parameters], + "weight_decay": cfg.TRAIN.wd, + }, + { + "params": [ + p for n, p in model.named_parameters() if n not in decay_parameters + ], + "weight_decay": 0.0, + }, + ] + + # init optimizer and lr scheduler + lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler) + lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine( + **lr_scheduler_cfg, + iters_per_epoch=ITERS_PER_EPOCH, + eta_min=cfg.TRAIN.min_lr_ratio * cfg.TRAIN.lr_scheduler.learning_rate, + warmup_epoch=int(0.2 * cfg.TRAIN.epochs), + )() + optimizer = paddle.optimizer.AdamW( + lr_scheduler, parameters=optimizer_grouped_parameters + ) + + # initialize solver + solver = ppsci.solver.Solver( + model, + constraint, + cfg.output_dir, + optimizer, + lr_scheduler, + cfg.TRAIN.epochs, + ITERS_PER_EPOCH, + eval_during_train=cfg.TRAIN.eval_during_train, + seed=cfg.seed, + validator=validator, + compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + ) + # train model + solver.train() + # evaluate after finished training + metric = sevir_metric.eval_rmse_func( + out_len=cfg.DATASET.seq_len, + layout=cfg.DATASET.layout, + metrics_mode=cfg.EVAL.metrics_mode, + metrics_list=cfg.EVAL.metrics_list, + threshold_list=cfg.EVAL.threshold_list, + ) + + with solver.no_grad_context_manager(True): + for index, (input_, label, _) in enumerate(sup_validator.data_loader): + truefield = label["vil"].squeeze(0) + prefield = model(input_)["vil"].squeeze(0) + metric.sevir_score.update(prefield, truefield) + + metric_dict = metric.sevir_score.compute() + print(metric_dict) + + +def evaluate(cfg: DictConfig): + # set eval dataloader config + eval_dataloader_cfg = { + "dataset": { + "name": "SEVIRDataset", + "data_dir": cfg.FILE_PATH, + "input_keys": cfg.MODEL.input_keys, + "label_keys": cfg.DATASET.label_keys, + "data_types": cfg.DATASET.data_types, + "seq_len": cfg.DATASET.seq_len, + "raw_seq_len": cfg.DATASET.raw_seq_len, + "sample_mode": cfg.DATASET.sample_mode, + "stride": cfg.DATASET.stride, + "batch_size": cfg.DATASET.batch_size, + "layout": cfg.DATASET.layout, + "in_len": cfg.DATASET.in_len, + "out_len": cfg.DATASET.out_len, + "split_mode": cfg.DATASET.split_mode, + "start_date": cfg.TEST.start_date, + "end_date": cfg.TEST.end_date, + "preprocess": cfg.DATASET.preprocess, + "rescale_method": cfg.DATASET.rescale_method, + "shuffle": False, + "verbose": False, + "training": False, + }, + "batch_size": cfg.EVAL.batch_size, + } + + sup_validator = ppsci.validate.SupervisedValidator( + eval_dataloader_cfg, + loss=ppsci.loss.MSELoss(), + metric={ + "rmse": ppsci.metric.FunctionalMetric( + sevir_metric.eval_rmse_func( + out_len=cfg.DATASET.seq_len, + layout=cfg.DATASET.layout, + metrics_mode=cfg.EVAL.metrics_mode, + metrics_list=cfg.EVAL.metrics_list, + threshold_list=cfg.EVAL.threshold_list, + ) + ), + }, + name="Sup_Validator", + ) + validator = {sup_validator.name: sup_validator} + + model = ppsci.arch.CuboidTransformer( + **cfg.MODEL, + ) + + # initialize solver + solver = ppsci.solver.Solver( + model, + output_dir=cfg.output_dir, + log_freq=cfg.log_freq, + seed=cfg.seed, + validator=validator, + pretrained_model_path=cfg.EVAL.pretrained_model_path, + compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + ) + # evaluate + metric = sevir_metric.eval_rmse_func( + out_len=cfg.DATASET.seq_len, + layout=cfg.DATASET.layout, + metrics_mode=cfg.EVAL.metrics_mode, + metrics_list=cfg.EVAL.metrics_list, + threshold_list=cfg.EVAL.threshold_list, + ) + + with solver.no_grad_context_manager(True): + for index, (input_, label, _) in enumerate(sup_validator.data_loader): + truefield = label["vil"].reshape([-1, *label["vil"].shape[2:]]) + prefield = model(input_)["vil"].reshape([-1, *label["vil"].shape[2:]]) + metric.sevir_score.update(prefield, truefield) + + metric_dict = metric.sevir_score.compute() + print(metric_dict) + + +def export(cfg: DictConfig): + # set model + model = ppsci.arch.CuboidTransformer( + **cfg.MODEL, + ) + + # initialize solver + solver = ppsci.solver.Solver( + model, + pretrained_model_path=cfg.INFER.pretrained_model_path, + ) + # export model + from paddle.static import InputSpec + + input_spec = [ + { + key: InputSpec([1, 13, 384, 384, 1], "float32", name=key) + for key in model.input_keys + }, + ] + solver.export(input_spec, cfg.INFER.export_path) + + +def inference(cfg: DictConfig): + import predictor + + from ppsci.data.dataset import sevir_dataset + + predictor = predictor.EarthformerPredictor(cfg) + + if cfg.INFER.rescale_method == "sevir": + scale_dict = sevir_dataset.PREPROCESS_SCALE_SEVIR + offset_dict = sevir_dataset.PREPROCESS_OFFSET_SEVIR + elif cfg.INFER.rescale_method == "01": + scale_dict = sevir_dataset.PREPROCESS_SCALE_01 + offset_dict = sevir_dataset.PREPROCESS_OFFSET_01 + else: + raise ValueError(f"Invalid rescale option: {cfg.INFER.rescale_method}.") + + # read h5 data + h5data = h5py.File(cfg.INFER.data_path, "r") + data = np.array(h5data[cfg.INFER.data_type]).transpose([0, 3, 1, 2]) + + idx = np.random.choice(len(data), None, False) + data = ( + scale_dict[cfg.INFER.data_type] * data[idx] + offset_dict[cfg.INFER.data_type] + ) + + input_data = data[: cfg.INFER.in_len, ...] + input_data = input_data.reshape(1, *input_data.shape, 1).astype(np.float32) + target_data = data[cfg.INFER.in_len : cfg.INFER.in_len + cfg.INFER.out_len, ...] + target_data = target_data.reshape(1, *target_data.shape, 1).astype(np.float32) + + pred_data = predictor.predict(input_data, cfg.INFER.batch_size) + + sevir_vis_seq.save_example_vis_results( + save_dir=cfg.INFER.sevir_vis_save, + save_prefix=f"data_{idx}", + in_seq=input_data, + target_seq=target_data, + pred_seq=pred_data, + layout=cfg.INFER.layout, + plot_stride=cfg.INFER.plot_stride, + label=cfg.INFER.logging_prefix, + interval_real_time=cfg.INFER.interval_real_time, + ) + + +@hydra.main( + version_base=None, + config_path="./conf", + config_name="earthformer_sevir_pretrain.yaml", +) +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + elif cfg.mode == "export": + export(cfg) + elif cfg.mode == "infer": + inference(cfg) + else: + raise ValueError( + f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/earthformer/enso_metric.py b/examples/earthformer/enso_metric.py new file mode 100644 index 000000000..e1e1d7cf3 --- /dev/null +++ b/examples/earthformer/enso_metric.py @@ -0,0 +1,126 @@ +from typing import Dict +from typing import Optional +from typing import Union + +import numpy as np +import paddle +from paddle.nn import functional as F + +from ppsci.data.dataset.enso_dataset import NINO_WINDOW_T +from ppsci.data.dataset.enso_dataset import scale_back_sst + + +def compute_enso_score( + y_pred: paddle.Tensor, + y_true: paddle.Tensor, + acc_weight: Optional[Union[str, np.ndarray, paddle.Tensor]] = None, +): + """Compute the accuracy and Root Mean Squared Error (RMSE) of enso dataset. + + Args: + y_pred (paddle.Tensor): The predict data. + y_true (paddle.Tensor): The label data. + acc_weight (Optional[Union[str, np.ndarray, paddle.Tensor]], optional): The wight of accuracy. Defaults to None.use + default acc_weight specified at https://tianchi.aliyun.com/competition/entrance/531871/information. + + """ + + pred = y_pred - y_pred.mean(axis=0, keepdim=True) # (N, 24) + true = y_true - y_true.mean(axis=0, keepdim=True) # (N, 24) + cor = (pred * true).sum(axis=0) / ( + paddle.sqrt(paddle.sum(pred**2, axis=0) * paddle.sum(true**2, axis=0)) + + 1e-6 + ) + + if acc_weight is None: + acc = cor.sum() + else: + nino_out_len = y_true.shape[-1] + if acc_weight == "default": + acc_weight = paddle.to_tensor( + [1.5] * 4 + [2] * 7 + [3] * 7 + [4] * (nino_out_len - 18) + )[:nino_out_len] * paddle.log(paddle.arange(nino_out_len) + 1) + elif isinstance(acc_weight, np.ndarray): + acc_weight = paddle.to_tensor(acc_weight[:nino_out_len]) + elif isinstance(acc_weight, paddle.Tensor): + acc_weight = acc_weight[:nino_out_len] + else: + raise ValueError(f"Invalid acc_weight {acc_weight}!") + acc_weight = acc_weight.to(y_pred) + acc = (acc_weight * cor).sum() + rmse = paddle.mean((y_pred - y_true) ** 2, axis=0).sqrt().sum() + return acc, rmse + + +def sst_to_nino(sst: paddle.Tensor, normalize_sst: bool = True, detach: bool = True): + """Convert sst to nino index. + + Args: + sst (paddle.Tensor): The predict data for sst. Shape = (N, T, H, W) + normalize_sst (bool, optional): Whether to use normalize for sst. Defaults to True. + detach (bool, optional): Whether to detach the tensor. Defaults to True. + + Returns: + nino_index (paddle.Tensor): The nino index. Shape = (N, T-NINO_WINDOW_T+1) + """ + + if detach: + nino_index = sst.detach() + else: + nino_index = sst + if normalize_sst: + nino_index = scale_back_sst(nino_index) + nino_index = nino_index[:, :, 10:13, 19:30].mean(axis=[2, 3]) # (N, 26) + nino_index = nino_index.unfold(axis=1, size=NINO_WINDOW_T, step=1).mean( + axis=2 + ) # (N, 24) + + return nino_index + + +def train_mse_func( + output_dict: Dict[str, "paddle.Tensor"], + label_dict: Dict[str, "paddle.Tensor"], + *args, +) -> paddle.Tensor: + return F.mse_loss(output_dict["sst_target"], label_dict["sst_target"]) + + +def eval_rmse_func( + output_dict: Dict[str, "paddle.Tensor"], + label_dict: Dict[str, "paddle.Tensor"], + nino_out_len: int = 12, + *args, +) -> Dict[str, paddle.Tensor]: + pred = output_dict["sst_target"] + sst_target = label_dict["sst_target"] + nino_target = label_dict["nino_target"].astype("float32") + # mse + mae = F.l1_loss(pred, sst_target) + # mse + mse = F.mse_loss(pred, sst_target) + # rmse + nino_preds = sst_to_nino(sst=pred[..., 0]) + nino_preds_list, nino_target_list = map(list, zip((nino_preds, nino_target))) + nino_preds_list = paddle.concat(nino_preds_list, axis=0) + nino_target_list = paddle.concat(nino_target_list, axis=0) + + valid_acc, valid_nino_rmse = compute_enso_score( + y_pred=nino_preds_list, y_true=nino_target_list, acc_weight=None + ) + valid_weighted_acc, _ = compute_enso_score( + y_pred=nino_preds_list, y_true=nino_target_list, acc_weight="default" + ) + valid_acc /= nino_out_len + valid_nino_rmse /= nino_out_len + valid_weighted_acc /= nino_out_len + valid_loss = -valid_acc + + return { + "valid_loss_epoch": valid_loss, + "mse": mse, + "mae": mae, + "rmse": valid_nino_rmse, + "corr_nino3.4_epoch": valid_acc, + "corr_nino3.4_weighted_epoch": valid_weighted_acc, + } diff --git a/examples/earthformer/predictor.py b/examples/earthformer/predictor.py new file mode 100644 index 000000000..d553355b6 --- /dev/null +++ b/examples/earthformer/predictor.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +from omegaconf import DictConfig + +from deploy.python_infer import base + + +class EarthformerPredictor(base.Predictor): + """General predictor for Earthformer model. + + Args: + cfg (DictConfig): Running configuration. + """ + + def __init__( + self, + cfg: DictConfig, + ): + super().__init__( + cfg.INFER.pdmodel_path, + cfg.INFER.pdpiparams_path, + device=cfg.INFER.device, + engine=cfg.INFER.engine, + precision=cfg.INFER.precision, + onnx_path=cfg.INFER.onnx_path, + ir_optim=cfg.INFER.ir_optim, + min_subgraph_size=cfg.INFER.min_subgraph_size, + gpu_mem=cfg.INFER.gpu_mem, + gpu_id=cfg.INFER.gpu_id, + max_batch_size=cfg.INFER.max_batch_size, + num_cpu_threads=cfg.INFER.num_cpu_threads, + ) + self.log_freq = cfg.log_freq + + # get input names and data handles + self.input_names = self.predictor.get_input_names() + self.input_data_handle = self.predictor.get_input_handle(self.input_names[0]) + + # get output names and data handles + self.output_names = self.predictor.get_output_names() + self.output_handle = self.predictor.get_output_handle(self.output_names[0]) + + def predict( + self, + input_data: np.ndarray, + batch_size: int = 1, + ) -> np.ndarray: + """Predicts the output of the yinglong model for the given input. + + Args: + input_data (np.ndarray): Input data of shape (N, T, H, W). + batch_size (int, optional): Batch size, now only support 1. Defaults to 1. + Returns: + np.ndarray: Prediction. + """ + if batch_size != 1: + raise ValueError( + f"EarthformerPredictor only support batch_size=1, but got {batch_size}" + ) + # prepare input handle(s) + input_handles = {self.input_names[0]: self.input_data_handle} + # prepare output handle(s) + output_handles = {self.output_names[0]: self.output_handle} + + # prepare batch input dict + batch_input_dict = { + self.input_names[0]: input_data, + } + # send batch input data to input handle(s) + for name, handle in input_handles.items(): + handle.copy_from_cpu(batch_input_dict[name]) + + # run predictor + self.predictor.run() + + # receive batch output data from output handle(s) + pred = output_handles[self.output_names[0]].copy_to_cpu() + + return pred diff --git a/examples/earthformer/sevir_cmap.py b/examples/earthformer/sevir_cmap.py new file mode 100644 index 000000000..27f8ac903 --- /dev/null +++ b/examples/earthformer/sevir_cmap.py @@ -0,0 +1,334 @@ +"""Code is adapted from https://github.com/MIT-AI-Accelerator/neurips-2020-sevir. Their license is MIT License.""" + +from copy import deepcopy + +import numpy as np +from matplotlib.colors import BoundaryNorm +from matplotlib.colors import ListedColormap + +VIL_COLORS = [ + [0, 0, 0], + [0.30196078431372547, 0.30196078431372547, 0.30196078431372547], + [0.1568627450980392, 0.7450980392156863, 0.1568627450980392], + [0.09803921568627451, 0.5882352941176471, 0.09803921568627451], + [0.0392156862745098, 0.4117647058823529, 0.0392156862745098], + [0.0392156862745098, 0.29411764705882354, 0.0392156862745098], + [0.9607843137254902, 0.9607843137254902, 0.0], + [0.9294117647058824, 0.6745098039215687, 0.0], + [0.9411764705882353, 0.43137254901960786, 0.0], + [0.6274509803921569, 0.0, 0.0], + [0.9058823529411765, 0.0, 1.0], +] + +VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0] + + +def get_cmap(type, encoded=True): + if type.lower() == "vis": + cmap, norm = vis_cmap(encoded) + vmin, vmax = (0, 10000) if encoded else (0, 1) + elif type.lower() == "vil": + cmap, norm = vil_cmap(encoded) + vmin, vmax = None, None + elif type.lower() == "ir069": + cmap, norm = c09_cmap(encoded) + vmin, vmax = (-8000, -1000) if encoded else (-80, -10) + elif type.lower() == "lght": + cmap, norm = "hot", None + vmin, vmax = 0, 5 + else: + cmap, norm = "jet", None + vmin, vmax = (-7000, 2000) if encoded else (-70, 20) + return cmap, norm, vmin, vmax + + +def vil_cmap(encoded=True): + cols = deepcopy(VIL_COLORS) + lev = deepcopy(VIL_LEVELS) + # Exactly the same error occurs in the original implementation (https://github.com/MIT-AI-Accelerator/neurips-2020-sevir/blob/master/src/display/display.py). + # ValueError: There are 10 color bins including extensions, but ncolors = 9; ncolors must equal or exceed the number of bins + # We can not replicate the visualization in notebook (https://github.com/MIT-AI-Accelerator/neurips-2020-sevir/blob/master/notebooks/AnalyzeNowcast.ipynb) without error. + nil = cols.pop(0) + under = cols[0] + # over = cols.pop() + over = cols[-1] + cmap = ListedColormap(cols) + cmap.set_bad(nil) + cmap.set_under(under) + cmap.set_over(over) + norm = BoundaryNorm(lev, cmap.N) + return cmap, norm + + +def vis_cmap(encoded=True): + cols = [ + [0, 0, 0], + [0.0392156862745098, 0.0392156862745098, 0.0392156862745098], + [0.0784313725490196, 0.0784313725490196, 0.0784313725490196], + [0.11764705882352941, 0.11764705882352941, 0.11764705882352941], + [0.1568627450980392, 0.1568627450980392, 0.1568627450980392], + [0.19607843137254902, 0.19607843137254902, 0.19607843137254902], + [0.23529411764705882, 0.23529411764705882, 0.23529411764705882], + [0.27450980392156865, 0.27450980392156865, 0.27450980392156865], + [0.3137254901960784, 0.3137254901960784, 0.3137254901960784], + [0.35294117647058826, 0.35294117647058826, 0.35294117647058826], + [0.39215686274509803, 0.39215686274509803, 0.39215686274509803], + [0.43137254901960786, 0.43137254901960786, 0.43137254901960786], + [0.47058823529411764, 0.47058823529411764, 0.47058823529411764], + [0.5098039215686274, 0.5098039215686274, 0.5098039215686274], + [0.5490196078431373, 0.5490196078431373, 0.5490196078431373], + [0.5882352941176471, 0.5882352941176471, 0.5882352941176471], + [0.6274509803921569, 0.6274509803921569, 0.6274509803921569], + [0.6666666666666666, 0.6666666666666666, 0.6666666666666666], + [0.7058823529411765, 0.7058823529411765, 0.7058823529411765], + [0.7450980392156863, 0.7450980392156863, 0.7450980392156863], + [0.7843137254901961, 0.7843137254901961, 0.7843137254901961], + [0.8235294117647058, 0.8235294117647058, 0.8235294117647058], + [0.8627450980392157, 0.8627450980392157, 0.8627450980392157], + [0.9019607843137255, 0.9019607843137255, 0.9019607843137255], + [0.9411764705882353, 0.9411764705882353, 0.9411764705882353], + [0.9803921568627451, 0.9803921568627451, 0.9803921568627451], + [0.9803921568627451, 0.9803921568627451, 0.9803921568627451], + ] + lev = np.array( + [ + 0.0, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.12, + 0.14, + 0.16, + 0.2, + 0.24, + 0.28, + 0.32, + 0.36, + 0.4, + 0.44, + 0.48, + 0.52, + 0.56, + 0.6, + 0.64, + 0.68, + 0.72, + 0.76, + 0.8, + 0.9, + 1.0, + ] + ) + if encoded: + lev *= 1e4 + nil = cols.pop(0) + under = cols[0] + over = cols.pop() + cmap = ListedColormap(cols) + cmap.set_bad(nil) + cmap.set_under(under) + cmap.set_over(over) + norm = BoundaryNorm(lev, cmap.N) + return cmap, norm + + +def ir_cmap(encoded=True): + cols = [ + [0, 0, 0], + [1.0, 1.0, 1.0], + [0.9803921568627451, 0.9803921568627451, 0.9803921568627451], + [0.9411764705882353, 0.9411764705882353, 0.9411764705882353], + [0.9019607843137255, 0.9019607843137255, 0.9019607843137255], + [0.8627450980392157, 0.8627450980392157, 0.8627450980392157], + [0.8235294117647058, 0.8235294117647058, 0.8235294117647058], + [0.7843137254901961, 0.7843137254901961, 0.7843137254901961], + [0.7450980392156863, 0.7450980392156863, 0.7450980392156863], + [0.7058823529411765, 0.7058823529411765, 0.7058823529411765], + [0.6666666666666666, 0.6666666666666666, 0.6666666666666666], + [0.6274509803921569, 0.6274509803921569, 0.6274509803921569], + [0.5882352941176471, 0.5882352941176471, 0.5882352941176471], + [0.5490196078431373, 0.5490196078431373, 0.5490196078431373], + [0.5098039215686274, 0.5098039215686274, 0.5098039215686274], + [0.47058823529411764, 0.47058823529411764, 0.47058823529411764], + [0.43137254901960786, 0.43137254901960786, 0.43137254901960786], + [0.39215686274509803, 0.39215686274509803, 0.39215686274509803], + [0.35294117647058826, 0.35294117647058826, 0.35294117647058826], + [0.3137254901960784, 0.3137254901960784, 0.3137254901960784], + [0.27450980392156865, 0.27450980392156865, 0.27450980392156865], + [0.23529411764705882, 0.23529411764705882, 0.23529411764705882], + [0.19607843137254902, 0.19607843137254902, 0.19607843137254902], + [0.1568627450980392, 0.1568627450980392, 0.1568627450980392], + [0.11764705882352941, 0.11764705882352941, 0.11764705882352941], + [0.0784313725490196, 0.0784313725490196, 0.0784313725490196], + [0.0392156862745098, 0.0392156862745098, 0.0392156862745098], + [0.0, 0.803921568627451, 0.803921568627451], + ] + lev = np.array( + [ + -110.0, + -105.2, + -95.2, + -85.2, + -75.2, + -65.2, + -55.2, + -45.2, + -35.2, + -28.2, + -23.2, + -18.2, + -13.2, + -8.2, + -3.2, + 1.8, + 6.8, + 11.8, + 16.8, + 21.8, + 26.8, + 31.8, + 36.8, + 41.8, + 46.8, + 51.8, + 90.0, + 100.0, + ] + ) + if encoded: + lev *= 1e2 + nil = cols.pop(0) + under = cols[0] + over = cols.pop() + cmap = ListedColormap(cols) + cmap.set_bad(nil) + cmap.set_under(under) + cmap.set_over(over) + norm = BoundaryNorm(lev, cmap.N) + return cmap, norm + + +def c09_cmap(encoded=True): + cols = [ + [1.000000, 0.000000, 0.000000], + [1.000000, 0.031373, 0.000000], + [1.000000, 0.062745, 0.000000], + [1.000000, 0.094118, 0.000000], + [1.000000, 0.125490, 0.000000], + [1.000000, 0.156863, 0.000000], + [1.000000, 0.188235, 0.000000], + [1.000000, 0.219608, 0.000000], + [1.000000, 0.250980, 0.000000], + [1.000000, 0.282353, 0.000000], + [1.000000, 0.313725, 0.000000], + [1.000000, 0.349020, 0.003922], + [1.000000, 0.380392, 0.003922], + [1.000000, 0.411765, 0.003922], + [1.000000, 0.443137, 0.003922], + [1.000000, 0.474510, 0.003922], + [1.000000, 0.505882, 0.003922], + [1.000000, 0.537255, 0.003922], + [1.000000, 0.568627, 0.003922], + [1.000000, 0.600000, 0.003922], + [1.000000, 0.631373, 0.003922], + [1.000000, 0.666667, 0.007843], + [1.000000, 0.698039, 0.007843], + [1.000000, 0.729412, 0.007843], + [1.000000, 0.760784, 0.007843], + [1.000000, 0.792157, 0.007843], + [1.000000, 0.823529, 0.007843], + [1.000000, 0.854902, 0.007843], + [1.000000, 0.886275, 0.007843], + [1.000000, 0.917647, 0.007843], + [1.000000, 0.949020, 0.007843], + [1.000000, 0.984314, 0.011765], + [0.968627, 0.952941, 0.031373], + [0.937255, 0.921569, 0.050980], + [0.901961, 0.886275, 0.074510], + [0.870588, 0.854902, 0.094118], + [0.835294, 0.823529, 0.117647], + [0.803922, 0.788235, 0.137255], + [0.772549, 0.756863, 0.160784], + [0.737255, 0.725490, 0.180392], + [0.705882, 0.690196, 0.200000], + [0.670588, 0.658824, 0.223529], + [0.639216, 0.623529, 0.243137], + [0.607843, 0.592157, 0.266667], + [0.572549, 0.560784, 0.286275], + [0.541176, 0.525490, 0.309804], + [0.509804, 0.494118, 0.329412], + [0.474510, 0.462745, 0.349020], + [0.752941, 0.749020, 0.909804], + [0.800000, 0.800000, 0.929412], + [0.850980, 0.847059, 0.945098], + [0.898039, 0.898039, 0.964706], + [0.949020, 0.949020, 0.980392], + [1.000000, 1.000000, 1.000000], + [0.964706, 0.980392, 0.964706], + [0.929412, 0.960784, 0.929412], + [0.890196, 0.937255, 0.890196], + [0.854902, 0.917647, 0.854902], + [0.815686, 0.894118, 0.815686], + [0.780392, 0.874510, 0.780392], + [0.745098, 0.850980, 0.745098], + [0.705882, 0.831373, 0.705882], + [0.670588, 0.807843, 0.670588], + [0.631373, 0.788235, 0.631373], + [0.596078, 0.764706, 0.596078], + [0.560784, 0.745098, 0.560784], + [0.521569, 0.721569, 0.521569], + [0.486275, 0.701961, 0.486275], + [0.447059, 0.678431, 0.447059], + [0.411765, 0.658824, 0.411765], + [0.376471, 0.635294, 0.376471], + [0.337255, 0.615686, 0.337255], + [0.301961, 0.592157, 0.301961], + [0.262745, 0.572549, 0.262745], + [0.227451, 0.549020, 0.227451], + [0.192157, 0.529412, 0.192157], + [0.152941, 0.505882, 0.152941], + [0.117647, 0.486275, 0.117647], + [0.078431, 0.462745, 0.078431], + [0.043137, 0.443137, 0.043137], + [0.003922, 0.419608, 0.003922], + [0.003922, 0.431373, 0.027451], + [0.003922, 0.447059, 0.054902], + [0.003922, 0.462745, 0.082353], + [0.003922, 0.478431, 0.109804], + [0.003922, 0.494118, 0.137255], + [0.003922, 0.509804, 0.164706], + [0.003922, 0.525490, 0.192157], + [0.003922, 0.541176, 0.215686], + [0.003922, 0.556863, 0.243137], + [0.007843, 0.568627, 0.270588], + [0.007843, 0.584314, 0.298039], + [0.007843, 0.600000, 0.325490], + [0.007843, 0.615686, 0.352941], + [0.007843, 0.631373, 0.380392], + [0.007843, 0.647059, 0.403922], + [0.007843, 0.662745, 0.431373], + [0.007843, 0.678431, 0.458824], + [0.007843, 0.694118, 0.486275], + [0.011765, 0.705882, 0.513725], + [0.011765, 0.721569, 0.541176], + [0.011765, 0.737255, 0.568627], + [0.011765, 0.752941, 0.596078], + [0.011765, 0.768627, 0.619608], + [0.011765, 0.784314, 0.647059], + [0.011765, 0.800000, 0.674510], + [0.011765, 0.815686, 0.701961], + [0.011765, 0.831373, 0.729412], + [0.015686, 0.843137, 0.756863], + [0.015686, 0.858824, 0.784314], + [0.015686, 0.874510, 0.807843], + [0.015686, 0.890196, 0.835294], + [0.015686, 0.905882, 0.862745], + [0.015686, 0.921569, 0.890196], + [0.015686, 0.937255, 0.917647], + [0.015686, 0.952941, 0.945098], + [0.015686, 0.968627, 0.972549], + [1.000000, 1.000000, 1.000000], + ] + return ListedColormap(cols), None diff --git a/examples/earthformer/sevir_metric.py b/examples/earthformer/sevir_metric.py new file mode 100644 index 000000000..c1e0c945f --- /dev/null +++ b/examples/earthformer/sevir_metric.py @@ -0,0 +1,281 @@ +from typing import Dict +from typing import Optional +from typing import Sequence + +import numpy as np +import paddle +from paddle.nn import functional as F + +from ppsci.data.dataset import sevir_dataset + + +def _threshold(target, pred, T): + """ + Returns binary tensors t,p the same shape as target & pred. t = 1 wherever + target > t. p =1 wherever pred > t. p and t are set to 0 wherever EITHER + t or p are nan. + This is useful for counts that don't involve correct rejections. + + Args: + target (paddle.Tensor): label + pred (paddle.Tensor): predict + T (numeric_type): threshold + Returns: + t + p + """ + + t = (target >= T).astype("float32") + p = (pred >= T).astype("float32") + is_nan = paddle.logical_or(paddle.isnan(target), paddle.isnan(pred)) + t[is_nan] = 0 + p[is_nan] = 0 + return t, p + + +class SEVIRSkillScore: + r""" + The calculation of skill scores in SEVIR challenge is slightly different: + `mCSI = sum(mCSI_t) / T` + See https://github.com/MIT-AI-Accelerator/sevir_challenges/blob/dev/radar_nowcasting/RadarNowcastBenchmarks.ipynb for more details. + + Args: + seq_len (int): sequence length + layout (str): layout mode + mode (str): Should be in ("0", "1", "2") + "0": + cumulates hits/misses/fas of all test pixels + score_avg takes average over all thresholds + return + score_thresh shape = (1, ) + score_avg shape = (1, ) + "1": + cumulates hits/misses/fas of each step + score_avg takes average over all thresholds while keeps the seq_len dim + return + score_thresh shape = (seq_len, ) + score_avg shape = (seq_len, ) + "2": + cumulates hits/misses/fas of each step + score_avg takes average over all thresholds, then takes average over the seq_len dim + return + score_thresh shape = (1, ) + score_avg shape = (1, ) + preprocess_type (str): prepprocess type + threshold_list (Sequence[int]): threshold list + """ + + full_state_update: bool = True + + def __init__( + self, + layout: str = "NHWT", + mode: str = "0", + seq_len: Optional[int] = None, + preprocess_type: str = "sevir", + threshold_list: Sequence[int] = (16, 74, 133, 160, 181, 219), + metrics_list: Sequence[str] = ("csi", "bias", "sucr", "pod"), + eps: float = 1e-4, + dist_sync_on_step: bool = False, + ): + super().__init__() + self.layout = layout + self.preprocess_type = preprocess_type + self.threshold_list = threshold_list + self.metrics_list = metrics_list + self.eps = eps + self.mode = mode + self.seq_len = seq_len + + self.hits = paddle.zeros(shape=[len(self.threshold_list)]) + self.misses = paddle.zeros(shape=[len(self.threshold_list)]) + self.fas = paddle.zeros(shape=[len(self.threshold_list)]) + + if mode in ("0",): + self.keep_seq_len_dim = False + elif mode in ("1", "2"): + self.keep_seq_len_dim = True + assert isinstance( + self.seq_len, int + ), "seq_len must be provided when we need to keep seq_len dim." + + else: + raise NotImplementedError(f"mode {mode} not supported!") + + @staticmethod + def pod(hits, misses, fas, eps): + return hits / (hits + misses + eps) + + @staticmethod + def sucr(hits, misses, fas, eps): + return hits / (hits + fas + eps) + + @staticmethod + def csi(hits, misses, fas, eps): + return hits / (hits + misses + fas + eps) + + @staticmethod + def bias(hits, misses, fas, eps): + bias = (hits + fas) / (hits + misses + eps) + logbias = paddle.pow(bias / paddle.log(paddle.to_tensor(2.0)), 2.0) + return logbias + + @property + def hits_misses_fas_reduce_dims(self): + if not hasattr(self, "_hits_misses_fas_reduce_dims"): + seq_dim = self.layout.find("T") + self._hits_misses_fas_reduce_dims = list(range(len(self.layout))) + if self.keep_seq_len_dim: + self._hits_misses_fas_reduce_dims.pop(seq_dim) + return self._hits_misses_fas_reduce_dims + + def calc_seq_hits_misses_fas(self, pred, target, threshold): + """ + Args: + pred (paddle.Tensor): Predict data. + target (paddle.Tensor): True data. + threshold (int): The threshold to calculate hits, misses and fas. + + Returns: + hits (paddle.Tensor): Number of hits. + misses (paddle.Tensor): Number of misses. + fas (paddle.Tensor): Number of false positives. + each has shape (seq_len, ) + """ + + with paddle.no_grad(): + t, p = _threshold(target, pred, threshold) + hits = paddle.sum(t * p, axis=self.hits_misses_fas_reduce_dims).astype( + "int32" + ) + misses = paddle.sum( + t * (1 - p), axis=self.hits_misses_fas_reduce_dims + ).astype("int32") + fas = paddle.sum((1 - t) * p, axis=self.hits_misses_fas_reduce_dims).astype( + "int32" + ) + return hits, misses, fas + + def preprocess(self, pred, target): + if self.preprocess_type == "sevir": + pred = sevir_dataset.SEVIRDataset.process_data_dict_back( + data_dict={"vil": pred.detach().astype("float32")} + )["vil"] + target = sevir_dataset.SEVIRDataset.process_data_dict_back( + data_dict={"vil": target.detach().astype("float32")} + )["vil"] + else: + raise NotImplementedError(f"{self.preprocess_type} not supported") + return pred, target + + def update(self, pred: paddle.Tensor, target: paddle.Tensor): + pred, target = self.preprocess(pred, target) + for i, threshold in enumerate(self.threshold_list): + hits, misses, fas = self.calc_seq_hits_misses_fas(pred, target, threshold) + self.hits[i] += hits + self.misses[i] += misses + self.fas[i] += fas + + def compute(self, pred: paddle.Tensor, target: paddle.Tensor): + metrics_dict = { + "pod": self.pod, + "csi": self.csi, + "sucr": self.sucr, + "bias": self.bias, + } + ret = {} + for threshold in self.threshold_list: + ret[threshold] = {} + ret["avg"] = {} + for metrics in self.metrics_list: + if self.keep_seq_len_dim: + score_avg = np.zeros((self.seq_len,)) + else: + score_avg = 0 + # shape = (len(threshold_list), seq_len) if self.keep_seq_len_dim, + # else shape = (len(threshold_list),) + scores = metrics_dict[metrics](self.hits, self.misses, self.fas, self.eps) + scores = scores.detach().cpu().numpy() + for i, threshold in enumerate(self.threshold_list): + if self.keep_seq_len_dim: + score = scores[i] # shape = (seq_len, ) + else: + score = scores[i].item() # shape = (1, ) + if self.mode in ("0", "1"): + ret[threshold][metrics] = score + elif self.mode in ("2",): + ret[threshold][metrics] = np.mean(score).item() + else: + raise NotImplementedError(f"{self.mode} is invalid.") + score_avg += score + score_avg /= len(self.threshold_list) + if self.mode in ("0", "1"): + ret["avg"][metrics] = score_avg + elif self.mode in ("2",): + ret["avg"][metrics] = np.mean(score_avg).item() + else: + raise NotImplementedError(f"{self.mode} is invalid.") + + metrics = {} + metrics["csi_avg_loss"] = 0 + for metric in self.metrics_list: + for th in self.threshold_list: + metrics[f"{metric}_{th}"] = ret[th][metric] + metrics[f"{metric}_avg"] = ret["avg"][metric] + + metrics["csi_avg_loss"] = -metrics["csi_avg"] + return metrics + + +class eval_rmse_func: + def __init__( + self, + out_len=12, + layout="NTHWC", + metrics_mode="0", + metrics_list=["csi", "pod", "sucr", "bias"], + threshold_list=[16, 74, 133, 160, 181, 219], + *args, + ) -> Dict[str, paddle.Tensor]: + super().__init__() + self.out_len = out_len + self.layout = layout + self.metrics_mode = metrics_mode + self.metrics_list = metrics_list + self.threshold_list = threshold_list + + self.sevir_score = SEVIRSkillScore( + layout=self.layout, + mode=self.metrics_mode, + seq_len=self.out_len, + threshold_list=self.threshold_list, + metrics_list=self.metrics_list, + ) + + def __call__( + self, + output_dict: Dict[str, "paddle.Tensor"], + label_dict: Dict[str, "paddle.Tensor"], + ): + pred = output_dict["vil"] + vil_target = label_dict["vil"] + vil_target = vil_target.reshape([-1, *vil_target.shape[2:]]) + # mse + mae = F.l1_loss(pred, vil_target, "none") + mae = mae.mean(axis=tuple(range(1, mae.ndim))) + # mse + mse = F.mse_loss(pred, vil_target, "none") + mse = mse.mean(axis=tuple(range(1, mse.ndim))) + + return {"mse": mse, "mae": mae} + + +def train_mse_func( + output_dict: Dict[str, "paddle.Tensor"], + label_dict: Dict[str, "paddle.Tensor"], + *args, +) -> paddle.Tensor: + pred = output_dict["vil"] + vil_target = label_dict["vil"] + target = vil_target.reshape([-1, *vil_target.shape[2:]]) + return F.mse_loss(pred, target) diff --git a/examples/earthformer/sevir_vis_seq.py b/examples/earthformer/sevir_vis_seq.py new file mode 100644 index 000000000..3bacfb747 --- /dev/null +++ b/examples/earthformer/sevir_vis_seq.py @@ -0,0 +1,247 @@ +import os +from typing import List + +import numpy as np +import sevir_cmap +from matplotlib import pyplot as plt +from matplotlib.colors import ListedColormap +from matplotlib.patches import Patch + +from ppsci.data.dataset import sevir_dataset + +HMF_COLORS = ( + np.array([[82, 82, 82], [252, 141, 89], [255, 255, 191], [145, 191, 219]]) / 255 +) + +THRESHOLDS = (0, 16, 74, 133, 160, 181, 219, 255) + + +def plot_hit_miss_fa(ax, y_true, y_pred, thres): + mask = np.zeros_like(y_true) + mask[np.logical_and(y_true >= thres, y_pred >= thres)] = 4 + mask[np.logical_and(y_true >= thres, y_pred < thres)] = 3 + mask[np.logical_and(y_true < thres, y_pred >= thres)] = 2 + mask[np.logical_and(y_true < thres, y_pred < thres)] = 1 + cmap = ListedColormap(HMF_COLORS) + ax.imshow(mask, cmap=cmap) + + +def plot_hit_miss_fa_all_thresholds(ax, y_true, y_pred, **unused_kwargs): + fig = np.zeros(y_true.shape) + y_true_idx = np.searchsorted(THRESHOLDS, y_true) + y_pred_idx = np.searchsorted(THRESHOLDS, y_pred) + fig[y_true_idx == y_pred_idx] = 4 + fig[y_true_idx > y_pred_idx] = 3 + fig[y_true_idx < y_pred_idx] = 2 + # do not count results in these not challenging areas. + fig[np.logical_and(y_true < THRESHOLDS[1], y_pred < THRESHOLDS[1])] = 1 + cmap = ListedColormap(HMF_COLORS) + ax.imshow(fig, cmap=cmap) + + +def get_cmap_dict(s): + return { + "cmap": sevir_cmap.get_cmap(s, encoded=True)[0], + "norm": sevir_cmap.get_cmap(s, encoded=True)[1], + "vmin": sevir_cmap.get_cmap(s, encoded=True)[2], + "vmax": sevir_cmap.get_cmap(s, encoded=True)[3], + } + + +def visualize_result( + in_seq: np.array, + target_seq: np.array, + pred_seq_list: List[np.array], + label_list: List[str], + interval_real_time: float = 10.0, + idx=0, + norm=None, + plot_stride=2, + figsize=(24, 8), + fs=10, + vis_thresh=THRESHOLDS[2], + vis_hits_misses_fas=True, +): + """ + Args: + in_seq (np.array): + target_seq (np.array): + interval_real_time (float): The minutes of each plot interval + """ + + if norm is None: + norm = {"scale": 255, "shift": 0} + in_len = in_seq.shape[-1] + out_len = target_seq.shape[-1] + max_len = max(in_len, out_len) + ncols = (max_len - 1) // plot_stride + 1 + if vis_hits_misses_fas: + fig, ax = plt.subplots( + nrows=2 + 3 * len(pred_seq_list), ncols=ncols, figsize=figsize + ) + else: + fig, ax = plt.subplots( + nrows=2 + len(pred_seq_list), ncols=ncols, figsize=figsize + ) + + ax[0][0].set_ylabel("Inputs", fontsize=fs) + for i in range(0, max_len, plot_stride): + if i < in_len: + xt = in_seq[idx, :, :, i] * norm["scale"] + norm["shift"] + ax[0][i // plot_stride].imshow(xt, **get_cmap_dict("vil")) + else: + ax[0][i // plot_stride].axis("off") + + ax[1][0].set_ylabel("Target", fontsize=fs) + for i in range(0, max_len, plot_stride): + if i < out_len: + xt = target_seq[idx, :, :, i] * norm["scale"] + norm["shift"] + ax[1][i // plot_stride].imshow(xt, **get_cmap_dict("vil")) + else: + ax[1][i // plot_stride].axis("off") + + target_seq = target_seq[idx : idx + 1] * norm["scale"] + norm["shift"] + y_preds = [ + pred_seq[idx : idx + 1] * norm["scale"] + norm["shift"] + for pred_seq in pred_seq_list + ] + + # Plot model predictions + if vis_hits_misses_fas: + for k in range(len(pred_seq_list)): + for i in range(0, max_len, plot_stride): + if i < out_len: + ax[2 + 3 * k][i // plot_stride].imshow( + y_preds[k][0, :, :, i], **get_cmap_dict("vil") + ) + plot_hit_miss_fa( + ax[2 + 1 + 3 * k][i // plot_stride], + target_seq[0, :, :, i], + y_preds[k][0, :, :, i], + vis_thresh, + ) + plot_hit_miss_fa_all_thresholds( + ax[2 + 2 + 3 * k][i // plot_stride], + target_seq[0, :, :, i], + y_preds[k][0, :, :, i], + ) + else: + ax[2 + 3 * k][i // plot_stride].axis("off") + ax[2 + 1 + 3 * k][i // plot_stride].axis("off") + ax[2 + 2 + 3 * k][i // plot_stride].axis("off") + + ax[2 + 3 * k][0].set_ylabel(label_list[k] + "\nPrediction", fontsize=fs) + ax[2 + 1 + 3 * k][0].set_ylabel( + label_list[k] + f"\nScores\nThresh={vis_thresh}", fontsize=fs + ) + ax[2 + 2 + 3 * k][0].set_ylabel( + label_list[k] + "\nScores\nAll Thresh", fontsize=fs + ) + else: + for k in range(len(pred_seq_list)): + for i in range(0, max_len, plot_stride): + if i < out_len: + ax[2 + k][i // plot_stride].imshow( + y_preds[k][0, :, :, i], **get_cmap_dict("vil") + ) + else: + ax[2 + k][i // plot_stride].axis("off") + + ax[2 + k][0].set_ylabel(label_list[k] + "\nPrediction", fontsize=fs) + + for i in range(0, max_len, plot_stride): + if i < out_len: + ax[-1][i // plot_stride].set_title( + f"{int(interval_real_time * (i + plot_stride))} Minutes", y=-0.25 + ) + + for j in range(len(ax)): + for i in range(len(ax[j])): + ax[j][i].xaxis.set_ticks([]) + ax[j][i].yaxis.set_ticks([]) + + # Legend of thresholds + num_thresh_legend = len(sevir_cmap.VIL_LEVELS) - 1 + legend_elements = [ + Patch( + facecolor=sevir_cmap.VIL_COLORS[i], + label=f"{int(sevir_cmap.VIL_LEVELS[i - 1])}-{int(sevir_cmap.VIL_LEVELS[i])}", + ) + for i in range(1, num_thresh_legend + 1) + ] + ax[0][0].legend( + handles=legend_elements, + loc="center left", + bbox_to_anchor=(-1.2, -0.0), + borderaxespad=0, + frameon=False, + fontsize="10", + ) + if vis_hits_misses_fas: + # Legend of Hit, Miss and False Alarm + legend_elements = [ + Patch(facecolor=HMF_COLORS[3], edgecolor="k", label="Hit"), + Patch(facecolor=HMF_COLORS[2], edgecolor="k", label="Miss"), + Patch(facecolor=HMF_COLORS[1], edgecolor="k", label="False Alarm"), + ] + + ax[3][0].legend( + handles=legend_elements, + loc="center left", + bbox_to_anchor=(-2.2, -0.0), + borderaxespad=0, + frameon=False, + fontsize="16", + ) + + plt.subplots_adjust(hspace=0.05, wspace=0.05) + return fig, ax + + +def save_example_vis_results( + save_dir, + save_prefix, + in_seq, + target_seq, + pred_seq, + label, + layout="NHWT", + interval_real_time: float = 10.0, + idx=0, + plot_stride=2, + fs=10, + norm=None, +): + """ + Args: + in_seq (np.array): float value 0-1 + target_seq (np.array): float value 0-1 + pred_seq (np.array): float value 0-1 + interval_real_time (float): The minutes of each plot interval + """ + + in_seq = sevir_dataset.change_layout_np(in_seq, in_layout=layout).astype(np.float32) + target_seq = sevir_dataset.change_layout_np(target_seq, in_layout=layout).astype( + np.float32 + ) + pred_seq = sevir_dataset.change_layout_np(pred_seq, in_layout=layout).astype( + np.float32 + ) + fig_path = os.path.join(save_dir, f"{save_prefix}.png") + fig, ax = visualize_result( + in_seq=in_seq, + target_seq=target_seq, + pred_seq_list=[ + pred_seq, + ], + label_list=[ + label, + ], + interval_real_time=interval_real_time, + idx=idx, + plot_stride=plot_stride, + fs=fs, + norm=norm, + ) + plt.savefig(fig_path) + plt.close(fig) diff --git a/jointContribution/XPINNs/XPINN_2D_PoissonsEqn.py b/jointContribution/XPINNs/XPINN_2D_PoissonsEqn.py old mode 100755 new mode 100644 diff --git a/jointContribution/graphGalerkin/pycamotk b/jointContribution/graphGalerkin/pycamotk deleted file mode 160000 index 83e96baf2..000000000 --- a/jointContribution/graphGalerkin/pycamotk +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 83e96baf2701c096ab1604bb74ef64455fc87b33 diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index c959c2c07..e59cac085 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -42,6 +42,7 @@ from ppsci.arch.cfdgcn import CFDGCN # isort:skip from ppsci.arch.dgmr import DGMR # isort:skip from ppsci.arch.vae import AutoEncoder # isort:skip +from ppsci.arch.cuboid_transformer import CuboidTransformer # isort:skip from ppsci.utils import logger # isort:skip @@ -55,6 +56,7 @@ "LorenzEmbedding", "RosslerEmbedding", "CylinderEmbedding", + "CuboidTransformer", "Generator", "Discriminator", "PhysformerGPT2", diff --git a/ppsci/arch/cuboid_transformer.py b/ppsci/arch/cuboid_transformer.py new file mode 100644 index 000000000..82b5e6902 --- /dev/null +++ b/ppsci/arch/cuboid_transformer.py @@ -0,0 +1,961 @@ +from typing import Sequence +from typing import Tuple +from typing import Union + +import paddle +from paddle import nn + +import ppsci.arch.cuboid_transformer_decoder as cuboid_decoder +import ppsci.arch.cuboid_transformer_encoder as cuboid_encoder +import ppsci.arch.cuboid_transformer_utils as cuboid_utils +from ppsci.arch import activation as act_mod +from ppsci.arch import base +from ppsci.arch.cuboid_transformer_encoder import NEGATIVE_SLOPE +from ppsci.utils import initializer + +"""A space-time Transformer with Cuboid Attention""" + + +class InitialEncoder(paddle.nn.Layer): + def __init__( + self, + dim, + out_dim, + downsample_scale: Union[int, Sequence[int]], + num_conv_layers: int = 2, + activation: str = "leaky", + padding_type: str = "nearest", + conv_init_mode: str = "0", + linear_init_mode: str = "0", + norm_init_mode: str = "0", + ): + super(InitialEncoder, self).__init__() + self.num_conv_layers = num_conv_layers + self.conv_init_mode = conv_init_mode + self.linear_init_mode = linear_init_mode + self.norm_init_mode = norm_init_mode + conv_block = [] + for i in range(num_conv_layers): + if i == 0: + conv_block.append( + paddle.nn.Conv2D( + kernel_size=(3, 3), + padding=(1, 1), + in_channels=dim, + out_channels=out_dim, + ) + ) + conv_block.append( + paddle.nn.GroupNorm(num_groups=16, num_channels=out_dim) + ) + conv_block.append( + act_mod.get_activation(activation) + if activation != "leaky_relu" + else nn.LeakyReLU(NEGATIVE_SLOPE) + ) + else: + conv_block.append( + paddle.nn.Conv2D( + kernel_size=(3, 3), + padding=(1, 1), + in_channels=out_dim, + out_channels=out_dim, + ) + ) + conv_block.append( + paddle.nn.GroupNorm(num_groups=16, num_channels=out_dim) + ) + conv_block.append( + act_mod.get_activation(activation) + if activation != "leaky_relu" + else nn.LeakyReLU(NEGATIVE_SLOPE) + ) + self.conv_block = paddle.nn.Sequential(*conv_block) + if isinstance(downsample_scale, int): + patch_merge_downsample = (1, downsample_scale, downsample_scale) + elif len(downsample_scale) == 2: + patch_merge_downsample = (1, *downsample_scale) + elif len(downsample_scale) == 3: + patch_merge_downsample = tuple(downsample_scale) + else: + raise NotImplementedError( + f"downsample_scale {downsample_scale} format not supported!" + ) + self.patch_merge = cuboid_encoder.PatchMerging3D( + dim=out_dim, + out_dim=out_dim, + padding_type=padding_type, + downsample=patch_merge_downsample, + linear_init_mode=linear_init_mode, + norm_init_mode=norm_init_mode, + ) + self.reset_parameters() + + def reset_parameters(self): + for m in self.children(): + cuboid_utils.apply_initialization( + m, + conv_mode=self.conv_init_mode, + linear_mode=self.linear_init_mode, + norm_mode=self.norm_init_mode, + ) + + def forward(self, x): + """x --> [K x Conv2D] --> PatchMerge + + Args: + x : (B, T, H, W, C) + + Returns: + out : (B, T, H_new, W_new, C_out) + """ + + B, T, H, W, C = x.shape + + if self.num_conv_layers > 0: + x = x.reshape([B * T, H, W, C]).transpose(perm=[0, 3, 1, 2]) + x = self.conv_block(x).transpose(perm=[0, 2, 3, 1]) + x = self.patch_merge(x.reshape([B, T, H, W, -1])) + else: + x = self.patch_merge(x) + return x + + +class FinalDecoder(paddle.nn.Layer): + def __init__( + self, + target_thw: Tuple[int, ...], + dim: int, + num_conv_layers: int = 2, + activation: str = "leaky", + conv_init_mode: str = "0", + linear_init_mode: str = "0", + norm_init_mode: str = "0", + ): + super(FinalDecoder, self).__init__() + self.target_thw = target_thw + self.dim = dim + self.num_conv_layers = num_conv_layers + self.conv_init_mode = conv_init_mode + self.linear_init_mode = linear_init_mode + self.norm_init_mode = norm_init_mode + conv_block = [] + for i in range(num_conv_layers): + conv_block.append( + paddle.nn.Conv2D( + kernel_size=(3, 3), + padding=(1, 1), + in_channels=dim, + out_channels=dim, + ) + ) + conv_block.append(paddle.nn.GroupNorm(num_groups=16, num_channels=dim)) + conv_block.append( + act_mod.get_activation(activation) + if activation != "leaky_relu" + else nn.LeakyReLU(NEGATIVE_SLOPE) + ) + self.conv_block = paddle.nn.Sequential(*conv_block) + self.upsample = cuboid_decoder.Upsample3DLayer( + dim=dim, + out_dim=dim, + target_size=target_thw, + kernel_size=3, + conv_init_mode=conv_init_mode, + ) + self.reset_parameters() + + def reset_parameters(self): + for m in self.children(): + cuboid_utils.apply_initialization( + m, + conv_mode=self.conv_init_mode, + linear_mode=self.linear_init_mode, + norm_mode=self.norm_init_mode, + ) + + def forward(self, x): + """x --> Upsample --> [K x Conv2D] + + Args: + x : (B, T, H, W, C) + + Returns: + out : (B, T, H_new, W_new, C) + """ + + x = self.upsample(x) + if self.num_conv_layers > 0: + B, T, H, W, C = x.shape + x = x.reshape([B * T, H, W, C]).transpose(perm=[0, 3, 1, 2]) + x = ( + self.conv_block(x) + .transpose(perm=[0, 2, 3, 1]) + .reshape([B, T, H, W, -1]) + ) + return x + + +class InitialStackPatchMergingEncoder(paddle.nn.Layer): + def __init__( + self, + num_merge: int, + in_dim: int, + out_dim_list: Tuple[int, ...], + downsample_scale_list: Tuple[float, ...], + num_conv_per_merge_list: Tuple[int, ...] = None, + activation: str = "leaky", + padding_type: str = "nearest", + conv_init_mode: str = "0", + linear_init_mode: str = "0", + norm_init_mode: str = "0", + ): + super(InitialStackPatchMergingEncoder, self).__init__() + self.conv_init_mode = conv_init_mode + self.linear_init_mode = linear_init_mode + self.norm_init_mode = norm_init_mode + self.num_merge = num_merge + self.in_dim = in_dim + self.out_dim_list = out_dim_list[:num_merge] + self.downsample_scale_list = downsample_scale_list[:num_merge] + self.num_conv_per_merge_list = num_conv_per_merge_list + self.num_group_list = [max(1, out_dim // 4) for out_dim in self.out_dim_list] + self.conv_block_list = paddle.nn.LayerList() + self.patch_merge_list = paddle.nn.LayerList() + for i in range(num_merge): + if i == 0: + in_dim = in_dim + else: + in_dim = self.out_dim_list[i - 1] + out_dim = self.out_dim_list[i] + downsample_scale = self.downsample_scale_list[i] + conv_block = [] + for j in range(self.num_conv_per_merge_list[i]): + if j == 0: + conv_in_dim = in_dim + else: + conv_in_dim = out_dim + conv_block.append( + paddle.nn.Conv2D( + kernel_size=(3, 3), + padding=(1, 1), + in_channels=conv_in_dim, + out_channels=out_dim, + ) + ) + conv_block.append( + paddle.nn.GroupNorm( + num_groups=self.num_group_list[i], num_channels=out_dim + ) + ) + conv_block.append( + act_mod.get_activation(activation) + if activation != "leaky_relu" + else nn.LeakyReLU(NEGATIVE_SLOPE) + ) + conv_block = paddle.nn.Sequential(*conv_block) + self.conv_block_list.append(conv_block) + patch_merge = cuboid_encoder.PatchMerging3D( + dim=out_dim, + out_dim=out_dim, + padding_type=padding_type, + downsample=(1, downsample_scale, downsample_scale), + linear_init_mode=linear_init_mode, + norm_init_mode=norm_init_mode, + ) + self.patch_merge_list.append(patch_merge) + self.reset_parameters() + + def reset_parameters(self): + for m in self.children(): + cuboid_utils.apply_initialization( + m, + conv_mode=self.conv_init_mode, + linear_mode=self.linear_init_mode, + norm_mode=self.norm_init_mode, + ) + + def get_out_shape_list(self, input_shape): + out_shape_list = [] + for patch_merge in self.patch_merge_list: + input_shape = patch_merge.get_out_shape(input_shape) + out_shape_list.append(input_shape) + return out_shape_list + + def forward(self, x): + """x --> [K x Conv2D] --> PatchMerge --> ... --> [K x Conv2D] --> PatchMerge + + Args: + x : (B, T, H, W, C) + + Returns: + out : (B, T, H_new, W_new, C_out) + """ + + for i, (conv_block, patch_merge) in enumerate( + zip(self.conv_block_list, self.patch_merge_list) + ): + B, T, H, W, C = x.shape + if self.num_conv_per_merge_list[i] > 0: + x = x.reshape([B * T, H, W, C]).transpose(perm=[0, 3, 1, 2]) + x = conv_block(x).transpose(perm=[0, 2, 3, 1]).reshape([B, T, H, W, -1]) + x = patch_merge(x) + return x + + +class FinalStackUpsamplingDecoder(paddle.nn.Layer): + def __init__( + self, + target_shape_list: Tuple[Tuple[int, ...]], + in_dim: int, + num_conv_per_up_list: Tuple[int, ...] = None, + activation: str = "leaky", + conv_init_mode: str = "0", + linear_init_mode: str = "0", + norm_init_mode: str = "0", + ): + super(FinalStackUpsamplingDecoder, self).__init__() + self.conv_init_mode = conv_init_mode + self.linear_init_mode = linear_init_mode + self.norm_init_mode = norm_init_mode + self.target_shape_list = target_shape_list + self.out_dim_list = [ + target_shape[-1] for target_shape in self.target_shape_list + ] + self.num_upsample = len(target_shape_list) + self.in_dim = in_dim + self.num_conv_per_up_list = num_conv_per_up_list + self.num_group_list = [max(1, out_dim // 4) for out_dim in self.out_dim_list] + self.conv_block_list = paddle.nn.LayerList() + self.upsample_list = paddle.nn.LayerList() + for i in range(self.num_upsample): + if i == 0: + in_dim = in_dim + else: + in_dim = self.out_dim_list[i - 1] + out_dim = self.out_dim_list[i] + upsample = cuboid_decoder.Upsample3DLayer( + dim=in_dim, + out_dim=in_dim, + target_size=target_shape_list[i][:-1], + kernel_size=3, + conv_init_mode=conv_init_mode, + ) + self.upsample_list.append(upsample) + conv_block = [] + for j in range(num_conv_per_up_list[i]): + if j == 0: + conv_in_dim = in_dim + else: + conv_in_dim = out_dim + conv_block.append( + paddle.nn.Conv2D( + kernel_size=(3, 3), + padding=(1, 1), + in_channels=conv_in_dim, + out_channels=out_dim, + ) + ) + conv_block.append( + paddle.nn.GroupNorm( + num_groups=self.num_group_list[i], num_channels=out_dim + ) + ) + conv_block.append( + act_mod.get_activation(activation) + if activation != "leaky_relu" + else nn.LeakyReLU(NEGATIVE_SLOPE) + ) + conv_block = paddle.nn.Sequential(*conv_block) + self.conv_block_list.append(conv_block) + self.reset_parameters() + + def reset_parameters(self): + for m in self.children(): + cuboid_utils.apply_initialization( + m, + conv_mode=self.conv_init_mode, + linear_mode=self.linear_init_mode, + norm_mode=self.norm_init_mode, + ) + + @staticmethod + def get_init_params(enc_input_shape, enc_out_shape_list, large_channel=False): + dec_target_shape_list = list(enc_out_shape_list[:-1])[::-1] + [ + tuple(enc_input_shape) + ] + if large_channel: + dec_target_shape_list_large_channel = [] + for i, enc_out_shape in enumerate(enc_out_shape_list[::-1]): + dec_target_shape_large_channel = list(dec_target_shape_list[i]) + dec_target_shape_large_channel[-1] = enc_out_shape[-1] + dec_target_shape_list_large_channel.append( + tuple(dec_target_shape_large_channel) + ) + dec_target_shape_list = dec_target_shape_list_large_channel + dec_in_dim = enc_out_shape_list[-1][-1] + return dec_target_shape_list, dec_in_dim + + def forward(self, x): + """x --> Upsample --> [K x Conv2D] --> ... --> Upsample --> [K x Conv2D] + + Args: + x : Shape (B, T, H, W, C) + + Returns: + out : Shape (B, T, H_new, W_new, C) + """ + for i, (conv_block, upsample) in enumerate( + zip(self.conv_block_list, self.upsample_list) + ): + x = upsample(x) + if self.num_conv_per_up_list[i] > 0: + B, T, H, W, C = x.shape + x = x.reshape([B * T, H, W, C]).transpose(perm=[0, 3, 1, 2]) + x = conv_block(x).transpose(perm=[0, 2, 3, 1]).reshape([B, T, H, W, -1]) + return x + + +class CuboidTransformer(base.Arch): + """Cuboid Transformer for spatiotemporal forecasting + + We adopt the Non-autoregressive encoder-decoder architecture. + The decoder takes the multi-scale memory output from the encoder. + + The initial downsampling / upsampling layers will be + Downsampling: [K x Conv2D --> PatchMerge] + Upsampling: [Nearest Interpolation-based Upsample --> K x Conv2D] + + x --> downsample (optional) ---> (+pos_embed) ---> enc --> mem_l initial_z (+pos_embed) ---> FC + | | + |------------| + | + | + y <--- upsample (optional) <--- dec <---------- + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). + input_shape (Tuple[int, ...]): The shape of the input data. + target_shape (Tuple[int, ...]): The shape of the target data. + base_units (int, optional): The base units. Defaults to 128. + block_units (int, optional): The block units. Defaults to None. + scale_alpha (float, optional): We scale up the channels based on the formula: + - round_to(base_units * max(downsample_scale) ** units_alpha, 4). Defaults to 1.0. + num_heads (int, optional): The number of heads. Defaults to 4. + attn_drop (float, optional): The attention dropout. Defaults to 0.0. + proj_drop (float, optional): The projection dropout. Defaults to 0.0. + ffn_drop (float, optional): The ffn dropout. Defaults to 0.0. + downsample (int, optional): The rate of downsample. Defaults to 2. + downsample_type (str, optional): The type of downsample. Defaults to "patch_merge". + upsample_type (str, optional): The rate of upsample. Defaults to "upsample". + upsample_kernel_size (int, optional): The kernel size of upsample. Defaults to 3. + enc_depth (list, optional): The depth of encoder. Defaults to [4, 4, 4]. + enc_attn_patterns (str, optional): The pattern of encoder attention. Defaults to None. + enc_cuboid_size (list, optional): The cuboid size of encoder. Defaults to [(4, 4, 4), (4, 4, 4)]. + enc_cuboid_strategy (list, optional): The cuboid strategy of encoder. Defaults to [("l", "l", "l"), ("d", "d", "d")]. + enc_shift_size (list, optional): The shift size of encoder. Defaults to [(0, 0, 0), (0, 0, 0)]. + enc_use_inter_ffn (bool, optional): Whether to use intermediate FFN for encoder. Defaults to True. + dec_depth (list, optional): The depth of decoder. Defaults to [2, 2]. + dec_cross_start (int, optional): The cross start of decoder. Defaults to 0. + dec_self_attn_patterns (str, optional): The partterns of decoder. Defaults to None. + dec_self_cuboid_size (list, optional): The cuboid size of decoder. Defaults to [(4, 4, 4), (4, 4, 4)]. + dec_self_cuboid_strategy (list, optional): The strategy of decoder. Defaults to [("l", "l", "l"), ("d", "d", "d")]. + dec_self_shift_size (list, optional): The shift size of decoder. Defaults to [(1, 1, 1), (0, 0, 0)]. + dec_cross_attn_patterns (_type_, optional): The cross attention patterns of decoder. Defaults to None. + dec_cross_cuboid_hw (list, optional): The cuboid_hw of decoder. Defaults to [(4, 4), (4, 4)]. + dec_cross_cuboid_strategy (list, optional): The cuboid strategy of decoder. Defaults to [("l", "l", "l"), ("d", "l", "l")]. + dec_cross_shift_hw (list, optional): The shift_hw of decoder. Defaults to [(0, 0), (0, 0)]. + dec_cross_n_temporal (list, optional): The cross_n_temporal of decoder. Defaults to [1, 2]. + dec_cross_last_n_frames (int, optional): The cross_last_n_frames of decoder. Defaults to None. + dec_use_inter_ffn (bool, optional): Whether to use intermediate FFN for decoder. Defaults to True. + dec_hierarchical_pos_embed (bool, optional): Whether to use hierarchical pos_embed for decoder. Defaults to False. + num_global_vectors (int, optional): The num of global vectors. Defaults to 4. + use_dec_self_global (bool, optional): Whether to use global vector for decoder. Defaults to True. + dec_self_update_global (bool, optional): Whether to update global vector for decoder. Defaults to True. + use_dec_cross_global (bool, optional): Whether to use cross global vector for decoder. Defaults to True. + use_global_vector_ffn (bool, optional): Whether to use global vector FFN. Defaults to True. + use_global_self_attn (bool, optional): Whether to use global attentions. Defaults to False. + separate_global_qkv (bool, optional): Whether to separate global qkv. Defaults to False. + global_dim_ratio (int, optional): The ratio of global dim. Defaults to 1. + self_pattern (str, optional): The pattern. Defaults to "axial". + cross_self_pattern (str, optional): The self cross pattern. Defaults to "axial". + cross_pattern (str, optional): The cross pattern. Defaults to "cross_1x1". + z_init_method (str, optional): How the initial input to the decoder is initialized. Defaults to "nearest_interp". + initial_downsample_type (str, optional): The downsample type of initial. Defaults to "conv". + initial_downsample_activation (str, optional): The downsample activation of initial. Defaults to "leaky". + initial_downsample_scale (int, optional): The downsample scale of initial. Defaults to 1. + initial_downsample_conv_layers (int, optional): The conv layer of downsample of initial. Defaults to 2. + final_upsample_conv_layers (int, optional): The conv layer of final upsample. Defaults to 2. + initial_downsample_stack_conv_num_layers (int, optional): The num of stack conv layer of initial downsample. Defaults to 1. + initial_downsample_stack_conv_dim_list (list, optional): The dim list of stack conv of initial downsample. Defaults to None. + initial_downsample_stack_conv_downscale_list (list, optional): The downscale list of stack conv of initial downsample. Defaults to [1]. + initial_downsample_stack_conv_num_conv_list (list, optional): The num of stack conv list of initial downsample. Defaults to [2]. + ffn_activation (str, optional): The activation of FFN. Defaults to "leaky". + gated_ffn (bool, optional): Whether to use gate FFN. Defaults to False. + norm_layer (str, optional): The type of normilize. Defaults to "layer_norm". + padding_type (str, optional): The type of padding. Defaults to "ignore". + pos_embed_type (str, optional): The type of pos embeding. Defaults to "t+hw". + checkpoint_level (bool, optional): Whether to use checkpoint. Defaults to True. + use_relative_pos (bool, optional): Whether to use relative pose. Defaults to True. + self_attn_use_final_proj (bool, optional): Whether to use final projection. Defaults to True. + dec_use_first_self_attn (bool, optional): Whether to use first self attention for decoder. Defaults to False. + attn_linear_init_mode (str, optional): The mode of attention linear init. Defaults to "0". + ffn_linear_init_mode (str, optional): The mode of FFN linear init. Defaults to "0". + conv_init_mode (str, optional): The mode of conv init. Defaults to "0". + down_up_linear_init_mode (str, optional): The mode of downsample and upsample linear init. Defaults to "0". + norm_init_mode (str, optional): The mode of normalization init. Defaults to "0". + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + input_shape: Tuple[int, ...], + target_shape: Tuple[int, ...], + base_units: int = 128, + block_units: int = None, + scale_alpha: float = 1.0, + num_heads: int = 4, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ffn_drop: float = 0.0, + downsample: int = 2, + downsample_type: str = "patch_merge", + upsample_type: str = "upsample", + upsample_kernel_size: int = 3, + enc_depth: Tuple[int, ...] = [4, 4, 4], + enc_attn_patterns: str = None, + enc_cuboid_size: Tuple[Tuple[int, ...], ...] = [(4, 4, 4), (4, 4, 4)], + enc_cuboid_strategy: Tuple[Tuple[str, ...], ...] = [ + ("l", "l", "l"), + ("d", "d", "d"), + ], + enc_shift_size: Tuple[Tuple[int, ...], ...] = [(0, 0, 0), (0, 0, 0)], + enc_use_inter_ffn: str = True, + dec_depth: Tuple[int, ...] = [2, 2], + dec_cross_start: int = 0, + dec_self_attn_patterns: str = None, + dec_self_cuboid_size: Tuple[Tuple[int, ...], ...] = [(4, 4, 4), (4, 4, 4)], + dec_self_cuboid_strategy: Tuple[Tuple[str, ...], ...] = [ + ("l", "l", "l"), + ("d", "d", "d"), + ], + dec_self_shift_size: Tuple[Tuple[int, ...], ...] = [(1, 1, 1), (0, 0, 0)], + dec_cross_attn_patterns: str = None, + dec_cross_cuboid_hw: Tuple[Tuple[int, ...], ...] = [(4, 4), (4, 4)], + dec_cross_cuboid_strategy: Tuple[Tuple[str, ...], ...] = [ + ("l", "l", "l"), + ("d", "l", "l"), + ], + dec_cross_shift_hw: Tuple[Tuple[int, ...], ...] = [(0, 0), (0, 0)], + dec_cross_n_temporal: Tuple[int, ...] = [1, 2], + dec_cross_last_n_frames: int = None, + dec_use_inter_ffn: bool = True, + dec_hierarchical_pos_embed: bool = False, + num_global_vectors: int = 4, + use_dec_self_global: bool = True, + dec_self_update_global: bool = True, + use_dec_cross_global: bool = True, + use_global_vector_ffn: bool = True, + use_global_self_attn: bool = False, + separate_global_qkv: bool = False, + global_dim_ratio: int = 1, + self_pattern: str = "axial", + cross_self_pattern: str = "axial", + cross_pattern: str = "cross_1x1", + z_init_method: str = "nearest_interp", + initial_downsample_type: str = "conv", + initial_downsample_activation: str = "leaky", + initial_downsample_scale: int = 1, + initial_downsample_conv_layers: int = 2, + final_upsample_conv_layers: int = 2, + initial_downsample_stack_conv_num_layers: int = 1, + initial_downsample_stack_conv_dim_list: Tuple[int, ...] = None, + initial_downsample_stack_conv_downscale_list: Tuple[int, ...] = [1], + initial_downsample_stack_conv_num_conv_list: Tuple[int, ...] = [2], + ffn_activation: str = "leaky", + gated_ffn: bool = False, + norm_layer: str = "layer_norm", + padding_type: str = "ignore", + pos_embed_type: str = "t+hw", + checkpoint_level: bool = True, + use_relative_pos: bool = True, + self_attn_use_final_proj: bool = True, + dec_use_first_self_attn: bool = False, + attn_linear_init_mode: str = "0", + ffn_linear_init_mode: str = "0", + conv_init_mode: str = "0", + down_up_linear_init_mode: str = "0", + norm_init_mode: str = "0", + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + self.attn_linear_init_mode = attn_linear_init_mode + self.ffn_linear_init_mode = ffn_linear_init_mode + self.conv_init_mode = conv_init_mode + self.down_up_linear_init_mode = down_up_linear_init_mode + self.norm_init_mode = norm_init_mode + assert len(enc_depth) == len(dec_depth) + self.base_units = base_units + self.num_global_vectors = num_global_vectors + + num_blocks = len(enc_depth) + if isinstance(self_pattern, str): + enc_attn_patterns = [self_pattern] * num_blocks + + if isinstance(cross_self_pattern, str): + dec_self_attn_patterns = [cross_self_pattern] * num_blocks + + if isinstance(cross_pattern, str): + dec_cross_attn_patterns = [cross_pattern] * num_blocks + + if global_dim_ratio != 1: + assert ( + separate_global_qkv is True + ), "Setting global_dim_ratio != 1 requires separate_global_qkv == True." + self.global_dim_ratio = global_dim_ratio + self.z_init_method = z_init_method + assert self.z_init_method in ["zeros", "nearest_interp", "last", "mean"] + self.input_shape = input_shape + self.target_shape = target_shape + T_in, H_in, W_in, C_in = input_shape + T_out, H_out, W_out, C_out = target_shape + assert H_in == H_out and W_in == W_out + if self.num_global_vectors > 0: + init_data = paddle.zeros( + (self.num_global_vectors, global_dim_ratio * base_units) + ) + self.init_global_vectors = paddle.create_parameter( + shape=init_data.shape, + dtype=init_data.dtype, + default_initializer=nn.initializer.Constant(0.0), + ) + + self.init_global_vectors.stop_gradient = not True + new_input_shape = self.get_initial_encoder_final_decoder( + initial_downsample_scale=initial_downsample_scale, + initial_downsample_type=initial_downsample_type, + activation=initial_downsample_activation, + initial_downsample_conv_layers=initial_downsample_conv_layers, + final_upsample_conv_layers=final_upsample_conv_layers, + padding_type=padding_type, + initial_downsample_stack_conv_num_layers=initial_downsample_stack_conv_num_layers, + initial_downsample_stack_conv_dim_list=initial_downsample_stack_conv_dim_list, + initial_downsample_stack_conv_downscale_list=initial_downsample_stack_conv_downscale_list, + initial_downsample_stack_conv_num_conv_list=initial_downsample_stack_conv_num_conv_list, + ) + T_in, H_in, W_in, _ = new_input_shape + self.encoder = cuboid_encoder.CuboidTransformerEncoder( + input_shape=(T_in, H_in, W_in, base_units), + base_units=base_units, + block_units=block_units, + scale_alpha=scale_alpha, + depth=enc_depth, + downsample=downsample, + downsample_type=downsample_type, + block_attn_patterns=enc_attn_patterns, + block_cuboid_size=enc_cuboid_size, + block_strategy=enc_cuboid_strategy, + block_shift_size=enc_shift_size, + num_heads=num_heads, + attn_drop=attn_drop, + proj_drop=proj_drop, + ffn_drop=ffn_drop, + gated_ffn=gated_ffn, + ffn_activation=ffn_activation, + norm_layer=norm_layer, + use_inter_ffn=enc_use_inter_ffn, + padding_type=padding_type, + use_global_vector=num_global_vectors > 0, + use_global_vector_ffn=use_global_vector_ffn, + use_global_self_attn=use_global_self_attn, + separate_global_qkv=separate_global_qkv, + global_dim_ratio=global_dim_ratio, + checkpoint_level=checkpoint_level, + use_relative_pos=use_relative_pos, + self_attn_use_final_proj=self_attn_use_final_proj, + attn_linear_init_mode=attn_linear_init_mode, + ffn_linear_init_mode=ffn_linear_init_mode, + conv_init_mode=conv_init_mode, + down_linear_init_mode=down_up_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + self.enc_pos_embed = cuboid_decoder.PosEmbed( + embed_dim=base_units, typ=pos_embed_type, maxH=H_in, maxW=W_in, maxT=T_in + ) + mem_shapes = self.encoder.get_mem_shapes() + self.z_proj = paddle.nn.Linear( + in_features=mem_shapes[-1][-1], out_features=mem_shapes[-1][-1] + ) + self.dec_pos_embed = cuboid_decoder.PosEmbed( + embed_dim=mem_shapes[-1][-1], + typ=pos_embed_type, + maxT=T_out, + maxH=mem_shapes[-1][1], + maxW=mem_shapes[-1][2], + ) + self.decoder = cuboid_decoder.CuboidTransformerDecoder( + target_temporal_length=T_out, + mem_shapes=mem_shapes, + cross_start=dec_cross_start, + depth=dec_depth, + upsample_type=upsample_type, + block_self_attn_patterns=dec_self_attn_patterns, + block_self_cuboid_size=dec_self_cuboid_size, + block_self_shift_size=dec_self_shift_size, + block_self_cuboid_strategy=dec_self_cuboid_strategy, + block_cross_attn_patterns=dec_cross_attn_patterns, + block_cross_cuboid_hw=dec_cross_cuboid_hw, + block_cross_shift_hw=dec_cross_shift_hw, + block_cross_cuboid_strategy=dec_cross_cuboid_strategy, + block_cross_n_temporal=dec_cross_n_temporal, + cross_last_n_frames=dec_cross_last_n_frames, + num_heads=num_heads, + attn_drop=attn_drop, + proj_drop=proj_drop, + ffn_drop=ffn_drop, + upsample_kernel_size=upsample_kernel_size, + ffn_activation=ffn_activation, + gated_ffn=gated_ffn, + norm_layer=norm_layer, + use_inter_ffn=dec_use_inter_ffn, + max_temporal_relative=T_in + T_out, + padding_type=padding_type, + hierarchical_pos_embed=dec_hierarchical_pos_embed, + pos_embed_type=pos_embed_type, + use_self_global=num_global_vectors > 0 and use_dec_self_global, + self_update_global=dec_self_update_global, + use_cross_global=num_global_vectors > 0 and use_dec_cross_global, + use_global_vector_ffn=use_global_vector_ffn, + use_global_self_attn=use_global_self_attn, + separate_global_qkv=separate_global_qkv, + global_dim_ratio=global_dim_ratio, + checkpoint_level=checkpoint_level, + use_relative_pos=use_relative_pos, + self_attn_use_final_proj=self_attn_use_final_proj, + use_first_self_attn=dec_use_first_self_attn, + attn_linear_init_mode=attn_linear_init_mode, + ffn_linear_init_mode=ffn_linear_init_mode, + conv_init_mode=conv_init_mode, + up_linear_init_mode=down_up_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + self.reset_parameters() + + def get_initial_encoder_final_decoder( + self, + initial_downsample_type, + activation, + initial_downsample_scale, + initial_downsample_conv_layers, + final_upsample_conv_layers, + padding_type, + initial_downsample_stack_conv_num_layers, + initial_downsample_stack_conv_dim_list, + initial_downsample_stack_conv_downscale_list, + initial_downsample_stack_conv_num_conv_list, + ): + T_in, H_in, W_in, C_in = self.input_shape + T_out, H_out, W_out, C_out = self.target_shape + self.initial_downsample_type = initial_downsample_type + if self.initial_downsample_type == "conv": + if isinstance(initial_downsample_scale, int): + initial_downsample_scale = ( + 1, + initial_downsample_scale, + initial_downsample_scale, + ) + elif len(initial_downsample_scale) == 2: + initial_downsample_scale = 1, *initial_downsample_scale + elif len(initial_downsample_scale) == 3: + initial_downsample_scale = tuple(initial_downsample_scale) + else: + raise NotImplementedError( + f"initial_downsample_scale {initial_downsample_scale} format not supported!" + ) + self.initial_encoder = InitialEncoder( + dim=C_in, + out_dim=self.base_units, + downsample_scale=initial_downsample_scale, + num_conv_layers=initial_downsample_conv_layers, + padding_type=padding_type, + activation=activation, + conv_init_mode=self.conv_init_mode, + linear_init_mode=self.down_up_linear_init_mode, + norm_init_mode=self.norm_init_mode, + ) + + self.final_decoder = FinalDecoder( + dim=self.base_units, + target_thw=(T_out, H_out, W_out), + num_conv_layers=final_upsample_conv_layers, + activation=activation, + conv_init_mode=self.conv_init_mode, + linear_init_mode=self.down_up_linear_init_mode, + norm_init_mode=self.norm_init_mode, + ) + new_input_shape = self.initial_encoder.patch_merge.get_out_shape( + self.input_shape + ) + self.dec_final_proj = paddle.nn.Linear( + in_features=self.base_units, out_features=C_out + ) + elif self.initial_downsample_type == "stack_conv": + if initial_downsample_stack_conv_dim_list is None: + initial_downsample_stack_conv_dim_list = [ + self.base_units + ] * initial_downsample_stack_conv_num_layers + self.initial_encoder = InitialStackPatchMergingEncoder( + num_merge=initial_downsample_stack_conv_num_layers, + in_dim=C_in, + out_dim_list=initial_downsample_stack_conv_dim_list, + downsample_scale_list=initial_downsample_stack_conv_downscale_list, + num_conv_per_merge_list=initial_downsample_stack_conv_num_conv_list, + padding_type=padding_type, + activation=activation, + conv_init_mode=self.conv_init_mode, + linear_init_mode=self.down_up_linear_init_mode, + norm_init_mode=self.norm_init_mode, + ) + initial_encoder_out_shape_list = self.initial_encoder.get_out_shape_list( + self.target_shape + ) + ( + dec_target_shape_list, + dec_in_dim, + ) = FinalStackUpsamplingDecoder.get_init_params( + enc_input_shape=self.target_shape, + enc_out_shape_list=initial_encoder_out_shape_list, + large_channel=True, + ) + self.final_decoder = FinalStackUpsamplingDecoder( + target_shape_list=dec_target_shape_list, + in_dim=dec_in_dim, + num_conv_per_up_list=initial_downsample_stack_conv_num_conv_list[::-1], + activation=activation, + conv_init_mode=self.conv_init_mode, + linear_init_mode=self.down_up_linear_init_mode, + norm_init_mode=self.norm_init_mode, + ) + self.dec_final_proj = paddle.nn.Linear( + in_features=dec_target_shape_list[-1][-1], out_features=C_out + ) + new_input_shape = self.initial_encoder.get_out_shape_list(self.input_shape)[ + -1 + ] + else: + raise NotImplementedError(f"{self.initial_downsample_type} is invalid.") + self.input_shape_after_initial_downsample = new_input_shape + T_in, H_in, W_in, _ = new_input_shape + return new_input_shape + + def reset_parameters(self): + if self.num_global_vectors > 0: + self.init_global_vectors = initializer.trunc_normal_( + self.init_global_vectors, std=0.02 + ) + if hasattr(self.initial_encoder, "reset_parameters"): + self.initial_encoder.reset_parameters() + else: + cuboid_utils.apply_initialization( + self.initial_encoder, + conv_mode=self.conv_init_mode, + linear_mode=self.down_up_linear_init_mode, + norm_mode=self.norm_init_mode, + ) + if hasattr(self.final_decoder, "reset_parameters"): + self.final_decoder.reset_parameters() + else: + cuboid_utils.apply_initialization( + self.final_decoder, + conv_mode=self.conv_init_mode, + linear_mode=self.down_up_linear_init_mode, + norm_mode=self.norm_init_mode, + ) + cuboid_utils.apply_initialization( + self.dec_final_proj, linear_mode=self.down_up_linear_init_mode + ) + self.encoder.reset_parameters() + self.enc_pos_embed.reset_parameters() + self.decoder.reset_parameters() + self.dec_pos_embed.reset_parameters() + cuboid_utils.apply_initialization(self.z_proj, linear_mode="0") + + def get_initial_z(self, final_mem, T_out): + B = final_mem.shape[0] + if self.z_init_method == "zeros": + z_shape = list((1, T_out)) + final_mem.shape[2:] + initial_z = paddle.zeros(shape=z_shape, dtype=final_mem.dtype) + initial_z = self.z_proj(self.dec_pos_embed(initial_z)).expand( + shape=[B, -1, -1, -1, -1] + ) + elif self.z_init_method == "nearest_interp": + initial_z = paddle.nn.functional.interpolate( + x=final_mem.transpose(perm=[0, 4, 1, 2, 3]), + size=(T_out, final_mem.shape[2], final_mem.shape[3]), + ).transpose(perm=[0, 2, 3, 4, 1]) + initial_z = self.z_proj(initial_z) + elif self.z_init_method == "last": + initial_z = paddle.broadcast_to( + x=final_mem[:, -1:, :, :, :], shape=(B, T_out) + final_mem.shape[2:] + ) + initial_z = self.z_proj(initial_z) + elif self.z_init_method == "mean": + initial_z = paddle.broadcast_to( + x=final_mem.mean(axis=1, keepdims=True), + shape=(B, T_out) + final_mem.shape[2:], + ) + initial_z = self.z_proj(initial_z) + else: + raise NotImplementedError + return initial_z + + def forward(self, x, verbose=False): + """ + Args: + x : Shape (B, T, H, W, C) + verbos : if True, print intermediate shapes + Returns: + out : The output Shape (B, T_out, H, W, C_out) + """ + + x = self.concat_to_tensor(x, self.input_keys) + flag_ndim = x.ndim + if flag_ndim == 6: + x = x.reshape([-1, *x.shape[2:]]) + B, _, _, _, _ = x.shape + + T_out = self.target_shape[0] + x = self.initial_encoder(x) + x = self.enc_pos_embed(x) + + if self.num_global_vectors > 0: + init_global_vectors = self.init_global_vectors.expand( + shape=[ + B, + self.num_global_vectors, + self.global_dim_ratio * self.base_units, + ] + ) + mem_l, mem_global_vector_l = self.encoder(x, init_global_vectors) + else: + mem_l = self.encoder(x) + + if verbose: + for i, mem in enumerate(mem_l): + print(f"mem[{i}].shape = {mem.shape}") + initial_z = self.get_initial_z(final_mem=mem_l[-1], T_out=T_out) + + if self.num_global_vectors > 0: + dec_out = self.decoder(initial_z, mem_l, mem_global_vector_l) + else: + dec_out = self.decoder(initial_z, mem_l) + + dec_out = self.final_decoder(dec_out) + + out = self.dec_final_proj(dec_out) + if flag_ndim == 6: + out = out.reshape([-1, *out.shape]) + return {key: out for key in self.output_keys} diff --git a/ppsci/arch/cuboid_transformer_decoder.py b/ppsci/arch/cuboid_transformer_decoder.py new file mode 100644 index 000000000..ff736e30a --- /dev/null +++ b/ppsci/arch/cuboid_transformer_decoder.py @@ -0,0 +1,1257 @@ +from functools import lru_cache +from typing import Tuple + +import numpy as np +import paddle +import paddle.nn.functional as F +from paddle import nn +from paddle.distributed import fleet + +import ppsci.arch.cuboid_transformer_encoder as cuboid_encoder +import ppsci.arch.cuboid_transformer_utils as cuboid_utils +from ppsci.utils import initializer + + +class PosEmbed(paddle.nn.Layer): + """pose embeding + + Args: + embed_dim (int): The dimension of embeding. + maxT (int): The embeding max time. + maxH (int): The embeding max height. + maxW (int): The embeding max width. + typ (str): + The type of the positional embedding. + - t+h+w: + Embed the spatial position to embeddings + - t+hw: + Embed the spatial position to embeddings + """ + + def __init__(self, embed_dim, maxT, maxH, maxW, typ: str = "t+h+w"): + super(PosEmbed, self).__init__() + self.typ = typ + assert self.typ in ["t+h+w", "t+hw"] + self.maxT = maxT + self.maxH = maxH + self.maxW = maxW + self.embed_dim = embed_dim + if self.typ == "t+h+w": + self.T_embed = paddle.nn.Embedding( + num_embeddings=maxT, embedding_dim=embed_dim + ) + self.H_embed = paddle.nn.Embedding( + num_embeddings=maxH, embedding_dim=embed_dim + ) + self.W_embed = paddle.nn.Embedding( + num_embeddings=maxW, embedding_dim=embed_dim + ) + elif self.typ == "t+hw": + self.T_embed = paddle.nn.Embedding( + num_embeddings=maxT, embedding_dim=embed_dim + ) + self.HW_embed = paddle.nn.Embedding( + num_embeddings=maxH * maxW, embedding_dim=embed_dim + ) + else: + raise NotImplementedError(f"{self.typ} is invalid.") + self.reset_parameters() + + def reset_parameters(self): + for m in self.children(): + cuboid_utils.apply_initialization(m, embed_mode="0") + + def forward(self, x): + """ + Args: + x : Shape (B, T, H, W, C) + + Returns: + out : the x + positional embeddings + """ + + _, T, H, W, _ = x.shape + t_idx = paddle.arange(end=T) + h_idx = paddle.arange(end=H) + w_idx = paddle.arange(end=W) + if self.typ == "t+h+w": + return ( + x + + self.T_embed(t_idx).reshape([T, 1, 1, self.embed_dim]) + + self.H_embed(h_idx).reshape([1, H, 1, self.embed_dim]) + + self.W_embed(w_idx).reshape([1, 1, W, self.embed_dim]) + ) + elif self.typ == "t+hw": + spatial_idx = h_idx.unsqueeze(axis=-1) * self.maxW + w_idx + return ( + x + + self.T_embed(t_idx).reshape([T, 1, 1, self.embed_dim]) + + self.HW_embed(spatial_idx) + ) + else: + raise NotImplementedError(f"{self.typ} is invalid.") + + +@lru_cache() +def compute_cuboid_cross_attention_mask( + T_x, T_mem, H, W, n_temporal, cuboid_hw, shift_hw, strategy, padding_type, device +): + pad_t_mem = (n_temporal - T_mem % n_temporal) % n_temporal + pad_t_x = (n_temporal - T_x % n_temporal) % n_temporal + pad_h = (cuboid_hw[0] - H % cuboid_hw[0]) % cuboid_hw[0] + pad_w = (cuboid_hw[1] - W % cuboid_hw[1]) % cuboid_hw[1] + mem_cuboid_size = ((T_mem + pad_t_mem) // n_temporal,) + cuboid_hw + x_cuboid_size = ((T_x + pad_t_x) // n_temporal,) + cuboid_hw + if pad_t_mem > 0 or pad_h > 0 or pad_w > 0: + if padding_type == "ignore": + mem_mask = paddle.ones(shape=(1, T_mem, H, W, 1), dtype="bool") + mem_mask = F.pad( + mem_mask, [0, 0, 0, pad_w, 0, pad_h, pad_t_mem, 0], data_format="NDHWC" + ) + else: + mem_mask = paddle.ones( + shape=(1, T_mem + pad_t_mem, H + pad_h, W + pad_w, 1), dtype="bool" + ) + if pad_t_x > 0 or pad_h > 0 or pad_w > 0: + if padding_type == "ignore": + x_mask = paddle.ones(shape=(1, T_x, H, W, 1), dtype="bool") + x_mask = F.pad( + x_mask, [0, 0, 0, pad_w, 0, pad_h, 0, pad_t_x], data_format="NDHWC" + ) + else: + x_mask = paddle.ones( + shape=(1, T_x + pad_t_x, H + pad_h, W + pad_w, 1), dtype="bool" + ) + if any(i > 0 for i in shift_hw): + if padding_type == "ignore": + x_mask = paddle.roll( + x=x_mask, shifts=(-shift_hw[0], -shift_hw[1]), axis=(2, 3) + ) + mem_mask = paddle.roll( + x=mem_mask, shifts=(-shift_hw[0], -shift_hw[1]), axis=(2, 3) + ) + x_mask = cuboid_encoder.cuboid_reorder(x_mask, x_cuboid_size, strategy=strategy) + x_mask = x_mask.squeeze(axis=-1).squeeze(axis=0) + num_cuboids, x_cuboid_volume = x_mask.shape + mem_mask = cuboid_encoder.cuboid_reorder( + mem_mask, mem_cuboid_size, strategy=strategy + ) + mem_mask = mem_mask.squeeze(axis=-1).squeeze(axis=0) + _, mem_cuboid_volume = mem_mask.shape + shift_mask = np.zeros(shape=(1, n_temporal, H + pad_h, W + pad_w, 1)) + cnt = 0 + for h in ( + slice(-cuboid_hw[0]), + slice(-cuboid_hw[0], -shift_hw[0]), + slice(-shift_hw[0], None), + ): + for w in ( + slice(-cuboid_hw[1]), + slice(-cuboid_hw[1], -shift_hw[1]), + slice(-shift_hw[1], None), + ): + shift_mask[:, :, h, w, :] = cnt + cnt += 1 + shift_mask = paddle.to_tensor(shift_mask) + shift_mask = cuboid_encoder.cuboid_reorder( + shift_mask, (1,) + cuboid_hw, strategy=strategy + ) + shift_mask = shift_mask.squeeze(axis=-1).squeeze(axis=0) + shift_mask = shift_mask.unsqueeze(axis=1) - shift_mask.unsqueeze(axis=2) == 0 + bh_bw = cuboid_hw[0] * cuboid_hw[1] + attn_mask = ( + shift_mask.reshape((num_cuboids, 1, bh_bw, 1, bh_bw)) + * x_mask.reshape((num_cuboids, -1, bh_bw, 1, 1)) + * mem_mask.reshape([num_cuboids, 1, 1, -1, bh_bw]) + ) + attn_mask = attn_mask.reshape([num_cuboids, x_cuboid_volume, mem_cuboid_volume]) + return attn_mask + + +class CuboidCrossAttentionLayer(paddle.nn.Layer): + """Implements the cuboid cross attention. + + The idea of Cuboid Cross Attention is to extend the idea of cuboid self attention to work for the + encoder-decoder-type cross attention. + + Assume that there is a memory tensor with shape (T1, H, W, C) and another query tensor with shape (T2, H, W, C), + + Here, we decompose the query tensor and the memory tensor into the same number of cuboids and attend the cuboid in + the query tensor with the corresponding cuboid in the memory tensor. + + For the height and width axes, we reuse the grid decomposition techniques described in the cuboid self-attention. + For the temporal axis, the layer supports the "n_temporal" parameter, that controls the number of cuboids we can + get after cutting the tensors. For example, if the temporal dilation is 2, both the query and + memory will be decomposed into 2 cuboids along the temporal axis. Like in the Cuboid Self-attention, + we support "local" and "dilated" decomposition strategy. + + The complexity of the layer is O((T2 / n_t * Bh * Bw) * (T1 / n_t * Bh * Bw) * n_t (H / Bh) (W / Bw)) = O(T2 * T1 / n_t H W Bh Bw) + + Args: + dim (int): The dimention of input tensor. + num_heads (int): The number of head. + n_temporal (int, optional): The num of temporal. Defaults to 1. + cuboid_hw (tuple, optional): The height and width of cuboid. Defaults to (7, 7). + shift_hw (tuple, optional): The height and width of shift. Defaults to (0, 0). + strategy (tuple, optional): The strategy. Defaults to ("d", "l", "l"). + padding_type (str, optional): The type of padding. Defaults to "ignore". + cross_last_n_frames (int, optional): The cross_last_n_frames of decoder. Defaults to None. + qkv_bias (bool, optional): Whether to enable bias in calculating qkv attention. Defaults to False. + qk_scale (float, optional): Whether to enable scale factor when calculating the attention. Defaults to None. + attn_drop (float, optional): The attention dropout. Defaults to 0.0. + proj_drop (float, optional): The projrction dropout. Defaults to 0.0. + max_temporal_relative (int, optional): The max temporal. Defaults to 50. + norm_layer (str, optional): The normalization layer. Defaults to "layer_norm". + use_global_vector (bool, optional): Whether to use the global vector or not. Defaults to True. + separate_global_qkv (bool, optional): Whether to use different network to calc q_global, k_global, v_global. Defaults to False. + global_dim_ratio (int, optional): The dim (channels) of global vectors is `global_dim_ratio*dim`. Defaults to 1. + checkpoint_level (int, optional): Whether to enable gradient checkpointing. Defaults to 1. + use_relative_pos (bool, optional): Whether to use relative pos. Defaults to True. + attn_linear_init_mode (str, optional): The mode of attention linear initialization. Defaults to "0". + ffn_linear_init_mode (str, optional): The mode of FFN linear initialization. Defaults to "0". + norm_init_mode (str, optional): The mode of normalization initialization. Defaults to "0". + """ + + def __init__( + self, + dim: int, + num_heads: int, + n_temporal: int = 1, + cuboid_hw: Tuple[int, ...] = (7, 7), + shift_hw: Tuple[int, ...] = (0, 0), + strategy: Tuple[str, ...] = ("d", "l", "l"), + padding_type: str = "ignore", + cross_last_n_frames: int = None, + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + max_temporal_relative: int = 50, + norm_layer: str = "layer_norm", + use_global_vector: bool = True, + separate_global_qkv: bool = False, + global_dim_ratio: int = 1, + checkpoint_level: int = 1, + use_relative_pos: bool = True, + attn_linear_init_mode: str = "0", + ffn_linear_init_mode: str = "0", + norm_init_mode: str = "0", + ): + super(CuboidCrossAttentionLayer, self).__init__() + self.attn_linear_init_mode = attn_linear_init_mode + self.ffn_linear_init_mode = ffn_linear_init_mode + self.norm_init_mode = norm_init_mode + self.dim = dim + self.num_heads = num_heads + self.n_temporal = n_temporal + assert n_temporal > 0 + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + shift_hw = list(shift_hw) + if strategy[1] == "d": + shift_hw[0] = 0 + if strategy[2] == "d": + shift_hw[1] = 0 + self.cuboid_hw = cuboid_hw + self.shift_hw = tuple(shift_hw) + self.strategy = strategy + self.padding_type = padding_type + self.max_temporal_relative = max_temporal_relative + self.cross_last_n_frames = cross_last_n_frames + self.use_relative_pos = use_relative_pos + self.use_global_vector = use_global_vector + self.separate_global_qkv = separate_global_qkv + if global_dim_ratio != 1 and separate_global_qkv is False: + raise ValueError( + "Setting global_dim_ratio != 1 requires separate_global_qkv == True." + ) + self.global_dim_ratio = global_dim_ratio + if self.padding_type not in ["ignore", "zeros", "nearest"]: + raise ValueError('padding_type should be ["ignore", "zeros", "nearest"]') + if use_relative_pos: + init_data = paddle.zeros( + ( + (2 * max_temporal_relative - 1) + * (2 * cuboid_hw[0] - 1) + * (2 * cuboid_hw[1] - 1), + num_heads, + ) + ) + self.relative_position_bias_table = paddle.create_parameter( + shape=init_data.shape, + dtype=init_data.dtype, + default_initializer=nn.initializer.Constant(0.0), + ) + self.relative_position_bias_table.stop_gradient = not True + self.relative_position_bias_table = initializer.trunc_normal_( + self.relative_position_bias_table, std=0.02 + ) + + coords_t = paddle.arange(end=max_temporal_relative) + coords_h = paddle.arange(end=self.cuboid_hw[0]) + coords_w = paddle.arange(end=self.cuboid_hw[1]) + coords = paddle.stack(x=paddle.meshgrid(coords_t, coords_h, coords_w)) + coords_flatten = paddle.flatten(x=coords, start_axis=1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.transpose(perm=[1, 2, 0]) + relative_coords[:, :, 0] += max_temporal_relative - 1 + relative_coords[:, :, 1] += self.cuboid_hw[0] - 1 + relative_coords[:, :, 2] += self.cuboid_hw[1] - 1 + relative_position_index = ( + relative_coords[:, :, 0] + * (2 * self.cuboid_hw[0] - 1) + * (2 * self.cuboid_hw[1] - 1) + + relative_coords[:, :, 1] * (2 * self.cuboid_hw[1] - 1) + + relative_coords[:, :, 2] + ) + self.register_buffer( + name="relative_position_index", tensor=relative_position_index + ) + self.q_proj = paddle.nn.Linear( + in_features=dim, out_features=dim, bias_attr=qkv_bias + ) + self.kv_proj = paddle.nn.Linear( + in_features=dim, out_features=dim * 2, bias_attr=qkv_bias + ) + self.attn_drop = paddle.nn.Dropout(p=attn_drop) + self.proj = paddle.nn.Linear(in_features=dim, out_features=dim) + self.proj_drop = paddle.nn.Dropout(p=proj_drop) + if self.use_global_vector: + if self.separate_global_qkv: + self.l2g_q_net = paddle.nn.Linear( + in_features=dim, out_features=dim, bias_attr=qkv_bias + ) + self.l2g_global_kv_net = paddle.nn.Linear( + in_features=global_dim_ratio * dim, + out_features=dim * 2, + bias_attr=qkv_bias, + ) + self.norm = cuboid_utils.get_norm_layer(norm_layer, in_channels=dim) + self._checkpoint_level = checkpoint_level + self.reset_parameters() + + def reset_parameters(self): + cuboid_utils.apply_initialization( + self.q_proj, linear_mode=self.attn_linear_init_mode + ) + cuboid_utils.apply_initialization( + self.kv_proj, linear_mode=self.attn_linear_init_mode + ) + cuboid_utils.apply_initialization( + self.proj, linear_mode=self.ffn_linear_init_mode + ) + cuboid_utils.apply_initialization(self.norm, norm_mode=self.norm_init_mode) + if self.use_global_vector: + if self.separate_global_qkv: + cuboid_utils.apply_initialization( + self.l2g_q_net, linear_mode=self.attn_linear_init_mode + ) + cuboid_utils.apply_initialization( + self.l2g_global_kv_net, linear_mode=self.attn_linear_init_mode + ) + + def forward(self, x, mem, mem_global_vectors=None): + """Calculate the forward + + Along the temporal axis, we pad the mem tensor from the left and the x tensor from the right so that the + relative position encoding can be calculated correctly. For example: + + mem: 0, 1, 2, 3, 4 + x: 0, 1, 2, 3, 4, 5 + + n_temporal = 1 + mem: 0, 1, 2, 3, 4 x: 0, 1, 2, 3, 4, 5 + + n_temporal = 2 + mem: pad, 1, 3 x: 0, 2, 4 + mem: 0, 2, 4 x: 1, 3, 5 + + n_temporal = 3 + mem: pad, 2 dec: 0, 3 + mem: 0, 3 dec: 1, 4 + mem: 1, 4 dec: 2, 5 + + Args: + x (paddle.Tensor): The input of the layer. It will have shape (B, T, H, W, C) + mem (paddle.Tensor): The memory. It will have shape (B, T_mem, H, W, C) + mem_global_vectors (paddle.Tensor): The global vectors from the memory. It will have shape (B, N, C) + + Returns: + out (paddle.Tensor): Output tensor should have shape (B, T, H, W, C_out) + """ + + if self.cross_last_n_frames is not None: + cross_last_n_frames = int(min(self.cross_last_n_frames, mem.shape[1])) + mem = mem[:, -cross_last_n_frames:, ...] + if self.use_global_vector: + _, num_global, _ = mem_global_vectors.shape + x = self.norm(x) + B, T_x, H, W, C_in = x.shape + B_mem, T_mem, H_mem, W_mem, C_mem = mem.shape + assert T_x < self.max_temporal_relative and T_mem < self.max_temporal_relative + cuboid_hw = self.cuboid_hw + n_temporal = self.n_temporal + shift_hw = self.shift_hw + assert ( + B_mem == B and H == H_mem and W == W_mem and C_in == C_mem + ), f"Shape of memory and the input tensor does not match. x.shape={x.shape}, mem.shape={mem.shape}" + pad_t_mem = (n_temporal - T_mem % n_temporal) % n_temporal + pad_t_x = (n_temporal - T_x % n_temporal) % n_temporal + pad_h = (cuboid_hw[0] - H % cuboid_hw[0]) % cuboid_hw[0] + pad_w = (cuboid_hw[1] - W % cuboid_hw[1]) % cuboid_hw[1] + mem = cuboid_utils.generalize_padding( + mem, pad_t_mem, pad_h, pad_w, self.padding_type, t_pad_left=True + ) + + x = cuboid_utils.generalize_padding( + x, pad_t_x, pad_h, pad_w, self.padding_type, t_pad_left=False + ) + + if any(i > 0 for i in shift_hw): + shifted_x = paddle.roll( + x=x, shifts=(-shift_hw[0], -shift_hw[1]), axis=(2, 3) + ) + shifted_mem = paddle.roll( + x=mem, shifts=(-shift_hw[0], -shift_hw[1]), axis=(2, 3) + ) + else: + shifted_x = x + shifted_mem = mem + mem_cuboid_size = (mem.shape[1] // n_temporal,) + cuboid_hw + x_cuboid_size = (x.shape[1] // n_temporal,) + cuboid_hw + reordered_mem = cuboid_encoder.cuboid_reorder( + shifted_mem, cuboid_size=mem_cuboid_size, strategy=self.strategy + ) + reordered_x = cuboid_encoder.cuboid_reorder( + shifted_x, cuboid_size=x_cuboid_size, strategy=self.strategy + ) + _, num_cuboids_mem, mem_cuboid_volume, _ = reordered_mem.shape + _, num_cuboids, x_cuboid_volume, _ = reordered_x.shape + assert ( + num_cuboids_mem == num_cuboids + ), f"Number of cuboids do not match. num_cuboids={num_cuboids}, num_cuboids_mem={num_cuboids_mem}" + attn_mask = compute_cuboid_cross_attention_mask( + T_x, + T_mem, + H, + W, + n_temporal, + cuboid_hw, + shift_hw, + strategy=self.strategy, + padding_type=self.padding_type, + device=x.place, + ) + head_C = C_in // self.num_heads + kv = ( + self.kv_proj(reordered_mem) + .reshape([B, num_cuboids, mem_cuboid_volume, 2, self.num_heads, head_C]) + .transpose(perm=[3, 0, 4, 1, 2, 5]) + ) + k, v = kv[0], kv[1] + q = ( + self.q_proj(reordered_x) + .reshape([B, num_cuboids, x_cuboid_volume, self.num_heads, head_C]) + .transpose(perm=[0, 3, 1, 2, 4]) + ) + q = q * self.scale + perm_4 = list(range(k.ndim)) + perm_4[-2] = -1 + perm_4[-1] = -2 + attn_score = q @ k.transpose(perm=perm_4) + if self.use_relative_pos: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[ + :x_cuboid_volume, :mem_cuboid_volume + ].reshape([-1]) + ].reshape([x_cuboid_volume, mem_cuboid_volume, -1]) + relative_position_bias = relative_position_bias.transpose( + perm=[2, 0, 1] + ).unsqueeze(axis=1) + attn_score = attn_score + relative_position_bias + if self.use_global_vector: + if self.separate_global_qkv: + l2g_q = ( + self.l2g_q_net(reordered_x) + .reshape([B, num_cuboids, x_cuboid_volume, self.num_heads, head_C]) + .transpose(perm=[0, 3, 1, 2, 4]) + ) + l2g_q = l2g_q * self.scale + l2g_global_kv = ( + self.l2g_global_kv_net(mem_global_vectors) + .reshape([B, 1, num_global, 2, self.num_heads, head_C]) + .transpose(perm=[3, 0, 4, 1, 2, 5]) + ) + l2g_global_k, l2g_global_v = l2g_global_kv[0], l2g_global_kv[1] + else: + kv_global = ( + self.kv_proj(mem_global_vectors) + .reshape([B, 1, num_global, 2, self.num_heads, head_C]) + .transpose(perm=[3, 0, 4, 1, 2, 5]) + ) + l2g_global_k, l2g_global_v = kv_global[0], kv_global[1] + l2g_q = q + perm_5 = list(range(l2g_global_k.ndim)) + perm_5[-2] = -1 + perm_5[-1] = -2 + l2g_attn_score = l2g_q @ l2g_global_k.transpose(perm=perm_5) + attn_score_l2l_l2g = paddle.concat(x=(attn_score, l2g_attn_score), axis=-1) + if attn_mask.ndim == 5: + attn_mask_l2l_l2g = F.pad( + attn_mask, [0, num_global], "constant", 1, data_format="NDHWC" + ) + else: + attn_mask_l2l_l2g = F.pad(attn_mask, [0, num_global], "constant", 1) + v_l_g = paddle.concat( + x=( + v, + l2g_global_v.expand( + shape=[B, self.num_heads, num_cuboids, num_global, head_C] + ), + ), + axis=3, + ) + attn_score_l2l_l2g = cuboid_encoder.masked_softmax( + attn_score_l2l_l2g, mask=attn_mask_l2l_l2g + ) + attn_score_l2l_l2g = self.attn_drop(attn_score_l2l_l2g) + reordered_x = ( + (attn_score_l2l_l2g @ v_l_g) + .transpose(perm=[0, 2, 3, 1, 4]) + .reshape(B, num_cuboids, x_cuboid_volume, self.dim) + ) + else: + attn_score = cuboid_encoder.masked_softmax(attn_score, mask=attn_mask) + attn_score = self.attn_drop(attn_score) + reordered_x = ( + (attn_score @ v) + .transpose(perm=[0, 2, 3, 1, 4]) + .reshape([B, num_cuboids, x_cuboid_volume, self.dim]) + ) + reordered_x = paddle.cast(reordered_x, dtype="float32") + reordered_x = self.proj_drop(self.proj(reordered_x)) + shifted_x = cuboid_encoder.cuboid_reorder_reverse( + reordered_x, + cuboid_size=x_cuboid_size, + strategy=self.strategy, + orig_data_shape=(x.shape[1], x.shape[2], x.shape[3]), + ) + if any(i > 0 for i in shift_hw): + x = paddle.roll(x=shifted_x, shifts=(shift_hw[0], shift_hw[1]), axis=(2, 3)) + else: + x = shifted_x + x = cuboid_utils.generalize_unpadding( + x, pad_t=pad_t_x, pad_h=pad_h, pad_w=pad_w, padding_type=self.padding_type + ) + return x + + +class StackCuboidCrossAttentionBlock(paddle.nn.Layer): + """A stack of cuboid cross attention layers. + + The advantage of cuboid attention is that we can combine cuboid attention building blocks with different + hyper-parameters to mimic a broad range of space-time correlation patterns. + + - "use_inter_ffn" is True + x, mem --> attn1 -----+-------> ffn1 ---+---> attn2 --> ... --> ffn_k --> out + | ^ | ^ + | | | | + |-------------|----|-------------| + - "use_inter_ffn" is False + x, mem --> attn1 -----+------> attn2 --> ... attnk --+----> ffnk ---+---> out, mem + | ^ | ^ ^ | ^ + | | | | | | | + |-------------|----|------------|-- ----------|--|-----------| + + Args: + dim (int): The dimension of the input. + num_heads (int): The number of head. + block_cuboid_hw (list, optional): The height and width of block cuboid.Defaults to [(4, 4), (4, 4)]. + block_shift_hw (list, optional): The height and width of shift cuboid . Defaults to [(0, 0), (2, 2)]. + block_n_temporal (list, optional): The length of block temporal. Defaults to [1, 2]. + block_strategy (list, optional): The strategy of block. Defaults to [("d", "d", "d"), ("l", "l", "l")]. + padding_type (str, optional): The type of paddling. Defaults to "ignore". + cross_last_n_frames (int, optional): The num of cross_last_n_frames. Defaults to None. + qkv_bias (bool, optional): Whether to enable bias in calculating qkv attention. Defaults to False. + qk_scale (float, optional): Whether to enable scale factor when calculating the attention. Defaults to None. + attn_drop (float, optional): The attention dropout. Defaults to 0.0. + proj_drop (float, optional): The projection dropout. Defaults to 0.0. + ffn_drop (float, optional): The ratio of FFN dropout. Defaults to 0.0. + activation (str, optional): The activation. Defaults to "leaky". + gated_ffn (bool, optional): Whether to use gate FFN. Defaults to False. + norm_layer (str, optional): The normalization layer. Defaults to "layer_norm". + use_inter_ffn (bool, optional): Whether to use inter FFN. Defaults to True. + max_temporal_relative (int, optional): The max temporal. Defaults to 50. + checkpoint_level (int, optional): Whether to enable gradient checkpointing. Defaults to 1. + use_relative_pos (bool, optional): Whether to use relative pos. Defaults to True. + use_global_vector (bool, optional): Whether to use the global vector or not. Defaults to False. + separate_global_qkv (bool, optional): Whether to use different network to calc q_global, k_global, v_global. Defaults to False. + global_dim_ratio (int, optional): The dim (channels) of global vectors is `global_dim_ratio*dim`. Defaults to 1. + attn_linear_init_mode (str, optional): The mode of attention linear initialization. Defaults to "0". + ffn_linear_init_mode (str, optional): The mode of FFN linear initialization. Defaults to "0". + norm_init_mode (str, optional): The mode of normalization. Defaults to "0". + """ + + def __init__( + self, + dim: int, + num_heads: int, + block_cuboid_hw: Tuple[Tuple[int, ...], ...] = [(4, 4), (4, 4)], + block_shift_hw: Tuple[Tuple[int, ...], ...] = [(0, 0), (2, 2)], + block_n_temporal: Tuple[int, ...] = [1, 2], + block_strategy: Tuple[Tuple[str, ...], ...] = [ + ("d", "d", "d"), + ("l", "l", "l"), + ], + padding_type: str = "ignore", + cross_last_n_frames: int = None, + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ffn_drop: float = 0.0, + activation: str = "leaky", + gated_ffn: bool = False, + norm_layer: str = "layer_norm", + use_inter_ffn: bool = True, + max_temporal_relative: int = 50, + checkpoint_level: int = 1, + use_relative_pos: bool = True, + use_global_vector: bool = False, + separate_global_qkv: bool = False, + global_dim_ratio: int = 1, + attn_linear_init_mode: str = "0", + ffn_linear_init_mode: str = "0", + norm_init_mode: str = "0", + ): + super(StackCuboidCrossAttentionBlock, self).__init__() + self.attn_linear_init_mode = attn_linear_init_mode + self.ffn_linear_init_mode = ffn_linear_init_mode + self.norm_init_mode = norm_init_mode + if ( + len(block_cuboid_hw[0]) <= 0 + or len(block_shift_hw) <= 0 + or len(block_strategy) <= 0 + ): + raise ValueError( + "Incorrect format.The lengths of block_cuboid_hw[0], block_shift_hw, and block_strategy must be greater than zero." + ) + if len(block_cuboid_hw) != len(block_shift_hw) and len(block_shift_hw) == len( + block_strategy + ): + raise ValueError( + "The lengths of block_cuboid_size, block_shift_size, and block_strategy must be equal." + ) + + self.num_attn = len(block_cuboid_hw) + self.checkpoint_level = checkpoint_level + self.use_inter_ffn = use_inter_ffn + self.use_global_vector = use_global_vector + if self.use_inter_ffn: + self.ffn_l = paddle.nn.LayerList( + sublayers=[ + cuboid_encoder.PositionwiseFFN( + units=dim, + hidden_size=4 * dim, + activation_dropout=ffn_drop, + dropout=ffn_drop, + gated_proj=gated_ffn, + activation=activation, + normalization=norm_layer, + pre_norm=True, + linear_init_mode=ffn_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(self.num_attn) + ] + ) + else: + self.ffn_l = paddle.nn.LayerList( + sublayers=[ + cuboid_encoder.PositionwiseFFN( + units=dim, + hidden_size=4 * dim, + activation_dropout=ffn_drop, + dropout=ffn_drop, + gated_proj=gated_ffn, + activation=activation, + normalization=norm_layer, + pre_norm=True, + linear_init_mode=ffn_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + ] + ) + self.attn_l = paddle.nn.LayerList( + sublayers=[ + CuboidCrossAttentionLayer( + dim=dim, + num_heads=num_heads, + cuboid_hw=ele_cuboid_hw, + shift_hw=ele_shift_hw, + strategy=ele_strategy, + n_temporal=ele_n_temporal, + cross_last_n_frames=cross_last_n_frames, + padding_type=padding_type, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + max_temporal_relative=max_temporal_relative, + use_global_vector=use_global_vector, + separate_global_qkv=separate_global_qkv, + global_dim_ratio=global_dim_ratio, + checkpoint_level=checkpoint_level, + use_relative_pos=use_relative_pos, + attn_linear_init_mode=attn_linear_init_mode, + ffn_linear_init_mode=ffn_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for ele_cuboid_hw, ele_shift_hw, ele_strategy, ele_n_temporal in zip( + block_cuboid_hw, block_shift_hw, block_strategy, block_n_temporal + ) + ] + ) + + def reset_parameters(self): + for m in self.ffn_l: + m.reset_parameters() + for m in self.attn_l: + m.reset_parameters() + + def forward(self, x, mem, mem_global_vector=None): + """ + Args: + x (paddle.Tensor): Shape (B, T_x, H, W, C) + mem (paddle.Tensor): Shape (B, T_mem, H, W, C) + mem_global_vector (paddle.Tensor): Shape (B, N_global, C) + + Returns: + out (paddle.Tensor): (B, T_x, H, W, C_out) + """ + + if self.use_inter_ffn: + for attn, ffn in zip(self.attn_l, self.ffn_l): + if self.checkpoint_level >= 2 and self.training: + x = x + fleet.utils.recompute(attn, x, mem, mem_global_vector) + else: + x = x + attn(x, mem, mem_global_vector) + if self.checkpoint_level >= 1 and self.training: + x = fleet.utils.recompute(ffn, x) + else: + x = ffn(x) + return x + else: + for attn in self.attn_l: + if self.checkpoint_level >= 2 and self.training: + x = x + fleet.utils.recompute(attn, x, mem, mem_global_vector) + else: + x = x + attn(x, mem, mem_global_vector) + if self.checkpoint_level >= 1 and self.training: + x = fleet.utils.recompute(self.ffn_l[0], x) + else: + x = self.ffn_l[0](x) + return x + + +class Upsample3DLayer(paddle.nn.Layer): + """Upsampling based on nn.UpSampling and Conv3x3. + + If the temporal dimension remains the same: + x --> interpolation-2d (nearest) --> conv3x3(dim, out_dim) + Else: + x --> interpolation-3d (nearest) --> conv3x3x3(dim, out_dim) + + Args: + dim (int): The dimension of the input tensor. + out_dim (int): The dimension of the output tensor. + target_size (Tuple[int,...]): The size of output tensor. + temporal_upsample (bool, optional): Whether the temporal axis will go through upsampling. Defaults to False. + kernel_size (int, optional): The kernel size of the Conv2D layer. Defaults to 3. + layout (str, optional): The layout of the inputs. Defaults to "THWC". + conv_init_mode (str, optional): The mode of conv initialization. Defaults to "0". + """ + + def __init__( + self, + dim: int, + out_dim: int, + target_size: Tuple[int, ...], + temporal_upsample: bool = False, + kernel_size: int = 3, + layout: str = "THWC", + conv_init_mode: str = "0", + ): + super(Upsample3DLayer, self).__init__() + self.conv_init_mode = conv_init_mode + self.target_size = target_size + self.out_dim = out_dim + self.temporal_upsample = temporal_upsample + if temporal_upsample: + self.up = paddle.nn.Upsample(size=target_size, mode="nearest") + else: + self.up = paddle.nn.Upsample( + size=(target_size[1], target_size[2]), mode="nearest" + ) + self.conv = paddle.nn.Conv2D( + in_channels=dim, + out_channels=out_dim, + kernel_size=(kernel_size, kernel_size), + padding=(kernel_size // 2, kernel_size // 2), + ) + assert layout in ["THWC", "CTHW"] + self.layout = layout + self.reset_parameters() + + def reset_parameters(self): + for m in self.children(): + cuboid_utils.apply_initialization(m, conv_mode=self.conv_init_mode) + + def forward(self, x): + """ + + Args: + x : (B, T, H, W, C) or (B, C, T, H, W) + + Returns: + out : (B, T, H_new, W_out, C_out) or (B, C, T, H_out, W_out) + """ + + if self.layout == "THWC": + B, T, H, W, C = x.shape + if self.temporal_upsample: + x = x.transpose(perm=[0, 4, 1, 2, 3]) + return self.conv(self.up(x)).transpose(perm=[0, 2, 3, 4, 1]) + else: + assert self.target_size[0] == T + x = x.reshape([B * T, H, W, C]).transpose(perm=[0, 3, 1, 2]) + x = self.up(x) + return ( + self.conv(x) + .transpose(perm=[0, 2, 3, 1]) + .reshape(list((B,) + self.target_size + (self.out_dim,))) + ) + elif self.layout == "CTHW": + B, C, T, H, W = x.shape + if self.temporal_upsample: + return self.conv(self.up(x)) + else: + assert self.output_size[0] == T + x = x.transpose(perm=[0, 2, 1, 3, 4]) + x = x.reshape([B * T, C, H, W]) + return ( + self.conv(self.up(x)) + .reshape( + [ + B, + self.target_size[0], + self.out_dim, + self.target_size[1], + self.target_size[2], + ] + ) + .transpose(perm=[0, 2, 1, 3, 4]) + ) + + +class CuboidTransformerDecoder(paddle.nn.Layer): + """Decoder of the CuboidTransformer. + + For each block, we first apply the StackCuboidSelfAttention and then apply the StackCuboidCrossAttention + + Repeat the following structure K times + + x --> StackCuboidSelfAttention --> | + |----> StackCuboidCrossAttention (If used) --> out + mem --> | + + Args: + target_temporal_length (int): The temporal length of the target. + mem_shapes (Tuple[int,...]): The mem shapes of the decoder. + cross_start (int, optional): The block to start cross attention. Defaults to 0. + depth (list, optional): The number of layers for each block. Defaults to [2, 2]. + upsample_type (str, optional): The type of upsample. Defaults to "upsample". + upsample_kernel_size (int, optional): The kernel size of upsample. Defaults to 3. + block_self_attn_patterns (str, optional): The patterns of block attention. Defaults to None. + block_self_cuboid_size (list, optional): The size of cuboid block. Defaults to [(4, 4, 4), (4, 4, 4)]. + block_self_cuboid_strategy (list, optional): The strategy of cuboid. Defaults to [("l", "l", "l"), ("d", "d", "d")]. + block_self_shift_size (list, optional): The size of shift. Defaults to [(1, 1, 1), (0, 0, 0)]. + block_cross_attn_patterns (str, optional): The patterns of cross attentions. Defaults to None. + block_cross_cuboid_hw (list, optional): The height and width of cross cuboid. Defaults to [(4, 4), (4, 4)]. + block_cross_cuboid_strategy (list, optional): The strategy of cross cuboid. Defaults to [("l", "l", "l"), ("d", "l", "l")]. + block_cross_shift_hw (list, optional): The height and width of cross shift. Defaults to [(0, 0), (0, 0)]. + block_cross_n_temporal (list, optional): The cross temporal of block. Defaults to [1, 2]. + cross_last_n_frames (int, optional): The num of cross last frames. Defaults to None. + num_heads (int, optional): The num of head. Defaults to 4. + attn_drop (float, optional): The ratio of attention dropout. Defaults to 0.0. + proj_drop (float, optional): The ratio of projection dropout. Defaults to 0.0. + ffn_drop (float, optional): The ratio of FFN dropout. Defaults to 0.0. + ffn_activation (str, optional): The activation layer of FFN. Defaults to "leaky". + gated_ffn (bool, optional): Whether to use gate FFN. Defaults to False. + norm_layer (str, optional): The normalization layer. Defaults to "layer_norm". + use_inter_ffn (bool, optional): Whether to use inter FFN. Defaults to False. + hierarchical_pos_embed (bool, optional): Whether to use hierarchical pos_embed. Defaults to False. + pos_embed_type (str, optional): The type of pos embeding. Defaults to "t+hw". + max_temporal_relative (int, optional): The max number of teemporal relative. Defaults to 50. + padding_type (str, optional): The type of padding. Defaults to "ignore". + checkpoint_level (bool, optional): Whether to enable gradient checkpointing. Defaults to True. + use_relative_pos (bool, optional): Whether to use relative pos. Defaults to True. + self_attn_use_final_proj (bool, optional): Whether to use self attention for final projection. Defaults to True. + use_first_self_attn (bool, optional): Whether to use first self attention. Defaults to False. + use_self_global (bool, optional): Whether to use self global vector. Defaults to False. + self_update_global (bool, optional): Whether to update global vector. Defaults to True. + use_cross_global (bool, optional): Whether to use cross global vector. Defaults to False. + use_global_vector_ffn (bool, optional): Whether to use FFN global vectors. Defaults to True. + use_global_self_attn (bool, optional): Whether to use global self attention. Defaults to False. + separate_global_qkv (bool, optional): Whether to use different network to calc q_global, k_global, v_global. Defaults to False. + global_dim_ratio (int, optional): The dim (channels) of global vectors is `global_dim_ratio*dim`. Defaults to 1. + attn_linear_init_mode (str, optional): The mode of attention linear initialization. Defaults to "0". + ffn_linear_init_mode (str, optional): The mode of FFN linear initialization. Defaults to "0". + conv_init_mode (str, optional): The mode of conv initialization. Defaults to "0". + up_linear_init_mode (str, optional): The mode of up linear initialization. Defaults to "0". + norm_init_mode (str, optional): The mode of normalization initialization. Defaults to "0". + """ + + def __init__( + self, + target_temporal_length: int, + mem_shapes: Tuple[int, ...], + cross_start: int = 0, + depth: Tuple[int, ...] = [2, 2], + upsample_type: str = "upsample", + upsample_kernel_size: int = 3, + block_self_attn_patterns: str = None, + block_self_cuboid_size: Tuple[Tuple[int, ...], ...] = [(4, 4, 4), (4, 4, 4)], + block_self_cuboid_strategy: Tuple[Tuple[str, ...], ...] = [ + ("l", "l", "l"), + ("d", "d", "d"), + ], + block_self_shift_size: Tuple[Tuple[int, ...], ...] = [(1, 1, 1), (0, 0, 0)], + block_cross_attn_patterns: str = None, + block_cross_cuboid_hw: Tuple[Tuple[int, ...], ...] = [(4, 4), (4, 4)], + block_cross_cuboid_strategy: Tuple[Tuple[str, ...], ...] = [ + ("l", "l", "l"), + ("d", "l", "l"), + ], + block_cross_shift_hw: Tuple[Tuple[int, ...], ...] = [(0, 0), (0, 0)], + block_cross_n_temporal: Tuple[int, ...] = [1, 2], + cross_last_n_frames: int = None, + num_heads: int = 4, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ffn_drop: float = 0.0, + ffn_activation: str = "leaky", + gated_ffn: bool = False, + norm_layer: str = "layer_norm", + use_inter_ffn: bool = False, + hierarchical_pos_embed: bool = False, + pos_embed_type: str = "t+hw", + max_temporal_relative: int = 50, + padding_type: str = "ignore", + checkpoint_level: bool = True, + use_relative_pos: bool = True, + self_attn_use_final_proj: bool = True, + use_first_self_attn: bool = False, + use_self_global: bool = False, + self_update_global: bool = True, + use_cross_global: bool = False, + use_global_vector_ffn: bool = True, + use_global_self_attn: bool = False, + separate_global_qkv: bool = False, + global_dim_ratio: int = 1, + attn_linear_init_mode: str = "0", + ffn_linear_init_mode: str = "0", + conv_init_mode: str = "0", + up_linear_init_mode: str = "0", + norm_init_mode: str = "0", + ): + super(CuboidTransformerDecoder, self).__init__() + self.attn_linear_init_mode = attn_linear_init_mode + self.ffn_linear_init_mode = ffn_linear_init_mode + self.conv_init_mode = conv_init_mode + self.up_linear_init_mode = up_linear_init_mode + self.norm_init_mode = norm_init_mode + assert len(depth) == len(mem_shapes) + self.target_temporal_length = target_temporal_length + self.num_blocks = len(mem_shapes) + self.cross_start = cross_start + self.mem_shapes = mem_shapes + self.depth = depth + self.upsample_type = upsample_type + self.hierarchical_pos_embed = hierarchical_pos_embed + self.checkpoint_level = checkpoint_level + self.use_self_global = use_self_global + self.self_update_global = self_update_global + self.use_cross_global = use_cross_global + self.use_global_vector_ffn = use_global_vector_ffn + self.use_first_self_attn = use_first_self_attn + if block_self_attn_patterns is not None: + if isinstance(block_self_attn_patterns, (tuple, list)): + assert len(block_self_attn_patterns) == self.num_blocks + else: + block_self_attn_patterns = [ + block_self_attn_patterns for _ in range(self.num_blocks) + ] + block_self_cuboid_size = [] + block_self_cuboid_strategy = [] + block_self_shift_size = [] + for idx, key in enumerate(block_self_attn_patterns): + func = cuboid_utils.CuboidSelfAttentionPatterns.get(key) + cuboid_size, strategy, shift_size = func(mem_shapes[idx]) + block_self_cuboid_size.append(cuboid_size) + block_self_cuboid_strategy.append(strategy) + block_self_shift_size.append(shift_size) + else: + if not isinstance(block_self_cuboid_size[0][0], (list, tuple)): + block_self_cuboid_size = [ + block_self_cuboid_size for _ in range(self.num_blocks) + ] + else: + assert ( + len(block_self_cuboid_size) == self.num_blocks + ), f"Incorrect input format! Received block_self_cuboid_size={block_self_cuboid_size}" + if not isinstance(block_self_cuboid_strategy[0][0], (list, tuple)): + block_self_cuboid_strategy = [ + block_self_cuboid_strategy for _ in range(self.num_blocks) + ] + else: + assert ( + len(block_self_cuboid_strategy) == self.num_blocks + ), f"Incorrect input format! Received block_self_cuboid_strategy={block_self_cuboid_strategy}" + if not isinstance(block_self_shift_size[0][0], (list, tuple)): + block_self_shift_size = [ + block_self_shift_size for _ in range(self.num_blocks) + ] + else: + assert ( + len(block_self_shift_size) == self.num_blocks + ), f"Incorrect input format! Received block_self_shift_size={block_self_shift_size}" + self_blocks = [] + for i in range(self.num_blocks): + if not self.use_first_self_attn and i == self.num_blocks - 1: + ele_depth = depth[i] - 1 + else: + ele_depth = depth[i] + stack_cuboid_blocks = [ + cuboid_encoder.StackCuboidSelfAttentionBlock( + dim=self.mem_shapes[i][-1], + num_heads=num_heads, + block_cuboid_size=block_self_cuboid_size[i], + block_strategy=block_self_cuboid_strategy[i], + block_shift_size=block_self_shift_size[i], + attn_drop=attn_drop, + proj_drop=proj_drop, + ffn_drop=ffn_drop, + activation=ffn_activation, + gated_ffn=gated_ffn, + norm_layer=norm_layer, + use_inter_ffn=use_inter_ffn, + padding_type=padding_type, + use_global_vector=use_self_global, + use_global_vector_ffn=use_global_vector_ffn, + use_global_self_attn=use_global_self_attn, + separate_global_qkv=separate_global_qkv, + global_dim_ratio=global_dim_ratio, + checkpoint_level=checkpoint_level, + use_relative_pos=use_relative_pos, + use_final_proj=self_attn_use_final_proj, + attn_linear_init_mode=attn_linear_init_mode, + ffn_linear_init_mode=ffn_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(ele_depth) + ] + self_blocks.append(paddle.nn.LayerList(sublayers=stack_cuboid_blocks)) + self.self_blocks = paddle.nn.LayerList(sublayers=self_blocks) + if block_cross_attn_patterns is not None: + if isinstance(block_cross_attn_patterns, (tuple, list)): + assert len(block_cross_attn_patterns) == self.num_blocks + else: + block_cross_attn_patterns = [ + block_cross_attn_patterns for _ in range(self.num_blocks) + ] + block_cross_cuboid_hw = [] + block_cross_cuboid_strategy = [] + block_cross_shift_hw = [] + block_cross_n_temporal = [] + for idx, key in enumerate(block_cross_attn_patterns): + if key == "last_frame_dst": + cuboid_hw = None + shift_hw = None + strategy = None + n_temporal = None + else: + func = cuboid_utils.CuboidCrossAttentionPatterns.get(key) + cuboid_hw, shift_hw, strategy, n_temporal = func(mem_shapes[idx]) + block_cross_cuboid_hw.append(cuboid_hw) + block_cross_cuboid_strategy.append(strategy) + block_cross_shift_hw.append(shift_hw) + block_cross_n_temporal.append(n_temporal) + else: + if not isinstance(block_cross_cuboid_hw[0][0], (list, tuple)): + block_cross_cuboid_hw = [ + block_cross_cuboid_hw for _ in range(self.num_blocks) + ] + else: + assert ( + len(block_cross_cuboid_hw) == self.num_blocks + ), f"Incorrect input format! Received block_cross_cuboid_hw={block_cross_cuboid_hw}" + if not isinstance(block_cross_cuboid_strategy[0][0], (list, tuple)): + block_cross_cuboid_strategy = [ + block_cross_cuboid_strategy for _ in range(self.num_blocks) + ] + else: + assert ( + len(block_cross_cuboid_strategy) == self.num_blocks + ), f"Incorrect input format! Received block_cross_cuboid_strategy={block_cross_cuboid_strategy}" + if not isinstance(block_cross_shift_hw[0][0], (list, tuple)): + block_cross_shift_hw = [ + block_cross_shift_hw for _ in range(self.num_blocks) + ] + else: + assert ( + len(block_cross_shift_hw) == self.num_blocks + ), f"Incorrect input format! Received block_cross_shift_hw={block_cross_shift_hw}" + if not isinstance(block_cross_n_temporal[0], (list, tuple)): + block_cross_n_temporal = [ + block_cross_n_temporal for _ in range(self.num_blocks) + ] + else: + assert ( + len(block_cross_n_temporal) == self.num_blocks + ), f"Incorrect input format! Received block_cross_n_temporal={block_cross_n_temporal}" + self.cross_blocks = paddle.nn.LayerList() + for i in range(self.cross_start, self.num_blocks): + cross_block = paddle.nn.LayerList( + sublayers=[ + StackCuboidCrossAttentionBlock( + dim=self.mem_shapes[i][-1], + num_heads=num_heads, + block_cuboid_hw=block_cross_cuboid_hw[i], + block_strategy=block_cross_cuboid_strategy[i], + block_shift_hw=block_cross_shift_hw[i], + block_n_temporal=block_cross_n_temporal[i], + cross_last_n_frames=cross_last_n_frames, + attn_drop=attn_drop, + proj_drop=proj_drop, + ffn_drop=ffn_drop, + gated_ffn=gated_ffn, + norm_layer=norm_layer, + use_inter_ffn=use_inter_ffn, + activation=ffn_activation, + max_temporal_relative=max_temporal_relative, + padding_type=padding_type, + use_global_vector=use_cross_global, + separate_global_qkv=separate_global_qkv, + global_dim_ratio=global_dim_ratio, + checkpoint_level=checkpoint_level, + use_relative_pos=use_relative_pos, + attn_linear_init_mode=attn_linear_init_mode, + ffn_linear_init_mode=ffn_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(depth[i]) + ] + ) + self.cross_blocks.append(cross_block) + if self.num_blocks > 1: + if self.upsample_type == "upsample": + self.upsample_layers = paddle.nn.LayerList( + sublayers=[ + Upsample3DLayer( + dim=self.mem_shapes[i + 1][-1], + out_dim=self.mem_shapes[i][-1], + target_size=(target_temporal_length,) + + self.mem_shapes[i][1:3], + kernel_size=upsample_kernel_size, + temporal_upsample=False, + conv_init_mode=conv_init_mode, + ) + for i in range(self.num_blocks - 1) + ] + ) + else: + raise NotImplementedError(f"{self.upsample_type} is invalid.") + if self.hierarchical_pos_embed: + self.hierarchical_pos_embed_l = paddle.nn.LayerList( + sublayers=[ + PosEmbed( + embed_dim=self.mem_shapes[i][-1], + typ=pos_embed_type, + maxT=target_temporal_length, + maxH=self.mem_shapes[i][1], + maxW=self.mem_shapes[i][2], + ) + for i in range(self.num_blocks - 1) + ] + ) + self.reset_parameters() + + def reset_parameters(self): + for ms in self.self_blocks: + for m in ms: + m.reset_parameters() + for ms in self.cross_blocks: + for m in ms: + m.reset_parameters() + if self.num_blocks > 1: + for m in self.upsample_layers: + m.reset_parameters() + if self.hierarchical_pos_embed: + for m in self.hierarchical_pos_embed_l: + m.reset_parameters() + + def forward(self, x, mem_l, mem_global_vector_l=None): + """ + Args: + x : Shape (B, T_top, H_top, W_top, C). + mem_l : A list of memory tensors. + """ + + B, T_top, H_top, W_top, C = x.shape + assert T_top == self.target_temporal_length + assert (H_top, W_top) == (self.mem_shapes[-1][1], self.mem_shapes[-1][2]) + for i in range(self.num_blocks - 1, -1, -1): + mem_global_vector = ( + None if mem_global_vector_l is None else mem_global_vector_l[i] + ) + if not self.use_first_self_attn and i == self.num_blocks - 1: + if i >= self.cross_start: + x = self.cross_blocks[i - self.cross_start][0]( + x, mem_l[i], mem_global_vector + ) + for idx in range(self.depth[i] - 1): + if self.use_self_global: + if self.self_update_global: + x, mem_global_vector = self.self_blocks[i][idx]( + x, mem_global_vector + ) + else: + x, _ = self.self_blocks[i][idx](x, mem_global_vector) + else: + x = self.self_blocks[i][idx](x) + if i >= self.cross_start: + x = self.cross_blocks[i - self.cross_start][idx + 1]( + x, mem_l[i], mem_global_vector + ) + else: + for idx in range(self.depth[i]): + if self.use_self_global: + if self.self_update_global: + x, mem_global_vector = self.self_blocks[i][idx]( + x, mem_global_vector + ) + else: + x, _ = self.self_blocks[i][idx](x, mem_global_vector) + else: + x = self.self_blocks[i][idx](x) + if i >= self.cross_start: + x = self.cross_blocks[i - self.cross_start][idx]( + x, mem_l[i], mem_global_vector + ) + if i > 0: + x = self.upsample_layers[i - 1](x) + if self.hierarchical_pos_embed: + x = self.hierarchical_pos_embed_l[i - 1](x) + return x diff --git a/ppsci/arch/cuboid_transformer_encoder.py b/ppsci/arch/cuboid_transformer_encoder.py new file mode 100644 index 000000000..e4af19f9f --- /dev/null +++ b/ppsci/arch/cuboid_transformer_encoder.py @@ -0,0 +1,1520 @@ +from collections import OrderedDict +from functools import lru_cache +from typing import Tuple + +import numpy as np +import paddle +import paddle.nn.functional as F +from paddle import nn +from paddle.distributed import fleet + +import ppsci.arch.cuboid_transformer_utils as cuboid_utils +from ppsci.arch import activation as act_mod +from ppsci.utils import initializer + +NEGATIVE_SLOPE = 0.1 + + +class PatchMerging3D(paddle.nn.Layer): + """Patch Merging Layer + + Args: + dim (int): Number of input channels. + out_dim (int, optional): The dim of output. Defaults to None. + downsample (tuple, optional): Downsample factor. Defaults to (1, 2, 2). + norm_layer (str, optional): The normalization layer. Defaults to "layer_norm". + padding_type (str, optional): The type of padding. Defaults to "nearest". + linear_init_mode (str, optional): The mode of linear init. Defaults to "0". + norm_init_mode (str, optional): The mode of normalization init. Defaults to "0". + + """ + + def __init__( + self, + dim: int, + out_dim: int = None, + downsample: Tuple[int, ...] = (1, 2, 2), + norm_layer: str = "layer_norm", + padding_type: str = "nearest", + linear_init_mode: str = "0", + norm_init_mode: str = "0", + ): + super().__init__() + self.linear_init_mode = linear_init_mode + self.norm_init_mode = norm_init_mode + self.dim = dim + if out_dim is None: + out_dim = max(downsample) * dim + self.out_dim = out_dim + self.downsample = downsample + self.padding_type = padding_type + self.reduction = paddle.nn.Linear( + in_features=downsample[0] * downsample[1] * downsample[2] * dim, + out_features=out_dim, + bias_attr=False, + ) + self.norm = cuboid_utils.get_norm_layer( + norm_layer, in_channels=downsample[0] * downsample[1] * downsample[2] * dim + ) + self.reset_parameters() + + def reset_parameters(self): + for m in self.children(): + cuboid_utils.apply_initialization( + m, linear_mode=self.linear_init_mode, norm_mode=self.norm_init_mode + ) + + def get_out_shape(self, data_shape): + T, H, W, C_in = data_shape + pad_t = (self.downsample[0] - T % self.downsample[0]) % self.downsample[0] + pad_h = (self.downsample[1] - H % self.downsample[1]) % self.downsample[1] + pad_w = (self.downsample[2] - W % self.downsample[2]) % self.downsample[2] + return ( + (T + pad_t) // self.downsample[0], + (H + pad_h) // self.downsample[1], + (W + pad_w) // self.downsample[2], + self.out_dim, + ) + + def forward(self, x): + """ + + Args: + x : (B, T, H, W, C) + + Returns: + out : Shape (B, T // downsample[0], H // downsample[1], W // downsample[2], out_dim) + """ + + B, T, H, W, C = x.shape + pad_t = (self.downsample[0] - T % self.downsample[0]) % self.downsample[0] + pad_h = (self.downsample[1] - H % self.downsample[1]) % self.downsample[1] + pad_w = (self.downsample[2] - W % self.downsample[2]) % self.downsample[2] + if pad_h or pad_h or pad_w: + T += pad_t + H += pad_h + W += pad_w + x = cuboid_utils.generalize_padding( + x, pad_t, pad_h, pad_w, padding_type=self.padding_type + ) + x = ( + x.reshape( + ( + B, + T // self.downsample[0], + self.downsample[0], + H // self.downsample[1], + self.downsample[1], + W // self.downsample[2], + self.downsample[2], + C, + ) + ) + .transpose(perm=[0, 1, 3, 5, 2, 4, 6, 7]) + .reshape( + [ + B, + T // self.downsample[0], + H // self.downsample[1], + W // self.downsample[2], + self.downsample[0] * self.downsample[1] * self.downsample[2] * C, + ] + ) + ) + x = self.norm(x) + x = self.reduction(x) + return x + + +class PositionwiseFFN(paddle.nn.Layer): + """The Position-wise FFN layer used in Transformer-like architectures + + If pre_norm is True: + norm(data) -> fc1 -> act -> act_dropout -> fc2 -> dropout -> res(+data) + Else: + data -> fc1 -> act -> act_dropout -> fc2 -> dropout -> norm(res(+data)) + Also, if we use gated projection. We will use + fc1_1 * act(fc1_2(data)) to map the data + + Args: + units (int, optional): The units. Defaults to 512. + hidden_size (int, optional): The size of hidden layer. Defaults to 2048. + activation_dropout (float, optional): The dropout of activate. Defaults to 0.0. + dropout (float, optional): The drop ratio used in DropPat. Defaults to 0.1. + gated_proj (bool, optional): Whether to use gate projection. Defaults to False. + activation (str, optional): The activate. Defaults to "relu". + normalization (str, optional): The normalization. Defaults to "layer_norm". + layer_norm_eps (float, optional): The epsilon of layer normalization. Defaults to 1e-05. + pre_norm (bool): Pre-layer normalization as proposed in the paper: + "[ACL2018] The Best of Both Worlds: Combining Recent Advances in Neural Machine Translation" This will stabilize the training of Transformers. + You may also refer to "[Arxiv2020] Understanding the Difficulty of Training Transformers". Defaults to False. + linear_init_mode (str, optional): The mode of linear initialization. Defaults to "0". + norm_init_mode (str, optional): The mode of normalization initialization. Defaults to "0". + """ + + def __init__( + self, + units: int = 512, + hidden_size: int = 2048, + activation_dropout: float = 0.0, + dropout: float = 0.1, + gated_proj: bool = False, + activation: str = "relu", + normalization: str = "layer_norm", + layer_norm_eps: float = 1e-05, + pre_norm: bool = False, + linear_init_mode: str = "0", + norm_init_mode: str = "0", + ): + super().__init__() + self.linear_init_mode = linear_init_mode + self.norm_init_mode = norm_init_mode + self._pre_norm = pre_norm + self._gated_proj = gated_proj + self._kwargs = OrderedDict( + [ + ("units", units), + ("hidden_size", hidden_size), + ("activation_dropout", activation_dropout), + ("activation", activation), + ("dropout", dropout), + ("normalization", normalization), + ("layer_norm_eps", layer_norm_eps), + ("gated_proj", gated_proj), + ("pre_norm", pre_norm), + ] + ) + self.dropout_layer = paddle.nn.Dropout(p=dropout) + self.activation_dropout_layer = paddle.nn.Dropout(p=activation_dropout) + self.ffn_1 = paddle.nn.Linear( + in_features=units, out_features=hidden_size, bias_attr=True + ) + if self._gated_proj: + self.ffn_1_gate = paddle.nn.Linear( + in_features=units, out_features=hidden_size, bias_attr=True + ) + if activation == "leaky_relu": + self.activation = nn.LeakyReLU(NEGATIVE_SLOPE) + else: + self.activation = act_mod.get_activation(activation) + self.ffn_2 = paddle.nn.Linear( + in_features=hidden_size, out_features=units, bias_attr=True + ) + self.layer_norm = cuboid_utils.get_norm_layer( + normalization=normalization, in_channels=units, epsilon=layer_norm_eps + ) + self.reset_parameters() + + def reset_parameters(self): + cuboid_utils.apply_initialization(self.ffn_1, linear_mode=self.linear_init_mode) + if self._gated_proj: + cuboid_utils.apply_initialization( + self.ffn_1_gate, linear_mode=self.linear_init_mode + ) + cuboid_utils.apply_initialization(self.ffn_2, linear_mode=self.linear_init_mode) + cuboid_utils.apply_initialization( + self.layer_norm, norm_mode=self.norm_init_mode + ) + + def forward(self, data): + """ + Args: + x : Shape (B, seq_length, C_in) + + Returns: + out : Shape (B, seq_length, C_out) + """ + + residual = data + if self._pre_norm: + data = self.layer_norm(data) + if self._gated_proj: + out = self.activation(self.ffn_1_gate(data)) * self.ffn_1(data) + else: + out = self.activation(self.ffn_1(data)) + out = self.activation_dropout_layer(out) + out = self.ffn_2(out) + out = self.dropout_layer(out) + out = out + residual + if not self._pre_norm: + out = self.layer_norm(out) + return out + + +def update_cuboid_size_shift_size(data_shape, cuboid_size, shift_size, strategy): + """Update the cuboid_size and shift_size + + Args: + data_shape (Tuple[int,...]): The shape of the data. + cuboid_size (Tuple[int,...]): Size of the cuboid. + shift_size (Tuple[int,...]): Size of the shift. + strategy (str): The strategy of attention. + + Returns: + new_cuboid_size (Tuple[int,...]): Size of the cuboid. + new_shift_size (Tuple[int,...]): Size of the shift. + """ + + new_cuboid_size = list(cuboid_size) + new_shift_size = list(shift_size) + for i in range(len(data_shape)): + if strategy[i] == "d": + new_shift_size[i] = 0 + if data_shape[i] <= cuboid_size[i]: + new_cuboid_size[i] = data_shape[i] + new_shift_size[i] = 0 + return tuple(new_cuboid_size), tuple(new_shift_size) + + +def cuboid_reorder(data, cuboid_size, strategy): + """Reorder the tensor into (B, num_cuboids, bT * bH * bW, C) + We assume that the tensor shapes are divisible to the cuboid sizes. + + Args: + data (paddle.Tensor): The input data. + cuboid_size (Tuple[int,...]): The size of the cuboid. + strategy (Tuple[int,...]): The cuboid strategy. + + Returns: + reordered_data (paddle.Tensor): Shape will be (B, num_cuboids, bT * bH * bW, C). + num_cuboids = T / bT * H / bH * W / bW + """ + + B, T, H, W, C = data.shape + num_cuboids = T // cuboid_size[0] * H // cuboid_size[1] * W // cuboid_size[2] + cuboid_volume = cuboid_size[0] * cuboid_size[1] * cuboid_size[2] + intermediate_shape = [] + nblock_axis = [] + block_axis = [] + for i, (block_size, total_size, ele_strategy) in enumerate( + zip(cuboid_size, (T, H, W), strategy) + ): + if ele_strategy == "l": + intermediate_shape.extend([total_size // block_size, block_size]) + nblock_axis.append(2 * i + 1) + block_axis.append(2 * i + 2) + elif ele_strategy == "d": + intermediate_shape.extend([block_size, total_size // block_size]) + nblock_axis.append(2 * i + 2) + block_axis.append(2 * i + 1) + else: + raise NotImplementedError(f"{ele_strategy} is invalid.") + data = data.reshape(list((B,) + tuple(intermediate_shape) + (C,))) + reordered_data = data.transpose( + perm=(0,) + tuple(nblock_axis) + tuple(block_axis) + (7,) + ) + reordered_data = reordered_data.reshape((B, num_cuboids, cuboid_volume, C)) + return reordered_data + + +@lru_cache() +def compute_cuboid_self_attention_mask( + data_shape, cuboid_size, shift_size, strategy, padding_type, device +): + """Compute the shift window attention mask + + Args: + data_shape (Tuple[int,....]): Should be (T, H, W). + cuboid_size (Tuple[int,....]): Size of the cuboid. + shift_size (Tuple[int,....]): The shift size. + strategy (str): The decomposition strategy. + padding_type (str): Type of the padding. + device (str): The device. + + Returns: + attn_mask (paddle.Tensor): Mask with shape (num_cuboid, cuboid_vol, cuboid_vol). + The padded values will always be masked. The other masks will ensure that the shifted windows + will only attend to those in the shifted windows. + """ + T, H, W = data_shape + pad_t = (cuboid_size[0] - T % cuboid_size[0]) % cuboid_size[0] + pad_h = (cuboid_size[1] - H % cuboid_size[1]) % cuboid_size[1] + pad_w = (cuboid_size[2] - W % cuboid_size[2]) % cuboid_size[2] + data_mask = None + if pad_t > 0 or pad_h > 0 or pad_w > 0: + if padding_type == "ignore": + data_mask = paddle.ones(shape=(1, T, H, W, 1), dtype="bool") + data_mask = F.pad( + data_mask, [0, 0, 0, pad_w, 0, pad_h, 0, pad_t], data_format="NDHWC" + ) + else: + data_mask = paddle.ones( + shape=(1, T + pad_t, H + pad_h, W + pad_w, 1), dtype="bool" + ) + if any(i > 0 for i in shift_size): + if padding_type == "ignore": + data_mask = paddle.roll( + x=data_mask, + shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), + axis=(1, 2, 3), + ) + if padding_type == "ignore": + data_mask = cuboid_reorder(data_mask, cuboid_size, strategy=strategy) + data_mask = data_mask.squeeze(axis=-1).squeeze(axis=0) + shift_mask = np.zeros(shape=(1, T + pad_t, H + pad_h, W + pad_w, 1)) + cnt = 0 + for t in ( + slice(-cuboid_size[0]), + slice(-cuboid_size[0], -shift_size[0]), + slice(-shift_size[0], None), + ): + for h in ( + slice(-cuboid_size[1]), + slice(-cuboid_size[1], -shift_size[1]), + slice(-shift_size[1], None), + ): + for w in ( + slice(-cuboid_size[2]), + slice(-cuboid_size[2], -shift_size[2]), + slice(-shift_size[2], None), + ): + shift_mask[:, t, h, w, :] = cnt + cnt += 1 + shift_mask = paddle.to_tensor(shift_mask) + shift_mask = cuboid_reorder(shift_mask, cuboid_size, strategy=strategy) + shift_mask = shift_mask.squeeze(axis=-1).squeeze(axis=0) + attn_mask = shift_mask.unsqueeze(axis=1) - shift_mask.unsqueeze(axis=2) == 0 + if padding_type == "ignore": + attn_mask = ( + data_mask.unsqueeze(axis=1) * data_mask.unsqueeze(axis=2) * attn_mask + ) + return attn_mask + + +def masked_softmax(att_score, mask, axis: int = -1): + """Ignore the masked elements when calculating the softmax. + The mask can be broadcastable. + + Args: + att_score (paddle.Tensor): Shape (..., length, ...) + mask (paddle.Tensor): Shape (..., length, ...) + 1 --> The element is not masked + 0 --> The element is masked + axis (int): The axis to calculate the softmax. att_score.shape[axis] must be the same as mask.shape[axis] + + Returns: + att_weights (paddle.Tensor): Shape (..., length, ...). + """ + + if mask is not None: + if att_score.dtype == paddle.float16: + att_score = att_score.masked_fill(paddle.logical_not(mask), -1e4) + else: + att_score = att_score.masked_fill(paddle.logical_not(mask), -1e18) + att_weights = paddle.nn.functional.softmax(x=att_score, axis=axis) * mask + else: + att_weights = paddle.nn.functional.softmax(x=att_score, axis=axis) + return att_weights + + +def cuboid_reorder_reverse(data, cuboid_size, strategy, orig_data_shape): + """Reverse the reordered cuboid back to the original space + + Args: + data (paddle.Tensor): The input data. + cuboid_size (Tuple[int,...]): The size of cuboid. + strategy (str): The strategy of reordering. + orig_data_shape (Tuple[int,...]): The original shape of the data. + + Returns: + data (paddle.Tensor): The recovered data + """ + + B, num_cuboids, cuboid_volume, C = data.shape + T, H, W = orig_data_shape + permutation_axis = [0] + for i, (block_size, total_size, ele_strategy) in enumerate( + zip(cuboid_size, (T, H, W), strategy) + ): + if ele_strategy == "l": + permutation_axis.append(i + 1) + permutation_axis.append(i + 4) + elif ele_strategy == "d": + permutation_axis.append(i + 4) + permutation_axis.append(i + 1) + else: + raise NotImplementedError((f"{ele_strategy} is invalid.")) + permutation_axis.append(7) + data = data.reshape( + [ + B, + T // cuboid_size[0], + H // cuboid_size[1], + W // cuboid_size[2], + cuboid_size[0], + cuboid_size[1], + cuboid_size[2], + C, + ] + ) + data = data.transpose(perm=permutation_axis) + data = data.reshape((B, T, H, W, C)) + return data + + +class CuboidSelfAttentionLayer(paddle.nn.Layer): + """Implements the cuboid self attention. + + The idea of Cuboid Self Attention is to divide the input tensor (T, H, W) into several non-overlapping cuboids. + We apply self-attention inside each cuboid and all cuboid-level self attentions are executed in parallel. + + We adopt two mechanisms for decomposing the input tensor into cuboids: + + 1) local: + We group the tensors within a local window, e.g., X[t:(t+b_t), h:(h+b_h), w:(w+b_w)]. We can also apply the + shifted window strategy proposed in "[ICCV2021] Swin Transformer: Hierarchical Vision Transformer using Shifted Windows". + 2) dilated: + Inspired by the success of dilated convolution "[ICLR2016] Multi-Scale Context Aggregation by Dilated Convolutions", + we split the tensor with dilation factors that are tied to the size of the cuboid. For example, for a cuboid that has width `b_w`, + we sample the elements starting from 0 as 0, w / b_w, 2 * w / b_w, ..., (b_w - 1) * w / b_w. + + The cuboid attention can be viewed as a generalization of the attention mechanism proposed in Video Swin Transformer, https://arxiv.org/abs/2106.13230. + The computational complexity of CuboidAttention can be simply calculated as O(T H W * b_t b_h b_w). To cover multiple correlation patterns, + we are able to combine multiple CuboidAttention layers with different configurations such as cuboid size, shift size, and local / global decomposing strategy. + + In addition, it is straight-forward to extend the cuboid attention to other types of spatiotemporal data that are not described + as regular tensors. We need to define alternative approaches to partition the data into "cuboids". + + In addition, inspired by "[NeurIPS2021] Do Transformers Really Perform Badly for Graph Representation?", + "[NeurIPS2020] Big Bird: Transformers for Longer Sequences", "[EMNLP2021] Longformer: The Long-Document Transformer", we keep + $K$ global vectors to record the global status of the spatiotemporal system. These global vectors will attend to the whole tensor and + the vectors inside each individual cuboids will also attend to the global vectors so that they can peep into the global status of the system. + + Args: + dim (int): The dimension of the input tensor. + num_heads (int): The number of heads. + cuboid_size (tuple, optional): The size of cuboid. Defaults to (2, 7, 7). + shift_size (tuple, optional): The size of shift. Defaults to (0, 0, 0). + strategy (tuple, optional): The strategy. Defaults to ("l", "l", "l"). + padding_type (str, optional): The type of padding. Defaults to "ignore". + qkv_bias (bool, optional): Whether to enable bias in calculating qkv attention. Defaults to False. + qk_scale (float, optional): Whether to enable scale factor when calculating the attention. Defaults to None. + attn_drop (float, optional): The attention dropout. Defaults to 0.0. + proj_drop (float, optional): The projection dropout. Defaults to 0.0. + use_final_proj (bool, optional): Whether to use the final projection. Defaults to True. + norm_layer (str, optional): The normalization layer. Defaults to "layer_norm". + use_global_vector (bool, optional): Whether to use the global vector or not. Defaults to False. + use_global_self_attn (bool, optional): Whether to use self attention among global vectors. Defaults to False. + separate_global_qkv (bool, optional): Whether to use different network to calc q_global, k_global, v_global. Defaults to False. + global_dim_ratio (int, optional): The dim (channels) of global vectors is `global_dim_ratio*dim`. Defaults to 1. + checkpoint_level (bool, optional): Whether to enable gradient checkpointing. Defaults to True. + use_relative_pos (bool, optional): Whether to use relative pos. Defaults to True. + attn_linear_init_mode (str, optional): The mode of attention linear initialization. Defaults to "0". + ffn_linear_init_mode (str, optional): The mode of FFN linear initialization. Defaults to "0". + norm_init_mode (str, optional): The mode of normalization initialization. Defaults to "0". + """ + + def __init__( + self, + dim: int, + num_heads: int, + cuboid_size: Tuple[int, ...] = (2, 7, 7), + shift_size: Tuple[int, ...] = (0, 0, 0), + strategy: Tuple[str, ...] = ("l", "l", "l"), + padding_type: str = "ignore", + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + use_final_proj: bool = True, + norm_layer: str = "layer_norm", + use_global_vector: bool = False, + use_global_self_attn: bool = False, + separate_global_qkv: bool = False, + global_dim_ratio: int = 1, + checkpoint_level: bool = True, + use_relative_pos: bool = True, + attn_linear_init_mode: str = "0", + ffn_linear_init_mode: str = "0", + norm_init_mode: str = "0", + ): + super(CuboidSelfAttentionLayer, self).__init__() + self.attn_linear_init_mode = attn_linear_init_mode + self.ffn_linear_init_mode = ffn_linear_init_mode + self.norm_init_mode = norm_init_mode + assert dim % num_heads == 0 + self.num_heads = num_heads + self.dim = dim + self.cuboid_size = cuboid_size + self.shift_size = shift_size + self.strategy = strategy + self.padding_type = padding_type + self.use_final_proj = use_final_proj + self.use_relative_pos = use_relative_pos + self.use_global_vector = use_global_vector + self.use_global_self_attn = use_global_self_attn + self.separate_global_qkv = separate_global_qkv + if global_dim_ratio != 1: + assert ( + separate_global_qkv is True + ), "Setting global_dim_ratio != 1 requires separate_global_qkv == True." + self.global_dim_ratio = global_dim_ratio + assert self.padding_type in ["ignore", "zeros", "nearest"] + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + if use_relative_pos: + init_data = paddle.zeros( + ( + (2 * cuboid_size[0] - 1) + * (2 * cuboid_size[1] - 1) + * (2 * cuboid_size[2] - 1), + num_heads, + ) + ) + self.relative_position_bias_table = paddle.create_parameter( + shape=init_data.shape, + dtype=init_data.dtype, + default_initializer=nn.initializer.Constant(0.0), + ) + self.relative_position_bias_table.stop_gradient = not True + self.relative_position_bias_table = initializer.trunc_normal_( + self.relative_position_bias_table, std=0.02 + ) + + coords_t = paddle.arange(end=self.cuboid_size[0]) + coords_h = paddle.arange(end=self.cuboid_size[1]) + coords_w = paddle.arange(end=self.cuboid_size[2]) + coords = paddle.stack(x=paddle.meshgrid(coords_t, coords_h, coords_w)) + coords_flatten = paddle.flatten(x=coords, start_axis=1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.transpose(perm=[1, 2, 0]) + relative_coords[:, :, 0] += self.cuboid_size[0] - 1 + relative_coords[:, :, 1] += self.cuboid_size[1] - 1 + relative_coords[:, :, 2] += self.cuboid_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.cuboid_size[1] - 1) * ( + 2 * self.cuboid_size[2] - 1 + ) + relative_coords[:, :, 1] *= 2 * self.cuboid_size[2] - 1 + relative_position_index = relative_coords.sum(axis=-1) + self.register_buffer( + name="relative_position_index", tensor=relative_position_index + ) + self.qkv = paddle.nn.Linear( + in_features=dim, out_features=dim * 3, bias_attr=qkv_bias + ) + self.attn_drop = paddle.nn.Dropout(p=attn_drop) + if self.use_global_vector: + if self.separate_global_qkv: + self.l2g_q_net = paddle.nn.Linear( + in_features=dim, out_features=dim, bias_attr=qkv_bias + ) + self.l2g_global_kv_net = paddle.nn.Linear( + in_features=global_dim_ratio * dim, + out_features=dim * 2, + bias_attr=qkv_bias, + ) + self.g2l_global_q_net = paddle.nn.Linear( + in_features=global_dim_ratio * dim, + out_features=dim, + bias_attr=qkv_bias, + ) + self.g2l_k_net = paddle.nn.Linear( + in_features=dim, out_features=dim, bias_attr=qkv_bias + ) + self.g2l_v_net = paddle.nn.Linear( + in_features=dim, + out_features=global_dim_ratio * dim, + bias_attr=qkv_bias, + ) + if self.use_global_self_attn: + self.g2g_global_qkv_net = paddle.nn.Linear( + in_features=global_dim_ratio * dim, + out_features=global_dim_ratio * dim * 3, + bias_attr=qkv_bias, + ) + else: + self.global_qkv = paddle.nn.Linear( + in_features=dim, out_features=dim * 3, bias_attr=qkv_bias + ) + self.global_attn_drop = paddle.nn.Dropout(p=attn_drop) + if use_final_proj: + self.proj = paddle.nn.Linear(in_features=dim, out_features=dim) + self.proj_drop = paddle.nn.Dropout(p=proj_drop) + if self.use_global_vector: + self.global_proj = paddle.nn.Linear( + in_features=global_dim_ratio * dim, + out_features=global_dim_ratio * dim, + ) + self.norm = cuboid_utils.get_norm_layer(norm_layer, in_channels=dim) + if self.use_global_vector: + self.global_vec_norm = cuboid_utils.get_norm_layer( + norm_layer, in_channels=global_dim_ratio * dim + ) + self.checkpoint_level = checkpoint_level + self.reset_parameters() + + def reset_parameters(self): + cuboid_utils.apply_initialization( + self.qkv, linear_mode=self.attn_linear_init_mode + ) + if self.use_final_proj: + cuboid_utils.apply_initialization( + self.proj, linear_mode=self.ffn_linear_init_mode + ) + cuboid_utils.apply_initialization(self.norm, norm_mode=self.norm_init_mode) + if self.use_global_vector: + if self.separate_global_qkv: + cuboid_utils.apply_initialization( + self.l2g_q_net, linear_mode=self.attn_linear_init_mode + ) + cuboid_utils.apply_initialization( + self.l2g_global_kv_net, linear_mode=self.attn_linear_init_mode + ) + cuboid_utils.apply_initialization( + self.g2l_global_q_net, linear_mode=self.attn_linear_init_mode + ) + cuboid_utils.apply_initialization( + self.g2l_k_net, linear_mode=self.attn_linear_init_mode + ) + cuboid_utils.apply_initialization( + self.g2l_v_net, linear_mode=self.attn_linear_init_mode + ) + if self.use_global_self_attn: + cuboid_utils.apply_initialization( + self.g2g_global_qkv_net, linear_mode=self.attn_linear_init_mode + ) + else: + cuboid_utils.apply_initialization( + self.global_qkv, linear_mode=self.attn_linear_init_mode + ) + cuboid_utils.apply_initialization( + self.global_vec_norm, norm_mode=self.norm_init_mode + ) + + def forward(self, x, global_vectors=None): + x = self.norm(x) + + B, T, H, W, C_in = x.shape + assert C_in == self.dim + if self.use_global_vector: + _, num_global, _ = global_vectors.shape + global_vectors = self.global_vec_norm(global_vectors) + cuboid_size, shift_size = update_cuboid_size_shift_size( + (T, H, W), self.cuboid_size, self.shift_size, self.strategy + ) + + pad_t = (cuboid_size[0] - T % cuboid_size[0]) % cuboid_size[0] + pad_h = (cuboid_size[1] - H % cuboid_size[1]) % cuboid_size[1] + pad_w = (cuboid_size[2] - W % cuboid_size[2]) % cuboid_size[2] + x = cuboid_utils.generalize_padding(x, pad_t, pad_h, pad_w, self.padding_type) + + if any(i > 0 for i in shift_size): + shifted_x = paddle.roll( + x=x, + shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), + axis=(1, 2, 3), + ) + else: + shifted_x = x + + reordered_x = cuboid_reorder( + shifted_x, cuboid_size=cuboid_size, strategy=self.strategy + ) + + _, num_cuboids, cuboid_volume, _ = reordered_x.shape + attn_mask = compute_cuboid_self_attention_mask( + (T, H, W), + cuboid_size, + shift_size=shift_size, + strategy=self.strategy, + padding_type=self.padding_type, + device=x.place, + ) + head_C = C_in // self.num_heads + qkv = ( + self.qkv(reordered_x) + .reshape([B, num_cuboids, cuboid_volume, 3, self.num_heads, head_C]) + .transpose(perm=[3, 0, 4, 1, 2, 5]) + ) + + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + perm_0 = list(range(k.ndim)) + perm_0[-2] = -1 + perm_0[-1] = -2 + attn_score = q @ k.transpose(perm=perm_0) + + if self.use_relative_pos: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[:cuboid_volume, :cuboid_volume].reshape( + [-1] + ) + ].reshape([cuboid_volume, cuboid_volume, -1]) + relative_position_bias = relative_position_bias.transpose( + perm=[2, 0, 1] + ).unsqueeze(axis=1) + attn_score = attn_score + relative_position_bias + + if self.use_global_vector: + global_head_C = self.global_dim_ratio * head_C + if self.separate_global_qkv: + l2g_q = ( + self.l2g_q_net(reordered_x) + .reshape([B, num_cuboids, cuboid_volume, self.num_heads, head_C]) + .transpose(perm=[0, 3, 1, 2, 4]) + ) + l2g_q = l2g_q * self.scale + l2g_global_kv = ( + self.l2g_global_kv_net(global_vectors) + .reshape([B, 1, num_global, 2, self.num_heads, head_C]) + .transpose(perm=[3, 0, 4, 1, 2, 5]) + ) + l2g_global_k, l2g_global_v = l2g_global_kv[0], l2g_global_kv[1] + g2l_global_q = ( + self.g2l_global_q_net(global_vectors) + .reshape([B, num_global, self.num_heads, head_C]) + .transpose(perm=[0, 2, 1, 3]) + ) + g2l_global_q = g2l_global_q * self.scale + g2l_k = ( + self.g2l_k_net(reordered_x) + .reshape([B, num_cuboids, cuboid_volume, self.num_heads, head_C]) + .transpose(perm=[0, 3, 1, 2, 4]) + ) + g2l_v = ( + self.g2l_v_net(reordered_x) + .reshape( + [B, num_cuboids, cuboid_volume, self.num_heads, global_head_C] + ) + .transpose(perm=[0, 3, 1, 2, 4]) + ) + if self.use_global_self_attn: + g2g_global_qkv = ( + self.g2g_global_qkv_net(global_vectors) + .reshape([B, 1, num_global, 3, self.num_heads, global_head_C]) + .transpose(perm=[3, 0, 4, 1, 2, 5]) + ) + g2g_global_q, g2g_global_k, g2g_global_v = ( + g2g_global_qkv[0], + g2g_global_qkv[1], + g2g_global_qkv[2], + ) + g2g_global_q = g2g_global_q.squeeze(axis=2) * self.scale + else: + q_global, k_global, v_global = ( + self.global_qkv(global_vectors) + .reshape([B, 1, num_global, 3, self.num_heads, head_C]) + .transpose(perm=[3, 0, 4, 1, 2, 5]) + ) + q_global = q_global.squeeze(axis=2) * self.scale + l2g_q, g2l_k, g2l_v = q, k, v + g2l_global_q, l2g_global_k, l2g_global_v = ( + q_global, + k_global, + v_global, + ) + if self.use_global_self_attn: + g2g_global_q, g2g_global_k, g2g_global_v = ( + q_global, + k_global, + v_global, + ) + + perm_1 = list(range(l2g_global_k.ndim)) + perm_1[-2] = -1 + perm_1[-1] = -2 + l2g_attn_score = l2g_q @ l2g_global_k.transpose(perm=perm_1) + attn_score_l2l_l2g = paddle.concat(x=(attn_score, l2g_attn_score), axis=-1) + + if attn_mask.ndim == 5: + attn_mask_l2l_l2g = F.pad( + attn_mask, [0, num_global], "constant", 1, data_format="NDHWC" + ) + elif attn_mask.ndim == 3: + attn_mask = attn_mask.astype("float32") + attn_mask_l2l_l2g = F.pad( + attn_mask, [0, num_global], "constant", 1, data_format="NCL" + ) + attn_mask_l2l_l2g = attn_mask_l2l_l2g.astype("bool") + else: + attn_mask_l2l_l2g = F.pad(attn_mask, [0, num_global], "constant", 1) + + v_l_g = paddle.concat( + x=( + v, + l2g_global_v.expand( + shape=[B, self.num_heads, num_cuboids, num_global, head_C] + ), + ), + axis=3, + ) + attn_score_l2l_l2g = masked_softmax( + attn_score_l2l_l2g, mask=attn_mask_l2l_l2g + ) + attn_score_l2l_l2g = self.attn_drop(attn_score_l2l_l2g) + reordered_x = ( + (attn_score_l2l_l2g @ v_l_g) + .transpose(perm=[0, 2, 3, 1, 4]) + .reshape([B, num_cuboids, cuboid_volume, self.dim]) + ) + if self.padding_type == "ignore": + g2l_attn_mask = paddle.ones(shape=(1, T, H, W, 1)) + if pad_t > 0 or pad_h > 0 or pad_w > 0: + g2l_attn_mask = F.pad( + g2l_attn_mask, + [0, 0, 0, pad_w, 0, pad_h, 0, pad_t], + data_format="NDHWC", + ) + if any(i > 0 for i in shift_size): + g2l_attn_mask = paddle.roll( + x=g2l_attn_mask, + shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), + axis=(1, 2, 3), + ) + g2l_attn_mask = g2l_attn_mask.reshape((-1,)) + else: + g2l_attn_mask = None + temp = g2l_k.reshape( + [B, self.num_heads, num_cuboids * cuboid_volume, head_C] + ) + perm_2 = list(range(temp.ndim)) + perm_2[-2] = -1 + perm_2[-1] = -2 + g2l_attn_score = g2l_global_q @ temp.transpose(perm=perm_2) + if self.use_global_self_attn: + temp = g2g_global_k.squeeze(axis=2) + perm_3 = list(range(temp.ndim)) + perm_3[-2] = -1 + perm_3[-1] = -2 + g2g_attn_score = g2g_global_q @ temp.transpose(perm=perm_3) + g2all_attn_score = paddle.concat( + x=(g2l_attn_score, g2g_attn_score), axis=-1 + ) + if g2l_attn_mask is not None: + g2all_attn_mask = F.pad( + g2l_attn_mask, + [0, num_global], + "constant", + 1, + data_format="NDHWC", + ) + else: + g2all_attn_mask = None + new_v = paddle.concat( + x=( + g2l_v.reshape( + [ + B, + self.num_heads, + num_cuboids * cuboid_volume, + global_head_C, + ] + ), + g2g_global_v.reshape( + [B, self.num_heads, num_global, global_head_C] + ), + ), + axis=2, + ) + else: + g2all_attn_score = g2l_attn_score + g2all_attn_mask = g2l_attn_mask + new_v = g2l_v.reshape( + [B, self.num_heads, num_cuboids * cuboid_volume, global_head_C] + ) + g2all_attn_score = masked_softmax(g2all_attn_score, mask=g2all_attn_mask) + g2all_attn_score = self.global_attn_drop(g2all_attn_score) + new_global_vector = ( + (g2all_attn_score @ new_v) + .transpose(perm=[0, 2, 1, 3]) + .reshape([B, num_global, self.global_dim_ratio * self.dim]) + ) + else: + attn_score = masked_softmax(attn_score, mask=attn_mask) + attn_score = self.attn_drop(attn_score) + reordered_x = ( + (attn_score @ v) + .transpose(perm=[0, 2, 3, 1, 4]) + .reshape([B, num_cuboids, cuboid_volume, self.dim]) + ) + + if self.use_final_proj: + reordered_x = paddle.cast(reordered_x, dtype="float32") + reordered_x = self.proj_drop(self.proj(reordered_x)) + if self.use_global_vector: + new_global_vector = self.proj_drop(self.global_proj(new_global_vector)) + shifted_x = cuboid_reorder_reverse( + reordered_x, + cuboid_size=cuboid_size, + strategy=self.strategy, + orig_data_shape=(T + pad_t, H + pad_h, W + pad_w), + ) + if any(i > 0 for i in shift_size): + x = paddle.roll( + x=shifted_x, + shifts=(shift_size[0], shift_size[1], shift_size[2]), + axis=(1, 2, 3), + ) + else: + x = shifted_x + x = cuboid_utils.generalize_unpadding( + x, pad_t=pad_t, pad_h=pad_h, pad_w=pad_w, padding_type=self.padding_type + ) + if self.use_global_vector: + return x, new_global_vector + else: + return x + + +class StackCuboidSelfAttentionBlock(paddle.nn.Layer): + """ + - "use_inter_ffn" is True + x --> attn1 -----+-------> ffn1 ---+---> attn2 --> ... --> ffn_k --> out + | ^ | ^ + | | | | + |-------------| |-------------| + - "use_inter_ffn" is False + x --> attn1 -----+------> attn2 --> ... attnk --+----> ffnk ---+---> out + | ^ | ^ ^ | ^ + | | | | | | | + |-------------| |------------| ----------| |-----------| + If we have enabled global memory vectors, each attention will be a + + Args: + dim (int): The dimension of the input tensor. + num_heads (int): The number of heads. + block_cuboid_size (list, optional): The size of block cuboid . Defaults to [(4, 4, 4), (4, 4, 4)]. + block_shift_size (list, optional): The shift size of block. Defaults to [(0, 0, 0), (2, 2, 2)]. + block_strategy (list, optional): The strategy of block. Defaults to [("d", "d", "d"), ("l", "l", "l")]. + padding_type (str, optional): The type of padding. Defaults to "ignore". + qkv_bias (bool, optional): Whether to enable bias in calculating qkv attention. Defaults to False. + qk_scale (float, optional): Whether to enable scale factor when calculating the attention. Defaults to None. + attn_drop (float, optional): The attention dropout. Defaults to 0.0. + proj_drop (float, optional): The projection dropout. Defaults to 0.0. + use_final_proj (bool, optional): Whether to use the final projection. Defaults to True. + norm_layer (str, optional): The normalization layer. Defaults to "layer_norm". + use_global_vector (bool, optional): Whether to use the global vector or not. Defaults to False. + use_global_self_attn (bool, optional): Whether to use self attention among global vectors. Defaults to False. + separate_global_qkv (bool, optional): Whether to use different network to calc q_global, k_global, v_global. + Defaults to False. + global_dim_ratio (int, optional): The dim (channels) of global vectors is `global_dim_ratio*dim`. + Defaults to 1. + checkpoint_level (bool, optional): Whether to enable gradient checkpointing. Defaults to True. + use_relative_pos (bool, optional): Whether to use relative pos. Defaults to True. + use_relative_pos (bool, optional): Whether to use relative pos. Defaults to True. + attn_linear_init_mode (str, optional): The mode of attention linear initialization. Defaults to "0". + ffn_linear_init_mode (str, optional): The mode of FFN linear initialization. Defaults to "0". + norm_init_mode (str, optional): The mode of normalization initialization. Defaults to "0". + + """ + + def __init__( + self, + dim: int, + num_heads: int, + block_cuboid_size: Tuple[Tuple[int, ...], ...] = [(4, 4, 4), (4, 4, 4)], + block_shift_size: Tuple[Tuple[int, ...], ...] = [(0, 0, 0), (2, 2, 2)], + block_strategy: Tuple[Tuple[str, ...], ...] = [ + ("d", "d", "d"), + ("l", "l", "l"), + ], + padding_type: str = "ignore", + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ffn_drop: float = 0.0, + activation: str = "leaky", + gated_ffn: bool = False, + norm_layer: str = "layer_norm", + use_inter_ffn: bool = False, + use_global_vector: bool = False, + use_global_vector_ffn: bool = True, + use_global_self_attn: bool = False, + separate_global_qkv: bool = False, + global_dim_ratio: int = 1, + checkpoint_level: bool = True, + use_relative_pos: bool = True, + use_final_proj: bool = True, + attn_linear_init_mode: str = "0", + ffn_linear_init_mode: str = "0", + norm_init_mode: str = "0", + ): + super(StackCuboidSelfAttentionBlock, self).__init__() + self.attn_linear_init_mode = attn_linear_init_mode + self.ffn_linear_init_mode = ffn_linear_init_mode + self.norm_init_mode = norm_init_mode + if ( + len(block_cuboid_size[0]) <= 0 + or len(block_shift_size) <= 0 + or len(block_strategy) <= 0 + ): + raise ValueError( + "Format of the block cuboid size is not correct. block_cuboid_size={block_cuboid_size}" + ) + if len(block_cuboid_size) != len(block_shift_size) and len( + block_cuboid_size + ) != len(block_strategy): + raise ValueError( + "The lengths of block_cuboid_size, block_shift_size, and block_strategy must be equal." + ) + + self.num_attn = len(block_cuboid_size) + self.checkpoint_level = checkpoint_level + self.use_inter_ffn = use_inter_ffn + self.use_global_vector = use_global_vector + self.use_global_vector_ffn = use_global_vector_ffn + self.use_global_self_attn = use_global_self_attn + self.global_dim_ratio = global_dim_ratio + if self.use_inter_ffn: + self.ffn_l = paddle.nn.LayerList( + sublayers=[ + PositionwiseFFN( + units=dim, + hidden_size=4 * dim, + activation_dropout=ffn_drop, + dropout=ffn_drop, + gated_proj=gated_ffn, + activation=activation, + normalization=norm_layer, + pre_norm=True, + linear_init_mode=ffn_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(self.num_attn) + ] + ) + if self.use_global_vector_ffn and self.use_global_vector: + self.global_ffn_l = paddle.nn.LayerList( + sublayers=[ + PositionwiseFFN( + units=global_dim_ratio * dim, + hidden_size=global_dim_ratio * 4 * dim, + activation_dropout=ffn_drop, + dropout=ffn_drop, + gated_proj=gated_ffn, + activation=activation, + normalization=norm_layer, + pre_norm=True, + linear_init_mode=ffn_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(self.num_attn) + ] + ) + else: + self.ffn_l = paddle.nn.LayerList( + sublayers=[ + PositionwiseFFN( + units=dim, + hidden_size=4 * dim, + activation_dropout=ffn_drop, + dropout=ffn_drop, + gated_proj=gated_ffn, + activation=activation, + normalization=norm_layer, + pre_norm=True, + linear_init_mode=ffn_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + ] + ) + if self.use_global_vector_ffn and self.use_global_vector: + self.global_ffn_l = paddle.nn.LayerList( + sublayers=[ + PositionwiseFFN( + units=global_dim_ratio * dim, + hidden_size=global_dim_ratio * 4 * dim, + activation_dropout=ffn_drop, + dropout=ffn_drop, + gated_proj=gated_ffn, + activation=activation, + normalization=norm_layer, + pre_norm=True, + linear_init_mode=ffn_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + ] + ) + self.attn_l = paddle.nn.LayerList( + sublayers=[ + CuboidSelfAttentionLayer( + dim=dim, + num_heads=num_heads, + cuboid_size=ele_cuboid_size, + shift_size=ele_shift_size, + strategy=ele_strategy, + padding_type=padding_type, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + use_global_vector=use_global_vector, + use_global_self_attn=use_global_self_attn, + separate_global_qkv=separate_global_qkv, + global_dim_ratio=global_dim_ratio, + checkpoint_level=checkpoint_level, + use_relative_pos=use_relative_pos, + use_final_proj=use_final_proj, + attn_linear_init_mode=attn_linear_init_mode, + ffn_linear_init_mode=ffn_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for ele_cuboid_size, ele_shift_size, ele_strategy in zip( + block_cuboid_size, block_shift_size, block_strategy + ) + ] + ) + + def reset_parameters(self): + for m in self.ffn_l: + m.reset_parameters() + if self.use_global_vector_ffn and self.use_global_vector: + for m in self.global_ffn_l: + m.reset_parameters() + for m in self.attn_l: + m.reset_parameters() + + def forward(self, x, global_vectors=None): + if self.use_inter_ffn: + if self.use_global_vector: + for idx, (attn, ffn) in enumerate(zip(self.attn_l, self.ffn_l)): + if self.checkpoint_level >= 2 and self.training: + x_out, global_vectors_out = fleet.utils.recompute( + attn, x, global_vectors + ) + else: + x_out, global_vectors_out = attn(x, global_vectors) + x = x + x_out + global_vectors = global_vectors + global_vectors_out + if self.checkpoint_level >= 1 and self.training: + x = fleet.utils.recompute(ffn, x) + if self.use_global_vector_ffn: + global_vectors = fleet.utils.recompute( + self.global_ffn_l[idx], global_vectors + ) + else: + x = ffn(x) + if self.use_global_vector_ffn: + global_vectors = self.global_ffn_l[idx](global_vectors) + return x, global_vectors + else: + for idx, (attn, ffn) in enumerate(zip(self.attn_l, self.ffn_l)): + if self.checkpoint_level >= 2 and self.training: + x = x + fleet.utils.recompute(attn, x) + else: + x = x + attn(x) + if self.checkpoint_level >= 1 and self.training: + x = fleet.utils.recompute(ffn, x) + else: + x = ffn(x) + return x + elif self.use_global_vector: + for idx, attn in enumerate(self.attn_l): + if self.checkpoint_level >= 2 and self.training: + x_out, global_vectors_out = fleet.utils.recompute( + attn, x, global_vectors + ) + else: + x_out, global_vectors_out = attn(x, global_vectors) + x = x + x_out + global_vectors = global_vectors + global_vectors_out + if self.checkpoint_level >= 1 and self.training: + x = fleet.utils.recompute(self.ffn_l[0], x) + if self.use_global_vector_ffn: + global_vectors = fleet.utils.recompute( + self.global_ffn_l[0], global_vectors + ) + else: + x = self.ffn_l[0](x) + if self.use_global_vector_ffn: + global_vectors = self.global_ffn_l[0](global_vectors) + return x, global_vectors + else: + for idx, attn in enumerate(self.attn_l): + if self.checkpoint_level >= 2 and self.training: + out = fleet.utils.recompute(attn, x) + else: + out = attn(x) + x = x + out + if self.checkpoint_level >= 1 and self.training: + x = fleet.utils.recompute(self.ffn_l[0], x) + else: + x = self.ffn_l[0](x) + return x + + +class CuboidTransformerEncoder(paddle.nn.Layer): + """Encoder of the CuboidTransformer + + x --> attn_block --> patch_merge --> attn_block --> patch_merge --> ... --> out + + Args: + input_shape (Tuple[int,...]): The shape of the input. Contains T, H, W, C + base_units (int, optional): The number of units. Defaults to 128. + block_units (int, optional): The number of block units. Defaults to None. + scale_alpha (float, optional): We scale up the channels based on the formula: + - round_to(base_units * max(downsample_scale) ** units_alpha, 4). Defaults to 1.0. + depth (list, optional): The number of layers for each block. Defaults to [4, 4, 4]. + downsample (int, optional): The downsample ratio. Defaults to 2. + downsample_type (str, optional): The type of downsample. Defaults to "patch_merge". + block_attn_patterns (str, optional): Attention pattern for the cuboid attention for each block. Defaults to None. + block_cuboid_size (list, optional): A list of cuboid size parameters. Defaults to [(4, 4, 4), (4, 4, 4)]. + block_strategy (list, optional): A list of cuboid strategies. Defaults to [("l", "l", "l"), ("d", "d", "d")]. + block_shift_size (list, optional): A list of shift sizes. Defaults to [(0, 0, 0), (0, 0, 0)]. + num_heads (int, optional): The number of heads. Defaults to 4. + attn_drop (float, optional): The ratio of attention dropout. Defaults to 0.0. + proj_drop (float, optional): The ratio of projection dropout. Defaults to 0.0. + ffn_drop (float, optional): The ratio of FFN dropout. Defaults to 0.0. + ffn_activation (str, optional): The FFN activation. Defaults to "leaky". + gated_ffn (bool, optional): Whether to use gate FFN. Defaults to False. + norm_layer (str, optional): The normalization layer. Defaults to "layer_norm". + use_inter_ffn (bool, optional): Whether to use inter FFN. Defaults to True. + padding_type (str, optional): The type of padding. Defaults to "ignore". + checkpoint_level (bool, optional): Whether to enable gradient checkpointing. Defaults to True. + use_relative_pos (bool, optional): Whether to use relative pos. Defaults to True. + self_attn_use_final_proj (bool, optional): Whether to use self attention for final projection. Defaults to True. + use_global_vector (bool, optional): Whether to use the global vector or not. Defaults to False. + use_global_vector_ffn (bool, optional): Whether to use FFN global vectors. Defaults to False. + use_global_self_attn (bool, optional): Whether to use global self attention. Defaults to False. + separate_global_qkv (bool, optional): Whether to use different network to calc q_global, k_global, v_global. + Defaults to False. + global_dim_ratio (int, optional): The dim (channels) of global vectors is `global_dim_ratio*dim`. + Defaults to 1. + attn_linear_init_mode (str, optional): The mode of attention linear initialization. Defaults to "0". + ffn_linear_init_mode (str, optional): The mode of FFN linear initialization. Defaults to "0". + conv_init_mode (str, optional): The mode of conv initialization. Defaults to "0". + down_linear_init_mode (str, optional): The mode of downsample linear initialization. Defaults to "0". + norm_init_mode (str, optional): The mode of normalization. Defaults to "0". + + """ + + def __init__( + self, + input_shape: Tuple[int, ...], + base_units: int = 128, + block_units: int = None, + scale_alpha: float = 1.0, + depth: Tuple[int, ...] = [4, 4, 4], + downsample: int = 2, + downsample_type: str = "patch_merge", + block_attn_patterns: str = None, + block_cuboid_size: Tuple[Tuple[int, ...], ...] = [(4, 4, 4), (4, 4, 4)], + block_strategy: Tuple[Tuple[str, ...], ...] = [ + ("l", "l", "l"), + ("d", "d", "d"), + ], + block_shift_size: Tuple[Tuple[int, ...], ...] = [(0, 0, 0), (0, 0, 0)], + num_heads: int = 4, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ffn_drop: float = 0.0, + ffn_activation: str = "leaky", + gated_ffn: bool = False, + norm_layer: str = "layer_norm", + use_inter_ffn: bool = True, + padding_type: str = "ignore", + checkpoint_level: bool = True, + use_relative_pos: bool = True, + self_attn_use_final_proj: bool = True, + use_global_vector: bool = False, + use_global_vector_ffn: bool = True, + use_global_self_attn: bool = False, + separate_global_qkv: bool = False, + global_dim_ratio: int = 1, + attn_linear_init_mode: str = "0", + ffn_linear_init_mode: str = "0", + conv_init_mode: str = "0", + down_linear_init_mode: str = "0", + norm_init_mode: str = "0", + ): + super(CuboidTransformerEncoder, self).__init__() + self.attn_linear_init_mode = attn_linear_init_mode + self.ffn_linear_init_mode = ffn_linear_init_mode + self.conv_init_mode = conv_init_mode + self.down_linear_init_mode = down_linear_init_mode + self.norm_init_mode = norm_init_mode + self.input_shape = input_shape + self.depth = depth + self.num_blocks = len(depth) + self.base_units = base_units + self.scale_alpha = scale_alpha + if not isinstance(downsample, (tuple, list)): + downsample = 1, downsample, downsample + self.downsample = downsample + self.downsample_type = downsample_type + self.num_heads = num_heads + self.use_global_vector = use_global_vector + self.checkpoint_level = checkpoint_level + if block_units is None: + block_units = [ + cuboid_utils.round_to( + base_units * int((max(downsample) ** scale_alpha) ** i), 4 + ) + for i in range(self.num_blocks) + ] + else: + assert len(block_units) == self.num_blocks and block_units[0] == base_units + self.block_units = block_units + if self.num_blocks > 1: + if downsample_type == "patch_merge": + self.down_layers = paddle.nn.LayerList( + sublayers=[ + PatchMerging3D( + dim=self.block_units[i], + downsample=downsample, + padding_type=padding_type, + out_dim=self.block_units[i + 1], + linear_init_mode=down_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for i in range(self.num_blocks - 1) + ] + ) + else: + raise NotImplementedError(f"{downsample_type} is invalid.") + if self.use_global_vector: + self.down_layer_global_proj = paddle.nn.LayerList( + sublayers=[ + paddle.nn.Linear( + in_features=global_dim_ratio * self.block_units[i], + out_features=global_dim_ratio * self.block_units[i + 1], + ) + for i in range(self.num_blocks - 1) + ] + ) + if block_attn_patterns is not None: + mem_shapes = self.get_mem_shapes() + if isinstance(block_attn_patterns, (tuple, list)): + assert len(block_attn_patterns) == self.num_blocks + else: + block_attn_patterns = [ + block_attn_patterns for _ in range(self.num_blocks) + ] + block_cuboid_size = [] + block_strategy = [] + block_shift_size = [] + for idx, key in enumerate(block_attn_patterns): + func = cuboid_utils.CuboidSelfAttentionPatterns.get(key) + cuboid_size, strategy, shift_size = func(mem_shapes[idx]) + block_cuboid_size.append(cuboid_size) + block_strategy.append(strategy) + block_shift_size.append(shift_size) + else: + if not isinstance(block_cuboid_size[0][0], (list, tuple)): + block_cuboid_size = [block_cuboid_size for _ in range(self.num_blocks)] + else: + assert ( + len(block_cuboid_size) == self.num_blocks + ), f"Incorrect input format! Received block_cuboid_size={block_cuboid_size}" + if not isinstance(block_strategy[0][0], (list, tuple)): + block_strategy = [block_strategy for _ in range(self.num_blocks)] + else: + assert ( + len(block_strategy) == self.num_blocks + ), f"Incorrect input format! Received block_strategy={block_strategy}" + if not isinstance(block_shift_size[0][0], (list, tuple)): + block_shift_size = [block_shift_size for _ in range(self.num_blocks)] + else: + assert ( + len(block_shift_size) == self.num_blocks + ), f"Incorrect input format! Received block_shift_size={block_shift_size}" + self.block_cuboid_size = block_cuboid_size + self.block_strategy = block_strategy + self.block_shift_size = block_shift_size + self.blocks = paddle.nn.LayerList( + sublayers=[ + paddle.nn.Sequential( + *[ + StackCuboidSelfAttentionBlock( + dim=self.block_units[i], + num_heads=num_heads, + block_cuboid_size=block_cuboid_size[i], + block_strategy=block_strategy[i], + block_shift_size=block_shift_size[i], + attn_drop=attn_drop, + proj_drop=proj_drop, + ffn_drop=ffn_drop, + activation=ffn_activation, + gated_ffn=gated_ffn, + norm_layer=norm_layer, + use_inter_ffn=use_inter_ffn, + padding_type=padding_type, + use_global_vector=use_global_vector, + use_global_vector_ffn=use_global_vector_ffn, + use_global_self_attn=use_global_self_attn, + separate_global_qkv=separate_global_qkv, + global_dim_ratio=global_dim_ratio, + checkpoint_level=checkpoint_level, + use_relative_pos=use_relative_pos, + use_final_proj=self_attn_use_final_proj, + attn_linear_init_mode=attn_linear_init_mode, + ffn_linear_init_mode=ffn_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(depth[i]) + ] + ) + for i in range(self.num_blocks) + ] + ) + self.reset_parameters() + + def reset_parameters(self): + if self.num_blocks > 1: + for m in self.down_layers: + m.reset_parameters() + if self.use_global_vector: + cuboid_utils.apply_initialization( + self.down_layer_global_proj, linear_mode=self.down_linear_init_mode + ) + for ms in self.blocks: + for m in ms: + m.reset_parameters() + + def get_mem_shapes(self): + """Get the shape of the output memory based on the input shape. This can be used for constructing the decoder. + + Returns: + mem_shapes : A list of shapes of the output memory + """ + + if self.num_blocks == 1: + return [self.input_shape] + else: + mem_shapes = [self.input_shape] + curr_shape = self.input_shape + for down_layer in self.down_layers: + curr_shape = down_layer.get_out_shape(curr_shape) + mem_shapes.append(curr_shape) + return mem_shapes + + def forward(self, x, global_vectors=None): + """ + Args: + x : Shape (B, T, H, W, C) + + Returns: + out (List[paddle.Tensor,..]): A list of tensors from the bottom layer to the top layer of the encoder. For + example, it can have shape + - (B, T, H, W, C1) + - (B, T, H // 2, W // 2, 2 * C1) + - (B, T, H // 4, W // 4, 4 * C1) + ... + global_mem_out (List,Optional): The output of the global vector. + """ + + B, T, H, W, C_in = x.shape + assert (T, H, W, C_in) == self.input_shape + + if self.use_global_vector: + out = [] + global_mem_out = [] + for i in range(self.num_blocks): + for l in self.blocks[i]: + x, global_vectors = l(x, global_vectors) + out.append(x) + global_mem_out.append(global_vectors) + if self.num_blocks > 1 and i < self.num_blocks - 1: + x = self.down_layers[i](x) + global_vectors = self.down_layer_global_proj[i](global_vectors) + return out, global_mem_out + else: + out = [] + for i in range(self.num_blocks): + x = self.blocks[i](x) + out.append(x) + if self.num_blocks > 1 and i < self.num_blocks - 1: + x = self.down_layers[i](x) + return out diff --git a/ppsci/arch/cuboid_transformer_utils.py b/ppsci/arch/cuboid_transformer_utils.py new file mode 100644 index 000000000..456e975cf --- /dev/null +++ b/ppsci/arch/cuboid_transformer_utils.py @@ -0,0 +1,349 @@ +import functools +from typing import Tuple + +import paddle +import paddle.nn.functional as F +from paddle import nn + +from ppsci.utils import initializer + + +def round_to(dat, c): + return dat + (dat - dat % c) % c + + +class RMSNorm(paddle.nn.Layer): + """Root Mean Square Layer Normalization proposed in "[NeurIPS2019] Root Mean Square Layer Normalization" + + Args: + d (Optional[int]): The model size. + p (float, optional): The partial RMSNorm, valid value [0, 1]. Defaults to -1.0. + eps (float, optional): The epsilon value. Defaults to 1e-08. + bias (bool, optional): Whether use bias term for RMSNorm, + because RMSNorm doesn't enforce re-centering invariance.Defaults to False. + """ + + def __init__( + self, + d: Tuple[int, ...], + p: float = -1.0, + eps: float = 1e-08, + bias: bool = False, + ): + super().__init__() + self.eps = eps + self.d = d + self.p = p + self.bias = bias + init_data = paddle.ones(d) + self.scale = paddle.create_parameter( + shape=init_data.shape, + dtype=init_data.dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + self.scale.stop_gradient = False + self.add_parameter(name="scale", parameter=self.scale) + if self.bias: + init_data = paddle.zeros(d) + self.offset = paddle.create_parameter( + shape=init_data.shape, + dtype=init_data.dtype, + default_initializer=nn.initializer.Constant(0.0), + ) + self.offset.stop_gradient = False + self.add_parameter(name="offset", parameter=self.offset) + + def forward(self, x): + if self.p < 0.0 or self.p > 1.0: + norm_x = x.norm(p=2, axis=-1, keepdim=True) + d_x = self.d + else: + partial_size = int(self.d * self.p) + partial_x, _ = paddle.split( + x=x, num_or_sections=[partial_size, self.d - partial_size], axis=-1 + ) + norm_x = partial_x.norm(p=2, axis=-1, keepdim=True) + d_x = partial_size + rms_x = norm_x * d_x ** (-1.0 / 2) + x_normed = x / (rms_x + self.eps) + if self.bias: + return self.scale * x_normed + self.offset + return self.scale * x_normed + + +def get_norm_layer( + normalization: str = "layer_norm", + axis: int = -1, + epsilon: float = 1e-05, + in_channels: int = 0, + **kwargs, +): + """Get the normalization layer based on the provided type + + Args: + normalization (str): The type of the layer normalization from ['layer_norm']. + axis (float): The axis to normalize the. + epsilon (float): The epsilon of the normalization layer. + in_channels (int): Input channel. + + Returns: + norm_layer (norm): The layer normalization layer. + """ + + if isinstance(normalization, str): + if normalization == "layer_norm": + assert in_channels > 0 + assert axis == -1 + norm_layer = paddle.nn.LayerNorm( + normalized_shape=in_channels, epsilon=epsilon, **kwargs + ) + elif normalization == "rms_norm": + assert axis == -1 + norm_layer = RMSNorm(d=in_channels, epsilon=epsilon, **kwargs) + else: + raise NotImplementedError( + "normalization={} is not supported".format(normalization) + ) + return norm_layer + elif normalization is None: + return paddle.nn.Identity() + else: + raise NotImplementedError("The type of normalization must be str") + + +def generalize_padding(x, pad_t, pad_h, pad_w, padding_type, t_pad_left=False): + if pad_t == 0 and pad_h == 0 and pad_w == 0: + return x + assert padding_type in ["zeros", "ignore", "nearest"] + B, T, H, W, C = x.shape + if padding_type == "nearest": + return paddle.nn.functional.interpolate( + x=x.transpose(perm=[0, 4, 1, 2, 3]), size=(T + pad_t, H + pad_h, W + pad_w) + ).transpose(perm=[0, 2, 3, 4, 1]) + elif t_pad_left: + return F.pad(x, [0, 0, 0, pad_w, 0, pad_h, pad_t, 0], data_format="NDHWC") + else: + data_pad = F.pad( + x, [0, 0, pad_t, 0, pad_h, 0, pad_w, 0, 0, 0], data_format="NDHWC" + ) + data_pad = paddle.concat( + [data_pad[:, pad_t:, ...], data_pad[:, :pad_t, ...]], axis=1 + ) + return data_pad + + +def generalize_unpadding(x, pad_t, pad_h, pad_w, padding_type): + assert padding_type in ["zeros", "ignore", "nearest"] + B, T, H, W, C = x.shape + if pad_t == 0 and pad_h == 0 and pad_w == 0: + return x + if padding_type == "nearest": + return paddle.nn.functional.interpolate( + x=x.transpose(perm=[0, 4, 1, 2, 3]), size=(T - pad_t, H - pad_h, W - pad_w) + ).transpose(perm=[0, 2, 3, 4, 1]) + else: + return x[:, : T - pad_t, : H - pad_h, : W - pad_w, :] + + +def apply_initialization( + m: paddle.nn.Layer, + linear_mode: str = "0", + conv_mode: str = "0", + norm_mode: str = "0", + embed_mode: str = "0", +): + if isinstance(m, paddle.nn.Linear): + if linear_mode in ("0",): + m.weight = initializer.kaiming_normal_(m.weight, nonlinearity="linear") + elif linear_mode in ("1",): + m.weight = initializer.kaiming_normal_( + m.weight, a=0.1, mode="fan_out", nonlinearity="leaky_relu" + ) + else: + raise NotImplementedError(f"{linear_mode} is invalid.") + if hasattr(m, "bias") and m.bias is not None: + m.bias = initializer.zeros_(m.bias) + elif isinstance( + m, + ( + paddle.nn.Conv2D, + paddle.nn.Conv3D, + paddle.nn.Conv2DTranspose, + paddle.nn.Conv3DTranspose, + ), + ): + if conv_mode in ("0",): + m.weight = initializer.kaiming_normal_( + m.weight, a=0.1, mode="fan_out", nonlinearity="leaky_relu" + ) + else: + raise NotImplementedError(f"{conv_mode} is invalid.") + if hasattr(m, "bias") and m.bias is not None: + m.bias = initializer.zeros_(m.bias) + elif isinstance(m, paddle.nn.LayerNorm): + if norm_mode in ("0",): + m.weight = initializer.zeros_(m.weight) + m.bias = initializer.zeros_(m.bias) + else: + raise NotImplementedError(f"{norm_mode} is invalid.") + elif isinstance(m, paddle.nn.GroupNorm): + if norm_mode in ("0",): + m.weight = initializer.ones_(m.weight) + m.bias = initializer.zeros_(m.bias) + else: + raise NotImplementedError(f"{norm_mode} is invalid.") + elif isinstance(m, paddle.nn.Embedding): + if embed_mode in ("0",): + m.weight.data = initializer.trunc_normal_(m.weight.data, std=0.02) + else: + raise NotImplementedError(f"{embed_mode} is invalid.") + + else: + pass + + +class CuboidSelfAttentionPatterns: + def __init__(self): + super().__init__() + self.patterns = {} + self.patterns = { + "full": self.full_attention, + "axial": self.axial, + "divided_st": self.divided_space_time, + } + for p in [1, 2, 4, 8, 10]: + for m in [1, 2, 4, 8, 16, 32]: + key = f"video_swin_{p}x{m}" + self.patterns[key] = functools.partial(self.video_swin, P=p, M=m) + + for m in [1, 2, 4, 8, 16, 32]: + key = f"spatial_lg_{m}" + self.patterns[key] = functools.partial(self.spatial_lg_v1, M=m) + + for k in [2, 4, 8]: + key = f"axial_space_dilate_{k}" + self.patterns[key] = functools.partial(self.axial_space_dilate_K, K=k) + + def get(self, pattern_name): + return self.patterns[pattern_name] + + def full_attention(self, input_shape): + T, H, W, _ = input_shape + cuboid_size = [(T, H, W)] + strategy = [("l", "l", "l")] + shift_size = [(0, 0, 0)] + return cuboid_size, strategy, shift_size + + def axial(self, input_shape): + """Axial attention proposed in https://arxiv.org/abs/1912.12180 + + Args: + input_shape (Tuple[int,...]): The shape of the input tensor, T H W. + + Returns: + cuboid_size (Tuple[int,...]): The size of cuboid. + strategy (Tuple[str,...]): The strategy of the attention. + shift_size (Tuple[int,...]): The shift size of the attention. + """ + + T, H, W, _ = input_shape + cuboid_size = [(T, 1, 1), (1, H, 1), (1, 1, W)] + strategy = [("l", "l", "l"), ("l", "l", "l"), ("l", "l", "l")] + shift_size = [(0, 0, 0), (0, 0, 0), (0, 0, 0)] + return cuboid_size, strategy, shift_size + + def divided_space_time(self, input_shape): + T, H, W, _ = input_shape + cuboid_size = [(T, 1, 1), (1, H, W)] + strategy = [("l", "l", "l"), ("l", "l", "l")] + shift_size = [(0, 0, 0), (0, 0, 0)] + return cuboid_size, strategy, shift_size + + def video_swin(self, input_shape, P=2, M=4): + """Adopt the strategy in Video SwinTransformer https://arxiv.org/pdf/2106.13230.pdf""" + T, H, W, _ = input_shape + P = min(P, T) + M = min(M, H, W) + cuboid_size = [(P, M, M), (P, M, M)] + strategy = [("l", "l", "l"), ("l", "l", "l")] + shift_size = [(0, 0, 0), (P // 2, M // 2, M // 2)] + return cuboid_size, strategy, shift_size + + def spatial_lg_v1(self, input_shape, M=4): + T, H, W, _ = input_shape + if H <= M and W <= M: + cuboid_size = [(T, 1, 1), (1, H, W)] + strategy = [("l", "l", "l"), ("l", "l", "l")] + shift_size = [(0, 0, 0), (0, 0, 0)] + else: + cuboid_size = [(T, 1, 1), (1, M, M), (1, M, M)] + strategy = [("l", "l", "l"), ("l", "l", "l"), ("d", "d", "d")] + shift_size = [(0, 0, 0), (0, 0, 0), (0, 0, 0)] + return cuboid_size, strategy, shift_size + + def axial_space_dilate_K(self, input_shape, K=2): + T, H, W, _ = input_shape + K = min(K, H, W) + cuboid_size = [ + (T, 1, 1), + (1, H // K, 1), + (1, H // K, 1), + (1, 1, W // K), + (1, 1, W // K), + ] + strategy = [ + ("l", "l", "l"), + ("d", "d", "d"), + ("l", "l", "l"), + ("d", "d", "d"), + ("l", "l", "l"), + ] + shift_size = [(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0)] + return cuboid_size, strategy, shift_size + + +class CuboidCrossAttentionPatterns: + def __init__(self): + super().__init__() + self.patterns = {} + for k in [1, 2, 4, 8]: + key1 = f"cross_{k}x{k}" + key2 = f"cross_{k}x{k}_lg" + key3 = f"cross_{k}x{k}_heter" + self.patterns[key1] = functools.partial(self.cross_KxK, K=k) + self.patterns[key2] = functools.partial(self.cross_KxK_lg, K=k) + self.patterns[key3] = functools.partial(self.cross_KxK_heter, K=k) + + def get(self, pattern_name): + return self.patterns[pattern_name] + + def cross_KxK(self, mem_shape, K): + T_mem, H, W, _ = mem_shape + K = min(K, H, W) + cuboid_hw = [(K, K)] + shift_hw = [(0, 0)] + strategy = [("l", "l", "l")] + n_temporal = [1] + return cuboid_hw, shift_hw, strategy, n_temporal + + def cross_KxK_lg(self, mem_shape, K): + T_mem, H, W, _ = mem_shape + K = min(K, H, W) + cuboid_hw = [(K, K), (K, K)] + shift_hw = [(0, 0), (0, 0)] + strategy = [("l", "l", "l"), ("d", "d", "d")] + n_temporal = [1, 1] + return cuboid_hw, shift_hw, strategy, n_temporal + + def cross_KxK_heter(self, mem_shape, K): + T_mem, H, W, _ = mem_shape + K = min(K, H, W) + cuboid_hw = [(K, K), (K, K), (K, K)] + shift_hw = [(0, 0), (0, 0), (K // 2, K // 2)] + strategy = [("l", "l", "l"), ("d", "d", "d"), ("l", "l", "l")] + n_temporal = [1, 1, 1] + return cuboid_hw, shift_hw, strategy, n_temporal + + +CuboidSelfAttentionPatterns = CuboidSelfAttentionPatterns() +CuboidCrossAttentionPatterns = CuboidCrossAttentionPatterns() diff --git a/ppsci/data/dataset/__init__.py b/ppsci/data/dataset/__init__.py index 9979ac801..c0eebe860 100644 --- a/ppsci/data/dataset/__init__.py +++ b/ppsci/data/dataset/__init__.py @@ -24,6 +24,7 @@ from ppsci.data.dataset.csv_dataset import IterableCSVDataset from ppsci.data.dataset.cylinder_dataset import MeshCylinderDataset from ppsci.data.dataset.dgmr_dataset import DGMRDataset +from ppsci.data.dataset.enso_dataset import ENSODataset from ppsci.data.dataset.era5_dataset import ERA5Dataset from ppsci.data.dataset.era5_dataset import ERA5SampledDataset from ppsci.data.dataset.mat_dataset import IterableMatDataset @@ -33,6 +34,7 @@ from ppsci.data.dataset.npz_dataset import IterableNPZDataset from ppsci.data.dataset.npz_dataset import NPZDataset from ppsci.data.dataset.radar_dataset import RadarDataset +from ppsci.data.dataset.sevir_dataset import SEVIRDataset from ppsci.data.dataset.trphysx_dataset import CylinderDataset from ppsci.data.dataset.trphysx_dataset import LorenzDataset from ppsci.data.dataset.trphysx_dataset import RosslerDataset @@ -66,6 +68,8 @@ "DGMRDataset", "MeshAirfoilDataset", "MeshCylinderDataset", + "ENSODataset", + "SEVIRDataset", "build_dataset", ] diff --git a/ppsci/data/dataset/enso_dataset.py b/ppsci/data/dataset/enso_dataset.py new file mode 100644 index 000000000..891c3a0df --- /dev/null +++ b/ppsci/data/dataset/enso_dataset.py @@ -0,0 +1,421 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib +from typing import Dict +from typing import Optional +from typing import Tuple + +import numpy as np +from paddle import io + +try: + from pathlib import Path + + import xarray as xr +except ModuleNotFoundError: + pass + +NINO_WINDOW_T = 3 # Nino index is the sliding average over sst, window size is 3 +CMIP6_SST_MAX = 10.198975563049316 +CMIP6_SST_MIN = -16.549121856689453 +CMIP5_SST_MAX = 8.991744995117188 +CMIP5_SST_MIN = -9.33076286315918 +CMIP6_NINO_MAX = 4.138188362121582 +CMIP6_NINO_MIN = -3.5832221508026123 +CMIP5_NINO_MAX = 3.8253555297851562 +CMIP5_NINO_MIN = -2.691682815551758 +SST_MAX = max(CMIP6_SST_MAX, CMIP5_SST_MAX) +SST_MIN = min(CMIP6_SST_MIN, CMIP5_SST_MIN) + + +def scale_sst(sst): + return (sst - SST_MIN) / (SST_MAX - SST_MIN) + + +def scale_back_sst(sst): + return (SST_MAX - SST_MIN) * sst + SST_MIN + + +def prepare_inputs_targets( + len_time, input_gap, input_length, pred_shift, pred_length, samples_gap +): + """Prepares the input and target indices for training. + + Args: + len_time (int): The total number of time steps in the dataset. + input_gap (int): time gaps between two consecutive input frames. + input_length (int): the number of input frames. + pred_shift (int): the lead_time of the last target to be predicted. + pred_length (int): the number of frames to be predicted. + samples_gap (int): stride of seq sampling. + + """ + + if pred_shift < pred_length: + raise ValueError("pred_shift should be small than pred_length") + input_span = input_gap * (input_length - 1) + 1 + pred_gap = pred_shift // pred_length + input_ind = np.arange(0, input_span, input_gap) + target_ind = np.arange(0, pred_shift, pred_gap) + input_span + pred_gap - 1 + ind = np.concatenate([input_ind, target_ind]).reshape(1, input_length + pred_length) + max_n_sample = len_time - (input_span + pred_shift - 1) + ind = ind + np.arange(max_n_sample)[:, np.newaxis] @ np.ones( + (1, input_length + pred_length), dtype=int + ) + return ind[::samples_gap] + + +def fold(data, size=36, stride=12): + """inverse of unfold/sliding window operation + only applicable to the case where the size of the sliding windows is n*stride + + Args: + data (tuple[int,...]): The input data.(N, size, *). + size (int, optional): The size of a single datum.The Defaults to 36. + stride (int, optional): The step.Defaults to 12. + + Returns: + outdata (np.array): (N_, *).N/size is the number/width of sliding blocks + """ + + if size % stride != 0: + raise ValueError("size modulo stride should be zero") + times = size // stride + remain = (data.shape[0] - 1) % times + if remain > 0: + ls = list(data[::times]) + [data[-1, -(remain * stride) :]] + outdata = np.concatenate(ls, axis=0) # (36*(151//3+1)+remain*stride, *, 15) + else: + outdata = np.concatenate(data[::times], axis=0) # (36*(151/3+1), *, 15) + assert ( + outdata.shape[0] == size * ((data.shape[0] - 1) // times + 1) + remain * stride + ) + return outdata + + +def data_transform(data, num_years_per_model): + """The transform of the input data. + + Args: + data (Tuple[list,...]): The input data.Shape of (N, 36, *). + num_years_per_model (int): The number of years associated with each model.151/140. + + """ + + length = data.shape[0] + assert length % num_years_per_model == 0 + num_models = length // num_years_per_model + outdata = np.stack( + np.split(data, length / num_years_per_model, axis=0), axis=-1 + ) # (151, 36, *, 15) + # cmip6sst outdata.shape = (151, 36, 24, 48, 15) = (year, month, lat, lon, model) + # cmip5sst outdata.shape = (140, 36, 24, 48, 17) + # cmip6nino outdata.shape = (151, 36, 15) + # cmip5nino outdata.shape = (140, 36, 17) + outdata = fold(outdata, size=36, stride=12) + # cmip6sst outdata.shape = (1836, 24, 48, 15), 1836 == 151 * 12 + 24 + # cmip5sst outdata.shape = (1704, 24, 48, 17) + # cmip6nino outdata.shape = (1836, 15) + # cmip5nino outdata.shape = (1704, 17) + + # check output data + assert outdata.shape[-1] == num_models + assert not np.any(np.isnan(outdata)) + return outdata + + +def read_raw_data(ds_dir, out_dir=None): + """read and process raw cmip data from CMIP_train.nc and CMIP_label.nc + + Args: + ds_dir (str): the path of the dataset. + out_dir (str): the path of output. Defaults to None. + + """ + + train_cmip = xr.open_dataset(Path(ds_dir) / "CMIP_train.nc").transpose( + "year", "month", "lat", "lon" + ) + label_cmip = xr.open_dataset(Path(ds_dir) / "CMIP_label.nc").transpose( + "year", "month" + ) + # train_cmip.sst.values.shape = (4645, 36, 24, 48) + + # select longitudes + lon = train_cmip.lon.values + lon = lon[np.logical_and(lon >= 95, lon <= 330)] + train_cmip = train_cmip.sel(lon=lon) + + cmip6sst = data_transform( + data=train_cmip.sst.values[:2265], num_years_per_model=151 + ) + cmip5sst = data_transform( + data=train_cmip.sst.values[2265:], num_years_per_model=140 + ) + cmip6nino = data_transform( + data=label_cmip.nino.values[:2265], num_years_per_model=151 + ) + cmip5nino = data_transform( + data=label_cmip.nino.values[2265:], num_years_per_model=140 + ) + + # cmip6sst.shape = (1836, 24, 48, 15) + # cmip5sst.shape = (1704, 24, 48, 17) + assert len(cmip6sst.shape) == 4 + assert len(cmip5sst.shape) == 4 + assert len(cmip6nino.shape) == 2 + assert len(cmip5nino.shape) == 2 + # store processed data for faster data access + if out_dir is not None: + ds_cmip6 = xr.Dataset( + { + "sst": (["month", "lat", "lon", "model"], cmip6sst), + "nino": (["month", "model"], cmip6nino), + }, + coords={ + "month": np.repeat( + np.arange(1, 13)[None], cmip6nino.shape[0] // 12, axis=0 + ).flatten(), + "lat": train_cmip.lat.values, + "lon": train_cmip.lon.values, + "model": np.arange(15) + 1, + }, + ) + ds_cmip6.to_netcdf(Path(out_dir) / "cmip6.nc") + ds_cmip5 = xr.Dataset( + { + "sst": (["month", "lat", "lon", "model"], cmip5sst), + "nino": (["month", "model"], cmip5nino), + }, + coords={ + "month": np.repeat( + np.arange(1, 13)[None], cmip5nino.shape[0] // 12, axis=0 + ).flatten(), + "lat": train_cmip.lat.values, + "lon": train_cmip.lon.values, + "model": np.arange(17) + 1, + }, + ) + ds_cmip5.to_netcdf(Path(out_dir) / "cmip5.nc") + train_cmip.close() + label_cmip.close() + return cmip6sst, cmip5sst, cmip6nino, cmip5nino + + +def cat_over_last_dim(data): + """treat different models (15 from CMIP6, 17 from CMIP5) as batch_size + e.g., cmip6sst.shape = (178, 38, 24, 48, 15), converted_cmip6sst.shape = (2670, 38, 24, 48) + e.g., cmip5sst.shape = (165, 38, 24, 48, 15), converted_cmip6sst.shape = (2475, 38, 24, 48) + + """ + + return np.concatenate(np.moveaxis(data, -1, 0), axis=0) + + +class ENSODataset(io.Dataset): + """The El NiƱo/Southern Oscillation dataset. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + label_keys (Tuple[str, ...]): Name of label keys, such as ("output",). + data_dir (str): The directory of data. + weight_dict (Optional[Dict[str, Union[Callable, float]]]): Define the weight of each constraint variable. Defaults to None. + in_len (int, optional): The length of input data. Defaults to 12. + out_len (int, optional): The length of out data. Defaults to 26. + in_stride (int, optional): The stride of input data. Defaults to 1. + out_stride (int, optional): The stride of output data. Defaults to 1. + train_samples_gap (int, optional): The stride of sequence sampling during training. Defaults to 10. + e.g., samples_gap = 10, the first seq contains [0, 1, ..., T-1] frame indices, the second seq contains [10, 11, .., T+9] + eval_samples_gap (int, optional): The stride of sequence sampling during eval. Defaults to 11. + normalize_sst (bool, optional): Whether to use normalization. Defaults to True. + batch_size (int, optional): Batch size. Defaults to 1. + num_workers (int, optional): The num of workers. Defaults to 1. + training (str, optional): Training pathse. Defaults to "train". + + """ + + # Whether support batch indexing for speeding up fetching process. + batch_index: bool = False + + def __init__( + self, + input_keys: Tuple[str, ...], + label_keys: Tuple[str, ...], + data_dir: str, + weight_dict: Optional[Dict[str, float]] = None, + in_len=12, + out_len=26, + in_stride=1, + out_stride=1, + train_samples_gap=10, + eval_samples_gap=11, + normalize_sst=True, + # datamodule_only + batch_size=1, + num_workers=1, + training="train", + ): + super(ENSODataset, self).__init__() + if importlib.util.find_spec("xarray") is None: + raise ModuleNotFoundError( + "To use RadarDataset, please install 'xarray' via: `pip install " + "xarray` first." + ) + if importlib.util.find_spec("pathlib") is None: + raise ModuleNotFoundError( + "To use RadarDataset, please install 'pathlib' via: `pip install " + "pathlib` first." + ) + self.input_keys = input_keys + self.label_keys = label_keys + self.data_dir = data_dir + self.weight_dict = {} if weight_dict is None else weight_dict + if weight_dict is not None: + self.weight_dict = {key: 1.0 for key in self.label_keys} + self.weight_dict.update(weight_dict) + + self.in_len = in_len + self.out_len = out_len + self.in_stride = in_stride + self.out_stride = out_stride + self.train_samples_gap = train_samples_gap + self.eval_samples_gap = eval_samples_gap + self.normalize_sst = normalize_sst + # datamodule_only + self.batch_size = batch_size + if num_workers != 1: + raise ValueError( + "Current implementation does not support `num_workers != 1`!" + ) + self.num_workers = num_workers + self.training = training + + # pre-data + cmip6sst, cmip5sst, cmip6nino, cmip5nino = read_raw_data(self.data_dir) + # TODO: more flexible train/val/test split + self.sst_train = [cmip6sst, cmip5sst[..., :-2]] + self.nino_train = [cmip6nino, cmip5nino[..., :-2]] + self.sst_eval = [cmip5sst[..., -2:-1]] + self.nino_eval = [cmip5nino[..., -2:-1]] + self.sst_test = [cmip5sst[..., -1:]] + self.nino_test = [cmip5nino[..., -1:]] + + self.sst, self.target_nino = self.create_data() + + def create_data( + self, + ): + if self.training == "train": + sst_cmip6 = self.sst_train[0] + nino_cmip6 = self.nino_train[0] + sst_cmip5 = self.sst_train[1] + nino_cmip5 = self.nino_train[1] + samples_gap = self.train_samples_gap + elif self.training == "eval": + sst_cmip6 = None + nino_cmip6 = None + sst_cmip5 = self.sst_eval[0] + nino_cmip5 = self.nino_eval[0] + samples_gap = self.eval_samples_gap + elif self.training == "test": + sst_cmip6 = None + nino_cmip6 = None + sst_cmip5 = self.sst_test[0] + nino_cmip5 = self.nino_test[0] + samples_gap = self.eval_samples_gap + + # cmip6 (N, *, 15) + # cmip5 (N, *, 17) + sst = [] + target_nino = [] + + nino_idx_slice = slice( + self.in_len, self.in_len + self.out_len - NINO_WINDOW_T + 1 + ) # e.g., 12:36 + if sst_cmip6 is not None: + assert len(sst_cmip6.shape) == 4 + assert len(nino_cmip6.shape) == 2 + idx_sst = prepare_inputs_targets( + len_time=sst_cmip6.shape[0], + input_length=self.in_len, + input_gap=self.in_stride, + pred_shift=self.out_len * self.out_stride, + pred_length=self.out_len, + samples_gap=samples_gap, + ) + + sst.append(cat_over_last_dim(sst_cmip6[idx_sst])) + target_nino.append( + cat_over_last_dim(nino_cmip6[idx_sst[:, nino_idx_slice]]) + ) + if sst_cmip5 is not None: + assert len(sst_cmip5.shape) == 4 + assert len(nino_cmip5.shape) == 2 + idx_sst = prepare_inputs_targets( + len_time=sst_cmip5.shape[0], + input_length=self.in_len, + input_gap=self.in_stride, + pred_shift=self.out_len * self.out_stride, + pred_length=self.out_len, + samples_gap=samples_gap, + ) + sst.append(cat_over_last_dim(sst_cmip5[idx_sst])) + target_nino.append( + cat_over_last_dim(nino_cmip5[idx_sst[:, nino_idx_slice]]) + ) + + # sst data containing both the input and target + self.sst = np.concatenate(sst, axis=0) # (N, in_len+out_len, lat, lon) + if self.normalize_sst: + self.sst = scale_sst(self.sst) + # nino data containing the target only + self.target_nino = np.concatenate( + target_nino, axis=0 + ) # (N, out_len+NINO_WINDOW_T-1) + assert self.sst.shape[0] == self.target_nino.shape[0] + assert self.sst.shape[1] == self.in_len + self.out_len + assert self.target_nino.shape[1] == self.out_len - NINO_WINDOW_T + 1 + return self.sst, self.target_nino + + def get_datashape(self): + return {"sst": self.sst.shape, "nino target": self.target_nino.shape} + + def __len__(self): + return self.sst.shape[0] + + def __getitem__(self, idx): + sst_data = self.sst[idx].astype("float32") + sst_data = sst_data[..., np.newaxis] + in_seq = sst_data[: self.in_len, ...] # ( in_len, lat, lon, 1) + target_seq = sst_data[self.in_len :, ...] # ( in_len, lat, lon, 1) + weight_item = self.weight_dict + + if self.training == "train": + input_item = {self.input_keys[0]: in_seq} + label_item = { + self.label_keys[0]: target_seq, + } + + return input_item, label_item, weight_item + else: + input_item = {self.input_keys[0]: in_seq} + label_item = { + self.label_keys[0]: target_seq, + self.label_keys[1]: self.target_nino[idx], + } + + return input_item, label_item, weight_item diff --git a/ppsci/data/dataset/sevir_dataset.py b/ppsci/data/dataset/sevir_dataset.py new file mode 100644 index 000000000..63ee225c4 --- /dev/null +++ b/ppsci/data/dataset/sevir_dataset.py @@ -0,0 +1,806 @@ +import datetime +import os +from copy import deepcopy +from typing import Dict +from typing import Optional +from typing import Sequence +from typing import Tuple + +import h5py +import numpy as np +import paddle +import paddle.nn.functional as F +import pandas as pd +from paddle import io + +# SEVIR Dataset constants +SEVIR_DATA_TYPES = ["vis", "ir069", "ir107", "vil", "lght"] +SEVIR_RAW_DTYPES = { + "vis": np.int16, + "ir069": np.int16, + "ir107": np.int16, + "vil": np.uint8, + "lght": np.int16, +} +LIGHTING_FRAME_TIMES = np.arange(-120.0, 125.0, 5) * 60 +SEVIR_DATA_SHAPE = { + "lght": (48, 48), +} +PREPROCESS_SCALE_SEVIR = { + "vis": 1, # Not utilized in original paper + "ir069": 1 / 1174.68, + "ir107": 1 / 2562.43, + "vil": 1 / 47.54, + "lght": 1 / 0.60517, +} +PREPROCESS_OFFSET_SEVIR = { + "vis": 0, # Not utilized in original paper + "ir069": 3683.58, + "ir107": 1552.80, + "vil": -33.44, + "lght": -0.02990, +} +PREPROCESS_SCALE_01 = { + "vis": 1, + "ir069": 1, + "ir107": 1, + "vil": 1 / 255, # currently the only one implemented + "lght": 1, +} +PREPROCESS_OFFSET_01 = { + "vis": 0, + "ir069": 0, + "ir107": 0, + "vil": 0, # currently the only one implemented + "lght": 0, +} + + +def change_layout_np(data, in_layout="NHWT", out_layout="NHWT", ret_contiguous=False): + # first convert to 'NHWT' + if in_layout == "NHWT": + pass + elif in_layout == "NTHW": + data = np.transpose(data, axes=(0, 2, 3, 1)) + elif in_layout == "NWHT": + data = np.transpose(data, axes=(0, 2, 1, 3)) + elif in_layout == "NTCHW": + data = data[:, :, 0, :, :] + data = np.transpose(data, axes=(0, 2, 3, 1)) + elif in_layout == "NTHWC": + data = data[:, :, :, :, 0] + data = np.transpose(data, axes=(0, 2, 3, 1)) + elif in_layout == "NTWHC": + data = data[:, :, :, :, 0] + data = np.transpose(data, axes=(0, 3, 2, 1)) + elif in_layout == "TNHW": + data = np.transpose(data, axes=(1, 2, 3, 0)) + elif in_layout == "TNCHW": + data = data[:, :, 0, :, :] + data = np.transpose(data, axes=(1, 2, 3, 0)) + else: + raise NotImplementedError(f"{in_layout} is invalid.") + + if out_layout == "NHWT": + pass + elif out_layout == "NTHW": + data = np.transpose(data, axes=(0, 3, 1, 2)) + elif out_layout == "NWHT": + data = np.transpose(data, axes=(0, 2, 1, 3)) + elif out_layout == "NTCHW": + data = np.transpose(data, axes=(0, 3, 1, 2)) + data = np.expand_dims(data, axis=2) + elif out_layout == "NTHWC": + data = np.transpose(data, axes=(0, 3, 1, 2)) + data = np.expand_dims(data, axis=-1) + elif out_layout == "NTWHC": + data = np.transpose(data, axes=(0, 3, 2, 1)) + data = np.expand_dims(data, axis=-1) + elif out_layout == "TNHW": + data = np.transpose(data, axes=(3, 0, 1, 2)) + elif out_layout == "TNCHW": + data = np.transpose(data, axes=(3, 0, 1, 2)) + data = np.expand_dims(data, axis=2) + else: + raise NotImplementedError(f"{out_layout} is invalid.") + if ret_contiguous: + data = data.ascontiguousarray() + return data + + +def change_layout_paddle( + data, in_layout="NHWT", out_layout="NHWT", ret_contiguous=False +): + # first convert to 'NHWT' + if in_layout == "NHWT": + pass + elif in_layout == "NTHW": + data = data.transpose(perm=[0, 2, 3, 1]) + elif in_layout == "NTCHW": + data = data[:, :, 0, :, :] + data = data.transpose(perm=[0, 2, 3, 1]) + elif in_layout == "NTHWC": + data = data[:, :, :, :, 0] + data = data.transpose(perm=[0, 2, 3, 1]) + elif in_layout == "TNHW": + data = data.transpose(perm=[1, 2, 3, 0]) + elif in_layout == "TNCHW": + data = data[:, :, 0, :, :] + data = data.transpose(perm=[1, 2, 3, 0]) + else: + raise NotImplementedError(f"{in_layout} is invalid.") + + if out_layout == "NHWT": + pass + elif out_layout == "NTHW": + data = data.transpose(perm=[0, 3, 1, 2]) + elif out_layout == "NTCHW": + data = data.transpose(perm=[0, 3, 1, 2]) + data = paddle.unsqueeze(data, axis=2) + elif out_layout == "NTHWC": + data = data.transpose(perm=[0, 3, 1, 2]) + data = paddle.unsqueeze(data, axis=-1) + elif out_layout == "TNHW": + data = data.transpose(perm=[3, 0, 1, 2]) + elif out_layout == "TNCHW": + data = data.transpose(perm=[3, 0, 1, 2]) + data = paddle.unsqueeze(data, axis=2) + else: + raise NotImplementedError(f"{out_layout} is invalid.") + return data + + +def path_splitall(path): + allparts = [] + while 1: + parts = os.path.split(path) + if parts[0] == path: # sentinel for absolute paths + allparts.insert(0, parts[0]) + break + elif parts[1] == path: # sentinel for relative paths + allparts.insert(0, parts[1]) + break + else: + path = parts[0] + allparts.insert(0, parts[1]) + return allparts + + +class SEVIRDataset(io.Dataset): + """The Storm EVent ImagRy dataset. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + label_keys (Tuple[str, ...]): Name of label keys, such as ("output",). + data_dir (str): The path of the dataset. + weight_dict (Optional[Dict[str, Union[Callable, float]]]): Define the weight of each constraint variable. Defaults to None. + data_types (Sequence[str], optional): A subset of SEVIR_DATA_TYPES. Defaults to [ "vil", ]. + seq_len (int, optional): The length of the data sequences. Should be smaller than the max length raw_seq_len. Defaults to 49. + raw_seq_len (int, optional): The length of the raw data sequences. Defaults to 49. + sample_mode (str, optional): The mode of sampling, eg.'random' or 'sequent'. Defaults to "sequent". + stride (int, optional): Useful when sample_mode == 'sequent' + stride must not be smaller than out_len to prevent data leakage in testing. Defaults to 12. + batch_size (int, optional): The batch size. Defaults to 1. + layout (str, optional): consists of batch_size 'N', seq_len 'T', channel 'C', height 'H', width 'W' + The layout of sampled data. Raw data layout is 'NHWT'. + valid layout: 'NHWT', 'NTHW', 'NTCHW', 'TNHW', 'TNCHW'. Defaults to "NHWT". + in_len (int, optional): The length of input data. Defaults to 13. + out_len (int, optional): The length of output data. Defaults to 12. + num_shard (int, optional): Split the whole dataset into num_shard parts for distributed training. Defaults to 1. + rank (int, optional): Rank of the current process within num_shard. Defaults to 0. + split_mode (str, optional): if 'ceil', all `num_shard` dataloaders have the same length = ceil(total_len / num_shard). + Different dataloaders may have some duplicated data batches, if the total size of datasets is not divided by num_shard. + if 'floor', all `num_shard` dataloaders have the same length = floor(total_len / num_shard). + The last several data batches may be wasted, if the total size of datasets is not divided by num_shard. + if 'uneven', the last datasets has larger length when the total length is not divided by num_shard. + The uneven split leads to synchronization error in dist.all_reduce() or dist.barrier(). + See related issue: https://github.com/pytorch/pytorch/issues/33148 + Notice: this also affects the behavior of `self.use_up`. Defaults to "uneven". + start_date (datetime.datetime, optional): Start time of SEVIR samples to generate. Defaults to None. + end_date (datetime.datetime, optional): End time of SEVIR samples to generate. Defaults to None. + datetime_filter (function, optional): Mask function applied to time_utc column of catalog (return true to keep the row). + Pass function of the form lambda t : COND(t) + Example: lambda t: np.logical_and(t.dt.hour>=13,t.dt.hour<=21) # Generate only day-time events. Defaults to None. + catalog_filter (function, optional): function or None or 'default' + Mask function applied to entire catalog dataframe (return true to keep row). + Pass function of the form lambda catalog: COND(catalog) + Example: lambda c: [s[0]=='S' for s in c.id] # Generate only the 'S' events + shuffle (bool, optional): If True, data samples are shuffled before each epoch. Defaults to False. + shuffle_seed (int, optional): Seed to use for shuffling. Defaults to 1. + output_type (np.dtype, optional): The type of generated tensors. Defaults to np.float32. + preprocess (bool, optional): If True, self.preprocess_data_dict(data_dict) is called before each sample generated. Defaults to True. + rescale_method (str, optional): The method of rescale. Defaults to "01". + downsample_dict (Dict[str, Sequence[int]], optional): downsample_dict.keys() == data_types. downsample_dict[key] is a Sequence of + (t_factor, h_factor, w_factor),representing the downsampling factors of all dimensions. Defaults to None. + verbose (bool, optional): Verbose when opening raw data files. Defaults to False. + training (str, optional): Training pathse. Defaults to "train". + """ + + # Whether support batch indexing for speeding up fetching process. + batch_index: bool = False + + def __init__( + self, + input_keys: Tuple[str, ...], + label_keys: Tuple[str, ...], + data_dir: str, + weight_dict: Optional[Dict[str, float]] = None, + data_types: Sequence[str] = [ + "vil", + ], + seq_len: int = 49, + raw_seq_len: int = 49, + sample_mode: str = "sequent", + stride: int = 12, + batch_size: int = 1, + layout: str = "NHWT", + in_len: int = 13, + out_len: int = 12, + num_shard: int = 1, + rank: int = 0, + split_mode: str = "uneven", + start_date: datetime.datetime = None, + end_date: datetime.datetime = None, + datetime_filter=None, + catalog_filter="default", + shuffle: bool = False, + shuffle_seed: int = 1, + output_type=np.float32, + preprocess: bool = True, + rescale_method: str = "01", + downsample_dict: Dict[str, Sequence[int]] = None, + verbose: bool = False, + training="train", + ): + super(SEVIRDataset, self).__init__() + self.input_keys = input_keys + self.label_keys = label_keys + self.data_dir = data_dir + self.weight_dict = {} if weight_dict is None else weight_dict + if weight_dict is not None: + self.weight_dict = {key: 1.0 for key in self.label_keys} + self.weight_dict.update(weight_dict) + + # sevir + SEVIR_ROOT_DIR = os.path.join(self.data_dir, "sevir") + sevir_catalog = os.path.join(SEVIR_ROOT_DIR, "CATALOG.csv") + sevir_data_dir = os.path.join(SEVIR_ROOT_DIR, "data") + # sevir-lr + # SEVIR_ROOT_DIR = os.path.join(self.data_dir, "sevir_lr") + # SEVIR_CATALOG = os.path.join(SEVIR_ROOT_DIR, "CATALOG.csv") + # SEVIR_DATA_DIR = os.path.join(SEVIR_ROOT_DIR, "data") + + if data_types is None: + data_types = SEVIR_DATA_TYPES + else: + assert set(data_types).issubset(SEVIR_DATA_TYPES) + + # configs which should not be modified + self._dtypes = SEVIR_RAW_DTYPES + self.lght_frame_times = LIGHTING_FRAME_TIMES + self.data_shape = SEVIR_DATA_SHAPE + + self.raw_seq_len = raw_seq_len + self.seq_len = seq_len + + if seq_len > raw_seq_len: + raise ValueError("seq_len must be small than raw_seq_len") + + if sample_mode not in ["random", "sequent"]: + raise ValueError("sample_mode must be 'random' or 'sequent'.") + + self.sample_mode = sample_mode + self.stride = stride + self.batch_size = batch_size + valid_layout = ("NHWT", "NTHW", "NTCHW", "NTHWC", "TNHW", "TNCHW") + if layout not in valid_layout: + raise ValueError( + f"Invalid layout = {layout}! Must be one of {valid_layout}." + ) + self.layout = layout + self.in_len = in_len + self.out_len = out_len + + self.num_shard = num_shard + self.rank = rank + valid_split_mode = ("ceil", "floor", "uneven") + if split_mode not in valid_split_mode: + raise ValueError( + f"Invalid split_mode: {split_mode}! Must be one of {valid_split_mode}." + ) + self.split_mode = split_mode + self._samples = None + self._hdf_files = {} + self.data_types = data_types + if isinstance(sevir_catalog, str): + self.catalog = pd.read_csv( + sevir_catalog, parse_dates=["time_utc"], low_memory=False + ) + else: + self.catalog = sevir_catalog + self.sevir_data_dir = sevir_data_dir + self.datetime_filter = datetime_filter + self.catalog_filter = catalog_filter + self.start_date = start_date + self.end_date = end_date + # train val test split + self.start_date = ( + datetime.datetime(*start_date) if start_date is not None else None + ) + self.end_date = datetime.datetime(*end_date) if end_date is not None else None + + self.shuffle = shuffle + self.shuffle_seed = int(shuffle_seed) + self.output_type = output_type + self.preprocess = preprocess + self.downsample_dict = downsample_dict + self.rescale_method = rescale_method + self.verbose = verbose + + if self.start_date is not None: + self.catalog = self.catalog[self.catalog.time_utc > self.start_date] + if self.end_date is not None: + self.catalog = self.catalog[self.catalog.time_utc <= self.end_date] + if self.datetime_filter: + self.catalog = self.catalog[self.datetime_filter(self.catalog.time_utc)] + + if self.catalog_filter is not None: + if self.catalog_filter == "default": + self.catalog_filter = lambda c: c.pct_missing == 0 + self.catalog = self.catalog[self.catalog_filter(self.catalog)] + + self._compute_samples() + self._open_files(verbose=self.verbose) + + def _compute_samples(self): + """ + Computes the list of samples in catalog to be used. This sets self._samples + """ + # locate all events containing colocated data_types + imgt = self.data_types + imgts = set(imgt) + filtcat = self.catalog[ + np.logical_or.reduce([self.catalog.img_type == i for i in imgt]) + ] + # remove rows missing one or more requested img_types + filtcat = filtcat.groupby("id").filter( + lambda x: imgts.issubset(set(x["img_type"])) + ) + # If there are repeated IDs, remove them (this is a bug in SEVIR) + # TODO: is it necessary to keep one of them instead of deleting them all + filtcat = filtcat.groupby("id").filter(lambda x: x.shape[0] == len(imgt)) + self._samples = filtcat.groupby("id").apply( + lambda df: self._df_to_series(df, imgt) + ) + if self.shuffle: + self.shuffle_samples() + + def shuffle_samples(self): + self._samples = self._samples.sample(frac=1, random_state=self.shuffle_seed) + + def _df_to_series(self, df, imgt): + d = {} + df = df.set_index("img_type") + for i in imgt: + s = df.loc[i] + idx = s.file_index if i != "lght" else s.id + d.update({f"{i}_filename": [s.file_name], f"{i}_index": [idx]}) + + return pd.DataFrame(d) + + def _open_files(self, verbose=True): + """ + Opens HDF files + """ + imgt = self.data_types + hdf_filenames = [] + for t in imgt: + hdf_filenames += list(np.unique(self._samples[f"{t}_filename"].values)) + self._hdf_files = {} + for f in hdf_filenames: + if verbose: + print("Opening HDF5 file for reading", f) + self._hdf_files[f] = h5py.File(self.sevir_data_dir + "/" + f, "r") + + def close(self): + """ + Closes all open file handles + """ + for f in self._hdf_files: + self._hdf_files[f].close() + self._hdf_files = {} + + @property + def num_seq_per_event(self): + return 1 + (self.raw_seq_len - self.seq_len) // self.stride + + @property + def total_num_seq(self): + """ + The total number of sequences within each shard. + Notice that it is not the product of `self.num_seq_per_event` and `self.total_num_event`. + """ + return int(self.num_seq_per_event * self.num_event) + + @property + def total_num_event(self): + """ + The total number of events in the whole dataset, before split into different shards. + """ + return int(self._samples.shape[0]) + + @property + def start_event_idx(self): + """ + The event idx used in certain rank should satisfy event_idx >= start_event_idx + """ + return self.total_num_event // self.num_shard * self.rank + + @property + def end_event_idx(self): + """ + The event idx used in certain rank should satisfy event_idx < end_event_idx + + """ + if self.split_mode == "ceil": + _last_start_event_idx = ( + self.total_num_event // self.num_shard * (self.num_shard - 1) + ) + _num_event = self.total_num_event - _last_start_event_idx + return self.start_event_idx + _num_event + elif self.split_mode == "floor": + return self.total_num_event // self.num_shard * (self.rank + 1) + else: # self.split_mode == 'uneven': + if self.rank == self.num_shard - 1: # the last process + return self.total_num_event + else: + return self.total_num_event // self.num_shard * (self.rank + 1) + + @property + def num_event(self): + """ + The number of events split into each rank + """ + return self.end_event_idx - self.start_event_idx + + def __len__(self): + """ + Used only when self.sample_mode == 'sequent' + """ + return self.total_num_seq // self.batch_size + + def _read_data(self, row, data): + """ + Iteratively read data into data dict. Finally data[imgt] gets shape (batch_size, height, width, raw_seq_len). + + Args: + row (Dict,optional): A series with fields IMGTYPE_filename, IMGTYPE_index, IMGTYPE_time_index. + data (Dict,optional): , data[imgt] is a data tensor with shape = (tmp_batch_size, height, width, raw_seq_len). + + Returns: + data (np.array): Updated data. Updated shape = (tmp_batch_size + 1, height, width, raw_seq_len). + """ + + imgtyps = np.unique([x.split("_")[0] for x in list(row.keys())]) + for t in imgtyps: + fname = row[f"{t}_filename"] + idx = row[f"{t}_index"] + t_slice = slice(0, None) + # Need to bin lght counts into grid + if t == "lght": + lght_data = self._hdf_files[fname][idx][:] + data_i = self._lght_to_grid(lght_data, t_slice) + else: + data_i = self._hdf_files[fname][t][idx : idx + 1, :, :, t_slice] + data[t] = ( + np.concatenate((data[t], data_i), axis=0) if (t in data) else data_i + ) + return data + + def _lght_to_grid(self, data, t_slice=slice(0, None)): + """ + Converts Nx5 lightning data matrix into a 2D grid of pixel counts + """ + # out_size = (48,48,len(self.lght_frame_times)-1) if isinstance(t_slice,(slice,)) else (48,48) + out_size = ( + (*self.data_shape["lght"], len(self.lght_frame_times)) + if t_slice.stop is None + else (*self.data_shape["lght"], 1) + ) + if data.shape[0] == 0: + return np.zeros((1,) + out_size, dtype=np.float32) + + # filter out points outside the grid + x, y = data[:, 3], data[:, 4] + m = np.logical_and.reduce([x >= 0, x < out_size[0], y >= 0, y < out_size[1]]) + data = data[m, :] + if data.shape[0] == 0: + return np.zeros((1,) + out_size, dtype=np.float32) + + # Filter/separate times + t = data[:, 0] + if t_slice.stop is not None: # select only one time bin + if t_slice.stop > 0: + if t_slice.stop < len(self.lght_frame_times): + tm = np.logical_and( + t >= self.lght_frame_times[t_slice.stop - 1], + t < self.lght_frame_times[t_slice.stop], + ) + else: + tm = t >= self.lght_frame_times[-1] + else: # special case: frame 0 uses lght from frame 1 + tm = np.logical_and( + t >= self.lght_frame_times[0], t < self.lght_frame_times[1] + ) + # tm=np.logical_and( (t>=FRAME_TIMES[t_slice],t self.end_event_idx: + pad_size = event_idx_slice_end - self.end_event_idx + event_idx_slice_end = self.end_event_idx + pd_batch = self._samples.iloc[event_idx:event_idx_slice_end] + data = {} + for index, row in pd_batch.iterrows(): + data = self._read_data(row, data) + if pad_size > 0: + event_batch = [] + for t in self.data_types: + pad_shape = [ + pad_size, + ] + list(data[t].shape[1:]) + data_pad = np.concatenate( + ( + data[t].astype(self.output_type), + np.zeros(pad_shape, dtype=self.output_type), + ), + axis=0, + ) + event_batch.append(data_pad) + else: + event_batch = [data[t].astype(self.output_type) for t in self.data_types] + return event_batch + + def __iter__(self): + return self + + @staticmethod + def preprocess_data_dict(data_dict, data_types=None, layout="NHWT", rescale="01"): + """The preprocess of data dict. + Args: + data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): The dict of data. + data_types (Sequence[str]) : The data types that we want to rescale. This mainly excludes "mask" from preprocessing. + layout (str) : consists of batch_size 'N', seq_len 'T', channel 'C', height 'H', width 'W'. + rescale (str): + 'sevir': use the offsets and scale factors in original implementation. + '01': scale all values to range 0 to 1, currently only supports 'vil'. + Returns: + data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]) : preprocessed data. + """ + + if rescale == "sevir": + scale_dict = PREPROCESS_SCALE_SEVIR + offset_dict = PREPROCESS_OFFSET_SEVIR + elif rescale == "01": + scale_dict = PREPROCESS_SCALE_01 + offset_dict = PREPROCESS_OFFSET_01 + else: + raise ValueError(f"Invalid rescale option: {rescale}.") + if data_types is None: + data_types = data_dict.keys() + for key, data in data_dict.items(): + if key in data_types: + if isinstance(data, np.ndarray): + data = scale_dict[key] * ( + data.astype(np.float32) + offset_dict[key] + ) + data = change_layout_np( + data=data, in_layout="NHWT", out_layout=layout + ) + elif isinstance(data, paddle.Tensor): + data = scale_dict[key] * (data.astype("float32") + offset_dict[key]) + data = change_layout_paddle( + data=data, in_layout="NHWT", out_layout=layout + ) + data_dict[key] = data + return data_dict + + @staticmethod + def process_data_dict_back(data_dict, data_types=None, rescale="01"): + if rescale == "sevir": + scale_dict = PREPROCESS_SCALE_SEVIR + offset_dict = PREPROCESS_OFFSET_SEVIR + elif rescale == "01": + scale_dict = PREPROCESS_SCALE_01 + offset_dict = PREPROCESS_OFFSET_01 + else: + raise ValueError(f"Invalid rescale option: {rescale}.") + if data_types is None: + data_types = data_dict.keys() + for key in data_types: + data = data_dict[key] + data = data.astype("float32") / scale_dict[key] - offset_dict[key] + data_dict[key] = data + return data_dict + + @staticmethod + def data_dict_to_tensor(data_dict, data_types=None): + """ + Convert each element in data_dict to paddle.Tensor (copy without grad). + """ + ret_dict = {} + if data_types is None: + data_types = data_dict.keys() + for key, data in data_dict.items(): + if key in data_types: + if isinstance(data, paddle.Tensor): + ret_dict[key] = data.detach().clone() + elif isinstance(data, np.ndarray): + ret_dict[key] = paddle.to_tensor(data) + else: + raise ValueError( + f"Invalid data type: {type(data)}. Should be paddle.Tensor or np.ndarray" + ) + else: # key == "mask" + ret_dict[key] = data + return ret_dict + + @staticmethod + def downsample_data_dict( + data_dict, data_types=None, factors_dict=None, layout="NHWT" + ): + """The downsample of data. + + Args: + data_dict (Dict[str, Union[np.array, paddle.Tensor]]): The dict of data. + factors_dict ( Optional[Dict[str, Sequence[int]]]):each element `factors` is a Sequence of int, representing (t_factor, + h_factor, w_factor) + + Returns: + downsampled_data_dict (Dict[str, paddle.Tensor]): Modify on a deep copy of data_dict instead of directly modifying the original + data_dict + """ + + if factors_dict is None: + factors_dict = {} + if data_types is None: + data_types = data_dict.keys() + downsampled_data_dict = SEVIRDataset.data_dict_to_tensor( + data_dict=data_dict, data_types=data_types + ) # make a copy + for key, data in data_dict.items(): + factors = factors_dict.get(key, None) + if factors is not None: + downsampled_data_dict[key] = change_layout_paddle( + data=downsampled_data_dict[key], in_layout=layout, out_layout="NTHW" + ) + # downsample t dimension + t_slice = [ + slice(None, None), + ] * 4 + t_slice[1] = slice(None, None, factors[0]) + downsampled_data_dict[key] = downsampled_data_dict[key][tuple(t_slice)] + # downsample spatial dimensions + downsampled_data_dict[key] = F.avg_pool2d( + input=downsampled_data_dict[key], + kernel_size=(factors[1], factors[2]), + ) + + downsampled_data_dict[key] = change_layout_paddle( + data=downsampled_data_dict[key], in_layout="NTHW", out_layout=layout + ) + + return downsampled_data_dict + + def layout_to_in_out_slice( + self, + ): + t_axis = self.layout.find("T") + num_axes = len(self.layout) + in_slice = [ + slice(None, None), + ] * num_axes + out_slice = deepcopy(in_slice) + in_slice[t_axis] = slice(None, self.in_len) + if self.out_len is None: + out_slice[t_axis] = slice(self.in_len, None) + else: + out_slice[t_axis] = slice(self.in_len, self.in_len + self.out_len) + return in_slice, out_slice + + def __getitem__(self, index): + event_idx = (index * self.batch_size) // self.num_seq_per_event + seq_idx = (index * self.batch_size) % self.num_seq_per_event + num_sampled = 0 + sampled_idx_list = [] # list of (event_idx, seq_idx) records + while num_sampled < self.batch_size: + sampled_idx_list.append({"event_idx": event_idx, "seq_idx": seq_idx}) + seq_idx += 1 + if seq_idx >= self.num_seq_per_event: + event_idx += 1 + seq_idx = 0 + num_sampled += 1 + + start_event_idx = sampled_idx_list[0]["event_idx"] + event_batch_size = sampled_idx_list[-1]["event_idx"] - start_event_idx + 1 + + event_batch = self._load_event_batch( + event_idx=start_event_idx, event_batch_size=event_batch_size + ) + ret_dict = {} + for sampled_idx in sampled_idx_list: + batch_slice = [ + sampled_idx["event_idx"] - start_event_idx, + ] # use [] to keepdim + seq_slice = slice( + sampled_idx["seq_idx"] * self.stride, + sampled_idx["seq_idx"] * self.stride + self.seq_len, + ) + for imgt_idx, imgt in enumerate(self.data_types): + sampled_seq = event_batch[imgt_idx][batch_slice, :, :, seq_slice] + if imgt in ret_dict: + ret_dict[imgt] = np.concatenate( + (ret_dict[imgt], sampled_seq), axis=0 + ) + else: + ret_dict.update({imgt: sampled_seq}) + + ret_dict = self.data_dict_to_tensor( + data_dict=ret_dict, data_types=self.data_types + ) + if self.preprocess: + ret_dict = self.preprocess_data_dict( + data_dict=ret_dict, + data_types=self.data_types, + layout=self.layout, + rescale=self.rescale_method, + ) + + if self.downsample_dict is not None: + ret_dict = self.downsample_data_dict( + data_dict=ret_dict, + data_types=self.data_types, + factors_dict=self.downsample_dict, + layout=self.layout, + ) + in_slice, out_slice = self.layout_to_in_out_slice() + data_seq = ret_dict["vil"] + if isinstance(data_seq, paddle.Tensor): + data_seq = data_seq.numpy() + x = data_seq[in_slice[0], in_slice[1], in_slice[2], in_slice[3], in_slice[4]] + y = data_seq[ + out_slice[0], out_slice[1], out_slice[2], out_slice[3], out_slice[4] + ] + + weight_item = self.weight_dict + input_item = {self.input_keys[0]: x} + label_item = { + self.label_keys[0]: y, + } + + return input_item, label_item, weight_item