Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5 support more model types #127

Closed
VictorM-PS opened this issue Oct 25, 2023 · 7 comments
Closed

T5 support more model types #127

VictorM-PS opened this issue Oct 25, 2023 · 7 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@VictorM-PS
Copy link

Hi,

I am following the enc_dec example and after changing a bit your code and adapting it to my case, I am hitting a wall when defining the weights. I see the example is running t5-small, my version is a t5-large, but after changing the config.ini with my values I expected it to run.

The error:

Traceback (most recent call last):
  File "/code/tensorrt_llm/TensorRT-LLM/examples/enc_dec/build.py", line 365, in <module>
    run_build(component='encoder')
  File "/code/tensorrt_llm/TensorRT-LLM/examples/enc_dec/build.py", line 357, in run_build
    build(0, args)
  File "/code/tensorrt_llm/TensorRT-LLM/examples/enc_dec/build.py", line 329, in build
    engine = build_rank_engine(builder, builder_config, engine_name,
  File "/code/tensorrt_llm/TensorRT-LLM/examples/enc_dec/build.py", line 228, in build_rank_engine
    load_t5_from_pytorch(tllm_model,
  File "/code/tensorrt_llm/TensorRT-LLM/examples/enc_dec/weight.py", line 180, in load_t5_from_pytorch
    layer.mlp.fc.weight.value = pytorch_model[
KeyError: 'encoder.block.0.layer.1.DenseReluDense.wi.weight'

I currently use fastertransformer_backend to optimize T5, but I see maintenance effort is dropped. Plus, it is not great not to be able to select GPU in the instance group config for triton server...

So I have two questions:

  • Do you plan to add support to T5 as an alternative to the fastertransformer repo?
  • If so, as a feature request, could we get a more general optimization for different T5 models? Or an explanation on how to adapt the current t5-small example?

Thank you for the amazing work!

@byshiue byshiue self-assigned this Oct 26, 2023
@byshiue byshiue added the triaged Issue has been triaged by maintainers label Oct 26, 2023
@symphonylyh
Copy link
Collaborator

symphonylyh commented Oct 26, 2023

Hi @VictorM-PS ,
We're in the progress of adding full support for general T5 and BART etc models. With that, simply specifying model name t5-xx can automatically handle the config setup without hardcoded values and manual changes.

Before that, please kindly make the two changes below to make t5-large working:

  1. In load_t5_from_pytorch() in weight.py, please change the hardcoded line to pytorch_ckpt = torch.load(os.path.join(pytorch_ckpt_path,'t5-large.ckpt')). Again, this will soon be made more flexible instead of hardcoded.
  2. in models/config.ini, I changed these fields in both [encoder] and [decoder] to make t5-large running:
    n_layer = 24
    n_head = 16
    hidden_size = 1024
    ffn_hidden_size = 4096
    ...
    n_positions = 512

As for the triton server question, are you aware of the corresponding TRT-LLM backend released at the same time as the main repo: https://github.com/triton-inference-server/tensorrtllm_backend?
T5 is not supported in Triton TRT-LLM yet (working on it), but this can give you a flavor of TRT-LLM equivalency of the previous Triton FT path.

@symphonylyh
Copy link
Collaborator

Hi @VictorM-PS , more general T5 and Flan-T5 support has been done and scheduled for TRT-LLM's 0.6.0 monthly release soon. Please stay tuned.

@VictorM-PS
Copy link
Author

Hi @symphonylyh, thank you very much for the heads-up and the amazing work! :)

@symphonylyh
Copy link
Collaborator

@VictorM-PS
actually it has been released to dev branch (i.e., main) earlier! please check: #424

@symphonylyh
Copy link
Collaborator

Let me close this issue as the support has been released. Feel free to reopen if needed.

@yuanze1024
Copy link

@symphonylyh Thank you for your efforts to support T5.
I'm building FlanT5xxl's TRT which I want to use it in BLIP2's pipeline. When I built it, it raised an Exception as below.

Traceback (most recent call last):
  File "/usr/lib/python3.10/configparser.py", line 791, in get
    value = d[option]
  File "/usr/lib/python3.10/collections/__init__.py", line 986, in __getitem__
    return self.__missing__(key)            # support subclasses that define __missing__
  File "/usr/lib/python3.10/collections/__init__.py", line 978, in __missing__
    raise KeyError(key)
KeyError: 'n_positions'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/code/TensorRT-LLM/examples/enc_dec/build.py", line 515, in <module>
    run_build(component='encoder')
  File "/code/TensorRT-LLM/examples/enc_dec/build.py", line 489, in run_build
    args = parse_arguments(component)
  File "/code/TensorRT-LLM/examples/enc_dec/build.py", line 235, in parse_arguments
    args = parse_config(
  File "/code/TensorRT-LLM/examples/enc_dec/build.py", line 40, in parse_config
    args = globals()[f'parse_{model_type}_config'](config, component, args)
  File "/code/TensorRT-LLM/examples/enc_dec/t5/weight.py", line 31, in parse_t5_config
    args.n_positions = config.getint(component, 'n_positions')
  File "/usr/lib/python3.10/configparser.py", line 820, in getint
    return self._get_conv(section, option, int, raw=raw, vars=vars,
  File "/usr/lib/python3.10/configparser.py", line 810, in _get_conv
    return self._get(section, conv, option, raw=raw, vars=vars,
  File "/usr/lib/python3.10/configparser.py", line 805, in _get
    return conv(self.get(section, option, **kwargs))
  File "/usr/lib/python3.10/configparser.py", line 794, in get
    raise NoOptionError(option, section)
configparser.NoOptionError: No option 'n_positions' in section: 'encoder'

According to huggingface/transformers#8047, the parameter 'n_positions' seems have been deleted long time ago. So I just add a n_positions=512 in config.ini both [encoder] and [decoder] section, and it works. So I think it maybe a bug?

@symphonylyh
Copy link
Collaborator

Hi @yuanze1024 , thanks, yes it seems other Flan-T5 variants all have a n_positions (unused though) but only flan-t5-xxl does.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

5 participants
@byshiue @symphonylyh @yuanze1024 @VictorM-PS and others