-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite mul! to dispatch based on memory layout, not matrix type #25558
Changes from 22 commits
e86e8bb
1d80d44
81c2f92
84fc2ee
c65a568
f12129d
f467664
1387afc
4ff991d
9bbdc8f
c4d93e5
9d23927
cac81bc
c745821
fe55aad
807644d
30d5ad8
a67eebe
725ab0e
3f2528d
15238b6
3e1e4c4
64e8609
ce99b1b
1a454fd
33f4e48
37c44d5
c5ddd01
8c4d4cd
482939a
01047c8
3618a39
0b0eb44
bad7814
6110ccb
dfcc856
8ad8a35
5d55e48
74c7f67
d723b1a
a0cd467
681f73b
9d0f50b
3413fba
ed88786
5535185
53c2879
7b19e38
09aa094
f2f1b8f
e19f1a5
196b040
d386ef3
f2bc361
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,9 +16,10 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as | |
getproperty, imag, inv, isapprox, isone, IndexStyle, kron, length, log, map, ndims, | ||
oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech, | ||
setindex!, show, similar, sin, sincos, sinh, size, size_to_strides, sqrt, StridedReinterpretArray, | ||
StridedReshapedArray, strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec | ||
StridedReshapedArray, ReshapedArray, ReinterpretArray, strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec, | ||
MemoryLayout, UnknownLayout | ||
using Base: hvcat_fill, iszero, IndexLinear, _length, promote_op, promote_typeof, | ||
@propagate_inbounds, @pure, reduce, typed_vcat | ||
@propagate_inbounds, @pure, reduce, typed_vcat, AbstractCartesianIndex, RangeIndex, Slice | ||
# We use `_length` because of non-1 indices; releases after julia 0.5 | ||
# can go back to `length`. `_length(A)` is equivalent to `length(linearindices(A))`. | ||
|
||
|
@@ -166,6 +167,115 @@ else | |
const BlasInt = Int32 | ||
end | ||
|
||
## Traits for memory layouts ## | ||
abstract type AbstractStridedLayout{T} <: MemoryLayout{T} end | ||
abstract type DenseColumns{T} <: AbstractStridedLayout{T} end | ||
struct DenseColumnMajor{T} <: DenseColumns{T} end | ||
struct DenseColumnsStridedRows{T} <: DenseColumns{T} end | ||
abstract type DenseRows{T} <: AbstractStridedLayout{T} end | ||
struct DenseRowMajor{T} <: DenseRows{T} end | ||
struct DenseRowsStridedColumns{T} <: DenseRows{T} end | ||
struct StridedLayout{T} <: AbstractStridedLayout{T} end | ||
|
||
""" | ||
AbstractStridedLayout{T} | ||
|
||
is an abstract type whose subtypes are returned by `MemoryLayout(A)` | ||
if a matrix or vector `A` have storage laid out at regular offsets in memory, | ||
and which can therefore be passed to external C and Fortran functions expecting | ||
this memory layout. | ||
""" | ||
AbstractStridedLayout | ||
|
||
""" | ||
DenseColumnMajor{T}() | ||
|
||
is returned by `MemoryLayout(A)` if a vector or matrix `A` has storage in memory | ||
equivalent to an `Array`, so that `stride(A,1) == 1` and `stride(A,2) == size(A,1)`. | ||
Arrays with `DenseColumnMajor` must conform to the `DenseArray` interface. | ||
""" | ||
DenseColumnMajor | ||
|
||
""" | ||
DenseColumnsStridedRows{T}() | ||
|
||
is returned by `MemoryLayout(A)` if a vector or matrix `A` has storage in memory | ||
as a column major matrix. In other words, the columns are stored in memory with | ||
offsets of one, while the rows are stored with offsets given by `stride(A,2)`. | ||
Arrays with `DenseColumnsStridedRows` must conform to the `DenseArray` interface. | ||
""" | ||
DenseColumnsStridedRows | ||
|
||
""" | ||
DenseRowMajor{T}() | ||
|
||
is returned by `MemoryLayout(A)` if a vector or matrix `A` has storage in memory | ||
equivalent to the transpose of an `Array`, so that `stride(A,1) == size(A,1)` and | ||
`stride(A,2) == 1`. Arrays with `DenseRowMajor` must conform to the | ||
`DenseArray` interface. | ||
""" | ||
DenseRowMajor | ||
|
||
""" | ||
DenseRowsStridedColumns{T}() | ||
|
||
is returned by `MemoryLayout(A)` if a matrix `A` has storage in memory | ||
as a row major matrix. In other words, the rows are stored in memory with | ||
offsets of one, while the columns are stored with offsets given by `stride(A,1)`. | ||
`Array`s with `DenseRowsStridedColumns` must conform to the `DenseArray` interface, | ||
and `transpose(A)` should return a matrix whose layout is `DenseColumnsStridedRows{T}()`. | ||
""" | ||
DenseRowsStridedColumns | ||
|
||
""" | ||
StridedLayout{T}() | ||
|
||
is returned by `MemoryLayout(A)` if a vector or matrix `A` has storage laid out at regular | ||
offsets in memory. In other words, the columns are stored with offsets given | ||
by `stride(A,1)` and for matrices the rows are stored in memory with offsets | ||
of `stride(A,2)`. `Array`s with `StridedLayout` must conform to the `DenseArray` interface. | ||
""" | ||
StridedLayout | ||
|
||
MemoryLayout(A::Vector{T}) where T = DenseColumnMajor{T}() | ||
MemoryLayout(A::Matrix{T}) where T = DenseColumnMajor{T}() | ||
MemoryLayout(A::DenseArray{T}) where T = StridedLayout{T}() | ||
|
||
MemoryLayout(A::SubArray) = submemorylayout(MemoryLayout(parent(A)), parentindices(A)) | ||
submemorylayout(::MemoryLayout{T}, _) where T = UnknownLayout{T}() | ||
submemorylayout(::DenseColumns{T}, ::Tuple{I}) where {T,I<:Union{AbstractUnitRange{Int},Int,AbstractCartesianIndex}} = | ||
DenseColumnMajor{T}() | ||
submemorylayout(::AbstractStridedLayout{T}, ::Tuple{I}) where {T,I<:Union{RangeIndex,AbstractCartesianIndex}} = | ||
StridedLayout{T}() | ||
submemorylayout(::DenseColumns{T}, ::Tuple{I,Int}) where {T,I<:Union{AbstractUnitRange{Int},Int,AbstractCartesianIndex}} = | ||
DenseColumnMajor{T}() | ||
submemorylayout(::DenseColumns{T}, ::Tuple{I,Int}) where {T,I<:Slice} = | ||
DenseColumnMajor{T}() | ||
submemorylayout(::DenseRows{T}, ::Tuple{Int,I}) where {T,I<:Union{AbstractUnitRange{Int},Int,AbstractCartesianIndex}} = | ||
DenseColumnMajor{T}() | ||
submemorylayout(::DenseRows{T}, ::Tuple{Int,I}) where {T,I<:Slice} = | ||
DenseColumnMajor{T}() | ||
submemorylayout(::DenseColumnMajor{T}, ::Tuple{I1,I2}) where {T,I1<:Slice,I2<:AbstractUnitRange{Int}} = | ||
DenseColumnMajor{T}() | ||
submemorylayout(::DenseColumnMajor{T}, ::Tuple{I1,I2}) where {T,I1<:AbstractUnitRange{Int},I2<:AbstractUnitRange{Int}} = | ||
DenseColumnsStridedRows{T}() | ||
submemorylayout(::DenseColumns{T}, ::Tuple{I1,I2}) where {T,I1<:AbstractUnitRange{Int},I2<:AbstractUnitRange{Int}} = | ||
DenseColumnsStridedRows{T}() | ||
submemorylayout(::DenseRows{T}, ::Tuple{I1,I2}) where {T,I1<:AbstractUnitRange{Int},I2<:Slice} = | ||
DenseRowMajor{T}() | ||
submemorylayout(::DenseRows{T}, ::Tuple{I1,I2}) where {T,I1<:AbstractUnitRange{Int},I2<:AbstractUnitRange{Int}} = | ||
DenseRowsStridedColumns{T}() | ||
submemorylayout(::AbstractStridedLayout{T}, ::Tuple{I1,I2}) where {T,I1<:Union{RangeIndex,AbstractCartesianIndex},I2<:Union{RangeIndex,AbstractCartesianIndex}} = | ||
StridedLayout{T}() | ||
|
||
MemoryLayout(A::ReshapedArray) = reshapedmemorylayout(MemoryLayout(parent(A))) | ||
reshapedmemorylayout(::MemoryLayout{T}) where T = UnknownLayout{T}() | ||
reshapedmemorylayout(::DenseColumnMajor{T}) where T = DenseColumnMajor{T}() | ||
|
||
MemoryLayout(A::ReinterpretArray{V}) where V = reinterpretedmemorylayout(MemoryLayout(parent(A)), V) | ||
reinterpretedmemorylayout(::MemoryLayout{T}, ::Type{V}) where {T,V} = UnknownLayout{V}() | ||
reinterpretedmemorylayout(::DenseColumnMajor{T}, ::Type{V}) where {T,V} = DenseColumnMajor{V}() | ||
|
||
# Check that stride of matrix/vector is 1 | ||
# Writing like this to avoid splatting penalty when called with multiple arguments, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All this stuff on memory layouts should be moved back to Base. |
||
# see PR 16416 | ||
|
@@ -197,8 +307,9 @@ julia> LinearAlgebra.stride1(B) | |
``` | ||
""" | ||
stride1(x) = stride(x,1) | ||
stride1(x::Array) = 1 | ||
stride1(x::DenseArray) = stride(x, 1)::Int | ||
stride1(x::AbstractArray) = _stride1(x, MemoryLayout(x)) | ||
_stride1(x, _) = stride(x, 1)::Int | ||
_stride1(x, ::DenseColumns) = 1 | ||
|
||
@inline chkstride1(A...) = _chkstride1(true, A...) | ||
@noinline _chkstride1(ok::Bool) = ok || error("matrix does not have contiguous columns") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
# This file is a part of Julia. License is MIT: https://julialang.org/license | ||
|
||
using Base: @propagate_inbounds, _return_type, _default_type, @_inline_meta | ||
import Base: length, size, axes, IndexStyle, getindex, setindex!, parent, vec, convert, similar | ||
import Base: length, size, axes, IndexStyle, getindex, setindex!, parent, vec, convert, similar, conj | ||
|
||
### basic definitions (types, aliases, constructors, abstractarray interface, sundry similar) | ||
|
||
|
@@ -94,6 +94,8 @@ const AdjOrTransAbsMat{T} = AdjOrTrans{T,<:AbstractMatrix} | |
wrapperop(A::Adjoint) = adjoint | ||
wrapperop(A::Transpose) = transpose | ||
|
||
|
||
|
||
# AbstractArray interface, basic definitions | ||
length(A::AdjOrTrans) = length(A.parent) | ||
size(v::AdjOrTransAbsVec) = (1, length(v.parent)) | ||
|
@@ -102,6 +104,36 @@ axes(v::AdjOrTransAbsVec) = (Base.OneTo(1), axes(v.parent)...) | |
axes(A::AdjOrTransAbsMat) = reverse(axes(A.parent)) | ||
IndexStyle(::Type{<:AdjOrTransAbsVec}) = IndexLinear() | ||
IndexStyle(::Type{<:AdjOrTransAbsMat}) = IndexCartesian() | ||
|
||
|
||
# MemoryLayout of transposed and adjoint matrices | ||
struct ConjLayout{T<:Complex, ML<:MemoryLayout} <: MemoryLayout{T} | ||
layout::ML | ||
end | ||
ConjLayout(layout::MemoryLayout{T}) where T<:Complex = ConjLayout{T, typeof(layout)}(layout) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With triangular dispatch, you can now define |
||
conj(::UnknownLayout{T}) where T = UnknownLayout{T}() | ||
conj(c::ConjLayout) = c.layout | ||
conj(layout::MemoryLayout{T}) where T<:Complex = ConjLayout(layout) | ||
|
||
|
||
MemoryLayout(A::Adjoint) = adjoint(MemoryLayout(parent(A))) | ||
MemoryLayout(A::Transpose) = transpose(MemoryLayout(parent(A))) | ||
transpose(::MemoryLayout{T}) where T = UnknownLayout{T}() | ||
transpose(::StridedLayout{T}) where T = StridedLayout{T}() | ||
transpose(::DenseColumnsStridedRows{T}) where T = DenseRowsStridedColumns{T}() | ||
transpose(::DenseRowsStridedColumns{T}) where T = DenseColumnsStridedRows{T}() | ||
transpose(::DenseColumnMajor{T}) where T = DenseRowMajor{T}() | ||
transpose(::DenseRowMajor{T}) where T = DenseColumnMajor{T}() | ||
adjoint(::MemoryLayout{T}) where T = UnknownLayout{T}() | ||
adjoint(M::MemoryLayout{T}) where T<:Real = transpose(M) | ||
adjoint(M::ConjLayout{T}) where T<:Complex = transpose(conj(M)) | ||
adjoint(M::MemoryLayout{T}) where T<:Complex = conj(transpose(M)) | ||
submemorylayout(M::ConjLayout{T}, t::Tuple) where T<:Complex = conj(submemorylayout(conj(M), t)) | ||
|
||
# Adjoints and transposes conform to the strided array interface if their parent does | ||
Base.unsafe_convert(::Type{Ptr{T}}, A::AdjOrTrans{T,S}) where {T,S} = Base.unsafe_convert(Ptr{T}, parent(A)) | ||
strides(A::AdjOrTrans) = (stride(parent(A),2), stride(parent(A),1)) | ||
|
||
@propagate_inbounds getindex(v::AdjOrTransAbsVec, i::Int) = wrapperop(v)(v.parent[i]) | ||
@propagate_inbounds getindex(A::AdjOrTransAbsMat, i::Int, j::Int) = wrapperop(A)(A.parent[j, i]) | ||
@propagate_inbounds setindex!(v::AdjOrTransAbsVec, x, i::Int) = (setindex!(v.parent, wrapperop(v)(x), i); v) | ||
|
@@ -169,10 +201,7 @@ broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs... | |
# Adjoint/Transpose-vector * vector | ||
*(u::AdjointAbsVec, v::AbstractVector) = dot(u.parent, v) | ||
*(u::TransposeAbsVec{T}, v::AbstractVector{T}) where {T<:Real} = dot(u.parent, v) | ||
function *(u::TransposeAbsVec, v::AbstractVector) | ||
@boundscheck length(u) == length(v) || throw(DimensionMismatch()) | ||
return sum(@inbounds(return u[k]*v[k]) for k in 1:length(u)) | ||
end | ||
*(u::TransposeAbsVec, v::AbstractVector) = dotu(u.parent, v) | ||
# vector * Adjoint/Transpose-vector | ||
*(u::AbstractVector, v::AdjOrTransAbsVec) = broadcast(*, u, v) | ||
# Adjoint/Transpose-vector * Adjoint/Transpose-vector | ||
|
@@ -203,7 +232,7 @@ pinv(v::TransposeAbsVec, tol::Real = 0) = pinv(conj(v.parent)).parent | |
/(u::AdjointAbsVec, A::Transpose{<:Any,<:AbstractMatrix}) = adjoint(conj(A.parent) \ u.parent) # technically should be adjoint(copy(adjoint(copy(A))) \ u.parent) | ||
/(u::TransposeAbsVec, A::Adjoint{<:Any,<:AbstractMatrix}) = transpose(conj(A.parent) \ u.parent) # technically should be transpose(copy(transpose(copy(A))) \ u.parent) | ||
|
||
# dismabiguation methods | ||
# disambiguation methods | ||
*(A::AdjointAbsVec, B::Transpose{<:Any,<:AbstractMatrix}) = A * copy(B) | ||
*(A::TransposeAbsVec, B::Adjoint{<:Any,<:AbstractMatrix}) = A * copy(B) | ||
*(A::Transpose{<:Any,<:AbstractMatrix}, B::Adjoint{<:Any,<:AbstractMatrix}) = copy(A) * B | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say this part in the
AbstractStridedLayout
section