diff --git a/src/frontends/pytorch/src/op/copy.cpp b/src/frontends/pytorch/src/op/copy.cpp index 5f011ce5a7a64c..4494f34b2b2f84 100644 --- a/src/frontends/pytorch/src/op/copy.cpp +++ b/src/frontends/pytorch/src/op/copy.cpp @@ -29,6 +29,22 @@ OutputVector translate_copy_(const NodeContext& context) { return {res}; }; +OutputVector translate_copy_fx(const NodeContext& context) { + // copy = torch.ops.aten.copy.default(slice_4); + // copy = torch.ops.aten.copy.default(slice_4, clone); + num_inputs_check(context, 1, 2); + auto self = context.get_input(0); + if (context.input_is_none(1)) { + return {self}; + } else { + auto src = context.get_input(1); + auto src_converted = context.mark_node(std::make_shared(src, self)); + auto self_shape = context.mark_node(std::make_shared(self)); + Output res = context.mark_node(std::make_shared(src_converted, self_shape)); + return {res}; + } +}; + OutputVector translate_alias_copy(const NodeContext& context) { // aten::alias_copy(Tensor self) -> Tensor // aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) diff --git a/src/frontends/pytorch/src/op/rand.cpp b/src/frontends/pytorch/src/op/rand.cpp index 6348182a934492..74173ff77b24e6 100644 --- a/src/frontends/pytorch/src/op/rand.cpp +++ b/src/frontends/pytorch/src/op/rand.cpp @@ -45,7 +45,7 @@ OutputVector make_random_normal(const NodeContext& context, }; // namespace OutputVector translate_rand(const NodeContext& context) { - num_inputs_check(context, 2, 6); + num_inputs_check(context, 1, 6); auto sizes = context.get_input(0); if (context.get_input_type(0).is()) { sizes = concat_list_construct(sizes); @@ -57,14 +57,15 @@ OutputVector translate_rand(const NodeContext& context) { size_t out_id = 1; if (context.get_input_size() == 3) { PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(1), - "aten::randn conversion with generator does not supported"); + "aten::rand conversion with generator does not supported"); out_id = 2; } // aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) // aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) - if (context.get_input_size() == 2 || context.get_input_size() == 3) { + if (context.get_input_size() <= 3) { auto res = context.mark_node(std::make_shared(sizes, low, high, dtype)); - context.mutate_input(out_id, res); + if (context.get_input_size() >= 2) + context.mutate_input(out_id, res); return {res}; } // aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 917e99f804bfc1..aa10e8d602e752 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -253,6 +253,7 @@ OP_CONVERTER(translate_batch_norm_legit_no_training_fx); OP_CONVERTER(translate_batch_norm_legit_no_stats_fx); OP_CONVERTER(translate_cat_fx); OP_CONVERTER(translate_constant_pad_nd_fx); +OP_CONVERTER(translate_copy_fx); OP_CONVERTER(translate_cumsum_fx); OP_CONVERTER(translate_chunk_fx); OP_CONVERTER(translate_div_fx); @@ -781,7 +782,7 @@ const std::map get_supported_ops_fx() { {"aten.clone.default", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd {"aten.constant_pad_nd.default", op::translate_constant_pad_nd_fx}, {"aten.convolution.default", op::translate_convolution}, - {"aten.copy.default", op::skip_node}, + {"aten.copy.default", op::translate_copy_fx}, {"aten.copy_.default", op::translate_copy_}, {"aten.cos.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten.cosh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, @@ -888,6 +889,7 @@ const std::map get_supported_ops_fx() { {"aten.pow.Tensor_Tensor", op::translate_pow}, {"aten.pixel_shuffle.default", op::translate_pixel_shuffle}, {"aten.pixel_unshuffle.default", op::translate_pixel_unshuffle}, + {"aten.rand.default", op::translate_rand}, {"aten.reciprocal.default", op::translate_reciprocal}, {"aten.reflection_pad1d.default", op::translate_reflection_pad_nd_fx}, {"aten.reflection_pad2d.default", op::translate_reflection_pad_nd_fx}, diff --git a/tests/layer_tests/pytorch_tests/test_aliases.py b/tests/layer_tests/pytorch_tests/test_aliases.py index e6ce36ec88f18a..c90d2b929839c9 100644 --- a/tests/layer_tests/pytorch_tests/test_aliases.py +++ b/tests/layer_tests/pytorch_tests/test_aliases.py @@ -14,6 +14,15 @@ def forward(self, x): return y +class aten_alias_tensor(torch.nn.Module): + def forward(self, x): + y = x.clone() + n,c,h,w = x.shape + ones = torch.ones([2,h,w]).to(x.dtype) + y[:, 1:, :, :] = ones + return y + + class aten_loop_alias(torch.nn.Module): def forward(self, x): y = x.clone() @@ -36,6 +45,14 @@ def test_alias(self, ie_device, precision, ir_version): "aten::copy_"], ie_device, precision, ir_version) + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.precommit_torch_export + def test_alias_tensor(self, ie_device, precision, ir_version): + self._test(aten_alias_tensor(), None, ["aten::slice", + "aten::copy_"], + ie_device, precision, ir_version, freeze_model=False) + @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.precommit_torch_export diff --git a/tests/model_hub_tests/pytorch/hf_transformers_models b/tests/model_hub_tests/pytorch/hf_transformers_models index b0e9000548a769..adbbfc468b371e 100644 --- a/tests/model_hub_tests/pytorch/hf_transformers_models +++ b/tests/model_hub_tests/pytorch/hf_transformers_models @@ -19,6 +19,7 @@ anugunj/omnivore-swinL-in21k,omnivore,skip,Load problem apple/mobilevitv2-1.0-imagenet1k-256,mobilevitv2,xfail,Unsupported op aten::col2im ArthurZ/jukebox_prior_0,jukebox_prior,skip,Load problem ArthurZ/jukebox-vqvae,jukebox_vqvae,skip,Load problem +asapp/sew-d-base-plus-400k-ft-ls100h,sew-d ashishpatel26/span-marker-bert-base-fewnerd-coarse-super,span-marker,skip,Load problem asi/albert-act-tiny,albert_act,skip,Load problem BAAI/AltCLIP,altclip @@ -63,6 +64,8 @@ EleutherAI/enformer-official-rough,enformer,skip,Load problem EleutherAI/gpt-neo-125m,gpt_neo EleutherAI/pythia-6.9b,gpt_neox facebook/bart-large-mnli,bart +facebook/blenderbot-400M-distill,blenderbot +facebook/blenderbot_small-90M,blenderbot-small facebook/convnextv2-tiny-22k-384,convnextv2 facebook/detr-resnet-50,detr facebook/dinov2-base,dinov2 @@ -71,17 +74,20 @@ facebook/encodec_24khz,encodec facebook/esm2_t6_8M_UR50D,esm facebook/flava-full,flava,xfail,Tracing problem facebook/flava-image-codebook,flava_image_codebook,skip,Load problem +facebook/levit-128S,levit,xfail,Tracing problem facebook/m2m100_418M,m2m_100 facebook/mask2former-swin-base-coco-panoptic,mask2former facebook/maskformer-swin-base-coco,maskformer +facebook/mbart-large-50-many-to-many-mmt,mbart facebook/mms-tts-eng,vits,xfail,Accuracy failed: results cannot be broadcasted facebook/musicgen-small,musicgen facebook/opt-125m,opt facebook/rag-token-nq,rag,skip,Load problem facebook/sam-vit-large,sam,xfail,No node with name original_sizes facebook/timesformer-base-finetuned-k400,timesformer -facebook/vit-mae-base,vit_mae,xfail,Accuracy validation failed +facebook/vit-mae-base,vit_mae facebook/wmt19-ru-en,fsmt,xfail,Tracing problem +facebook/xglm-7.5B,xglm facebook/xlm-roberta-xl,xlm-roberta-xl facebook/xmod-base,xmod flax-community/ft5-cnn-dm,f_t5,skip,Load problem @@ -103,6 +109,7 @@ google/fnet-base,fnet,xfail,Unsupported op aten::fft_fftn aten::real google/mobilebert-uncased,mobilebert google/mobilenet_v1_0.75_192,mobilenet_v1 google/mt5-base,mt5 +google/owlv2-base-patch16-ensemble,owlv2 google/owlvit-base-patch32,owlvit google/pix2struct-docvqa-base,pix2struct google/realm-orqa-nq-openqa,realm,skip,Load problem @@ -116,7 +123,6 @@ Graphcore/groupbert-base-uncased,groupbert,skip,Load problem haoranzhao419/saffu-100M-0.1,saffu-100M-0.1,skip,Load problem Helsinki-NLP/opus-mt-fr-en,marian #hf-internal-testing/random-nllb-moe-2-experts,nllb-moe,skip,Load problem -hf-internal-testing/tiny-random-BlenderbotModel,blenderbot,skip,Load problem hf-internal-testing/tiny-random-CodeGenModel,codegen hf-internal-testing/tiny-random-convnext,convnext hf-internal-testing/tiny-random-CvtModel,cvt @@ -133,32 +139,24 @@ hf-internal-testing/tiny-random-GPTJModel,gptj hf-internal-testing/tiny-random-groupvit,groupvit hf-internal-testing/tiny-random-IBertModel,ibert hf-internal-testing/tiny-random-ImageGPTModel,imagegpt -hf-internal-testing/tiny-random-LevitModel,levit,skip,Load problem hf-internal-testing/tiny-random-LiltModel,lilt hf-internal-testing/tiny-random-LongT5Model,longt5,skip,Load problem -hf-internal-testing/tiny-random-mbart,mbart,xfail,Compile error: CPU plug-in doesnt support Squeeze operation with dynamic rank hf-internal-testing/tiny-random-MobileNetV2Model,mobilenet_v2 hf-internal-testing/tiny-random-mobilevit,mobilevit hf-internal-testing/tiny-random-MPNetModel,mpnet hf-internal-testing/tiny-random-MptForCausalLM,mpt hf-internal-testing/tiny-random-NllbMoeForConditionalGeneration,nllb_moe,skip,Load problem hf-internal-testing/tiny-random-NystromformerModel,nystromformer -hf-internal-testing/tiny-random-PegasusModel,pegasus,skip,Load problem -hf-internal-testing/tiny-random-PoolFormerModel,poolformer,skip,Load problem hf-internal-testing/tiny-random-RegNetModel,regnet hf-internal-testing/tiny-random-RemBertModel,rembert -hf-internal-testing/tiny-random-RoCBertModel,roc_bert,skip,Load problem hf-internal-testing/tiny-random-RoFormerModel,roformer hf-internal-testing/tiny-random-SegformerModel,segformer -hf-internal-testing/tiny-random-SEWDModel,sew-d,skip,Load problem hf-internal-testing/tiny-random-SEWModel,sew,skip,Load problem hf-internal-testing/tiny-random-Speech2TextModel,speech_to_text,skip,Load problem hf-internal-testing/tiny-random-speech-encoder-decoder,speech-encoder-decoder,skip,Load problem hf-internal-testing/tiny-random-SplinterModel,splinter hf-internal-testing/tiny-random-SqueezeBertModel,squeezebert hf-internal-testing/tiny-random-SwinModel,swin -hf-internal-testing/tiny-random-unispeech,unispeech,skip,Load problem -hf-internal-testing/tiny-random-UniSpeechSatModel,unispeech-sat,skip,Load problem hf-internal-testing/tiny-random-vision_perceiver_conv,perceiver hf-internal-testing/tiny-random-ViTMSNModel,vit_msn hf-internal-testing/tiny-random-wav2vec2-conformer,wav2vec2-conformer @@ -200,8 +198,6 @@ KBLab/megatron-bert-large-swedish-cased-110k,megatron-bert kiddothe2b/hierarchical-transformer-base-4096-v2,hat,skip,Load problem k-l-lambda/clip-text-generator,clip_text_generator,skip,Load problem k-l-lambda/stable-diffusion-v1-4-inv-embed,inv_word_embed,skip,Load problem -KoboldAI/fairseq-dense-13B-Janeway,xglm,skip,Large Model -konverner/qdq-camembert-apolliner,qdqbert,xfail,Repository not found krasserm/perceiver-ar-clm-base,perceiver-ar-causal-language-model,skip,Load problem krasserm/perceiver-ar-sam-giant-midi,perceiver-ar-symbolic-audio-model,skip,Load problem krasserm/perceiver-io-img-clf,perceiver-io-image-classifier,skip,Load problem @@ -243,20 +239,25 @@ microsoft/biogpt,biogpt microsoft/conditional-detr-resnet-50,conditional_detr microsoft/deberta-base,deberta microsoft/git-large-coco,git,xfail,Tracing error: Please check correctness of provided example_input (but eval was correct) +microsoft/kosmos-2-patch14-224,kosmos-2 microsoft/layoutlm-base-uncased,layoutlm microsoft/layoutlmv2-base-uncased,layoutlmv2,xfail,Tracing error: Please check correctness of provided example_input (but eval was correct) microsoft/layoutlmv3-base,layoutlmv3 microsoft/markuplm-base,markuplm +microsoft/prophetnet-large-uncased-squad-qg,prophetnet microsoft/resnet-50,resnet microsoft/speecht5_hifigan,hifigan,skip,Load problem microsoft/speecht5_tts,speecht5,xfail,Tracing error: hangs with no error (probably because of infinite while inside generate) microsoft/swinv2-tiny-patch4-window8-256,swinv2 microsoft/table-transformer-detection,table-transformer +microsoft/unispeech-1350-en-17h-ky-ft-1h,unispeech +microsoft/unispeech-sat-base-100h-libri-ft,unispeech-sat microsoft/wavlm-large,wavlm,skip,Load problem microsoft/xclip-base-patch32,xclip microsoft/xprophetnet-large-wiki100-cased,xlm-prophetnet miguelvictor/python-fromzero-lstmlm,lstmlm,skip,Load problem mingzi151/test-hf-wav2vec2bert,wav2vec2bert,skip,Load problem +mistralai/Mistral-7B-v0.1,mistral MIT/ast-finetuned-audioset-10-10-0.4593,audio-spectrogram-transformer Mizuiro-sakura/luke-japanese-large-sentiment-analysis-wrime,luke mlml-chip/thyme2_colon_e2e,cnlpt,skip,Load problem @@ -306,6 +307,7 @@ pleisto/yuren-baichuan-7b,multimodal_llama,skip,Load problem predictia/europe_reanalysis_downscaler_convbaseline,convbilinear,skip,Load problem predictia/europe_reanalysis_downscaler_convswin2sr,conv_swin2sr,skip,Load problem pszemraj/led-large-book-summary,led +pszemraj/pegasus-x-large-book-summary,pegasus_x qmeeus/whisper-small-ner-combined,whisper_for_slu,skip,Load problem raman-ai/pcqv2-tokengt-lap16,tokengt,skip,Load problem range3/pegasus-gpt2-medium,pegasusgpt2,skip,Load problem @@ -316,6 +318,7 @@ RUCAIBox/mass-base-uncased,mass,skip,Load problem RWKV/rwkv-4-169m-pile,rwkv sahasrarjn/interbert,BERT,skip,Load problem saibo/genkalm-medium-gpt2,genkalm,skip,Load problem +sail/poolformer_m36,poolformer SajjadAyoubi/clip-fa-vision,clip_vision_model Salesforce/blip2-flan-t5-xl:vision_model,blip-2 Salesforce/blip2-flan-t5-xl:qformer,blip-2 @@ -354,6 +357,7 @@ SteveZhan/my-resnet50d,resnet_steve,skip,Load problem suno/bark,bark,skip,Load problem surajnair/r3m-50,r3m,skip,Load problem susnato/clvp_dev,clvp,skip,Load problem +susnato/phi-1_5_dev,phi Tanrei/GPTSAN-japanese,gptsan-japanese,xfail,Unsupported op aten::index_put_ prim::TupleConstruct tau/bart-large-sled-govreport,tau/sled,skip,Load problem taufeeque/best-cb-model,codebook,skip,Load problem @@ -375,12 +379,11 @@ transZ/tforge_v1.9,Transformer_Forge,skip,Load problem trl-internal-testing/tiny-random-BigBirdPegasusForConditionalGeneration,bigbird_pegasus trl-internal-testing/tiny-random-BlenderbotSmallForConditionalGeneration,blenderbot-small,skip,Load problem trl-internal-testing/tiny-random-MvpForConditionalGeneration,mvp -trl-internal-testing/tiny-random-PegasusXForConditionalGeneration,pegasus_x,skip,Load problem -trl-internal-testing/tiny-random-PLBartForConditionalGeneration,plbart,xfail,Compile error: CPU plug-in doesnt support Squeeze operation with dynamic rank -trl-internal-testing/tiny-random-ProphetNetForConditionalGeneration,prophetnet,skip,Load problem trl-internal-testing/tiny-random-SwitchTransformersForConditionalGeneration,switch_transformers,skip,Load problem +tuner007/pegasus_paraphrase,pegasus turing-motors/heron-chat-blip-ja-stablelm-base-7b-v0,video_blip,skip,Load problem turing-motors/heron-chat-git-ELYZA-fast-7b-v0,git_llama,skip,Load problem +uclanlp/plbart-base,plbart uclanlp/visualbert-vqa-coco-pre,visual_bert ummagumm-a/samolet_room_classifier,AirModelHF,skip,Load problem ummagumm-a/samolet-room-classifier,gru,skip,Load problem @@ -388,7 +391,7 @@ UNCANNY69/Misinfo-BERT-LSTM,BertLSTMForSequenceClassification,skip,Load problem UNCANNY69/Miss-BERT-CNN,BertCNNForSequenceClassification,skip,Load problem unc-nlp/lxmert-base-uncased,lxmert,skip,Load problem uw-madison/mra-base-512-4,mra -uw-madison/yoso-4096,yoso,xfail,Compile error: CPU plug-in doesnt support Squeeze operation with dynamic rank +uw-madison/yoso-4096,yoso valhalla/cogview-gpt2-test,cog_view,skip,Load problem valhalla/s2t_mustc_multilinguial_medium,speech_to_text_transformer,skip,Load problem vblagoje/greaselm-csqa,greaselm,skip,Load problem @@ -399,12 +402,12 @@ visualjoyce/transformers4vl-vilbert-mt,vilbert,skip,Load problem vumichien/nonsemantic-speech-trillsson3,trillsson_efficientnet,skip,Load problem vumichien/trillsson3-ft-keyword-spotting-12,trillsson_efficient,skip,Load problem wangruiai2023/nougat,nougat,skip,Load problem +weiweishi/roc-bert-base-zh,roc_bert WENGSYX/CoNN_Parity,conn,skip,Load problem xlm-roberta-base,xlm-roberta xlnet-base-cased,xlnet ybelkada/focusondepth,focusondepth,skip,Load problem ybelkada/random-tiny-BertGenerationModel,bert-generation -ydshieh/temp-testing-kosmos-2,kosmos-2,skip,Load problem YituTech/conv-bert-base,convbert yjernite/retribert-base-uncased,retribert,xfail,Unsupported op aten::cross_entropy_loss ylacombe/hf-seamless-m4t-medium,seamless_m4t,skip,Load problem diff --git a/tests/model_hub_tests/pytorch/test_hf_transformers.py b/tests/model_hub_tests/pytorch/test_hf_transformers.py index 1c4908d3f51994..94759c05265ee9 100644 --- a/tests/model_hub_tests/pytorch/test_hf_transformers.py +++ b/tests/model_hub_tests/pytorch/test_hf_transformers.py @@ -424,6 +424,13 @@ def forward(self, flattened_patches, attention_mask): pad_token_id = model.generation_config.pad_token_id example["decoder_input_ids"] = torch.ones( (inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long) * pad_token_id + elif 'kosmos-2' in mi.tags: + from transformers import AutoProcessor + processor = AutoProcessor.from_pretrained(name) + + prompt = "An image of" + inputs = processor(text=prompt, images=self.image, return_tensors="pt") + example = dict(inputs) else: try: if auto_model == "AutoModelForCausalLM": @@ -530,6 +537,9 @@ def forward(self, flattened_patches, attention_mask): else: example = (torch.randint(1, 1000, [1, 100]),) self.example = filter_example(model, example) + if "vit_mae" in mi.tags: + # vit-mae by default will generate random noise + self.example["noise"] = torch.rand(1, 192) model.eval() # do first inference if isinstance(self.example, dict):