diff --git a/torchchat/generate.py b/torchchat/generate.py index fcbe5513b..a9094aa40 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -797,6 +797,10 @@ def _gen_model_input( max_new_tokens is not None ), "max_new_tokens must be specified for Flamingo models" + # Wrap string prompts into a list + if isinstance(prompt, str): + prompt = [{"role": "user", "content": prompt}] + image_found = False messages = [] for message in prompt: @@ -959,8 +963,9 @@ def chat( max_seq_length = ( text_transformer_args.max_seq_length if text_transformer_args else 2048 ) + encoded, batch = self._gen_model_input( - [{"role": "user", "content": generator_args.prompt}], + generator_args.prompt, generator_args.image_prompts, generator_args.max_new_tokens, max_seq_length,