Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] Unwrap module for dispatching on unwrapped types #1220

Merged
merged 93 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
401b9d4
Create Unwraped folder and move most of the `iswrappedarray` function…
kmp5VT Oct 26, 2023
0afcc9c
Make a prototype for permutedims in NDTensors
kmp5VT Oct 26, 2023
fffd862
format
kmp5VT Oct 26, 2023
61fbe32
add permutedims! to `Unwrap`
kmp5VT Oct 26, 2023
6adee77
format
kmp5VT Oct 26, 2023
0af661d
permutedims!! calls Base.permutedim! with expose
kmp5VT Oct 26, 2023
993d512
Call base permutedims not ndtensors
kmp5VT Oct 26, 2023
7915a8e
Arrays calls permutedims without leaf_parenttype
kmp5VT Oct 26, 2023
144c57d
write a `LinearAlgebra.mul!` for `Exposed`
kmp5VT Oct 26, 2023
9b4fa78
format
kmp5VT Oct 26, 2023
4db2297
Remove file
kmp5VT Oct 26, 2023
86b9e3b
Fix metal mul wrapper
kmp5VT Oct 26, 2023
23f9ec1
Create `copyto` function for Exposed struct
kmp5VT Oct 26, 2023
72d00ac
Properly export `Exposed`
kmp5VT Oct 26, 2023
8f41fcf
format
kmp5VT Oct 26, 2023
1c2e2c5
Some fixes to metal
kmp5VT Oct 27, 2023
08e92b1
fix permutedims!! signature
kmp5VT Oct 27, 2023
4d38994
format
kmp5VT Oct 27, 2023
7728b8c
Fix permute call only call when dest is a reshapedarray
kmp5VT Oct 27, 2023
5a6d524
format
kmp5VT Oct 27, 2023
c58e034
Don't use NDTensors mtl because 64 bit precision is not supported
kmp5VT Oct 27, 2023
134909c
Allow blockspare tests to support metal
kmp5VT Oct 27, 2023
1a90927
Creat unexpose function
kmp5VT Oct 27, 2023
3f6dd75
Make parent function for Exposed
kmp5VT Oct 27, 2023
0f02339
Favor unxpose over .object and return unexposed
kmp5VT Oct 27, 2023
58473b3
Revert blocksparse changes
kmp5VT Oct 27, 2023
2606602
remove AbstractArray copyto.jl
kmp5VT Oct 27, 2023
6d0b353
Add back mul!! function
kmp5VT Oct 27, 2023
db6edd1
Include mul.jl in NDTensors.jl
kmp5VT Oct 27, 2023
39ac932
Call mul!! in case we need to mutate
kmp5VT Oct 27, 2023
8398394
format
kmp5VT Oct 27, 2023
469fad7
Merge branch 'main' into kmp5/enhancements/unwrap
kmp5VT Oct 27, 2023
57b330f
Remove NDTensors custom permutedims
kmp5VT Oct 27, 2023
e0a425c
Remove files from NDTensors
kmp5VT Oct 27, 2023
f69ece3
alphabetical order
kmp5VT Oct 27, 2023
c5ebc5d
Use linearalgebra with Expose
kmp5VT Oct 27, 2023
5cad7ae
import permutedims!
kmp5VT Oct 27, 2023
61163c6
leaf_parentype -> unwrap_type
kmp5VT Oct 30, 2023
532a519
per matts comment
kmp5VT Oct 30, 2023
aad36ed
format
kmp5VT Oct 30, 2023
a4838ca
Merge branch 'main' into kmp5/enhancements/unwrap
kmp5VT Oct 30, 2023
c81a23b
import unwrap_type
kmp5VT Oct 30, 2023
df9d194
Move metal examples test to end of test suite
kmp5VT Oct 30, 2023
6ae4816
remove `Base.` and `LinearAlgebra.`
kmp5VT Oct 30, 2023
acaa9c6
add comment
kmp5VT Oct 30, 2023
3b9817a
format
kmp5VT Oct 30, 2023
e32a674
formatting
kmp5VT Oct 30, 2023
9c8cfde
second p
kmp5VT Oct 30, 2023
475e150
Make transpose of expose
kmp5VT Oct 30, 2023
0666dc8
using not import
kmp5VT Oct 30, 2023
c4f618e
Unwrap.unwrap_type
kmp5VT Oct 30, 2023
abeb9fa
get copyto working
kmp5VT Oct 30, 2023
488bce8
Use full path name
kmp5VT Oct 30, 2023
57f8bed
Fix mul!
kmp5VT Oct 30, 2023
b477058
Add some functions to expose
kmp5VT Oct 30, 2023
d648f1e
typo
kmp5VT Oct 30, 2023
0249427
use NDTensors.mtl
kmp5VT Oct 30, 2023
c9cc991
typo
kmp5VT Oct 30, 2023
3e93d7e
Working linear algebra qr, ql, and SVD with Expose
kmp5VT Oct 30, 2023
5501ff6
Remove deleted file
kmp5VT Oct 30, 2023
e8e562a
remove leaf_parenttype and using
kmp5VT Oct 30, 2023
4eda671
If dev is metal use Float32
kmp5VT Oct 30, 2023
d8d0349
Don't do Float64 on metal
kmp5VT Oct 30, 2023
9c1fad3
use unwrap cpu
kmp5VT Oct 30, 2023
c5cdd65
format
kmp5VT Oct 30, 2023
f09689b
Rename util, abstractarray
kmp5VT Oct 31, 2023
faafbfb
Fix issue in linearalgebra
kmp5VT Oct 31, 2023
d66f907
Update get/set index for full stack
kmp5VT Oct 31, 2023
26daedf
Skip Float64 on mtl
kmp5VT Oct 31, 2023
1d383d6
format
kmp5VT Oct 31, 2023
fa039ef
format
kmp5VT Oct 31, 2023
5fcf642
Fix scalar indexing issues
kmp5VT Oct 31, 2023
d3a7e7c
When A, B and C transpose just change order to `B * A = C`. There was…
kmp5VT Oct 31, 2023
56be813
format
kmp5VT Oct 31, 2023
57f9cb9
Merge remote-tracking branch 'origin/main' into kmp5/enhancements/unwrap
kmp5VT Oct 31, 2023
b96dd4e
Fix permutedims! for arraystorage
kmp5VT Oct 31, 2023
abe78eb
format
kmp5VT Oct 31, 2023
7df11ae
Remove `Base.` and others from inline function calls
kmp5VT Oct 31, 2023
aa8e3f9
Remove leaf_parenttype
kmp5VT Oct 31, 2023
d27ec70
Add todo comments
kmp5VT Oct 31, 2023
34cbaff
Remove deleted file
kmp5VT Oct 31, 2023
13aa27d
Add back expose in SVD
kmp5VT Oct 31, 2023
5b38f13
create a mul for all transpose
kmp5VT Nov 1, 2023
f3e91a2
per some comments
kmp5VT Nov 1, 2023
76ccbfa
Merge branch 'main' into kmp5/enhancements/unwrap
mtfishman Nov 1, 2023
50a452f
Merge commit '70967cd1c1f0c9236450485446c2a10ad627bc07' into kmp5/enh…
kmp5VT Nov 1, 2023
7fb5f91
Add some scalar index work arounds for CUDA.
kmp5VT Nov 1, 2023
d82057a
Expose version wasn't properly passing kwargs. adding that back in fi…
kmp5VT Nov 1, 2023
b9f2980
format
kmp5VT Nov 1, 2023
130d4f7
Merge branch 'kmp5/enhancements/unwrap' of github.com:kmp5VT/ITensors…
kmp5VT Nov 1, 2023
a7e2ea4
Use CPU
kmp5VT Nov 1, 2023
96a8334
I do need this function
kmp5VT Nov 1, 2023
816ae64
format
kmp5VT Nov 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ include("SortedSets/src/SortedSets.jl")
using .SortedSets
include("TagSets/src/TagSets.jl")
using .TagSets

