Skip to content

Commit

Permalink
update the requirements and fix the _expand_mask import issue
Browse files Browse the repository at this point in the history
  • Loading branch information
maulikmadhavi committed May 19, 2024
1 parent b28106d commit f0317e0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
15 changes: 14 additions & 1 deletion lavis/models/blip_diffusion_models/modeling_ctx_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,22 @@
from transformers.models.clip.modeling_clip import (
CLIPEncoder,
CLIPPreTrainedModel,
_expand_mask,
)

def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len

expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

inverted_mask = 1.0 - expanded_mask

return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)



class CtxCLIPTextModel(CLIPPreTrainedModel):
config_class = CLIPTextConfig
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ scikit-image
sentencepiece
spacy
streamlit
timm==0.4.12
timm
torch>=1.10.0
torchvision
tqdm
transformers==4.33.2
transformers
webdataset
wheel
torchaudio
Expand All @@ -35,5 +35,5 @@ peft

easydict==1.9
pyyaml_env_tag==0.1
open3d==0.13.0
open3d
h5py

0 comments on commit f0317e0

Please sign in to comment.