Skip to content

Commit

Permalink
Split fn_ctxt/adjust_fulfillment_errors from fn_ctxt/checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Nathan-Fenner committed Feb 7, 2023
1 parent e1eaa2d commit 4c05366
Show file tree
Hide file tree
Showing 2 changed files with 374 additions and 370 deletions.
374 changes: 373 additions & 1 deletion compiler/rustc_hir_typeck/src/fn_ctxt/adjust_fulfillment_errors.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,382 @@
use crate::FnCtxt;
use rustc_hir as hir;
use rustc_hir::def::Res;
use rustc_middle::ty::{self, DefIdTree, Ty};
use rustc_hir::def_id::DefId;
use rustc_infer::traits::ObligationCauseCode;
use rustc_middle::ty::{self, DefIdTree, Ty, TypeSuperVisitable, TypeVisitable, TypeVisitor};
use rustc_span::{self, Span};
use rustc_trait_selection::traits;

use std::ops::ControlFlow;

impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
pub fn adjust_fulfillment_error_for_expr_obligation(
&self,
error: &mut traits::FulfillmentError<'tcx>,
) -> bool {
let (traits::ExprItemObligation(def_id, hir_id, idx) | traits::ExprBindingObligation(def_id, _, hir_id, idx))
= *error.obligation.cause.code().peel_derives() else { return false; };
let hir = self.tcx.hir();
let hir::Node::Expr(expr) = hir.get(hir_id) else { return false; };

let Some(unsubstituted_pred) =
self.tcx.predicates_of(def_id).instantiate_identity(self.tcx).predicates.into_iter().nth(idx)
else { return false; };

let generics = self.tcx.generics_of(def_id);
let predicate_substs = match unsubstituted_pred.kind().skip_binder() {
ty::PredicateKind::Clause(ty::Clause::Trait(pred)) => pred.trait_ref.substs,
ty::PredicateKind::Clause(ty::Clause::Projection(pred)) => pred.projection_ty.substs,
_ => ty::List::empty(),
};

let find_param_matching = |matches: &dyn Fn(&ty::ParamTy) -> bool| {
predicate_substs.types().find_map(|ty| {
ty.walk().find_map(|arg| {
if let ty::GenericArgKind::Type(ty) = arg.unpack()
&& let ty::Param(param_ty) = ty.kind()
&& matches(param_ty)
{
Some(arg)
} else {
None
}
})
})
};

// Prefer generics that are local to the fn item, since these are likely
// to be the cause of the unsatisfied predicate.
let mut param_to_point_at = find_param_matching(&|param_ty| {
self.tcx.parent(generics.type_param(param_ty, self.tcx).def_id) == def_id
});
// Fall back to generic that isn't local to the fn item. This will come
// from a trait or impl, for example.
let mut fallback_param_to_point_at = find_param_matching(&|param_ty| {
self.tcx.parent(generics.type_param(param_ty, self.tcx).def_id) != def_id
&& param_ty.name != rustc_span::symbol::kw::SelfUpper
});
// Finally, the `Self` parameter is possibly the reason that the predicate
// is unsatisfied. This is less likely to be true for methods, because
// method probe means that we already kinda check that the predicates due
// to the `Self` type are true.
let mut self_param_to_point_at =
find_param_matching(&|param_ty| param_ty.name == rustc_span::symbol::kw::SelfUpper);

// Finally, for ambiguity-related errors, we actually want to look
// for a parameter that is the source of the inference type left
// over in this predicate.
if let traits::FulfillmentErrorCode::CodeAmbiguity = error.code {
fallback_param_to_point_at = None;
self_param_to_point_at = None;
param_to_point_at =
self.find_ambiguous_parameter_in(def_id, error.root_obligation.predicate);
}

if self.closure_span_overlaps_error(error, expr.span) {
return false;
}

match &expr.kind {
hir::ExprKind::Path(qpath) => {
if let hir::Node::Expr(hir::Expr {
kind: hir::ExprKind::Call(callee, args),
hir_id: call_hir_id,
span: call_span,
..
}) = hir.get_parent(expr.hir_id)
&& callee.hir_id == expr.hir_id
{
if self.closure_span_overlaps_error(error, *call_span) {
return false;
}

for param in
[param_to_point_at, fallback_param_to_point_at, self_param_to_point_at]
.into_iter()
.flatten()
{
if self.blame_specific_arg_if_possible(
error,
def_id,
param,
*call_hir_id,
callee.span,
None,
args,
)
{
return true;
}
}
}
// Notably, we only point to params that are local to the
// item we're checking, since those are the ones we are able
// to look in the final `hir::PathSegment` for. Everything else
// would require a deeper search into the `qpath` than I think
// is worthwhile.
if let Some(param_to_point_at) = param_to_point_at
&& self.point_at_path_if_possible(error, def_id, param_to_point_at, qpath)
{
return true;
}
}
hir::ExprKind::MethodCall(segment, receiver, args, ..) => {
for param in [param_to_point_at, fallback_param_to_point_at, self_param_to_point_at]
.into_iter()
.flatten()
{
if self.blame_specific_arg_if_possible(
error,
def_id,
param,
hir_id,
segment.ident.span,
Some(receiver),
args,
) {
return true;
}
}
if let Some(param_to_point_at) = param_to_point_at
&& self.point_at_generic_if_possible(error, def_id, param_to_point_at, segment)
{
return true;
}
}
hir::ExprKind::Struct(qpath, fields, ..) => {
if let Res::Def(
hir::def::DefKind::Struct | hir::def::DefKind::Variant,
variant_def_id,
) = self.typeck_results.borrow().qpath_res(qpath, hir_id)
{
for param in
[param_to_point_at, fallback_param_to_point_at, self_param_to_point_at]
{
if let Some(param) = param {
let refined_expr = self.point_at_field_if_possible(
def_id,
param,
variant_def_id,
fields,
);

match refined_expr {
None => {}
Some((refined_expr, _)) => {
error.obligation.cause.span = refined_expr
.span
.find_ancestor_in_same_ctxt(error.obligation.cause.span)
.unwrap_or(refined_expr.span);
return true;
}
}
}
}
}
if let Some(param_to_point_at) = param_to_point_at
&& self.point_at_path_if_possible(error, def_id, param_to_point_at, qpath)
{
return true;
}
}
_ => {}
}

false
}

