Skip to content

Commit

Permalink
additional changes from comments
Browse files Browse the repository at this point in the history
  • Loading branch information
achew010 committed Jul 29, 2024
1 parent 8895cad commit f6848a7
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
4 changes: 2 additions & 2 deletions plugins/framework/src/fms_acceleration/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __post_init__(self):
self.forward is not None,
self.forward_builder is not None,
self.import_and_maybe_reload is not None,
]) > 1:
]) != 1:
raise ValueError(
f"Rule '{self.rule_id}' must only have only one of forward, "
"foward builder, or import_and_maybe_reload, specified."
Expand Down Expand Up @@ -332,7 +332,7 @@ def _import_and_reload(model: torch.nn.Module):

# If there are multiple reload targets,
# ensure that their paths do not conflict as reloading same module might reset patches
if len(_with_reload)>1:
if len(_with_reload) > 1:
# sort ascending target path length
_with_reload = sorted(
_with_reload,
Expand Down
7 changes: 4 additions & 3 deletions scripts/benchmarks/compare_with_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ def main(
columns={"self": "new", "other": "ref"}, level=-1
)
diff = diff[diff.index.isin([outlier for outlier in outliers])]
outliers_df = outliers_df.set_index(indices).merge(
diff, left_index=True, right_index=True
)
if not diff.empty:
outliers_df = outliers_df.set_index(indices).merge(
diff, left_index=True, right_index=True
)
outliers_df.to_csv(os.path.join(result_dir, OUTLIERS_FILENAME))
for chart, filename in charts:
chart.figure.savefig(os.path.join(result_dir, filename))
Expand Down
4 changes: 4 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ deps =
-e {toxinidir}/plugins/framework # install the framework here as the flash attention deps requires torch
passenv = * # will pass the parent env, otherwise there are too many envs e.g. TRANSFORMERS that need to be set
setenv =
# Need to be set in new versions of triton that don't allow for access to global variable in the JIT compile
# Subsequently, consider changing triton kernels to access global variables that are annotated as constexpr
# source: https://github.com/triton-lang/triton/blob/7b617bcc35c4cf06f61dd267fc049fe33b2851f9/python/triton/compiler/code_generator.py#L280
# Tracking this as an issue here # https://github.com/foundation-model-stack/fms-acceleration/issues/56
TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1
commands =
# need a version of fms-hf-tuning that has integrated the framework
Expand Down

0 comments on commit f6848a7

Please sign in to comment.