Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce private and protected access modifiers for class member access #4248

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 137 additions & 16 deletions toolchain/check/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "toolchain/check/context.h"

#include <optional>
#include <string>
#include <utility>

Expand Down Expand Up @@ -292,7 +293,8 @@ auto Context::LookupNameInDecl(SemIR::LocId loc_id, SemIR::NameId name_id,
// // Error, no `F` in `B`.
// fn B.F() {}
return LookupNameInExactScope(loc_id, name_id, scope_id,
name_scopes().Get(scope_id));
name_scopes().Get(scope_id))
.first;
}
}

Expand Down Expand Up @@ -334,37 +336,141 @@ auto Context::LookupUnqualifiedName(Parse::NodeId node_id,
auto Context::LookupNameInExactScope(SemIRLoc loc, SemIR::NameId name_id,
SemIR::NameScopeId scope_id,
const SemIR::NameScope& scope)
-> SemIR::InstId {
-> std::pair<SemIR::InstId, SemIR::AccessKind> {
if (auto lookup = scope.name_map.Lookup(name_id)) {
auto inst_id = scope.names[lookup.value()].inst_id;
LoadImportRef(*this, inst_id);
return inst_id;
auto entry = scope.names[lookup.value()];
LoadImportRef(*this, entry.inst_id);
return {entry.inst_id, entry.access_kind};
}

if (!scope.import_ir_scopes.empty()) {
return ImportNameFromOtherPackage(*this, loc, scope_id,
scope.import_ir_scopes, name_id);
// TODO: Enforce other access modifiers for imports.
return {ImportNameFromOtherPackage(*this, loc, scope_id,
scope.import_ir_scopes, name_id),
SemIR::AccessKind::Public};
}
return {SemIR::InstId::Invalid, SemIR::AccessKind::Public};
}

// Prints diagnostics on invalid qualified name access.
static auto DiagnoseInvalidQualifiedNameAccess(Context& context, SemIRLoc loc,
SemIR::InstId scope_result_id,
SemIR::NameId name_id,
SemIR::AccessKind access_kind,
bool is_parent_access,
AccessInfo access_info) -> void {
auto class_type = context.insts().TryGetAs<SemIR::ClassType>(
context.constant_values().GetInstId(access_info.constant_id));
if (!class_type) {
return;
}

// TODO: Support scoped entities other than just classes.
auto class_info = context.classes().Get(class_type->class_id);

// TODO: Support passing AccessKind to diagnostics.
CARBON_DIAGNOSTIC(ClassInvalidMemberAccess, Error,
"Cannot access {0} member `{1}` of type `{2}`.",
llvm::StringLiteral, SemIR::NameId, SemIR::TypeId);
CARBON_DIAGNOSTIC(ClassMemberDefinition, Note,
"The {0} member `{1}` is defined here.",
llvm::StringLiteral, SemIR::NameId);

auto parent_type_id = class_info.self_type_id;
auto access_desc = access_kind == SemIR::AccessKind::Private
? llvm::StringLiteral("private")
: llvm::StringLiteral("protected");

if (access_kind == SemIR::AccessKind::Private && is_parent_access) {
if (auto base_decl = context.insts().TryGetAsIfValid<SemIR::BaseDecl>(
class_info.base_id)) {
parent_type_id = base_decl->base_type_id;
} else if (auto adapt_decl =
context.insts().TryGetAsIfValid<SemIR::AdaptDecl>(
class_info.adapt_id)) {
parent_type_id = adapt_decl->adapted_type_id;
} else {
CARBON_FATAL() << "Expected parent for parent access";
}
}
return SemIR::InstId::Invalid;

context.emitter()
.Build(loc, ClassInvalidMemberAccess, access_desc, name_id,
parent_type_id)
.Note(scope_result_id, ClassMemberDefinition, access_desc, name_id)
.Emit();
}

// Returns whether the access is prohibited by the access modifiers.
static auto IsAccessProhibited(std::optional<AccessInfo> access_info,
SemIR::AccessKind access_kind,
bool is_parent_access) -> bool {
if (!access_info) {
return false;
}

switch (access_kind) {
case SemIR::AccessKind::Public:
return false;
case SemIR::AccessKind::Protected:
return access_info->highest_allowed_access == SemIR::AccessKind::Public;
case SemIR::AccessKind::Private:
return access_info->highest_allowed_access !=
SemIR::AccessKind::Private ||
is_parent_access;
}
}

// Information regarding a prohibited access.
struct ProhibitedAccessInfo {
// The resulting inst of the lookup.
SemIR::InstId scope_result_id;
// The access kind of the lookup.
SemIR::AccessKind access_kind;
// If the lookup is from an extended scope. For example, if this is a base
// class member access from a class that extends it.
bool is_parent_access;
};

auto Context::LookupQualifiedName(SemIRLoc loc, SemIR::NameId name_id,
LookupScope scope, bool required)
LookupScope scope, bool required,
std::optional<AccessInfo> access_info)
-> LookupResult {
llvm::SmallVector<LookupScope> scopes = {scope};

// TODO: Support reporting of multiple prohibited access.
llvm::SmallVector<ProhibitedAccessInfo> prohibited_accesses;

LookupResult result = {.specific_id = SemIR::SpecificId::Invalid,
.inst_id = SemIR::InstId::Invalid};
bool has_error = false;
bool is_parent_access = false;

// Walk this scope and, if nothing is found here, the scopes it extends.
while (!scopes.empty()) {
auto [scope_id, specific_id] = scopes.pop_back_val();
const auto& name_scope = name_scopes().Get(scope_id);
has_error |= name_scope.has_error;

auto scope_result_id =
auto [scope_result_id, access_kind] =
LookupNameInExactScope(loc, name_id, scope_id, name_scope);
if (!scope_result_id.is_valid()) {
// Nothing found in this scope: also look in its extended scopes.

auto is_access_prohibited =
IsAccessProhibited(access_info, access_kind, is_parent_access);

// Keep track of prohibited accesses, this will be useful for reporting
// multiple prohibited accesses if we can't find a suitable lookup.
if (is_access_prohibited) {
prohibited_accesses.push_back({
.scope_result_id = scope_result_id,
.access_kind = access_kind,
.is_parent_access = is_parent_access,
});
}

if (!scope_result_id.is_valid() || is_access_prohibited) {
// If nothing is found in this scope or if we encountered an invalid
// access, look in its extended scopes.
auto extended = name_scope.extended_scopes;
scopes.reserve(scopes.size() + extended.size());
for (auto extended_id : llvm::reverse(extended)) {
Expand All @@ -373,6 +479,7 @@ auto Context::LookupQualifiedName(SemIRLoc loc, SemIR::NameId name_id,
scopes.push_back({.name_scope_id = extended_id,
.specific_id = SemIR::SpecificId::Invalid});
}
is_parent_access |= !extended.empty();
continue;
}

Expand All @@ -397,8 +504,22 @@ auto Context::LookupQualifiedName(SemIRLoc loc, SemIR::NameId name_id,

if (required && !result.inst_id.is_valid()) {
if (!has_error) {
DiagnoseNameNotFound(loc, name_id);
if (prohibited_accesses.empty()) {
DiagnoseNameNotFound(loc, name_id);
} else {
// TODO: We should report multiple prohibited accesses in case we don't
// find a valid lookup. Reporting the last one should suffice for now.
auto [scope_result_id, access_kind, is_parent_access] =
prohibited_accesses.back();

// Note, `access_info` is guaranteed to have a value here, since
// `prohibited_accesses` is non-empty.
DiagnoseInvalidQualifiedNameAccess(*this, loc, scope_result_id, name_id,
access_kind, is_parent_access,
*access_info);
}
}

return {.specific_id = SemIR::SpecificId::Invalid,
.inst_id = SemIR::InstId::BuiltinError};
}
Expand All @@ -420,7 +541,7 @@ static auto GetCorePackage(Context& context, SemIRLoc loc)
auto core_name_id = SemIR::NameId::ForIdentifier(core_ident_id);

// Look up `package.Core`.
auto core_inst_id = context.LookupNameInExactScope(
auto [core_inst_id, _] = context.LookupNameInExactScope(
loc, core_name_id, SemIR::NameScopeId::Package,
context.name_scopes().Get(SemIR::NameScopeId::Package));
if (core_inst_id.is_valid()) {
Expand All @@ -447,8 +568,8 @@ auto Context::LookupNameInCore(SemIRLoc loc, llvm::StringRef name)
}

auto name_id = SemIR::NameId::ForIdentifier(identifiers().Add(name));
auto inst_id = LookupNameInExactScope(loc, name_id, core_package_id,
name_scopes().Get(core_package_id));
auto [inst_id, _] = LookupNameInExactScope(
loc, name_id, core_package_id, name_scopes().Get(core_package_id));
if (!inst_id.is_valid()) {
CARBON_DIAGNOSTIC(
CoreNameNotFound, Error,
Expand Down
17 changes: 15 additions & 2 deletions toolchain/check/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "toolchain/sem_ir/ids.h"
#include "toolchain/sem_ir/import_ir.h"
#include "toolchain/sem_ir/inst.h"
#include "toolchain/sem_ir/name_scope.h"
#include "toolchain/sem_ir/typed_insts.h"

namespace Carbon::Check {
Expand All @@ -46,6 +47,16 @@ struct LookupResult {
SemIR::InstId inst_id;
};

// Information about an access.
struct AccessInfo {
// The constant being accessed.
SemIR::ConstantId constant_id;

// The highest allowed access for a lookup. For example, `Protected` allows
// access to `Public` and `Protected` names, but not `Private`.
SemIR::AccessKind highest_allowed_access;
};

// Context and shared functionality for semantics handlers.
class Context {
public:
Expand Down Expand Up @@ -173,12 +184,14 @@ class Context {
// instruction if the name is not found.
auto LookupNameInExactScope(SemIRLoc loc, SemIR::NameId name_id,
SemIR::NameScopeId scope_id,
const SemIR::NameScope& scope) -> SemIR::InstId;
const SemIR::NameScope& scope)
-> std::pair<SemIR::InstId, SemIR::AccessKind>;

// Performs a qualified name lookup in a specified scope and in scopes that
// it extends, returning the referenced instruction.
auto LookupQualifiedName(SemIRLoc loc, SemIR::NameId name_id,
LookupScope scope, bool required = true)
LookupScope scope, bool required = true,
std::optional<AccessInfo> access_info = std::nullopt)
-> LookupResult;

// Returns the instruction corresponding to a name in the core package, or
Expand Down
2 changes: 1 addition & 1 deletion toolchain/check/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ static auto BuildInterfaceWitness(
CARBON_FATAL() << "Unexpected type: " << type_inst;
}
auto& fn = context.functions().Get(fn_type->function_id);
auto impl_decl_id = context.LookupNameInExactScope(
auto [impl_decl_id, _] = context.LookupNameInExactScope(
decl_id, fn.name_id, impl.scope_id, impl_scope);
if (impl_decl_id.is_valid()) {
used_decl_ids.push_back(impl_decl_id);
Expand Down
92 changes: 91 additions & 1 deletion toolchain/check/member_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "toolchain/check/member_access.h"

#include <optional>

#include "llvm/ADT/STLExtras.h"
#include "toolchain/base/kind_switch.h"
#include "toolchain/check/context.h"
Expand All @@ -13,6 +15,7 @@
#include "toolchain/sem_ir/generic.h"
#include "toolchain/sem_ir/ids.h"
#include "toolchain/sem_ir/inst.h"
#include "toolchain/sem_ir/name_scope.h"
#include "toolchain/sem_ir/typed_insts.h"

namespace Carbon::Check {
Expand Down Expand Up @@ -98,6 +101,82 @@ static auto IsInstanceMethod(const SemIR::File& sem_ir,
return false;
}

// Returns the FunctionId of the current function if it exists.
static auto GetCurrentFunction(Context& context)
-> std::optional<SemIR::FunctionId> {
if (context.return_scope_stack().empty()) {
return std::nullopt;
}

return context.insts()
.GetAs<SemIR::FunctionDecl>(context.return_scope_stack().back().decl_id)
.function_id;
}

// Returns the highest allowed access. For example, if this returns `Protected`
// then only `Public` and `Protected` accesses are allowed--not `Private`.
static auto GetHighestAllowedAccess(Context& context, SemIRLoc loc,
SemIR::ConstantId name_scope_const_id)
-> SemIR::AccessKind {
// TODO: Maybe use LookupUnqualifiedName for `Self` to support things like
// `var x: Self.ParentProtectedType`?
auto current_function = GetCurrentFunction(context);
brymer-meneses marked this conversation as resolved.
Show resolved Hide resolved
// If `current_function` is a `nullopt` then we're accessing from a global
// variable.
if (!current_function) {
return SemIR::AccessKind::Public;
}

auto scope_id = context.functions().Get(*current_function).parent_scope_id;
if (!scope_id.is_valid()) {
return SemIR::AccessKind::Public;
}
auto scope = context.name_scopes().Get(scope_id);

// Lookup the inst for `Self` in the parent scope of the current function.
auto [self_type_inst_id, _] = context.LookupNameInExactScope(
loc, SemIR::NameId::SelfType, scope_id, scope);
if (!self_type_inst_id.is_valid()) {
return SemIR::AccessKind::Public;
}

// TODO: Support other types for `Self`.
auto self_class_type =
context.insts().TryGetAs<SemIR::ClassType>(self_type_inst_id);
if (!self_class_type) {
return SemIR::AccessKind::Public;
}

auto self_class_info = context.classes().Get(self_class_type->class_id);

// TODO: Support other types.
if (auto class_type = context.insts().TryGetAs<SemIR::ClassType>(
context.constant_values().GetInstId(name_scope_const_id))) {
auto class_info = context.classes().Get(class_type->class_id);

if (self_class_info.self_type_id == class_info.self_type_id) {
return SemIR::AccessKind::Private;
}

// If the `type_id` of `Self` does not match with the one we're currently
// accessing, try checking if this class is of the parent type of `Self`.
if (auto base_decl = context.insts().TryGetAsIfValid<SemIR::BaseDecl>(
self_class_info.base_id)) {
if (base_decl->base_type_id == class_info.self_type_id) {
return SemIR::AccessKind::Protected;
}
} else if (auto adapt_decl =
context.insts().TryGetAsIfValid<SemIR::AdaptDecl>(
self_class_info.adapt_id)) {
if (adapt_decl->adapted_type_id == class_info.self_type_id) {
return SemIR::AccessKind::Protected;
}
}
}

return SemIR::AccessKind::Public;
}

// Returns whether `scope` is a scope for which impl lookup should be performed
// if we find an associated entity.
static auto ScopeNeedsImplLookup(Context& context, LookupScope scope) -> bool {
Expand Down Expand Up @@ -224,7 +303,18 @@ static auto LookupMemberNameInScope(Context& context, SemIR::LocId loc_id,
LookupResult result = {.specific_id = SemIR::SpecificId::Invalid,
.inst_id = SemIR::InstId::BuiltinError};
if (lookup_scope.name_scope_id.is_valid()) {
result = context.LookupQualifiedName(loc_id, name_id, lookup_scope);
AccessInfo access_info = {
.constant_id = name_scope_const_id,
.highest_allowed_access =
GetHighestAllowedAccess(context, loc_id, name_scope_const_id),
};

result = context.LookupQualifiedName(loc_id, name_id, lookup_scope,
/*required=*/true, access_info);

if (!result.inst_id.is_valid()) {
return SemIR::InstId::BuiltinError;
}
}

// TODO: This duplicates the work that HandleNameAsExpr does. Factor this out.
Expand Down
Loading
Loading