Skip to content

Commit

Permalink
Improve adhoc arithmetic for UnivPoly, fix broken coeff method (#1877)
Browse files Browse the repository at this point in the history
* Improve adhoc arithmetic for UnivPoly
* Fix broken `coeff` method: If the asked for monomial contains
  vars that are not present in the data-parent, the whole coeff is
  clearly zero. This was not caught before, since the slow adhoc
  functions always re-created all elements with the current
  largest poly ring as data parent.

---------

Co-authored-by: Lars Göttgens <[email protected]>
  • Loading branch information
fingolfin and lgoettgens authored Oct 27, 2024
1 parent 129efd6 commit 814f22b
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 45 deletions.
116 changes: 72 additions & 44 deletions src/generic/UnivPoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ function coeff(p::UnivPoly{T}, vars::Vector{Int}, exps::Vector{Int}) where {T}
if vars[i] <= num
push!(vars2, vars[i])
push!(exps2, exps[i])
elseif exps[i] > 0
return zero(S)
end
end
return UnivPoly{T}(coeff(data(p), vars2, exps2), S)
Expand Down Expand Up @@ -361,6 +363,7 @@ end
###############################################################################

function univ_promote(x::UnivPoly{T}, y::UnivPoly{T}) where {T <: RingElement}
check_parent(x, y)
nx = nvars(parent(data(x)))
ny = nvars(parent(data(y)))
if nx == ny
Expand All @@ -371,19 +374,16 @@ function univ_promote(x::UnivPoly{T}, y::UnivPoly{T}) where {T <: RingElement}
end

function +(a::UnivPoly{T}, b::UnivPoly{T}) where {T}
check_parent(a, b)
a, b = univ_promote(a, b)
return UnivPoly{T}(data(a) + data(b), parent(a))
end

function -(a::UnivPoly{T}, b::UnivPoly{T}) where {T}
check_parent(a, b)
a, b = univ_promote(a, b)
return UnivPoly{T}(data(a) - data(b), parent(a))
end

function *(a::UnivPoly{T}, b::UnivPoly{T}) where {T}
check_parent(a, b)
a, b = univ_promote(a, b)
return UnivPoly{T}(data(a)*data(b), parent(a))
end
Expand Down Expand Up @@ -504,21 +504,25 @@ end
#
###############################################################################

function *(p::UnivPoly{T}, n::Union{Integer, Rational, AbstractFloat}) where {T}
S = parent(p)
return UnivPoly{T}(data(p)*n, S)
end
for op in (:+, :-, :*)
@eval begin
function $op(p::UnivPoly{T}, n::Union{Integer, Rational, AbstractFloat}) where {T}
S = parent(p)
return UnivPoly{T}($op(data(p),n), S)
end

function *(p::UnivPoly{T}, n::T) where {T <: RingElem}
S = parent(p)
return UnivPoly{T}(data(p)*n, S)
end
function $op(p::UnivPoly{T}, n::T) where {T <: RingElem}
S = parent(p)
return UnivPoly{T}($op(data(p),n), S)
end

*(n::Union{Integer, Rational, AbstractFloat}, p::UnivPoly) = p*n
$op(n::Union{Integer, Rational, AbstractFloat}, p::UnivPoly) = $op(p,n)

*(n::T, p::UnivPoly{T}) where {T <: RingElem} = p*n
$op(n::T, p::UnivPoly{T}) where {T <: RingElem} = $op(p,n)
end
end

function divexact(p::UnivPoly{T}, n::Union{Integer, Rational, BigFloat}; check::Bool=true) where {T}
function divexact(p::UnivPoly{T}, n::Union{Integer, Rational, AbstractFloat}; check::Bool=true) where {T}
S = parent(p)
return UnivPoly{T}(divexact(data(p), n; check=check), S)
end
Expand Down Expand Up @@ -674,13 +678,11 @@ end
###############################################################################

function divexact(a::UnivPoly{T}, b::UnivPoly{T}; check::Bool=true) where {T}
check_parent(a, b)
a, b = univ_promote(a, b)
return UnivPoly{T}(divexact(data(a), data(b); check=check), parent(a))
end

function divides(a::UnivPoly{T}, b::UnivPoly{T}) where {T}
check_parent(a, b)
a, b = univ_promote(a, b)
flag, q = divides(data(a), data(b))
return flag, UnivPoly{T}(q, parent(a))
Expand All @@ -693,13 +695,11 @@ end
###############################################################################

function Base.div(a::UnivPoly{T}, b::UnivPoly{T}) where {T}
check_parent(a, b)
a, b = univ_promote(a, b)
return UnivPoly{T}(div(data(a), data(b)), parent(a))
end

function Base.divrem(a::UnivPoly{T}, b::UnivPoly{T}) where {T}
check_parent(a, b)
a, b = univ_promote(a, b)
q, r = divrem(data(a), data(b))
return UnivPoly{T}(q, parent(a)), UnivPoly{T}(r, parent(a))
Expand Down Expand Up @@ -730,11 +730,9 @@ end
###############################################################################

function remove(z::UnivPoly{T}, p::UnivPoly{T}) where {T}
check_parent(z, p)
S = parent(z)
z, p = univ_promote(z, p)
val, q = remove(data(z), data(p))
return val, UnivPoly{T}(q, S)
return val, UnivPoly{T}(q, parent(z))
end

function valuation(z::UnivPoly{T}, p::UnivPoly{T}) where {T}
Expand Down Expand Up @@ -866,13 +864,11 @@ end
###############################################################################

function gcd(a::UnivPoly{T}, b::UnivPoly{T}) where {T <: RingElement}
check_parent(a, b)
a, b = univ_promote(a, b)
return UnivPoly{T}(gcd(data(a), data(b)), parent(a))
end

function lcm(a::UnivPoly{T}, b::UnivPoly{T}) where {T <: RingElement}
check_parent(a, b)
a, b = univ_promote(a, b)
return UnivPoly{T}(lcm(data(a), data(b)), parent(a))
end
Expand Down Expand Up @@ -1005,59 +1001,91 @@ end
###############################################################################

function zero!(a::UnivPoly{T}) where {T <: RingElement}
a.p = zero!(a.p)
return a
a.p = zero!(a.p)
return a
end

function one!(a::UnivPoly{T}) where {T <: RingElement}
a.p = one!(a.p)
return a
a.p = one!(a.p)
return a
end

function neg!(z::UnivPoly{T}, a::UnivPoly{T}) where {T <: RingElement}
z.p = neg!(z.p, a.p)
return z
if parent(data(z)) == parent(data(a))
z.p = neg!(z.p, a.p)
else
z.p = -a.p
end
return z
end

function fit!(a::UnivPoly, n::Int)
fit!(data(a), n)
fit!(data(a), n)
end

function add!(a::UnivPoly{T}, b::UnivPoly{T}, c::UnivPoly{T}) where {T <: RingElement}
a.p = add!(a.p, b.p, c.p)
return a
if parent(data(a)) == parent(data(b)) == parent(data(c))
a.p = add!(data(a), data(b), data(c))
else
a.p = data(b + c)
end
return a
end

function add!(a::UnivPoly{T}, b::UnivPoly{T}, c::RingElement) where {T <: RingElement}
a.p = add!(a.p, b.p, c)
return a
if parent(data(a)) == parent(data(b))
a.p = add!(data(a), data(b), c)
else
a.p = data(b + c)
end
return a
end

add!(a::UnivPoly{T}, b::RingElement, c::UnivPoly{T}) where {T <: RingElement} = add!(a, c, b)

function sub!(a::UnivPoly{T}, b::UnivPoly{T}, c::UnivPoly{T}) where {T <: RingElement}
a.p = sub!(a.p, b.p, c.p)
return a
if parent(data(a)) == parent(data(b)) == parent(data(c))
a.p = sub!(data(a), data(b), data(c))
else
a.p = data(b - c)
end
return a
end

function sub!(a::UnivPoly{T}, b::UnivPoly{T}, c::RingElement) where {T <: RingElement}
a.p = sub!(a.p, b.p, c)
return a
if parent(data(a)) == parent(data(b))
a.p = sub!(data(a), data(b), c)
else
a.p = data(b - c)
end
return a
end

function sub!(a::UnivPoly{T}, b::RingElement, c::UnivPoly{T}) where {T <: RingElement}
a.p = sub!(a.p, b, c.p)
return a
if parent(data(a)) == parent(data(c))
a.p = sub!(data(a), b, data(c))
else
a.p = data(b - c)
end
return a
end

function mul!(a::UnivPoly{T}, b::UnivPoly{T}, c::UnivPoly{T}) where {T <: RingElement}
a.p = mul!(a.p, b.p, c.p)
return a
if parent(data(a)) == parent(data(b)) == parent(data(c))
a.p = mul!(data(a), data(b), data(c))
else
a.p = data(b * c)
end
return a
end

function mul!(a::UnivPoly{T}, b::UnivPoly{T}, c::RingElement) where {T <: RingElement}
a.p = mul!(a.p, b.p, c)
return a
if parent(data(a)) == parent(data(b))
a.p = mul!(data(a), data(b), c)
else
a.p = data(b * c)
end
return a
end

mul!(a::UnivPoly{T}, b::RingElement, c::UnivPoly{T}) where {T <: RingElement} = mul!(a, c, b)
Expand Down
2 changes: 1 addition & 1 deletion test/generic/UnivPoly-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ end
g = 3x^3*y^2 + 2x^3*y*z + 2x^2*y*z + 3x + 2y + 1

@test collect(coefficients(f)) == [R(v) for v in [3, 2, 1, 4]]
@test collect(exponent_vectors(f)) == [[3, 0, 0], [2, 0, 0], [1, 0, 0], [0, 0, 0]]
@test collect(exponent_vectors(f)) == [[3], [2], [1], [0]]
@test sum(terms(f)) == f
@test collect(monomials(f)) == [x^3, x^2, x, S(1)]

Expand Down

0 comments on commit 814f22b

Please sign in to comment.