Skip to content

Commit

Permalink
Add generic restrictions
Browse files Browse the repository at this point in the history
  • Loading branch information
Malte Voos committed Jun 23, 2019
1 parent 2595905 commit 4735140
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 16 deletions.
37 changes: 37 additions & 0 deletions src/compiler/crystal/semantic/top_level_visitor.cr
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ class Crystal::TopLevelVisitor < Crystal::SemanticVisitor
node.raise "type vars must be #{type_type_vars.join ", "}, not #{type_vars.join ", "}"
end
end
if restriction = node.generic_restriction
type_restriction = type.restriction
if type_restriction
if restriction != type_restriction
restriction.raise "generic restriction must be '#{type_restriction}', not '#{restriction}'"
end
else
restriction.raise "#{type.type_desc} #{name} doesn't have a generic restriction"
end
end
else
node.raise "#{name} is not a generic #{type.type_desc}"
end
Expand All @@ -81,6 +91,7 @@ class Crystal::TopLevelVisitor < Crystal::SemanticVisitor
if type_vars = node.type_vars
type = GenericClassType.new @program, scope, name, nil, type_vars, false
type.splat_index = node.splat_index
type.restriction = node.generic_restriction
else
type = NonGenericClassType.new @program, scope, name, nil, false
end
Expand Down Expand Up @@ -221,10 +232,36 @@ class Crystal::TopLevelVisitor < Crystal::SemanticVisitor
end

type = type.as(ModuleType)

if type_vars = node.type_vars
if type.is_a?(GenericType)
type_type_vars = type.type_vars
if type_vars != type_type_vars
if type_type_vars.size == 1
node.raise "type var must be #{type_type_vars.join ", "}, not #{type_vars.join ", "}"
else
node.raise "type vars must be #{type_type_vars.join ", "}, not #{type_vars.join ", "}"
end
end
if restriction = node.generic_restriction
type_restriction = type.restriction
if type_restriction
if restriction != type_restriction
restriction.raise "generic restriction must be '#{type_restriction}', not '#{restriction}'"
end
else
restriction.raise "#{type.type_desc} #{name} doesn't have a generic restriction"
end
end
else
node.raise "#{name} is not a generic #{type.type_desc}"
end
end
else
if type_vars = node.type_vars
type = GenericModuleType.new @program, scope, name, type_vars
type.splat_index = node.splat_index
type.restriction = node.generic_restriction
else
type = NonGenericModuleType.new @program, scope, name
end
Expand Down
18 changes: 10 additions & 8 deletions src/compiler/crystal/syntax/ast.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1284,7 +1284,7 @@ module Crystal

# Class definition:
#
# 'class' name [ '<' superclass ]
# 'class' name [ '<' superclass ] [ 'where' generic restriction ]
# body
# 'end'
#
Expand All @@ -1293,14 +1293,15 @@ module Crystal
property body : ASTNode
property superclass : ASTNode?
property type_vars : Array(String)?
property generic_restriction : ASTNode?
property name_location : Location?
property doc : String?
property splat_index : Int32?
property? abstract : Bool
property? struct : Bool
property visibility = Visibility::Public

def initialize(@name, body = nil, @superclass = nil, @type_vars = nil, @abstract = false, @struct = false, @splat_index = nil)
def initialize(@name, body = nil, @superclass = nil, @type_vars = nil, @generic_restriction = nil, @abstract = false, @struct = false, @splat_index = nil)
@body = Expressions.from body
end

Expand All @@ -1310,30 +1311,31 @@ module Crystal
end

def clone_without_location
clone = ClassDef.new(@name, @body.clone, @superclass.clone, @type_vars.clone, @abstract, @struct, @splat_index)
clone = ClassDef.new(@name, @body.clone, @superclass.clone, @type_vars.clone, @generic_restriction.clone, @abstract, @struct, @splat_index)
clone.name_location = name_location
clone
end

def_equals_and_hash @name, @body, @superclass, @type_vars, @abstract, @struct, @splat_index
def_equals_and_hash @name, @body, @superclass, @type_vars, @generic_restriction, @abstract, @struct, @splat_index
end

# Module definition:
#
# 'module' name
# 'module' name [ 'where' generic restriction ]
# body
# 'end'
#
class ModuleDef < ASTNode
property name : Path
property body : ASTNode
property type_vars : Array(String)?
property generic_restriction : ASTNode?
property splat_index : Int32?
property name_location : Location?
property doc : String?
property visibility = Visibility::Public

def initialize(@name, body = nil, @type_vars = nil, @splat_index = nil)
def initialize(@name, body = nil, @type_vars = nil, @generic_restriction = nil, @splat_index = nil)
@body = Expressions.from body
end

Expand All @@ -1342,12 +1344,12 @@ module Crystal
end

def clone_without_location
clone = ModuleDef.new(@name, @body.clone, @type_vars.clone, @splat_index)
clone = ModuleDef.new(@name, @body.clone, @type_vars.clone, @generic_restriction.clone, @splat_index)
clone.name_location = name_location
clone
end

def_equals_and_hash @name, @body, @type_vars, @splat_index
def_equals_and_hash @name, @body, @type_vars, @generic_restriction, @splat_index
end

