Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ArrayInterface.Size instead of ArrayInterface.size #241

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,8 @@ end

abstract type AbstractArray2{T,N} <: AbstractArray{T,N} end

Base.size(A::AbstractArray2) = map(Int, ArrayInterface.size(A))
Base.size(A::AbstractArray2, dim) = Int(ArrayInterface.size(A, dim))
Base.size(A::AbstractArray2) = Base.size(Size(A))
Base.size(A::AbstractArray2, dim) = length(Size(A, dim))

function Base.axes(A::AbstractArray2)
!(parent_type(A) <: typeof(A)) && return ArrayInterface.axes(parent(A))
Expand Down Expand Up @@ -731,13 +731,13 @@ function __init__()
@generated function axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S}
return Tuple{[StaticArrays.SOneTo{s} for s in S.parameters]...}
end
@generated function size(A::StaticArrays.StaticArray{S}) where {S}
@generated function ArrayInterface.Size(A::StaticArrays.StaticArray{S}) where {S}
t = Expr(:tuple)
Sp = S.parameters
for n = 1:length(Sp)
push!(t.args, Expr(:call, Expr(:curly, :StaticInt, Sp[n])))
end
return t
return :(ArrayInterface.Size($t))
end
@generated function strides(A::StaticArrays.StaticArray{S}) where {S}
t = Expr(:tuple, Expr(:call, Expr(:curly, :StaticInt, 1)))
Expand Down
140 changes: 90 additions & 50 deletions src/size.jl
Original file line number Diff line number Diff line change
@@ -1,79 +1,115 @@

"""
size(A) -> Tuple
size(A, dim) -> Union{Int,StaticInt}
Size(s::Tuple{Vararg{Union{Int,StaticInt}})
Size(A) -> Size(size(A))

Returns the size of each dimension of `A` or along dimension `dim` of `A`. If the size of
any axes are known at compile time, these should be returned as `Static` numbers. Otherwise,
`ArrayInterface.size(A)` is identical to `Base.size(A)`
Type that represents statically sized dimensions as `StaticInt`s.
"""
struct Size{S<:Tuple}
size::S

