diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index 47597c97..542fb133 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -5,7 +5,7 @@ module SymbolicUtils using DocStringExtensions -export @syms, term, showraw, hasmetadata, getmetadata, setmetadata, name +export @syms, term, showraw, hasmetadata, getmetadata, setmetadata, name, coeff using TermInterface using DataStructures diff --git a/src/inspect.jl b/src/inspect.jl index cef10216..961c9f0f 100644 --- a/src/inspect.jl +++ b/src/inspect.jl @@ -12,10 +12,10 @@ function AbstractTrees.nodevalue(x::BasicSymbolic) string(x.impl.val) elseif isadd(x) string(exprtype(x), - (scalar = x.impl.coeff, coeffs = Tuple(k => v for (k, v) in x.impl.dict))) + (scalar = coeff(x), coeffs = Tuple(k => v for (k, v) in x.impl.dict))) elseif ismul(x) string(exprtype(x), - (scalar = x.impl.coeff, powers = Tuple(k => v for (k, v) in x.impl.dict))) + (scalar = coeff(x), powers = Tuple(k => v for (k, v) in x.impl.dict))) elseif isdiv(x) || ispow(x) string(exprtype(x)) else diff --git a/src/polyform.jl b/src/polyform.jl index 3db1dd16..6969d20b 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -514,7 +514,7 @@ function quick_mulpow(x, y) den = _Pow(symtype(y), y.impl.base, y.impl.exp-d[y.impl.base]) delete!(d, y.impl.base) end - return _Mul(symtype(x), x.impl.coeff, d), den + return _Mul(symtype(x), coeff(x), d), den else return x, y end diff --git a/src/types.jl b/src/types.jl index 97115557..0434f090 100644 --- a/src/types.jl +++ b/src/types.jl @@ -68,6 +68,10 @@ function name(x::BasicSymbolic) x.impl.name end +function coeff(x::BasicSymbolic) + x.impl.coeff +end + # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @noinline error_sym() = error("Sym doesn't have a operation or arguments!") @@ -293,7 +297,7 @@ function _isequal(a, b, E) if E === SYM nameof(a) === nameof(b) elseif E === ADD || E === MUL - coeff_isequal(a.impl.coeff, b.impl.coeff) && isequal(a.impl.dict, b.impl.dict) + coeff_isequal(coeff(a), coeff(b)) && isequal(a.impl.dict, b.impl.dict) elseif E === DIV isequal(a.impl.num, b.impl.num) && isequal(a.impl.den, b.impl.den) elseif E === POW @@ -337,7 +341,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt h = s.hash[] !iszero(h) && return h hashoffset = isadd(s) ? ADD_SALT : SUB_SALT - h′ = hash(hashoffset, hash(s.impl.coeff, hash(s.impl.dict, salt))) + h′ = hash(hashoffset, hash(coeff(s), hash(s.impl.dict, salt))) s.hash[] = h′ return h′ elseif E === DIV @@ -444,7 +448,7 @@ const Rat = Union{Rational, Integer} function ratcoeff(x) if ismul(x) - ratcoeff(x.impl.coeff) + ratcoeff(coeff(x)) elseif x isa Rat (true, x) else @@ -455,7 +459,7 @@ ratio(x::Integer,y::Integer) = iszero(rem(x,y)) ? div(x,y) : x//y ratio(x::Rat,y::Rat) = x//y function maybe_intcoeff(x) if ismul(x) - coeff = x.impl.coeff + coeff = coeff(x) if coeff isa Rational && isone(denominator(coeff)) _Mul(symtype(x), coeff.num, x.impl.dict; metadata = x.metadata) else @@ -537,7 +541,7 @@ function toterm(t::BasicSymbolic{T}) where {T} return t elseif E === ADD || E === MUL args = BasicSymbolic[] - push!(args, t.impl.coeff) + push!(args, coeff(t)) for (k, coeff) in t.impl.dict push!( args, coeff == 1 ? k : _Term(T, E === MUL ? (^) : (*), [_Const(coeff), k])) @@ -562,7 +566,7 @@ function makeadd(sign, coeff, xs...) d = Dict{BasicSymbolic, Any}() for x in xs if isadd(x) - coeff += x.impl.coeff + coeff += coeff(x) _merge!(+, d, x.impl.dict, filter = _iszero) continue end @@ -572,7 +576,7 @@ function makeadd(sign, coeff, xs...) end if ismul(x) k = _Mul(symtype(x), 1, x.impl.dict) - v = sign * x.impl.coeff + get(d, k, 0) + v = sign * coeff(x) + get(d, k, 0) else k = x v = sign + get(d, x, 0) @@ -593,7 +597,7 @@ function makemul(coeff, xs...; d = Dict{BasicSymbolic, Any}()) elseif x isa Number coeff *= x elseif ismul(x) - coeff *= x.impl.coeff + coeff *= coeff(x) _merge!(+, d, x.impl.dict, filter = _iszero) else v = 1 + get(d, x, 0) @@ -1219,10 +1223,10 @@ function +(a::SN, b::SN) !issafecanon(+, a, b) && return term(+, a, b) # Don't flatten if args have metadata if isadd(a) && isadd(b) return _Add( - add_t(a, b), a.impl.coeff + b.impl.coeff, _merge(+, a.impl.dict, b.impl.dict, filter = _iszero)) + add_t(a, b), coeff(a) + coeff(b), _merge(+, a.impl.dict, b.impl.dict, filter = _iszero)) elseif isadd(a) coeff, dict = makeadd(1, 0, b) - return _Add(add_t(a, b), a.impl.coeff + coeff, _merge(+, a.impl.dict, dict, filter = _iszero)) + return _Add(add_t(a, b), coeff(a) + coeff, _merge(+, a.impl.dict, dict, filter = _iszero)) elseif isadd(b) return b + a end @@ -1236,7 +1240,7 @@ function +(a::Number, b::SN) !issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata iszero(a) && return b if isadd(b) - _Add(add_t(a, b), a + b.impl.coeff, b.impl.dict) + _Add(add_t(a, b), a + coeff(b), b.impl.dict) else _Add(add_t(a, b), makeadd(1, a, b)...) end @@ -1254,7 +1258,7 @@ function -(a::SN) return term(-, a) end if isadd(a) - _Add(sub_t(a), -a.impl.coeff, mapvalues((_, v) -> -v, a.impl.dict)) + _Add(sub_t(a), -coeff(a), mapvalues((_, v) -> -v, a.impl.dict)) else _Add(sub_t(a), makeadd(-1, 0, a)...) end @@ -1262,7 +1266,7 @@ end function -(a::SN, b::SN) (!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b) if isadd(a) && isadd(b) - _Add(sub_t(a, b), a.impl.coeff - b.impl.coeff, _merge(-, a.impl.dict, b.impl.dict, filter = _iszero)) + _Add(sub_t(a, b), coeff(a) - coeff(b), _merge(-, a.impl.dict, b.impl.dict, filter = _iszero)) else a + (-b) end @@ -1289,16 +1293,16 @@ function *(a::SN, b::SN) elseif isdiv(b) _Div(a * b.impl.num, b.impl.den) elseif ismul(a) && ismul(b) - _Mul(mul_t(a, b), a.impl.coeff * b.impl.coeff, + _Mul(mul_t(a, b), coeff(a) * coeff(b), _merge(+, a.impl.dict, b.impl.dict, filter = _iszero)) elseif ismul(a) && ispow(b) if b.impl.exp isa Number _Mul(mul_t(a, b), - a.impl.coeff, + coeff(a), _merge(+, a.impl.dict, Base.ImmutableDict(b.impl.base => b.impl.exp), filter = _iszero)) else - _Mul(mul_t(a, b), a.impl.coeff, + _Mul(mul_t(a, b), coeff(a), _merge(+, a.impl.dict, Base.ImmutableDict(b => 1), filter = _iszero)) end elseif ispow(a) && ismul(b) @@ -1321,7 +1325,7 @@ function *(a::Number, b::SN) elseif isone(-a) && isadd(b) # -1(a+b) -> -a - b T = promote_symtype(+, typeof(a), symtype(b)) - _Add(T, b.impl.coeff * a, + _Add(T, coeff(b) * a, Dict{BasicSymbolic, Any}(k => v * a for (k, v) in b.impl.dict)) else _Mul(mul_t(a, b), makemul(a, b)...) @@ -1346,7 +1350,7 @@ function ^(a::SN, b) elseif b isa Number && b < 0 _Div(1, a^(-b)) elseif ismul(a) && b isa Number - coeff = unstable_pow(a.impl.coeff, b) + coeff = unstable_pow(coeff(a), b) _Mul(promote_symtype(^, symtype(a), symtype(b)), coeff, mapvalues((k, v) -> b * v, a.impl.dict)) else diff --git a/test/basics.jl b/test/basics.jl index eeb927ec..a7db995f 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -344,18 +344,18 @@ end @testset "div" begin @syms x::SafeReal y::Real @test issym((2x / 2y).impl.num) - @test (2x / 3y).impl.num.impl.coeff == 2 - @test (2x / 3y).impl.den.impl.coeff == 3 - @test (2x / -3x).impl.num.impl.coeff == -2 - @test (2x / -3x).impl.den.impl.coeff == 3 - @test (2.5x / 3x).impl.num.impl.coeff == 2.5 - @test (2.5x / 3x).impl.den.impl.coeff == 3 - @test (x / 3x).impl.den.impl.coeff == 3 + @test coeff((2x / 3y).impl.num) == 2 + @test coeff((2x / 3y).impl.den) == 3 + @test coeff((2x / -3x).impl.num) == -2 + @test coeff((2x / -3x).impl.den) == 3 + @test coeff((2.5x / 3x).impl.num) == 2.5 + @test coeff((2.5x / 3x).impl.den) == 3 + @test coeff((x / 3x).impl.den) == 3 @syms x y @test issym((2x / 2y).impl.num) - @test (2x / 3y).impl.num.impl.coeff == 2 - @test (2x / 3y).impl.den.impl.coeff == 3 + @test coeff((2x / 3y).impl.num) == 2 + @test coeff((2x / 3y).impl.den) == 3 @test (2x / -3x) == -2 // 3 @test (2.5x / 3x).impl.num == 2.5 @test (2.5x / 3x).impl.den == 3