Skip to content

Commit

Permalink
Added the changes regarding regularization that were already present …
Browse files Browse the repository at this point in the history
…in the master (v2.3.1)
  • Loading branch information
ilariagabusi committed Oct 15, 2024
1 parent 8b07d45 commit 5b7f03f
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 85 deletions.
89 changes: 48 additions & 41 deletions commit/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ cdef class Evaluation :
cdef public A
cdef public regularisation_params
cdef public x
cdef public x_nnls
cdef public CONFIG
cdef public temp_data
cdef public confidence_map_img
Expand All @@ -105,6 +106,7 @@ cdef class Evaluation :
self.regularisation_params = None # set by "set_regularisation" method
self.x = None # set by "fit" method
self.confidence_map_img = None # set by "fit" method
self.x_nnls = None # set by "fit" method (coefficients of IC compartment estimated without regularization)
self.debias_mask = None # set by "fit" method
self.verbose = 3

Expand Down Expand Up @@ -284,7 +286,7 @@ cdef class Evaluation :
self.set_config('ATOMS_path', pjoin( self.get_config('study_path'), 'kernels', self.model.id ))


def generate_kernels( self, regenerate=False, lmax=12, ndirs=500):
def generate_kernels( self, regenerate=False, lmax=12, ndirs=500 ) :
"""Generate the high-resolution response functions for each compartment.
Dispatch to the proper function, depending on the model.
Expand Down Expand Up @@ -848,7 +850,7 @@ cdef class Evaluation :
This field can be specified only if regularisers[0] is 'group_lasso' or 'sparse_group_lasso'.
NB: this array must have the same size as the number of groups in the IC compartment and contain only non-negative values.
'coeff_weights' - np.array(np.float64) :
weights associated to each individual element of the compartment (implemented for all compartments).
weights associated to each individual element of the compartment (for the moment implemented only for IC compartment).
This field can be specified only if the chosen regulariser is 'lasso' or 'sparse_group_lasso'.
NB: this array must have the same size as the number of elements in the compartment and contain only non-negative values.
Expand Down Expand Up @@ -886,10 +888,8 @@ cdef class Evaluation :
regularisation['sizeIC'] = int( self.DICTIONARY['IC']['nF'] * self.KERNELS['wmr'].shape[0] * self.KERNELS['wmc'].shape[0])
regularisation['startEC'] = int( regularisation['sizeIC'] )
regularisation['sizeEC'] = int( self.DICTIONARY['EC']['nE'] * self.KERNELS['wmh'].shape[0] )
regularisation['sizeEC'] = int( self.DICTIONARY['EC']['nE'] * self.KERNELS['wmh'].shape[0] )
regularisation['startISO'] = int( regularisation['sizeIC'] + regularisation['sizeEC'] )
regularisation['sizeISO'] = int( self.DICTIONARY['nV'] * self.KERNELS['iso'].shape[0] )
regularisation['sizeISO'] = int( self.DICTIONARY['nV'] * self.KERNELS['iso'].shape[0] )

regularisation['regIC'] = regularisers[0]
regularisation['regEC'] = regularisers[1]
Expand Down Expand Up @@ -962,7 +962,7 @@ cdef class Evaluation :
logger.error('All coefficients weights must be non-negative')
if dictIC_params['coeff_weights'].size != len(self.DICTIONARY['TRK']['kept']):
logger.error(f'"coeff_weights" must have the same size as the number of elements in the IC compartment (got {dictIC_params["coeff_weights"].size} but {len(self.DICTIONARY["TRK"]["kept"])} expected)')
dictIC_params['coeff_weights'] = dictIC_params['coeff_weights'][self.DICTIONARY['TRK']['kept']==1]
dictIC_params['coeff_weights_kept'] = dictIC_params['coeff_weights'][self.DICTIONARY['TRK']['kept']==1]

# check if group parameters are consistent with the regularisation
if regularisation['regIC'] not in ['group_lasso', 'sparse_group_lasso'] and dictIC_params is not None:
Expand Down Expand Up @@ -1006,6 +1006,9 @@ cdef class Evaluation :

# check if group_weights_extra is consistent with the number of groups
if (regularisation['regIC'] == 'group_lasso' or regularisation['regIC'] == 'sparse_group_lasso') and 'group_weights_extra' in dictIC_params:
if type(dictIC_params['group_weights_extra']) not in [list, np.ndarray]:
logger.error('"group_weights_extra" must be a list or a numpy array')
dictIC_params['group_weights_extra'] = np.array(dictIC_params['group_weights_extra'], dtype=np.float64)
if np.any(dictIC_params['group_weights_extra'] < 0):
logger.error('All group weights must be non-negative')
if dictIC_params['group_weights_extra'].size != dictIC_params['group_idx'].size:
Expand Down Expand Up @@ -1039,46 +1042,50 @@ cdef class Evaluation :
newweightsIC_group.append(weightsIC_group[count])

newweightsIC_group = np.array(newweightsIC_group, dtype=np.float64)
dictIC_params['group_idx'] = np.array(newICgroup_idx, dtype=np.object_)
dictIC_params['group_idx_kept'] = np.array(newICgroup_idx, dtype=np.object_)
if weightsIC_group.size != newweightsIC_group.size:
logger.warning(f"""\
Not all the original groups are kept.
{weightsIC_group.size - newweightsIC_group.size} groups have been removed because their streamlines didn't satify the criteria set in trk2dictionary.""")
{weightsIC_group.size - newweightsIC_group.size} groups have been removed because their streamlines didn't satisfy the criteria set in trk2dictionary.""")
else:
newweightsIC_group = weightsIC_group
dictIC_params['group_idx_kept'] = dictIC_params['group_idx']

