Skip to content

Commit

Permalink
hack: use shuffles for broadcast_first on Nvidia
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugobros3 committed Jan 8, 2023
1 parent 6a9ad5f commit f3ef83d
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 8 deletions.
4 changes: 4 additions & 0 deletions include/shady/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ typedef struct CompilerConfig_ {
bool simt_to_explicit_simd;
} lower;

struct {
bool spv_shuffle_instead_of_broadcast_first;
} hacks;

struct {
bool memory_accesses;
bool stack_accesses;
Expand Down
8 changes: 7 additions & 1 deletion src/runtime/runtime_program.c
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,14 @@ static CompilerConfig get_compiler_config_for_device(Device* device) {
if (!device->caps.features.subgroup_extended_types.shaderSubgroupExtendedTypes)
config.lower.emulate_subgroup_ops_extended_types = true;

if (device->caps.implementation.is_moltenvk)
if (device->caps.implementation.is_moltenvk) {
warn_print("Hack: MoltenVK says they supported subgroup extended types, but it's a lie. 64-bit types are unaccounted for !\n");
config.lower.emulate_subgroup_ops_extended_types = true;
}
if (device->caps.base_properties.vendorID == 0x10de) {
warn_print("Hack: NVidia somehow has unreliable broadcast_first. Emulating it with shuffles seemingly fixes the issue.\n");
config.hacks.spv_shuffle_instead_of_broadcast_first = true;
}

config.logging.skip_generated = true;
config.logging.skip_builtin = true;
Expand Down
4 changes: 4 additions & 0 deletions src/shady/emit/spirv/emit_spv.c
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,10 @@ void emit_spirv(CompilerConfig* config, Module* mod, size_t* output_size, char**
spvb_capability(file_builder, SpvCapabilityGroupNonUniformBallot);
spvb_capability(file_builder, SpvCapabilityGroupNonUniformArithmetic);

// TODO track capabilities properly
if (emitter.configuration->hacks.spv_shuffle_instead_of_broadcast_first)
spvb_capability(file_builder, SpvCapabilityGroupNonUniformShuffle);

spvb_finish(file_builder, words);

// cleanup the emitter
Expand Down
21 changes: 15 additions & 6 deletions src/shady/emit/spirv/emit_spv_instructions.c
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ static SpvOp get_opcode(Emitter* emitter, struct IselTableEntry entry, Nodes arg
}

static void emit_primop(Emitter* emitter, FnBuilder fn_builder, BBBuilder bb_builder, const Node* instr, size_t results_count, SpvId results[]) {
PrimOp prim_op = instr->payload.prim_op;
Nodes args = prim_op.operands;
Nodes type_arguments = prim_op.type_arguments;
PrimOp the_op = instr->payload.prim_op;
Nodes args = the_op.operands;
Nodes type_arguments = the_op.type_arguments;

struct IselTableEntry entry = isel_table[prim_op.op];
struct IselTableEntry entry = isel_table[the_op.op];
if (entry.class != Custom) {
LARRAY(SpvId, emitted_args, args.count);
for (size_t i = 0; i < args.count; i++)
Expand Down Expand Up @@ -176,7 +176,7 @@ static void emit_primop(Emitter* emitter, FnBuilder fn_builder, BBBuilder bb_bui

return;
}
switch (prim_op.op) {
switch (the_op.op) {
case subgroup_ballot_op: {
const Type* i32x4 = pack_type(emitter->arena, (PackType) { .width = 4, .element_type = int32_type(emitter->arena) });
SpvId scope_subgroup = emit_value(emitter, bb_builder, int32_literal(emitter->arena, SpvScopeSubgroup));
Expand All @@ -193,7 +193,16 @@ static void emit_primop(Emitter* emitter, FnBuilder fn_builder, BBBuilder bb_bui
}
case subgroup_broadcast_first_op: {
SpvId scope_subgroup = emit_value(emitter, bb_builder, int32_literal(emitter->arena, SpvScopeSubgroup));
SpvId result = spvb_broadcast_first(bb_builder, emit_type(emitter, get_unqualified_type(first(args)->type)), emit_value(emitter, bb_builder, first(args)), scope_subgroup);
SpvId result;

if (emitter->configuration->hacks.spv_shuffle_instead_of_broadcast_first) {
SpvId local_id;
emit_primop(emitter, fn_builder, bb_builder, prim_op(emitter->arena, (PrimOp) { .op = subgroup_local_id_op }), 1, &local_id);
result = spvb_shuffle(bb_builder, emit_type(emitter, get_unqualified_type(first(args)->type)), scope_subgroup, emit_value(emitter, bb_builder, first(args)), local_id);
} else {
result = spvb_broadcast_first(bb_builder, emit_type(emitter, get_unqualified_type(first(args)->type)), emit_value(emitter, bb_builder, first(args)), scope_subgroup);
}

assert(results_count == 1);
results[0] = result;
return;
Expand Down
11 changes: 11 additions & 0 deletions src/shady/emit/spirv/spirv_builder.c
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,17 @@ SpvId spvb_broadcast_first(struct SpvBasicBlockBuilder* bb_builder, SpvId result
return id;
}

SpvId spvb_shuffle(struct SpvBasicBlockBuilder* bb_builder, SpvId result_type, SpvId scope, SpvId value, SpvId id) {
op(SpvOpGroupNonUniformShuffle, 6);
SpvId rid = spvb_fresh_id(bb_builder->fn_builder->file_builder);
ref_id(result_type);
ref_id(rid);
ref_id(scope);
ref_id(value);
ref_id(id);
return rid;
}

SpvId spvb_non_uniform_iadd(struct SpvBasicBlockBuilder* bb_builder, SpvId result_type, SpvId value, SpvId scope, SpvGroupOperation group_op, SpvId* cluster_size) {
op(SpvOpGroupNonUniformIAdd, cluster_size ? 7 : 6);
SpvId id = spvb_fresh_id(bb_builder->fn_builder->file_builder);
Expand Down
1 change: 1 addition & 0 deletions src/shady/emit/spirv/spirv_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ SpvId spvb_binop(struct SpvBasicBlockBuilder* bb_builder, SpvOp op, SpvId result
SpvId spvb_unop(struct SpvBasicBlockBuilder* bb_builder, SpvOp op, SpvId result_type, SpvId value);
SpvId spvb_elect(struct SpvBasicBlockBuilder* bb_builder, SpvId result_type, SpvId scope);
SpvId spvb_ballot(struct SpvBasicBlockBuilder*, SpvId result_t, SpvId predicate, SpvId scope);
SpvId spvb_shuffle(struct SpvBasicBlockBuilder* bb_builder, SpvId result_type, SpvId scope, SpvId value, SpvId id);
SpvId spvb_broadcast_first(struct SpvBasicBlockBuilder*, SpvId result_t, SpvId value, SpvId scope);
SpvId spvb_non_uniform_iadd(struct SpvBasicBlockBuilder*, SpvId result_t, SpvId value, SpvId scope, SpvGroupOperation group_op, SpvId* cluster_size);

Expand Down
2 changes: 1 addition & 1 deletion src/shady/passes/lower_tailcalls.c
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ void generate_top_level_dispatch_fn(Context* ctx) {

if (ctx->config->printf_trace.god_function) {
if (count_iterations)
bind_instruction(loop_body_builder, prim_op(dst_arena, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(dst_arena, string_lit(dst_arena, (StringLiteral) { .string = "trace: top loop, lid=%d iteration=%d next_fn=%d next_mask=%x\n" }), local_id, iterations_count_param, next_function, next_mask) }));
bind_instruction(loop_body_builder, prim_op(dst_arena, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(dst_arena, string_lit(dst_arena, (StringLiteral) { .string = "trace: top loop, lid=%d iteration=%d next_fn=%d next_mask=%lx\n" }), local_id, iterations_count_param, next_function, next_mask) }));
else
bind_instruction(loop_body_builder, prim_op(dst_arena, (PrimOp) { .op = debug_printf_op, .operands = mk_nodes(dst_arena, string_lit(dst_arena, (StringLiteral) { .string = "trace: top loop, lid=%d next_fn=%d next_mask=%x\n" }), local_id, next_function, next_mask) }));
}
Expand Down

0 comments on commit f3ef83d

Please sign in to comment.