Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Gram-Schmidt process #11 #15

Merged
merged 21 commits into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ include/Eigen/*
.idea/
cmake-build*/
build/*
.vscode
*~
98 changes: 98 additions & 0 deletions include/Spectra/LinAlg/Orthogonalization.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (C) 2020 Netherlands eScience Center <[email protected]>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at https://mozilla.org/MPL/2.0/.

#ifndef SPECTRA_ORTHOGONALIZATION_H
#define SPECTRA_ORTHOGONALIZATION_H

#include <Eigen/Core>
#include <Eigen/Dense>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#include <Eigen/QR> seems sufficient. #include <Eigen/Dense> will also include many heavy solvers such as LU, SVD, etc.


namespace Spectra {

template <typename Matrix>
Eigen::Index sanity_check(Matrix& in_output, Eigen::Index leftColsToSkip = 0)
{
assert(in_output.cols() > leftColsToSkip && "leftColsToSkip is larger than columns of matrix");
assert(leftColsToSkip >= 0 && "leftColsToSkip is negative");
if (leftColsToSkip == 0)
{
in_output.col(0).normalize();
leftColsToSkip = 1;
}
return leftColsToSkip;
}
NicoRenaud marked this conversation as resolved.
Show resolved Hide resolved

template <typename Matrix>
void QR_orthogonalisation(Matrix& in_output)
{
using InternalMatrix= Eigen::Matrix<typename Matrix::Scalar,Eigen::Dynamic, Eigen::Dynamic>;
Eigen::Index nrows = in_output.rows();
Eigen::Index ncols = in_output.cols();
ncols = std::min(nrows, ncols);
InternalMatrix I = InternalMatrix::Identity(nrows, ncols);
Eigen::HouseholderQR<Matrix> qr(in_output);
in_output = qr.householderQ() * I;
}

template <typename Matrix>
void MGS_orthogonalisation(Matrix& in_output, Eigen::Index leftColsToSkip = 0)
{
leftColsToSkip = sanity_check(in_output, leftColsToSkip);

for (Eigen::Index k = leftColsToSkip; k < in_output.cols(); ++k)
{
for (Eigen::Index j = 0; j < k; j++)
{
in_output.col(k) -= in_output.col(j).dot(in_output.col(k)) / (in_output.col(j).dot(in_output.col(j))) * in_output.col(j);
NicoRenaud marked this conversation as resolved.
Show resolved Hide resolved
}
in_output.col(k).normalize();
}
}

template <typename Matrix>
void GS_orthogonalisation(Matrix& in_output, Eigen::Index leftColsToSkip = 0)
{
leftColsToSkip = sanity_check(in_output, leftColsToSkip);

for (Eigen::Index j = leftColsToSkip; j < in_output.cols(); ++j)
{
in_output.col(j) -= in_output.leftCols(j) * (in_output.leftCols(j).transpose() * in_output.col(j));
in_output.col(j).normalize();
}
}

template <typename Matrix>
void twice_is_enough_orthogonalisation(Matrix& in_output, Eigen::Index leftColsToSkip = 0)
{
GS_orthogonalisation(in_output, leftColsToSkip);
GS_orthogonalisation(in_output, leftColsToSkip);
}

template <typename Matrix>
void partial_orthogonalisation(Matrix& in_output, Eigen::Index leftColsToSkip = 0)
NicoRenaud marked this conversation as resolved.
Show resolved Hide resolved
{
leftColsToSkip = sanity_check(in_output, leftColsToSkip);

Eigen::Index rightColToOrtho = in_output.cols() - leftColsToSkip;
in_output.rightCols(rightColToOrtho) -= (in_output.leftCols(leftColsToSkip) * in_output.leftCols(leftColsToSkip).transpose()) * in_output.rightCols(rightColToOrtho);
in_output.rightCols(rightColToOrtho).colwise().normalize();
}

template <typename Matrix>
void JensWehner_orthogonalisation(Matrix& in_output, Eigen::Index leftColsToSkip = 0)
{
leftColsToSkip = sanity_check(in_output, leftColsToSkip);

Eigen::Index rightColToOrtho = in_output.cols() - leftColsToSkip;
partial_orthogonalisation(in_output, leftColsToSkip);
Eigen::Ref<Matrix> right_cols = in_output.rightCols(rightColToOrtho);
QR_orthogonalisation(right_cols);
in_output.rightCols(rightColToOrtho) = right_cols;
}

} // namespace Spectra

#endif //SPECTRA_ORTHOGONALIZATION_H
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ list(APPEND test_target_sources
GenEigs.cpp
GenEigsRealShift.cpp
GenEigsComplexShift.cpp
Orthogonalization.cpp
SymGEigsCholesky.cpp
SymGEigsRegInv.cpp
SVD.cpp
Expand Down
99 changes: 99 additions & 0 deletions test/Orthogonalization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#include <Eigen/Core>
#include <Spectra/LinAlg/Orthogonalization.h>
#include <iostream>
using namespace Spectra;

#define CATCH_CONFIG_MAIN
#include "catch.hpp"

using Eigen::MatrixXd;
using Eigen::VectorXd;
using Eigen::Index;

template <typename Matrix>
void check_orthogonality(const Matrix& basis)
{
const double tol = 1e-12;
Matrix xs = basis.transpose() * basis;
INFO("The orthonormalized basis must fulfill that basis.T * basis = I");
INFO("Matrix is\n " << basis);
INFO("Overlap is\n " << xs);
CHECK(xs.isIdentity(tol));
}

TEST_CASE("complete orthonormalization", "[orthogonalisation]")
{
std::srand(123);
const Index n = 20;

MatrixXd mat = MatrixXd::Random(n, n);

SECTION("MGS")
{
MGS_orthogonalisation(mat);
check_orthogonality(mat);
}

SECTION("GS")
{
GS_orthogonalisation(mat);
check_orthogonality(mat);
}

SECTION("QR")
{
QR_orthogonalisation(mat);
check_orthogonality(mat);
}

SECTION("twice_is_enough")
{
twice_is_enough_orthogonalisation(mat);
check_orthogonality(mat);
}

SECTION("JensWehner")
{
JensWehner_orthogonalisation(mat);
check_orthogonality(mat);
}
}

TEST_CASE("Partial orthonormalization", "[orthogonalisation]")
{
std::srand(123);
const Index n = 20;
const Index sub = 5;
Index start = n - sub;

// Create a n x 20 orthonormal basis
MatrixXd mat = MatrixXd::Random(n, start);
QR_orthogonalisation(mat);

mat.conservativeResize(Eigen::NoChange, n);
mat.rightCols(sub) = MatrixXd::Random(n, sub);

SECTION("MGS")
{
MGS_orthogonalisation(mat, start);
check_orthogonality(mat);
}

SECTION("GS")
{
GS_orthogonalisation(mat, start);
check_orthogonality(mat);
}

SECTION("twice_is_enough")
{
twice_is_enough_orthogonalisation(mat, start);
check_orthogonality(mat);
}

SECTION("JensWehner")
{
JensWehner_orthogonalisation(mat, start);
check_orthogonality(mat);
}
}