From 5d55792ba2b1bb807f75a43432c519278fc5d22a Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Sat, 30 Mar 2024 18:30:28 +0100 Subject: [PATCH] spirv: clz, ctz for opencl This instruction seems common in compiler_rt. --- src/codegen/spirv.zig | 80 ++++++++++++++++++++++++++++++++++++++++++ test/behavior/math.zig | 3 -- 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 6b13f2623aa3..9113d72d927d 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -2332,6 +2332,9 @@ const DeclGen = struct { .mul_add => try self.airMulAdd(inst), + .ctz => try self.airClzCtz(inst, .ctz), + .clz => try self.airClzCtz(inst, .clz), + .splat => try self.airSplat(inst), .reduce, .reduce_optimized => try self.airReduce(inst), .shuffle => try self.airShuffle(inst), @@ -3029,6 +3032,83 @@ const DeclGen = struct { return try wip.finalize(); } + fn airClzCtz(self: *DeclGen, inst: Air.Inst.Index, op: enum { clz, ctz }) !?IdRef { + if (self.liveness.isUnused(inst)) return null; + + const mod = self.module; + const target = self.getTarget(); + const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; + const result_ty = self.typeOfIndex(inst); + const operand_ty = self.typeOf(ty_op.operand); + const operand = try self.resolve(ty_op.operand); + + const info = self.arithmeticTypeInfo(operand_ty); + switch (info.class) { + .composite_integer => unreachable, // TODO + .integer, .strange_integer => {}, + .float, .bool => unreachable, + } + + var wip = try self.elementWise(result_ty, false); + defer wip.deinit(); + + const elem_ty = if (wip.is_array) operand_ty.scalarType(mod) else operand_ty; + const elem_ty_ref = try self.resolveType(elem_ty, .direct); + const elem_ty_id = self.typeId(elem_ty_ref); + + for (wip.results, 0..) |*result_id, i| { + const elem = try wip.elementAt(operand_ty, operand, i); + + switch (target.os.tag) { + .opencl => { + const set = try self.spv.importInstructionSet(.@"OpenCL.std"); + const ext_inst: u32 = switch (op) { + .clz => 151, // clz + .ctz => 152, // ctz + }; + + // Note: result of OpenCL ctz/clz returns operand_ty, and we want result_ty. + // result_ty is always large enough to hold the result, so we might have to down + // cast it. + const tmp = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ + .id_result_type = elem_ty_id, + .id_result = tmp, + .set = set, + .instruction = .{ .inst = ext_inst }, + .id_ref_4 = &.{elem}, + }); + + if (wip.ty_id == elem_ty_id) { + result_id.* = tmp; + continue; + } + + result_id.* = self.spv.allocId(); + if (result_ty.scalarType(mod).isSignedInt(mod)) { + assert(elem_ty.scalarType(mod).isSignedInt(mod)); + try self.func.body.emit(self.spv.gpa, .OpSConvert, .{ + .id_result_type = wip.ty_id, + .id_result = result_id.*, + .signed_value = tmp, + }); + } else { + assert(elem_ty.scalarType(mod).isUnsignedInt(mod)); + try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ + .id_result_type = wip.ty_id, + .id_result = result_id.*, + .unsigned_value = tmp, + }); + } + }, + .vulkan => unreachable, // TODO + else => unreachable, + } + } + + return try wip.finalize(); + } + fn airSplat(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; const operand_id = try self.resolve(ty_op.operand); diff --git a/test/behavior/math.zig b/test/behavior/math.zig index 092924af06ed..fbd8369219c6 100644 --- a/test/behavior/math.zig +++ b/test/behavior/math.zig @@ -65,7 +65,6 @@ test "@clz" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try testClz(); try comptime testClz(); @@ -148,7 +147,6 @@ test "@ctz" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try testCtz(); try comptime testCtz(); @@ -1752,7 +1750,6 @@ test "@clz works on both vector and scalar inputs" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; var x: u32 = 0x1; _ = &x;