Skip to content

Commit

Permalink
Add support for deduction of most kinds of type constant. (#4389)
Browse files Browse the repository at this point in the history
This adds deduction in all the cases where we can match the instruction
fields of the parameter against the corresponding instruction fields of
the argument. This handles all current type constants except for struct
types, for which we would want to match by field name.

---------

Co-authored-by: Jon Ross-Perkins <[email protected]>
  • Loading branch information
zygoloid and jonmeow authored Oct 9, 2024
1 parent 9fadfb5 commit 6410d6e
Show file tree
Hide file tree
Showing 10 changed files with 3,152 additions and 179 deletions.
176 changes: 134 additions & 42 deletions toolchain/check/deduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,31 @@ class DeductionWorklist {
{.param = param, .arg = arg, .needs_substitution = needs_substitution});
}

// Adds a single (param, arg) type deduction.
auto Add(SemIR::TypeId param, SemIR::TypeId arg, bool needs_substitution)
-> void {
Add(context_.types().GetInstId(param), context_.types().GetInstId(arg),
needs_substitution);
}

// Adds a single (param, arg) deduction of a specific.
auto Add(SemIR::SpecificId param, SemIR::SpecificId arg,
bool needs_substitution) -> void {
auto& param_specific = context_.specifics().Get(param);
auto& arg_specific = context_.specifics().Get(arg);
if (param_specific.generic_id != arg_specific.generic_id) {
// TODO: Decide whether to error on this or just treat the specific as
// non-deduced. For now we treat it as non-deduced.
return;
}
AddAll(param_specific.args_id, arg_specific.args_id, needs_substitution);
}

// Adds a list of (param, arg) deductions. These are added in reverse order so
// they are popped in forward order.
auto AddAll(llvm::ArrayRef<SemIR::InstId> params,
llvm::ArrayRef<SemIR::InstId> args, bool needs_substitution)
-> void {
template <typename ElementId>
auto AddAll(llvm::ArrayRef<ElementId> params, llvm::ArrayRef<ElementId> args,
bool needs_substitution) -> void {
if (params.size() != args.size()) {
// TODO: Decide whether to error on this or just treat the parameter list
// as non-deduced. For now we treat it as non-deduced.
Expand All @@ -65,6 +85,44 @@ class DeductionWorklist {
needs_substitution);
}

auto AddAll(SemIR::TypeBlockId params, SemIR::TypeBlockId args,
bool needs_substitution) -> void {
AddAll(context_.type_blocks().Get(params), context_.type_blocks().Get(args),
needs_substitution);
}

// Adds a (param, arg) pair for an instruction argument, given its kind.
auto AddInstArg(SemIR::IdKind kind, int32_t param, int32_t arg,
bool needs_substitution) -> void {
switch (kind) {
case SemIR::IdKind::None:
case SemIR::IdKind::For<SemIR::ClassId>:
case SemIR::IdKind::For<SemIR::InterfaceId>:
case SemIR::IdKind::For<SemIR::IntKind>:
break;
case SemIR::IdKind::For<SemIR::InstId>:
Add(SemIR::InstId(param), SemIR::InstId(arg), needs_substitution);
break;
case SemIR::IdKind::For<SemIR::TypeId>:
Add(SemIR::TypeId(param), SemIR::TypeId(arg), needs_substitution);
break;
case SemIR::IdKind::For<SemIR::InstBlockId>:
AddAll(SemIR::InstBlockId(param), SemIR::InstBlockId(arg),
needs_substitution);
break;
case SemIR::IdKind::For<SemIR::TypeBlockId>:
AddAll(SemIR::TypeBlockId(param), SemIR::TypeBlockId(arg),
needs_substitution);
break;
case SemIR::IdKind::For<SemIR::SpecificId>:
Add(SemIR::SpecificId(param), SemIR::SpecificId(arg),
needs_substitution);
break;
default:
CARBON_FATAL("unexpected argument kind");
}
}

// Returns whether we have completed all deductions.
auto Done() -> bool { return deductions_.empty(); }

Expand Down Expand Up @@ -208,65 +266,99 @@ auto DeductionContext::Deduce() -> bool {
}
}

// If the parameter is a symbolic constant, deduce against it.
// If the parameter is a symbolic constant, deduce against it. Otherwise, we
// assume there is nothing to deduce.
// TODO: This won't do the right thing in a template deduction.
auto param_const_id = context().constant_values().Get(param_id);
if (!param_const_id.is_valid() || !param_const_id.is_symbolic()) {
continue;
}

// If we've not yet substituted into the parameter, do so now.
if (needs_substitution) {
param_const_id = SubstConstant(context(), param_const_id, substitutions_);
if (!param_const_id.is_valid() || !param_const_id.is_symbolic()) {
continue;
}
needs_substitution = false;
}

CARBON_KIND_SWITCH(context().insts().Get(
context().constant_values().GetInstId(
param_const_id))) {
// Attempt to match `param_inst` against `arg_id`. If the match succeeds,
// this should `continue` the outer loop. On `break`, we will try to desugar
// the parameter to continue looking for a match.
auto param_inst = context().insts().Get(
context().constant_values().GetInstId(param_const_id));
CARBON_KIND_SWITCH(param_inst) {
// Deducing a symbolic binding from an argument with a constant value
// deduces the binding as having that constant value.
case CARBON_KIND(SemIR::BindSymbolicName bind): {
auto& entity_name = context().entity_names().Get(bind.entity_name_id);
auto index = entity_name.bind_index;
if (index.is_valid() && index >= first_deduced_index_) {
CARBON_CHECK(
static_cast<size_t>(index.index) < result_arg_ids_.size(),
"Deduced value for unexpected index {0}; expected to "
"deduce {1} arguments.",
index, result_arg_ids_.size());
auto arg_const_inst_id =
context().constant_values().GetConstantInstId(arg_id);
if (arg_const_inst_id.is_valid()) {
if (result_arg_ids_[index.index].is_valid() &&
result_arg_ids_[index.index] != arg_const_inst_id) {
if (diagnose_) {
// TODO: Include the two different deduced values.
CARBON_DIAGNOSTIC(
DeductionInconsistent, Error,
"inconsistent deductions for value of generic "
"parameter `{0}`",
SemIR::NameId);
auto diag = context().emitter().Build(
loc_id_, DeductionInconsistent, entity_name.name_id);
NoteGenericHere(context(), generic_id_, diag);
diag.Emit();
}
return false;
if (!index.is_valid() || index < first_deduced_index_) {
break;
}

CARBON_CHECK(static_cast<size_t>(index.index) < result_arg_ids_.size(),
"Deduced value for unexpected index {0}; expected to "
"deduce {1} arguments.",
index, result_arg_ids_.size());
auto arg_const_inst_id =
context().constant_values().GetConstantInstId(arg_id);
if (arg_const_inst_id.is_valid()) {
if (result_arg_ids_[index.index].is_valid() &&
result_arg_ids_[index.index] != arg_const_inst_id) {
if (diagnose_) {
// TODO: Include the two different deduced values.
CARBON_DIAGNOSTIC(DeductionInconsistent, Error,
"inconsistent deductions for value of generic "
"parameter `{0}`",
SemIR::NameId);
auto diag = context().emitter().Build(
loc_id_, DeductionInconsistent, entity_name.name_id);
NoteGenericHere(context(), generic_id_, diag);
diag.Emit();
}
result_arg_ids_[index.index] = arg_const_inst_id;
return false;
}
result_arg_ids_[index.index] = arg_const_inst_id;
}
break;
continue;
}

// Various kinds of parameter should match an argument of the same form,
// if the operands all match.
case SemIR::ArrayType::Kind:
case SemIR::ClassType::Kind:
case SemIR::ConstType::Kind:
case SemIR::FloatType::Kind:
case SemIR::InterfaceType::Kind:
case SemIR::IntType::Kind:
case SemIR::PointerType::Kind:
case SemIR::TupleType::Kind:
case SemIR::TupleValue::Kind: {
auto arg_inst = context().insts().Get(arg_id);
if (arg_inst.kind() != param_inst.kind()) {
break;
}
auto [kind0, kind1] = param_inst.ArgKinds();
worklist_.AddInstArg(kind0, param_inst.arg0(), arg_inst.arg0(),
needs_substitution);
worklist_.AddInstArg(kind1, param_inst.arg1(), arg_inst.arg1(),
needs_substitution);
continue;
}

case SemIR::StructType::Kind:
case SemIR::StructValue::Kind:
// TODO: Match field name order between param and arg.
break;

// TODO: Handle more cases.

default:
break;
}

// If we've not yet substituted into the parameter, do so now and try again.
if (needs_substitution) {
param_const_id = SubstConstant(context(), param_const_id, substitutions_);
if (!param_const_id.is_valid() || !param_const_id.is_symbolic()) {
continue;
}
Add(context().constant_values().GetInstId(param_const_id), arg_id,
/*needs_substitution=*/false);
}
}

return true;
Expand Down
15 changes: 7 additions & 8 deletions toolchain/check/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1433,15 +1433,14 @@ static auto TryEvalInstInContext(EvalContext& eval_context,
// `const (const T)` evaluates to `const T`. Otherwise, `const T` evaluates
// to itself.
case CARBON_KIND(SemIR::ConstType typed_inst): {
auto inner_id = eval_context.GetConstantValue(typed_inst.inner_id);
if (inner_id.is_constant() &&
eval_context.insts()
.Get(eval_context.constant_values().GetInstId(inner_id))
.Is<SemIR::ConstType>()) {
return inner_id;
auto phase = Phase::Template;
auto inner_id =
GetConstantValue(eval_context, typed_inst.inner_id, &phase);
if (eval_context.context().types().Is<SemIR::ConstType>(inner_id)) {
return eval_context.context().types().GetConstantId(inner_id);
}
return MakeConstantResult(eval_context.context(), inst,
GetPhase(inner_id));
typed_inst.inner_id = inner_id;
return MakeConstantResult(eval_context.context(), typed_inst, phase);
}

// These cases are either not expressions or not constant.
Expand Down
Loading

0 comments on commit 6410d6e

Please sign in to comment.