-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactored Model Patcher Class #55
Refactored Model Patcher Class #55
Conversation
model = patch_model(model, base_type=self._base_layer) | ||
# wrapper function to register foak patches | ||
from fms_acceleration_foak.models import load_foak_patches | ||
load_foak_patches(base_type = self._base_layer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a better name then load_foak_patches
? You need to have more comments,
- for what reason patching is required.
- have a more descriptive name like
registering_model_patches_for_blah_blah
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dont import here, as there is no need to guard this import
@fabianlim This is the test plan that i'm using to write the unit test cases. Test PlanObjective:
1. ModelPatcherHistory
2. ModelPatcherTrigger
3. ModelPatcherRule
4. ModelPatcher
|
785b971
to
a31bf6e
Compare
373e341
to
9438aba
Compare
02e0998
to
8c825d9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- rename test_model_pactcher.py -> test_model_patcher_helpers.py
- rename test_model_patcher2.py -> test_model_patcher.py
- run an unassisted bench on 7b for related bech scenarios that use model patcher.
… to allow triton access to global constexpr
aff75fd
to
ac31192
Compare
assert len(_with_reload) <= 1, "cannot have have at most one rule with reload" | ||
# If there are multiple reload targets, | ||
# ensure that their paths do not conflict as reloading same module might reset patches | ||
if len(_with_reload)>1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
spacing
packaging # this is required for flash-attn dep as fms_hf_tuning did not specify | ||
-e {toxinidir}/plugins/framework # install the framework here as the flash attention deps requires torch | ||
passenv = * # will pass the parent env, otherwise there are too many envs e.g. TRANSFORMERS that need to be set | ||
setenv = | ||
TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you put a link to where the documentation says this must be set.
c96c4c3
to
02a92e5
Compare
02a92e5
to
f6848a7
Compare
Description
This PR addresses #44, it moves the ModelPatcher class to the Framework plugin so that all plugin patches can be managed and maintained under a common framework.
Items
fuse-ops-and-kernels
tofms-acceleration
Benchmark Tests
We observe a general decrease in memory usage in this PR's benchmark results (memory plots are higher for reference), this led to some experiments running to completion when they ran out of memory previously.