Skip to content

Commit

Permalink
Change in generate_instances.py so that this scripts can be called fr…
Browse files Browse the repository at this point in the history
…om flexible location
  • Loading branch information
qianfengz committed Aug 20, 2024
1 parent 1a73f34 commit d293caf
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions xformers/csrc/attention/hip_fmha/generate_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
"{has_or_no_bias_str}_{has_or_no_biasgrad_str}_{has_or_no_dropout_str}_{max_k_str}.cpp"
)

FMHA_INSTANCE_REF_FNAME = "instances/fmha_{mode}_{function}_{dtype}_instances_ref.h"
FMHA_INSTANCE_REF_FNAME = "fmha_{mode}_{function}_{dtype}_instances_ref.h"

BOOL_MAP = {True: "true", False: "false"}

Expand Down Expand Up @@ -174,11 +174,12 @@ def create_infer_instances_ref(instance_dir: Path, headdims: List) -> None:
function="infer",
dtype=dtype,
)
ref_fname_path = instance_dir / ref_fname
infer_instance_inc = FMHA_INFER_INSTANCE_TEMPLATE_INC.format(
mode=mode,
dtype_file=TYPE_FNAME_MAP[dtype],
)
with open(ref_fname, "a") as file:
with open(ref_fname_path, "a") as file:
file.write(FMHA_COPYRIGHT_HEADER)
file.write(infer_instance_inc)
for max_k in headdims:
Expand Down Expand Up @@ -246,11 +247,12 @@ def create_forward_instances_ref(instance_dir: Path, headdims: List) -> None:
function="forward",
dtype=dtype,
)
ref_fname_path = instance_dir / ref_fname
forward_instance_inc = FMHA_FORWARD_INSTANCE_TEMPLATE_INC.format(
mode=mode,
dtype_file=TYPE_FNAME_MAP[dtype],
)
with open(ref_fname, "a") as file:
with open(ref_fname_path, "a") as file:
file.write(FMHA_COPYRIGHT_HEADER)
file.write(forward_instance_inc)
for max_k in headdims:
Expand Down Expand Up @@ -326,11 +328,12 @@ def create_backward_instances_ref(instance_dir: Path, headdims: List) -> None:
function="backward",
dtype=dtype,
)
ref_fname_path = instance_dir / ref_fname
backward_instance_inc = FMHA_BACKWARD_INSTANCE_TEMPLATE_INC.format(
mode=mode,
dtype_file=TYPE_FNAME_MAP[dtype],
)
with open(ref_fname, "a") as file:
with open(ref_fname_path, "a") as file:
file.write(FMHA_COPYRIGHT_HEADER)
file.write(backward_instance_inc)
for max_k in headdims:
Expand Down

0 comments on commit d293caf

Please sign in to comment.