# compute group weights
if regularisation['regIC'] == 'group_lasso' or regularisation['regIC'] == 'sparse_group_lasso':
if dictIC_params['group_weights_cardinality']:
group_size = np.array([g.size for g in dictIC_params['group_idx']], dtype=np.int32)
group_size = np.array([g.size for g in dictIC_params['group_idx_kept']], dtype=np.int32)
newweightsIC_group *= np.sqrt(group_size)
if dictIC_params['group_weights_adaptive']:
if self.x is None or self.regularisation_params['regIC'] is not None:
if self.x_nnls is None: #or self.regularisation_params['regIC'] is not None:
logger.error('Group weights cannot be computed if the fit without regularisation has not been performed before')
x_nnls, _, _ = self.get_coeffs(get_normalized=False)
group_x_norm = np.array([np.linalg.norm(x_nnls[g])+1e-12 for g in dictIC_params['group_idx']], dtype=np.float64)
# x_nnls, _, _ = self.get_coeffs(get_normalized=False)
group_x_norm = np.array([np.linalg.norm(self.x_nnls[g])+1e-12 for g in dictIC_params['group_idx_kept']], dtype=np.float64)
newweightsIC_group /= group_x_norm
dictIC_params['group_weights'] = newweightsIC_group

regularisation['dictIC_params'] = dictIC_params

# update lambdas using lambda_max
if regularisation['regIC'] == 'lasso':
if dictIC_params is not None and 'coeff_weights' in dictIC_params:
regularisation['lambdaIC_max'] = compute_lambda_max_lasso(regularisation['startIC'], regularisation['sizeIC'], dictIC_params['coeff_weights'])
if dictIC_params is not None and 'coeff_weights_kept' in dictIC_params:
regularisation['lambdaIC_max'] = compute_lambda_max_lasso(regularisation['startIC'], regularisation['sizeIC'], dictIC_params['coeff_weights_kept'])
else:
regularisation['lambdaIC_max'] = compute_lambda_max_lasso(regularisation['startIC'], regularisation['sizeIC'], np.ones(regularisation['sizeIC'], dtype=np.float64))
regularisation['lambdaIC'] = regularisation['lambdaIC_perc'] * regularisation['lambdaIC_max']
if regularisation['regIC'] == 'group_lasso':
regularisation['lambdaIC_max'] = compute_lambda_max_group(dictIC_params['group_weights'], dictIC_params['group_idx'])
regularisation['lambdaIC_max'] = compute_lambda_max_group(dictIC_params['group_weights'], dictIC_params['group_idx_kept'])
regularisation['lambdaIC'] = regularisation['lambdaIC_perc'] * regularisation['lambdaIC_max']
if regularisation['regIC'] == 'sparse_group_lasso':
regularisation['lambdaIC_max'] = ( compute_lambda_max_lasso(regularisation['startIC'], regularisation['sizeIC']), compute_lambda_max_group(dictIC_params['group_weights'], dictIC_params['group_idx']) )
if 'coeff_weights_kept' in dictIC_params:
regularisation['lambdaIC_max'] = ( compute_lambda_max_lasso(regularisation['startIC'], regularisation['sizeIC'], dictIC_params['coeff_weights_kept']), compute_lambda_max_group(dictIC_params['group_weights'], dictIC_params['group_idx_kept']) )
else:
regularisation['lambdaIC_max'] = ( compute_lambda_max_lasso(regularisation['startIC'], regularisation['sizeIC'], np.ones(regularisation['sizeIC'], dtype=np.float64)), compute_lambda_max_group(dictIC_params['group_weights'], dictIC_params['group_idx_kept']) )
regularisation['lambdaIC'] = ( regularisation['lambdaIC_perc'][0] * regularisation['lambdaIC_max'][0], regularisation['lambdaIC_perc'][1] * regularisation['lambdaIC_max'][1] )

