Skip to content

Commit

Permalink
add tri forms for woodbury (Xt_A_X, X_A_Xt) (#27)
Browse files Browse the repository at this point in the history
Adds Xt_A_X and X_A_Xt for WoodburyPDMat

Co-authored-by: William J. Wilkinson <[email protected]>
  • Loading branch information
AlexRobson and William J. Wilkinson authored Jun 10, 2022
1 parent 79d73fa commit 6ddd37d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PDMatsExtras"
uuid = "2c7acb1b-7338-470f-b38f-951d2bcb9193"
authors = ["Invenia Technical Computing"]
version = "2.5.2"
version = "2.6.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
4 changes: 4 additions & 0 deletions src/woodbury_pd_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ function PDMats.unwhiten!(r::DenseVecOrMat, W::WoodburyPDMat{<:Real}, x::DenseVe
return unwhiten!(r, PDMat(Symmetric(Matrix(W))), x)
end



# NOTE: the parameterisation to scale up the Woodbury matrix is not unique. Here we
# implement one way to scale it.
PDMats.Xt_A_X(a::WoodburyPDMat, x::Diagonal) = WoodburyPDMat(x * a.A, a.D, x * a.S * x)
PDMats.X_A_Xt(a::WoodburyPDMat, x::Diagonal) = PDMats.Xt_A_X(a, x)
*(a::WoodburyPDMat, c::Real) = WoodburyPDMat(a.A, a.D * c, a.S * c)
*(c::Real, a::WoodburyPDMat) = a * c
8 changes: 8 additions & 0 deletions test/woodbury_pd_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
A = randn(4, 2)
D = Diagonal(randn(2).^2 .+ 1)
S = Diagonal(randn(4).^2 .+ 1)
σ = Diagonal(rand(4,))
x = randn(size(A, 1))

W = WoodburyPDMat(A, D, S)
Expand Down Expand Up @@ -50,6 +51,13 @@
@test (c * W) isa WoodburyPDMat
end

@testset "$f X::Diagonal" for f in [Xt_A_X, X_A_Xt]
@test f(W, σ) σ * W_dense * σ atol=1e-6
@test f(W, σ) σ * W * σ atol=1e-6
@test Matrix(f(W, σ)) f(W_dense, σ) atol=1e-6
@test f(W, σ) isa WoodburyPDMat
end

@testset "MvNormal logpdf" begin
m = randn(size(A, 1))
@test logpdf(MvNormal(m, W), x) logpdf(MvNormal(m, Symmetric(Matrix(W))), x)
Expand Down

2 comments on commit 6ddd37d

@AlexRobson
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/62143

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.6.0 -m "<description of version>" 6ddd37d6b1a1d42fb67acc090bb2ae2dcd3e63be
git push origin v2.6.0

Please sign in to comment.