Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split fn_ctxt/adjust_fulfillment_errors from fn_ctxt/checks #107746

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
374 changes: 373 additions & 1 deletion compiler/rustc_hir_typeck/src/fn_ctxt/adjust_fulfillment_errors.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,382 @@
use crate::FnCtxt;
use rustc_hir as hir;
use rustc_hir::def::Res;
use rustc_middle::ty::{self, DefIdTree, Ty};
use rustc_hir::def_id::DefId;
use rustc_infer::traits::ObligationCauseCode;
use rustc_middle::ty::{self, DefIdTree, Ty, TypeSuperVisitable, TypeVisitable, TypeVisitor};
use rustc_span::{self, Span};
use rustc_trait_selection::traits;

use std::ops::ControlFlow;

impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
pub fn adjust_fulfillment_error_for_expr_obligation(
&self,
error: &mut traits::FulfillmentError<'tcx>,
) -> bool {
let (traits::ExprItemObligation(def_id, hir_id, idx) | traits::ExprBindingObligation(def_id, _, hir_id, idx))
= *error.obligation.cause.code().peel_derives() else { return false; };
let hir = self.tcx.hir();
let hir::Node::Expr(expr) = hir.get(hir_id) else { return false; };

let Some(unsubstituted_pred) =
self.tcx.predicates_of(def_id).instantiate_identity(self.tcx).predicates.into_iter().nth(idx)
else { return false; };

let generics = self.tcx.generics_of(def_id);
let predicate_substs = match unsubstituted_pred.kind().skip_binder() {
ty::PredicateKind::Clause(ty::Clause::Trait(pred)) => pred.trait_ref.substs,
ty::PredicateKind::Clause(ty::Clause::Projection(pred)) => pred.projection_ty.substs,
_ => ty::List::empty(),
};

let find_param_matching = |matches: &dyn Fn(&ty::ParamTy) -> bool| {
predicate_substs.types().find_map(|ty| {
ty.walk().find_map(|arg| {
if let ty::GenericArgKind::Type(ty) = arg.unpack()
&& let ty::Param(param_ty) = ty.kind()
&& matches(param_ty)
{
Some(arg)
} else {
None
}
})
})
};

// Prefer generics that are local to the fn item, since these are likely
// to be the cause of the unsatisfied predicate.
let mut param_to_point_at = find_param_matching(&|param_ty| {
self.tcx.parent(generics.type_param(param_ty, self.tcx).def_id) == def_id
});
// Fall back to generic that isn't local to the fn item. This will come
// from a trait or impl, for example.
let mut fallback_param_to_point_at = find_param_matching(&|param_ty| {
self.tcx.parent(generics.type_param(param_ty, self.tcx).def_id) != def_id
&& param_ty.name != rustc_span::symbol::kw::SelfUpper
});
// Finally, the `Self` parameter is possibly the reason that the predicate
// is unsatisfied. This is less likely to be true for methods, because
// method probe means that we already kinda check that the predicates due
// to the `Self` type are true.
let mut self_param_to_point_at =
find_param_matching(&|param_ty| param_ty.name == rustc_span::symbol::kw::SelfUpper);

// Finally, for ambiguity-related errors, we actually want to look
// for a parameter that is the source of the inference type left
// over in this predicate.
if let traits::FulfillmentErrorCode::CodeAmbiguity = error.code {
fallback_param_to_point_at = None;
self_param_to_point_at = None;
param_to_point_at =
self.find_ambiguous_parameter_in(def_id, error.root_obligation.predicate);
}

if self.closure_span_overlaps_error(error, expr.span) {
return false;
}

match &expr.kind {
hir::ExprKind::Path(qpath) => {
if let hir::Node::Expr(hir::Expr {
kind: hir::ExprKind::Call(callee, args),
hir_id: call_hir_id,
span: call_span,
..
}) = hir.get_parent(expr.hir_id)
&& callee.hir_id == expr.hir_id
{
if self.closure_span_overlaps_error(error, *call_span) {
return false;
}

for param in
[param_to_point_at, fallback_param_to_point_at, self_param_to_point_at]
.into_iter()
.flatten()
{
if self.blame_specific_arg_if_possible(
error,
def_id,
param,
*call_hir_id,
callee.span,
None,
args,
)
{
return true;
}
}
}
// Notably, we only point to params that are local to the
// item we're checking, since those are the ones we are able
// to look in the final `hir::PathSegment` for. Everything else
// would require a deeper search into the `qpath` than I think
// is worthwhile.
if let Some(param_to_point_at) = param_to_point_at
&& self.point_at_path_if_possible(error, def_id, param_to_point_at, qpath)
{
return true;
}
}
hir::ExprKind::MethodCall(segment, receiver, args, ..) => {
for param in [param_to_point_at, fallback_param_to_point_at, self_param_to_point_at]
.into_iter()
.flatten()
{
if self.blame_specific_arg_if_possible(
error,
def_id,
param,
hir_id,
segment.ident.span,
Some(receiver),
args,
) {
return true;
}
}
if let Some(param_to_point_at) = param_to_point_at
&& self.point_at_generic_if_possible(error, def_id, param_to_point_at, segment)
{
return true;
}
}
hir::ExprKind::Struct(qpath, fields, ..) => {
if let Res::Def(
hir::def::DefKind::Struct | hir::def::DefKind::Variant,
variant_def_id,
) = self.typeck_results.borrow().qpath_res(qpath, hir_id)
{
for param in
[param_to_point_at, fallback_param_to_point_at, self_param_to_point_at]
{
if let Some(param) = param {
let refined_expr = self.point_at_field_if_possible(
def_id,
param,
variant_def_id,
fields,
);

match refined_expr {
None => {}
Some((refined_expr, _)) => {
error.obligation.cause.span = refined_expr
.span
.find_ancestor_in_same_ctxt(error.obligation.cause.span)
.unwrap_or(refined_expr.span);
return true;
}
}
}
}
}
if let Some(param_to_point_at) = param_to_point_at
&& self.point_at_path_if_possible(error, def_id, param_to_point_at, qpath)
{
return true;
}
}
_ => {}
}

false
}

