From 479e3df279570dbfc015a52e022f6292d90ef843 Mon Sep 17 00:00:00 2001 From: Richard Plevin Date: Thu, 3 Jan 2019 18:14:11 -0800 Subject: [PATCH] Various improvements for subclassing parameterized types --- src/Classes.jl | 149 ++++++++++++++++++++++++++++++------------- test/test_classes.jl | 66 +++++-------------- 2 files changed, 121 insertions(+), 94 deletions(-) diff --git a/src/Classes.jl b/src/Classes.jl index 9498e5c..a412cf7 100644 --- a/src/Classes.jl +++ b/src/Classes.jl @@ -2,19 +2,84 @@ module Classes using DataStructures using MacroTools +using MacroTools:combinedef, combinestructdef using InteractiveUtils: subtypes export @class, @method, Class, AbstractClass, isclass, classof, superclass, superclasses, issubclass, subclasses, absclass +# +# Functional interface to MacroTools' dict-based expression creation functions +# TBD: suggest adding these (without requiring dict) to MacroTools +# +function emit_struct(name::Symbol, supertype::Symbol, mutable::Bool, params::Vector, fields::Vector, ctors::Vector) + fieldtups = [(tup[1], tup[2]) for tup in map(splitarg, fields)] + d = Dict(:name=>name, :supertype=>supertype, :mutable=>mutable, :params=>params, :fields=>fieldtups, :constructors=>ctors) + combinestructdef(d) +end + +function emit_function(name::Symbol, body; args::Vector=[], kwargs::Vector=[], rtype=nothing, params::Vector=[], wparams::Vector=[]) + d = Dict(:name=>name, :args=>args, :kwargs=>kwargs, :body=>body, :params=>params, :whereparams=>wparams) + if rtype !== nothing + d[:rtype] = rtype + end + combinedef(d) +end + + abstract type AbstractClass end # supertype of all shadow class types abstract type Class <: AbstractClass end # superclass of all concrete classes abs_symbol(cls::Symbol) = Symbol("Abstract", cls) +# Since nameof() doesn't cover all the cases we need, we define our own +typename(t) = t +typename(t::TypeVar) = t.name + + +# fieldnames(DataType) +# (:name, :super, :parameters, :types, :names, :instance, :layout, :size, :ninitialized, :uid, :abstract, :mutable, :hasfreetypevars, +# :isconcretetype, :isdispatchtuple, :isbitstype, :zeroinit, :isinlinealloc, Symbol("llvm::StructType"), Symbol("llvm::DIType")) +# +# if dtype.hasfreetypevars, dtype.types is like svec(XYZ<:ABC,...) + +function _translate_ivar(d::Dict, ivar) + if ! @capture(ivar, vname_::vtype_ | vname_) + error("Expected field definition, got $ivar") + end + + if vtype === nothing + return ivar # no type, nothing to translate + end + + vtype = get(d, vtype, vtype) # translate parameterized types + return :($vname::$vtype) +end + +function _translate_where(d::Dict, wparam::TypeVar) + # supname = :Any + supname = wparam.ub # TBD: not sure this suffices + name = wparam.name + name = get(d, name, name) # translate, if a type parameter, else pass through + + return :($name <: $supname) +end + +# If a symbol is already a gensym, extract the symbol and re-gensym with it +regensym(s) = MacroTools.isgensym(s) ? gensym(Symbol(MacroTools.gensymname(s))) : gensym(s) + # Return info about a class in a named tuple function _class_info(::Type{T}) where {T <: AbstractClass} - ivars = (isabstracttype(T) ? Expr[] : [:($vname::$vtype) for (vname, vtype) in zip(fieldnames(T), T.types)]) - return (modname=T.name.module, mutable=T.mutable, parameters=T.parameters, ivars=ivars, super=superclass(T)) + typ = (typeof(T) === UnionAll ? Base.unwrap_unionall(T) : T) + + # note: must extract symbol from type to create required expression + ivars = (isabstracttype(typ) ? Expr[] : [:($vname::$(typename(vtype))) for (vname, vtype) in zip(fieldnames(typ), typ.types)]) + wheres = typ.parameters + + d = Dict(t.name=>regensym(t.name) for t in wheres) # create mapping of type params to gen'd symbols + ivars = [_translate_ivar(d, iv) for iv in ivars] # translate types to use gensyms + wheres = [_translate_where(d, w) for w in wheres] + + return (wheres=wheres, ivars=ivars, super=superclass(typ)) end """ @@ -44,7 +109,9 @@ end Return `true` if `X` is a concrete subclass of `AbstractClass`, or is `Class`, which is abstract. """ isclass(any) = false -isclass(::Type{T}) where {T <: AbstractClass} = (T === Class || isconcretetype(T)) + +# Note that !isabstracttype(T) != isconcretetype(T): parameterized types return false for both +isclass(::Type{T}) where {T <: AbstractClass} = (T === Class || !isabstracttype(T)) """ issubclass(t1::DataType, t2::DataType) @@ -113,9 +180,10 @@ end function _initializer(class, fields, wheres) args = _argnames(fields) assigns = [:(_self.$arg = $arg) for arg in args] + T = gensym(class) funcdef = :( - function $class(_self::T, $(fields...)) where {T <: $(abs_symbol(class)), $(wheres...)} + function $class(_self::$T, $(fields...)) where {$T <: $(abs_symbol(class)), $(wheres...)} $(assigns...) _self end @@ -124,53 +192,48 @@ function _initializer(class, fields, wheres) return funcdef end -function _constructors(clsname, super, local_fields, all_fields, wheres) - args = _argnames(all_fields) - params = [clause.args[1] for clause in wheres] # extract parameter names from where clauses - - dflt = length(params) > 0 ? :( - function $clsname{$(params...)}($(all_fields...)) where {$(wheres...)} - new{$(params...)}($(args...)) - end) : :( - function $clsname($(all_fields...)) - new($(args...)) - end) - - init_all = _initializer(clsname, all_fields, wheres) - methods = [dflt, init_all] +function _constructors(clsname, super, super_info, local_fields, all_fields, wheres) + all_wheres = [super_info.wheres; wheres] + init_all = _initializer(clsname, all_fields, all_wheres) + inits = [init_all] # If clsname is a direct subclasses of Classes.Class, it has no fields # other than those defined locally, so the two methods would be identical. # In this case, we emit only one of them. if all_fields != local_fields init_local = _initializer(clsname, local_fields, wheres) - push!(methods, init_local) + push!(inits, init_local) end + params = [clause.args[1] for clause in all_wheres] # extract parameter names from where clauses + has_params = length(params) != 0 + + args = _argnames(all_fields) + body = has_params ? :(new{$(params...)}($(args...))) : :(new($(args...))) + dflt = emit_function(clsname, body, args=all_fields, params=params, wparams=all_wheres, rtype=clsname) + + methods = [dflt] + # Primarily for immutable classes, we emit a constructor that takes an instance - # of the direct superclass and copies values when creating a new object. - super_info = _class_info(super) + # of the direct superclass and copies values when creating a new object. super_fields = super_info.ivars if length(super_fields) != 0 super_args = [:(_super.$arg) for arg in _argnames(super_fields)] local_args = _argnames(local_fields) all_args = [super_args; local_args] - immut_init = length(params) > 0 ? :( - function $clsname{$(params...)}(_super::$super, $(local_fields...)) where {$(wheres...)} - new{$(params...)}($(all_args...)) - end) : :( - function $clsname(_super::$super, $(local_fields...)) - new($(all_args...)) - end) + body = has_params ? :(new{$(params...)}($(all_args...))) : :(new($(all_args...))) + args = [:(_super::$super); local_fields] + immut_init = emit_function(clsname, body; args=args, params=params, wparams=all_wheres, rtype=clsname) push!(methods, immut_init) end - return methods + return methods, inits end function _defclass(clsname, supercls, mutable, wheres, exprs) wheres = (wheres === nothing ? [] : wheres) + # @info "clsname:$clsname supercls:$supercls mutable:$mutable wheres:$wheres exprs:$exprs" # partition expressions into constructors and field defs ctors = Vector{Expr}() @@ -184,22 +247,18 @@ function _defclass(clsname, supercls, mutable, wheres, exprs) end end - superinfo = _class_info(supercls) - all_fields = copy(superinfo.ivars) - append!(all_fields, fields) + super_info = _class_info(supercls) + all_fields = [super_info.ivars; fields] + all_wheres = [super_info.wheres; wheres] # add default constructors - append!(ctors, _constructors(clsname, supercls, fields, all_fields, wheres)) + inner, outer = _constructors(clsname, supercls, super_info, fields, all_fields, wheres) + append!(ctors, inner) abs_class = abs_symbol(clsname) - abs_super = absclass(supercls) + abs_super = nameof(absclass(supercls)) - struct_def = :( - struct $clsname{$(wheres...)} <: $abs_class - $(all_fields...) - $(ctors...) - end - ) + struct_def = emit_struct(clsname, abs_class, mutable, all_wheres, all_fields, ctors) # set mutability flag struct_def.args[1] = mutable @@ -207,7 +266,7 @@ function _defclass(clsname, supercls, mutable, wheres, exprs) result = quote abstract type $abs_class <: $abs_super end $struct_def - + $(outer...) Classes.superclass(::Type{$clsname}) = $supercls $clsname # return the struct type end @@ -229,6 +288,8 @@ macro class(elements...) error("Unrecognized form for @class definition: $elements") end + # @info "name_expr: $name_expr, definition: $definition" + # initialize the "captured" vars to avoid "unknown var" warnings cls = clsname = exprs = wheres = nothing @@ -264,14 +325,14 @@ macro method(funcdef) error("First argument of method $name must be explicitly typed") end - type_symbol = gensym("T") # gensym avoids conflict with user's type params + type_symbol = gensym() # avoids conflict with user's type params abs_super = abs_symbol(T) # Redefine the function to accept any first arg that's a subclass of abstype parts[:whereparams] = (:($type_symbol <: $abs_super), whereparams...) args[1] = :($arg1::$type_symbol) - expr = MacroTools.combinedef(parts) - return esc(expr) + expr = combinedef(parts) + return esc(expr) end end # module diff --git a/test/test_classes.jl b/test/test_classes.jl index 81ad3eb..0e509e4 100644 --- a/test/test_classes.jl +++ b/test/test_classes.jl @@ -1,14 +1,12 @@ using Test using Classes -using Suppressor -using MacroTools: striplines @test superclass(Class) === nothing @test isclass(AbstractClass) == false # not concrete @test isclass(Int) == false # not <: AbstractClass -@class Foo <: Class begin +@class Foo begin foo::Int Foo() = Foo(0) @@ -26,11 +24,6 @@ end @test superclass(Foo) == Class @test_throws Exception superclass(AbstractFoo) -function clean_str(s) - s = replace(s, r"\n" => " ") - s = replace(s, r"\s\s+" => " ") -end - @class mutable Bar <: Foo begin bar::Int @@ -45,7 +38,7 @@ end @class mutable Baz <: Bar begin baz::Int - function Baz(self::Union{Nothing, absclass(Baz)}=nothing) + function Baz(self::Union{Nothing, AbstractBaz}=nothing) self = (self === nothing ? new() : self) superclass(Baz)(self) Baz(self, 0) @@ -144,49 +137,22 @@ xyz = SubTupleHolder(sub, 10, 20, 30, (foo=111, bar=222)) # "First argument of method whatever must be explicitly typed" @test_throws(LoadError, eval(Meta.parse("@method whatever(i) = i"))) -expr = quote - abstract type AbstractX <: AbstractClass end - struct X{} <: AbstractX - function X() - #= /Users/rjp/.julia/dev/Classes/src/Classes.jl:136 =# - new() - end - function X(_self::T) where T <: AbstractX - _self - end - end - Classes.superclass(::Type{X}) = begin - Class - end - X +@class Parameterized{T1 <: Foo, T2 <: Foo} begin + one::T1 + two::T2 end -expected1 = striplines(expr) -emitted1 = striplines(Classes._defclass(:X, Class, false, nothing, [])) - -@test string(expected1) == string(emitted1) - -expr = quote - abstract type AbstractX <: AbstractClass end - mutable struct X{NT <: NamedTuple} <: AbstractX - i::Int - j::Int - function X{NT}(i::Int, j::Int) where NT <: NamedTuple - new{NT}(i, j) - end - function X(_self::T, i::Int, j::Int) where {T <: AbstractX, NT <: NamedTuple} - _self.i = i - _self.j = j - _self - end - end - Classes.superclass(::Type{X}) = begin - Class - end - X +@class ParameterizedSub{T3 <: TupleHolder, T4 <: TupleHolder} <: Parameterized begin + x::Float64 + y::Float64 end -expected2 = striplines(expr) -emitted2 = striplines(Classes._defclass(:X, Class, true, [:(NT <: NamedTuple)], [:(i::Int), :(j::Int)])) +# TBD: add tests on these -@test string(expected2) == string(emitted2) +# Generated: needs work on parameterized types (T1, T2 are not defined) +# - convert param names to gensyms to avoid collisions +# +# function (ParameterizedSub{T3, T4}(one::T1, two::T2, x::Float64, y::Float64; )::Any) where {T3 <: TupleHolder, T4 <: TupleHolder} +# #= /Users/rjp/.julia/packages/MacroTools/4AjBS/src/utils.jl:302 =# +# new{T3, T4}(one, two, x, y) +# end \ No newline at end of file