fn point_at_path_if_possible(
&self,
error: &mut traits::FulfillmentError<'tcx>,
def_id: DefId,
param: ty::GenericArg<'tcx>,
qpath: &hir::QPath<'tcx>,
) -> bool {
match qpath {
hir::QPath::Resolved(_, path) => {
if let Some(segment) = path.segments.last()
&& self.point_at_generic_if_possible(error, def_id, param, segment)
{
return true;
}
}
hir::QPath::TypeRelative(_, segment) => {
if self.point_at_generic_if_possible(error, def_id, param, segment) {
return true;
}
}
_ => {}
}

false
}

fn point_at_generic_if_possible(
&self,
error: &mut traits::FulfillmentError<'tcx>,
def_id: DefId,
param_to_point_at: ty::GenericArg<'tcx>,
segment: &hir::PathSegment<'tcx>,
) -> bool {
let own_substs = self
.tcx
.generics_of(def_id)
.own_substs(ty::InternalSubsts::identity_for_item(self.tcx, def_id));
let Some((index, _)) = own_substs
.iter()
.filter(|arg| matches!(arg.unpack(), ty::GenericArgKind::Type(_)))
.enumerate()
.find(|(_, arg)| **arg == param_to_point_at) else { return false };
let Some(arg) = segment
.args()
.args
.iter()
.filter(|arg| matches!(arg, hir::GenericArg::Type(_)))
.nth(index) else { return false; };
error.obligation.cause.span = arg
.span()
.find_ancestor_in_same_ctxt(error.obligation.cause.span)
.unwrap_or(arg.span());
true
}

fn find_ambiguous_parameter_in<T: TypeVisitable<'tcx>>(
&self,
item_def_id: DefId,
t: T,
) -> Option<ty::GenericArg<'tcx>> {
struct FindAmbiguousParameter<'a, 'tcx>(&'a FnCtxt<'a, 'tcx>, DefId);
impl<'tcx> TypeVisitor<'tcx> for FindAmbiguousParameter<'_, 'tcx> {
type BreakTy = ty::GenericArg<'tcx>;
fn visit_ty(&mut self, ty: Ty<'tcx>) -> std::ops::ControlFlow<Self::BreakTy> {
if let Some(origin) = self.0.type_var_origin(ty)
&& let rustc_infer::infer::type_variable::TypeVariableOriginKind::TypeParameterDefinition(_, Some(def_id)) =
origin.kind
&& let generics = self.0.tcx.generics_of(self.1)
&& let Some(index) = generics.param_def_id_to_index(self.0.tcx, def_id)
&& let Some(subst) = ty::InternalSubsts::identity_for_item(self.0.tcx, self.1)
.get(index as usize)
{
ControlFlow::Break(*subst)
} else {
ty.super_visit_with(self)
}
}
}
t.visit_with(&mut FindAmbiguousParameter(self, item_def_id)).break_value()
}

