Skip to content

Commit

Permalink
Merge pull request #1564 from hylo-lang/union-switch
Browse files Browse the repository at this point in the history
Add an IR instruction to switch over the contents of a union
  • Loading branch information
kyouko-taiga authored Aug 22, 2024
2 parents 211d4e0 + 29bf165 commit 0392eff
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 62 deletions.
41 changes: 34 additions & 7 deletions Sources/CodeGen/LLVM/Transpilation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,8 @@ extension SwiftyLLVM.Module {
insert(switch: i)
case is IR.UnionDiscriminator:
insert(unionDiscriminator: i)
case is IR.UnionSwitch:
insert(unionSwitch: i)
case is IR.Unreachable:
insert(unreachable: i)
case is IR.WrapExistentialAddr:
Expand Down Expand Up @@ -1199,14 +1201,28 @@ extension SwiftyLLVM.Module {
/// Inserts the transpilation of `i` at `insertionPoint`.
func insert(unionDiscriminator i: IR.InstructionID) {
let s = m[i] as! UnionDiscriminator
let t = UnionType(m.type(of: s.container).ast)!
register[.register(i)] = discriminator(s.container)
}

let baseType = ir.llvm(unionType: t, in: &self)
let container = llvm(s.container)
let indices = [i32.constant(0), i32.constant(1)]
let discriminator = insertGetElementPointerInBounds(
of: container, typed: baseType, indices: indices, at: insertionPoint)
register[.register(i)] = insertLoad(word(), from: discriminator, at: insertionPoint)
/// Inserts the transpilation of `i` at `insertionPoint`.
func insert(unionSwitch i: IR.InstructionID) {
let s = m[i] as! UnionSwitch

if let (_, b) = s.targets.elements.uniqueElement {
insertBr(to: block[b]!, at: insertionPoint)
} else {
let d = discriminator(s.scrutinee)
let t = UnionType(m.type(of: s.scrutinee).ast)!
let e = m.program.discriminatorToElement(in: t)
let branches = s.targets.map { (t, b) in
(word().constant(e.firstIndex(of: t)!), block[b]!)
}

// The last branch is the "default".
insertSwitch(
on: d, cases: branches.dropLast(), default: branches.last!.1,
at: insertionPoint)
}
}

/// Inserts the transpilation of `i` at `insertionPoint`.
Expand Down Expand Up @@ -1292,6 +1308,17 @@ extension SwiftyLLVM.Module {
v = insertInsertValue(llvm(table), at: 1, into: v, at: insertionPoint)
return v
}

/// Returns the value of `container`'s discriminator.
func discriminator(_ container: IR.Operand) -> SwiftyLLVM.Instruction {
let union = UnionType(m.type(of: container).ast)!
let baseType = ir.llvm(unionType: union, in: &self)
let container = llvm(container)
let indices = [i32.constant(0), i32.constant(1)]
let discriminator = insertGetElementPointerInBounds(
of: container, typed: baseType, indices: indices, at: insertionPoint)
return insertLoad(word(), from: discriminator, at: insertionPoint)
}
}

/// Inserts the prologue of the subscript `transpilation` at the end of its entry and returns
Expand Down
2 changes: 2 additions & 0 deletions Sources/IR/Analysis/Module+NormalizeObjectStates.swift
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ extension Module {
pc = interpret(subfieldView: user, in: &context)
case is UnionDiscriminator:
pc = interpret(unionDiscriminator: user, in: &context)
case is UnionSwitch:
pc = successor(of: user)
case is Unreachable:
pc = successor(of: user)
case is WrapExistentialAddr:
Expand Down
92 changes: 37 additions & 55 deletions Sources/IR/Emitter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -737,20 +737,15 @@ struct Emitter {
}

// Otherwise, use a switch to select the correct move-initialization.
let elements = program.discriminatorToElement(in: t)
var successors: [Block.ID] = []
for _ in t.elements {
successors.append(appendBlock())
}

let n = emitUnionDiscriminator(argument, at: site)
insert(module.makeSwitch(on: n, toOneOf: successors, at: site))
let targets = UnionSwitch.Targets(
t.elements.map({ (e) in (key: e, value: appendBlock()) }),
uniquingKeysWith: { (a, _) in a })
insert(module.makeUnionSwitch(on: receiver, toOneOf: targets, at: site))

let tail = appendBlock()
for i in 0 ..< elements.count {
insertionPoint = .end(of: successors[i])
emitMoveInitUnionPayload(
of: receiver, consuming: argument, containing: elements[i], at: site)
for (u, b) in targets {
insertionPoint = .end(of: b)
emitMoveInitUnionPayload(of: receiver, consuming: argument, containing: u, at: site)
insert(module.makeBranch(to: tail, at: site))
}

Expand Down Expand Up @@ -897,20 +892,16 @@ struct Emitter {
return
}

// Otherwise, use a switch to select the correct move-initialization.
let elements = program.discriminatorToElement(in: t)
var successors: [Block.ID] = []
for _ in t.elements {
successors.append(appendBlock())
}

let n = emitUnionDiscriminator(source, at: site)
insert(module.makeSwitch(on: n, toOneOf: successors, at: site))
// Otherwise, use a switch to select the correct copy method.
let targets = UnionSwitch.Targets(
t.elements.map({ (e) in (key: e, value: appendBlock()) }),
uniquingKeysWith: { (a, _) in a })
insert(module.makeUnionSwitch(on: source, toOneOf: targets, at: site))

let tail = appendBlock()
for i in 0 ..< elements.count {
insertionPoint = .end(of: successors[i])
emitCopyUnionPayload(from: source, containing: elements[i], to: target, at: site)
for (u, b) in targets {
insertionPoint = .end(of: b)
emitCopyUnionPayload(from: source, containing: u, to: target, at: site)
insert(module.makeBranch(to: tail, at: site))
}

Expand Down Expand Up @@ -2391,24 +2382,22 @@ struct Emitter {
///
/// This method method implements conditional narrowing for union types.
private mutating func emitConditionalNarrowing(
_ subject: Operand, typed subjectType: UnionType,
_ subject: Operand, typed union: UnionType,
as pattern: BindingPattern.ID, typed patternType: AnyType,
to storage: Operand,
else failure: Block.ID, in scope: AnyScopeID
) -> Block.ID {
// TODO: Implement narrowing to an arbitrary subtype.
guard subjectType.elements.contains(patternType) else { UNIMPLEMENTED() }
guard union.elements.contains(patternType) else { UNIMPLEMENTED() }
let site = ast[pattern].site

let i = program.discriminatorToElement(in: subjectType).firstIndex(of: patternType)!
let expected = IntegerConstant(i, bitWidth: 64) // FIXME: should be width of 'word'
let actual = emitUnionDiscriminator(subject, at: site)

let test = insert(
module.makeLLVM(applying: .icmp(.eq, .word), to: [.constant(expected), actual], at: site))!
let next = appendBlock(in: scope)
insert(module.makeCondBranch(if: test, then: next, else: failure, at: site))
var targets = UnionSwitch.Targets(
union.elements.map({ (e) in (key: e, value: failure) }),
uniquingKeysWith: { (a, _) in a })
targets[patternType] = next

insert(module.makeUnionSwitch(on: subject, toOneOf: targets, at: site))
insertionPoint = .end(of: next)
let x0 = insert(module.makeOpenUnion(subject, as: patternType, at: site))!
pushing(Frame()) { (this) in
Expand Down Expand Up @@ -3085,19 +3074,15 @@ struct Emitter {
}

// One successor per member in the union, ordered by their mangled representation.
let elements = program.discriminatorToElement(in: t)
var successors: [Block.ID] = []
for _ in t.elements {
successors.append(appendBlock())
}

let n = emitUnionDiscriminator(storage, at: site)
insert(module.makeSwitch(on: n, toOneOf: successors, at: site))
let targets = UnionSwitch.Targets(
t.elements.map({ (e) in (key: e, value: appendBlock()) }),
uniquingKeysWith: { (a, _) in a })
insert(module.makeUnionSwitch(on: storage, toOneOf: targets, at: site))

let tail = appendBlock()
for i in 0 ..< elements.count {
insertionPoint = .end(of: successors[i])
emitDeinitUnionPayload(of: storage, containing: elements[i], at: site)
for (u, b) in targets {
insertionPoint = .end(of: b)
emitDeinitUnionPayload(of: storage, containing: u, at: site)
insert(module.makeBranch(to: tail, at: site))
}

Expand Down Expand Up @@ -3188,13 +3173,10 @@ struct Emitter {
}

// Otherwise, compare their payloads.
let elements = program.discriminatorToElement(in: union)

let same = appendBlock()
var successors: [Block.ID] = []
for _ in elements {
successors.append(appendBlock())
}
let targets = UnionSwitch.Targets(
union.elements.map({ (e) in (key: e, value: appendBlock()) }),
uniquingKeysWith: { (a, _) in a })
let fail = appendBlock()
let tail = appendBlock()

Expand All @@ -3205,11 +3187,11 @@ struct Emitter {
insert(module.makeCondBranch(if: x0, then: same, else: fail, at: site))

insertionPoint = .end(of: same)
insert(module.makeSwitch(on: dl, toOneOf: successors, at: site))
for i in 0 ..< elements.count {
insertionPoint = .end(of: successors[i])
let y0 = insert(module.makeOpenUnion(lhs, as: elements[i], at: site))!
let y1 = insert(module.makeOpenUnion(rhs, as: elements[i], at: site))!
insert(module.makeUnionSwitch(on: lhs, toOneOf: targets, at: site))
for (u, b) in targets {
insertionPoint = .end(of: b)
let y0 = insert(module.makeOpenUnion(lhs, as: u, at: site))!
let y1 = insert(module.makeOpenUnion(rhs, as: u, at: site))!
emitStoreEquality(y0, y1, to: target, at: site)
insert(module.makeCloseUnion(y1, at: site))
insert(module.makeCloseUnion(y0, at: site))
Expand Down
9 changes: 9 additions & 0 deletions Sources/IR/InstructionTransformer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,15 @@ extension IR.Program {
target.makeUnionDiscriminator(x0, at: s.site)
}

case let s as UnionSwitch:
let x0 = t.transform(s.scrutinee, in: &self)
let x1 = s.targets.reduce(into: UnionSwitch.Targets()) { (d, kv) in
_ = d[t.transform(kv.key, in: &self)].setIfNil(t.transform(kv.value, in: &self))
}
return insert(at: p, in:n) { (target) in
target.makeUnionSwitch(on: x0, toOneOf: x1, at: s.site)
}

case let s as Unreachable:
return modules[n]!.insert(s, at: p)

Expand Down
82 changes: 82 additions & 0 deletions Sources/IR/Operands/Instruction/UnionSwitch.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import FrontEnd
import OrderedCollections

/// Branches to one of several basic blocks based on the discriminator of a union.
public struct UnionSwitch: Terminator {

/// The type of a map from payload type to its target.
public typealias Targets = OrderedDictionary<AnyType, Block.ID>

/// The union container whose discriminator is read.
public private(set) var scrutinee: Operand

/// A map from payload type to its target.
public private(set) var targets: Targets

/// The site of the code corresponding to that instruction.
public let site: SourceRange

/// Creates an instance with the given properties.
fileprivate init(scrutinee: Operand, targets: Targets, site: SourceRange) {
self.scrutinee = scrutinee
self.targets = targets
self.site = site
}



public var operands: [Operand] {
[scrutinee]
}

public var successors: [Block.ID] {
Array(targets.values)
}

public mutating func replaceOperand(at i: Int, with new: Operand) {
precondition(i == 0)
scrutinee = new
}

mutating func replaceSuccessor(_ old: Block.ID, with new: Block.ID) -> Bool {
precondition(new.function == successors[0].function)
for (t, b) in targets {
if b == old { targets[t] = b; return true }
}
return false
}

}

extension UnionSwitch: CustomStringConvertible {

public var description: String {
var s = "union_switch \(scrutinee)"
for (t, b) in targets {
s.write(", \(t) => \(b)")
}
return s
}

}

extension Module {

/// Creates a `union_switch` anchored at `site` that jumps to the block assigned to the type of
/// `scrutinee`'s payload in `targets`.
///
/// - Requires: `scrutinee` is a union container and `targets` has a key defined for each of the
/// elements in scrutinee's type.
func makeUnionSwitch(
on scrutinee: Operand, toOneOf targets: UnionSwitch.Targets, at site: SourceRange
) -> UnionSwitch {
let t = type(of: scrutinee)
guard t.isAddress, let u = UnionType(t.ast) else {
preconditionFailure("invalid type '\(t)'")
}
precondition(u.elements.allSatisfy({ (e) in targets[e] != nil }))

return .init(scrutinee: scrutinee, targets: targets, site: site)
}

}
8 changes: 8 additions & 0 deletions Tests/EndToEndTests/TestCases/UnionNarrowing.hylo
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
//- compileAndRun expecting: .success

fun f<T: Regular, U: Regular>(_ u: sink Union<T, U>) -> Bool {
if let _: T = u { true } else { false }
}

public fun main() {
var x: Union<{a: Bool}, {b: Int}> = (b: 42)
if let y: {b: _} = x {
precondition(y.0 == 42)
} else {
fatal_error()
}

precondition(f<Int, Bool>(42 as _))
precondition(f<Int, Int>(42 as _))
precondition(!f<Int, Bool>(true as _))
}

0 comments on commit 0392eff

Please sign in to comment.