Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

got dot access working \o/ #3

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
name = "OOPMacro"
uuid = "90c472d1-064c-5c63-af2e-229f1fdb5f26"
authors = ["Shih-Ming Wang <[email protected]>", "Marius Kruger <[email protected]>"]
version = "0.3.0"
version = "0.4.0"

[compat]
julia = ">= 0.7"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "BenchmarkTools", "Statistics"]
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,5 @@ cvalue = c.cfield

# Future Work
- write unit test for each function in clsUtil and OOPMacroImpl
- override getproperty() to make more natural usage of
'methods'
- maybe don't require manually setting the self arg when declaring methods; rathre specify @static if it is not a object method
- maybe don't require manually setting the self arg when declaring methods; rather specify @static if it is not a object method
- Type generic parameter ??
90 changes: 67 additions & 23 deletions src/OOPMacroImpl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,36 @@ include("clsUtil.jl")
#= ClsMethods = Dict{Symbol, Dict{Expr, Expr}}() =#
ClsMethods = Dict(:Any=>Dict{Expr, Expr}())
ClsFields = Dict(:Any=>Vector{Expr}())
validOptions = (:nodotoperator,)

OOPMacroModule=@__MODULE__
clsFnDefDict=WeakKeyDict{Any,Dict{Symbol,Function}}()
macro class(args...)
if length(args) < 2
error("At least a class name and body must be specified.")
end
options = args[1:end-2]
clsName = args[end-1]
cBody = args[end]
for o in options
if o ∉ validOptions
error("$o is not a valid option. Valid options are: $(join(validOptions, ", "))")
end
end
supportDotOperator = :nodotoperator ∉ options

macro class(ClsName, Cbody)
ClsName, ParentClsNameLst = getCAndP(ClsName)
AbsClsName = getAbstractCls(ClsName)
clsName, ParentClsNameLst = getCAndP(clsName)
AbsClsName = getAbstractCls(clsName)
AbsParentClsName = getAbstractCls(ParentClsNameLst)

ClsFields[ClsName] = fields = copyFields(ParentClsNameLst, ClsFields)
ClsMethods[ClsName] = methods = Dict{Expr,Expr}()

ClsFields[clsName] = fields = copyFields(ParentClsNameLst, ClsFields)
ClsMethods[clsName] = methods = Dict{Expr,Expr}()

cons = Any[]
hasInit = false

# record fields and methods separately
for (i, block) in enumerate(Cbody.args)
for (i, block) in enumerate(cBody.args)
if isa(block, Symbol)
union!(fields, [:($block::Any)])
elseif isa(block, LineNumberNode)
Expand All @@ -30,37 +44,36 @@ macro class(ClsName, Cbody)
continue
elseif block.head == :(=) || block.head == :function
fname = getFnName(block, withoutGeneric=true)
if fname == ClsName
if fname == clsName
append!(cons, [block])
elseif fname == :__init__
hasInit = true
setFnName!(block, ClsName)
self = findFnSelfArgNameSymbol(block, ClsName)
setFnName!(block, clsName)
self = findFnSelfArgNameSymbol(block, clsName)
deleteFnSelf!(block)
prepend!(block.args[2].args, [:($self = $ClsName(()))])
prepend!(block.args[2].args, [:($self = $clsName(()))])
append!(block.args[2].args, [:($self)])
append!(cons, [block])
else
fn = copy(block)
setFnSelfArgType!(fn, ClsName)
setFnSelfArgType!(fn, clsName)
methods[findFnCall(fn)] = fn
end
else
error("@class: Case not handled")
end
end


ClsFnCalls = Set(keys(methods))
for parent in ParentClsNameLst
for pfn in values(ClsMethods[parent])
fn = copy(pfn)
setFnSelfArgType!(fn, ClsName)
fnCall = findFnCall(fn)
setFnSelfArgType!(fn, clsName)
fnCall = findFncons_strCall(fn)
if haskey(methods, fnCall)
fName = getFnName(fn, withoutGeneric=true)
if !(fnCall in ClsFnCalls)
error("Ambiguious Function Definition: Multiple definition of function $fName while $ClsName does not overwtie this function!!")
error("Ambiguious Function Definition: Multiple definition of function $fName while $clsName does not overwrite this function!!")
end
setFnName!(fn, Symbol(string("super_", parent, fName)), withoutGeneric=true)
methods[fnCall] = fn
Expand All @@ -71,18 +84,49 @@ macro class(ClsName, Cbody)
end

