Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon 6th Code Camp No.15] support earthformer #816

Closed
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
4673353
add-earthformer
Mar 21, 2024
b1d026b
add-earthformer
Mar 21, 2024
cc17151
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleScien…
Mar 21, 2024
5b8ec36
add-earthformer
Mar 21, 2024
fd7aadb
add-earthformer
Mar 25, 2024
074a32d
add-earthformer
Yang-Changhui Mar 25, 2024
6f08e82
add-earthformer
Yang-Changhui Mar 26, 2024
2dfbabb
Merge branch 'develop' into add_earthformer
Yang-Changhui Mar 26, 2024
457b8a7
add-earthformer
Yang-Changhui Mar 27, 2024
54463dd
add-earthformer
Yang-Changhui Mar 27, 2024
0eebf42
Merge branch 'develop' into add_earthformer
Yang-Changhui Mar 27, 2024
541ae87
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleScien…
Yang-Changhui Mar 27, 2024
ede9618
Merge branch 'add_earthformer' of https://github.com/Yang-Changhui/Pa…
Yang-Changhui Mar 27, 2024
a58869b
add-earthformer
Yang-Changhui Mar 28, 2024
872c763
add-earthformer
Yang-Changhui Mar 28, 2024
ff41984
add-earthformer
Yang-Changhui Mar 28, 2024
a98e284
add-earthformer
Yang-Changhui Mar 28, 2024
5111ab8
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleScien…
Yang-Changhui Mar 28, 2024
83f221e
add-earthformer
Yang-Changhui Mar 28, 2024
18eb8aa
add-earthfromer
Yang-Changhui Mar 28, 2024
9ca51c9
Merge branch 'develop' into add_earthformer
Yang-Changhui Mar 29, 2024
2d4d682
add-earthformer
Yang-Changhui Mar 30, 2024
f946a1f
Merge branch 'add_earthformer' of https://github.com/Yang-Changhui/Pa…
Yang-Changhui Mar 30, 2024
f4cb1ca
add-earthformer
Yang-Changhui Mar 30, 2024
e4ae8f3
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleScien…
Yang-Changhui Mar 30, 2024
4979d89
'add-earthfromer'
Yang-Changhui Mar 30, 2024
b3ec330
add-earthfromer
Yang-Changhui Mar 30, 2024
d64e3e6
add-earthformer
Yang-Changhui Apr 2, 2024
5d37aa1
Merge branch 'develop' into add_earthformer
Yang-Changhui Apr 2, 2024
90d343d
add-earthformer
Yang-Changhui Apr 2, 2024
4109f12
Merge branch 'add_earthformer' of https://github.com/Yang-Changhui/Pa…
Yang-Changhui Apr 2, 2024
6ad448c
Merge branch 'develop' into add_earthformer
Yang-Changhui Apr 2, 2024
ca8d402
add-earthformer
Yang-Changhui Apr 3, 2024
fc10355
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleScien…
Yang-Changhui Apr 3, 2024
c3a2ed6
Merge branch 'add_earthformer' of https://github.com/Yang-Changhui/Pa…
Yang-Changhui Apr 3, 2024
bc19623
add-earthformer
Yang-Changhui Apr 7, 2024
0ae6c9c
add-earthformer
Yang-Changhui Apr 7, 2024
43f7a46
Merge branch 'develop' into add_earthformer
Yang-Changhui Apr 7, 2024
149ac74
add-earthfromer
Yang-Changhui Apr 13, 2024
ec87043
Merge branch 'add_earthformer' of https://github.com/Yang-Changhui/Pa…
Yang-Changhui Apr 13, 2024
25c7bc8
Merge branch 'develop' into add_earthformer
Yang-Changhui Apr 13, 2024
79aed10
add-earthformer
Yang-Changhui Apr 13, 2024
bddc5e4
add-earthformer
Yang-Changhui Apr 16, 2024
b44c2a5
add-earthformer
Yang-Changhui Apr 16, 2024
b15ca3a
Merge branch 'develop' into add_earthformer
Yang-Changhui Apr 16, 2024
50d5137
add-earthformer
Yang-Changhui Apr 16, 2024
cdf1216
Merge branch 'add_earthformer' of https://github.com/Yang-Changhui/Pa…
Yang-Changhui Apr 16, 2024
8d06a9b
Merge branch 'develop' into add_earthformer
Yang-Changhui Apr 16, 2024
884e8f5
add-earthformer
Yang-Changhui Apr 24, 2024
0d3deb1
Merge branch 'add_earthformer' of https://github.com/Yang-Changhui/Pa…
Yang-Changhui Apr 24, 2024
87e158f
Merge branch 'develop' into add_earthformer
Yang-Changhui Apr 24, 2024
e240608
Merge branch 'develop' into add_earthformer
Yang-Changhui Apr 24, 2024
3dccb22
Merge branch 'develop' into add_earthformer
zhiminzhang0830 Apr 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/zh/api/arch.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@
- HEDeepONets
- ChipDeepONets
- AutoEncoder
- CuboidTransformerModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CuboidTransformerModel建议改为CuboidTransformer

