diff --git a/include/swift/AST/Stmt.h b/include/swift/AST/Stmt.h index 4bc0c29e8cf54..22ac0e60a7444 100644 --- a/include/swift/AST/Stmt.h +++ b/include/swift/AST/Stmt.h @@ -1055,15 +1055,26 @@ class CaseStmt final return UnknownAttrLoc.isValid(); } - Optional> getCaseBodyVariables() const { - if (!CaseBodyVariables) - return None; + /// Return an ArrayRef containing the case body variables of this CaseStmt. + /// + /// Asserts if case body variables was not explicitly initialized. In contexts + /// where one wants a non-asserting version, \see + /// getCaseBodyVariablesOrEmptyArray. + ArrayRef getCaseBodyVariables() const { ArrayRef a = *CaseBodyVariables; return a; } - Optional> getCaseBodyVariables() { - return CaseBodyVariables; + bool hasCaseBodyVariables() const { return CaseBodyVariables.hasValue(); } + + /// Return an MutableArrayRef containing the case body variables of this + /// CaseStmt. + /// + /// Asserts if case body variables was not explicitly initialized. In contexts + /// where one wants a non-asserting version, \see + /// getCaseBodyVariablesOrEmptyArray. + MutableArrayRef getCaseBodyVariables() { + return *CaseBodyVariables; } ArrayRef getCaseBodyVariablesOrEmptyArray() const { diff --git a/lib/AST/ASTDumper.cpp b/lib/AST/ASTDumper.cpp index d069df102c639..e589f7d4c56a6 100644 --- a/lib/AST/ASTDumper.cpp +++ b/lib/AST/ASTDumper.cpp @@ -1622,13 +1622,13 @@ class PrintStmt : public StmtVisitor { if (S->hasUnknownAttr()) OS << " @unknown"; - if (auto caseBodyVars = S->getCaseBodyVariables()) { + if (S->hasCaseBodyVariables()) { OS << '\n'; OS.indent(Indent + 2); PrintWithColorRAII(OS, ParenthesisColor) << '('; PrintWithColorRAII(OS, StmtColor) << "case_body_variables"; OS << '\n'; - for (auto *vd : *caseBodyVars) { + for (auto *vd : S->getCaseBodyVariables()) { OS.indent(2); // TODO: Printing a var decl does an Indent ... dump(vd) ... '\n'. We // should see if we can factor this dumping so that the caller of diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index 991a0cf36c057..330dce61e2f6e 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -5094,7 +5094,7 @@ NullablePtr VarDecl::getCorrespondingCaseBodyVariable() const { // A var decl associated with a case stmt implies that the case stmt has body // var decls. So we can access the optional value here without worry. - auto caseBodyVars = *caseStmt->getCaseBodyVariables(); + auto caseBodyVars = caseStmt->getCaseBodyVariables(); auto result = llvm::find_if(caseBodyVars, [&](VarDecl *caseBodyVar) { return caseBodyVar->getName() == name; }); diff --git a/lib/AST/NameLookup.cpp b/lib/AST/NameLookup.cpp index 01980d12a2f71..84cc9dbc5b29d 100644 --- a/lib/AST/NameLookup.cpp +++ b/lib/AST/NameLookup.cpp @@ -2349,10 +2349,8 @@ void FindLocalVal::visitCaseStmt(CaseStmt *S) { } if (!inPatterns && !items.empty()) { - if (auto caseBodyVars = S->getCaseBodyVariables()) { - for (auto *vd : *caseBodyVars) { - checkValueDecl(vd, DeclVisibilityKind::LocalVariable); - } + for (auto *vd : S->getCaseBodyVariablesOrEmptyArray()) { + checkValueDecl(vd, DeclVisibilityKind::LocalVariable); } } visit(S->getBody()); diff --git a/lib/SILGen/SILGenPattern.cpp b/lib/SILGen/SILGenPattern.cpp index aa7870cfbfbbf..19448f5776cb4 100644 --- a/lib/SILGen/SILGenPattern.cpp +++ b/lib/SILGen/SILGenPattern.cpp @@ -2332,12 +2332,8 @@ void PatternMatchEmission::initSharedCaseBlockDest(CaseStmt *caseBlock, auto *block = SGF.createBasicBlock(); result.first->second.first = block; - // Add args for any pattern variables - auto caseBodyVars = caseBlock->getCaseBodyVariables(); - if (!caseBodyVars) - return; - - for (auto *vd : *caseBodyVars) { + // Add args for any pattern variables if we have any. + for (auto *vd : caseBlock->getCaseBodyVariablesOrEmptyArray()) { if (!vd->hasName()) continue; @@ -2365,14 +2361,10 @@ void PatternMatchEmission::emitAddressOnlyAllocations() { for (auto &entry : SharedCases) { CaseStmt *caseBlock = entry.first; - auto caseBodyVars = caseBlock->getCaseBodyVariables(); - if (!caseBodyVars) - continue; - // If we have a shared case with bound decls, setup the arguments for the // shared block by emitting the temporary allocation used for the arguments // of the shared block. - for (auto *vd : *caseBodyVars) { + for (auto *vd : caseBlock->getCaseBodyVariablesOrEmptyArray()) { if (!vd->hasName()) continue; @@ -2436,8 +2428,7 @@ void PatternMatchEmission::emitSharedCaseBlocks() { assert(SGF.getCleanupsDepth() == PatternMatchStmtDepth); SWIFT_DEFER { assert(SGF.getCleanupsDepth() == PatternMatchStmtDepth); }; - auto caseBodyVars = caseBlock->getCaseBodyVariables(); - if (!caseBodyVars) { + if (!caseBlock->hasCaseBodyVariables()) { emitCaseBody(caseBlock); continue; } @@ -2448,7 +2439,7 @@ void PatternMatchEmission::emitSharedCaseBlocks() { // args needing Cleanup will get that as well. Scope scope(SGF.Cleanups, CleanupLocation(caseBlock)); unsigned argIndex = 0; - for (auto *vd : *caseBodyVars) { + for (auto *vd : caseBlock->getCaseBodyVariables()) { if (!vd->hasName()) continue; @@ -2606,14 +2597,14 @@ static void switchCaseStmtSuccessCallback(SILGenFunction &SGF, if (!row.hasFallthroughTo() && caseBlock->getCaseLabelItems().size() == 1) { // If we have case body vars, set them up to point at the matching var // decls. - if (auto caseBodyVars = caseBlock->getCaseBodyVariables()) { + if (caseBlock->hasCaseBodyVariables()) { // Since we know that we only have one case label item, grab its pattern // vars and use that to update expected with the right SILValue. // // TODO: Do we need a copy here? SmallVector patternVars; row.getCasePattern()->collectVariables(patternVars); - for (auto *expected : *caseBodyVars) { + for (auto *expected : caseBlock->getCaseBodyVariables()) { if (!expected->hasName()) continue; for (auto *vd : patternVars) { @@ -2622,7 +2613,8 @@ static void switchCaseStmtSuccessCallback(SILGenFunction &SGF, } // Ok, we found a match. Update the VarLocs for the case block. - SGF.VarLocs[expected] = SGF.VarLocs[vd]; + auto v = SGF.VarLocs[vd]; + SGF.VarLocs[expected] = v; } } } @@ -2639,8 +2631,7 @@ static void switchCaseStmtSuccessCallback(SILGenFunction &SGF, // If we do not have any bound decls, we do not need to setup any // variables. Just jump to the shared destination. - auto caseBodyVars = caseBlock->getCaseBodyVariables(); - if (!caseBodyVars) { + if (!caseBlock->hasCaseBodyVariables()) { // Don't emit anything yet, we emit it at the cleanup level of the switch // statement. JumpDest sharedDest = emission.getSharedCaseBlockDest(caseBlock); @@ -2657,7 +2648,7 @@ static void switchCaseStmtSuccessCallback(SILGenFunction &SGF, SILModule &M = SGF.F.getModule(); SmallVector patternVars; row.getCasePattern()->collectVariables(patternVars); - for (auto *expected : *caseBodyVars) { + for (auto *expected : caseBlock->getCaseBodyVariables()) { if (!expected->hasName()) continue; for (auto *var : patternVars) { @@ -2844,8 +2835,7 @@ void SILGenFunction::emitSwitchFallthrough(FallthroughStmt *S) { // If our destination case doesn't have any bound decls, there is no rebinding // to do. Just jump to the shared dest. - auto destCaseBodyVars = destCaseStmt->getCaseBodyVariables(); - if (!destCaseBodyVars) { + if (!destCaseStmt->hasCaseBodyVariables()) { Cleanups.emitBranchAndCleanups(sharedDest, S); return; } @@ -2855,13 +2845,13 @@ void SILGenFunction::emitSwitchFallthrough(FallthroughStmt *S) { SmallVector args; CaseStmt *fallthroughSourceStmt = S->getFallthroughSource(); - for (auto *expected : *destCaseBodyVars) { + for (auto *expected : destCaseStmt->getCaseBodyVariables()) { if (!expected->hasName()) continue; // The type checker enforces that if our destination case has variables then // our fallthrough source must as well. - for (auto *var : *fallthroughSourceStmt->getCaseBodyVariables()) { + for (auto *var : fallthroughSourceStmt->getCaseBodyVariables()) { if (!var->hasName() || var->getName() != expected->getName()) { continue; } diff --git a/lib/Sema/MiscDiagnostics.cpp b/lib/Sema/MiscDiagnostics.cpp index f2a3d4c1f78b1..6ceaf46d5f06d 100644 --- a/lib/Sema/MiscDiagnostics.cpp +++ b/lib/Sema/MiscDiagnostics.cpp @@ -2414,10 +2414,8 @@ class VarDeclUsageChecker : public ASTWalker { // Make sure that we setup our case body variables. if (auto *caseStmt = dyn_cast(S)) { - if (auto caseBoundDecls = caseStmt->getCaseBodyVariables()) { - for (auto *vd : *caseBoundDecls) { - VarDecls[vd] |= 0; - } + for (auto *vd : caseStmt->getCaseBodyVariablesOrEmptyArray()) { + VarDecls[vd] |= 0; } } diff --git a/test/SILGen/switch.swift b/test/SILGen/switch.swift index d24593775d1dd..c5c14ff761ffa 100644 --- a/test/SILGen/switch.swift +++ b/test/SILGen/switch.swift @@ -1500,3 +1500,86 @@ func nonTrivialLoadableFallthroughCallee2(_ e : MultipleNonTrivialCaseEnum) { } } +// Make sure that we do not crash while emitting this code. +// +// DISCUSSION: The original crash was due to us performing an assignment/lookup +// on the VarLocs DenseMap in the same statement. This was caught be an +// asanified compiler. This test is just to make sure we do not regress. +enum Storage { + case empty + case single(Int) + case pair(Int, Int) + case array([Int]) + + subscript(range: [Int]) -> Storage { + get { + return .empty + } + set { + switch self { + case .empty: + break + case .single(let index): + break + case .pair(let first, let second): + switch (range[0], range[1]) { + case (0, 0): + switch newValue { + case .empty: + break + case .single(let other): + break + case .pair(let otherFirst, let otherSecond): + break + case .array(let other): + break + } + break + case (0, 1): + switch newValue { + case .empty: + break + case .single(let other): + break + case .pair(let otherFirst, let otherSecond): + break + case .array(let other): + break + } + break + case (0, 2): + break + case (1, 2): + switch newValue { + case .empty: + break + case .single(let other): + break + case .pair(let otherFirst, let otherSecond): + break + case .array(let other): + self = .array([first] + other) + } + break + case (2, 2): + switch newValue { + case .empty: + break + case .single(let other): + break + case .pair(let otherFirst, let otherSecond): + break + case .array(let other): + self = .array([first, second] + other) + } + break + default: + let r = range + } + case .array(let indexes): + break + } + } + } +} +