Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check yield terminator's resume type in borrowck #119563

Merged
merged 1 commit into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 12 additions & 21 deletions compiler/rustc_borrowck/src/type_check/input_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,31 +94,22 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
);
}

debug!(
"equate_inputs_and_outputs: body.yield_ty {:?}, universal_regions.yield_ty {:?}",
body.yield_ty(),
universal_regions.yield_ty
);

// We will not have a universal_regions.yield_ty if we yield (by accident)
// outside of a coroutine and return an `impl Trait`, so emit a span_delayed_bug
// because we don't want to panic in an assert here if we've already got errors.
if body.yield_ty().is_some() != universal_regions.yield_ty.is_some() {
self.tcx().dcx().span_delayed_bug(
body.span,
format!(
"Expected body to have yield_ty ({:?}) iff we have a UR yield_ty ({:?})",
body.yield_ty(),
universal_regions.yield_ty,
),
if let Some(mir_yield_ty) = body.yield_ty() {
let yield_span = body.local_decls[RETURN_PLACE].source_info.span;
self.equate_normalized_input_or_output(
universal_regions.yield_ty.unwrap(),
mir_yield_ty,
yield_span,
);
}

if let (Some(mir_yield_ty), Some(ur_yield_ty)) =
(body.yield_ty(), universal_regions.yield_ty)
{
if let Some(mir_resume_ty) = body.resume_ty() {
let yield_span = body.local_decls[RETURN_PLACE].source_info.span;
self.equate_normalized_input_or_output(ur_yield_ty, mir_yield_ty, yield_span);
self.equate_normalized_input_or_output(
universal_regions.resume_ty.unwrap(),
mir_resume_ty,
yield_span,
);
}

// Return types are a bit more complex. They may contain opaque `impl Trait` types.
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_borrowck/src/type_check/liveness/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ impl<'cx, 'tcx> Visitor<'tcx> for LiveVariablesVisitor<'cx, 'tcx> {
match ty_context {
TyContext::ReturnTy(SourceInfo { span, .. })
| TyContext::YieldTy(SourceInfo { span, .. })
| TyContext::ResumeTy(SourceInfo { span, .. })
| TyContext::UserTy(span)
| TyContext::LocalDecl { source_info: SourceInfo { span, .. }, .. } => {
span_bug!(span, "should not be visiting outside of the CFG: {:?}", ty_context);
Expand Down
26 changes: 24 additions & 2 deletions compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1450,13 +1450,13 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
}
TerminatorKind::Yield { value, .. } => {
TerminatorKind::Yield { value, resume_arg, .. } => {
self.check_operand(value, term_location);

let value_ty = value.ty(body, tcx);
match body.yield_ty() {
None => span_mirbug!(self, term, "yield in non-coroutine"),
Some(ty) => {
let value_ty = value.ty(body, tcx);
if let Err(terr) = self.sub_types(
value_ty,
ty,
Expand All @@ -1474,6 +1474,28 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
}

match body.resume_ty() {
None => span_mirbug!(self, term, "yield in non-coroutine"),
Some(ty) => {
let resume_ty = resume_arg.ty(body, tcx);
if let Err(terr) = self.sub_types(
ty,
resume_ty.ty,
term_location.to_locations(),
ConstraintCategory::Yield,
) {
span_mirbug!(
self,
term,
"type of resume place is {:?}, but the resume type is {:?}: {:?}",
resume_ty,
ty,
terr
);
}
}
}
}
}
}
Expand Down
12 changes: 9 additions & 3 deletions compiler/rustc_borrowck/src/universal_regions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ pub struct UniversalRegions<'tcx> {
pub unnormalized_input_tys: &'tcx [Ty<'tcx>],

pub yield_ty: Option<Ty<'tcx>>,

pub resume_ty: Option<Ty<'tcx>>,
}

/// The "defining type" for this MIR. The key feature of the "defining
Expand Down Expand Up @@ -525,9 +527,12 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
debug!("build: extern regions = {}..{}", first_extern_index, first_local_index);
debug!("build: local regions = {}..{}", first_local_index, num_universals);

let yield_ty = match defining_ty {
DefiningTy::Coroutine(_, args) => Some(args.as_coroutine().yield_ty()),
_ => None,
let (resume_ty, yield_ty) = match defining_ty {
DefiningTy::Coroutine(_, args) => {
let tys = args.as_coroutine();
(Some(tys.resume_ty()), Some(tys.yield_ty()))
}
_ => (None, None),
};

UniversalRegions {
Expand All @@ -541,6 +546,7 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
unnormalized_output_ty: *unnormalized_output_ty,
unnormalized_input_tys,
yield_ty,
resume_ty,
}
}

Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ pub struct CoroutineInfo<'tcx> {
/// The yield type of the function, if it is a coroutine.
pub yield_ty: Option<Ty<'tcx>>,

/// The resume type of the function, if it is a coroutine.
pub resume_ty: Option<Ty<'tcx>>,

/// Coroutine drop glue.
pub coroutine_drop: Option<Body<'tcx>>,

Expand Down Expand Up @@ -385,6 +388,7 @@ impl<'tcx> Body<'tcx> {
coroutine: coroutine_kind.map(|coroutine_kind| {
Box::new(CoroutineInfo {
yield_ty: None,
resume_ty: None,
coroutine_drop: None,
coroutine_layout: None,
coroutine_kind,
Expand Down Expand Up @@ -551,6 +555,11 @@ impl<'tcx> Body<'tcx> {
self.coroutine.as_ref().and_then(|coroutine| coroutine.yield_ty)
}

#[inline]
pub fn resume_ty(&self) -> Option<Ty<'tcx>> {
self.coroutine.as_ref().and_then(|coroutine| coroutine.resume_ty)
}

#[inline]
pub fn coroutine_layout(&self) -> Option<&CoroutineLayout<'tcx>> {
self.coroutine.as_ref().and_then(|coroutine| coroutine.coroutine_layout.as_ref())
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_middle/src/mir/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,12 @@ macro_rules! super_body {
TyContext::YieldTy(SourceInfo::outermost(span))
);
}
if let Some(resume_ty) = $(& $mutability)? gen.resume_ty {
$self.visit_ty(
resume_ty,
TyContext::ResumeTy(SourceInfo::outermost(span))
);
}
}

for (bb, data) in basic_blocks_iter!($body, $($mutability, $invalidate)?) {
Expand Down Expand Up @@ -1244,6 +1250,8 @@ pub enum TyContext {

YieldTy(SourceInfo),

ResumeTy(SourceInfo),

/// A type found at some location.
Location(Location),
}
Expand Down
27 changes: 17 additions & 10 deletions compiler/rustc_mir_build/src/build/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,17 +488,17 @@ fn construct_fn<'tcx>(

let arguments = &thir.params;

let (yield_ty, return_ty) = if coroutine_kind.is_some() {
let (resume_ty, yield_ty, return_ty) = if coroutine_kind.is_some() {
let coroutine_ty = arguments[thir::UPVAR_ENV_PARAM].ty;
let coroutine_sig = match coroutine_ty.kind() {
ty::Coroutine(_, gen_args, ..) => gen_args.as_coroutine().sig(),
_ => {
span_bug!(span, "coroutine w/o coroutine type: {:?}", coroutine_ty)
}
};
(Some(coroutine_sig.yield_ty), coroutine_sig.return_ty)
(Some(coroutine_sig.resume_ty), Some(coroutine_sig.yield_ty), coroutine_sig.return_ty)
} else {
(None, fn_sig.output())
(None, None, fn_sig.output())
};

if let Some(custom_mir_attr) =
Expand Down Expand Up @@ -562,9 +562,12 @@ fn construct_fn<'tcx>(
} else {
None
};
if yield_ty.is_some() {

if coroutine_kind.is_some() {
body.coroutine.as_mut().unwrap().yield_ty = yield_ty;
body.coroutine.as_mut().unwrap().resume_ty = resume_ty;
}

body
}

Expand Down Expand Up @@ -631,28 +634,29 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
let hir_id = tcx.local_def_id_to_hir_id(def_id);
let coroutine_kind = tcx.coroutine_kind(def_id);

let (inputs, output, yield_ty) = match tcx.def_kind(def_id) {
let (inputs, output, resume_ty, yield_ty) = match tcx.def_kind(def_id) {
DefKind::Const
| DefKind::AssocConst
| DefKind::AnonConst
| DefKind::InlineConst
| DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None),
| DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None, None),
DefKind::Ctor(..) | DefKind::Fn | DefKind::AssocFn => {
let sig = tcx.liberate_late_bound_regions(
def_id.to_def_id(),
tcx.fn_sig(def_id).instantiate_identity(),
);
(sig.inputs().to_vec(), sig.output(), None)
(sig.inputs().to_vec(), sig.output(), None, None)
}
DefKind::Closure if coroutine_kind.is_some() => {
let coroutine_ty = tcx.type_of(def_id).instantiate_identity();
let ty::Coroutine(_, args) = coroutine_ty.kind() else {
bug!("expected type of coroutine-like closure to be a coroutine")
};
let args = args.as_coroutine();
let resume_ty = args.resume_ty();
let yield_ty = args.yield_ty();
let return_ty = args.return_ty();
(vec![coroutine_ty, args.resume_ty()], return_ty, Some(yield_ty))
(vec![coroutine_ty, args.resume_ty()], return_ty, Some(resume_ty), Some(yield_ty))
}
DefKind::Closure => {
let closure_ty = tcx.type_of(def_id).instantiate_identity();
Expand All @@ -666,7 +670,7 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
ty::ClosureKind::FnMut => Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty),
ty::ClosureKind::FnOnce => closure_ty,
};
([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None)
([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None, None)
}
dk => bug!("{:?} is not a body: {:?}", def_id, dk),
};
Expand Down Expand Up @@ -705,7 +709,10 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
Some(guar),
);

body.coroutine.as_mut().map(|gen| gen.yield_ty = yield_ty);
body.coroutine.as_mut().map(|gen| {
gen.yield_ty = yield_ty;
gen.resume_ty = resume_ty;
});

body
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1733,6 +1733,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
}

body.coroutine.as_mut().unwrap().yield_ty = None;
body.coroutine.as_mut().unwrap().resume_ty = None;
body.coroutine.as_mut().unwrap().coroutine_layout = Some(layout);

// Insert `drop(coroutine_struct)` which is used to drop upvars for coroutines in
Expand Down
35 changes: 35 additions & 0 deletions tests/ui/coroutine/check-resume-ty-lifetimes-2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#![feature(coroutine_trait)]
#![feature(coroutines)]

use std::ops::Coroutine;

struct Contravariant<'a>(fn(&'a ()));
struct Covariant<'a>(fn() -> &'a ());

