Skip to content

Commit

Permalink
Try fix ck decoder compilation in fbcode (2/n) (#1012)
Browse files Browse the repository at this point in the history
* fix includes in decoder and splitk decoder

* resolve compilation error by adding -Werror to cpp extensions compiler flags

and fixing a hidden override
  • Loading branch information
tenpercent authored Apr 2, 2024
1 parent 7fffd3d commit 09773f8
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 8 deletions.
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ def get_extensions():
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-DCK_FMHA_FWD_FAST_EXP2=1",
"-fgpu-flush-denormals-to-zero",
"-Werror",
"-Woverloaded-virtual",
]
+ generator_flag
+ cc_flag,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ at::Tensor& efficient_attention_forward_decoder_ck_out_impl(
lds_bytes);

auto invoker = device_op_t::Invoker{};
(void)invoker.Run(arg, {stream});
(void)invoker.Run(&arg, {stream});
});

return O;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ at::Tensor& efficient_attention_forward_decoder_splitk_ck_out_impl(
lds_bytes);

auto invoker = device_op_t::Invoker{};
(void)invoker.Run(arg, {stream});
(void)invoker.Run(&arg, {stream});
});

return O;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,11 @@ struct FMHADecoderSeqlen1DeviceOp : public BaseOperator {

struct Invoker : public BaseInvoker {
using Argument = DeviceOp::Argument;
float Run(
const Argument& arg,
const StreamConfig& stream_config = StreamConfig{}) {
virtual float Run(
const BaseArgument* base_arg,
const StreamConfig& stream_config = StreamConfig{}) override {
// copy so it's alive while being used
auto arg = *dynamic_cast<const Argument*>(base_arg);
auto threads_per_wavefront = arg.block_dim.x;

auto Q_size_k_alignment_necessary = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,12 @@ struct FMHADecoderSplitKDeviceOp : public BaseOperator {

struct Invoker : public BaseInvoker {
using Argument = DeviceOp::Argument;
float Run(
const Argument& arg,
const StreamConfig& stream_config = StreamConfig{}) {
virtual float Run(
const BaseArgument* base_arg,
const StreamConfig& stream_config = StreamConfig{}) override {
// copy so it's alive while being used
auto arg = *dynamic_cast<const Argument*>(base_arg);

auto threads_per_wavefront = arg.block_dim.x;
auto Q_size_k_alignment_necessary = 0;

Expand Down

0 comments on commit 09773f8

Please sign in to comment.