Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix Aqua's reported piracies and method ambiguities #85

Merged
merged 2 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/QSymbolicsBase/basic_ops_homogeneous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ arguments(x::SScaled) = [x.coeff,x.obj]
operation(x::SScaled) = *
head(x::SScaled) = :*
children(x::SScaled) = [:*,x.coeff,x.obj]
function Base.:(*)(c, x::Symbolic{T}) where {T<:QObj}
function Base.:(*)(c::U, x::Symbolic{T}) where {U<:Union{Number, Symbolic{<:Number}},T<:QObj}
if (isa(c, Number) && iszero(c)) || iszero(x)
SZero{T}()
elseif _isone(c)
Expand All @@ -40,9 +40,9 @@ function Base.:(*)(c, x::Symbolic{T}) where {T<:QObj}
SScaled{T}(c, x)
end
end
Base.:(*)(x::Symbolic{T}, c) where {T<:QObj} = c*x
Base.:(*)(x::Symbolic{T}, c::Number) where {T<:QObj} = c*x
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Motivation here was to avoid ambiguities with other packages,

But I'm not sure if it should be ::Number or ::{Union{Number, Symbolic{Number}}} to allow p*tr(q) just as tr(q)*p is allowed. The test suits for QuantumSymbolics and QuantumSavory pass right now, but I wasn't sure if that is a valid operation since it was allowed before my changes.

Base.:(*)(x::Symbolic{T}, y::Symbolic{S}) where {T<:QObj,S<:QObj} = throw(ArgumentError("multiplication between $(typeof(x)) and $(typeof(y)) is not defined; maybe you are looking for a tensor product `tensor`"))
Base.:(/)(x::Symbolic{T}, c) where {T<:QObj} = iszero(c) ? throw(DomainError(c,"cannot divide QSymbolics expressions by zero")) : (1/c)*x
Base.:(/)(x::Symbolic{T}, c::Number) where {T<:QObj} = iszero(c) ? throw(DomainError(c,"cannot divide QSymbolics expressions by zero")) : (1/c)*x
basis(x::SScaled) = basis(x.obj)

const SScaledKet = SScaled{AbstractKet}
Expand Down Expand Up @@ -94,13 +94,13 @@ arguments(x::SAdd) = x._arguments_precomputed
operation(x::SAdd) = +
head(x::SAdd) = :+
children(x::SAdd) = [:+; x._arguments_precomputed]
function Base.:(+)(xs::Vararg{Symbolic{T},N}) where {T<:QObj,N}
function Base.:(+)(x::Symbolic{T}, xs::Vararg{Symbolic{T}, N}) where {T<:QObj, N}
xs = (x, xs...)
xs = collect(xs)
f = first(xs)
nonzero_terms = filter!(x->!iszero(x),xs)
isempty(nonzero_terms) ? f : SAdd{T}(countmap_flatten(nonzero_terms, SScaled{T}))
end
Base.:(+)(xs::Vararg{Symbolic{<:QObj},0}) = 0 # to avoid undefined type parameters issue in the above method
basis(x::SAdd) = basis(first(x.dict).first)

const SAddBra = SAdd{AbstractBra}
Expand Down Expand Up @@ -137,7 +137,8 @@ arguments(x::SMulOperator) = x.terms
operation(x::SMulOperator) = *
head(x::SMulOperator) = :*
children(x::SMulOperator) = [:*;x.terms]
function Base.:(*)(xs::Symbolic{AbstractOperator}...)
function Base.:(*)(x::Symbolic{AbstractOperator}, xs::Vararg{Symbolic{AbstractOperator}, N}) where {N}
xs = (x, xs...)
zero_ind = findfirst(x->iszero(x), xs)
if isnothing(zero_ind)
if any(x->!(samebases(basis(x),basis(first(xs)))),xs)
Expand Down
2 changes: 2 additions & 0 deletions src/QSymbolicsBase/basic_superops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
basis(x::KrausRepr) = basis(first(x.krausops))
Base.:(*)(sop::KrausRepr, op::Symbolic{AbstractOperator}) = (+)((i*op*dagger(i) for i in sop.krausops)...)
Base.:(*)(sop::KrausRepr, k::Symbolic{AbstractKet}) = (+)((i*SProjector(k)*dagger(i) for i in sop.krausops)...)
Base.:(*)(sop::KrausRepr, k::SZeroOperator) = SZeroOperator()
Base.:(*)(sop::KrausRepr, k::SZeroKet) = SZeroOperator()

Check warning on line 33 in src/QSymbolicsBase/basic_superops.jl

View check run for this annotation

Codecov / codecov/patch

src/QSymbolicsBase/basic_superops.jl#L32-L33

Added lines #L32 - L33 were not covered by tests
Base.show(io::IO, x::KrausRepr) = print(io, "𝒦("*join([symbollabel(i) for i in x.krausops], ",")*")")

##
Expand Down
43 changes: 39 additions & 4 deletions test/test_aqua.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,42 @@
@testitem "Aqua" tags=[:aqua] begin
using Aqua
Aqua.test_all(QuantumSymbolics,
ambiguities=(;broken=true),
piracies=(;broken=true),
)

# Add any new types needed to QObj, or here if QObj if not appropriate.
# Add types from elsewhere in the ecosystem here or preferably to QObj
own_types = [Base.uniontypes(QObj)...,]
own_types_union = Union{SymQObj,}

Aqua.test_all(QuantumSymbolics, piracies=(;treat_as_own=own_types))

function normalize_arguments(method)
args = Base.unwrap_unionall(method.sig).types[2:end]
normalized_args = []
# handle few edge cases specific to our analysis
for arg in args
# mutation and order of if-conditions is intedtional here
if (arg isa UnionAll) && (arg.body <: Type) arg = arg.body.parameters[1] end
if (arg isa Core.TypeofVararg) arg = arg.T end
if (arg isa TypeVar) arg = arg.ub end
push!(normalized_args, arg)
end
return normalized_args
end

# Custom type-piracy detection, to catch uses of QuantumInterface types without a Symbolic
filtered_piracies = filter(Aqua.Piracy.hunt(QuantumSymbolics)) do m
!any(normalize_arguments(m) .<: own_types_union)
end

aqua_piracies = Aqua.Piracy.hunt(QuantumSymbolics, treat_as_own=own_types)
internally_detected_piracies = setdiff(filtered_piracies, aqua_piracies)
if !isempty(internally_detected_piracies)
printstyled(
stderr,
"Internally flagged possible type-piracy:\n";
color = Base.warn_color()
)
show(stderr, MIME"text/plain"(), internally_detected_piracies)
println(stderr, "\n")
end
@test isempty(internally_detected_piracies)
end
Loading