Skip to content

Commit

Permalink
spirv: avoid copying operands in dedup pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Snektron committed Mar 30, 2024
1 parent 284eb4c commit 6209d85
Showing 1 changed file with 88 additions and 123 deletions.
211 changes: 88 additions & 123 deletions src/link/SpirV/deduplicate.zig
Original file line number Diff line number Diff line change
Expand Up @@ -35,72 +35,54 @@ const ModuleInfo = struct {
/// The type that this entity represents. This is just
/// the instruction opcode.
kind: Opcode,
/// Offset of first child result-id, stored in entity_children.
/// These are the shallow entities appearing directly in the
/// type's instruction.
first_child: u32,
/// Offset to the first word of extra-data: Data in the instruction
/// that must be considered for uniqueness, but doesn't include
/// any IDs.
first_extra_data: u32,
/// The offset of this entity's operands, in
/// `binary.instructions`.
first_operand: u32,
/// The number of operands in this entity
num_operands: u16,
/// The (first_operand-relative) offset of the result-id,
/// or the entity that is affected by this entity if this entity
/// is a decoration.
result_id_index: u16,
};

/// Maps result-id to Entity's
entities: std.AutoArrayHashMapUnmanaged(ResultId, Entity),
/// The list of children per instruction.
entity_children: []const ResultId,
/// The list of extra data per instruction.
/// TODO: This is a bit awkward, maybe we need to store it some
/// other way?
extra_data: []const u32,
/// A bit set that keeps track of which operands are result-ids.
/// Note: This also includes any result-id!
/// Because we need these values when recoding the module anyway,
/// it contains the status of ALL operands in the module.
operand_is_id: std.DynamicBitSetUnmanaged,

