Skip to content

Commit

Permalink
detect bug cholesky failure and call dgels
Browse files Browse the repository at this point in the history
  • Loading branch information
ShadenSmith committed Oct 20, 2016
1 parent 399c431 commit 17c26e9
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 41 deletions.
3 changes: 2 additions & 1 deletion src/cpd.c
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,8 @@ double cpd_als_iterate(
}
}
}
if(it > 0 && fabs(fit - oldfit) < opts[SPLATT_OPTION_TOLERANCE]) {
if(fit == 1. ||
(it > 0 && fabs(fit - oldfit) < opts[SPLATT_OPTION_TOLERANCE])) {
break;
}
oldfit = fit;
Expand Down
160 changes: 120 additions & 40 deletions src/matrix.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@
void sgetrf_(int *, int *, float *, int *, int *, int *);
void sgetrs_(char *, int *, int *, float *, int *, int *, float *, int *, int *);

void sgels_(char *, int *, int *, int *, float *, int *, float *, int *, float *, int *, int *);


#define LAPACK_DPOTRF spotrf_
#define LAPACK_DPOTRS spotrs_
#define LAPACK_DSYRK ssyrk_
#define LAPACK_DGETRF sgetrf_
#define LAPACK_DGETRS sgetrs_
#define LAPACK_DGELS sgels_
#else
void dpotrf_(char *, int *, double *, int *, int *);
void dpotrs_(char *, int *, int *, double *, int *, double *, int *, int *);
Expand All @@ -33,11 +37,15 @@
void dgetrf_(int *, int *, double *, int *, int *, int *);
void dgetrs_(char *, int *, int *, double *, int *, int *, double *, int *, int *);

/* QR solve */
void dgels_(char *, int *, int *, int *, double *, int *, double *, int *, double *, int *, int *);

#define LAPACK_DPOTRF dpotrf_
#define LAPACK_DPOTRS dpotrs_
#define LAPACK_DSYRK dsyrk_
#define LAPACK_DGETRF dgetrf_
#define LAPACK_DGETRS dgetrs_
#define LAPACK_DGELS dgels_
#endif


Expand All @@ -46,6 +54,74 @@
* PRIVATE FUNCTIONS
*****************************************************************************/


/**
* @brief Form the Gram matrix from A^T * A.
*
* @param[out] neq_matrix The matrix to fill.
* @param aTa The individual Gram matrices.
* @param mode Which mode we are computing for.
* @param nmodes How many total modes.
* @param reg Regularization parameter (to add to the diagonal).
*/
static void p_form_gram(
matrix_t * neq_matrix,
matrix_t * * aTa,
idx_t const mode,
idx_t const nmodes,
val_t const reg)
{
/* nfactors */
int N = aTa[0]->J;

/* form upper-triangual normal equations */
val_t * const restrict neqs = neq_matrix->vals;
#pragma omp parallel
{
/* first initialize with 1s */
#pragma omp for schedule(static, 1)
for(int i=0; i < N; ++i) {
neqs[i+(i*N)] = 1. + reg;
for(int j=0; j < N; ++j) {
neqs[j+(i*N)] = 1.;
}
}

/* now Hadamard product all (A^T * A) matrices */
for(idx_t m=0; m < nmodes; ++m) {
if(m == mode) {
continue;
}

val_t const * const restrict mat = aTa[m]->vals;
#pragma omp for schedule(static, 1)
for(int i=0; i < N; ++i) {
/*
* `mat` is symmetric but stored upper right triangular, so be careful
* to only access that.
*/

/* copy upper triangle */
for(int j=i; j < N; ++j) {
neqs[j+(i*N)] *= mat[j+(i*N)];
}
}
} /* foreach mode */

#pragma omp barrier

/* now copy lower triangular */
#pragma omp for schedule(static, 1)
for(int i=0; i < N; ++i) {
for(int j=0; j < i; ++j) {
neqs[j+(i*N)] = neqs[i+(j*N)];
}
}
} /* omp parallel */
}



static void p_mat_2norm(
matrix_t * const A,
val_t * const restrict lambda,
Expand Down Expand Up @@ -484,6 +560,7 @@ void mat_normalize(
}



void mat_solve_normals(
idx_t const mode,
idx_t const nmodes,
Expand All @@ -496,55 +573,58 @@ void mat_solve_normals(
/* nfactors */
int N = aTa[0]->J;

/* form upper-triangual normal equations */
val_t * const restrict neqs = aTa[MAX_NMODES]->vals;
#pragma omp parallel
{
/* first initialize */
#pragma omp for schedule(static, 1)
for(int i=0; i < N; ++i) {
neqs[i+(i*N)] = 1. + reg;
for(int j=0; j < N; ++j) {
neqs[j+(i*N)] = 1.;
}
}
p_form_gram(aTa[MAX_NMODES], aTa, mode, nmodes, reg);

for(idx_t m=0; m < nmodes; ++m) {
if(m == mode) {
continue;
}
int info;
char uplo = 'L';
int lda = N;
int ldb = N;
int order = N;
int nrhs = (int) rhs->I;

val_t const * const restrict mat = aTa[m]->vals;
#pragma omp for schedule(static, 1) nowait
for(int i=0; i < N; ++i) {
for(int j=0; j < N; ++j) {
neqs[j+(i*N)] *= mat[j+(i*N)];
}
}
}
} /* omp parallel */
val_t * const neqs = aTa[MAX_NMODES]->vals;

bool is_spd = true;

/* LU factorization */
int * ipiv = splatt_malloc(N * sizeof(*ipiv));
int info;
LAPACK_DGETRF(&N, &N, neqs, &N, ipiv, &info);
/* Cholesky factorization */
LAPACK_DPOTRF(&uplo, &order, neqs, &lda, &info);
if(info) {
fprintf(stderr, "SPLATT: DGETRF returned %d\n", info);
fprintf(stderr, "SPLATT: Gram matrix is not SPD. Trying *GELS().\n");
is_spd = false;
}


/* solve system of equations */
int nrhs = rhs->I;
char trans = 'N';
LAPACK_DGETRS(&trans, &N, &nrhs, neqs, &N, ipiv, rhs->vals, &N, &info);
if(info) {
fprintf(stderr, "SPLATT: DGETRS returned %d\n", info);
/* Continue with Cholesky */
if(is_spd) {
/* Solve against rhs */
LAPACK_DPOTRS(&uplo, &order, &nrhs, neqs, &lda, rhs->vals, &ldb, &info);
if(info) {
fprintf(stderr, "SPLATT: DPOTRS returned %d\n", info);
}
} else {
/* restore gram matrix */
p_form_gram(aTa[MAX_NMODES], aTa, mode, nmodes, reg);

char trans = 'N';

/* query worksize */
int lwork = -1;
val_t work_query;
LAPACK_DGELS(&trans, &N, &N, &nrhs, neqs, &lda, rhs->vals, &ldb, &work_query,
&lwork, &info);
lwork = (int) work_query;

/* setup workspace */
val_t * work = splatt_malloc(lwork * sizeof(*work));

/* Use a LS solver */
LAPACK_DGELS(&trans, &N, &N, &nrhs, neqs, &lda, rhs->vals, &ldb, work, &lwork,
&info);
if(info) {
printf("SPLATT: DGELS returned %d\n", info);
}
splatt_free(work);
}

/* cleanup pivot */
splatt_free(ipiv);

timer_stop(&timers[TIMER_INV]);
}

Expand Down

0 comments on commit 17c26e9

Please sign in to comment.