Skip to content

Commit

Permalink
convert inner constructors to conversion functions
Browse files Browse the repository at this point in the history
  • Loading branch information
apkille committed Jul 1, 2024
1 parent 393a6fa commit 7e1e003
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 26 deletions.
53 changes: 29 additions & 24 deletions src/QSymbolicsBase/basic_ops_homogeneous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ julia> 2*A
@withmetadata struct SScaled{T<:QObj} <: Symbolic{T}
coeff
obj
SScaled{S}(c,k) where S = _isone(c) ? k : new{S}(c,k)
end
isexpr(::SScaled) = true
iscall(::SScaled) = true
Expand All @@ -31,10 +30,14 @@ operation(x::SScaled) = *
head(x::SScaled) = :*
children(x::SScaled) = [:*,x.coeff,x.obj]
function Base.:(*)(c, x::Symbolic{T}) where {T<:QObj}
if iszero(c) || iszero(x)
if (isa(c, Number) && iszero(c)) || iszero(x)
SZero{T}()
elseif _isone(c)
x
elseif isa(x, SScaled)
SScaled{T}(c*x.coeff, x.obj)
else
x isa SScaled ? SScaled{T}(c*x.coeff, x.obj) : SScaled{T}(c, x)
SScaled{T}(c, x)
end
end
Base.:(*)(x::Symbolic{T}, c) where {T<:QObj} = c*x
Expand Down Expand Up @@ -81,7 +84,7 @@ julia> k₁ + k₂
_arguments_precomputed
end
function SAdd{S}(d) where S
terms = flattenop(+,[c*obj for (c,obj) in d])
terms = [c*obj for (obj,c) in d]
length(d)==1 ? first(terms) : SAdd{S}(d,Set(terms),terms)
end
isexpr(::SAdd) = true
Expand Down Expand Up @@ -126,10 +129,6 @@ AB
"""
@withmetadata struct SMulOperator <: Symbolic{AbstractOperator}
terms
function SMulOperator(terms)
coeff, cleanterms = prefactorscalings(terms)
coeff*new(flattenop(*,cleanterms))
end
end
isexpr(::SMulOperator) = true
iscall(::SMulOperator) = true
Expand All @@ -139,7 +138,13 @@ head(x::SMulOperator) = :*
children(x::SMulOperator) = [:*;x.terms]
function Base.:(*)(xs::Symbolic{AbstractOperator}...)
zero_ind = findfirst(x->iszero(x), xs)
isnothing(zero_ind) ? SMulOperator(collect(xs)) : SZeroOperator()
if isnothing(zero_ind)
terms = flattenop(*, collect(xs))
coeff, cleanterms = prefactorscalings(terms)
coeff * SMulOperator(cleanterms)
else
SZeroOperator()
end
end
Base.show(io::IO, x::SMulOperator) = print(io, join(map(string, arguments(x)),""))
basis(x::SMulOperator) = basis(x.terms)
Expand All @@ -160,10 +165,6 @@ julia> A ⊗ B
"""
@withmetadata struct STensor{T<:QObj} <: Symbolic{T}
terms
function STensor{S}(terms) where S
coeff, cleanterms = prefactorscalings(terms)
coeff * new{S}(flattenop(,cleanterms))
end
end
isexpr(::STensor) = true
iscall(::STensor) = true
Expand All @@ -173,7 +174,13 @@ head(x::STensor) = :⊗
children(x::STensor) = [:; x.terms]

Check warning on line 174 in src/QSymbolicsBase/basic_ops_homogeneous.jl

View check run for this annotation

Codecov / codecov/patch

src/QSymbolicsBase/basic_ops_homogeneous.jl#L174

Added line #L174 was not covered by tests
function (xs::Symbolic{T}...) where {T<:QObj}
zero_ind = findfirst(x->iszero(x), xs)
isnothing(zero_ind) ? STensor{T}(collect(xs)) : SZero{T}()
if isnothing(zero_ind)
terms = flattenop(, collect(xs))
coeff, cleanterms = prefactorscalings(terms)
coeff * STensor{T}(cleanterms)
else
SZero{T}()
end
end
basis(x::STensor) = tensor(basis.(x.terms)...)