# print
if regularisation['regIC'] is not None:
if (regularisation['regIC'] == 'lasso' or regularisation['regIC'] == 'sparse_group_lasso') and dictIC_params is not None and 'coeff_weights' in dictIC_params:
if (regularisation['regIC'] == 'lasso' or regularisation['regIC'] == 'sparse_group_lasso') and dictIC_params is not None and 'coeff_weights_kept' in dictIC_params:
logger.subinfo( f'Regularisation type: {regularisation["regIC"]} (weighted version)', indent_lvl=2, indent_char='-' )
else:
logger.subinfo( f'Regularisation type: {regularisation["regIC"]}', indent_lvl=2, indent_char='-' )
Expand All @@ -1089,8 +1096,8 @@ cdef class Evaluation :
logger.debug( f'% lambda: {regularisation["lambdaIC_perc"]}' )
logger.debug( f'Lambda used: {regularisation["lambdaIC"]}' )
if regularisation['regIC'] == 'group_lasso' or regularisation['regIC'] == 'sparse_group_lasso':
logger.debug( f'Number of groups: {len(dictIC_params["group_idx"])}' )
if dictIC_params['group_weights_cardinality']==False and dictIC_params['group_weights_adaptive']==False and dictIC_params['group_weights_extra'] is None:
logger.debug( f'Number of groups: {len(dictIC_params["group_idx_kept"])}' )
if dictIC_params['group_weights_cardinality']==False and dictIC_params['group_weights_adaptive']==False and not ('group_weights_extra' in dictIC_params):
logger.debug( 'Group weights are not considered (all ones)' )
else:
str_weights = 'Group weights computed using '
Expand Down Expand Up @@ -1120,9 +1127,9 @@ cdef class Evaluation :
regularisation['lambdaEC_perc'] = lambdas[1]
else:
regularisation['lambdaEC_perc'] = lambdas[1]
if dictEC_params is not None and 'coeff_weights' in dictEC_params:
if dictEC_params['coeff_weights'].size != regularisation['sizeEC']:
logger.error(f'"coeff_weights" must have the same size as the number of elements in the EC compartment (got {dictEC_params["coeff_weights"].size} but {regularisation["sizeEC"]} expected)')
# if dictEC_params is not None and 'coeff_weights' in dictEC_params:
# if dictEC_params['coeff_weights'].size != regularisation['sizeEC']:
# logger.error(f'"coeff_weights" must have the same size as the number of elements in the EC compartment (got {dictEC_params["coeff_weights"].size} but {regularisation["sizeEC"]} expected)')
elif regularisation['regEC'] == 'smoothness':
logger.error('Not yet implemented')
elif regularisation['regEC'] == 'group_lasso':
Expand All @@ -1136,18 +1143,18 @@ cdef class Evaluation :

# update lambdas using lambda_max
if regularisation['regEC'] == 'lasso':
if dictEC_params is not None and 'coeff_weights' in dictEC_params:
regularisation['lambdaEC_max'] = compute_lambda_max_lasso(regularisation['startEC'], regularisation['sizeEC'], dictEC_params['coeff_weights'])
else:
regularisation['lambdaEC_max'] = compute_lambda_max_lasso(regularisation['startEC'], regularisation['sizeEC'], np.ones(regularisation['sizeEC'], dtype=np.float64))
# if dictEC_params is not None and 'coeff_weights' in dictEC_params:
# regularisation['lambdaEC_max'] = compute_lambda_max_lasso(regularisation['startEC'], regularisation['sizeEC'], dictEC_params['coeff_weights'])
# else:
regularisation['lambdaEC_max'] = compute_lambda_max_lasso(regularisation['startEC'], regularisation['sizeEC'], np.ones(regularisation['sizeEC'], dtype=np.float64))
regularisation['lambdaEC'] = regularisation['lambdaEC_perc'] * regularisation['lambdaEC_max']