cons_str = join(cons,"\n") * "\n"
@show hasInit
if hasInit
cons_str *= "$ClsName(::Tuple{}) = new()\n"
cons_str *= "$clsName(::Tuple{}) = new()\n"
end

# println("cons_str ",con_str)
clsDefStr = """
mutable struct $ClsName
mutable struct $clsName
$(join(fields,"\n"))
""" * cons_str * """
end
"""
# Escape here because we want ClsName and the methods be defined in user scope instead of OOPMacro module scope.
esc(Expr(:block, Meta.parse(clsDefStr), values(methods)...))
# fun1::Function
end"""
println("clsDefStr ",clsDefStr)
# this allows calling functions on the class..
clsFnNames = map(fn->"$(getFnName(fn, withoutGeneric=true))", collect(values(methods)))
clsFnNameList = join(map(name->":$name,", clsFnNames),"")
# dotAccStr = """
# function Base.getproperty(self::$clsName, nameSymbol::Symbol)
# if isdefined(self, nameSymbol) #|| nameSymbol ∉ ($clsFnNameList)
# getfield(self, nameSymbol)
# else
# $(OOPMacroModule).clsFnDefDict[self][nameSymbol]
# end
# end
# """
dotAccStr2 = """
function __initOOPMacroFunctions(self::$clsName)
#fnDict=$(OOPMacroModule).clsFnDefDict[self] = Dict{Symbol,Function}()
#for nameSymbol in ($clsFnNameList)
#self.fun1=(args...; kwargs...)->eval(:(\$nameSymbol(\$self, \$args...; \$kwargs...)))
self.fun1=(args...; kwargs...)->eval(:(fun1(\$self, \$args...; \$kwargs...)))
#end
self
end
"""
blockSections = [Meta.parse(clsDefStr), values(methods)...]
if supportDotOperator
push!(blockSections, #Meta.parse(dotAccStr),
Meta.parse(dotAccStr2)
)
end

# Escape here because we want clsName and the methods be defined in user scope instead of OOPMacro module scope.
esc(Expr(:block, blockSections...))
end

macro super(ParentClsName, FCall)
Expand Down
2 changes: 2 additions & 0 deletions src/clsUtil.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@

""" Get the class name and parents """
function getCAndP(cls)
if isa(cls, Symbol)
C, P = cls, [:Any]
Expand Down
10 changes: 10 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ using Test
field0
field1::Int
field2::Int
SimpleCls(field0, field1::Int, field2::Int) = begin
self = __initOOPMacroFunctions(new(field0, field1, field2))
end

#= Supports different style of function declaration =#
fun0(self::SimpleCls, x) = self.field0 + x
Expand All @@ -30,18 +33,25 @@ s = SimpleCls(0,1,2)
@test s.field2 == 2
@test fun0(s, 2.) == 2
@test fun1(s, 2., 3) == 3
@test s.fun1(2., 3) == 3
@test fun2(s, 2.) == 4
@test fun2(s, 2) == 4
@test_throws(MethodError, fun2(s,"a"))

@class SimpleCls1 begin
field0::Int
SimpleCls1(field0::Int) = begin
self = __initOOPMacroFunctions(new(field0))
end
fun0(self, x, y=1) = self.field0 + x + y
fun1(self, x, y=1; z=2) = self.field0 + x + y + z
end
s1 = SimpleCls1(0)
@test fun0(s1, 1) == 2
@test s1.fun0(1) == 2
@test fun0(s1, 1, 2) == 3
@test s1.fun0(1, 2) == 3
@test fun1(s1, 1) == 4
@test fun1(s1, 1, 2) == 5
@test fun1(s1, 1, 2, z=3) == 6
@test s1.fun1(1, 2, z=3) == 6
48 changes: 48 additions & 0 deletions test/dotoperator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using Test

@class BasicDotOpr begin
field0::Int
fun0(self::SimpleCls, x) = self.field0 + x
end

@class nodotoperator BasicNoDotOpr begin
field0::Int
fun0(self::SimpleCls, x) = self.field0 + x
end

@testset "dot operator tests" begin

@testset "invalid option validation" begin
@test_throws(LoadError, @macroexpand @class invalidoption MyCls begin end)
end

