diff --git a/src/cpd.c b/src/cpd.c index 456b18b..6b709cc 100644 --- a/src/cpd.c +++ b/src/cpd.c @@ -355,7 +355,8 @@ double cpd_als_iterate( } } } - if(it > 0 && fabs(fit - oldfit) < opts[SPLATT_OPTION_TOLERANCE]) { + if(fit == 1. || + (it > 0 && fabs(fit - oldfit) < opts[SPLATT_OPTION_TOLERANCE])) { break; } oldfit = fit; diff --git a/src/matrix.c b/src/matrix.c index fec9ce1..452660d 100644 --- a/src/matrix.c +++ b/src/matrix.c @@ -19,11 +19,15 @@ void sgetrf_(int *, int *, float *, int *, int *, int *); void sgetrs_(char *, int *, int *, float *, int *, int *, float *, int *, int *); + void sgels_(char *, int *, int *, int *, float *, int *, float *, int *, float *, int *, int *); + + #define LAPACK_DPOTRF spotrf_ #define LAPACK_DPOTRS spotrs_ #define LAPACK_DSYRK ssyrk_ #define LAPACK_DGETRF sgetrf_ #define LAPACK_DGETRS sgetrs_ + #define LAPACK_DGELS sgels_ #else void dpotrf_(char *, int *, double *, int *, int *); void dpotrs_(char *, int *, int *, double *, int *, double *, int *, int *); @@ -33,11 +37,15 @@ void dgetrf_(int *, int *, double *, int *, int *, int *); void dgetrs_(char *, int *, int *, double *, int *, int *, double *, int *, int *); + /* QR solve */ + void dgels_(char *, int *, int *, int *, double *, int *, double *, int *, double *, int *, int *); + #define LAPACK_DPOTRF dpotrf_ #define LAPACK_DPOTRS dpotrs_ #define LAPACK_DSYRK dsyrk_ #define LAPACK_DGETRF dgetrf_ #define LAPACK_DGETRS dgetrs_ + #define LAPACK_DGELS dgels_ #endif @@ -46,6 +54,74 @@ * PRIVATE FUNCTIONS *****************************************************************************/ + +/** +* @brief Form the Gram matrix from A^T * A. +* +* @param[out] neq_matrix The matrix to fill. +* @param aTa The individual Gram matrices. +* @param mode Which mode we are computing for. +* @param nmodes How many total modes. +* @param reg Regularization parameter (to add to the diagonal). +*/ +static void p_form_gram( + matrix_t * neq_matrix, + matrix_t * * aTa, + idx_t const mode, + idx_t const nmodes, + val_t const reg) +{ + /* nfactors */ + int N = aTa[0]->J; + + /* form upper-triangual normal equations */ + val_t * const restrict neqs = neq_matrix->vals; + #pragma omp parallel + { + /* first initialize with 1s */ + #pragma omp for schedule(static, 1) + for(int i=0; i < N; ++i) { + neqs[i+(i*N)] = 1. + reg; + for(int j=0; j < N; ++j) { + neqs[j+(i*N)] = 1.; + } + } + + /* now Hadamard product all (A^T * A) matrices */ + for(idx_t m=0; m < nmodes; ++m) { + if(m == mode) { + continue; + } + + val_t const * const restrict mat = aTa[m]->vals; + #pragma omp for schedule(static, 1) + for(int i=0; i < N; ++i) { + /* + * `mat` is symmetric but stored upper right triangular, so be careful + * to only access that. + */ + + /* copy upper triangle */ + for(int j=i; j < N; ++j) { + neqs[j+(i*N)] *= mat[j+(i*N)]; + } + } + } /* foreach mode */ + + #pragma omp barrier + + /* now copy lower triangular */ + #pragma omp for schedule(static, 1) + for(int i=0; i < N; ++i) { + for(int j=0; j < i; ++j) { + neqs[j+(i*N)] = neqs[i+(j*N)]; + } + } + } /* omp parallel */ +} + + + static void p_mat_2norm( matrix_t * const A, val_t * const restrict lambda, @@ -484,6 +560,7 @@ void mat_normalize( } + void mat_solve_normals( idx_t const mode, idx_t const nmodes, @@ -496,55 +573,58 @@ void mat_solve_normals( /* nfactors */ int N = aTa[0]->J; - /* form upper-triangual normal equations */ - val_t * const restrict neqs = aTa[MAX_NMODES]->vals; - #pragma omp parallel - { - /* first initialize */ - #pragma omp for schedule(static, 1) - for(int i=0; i < N; ++i) { - neqs[i+(i*N)] = 1. + reg; - for(int j=0; j < N; ++j) { - neqs[j+(i*N)] = 1.; - } - } + p_form_gram(aTa[MAX_NMODES], aTa, mode, nmodes, reg); - for(idx_t m=0; m < nmodes; ++m) { - if(m == mode) { - continue; - } + int info; + char uplo = 'L'; + int lda = N; + int ldb = N; + int order = N; + int nrhs = (int) rhs->I; - val_t const * const restrict mat = aTa[m]->vals; - #pragma omp for schedule(static, 1) nowait - for(int i=0; i < N; ++i) { - for(int j=0; j < N; ++j) { - neqs[j+(i*N)] *= mat[j+(i*N)]; - } - } - } - } /* omp parallel */ + val_t * const neqs = aTa[MAX_NMODES]->vals; + bool is_spd = true; - /* LU factorization */ - int * ipiv = splatt_malloc(N * sizeof(*ipiv)); - int info; - LAPACK_DGETRF(&N, &N, neqs, &N, ipiv, &info); + /* Cholesky factorization */ + LAPACK_DPOTRF(&uplo, &order, neqs, &lda, &info); if(info) { - fprintf(stderr, "SPLATT: DGETRF returned %d\n", info); + fprintf(stderr, "SPLATT: Gram matrix is not SPD. Trying *GELS().\n"); + is_spd = false; } - - /* solve system of equations */ - int nrhs = rhs->I; - char trans = 'N'; - LAPACK_DGETRS(&trans, &N, &nrhs, neqs, &N, ipiv, rhs->vals, &N, &info); - if(info) { - fprintf(stderr, "SPLATT: DGETRS returned %d\n", info); + /* Continue with Cholesky */ + if(is_spd) { + /* Solve against rhs */ + LAPACK_DPOTRS(&uplo, &order, &nrhs, neqs, &lda, rhs->vals, &ldb, &info); + if(info) { + fprintf(stderr, "SPLATT: DPOTRS returned %d\n", info); + } + } else { + /* restore gram matrix */ + p_form_gram(aTa[MAX_NMODES], aTa, mode, nmodes, reg); + + char trans = 'N'; + + /* query worksize */ + int lwork = -1; + val_t work_query; + LAPACK_DGELS(&trans, &N, &N, &nrhs, neqs, &lda, rhs->vals, &ldb, &work_query, + &lwork, &info); + lwork = (int) work_query; + + /* setup workspace */ + val_t * work = splatt_malloc(lwork * sizeof(*work)); + + /* Use a LS solver */ + LAPACK_DGELS(&trans, &N, &N, &nrhs, neqs, &lda, rhs->vals, &ldb, work, &lwork, + &info); + if(info) { + printf("SPLATT: DGELS returned %d\n", info); + } + splatt_free(work); } - /* cleanup pivot */ - splatt_free(ipiv); - timer_stop(&timers[TIMER_INV]); }