# print
if regularisation['regEC'] is not None:
if regularisation['regEC'] == 'lasso' and dictEC_params is not None and 'coeff_weights' in dictEC_params:
logger.subinfo( f'Regularisation type: {regularisation["regEC"]} (weighted version)', indent_lvl=2, indent_char='-' )
else:
logger.subinfo( f'Regularisation type: {regularisation["regEC"]}', indent_lvl=2, indent_char='-' )
# if regularisation['regEC'] == 'lasso' and dictEC_params is not None and 'coeff_weights' in dictEC_params:
# logger.subinfo( f'Regularisation type: {regularisation["regEC"]} (weighted version)', indent_lvl=2, indent_char='-' )
# else:
logger.subinfo( f'Regularisation type: {regularisation["regEC"]}', indent_lvl=2, indent_char='-' )

logger.subinfo( f'Non-negativity constraint: {regularisation["nnEC"]}', indent_char='-', indent_lvl=2 )

Expand Down Expand Up @@ -1175,9 +1182,9 @@ cdef class Evaluation :
regularisation['lambdaISO_perc'] = lambdas[2]
else:
regularisation['lambdaISO_perc'] = lambdas[2]
if dictISO_params is not None and 'coeff_weights' in dictISO_params:
if dictISO_params['coeff_weights'].size != regularisation['sizeISO']:
logger.error(f'"coeff_weights" must have the same size as the number of elements in the ISO compartment (got {dictISO_params["coeff_weights"].size} but {regularisation["sizeISO"]} expected)')
# if dictISO_params is not None and 'coeff_weights' in dictISO_params:
# if dictISO_params['coeff_weights'].size != regularisation['sizeISO']:
# logger.error(f'"coeff_weights" must have the same size as the number of elements in the ISO compartment (got {dictISO_params["coeff_weights"].size} but {regularisation["sizeISO"]} expected)')
elif regularisation['regISO'] == 'smoothness':
logger.error('Not yet implemented')
elif regularisation['regISO'] == 'group_lasso':
Expand All @@ -1191,18 +1198,18 @@ cdef class Evaluation :

# update lambdas using lambda_max
if regularisation['regISO'] == 'lasso':
if dictISO_params is not None and 'coeff_weights' in dictISO_params:
regularisation['lambdaISO_max'] = compute_lambda_max_lasso(regularisation['startISO'], regularisation['sizeISO'], dictISO_params['coeff_weights'])
else:
regularisation['lambdaISO_max'] = compute_lambda_max_lasso(regularisation['startISO'], regularisation['sizeISO'], np.ones(regularisation['sizeISO'], dtype=np.float64))
# if dictISO_params is not None and 'coeff_weights' in dictISO_params:
# regularisation['lambdaISO_max'] = compute_lambda_max_lasso(regularisation['startISO'], regularisation['sizeISO'], dictISO_params['coeff_weights'])
# else:
regularisation['lambdaISO_max'] = compute_lambda_max_lasso(regularisation['startISO'], regularisation['sizeISO'], np.ones(regularisation['sizeISO'], dtype=np.float64))
regularisation['lambdaISO'] = regularisation['lambdaISO_perc'] * regularisation['lambdaISO_max']

# print
if regularisation['regISO'] is not None:
if regularisation['regISO'] == 'lasso' and dictISO_params is not None and 'coeff_weights' in dictISO_params:
logger.subinfo( f'Regularisation type: {regularisation["regISO"]} (weighted version)', indent_lvl=2, indent_char='-' )
else:
logger.subinfo( f'Regularisation type: {regularisation["regISO"]}', indent_lvl=2, indent_char='-' )
# if regularisation['regISO'] == 'lasso' and dictISO_params is not None and 'coeff_weights' in dictISO_params:
# logger.subinfo( f'Regularisation type: {regularisation["regISO"]} (weighted version)', indent_lvl=2, indent_char='-' )
# else:
logger.subinfo( f'Regularisation type: {regularisation["regISO"]}', indent_lvl=2, indent_char='-' )

logger.subinfo( f'Non-negativity constraint: {regularisation["nnISO"]}', indent_char='-', indent_lvl=2 )
if regularisation['regISO'] is not None:
Expand Down Expand Up @@ -1901,7 +1908,7 @@ cdef class Evaluation :
with open( pjoin(RESULTS_path,'results.pickle'), 'wb+' ) as fid :
self.CONFIG['optimization']['regularisation'].pop('omega', None)
self.CONFIG['optimization']['regularisation'].pop('prox', None)
pickle.dump( [self.CONFIG, x, self.x, rmse], fid, protocol=2 )
pickle.dump( [self.CONFIG, self.x, x, rmse], fid, protocol=2 )

if save_est_dwi :
logger.subinfo('Estimated signal:', indent_char='-', indent_lvl=2, with_progress=True)
Expand Down
Loading

0 comments on commit 5b7f03f

Please sign in to comment.