fn bad1<'short, 'long: 'short>() -> impl Coroutine<Covariant<'short>> {
|_: Covariant<'short>| {
let a: Covariant<'long> = yield ();
//~^ ERROR lifetime may not live long enough
}
}

fn bad2<'short, 'long: 'short>() -> impl Coroutine<Contravariant<'long>> {
|_: Contravariant<'long>| {
let a: Contravariant<'short> = yield ();
//~^ ERROR lifetime may not live long enough
}
}

fn good1<'short, 'long: 'short>() -> impl Coroutine<Covariant<'long>> {
|_: Covariant<'long>| {
let a: Covariant<'short> = yield ();
}
}

fn good2<'short, 'long: 'short>() -> impl Coroutine<Contravariant<'short>> {
|_: Contravariant<'short>| {
let a: Contravariant<'long> = yield ();
}
}

fn main() {}
36 changes: 36 additions & 0 deletions tests/ui/coroutine/check-resume-ty-lifetimes-2.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
error: lifetime may not live long enough
--> $DIR/check-resume-ty-lifetimes-2.rs:11:16
|
LL | fn bad1<'short, 'long: 'short>() -> impl Coroutine<Covariant<'short>> {
| ------ ----- lifetime `'long` defined here
| |
| lifetime `'short` defined here
LL | |_: Covariant<'short>| {
LL | let a: Covariant<'long> = yield ();
| ^^^^^^^^^^^^^^^^ type annotation requires that `'short` must outlive `'long`
|
= help: consider adding the following bound: `'short: 'long`
help: consider adding 'move' keyword before the nested closure
|
LL | move |_: Covariant<'short>| {
| ++++

error: lifetime may not live long enough
--> $DIR/check-resume-ty-lifetimes-2.rs:18:40
|
LL | fn bad2<'short, 'long: 'short>() -> impl Coroutine<Contravariant<'long>> {
| ------ ----- lifetime `'long` defined here
| |
| lifetime `'short` defined here
LL | |_: Contravariant<'long>| {
LL | let a: Contravariant<'short> = yield ();
| ^^^^^^^^ yielding this value requires that `'short` must outlive `'long`
|
= help: consider adding the following bound: `'short: 'long`
help: consider adding 'move' keyword before the nested closure
|
LL | move |_: Contravariant<'long>| {
| ++++

error: aborting due to 2 previous errors

27 changes: 27 additions & 0 deletions tests/ui/coroutine/check-resume-ty-lifetimes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#![feature(coroutine_trait)]
#![feature(coroutines)]
#![allow(unused)]

use std::ops::Coroutine;
use std::ops::CoroutineState;
use std::pin::pin;

fn mk_static(s: &str) -> &'static str {
let mut storage: Option<&'static str> = None;

let mut coroutine = pin!(|_: &str| {
let x: &'static str = yield ();
//~^ ERROR lifetime may not live long enough
storage = Some(x);
});

coroutine.as_mut().resume(s);
coroutine.as_mut().resume(s);

storage.unwrap()
}

fn main() {
let s = mk_static(&String::from("hello, world"));
println!("{s}");
}
11 changes: 11 additions & 0 deletions tests/ui/coroutine/check-resume-ty-lifetimes.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
error: lifetime may not live long enough
--> $DIR/check-resume-ty-lifetimes.rs:13:16
|
LL | fn mk_static(s: &str) -> &'static str {
| - let's call the lifetime of this reference `'1`
...
LL | let x: &'static str = yield ();
| ^^^^^^^^^^^^ type annotation requires that `'1` must outlive `'static`

error: aborting due to 1 previous error

Loading