include("Unwrap/src/Unwrap.jl")
using .Unwrap
using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo

using Base.Cartesian: @nexprs
Expand Down
3 changes: 3 additions & 0 deletions NDTensors/src/Unwrap/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Unwrap

A module to unwrap complex array types to assist in the generic programming of array-type based functions.
1 change: 1 addition & 0 deletions NDTensors/src/Unwrap/TODO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Replace all `leaf_parenttype` calls by wrapping the arrays in this `expose` type
16 changes: 16 additions & 0 deletions NDTensors/src/Unwrap/src/Unwrap.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module Unwrap
using SimpleTraits
using LinearAlgebra
using Base: ReshapedArray
using Strided: StridedView
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved

include("iswrappedarray.jl")
include("expose.jl")

export IsWrappedArray, is_wrapped_array, parenttype, unwrap_type, expose

## TODO Create functions which take the `Expose` type and launch functions
## using that type
## TODO write exposed based functions in the NDTensors Extensions when necessary

end
5 changes: 5 additions & 0 deletions NDTensors/src/Unwrap/src/expose.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
struct Expose{Unwraped,Object}
object::Object
end

expose(object) = Expose{unwrap_type(object),typeof(object)}(object)
57 changes: 57 additions & 0 deletions NDTensors/src/Unwrap/src/iswrappedarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Trait indicating if the AbstractArray type is an array wrapper.
# Assumes that it implements `NDTensors.parenttype`.
@traitdef IsWrappedArray{ArrayT}

