Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

spirv: deduplication pass #19490

Merged
merged 5 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions src/codegen/spirv.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2332,6 +2332,9 @@ const DeclGen = struct {

.mul_add => try self.airMulAdd(inst),

.ctz => try self.airClzCtz(inst, .ctz),
.clz => try self.airClzCtz(inst, .clz),

.splat => try self.airSplat(inst),
.reduce, .reduce_optimized => try self.airReduce(inst),
.shuffle => try self.airShuffle(inst),
Expand Down Expand Up @@ -3029,6 +3032,83 @@ const DeclGen = struct {
return try wip.finalize();
}

fn airClzCtz(self: *DeclGen, inst: Air.Inst.Index, op: enum { clz, ctz }) !?IdRef {
if (self.liveness.isUnused(inst)) return null;

const mod = self.module;
const target = self.getTarget();
const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
const result_ty = self.typeOfIndex(inst);
const operand_ty = self.typeOf(ty_op.operand);
const operand = try self.resolve(ty_op.operand);

const info = self.arithmeticTypeInfo(operand_ty);
switch (info.class) {
.composite_integer => unreachable, // TODO
.integer, .strange_integer => {},
.float, .bool => unreachable,
}

var wip = try self.elementWise(result_ty, false);
defer wip.deinit();

const elem_ty = if (wip.is_array) operand_ty.scalarType(mod) else operand_ty;
const elem_ty_ref = try self.resolveType(elem_ty, .direct);
const elem_ty_id = self.typeId(elem_ty_ref);

for (wip.results, 0..) |*result_id, i| {
const elem = try wip.elementAt(operand_ty, operand, i);

switch (target.os.tag) {
.opencl => {
const set = try self.spv.importInstructionSet(.@"OpenCL.std");
const ext_inst: u32 = switch (op) {
.clz => 151, // clz
.ctz => 152, // ctz
};

// Note: result of OpenCL ctz/clz returns operand_ty, and we want result_ty.
// result_ty is always large enough to hold the result, so we might have to down
// cast it.
const tmp = self.spv.allocId();
try self.func.body.emit(self.spv.gpa, .OpExtInst, .{
.id_result_type = elem_ty_id,
.id_result = tmp,
.set = set,
.instruction = .{ .inst = ext_inst },
.id_ref_4 = &.{elem},
});

if (wip.ty_id == elem_ty_id) {
result_id.* = tmp;
continue;
}

result_id.* = self.spv.allocId();
if (result_ty.scalarType(mod).isSignedInt(mod)) {
assert(elem_ty.scalarType(mod).isSignedInt(mod));
try self.func.body.emit(self.spv.gpa, .OpSConvert, .{
.id_result_type = wip.ty_id,
.id_result = result_id.*,
.signed_value = tmp,
});
} else {
assert(elem_ty.scalarType(mod).isUnsignedInt(mod));
try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
.id_result_type = wip.ty_id,
.id_result = result_id.*,
.unsigned_value = tmp,
});
}
},
.vulkan => unreachable, // TODO
else => unreachable,
}
}

return try wip.finalize();
}

fn airSplat(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
const operand_id = try self.resolve(ty_op.operand);
Expand Down
2 changes: 2 additions & 0 deletions src/link/SpirV.zig
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,15 @@ fn linkModule(self: *SpirV, a: Allocator, module: []Word) ![]Word {

const lower_invocation_globals = @import("SpirV/lower_invocation_globals.zig");
const prune_unused = @import("SpirV/prune_unused.zig");
const dedup = @import("SpirV/deduplicate.zig");

var parser = try BinaryModule.Parser.init(a);
defer parser.deinit();
var binary = try parser.parse(module);

try lower_invocation_globals.run(&parser, &binary);
try prune_unused.run(&parser, &binary);
try dedup.run(&parser, &binary);

return binary.finalize(a);
}
Expand Down
2 changes: 2 additions & 0 deletions src/link/SpirV/BinaryModule.zig
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ pub const ParseError = error{
DuplicateId,
/// Some ID did not resolve.
InvalidId,
/// This opcode or instruction is not supported yet.
UnsupportedOperation,
/// Parser ran out of memory.
OutOfMemory,
};
Expand Down
Loading