Skip to content

Commit

Permalink
Push Message formatting into _gen_model_input (#1295)
Browse files Browse the repository at this point in the history
* Revert Generate Behavior for non-Flamingo Models

* Simplify
  • Loading branch information
Jack-Khuu authored Oct 14, 2024
1 parent c867660 commit d1ab6e0
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,10 @@ def _gen_model_input(
max_new_tokens is not None
), "max_new_tokens must be specified for Flamingo models"

# Wrap string prompts into a list
if isinstance(prompt, str):
prompt = [{"role": "user", "content": prompt}]

image_found = False
messages = []
for message in prompt:
Expand Down Expand Up @@ -959,8 +963,9 @@ def chat(
max_seq_length = (
text_transformer_args.max_seq_length if text_transformer_args else 2048
)

encoded, batch = self._gen_model_input(
[{"role": "user", "content": generator_args.prompt}],
generator_args.prompt,
generator_args.image_prompts,
generator_args.max_new_tokens,
max_seq_length,
Expand Down

0 comments on commit d1ab6e0

Please sign in to comment.