Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use generate() with inputs_embeds #70

Closed
liechtym opened this issue Dec 28, 2023 · 2 comments
Closed

How to use generate() with inputs_embeds #70

liechtym opened this issue Dec 28, 2023 · 2 comments

Comments

@liechtym
Copy link

I hope this is the right place to ask this question. Let me know if I need to move to another repo.

Currently I'm using NeuronModelForCausalLM which uses LlamaForSampling under the hood.

I have a use case where I need to be able to do the following:

  1. Generate embedding tokens
  2. Modify embedding tokens
  3. Run inference from modified embedding tokens

I am able to do steps 1 & 2 currently using the following:

from optimum.neuron import NeuronModelForCausalLM

llama_model = NeuronModelForCausalLM.from_pretrained('aws-neuron/Llama-2-7b-chat-hf-seqlen-2048-bs-1')

embedded_tokens = llama_model.model.chkpt_model.model.embed_tokens(token_ids)

### Code to modify embedded_tokens

However, as far as I can tell, generation with these modified tokens is not possible with llama_model.generate()

When I use the 'input_embeds' keyword argument, and set input_ids=None, I get the following:

ValueError: The following `model_kwargs` are not used by the model: ['inputs_embeds']

If this is not possible with the NeuronModelForCausalLM.generate() currently, is there a way to work around this manually? If so, could you provide an example?

Thanks very much for your help!

@aws-taylor
Copy link
Contributor

Hello @liechtym,

I think this may be more appropriate for https://github.com/huggingface/optimum-neuron.

-T

@liechtym
Copy link
Author

Thanks. Moved to huggingface/optimum-neuron#395

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants