You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.
from lightning_transformers.task.nlp.token_classification import TokenClassificationTransformer
model = TokenClassificationTransformer.load_from_checkpoint(ckpt_fn).to('cuda:0')
with torch.no_grad():
model.hf_predict('this is a test sentence.')
Running this, you get a device mismatch since it puts the model inputs on CPU. However, I looked at the pipeline docs on HF and tried passing device='cuda:0 to model.hf_predict yet I get the following error:
predictions = model.hf_predict(x, device='cuda:0')
File "/root/lightning-transformers/lightning_transformers/core/model.py", line 183, in hf_predict
return self.hf_pipeline(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/transformers/pipelines/token_classification.py", line 189, in __call__
return super().__call__(inputs, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/transformers/pipelines/base.py", line 987, in __call__
preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(**kwargs)
TypeError: _sanitize_parameters() got an unexpected keyword argument 'device'
Any advice? This seems like a pretty standard use case so I think there should be an easy fix.
Thanks
The text was updated successfully, but these errors were encountered:
Hey @griff4692 sorry for the late response, this is definitely a bug with the predict API. We'll be getting rid of the CLI soon, so will also need to update the docs!
🐛 Bug
Running this, you get a device mismatch since it puts the model inputs on CPU. However, I looked at the pipeline docs on HF and tried passing
device='cuda:0
to model.hf_predict yet I get the following error:Any advice? This seems like a pretty standard use case so I think there should be an easy fix.
Thanks
The text was updated successfully, but these errors were encountered: