-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Hackathon 6th Code Camp No.15] support earthformer #816
Changes from 4 commits
4673353
b1d026b
cc17151
5b8ec36
fd7aadb
074a32d
6f08e82
2dfbabb
457b8a7
54463dd
0eebf42
541ae87
ede9618
a58869b
872c763
ff41984
a98e284
5111ab8
83f221e
18eb8aa
9ca51c9
2d4d682
f946a1f
f4cb1ca
e4ae8f3
4979d89
b3ec330
d64e3e6
5d37aa1
90d343d
4109f12
6ad448c
ca8d402
fc10355
c3a2ed6
bc19623
0ae6c9c
43f7a46
149ac74
ec87043
25c7bc8
79aed10
bddc5e4
b44c2a5
b15ca3a
50d5137
cdf1216
8d06a9b
884e8f5
0d3deb1
87e158f
e240608
3dccb22
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,5 +26,6 @@ | |
- HEDeepONets | ||
- ChipDeepONets | ||
- AutoEncoder | ||
- CuboidTransformerModel | ||
show_root_heading: true | ||
heading_level: 3 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,4 +23,6 @@ | |
- MeshCylinderDataset | ||
- RadarDataset | ||
- build_dataset | ||
- ENSODataset | ||
- SEVIRDataset | ||
show_root_heading: true |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
hydra: | ||
run: | ||
# dynamic output directory according to running time and override name | ||
dir: outputs_earthformer_pretrain | ||
job: | ||
name: ${mode} # name of logfile | ||
chdir: false # keep current working direcotry unchaned | ||
config: | ||
override_dirname: | ||
exclude_keys: | ||
- TRAIN.checkpoint_path | ||
- TRAIN.pretrained_model_path | ||
- EVAL.pretrained_model_path | ||
- mode | ||
- output_dir | ||
- log_freq | ||
sweep: | ||
HydrogenSulfate marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# output directory for multirun | ||
dir: ${hydra.run.dir} | ||
subdir: ./ | ||
|
||
# general settings | ||
mode: train # running mode: train/eval/export/infer | ||
seed: 0 | ||
output_dir: ${hydra:run.dir} | ||
log_freq: 20 | ||
|
||
USE_SAMPLED_DATA: false | ||
# set train and evaluate data path | ||
FILE_PATH: /home/aistudio/data/data260191/enso_round1_train_20210201 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里可以改成相对路径 |
||
|
||
# dataset setting | ||
DATASET: | ||
label_keys: ["sst_target","nino_target"] | ||
in_len: 12 | ||
out_len: 14 | ||
nino_window_t: 3 | ||
in_stride: 1 | ||
out_stride: 1 | ||
train_samples_gap: 2 | ||
eval_samples_gap: 1 | ||
normalize_sst: true | ||
|
||
# model settings | ||
MODEL: | ||
self_pattern: "axial" | ||
cross_self_pattern: "axial" | ||
cross_pattern: "cross_1x1" | ||
afno: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. afno模型时FourCastNet的模型设置,此处直接使用即可,对应训练代码中也需要修改使用方式cfg.MODEL MODEL:
input_keys: ["sst_data"]
output_keys: ["sst_target","nino_target"]
input_shape: [12, 24, 48, 1]
target_shape: [14, 24, 48, 1]
base_units: 64
scale_alpha: 1.0
.... |
||
input_keys: ["sst_data"] | ||
output_keys: ["sst_target","nino_target"] | ||
input_shape: [12, 24, 48, 1] | ||
target_shape: [14, 24, 48, 1] | ||
base_units: 64 | ||
# block_units: null | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. block_units这个参数是需要的吗,不需要的话可以删除? |
||
scale_alpha: 1.0 | ||
|
||
enc_depth: [1, 1] | ||
dec_depth: [1, 1] | ||
enc_use_inter_ffn: true | ||
dec_use_inter_ffn: true | ||
dec_hierarchical_pos_embed: false | ||
|
||
downsample: 2 | ||
downsample_type: "patch_merge" | ||
upsample_type: "upsample" | ||
|
||
num_global_vectors: 0 | ||
use_dec_self_global: false | ||
dec_self_update_global: true | ||
use_dec_cross_global: false | ||
use_global_vector_ffn: false | ||
use_global_self_attn: false | ||
separate_global_qkv: false | ||
global_dim_ratio: 1 | ||
|
||
dec_cross_last_n_frames: null | ||
|
||
attn_drop: 0.1 | ||
proj_drop: 0.1 | ||
ffn_drop: 0.1 | ||
num_heads: 4 | ||
|
||
ffn_activation: "gelu" | ||
gated_ffn: false | ||
norm_layer: "layer_norm" | ||
padding_type: "zeros" | ||
pos_embed_type: "t+h+w" | ||
use_relative_pos: true | ||
self_attn_use_final_proj: true | ||
dec_use_first_self_attn: false | ||
|
||
z_init_method: "zeros" | ||
initial_downsample_type: "conv" | ||
initial_downsample_activation: "leaky" | ||
initial_downsample_scale: [1, 1, 2] | ||
initial_downsample_conv_layers: 2 | ||
final_upsample_conv_layers: 1 | ||
checkpoint_level: 2 | ||
|
||
attn_linear_init_mode: "0" | ||
ffn_linear_init_mode: "0" | ||
conv_init_mode: "0" | ||
down_up_linear_init_mode: "0" | ||
norm_init_mode: "0" | ||
|
||
|
||
# training settings | ||
TRAIN: | ||
epochs: 100 | ||
save_freq: 20 | ||
eval_during_train: true | ||
eval_freq: 10 | ||
lr_scheduler: | ||
epochs: ${TRAIN.epochs} | ||
learning_rate: 0.0002 | ||
by_epoch: True | ||
min_lr_ratio: 1.0e-3 | ||
wd: 1.0e-5 | ||
batch_size: 16 | ||
pretrained_model_path: null | ||
checkpoint_path: null | ||
|
||
|
||
# evaluation settings | ||
EVAL: | ||
pretrained_model_path: /home/aistudio/best_model.pdparams | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 相对路径 |
||
compute_metric_by_batch: False | ||
eval_with_no_grad: true | ||
batch_size: 16 | ||
|
||
INFER: | ||
pretrained_model_path: /home/aistudio/best_model.pdparams | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 相对路径 |
||
export_path: ./inference/earthformer/enso | ||
pdmodel_path: ${INFER.export_path}.pdmodel | ||
pdpiparams_path: ${INFER.export_path}.pdiparams | ||
device: gpu | ||
engine: native | ||
precision: fp32 | ||
onnx_path: ${INFER.export_path}.onnx | ||
ir_optim: true | ||
min_subgraph_size: 10 | ||
gpu_mem: 4000 | ||
gpu_id: 0 | ||
max_batch_size: 16 | ||
num_cpu_threads: 4 | ||
batch_size: 1 | ||
data_path: /home/aistudio/data/data260191/enso_round1_train_20210201/SODA_train.nc | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 相对路径 |
||
in_len: 12 | ||
in_stride: 1 | ||
out_len: 14 | ||
out_stride: 1 | ||
samples_gap: 1 | ||
|
||
|
||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 多余空行删除 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import paddle | ||
import numpy as np | ||
from typing import Optional, Union, Dict | ||
from ppsci.data.dataset.enso_dataset import NINO_WINDOW_T, scale_back_sst | ||
from paddle.nn import functional as F | ||
|
||
|
||
def compute_enso_score( | ||
y_pred, y_true, | ||
acc_weight: Optional[Union[str, np.ndarray, paddle.Tensor]] = None): | ||
r""" | ||
|
||
Parameters | ||
---------- | ||
y_pred: paddle.Tensor | ||
y_true: paddle.Tensor | ||
acc_weight: Optional[Union[str, np.ndarray, paddle.Tensor]] | ||
None: not used | ||
default: use default acc_weight specified at https://tianchi.aliyun.com/competition/entrance/531871/information | ||
np.ndarray: custom weights | ||
|
||
Returns | ||
------- | ||
acc | ||
rmse | ||
""" | ||
pred = y_pred - y_pred.mean(axis=0, keepdim=True) # (N, 24) | ||
true = y_true - y_true.mean(axis=0, keepdim=True) # (N, 24) | ||
cor = (pred * true).sum(axis=0) / ( | ||
paddle.sqrt(paddle.sum(pred ** 2, axis=0) * paddle.sum(true ** 2, axis=0)) + 1e-6) | ||
|
||
if acc_weight is None: | ||
acc = cor.sum() | ||
else: | ||
nino_out_len = y_true.shape[-1] | ||
if acc_weight == "default": | ||
acc_weight = paddle.to_tensor([1.5] * 4 + [2] * 7 + [3] * 7 + [4] * (nino_out_len - 18))[:nino_out_len] \ | ||
* paddle.log(paddle.arange(nino_out_len) + 1) | ||
elif isinstance(acc_weight, np.ndarray): | ||
acc_weight = paddle.to_tensor(acc_weight[:nino_out_len]) | ||
elif isinstance(acc_weight, paddle.Tensor): | ||
acc_weight = acc_weight[:nino_out_len] | ||
else: | ||
raise ValueError(f"Invalid acc_weight {acc_weight}!") | ||
acc_weight = acc_weight.to(y_pred) | ||
acc = (acc_weight * cor).sum() | ||
rmse = paddle.mean((y_pred - y_true) ** 2, axis=0).sqrt().sum() | ||
return acc, rmse | ||
|
||
|
||
def sst_to_nino(sst: paddle.Tensor, | ||
normalize_sst: bool = True, | ||
detach: bool = True): | ||
r""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 一般带r的是防止含方程的转义字符被识别错误,如果没有方程感觉docstring开头的r可以删除? |
||
|
||
Parameters | ||
---------- | ||
sst: paddle.Tensor | ||
Shape = (N, T, H, W) | ||
|
||
Returns | ||
------- | ||
nino_index: paddle.Tensor | ||
Shape = (N, T-NINO_WINDOW_T+1) | ||
""" | ||
if detach: | ||
nino_index = sst.detach() | ||
else: | ||
nino_index = sst | ||
if normalize_sst: | ||
nino_index = scale_back_sst(nino_index) | ||
nino_index = nino_index[:, :, 10:13, 19:30].mean(axis=[2, 3]) # (N, 26) | ||
nino_index = nino_index.unfold(axis=1, size=NINO_WINDOW_T, step=1).mean(axis=2) # (N, 24) | ||
|
||
return nino_index | ||
|
||
|
||
def train_mse_func( | ||
output_dict: Dict[str, "paddle.Tensor"], | ||
label_dict: Dict[str, "paddle.Tensor"], | ||
*args, | ||
) -> paddle.Tensor: | ||
return F.mse_loss(output_dict["sst_target"], label_dict["sst_target"]) | ||
|
||
|
||
def eval_rmse_func( | ||
output_dict: Dict[str, "paddle.Tensor"], | ||
label_dict: Dict[str, "paddle.Tensor"], | ||
nino_out_len=12, | ||
*args, | ||
) -> Dict[str, paddle.Tensor]: | ||
pred = output_dict["sst_target"] | ||
sst_target = label_dict["sst_target"] | ||
nino_target = label_dict["nino_target"].astype('float32') | ||
# mse | ||
mae = F.l1_loss(pred, sst_target) | ||
# mse | ||
mse = F.mse_loss(pred, sst_target) | ||
# rmse | ||
nino_preds = sst_to_nino(sst=pred[..., 0]) | ||
nino_preds_list, nino_target_list = map(list, zip((nino_preds, nino_target))) | ||
nino_preds_list = paddle.concat(nino_preds_list, axis=0) | ||
nino_target_list = paddle.concat(nino_target_list, axis=0) | ||
|
||
valid_acc, valid_nino_rmse = compute_enso_score( | ||
y_pred=nino_preds_list, y_true=nino_target_list, | ||
acc_weight=None) | ||
valid_weighted_acc, _ = compute_enso_score( | ||
y_pred=nino_preds_list, y_true=nino_target_list, | ||
acc_weight="default") | ||
valid_acc /= nino_out_len | ||
valid_nino_rmse /= nino_out_len | ||
valid_weighted_acc /= nino_out_len | ||
valid_loss = -valid_acc | ||
|
||
return {"valid_loss_epoch": valid_loss, "mse": mse, "mae": mae, "rmse": valid_nino_rmse, | ||
"corr_nino3.4_epoch": valid_acc, "corr_nino3.4_weighted_epoch": valid_weighted_acc, } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 末尾逗号可以看删除 |
||
|
||
|
||
def get_parameter_names(model, forbidden_layer_types): | ||
result = [] | ||
for name, child in model.named_children(): | ||
result += [ | ||
f"{name}.{n}" | ||
for n in get_parameter_names(child, forbidden_layer_types) | ||
if not isinstance(child, tuple(forbidden_layer_types)) | ||
] | ||
# Add model specific parameters (defined with nn.Parameter) since they are not in any child. | ||
result += list(model._parameters.keys()) | ||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CuboidTransformerModel建议改为CuboidTransformer