@testset "tests with dot operator" begin
@testset "basic test" begin
@show bdo = BasicDotOpr(1)
@time bdo.field0
@time bdo.field0
@time bdo.field0
@test_throws(ErrorException, bdo.invalidfield)
@test fun0(bdo, 1) == 2
@time fun0(bdo, 1) == 2
@time fun0(bdo, 1) == 2
@time fun0(bdo, 1) == 2
@test bdo.fun0(1) == 2
@time bdo.fun0(1) == 2
@time bdo.fun0(1) == 2
@time bdo.fun0(1) == 2
end
end

@testset "tests without dot operator" begin
@testset "basic test" begin
@show bndo = BasicNoDotOpr(1)
@time bndo.field0
@time bndo.field0
@time bndo.field0
@test_throws(ErrorException, bndo.invalidfield)
@test fun0(bndo, 1) == 2
@test_throws(ErrorException, bndo.fun0(1))
end
end
end
58 changes: 58 additions & 0 deletions test/dotoperator_benchmark.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
using Test
using BenchmarkTools
using Statistics

if !isdefined(Main, :B1DotOpr)
@class B1DotOpr begin
field1::Int
fun1#::Function
fun1(self::SimpleCls, x) = self.field1 + x
B1DotOpr(field1::Int) = begin
self = __initOOPMacroFunctions(new(field1,nothing))
#self.field1=field1
end
end
end
if !isdefined(Main, :B1NoDotOpr)
@class nodotoperator B1NoDotOpr begin
field1::Int
fun1(self::SimpleCls, x) = self.field1 + x
# B1NoDotOpr(field1::Int) = begin
# self = new()
# self.field1=field1
# end
end
end

bg = BenchmarkGroup()
bg["DotOpr"] = BenchmarkGroup()
bg["NoDotOpr"] = BenchmarkGroup()

bg["DotOpr"]["get_field1"] = @benchmarkable o.field1 setup=(o = B1DotOpr(rand(Int)))
bg["NoDotOpr"]["get_field1"] = @benchmarkable o.field1 setup=(o = B1NoDotOpr(rand(Int)))

bg["DotOpr"]["call_normal_fun1"] = @benchmarkable fun1(o, 1) setup=(o = B1DotOpr(rand(Int)))
bg["NoDotOpr"]["call_normal_fun1"] = @benchmarkable fun1(o, 1) setup=(o = B1NoDotOpr(rand(Int)))

bg["DotOpr"]["call_fun1"] = @benchmarkable o.fun1(1) setup=(o = B1DotOpr(rand(Int)))
bg["NoDotOpr"]["call_fun1"] = @benchmarkable fun1(o, 1) setup=(o = B1NoDotOpr(rand(Int)))

# bg["DotOpr"]["call_warm_fun1"] = @benchmarkable o.fun1(1) setup=(o = B1DotOpr(rand(Int)); o.fun1(1))
# bg["NoDotOpr"]["call_warm_fun1"] = @benchmarkable fun1(o, 1) setup=(o = B1NoDotOpr(rand(Int)); fun1(o, 1))

tune!(bg)
results = run(bg, verbose = true, seconds = 2)

@testset "tests if there are regressions" begin
for t in ("get_field1", "call_normal_fun1",
"call_fun1" #, "call_warm_fun1"
)
@testset "regressions in $t" begin
med = median(results)
#println(t, med)
j = judge(med["DotOpr"][t], med["NoDotOpr"][t])
println(t, ": ", j)
@test !isregression(j)
end
end
end
8 changes: 8 additions & 0 deletions test/inheritence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,21 @@ pvalue = p.pfield
pvalue2 = c.pfield2
cvalue = c.cfield
@test pfun(p) == pvalue
@test p.pfun() == pvalue
@test pfun(c) == cvalue
@test c.pfun() == cvalue
@test pfunAdd(p,1) == pvalue + 1
@test p.pfunAdd(1) == pvalue + 1
@test pfunAdd(c,1) == cvalue + 1
@test c.pfunAdd(1) == cvalue + 1
@test_throws(MethodError, pfunAdd(c,"a"))

@test cfunSuper(c) == pvalue
@test c.cfunSuper() == pvalue
@test cfunAddSuper(c,1) == pvalue+1
@test c.cfunAddSuper(1) == pvalue+1

@test cfunSuper2(c) == pvalue2
@test c.cfunSuper2() == pvalue2
@test cfunAddSuper2(c,1) == pvalue2+1
@test c.cfunAddSuper2(1) == pvalue2+1
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ include("fnUtil.jl")
include("basic.jl")
include("constructor.jl")
include("inheritence.jl")
include("dotoperator.jl")
include("dotoperator_benchmark.jl")