Skip to content

Commit

Permalink
Various improvements for subclassing parameterized types
Browse files Browse the repository at this point in the history
  • Loading branch information
rjplevin committed Jan 4, 2019
1 parent b10aae1 commit 479e3df
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 94 deletions.
149 changes: 105 additions & 44 deletions src/Classes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,84 @@ module Classes

using DataStructures
using MacroTools
using MacroTools:combinedef, combinestructdef
using InteractiveUtils: subtypes

export @class, @method, Class, AbstractClass, isclass, classof, superclass, superclasses, issubclass, subclasses, absclass

#
# Functional interface to MacroTools' dict-based expression creation functions
# TBD: suggest adding these (without requiring dict) to MacroTools
#
function emit_struct(name::Symbol, supertype::Symbol, mutable::Bool, params::Vector, fields::Vector, ctors::Vector)
fieldtups = [(tup[1], tup[2]) for tup in map(splitarg, fields)]
d = Dict(:name=>name, :supertype=>supertype, :mutable=>mutable, :params=>params, :fields=>fieldtups, :constructors=>ctors)
combinestructdef(d)
end

function emit_function(name::Symbol, body; args::Vector=[], kwargs::Vector=[], rtype=nothing, params::Vector=[], wparams::Vector=[])
d = Dict(:name=>name, :args=>args, :kwargs=>kwargs, :body=>body, :params=>params, :whereparams=>wparams)
if rtype !== nothing
d[:rtype] = rtype
end
combinedef(d)
end


abstract type AbstractClass end # supertype of all shadow class types
abstract type Class <: AbstractClass end # superclass of all concrete classes

abs_symbol(cls::Symbol) = Symbol("Abstract", cls)

# Since nameof() doesn't cover all the cases we need, we define our own
typename(t) = t
typename(t::TypeVar) = t.name


# fieldnames(DataType)
# (:name, :super, :parameters, :types, :names, :instance, :layout, :size, :ninitialized, :uid, :abstract, :mutable, :hasfreetypevars,
# :isconcretetype, :isdispatchtuple, :isbitstype, :zeroinit, :isinlinealloc, Symbol("llvm::StructType"), Symbol("llvm::DIType"))
#
# if dtype.hasfreetypevars, dtype.types is like svec(XYZ<:ABC,...)

function _translate_ivar(d::Dict, ivar)
if ! @capture(ivar, vname_::vtype_ | vname_)
error("Expected field definition, got $ivar")
end

if vtype === nothing
return ivar # no type, nothing to translate
end

vtype = get(d, vtype, vtype) # translate parameterized types
return :($vname::$vtype)
end

function _translate_where(d::Dict, wparam::TypeVar)
# supname = :Any
supname = wparam.ub # TBD: not sure this suffices
name = wparam.name
name = get(d, name, name) # translate, if a type parameter, else pass through

return :($name <: $supname)
end

# If a symbol is already a gensym, extract the symbol and re-gensym with it
regensym(s) = MacroTools.isgensym(s) ? gensym(Symbol(MacroTools.gensymname(s))) : gensym(s)

# Return info about a class in a named tuple
function _class_info(::Type{T}) where {T <: AbstractClass}
ivars = (isabstracttype(T) ? Expr[] : [:($vname::$vtype) for (vname, vtype) in zip(fieldnames(T), T.types)])
return (modname=T.name.module, mutable=T.mutable, parameters=T.parameters, ivars=ivars, super=superclass(T))
typ = (typeof(T) === UnionAll ? Base.unwrap_unionall(T) : T)

# note: must extract symbol from type to create required expression
ivars = (isabstracttype(typ) ? Expr[] : [:($vname::$(typename(vtype))) for (vname, vtype) in zip(fieldnames(typ), typ.types)])
wheres = typ.parameters

d = Dict(t.name=>regensym(t.name) for t in wheres) # create mapping of type params to gen'd symbols
ivars = [_translate_ivar(d, iv) for iv in ivars] # translate types to use gensyms
wheres = [_translate_where(d, w) for w in wheres]

return (wheres=wheres, ivars=ivars, super=superclass(typ))
end

"""
Expand Down Expand Up @@ -44,7 +109,9 @@ end
Return `true` if `X` is a concrete subclass of `AbstractClass`, or is `Class`, which is abstract.
"""
isclass(any) = false
isclass(::Type{T}) where {T <: AbstractClass} = (T === Class || isconcretetype(T))