```julia
julia> using StaticArrays, ArrayInterface
Size{S}(s::Tuple{Vararg{<:CanonicalInt}}) where {S} = new{S}(s::S)
Size(s::Tuple{Vararg{<:CanonicalInt}}) = Size{typeof(s)}(s)
end

julia> A = @SMatrix rand(3,4);
"""
Length(x::Union{Int,StaticInt})
Length(A) = Length(length(A))

julia> ArrayInterface.size(A)
(static(3), static(4))
```
Type that represents statically sized dimensions as `StaticInt`s.
"""
size(a::A) where {A} = _maybe_size(Base.IteratorSize(A), a)
const Length{L} = Size{Tuple{L}}
Length(x::CanonicalInt) = Size((x,))
@inline function Length(x)
len = known_length(x)
if len === missing
return Length(length(x))
else
return Length(static(len))
end
end

Base.ndims(@nospecialize(s::Size)) = ndims(typeof(s))
Base.ndims(::Type{<:Size{S}}) where {S} = known_length(S)
Base.size(s::Size{Tuple{Vararg{Int}}}) = getfield(s, :size)
Base.size(s::Size) = map(Int, s.size)
function Base.size(s::Size{S}, dim::CanonicalInt) where {S}
if dim > known_length(S)
return 1
else
return Int(getfield(s.size, Int(dim)))
end
end

Base.:(==)(x::Size, y::Size) = getfield(x, :size) == getfield(y, :size)

static_length(x::Length) = getfield(getfield(x, :size), 1)
static_length(x::Size) = prod(getfield(x, :size))
Base.length(x::Size) = Int(static_length(x))

Base.show(io::IO, ::MIME"text/plain", @nospecialize(x::Size)) = print(io, "Size($(x.size))")

# default constructors
Size(s::Size) = s
Size(a::A) where {A} = Size(_maybe_size(Base.IteratorSize(A), a))
_maybe_size(::Base.HasShape{N}, a::A) where {N,A} = map(static_length, axes(a))
_maybe_size(::Base.HasLength, a::A) where {A} = (static_length(a),)
size(x::SubArray) = eachop(_sub_size, to_parent_dims(x), x.indices)

# type specific constructors
Size(x::SubArray) = Size(eachop(_sub_size, to_parent_dims(x), x.indices))
_sub_size(x::Tuple, ::StaticInt{dim}) where {dim} = static_length(getfield(x, dim))
@inline size(B::VecAdjTrans) = (One(), length(parent(B)))
@inline size(B::MatAdjTrans) = permute(size(parent(B)), to_parent_dims(B))
@inline function size(B::PermutedDimsArray{T,N,I1}) where {T,N,I1}
permute(size(parent(B)), static(I1))
@inline Size(A::VecAdjTrans) = Size((One(), static_length(parent(A))))
@inline function Size(A::MatAdjTrans)
Size(permute(getfield(Size(parent(A)), :size), (static(2), static(1))))
end
Size(A::Union{Array,ReshapedArray}) = Size(Base.size(A))
@inline function Size(A::PermutedDimsArray{T,N,I1}) where {T,N,I1}
Size(permute(getfield(Size(parent(A)), :size), static(I1)))
end
Size(A::AbstractRange) = Size((static_length(A),))
Size(x::Base.Generator) = Size(getfield(x, :iter))
Size(x::Iterators.Reverse) = Size(getfield(x, :itr))
Size(x::Iterators.Enumerate) = Size(getfield(x, :itr))
Size(x::Iterators.Accumulate) = Size(getfield(x, :itr))
Size(x::Iterators.Pairs) = Size(getfield(x, :itr))
@inline function Size(x::Iterators.ProductIterator)
Size(eachop(_sub_size, nstatic(Val(ndims(x))), getfield(x, :iterators)))
end
function size(a::ReinterpretArray{T,N,S,A}) where {T,N,S,A}
psize = size(parent(a))
Size(x::Iterators.Zip) = Size(Static.reduce_tup(promote_shape, map(size, getfield(x, :is))))

function Size(a::ReinterpretArray{T,N,S,A}) where {T,N,S,A}
if _is_reshaped(typeof(a))
if sizeof(S) === sizeof(T)
return psize
return Size(parent(a))
elseif sizeof(S) > sizeof(T)
return (static(div(sizeof(S), sizeof(T))), psize...)
return Size((static(div(sizeof(S), sizeof(T))), getfield(Size(parent(a)), :size)...))
else
return tail(psize)
return Size(tail(getfield(Size(parent(a)), :size)))
end
else
return (div(first(psize) * static(sizeof(S)), static(sizeof(T))), tail(psize)...,)
psize = getfield(Size(parent(a)), :size)
return Size((div(first(psize) * static(sizeof(S)), static(sizeof(T))), tail(psize)...,))
end
end
size(A::ReshapedArray) = Base.size(A)
size(A::AbstractRange) = (static_length(A),)
size(x::Base.Generator) = size(getfield(x, :iter))
size(x::Iterators.Reverse) = size(getfield(x, :itr))
size(x::Iterators.Enumerate) = size(getfield(x, :itr))
size(x::Iterators.Accumulate) = size(getfield(x, :itr))
size(x::Iterators.Pairs) = size(getfield(x, :itr))
@inline function size(x::Iterators.ProductIterator)
eachop(_sub_size, nstatic(Val(ndims(x))), getfield(x, :iterators))
end

size(a, dim) = size(a, to_dims(a, dim))
size(a::Array, dim::Integer) = Base.arraysize(a, convert(Int, dim))
function size(a::A, dim::Integer) where {A}
if parent_type(A) <: A
len = known_size(A, dim)
if len === missing
return Int(length(axes(a, dim)))
else
return StaticInt(len)
end
## size of individual dimensions
Size(x, dim) = Size(x, to_dims(x, dim))
function Size(x, dim::Int)
sz = known_size(x, dim)
if sz === missing
return Length(Int(getfield(getfield(Size(x), :size), dim)))
else
return size(a)[dim]
return Length(Int(sz))
end
end
function size(A::SubArray, dim::Integer)
pdim = to_parent_dims(A, dim)
if pdim > ndims(parent_type(A))
return size(parent(A), pdim)
@inline function Size(x, ::StaticInt{dim}) where {dim}
sz = known_size(x, dim)
if sz === missing
return Length(getfield(getfield(Size(x), :size), dim))
else
return static_length(A.indices[pdim])
return Length(static(sz))
end
end
size(x::Iterators.Zip) = Static.reduce_tup(promote_shape, map(size, getfield(x, :is)))

"""
known_size(::Type{T}) -> Tuple
Expand All @@ -87,6 +123,7 @@ known_size(x) = known_size(typeof(x))
function known_size(::Type{T}) where {T<:AbstractRange}
(_range_length(known_first(T), known_step(T), known_last(T)),)
end
known_size(::Type{<:Size{S}}) where {S} = known(S)
known_size(::Type{<:Base.Generator{I}}) where {I} = known_size(I)
known_size(::Type{<:Iterators.Reverse{I}}) where {I} = known_size(I)
known_size(::Type{<:Iterators.Enumerate{I}}) where {I} = known_size(I)
Expand Down Expand Up @@ -123,3 +160,6 @@ _known_size(::Type{T}, dim::StaticInt) where {T} = known_length(field_type(T, di
end
end

size(x) = getfield(Size(x), :size)
size(x, dim) = static_length(Size(x, dim))

1 change: 1 addition & 0 deletions test/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ArrayInterface.dimnames(x::NamedDimsWrapper) = getfield(x, :dimnames)
function ArrayInterface.known_dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L}}
ArrayInterface.Static.known(L)
end
ArrayInterface.Size(x::NamedDimsWrapper) = ArrayInterface.Size(parent(x))

Base.parent(x::NamedDimsWrapper) = x.parent

Expand Down