fn closure_span_overlaps_error(
&self,
error: &traits::FulfillmentError<'tcx>,
span: Span,
) -> bool {
if let traits::FulfillmentErrorCode::CodeSelectionError(
traits::SelectionError::OutputTypeParameterMismatch(_, expected, _),
) = error.code
&& let ty::Closure(def_id, _) | ty::Generator(def_id, ..) = expected.skip_binder().self_ty().kind()
&& span.overlaps(self.tcx.def_span(*def_id))
{
true
} else {
false
}
}

fn point_at_field_if_possible(
&self,
def_id: DefId,
param_to_point_at: ty::GenericArg<'tcx>,
variant_def_id: DefId,
expr_fields: &[hir::ExprField<'tcx>],
) -> Option<(&'tcx hir::Expr<'tcx>, Ty<'tcx>)> {
let def = self.tcx.adt_def(def_id);

let identity_substs = ty::InternalSubsts::identity_for_item(self.tcx, def_id);
let fields_referencing_param: Vec<_> = def
.variant_with_id(variant_def_id)
.fields
.iter()
.filter(|field| {
let field_ty = field.ty(self.tcx, identity_substs);
Self::find_param_in_ty(field_ty.into(), param_to_point_at)
})
.collect();

if let [field] = fields_referencing_param.as_slice() {
for expr_field in expr_fields {
// Look for the ExprField that matches the field, using the
// same rules that check_expr_struct uses for macro hygiene.
if self.tcx.adjust_ident(expr_field.ident, variant_def_id) == field.ident(self.tcx)
{
return Some((expr_field.expr, self.tcx.type_of(field.did)));
}
}
}

None
}

/// - `blame_specific_*` means that the function will recursively traverse the expression,
/// looking for the most-specific-possible span to blame.
///
/// - `point_at_*` means that the function will only go "one level", pointing at the specific
/// expression mentioned.
///
/// `blame_specific_arg_if_possible` will find the most-specific expression anywhere inside
/// the provided function call expression, and mark it as responsible for the fullfillment
/// error.
fn blame_specific_arg_if_possible(
&self,
error: &mut traits::FulfillmentError<'tcx>,
def_id: DefId,
param_to_point_at: ty::GenericArg<'tcx>,
call_hir_id: hir::HirId,
callee_span: Span,
receiver: Option<&'tcx hir::Expr<'tcx>>,
args: &'tcx [hir::Expr<'tcx>],
) -> bool {
let ty = self.tcx.type_of(def_id);
if !ty.is_fn() {
return false;
}
let sig = ty.fn_sig(self.tcx).skip_binder();
let args_referencing_param: Vec<_> = sig
.inputs()
.iter()
.enumerate()
.filter(|(_, ty)| Self::find_param_in_ty((**ty).into(), param_to_point_at))
.collect();
// If there's one field that references the given generic, great!
if let [(idx, _)] = args_referencing_param.as_slice()
&& let Some(arg) = receiver
.map_or(args.get(*idx), |rcvr| if *idx == 0 { Some(rcvr) } else { args.get(*idx - 1) }) {

error.obligation.cause.span = arg.span.find_ancestor_in_same_ctxt(error.obligation.cause.span).unwrap_or(arg.span);

if let hir::Node::Expr(arg_expr) = self.tcx.hir().get(arg.hir_id) {
// This is more specific than pointing at the entire argument.
self.blame_specific_expr_if_possible(error, arg_expr)
}

error.obligation.cause.map_code(|parent_code| {
ObligationCauseCode::FunctionArgumentObligation {
arg_hir_id: arg.hir_id,
call_hir_id,
parent_code,
}
});
return true;
} else if args_referencing_param.len() > 0 {
// If more than one argument applies, then point to the callee span at least...
// We have chance to fix this up further in `point_at_generics_if_possible`
error.obligation.cause.span = callee_span;
}

false
}

/**
* Recursively searches for the most-specific blamable expression.
* For example, if you have a chain of constraints like:
Expand Down
Loading

0 comments on commit 4c05366

Please sign in to comment.