fn point_at_path_if_possible(
&self,
error: &mut traits::FulfillmentError<'tcx>,
def_id: DefId,
param: ty::GenericArg<'tcx>,
qpath: &hir::QPath<'tcx>,
) -> bool {
match qpath {
hir::QPath::Resolved(_, path) => {
if let Some(segment) = path.segments.last()
&& self.point_at_generic_if_possible(error, def_id, param, segment)
{
return true;
}
}
hir::QPath::TypeRelative(_, segment) => {
if self.point_at_generic_if_possible(error, def_id, param, segment) {
return true;
}
}
_ => {}
}

false
}

fn point_at_generic_if_possible(
&self,
error: &mut traits::FulfillmentError<'tcx>,
def_id: DefId,
param_to_point_at: ty::GenericArg<'tcx>,
segment: &hir::PathSegment<'tcx>,
) -> bool {
let own_substs = self
.tcx
.generics_of(def_id)
.own_substs(ty::InternalSubsts::identity_for_item(self.tcx, def_id));
let Some((index, _)) = own_substs
.iter()
.filter(|arg| matches!(arg.unpack(), ty::GenericArgKind::Type(_)))
.enumerate()
.find(|(_, arg)| **arg == param_to_point_at) else { return false };
let Some(arg) = segment
.args()
.args
.iter()
.filter(|arg| matches!(arg, hir::GenericArg::Type(_)))
.nth(index) else { return false; };
error.obligation.cause.span = arg
.span()
.find_ancestor_in_same_ctxt(error.obligation.cause.span)
.unwrap_or(arg.span());
true
}

fn find_ambiguous_parameter_in<T: TypeVisitable<'tcx>>(
&self,
item_def_id: DefId,
t: T,
) -> Option<ty::GenericArg<'tcx>> {
struct FindAmbiguousParameter<'a, 'tcx>(&'a FnCtxt<'a, 'tcx>, DefId);
impl<'tcx> TypeVisitor<'tcx> for FindAmbiguousParameter<'_, 'tcx> {
type BreakTy = ty::GenericArg<'tcx>;
fn visit_ty(&mut self, ty: Ty<'tcx>) -> std::ops::ControlFlow<Self::BreakTy> {
if let Some(origin) = self.0.type_var_origin(ty)
&& let rustc_infer::infer::type_variable::TypeVariableOriginKind::TypeParameterDefinition(_, Some(def_id)) =
origin.kind
&& let generics = self.0.tcx.generics_of(self.1)
&& let Some(index) = generics.param_def_id_to_index(self.0.tcx, def_id)
&& let Some(subst) = ty::InternalSubsts::identity_for_item(self.0.tcx, self.1)
.get(index as usize)
{
ControlFlow::Break(*subst)
} else {
ty.super_visit_with(self)
}
}
}
t.visit_with(&mut FindAmbiguousParameter(self, item_def_id)).break_value()
}

