Skip to content

Commit

Permalink
[PT FE] Fix issue with aten.copy in FX graph (#23711)
Browse files Browse the repository at this point in the history
### Details:
 - *Fix translation for `aten.copy.default`*
 - *Add support for `aten.rand.default`*
 - *Support `vit-mae` model*
 - *Update hf model list with newly added models since list creation*

### Tickets:
 - *CVS-133732*
  • Loading branch information
mvafin authored Mar 28, 2024
1 parent 1c57a1c commit fba9385
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 22 deletions.
16 changes: 16 additions & 0 deletions src/frontends/pytorch/src/op/copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ OutputVector translate_copy_(const NodeContext& context) {
return {res};
};

OutputVector translate_copy_fx(const NodeContext& context) {
// copy = torch.ops.aten.copy.default(slice_4);
// copy = torch.ops.aten.copy.default(slice_4, clone);
num_inputs_check(context, 1, 2);
auto self = context.get_input(0);
if (context.input_is_none(1)) {
return {self};
} else {
auto src = context.get_input(1);
auto src_converted = context.mark_node(std::make_shared<v1::ConvertLike>(src, self));
auto self_shape = context.mark_node(std::make_shared<v3::ShapeOf>(self));
Output<Node> res = context.mark_node(std::make_shared<v3::Broadcast>(src_converted, self_shape));
return {res};
}
};

OutputVector translate_alias_copy(const NodeContext& context) {
// aten::alias_copy(Tensor self) -> Tensor
// aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
Expand Down
9 changes: 5 additions & 4 deletions src/frontends/pytorch/src/op/rand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ OutputVector make_random_normal(const NodeContext& context,
}; // namespace

OutputVector translate_rand(const NodeContext& context) {
num_inputs_check(context, 2, 6);
num_inputs_check(context, 1, 6);
auto sizes = context.get_input(0);
if (context.get_input_type(0).is<type::List>()) {
sizes = concat_list_construct(sizes);
Expand All @@ -57,14 +57,15 @@ OutputVector translate_rand(const NodeContext& context) {
size_t out_id = 1;
if (context.get_input_size() == 3) {
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(1),
"aten::randn conversion with generator does not supported");
"aten::rand conversion with generator does not supported");
out_id = 2;
}
// aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
// aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
if (context.get_input_size() == 2 || context.get_input_size() == 3) {
if (context.get_input_size() <= 3) {
auto res = context.mark_node(std::make_shared<v8::RandomUniform>(sizes, low, high, dtype));
context.mutate_input(out_id, res);
if (context.get_input_size() >= 2)
context.mutate_input(out_id, res);
return {res};
}
// aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool?
Expand Down
4 changes: 3 additions & 1 deletion src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ OP_CONVERTER(translate_batch_norm_legit_no_training_fx);
OP_CONVERTER(translate_batch_norm_legit_no_stats_fx);
OP_CONVERTER(translate_cat_fx);
OP_CONVERTER(translate_constant_pad_nd_fx);
OP_CONVERTER(translate_copy_fx);
OP_CONVERTER(translate_cumsum_fx);
OP_CONVERTER(translate_chunk_fx);
OP_CONVERTER(translate_div_fx);
Expand Down Expand Up @@ -781,7 +782,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.clone.default", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd
{"aten.constant_pad_nd.default", op::translate_constant_pad_nd_fx},
{"aten.convolution.default", op::translate_convolution},
{"aten.copy.default", op::skip_node},
{"aten.copy.default", op::translate_copy_fx},
{"aten.copy_.default", op::translate_copy_},
{"aten.cos.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Cos>},
{"aten.cosh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Cosh>},
Expand Down Expand Up @@ -888,6 +889,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.pow.Tensor_Tensor", op::translate_pow},
{"aten.pixel_shuffle.default", op::translate_pixel_shuffle},
{"aten.pixel_unshuffle.default", op::translate_pixel_unshuffle},
{"aten.rand.default", op::translate_rand},
{"aten.reciprocal.default", op::translate_reciprocal},
{"aten.reflection_pad1d.default", op::translate_reflection_pad_nd_fx},
{"aten.reflection_pad2d.default", op::translate_reflection_pad_nd_fx},
Expand Down
17 changes: 17 additions & 0 deletions tests/layer_tests/pytorch_tests/test_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ def forward(self, x):
return y


class aten_alias_tensor(torch.nn.Module):
def forward(self, x):
y = x.clone()
n,c,h,w = x.shape
ones = torch.ones([2,h,w]).to(x.dtype)
y[:, 1:, :, :] = ones
return y


class aten_loop_alias(torch.nn.Module):
def forward(self, x):
y = x.clone()
Expand All @@ -36,6 +45,14 @@ def test_alias(self, ie_device, precision, ir_version):
"aten::copy_"],
ie_device, precision, ir_version)

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
def test_alias_tensor(self, ie_device, precision, ir_version):
self._test(aten_alias_tensor(), None, ["aten::slice",
"aten::copy_"],
ie_device, precision, ir_version, freeze_model=False)

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
Expand Down
37 changes: 20 additions & 17 deletions tests/model_hub_tests/pytorch/hf_transformers_models
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ anugunj/omnivore-swinL-in21k,omnivore,skip,Load problem
apple/mobilevitv2-1.0-imagenet1k-256,mobilevitv2,xfail,Unsupported op aten::col2im
ArthurZ/jukebox_prior_0,jukebox_prior,skip,Load problem
ArthurZ/jukebox-vqvae,jukebox_vqvae,skip,Load problem
asapp/sew-d-base-plus-400k-ft-ls100h,sew-d
ashishpatel26/span-marker-bert-base-fewnerd-coarse-super,span-marker,skip,Load problem
asi/albert-act-tiny,albert_act,skip,Load problem
BAAI/AltCLIP,altclip
Expand Down Expand Up @@ -63,6 +64,8 @@ EleutherAI/enformer-official-rough,enformer,skip,Load problem
EleutherAI/gpt-neo-125m,gpt_neo
EleutherAI/pythia-6.9b,gpt_neox
facebook/bart-large-mnli,bart
facebook/blenderbot-400M-distill,blenderbot
facebook/blenderbot_small-90M,blenderbot-small
facebook/convnextv2-tiny-22k-384,convnextv2
facebook/detr-resnet-50,detr
facebook/dinov2-base,dinov2
Expand All @@ -71,17 +74,20 @@ facebook/encodec_24khz,encodec
facebook/esm2_t6_8M_UR50D,esm
facebook/flava-full,flava,xfail,Tracing problem
facebook/flava-image-codebook,flava_image_codebook,skip,Load problem
facebook/levit-128S,levit,xfail,Tracing problem
facebook/m2m100_418M,m2m_100
facebook/mask2former-swin-base-coco-panoptic,mask2former
facebook/maskformer-swin-base-coco,maskformer
facebook/mbart-large-50-many-to-many-mmt,mbart
facebook/mms-tts-eng,vits,xfail,Accuracy failed: results cannot be broadcasted
facebook/musicgen-small,musicgen
facebook/opt-125m,opt
facebook/rag-token-nq,rag,skip,Load problem
facebook/sam-vit-large,sam,xfail,No node with name original_sizes
facebook/timesformer-base-finetuned-k400,timesformer
facebook/vit-mae-base,vit_mae,xfail,Accuracy validation failed
facebook/vit-mae-base,vit_mae
facebook/wmt19-ru-en,fsmt,xfail,Tracing problem
facebook/xglm-7.5B,xglm
facebook/xlm-roberta-xl,xlm-roberta-xl
facebook/xmod-base,xmod
flax-community/ft5-cnn-dm,f_t5,skip,Load problem
Expand All @@ -103,6 +109,7 @@ google/fnet-base,fnet,xfail,Unsupported op aten::fft_fftn aten::real
google/mobilebert-uncased,mobilebert
google/mobilenet_v1_0.75_192,mobilenet_v1
google/mt5-base,mt5
google/owlv2-base-patch16-ensemble,owlv2
google/owlvit-base-patch32,owlvit
google/pix2struct-docvqa-base,pix2struct
google/realm-orqa-nq-openqa,realm,skip,Load problem
Expand All @@ -116,7 +123,6 @@ Graphcore/groupbert-base-uncased,groupbert,skip,Load problem
haoranzhao419/saffu-100M-0.1,saffu-100M-0.1,skip,Load problem
Helsinki-NLP/opus-mt-fr-en,marian
#hf-internal-testing/random-nllb-moe-2-experts,nllb-moe,skip,Load problem
hf-internal-testing/tiny-random-BlenderbotModel,blenderbot,skip,Load problem
hf-internal-testing/tiny-random-CodeGenModel,codegen
hf-internal-testing/tiny-random-convnext,convnext
hf-internal-testing/tiny-random-CvtModel,cvt
Expand All @@ -133,32 +139,24 @@ hf-internal-testing/tiny-random-GPTJModel,gptj
hf-internal-testing/tiny-random-groupvit,groupvit
hf-internal-testing/tiny-random-IBertModel,ibert
hf-internal-testing/tiny-random-ImageGPTModel,imagegpt
hf-internal-testing/tiny-random-LevitModel,levit,skip,Load problem
hf-internal-testing/tiny-random-LiltModel,lilt
hf-internal-testing/tiny-random-LongT5Model,longt5,skip,Load problem
hf-internal-testing/tiny-random-mbart,mbart,xfail,Compile error: CPU plug-in doesnt support Squeeze operation with dynamic rank
hf-internal-testing/tiny-random-MobileNetV2Model,mobilenet_v2
hf-internal-testing/tiny-random-mobilevit,mobilevit
hf-internal-testing/tiny-random-MPNetModel,mpnet
hf-internal-testing/tiny-random-MptForCausalLM,mpt
hf-internal-testing/tiny-random-NllbMoeForConditionalGeneration,nllb_moe,skip,Load problem
hf-internal-testing/tiny-random-NystromformerModel,nystromformer
hf-internal-testing/tiny-random-PegasusModel,pegasus,skip,Load problem
hf-internal-testing/tiny-random-PoolFormerModel,poolformer,skip,Load problem
hf-internal-testing/tiny-random-RegNetModel,regnet
hf-internal-testing/tiny-random-RemBertModel,rembert
hf-internal-testing/tiny-random-RoCBertModel,roc_bert,skip,Load problem
hf-internal-testing/tiny-random-RoFormerModel,roformer
hf-internal-testing/tiny-random-SegformerModel,segformer
hf-internal-testing/tiny-random-SEWDModel,sew-d,skip,Load problem
hf-internal-testing/tiny-random-SEWModel,sew,skip,Load problem
hf-internal-testing/tiny-random-Speech2TextModel,speech_to_text,skip,Load problem
hf-internal-testing/tiny-random-speech-encoder-decoder,speech-encoder-decoder,skip,Load problem
hf-internal-testing/tiny-random-SplinterModel,splinter
hf-internal-testing/tiny-random-SqueezeBertModel,squeezebert
hf-internal-testing/tiny-random-SwinModel,swin
hf-internal-testing/tiny-random-unispeech,unispeech,skip,Load problem
hf-internal-testing/tiny-random-UniSpeechSatModel,unispeech-sat,skip,Load problem
hf-internal-testing/tiny-random-vision_perceiver_conv,perceiver
hf-internal-testing/tiny-random-ViTMSNModel,vit_msn
hf-internal-testing/tiny-random-wav2vec2-conformer,wav2vec2-conformer
Expand Down Expand Up @@ -200,8 +198,6 @@ KBLab/megatron-bert-large-swedish-cased-110k,megatron-bert
kiddothe2b/hierarchical-transformer-base-4096-v2,hat,skip,Load problem
k-l-lambda/clip-text-generator,clip_text_generator,skip,Load problem
k-l-lambda/stable-diffusion-v1-4-inv-embed,inv_word_embed,skip,Load problem
KoboldAI/fairseq-dense-13B-Janeway,xglm,skip,Large Model
konverner/qdq-camembert-apolliner,qdqbert,xfail,Repository not found
krasserm/perceiver-ar-clm-base,perceiver-ar-causal-language-model,skip,Load problem
krasserm/perceiver-ar-sam-giant-midi,perceiver-ar-symbolic-audio-model,skip,Load problem
krasserm/perceiver-io-img-clf,perceiver-io-image-classifier,skip,Load problem
Expand Down Expand Up @@ -243,20 +239,25 @@ microsoft/biogpt,biogpt
microsoft/conditional-detr-resnet-50,conditional_detr
microsoft/deberta-base,deberta
microsoft/git-large-coco,git,xfail,Tracing error: Please check correctness of provided example_input (but eval was correct)
microsoft/kosmos-2-patch14-224,kosmos-2
microsoft/layoutlm-base-uncased,layoutlm
microsoft/layoutlmv2-base-uncased,layoutlmv2,xfail,Tracing error: Please check correctness of provided example_input (but eval was correct)
microsoft/layoutlmv3-base,layoutlmv3
microsoft/markuplm-base,markuplm
microsoft/prophetnet-large-uncased-squad-qg,prophetnet
microsoft/resnet-50,resnet
microsoft/speecht5_hifigan,hifigan,skip,Load problem
microsoft/speecht5_tts,speecht5,xfail,Tracing error: hangs with no error (probably because of infinite while inside generate)
microsoft/swinv2-tiny-patch4-window8-256,swinv2
microsoft/table-transformer-detection,table-transformer
microsoft/unispeech-1350-en-17h-ky-ft-1h,unispeech
microsoft/unispeech-sat-base-100h-libri-ft,unispeech-sat
microsoft/wavlm-large,wavlm,skip,Load problem
microsoft/xclip-base-patch32,xclip
microsoft/xprophetnet-large-wiki100-cased,xlm-prophetnet
miguelvictor/python-fromzero-lstmlm,lstmlm,skip,Load problem
mingzi151/test-hf-wav2vec2bert,wav2vec2bert,skip,Load problem
mistralai/Mistral-7B-v0.1,mistral
MIT/ast-finetuned-audioset-10-10-0.4593,audio-spectrogram-transformer
Mizuiro-sakura/luke-japanese-large-sentiment-analysis-wrime,luke
mlml-chip/thyme2_colon_e2e,cnlpt,skip,Load problem
Expand Down Expand Up @@ -306,6 +307,7 @@ pleisto/yuren-baichuan-7b,multimodal_llama,skip,Load problem
predictia/europe_reanalysis_downscaler_convbaseline,convbilinear,skip,Load problem
predictia/europe_reanalysis_downscaler_convswin2sr,conv_swin2sr,skip,Load problem
pszemraj/led-large-book-summary,led
pszemraj/pegasus-x-large-book-summary,pegasus_x
qmeeus/whisper-small-ner-combined,whisper_for_slu,skip,Load problem
raman-ai/pcqv2-tokengt-lap16,tokengt,skip,Load problem
range3/pegasus-gpt2-medium,pegasusgpt2,skip,Load problem
Expand All @@ -316,6 +318,7 @@ RUCAIBox/mass-base-uncased,mass,skip,Load problem
RWKV/rwkv-4-169m-pile,rwkv
sahasrarjn/interbert,BERT,skip,Load problem
saibo/genkalm-medium-gpt2,genkalm,skip,Load problem
sail/poolformer_m36,poolformer
SajjadAyoubi/clip-fa-vision,clip_vision_model
Salesforce/blip2-flan-t5-xl:vision_model,blip-2
Salesforce/blip2-flan-t5-xl:qformer,blip-2
Expand Down Expand Up @@ -354,6 +357,7 @@ SteveZhan/my-resnet50d,resnet_steve,skip,Load problem
suno/bark,bark,skip,Load problem
surajnair/r3m-50,r3m,skip,Load problem
susnato/clvp_dev,clvp,skip,Load problem
susnato/phi-1_5_dev,phi
Tanrei/GPTSAN-japanese,gptsan-japanese,xfail,Unsupported op aten::index_put_ prim::TupleConstruct
tau/bart-large-sled-govreport,tau/sled,skip,Load problem
taufeeque/best-cb-model,codebook,skip,Load problem
Expand All @@ -375,20 +379,19 @@ transZ/tforge_v1.9,Transformer_Forge,skip,Load problem
trl-internal-testing/tiny-random-BigBirdPegasusForConditionalGeneration,bigbird_pegasus
trl-internal-testing/tiny-random-BlenderbotSmallForConditionalGeneration,blenderbot-small,skip,Load problem
trl-internal-testing/tiny-random-MvpForConditionalGeneration,mvp
trl-internal-testing/tiny-random-PegasusXForConditionalGeneration,pegasus_x,skip,Load problem
trl-internal-testing/tiny-random-PLBartForConditionalGeneration,plbart,xfail,Compile error: CPU plug-in doesnt support Squeeze operation with dynamic rank
trl-internal-testing/tiny-random-ProphetNetForConditionalGeneration,prophetnet,skip,Load problem
trl-internal-testing/tiny-random-SwitchTransformersForConditionalGeneration,switch_transformers,skip,Load problem
tuner007/pegasus_paraphrase,pegasus
turing-motors/heron-chat-blip-ja-stablelm-base-7b-v0,video_blip,skip,Load problem
turing-motors/heron-chat-git-ELYZA-fast-7b-v0,git_llama,skip,Load problem
uclanlp/plbart-base,plbart
uclanlp/visualbert-vqa-coco-pre,visual_bert
ummagumm-a/samolet_room_classifier,AirModelHF,skip,Load problem
ummagumm-a/samolet-room-classifier,gru,skip,Load problem
UNCANNY69/Misinfo-BERT-LSTM,BertLSTMForSequenceClassification,skip,Load problem
UNCANNY69/Miss-BERT-CNN,BertCNNForSequenceClassification,skip,Load problem
unc-nlp/lxmert-base-uncased,lxmert,skip,Load problem
uw-madison/mra-base-512-4,mra
uw-madison/yoso-4096,yoso,xfail,Compile error: CPU plug-in doesnt support Squeeze operation with dynamic rank
uw-madison/yoso-4096,yoso
valhalla/cogview-gpt2-test,cog_view,skip,Load problem
valhalla/s2t_mustc_multilinguial_medium,speech_to_text_transformer,skip,Load problem
vblagoje/greaselm-csqa,greaselm,skip,Load problem
Expand All @@ -399,12 +402,12 @@ visualjoyce/transformers4vl-vilbert-mt,vilbert,skip,Load problem
vumichien/nonsemantic-speech-trillsson3,trillsson_efficientnet,skip,Load problem
vumichien/trillsson3-ft-keyword-spotting-12,trillsson_efficient,skip,Load problem
wangruiai2023/nougat,nougat,skip,Load problem
weiweishi/roc-bert-base-zh,roc_bert
WENGSYX/CoNN_Parity,conn,skip,Load problem
xlm-roberta-base,xlm-roberta
xlnet-base-cased,xlnet
ybelkada/focusondepth,focusondepth,skip,Load problem
ybelkada/random-tiny-BertGenerationModel,bert-generation
ydshieh/temp-testing-kosmos-2,kosmos-2,skip,Load problem
YituTech/conv-bert-base,convbert
yjernite/retribert-base-uncased,retribert,xfail,Unsupported op aten::cross_entropy_loss
ylacombe/hf-seamless-m4t-medium,seamless_m4t,skip,Load problem
Expand Down
10 changes: 10 additions & 0 deletions tests/model_hub_tests/pytorch/test_hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,13 @@ def forward(self, flattened_patches, attention_mask):
pad_token_id = model.generation_config.pad_token_id
example["decoder_input_ids"] = torch.ones(
(inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long) * pad_token_id
elif 'kosmos-2' in mi.tags:
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(name)

prompt = "<grounding>An image of"
inputs = processor(text=prompt, images=self.image, return_tensors="pt")
example = dict(inputs)
else:
try:
if auto_model == "AutoModelForCausalLM":
Expand Down Expand Up @@ -530,6 +537,9 @@ def forward(self, flattened_patches, attention_mask):
else:
example = (torch.randint(1, 1000, [1, 100]),)
self.example = filter_example(model, example)
if "vit_mae" in mi.tags:
# vit-mae by default will generate random noise
self.example["noise"] = torch.rand(1, 192)
model.eval()
# do first inference
if isinstance(self.example, dict):
Expand Down

0 comments on commit fba9385

Please sign in to comment.