Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support weight-stripped engine and REFIT_IDENTICAL flag #3167

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

zewenli98
Copy link
Collaborator

@zewenli98 zewenli98 commented Sep 19, 2024

Description

  1. Supported weight-stripped engine
  2. Added REFIT_IDENTICAL flag

Fixes #3146

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@zewenli98 zewenli98 self-assigned this Sep 19, 2024
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Sep 19, 2024
Comment on lines 79 to 82
name: str = "",
settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed
weight_name_map: Optional[dict[Any, Any]] = None,
graph_module: torch.fx.GraphModule = None,
Copy link
Collaborator Author

@zewenli98 zewenli98 Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@narendasan I tried to do refitting for C++ runtime like for Python runtime but didn't work. Any suggestions? should I do in C++ or Python?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesnt refit already work on both apis?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also why do we need the graph module in this module?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. In this PR I moved the refitting part into TRTModule, so only works for Python runtime.

  2. graph module is used for refitting

@@ -619,27 +609,32 @@ def run(
builder_config, self.compilation_settings.timing_cache_path
)

serialized_engine = self.builder.build_serialized_network(
# if strip_engine_weights is true, the serialized engine need to be refitted before using
maybe_unrefitted_serialized_engine = self.builder.build_serialized_network(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this maybe unrefitted engine?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please see the design in the comment below. If compilation_settings.strip_engine_weights is true, it needs to be refitted, else it doesn't. so it's maybe

), "weight-stripped engines must be refittable, please set make_refittable=True"

# Refit the weights
refitter = trt.Refitter(self.engine, TRT_LOGGER)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use this function?

def _refit_single_trt_engine_with_gm(

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function requires input_list which is not provided in the caller.

@@ -121,6 +124,52 @@ def setup_engine(self) -> None:
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
self.context = self.engine.create_execution_context()

if self.settings.strip_engine_weights:
Copy link
Collaborator

@narendasan narendasan Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We likely shouldnt be doing the refit in these modules

I think for weight stripping there are 3 workflows.

  1. a user just wants a weight stripped engine. They should use convert_exported_program_to_trt_engine with settings strip_weights. The choice of make_refittable can be used to decide between kREFIT and kREFIT_IDENTICAL (though it might not be entirely clear so we might want to think about that setting).
  2. We want to utilize weight stripping to have a lighter weight cache. Here this choice is opaque to the user. The user choice of make_refittable controls if we use kREFIT or kREFIT_IDENTICAL. But once the engine is loaded or we pull from cache we immediately refit (prior to passing the engine to the TRTModule). Same as we do today
  3. The user wants a stripped weights compiled program (im not sure why or if this is a real usecase). Here, this is basically the same as lazy engine loading. We would require that users need to run through refit_engine_weights before executing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. The very beginning idea/design is commented below. I'll move the refitting part back to TRTInterpreter.run()

The choice of make_refittable can be used to decide between kREFIT and kREFIT_IDENTICAL

Do you mean we use make_refittable to control both kREFIT and kREFIT_IDENTICAL?

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zewenli98 do you have a design for this feature?

@zewenli98
Copy link
Collaborator Author

@narendasan Ok, at first the overall design was like:

In TRTInterpreter.run():

if compilation_settings.strip_engine_weights is True:
    if engine_cache not hit:
        1. build a weight-stripped engine
        2. save the weight-stripped engine if engine_cache is set
        3. return the weight-stripped engine (not yet refit)
    else:
        load and return the weight-stripped engine (not yet refit)
else:
    if engine_cache not hit:
        1. build a weight-included engine
        2. save the weight-included engine if engine_cache is set
        3. return the weight-included engine (don't need to refit)
    else:
        load and return the weight-included engine (not yet refit)

Then, in TRTModule, refit if necessary before inference.
The reason that I didn't put the refitting part into TRTInterpreter.run() is that I want to avoid repeated de/serializations of TRT engines: (1) deserialize in TRTInterpreter.run() for refitting and then serialize (2) deserialize in TRTModule again.

@narendasan narendasan closed this Sep 20, 2024
@narendasan narendasan reopened this Sep 20, 2024
@zewenli98
Copy link
Collaborator Author

@narendasan The design was updated.

From the users' perspective, they are able to set make_refittable and refit_identical_engine_weights.
make_refittable for general refitting and refit_identical_engine_weights only for refitting with identical weights

if self.compilation_settings.make_refittable:
    if version.parse(trt.__version__) >= version.parse("10.0"):
        if self.compilation_settings.refit_identical_engine_weights:
            builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL)
        else:
            builder_config.set_flag(trt.BuilderFlag.REFIT)
    else:
        builder_config.set_flag(trt.BuilderFlag.REFIT)

Besides, users can specify strip_engine_weights. If strip_engine_weights is True, TRTInterpreter.run() will return weight-stripped engine. Otherwise, return general engine (with weights).

For the 3 workflows mentioned above,

  1. controlling the args above, users can call convert_exported_program_to_trt_engine specifying strip_engine_weights=True to get weight-stripped engine.

  2. For engine caching, the implementation of weight-stripped engine is opaque to users, which means engine caching mechanism will (1) save weight-stripped engine no matter what settings users specify (make_refittable is required to be true) and then (2) load and refit the weight-stripped engine while reusing cached engines.
    If strip_engine_weights is True, the engine will not be refitted. Instead, just returns weight-stripped engine.

  3. If users specify strip_engine_weights=True, calling torch.compile() or torch_trt.dynamo.compile() will return weight-stripped compiled program. If running the compiled program with inputs, all the results will be zeros. Then, calling refit_module_weights will make weights back, e.g.:

from torch_tensorrt.dynamo._refit import refit_module_weights
refitted_trt_gm = refit_module_weights(trt_gm, exp_program)
refitted_output = refitted_trt_gm(*inputs)

Please see more details in the tests.

@narendasan
Copy link
Collaborator

The reason that I didn't put the refitting part into TRTInterpreter.run() is that I want to avoid repeated de/serializations of TRT engines: (1) deserialize in TRTInterpreter.run() for refitting and then serialize (2) deserialize in TRTModule again.

I think that we need to separate the runtime and the compiler so im willing to spend the time serializing and deserializing.

I think we should frame PR this around moving TRTInterpreter to default to building weight stripped engines. There will be 3 kinds of engines now.

  1. weight strip + refittable (strip_weights + kREFIT) - should move towards this being the default
  2. weight strip + refittable with original weights (strip_weights + kREFIT_INDIVIDUAL)
  3. non_refittable

The first 2 need separate cache entries. So we need to be able to hash on the weights in the case that the model is being built with kREFIT_INDIVIDUAL

We should look to prefer case 1 in the long term as it allows us to reuse the most work, case 2 would be the next preference. Case 2 should produce faster engines than Case 1 so there remains a need to support kREFIT_IDENTICAL

Do you mean we use make_refittable to control both kREFIT and kREFIT_IDENTICAL?

The case for type 3 engines now is only valid if building a non refittable engine is faster than building a refit_identical engine then refitting the weights. If it is not by a significant enough margin I propose we remove that workflow and just have refit or refit_identical engines.

So assuming that we can remove type 3 engines, make_refittable really means "allows the weights to be changed" (we can change the name if needed here), since now both engines are refittable they just have different weight constraints.

@narendasan
Copy link
Collaborator

Some of the open questions are:

  • how we determine if the weights have been refit prior to running the engine. Can TRT tell us without an error?
  • How can we tell if a user is trying to refit an engine with different weights to an engine built with REFIT_IDENTICAL?
  • If building strip weights refit identical + refit is slower than just building?

@zewenli98
Copy link
Collaborator Author

The reason that I didn't put the refitting part into TRTInterpreter.run() is that I want to avoid repeated de/serializations of TRT engines: (1) deserialize in TRTInterpreter.run() for refitting and then serialize (2) deserialize in TRTModule again.

I think that we need to separate the runtime and the compiler so im willing to spend the time serializing and deserializing.

I think we should frame PR this around moving TRTInterpreter to default to building weight stripped engines. There will be 3 kinds of engines now.

  1. weight strip + refittable (strip_weights + kREFIT) - should move towards this being the default
  2. weight strip + refittable with original weights (strip_weights + kREFIT_INDIVIDUAL)
  3. non_refittable

The first 2 need separate cache entries. So we need to be able to hash on the weights in the case that the model is being built with kREFIT_INDIVIDUAL

We should look to prefer case 1 in the long term as it allows us to reuse the most work, case 2 would be the next preference. Case 2 should produce faster engines than Case 1 so there remains a need to support kREFIT_IDENTICAL

Are you referring to kREFIT_IDENTICAL or kREFIT_INDIVIDUAL? The updated design only considered kREFIT_IDENTICAL. kREFIT_INDIVIDUAL is for fine-grained control which is not yet to be considered.

@zewenli98
Copy link
Collaborator Author

zewenli98 commented Sep 20, 2024

  • how we determine if the weights have been refit prior to running the engine. Can TRT tell us without an error?

My current design is: If users specify strip_engine_weights=True in compile, the weights will not be refitted. They will get a weight-stripped engine.
However, if they get an engine somewhere, they can call get_missing_weights() to see if there's any weight not gets refitted.

  • How can we tell if a user is trying to refit an engine with different weights to an engine built with REFIT_IDENTICAL?

I also thought about it earlier. The TRT doc says "if the refit weights are not identical to the build-time weights, behavior is undefined... This enables use of a single set of weights with different inference backends, or with TensorRT plans for multiple GPU architectures."
My understanding is that we cannot tell if weights are identical in build time and refitting, from the perspective of engine itself, because weight-stripped engine doesn't compare weights in build time and refitting phase, or give any prompts. So users need to be clear what they are refitting.

  • If building strip weights refit identical + refit is slower than just building?

will investigate on it.

@zewenli98
Copy link
Collaborator Author

zewenli98 commented Sep 23, 2024

The case for type 3 engines now is only valid if building a non refittable engine is faster than building a refit_identical engine then refitting the weights. If it is not by a significant enough margin I propose we remove that workflow and just have refit or refit_identical engines.

@narendasan I tested on building Resnet18 and vgg16 via the two paths: (1) strip weights + refit_identical + refit (2) non-refittable, build time of the two ways are almost same (diff < 1%), and engine sizes are also almost same (diff < 0.1%). I'm not sure if there are other benefits from non-refittable engines even though the build time, engine size, and performance are the same, like in deployment weights are not allowed to be changed in terms of safety?

@zewenli98
Copy link
Collaborator Author

@narendasan I just confirmed with TRT team, the conclusion is engine built with STRIP_PLAN + REFIT_IDENTICAL + refit is almost same as non-refittable engine. Do you prefer to remove non-refittable engine path?
If yes, the paths would be:

  1. weight strip + refittable (strip_weights + kREFIT) - default
  2. weight strip + refittable with original weights (strip_weights + kREFIT_IDENTICAL)

So assuming that we can remove type 3 engines, make_refittable really means "allows the weights to be changed" (we can change the name if needed here), since now both engines are refittable they just have different weight constraints.

I think we can rename make_refittable to refit_mode: str: Union["general", "identical"] (may be easier to extend in the future?) or refit_identical_weights: bool. Then, we can remove refit_identical_engine_weights arg which has been committed in this PR.

On top of this, STRIP_PLAN will be always on while building engines. we have strip_engine_weights arg to allow users to control if they want to get weight-stripped engines.

In summary, the 3 workflows mentioned above would be:

  1. Users just want a weight stripped engine. They can call convert_exported_program_to_trt_engine specifying strip_engine_weights=True to get weight-stripped engine. It is also supported if the engine is loaded from engine cache.

  2. We want to utilize weight stripping to have a lighter weight engine cache. The implementation of weight-stripped engine is opaque to users. However, if users specify kREFIT or kREFIT_IDENTICAL, they would be considered as different engine and cached twice.

  3. Users want a stripped weights compiled program. They just need to call torch.compile() or torch_trt.dynamo.compile() with strip_engine_weights=True. If running the compiled program with inputs immediately, all the results will always be zeros. Calling refit_module_weights() will make weights back

@narendasan
Copy link
Collaborator

I think we should remove non-refittable then and we can add it back as a non default workflow later if theres some reason to.

Users want a stripped weights compiled program. They just need to call torch.compile() or torch_trt.dynamo.compile() with strip_engine_weights=True. If running the compiled program with inputs immediately, all the results will always be zeros. Calling refit_module_weights() will make weights back

I still dont know what the usecase for this is

@narendasan
Copy link
Collaborator

How can we tell if a user is trying to refit an engine with different weights to an engine built with REFIT_IDENTICAL?

I also thought about it earlier. The TRT doc says "if the refit weights are not identical to the build-time weights, behavior is undefined... This enables use of a single set of weights with different inference backends, or with TensorRT plans for multiple GPU architectures."
My understanding is that we cannot tell if weights are identical in build time and refitting, from the perspective of engine itself, because weight-stripped engine doesn't compare weights in build time and refitting phase, or give any prompts. So users need to be clear what they are refitting.

We should think about a solution for this since behavior is undefined

github-actions[bot]

This comment was marked as resolved.

github-actions[bot]

This comment was marked as resolved.

@@ -414,6 +410,10 @@ def refit_module_weights(
"The type of graph module is not supported for refitting or two compiled modules do not match."
)

assert (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weight stripped engines can only be refit once?

# clear EXCLUDE_WEIGHTS flag
serialization_config = engine.create_serialization_config()
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
serialized_engine = engine.serialize_with_config(serialization_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we serialize then immediately deserialize here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because we want the engine to clear the EXCLUDE_WEIGHTS flag. Is there a way to clear the flag without doing serialization?

new_engine_info = list(engine_info)
new_engine_info[ENGINE_IDX] = serialized_engine
new_engine_info[ENGINE_IDX] = bytes(serialized_engine)
Copy link
Collaborator

@narendasan narendasan Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we only need to deserialize in a PythonTorchTensorRTModule and we should probably use setup_engine instead. The standard interface should be like provide the serialized engine, then setup engine will set the module up properly for both Python and C++

@@ -532,7 +548,10 @@ def run(
# self.engine_cache could be None if:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are kind of avoiding this but I feel like we might want to restructure completely run to assume refit and check if the user wants immutable weights. Also we might want to pull the cache pulling and inserting code into helpers just to make easier to understand

Copy link
Collaborator

@narendasan narendasan Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like lines from 551-598 and 671-711 should probably be helpers that pull and insert the weight stripped engine. We should have one for single module refit as well

Copy link
Collaborator

@narendasan narendasan Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like

if self.reuse_cached_engines:
    weight_stripped_engines = _pull_cached_engine(hash, settings, inputs) 
serialized_engine = fit_weights(self, weight_stripped_engine):
      runtime = trt.Runtime(TRT_LOGGER)
      engine = runtime.deserialize_cuda_engine(
          weight_stripped_serialized_engine
      )

      from torch_tensorrt.dynamo._refit import (
          _refit_single_trt_engine_with_gm,
      )

      _refit_single_trt_engine_with_gm(
          new_gm=self.module,
          old_engine=engine,
          input_list=self.input_specs,
          settings=self.compilation_settings,
          weight_name_map=self.weight_name_map,
      )

      # Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared
      serialization_config = engine.create_serialization_config()
      serialization_config.clear_flag(
          trt.SerializationFlag.EXCLUDE_WEIGHTS
      )
      serialized_engine = engine.serialize_with_config(
          serialization_config
      )
      return serialized_engine

@@ -629,35 +657,68 @@ def run(
assert serialized_engine

_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
f"Build weight-stripped TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably just leave this as is, if the user requests immutable_weights, it wouldnt apply

weight_stripped_serialized_engine = serialized_engine
else:
# Serialize the refitted engine where the EXCLUDE_WEIGHTS flag must be cleared
runtime = trt.Runtime(TRT_LOGGER)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we refit then strip the weights again? If refit is enabled shouldnt the builder always us a weight stripped engine?

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to add the refit settings to the _SETTINGS_TO_BE_ENGINE_INVARIANT

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

✨[Feature] Weight specific engine caching
4 participants