show_root_heading: true
heading_level: 3
2 changes: 2 additions & 0 deletions docs/zh/api/data/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@
- MeshCylinderDataset
- RadarDataset
- build_dataset
- ENSODataset
- SEVIRDataset
show_root_heading: true
157 changes: 157 additions & 0 deletions examples/earthformer/enso/conf/earthformer_enso_pretrain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
hydra:
run:
# dynamic output directory according to running time and override name
dir: outputs_earthformer_pretrain
job:
name: ${mode} # name of logfile
chdir: false # keep current working direcotry unchaned
config:
override_dirname:
exclude_keys:
- TRAIN.checkpoint_path
- TRAIN.pretrained_model_path
- EVAL.pretrained_model_path
- mode
- output_dir
- log_freq
sweep:
HydrogenSulfate marked this conversation as resolved.
Show resolved Hide resolved
# output directory for multirun
dir: ${hydra.run.dir}
subdir: ./

# general settings
mode: train # running mode: train/eval/export/infer
seed: 0
output_dir: ${hydra:run.dir}
log_freq: 20

USE_SAMPLED_DATA: false
# set train and evaluate data path
FILE_PATH: /home/aistudio/data/data260191/enso_round1_train_20210201
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以改成相对路径


# dataset setting
DATASET:
label_keys: ["sst_target","nino_target"]
in_len: 12
out_len: 14
nino_window_t: 3
in_stride: 1
out_stride: 1
train_samples_gap: 2
eval_samples_gap: 1
normalize_sst: true

# model settings
MODEL:
self_pattern: "axial"
cross_self_pattern: "axial"
cross_pattern: "cross_1x1"
afno:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afno模型时FourCastNet的模型设置,此处直接使用即可,对应训练代码中也需要修改使用方式cfg.MODEL

MODEL:
    input_keys: ["sst_data"]
    output_keys: ["sst_target","nino_target"]
    input_shape: [12, 24, 48, 1]
    target_shape: [14, 24, 48, 1]
    base_units: 64
    scale_alpha: 1.0
    ....

input_keys: ["sst_data"]
output_keys: ["sst_target","nino_target"]
input_shape: [12, 24, 48, 1]
target_shape: [14, 24, 48, 1]
base_units: 64
# block_units: null
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block_units这个参数是需要的吗,不需要的话可以删除?

scale_alpha: 1.0

enc_depth: [1, 1]
dec_depth: [1, 1]
enc_use_inter_ffn: true
dec_use_inter_ffn: true
dec_hierarchical_pos_embed: false

downsample: 2
downsample_type: "patch_merge"
upsample_type: "upsample"

num_global_vectors: 0
use_dec_self_global: false
dec_self_update_global: true
use_dec_cross_global: false
use_global_vector_ffn: false
use_global_self_attn: false
separate_global_qkv: false
global_dim_ratio: 1

dec_cross_last_n_frames: null

attn_drop: 0.1
proj_drop: 0.1
ffn_drop: 0.1
num_heads: 4

ffn_activation: "gelu"
gated_ffn: false
norm_layer: "layer_norm"
padding_type: "zeros"
pos_embed_type: "t+h+w"
use_relative_pos: true
self_attn_use_final_proj: true
dec_use_first_self_attn: false

z_init_method: "zeros"
initial_downsample_type: "conv"
initial_downsample_activation: "leaky"
initial_downsample_scale: [1, 1, 2]
initial_downsample_conv_layers: 2
final_upsample_conv_layers: 1
checkpoint_level: 2

attn_linear_init_mode: "0"
ffn_linear_init_mode: "0"
conv_init_mode: "0"
down_up_linear_init_mode: "0"
norm_init_mode: "0"


# training settings
TRAIN:
epochs: 100
save_freq: 20
eval_during_train: true
eval_freq: 10
lr_scheduler:
epochs: ${TRAIN.epochs}
learning_rate: 0.0002
by_epoch: True
min_lr_ratio: 1.0e-3
wd: 1.0e-5
batch_size: 16
pretrained_model_path: null
checkpoint_path: null


# evaluation settings
EVAL:
pretrained_model_path: /home/aistudio/best_model.pdparams
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

相对路径

compute_metric_by_batch: False
eval_with_no_grad: true
batch_size: 16

INFER:
pretrained_model_path: /home/aistudio/best_model.pdparams
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

相对路径

export_path: ./inference/earthformer/enso
pdmodel_path: ${INFER.export_path}.pdmodel
pdpiparams_path: ${INFER.export_path}.pdiparams
device: gpu
engine: native
precision: fp32
onnx_path: ${INFER.export_path}.onnx
ir_optim: true
min_subgraph_size: 10
gpu_mem: 4000
gpu_id: 0
max_batch_size: 16
num_cpu_threads: 4
batch_size: 1
data_path: /home/aistudio/data/data260191/enso_round1_train_20210201/SODA_train.nc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

相对路径

in_len: 12
in_stride: 1
out_len: 14
out_stride: 1
samples_gap: 1




Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多余空行删除

