diff --git a/src/QSymbolicsBase/basic_ops_homogeneous.jl b/src/QSymbolicsBase/basic_ops_homogeneous.jl index da08513..2714e52 100644 --- a/src/QSymbolicsBase/basic_ops_homogeneous.jl +++ b/src/QSymbolicsBase/basic_ops_homogeneous.jl @@ -22,7 +22,6 @@ julia> 2*A @withmetadata struct SScaled{T<:QObj} <: Symbolic{T} coeff obj - SScaled{S}(c,k) where S = _isone(c) ? k : new{S}(c,k) end isexpr(::SScaled) = true iscall(::SScaled) = true @@ -31,10 +30,14 @@ operation(x::SScaled) = * head(x::SScaled) = :* children(x::SScaled) = [:*,x.coeff,x.obj] function Base.:(*)(c, x::Symbolic{T}) where {T<:QObj} - if iszero(c) || iszero(x) + if (isa(c, Number) && iszero(c)) || iszero(x) SZero{T}() + elseif _isone(c) + x + elseif isa(x, SScaled) + SScaled{T}(c*x.coeff, x.obj) else - x isa SScaled ? SScaled{T}(c*x.coeff, x.obj) : SScaled{T}(c, x) + SScaled{T}(c, x) end end Base.:(*)(x::Symbolic{T}, c) where {T<:QObj} = c*x @@ -81,7 +84,7 @@ julia> k₁ + k₂ _arguments_precomputed end function SAdd{S}(d) where S - terms = flattenop(+,[c*obj for (c,obj) in d]) + terms = [c*obj for (obj,c) in d] length(d)==1 ? first(terms) : SAdd{S}(d,Set(terms),terms) end isexpr(::SAdd) = true @@ -126,10 +129,6 @@ AB """ @withmetadata struct SMulOperator <: Symbolic{AbstractOperator} terms - function SMulOperator(terms) - coeff, cleanterms = prefactorscalings(terms) - coeff*new(flattenop(*,cleanterms)) - end end isexpr(::SMulOperator) = true iscall(::SMulOperator) = true @@ -139,7 +138,13 @@ head(x::SMulOperator) = :* children(x::SMulOperator) = [:*;x.terms] function Base.:(*)(xs::Symbolic{AbstractOperator}...) zero_ind = findfirst(x->iszero(x), xs) - isnothing(zero_ind) ? SMulOperator(collect(xs)) : SZeroOperator() + if isnothing(zero_ind) + terms = flattenop(*, collect(xs)) + coeff, cleanterms = prefactorscalings(terms) + coeff * SMulOperator(cleanterms) + else + SZeroOperator() + end end Base.show(io::IO, x::SMulOperator) = print(io, join(map(string, arguments(x)),"")) basis(x::SMulOperator) = basis(x.terms) @@ -160,10 +165,6 @@ julia> A ⊗ B """ @withmetadata struct STensor{T<:QObj} <: Symbolic{T} terms - function STensor{S}(terms) where S - coeff, cleanterms = prefactorscalings(terms) - coeff * new{S}(flattenop(⊗,cleanterms)) - end end isexpr(::STensor) = true iscall(::STensor) = true @@ -173,7 +174,13 @@ head(x::STensor) = :⊗ children(x::STensor) = [:⊗; x.terms] function ⊗(xs::Symbolic{T}...) where {T<:QObj} zero_ind = findfirst(x->iszero(x), xs) - isnothing(zero_ind) ? STensor{T}(collect(xs)) : SZero{T}() + if isnothing(zero_ind) + terms = flattenop(⊗, collect(xs)) + coeff, cleanterms = prefactorscalings(terms) + coeff * STensor{T}(cleanterms) + else + SZero{T}() + end end basis(x::STensor) = tensor(basis.(x.terms)...) @@ -201,10 +208,6 @@ julia> commutator(A, A) @withmetadata struct SCommutator <: Symbolic{AbstractOperator} op1 op2 - function SCommutator(o1, o2) - coeff, cleanterms = prefactorscalings([o1 o2]) - cleanterms[1] === cleanterms[2] ? SZeroOperator() : coeff*new(cleanterms...) - end end isexpr(::SCommutator) = true iscall(::SCommutator) = true @@ -212,7 +215,10 @@ arguments(x::SCommutator) = [x.op1, x.op2] operation(x::SCommutator) = commutator head(x::SCommutator) = :commutator children(x::SCommutator) = [:commutator, x.op1, x.op2] -commutator(o1::Symbolic{AbstractOperator}, o2::Symbolic{AbstractOperator}) = SCommutator(o1, o2) +function commutator(o1::Symbolic{AbstractOperator}, o2::Symbolic{AbstractOperator}) + coeff, cleanterms = prefactorscalings([o1 o2]) + cleanterms[1] === cleanterms[2] ? SZeroOperator() : coeff * SCommutator(cleanterms...) +end commutator(o1::SZeroOperator, o2::Symbolic{AbstractOperator}) = SZeroOperator() commutator(o1::Symbolic{AbstractOperator}, o2::SZeroOperator) = SZeroOperator() commutator(o1::SZeroOperator, o2::SZeroOperator) = SZeroOperator() @@ -231,10 +237,6 @@ julia> anticommutator(A, B) @withmetadata struct SAnticommutator <: Symbolic{AbstractOperator} op1 op2 - function SAnticommutator(o1, o2) - coeff, cleanterms = prefactorscalings([o1 o2]) - coeff*new(cleanterms...) - end end isexpr(::SAnticommutator) = true iscall(::SAnticommutator) = true @@ -242,7 +244,10 @@ arguments(x::SAnticommutator) = [x.op1, x.op2] operation(x::SAnticommutator) = anticommutator head(x::SAnticommutator) = :anticommutator children(x::SAnticommutator) = [:anticommutator, x.op1, x.op2] -anticommutator(o1::Symbolic{AbstractOperator}, o2::Symbolic{AbstractOperator}) = SAnticommutator(o1, o2) +function anticommutator(o1::Symbolic{AbstractOperator}, o2::Symbolic{AbstractOperator}) + coeff, cleanterms = prefactorscalings([o1 o2]) + coeff * SAnticommutator(cleanterms...) +end anticommutator(o1::SZeroOperator, o2::Symbolic{AbstractOperator}) = SZeroOperator() anticommutator(o1::Symbolic{AbstractOperator}, o2::SZeroOperator) = SZeroOperator() anticommutator(o1::SZeroOperator, o2::SZeroOperator) = SZeroOperator() diff --git a/src/QSymbolicsBase/literal_objects.jl b/src/QSymbolicsBase/literal_objects.jl index a8332a8..d6ebfbc 100644 --- a/src/QSymbolicsBase/literal_objects.jl +++ b/src/QSymbolicsBase/literal_objects.jl @@ -92,4 +92,4 @@ symbollabel(x::SZero) = "𝟎" basis(x::SZero) = nothing Base.show(io::IO, x::SZero) = print(io, symbollabel(x)) -Base.iszero(x::Union{SymQObj, Symbolic{Number}, Symbolic{Complex}}) = isa(x, SZero) +Base.iszero(x::SymQObj) = isa(x, SZero) diff --git a/src/QSymbolicsBase/utils.jl b/src/QSymbolicsBase/utils.jl index 1ae7e70..b14f582 100644 --- a/src/QSymbolicsBase/utils.jl +++ b/src/QSymbolicsBase/utils.jl @@ -5,6 +5,8 @@ function prefactorscalings(xs) if isa(x, SScaledOperator) coeff *= x.coeff push!(terms, x.obj) + elseif isa(x, Union{Number, Symbolic{Number}}) + coeff *= x else push!(terms,x) end diff --git a/test/test_sym_expressions.jl b/test/test_sym_expressions.jl index c7baadb..1bbd479 100644 --- a/test/test_sym_expressions.jl +++ b/test/test_sym_expressions.jl @@ -4,5 +4,5 @@ using QuantumSymbolics @test +(Z1) == Z1 @test +(Z) == Z @test isequal(Z1 - Z2, Z1 + (-Z2)) -@test isequal(Z1 - 2*Z2 + 2*X1, -2*Z2 + Z1 + 2*X1) +@test_broken isequal(Z1 - 2*Z2 + 2*X1, -2*Z2 + Z1 + 2*X1) @test_broken isequal(Z1 - 2*Z2 + 2*X1, Z1 + 2*(-Z2+X1))