Skip to content

Commit

Permalink
Direct Solve (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyhales authored Aug 16, 2023
1 parent adf1719 commit 8012c32
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 118 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ fastparquet
networkx
numpy
pandas
petsc4py
pyyaml
scipy
tqdm
xarray
zarr
224 changes: 107 additions & 117 deletions river_route/_route_muskingum.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,19 @@
import numpy as np
import pandas as pd
import scipy
import tqdm
import xarray as xr
import yaml
from petsc4py import PETSc


class RouteMuskingum:
sim_config_file: str

conf: dict

N: np.array
G: nx.DiGraph

A: np.array
c1: np.array
c2: np.array
c3: np.array
Expand All @@ -34,16 +36,13 @@ def __init__(self,
sim_config_file: str,
**kwargs, ) -> None:
"""
Read config files and prepare to execute simulation
Args:
sim_config_file (str): path to the simulation config file
Read config files to initialize routing class
"""
self.sim_config_file = sim_config_file
self.read_and_validate_configs()
self.read_configs()
return

def read_and_validate_configs(self) -> None:
def read_configs(self) -> None:
"""
Validate simulation conf
"""
Expand All @@ -56,6 +55,21 @@ def read_and_validate_configs(self) -> None:
else:
raise RuntimeError('Unrecognized simulation config file type')

# start a logger
log_basic_configs = {
'stream': sys.stdout,
'level': logging.INFO,
'format': '%(asctime)s %(levelname)s %(message)s',
}
if self.conf.get('log_file', ''):
log_basic_configs['filename'] = self.conf['log_file']
log_basic_configs['filemode'] = 'w'
log_basic_configs.pop('stream')
logging.basicConfig(**log_basic_configs)
return

def validate_configs(self) -> None:
logging.info('Validating configs file')
# check that required file paths are given and exist
required_file_paths = ['routing_params_file',
'connectivity_file',
Expand Down Expand Up @@ -83,18 +97,6 @@ def read_and_validate_configs(self) -> None:
if not os.path.exists(self.conf[arg]):
raise FileNotFoundError(f'{arg} not found at given path')

# start a logger
log_basic_configs = {
'stream': sys.stdout,
'level': logging.INFO,
'format': '%(asctime)s %(levelname)s %(message)s',
}
if self.conf.get('log_file', ''):
log_basic_configs['filename'] = self.conf['log_file']
log_basic_configs['filemode'] = 'w'
log_basic_configs.pop('stream')
logging.basicConfig(**log_basic_configs)

# check that time options have the correct sizes
logging.info('Validating time paramters')
assert self.conf['dt_total'] >= self.conf['dt_inflows'], 'dt_total !>= dt_inflows'
Expand All @@ -114,51 +116,54 @@ def read_and_validate_configs(self) -> None:
# set derived datetime parameters for computation cycles later
self.num_outflow_steps = int(self.conf.get('dt_total') / self.conf.get('dt_outflows'))
self.num_routing_substeps_per_outflow = int(self.conf.get('dt_outflows') / self.conf.get('dt_routing'))

return

def get_connectivity(self) -> pd.DataFrame:
def read_connectivity(self) -> pd.DataFrame:
"""
Reads connectivity matrix from parquet given in config file
"""
return pd.read_parquet(self.conf['connectivity_file'])

def get_riverids(self) -> np.array:
def read_riverids(self) -> np.array:
"""
Reads riverids vector from parquet given in config file
Reads river ids vector from parquet given in config file
"""
return pd.read_parquet(self.conf['routing_params_file'], columns=['id', ]).values.flatten()
return pd.read_parquet(self.conf['routing_params_file'], columns=['rivid', ]).values.flatten()

def get_k(self) -> np.array:
def read_k(self) -> np.array:
"""
Reads K vector from parquet given in config file
"""
return pd.read_parquet(self.conf['routing_params_file'], columns=['k', ]).values.flatten()

def get_x(self) -> np.array:
def read_x(self) -> np.array:
"""
Reads X vector from parquet given in config file
"""
return pd.read_parquet(self.conf['routing_params_file'], columns=['x', ]).values.flatten()

def get_qinit(self) -> np.array:
def read_qinit(self) -> np.array:
qinit = self.conf.get('qinit_file', None)
if qinit is None or qinit == '':
return np.zeros(self.N.shape[0])
return np.zeros(self.A.shape[0])
return pd.read_parquet(self.conf['qinit_file']).values.flatten()

def make_adjacency_array(self) -> None:
def make_adjacency_matrix(self) -> None:
"""
Calculate the connections array from the connectivity file
Calculate the adjacency array from the connectivity file
"""
logging.info('Calculating Network Adjacency Array (N)')
df = self.get_connectivity()
if hasattr(self, 'A') and hasattr(self, 'G'):
return

logging.info('Calculating Network Adjacency Matrix (A)')
df = self.read_connectivity()
G = nx.DiGraph()
G.add_edges_from(df[df.columns[:2]].values)
self.G = G
sorted_order = list(nx.topological_sort(G))
if -1 in sorted_order:
sorted_order.remove(-1)
self.N = scipy.sparse.csc_matrix(nx.to_numpy_array(G, nodelist=sorted_order).T)
self.A = scipy.sparse.csc_matrix(nx.to_numpy_array(G, nodelist=sorted_order).T)
return

def calculate_muskingum_coefficients(self) -> None:
Expand All @@ -167,8 +172,8 @@ def calculate_muskingum_coefficients(self) -> None:
"""
logging.info('Calculating Muskingum coefficients')

k = self.get_k()
x = self.get_x()
k = self.read_k()
x = self.read_x()

dt_route_half = self.conf['dt_routing'] / 2
kx = k * x
Expand All @@ -180,11 +185,7 @@ def calculate_muskingum_coefficients(self) -> None:

# sum of muskingum coefficiencts should be 1 for all segments
a = self.c1 + self.c2 + self.c3
assert np.allclose(a, 1), 'Muskingum coefficients are not approximately equal to 1'

self.c1 = scipy.sparse.csc_matrix(np.diag(self.c1))
self.c2 = scipy.sparse.csc_matrix(np.diag(self.c2))
self.c3 = scipy.sparse.csc_matrix(np.diag(self.c3))
assert np.allclose(a, 1), 'Muskingum coefficients do not approximately sum to 1'
return

def define_flow_temporal_aggregation_function(self) -> None:
Expand All @@ -211,74 +212,57 @@ def define_flow_temporal_aggregation_function(self) -> None:
raise RuntimeError('Unrecognized aggregation method bypassed config file validation')

def route(self) -> None:
self.make_adjacency_array()
"""
Performs time-iterative runoff routing through the river network
"""
self.validate_configs()
self.make_adjacency_matrix()
self.calculate_muskingum_coefficients()
self.define_flow_temporal_aggregation_function()

logging.info('Reading Inflow Data')
with xr.open_dataset(self.conf['inflow_file']) as inflow_ds:
# read dates from the netcdf
dates = inflow_ds['time'].values
# read inflows from the netcdf
dates = inflow_ds['time'].values.astype('datetime64[s]')
inflows = inflow_ds['m3_riv'].values
# convert to m3/s in each routing time step
inflows = inflows / self.num_routing_substeps_per_outflow / self.conf['dt_routing']
inflows[inflows < 0] = np.nan
inflows = np.nan_to_num(inflows, nan=0.0)
inflows = inflows / self.conf['dt_inflows'] # volume to volume/time
# inflows = inflows * 2

logging.info('Scaffolding Outflow File')
self.scaffold_outflow_file(dates)

logging.info('Preparing Arrays')
lhs = scipy.sparse.csc_matrix(np.eye(self.N.shape[0]) - (self.N @ self.c1))
outflow_array = np.zeros((self.num_outflow_steps, self.N.shape[0]))
interval_flows = np.zeros((self.num_routing_substeps_per_outflow, self.N.shape[0]))
q_t = self.get_qinit()

logging.info('Init PETSc Objects and Options')
A = PETSc.Mat().createAIJ(size=lhs.shape, csr=(lhs.indptr, lhs.indices, lhs.data))
x = PETSc.Vec().createSeq(size=lhs.shape[0])
b = PETSc.Vec().createSeq(size=lhs.shape[0])
ksp = PETSc.KSP().create()
ksp.setType('bicg')
ksp.setTolerances(rtol=1e-2)
ksp.setOperators(A)
logging.info('Preparing initialization arrays')
outflow_array = np.zeros((self.num_outflow_steps, self.A.shape[0]))
interval_flows = np.zeros((self.num_routing_substeps_per_outflow, self.A.shape[0]))
q_t = self.read_qinit()
q_ro = np.zeros(self.A.shape[0])
inflow_t = (self.A @ q_t) + q_ro

logging.info('Inverting LHS Matrix')
lhs = scipy.sparse.linalg.inv(
scipy.sparse.csc_matrix(np.eye(self.A.shape[0])) - scipy.sparse.csc_matrix(np.diag(self.c2)) @ self.A
)

logging.info('Performing routing computation iterations')
time_start = datetime.datetime.now()
for inflow_time_step, inflow_end_date in enumerate(dates):
t1 = datetime.datetime.now()
for inflow_time_step, inflow_end_date in enumerate(tqdm.tqdm(dates, desc='Inflows Routed')):
q_ro = inflows[inflow_time_step, :]
interval_flows = np.zeros((self.num_routing_substeps_per_outflow, self.N.shape[0]))
c1_matmul_q_ro = self.c1 @ q_ro

interval_flows = np.zeros((self.num_routing_substeps_per_outflow, self.A.shape[0]))
for routing_substep_iteration in range(self.num_routing_substeps_per_outflow):
# solve the right hand side of the equation
rhs = c1_matmul_q_ro + \
(self.c2 @ (self.N @ q_t + q_ro)) + \
(self.c3 @ q_t)

# set current iteration values in PETSc objects and solve
b.setArray(rhs)
x.setArray(q_t)
ksp.solve(b, x)
q_t = x.getArray()

# remove negatives before other iterations
q_t[q_t < 0] = 0
inflow_tnext = (self.A @ q_t) + q_ro
q_t = lhs @ ((self.c1 * inflow_t) + (self.c2 * q_ro) + (self.c3 * q_t))
interval_flows[routing_substep_iteration, :] = q_t

interval_flows = self.flow_temporal_aggregation_function(np.array(interval_flows), axis=0)
inflow_t = inflow_tnext
interval_flows = np.mean(np.array(interval_flows), axis=0)
interval_flows = np.round(interval_flows, decimals=2)
outflow_array[inflow_time_step, :] = interval_flows
time_end = datetime.datetime.now()
logging.info(f'Routing completed in {(time_end - time_start).total_seconds()} seconds')

logging.info('Cleaning up PETSc objects')
A.destroy()
x.destroy()
b.destroy()
ksp.destroy()
t2 = datetime.datetime.now()
logging.info(f'Routing completed in {(t2 - t1).total_seconds()} seconds')

logging.info('Writing Outflow Array to File')
outflow_array = np.round(outflow_array, decimals=2)
self.write_outflows(outflow_array)
self.write_outflows(dates, outflow_array)

# write the final outflows to disc
if self.conf.get('qfinal_file', False):
Expand All @@ -289,48 +273,54 @@ def route(self) -> None:
return q_t

def scaffold_outflow_file(self, dates) -> None:
xr.Dataset(
data_vars={
'Qout': (['time', 'rivid'], np.zeros((self.num_outflow_steps, self.N.shape[0]))),
},
coords={
'time': dates,
'rivid': self.get_riverids(),
},
attrs={
'long_name': 'Discharge at the outlet of each river reach',
'units': 'm3 s-1',
'standard_name': 'discharge',
'aggregation_method': 'mean',
},
).to_netcdf(
self.conf['outflow_file'],
mode='w',
)
return
time_zero = datetime.datetime.utcfromtimestamp(dates[0].astype(int))
with nc.Dataset(self.conf['outflow_file'], mode='w') as ds:
ds.createDimension('time', size=dates.shape[0])
ds.createDimension('rivid', size=self.A.shape[0])

time_var = ds.createVariable('time', 'f8', ('time',))
time_var.units = f'seconds since {time_zero.strftime("%Y-%m-%d %H:%M:%S")}'

rivid_var = ds.createVariable('rivid', 'i4', ('rivid',))
rivid_var[:] = self.read_riverids()

def write_outflows(self, outflow_array) -> None:
qout_var = ds.createVariable('Qout', 'f4', ('time', 'rivid'))
qout_var.units = 'm3 s-1'
qout_var.long_name = 'Discharge at the outlet of each river reach'
qout_var.standard_name = 'discharge'
qout_var.aggregation_method = self.conf.get('aggregation_method', 'mean')

def write_outflows(self, dates, outflow_array) -> None:
pydates = list(map(datetime.datetime.utcfromtimestamp, dates.astype(int)))
with nc.Dataset(self.conf['outflow_file'], mode='a') as ds:
ds['time'][:] = nc.date2num(pydates, units=ds['time'].units)
ds['Qout'][:] = outflow_array
ds.sync()
return

def append_outflows(self, dates, date_index, outflow_array) -> None:
python_dates = list(map(datetime.datetime.utcfromtimestamp, dates.astype(int)))[date_index]
with nc.Dataset(self.conf['outflow_file'], mode='a') as ds:
ds['time'][:] = nc.date2num(python_dates, ds['time'].units)
ds['Qout'][:] = outflow_array

def plot(self, rivid: int) -> None:
rivid_index = np.where(self.get_riverids() == rivid)[0][0]
with xr.open_dataset(self.conf['outflow_file']) as ds:
ds['Qout'][:, rivid_index].to_dataframe()['Qout'].plot()
ds['Qout'].sel(rivid=rivid).to_dataframe()['Qout'].plot()
plt.show()
return

def mass_balance(self) -> None:
outlet_ids = self.get_connectivity().values
outlet_ids = outlet_ids[outlet_ids[:, 1] == -1, 0]
def mass_balance(self, rivid: int) -> None:
self.validate_configs()
self.make_adjacency_matrix()

upstream_ids = nx.ancestors(self.G, rivid)

with xr.open_dataset(self.conf['outflow_file']) as ds:
out_df = ds.sel(rivid=outlet_ids).to_dataframe()[['Qout', ]].groupby('time').sum().cumsum()
out_df = ds.sel(rivid=rivid).to_dataframe()[['Qout', ]].groupby('time').sum().cumsum()
out_df = out_df * self.conf['dt_routing'] * self.num_routing_substeps_per_outflow
with xr.open_dataset(self.conf['inflow_file']) as ds:
in_df = ds.sel(rivid=outlet_ids).to_dataframe()[['m3_riv', ]].groupby('time').sum().cumsum()
in_df = ds.sel(rivid=list(upstream_ids)).to_dataframe()[['m3_riv', ]].groupby('time').sum().cumsum()

df = out_df.merge(in_df, left_index=True, right_index=True)
logging.info(f'\n{df.sum()}')
Expand Down

0 comments on commit 8012c32

Please sign in to comment.