Skip to content

Commit

Permalink
Properly handle async blocks and fns in if exprs without else
Browse files Browse the repository at this point in the history
When encountering a tail expression in the then arm of an `if` expression
without an `else` arm, account for `async fn` and `async` blocks to
suggest `return`ing the value and pointing at the return type of the
`async fn`.

We now also account for AFIT when looking for the return type to point at.

Fix #115405.
  • Loading branch information
estebank committed Feb 6, 2024
1 parent bf3c6c5 commit e753c31
Show file tree
Hide file tree
Showing 11 changed files with 237 additions and 47 deletions.
38 changes: 29 additions & 9 deletions compiler/rustc_hir_typeck/src/coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,16 @@ impl<'a, 'tcx> Deref for Coerce<'a, 'tcx> {

type CoerceResult<'tcx> = InferResult<'tcx, (Vec<Adjustment<'tcx>>, Ty<'tcx>)>;

struct CollectRetsVisitor<'tcx> {
ret_exprs: Vec<&'tcx hir::Expr<'tcx>>,
pub struct CollectRetsVisitor<'tcx> {
pub ret_exprs: Vec<&'tcx hir::Expr<'tcx>>,
}

impl<'tcx> Visitor<'tcx> for CollectRetsVisitor<'tcx> {
fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
if let hir::ExprKind::Ret(_) = expr.kind {
self.ret_exprs.push(expr);
match expr.kind {
hir::ExprKind::Ret(_) => self.ret_exprs.push(expr),
hir::ExprKind::Closure(_) => return,
_ => {}
}
intravisit::walk_expr(self, expr);
}
Expand Down Expand Up @@ -1856,13 +1858,31 @@ impl<'tcx, 'exprs, E: AsCoercionSite> CoerceMany<'tcx, 'exprs, E> {
}

let parent_id = fcx.tcx.hir().get_parent_item(id);
let parent_item = fcx.tcx.hir_node_by_def_id(parent_id.def_id);
let mut parent_item = fcx.tcx.hir_node_by_def_id(parent_id.def_id);
// When suggesting return, we need to account for closures and async blocks, not just items.
for (_, node) in fcx.tcx.hir().parent_iter(id) {
match node {
hir::Node::Expr(&hir::Expr {
kind: hir::ExprKind::Closure(hir::Closure { .. }),
..
}) => {
parent_item = node;
break;
}
hir::Node::Item(_) | hir::Node::TraitItem(_) | hir::Node::ImplItem(_) => break,
_ => {}
}
}