fn closure_span_overlaps_error(
&self,
error: &traits::FulfillmentError<'tcx>,
span: Span,
) -> bool {
if let traits::FulfillmentErrorCode::CodeSelectionError(
traits::SelectionError::OutputTypeParameterMismatch(_, expected, _),
) = error.code
&& let ty::Closure(def_id, _) | ty::Generator(def_id, ..) = expected.skip_binder().self_ty().kind()
&& span.overlaps(self.tcx.def_span(*def_id))
{
true
} else {
false
}
}

fn point_at_field_if_possible(
&self,
def_id: DefId,
param_to_point_at: ty::GenericArg<'tcx>,
variant_def_id: DefId,
expr_fields: &[hir::ExprField<'tcx>],
) -> Option<(&'tcx hir::Expr<'tcx>, Ty<'tcx>)> {
let def = self.tcx.adt_def(def_id);

let identity_substs = ty::InternalSubsts::identity_for_item(self.tcx, def_id);
let fields_referencing_param: Vec<_> = def
.variant_with_id(variant_def_id)
.fields
.iter()
.filter(|field| {
let field_ty = field.ty(self.tcx, identity_substs);
Self::find_param_in_ty(field_ty.into(), param_to_point_at)
})
.collect();

if let [field] = fields_referencing_param.as_slice() {
for expr_field in expr_fields {
// Look for the ExprField that matches the field, using the
// same rules that check_expr_struct uses for macro hygiene.
if self.tcx.adjust_ident(expr_field.ident, variant_def_id) == field.ident(self.tcx)
{
return Some((expr_field.expr, self.tcx.type_of(field.did)));
}
}
}

None
}

/// - `blame_specific_*` means that the function will recursively traverse the expression,
/// looking for the most-specific-possible span to blame.
///
/// - `point_at_*` means that the function will only go "one level", pointing at the specific
/// expression mentioned.
///
/// `blame_specific_arg_if_possible` will find the most-specific expression anywhere inside
/// the provided function call expression, and mark it as responsible for the fullfillment
/// error.
fn blame_specific_arg_if_possible(
&self,
error: &mut traits::FulfillmentError<'tcx>,
def_id: DefId,
param_to_point_at: ty::GenericArg<'tcx>,
call_hir_id: hir::HirId,
callee_span: Span,
receiver: Option<&'tcx hir::Expr<'tcx>>,
args: &'tcx [hir::Expr<'tcx>],
) -> bool {
let ty = self.tcx.type_of(def_id);
if !ty.is_fn() {
return false;
}
let sig = ty.fn_sig(self.tcx).skip_binder();
let args_referencing_param: Vec<_> = sig
.inputs()
.iter()
.enumerate()
.filter(|(_, ty)| Self::find_param_in_ty((**ty).into(), param_to_point_at))
.collect();
// If there's one field that references the given generic, great!
if let [(idx, _)] = args_referencing_param.as_slice()
&& let Some(arg) = receiver
.map_or(args.get(*idx), |rcvr| if *idx == 0 { Some(rcvr) } else { args.get(*idx - 1) }) {

error.obligation.cause.span = arg.span.find_ancestor_in_same_ctxt(error.obligation.cause.span).unwrap_or(arg.span);

if let hir::Node::Expr(arg_expr) = self.tcx.hir().get(arg.hir_id) {
// This is more specific than pointing at the entire argument.
self.blame_specific_expr_if_possible(error, arg_expr)
}

error.obligation.cause.map_code(|parent_code| {
ObligationCauseCode::FunctionArgumentObligation {
arg_hir_id: arg.hir_id,
call_hir_id,
parent_code,
}
});
return true;
} else if args_referencing_param.len() > 0 {
// If more than one argument applies, then point to the callee span at least...
// We have chance to fix this up further in `point_at_generics_if_possible`
error.obligation.cause.span = callee_span;
}

false
}

/**
* Recursively searches for the most-specific blamable expression.
* For example, if you have a chain of constraints like:
Expand Down
Loading