Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

Running inference on GPU #258

Closed
griff4692 opened this issue Jun 13, 2022 · 3 comments · Fixed by #309
Closed

Running inference on GPU #258

griff4692 opened this issue Jun 13, 2022 · 3 comments · Fixed by #309
Labels
bug / fix Something isn't working help wanted Extra attention is needed

Comments

@griff4692
Copy link

🐛 Bug

from lightning_transformers.task.nlp.token_classification import TokenClassificationTransformer
model = TokenClassificationTransformer.load_from_checkpoint(ckpt_fn).to('cuda:0')
with torch.no_grad():
     model.hf_predict('this is a test sentence.')

Running this, you get a device mismatch since it puts the model inputs on CPU. However, I looked at the pipeline docs on HF and tried passing device='cuda:0 to model.hf_predict yet I get the following error:

  predictions = model.hf_predict(x, device='cuda:0')
  File "/root/lightning-transformers/lightning_transformers/core/model.py", line 183, in hf_predict
    return self.hf_pipeline(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/pipelines/token_classification.py", line 189, in __call__
    return super().__call__(inputs, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/pipelines/base.py", line 987, in __call__
    preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(**kwargs)
TypeError: _sanitize_parameters() got an unexpected keyword argument 'device'

Any advice? This seems like a pretty standard use case so I think there should be an easy fix.

Thanks

@griff4692 griff4692 added bug / fix Something isn't working help wanted Extra attention is needed labels Jun 13, 2022
@griff4692
Copy link
Author

It looks like you can run

https://github.com/PyTorchLightning/lightning-transformers/blob/master/lightning_transformers/cli/predict.py

for inference. I can't find example call args in the docs, however

@SeanNaren
Copy link
Contributor

Hey @griff4692 sorry for the late response, this is definitely a bug with the predict API. We'll be getting rid of the CLI soon, so will also need to update the docs!

@Borda
Copy link
Member

Borda commented Sep 14, 2022

@SeanNaren, I believe that we have finished the API cleaning, so now, just to polish the docs?
cc: @rohitgr7

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants