Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support initialization of specific classes from struct literals #4320

Merged
merged 4 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions toolchain/check/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "toolchain/sem_ir/builtin_inst_kind.h"
#include "toolchain/sem_ir/file.h"
#include "toolchain/sem_ir/formatter.h"
#include "toolchain/sem_ir/generic.h"
#include "toolchain/sem_ir/ids.h"
#include "toolchain/sem_ir/import_ir.h"
#include "toolchain/sem_ir/inst.h"
Expand Down Expand Up @@ -873,7 +874,7 @@ class TypeCompleter {
if (inst.specific_id.is_valid()) {
ResolveSpecificDefinition(context_, inst.specific_id);
}
Push(class_info.object_repr_id);
Push(class_info.GetObjectRepr(context_.sem_ir(), inst.specific_id));
break;
}
case CARBON_KIND(SemIR::ConstType inst): {
Expand Down Expand Up @@ -1051,14 +1052,19 @@ class TypeCompleter {
// The value representation of an adapter is the value representation of
// its adapted type.
if (class_info.adapt_id.is_valid()) {
return GetNestedValueRepr(class_info.object_repr_id);
return GetNestedValueRepr(SemIR::GetTypeInSpecific(
context_.sem_ir(), inst.specific_id,
context_.insts()
.GetAs<SemIR::AdaptDecl>(class_info.adapt_id)
.adapted_type_id));
}
// Otherwise, the value representation for a class is a pointer to the
// object representation.
// TODO: Support customized value representations for classes.
// TODO: Pick a better value representation when possible.
return MakePointerValueRepr(class_info.object_repr_id,
SemIR::ValueRepr::ObjectAggregate);
return MakePointerValueRepr(
class_info.GetObjectRepr(context_.sem_ir(), inst.specific_id),
SemIR::ValueRepr::ObjectAggregate);
}

template <typename InstT>
Expand Down
35 changes: 21 additions & 14 deletions toolchain/check/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@ static auto ConvertStructToClass(Context& context, SemIR::StructType src_type,
SemIR::InstId value_id,
ConversionTarget target) -> SemIR::InstId {
PendingBlock target_block(context);
auto& class_info = context.classes().Get(dest_type.class_id);
if (class_info.inheritance_kind == SemIR::Class::Abstract) {
auto& dest_class_info = context.classes().Get(dest_type.class_id);
if (dest_class_info.inheritance_kind == SemIR::Class::Abstract) {
CARBON_DIAGNOSTIC(ConstructionOfAbstractClass, Error,
"Cannot construct instance of abstract class. "
"Consider using `partial {0}` instead.",
Expand All @@ -542,11 +542,13 @@ static auto ConvertStructToClass(Context& context, SemIR::StructType src_type,
target.type_id);
return SemIR::InstId::BuiltinError;
}
if (class_info.object_repr_id == SemIR::TypeId::Error) {
auto object_repr_id =
dest_class_info.GetObjectRepr(context.sem_ir(), dest_type.specific_id);
if (object_repr_id == SemIR::TypeId::Error) {
return SemIR::InstId::BuiltinError;
}
auto dest_struct_type =
context.types().GetAs<SemIR::StructType>(class_info.object_repr_id);
context.types().GetAs<SemIR::StructType>(object_repr_id);

// If we're trying to create a class value, form a temporary for the value to
// point to.
Expand All @@ -571,9 +573,10 @@ static auto ConvertStructToClass(Context& context, SemIR::StructType src_type,
return result_id;
}

// An inheritance path is a sequence of `BaseDecl`s in order from derived to
// base.
using InheritancePath = llvm::SmallVector<SemIR::InstId>;
// An inheritance path is a sequence of `BaseDecl`s and corresponding base types
// in order from derived to base.
using InheritancePath =
llvm::SmallVector<std::pair<SemIR::InstId, SemIR::TypeId>>;

// Computes the inheritance path from class `derived_id` to class `base_id`.
// Returns nullopt if `derived_id` is not a class derived from `base_id`.
Expand Down Expand Up @@ -602,10 +605,13 @@ static auto ComputeInheritancePath(Context& context, SemIR::TypeId derived_id,
result = std::nullopt;
break;
}
result->push_back(derived_class.base_id);
derived_id = context.insts()
.GetAs<SemIR::BaseDecl>(derived_class.base_id)
.base_type_id;
auto base_decl =
context.insts().GetAs<SemIR::BaseDecl>(derived_class.base_id);
auto base_type_id = SemIR::GetTypeInSpecific(
context.sem_ir(), derived_class_type->specific_id,
base_decl.base_type_id);
result->push_back({derived_class.base_id, base_type_id});
derived_id = base_type_id;
}
return result;
}
Expand All @@ -619,10 +625,10 @@ static auto ConvertDerivedToBase(Context& context, SemIR::LocId loc_id,
value_id = ConvertToValueOrRefExpr(context, value_id);

// Add a series of `.base` accesses.
for (auto base_id : path) {
for (auto [base_id, base_type_id] : path) {
auto base_decl = context.insts().GetAs<SemIR::BaseDecl>(base_id);
value_id = context.AddInst<SemIR::ClassElementAccess>(
loc_id, {.type_id = base_decl.base_type_id,
loc_id, {.type_id = base_type_id,
.base_id = value_id,
.index = base_decl.index});
}
Expand Down Expand Up @@ -677,7 +683,8 @@ static auto GetCompatibleBaseType(Context& context, SemIR::TypeId type_id)
if (auto class_type = context.types().TryGetAs<SemIR::ClassType>(type_id)) {
auto& class_info = context.classes().Get(class_type->class_id);
if (class_info.adapt_id.is_valid()) {
return class_info.object_repr_id;
return class_info.GetObjectRepr(context.sem_ir(),
class_type->specific_id);
}
}

Expand Down
3 changes: 3 additions & 0 deletions toolchain/check/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,9 @@ auto TryEvalInstInContext(EvalContext& eval_context, SemIR::InstId inst_id,
case SemIR::ClassType::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::ClassType::specific_id);
case SemIR::CompleteTypeWitness::Kind:
return RebuildIfFieldsAreConstant(
eval_context, inst, &SemIR::CompleteTypeWitness::object_repr_id);
case SemIR::FunctionType::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::FunctionType::specific_id);
Expand Down
120 changes: 77 additions & 43 deletions toolchain/check/handle_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ static auto MergeClassRedecl(Context& context, SemIRLoc new_loc,
prev_class.body_block_id = new_class.body_block_id;
prev_class.adapt_id = new_class.adapt_id;
prev_class.base_id = new_class.base_id;
prev_class.object_repr_id = new_class.object_repr_id;
prev_class.complete_type_witness_id = new_class.complete_type_witness_id;
}

if ((prev_import_ir_id.is_valid() && !new_is_import) ||
Expand Down Expand Up @@ -561,55 +561,89 @@ auto HandleParseNode(Context& context, Parse::BaseDeclId node_id) -> bool {
return true;
}

auto HandleParseNode(Context& context, Parse::ClassDefinitionId /*node_id*/)
// Checks that the specified finished adapter definition is valid and builds and
// returns a corresponding complete type witness instruction.
static auto CheckCompleteAdapterClassType(Context& context,
Parse::NodeId node_id,
SemIR::ClassId class_id,
SemIR::InstBlockId fields_id)
-> SemIR::InstId {
const auto& class_info = context.classes().Get(class_id);
if (class_info.base_id.is_valid()) {
CARBON_DIAGNOSTIC(AdaptWithBase, Error,
"Adapter cannot have a base class.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting to the diagnostic phrasing discussion, should this be "Adapter has a base class"?

Note I don't have a strong opinion, just asking to try to get a similar voice as we add diagnostics.

Oh, but more seriously... you can probably remove the capital and period now. The related PR is still pending but I think it'd make sense to stop writing those into new diagnostics.

CARBON_DIAGNOSTIC(AdaptBaseHere, Note, "`base` declaration is here.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW with this and AdaptFieldsHere, had you considered "AdaptWithBaseHere", using the main error as a prefix to tie the note more closely together? (or, AdaptBase/AdaptFields, dropping "With")

context.emitter()
.Build(class_info.adapt_id, AdaptWithBase)
.Note(class_info.base_id, AdaptBaseHere)
.Emit();
return SemIR::InstId::BuiltinError;
}

if (!context.inst_blocks().Get(fields_id).empty()) {
auto first_field_id = context.inst_blocks().Get(fields_id).front();
CARBON_DIAGNOSTIC(AdaptWithFields, Error, "Adapter cannot have fields.");
CARBON_DIAGNOSTIC(AdaptFieldHere, Note, "First field declaration is here.");
context.emitter()
.Build(class_info.adapt_id, AdaptWithFields)
.Note(first_field_id, AdaptFieldHere)
.Emit();
return SemIR::InstId::BuiltinError;
}

// The object representation of the adapter is the object representation
// of the adapted type. This is the adapted type itself unless it's a class
// type.
//
// TODO: The object representation of `const T` should also be the object
// representation of `T`.
auto adapted_type_id = context.insts()
.GetAs<SemIR::AdaptDecl>(class_info.adapt_id)
.adapted_type_id;
if (auto adapted_class =
context.types().TryGetAs<SemIR::ClassType>(adapted_type_id)) {
auto& adapted_class_info = context.classes().Get(adapted_class->class_id);
if (adapted_class_info.adapt_id.is_valid()) {
return adapted_class_info.complete_type_witness_id;
}
}

return context.AddInst<SemIR::CompleteTypeWitness>(
node_id,
{.type_id = context.GetBuiltinType(SemIR::BuiltinInstKind::WitnessType),
.object_repr_id = adapted_type_id});
}

// Checks that the specified finished class definition is valid and builds and
// returns a corresponding complete type witness instruction.
static auto CheckCompleteClassType(Context& context, Parse::NodeId node_id,
SemIR::ClassId class_id,
SemIR::InstBlockId fields_id)
-> SemIR::InstId {
auto& class_info = context.classes().Get(class_id);
if (class_info.adapt_id.is_valid()) {
return CheckCompleteAdapterClassType(context, node_id, class_id, fields_id);
}

return context.AddInst<SemIR::CompleteTypeWitness>(
node_id,
{.type_id = context.GetBuiltinType(SemIR::BuiltinInstKind::WitnessType),
.object_repr_id = context.GetStructType(fields_id)});
}

auto HandleParseNode(Context& context, Parse::ClassDefinitionId node_id)
-> bool {
auto fields_id = context.args_type_info_stack().Pop();
auto class_id =
context.node_stack().Pop<Parse::NodeKind::ClassDefinitionStart>();
context.inst_block_stack().Pop();

// The class type is now fully defined. Compute its object representation.
auto complete_type_witness_id =
CheckCompleteClassType(context, node_id, class_id, fields_id);
auto& class_info = context.classes().Get(class_id);
if (class_info.adapt_id.is_valid()) {
class_info.object_repr_id = SemIR::TypeId::Error;
if (class_info.base_id.is_valid()) {
CARBON_DIAGNOSTIC(AdaptWithBase, Error,
"Adapter cannot have a base class.");
CARBON_DIAGNOSTIC(AdaptBaseHere, Note, "`base` declaration is here.");
context.emitter()
.Build(class_info.adapt_id, AdaptWithBase)
.Note(class_info.base_id, AdaptBaseHere)
.Emit();
} else if (!context.inst_blocks().Get(fields_id).empty()) {
auto first_field_id = context.inst_blocks().Get(fields_id).front();
CARBON_DIAGNOSTIC(AdaptWithFields, Error, "Adapter cannot have fields.");
CARBON_DIAGNOSTIC(AdaptFieldHere, Note,
"First field declaration is here.");
context.emitter()
.Build(class_info.adapt_id, AdaptWithFields)
.Note(first_field_id, AdaptFieldHere)
.Emit();
} else {
// The object representation of the adapter is the object representation
// of the adapted type.
auto adapted_type_id = context.insts()
.GetAs<SemIR::AdaptDecl>(class_info.adapt_id)
.adapted_type_id;
// If we adapt an adapter, directly track the non-adapter type we're
// adapting so that we have constant-time access to it.
if (auto adapted_class =
context.types().TryGetAs<SemIR::ClassType>(adapted_type_id)) {
auto& adapted_class_info =
context.classes().Get(adapted_class->class_id);
if (adapted_class_info.adapt_id.is_valid()) {
adapted_type_id = adapted_class_info.object_repr_id;
}
}
class_info.object_repr_id = adapted_type_id;
}
} else {
class_info.object_repr_id = context.GetStructType(fields_id);
}
class_info.complete_type_witness_id = complete_type_witness_id;

context.inst_block_stack().Pop();

FinishGenericDefinition(context, class_info.generic_id);

Expand Down
34 changes: 26 additions & 8 deletions toolchain/check/import_ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "toolchain/sem_ir/import_ir.h"
#include "toolchain/sem_ir/inst.h"
#include "toolchain/sem_ir/inst_kind.h"
#include "toolchain/sem_ir/type_info.h"
#include "toolchain/sem_ir/typed_insts.h"

namespace Carbon::Check {
Expand Down Expand Up @@ -1011,6 +1012,9 @@ class ImportRefResolver {
case CARBON_KIND(SemIR::ClassType inst): {
return TryResolveTypedInst(inst);
}
case CARBON_KIND(SemIR::CompleteTypeWitness inst): {
return TryResolveTypedInst(inst);
}
case CARBON_KIND(SemIR::ConstType inst): {
return TryResolveTypedInst(inst);
}
Expand Down Expand Up @@ -1228,12 +1232,11 @@ class ImportRefResolver {
// Fills out the class definition for an incomplete class.
auto AddClassDefinition(const SemIR::Class& import_class,
SemIR::Class& new_class,
SemIR::ConstantId object_repr_const_id,
SemIR::InstId complete_type_witness_id,
SemIR::InstId base_id) -> void {
new_class.definition_id = new_class.first_owning_decl_id;

new_class.object_repr_id =
context_.GetTypeIdForTypeConstant(object_repr_const_id);
new_class.complete_type_witness_id = complete_type_witness_id;

new_class.scope_id = context_.name_scopes().Add(
new_class.first_owning_decl_id, SemIR::NameId::Invalid,
Expand Down Expand Up @@ -1312,10 +1315,10 @@ class ImportRefResolver {
auto param_const_ids = GetLocalParamConstantIds(import_class.param_refs_id);
auto generic_data = GetLocalGenericData(import_class.generic_id);
auto self_const_id = GetLocalConstantId(import_class.self_type_id);
auto object_repr_const_id =
import_class.object_repr_id.is_valid()
? GetLocalConstantId(import_class.object_repr_id)
: SemIR::ConstantId::Invalid;
auto complete_type_witness_id =
import_class.complete_type_witness_id.is_valid()
? GetLocalConstantInstId(import_class.complete_type_witness_id)
: SemIR::InstId::Invalid;
auto base_id = import_class.base_id.is_valid()
? GetLocalConstantInstId(import_class.base_id)
: SemIR::InstId::Invalid;
Expand All @@ -1334,7 +1337,7 @@ class ImportRefResolver {
new_class.self_type_id = context_.GetTypeIdForTypeConstant(self_const_id);

if (import_class.is_defined()) {
AddClassDefinition(import_class, new_class, object_repr_const_id,
AddClassDefinition(import_class, new_class, complete_type_witness_id,
base_id);
}

Expand Down Expand Up @@ -1368,6 +1371,21 @@ class ImportRefResolver {
}
}

auto TryResolveTypedInst(SemIR::CompleteTypeWitness inst) -> ResolveResult {
CARBON_CHECK(import_ir_.types().GetInstId(inst.type_id) ==
SemIR::InstId::BuiltinWitnessType);
auto object_repr_const_id = GetLocalConstantId(inst.object_repr_id);
if (HasNewWork()) {
return Retry();
}
auto object_repr_id =
context_.GetTypeIdForTypeConstant(object_repr_const_id);
return ResolveAs<SemIR::CompleteTypeWitness>(
{.type_id =
context_.GetBuiltinType(SemIR::BuiltinInstKind::WitnessType),
.object_repr_id = object_repr_id});
}

auto TryResolveTypedInst(SemIR::ConstType inst) -> ResolveResult {
CARBON_CHECK(inst.type_id == SemIR::TypeId::TypeType);
auto inner_const_id = GetLocalConstantId(inst.inner_id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ let d: c = {};
// CHECK:STDOUT: constants {
// CHECK:STDOUT: %C: type = class_type @C [template]
// CHECK:STDOUT: %.1: type = struct_type {} [template]
// CHECK:STDOUT: %.2: type = tuple_type () [template]
// CHECK:STDOUT: %.3: type = ptr_type %.1 [template]
// CHECK:STDOUT: %.2: <witness> = complete_type_witness %.1 [template]
// CHECK:STDOUT: %.3: type = tuple_type () [template]
// CHECK:STDOUT: %.4: type = ptr_type %.1 [template]
// CHECK:STDOUT: %struct: %C = struct_value () [template]
// CHECK:STDOUT: }
// CHECK:STDOUT:
Expand All @@ -43,6 +44,8 @@ let d: c = {};
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: class @C {
// CHECK:STDOUT: %.loc11: <witness> = complete_type_witness %.1 [template = constants.%.2]
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = constants.%C
// CHECK:STDOUT: }
Expand Down
Loading
Loading