diff --git a/src/types.jl b/src/types.jl index 81220ebd..f2b0cfa9 100644 --- a/src/types.jl +++ b/src/types.jl @@ -98,6 +98,7 @@ function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple @compactified obj::BasicSymbolic begin Sym => Sym{T}(nt_new.name; nt_new...) Term => Term{T}(nt_new.f, nt_new.arguments; nt_new...) + Add => Add(T, nt_new.coeff, nt_new.dict; nt_new...) _ => Unityper.rt_constructor(obj){T}(;nt_new...) end end @@ -418,7 +419,8 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T end end - Add{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) + s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) + BasicSymbolic(s) end function Mul(T, a, b; metadata=NO_METADATA, kw...) diff --git a/test/hash_consing.jl b/test/hash_consing.jl index 82385736..d739dbbc 100644 --- a/test/hash_consing.jl +++ b/test/hash_consing.jl @@ -1,5 +1,5 @@ using SymbolicUtils, Test -using SymbolicUtils: Term +using SymbolicUtils: Term, Add struct Ctx1 end struct Ctx2 end @@ -40,3 +40,17 @@ end tm1 = setmetadata(t1, Ctx1, "meta_1") @test t1 !== tm1 end + +@testset "Add" begin + d1 = a + b + d2 = b + a + @test d1 === d2 + d3 = b - 2 + a + d4 = a + b - 2 + @test d3 === d4 + d5 = Add(Int, 0, Dict(a => 1, b => 1)) + @test d5 !== d1 + + dm1 = setmetadata(d1,Ctx1,"meta_1") + @test d1 !== dm1 +end