-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
6,101 additions
and
689 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Tue Sep 15 14:15:59 2020 | ||
@author: wang.zife | ||
""" | ||
|
||
#%% | ||
import os | ||
import json | ||
import numpy as np | ||
import pandas as pd | ||
import utils | ||
|
||
#mode_ind = 0 # 1 #0 | ||
#station_ind = 1 | ||
def main(lag,window): | ||
verbose = True | ||
#lag = 6*12 # 4*12 # AR lag months | ||
#window = 3 # 1 # time window for moving average, 1 means no average | ||
num_min_years = 40+lag//12 # 60+4 # number of min years to be considered | ||
min_volumes = 100 # 9 # minimum of yearly volume (km3/yr) to be considered | ||
Nvalid = 10 # number of valid data | ||
Ntest = 20 # 10 # number of test data | ||
scaleback = False # True # scale back results or not | ||
datapath = '../data/' | ||
modes = ['ARonly','ARenso','ENSO'] | ||
#mode = modes[mode_ind] # 'ARenso' # 'ARonly' # 'ENSO' # | ||
#columns = ['Nino12','Nino3','Nino34','Nino4','Nino12_anom','Nino3_anom','Nino34_anom','Nino4_anom','tni','soi'] | ||
column = ['Nino12','Nino3','Nino4'] # ['Nino34_anom'] # 6 is 'Nino34_anom' | ||
savepath_root = '../results/analysis/Regression/riversAR/MovingAveraged/window_{}/'.format(window) # '../results/analysis/Regression/riversAR/MovingAveraged/''../results/analysis/Regression/riversAR/' # | ||
#savepath = '../results/analysis/Regression/riversAR/station_ind_{}/{}/'.format(station_ind,mode) | ||
#txtsavepath = '../results/analysis/Regression/riversAR/ARlag_{}/'.format(lag//12) | ||
txtsavepath = savepath_root+'ARlag_{}/'.format(lag//12) | ||
mainparas = {} | ||
mainparas['num_min_years'] = num_min_years | ||
mainparas['min_volumes'] = min_volumes | ||
mainparas['lag'] = int(lag) | ||
mainparas['feature'] = ' '.join(column) #str(column[0]) | ||
mainparas['Nvalid'] = int(Nvalid) | ||
mainparas['Ntest'] = int(Ntest) | ||
mainparas['scaleback'] = scaleback | ||
mainparas['txtsavepath'] = str(txtsavepath) | ||
mainparas['verbose'] = verbose | ||
#%% read river flow | ||
riverflow_df = pd.read_csv('../data/RiverFlow/processed/riverflow.csv',index_col=0,header=0) | ||
info_df = pd.read_csv('../data/RiverFlow/processed/info.csv',index_col=0,header=0) | ||
times_df = pd.read_csv('../data/RiverFlow/processed/times.csv',index_col=0,header=0) | ||
#times = np.asarray(times_df,dtype=int).reshape((-1,)) | ||
times = list(np.asarray(times_df,dtype=int).reshape((-1,)))+[201901] | ||
#%% read ENSO index | ||
indices_df, _ = utils.read_enso() # 187001 to 201912 | ||
##indices_df.to_csv('../data/Nino/processed/Nino_ENSO_187001-201912_df.csv') | ||
|
||
##riverflow_df = pd.read_csv('../data/Denoised/denoised_riverflow.csv',index_col=0,header=0) | ||
##indices_df = pd.read_csv('../data/Denoised/denoised_Nino_ENSO_187001-201912_df.csv',index_col=0) | ||
|
||
indices_df = indices_df[column] # select input feature | ||
|
||
#large_rivers = info_df.loc[(info_df['station_volume_km3/yr']>100)&(info_df['num_month']>528)] | ||
#mode = 'ARonly' | ||
#station_ind = 0 | ||
for mode in modes[0:2]:#modes[:-1]: | ||
mainparas['mode'] = str(mode) | ||
for station_ind in [13]: #range(925): | ||
#if station_ind==13 or station_ind==21 or station_ind==26: continue | ||
## determine start and stop yearmonth for data | ||
rivername = info_df['river_name'].iloc[station_ind] | ||
start_month_ind = info_df['start_month_ind'].iloc[station_ind] # data including start_month_ind | ||
stop_month_ind = info_df['stop_month_ind'].iloc[station_ind] # data not including stop_month_ind | ||
## start from January and stop at December | ||
while times[start_month_ind]%100!=1: | ||
start_month_ind += 1 | ||
while times[stop_month_ind]%100!=1: | ||
stop_month_ind -= 1 | ||
assert (stop_month_ind-start_month_ind)%12==0, 'number of months must be a multiplier of 12' | ||
if stop_month_ind-start_month_ind<num_min_years*12: | ||
print('Error! too few data points to be considered!') | ||
print('Skipping {}-th river {}'.format(station_ind,rivername)) | ||
continue | ||
#else: | ||
#print('Processing {}-th river {}'.format(station_ind,rivername)) | ||
|
||
volumes = info_df['station_volume_km3/yr'].iloc[station_ind] | ||
if volumes<min_volumes: | ||
print('Error! too small volume to be considered!') | ||
print('Skipping {}-th river {}'.format(station_ind,rivername)) | ||
continue | ||
else: | ||
print('Processing {}-th river {}'.format(station_ind,rivername)) | ||
|
||
mainparas['station_ind'] = int(station_ind) | ||
mainparas['rivername'] = str(rivername) | ||
lag_start_index,start_index,end_index = times[start_month_ind-lag],times[start_month_ind],times[stop_month_ind-1] # end_index included | ||
mainparas['lag_start_index'],mainparas['start_index'],mainparas['end_index'] = int(lag_start_index),int(start_index),int(end_index) | ||
|
||
riverflow = riverflow_df[str(station_ind)].loc[start_index:end_index] | ||
riverflow = np.asarray(riverflow,dtype=float).reshape((-1,12)) | ||
#riverflow = riverflow.reshape((-1,12)) | ||
flow_year = np.sum(riverflow,axis=1) # yearly flow, unit: m^3/s | ||
flow_year = flow_year*30*24*3600/1e9 # unit: km^3/year | ||
|
||
## moving average | ||
if window>1: | ||
flow_year = np.asarray([np.mean(flow_year[i:i+window]) for i in range(len(flow_year)-window+1)]) | ||
|
||
flow_year_anom = flow_year-np.mean(flow_year[0:30]) # anomaly based on first 30 years, unit: km^3/year | ||
|
||
indice = indices_df.loc[start_index:end_index] | ||
#enso = indice.to_numpy().reshape((-1,12)) | ||
#enso_annual = np.mean(enso,axis=1) # mean over a year | ||
enso = indice.to_numpy() # [N,D] | ||
enso_annual = np.asarray([np.mean(enso[i:i+12],axis=0) for i in range(0,len(enso),12)]) # mean over a year | ||
|
||
## moving average | ||
if window>1: | ||
#enso_annual = np.asarray([np.mean(enso_annual[i:i+window],axis=0) for i in range(len(enso_annual)-window+1)]) | ||
enso_annual = enso_annual[window-1:] | ||
|
||
#start_year, stop_year = start_index//100+4, end_index//100+1 # stop_year not included | ||
start_year, stop_year = start_index//100+lag//12, end_index//100+1 # stop_year not included | ||
|
||
## shorter years after moving average | ||
start_year = start_year+window-1 | ||
|
||
Ntotal = stop_year-start_year # number of total years available, at least 60 years | ||
Ntrain = Ntotal-Nvalid-Ntest | ||
mainparas['Ntotal'] = int(Ntotal) | ||
mainparas['Ntrain'] = int(Ntrain) | ||
years_test_start,years_test_stop = stop_year-Ntest,stop_year | ||
years_all_start,years_all_stop = start_year,stop_year | ||
year_mid = years_test_start-0.5 | ||
mainparas['years_all_start'],mainparas['years_test_start'],mainparas['years_all_stop'] = int(years_all_start),int(years_test_start),int(years_all_stop) | ||
mainparas['years_test_stop'] = int(years_test_stop) | ||
mainparas['year_mid'] = float(year_mid) | ||
## separate data | ||
(X_train,y_train), (X_valid,y_valid), (X_test,y_test) = utils.separatedata(mode,flow_year_anom,enso_annual,lag=lag//12,Ntrain=Ntrain,Nvalid=Nvalid,Ntest=Ntest) | ||
## scale data | ||
(X_train,y_train), (X_valid,y_valid), (X_test,y_test), (X_scaler,y_scaler) = utils.scaledata(X_train=X_train,y_train=y_train, X_valid=X_valid,y_valid=y_valid, X_test=X_test,y_test=y_test,scalemethod='StandardScaler') | ||
X_scaler_mean = X_scaler.mean_ | ||
X_scaler_std = np.sqrt(X_scaler.var_) | ||
y_scaler_mean = y_scaler.mean_ | ||
y_scaler_std = np.sqrt(y_scaler.var_) | ||
mainparas['X_scaler_mean'] = list(X_scaler_mean) | ||
mainparas['X_scaler_std'] = list(X_scaler_std) | ||
mainparas['y_scaler_mean'] = list(y_scaler_mean) | ||
mainparas['y_scaler_std'] = list(y_scaler_std) | ||
|
||
## | ||
#savepath = '../results/analysis/Regression/riversAR/ARlag_{}/stations/station_ind_{}/{}/'.format(lag//12,station_ind,mode) | ||
savepath = savepath_root+'ARlag_{}/stations/station_ind_{}/{}/'.format(lag//12,station_ind,mode) | ||
mainparas['savepath'] = savepath | ||
if not os.path.exists(savepath): os.makedirs(savepath) | ||
|
||
import regressions | ||
hyparas,res = regressions.regressions(X_train,y_train, X_valid,y_valid, X_test,y_test, savepath, mainparas) | ||
|
||
configs = {**mainparas,**hyparas,**res} | ||
with open(savepath+'configs.txt', 'w') as file: | ||
file.write(json.dumps(configs,indent=0)) # use `json.loads` to do the reverse | ||
|
||
if verbose: | ||
import myplots | ||
myplots.plot_test_all_data(station_ind,savepath,mainparas) | ||
|
||
|
||
for lag in [9]: | ||
lag *= 12 | ||
for window in [1]: | ||
main(lag=lag,window=window) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
#%% | ||
def inference(hyparas,paras): | ||
import os | ||
import torch | ||
import numpy as np | ||
from skimage.transform import resize | ||
import models | ||
|
||
use_climatology = hyparas['use_climatology'] | ||
savepath = paras['savepath'] | ||
variable = paras['variable'] | ||
test_datapath = paras['test_datapath'] | ||
device = paras['device'] | ||
checkpoint_name = paras['checkpoint_name'] | ||
verbose = paras['verbose'] | ||
torch.backends.cudnn.benchmark = True | ||
checkpoint_pathname = savepath+checkpoint_name | ||
filenames = [f for f in os.listdir(test_datapath) if f.endswith('.npz')] | ||
filenames = sorted(filenames) | ||
|
||
#%% | ||
#model = models.YNet(input_channels=input_channels,output_channels=output_channels, | ||
# hidden_channels=hidden_channels,num_layers=num_layers, | ||
# scale=scale,use_climatology=use_climatology) | ||
model = models.YNet(**hyparas) | ||
checkpoint = torch.load(checkpoint_pathname, map_location=lambda storage, loc: storage) | ||
model.load_state_dict(checkpoint['model_state_dict']) | ||
model = model.to(device) | ||
model.eval() | ||
|
||
ys,y_preds = [],[] | ||
for filename in filenames: | ||
data = np.load(test_datapath+filename) | ||
X = data['gcms'] # tmax tmin:[-1,1], ppt: [0.0,1.0], [Ngcm,Nlat,Nlon] | ||
y = data['prism'] # [1,Nlat,Nlon], tmax/tmin:[-1.0,1.0], ppt:[0.0,1.0] | ||
|
||
input1 = torch.from_numpy(X[np.newaxis,...]).float() #[Ngcm,Nlat,Nlon]-->[1,Ngcm,Nlat,Nlon] | ||
X2 = resize(np.transpose(X,axes=(1,2,0)),y.shape[1:],order=1,preserve_range=True) # [Nlat,Nlon,Ngcm] | ||
X2 = np.transpose(X2,axes=(2,0,1))# [Ngcm,Nlat,Nlon] | ||
input2 = torch.from_numpy(X2[np.newaxis,...]).float() # [Ngcm,Nlat,Nlon] | ||
inputs = [input1,input2] | ||
if use_climatology: | ||
Xaux = np.concatenate((data['climatology'],data['elevation']),axis=0) # [2,Nlat,Nlon] | ||
input3 = torch.from_numpy(Xaux[np.newaxis,...]).float() #[1,2,Nlat,Nlon] --> [1,2,Nlat,Nlon] | ||
inputs += [input3] | ||
inputs = [e.to(device) for e in inputs] | ||
with torch.no_grad(): | ||
y_pred = model(*inputs) # [1,1,Nlat,Nlon] | ||
|
||
y_pred = np.squeeze(y_pred.cpu().detach().numpy()) # [1,1,Nlat,Nlon]-->[Nlat,Nlon]] | ||
y = np.squeeze(y) # [1,Nlat,Nlon]-->[Nlat,Nlon]] | ||
y_preds.append(y_pred) | ||
ys.append(y) | ||
|
||
y_preds = np.stack(y_preds,axis=0) #[Ntest,Nlat,Nlon], unit: mm/day | ||
ys = np.stack(ys,axis=0) #[Ntest,Nlat,Nlon], unit: mm/day | ||
if variable=='ppt': | ||
y_preds = np.expm1(y_preds*5.0) # [0.0,1.0]-->[0.0,R], unit: mm/day | ||
ys = np.expm1(ys*5.0) # [0.0,1.0]-->[0.0,R], unit: mm/day | ||
X_mean = np.mean(np.expm1(X*5.0),axis=0) # [Nlat,Nlon] | ||
elif variable=='tmax' or variable=='tmin': | ||
y_preds = y_preds*50.0 # [Nlat,Nlon] unit: Celsius | ||
ys = ys*50.0 # [Nlat,Nlon] unit: Celsius | ||
X_mean = np.mean(X*50.0,axis=0) # [Nlat,Nlon] | ||
else: | ||
print('Error! variable not recognized!') | ||
mse = np.mean((y_preds-ys)**2) | ||
rmse = np.sqrt(mse) | ||
mae = np.mean(np.abs(y_preds-ys)) | ||
print('test data MSE={},\nRMSE={}\nMAE={}'.format(mse,rmse,mae)) | ||
if savepath: | ||
np.savez(savepath+'pred_results_MSE{}.npz'.format(mse),y_preds=y_preds,mse=mse,rmse=rmse,mae=mae) | ||
|
||
#%% plot figures | ||
if verbose: | ||
import matplotlib | ||
matplotlib.use('Agg') | ||
import matplotlib.pyplot as plt | ||
fig,axs = plt.subplots(2,2) | ||
axs[0,0].imshow(y) | ||
axs[0,0].set_title('y') | ||
axs[1,0].imshow(y_pred) | ||
axs[1,0].set_title('Prediction') | ||
axs[0,1].imshow(X_mean) | ||
axs[0,1].set_title('Input GCM mean') | ||
diff_y_pred = np.abs(y-y_pred) | ||
diff1 = axs[1,1].imshow(diff_y_pred) | ||
axs[1,1].set_title('Abs(y-pred)') | ||
fig.colorbar(diff1,ax=axs[1,1],fraction=0.05) | ||
## hide x labels and tick labels for top plots and y ticks for right plots | ||
for ax in axs.flat: | ||
ax.label_outer() | ||
if savepath: | ||
plt.savefig(savepath+'pred_vs_groundtruth.png',dpi=1200,bbox_inches='tight') | ||
#plt.show() | ||
|
||
|
||
return mae,mse,rmse | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Mon Feb 1 21:52:07 2021 | ||
@author: wang.zife | ||
""" | ||
|
||
#%% | ||
#import os | ||
import torch | ||
import numpy as np | ||
import torch.backends.cudnn as cudnn | ||
import models | ||
#import mydatasets | ||
|
||
#%% | ||
#def myinfernece(savepath,test_loader,lag,month,hyparas): | ||
def myinference(hyparas,paras,test_loader): | ||
|
||
savepath = paras['savepath'] | ||
device = paras['device'] | ||
checkpoint_name = paras['checkpoint_name'] | ||
checkpoint_pathname = savepath+checkpoint_name | ||
verbose = paras['verbose'] | ||
#input_size = next(iter(test_dataset))[0][0].shape # [Tstep,channel,height,width] | ||
cudnn.benchmark = True | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
#%% | ||
model = models.CAM(**hyparas) | ||
checkpoint = torch.load(checkpoint_pathname, map_location=lambda storage, loc: storage) | ||
model.load_state_dict(checkpoint['model_state_dict']) | ||
model = model.to(device) | ||
model.eval() | ||
|
||
mse,nsamples = 0,0 | ||
preds,ys = [],[] | ||
for inputs,y in test_loader: | ||
#inputs # [batch,channel,height,width] | ||
# y [batch,] | ||
if isinstance(inputs,list): | ||
inputs = [e.to(device) for e in inputs] | ||
with torch.no_grad(): | ||
pred = model(*inputs) # [1,1] | ||
else: | ||
inputs = inputs.to(device) | ||
with torch.no_grad(): | ||
pred,_ = model(inputs) # [1,1] | ||
#print('pred.shape={}'.format(pred.shape)) | ||
pred = np.squeeze(pred.float().cpu().numpy()) # scalar | ||
y = np.squeeze(y.float().cpu().numpy()) # scalar? | ||
mse += np.sum((pred-y)**2) | ||
nsamples += 1 | ||
|
||
preds.append(pred) | ||
ys.append(y) | ||
|
||
preds = np.stack(preds,axis=0) #[Ntest,], unit: mm/day | ||
ys = np.stack(ys,axis=0) #[Ntest,] | ||
|
||
## scale back | ||
#preds = preds*300000 | ||
#ys = ys*300000 | ||
|
||
mse = mse/nsamples | ||
rmse = np.sqrt(mse) | ||
mae = np.sum(np.abs(ys-preds))/nsamples | ||
rmae = np.sum(np.abs(ys-preds)/(np.mean(ys)+1e-5))/nsamples | ||
print('test RMSE:{}, MAE={}, RMAE={}'.format(rmse,mae,rmae)) | ||
mse2 = np.mean((ys-preds)**2) | ||
rmse2 = np.sqrt(mse2) | ||
print('test RMSE2:{}'.format(rmse2)) | ||
|
||
#%% | ||
if savepath: | ||
np.savez(savepath+'pred_results_RMSE{}.npz'.format(rmse),rmse=rmse,mae=mae,rmae=rmae,preds=preds) | ||
#np.savez('../results/PRISM/ConvLSTM/lag_by_month/lag_{}/pred_results_lag_{}_month_{}.npz'.format(lag,lag,month),rmse=rmse,mae=mae,preds=preds) | ||
|
||
#%% plot figures | ||
if verbose: | ||
import matplotlib | ||
matplotlib.use('Agg') | ||
import matplotlib.pyplot as plt | ||
|
||
fig = plt.figure() | ||
plt.plot(ys,'--k',label='Groundtruth') | ||
plt.plot(preds,'-r',label='Prediction') | ||
plt.title('Groundtruth vs Prediction') | ||
plt.xlabel('Month') | ||
plt.ylabel('River flow') | ||
plt.legend() | ||
if savepath: | ||
plt.savefig(savepath+'pred_vs_time.png',dpi=1200,bbox_inches='tight') | ||
|
||
fig = plt.figure() | ||
diff_y_pred = ys-preds | ||
plt.plot(diff_y_pred) | ||
plt.title('Groundtruth-prediction') | ||
#plt.xticks([], []) | ||
plt.xlabel('Month') | ||
#plt.yticks([], []) | ||
plt.ylabel('River flow') | ||
if savepath: | ||
plt.savefig(savepath+'diff_y_pred.png',dpi=1200,bbox_inches='tight') | ||
|
||
|
||
return mae,mse,rmse | ||
|
||
|
Oops, something went wrong.