Skip to content

Commit

Permalink
spirv: fix some recursive pointers edge cases in dedup pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Snektron committed Apr 6, 2024
1 parent 242e196 commit e4236bd
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 9 deletions.
2 changes: 0 additions & 2 deletions src/codegen/spirv.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1341,8 +1341,6 @@ const DeclGen = struct {

const child_ty_id = try self.resolveType(child_ty, .indirect);

assert(self.wip_pointers.remove(key));

try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpTypePointer, .{
.id_result = result_id,
.storage_class = storage_class,
Expand Down
69 changes: 62 additions & 7 deletions src/link/SpirV/deduplicate.zig
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ const ModuleInfo = struct {
result_id_index: u16,
/// The first decoration in `self.decorations`.
first_decoration: u32,

fn operands(self: Entity, binary: *const BinaryModule) []const Word {
return binary.instructions[self.first_operand..][0..self.num_operands];
}
};

/// Maps result-id to Entity's
Expand Down Expand Up @@ -210,10 +214,41 @@ const EntityContext = struct {

const entity = self.info.entities.values()[index];

// If the current pointer is recursive, don't immediately add it to the map. This is to ensure that
// if the current pointer is already recursive, it gets the same hash a pointer that points to the
// same child but has a different result-id.
if (entity.kind == .OpTypePointer) {
// This may be either a pointer that is forward-referenced in the future,
// or a forward reference to a pointer.
const entry = try self.ptr_map_a.getOrPut(self.a, id);
// Note: We use the **struct** here instead of the pointer itself, to avoid an edge case like this:
//
// A - C*'
// \
// C - C*'
// /
// B - C*"
//
// In this case, hashing A goes like
// A -> C*' -> C -> C*' recursion
// And hashing B goes like
// B -> C*" -> C -> C*' -> C -> C*' recursion
// The are several calls to ptrType in codegen that may C*' and C*" to be generated as separate
// types. This is not a problem for C itself though - this can only be generated through resolveType()
// and so ensures equality by Zig's type system. Technically the above problem is still present, but it
// would only be present in a structure such as
//
// A - C*' - C'
// \
// C*" - C - C*
// /
// B
//
// where there is a duplicate definition of struct C. Resolving this requires a much more time consuming
// algorithm though, and because we don't expect any correctness issues with it, we leave that for now.

// TODO: Do we need to mind the storage class here? Its going to be recursive regardless, right?
const struct_id: ResultId = @enumFromInt(entity.operands(self.binary)[2]);
const entry = try self.ptr_map_a.getOrPut(self.a, struct_id);
if (entry.found_existing) {
// Pointer already seen. Hash the index instead of recursing into its children.
std.hash.autoHash(hasher, entry.index);
Expand All @@ -228,12 +263,17 @@ const EntityContext = struct {
for (decorations) |decoration| {
try self.hashEntity(hasher, decoration);
}

if (entity.kind == .OpTypePointer) {
const struct_id: ResultId = @enumFromInt(entity.operands(self.binary)[2]);
assert(self.ptr_map_a.swapRemove(struct_id));
}
}

fn hashEntity(self: *EntityContext, hasher: *std.hash.Wyhash, entity: ModuleInfo.Entity) !void {
std.hash.autoHash(hasher, entity.kind);
// Process operands
const operands = self.binary.instructions[entity.first_operand..][0..entity.num_operands];
const operands = entity.operands(self.binary);
for (operands, 0..) |operand, i| {
if (i == entity.result_id_index) {
// Not relevant, skip...
Expand Down Expand Up @@ -273,12 +313,19 @@ const EntityContext = struct {
const entity_a = self.info.entities.values()[index_a];
const entity_b = self.info.entities.values()[index_b];

if (entity_a.kind != entity_b.kind) {
return false;
}

if (entity_a.kind == .OpTypePointer) {
// May be a forward reference, or should be saved as a potential
// forward reference in the future. Whatever the case, it should
// be the same for both a and b.
const entry_a = try self.ptr_map_a.getOrPut(self.a, id_a);
const entry_b = try self.ptr_map_b.getOrPut(self.a, id_b);
const struct_id_a: ResultId = @enumFromInt(entity_a.operands(self.binary)[2]);
const struct_id_b: ResultId = @enumFromInt(entity_b.operands(self.binary)[2]);

const entry_a = try self.ptr_map_a.getOrPut(self.a, struct_id_a);
const entry_b = try self.ptr_map_b.getOrPut(self.a, struct_id_b);

if (entry_a.found_existing != entry_b.found_existing) return false;
if (entry_a.index != entry_b.index) return false;
Expand Down Expand Up @@ -306,6 +353,14 @@ const EntityContext = struct {
}
}

if (entity_a.kind == .OpTypePointer) {
const struct_id_a: ResultId = @enumFromInt(entity_a.operands(self.binary)[2]);
const struct_id_b: ResultId = @enumFromInt(entity_b.operands(self.binary)[2]);

assert(self.ptr_map_a.swapRemove(struct_id_a));
assert(self.ptr_map_b.swapRemove(struct_id_b));
}

return true;
}

Expand All @@ -316,8 +371,8 @@ const EntityContext = struct {
return false;
}

const operands_a = self.binary.instructions[entity_a.first_operand..][0..entity_a.num_operands];
const operands_b = self.binary.instructions[entity_b.first_operand..][0..entity_b.num_operands];
const operands_a = entity_a.operands(self.binary);
const operands_b = entity_b.operands(self.binary);

// Note: returns false for operands that have explicit defaults in optional operands... oh well
if (operands_a.len != operands_b.len) {
Expand Down Expand Up @@ -463,7 +518,7 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule, progress: *std.P
if (entity.kind == .OpTypePointer and !emitted_ptrs.contains(id)) {
// Grab the pointer's storage class from its operands in the original
// module.
const storage_class: spec.StorageClass = @enumFromInt(binary.instructions[entity.first_operand + 1]);
const storage_class: spec.StorageClass = @enumFromInt(entity.operands(binary)[1]);
try section.emit(a, .OpTypeForwardPointer, .{
.pointer_type = id,
.storage_class = storage_class,
Expand Down

0 comments on commit e4236bd

Please sign in to comment.