130 changes: 130 additions & 0 deletions examples/earthformer/enso/helps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import paddle
import numpy as np
from typing import Optional, Union, Dict
from ppsci.data.dataset.enso_dataset import NINO_WINDOW_T, scale_back_sst
from paddle.nn import functional as F


def compute_enso_score(
y_pred, y_true,
acc_weight: Optional[Union[str, np.ndarray, paddle.Tensor]] = None):
r"""

Parameters
----------
y_pred: paddle.Tensor
y_true: paddle.Tensor
acc_weight: Optional[Union[str, np.ndarray, paddle.Tensor]]
None: not used
default: use default acc_weight specified at https://tianchi.aliyun.com/competition/entrance/531871/information
np.ndarray: custom weights

Returns
-------
acc
rmse
"""
pred = y_pred - y_pred.mean(axis=0, keepdim=True) # (N, 24)
true = y_true - y_true.mean(axis=0, keepdim=True) # (N, 24)
cor = (pred * true).sum(axis=0) / (
paddle.sqrt(paddle.sum(pred ** 2, axis=0) * paddle.sum(true ** 2, axis=0)) + 1e-6)

if acc_weight is None:
acc = cor.sum()
else:
nino_out_len = y_true.shape[-1]
if acc_weight == "default":
acc_weight = paddle.to_tensor([1.5] * 4 + [2] * 7 + [3] * 7 + [4] * (nino_out_len - 18))[:nino_out_len] \
* paddle.log(paddle.arange(nino_out_len) + 1)
elif isinstance(acc_weight, np.ndarray):
acc_weight = paddle.to_tensor(acc_weight[:nino_out_len])
elif isinstance(acc_weight, paddle.Tensor):
acc_weight = acc_weight[:nino_out_len]
else:
raise ValueError(f"Invalid acc_weight {acc_weight}!")
acc_weight = acc_weight.to(y_pred)
acc = (acc_weight * cor).sum()
rmse = paddle.mean((y_pred - y_true) ** 2, axis=0).sqrt().sum()
return acc, rmse


def sst_to_nino(sst: paddle.Tensor,
normalize_sst: bool = True,
detach: bool = True):
r"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一般带r的是防止含方程的转义字符被识别错误,如果没有方程感觉docstring开头的r可以删除?


Parameters
----------
sst: paddle.Tensor
Shape = (N, T, H, W)

Returns
-------
nino_index: paddle.Tensor
Shape = (N, T-NINO_WINDOW_T+1)
"""
if detach:
nino_index = sst.detach()
else:
nino_index = sst
if normalize_sst:
nino_index = scale_back_sst(nino_index)
nino_index = nino_index[:, :, 10:13, 19:30].mean(axis=[2, 3]) # (N, 26)
nino_index = nino_index.unfold(axis=1, size=NINO_WINDOW_T, step=1).mean(axis=2) # (N, 24)

return nino_index


def train_mse_func(
output_dict: Dict[str, "paddle.Tensor"],
label_dict: Dict[str, "paddle.Tensor"],
*args,
) -> paddle.Tensor:
return F.mse_loss(output_dict["sst_target"], label_dict["sst_target"])


def eval_rmse_func(
output_dict: Dict[str, "paddle.Tensor"],
label_dict: Dict[str, "paddle.Tensor"],
nino_out_len=12,
*args,
) -> Dict[str, paddle.Tensor]:
pred = output_dict["sst_target"]
sst_target = label_dict["sst_target"]
nino_target = label_dict["nino_target"].astype('float32')
# mse
mae = F.l1_loss(pred, sst_target)
# mse
mse = F.mse_loss(pred, sst_target)
# rmse
nino_preds = sst_to_nino(sst=pred[..., 0])
nino_preds_list, nino_target_list = map(list, zip((nino_preds, nino_target)))
nino_preds_list = paddle.concat(nino_preds_list, axis=0)
nino_target_list = paddle.concat(nino_target_list, axis=0)

valid_acc, valid_nino_rmse = compute_enso_score(
y_pred=nino_preds_list, y_true=nino_target_list,
acc_weight=None)
valid_weighted_acc, _ = compute_enso_score(
y_pred=nino_preds_list, y_true=nino_target_list,
acc_weight="default")
valid_acc /= nino_out_len
valid_nino_rmse /= nino_out_len
valid_weighted_acc /= nino_out_len
valid_loss = -valid_acc

return {"valid_loss_epoch": valid_loss, "mse": mse, "mae": mae, "rmse": valid_nino_rmse,
"corr_nino3.4_epoch": valid_acc, "corr_nino3.4_weighted_epoch": valid_weighted_acc, }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

末尾逗号可以看删除



def get_parameter_names(model, forbidden_layer_types):
result = []
for name, child in model.named_children():
result += [
f"{name}.{n}"
for n in get_parameter_names(child, forbidden_layer_types)
if not isinstance(child, tuple(forbidden_layer_types))
]
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
result += list(model._parameters.keys())
return result
Loading