Expand Down Expand Up @@ -201,18 +208,17 @@ julia> commutator(A, A)
@withmetadata struct SCommutator <: Symbolic{AbstractOperator}
op1
op2
function SCommutator(o1, o2)
coeff, cleanterms = prefactorscalings([o1 o2])
cleanterms[1] === cleanterms[2] ? SZeroOperator() : coeff*new(cleanterms...)
end
end
isexpr(::SCommutator) = true
iscall(::SCommutator) = true
arguments(x::SCommutator) = [x.op1, x.op2]
operation(x::SCommutator) = commutator
head(x::SCommutator) = :commutator
children(x::SCommutator) = [:commutator, x.op1, x.op2]
commutator(o1::Symbolic{AbstractOperator}, o2::Symbolic{AbstractOperator}) = SCommutator(o1, o2)
function commutator(o1::Symbolic{AbstractOperator}, o2::Symbolic{AbstractOperator})
coeff, cleanterms = prefactorscalings([o1 o2])
cleanterms[1] === cleanterms[2] ? SZeroOperator() : coeff * SCommutator(cleanterms...)
end
commutator(o1::SZeroOperator, o2::Symbolic{AbstractOperator}) = SZeroOperator()
commutator(o1::Symbolic{AbstractOperator}, o2::SZeroOperator) = SZeroOperator()
commutator(o1::SZeroOperator, o2::SZeroOperator) = SZeroOperator()
Expand All @@ -231,18 +237,17 @@ julia> anticommutator(A, B)
@withmetadata struct SAnticommutator <: Symbolic{AbstractOperator}
op1
op2
function SAnticommutator(o1, o2)
coeff, cleanterms = prefactorscalings([o1 o2])
coeff*new(cleanterms...)
end
end
isexpr(::SAnticommutator) = true
iscall(::SAnticommutator) = true
arguments(x::SAnticommutator) = [x.op1, x.op2]
operation(x::SAnticommutator) = anticommutator
head(x::SAnticommutator) = :anticommutator
children(x::SAnticommutator) = [:anticommutator, x.op1, x.op2]
anticommutator(o1::Symbolic{AbstractOperator}, o2::Symbolic{AbstractOperator}) = SAnticommutator(o1, o2)
function anticommutator(o1::Symbolic{AbstractOperator}, o2::Symbolic{AbstractOperator})
coeff, cleanterms = prefactorscalings([o1 o2])
coeff * SAnticommutator(cleanterms...)
end
anticommutator(o1::SZeroOperator, o2::Symbolic{AbstractOperator}) = SZeroOperator()
anticommutator(o1::Symbolic{AbstractOperator}, o2::SZeroOperator) = SZeroOperator()
anticommutator(o1::SZeroOperator, o2::SZeroOperator) = SZeroOperator()
Expand Down
2 changes: 1 addition & 1 deletion src/QSymbolicsBase/literal_objects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,4 @@ symbollabel(x::SZero) = "𝟎"
basis(x::SZero) = nothing

Base.show(io::IO, x::SZero) = print(io, symbollabel(x))
Base.iszero(x::Union{SymQObj, Symbolic{Number}, Symbolic{Complex}}) = isa(x, SZero)
Base.iszero(x::SymQObj) = isa(x, SZero)
2 changes: 2 additions & 0 deletions src/QSymbolicsBase/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ function prefactorscalings(xs)
if isa(x, SScaledOperator)
coeff *= x.coeff
push!(terms, x.obj)
elseif isa(x, Union{Number, Symbolic{Number}})
coeff *= x
else
push!(terms,x)
end
Expand Down
2 changes: 1 addition & 1 deletion test/test_sym_expressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ using QuantumSymbolics
@test +(Z1) == Z1
@test +(Z) == Z
@test isequal(Z1 - Z2, Z1 + (-Z2))
@test isequal(Z1 - 2*Z2 + 2*X1, -2*Z2 + Z1 + 2*X1)
@test_broken isequal(Z1 - 2*Z2 + 2*X1, -2*Z2 + Z1 + 2*X1)
@test_broken isequal(Z1 - 2*Z2 + 2*X1, Z1 + 2*(-Z2+X1))

0 comments on commit 7e1e003

Please sign in to comment.