Skip to content

Commit

Permalink
Fix doctests for TFVisionTextDualEncoder (#21910)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 authored Mar 3, 2023
1 parent 9f5bfe1 commit 37e0974
Showing 1 changed file with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,10 @@ def get_text_features(
```python
>>> from transformers import TFVisionTextDualEncoderModel, AutoTokenizer
>>> model = TFVisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian")
>>> model = TFVisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian", from_pt=True)
>>> tokenizer = AutoTokenizer.from_pretrained("clip-italian/clip-italian")
>>> inputs = tokenizer(["una foto di un gatto", "una foto di un cane"], padding=True, return_tensors="pt")
>>> inputs = tokenizer(["una foto di un gatto", "una foto di un cane"], padding=True, return_tensors="np")
>>> text_features = model.get_text_features(**inputs)
```"""
text_outputs = self.text_model(
Expand Down Expand Up @@ -313,7 +313,7 @@ def get_image_features(
>>> import requests
>>> from transformers import TFVisionTextDualEncoderModel, AutoImageProcessor
>>> model = VisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian")
>>> model = TFVisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian", from_pt=True)
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
Expand Down Expand Up @@ -380,7 +380,7 @@ def call(
... ]
>>> images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
>>> inputs = processor(
... text=["a photo of a cat", "a photo of a dog"], images=images, return_tensors="pt", padding=True
... text=["a photo of a cat", "a photo of a dog"], images=images, return_tensors="np", padding=True
... )
>>> outputs = model(
... input_ids=inputs.input_ids,
Expand Down Expand Up @@ -587,6 +587,8 @@ def from_vision_text_pretrained(
if text_model.name != "text_model":
raise ValueError("text model must be created with the name `text_model`.")

model(model.dummy_inputs) # Ensure model is fully built

return model

@property
Expand Down

0 comments on commit 37e0974

Please sign in to comment.