pub fn parse(
arena: Allocator,
parser: *BinaryModule.Parser,
binary: BinaryModule,
) !ModuleInfo {
var entities = std.AutoArrayHashMap(ResultId, Entity).init(arena);
var entity_children = std.ArrayList(ResultId).init(arena);
var extra_data = std.ArrayList(u32).init(arena);
var id_offsets = std.ArrayList(u16).init(arena);
var operand_is_id = try std.DynamicBitSetUnmanaged.initEmpty(arena, binary.instructions.len);

var it = binary.iterateInstructions();
while (it.next()) |inst| {
if (inst.opcode == .OpFunction) break; // No more declarations are possible
if (!canDeduplicate(inst.opcode)) continue;

id_offsets.items.len = 0;
try parser.parseInstructionResultIds(binary, inst, &id_offsets);

const result_id_index: u32 = switch (inst.opcode.class()) {
const first_operand_offset: u32 = @intCast(inst.offset + 1);
for (id_offsets.items) |offset| {
operand_is_id.set(first_operand_offset + offset);
}

if (!canDeduplicate(inst.opcode)) continue;

const result_id_index: u16 = switch (inst.opcode.class()) {
.TypeDeclaration, .Annotation, .Debug => 0,
.ConstantCreation => 1,
else => unreachable,
};

const result_id: ResultId = @enumFromInt(inst.operands[id_offsets.items[result_id_index]]);

const first_child: u32 = @intCast(entity_children.items.len);
const first_extra_data: u32 = @intCast(extra_data.items.len);

try entity_children.ensureUnusedCapacity(id_offsets.items.len - 1);
try extra_data.ensureUnusedCapacity(inst.operands.len - id_offsets.items.len);

var id_i: usize = 0;
for (inst.operands, 0..) |operand, i| {
assert(id_i == id_offsets.items.len or id_offsets.items[id_i] >= i);
if (id_i != id_offsets.items.len and id_offsets.items[id_i] == i) {
// Skip .IdResult / .IdResultType.
if (id_i != result_id_index) {
entity_children.appendAssumeCapacity(@enumFromInt(operand));
}
id_i += 1;
} else {
// Non-id operand, add it to extra data.
extra_data.appendAssumeCapacity(operand);
}
}

switch (inst.opcode.class()) {
.Annotation, .Debug => {
// TODO
Expand All @@ -113,8 +95,9 @@ const ModuleInfo = struct {
}
entry.value_ptr.* = .{
.kind = inst.opcode,
.first_child = first_child,
.first_extra_data = first_extra_data,
.first_operand = first_operand_offset,
.num_operands = @intCast(inst.operands.len),
.result_id_index = result_id_index,
};
},
else => unreachable,
Expand All @@ -123,48 +106,17 @@ const ModuleInfo = struct {

return ModuleInfo{
.entities = entities.unmanaged,
.entity_children = entity_children.items,
.extra_data = extra_data.items,
.operand_is_id = operand_is_id,
};
}

/// Fetch a slice of children for the index corresponding to an entity.
fn childrenByIndex(self: ModuleInfo, index: usize) []const ResultId {
const values = self.entities.values();
const first_child = values[index].first_child;
if (index == values.len - 1) {
return self.entity_children[first_child..];
} else {
const next_first_child = values[index + 1].first_child;
return self.entity_children[first_child..next_first_child];
}
}

/// Fetch the slice of extra-data for the index corresponding to an entity.
fn extraDataByIndex(self: ModuleInfo, index: usize) []const u32 {
const values = self.entities.values();
const first_extra_data = values[index].first_extra_data;
if (index == values.len - 1) {
return self.extra_data[first_extra_data..];
} else {
const next_extra_data = values[index + 1].first_extra_data;
return self.extra_data[first_extra_data..next_extra_data];
}
}
};

const EntityContext = struct {
a: Allocator,
ptr_map_a: std.AutoArrayHashMapUnmanaged(ResultId, void) = .{},
ptr_map_b: std.AutoArrayHashMapUnmanaged(ResultId, void) = .{},
info: *const ModuleInfo,

fn init(a: Allocator, info: *const ModuleInfo) EntityContext {
return .{
.a = a,
.info = info,
};
}
binary: *const BinaryModule,

fn deinit(self: *EntityContext) void {
self.ptr_map_a.deinit(self.a);
Expand Down Expand Up @@ -203,14 +155,19 @@ const EntityContext = struct {
}
}

// Hash extra data
for (self.info.extraDataByIndex(index)) |data| {
std.hash.autoHash(hasher, data);
}

// Hash children
for (self.info.childrenByIndex(index)) |child| {
try self.hashInner(hasher, child);
// Process operands
const operands = self.binary.instructions[entity.first_operand..][0..entity.num_operands];
for (operands, 0..) |operand, i| {
if (i == entity.result_id_index) {
// Not relevant, skip...
continue;
} else if (self.info.operand_is_id.isSet(entity.first_operand + i)) {
// Operand is ID
try self.hashInner(hasher, @enumFromInt(operand));
} else {
// Operand is merely data
std.hash.autoHash(hasher, operand);
}
}
}

Expand All @@ -228,7 +185,11 @@ const EntityContext = struct {
const entity_a = self.info.entities.values()[index_a];
const entity_b = self.info.entities.values()[index_b];

if (entity_a.kind != entity_b.kind) return false;
if (entity_a.kind != entity_b.kind) {
return false;
} else if (entity_a.result_id_index != entity_a.result_id_index) {
return false;
}

if (entity_a.kind == .OpTypePointer) {
// May be a forward reference, or should be saved as a potential
Expand All @@ -246,18 +207,28 @@ const EntityContext = struct {
}
}

// Check if extra data is the same.
if (!std.mem.eql(u32, self.info.extraDataByIndex(index_a), self.info.extraDataByIndex(index_b))) {
const operands_a = self.binary.instructions[entity_a.first_operand..][0..entity_a.num_operands];
const operands_b = self.binary.instructions[entity_b.first_operand..][0..entity_b.num_operands];

// Note: returns false for operands that have explicit defaults in optional operands... oh well
if (operands_a.len != operands_b.len) {
return false;
}

// Recursively check if children are the same
const children_a = self.info.childrenByIndex(index_a);
const children_b = self.info.childrenByIndex(index_b);
if (children_a.len != children_b.len) return false;

for (children_a, children_b) |child_a, child_b| {
if (!try self.eqlInner(child_a, child_b)) {
for (operands_a, operands_b, 0..) |operand_a, operand_b, i| {
const a_is_id = self.info.operand_is_id.isSet(entity_a.first_operand + i);
const b_is_id = self.info.operand_is_id.isSet(entity_b.first_operand + i);
if (a_is_id != b_is_id) {
return false;
} else if (i == entity_a.result_id_index) {
// result-id for both...
continue;
} else if (a_is_id) {
// Both are IDs, so recurse.
if (!try self.eqlInner(@enumFromInt(operand_a), @enumFromInt(operand_b))) {
return false;
}
} else if (operand_a != operand_b) {
return false;
}
}
Expand Down Expand Up @@ -290,11 +261,13 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void {

const info = try ModuleInfo.parse(a, parser, binary.*);
log.info("added {} entities", .{info.entities.count()});
log.info("children size: {}", .{info.entity_children.len});
log.info("extra data size: {}", .{info.extra_data.len});

// Hash all keys once so that the maps can be allocated the right size.
var ctx = EntityContext.init(a, &info);
var ctx = EntityContext{
.a = a,
.info = &info,
.binary = binary,
};
for (info.entities.keys()) |id| {
_ = try ctx.hash(id);
}
Expand All @@ -318,7 +291,6 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void {
// Now process the module, and replace instructions where needed.
var section = Section{};
var it = binary.iterateInstructions();
var id_offsets = std.ArrayList(u16).init(a);
var new_functions_section: ?usize = null;
var new_operands = std.ArrayList(u32).init(a);
var emitted_ptrs = std.AutoHashMap(ResultId, void).init(a);
Expand Down Expand Up @@ -347,38 +319,31 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void {

// Re-emit the instruction, but replace all the IDs.

id_offsets.items.len = 0;
try parser.parseInstructionResultIds(binary.*, inst, &id_offsets);

new_operands.items.len = 0;
try new_operands.appendSlice(inst.operands);
for (id_offsets.items) |offset| {
{
const id: ResultId = @enumFromInt(inst.operands[offset]);
if (replace.get(id)) |new_id| {
new_operands.items[offset] = @intFromEnum(new_id);
}

for (new_operands.items, 0..) |*operand, i| {
const is_id = info.operand_is_id.isSet(inst.offset + 1 + i);
if (!is_id) continue;

if (replace.get(@enumFromInt(operand.*))) |new_id| {
operand.* = @intFromEnum(new_id);
}

// TODO: Does this logic work? Maybe it will emit an OpTypeForwardPointer to
// something thats not a struct...
// It seems to work correctly on behavior.zig at least
const id: ResultId = @enumFromInt(new_operands.items[offset]);
const id: ResultId = @enumFromInt(operand.*);
// TODO: This test is a little janky. Check the offset instead?
if (maybe_result_id == null or maybe_result_id.? != id) {
const index = info.entities.getIndex(id) orelse continue;
const entity = info.entities.values()[index];
if (entity.kind == .OpTypePointer) {
if (!emitted_ptrs.contains(id)) {
// The storage class is in the extra data
// TODO: This is kind of hacky...
const extra_data = info.extraDataByIndex(index);
const storage_class: spec.StorageClass = @enumFromInt(extra_data[0]);
try section.emit(a, .OpTypeForwardPointer, .{
.pointer_type = id,
.storage_class = storage_class,
});
try emitted_ptrs.put(id, {});
}
if (entity.kind == .OpTypePointer and !emitted_ptrs.contains(id)) {
// Grab the pointer's storage class from its operands in the original
// module.
const storage_class: spec.StorageClass = @enumFromInt(binary.instructions[entity.first_operand + 1]);
try section.emit(a, .OpTypeForwardPointer, .{
.pointer_type = id,
.storage_class = storage_class,
});
try emitted_ptrs.put(id, {});
}
}
}
Expand Down

0 comments on commit 6209d85

Please sign in to comment.