if let (Some(expr), Some(_), Some((fn_id, fn_decl, _, _))) =
(expression, blk_id, fcx.get_node_fn_decl(parent_item))
{
if let (Some(expr), Some(_), Some(fn_decl)) = (expression, blk_id, parent_item.fn_decl()) {
fcx.suggest_missing_break_or_return_expr(
&mut err, expr, fn_decl, expected, found, id, fn_id,
&mut err,
expr,
fn_decl,
expected,
found,
id,
parent_id.into(),
);
}

Expand Down
37 changes: 29 additions & 8 deletions compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,14 +955,35 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
owner_id,
..
}) => Some((hir::HirId::make_owner(owner_id.def_id), &sig.decl, ident, false)),
Node::Expr(&hir::Expr { hir_id, kind: hir::ExprKind::Closure(..), .. })
if let Some(Node::Item(&hir::Item {
ident,
kind: hir::ItemKind::Fn(ref sig, ..),
owner_id,
..
})) = self.tcx.hir().find_parent(hir_id) =>
{
Node::Expr(&hir::Expr {
hir_id,
kind:
hir::ExprKind::Closure(hir::Closure {
kind: hir::ClosureKind::Coroutine(..), ..
}),
..
}) => {
let (ident, sig, owner_id) = match self.tcx.hir().find_parent(hir_id) {
Some(Node::Item(&hir::Item {
ident,
kind: hir::ItemKind::Fn(ref sig, ..),
owner_id,
..
})) => (ident, sig, owner_id),
Some(Node::TraitItem(&hir::TraitItem {
ident,
kind: hir::TraitItemKind::Fn(ref sig, ..),
owner_id,
..
})) => (ident, sig, owner_id),
Some(Node::ImplItem(&hir::ImplItem {
ident,
kind: hir::ImplItemKind::Fn(ref sig, ..),
owner_id,
..
})) => (ident, sig, owner_id),
_ => return None,
};
Some((
hir::HirId::make_owner(owner_id.def_id),
&sig.decl,
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1726,7 +1726,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
}

/// Given a function block's `HirId`, returns its `FnDecl` if it exists, or `None` otherwise.
fn get_parent_fn_decl(&self, blk_id: hir::HirId) -> Option<(&'tcx hir::FnDecl<'tcx>, Ident)> {
pub(crate) fn get_parent_fn_decl(
&self,
blk_id: hir::HirId,
) -> Option<(&'tcx hir::FnDecl<'tcx>, Ident)> {
let parent = self.tcx.hir_node_by_def_id(self.tcx.hir().get_parent_item(blk_id).def_id);
self.get_node_fn_decl(parent).map(|(_, fn_decl, ident, _)| (fn_decl, ident))
}
Expand Down
103 changes: 78 additions & 25 deletions compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::FnCtxt;

use crate::coercion::CollectRetsVisitor;
use crate::errors;
use crate::fluent_generated as fluent;
use crate::fn_ctxt::rustc_span::BytePos;
Expand All @@ -16,6 +17,7 @@ use rustc_errors::{Applicability, Diagnostic, MultiSpan};
use rustc_hir as hir;
use rustc_hir::def::Res;
use rustc_hir::def::{CtorKind, CtorOf, DefKind};
use rustc_hir::intravisit::{Map, Visitor};
use rustc_hir::lang_items::LangItem;
use rustc_hir::{
CoroutineDesugaring, CoroutineKind, CoroutineSource, Expr, ExprKind, GenericBound, HirId, Node,
Expand Down Expand Up @@ -826,6 +828,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
}
hir::FnRetTy::Return(hir_ty) => {
if let hir::TyKind::OpaqueDef(item_id, ..) = hir_ty.kind
// FIXME: account for RPITIT.
&& let hir::Node::Item(hir::Item {
kind: hir::ItemKind::OpaqueTy(op_ty), ..
}) = self.tcx.hir_node(item_id.hir_id())
Expand Down Expand Up @@ -1037,33 +1040,83 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
return;
}

if let hir::FnRetTy::Return(ty) = fn_decl.output {
let ty = self.astconv().ast_ty_to_ty(ty);
let bound_vars = self.tcx.late_bound_vars(fn_id);
let ty = self
.tcx
.instantiate_bound_regions_with_erased(Binder::bind_with_vars(ty, bound_vars));
let ty = match self.tcx.asyncness(fn_id.owner) {
ty::Asyncness::Yes => self.get_impl_future_output_ty(ty).unwrap_or_else(|| {
span_bug!(fn_decl.output.span(), "failed to get output type of async function")
}),
ty::Asyncness::No => ty,
};
let ty = self.normalize(expr.span, ty);
if self.can_coerce(found, ty) {
if let Some(node) = self.tcx.opt_hir_node(fn_id)
&& let Some(owner_node) = node.as_owner()
&& let Some(span) = expr.span.find_ancestor_inside(owner_node.span())
let in_closure = matches!(
self.tcx
.hir()
.parent_iter(id)
.filter(|(_, node)| {
matches!(
node,
Node::Expr(Expr { kind: ExprKind::Closure(..), .. })
| Node::Item(_)
| Node::TraitItem(_)
| Node::ImplItem(_)
)
})
.next(),
Some((_, Node::Expr(Expr { kind: ExprKind::Closure(..), .. })))
);

let can_return = match fn_decl.output {
hir::FnRetTy::Return(ty) => {
let ty = self.astconv().ast_ty_to_ty(ty);
let bound_vars = self.tcx.late_bound_vars(fn_id);
let ty = self
.tcx
.instantiate_bound_regions_with_erased(Binder::bind_with_vars(ty, bound_vars));
let ty = match self.tcx.asyncness(fn_id.owner) {
ty::Asyncness::Yes => self.get_impl_future_output_ty(ty).unwrap_or_else(|| {
span_bug!(
fn_decl.output.span(),
"failed to get output type of async function"
)
}),
ty::Asyncness::No => ty,
};
let ty = self.normalize(expr.span, ty);
self.can_coerce(found, ty)
}
hir::FnRetTy::DefaultReturn(_) if in_closure => {
let mut rets = vec![];
if let Some(ret_coercion) = self.ret_coercion.as_ref() {
let ret_ty = ret_coercion.borrow().expected_ty();
rets.push(ret_ty);
}
let mut visitor = CollectRetsVisitor { ret_exprs: vec![] };
if let Some(item) = self.tcx.hir().find(id)
&& let Node::Expr(expr) = item
{
err.multipart_suggestion(
"you might have meant to return this value",
vec![
(span.shrink_to_lo(), "return ".to_string()),
(span.shrink_to_hi(), ";".to_string()),
],
Applicability::MaybeIncorrect,
);
visitor.visit_expr(expr);
for expr in visitor.ret_exprs {
if let Some(ty) = self.typeck_results.borrow().node_type_opt(expr.hir_id) {
rets.push(ty);
}
}
if let hir::ExprKind::Block(hir::Block { expr: Some(expr), .. }, _) = expr.kind
{
if let Some(ty) = self.typeck_results.borrow().node_type_opt(expr.hir_id) {
rets.push(ty);
}
}
}
info!(?rets);
rets.into_iter().all(|ty| self.can_coerce(found, ty))
}
_ => false,
};
if can_return {
if let Some(node) = self.tcx.opt_hir_node(fn_id)
&& let Some(owner_node) = node.as_owner()
&& let Some(span) = expr.span.find_ancestor_inside(owner_node.span())
{
err.multipart_suggestion(
"you might have meant to return this value",
vec![
(span.shrink_to_lo(), "return ".to_string()),
(span.shrink_to_hi(), ";".to_string()),
],
Applicability::MaybeIncorrect,
);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/hir/map/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ impl<'hir> Map<'hir> {
Node::Item(_)
| Node::ForeignItem(_)
| Node::TraitItem(_)
| Node::Expr(Expr { kind: ExprKind::Closure { .. }, .. })
| Node::Expr(Expr { kind: ExprKind::Closure(_), .. })
| Node::ImplItem(_)
// The input node `id` must be enclosed in the method's body as opposed
// to some other place such as its return type (fixes #114918).
Expand Down
22 changes: 22 additions & 0 deletions tests/ui/async-await/missing-return-in-async-block.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// run-rustfix
// edition:2021
use std::future::Future;
use std::pin::Pin;
pub struct S;
pub fn foo() {
let _ = Box::pin(async move {
if true {
return Ok(S); //~ ERROR mismatched types
}
Err(())
});
}
pub fn bar() -> Pin<Box<dyn Future<Output = Result<S, ()>> + 'static>> {
Box::pin(async move {
if true {
return Ok(S); //~ ERROR mismatched types
}
Err(())
})
}
fn main() {}
22 changes: 22 additions & 0 deletions tests/ui/async-await/missing-return-in-async-block.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// run-rustfix
// edition:2021
use std::future::Future;
use std::pin::Pin;
pub struct S;
pub fn foo() {
let _ = Box::pin(async move {
if true {
Ok(S) //~ ERROR mismatched types
}
Err(())
});
}
pub fn bar() -> Pin<Box<dyn Future<Output = Result<S, ()>> + 'static>> {
Box::pin(async move {
if true {
Ok(S) //~ ERROR mismatched types
}
Err(())
})
}
fn main() {}
35 changes: 35 additions & 0 deletions tests/ui/async-await/missing-return-in-async-block.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
error[E0308]: mismatched types
--> $DIR/missing-return-in-async-block.rs:9:13
|
LL | / if true {
LL | | Ok(S)
| | ^^^^^ expected `()`, found `Result<S, _>`
LL | | }
| |_________- expected this to be `()`
|
= note: expected unit type `()`
found enum `Result<S, _>`
help: you might have meant to return this value
|
LL | return Ok(S);
| ++++++ +

error[E0308]: mismatched types
--> $DIR/missing-return-in-async-block.rs:17:13
|
LL | / if true {
LL | | Ok(S)
| | ^^^^^ expected `()`, found `Result<S, _>`
LL | | }
| |_________- expected this to be `()`
|
= note: expected unit type `()`
found enum `Result<S, _>`
help: you might have meant to return this value
|
LL | return Ok(S);
| ++++++ +

error: aborting due to 2 previous errors

For more information about this error, try `rustc --explain E0308`.
2 changes: 2 additions & 0 deletions tests/ui/impl-trait/in-trait/default-body-type-err-2.stderr
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
error[E0308]: mismatched types
--> $DIR/default-body-type-err-2.rs:7:9
|
LL | async fn woopsie_async(&self) -> String {
| ------ expected `String` because of return type
LL | 42
| ^^- help: try using a conversion method: `.to_string()`
| |
Expand Down
2 changes: 2 additions & 0 deletions tests/ui/loops/dont-suggest-break-thru-item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ fn closure() {
if true {
Err(1)
//~^ ERROR mismatched types
//~| HELP you might have meant to return this value
}

Ok(())
Expand All @@ -21,6 +22,7 @@ fn async_block() {
if true {
Err(1)
//~^ ERROR mismatched types
//~| HELP you might have meant to return this value
}

Ok(())
Expand Down
Loading

0 comments on commit e753c31

Please sign in to comment.