Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored Model Patcher Class #55

Merged
merged 28 commits into from
Jul 29, 2024

Conversation

achew010
Copy link
Contributor

@achew010 achew010 commented Jul 18, 2024

Description

This PR addresses #44, it moves the ModelPatcher class to the Framework plugin so that all plugin patches can be managed and maintained under a common framework.

Items

  • Move all ModelPatcher functionality from fuse-ops-and-kernels to fms-acceleration
  • Replace local plugin patches with ModelPatcher functions
  • Unit tests to test ModelPatcher functionality

Benchmark Tests

We observe a general decrease in memory usage in this PR's benchmark results (memory plots are higher for reference), this led to some experiments running to completion when they ran out of memory previously.

model = patch_model(model, base_type=self._base_layer)
# wrapper function to register foak patches
from fms_acceleration_foak.models import load_foak_patches
load_foak_patches(base_type = self._base_layer)
Copy link
Contributor

@fabianlim fabianlim Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a better name then load_foak_patches? You need to have more comments,

  • for what reason patching is required.
  • have a more descriptive name like registering_model_patches_for_blah_blah

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont import here, as there is no need to guard this import

@achew010
Copy link
Contributor Author

@fabianlim This is the test plan that i'm using to write the unit test cases.

Test Plan

Objective:

  • Ensure ModelPatcher features produces the right outputs
  • Test the constraints of ModelPatcher, ModelPatcherRule, ModelPatcherTrigger and exceptions raised

1. ModelPatcherHistory

  • Test ModelPatcherHistory Construction
    • test_mp_history_constructs_successfully()

2. ModelPatcherTrigger

  • Test Trigger Construction
    • test_mp_trigger_constructs_successfully()
  • is_triggered
    • Test Trigger checks respond as intended
      • test_mp_trigger_returns_correct_response()
  • combine_triggers
    • test_correct_output_combine_mp_triggers()

3. ModelPatcherRule

  • Test ModelPatcherRule Construction
    • test_mp_rule_constructs_successfully()

4. ModelPatcher

  • combine_functions

    • test_correct_outputs_mp_combine_functions()
  • patch_target_module

    • test_standalone_import_and_reload_function_replaces_indirect_module
  • register behaviour

    • test_mp_registers_only_one_unique_rule()
  • patch scenarios

    • Simple forward replacement
      • test_mp_rule_patches_forward
    • import_and_reload replacement
      • test_MP_rule_import_and_reload_patches_downstream_module()
    • forward builder replacement
      • test_MP_rule_patches_forward_with_builder_and_args
  • summary behaviour

    • test_mp_history_converts_to_dataframe
  • load_patches behaviour

    • test_registration_of_rules_through_package_imports

@achew010 achew010 changed the title Refectored Model Patcher Class Refactored Model Patcher Class Jul 22, 2024
@achew010 achew010 force-pushed the refactor/model-patcher branch from 785b971 to a31bf6e Compare July 23, 2024 11:48
@fabianlim fabianlim force-pushed the refactor/model-patcher branch from 373e341 to 9438aba Compare July 25, 2024 03:45
@achew010 achew010 marked this pull request as ready for review July 25, 2024 07:20
@achew010 achew010 force-pushed the refactor/model-patcher branch from 02e0998 to 8c825d9 Compare July 25, 2024 07:51
Copy link
Contributor

@fabianlim fabianlim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. rename test_model_pactcher.py -> test_model_patcher_helpers.py
  2. rename test_model_patcher2.py -> test_model_patcher.py
  3. run an unassisted bench on 7b for related bech scenarios that use model patcher.

@achew010 achew010 force-pushed the refactor/model-patcher branch from aff75fd to ac31192 Compare July 29, 2024 03:34
assert len(_with_reload) <= 1, "cannot have have at most one rule with reload"
# If there are multiple reload targets,
# ensure that their paths do not conflict as reloading same module might reset patches
if len(_with_reload)>1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spacing

packaging # this is required for flash-attn dep as fms_hf_tuning did not specify
-e {toxinidir}/plugins/framework # install the framework here as the flash attention deps requires torch
passenv = * # will pass the parent env, otherwise there are too many envs e.g. TRANSFORMERS that need to be set
setenv =
TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you put a link to where the documentation says this must be set.

@achew010 achew010 force-pushed the refactor/model-patcher branch from c96c4c3 to 02a92e5 Compare July 29, 2024 05:02
@achew010 achew010 force-pushed the refactor/model-patcher branch from 02a92e5 to f6848a7 Compare July 29, 2024 05:19
@fabianlim fabianlim merged commit b6c1455 into foundation-model-stack:main Jul 29, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants