diff --git a/Project.toml b/Project.toml index 695016b..0c046ea 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,15 @@ name = "OOPMacro" uuid = "90c472d1-064c-5c63-af2e-229f1fdb5f26" authors = ["Shih-Ming Wang ", "Marius Kruger "] -version = "0.3.0" +version = "0.4.0" [compat] julia = ">= 0.7" [extras] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test", "BenchmarkTools", "Statistics"] diff --git a/README.md b/README.md index d8ea598..ae296de 100644 --- a/README.md +++ b/README.md @@ -136,7 +136,5 @@ cvalue = c.cfield # Future Work - write unit test for each function in clsUtil and OOPMacroImpl -- override getproperty() to make more natural usage of - 'methods' -- maybe don't require manually setting the self arg when declaring methods; rathre specify @static if it is not a object method +- maybe don't require manually setting the self arg when declaring methods; rather specify @static if it is not a object method - Type generic parameter ?? diff --git a/src/OOPMacroImpl.jl b/src/OOPMacroImpl.jl index cd71cc4..59560e5 100644 --- a/src/OOPMacroImpl.jl +++ b/src/OOPMacroImpl.jl @@ -4,22 +4,36 @@ include("clsUtil.jl") #= ClsMethods = Dict{Symbol, Dict{Expr, Expr}}() =# ClsMethods = Dict(:Any=>Dict{Expr, Expr}()) ClsFields = Dict(:Any=>Vector{Expr}()) +validOptions = (:nodotoperator,) +OOPMacroModule=@__MODULE__ +clsFnDefDict=WeakKeyDict{Any,Dict{Symbol,Function}}() +macro class(args...) + if length(args) < 2 + error("At least a class name and body must be specified.") + end + options = args[1:end-2] + clsName = args[end-1] + cBody = args[end] + for o in options + if o ∉ validOptions + error("$o is not a valid option. Valid options are: $(join(validOptions, ", "))") + end + end + supportDotOperator = :nodotoperator ∉ options -macro class(ClsName, Cbody) - ClsName, ParentClsNameLst = getCAndP(ClsName) - AbsClsName = getAbstractCls(ClsName) + clsName, ParentClsNameLst = getCAndP(clsName) + AbsClsName = getAbstractCls(clsName) AbsParentClsName = getAbstractCls(ParentClsNameLst) - ClsFields[ClsName] = fields = copyFields(ParentClsNameLst, ClsFields) - ClsMethods[ClsName] = methods = Dict{Expr,Expr}() - + ClsFields[clsName] = fields = copyFields(ParentClsNameLst, ClsFields) + ClsMethods[clsName] = methods = Dict{Expr,Expr}() cons = Any[] hasInit = false # record fields and methods separately - for (i, block) in enumerate(Cbody.args) + for (i, block) in enumerate(cBody.args) if isa(block, Symbol) union!(fields, [:($block::Any)]) elseif isa(block, LineNumberNode) @@ -30,19 +44,19 @@ macro class(ClsName, Cbody) continue elseif block.head == :(=) || block.head == :function fname = getFnName(block, withoutGeneric=true) - if fname == ClsName + if fname == clsName append!(cons, [block]) elseif fname == :__init__ hasInit = true - setFnName!(block, ClsName) - self = findFnSelfArgNameSymbol(block, ClsName) + setFnName!(block, clsName) + self = findFnSelfArgNameSymbol(block, clsName) deleteFnSelf!(block) - prepend!(block.args[2].args, [:($self = $ClsName(()))]) + prepend!(block.args[2].args, [:($self = $clsName(()))]) append!(block.args[2].args, [:($self)]) append!(cons, [block]) else fn = copy(block) - setFnSelfArgType!(fn, ClsName) + setFnSelfArgType!(fn, clsName) methods[findFnCall(fn)] = fn end else @@ -50,17 +64,16 @@ macro class(ClsName, Cbody) end end - ClsFnCalls = Set(keys(methods)) for parent in ParentClsNameLst for pfn in values(ClsMethods[parent]) fn = copy(pfn) - setFnSelfArgType!(fn, ClsName) - fnCall = findFnCall(fn) + setFnSelfArgType!(fn, clsName) + fnCall = findFncons_strCall(fn) if haskey(methods, fnCall) fName = getFnName(fn, withoutGeneric=true) if !(fnCall in ClsFnCalls) - error("Ambiguious Function Definition: Multiple definition of function $fName while $ClsName does not overwtie this function!!") + error("Ambiguious Function Definition: Multiple definition of function $fName while $clsName does not overwrite this function!!") end setFnName!(fn, Symbol(string("super_", parent, fName)), withoutGeneric=true) methods[fnCall] = fn @@ -71,18 +84,49 @@ macro class(ClsName, Cbody) end cons_str = join(cons,"\n") * "\n" + @show hasInit if hasInit - cons_str *= "$ClsName(::Tuple{}) = new()\n" + cons_str *= "$clsName(::Tuple{}) = new()\n" end - +# println("cons_str ",con_str) clsDefStr = """ - mutable struct $ClsName + mutable struct $clsName $(join(fields,"\n")) """ * cons_str * """ - end - """ - # Escape here because we want ClsName and the methods be defined in user scope instead of OOPMacro module scope. - esc(Expr(:block, Meta.parse(clsDefStr), values(methods)...)) + # fun1::Function + end""" +println("clsDefStr ",clsDefStr) + # this allows calling functions on the class.. + clsFnNames = map(fn->"$(getFnName(fn, withoutGeneric=true))", collect(values(methods))) + clsFnNameList = join(map(name->":$name,", clsFnNames),"") + # dotAccStr = """ + # function Base.getproperty(self::$clsName, nameSymbol::Symbol) + # if isdefined(self, nameSymbol) #|| nameSymbol ∉ ($clsFnNameList) + # getfield(self, nameSymbol) + # else + # $(OOPMacroModule).clsFnDefDict[self][nameSymbol] + # end + # end + # """ + dotAccStr2 = """ + function __initOOPMacroFunctions(self::$clsName) + #fnDict=$(OOPMacroModule).clsFnDefDict[self] = Dict{Symbol,Function}() + #for nameSymbol in ($clsFnNameList) + #self.fun1=(args...; kwargs...)->eval(:(\$nameSymbol(\$self, \$args...; \$kwargs...))) + self.fun1=(args...; kwargs...)->eval(:(fun1(\$self, \$args...; \$kwargs...))) + #end + self + end + """ + blockSections = [Meta.parse(clsDefStr), values(methods)...] + if supportDotOperator + push!(blockSections, #Meta.parse(dotAccStr), + Meta.parse(dotAccStr2) + ) + end + + # Escape here because we want clsName and the methods be defined in user scope instead of OOPMacro module scope. + esc(Expr(:block, blockSections...)) end macro super(ParentClsName, FCall) diff --git a/src/clsUtil.jl b/src/clsUtil.jl index b218bd0..67d8375 100644 --- a/src/clsUtil.jl +++ b/src/clsUtil.jl @@ -1,3 +1,5 @@ + +""" Get the class name and parents """ function getCAndP(cls) if isa(cls, Symbol) C, P = cls, [:Any] diff --git a/test/basic.jl b/test/basic.jl index 9227c81..4b49077 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -4,6 +4,9 @@ using Test field0 field1::Int field2::Int + SimpleCls(field0, field1::Int, field2::Int) = begin + self = __initOOPMacroFunctions(new(field0, field1, field2)) + end #= Supports different style of function declaration =# fun0(self::SimpleCls, x) = self.field0 + x @@ -30,18 +33,25 @@ s = SimpleCls(0,1,2) @test s.field2 == 2 @test fun0(s, 2.) == 2 @test fun1(s, 2., 3) == 3 +@test s.fun1(2., 3) == 3 @test fun2(s, 2.) == 4 @test fun2(s, 2) == 4 @test_throws(MethodError, fun2(s,"a")) @class SimpleCls1 begin field0::Int + SimpleCls1(field0::Int) = begin + self = __initOOPMacroFunctions(new(field0)) + end fun0(self, x, y=1) = self.field0 + x + y fun1(self, x, y=1; z=2) = self.field0 + x + y + z end s1 = SimpleCls1(0) @test fun0(s1, 1) == 2 +@test s1.fun0(1) == 2 @test fun0(s1, 1, 2) == 3 +@test s1.fun0(1, 2) == 3 @test fun1(s1, 1) == 4 @test fun1(s1, 1, 2) == 5 @test fun1(s1, 1, 2, z=3) == 6 +@test s1.fun1(1, 2, z=3) == 6 diff --git a/test/dotoperator.jl b/test/dotoperator.jl new file mode 100644 index 0000000..a0f9e4d --- /dev/null +++ b/test/dotoperator.jl @@ -0,0 +1,48 @@ +using Test + +@class BasicDotOpr begin + field0::Int + fun0(self::SimpleCls, x) = self.field0 + x +end + +@class nodotoperator BasicNoDotOpr begin + field0::Int + fun0(self::SimpleCls, x) = self.field0 + x +end + +@testset "dot operator tests" begin + + @testset "invalid option validation" begin + @test_throws(LoadError, @macroexpand @class invalidoption MyCls begin end) + end + + @testset "tests with dot operator" begin + @testset "basic test" begin + @show bdo = BasicDotOpr(1) + @time bdo.field0 + @time bdo.field0 + @time bdo.field0 + @test_throws(ErrorException, bdo.invalidfield) + @test fun0(bdo, 1) == 2 + @time fun0(bdo, 1) == 2 + @time fun0(bdo, 1) == 2 + @time fun0(bdo, 1) == 2 + @test bdo.fun0(1) == 2 + @time bdo.fun0(1) == 2 + @time bdo.fun0(1) == 2 + @time bdo.fun0(1) == 2 + end + end + + @testset "tests without dot operator" begin + @testset "basic test" begin + @show bndo = BasicNoDotOpr(1) + @time bndo.field0 + @time bndo.field0 + @time bndo.field0 + @test_throws(ErrorException, bndo.invalidfield) + @test fun0(bndo, 1) == 2 + @test_throws(ErrorException, bndo.fun0(1)) + end + end +end diff --git a/test/dotoperator_benchmark.jl b/test/dotoperator_benchmark.jl new file mode 100644 index 0000000..c920c6b --- /dev/null +++ b/test/dotoperator_benchmark.jl @@ -0,0 +1,58 @@ +using Test +using BenchmarkTools +using Statistics + +if !isdefined(Main, :B1DotOpr) + @class B1DotOpr begin + field1::Int + fun1#::Function + fun1(self::SimpleCls, x) = self.field1 + x + B1DotOpr(field1::Int) = begin + self = __initOOPMacroFunctions(new(field1,nothing)) + #self.field1=field1 + end + end +end +if !isdefined(Main, :B1NoDotOpr) + @class nodotoperator B1NoDotOpr begin + field1::Int + fun1(self::SimpleCls, x) = self.field1 + x + # B1NoDotOpr(field1::Int) = begin + # self = new() + # self.field1=field1 + # end + end +end + +bg = BenchmarkGroup() +bg["DotOpr"] = BenchmarkGroup() +bg["NoDotOpr"] = BenchmarkGroup() + +bg["DotOpr"]["get_field1"] = @benchmarkable o.field1 setup=(o = B1DotOpr(rand(Int))) +bg["NoDotOpr"]["get_field1"] = @benchmarkable o.field1 setup=(o = B1NoDotOpr(rand(Int))) + +bg["DotOpr"]["call_normal_fun1"] = @benchmarkable fun1(o, 1) setup=(o = B1DotOpr(rand(Int))) +bg["NoDotOpr"]["call_normal_fun1"] = @benchmarkable fun1(o, 1) setup=(o = B1NoDotOpr(rand(Int))) + +bg["DotOpr"]["call_fun1"] = @benchmarkable o.fun1(1) setup=(o = B1DotOpr(rand(Int))) +bg["NoDotOpr"]["call_fun1"] = @benchmarkable fun1(o, 1) setup=(o = B1NoDotOpr(rand(Int))) + +# bg["DotOpr"]["call_warm_fun1"] = @benchmarkable o.fun1(1) setup=(o = B1DotOpr(rand(Int)); o.fun1(1)) +# bg["NoDotOpr"]["call_warm_fun1"] = @benchmarkable fun1(o, 1) setup=(o = B1NoDotOpr(rand(Int)); fun1(o, 1)) + +tune!(bg) +results = run(bg, verbose = true, seconds = 2) + +@testset "tests if there are regressions" begin + for t in ("get_field1", "call_normal_fun1", + "call_fun1" #, "call_warm_fun1" + ) + @testset "regressions in $t" begin + med = median(results) + #println(t, med) + j = judge(med["DotOpr"][t], med["NoDotOpr"][t]) + println(t, ": ", j) + @test !isregression(j) + end + end +end diff --git a/test/inheritence.jl b/test/inheritence.jl index 38eb8a1..f446d00 100644 --- a/test/inheritence.jl +++ b/test/inheritence.jl @@ -31,13 +31,21 @@ pvalue = p.pfield pvalue2 = c.pfield2 cvalue = c.cfield @test pfun(p) == pvalue +@test p.pfun() == pvalue @test pfun(c) == cvalue +@test c.pfun() == cvalue @test pfunAdd(p,1) == pvalue + 1 +@test p.pfunAdd(1) == pvalue + 1 @test pfunAdd(c,1) == cvalue + 1 +@test c.pfunAdd(1) == cvalue + 1 @test_throws(MethodError, pfunAdd(c,"a")) @test cfunSuper(c) == pvalue +@test c.cfunSuper() == pvalue @test cfunAddSuper(c,1) == pvalue+1 +@test c.cfunAddSuper(1) == pvalue+1 @test cfunSuper2(c) == pvalue2 +@test c.cfunSuper2() == pvalue2 @test cfunAddSuper2(c,1) == pvalue2+1 +@test c.cfunAddSuper2(1) == pvalue2+1 diff --git a/test/runtests.jl b/test/runtests.jl index ec27401..f34edbe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,3 +5,5 @@ include("fnUtil.jl") include("basic.jl") include("constructor.jl") include("inheritence.jl") +include("dotoperator.jl") +include("dotoperator_benchmark.jl")