#! format: off
@traitimpl IsWrappedArray{ArrayT} <- is_wrapped_array(ArrayT)
#! format: on

is_wrapped_array(arraytype::Type{<:AbstractArray}) = (parenttype(arraytype) ≠ arraytype)

# TODO: This is only defined because the current design
# of `Diag` using a `Number` as the data type if it
# is a uniform diagonal type. Delete this when it is
# replaced by `DiagonalArray`.
is_wrapped_array(arraytype::Type{<:Number}) = false

# For working with instances, not used by
# `SimpleTraits.jl` traits dispatch.
is_wrapped_array(array::AbstractArray) = is_wrapped_array(typeof(array))

# By default, the `parentype` of an array type is itself
parenttype(arraytype::Type{<:AbstractArray}) = arraytype

# TODO: Use `SetParameters` here.
parenttype(::Type{<:Base.ReshapedArray{<:Any,<:Any,P}}) where {P} = P
kmp5VT marked this conversation as resolved.
Show resolved Hide resolved
parenttype(::Type{<:Transpose{<:Any,P}}) where {P} = P
parenttype(::Type{<:Adjoint{<:Any,P}}) where {P} = P
parenttype(::Type{<:Symmetric{<:Any,P}}) where {P} = P
parenttype(::Type{<:Hermitian{<:Any,P}}) where {P} = P
parenttype(::Type{<:UpperTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:LowerTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:UnitUpperTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:UnitLowerTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:Diagonal{<:Any,P}}) where {P} = P
parenttype(::Type{<:SubArray{<:Any,<:Any,P}}) where {P} = P
parenttype(::Type{<:StridedView{<:Any,<:Any,P}}) where {P} = P

# For working with instances, not used by
# `SimpleTraits.jl` traits dispatch.
parenttype(array::AbstractArray) = parenttype(typeof(array))

## These functions will be used in place of leaf_parenttype but will be
## call indirectly through the expose function.
@traitfn function unwrap_type(
arraytype::Type{ArrayT}
) where {ArrayT; IsWrappedArray{ArrayT}}
return leaf_parenttype(parenttype(arraytype))
end

@traitfn function unwrap_type(
arraytype::Type{ArrayT}
) where {ArrayT; !IsWrappedArray{ArrayT}}
return arraytype
end

# For working with instances.
unwrap_type(array::AbstractArray) = unwrap_type(typeof(array))
41 changes: 0 additions & 41 deletions NDTensors/src/abstractarray/iswrappedarray.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,3 @@
# Trait indicating if the AbstractArray type is an array wrapper.
# Assumes that it implements `NDTensors.parenttype`.
@traitdef IsWrappedArray{ArrayT}

#! format: off
@traitimpl IsWrappedArray{ArrayT} <- is_wrapped_array(ArrayT)
#! format: on

is_wrapped_array(arraytype::Type{<:AbstractArray}) = (parenttype(arraytype) ≠ arraytype)

# TODO: This is only defined because the current design
# of `Diag` using a `Number` as the data type if it
# is a uniform diagonal type. Delete this when it is
# replaced by `DiagonalArray`.
is_wrapped_array(arraytype::Type{<:Number}) = false

# For working with instances, not used by
# `SimpleTraits.jl` traits dispatch.
is_wrapped_array(array::AbstractArray) = is_wrapped_array(typeof(array))

# By default, the `parentype` of an array type is itself
parenttype(arraytype::Type{<:AbstractArray}) = arraytype

# TODO: Use `SetParameters` here.
parenttype(::Type{<:ReshapedArray{<:Any,<:Any,P}}) where {P} = P
parenttype(::Type{<:Transpose{<:Any,P}}) where {P} = P
parenttype(::Type{<:Adjoint{<:Any,P}}) where {P} = P
parenttype(::Type{<:Symmetric{<:Any,P}}) where {P} = P
parenttype(::Type{<:Hermitian{<:Any,P}}) where {P} = P
parenttype(::Type{<:UpperTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:LowerTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:UnitUpperTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:UnitLowerTriangular{<:Any,P}}) where {P} = P
parenttype(::Type{<:Diagonal{<:Any,P}}) where {P} = P
parenttype(::Type{<:SubArray{<:Any,<:Any,P}}) where {P} = P
parenttype(::Type{<:StridedView{<:Any,<:Any,P}}) where {P} = P

# For working with instances, not used by
# `SimpleTraits.jl` traits dispatch.
parenttype(array::AbstractArray) = parenttype(typeof(array))

@traitfn function leaf_parenttype(
arraytype::Type{ArrayT}
) where {ArrayT; IsWrappedArray{ArrayT}}
Expand Down
Loading