# Annotation definition:
Expand Down
7 changes: 6 additions & 1 deletion src/compiler/crystal/syntax/lexer.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1139,8 +1139,13 @@ module Crystal
when 'h'
case next_char
when 'e'
if next_char == 'n'
case next_char
when 'n'
return check_ident_or_keyword(:when, start)
when 'r'
if next_char == 'e'
return check_ident_or_keyword(:where, start)
end
end
when 'i'
if next_char == 'l' && next_char == 'e'
Expand Down
26 changes: 20 additions & 6 deletions src/compiler/crystal/syntax/parser.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1566,9 +1566,16 @@ module Crystal
superclass = parse_ident
end
end

generic_restriction = nil
if type_vars && @token.keyword?(:where)
next_token_skip_space_or_newline
generic_restriction = parse_expression
end

skip_statement_end

body = push_visbility { parse_expressions }
body = push_visibility { parse_expressions }

end_location = token_end_location
check_ident :end
Expand All @@ -1578,7 +1585,7 @@ module Crystal

@type_nest -= 1

class_def = ClassDef.new name, body, superclass, type_vars, is_abstract, is_struct, splat_index
class_def = ClassDef.new name, body, superclass, type_vars, generic_restriction, is_abstract, is_struct, splat_index
class_def.doc = doc
class_def.name_location = name_location
class_def.end_location = end_location
Expand Down Expand Up @@ -1641,9 +1648,16 @@ module Crystal
skip_space

type_vars, splat_index = parse_type_vars

generic_restriction = nil
if type_vars && @token.keyword?(:where)
next_token_skip_space_or_newline
generic_restriction = parse_expression
end

skip_statement_end

body = push_visbility { parse_expressions }
body = push_visibility { parse_expressions }

end_location = token_end_location
check_ident :end
Expand All @@ -1653,7 +1667,7 @@ module Crystal

@type_nest -= 1

module_def = ModuleDef.new name, body, type_vars, splat_index
module_def = ModuleDef.new name, body, type_vars, generic_restriction, splat_index
module_def.doc = doc
module_def.name_location = name_location
module_def.end_location = end_location
Expand Down Expand Up @@ -5164,7 +5178,7 @@ module Crystal
name_location = @token.location
next_token_skip_statement_end

body = push_visbility { parse_lib_body_expressions }
body = push_visibility { parse_lib_body_expressions }

check_ident :end
end_location = token_end_location
Expand Down Expand Up @@ -5844,7 +5858,7 @@ module Crystal
name == "self" || @def_vars.last.includes?(name)
end

def push_visbility
def push_visibility
old_visibility = @visibility
@visibility = nil
value = yield
Expand Down
30 changes: 29 additions & 1 deletion src/compiler/crystal/types.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1456,6 +1456,9 @@ module Crystal
# The type variable names (K and V in Hash).
getter type_vars : Array(String)

# The optional generic restriction specified with `where`.
property restriction : ASTNode?

# The index of the `*` in the type variables.
property splat_index : Int32?

Expand All @@ -1477,7 +1480,10 @@ module Crystal
return instance
end

# Used for instantiating the generic type
instance_type_vars = {} of String => ASTNode
# Passed to the macro interpreter checking the generic restriction
restriction_check_vars = {} of String => TypeVar
type_var_index = 0
self.type_vars.each_with_index do |name, index|
if splat_index == index
Expand All @@ -1486,9 +1492,11 @@ module Crystal
types << type_vars[type_var_index]
type_var_index += 1
end
var = Var.new(name, program.tuple_of(types))
type_var = program.tuple_of(types)
var = Var.new(name, type_var)
var.bind_to(var)
instance_type_vars[name] = var
restriction_check_vars[name] = type_var
else
type_var = type_vars[type_var_index]
case type_var
Expand All @@ -1499,11 +1507,24 @@ module Crystal
when ASTNode
instance_type_vars[name] = type_var
end
restriction_check_vars[name] = type_var
type_var_index += 1
end
end

instance = self.new_generic_instance(program, self, instance_type_vars)

# Don't check the generic restriction unless the type is fully instantiated
if restriction && !instance.unbound?
unless macro_expression_truthy?(program, restriction.not_nil!, self, free_vars: restriction_check_vars)
if type_vars.size == 1
raise TypeException.new "type var #{type_vars.first} doesn't satisfy restriction '#{restriction}' of #{type_desc} #{name}"
else
raise TypeException.new "type vars #{type_vars.join ", "} don't satisfy restriction '#{restriction}' of #{type_desc} #{name}"
end
end
end

generic_types[type_vars] = instance

if instance.is_a?(GenericClassInstanceType) && !instance.superclass
Expand Down Expand Up @@ -3333,3 +3354,10 @@ private def add_instance_var_initializer(including_types, name, value, meta_vars
end
end
end

private def macro_expression_truthy?(program, node : Crystal::ASTNode, scope : Crystal::Type, free_vars = nil)
interpreter = Crystal::MacroInterpreter.new program, scope, scope, node.location, in_macro: false
interpreter.free_vars = free_vars
node.accept interpreter
interpreter.last.truthy?
end

0 comments on commit 4735140

Please sign in to comment.