diff --git a/include/shady/ir.h b/include/shady/ir.h index 216102eaa..ab7e684e5 100644 --- a/include/shady/ir.h +++ b/include/shady/ir.h @@ -209,6 +209,10 @@ typedef struct CompilerConfig_ { bool simt_to_explicit_simd; } lower; + struct { + bool spv_shuffle_instead_of_broadcast_first; + } hacks; + struct { bool memory_accesses; bool stack_accesses; diff --git a/src/runtime/runtime_program.c b/src/runtime/runtime_program.c index b31f87ff7..dfd1a7aad 100644 --- a/src/runtime/runtime_program.c +++ b/src/runtime/runtime_program.c @@ -94,8 +94,14 @@ static CompilerConfig get_compiler_config_for_device(Device* device) { if (!device->caps.features.subgroup_extended_types.shaderSubgroupExtendedTypes) config.lower.emulate_subgroup_ops_extended_types = true; - if (device->caps.implementation.is_moltenvk) + if (device->caps.implementation.is_moltenvk) { + warn_print("Hack: MoltenVK says they supported subgroup extended types, but it's a lie. 64-bit types are unaccounted for !\n"); config.lower.emulate_subgroup_ops_extended_types = true; + } + if (device->caps.base_properties.vendorID == 0x10de) { + warn_print("Hack: NVidia somehow has unreliable broadcast_first. Emulating it with shuffles seemingly fixes the issue.\n"); + config.hacks.spv_shuffle_instead_of_broadcast_first = true; + } config.logging.skip_generated = true; config.logging.skip_builtin = true; diff --git a/src/shady/emit/spirv/emit_spv.c b/src/shady/emit/spirv/emit_spv.c index c6ebeecca..b0bee0254 100644 --- a/src/shady/emit/spirv/emit_spv.c +++ b/src/shady/emit/spirv/emit_spv.c @@ -471,6 +471,10 @@ void emit_spirv(CompilerConfig* config, Module* mod, size_t* output_size, char** spvb_capability(file_builder, SpvCapabilityGroupNonUniformBallot); spvb_capability(file_builder, SpvCapabilityGroupNonUniformArithmetic); + // TODO track capabilities properly + if (emitter.configuration->hacks.spv_shuffle_instead_of_broadcast_first) + spvb_capability(file_builder, SpvCapabilityGroupNonUniformShuffle); + spvb_finish(file_builder, words); // cleanup the emitter diff --git a/src/shady/emit/spirv/emit_spv_instructions.c b/src/shady/emit/spirv/emit_spv_instructions.c index 48969900d..9f6c77d1e 100644 --- a/src/shady/emit/spirv/emit_spv_instructions.c +++ b/src/shady/emit/spirv/emit_spv_instructions.c @@ -131,11 +131,11 @@ static SpvOp get_opcode(Emitter* emitter, struct IselTableEntry entry, Nodes arg } static void emit_primop(Emitter* emitter, FnBuilder fn_builder, BBBuilder bb_builder, const Node* instr, size_t results_count, SpvId results[]) { - PrimOp prim_op = instr->payload.prim_op; - Nodes args = prim_op.operands; - Nodes type_arguments = prim_op.type_arguments; + PrimOp the_op = instr->payload.prim_op; + Nodes args = the_op.operands; + Nodes type_arguments = the_op.type_arguments; - struct IselTableEntry entry = isel_table[prim_op.op]; + struct IselTableEntry entry = isel_table[the_op.op]; if (entry.class != Custom) { LARRAY(SpvId, emitted_args, args.count); for (size_t i = 0; i < args.count; i++) @@ -176,7 +176,7 @@ static void emit_primop(Emitter* emitter, FnBuilder fn_builder, BBBuilder bb_bui return; } - switch (prim_op.op) { + switch (the_op.op) { case subgroup_ballot_op: { const Type* i32x4 = pack_type(emitter->arena, (PackType) { .width = 4, .element_type = int32_type(emitter->arena) }); SpvId scope_subgroup = emit_value(emitter, bb_builder, int32_literal(emitter->arena, SpvScopeSubgroup)); @@ -193,7 +193,16 @@ static void emit_primop(Emitter* emitter, FnBuilder fn_builder, BBBuilder bb_bui } case subgroup_broadcast_first_op: { SpvId scope_subgroup = emit_value(emitter, bb_builder, int32_literal(emitter->arena, SpvScopeSubgroup)); - SpvId result = spvb_broadcast_first(bb_builder, emit_type(emitter, get_unqualified_type(first(args)->type)), emit_value(emitter, bb_builder, first(args)), scope_subgroup); + SpvId result; + + if (emitter->configuration->hacks.spv_shuffle_instead_of_broadcast_first) { + SpvId local_id; + emit_primop(emitter, fn_builder, bb_builder, prim_op(emitter->arena, (PrimOp) { .op = subgroup_local_id_op }), 1, &local_id); + result = spvb_shuffle(bb_builder, emit_type(emitter, get_unqualified_type(first(args)->type)), scope_subgroup, emit_value(emitter, bb_builder, first(args)), local_id); + } else { + result = spvb_broadcast_first(bb_builder, emit_type(emitter, get_unqualified_type(first(args)->type)), emit_value(emitter, bb_builder, first(args)), scope_subgroup); + } + assert(results_count == 1); results[0] = result; return; diff --git a/src/shady/emit/spirv/spirv_builder.c b/src/shady/emit/spirv/spirv_builder.c index ae2e85f13..b9b868d00 100644 --- a/src/shady/emit/spirv/spirv_builder.c +++ b/src/shady/emit/spirv/spirv_builder.c @@ -302,6 +302,17 @@ SpvId spvb_broadcast_first(struct SpvBasicBlockBuilder* bb_builder, SpvId result return id; } +SpvId spvb_shuffle(struct SpvBasicBlockBuilder* bb_builder, SpvId result_type, SpvId scope, SpvId value, SpvId id) { + op(SpvOpGroupNonUniformShuffle, 6); + SpvId rid = spvb_fresh_id(bb_builder->fn_builder->file_builder); + ref_id(result_type); + ref_id(rid); + ref_id(scope); + ref_id(value); + ref_id(id); + return rid; +} + SpvId spvb_non_uniform_iadd(struct SpvBasicBlockBuilder* bb_builder, SpvId result_type, SpvId value, SpvId scope, SpvGroupOperation group_op, SpvId* cluster_size) { op(SpvOpGroupNonUniformIAdd, cluster_size ? 7 : 6); SpvId id = spvb_fresh_id(bb_builder->fn_builder->file_builder); diff --git a/src/shady/emit/spirv/spirv_builder.h b/src/shady/emit/spirv/spirv_builder.h index 1e1d470fd..139f6c0d8 100644 --- a/src/shady/emit/spirv/spirv_builder.h +++ b/src/shady/emit/spirv/spirv_builder.h @@ -29,6 +29,7 @@ SpvId spvb_binop(struct SpvBasicBlockBuilder* bb_builder, SpvOp op, SpvId result SpvId spvb_unop(struct SpvBasicBlockBuilder* bb_builder, SpvOp op, SpvId result_type, SpvId value); SpvId spvb_elect(struct SpvBasicBlockBuilder* bb_builder, SpvId result_type, SpvId scope); SpvId spvb_ballot(struct SpvBasicBlockBuilder*, SpvId result_t, SpvId predicate, SpvId scope); +SpvId spvb_shuffle(struct SpvBasicBlockBuilder* bb_builder, SpvId result_type, SpvId scope, SpvId value, SpvId id); SpvId spvb_broadcast_first(struct SpvBasicBlockBuilder*, SpvId result_t, SpvId value, SpvId scope); SpvId spvb_non_uniform_iadd(struct SpvBasicBlockBuilder*, SpvId result_t, SpvId value, SpvId scope, SpvGroupOperation group_op, SpvId* cluster_size); diff --git a/src/shady/passes/lower_tailcalls.c b/src/shady/passes/lower_tailcalls.c index aa1378317..24c4c6d31 100644 --- a/src/shady/passes/lower_tailcalls.c +++ b/src/shady/passes/lower_tailcalls.c @@ -219,7 +219,7 @@ void generate_top_level_dispatch_fn(Context* ctx) { if (ctx->config->printf_trace.god_function) { if (count_iterations) - bind_instruction(loop_body_builder, prim_op(dst_arena, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(dst_arena, string_lit(dst_arena, (StringLiteral) { .string = "trace: top loop, lid=%d iteration=%d next_fn=%d next_mask=%x\n" }), local_id, iterations_count_param, next_function, next_mask) })); + bind_instruction(loop_body_builder, prim_op(dst_arena, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(dst_arena, string_lit(dst_arena, (StringLiteral) { .string = "trace: top loop, lid=%d iteration=%d next_fn=%d next_mask=%lx\n" }), local_id, iterations_count_param, next_function, next_mask) })); else bind_instruction(loop_body_builder, prim_op(dst_arena, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(dst_arena, string_lit(dst_arena, (StringLiteral) { .string = "trace: top loop, lid=%d next_fn=%d next_mask=%x\n" }), local_id, next_function, next_mask) })); }