Skip to content

Commit

Permalink
perf: Tile 8x8 covariance matrix multiplication
Browse files Browse the repository at this point in the history
Currently, we are multiplying an 8x8 covariance matrix with an 8x8
transport matrix, and we see that Eigen is failing to optimize this
properly, because it is calling a generalized GEMM method rather than an
optimized small matrix method. In order to resolve this, we change the
code to use a tiled multiplication method which splits the matrices into
4x4 sub-matrices which can be multiplied and added to achieve the
desired effect. This has two advantages:

  1. It allows Eigen to use its hand-rolled optimized 4x4 matrix
     multiplication methods.
  2. It allows us to perform some trickery with matrix identities to
     reduce the number of floating point operations.
  • Loading branch information
stephenswat committed Mar 3, 2022
1 parent 988e2ca commit a19f558
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions Core/include/Acts/Propagator/EigenStepper.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,50 @@ Acts::Result<double> Acts::EigenStepper<E, A>::step(
return EigenStepperError::StepInvalid;
}

// for moment, only update the transport part
state.stepping.jacTransport = D * state.stepping.jacTransport;
// Here, we want to calculate the expression: K = DJ, which is an 8×8
// matrix multiplication operation. Eigen, natively, doesn't do a great job
// at this, and applies a slow GEMM operation. We apply a blocked matrix
// multiplication operation, relying on the fact that:
//
// ┌ ┐ ┌ ┐ ┌ ┐ K₁₁ = D₁₁ * J₁₁ + D₁₂ * J₂₁
// │ K₁₁ K₁₂ │ = │ D₁₁ D₁₂ │ │ J₁₁ J₁₂ │ K₁₂ = D₁₁ * J₁₂ + D₁₂ * J₂₂
// │ K₂₁ K₂₂ │ = │ D₂₁ D₂₂ │ │ J₂₁ J₂₂ │ K₂₁ = D₂₁ * J₁₁ + D₂₂ * J₂₁
// └ ┘ └ ┘ └ ┘ K₂₂ = D₂₁ * J₁₂ + D₂₂ * J₂₂
//
// All of these sub-matrices are 4×4, and Eigen does a much better job
// optimizing these operations. However, we can go one further. Let's
// assume that some of these sub-matrices are zero matrices 0₈ and
// identity matrices I₈, namely:
//
// D₁₁ = I₈, J₁₁ = I₈, D₂₁ = 0₈, J₂₁ = 0₈
//
// Which gives:
//
// K₁₁ = I₈ * I₈ + D₁₂ * 0₈ = I₈
// K₁₂ = I₈ * J₁₂ + D₁₂ * J₂₂ = J₁₂ + D₁₂ * J₂₂
// K₂₁ = 0₈ * I₈ + D₂₂ * 0₈ = 0₈
// K₂₂ = 0₈ * J₁₂ + D₂₂ * J₂₂ = D₂₂ * J₂₂
//
// Furthermore, we're constructing K in place of J, and since
// K₁₁ = I₈ = J₁₁ and K₂₁ = 0₈ = D₂₁, we don't actually need to touch those
// sub-matrices at all!
if ((D.topLeftCorner<4, 4>().isIdentity()) &&
(D.bottomLeftCorner<4, 4>().isZero()) &&
(state.stepping.jacTransport.template topLeftCorner<4, 4>()
.isIdentity()) &&
(state.stepping.jacTransport.template bottomLeftCorner<4, 4>()
.isZero())) {
state.stepping.jacTransport.template topRightCorner<4, 4>() +=
D.topRightCorner<4, 4>() *
state.stepping.jacTransport.template bottomRightCorner<4, 4>();
state.stepping.jacTransport.template bottomRightCorner<4, 4>() =
D.bottomRightCorner<4, 4>() *
state.stepping.jacTransport.template bottomRightCorner<4, 4>();
} else {
// For safety purposes, we provide a full matrix multiplication as a
// backup strategy.
state.stepping.jacTransport = D * state.stepping.jacTransport;
}
} else {
if (!state.stepping.extension.finalize(state, *this, h)) {
return EigenStepperError::StepInvalid;
Expand Down

0 comments on commit a19f558

Please sign in to comment.