# Note that !isabstracttype(T) != isconcretetype(T): parameterized types return false for both
isclass(::Type{T}) where {T <: AbstractClass} = (T === Class || !isabstracttype(T))

"""
issubclass(t1::DataType, t2::DataType)
Expand Down Expand Up @@ -113,9 +180,10 @@ end
function _initializer(class, fields, wheres)
args = _argnames(fields)
assigns = [:(_self.$arg = $arg) for arg in args]
T = gensym(class)

funcdef = :(
function $class(_self::T, $(fields...)) where {T <: $(abs_symbol(class)), $(wheres...)}
function $class(_self::$T, $(fields...)) where {$T <: $(abs_symbol(class)), $(wheres...)}
$(assigns...)
_self
end
Expand All @@ -124,53 +192,48 @@ function _initializer(class, fields, wheres)
return funcdef
end

function _constructors(clsname, super, local_fields, all_fields, wheres)
args = _argnames(all_fields)
params = [clause.args[1] for clause in wheres] # extract parameter names from where clauses

dflt = length(params) > 0 ? :(
function $clsname{$(params...)}($(all_fields...)) where {$(wheres...)}
new{$(params...)}($(args...))
end) : :(
function $clsname($(all_fields...))
new($(args...))
end)

init_all = _initializer(clsname, all_fields, wheres)
methods = [dflt, init_all]
function _constructors(clsname, super, super_info, local_fields, all_fields, wheres)
all_wheres = [super_info.wheres; wheres]
init_all = _initializer(clsname, all_fields, all_wheres)
inits = [init_all]

# If clsname is a direct subclasses of Classes.Class, it has no fields
# other than those defined locally, so the two methods would be identical.
# In this case, we emit only one of them.
if all_fields != local_fields
init_local = _initializer(clsname, local_fields, wheres)
push!(methods, init_local)
push!(inits, init_local)
end

params = [clause.args[1] for clause in all_wheres] # extract parameter names from where clauses
has_params = length(params) != 0

args = _argnames(all_fields)
body = has_params ? :(new{$(params...)}($(args...))) : :(new($(args...)))
dflt = emit_function(clsname, body, args=all_fields, params=params, wparams=all_wheres, rtype=clsname)

methods = [dflt]

# Primarily for immutable classes, we emit a constructor that takes an instance
# of the direct superclass and copies values when creating a new object.
super_info = _class_info(super)
# of the direct superclass and copies values when creating a new object.
super_fields = super_info.ivars
if length(super_fields) != 0
super_args = [:(_super.$arg) for arg in _argnames(super_fields)]
local_args = _argnames(local_fields)
all_args = [super_args; local_args]

immut_init = length(params) > 0 ? :(
function $clsname{$(params...)}(_super::$super, $(local_fields...)) where {$(wheres...)}
new{$(params...)}($(all_args...))
end) : :(
function $clsname(_super::$super, $(local_fields...))
new($(all_args...))
end)
body = has_params ? :(new{$(params...)}($(all_args...))) : :(new($(all_args...)))
args = [:(_super::$super); local_fields]
immut_init = emit_function(clsname, body; args=args, params=params, wparams=all_wheres, rtype=clsname)
push!(methods, immut_init)
end

return methods
return methods, inits
end

function _defclass(clsname, supercls, mutable, wheres, exprs)
wheres = (wheres === nothing ? [] : wheres)
# @info "clsname:$clsname supercls:$supercls mutable:$mutable wheres:$wheres exprs:$exprs"

# partition expressions into constructors and field defs
ctors = Vector{Expr}()
Expand All @@ -184,30 +247,26 @@ function _defclass(clsname, supercls, mutable, wheres, exprs)
end
end

superinfo = _class_info(supercls)
all_fields = copy(superinfo.ivars)
append!(all_fields, fields)
super_info = _class_info(supercls)
all_fields = [super_info.ivars; fields]
all_wheres = [super_info.wheres; wheres]

# add default constructors
append!(ctors, _constructors(clsname, supercls, fields, all_fields, wheres))
inner, outer = _constructors(clsname, supercls, super_info, fields, all_fields, wheres)
append!(ctors, inner)

abs_class = abs_symbol(clsname)
abs_super = absclass(supercls)
abs_super = nameof(absclass(supercls))

struct_def = :(
struct $clsname{$(wheres...)} <: $abs_class
$(all_fields...)
$(ctors...)
end
)
struct_def = emit_struct(clsname, abs_class, mutable, all_wheres, all_fields, ctors)

# set mutability flag
struct_def.args[1] = mutable

result = quote
abstract type $abs_class <: $abs_super end
$struct_def

$(outer...)
Classes.superclass(::Type{$clsname}) = $supercls
$clsname # return the struct type
end
Expand All @@ -229,6 +288,8 @@ macro class(elements...)
error("Unrecognized form for @class definition: $elements")
end

# @info "name_expr: $name_expr, definition: $definition"

# initialize the "captured" vars to avoid "unknown var" warnings
cls = clsname = exprs = wheres = nothing

Expand Down Expand Up @@ -264,14 +325,14 @@ macro method(funcdef)
error("First argument of method $name must be explicitly typed")
end

type_symbol = gensym("T") # gensym avoids conflict with user's type params
type_symbol = gensym() # avoids conflict with user's type params
abs_super = abs_symbol(T)

# Redefine the function to accept any first arg that's a subclass of abstype
parts[:whereparams] = (:($type_symbol <: $abs_super), whereparams...)
args[1] = :($arg1::$type_symbol)
expr = MacroTools.combinedef(parts)
return esc(expr)
expr = combinedef(parts)
return esc(expr)
end

end # module
66 changes: 16 additions & 50 deletions test/test_classes.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
using Test
using Classes
using Suppressor
using MacroTools: striplines

@test superclass(Class) === nothing

@test isclass(AbstractClass) == false # not concrete
@test isclass(Int) == false # not <: AbstractClass

@class Foo <: Class begin
@class Foo begin
foo::Int

Foo() = Foo(0)
Expand All @@ -26,11 +24,6 @@ end
@test superclass(Foo) == Class
@test_throws Exception superclass(AbstractFoo)

function clean_str(s)
s = replace(s, r"\n" => " ")
s = replace(s, r"\s\s+" => " ")
end

@class mutable Bar <: Foo begin
bar::Int

Expand All @@ -45,7 +38,7 @@ end
@class mutable Baz <: Bar begin
baz::Int

function Baz(self::Union{Nothing, absclass(Baz)}=nothing)
function Baz(self::Union{Nothing, AbstractBaz}=nothing)
self = (self === nothing ? new() : self)
superclass(Baz)(self)
Baz(self, 0)
Expand Down Expand Up @@ -144,49 +137,22 @@ xyz = SubTupleHolder(sub, 10, 20, 30, (foo=111, bar=222))
# "First argument of method whatever must be explicitly typed"
@test_throws(LoadError, eval(Meta.parse("@method whatever(i) = i")))

expr = quote
abstract type AbstractX <: AbstractClass end
struct X{} <: AbstractX
function X()
#= /Users/rjp/.julia/dev/Classes/src/Classes.jl:136 =#
new()
end
function X(_self::T) where T <: AbstractX
_self
end
end
Classes.superclass(::Type{X}) = begin
Class
end
X
@class Parameterized{T1 <: Foo, T2 <: Foo} begin
one::T1
two::T2
end

expected1 = striplines(expr)
emitted1 = striplines(Classes._defclass(:X, Class, false, nothing, []))

@test string(expected1) == string(emitted1)

expr = quote
abstract type AbstractX <: AbstractClass end
mutable struct X{NT <: NamedTuple} <: AbstractX
i::Int
j::Int
function X{NT}(i::Int, j::Int) where NT <: NamedTuple
new{NT}(i, j)
end
function X(_self::T, i::Int, j::Int) where {T <: AbstractX, NT <: NamedTuple}
_self.i = i
_self.j = j
_self
end
end
Classes.superclass(::Type{X}) = begin
Class
end
X
@class ParameterizedSub{T3 <: TupleHolder, T4 <: TupleHolder} <: Parameterized begin
x::Float64
y::Float64
end

expected2 = striplines(expr)
emitted2 = striplines(Classes._defclass(:X, Class, true, [:(NT <: NamedTuple)], [:(i::Int), :(j::Int)]))
# TBD: add tests on these

@test string(expected2) == string(emitted2)
# Generated: needs work on parameterized types (T1, T2 are not defined)
# - convert param names to gensyms to avoid collisions
#
# function (ParameterizedSub{T3, T4}(one::T1, two::T2, x::Float64, y::Float64; )::Any) where {T3 <: TupleHolder, T4 <: TupleHolder}
# #= /Users/rjp/.julia/packages/MacroTools/4AjBS/src/utils.jl:302 =#
# new{T3, T4}(one, two, x, y)
# end

0 comments on commit 479e3df

Please sign in to comment.