diff --git a/docs/openclip_results.csv b/docs/openclip_results.csv index f901e0ecc..8d27ffd32 100644 --- a/docs/openclip_results.csv +++ b/docs/openclip_results.csv @@ -1,35 +1,50 @@ name,pretrained,Average perf. on 38 datasets,ImageNet 1k,Caltech-101,CIFAR-10,CIFAR-100,CLEVR Counts,CLEVR Distance,Country211,Describable Textures,EuroSAT,FGVC Aircraft,Food-101,GTSRB,ImageNet Sketch,ImageNet v2,ImageNet-A,ImageNet-O,ImageNet-R,KITTI Vehicle Distance,MNIST,ObjectNet,Oxford Flowers-102,Oxford-IIIT Pet,Pascal VOC 2007,PatchCamelyon,Rendered SST2,RESISC45,Stanford Cars,STL-10,SUN397,SVHN,Flickr,MSCOCO,WinoGAViL,iWildCam,Camelyon17,FMoW,Dollar Street,GeoDE EVA02-E-14-plus,laion2b_s9b_b144k,0.6930,0.8201,0.9535,0.9934,0.9316,0.2991,0.1998,0.3564,0.6777,0.7574,0.5360,0.9496,0.6740,0.7162,0.7564,0.8223,0.3540,0.9456,0.1842,0.7463,0.7937,0.8433,0.9567,0.8569,0.6442,0.6271,0.7490,0.9457,0.9926,0.7510,0.7560,0.8648,0.5991,0.4403,0.2591,0.6948,0.2668,0.6951,0.9244 +ViT-SO400M-14-SigLIP-384,webli,0.6921,0.8308,0.9599,0.9672,0.8357,0.4071,0.2246,0.3645,0.7303,0.6354,0.6069,0.9635,0.6429,0.7454,0.7717,0.8247,0.2775,0.9575,0.2082,0.8862,0.7694,0.9114,0.9680,0.7171,0.5268,0.7002,0.7211,0.9521,0.9930,0.7541,0.5151,0.8863,0.6331,0.5754,0.2294,0.6149,0.3309,0.7301,0.9328 +ViT-SO400M-14-SigLIP,webli,0.6808,0.8203,0.9600,0.9679,0.8417,0.4210,0.2213,0.3243,0.7106,0.6274,0.6029,0.9556,0.6382,0.7402,0.7607,0.7185,0.2960,0.9506,0.2489,0.8929,0.7061,0.8982,0.9522,0.7034,0.5057,0.6936,0.7257,0.9032,0.9939,0.7436,0.5670,0.8313,0.6071,0.5665,0.1915,0.6215,0.3163,0.7173,0.9278 +ViT-bigG-14-CLIPA-336,datacomp1b,0.6789,0.8230,0.9516,0.9901,0.9099,0.1593,0.2033,0.4041,0.7362,0.6398,0.5401,0.9591,0.6093,0.7359,0.7654,0.8536,0.3025,0.9489,0.3024,0.8513,0.7940,0.8715,0.9545,0.8210,0.5471,0.5261,0.7076,0.9517,0.9966,0.7524,0.6740,0.8450,0.5682,0.4600,0.2363,0.5842,0.1820,0.6998,0.9424 EVA02-E-14,laion2b_s4b_b115k,0.6690,0.8196,0.9541,0.9925,0.9258,0.1632,0.2499,0.3482,0.6878,0.7446,0.4892,0.9523,0.6729,0.7151,0.7566,0.8044,0.3340,0.9407,0.1294,0.7581,0.7674,0.8210,0.9569,0.8136,0.4972,0.5859,0.7324,0.9438,0.9926,0.7658,0.6381,0.8515,0.5892,0.4429,0.2289,0.4894,0.2801,0.6682,0.9182 +ViT-L-16-SigLIP-384,webli,0.6683,0.8207,0.9611,0.9605,0.8188,0.3275,0.2077,0.2470,0.7080,0.5817,0.5312,0.9564,0.6385,0.7360,0.7593,0.7663,0.3130,0.9507,0.2222,0.8525,0.7284,0.8934,0.9681,0.7172,0.5466,0.5634,0.6789,0.9493,0.9924,0.7250,0.5672,0.8756,0.6290,0.5550,0.2236,0.6637,0.1489,0.6916,0.9207 +ViT-H-14-CLIPA-336,datacomp1b,0.6677,0.8180,0.9467,0.9890,0.8968,0.1326,0.2254,0.3551,0.7197,0.6604,0.4718,0.9572,0.5816,0.7282,0.7562,0.8275,0.3115,0.9438,0.2574,0.8245,0.7742,0.8463,0.9573,0.8134,0.4979,0.6052,0.7114,0.9483,0.9955,0.7635,0.6599,0.8356,0.5822,0.4587,0.2239,0.4357,0.2500,0.6822,0.9278 ViT-H-14-quickgelu,metaclip_fullcc,0.6671,0.8051,0.9536,0.9804,0.8634,0.2115,0.1881,0.3716,0.7271,0.6450,0.5114,0.9423,0.6257,0.7052,0.7417,0.7533,0.3040,0.9342,0.2771,0.7266,0.7642,0.8448,0.9561,0.7495,0.6222,0.6925,0.7024,0.8990,0.9944,0.7440,0.5910,0.8507,0.5752,0.5312,0.1680,0.5782,0.2314,0.6811,0.9077 ViT-bigG-14,laion2b_s39b_b160k,0.6667,0.8009,0.9484,0.9824,0.8752,0.2989,0.2002,0.3379,0.6867,0.6919,0.4953,0.9309,0.6244,0.6894,0.7359,0.6933,0.3785,0.9213,0.1308,0.7157,0.7284,0.8163,0.9529,0.8077,0.6364,0.6535,0.7235,0.9460,0.9850,0.7450,0.6961,0.8623,0.5938,0.4488,0.1760,0.5905,0.2352,0.6857,0.9127 +ViT-H-14-CLIPA,datacomp1b,0.6653,0.8152,0.9458,0.9888,0.8991,0.1513,0.2255,0.3401,0.7090,0.7146,0.4751,0.9554,0.5538,0.7272,0.7498,0.7701,0.3135,0.9426,0.2461,0.8189,0.7423,0.8437,0.9559,0.8170,0.4958,0.6189,0.7098,0.9458,0.9948,0.7608,0.6622,0.8344,0.5804,0.4578,0.2160,0.4415,0.2684,0.6694,0.9236 ViT-L-14,datacomp_xl_s13b_b90k,0.6627,0.7921,0.9465,0.9824,0.8736,0.3555,0.2443,0.3157,0.6649,0.7124,0.4750,0.9452,0.5853,0.6795,0.7205,0.6959,0.3255,0.9083,0.2785,0.8661,0.7425,0.8262,0.9506,0.8247,0.5118,0.6101,0.6941,0.9305,0.9925,0.7427,0.6769,0.8119,0.5451,0.4666,0.1614,0.5089,0.2403,0.6624,0.9152 EVA01-g-14-plus,merged2b_s11b_b114k,0.6624,0.7933,0.9506,0.9910,0.9008,0.2302,0.2293,0.3087,0.6734,0.7280,0.3947,0.9366,0.6644,0.6814,0.7214,0.7416,0.3415,0.9246,0.1491,0.7176,0.7491,0.7959,0.9490,0.8285,0.6244,0.5854,0.7079,0.9073,0.9949,0.7426,0.5951,0.8535,0.5925,0.4684,0.1882,0.7100,0.2283,0.6589,0.9148 ViT-L-14-quickgelu,metaclip_fullcc,0.6592,0.7917,0.9527,0.9759,0.8410,0.3107,0.2260,0.3394,0.6862,0.5894,0.4537,0.9352,0.5623,0.6896,0.7256,0.7231,0.3010,0.9205,0.2785,0.6444,0.7457,0.8143,0.9461,0.8030,0.6197,0.6678,0.7360,0.8868,0.9933,0.7355,0.4681,0.8326,0.5576,0.5357,0.1581,0.7551,0.2592,0.6752,0.9140 EVA02-L-14-336,merged2b_s6b_b61k,0.6583,0.8039,0.9525,0.9892,0.8980,0.3635,0.2485,0.3354,0.6473,0.7139,0.3758,0.9421,0.5759,0.6891,0.7380,0.8289,0.2850,0.9324,0.2377,0.6421,0.7789,0.7645,0.9424,0.8267,0.5487,0.6463,0.6910,0.9158,0.9966,0.7480,0.4575,0.8381,0.5605,0.5053,0.2105,0.5691,0.2198,0.6811,0.9136 +ViT-L-14-CLIPA-336,datacomp1b,0.6570,0.8026,0.9439,0.9864,0.8826,0.1566,0.2439,0.3066,0.6856,0.5811,0.4281,0.9456,0.5695,0.7087,0.7346,0.7771,0.3290,0.9329,0.1997,0.7667,0.7317,0.8100,0.9495,0.7979,0.6028,0.5316,0.6884,0.9407,0.9929,0.7560,0.6290,0.8251,0.5640,0.4449,0.1937,0.6783,0.2500,0.6752,0.9240 +ViT-L-16-SigLIP-256,webli,0.6557,0.8045,0.9593,0.9619,0.8191,0.4065,0.2150,0.2141,0.7027,0.5598,0.5259,0.9463,0.6115,0.7209,0.7376,0.6213,0.3265,0.9396,0.1983,0.8499,0.6526,0.8827,0.9604,0.7409,0.5458,0.6172,0.6817,0.9386,0.9911,0.7253,0.5211,0.8542,0.6154,0.5748,0.1796,0.5757,0.1296,0.6904,0.9173 +ViT-L-14-CLIPA,datacomp1b,0.6536,0.7957,0.9453,0.9866,0.8850,0.1857,0.2449,0.2941,0.6963,0.6044,0.4299,0.9415,0.5906,0.7061,0.7305,0.7125,0.3370,0.9288,0.1927,0.7374,0.6988,0.8101,0.9497,0.8067,0.5915,0.5387,0.6843,0.9366,0.9919,0.7528,0.6390,0.8188,0.5604,0.4388,0.1724,0.6760,0.2457,0.6647,0.9152 convnext_xxlarge,laion2b_s34b_b82k_augreg_soup,0.6530,0.7947,0.9448,0.9822,0.8687,0.1454,0.2365,0.3170,0.7053,0.6128,0.4434,0.9321,0.5508,0.6840,0.7260,0.6719,0.4060,0.9160,0.2363,0.8277,0.7273,0.8241,0.9445,0.8090,0.5142,0.6952,0.7190,0.9409,0.9810,0.7458,0.6254,0.8521,0.5867,0.4702,0.1730,0.6071,0.0000,0.6764,0.9215 convnext_xxlarge,laion2b_s34b_b82k_augreg_rewind,0.6521,0.7931,0.9452,0.9823,0.8686,0.1651,0.2534,0.3155,0.7016,0.6331,0.4398,0.9308,0.5491,0.6825,0.7228,0.6657,0.3975,0.9139,0.2419,0.7930,0.7252,0.8241,0.9438,0.8100,0.5014,0.6897,0.7168,0.9406,0.9801,0.7459,0.6137,0.8498,0.5871,0.4741,0.1735,0.6071,0.0000,0.6799,0.9228 xlm-roberta-large-ViT-H-14,frozen_laion5b_s13b_b90k,0.6515,0.7695,0.9422,0.9718,0.8430,0.3358,0.2050,0.3172,0.6926,0.6793,0.4673,0.9236,0.6239,0.6581,0.6944,0.5935,0.3390,0.8940,0.1364,0.7804,0.6911,0.7532,0.9431,0.7995,0.5792,0.6436,0.6825,0.9362,0.9889,0.7551,0.5950,0.8461,0.5758,0.5206,0.1392,0.6749,0.2098,0.6460,0.9111 ViT-L-14,commonpool_xl_clip_s13b_b90k,0.6501,0.7637,0.9502,0.9797,0.8615,0.2547,0.2451,0.2984,0.6521,0.6681,0.3860,0.9355,0.5980,0.6538,0.6953,0.6197,0.3525,0.8924,0.2982,0.9040,0.7165,0.8006,0.9424,0.8336,0.5688,0.6178,0.6978,0.9352,0.9875,0.7351,0.6853,0.7768,0.5156,0.4728,0.1439,0.5100,0.1705,0.6776,0.9056 EVA02-L-14,merged2b_s4b_b131k,0.6488,0.7977,0.9512,0.9908,0.9071,0.3176,0.2462,0.3091,0.6319,0.6994,0.3638,0.9340,0.5718,0.6813,0.7295,0.7619,0.2880,0.9272,0.2518,0.6729,0.7489,0.7631,0.9398,0.8220,0.5431,0.6150,0.6968,0.9055,0.9961,0.7410,0.4793,0.8351,0.5556,0.5081,0.1886,0.5124,0.2017,0.6624,0.9073 convnext_xxlarge,laion2b_s34b_b82k_augreg,0.6479,0.7907,0.9429,0.9816,0.8677,0.1399,0.1195,0.3127,0.7096,0.6030,0.4250,0.9295,0.5454,0.6806,0.7223,0.6692,0.4025,0.9131,0.2616,0.8687,0.7235,0.8091,0.9455,0.8116,0.5340,0.6782,0.7100,0.9399,0.9824,0.7436,0.6379,0.8531,0.5834,0.4536,0.1616,0.5719,0.0000,0.6729,0.9228 +ViT-B-16-SigLIP-512,webli,0.6459,0.7914,0.9516,0.9265,0.7146,0.2411,0.2226,0.1927,0.6793,0.4007,0.4521,0.9394,0.5171,0.6990,0.7283,0.6769,0.3615,0.9264,0.3924,0.8288,0.6764,0.8677,0.9499,0.7139,0.6615,0.5722,0.6538,0.9249,0.9853,0.7152,0.5444,0.8578,0.5963,0.5696,0.1925,0.6606,0.1411,0.6928,0.9244 +ViT-H-14-CLIPA-336,laion2b,0.6439,0.7910,0.9438,0.9826,0.8643,0.1835,0.2158,0.3111,0.7160,0.6393,0.3437,0.9303,0.5007,0.6994,0.7241,0.7213,0.3655,0.9269,0.1561,0.6365,0.7022,0.8009,0.9444,0.7723,0.5787,0.6178,0.7029,0.9476,0.9894,0.7567,0.6255,0.8522,0.5883,0.4878,0.1853,0.5001,0.1666,0.6706,0.9257 ViT-g-14,laion2b_s34b_b88k,0.6427,0.7847,0.9452,0.9815,0.8465,0.3768,0.1870,0.3091,0.6856,0.6530,0.4441,0.9241,0.4964,0.6754,0.7158,0.6092,0.3705,0.9020,0.2700,0.7191,0.6908,0.8010,0.9379,0.8166,0.5384,0.5678,0.6960,0.9394,0.9893,0.7411,0.5611,0.8456,0.5758,0.4104,0.1524,0.4771,0.2090,0.6671,0.9090 ViT-H-14,laion2b_s32b_b79k,0.6419,0.7796,0.9421,0.9745,0.8473,0.2676,0.2358,0.2986,0.6782,0.7278,0.4265,0.9273,0.5832,0.6657,0.7090,0.5935,0.3825,0.8934,0.1097,0.7284,0.6941,0.7982,0.9438,0.7768,0.5430,0.6392,0.6995,0.9338,0.9848,0.7521,0.5252,0.8417,0.5770,0.4247,0.1528,0.5638,0.2264,0.6343,0.9086 convnext_large_d_320,laion2b_s29b_b131k_ft_soup,0.6387,0.7685,0.9348,0.9659,0.8304,0.4293,0.2010,0.2654,0.6830,0.7161,0.3621,0.9162,0.5822,0.6504,0.6944,0.6044,0.4410,0.8862,0.1027,0.7434,0.6898,0.7755,0.9358,0.8129,0.4814,0.5585,0.7078,0.9369,0.9856,0.7376,0.6712,0.8467,0.5665,0.4549,0.1786,0.4088,0.1901,0.6449,0.9094 +ViT-B-16-SigLIP-384,webli,0.6379,0.7849,0.9507,0.9276,0.7147,0.2195,0.2239,0.1858,0.6718,0.4307,0.4522,0.9362,0.5196,0.6955,0.7211,0.6233,0.3640,0.9214,0.3333,0.8088,0.6342,0.8624,0.9515,0.7162,0.7010,0.5607,0.6579,0.9245,0.9863,0.7096,0.5285,0.8559,0.5882,0.5719,0.1719,0.5931,0.1365,0.6846,0.9194 ViT-L-14,commonpool_xl_laion_s13b_b90k,0.6360,0.7545,0.9352,0.9796,0.8585,0.3819,0.2489,0.2503,0.6191,0.7378,0.2869,0.9200,0.6018,0.6352,0.6851,0.5747,0.3730,0.8708,0.1378,0.7740,0.6846,0.7435,0.9308,0.8107,0.5069,0.5986,0.7065,0.8912,0.9903,0.7327,0.5730,0.8130,0.5513,0.4966,0.1421,0.5671,0.2337,0.6600,0.9115 EVA01-g-14,laion400m_s11b_b41k,0.6358,0.7852,0.9477,0.9829,0.8865,0.1966,0.2467,0.2862,0.6144,0.7237,0.3226,0.9345,0.4913,0.6730,0.7152,0.7359,0.3285,0.9250,0.2405,0.6218,0.7200,0.7427,0.9414,0.8325,0.4987,0.5832,0.6976,0.9171,0.9889,0.7416,0.5889,0.8037,0.5293,0.4640,0.1975,0.4999,0.1859,0.6741,0.8969 convnext_large_d_320,laion2b_s29b_b131k_ft,0.6345,0.7660,0.9341,0.9647,0.8313,0.3688,0.1999,0.2673,0.6846,0.7131,0.3770,0.9160,0.5688,0.6472,0.6929,0.5933,0.4400,0.8823,0.1027,0.7695,0.6813,0.7696,0.9346,0.8002,0.4576,0.5623,0.6989,0.9348,0.9854,0.7355,0.6496,0.8415,0.5599,0.4558,0.1664,0.4342,0.1782,0.6355,0.9090 -coca_ViT-L-14,laion2b_s13b_b90k,0.6305,0.7564,0.9433,0.9717,0.8318,0.3565,0.2365,0.2546,0.6271,0.6850,0.3622,0.9045,0.5572,0.6459,0.6794,0.5345,0.3540,0.8819,0.1899,0.7567,0.6414,0.7628,0.9400,0.8112,0.5278,0.6661,0.6883,0.9282,0.9905,0.7394,0.6205,0.8155,0.5431,0.4701,0.1348,0.4125,0.1917,0.6495,0.8969 +coca_ViT-L-14,laion2b_s13b_b90k,0.6327,0.7561,0.9430,0.9722,0.8318,0.3781,0.2446,0.2551,0.6239,0.6752,0.3590,0.9038,0.5624,0.6453,0.6798,0.5336,0.3540,0.8812,0.1899,0.7790,0.6405,0.7643,0.9402,0.8096,0.5500,0.6634,0.6878,0.9276,0.9894,0.7406,0.6237,0.8134,0.5428,0.4739,0.1375,0.4268,0.1932,0.6542,0.8960 ViT-g-14,laion2b_s12b_b42k,0.6299,0.7663,0.9415,0.9706,0.8392,0.3317,0.2225,0.2878,0.6824,0.6469,0.3768,0.9155,0.4985,0.6516,0.6956,0.5716,0.3785,0.8869,0.1350,0.6840,0.6761,0.7800,0.9431,0.8108,0.5624,0.6425,0.7176,0.9292,0.9865,0.7541,0.3930,0.8366,0.5647,0.4427,0.1486,0.4948,0.2040,0.6542,0.9132 convnext_large_d,laion2b_s26b_b102k_augreg,0.6294,0.7591,0.9365,0.9655,0.8309,0.3461,0.1997,0.2525,0.6739,0.6959,0.3610,0.9055,0.5299,0.6430,0.6826,0.5352,0.4425,0.8767,0.1027,0.8063,0.6618,0.7667,0.9282,0.7891,0.5309,0.5612,0.6768,0.9316,0.9829,0.7307,0.6812,0.8384,0.5550,0.4646,0.1549,0.3964,0.1793,0.6402,0.9019 ViT-L-14-336,openai,0.6284,0.7656,0.9225,0.9493,0.7436,0.2003,0.1895,0.3445,0.5559,0.6144,0.3346,0.9386,0.5239,0.6100,0.7089,0.7748,0.3265,0.8905,0.2616,0.7916,0.7183,0.7852,0.9369,0.7815,0.6073,0.7057,0.6379,0.7932,0.9943,0.6865,0.5560,0.7730,0.4751,0.4145,0.1490,0.6456,0.2325,0.6390,0.9015 ViT-L-14-quickgelu,metaclip_400m,0.6252,0.7620,0.9464,0.9544,0.7727,0.2271,0.2514,0.3085,0.6245,0.6033,0.3983,0.9073,0.4755,0.6505,0.6977,0.6640,0.2895,0.8889,0.2419,0.6186,0.6923,0.7648,0.9381,0.7440,0.7039,0.6551,0.6848,0.8477,0.9928,0.7073,0.3239,0.7981,0.5191,0.5175,0.1408,0.6916,0.1874,0.6741,0.8931 +ViT-B-16-SigLIP,webli,0.6232,0.7604,0.9518,0.9234,0.7223,0.2373,0.2409,0.1594,0.6468,0.4428,0.4377,0.9162,0.5164,0.6792,0.6893,0.4541,0.3815,0.9030,0.4093,0.8354,0.5510,0.8549,0.9420,0.7212,0.5953,0.5244,0.6454,0.9081,0.9821,0.7001,0.5586,0.8189,0.5676,0.5738,0.1309,0.6045,0.1265,0.6589,0.9106 +ViT-B-16-SigLIP-256,webli,0.6226,0.7653,0.9496,0.9334,0.7327,0.2276,0.2340,0.1581,0.6574,0.4606,0.4473,0.9200,0.4940,0.6810,0.6920,0.4877,0.3785,0.9076,0.3685,0.8457,0.5723,0.8521,0.9424,0.7254,0.5657,0.5739,0.6440,0.9106,0.9818,0.7026,0.5399,0.8272,0.5724,0.5715,0.1493,0.4966,0.1253,0.6589,0.9061 ViT-L-14,commonpool_xl_s13b_b90k,0.6207,0.7229,0.9327,0.9801,0.8410,0.1985,0.2461,0.2962,0.6202,0.6889,0.1957,0.9107,0.5467,0.6118,0.6511,0.5625,0.2855,0.8594,0.3390,0.9084,0.7022,0.6966,0.9060,0.8076,0.5248,0.5953,0.5756,0.8939,0.9890,0.7103,0.6589,0.7339,0.4652,0.5072,0.1229,0.5246,0.1948,0.6811,0.8990 ViT-L-14,laion2b_s32b_b82k,0.6205,0.7525,0.9388,0.9662,0.8332,0.3123,0.2234,0.2631,0.6293,0.6459,0.3652,0.9100,0.5618,0.6328,0.6780,0.5385,0.3870,0.8742,0.2293,0.5410,0.6529,0.7479,0.9309,0.8053,0.5641,0.5925,0.6687,0.9263,0.9885,0.7434,0.4087,0.8251,0.5493,0.4385,0.1257,0.5972,0.2007,0.6402,0.8919 ViT-L-14,openai,0.6173,0.7554,0.9249,0.9559,0.7582,0.1943,0.2021,0.3187,0.5537,0.6263,0.3181,0.9305,0.5055,0.5959,0.6983,0.7075,0.3235,0.8784,0.2180,0.7634,0.6889,0.7923,0.9323,0.7828,0.5204,0.6881,0.6337,0.7788,0.9936,0.6756,0.5840,0.7508,0.4642,0.4136,0.1211,0.6741,0.2229,0.6297,0.8839 +coca_ViT-L-14,mscoco_finetuned_laion2b_s13b_b90k,0.6159,0.7204,0.9420,0.9630,0.7965,0.3765,0.2501,0.1800,0.6213,0.5867,0.2329,0.8436,0.5453,0.6114,0.6475,0.4548,0.3865,0.8574,0.3797,0.8292,0.6253,0.7074,0.9115,0.8106,0.4943,0.6107,0.6267,0.8865,0.9861,0.7398,0.5564,0.8373,0.6028,0.5146,0.1303,0.4294,0.1678,0.6636,0.8772 ViT-B-16,datacomp_xl_s13b_b90k,0.6147,0.7349,0.9380,0.9624,0.8212,0.3267,0.2461,0.2215,0.5793,0.5883,0.2970,0.9047,0.5523,0.6044,0.6598,0.4840,0.4285,0.8362,0.2883,0.7649,0.6350,0.7701,0.9254,0.8178,0.6002,0.5162,0.6535,0.8883,0.9811,0.7051,0.6272,0.7633,0.4880,0.4832,0.1181,0.4799,0.1504,0.6168,0.8990 -coca_ViT-L-14,mscoco_finetuned_laion2b_s13b_b90k,0.6138,0.7210,0.9459,0.9626,0.7966,0.3649,0.2488,0.1810,0.6218,0.5904,0.2344,0.8449,0.5532,0.6116,0.6486,0.4568,0.3905,0.8579,0.3502,0.8220,0.6257,0.7078,0.9104,0.8127,0.4687,0.6134,0.6232,0.8875,0.9864,0.7377,0.5317,0.8373,0.6038,0.5178,0.1309,0.4097,0.1682,0.6729,0.8768 ViT-B-32-256,datacomp_s34b_b86k,0.6087,0.7281,0.9348,0.9653,0.8287,0.2489,0.2271,0.1968,0.6064,0.6469,0.3645,0.8909,0.5152,0.6065,0.6481,0.3757,0.4635,0.8344,0.2658,0.7939,0.5960,0.7822,0.9115,0.7880,0.5880,0.5294,0.6505,0.8990,0.9731,0.7021,0.6708,0.7486,0.4892,0.4300,0.0910,0.6252,0.0000,0.6238,0.8923 +ViT-B-16-SigLIP-i18n-256,webli,0.6068,0.7513,0.9475,0.9118,0.7216,0.2552,0.1976,0.1593,0.6426,0.3826,0.3325,0.9171,0.5276,0.6588,0.6814,0.4585,0.3685,0.8920,0.3826,0.8301,0.5976,0.8387,0.9387,0.7536,0.5381,0.5700,0.5737,0.8926,0.9764,0.6978,0.4272,0.8088,0.5470,0.5710,0.1451,0.4899,0.1064,0.6472,0.9186 RN50x64,openai,0.6061,0.7391,0.9026,0.8510,0.5985,0.2254,0.1994,0.2981,0.5314,0.5765,0.3103,0.9205,0.4792,0.5593,0.6706,0.7077,0.3830,0.8441,0.3094,0.8583,0.6820,0.7745,0.9360,0.7398,0.5387,0.7106,0.6265,0.7581,0.9829,0.6661,0.6044,0.7794,0.4683,0.3936,0.1469,0.5280,0.1939,0.6472,0.8898 ViT-B-16-quickgelu,metaclip_fullcc,0.6041,0.7212,0.9328,0.9572,0.7891,0.2935,0.2260,0.2271,0.6223,0.5265,0.3059,0.8882,0.4659,0.6016,0.6505,0.4953,0.4150,0.8423,0.1871,0.6610,0.6138,0.7358,0.9175,0.7818,0.5915,0.5898,0.6744,0.8302,0.9841,0.6879,0.3909,0.7811,0.5035,0.5221,0.1227,0.6993,0.1932,0.6402,0.8868 ViT-L-14,laion400m_e32,0.5971,0.7277,0.9266,0.9464,0.7741,0.2421,0.2452,0.2302,0.6053,0.6233,0.2490,0.9007,0.4989,0.5964,0.6545,0.4647,0.4190,0.8467,0.1997,0.7612,0.5969,0.7306,0.9170,0.7561,0.4968,0.5601,0.6741,0.8962,0.9808,0.7258,0.4955,0.7891,0.5137,0.3932,0.1254,0.4555,0.1708,0.6168,0.8839 @@ -53,7 +68,7 @@ ViT-B-16,laion400m_e32,0.5621,0.6705,0.9131,0.9172,0.7116,0.2869,0.2451,0.1810,0 ViT-B-16,laion400m_e31,0.5617,0.6698,0.9159,0.9169,0.7130,0.2889,0.2451,0.1804,0.5138,0.5033,0.1742,0.8587,0.4353,0.5233,0.5943,0.3327,0.5035,0.7777,0.1997,0.6531,0.5128,0.6693,0.8911,0.7678,0.5925,0.5459,0.5849,0.8365,0.9703,0.6958,0.3388,0.7451,0.4674,0.4225,0.1056,0.5976,0.1546,0.5946,0.8534 ViT-B-32-quickgelu,metaclip_fullcc,0.5577,0.6766,0.9290,0.9518,0.7767,0.1871,0.2307,0.1764,0.5883,0.4991,0.2705,0.8309,0.3922,0.5599,0.5957,0.2993,0.4825,0.7805,0.1871,0.4272,0.5286,0.6935,0.9087,0.7652,0.5596,0.5310,0.6124,0.7738,0.9630,0.6689,0.3447,0.7295,0.4662,0.5238,0.0915,0.5656,0.1588,0.6051,0.8610 convnext_base,laion400m_s13b_b51k,0.5576,0.6627,0.9151,0.8899,0.6462,0.2386,0.2209,0.1700,0.5404,0.4850,0.1556,0.8515,0.4551,0.5196,0.5859,0.3092,0.4925,0.7575,0.2925,0.6114,0.5058,0.6900,0.8853,0.7528,0.6116,0.5376,0.5683,0.8409,0.9656,0.6845,0.4038,0.7438,0.4615,0.4045,0.1095,0.6565,0.1589,0.5537,0.8530 -coca_ViT-B-32,laion2b_s13b_b90k,0.5547,0.6359,0.9115,0.9389,0.7396,0.1889,0.2057,0.1444,0.5388,0.4615,0.1882,0.7901,0.4474,0.5139,0.5569,0.2160,0.4995,0.7352,0.2686,0.7148,0.4518,0.6296,0.8875,0.7805,0.5974,0.5772,0.6010,0.8414,0.9634,0.6751,0.5519,0.7297,0.4560,0.4588,0.0943,0.5609,0.1088,0.5736,0.8447 +coca_ViT-B-32,laion2b_s13b_b90k,0.5533,0.6331,0.9078,0.9387,0.7378,0.1831,0.2175,0.1450,0.5367,0.4602,0.1783,0.7893,0.4532,0.5121,0.5522,0.2149,0.4920,0.7376,0.2644,0.7097,0.4470,0.6226,0.8875,0.7832,0.5938,0.5766,0.5994,0.8397,0.9626,0.6736,0.5503,0.7248,0.4537,0.4698,0.0876,0.5749,0.1010,0.5724,0.8430 ViT-B-32,laion2b_e16,0.5483,0.6565,0.9104,0.9403,0.7544,0.1923,0.2310,0.1652,0.5383,0.5030,0.2298,0.8166,0.3655,0.5287,0.5739,0.2615,0.5030,0.7588,0.1758,0.6347,0.4877,0.6732,0.8903,0.7877,0.5072,0.5437,0.6190,0.8437,0.9653,0.6851,0.4164,0.7539,0.4768,0.4602,0.0971,0.4648,0.0000,0.5724,0.8526 roberta-ViT-B-32,laion2b_s12b_b32k,0.5411,0.6171,0.9039,0.9325,0.7505,0.1472,0.2007,0.1472,0.5920,0.5215,0.1725,0.7812,0.4082,0.4912,0.5331,0.2120,0.5075,0.7224,0.3854,0.6636,0.4499,0.5893,0.8670,0.7804,0.4985,0.5420,0.6117,0.8315,0.9564,0.6627,0.4526,0.7302,0.4590,0.4583,0.0606,0.4098,0.1161,0.5549,0.8426 ViT-B-32-quickgelu,metaclip_400m,0.5387,0.6558,0.9171,0.9125,0.7006,0.2175,0.2448,0.1716,0.5255,0.5239,0.2680,0.8106,0.3576,0.5330,0.5760,0.2863,0.4680,0.7477,0.2588,0.4144,0.5046,0.6811,0.8877,0.7081,0.6426,0.5338,0.5954,0.7060,0.9543,0.6345,0.2056,0.7007,0.4386,0.5097,0.0819,0.6443,0.0000,0.5970,0.8539 @@ -66,12 +81,12 @@ ViT-B-32-quickgelu,openai,0.5245,0.6332,0.8758,0.8983,0.6423,0.2320,0.2335,0.172 RN50x4,openai,0.5188,0.6627,0.8661,0.7943,0.4514,0.2045,0.0905,0.2039,0.4862,0.3354,0.2102,0.8640,0.3622,0.4468,0.5944,0.4145,0.4955,0.7274,0.2335,0.4903,0.5141,0.6766,0.8829,0.6814,0.5675,0.6716,0.5338,0.6673,0.9658,0.6089,0.3190,0.7234,0.4318,0.3912,0.0870,0.5435,0.1130,0.5654,0.8376 ViT-B-32,laion400m_e31,0.5077,0.6022,0.8916,0.8825,0.6781,0.1549,0.2261,0.1356,0.5218,0.4694,0.1437,0.7814,0.4082,0.4648,0.5234,0.1957,0.5085,0.7079,0.1224,0.4108,0.4281,0.6319,0.8541,0.7312,0.5495,0.5162,0.5108,0.7436,0.9494,0.6508,0.2891,0.6890,0.4327,0.4262,0.0745,0.4975,0.1076,0.5491,0.8328 ViT-B-32,laion400m_e32,0.5074,0.6024,0.8918,0.8840,0.6773,0.1536,0.2261,0.1349,0.5229,0.4754,0.1467,0.7817,0.4070,0.4646,0.5237,0.1953,0.5080,0.7084,0.1181,0.4000,0.4292,0.6323,0.8513,0.7328,0.5490,0.5206,0.5094,0.7454,0.9498,0.6509,0.2759,0.6866,0.4337,0.4265,0.0741,0.5084,0.1068,0.5444,0.8326 -RN101-quickgelu,openai,0.5033,0.6228,0.8527,0.8078,0.4764,0.2437,0.0923,0.1693,0.4335,0.3131,0.1853,0.8367,0.3753,0.4106,0.5612,0.2944,0.5085,0.6817,0.2644,0.5254,0.4515,0.6532,0.8652,0.6512,0.5819,0.6403,0.5476,0.6100,0.9680,0.5803,0.3185,0.6852,0.4025,0.4130,0.0888,0.4723,0.1615,0.5631,0.8164 RN101,openai,0.5033,0.6228,0.8527,0.8078,0.4764,0.2437,0.0923,0.1693,0.4335,0.3131,0.1853,0.8367,0.3753,0.4106,0.5612,0.2944,0.5085,0.6817,0.2644,0.5254,0.4515,0.6532,0.8652,0.6512,0.5819,0.6403,0.5476,0.6100,0.9680,0.5803,0.3185,0.6852,0.4025,0.4130,0.0888,0.4723,0.1615,0.5631,0.8164 +RN101-quickgelu,openai,0.5033,0.6228,0.8527,0.8078,0.4764,0.2437,0.0923,0.1693,0.4335,0.3131,0.1853,0.8367,0.3753,0.4106,0.5612,0.2944,0.5085,0.6817,0.2644,0.5254,0.4515,0.6532,0.8652,0.6512,0.5819,0.6403,0.5476,0.6100,0.9680,0.5803,0.3185,0.6852,0.4025,0.4130,0.0888,0.4723,0.1615,0.5631,0.8164 ViT-B-16,commonpool_l_laion_s1b_b8k,0.5011,0.5526,0.8766,0.9296,0.7184,0.2681,0.2173,0.1119,0.4144,0.4115,0.0714,0.7661,0.3296,0.4315,0.4790,0.2004,0.4930,0.6501,0.3432,0.4753,0.4638,0.5023,0.7769,0.7686,0.5158,0.5228,0.5314,0.6760,0.9409,0.6278,0.4301,0.6447,0.3924,0.4476,0.0490,0.5127,0.1026,0.5514,0.8463 ViT-B-16,commonpool_l_image_s1b_b8k,0.4812,0.5719,0.8856,0.9321,0.6955,0.2143,0.2453,0.1308,0.4170,0.3193,0.0735,0.7797,0.2514,0.4343,0.4872,0.2143,0.4725,0.6356,0.3826,0.2219,0.4793,0.4817,0.7784,0.7841,0.5002,0.4986,0.4622,0.6627,0.9489,0.6335,0.2673,0.6026,0.3622,0.4787,0.0424,0.5000,0.0000,0.5946,0.8422 -RN50-quickgelu,openai,0.4810,0.5982,0.8329,0.7157,0.4030,0.2171,0.1623,0.1542,0.4154,0.4081,0.1703,0.8080,0.3510,0.3544,0.5284,0.2327,0.5720,0.6073,0.1730,0.5755,0.4141,0.6522,0.8529,0.6510,0.6393,0.5645,0.4521,0.5453,0.9419,0.5994,0.2883,0.6868,0.3869,0.3622,0.0623,0.5624,0.0000,0.5222,0.8129 RN50,openai,0.4810,0.5982,0.8329,0.7157,0.4030,0.2171,0.1623,0.1542,0.4154,0.4081,0.1703,0.8080,0.3510,0.3544,0.5284,0.2327,0.5720,0.6073,0.1730,0.5755,0.4141,0.6522,0.8529,0.6510,0.6393,0.5645,0.4521,0.5453,0.9419,0.5994,0.2883,0.6868,0.3869,0.3622,0.0623,0.5624,0.0000,0.5222,0.8129 +RN50-quickgelu,openai,0.4810,0.5982,0.8329,0.7157,0.4030,0.2171,0.1623,0.1542,0.4154,0.4081,0.1703,0.8080,0.3510,0.3544,0.5284,0.2327,0.5720,0.6073,0.1730,0.5755,0.4141,0.6522,0.8529,0.6510,0.6393,0.5645,0.4521,0.5453,0.9419,0.5994,0.2883,0.6868,0.3869,0.3622,0.0623,0.5624,0.0000,0.5222,0.8129 ViT-B-16,commonpool_l_text_s1b_b8k,0.4760,0.5605,0.8720,0.9391,0.7054,0.1843,0.2373,0.0995,0.3941,0.3830,0.0451,0.7724,0.2317,0.4437,0.4835,0.2220,0.4770,0.6708,0.2686,0.2593,0.4911,0.5164,0.7049,0.7669,0.4857,0.4931,0.4663,0.6525,0.9523,0.6088,0.2122,0.6078,0.3730,0.4570,0.0623,0.5697,0.0000,0.5643,0.8564 ViT-B-16,commonpool_l_basic_s1b_b8k,0.4585,0.5155,0.8444,0.8289,0.5251,0.2061,0.2277,0.1173,0.4133,0.3820,0.0481,0.7461,0.2021,0.3932,0.4325,0.1913,0.4600,0.6087,0.3333,0.2809,0.4493,0.4357,0.6956,0.7151,0.5899,0.5387,0.4313,0.7216,0.9373,0.5974,0.1173,0.6015,0.3583,0.4812,0.0436,0.5712,0.0000,0.5421,0.8384 ViT-B-16,commonpool_l_s1b_b8k,0.4370,0.4593,0.8089,0.9133,0.6421,0.1594,0.2203,0.1177,0.3383,0.3348,0.0316,0.6735,0.2766,0.3448,0.3914,0.1592,0.4335,0.5265,0.2686,0.3603,0.4126,0.3681,0.5587,0.7093,0.5516,0.5118,0.4154,0.6060,0.9339,0.5713,0.3047,0.4948,0.2855,0.4777,0.0399,0.5102,0.0000,0.5654,0.8305 @@ -92,9 +107,9 @@ RN50-quickgelu,yfcc15m,0.2776,0.3275,0.5089,0.4919,0.2033,0.1305,0.1990,0.0637,0 ViT-B-32,commonpool_m_s128m_b4k,0.2580,0.1755,0.5231,0.7459,0.4391,0.1263,0.2265,0.0362,0.1606,0.2537,0.0115,0.2342,0.0869,0.0952,0.1440,0.0388,0.2780,0.1983,0.2743,0.0933,0.1574,0.1128,0.1676,0.5448,0.5048,0.5003,0.1810,0.1332,0.7690,0.3066,0.0933,0.1599,0.0974,0.3983,0.0127,0.5015,0.0000,0.4276,0.5942 ViT-B-32,commonpool_s_clip_s13m_b4k,0.1731,0.0505,0.2483,0.4768,0.1937,0.1529,0.2313,0.0119,0.0782,0.2067,0.0083,0.0801,0.0732,0.0200,0.0380,0.0181,0.1380,0.0655,0.2785,0.0874,0.0506,0.0539,0.0796,0.3379,0.6367,0.5014,0.0806,0.0276,0.5353,0.1126,0.1166,0.0343,0.0224,0.2994,0.0004,0.6874,0.0000,0.2605,0.2827 ViT-B-32,commonpool_s_text_s13m_b4k,0.1573,0.0460,0.2231,0.4679,0.1844,0.1350,0.1899,0.0121,0.0670,0.0896,0.0139,0.0618,0.0411,0.0175,0.0398,0.0187,0.1270,0.0606,0.3980,0.0771,0.0494,0.0428,0.0581,0.2942,0.5027,0.5008,0.1029,0.0204,0.5019,0.1051,0.0933,0.0424,0.0214,0.3120,0.0015,0.5000,0.0000,0.2745,0.2843 -ViT-B-32,commonpool_s_image_s13m_b4k,0.1449,0.0392,0.2238,0.3176,0.1329,0.1121,0.2217,0.0109,0.0521,0.1593,0.0120,0.0604,0.0579,0.0186,0.0308,0.0155,0.1055,0.0578,0.2883,0.0991,0.0436,0.0528,0.0474,0.2666,0.5273,0.4646,0.0794,0.0173,0.4601,0.0725,0.1305,0.0171,0.0130,0.2525,0.0033,0.5425,0.0085,0.2150,0.2752 ViT-B-32,datacomp_s_s13m_b4k,0.1449,0.0392,0.2238,0.3176,0.1329,0.1121,0.2217,0.0109,0.0521,0.1593,0.0120,0.0604,0.0579,0.0186,0.0308,0.0155,0.1055,0.0578,0.2883,0.0991,0.0436,0.0528,0.0474,0.2666,0.5273,0.4646,0.0794,0.0173,0.4601,0.0725,0.1305,0.0171,0.0130,0.2525,0.0033,0.5425,0.0085,0.2150,0.2752 +ViT-B-32,commonpool_s_image_s13m_b4k,0.1449,0.0392,0.2238,0.3176,0.1329,0.1121,0.2217,0.0109,0.0521,0.1593,0.0120,0.0604,0.0579,0.0186,0.0308,0.0155,0.1055,0.0578,0.2883,0.0991,0.0436,0.0528,0.0474,0.2666,0.5273,0.4646,0.0794,0.0173,0.4601,0.0725,0.1305,0.0171,0.0130,0.2525,0.0033,0.5425,0.0085,0.2150,0.2752 ViT-B-32,commonpool_s_basic_s13m_b4k,0.1423,0.0377,0.1806,0.2664,0.1154,0.1245,0.2335,0.0120,0.0553,0.0587,0.0103,0.0588,0.0638,0.0151,0.0319,0.0203,0.0985,0.0499,0.3390,0.1085,0.0440,0.0351,0.0488,0.3081,0.5096,0.4986,0.0795,0.0200,0.4659,0.0879,0.0810,0.0328,0.0168,0.3033,0.0003,0.5001,0.0000,0.2325,0.2643 ViT-B-32,commonpool_s_s13m_b4k,0.1420,0.0270,0.1564,0.4079,0.1296,0.1305,0.2233,0.0126,0.0574,0.1487,0.0081,0.0473,0.0654,0.0108,0.0234,0.0141,0.1000,0.0404,0.3460,0.0708,0.0360,0.0338,0.0443,0.2235,0.5268,0.5008,0.0698,0.0143,0.4266,0.0766,0.1121,0.0257,0.0132,0.3126,0.0002,0.5124,0.0000,0.2290,0.2167 ViT-B-32,commonpool_s_laion_s13m_b4k,0.1332,0.0305,0.1549,0.3364,0.1347,0.1309,0.1299,0.0098,0.0553,0.1578,0.0134,0.0501,0.0538,0.0125,0.0271,0.0147,0.1015,0.0443,0.2518,0.1387,0.0369,0.0244,0.0399,0.3030,0.4216,0.4992,0.0583,0.0155,0.4874,0.0659,0.1473,0.0223,0.0121,0.2410,0.0017,0.3703,0.0000,0.2079,0.2580 -coca_ViT-B-32,mscoco_finetuned_laion2b_s13b_b90k,0.1064,0.0091,0.0441,0.2002,0.0173,0.1315,0.2019,0.0047,0.0452,0.0844,0.0139,0.0177,0.0298,0.0034,0.0086,0.0091,0.0230,0.0158,0.2714,0.1442,0.0159,0.0131,0.0438,0.1247,0.5183,0.4992,0.0589,0.0058,0.2913,0.0211,0.1519,0.0104,0.0061,0.2375,0.0003,0.5140,0.0000,0.1729,0.0814 +coca_ViT-B-32,mscoco_finetuned_laion2b_s13b_b90k,0.1108,0.0079,0.0320,0.2564,0.0193,0.1245,0.2027,0.0044,0.0303,0.1157,0.0064,0.0159,0.0146,0.0028,0.0067,0.0121,0.0220,0.0199,0.3010,0.1506,0.0144,0.0054,0.0416,0.2023,0.5713,0.4992,0.0478,0.0056,0.2579,0.0204,0.1529,0.0092,0.0060,0.2329,0.0004,0.5681,0.0000,0.1729,0.0589 diff --git a/scripts/clipav1_vit_l16_i37_t8.sh b/scripts/clipav1_vit_l16_i37_t8.sh new file mode 100644 index 000000000..d3ff0901e --- /dev/null +++ b/scripts/clipav1_vit_l16_i37_t8.sh @@ -0,0 +1,6 @@ +# eval on a single gpu +CUDA_VISIBLE_DEVICES=2 TORCH_CUDNN_V8_API_ENABLED=1 TFDS_PREFETCH_SIZE=8192 python3 -m training.main \ + --model ViT-L-16-CL32-GAP \ + --pretrained "/path/to/clipa_vit_l16_i37_t8.pt" \ + --seed 0 \ + --imagenet-val '/path/to/ImageNet/val' \ No newline at end of file diff --git a/scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh b/scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh new file mode 100644 index 000000000..7f22386c3 --- /dev/null +++ b/scripts/clipav2_vit_h14_i84_224_336_cl32_gap_datacomp1b.sh @@ -0,0 +1,10 @@ +CUDA_VISIBLE_DEVICES=1 python3 -m training.main \ + --model ViT-H-14-CL32-GAP-BigVision \ + --pretrained "/path/to/vit_h14_i84_224_336_cl32_gap_datacomp1b.pt" \ + --force-image-size 336 \ + --square-resize-only \ + --interpolation 'bilinear' \ + --image-mean 0.485 0.456 0.406 \ + --image-std 0.229 0.224 0.225 \ + --seed 0 \ + --imagenet-val '/path/to/ImageNet/val' diff --git a/src/open_clip/__init__.py b/src/open_clip/__init__.py index fdb1199b8..23856a3f1 100644 --- a/src/open_clip/__init__.py +++ b/src/open_clip/__init__.py @@ -4,7 +4,8 @@ from .factory import list_models, add_model_config, get_model_config, load_checkpoint from .loss import ClipLoss, DistillClipLoss, CoCaLoss from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ - convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype + convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ + get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg from .openai import load_openai_model, list_openai_models from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained diff --git a/src/open_clip/big_vision.py b/src/open_clip/big_vision.py new file mode 100644 index 000000000..0d7eaf3fa --- /dev/null +++ b/src/open_clip/big_vision.py @@ -0,0 +1,136 @@ +import torch +import numpy as np + +from .model import CustomTextCLIP +from .transformer import TextTransformer, Transformer + + +@torch.no_grad() +def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): + """ Load weights from .npz checkpoints for official Google big_vision image-text models + + Currently the SigLIP source models are supported and a CustomTextCLIP destination model + w/ timm image encoder. + """ + from timm.layers import resample_patch_embed, resample_abs_pos_embed + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + interpolation = 'bilinear' + antialias = False + + def _convert_timm_img(module, prefix): + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: + embed_conv_w = resample_patch_embed( + embed_conv_w, + module.patch_embed.proj.weight.shape[-2:], + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + module.patch_embed.proj.weight.copy_(embed_conv_w) + module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + + if module.cls_token is not None: + module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + + pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) + if pos_embed_w.shape != module.pos_embed.shape: + assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' + num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) + pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, + new_size=module.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + module.pos_embed.copy_(pos_embed_w) + + mha_sub, b_sub, ln1_sub = (0, 0, 1) + for i, block in enumerate(module.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) + + module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + + if module.attn_pool is not None: + block_prefix = f'{prefix}MAPHead_0/' + mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' + module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) + module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) + module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) + module.attn_pool.kv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) + module.attn_pool.kv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) + module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + for r in range(2): + getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) + getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) + + def _convert_openclip_transformer(module: Transformer, prefix): + for i, block in enumerate(module.resblocks.children()): + block_prefix = f'{prefix}encoderblock_{i}/' + mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' + block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.in_proj_weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.in_proj_bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) + block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) + block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) + block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) + block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) + block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) + + def _convert_openclip_txt(module: TextTransformer, prefix): + module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) + module.positional_embedding.copy_(pos_embed_w) + _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') + module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) + module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) + module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + + _convert_timm_img(model.visual.trunk, 'params/img/') + _convert_openclip_txt(model.text, 'params/txt/') + model.logit_bias.copy_(_n2p(w['params/b'])[0]) + model.logit_scale.copy_(_n2p(w['params/t'])[0]) + + diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index ad81fb665..485de1bd5 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -123,35 +123,46 @@ def __init__( self.pad_id = pad_id @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) self.text_decoder.set_grad_checkpointing(enable) - def _encode_image(self, images, normalize=True): + def _encode_image(self, images, normalize: bool = True): image_latent, tokens_embs = self.visual(images) image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent return image_latent, tokens_embs - def _encode_text(self, text, normalize=True, embed_cls=True): + def _encode_text(self, text, normalize: bool = True, embed_cls: bool = True): text = text[:, :-1] if embed_cls else text # make space for CLS token text_latent, token_emb = self.text(text) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent return text_latent, token_emb - def encode_image(self, images, normalize=True): + def encode_image(self, images, normalize: bool = True): image_latent, _ = self._encode_image(images, normalize=normalize) return image_latent - def encode_text(self, text, normalize=True, embed_cls=True): + def encode_text(self, text, normalize: bool = True, embed_cls: bool = True): text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) return text_latent - def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): - text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) + def forward( + self, + image, + text: Optional[torch.Tensor] = None, + embed_cls: bool = True, + image_latent: Optional[torch.Tensor] = None, + image_embs: Optional[torch.Tensor] = None, + ): if image_latent is None or image_embs is None: image_latent, image_embs = self._encode_image(image) + if text is None: + return {"image_features": image_latent, "image_embs": image_embs} + + text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) + # TODO: add assertion to avoid bugs? labels = text[:, -token_embs.shape[1]:] diff --git a/src/open_clip/constants.py b/src/open_clip/constants.py index a670bb3fa..599c48c03 100644 --- a/src/open_clip/constants.py +++ b/src/open_clip/constants.py @@ -1,2 +1,6 @@ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) +INCEPTION_MEAN = (0.5, 0.5, 0.5) +INCEPTION_STD = (0.5, 0.5, 0.5) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 7268522e4..ef94b51f8 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -1,25 +1,24 @@ import json import logging import os -import pathlib import re from copy import deepcopy +from dataclasses import asdict from pathlib import Path from typing import Any, Dict, Optional, Tuple, Union -from functools import partial import torch from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ - resize_pos_embed, get_cast_dtype, resize_text_pos_embed + resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg from .coca_model import CoCa from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ list_pretrained_tags_by_model, download_pretrained_from_hf -from .transform import image_transform, AugmentationCfg -from .tokenizer import HFTokenizer, tokenize, syntax_mask_tokenize, random_mask_tokenize, block_mask_tokenize +from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs +from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH HF_HUB_PREFIX = 'hf-hub:' @@ -75,24 +74,54 @@ def get_model_config(model_name): return None -def get_tokenizer(model_name): +def _get_hf_config(model_id, cache_dir=None): + config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + return config + + +def get_tokenizer( + model_name: str = '', + context_length: Optional[int] = None, + **kwargs, +): if model_name.startswith(HF_HUB_PREFIX): - tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):]) + model_name = model_name[len(HF_HUB_PREFIX):] + try: + config = _get_hf_config(model_name)['model_cfg'] + except Exception: + tokenizer = HFTokenizer( + model_name, + context_length=context_length or DEFAULT_CONTEXT_LENGTH, + **kwargs, + ) + return tokenizer else: config = get_model_config(model_name) - if 'hf_tokenizer_name' in config['text_cfg']: - tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) - elif 'text_mask' in config['text_cfg'] and config['text_cfg']['text_mask'] == 'syntax': - tokenizer = syntax_mask_tokenize - elif 'text_mask' in config['text_cfg'] and config['text_cfg']['text_mask'] == 'random': - tokenizer = random_mask_tokenize - elif 'text_mask' in config['text_cfg'] and config['text_cfg']['text_mask'] == 'block': - tokenizer = block_mask_tokenize - else: - tokenizer = tokenize - if 'context_length' in config['text_cfg'].keys(): - context_length = config['text_cfg']['context_length'] - tokenizer = partial(tokenizer, context_length=context_length) + assert config is not None, f"No valid model config found for {model_name}." + + text_config = config.get('text_cfg', {}) + if 'tokenizer_kwargs' in text_config: + tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs) + else: + tokenizer_kwargs = kwargs + + if context_length is None: + context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH) + + if 'hf_tokenizer_name' in text_config: + tokenizer = HFTokenizer( + text_config['hf_tokenizer_name'], + context_length=context_length, + **tokenizer_kwargs, + ) + else: + tokenizer = SimpleTokenizer( + context_length=context_length, + **tokenizer_kwargs, + ) + return tokenizer @@ -112,6 +141,11 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'): def load_checkpoint(model, checkpoint_path, strict=True): + if Path(checkpoint_path).suffix in ('.npz', '.npy'): + from .big_vision import load_big_vision_weights + load_big_vision_weights(model, checkpoint_path) + return {} + state_dict = load_state_dict(checkpoint_path) # detect old format and make compatible with new format if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): @@ -136,6 +170,7 @@ def create_model( force_custom_text: bool = False, force_patch_dropout: Optional[float] = None, force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + force_preprocess_cfg: Optional[Dict[str, Any]] = None, pretrained_image: bool = False, pretrained_hf: bool = True, cache_dir: Optional[str] = None, @@ -143,20 +178,19 @@ def create_model( require_pretrained: bool = False, **model_kwargs, ): + force_preprocess_cfg = force_preprocess_cfg or {} + preprocess_cfg = asdict(PreprocessCfg()) has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) if has_hf_hub_prefix: model_id = model_name[len(HF_HUB_PREFIX):] checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) - config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) - - with open(config_path, 'r', encoding='utf-8') as f: - config = json.load(f) - pretrained_cfg = config['preprocess_cfg'] + config = _get_hf_config(model_id, cache_dir) + preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) model_cfg = config['model_cfg'] + pretrained_hf = False # override, no need to load original HF text weights else: model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names checkpoint_path = None - pretrained_cfg = {} model_cfg = None if isinstance(device, str): @@ -201,11 +235,12 @@ def create_model( # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes cast_dtype = get_cast_dtype(precision) is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) + if is_hf_model: + # load pretrained weights for HF text model IFF no CLIP weights being loaded + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model if custom_text: - if is_hf_model: - model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf if "multimodal_cfg" in model_cfg: model = CoCa(**model_cfg, **model_kwargs, cast_dtype=cast_dtype) else: @@ -222,6 +257,7 @@ def create_model( # Why? The convert_weights_to_lp fn only works with native models. model.to(device=device, dtype=dtype) from .transformer import LayerNormFp32 + def _convert_ln(m): if isinstance(m, LayerNormFp32): m.weight.data = m.weight.data.to(torch.float32) @@ -242,6 +278,7 @@ def _convert_ln(m): pretrained_cfg = get_pretrained_cfg(model_name, pretrained) if pretrained_cfg: checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) elif os.path.exists(pretrained): checkpoint_path = pretrained @@ -256,7 +293,7 @@ def _convert_ln(m): raise RuntimeError(error_str) pretrained_loaded = True elif has_hf_hub_prefix: - logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') load_checkpoint(model, checkpoint_path) pretrained_loaded = True @@ -265,16 +302,18 @@ def _convert_ln(m): raise RuntimeError( f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') - # set image / mean metadata from pretrained_cfg if available, or use default - model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN - model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD - if output_dict and hasattr(model, "output_dict"): model.output_dict = True if jit: model = torch.jit.script(model) + # set image preprocessing configuration in model attributes for convenience + if getattr(model.visual, 'image_size', None) is not None: + # use image_size set on model creation (via config or force_image_size arg) + force_preprocess_cfg['size'] = model.visual.image_size + set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg)) + return model @@ -325,15 +364,20 @@ def create_model_and_transforms( force_custom_text: bool = False, force_patch_dropout: Optional[float] = None, force_image_size: Optional[Union[int, Tuple[int, int]]] = None, - pretrained_image: bool = False, - pretrained_hf: bool = True, image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, + image_interpolation: Optional[str] = None, + image_resize_mode: Optional[str] = None, # only effective for inference aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, **model_kwargs, ): + force_preprocess_cfg = merge_preprocess_kwargs( + {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) + model = create_model( model_name, pretrained, @@ -344,6 +388,7 @@ def create_model_and_transforms( force_custom_text=force_custom_text, force_patch_dropout=force_patch_dropout, force_image_size=force_image_size, + force_preprocess_cfg=force_preprocess_cfg, pretrained_image=pretrained_image, pretrained_hf=pretrained_hf, cache_dir=cache_dir, @@ -351,20 +396,16 @@ def create_model_and_transforms( **model_kwargs, ) - image_mean = image_mean or getattr(model.visual, 'image_mean', None) - image_std = image_std or getattr(model.visual, 'image_std', None) - preprocess_train = image_transform( - model.visual.image_size, + pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg) + + preprocess_train = image_transform_v2( + pp_cfg, is_train=True, - mean=image_mean, - std=image_std, aug_cfg=aug_cfg, ) - preprocess_val = image_transform( - model.visual.image_size, + preprocess_val = image_transform_v2( + pp_cfg, is_train=False, - mean=image_mean, - std=image_std, ) return model, preprocess_train, preprocess_val @@ -379,12 +420,17 @@ def create_model_from_pretrained( force_quick_gelu: bool = False, force_custom_text: bool = False, force_image_size: Optional[Union[int, Tuple[int, int]]] = None, - return_transform: bool = True, image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, + image_interpolation: Optional[str] = None, + image_resize_mode: Optional[str] = None, # only effective for inference + return_transform: bool = True, cache_dir: Optional[str] = None, **model_kwargs, ): + force_preprocess_cfg = merge_preprocess_kwargs( + {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) + model = create_model( model_name, pretrained, @@ -394,6 +440,7 @@ def create_model_from_pretrained( force_quick_gelu=force_quick_gelu, force_custom_text=force_custom_text, force_image_size=force_image_size, + force_preprocess_cfg=force_preprocess_cfg, cache_dir=cache_dir, require_pretrained=True, **model_kwargs, @@ -402,13 +449,9 @@ def create_model_from_pretrained( if not return_transform: return model - image_mean = image_mean or getattr(model.visual, 'image_mean', None) - image_std = image_std or getattr(model.visual, 'image_std', None) - preprocess = image_transform( - model.visual.image_size, + preprocess = image_transform_v2( + PreprocessCfg(**model.visual.preprocess_cfg), is_train=False, - mean=image_mean, - std=image_std, ) return model, preprocess diff --git a/src/open_clip/hf_model.py b/src/open_clip/hf_model.py index 08dbdbcde..281a06cc5 100644 --- a/src/open_clip/hf_model.py +++ b/src/open_clip/hf_model.py @@ -103,7 +103,7 @@ def __init__( output_dim: int, config: PretrainedConfig = None, pooler_type: str = None, - proj: str = None, + proj_type: str = None, pretrained: bool = True, output_tokens: bool = False, ): @@ -139,11 +139,11 @@ def __init__( self.pooler = _POOLERS[pooler_type]() d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"]) - if (d_model == output_dim) and (proj is None): # do we always need a proj? + if (d_model == output_dim) and (proj_type is None): # do we always need a proj? self.proj = nn.Identity() - elif proj == 'linear': + elif proj_type == 'linear': self.proj = nn.Linear(d_model, output_dim, bias=False) - elif proj == 'mlp': + elif proj_type == 'mlp': hidden_size = (d_model + output_dim) // 2 self.proj = nn.Sequential( nn.Linear(d_model, hidden_size, bias=False), diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 0ccf01bca..0310ee560 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -2,21 +2,24 @@ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ -from dataclasses import dataclass +import copy import logging import math -from typing import Optional, Tuple, Union +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.utils.checkpoint import checkpoint +from functools import partial from .hf_model import HFTextEncoder from .modified_resnet import ModifiedResNet from .timm_model import TimmModel -from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\ + text_global_pool from .utils import to_2tuple @@ -31,14 +34,18 @@ class CLIPVisionCfg: ls_init_value: Optional[float] = None # layer scale initial value patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results - input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design - global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) - attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer - n_queries: int = 256 # n_queries for attentional pooler + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type) + attn_pooler_queries: int = 256 # n_queries for attentional pooler attn_pooler_heads: int = 8 # n heads for attentional_pooling + no_ln_pre: bool = False # disable pre transformer LayerNorm + pos_embed_type: str = 'learnable' + final_ln_after_pool: bool = False # apply final LayerNorm after pooling + pool_type: str = 'tok' output_tokens: bool = False + act_kwargs: Optional[dict] = None + norm_kwargs: Optional[dict] = None - timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') @@ -51,19 +58,29 @@ class CLIPVisionCfg: class CLIPTextCfg: context_length: int = 77 vocab_size: int = 49408 + hf_tokenizer_name: Optional[str] = None + tokenizer_kwargs: Optional[dict] = None + width: int = 512 heads: int = 8 layers: int = 12 + mlp_ratio: float = 4.0 ls_init_value: Optional[float] = None # layer scale initial value - hf_model_name: str = None - hf_tokenizer_name: str = None - hf_model_pretrained: bool = True - proj: str = 'mlp' - pooler_type: str = 'mean_pooler' embed_cls: bool = False pad_id: int = 0 + no_causal_mask: bool = False # disable causal masking + final_ln_after_pool: bool = False # apply final LayerNorm after pooling + pool_type: str = 'argmax' + proj_bias: bool = False output_tokens: bool = False - text_mask: str = 'first' # default first truncate in bpe_tokenizer + act_kwargs: dict = None + norm_kwargs: dict = None + + # HuggingFace specific text tower config + hf_model_name: Optional[str] = None + hf_model_pretrained: bool = True + hf_proj_type: str = 'mlp' + hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models def get_cast_dtype(precision: str): @@ -123,6 +140,11 @@ def _build_vision_tower( else: vision_heads = vision_cfg.width // vision_cfg.head_width norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + if vision_cfg.norm_kwargs: + norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) + if vision_cfg.act_kwargs is not None: + act_layer = partial(act_layer, **vision_cfg.act_kwargs) + visual = VisionTransformer( image_size=vision_cfg.image_size, patch_size=vision_cfg.patch_size, @@ -132,11 +154,13 @@ def _build_vision_tower( mlp_ratio=vision_cfg.mlp_ratio, ls_init_value=vision_cfg.ls_init_value, patch_dropout=vision_cfg.patch_dropout, - input_patchnorm=vision_cfg.input_patchnorm, - global_average_pool=vision_cfg.global_average_pool, attentional_pool=vision_cfg.attentional_pool, - n_queries=vision_cfg.n_queries, + attn_pooler_queries=vision_cfg.attn_pooler_queries, attn_pooler_heads=vision_cfg.attn_pooler_heads, + pos_embed_type=vision_cfg.pos_embed_type, + no_ln_pre=vision_cfg.no_ln_pre, + final_ln_after_pool=vision_cfg.final_ln_after_pool, + pool_type=vision_cfg.pool_type, output_tokens=vision_cfg.output_tokens, output_dim=embed_dim, act_layer=act_layer, @@ -159,14 +183,18 @@ def _build_text_tower( text = HFTextEncoder( text_cfg.hf_model_name, output_dim=embed_dim, - proj=text_cfg.proj, - pooler_type=text_cfg.pooler_type, + proj_type=text_cfg.hf_proj_type, + pooler_type=text_cfg.hf_pooler_type, pretrained=text_cfg.hf_model_pretrained, output_tokens=text_cfg.output_tokens, ) else: act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + if text_cfg.norm_kwargs: + norm_layer = partial(norm_layer, **text_cfg.norm_kwargs) + if text_cfg.act_kwargs is not None: + act_layer = partial(act_layer, **text_cfg.act_kwargs) text = TextTransformer( context_length=text_cfg.context_length, @@ -174,11 +202,15 @@ def _build_text_tower( width=text_cfg.width, heads=text_cfg.heads, layers=text_cfg.layers, + mlp_ratio=text_cfg.mlp_ratio, ls_init_value=text_cfg.ls_init_value, output_dim=embed_dim, embed_cls=text_cfg.embed_cls, - output_tokens=text_cfg.output_tokens, + no_causal_mask=text_cfg.no_causal_mask, pad_id=text_cfg.pad_id, + pool_type=text_cfg.pool_type, + proj_bias=text_cfg.proj_bias, + output_tokens=text_cfg.output_tokens, act_layer=act_layer, norm_layer=norm_layer, ) @@ -201,6 +233,7 @@ def __init__( ): super().__init__() self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) @@ -211,6 +244,7 @@ def __init__( self.positional_embedding = text.positional_embedding self.ln_final = text.ln_final self.text_projection = text.text_projection + self.text_pool_type = text.pool_type self.register_buffer('attn_mask', text.attn_mask, persistent=False) self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) @@ -242,8 +276,13 @@ def encode_text(self, text, normalize: bool = False): x = self.transformer(x, attn_mask=self.attn_mask) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + x, _ = text_global_pool(x, text, self.text_pool_type) + if self.text_projection is not None: + if isinstance(self.text_projection, nn.Linear): + x = self.text_projection(x) + else: + x = x @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x def forward( @@ -528,4 +567,40 @@ def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', anti old_pos_embed = old_pos_embed.permute(0, 2, 1)[0] new_pos_embed = old_pos_embed - state_dict['positional_embedding'] = new_pos_embed \ No newline at end of file + state_dict['positional_embedding'] = new_pos_embed + + +def get_model_preprocess_cfg(model): + module = getattr(model, 'visual', model) + preprocess_cfg = getattr(module, 'preprocess_cfg', {}) + if not preprocess_cfg: + # use separate legacy attributes if preprocess_cfg dict not found + size = getattr(module, 'image_size') + if size is not None: + preprocess_cfg['size'] = size + mean = getattr(module, 'image_mean', None) + if mean is not None: + preprocess_cfg['mean'] = getattr(module, 'mean') + std = getattr(module, 'image_std', None) + if std is not None: + preprocess_cfg['std'] = getattr(module, 'std') + return preprocess_cfg + + +def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): + module = getattr(model, 'visual', model) + module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat + module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat + module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict + + +def get_model_tokenize_cfg(model): + module = getattr(model, 'text', model) + cfg = {} + context_length = getattr(module, 'context_length', None) + if context_length is not None: + cfg['context_length'] = context_length + vocab_size = getattr(module, 'vocab_size', None) + if vocab_size is not None: + cfg['vocab_size'] = vocab_size + return cfg diff --git a/src/open_clip/model_configs/ViT-B-16-CL16.json b/src/open_clip/model_configs/ViT-B-16-CL16.json deleted file mode 100644 index 829f8c40a..000000000 --- a/src/open_clip/model_configs/ViT-B-16-CL16.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 768, - "patch_size": 16 - }, - "text_cfg": { - "context_length": 16, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-B-16-SigLIP-256.json b/src/open_clip/model_configs/ViT-B-16-SigLIP-256.json new file mode 100644 index 000000000..d7ad3acba --- /dev/null +++ b/src/open_clip/model_configs/ViT-B-16-SigLIP-256.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 256, + "timm_model_name": "vit_base_patch16_siglip_256", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 768, + "heads": 12, + "layers": 12, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-B-16-SigLIP-384.json b/src/open_clip/model_configs/ViT-B-16-SigLIP-384.json new file mode 100644 index 000000000..df9a25cdc --- /dev/null +++ b/src/open_clip/model_configs/ViT-B-16-SigLIP-384.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 384, + "timm_model_name": "vit_base_patch16_siglip_384", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 768, + "heads": 12, + "layers": 12, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-B-16-SigLIP-512.json b/src/open_clip/model_configs/ViT-B-16-SigLIP-512.json new file mode 100644 index 000000000..88b018528 --- /dev/null +++ b/src/open_clip/model_configs/ViT-B-16-SigLIP-512.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 512, + "timm_model_name": "vit_base_patch16_siglip_512", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 768, + "heads": 12, + "layers": 12, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json b/src/open_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json new file mode 100644 index 000000000..7a28797a7 --- /dev/null +++ b/src/open_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 256, + "timm_model_name": "vit_base_patch16_siglip_256", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 250000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 768, + "heads": 12, + "layers": 12, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-B-16-SigLIP.json b/src/open_clip/model_configs/ViT-B-16-SigLIP.json new file mode 100644 index 000000000..a9f2b654a --- /dev/null +++ b/src/open_clip/model_configs/ViT-B-16-SigLIP.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "vit_base_patch16_siglip_224", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 768, + "heads": 12, + "layers": 12, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-H-14-CL32-GAP.json b/src/open_clip/model_configs/ViT-H-14-CL32-GAP.json deleted file mode 100644 index 26f91cf50..000000000 --- a/src/open_clip/model_configs/ViT-H-14-CL32-GAP.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "embed_dim": 1024, - "vision_cfg": { - "image_size": 224, - "layers": 32, - "width": 1280, - "head_width": 80, - "patch_size": 14, - "global_average_pool": true - }, - "text_cfg": { - "context_length": 32, - "vocab_size": 49408, - "width": 1024, - "heads": 16, - "layers": 24 - } -} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-H-14-CL8-SyntaxMask-GAP.json b/src/open_clip/model_configs/ViT-H-14-CL8-SyntaxMask-GAP.json deleted file mode 100644 index 7e28b6173..000000000 --- a/src/open_clip/model_configs/ViT-H-14-CL8-SyntaxMask-GAP.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "embed_dim": 1024, - "vision_cfg": { - "image_size": 224, - "layers": 32, - "width": 1280, - "head_width": 80, - "patch_size": 14, - "global_average_pool": true - }, - "text_cfg": { - "context_length": 8, - "vocab_size": 49408, - "width": 1024, - "heads": 16, - "layers": 24, - "text_mask": "syntax" - } -} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-H-14-CLIPA-336.json b/src/open_clip/model_configs/ViT-H-14-CLIPA-336.json new file mode 100644 index 000000000..01fabb29d --- /dev/null +++ b/src/open_clip/model_configs/ViT-H-14-CLIPA-336.json @@ -0,0 +1,26 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 336, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 32, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 1024, + "heads": 16, + "layers": 24, + "pool_type": "last", + "no_causal_mask": true + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-H-14-CLIPA.json b/src/open_clip/model_configs/ViT-H-14-CLIPA.json new file mode 100644 index 000000000..7df033884 --- /dev/null +++ b/src/open_clip/model_configs/ViT-H-14-CLIPA.json @@ -0,0 +1,26 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 32, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 1024, + "heads": 16, + "layers": 24, + "pool_type": "last", + "no_causal_mask": true + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-L-14-CLIPA-336.json b/src/open_clip/model_configs/ViT-L-14-CLIPA-336.json new file mode 100644 index 000000000..60a4df589 --- /dev/null +++ b/src/open_clip/model_configs/ViT-L-14-CLIPA-336.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 32, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 768, + "heads": 12, + "layers": 12, + "pool_type": "last", + "no_causal_mask": true + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-L-14-CLIPA.json b/src/open_clip/model_configs/ViT-L-14-CLIPA.json new file mode 100644 index 000000000..b4dde7b54 --- /dev/null +++ b/src/open_clip/model_configs/ViT-L-14-CLIPA.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 32, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 768, + "heads": 12, + "layers": 12, + "pool_type": "last", + "no_causal_mask": true + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-L-16-CL16-GAP.json b/src/open_clip/model_configs/ViT-L-16-CL16-GAP.json deleted file mode 100644 index a4262daf6..000000000 --- a/src/open_clip/model_configs/ViT-L-16-CL16-GAP.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 224, - "layers": 24, - "width": 1024, - "patch_size": 16, - "global_average_pool": true - }, - "text_cfg": { - "context_length": 16, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-L-16-CL8-Syntax-GAP.json b/src/open_clip/model_configs/ViT-L-16-CL8-Syntax-GAP.json deleted file mode 100644 index 3569fdbe8..000000000 --- a/src/open_clip/model_configs/ViT-L-16-CL8-Syntax-GAP.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 224, - "layers": 24, - "width": 1024, - "patch_size": 16, - "global_average_pool": true - }, - "text_cfg": { - "context_length": 8, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12, - "text_mask": "syntax" - } -} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-L-16-SigLIP-256.json b/src/open_clip/model_configs/ViT-L-16-SigLIP-256.json new file mode 100644 index 000000000..5ba8f7abb --- /dev/null +++ b/src/open_clip/model_configs/ViT-L-16-SigLIP-256.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 1024, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 256, + "timm_model_name": "vit_large_patch16_siglip_256", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 1024, + "heads": 16, + "layers": 24, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-L-16-SigLIP-384.json b/src/open_clip/model_configs/ViT-L-16-SigLIP-384.json new file mode 100644 index 000000000..fd2cc2e34 --- /dev/null +++ b/src/open_clip/model_configs/ViT-L-16-SigLIP-384.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 1024, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 384, + "timm_model_name": "vit_large_patch16_siglip_384", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 1024, + "heads": 16, + "layers": 24, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-384.json b/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-384.json new file mode 100644 index 000000000..4c527f581 --- /dev/null +++ b/src/open_clip/model_configs/ViT-SO400M-14-SigLIP-384.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 1152, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 384, + "timm_model_name": "vit_so400m_patch14_siglip_384", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 64, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 1152, + "heads": 16, + "layers": 27, + "mlp_ratio": 3.7362, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-SO400M-14-SigLIP.json b/src/open_clip/model_configs/ViT-SO400M-14-SigLIP.json new file mode 100644 index 000000000..564eb78a4 --- /dev/null +++ b/src/open_clip/model_configs/ViT-SO400M-14-SigLIP.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 1152, + "init_logit_bias": -10, + "custom_text": true, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "vit_so400m_patch14_siglip_224", + "timm_model_pretrained": false, + "timm_pool": "map", + "timm_proj": "none" + }, + "text_cfg": { + "context_length": 16, + "vocab_size": 32000, + "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", + "tokenizer_kwargs": { + "clean": "canonicalize" + }, + "width": 1152, + "heads": 16, + "layers": 27, + "mlp_ratio": 3.7362, + "no_causal_mask": true, + "proj_bias": true, + "pool_type": "last", + "norm_kwargs":{ + "eps": 1e-6 + } + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-bigG-14-CLIPA-336.json b/src/open_clip/model_configs/ViT-bigG-14-CLIPA-336.json new file mode 100644 index 000000000..75ba7675c --- /dev/null +++ b/src/open_clip/model_configs/ViT-bigG-14-CLIPA-336.json @@ -0,0 +1,27 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 336, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14, + "no_ln_pre": true, + "pool_type": "avg", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 32, + "vocab_size": 32000, + "hf_tokenizer_name": "bert-base-uncased", + "tokenizer_kwargs": { + "strip_sep_token": true + }, + "width": 1280, + "heads": 20, + "layers": 32, + "pool_type": "last", + "no_causal_mask": true + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/coca_roberta-ViT-B-32.json b/src/open_clip/model_configs/coca_roberta-ViT-B-32.json index fb46354b9..aa9d3f562 100644 --- a/src/open_clip/model_configs/coca_roberta-ViT-B-32.json +++ b/src/open_clip/model_configs/coca_roberta-ViT-B-32.json @@ -10,7 +10,7 @@ "text_cfg": { "hf_model_name": "roberta-base", "hf_tokenizer_name": "roberta-base", - "proj": "linear", + "hf_proj_type": "linear", "width": 768, "output_tokens": true }, diff --git a/src/open_clip/model_configs/mt5-base-ViT-B-32.json b/src/open_clip/model_configs/mt5-base-ViT-B-32.json index 58cad89cf..e22366897 100644 --- a/src/open_clip/model_configs/mt5-base-ViT-B-32.json +++ b/src/open_clip/model_configs/mt5-base-ViT-B-32.json @@ -9,7 +9,6 @@ "text_cfg": { "hf_model_name": "google/mt5-base", "hf_tokenizer_name": "google/mt5-base", - "proj": "mlp", - "pooler_type": "mean_pooler" + "hf_pooler_type": "mean_pooler" } } diff --git a/src/open_clip/model_configs/mt5-xl-ViT-H-14.json b/src/open_clip/model_configs/mt5-xl-ViT-H-14.json index b43281077..f58717cdd 100644 --- a/src/open_clip/model_configs/mt5-xl-ViT-H-14.json +++ b/src/open_clip/model_configs/mt5-xl-ViT-H-14.json @@ -10,7 +10,6 @@ "text_cfg": { "hf_model_name": "google/mt5-xl", "hf_tokenizer_name": "google/mt5-xl", - "proj": "mlp", - "pooler_type": "mean_pooler" + "hf_pooler_type": "mean_pooler" } } diff --git a/src/open_clip/model_configs/nllb-clip-base.json b/src/open_clip/model_configs/nllb-clip-base.json index 8b85d0df5..57265b33f 100644 --- a/src/open_clip/model_configs/nllb-clip-base.json +++ b/src/open_clip/model_configs/nllb-clip-base.json @@ -9,7 +9,7 @@ "text_cfg": { "hf_model_name": "facebook/nllb-200-distilled-600M", "hf_tokenizer_name": "facebook/nllb-200-distilled-600M", - "proj": "linear", - "pooler_type": "cls_pooler" + "hf_proj_type": "linear", + "hf_pooler_type": "cls_pooler" } } \ No newline at end of file diff --git a/src/open_clip/model_configs/nllb-clip-large.json b/src/open_clip/model_configs/nllb-clip-large.json index 4e5bc14a8..72d04a733 100644 --- a/src/open_clip/model_configs/nllb-clip-large.json +++ b/src/open_clip/model_configs/nllb-clip-large.json @@ -10,7 +10,7 @@ "text_cfg": { "hf_model_name": "facebook/nllb-200-distilled-1.3B", "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B", - "proj": "linear", - "pooler_type": "cls_pooler" + "hf_proj_type": "linear", + "hf_pooler_type": "cls_pooler" } } \ No newline at end of file diff --git a/src/open_clip/model_configs/roberta-ViT-B-32.json b/src/open_clip/model_configs/roberta-ViT-B-32.json index ed687d472..c0c7a5599 100644 --- a/src/open_clip/model_configs/roberta-ViT-B-32.json +++ b/src/open_clip/model_configs/roberta-ViT-B-32.json @@ -10,7 +10,6 @@ "text_cfg": { "hf_model_name": "roberta-base", "hf_tokenizer_name": "roberta-base", - "proj": "mlp", - "pooler_type": "mean_pooler" + "hf_pooler_type": "mean_pooler" } } diff --git a/src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json b/src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json index 751bccc2c..375fa9e12 100644 --- a/src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json +++ b/src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json @@ -9,7 +9,6 @@ "text_cfg": { "hf_model_name": "xlm-roberta-base", "hf_tokenizer_name": "xlm-roberta-base", - "proj": "mlp", - "pooler_type": "mean_pooler" + "hf_pooler_type": "mean_pooler" } } diff --git a/src/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json b/src/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json index 31f271faa..c56b4e898 100644 --- a/src/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json +++ b/src/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json @@ -10,7 +10,6 @@ "text_cfg": { "hf_model_name": "xlm-roberta-large", "hf_tokenizer_name": "xlm-roberta-large", - "proj": "mlp", - "pooler_type": "mean_pooler" + "hf_pooler_type": "mean_pooler" } } diff --git a/src/open_clip/pos_embed.py b/src/open_clip/pos_embed.py new file mode 100644 index 000000000..5c8082b34 --- /dev/null +++ b/src/open_clip/pos_embed.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np + +import torch + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index 59961f986..2454f5797 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -7,6 +7,8 @@ from tqdm import tqdm +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, INCEPTION_MEAN, INCEPTION_STD, \ + IMAGENET_MEAN, IMAGENET_STD from .version import __version__ try: @@ -18,13 +20,43 @@ _has_hf_hub = False -def _pcfg(url='', hf_hub='', mean=None, std=None): - return dict( - url=url, - hf_hub=hf_hub, - mean=mean, - std=std, - ) +def _pcfg(url='', hf_hub='', **kwargs): + # OpenAI / OpenCLIP defaults + return { + 'url': url, + 'hf_hub': hf_hub, + 'mean': OPENAI_DATASET_MEAN, + 'std': OPENAI_DATASET_STD, + 'interpolation': 'bicubic', + 'resize_mode': 'shortest', + **kwargs, + } + + +def _slpcfg(url='', hf_hub='', **kwargs): + # SiGLIP defaults + return { + 'url': url, + 'hf_hub': hf_hub, + 'mean': INCEPTION_MEAN, + 'std': INCEPTION_STD, + 'interpolation': 'bicubic', + 'resize_mode': 'squash', + **kwargs, + } + + +def _apcfg(url='', hf_hub='', **kwargs): + # CLIPA defaults + return { + 'url': url, + 'hf_hub': hf_hub, + 'mean': IMAGENET_MEAN, + 'std': IMAGENET_STD, + 'interpolation': 'bilinear', + 'resize_mode': 'squash', + **kwargs, + } _RN50 = dict( @@ -164,7 +196,7 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), laion2b_s32b_b82k=_pcfg( hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + mean=INCEPTION_MEAN, std=INCEPTION_STD), # DataComp-XL models datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'), commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), @@ -263,6 +295,7 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): "RN50x4": _RN50x4, "RN50x16": _RN50x16, "RN50x64": _RN50x64, + "ViT-B-32": _VITB32, "ViT-B-32-256": _VITB32_256, "ViT-B-32-quickgelu": _VITB32_quickgelu, @@ -276,17 +309,21 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): "ViT-H-14-quickgelu": _VITH14_quickgelu, "ViT-g-14": _VITg14, "ViT-bigG-14": _VITbigG14, + "roberta-ViT-B-32": _robertaViTB32, "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, + "convnext_base": _convnext_base, "convnext_base_w": _convnext_base_w, "convnext_base_w_320": _convnext_base_w_320, "convnext_large_d": _convnext_large_d, "convnext_large_d_320": _convnext_large_d_320, "convnext_xxlarge": _convnext_xxlarge, + "coca_ViT-B-32": _coca_VITB32, "coca_ViT-L-14": _coca_VITL14, + "EVA01-g-14": dict( # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'), @@ -315,6 +352,56 @@ def _pcfg(url='', hf_hub='', mean=None, std=None): # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'), ), + + "ViT-B-16-SigLIP": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'), + ), + "ViT-B-16-SigLIP-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'), + ), + "ViT-B-16-SigLIP-i18n-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'), + ), + "ViT-B-16-SigLIP-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'), + ), + "ViT-B-16-SigLIP-512": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'), + ), + "ViT-L-16-SigLIP-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'), + ), + "ViT-L-16-SigLIP-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'), + ), + "ViT-SO400M-14-SigLIP": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'), + ), + "ViT-SO400M-14-SigLIP-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), + ), + + # FIXME update CLIPA pretrained to final home, rwightman/ is temporary for testing + "ViT-L-14-CLIPA": dict( + datacomp1b=_apcfg(hf_hub='rwightman/ViT-L-14-CLIPA-datacomp1B/'), + ), + "ViT-L-14-CLIPA-336": dict( + datacomp1b=_apcfg(hf_hub='rwightman/ViT-L-14-CLIPA-336-datacomp1B/'), + ), + "ViT-H-14-CLIPA": dict( + datacomp1b=_apcfg(hf_hub='rwightman/ViT-H-14-CLIPA-datacomp1B/'), + ), + "ViT-H-14-CLIPA-336": dict( + laion2b=_apcfg(hf_hub='rwightman/ViT-H-14-CLIPA-336-laion2B/'), + datacomp1b=_apcfg(hf_hub='rwightman/ViT-H-14-CLIPA-336-datacomp1B/'), + ), + # "ViT-bigG-14-CLIPA": dict( + # datacomp1b=_apcfg(hf_hub='rwightman/ViT-bigG-14-CLIPA-datacomp1B/'), + # ), + "ViT-bigG-14-CLIPA-336": dict( + datacomp1b=_apcfg(hf_hub='rwightman/ViT-bigG-14-CLIPA-336-datacomp1B/'), + ), + "nllb-clip-base": dict( v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'), ), diff --git a/src/open_clip/push_to_hf_hub.py b/src/open_clip/push_to_hf_hub.py index 6e6271da1..dcb8a78b5 100644 --- a/src/open_clip/push_to_hf_hub.py +++ b/src/open_clip/push_to_hf_hub.py @@ -36,6 +36,7 @@ HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version HF_CONFIG_NAME = 'open_clip_config.json' + def save_config_for_hf( model, config_path: str, @@ -45,6 +46,11 @@ def save_config_for_hf( 'mean': model.visual.image_mean, 'std': model.visual.image_std, } + other_pp = getattr(model.visual, 'preprocess_cfg', {}) + if 'interpolation' in other_pp: + preprocess_cfg['interpolation'] = other_pp['interpolation'] + if 'resize_mode' in other_pp: + preprocess_cfg['resize_mode'] = other_pp['resize_mode'] hf_config = { 'model_cfg': model_config, 'preprocess_cfg': preprocess_cfg, @@ -59,7 +65,7 @@ def save_for_hf( tokenizer: HFTokenizer, model_config: dict, save_directory: str, - safe_serialization: Union[bool, str] = False, + safe_serialization: Union[bool, str] = 'both', skip_weights : bool = False, ): config_filename = HF_CONFIG_NAME @@ -95,6 +101,7 @@ def push_to_hf_hub( safe_serialization: Union[bool, str] = False, ): if not isinstance(tokenizer, HFTokenizer): + # FIXME this makes it awkward to push models with new tokenizers, come up with better soln. # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 tokenizer = HFTokenizer('openai/clip-vit-large-patch14') @@ -157,12 +164,15 @@ def push_pretrained_to_hf_hub( precision: str = 'fp32', image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, + image_interpolation: Optional[str] = None, + image_resize_mode: Optional[str] = None, # only effective for inference commit_message: str = 'Add model', token: Optional[str] = None, revision: Optional[str] = None, private: bool = False, create_pr: bool = False, model_card: Optional[dict] = None, + hf_tokenizer_self: bool = False, ): model, preprocess_eval = create_model_from_pretrained( model_name, @@ -170,12 +180,16 @@ def push_pretrained_to_hf_hub( precision=precision, image_mean=image_mean, image_std=image_std, + image_interpolation=image_interpolation, + image_resize_mode=image_resize_mode, ) - model_config = get_model_config(model_name) assert model_config tokenizer = get_tokenizer(model_name) + if hf_tokenizer_self: + # make hf tokenizer config in the uploaded model point to self instead of original location + model_config['text']['hf_tokenizer_name'] = repo_id push_to_hf_hub( model=model, @@ -193,10 +207,15 @@ def push_pretrained_to_hf_hub( def generate_readme(model_card: dict, model_name: str): + tags = model_card.pop('tags', ('clip',)) + pipeline_tag = model_card.pop('pipeline_tag', 'zero-shot-image-classification') readme_text = "---\n" - readme_text += "tags:\n- clip\n" + if tags: + readme_text += "tags:\n" + for t in tags: + readme_text += f"- {t}\n" readme_text += "library_name: open_clip\n" - readme_text += "pipeline_tag: zero-shot-image-classification\n" + readme_text += f"pipeline_tag: {pipeline_tag}\n" readme_text += f"license: {model_card.get('license', 'mit')}\n" if 'details' in model_card and 'Dataset' in model_card['details']: readme_text += 'datasets:\n' @@ -262,6 +281,22 @@ def generate_readme(model_card: dict, model_name: str): parser.add_argument( '--image-std', type=float, nargs='+', default=None, metavar='STD', help='Override default image std deviation of of dataset') + parser.add_argument( + '--image-interpolation', + default=None, type=str, choices=['bicubic', 'bilinear', 'random'], + help="image resize interpolation" + ) + parser.add_argument( + '--image-resize-mode', + default=None, type=str, choices=['shortest', 'longest', 'squash'], + help="image resize mode during inference" + ) + parser.add_argument( + "--hf-tokenizer-self", + default=False, + action="store_true", + help="make hf_tokenizer_name point in uploaded config point to itself" + ) args = parser.parse_args() print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') @@ -275,6 +310,8 @@ def generate_readme(model_card: dict, model_name: str): precision=args.precision, image_mean=args.image_mean, # override image mean/std if trained w/ non defaults image_std=args.image_std, + image_interpolation=args.image_interpolation, + image_resize_mode=args.image_resize_mode, ) print(f'{args.model} saved.') diff --git a/src/open_clip/timm_model.py b/src/open_clip/timm_model.py index 3d3f595d6..5ddb9a76b 100644 --- a/src/open_clip/timm_model.py +++ b/src/open_clip/timm_model.py @@ -55,11 +55,16 @@ def __init__( timm_kwargs['patch_drop_rate'] = patch_drop custom_pool = pool in ('abs_attn', 'rot_attn') - if not proj and not custom_pool: + if proj: + assert proj in ("linear", "mlp", "none") + extra_proj = proj in ("linear", "mlp") + if not extra_proj and not custom_pool: # use network classifier head as projection if no proj specified and no custom pooling used + # if projection is explicitly set to "none" will be pass through from network trunk + proj_dim = 0 if proj == 'none' else embed_dim self.trunk = timm.create_model( model_name, - num_classes=embed_dim, + num_classes=proj_dim, global_pool=pool, pretrained=pretrained, **timm_kwargs, @@ -99,8 +104,6 @@ def __init__( head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) elif proj == 'mlp': head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) - else: - assert not proj, f'Unknown projection type {proj}.' self.head = nn.Sequential(head_layers) diff --git a/src/open_clip/tokenizer.py b/src/open_clip/tokenizer.py index 3e651aed5..985c0e030 100644 --- a/src/open_clip/tokenizer.py +++ b/src/open_clip/tokenizer.py @@ -5,25 +5,21 @@ import gzip import html import os -from functools import lru_cache -from typing import Union, List +import random +import string +from functools import lru_cache, partial +from typing import Callable, Optional, List, Union import ftfy +import numpy as np import regex as re import torch -import numpy as np # https://stackoverflow.com/q/62691279 -import os os.environ["TOKENIZERS_PARALLELISM"] = "false" +_nltk_init = False -try: - import nltk - # run them for the first time - nltk.download('punkt') - nltk.download('averaged_perceptron_tagger') -except: - nltk = None +DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP @lru_cache() @@ -78,8 +74,64 @@ def whitespace_clean(text): return text +def _clean_canonicalize(x): + # basic, remove whitespace, remove punctuation, lower case + return canonicalize_text(basic_clean(x)) + + +def _clean_lower(x): + # basic, remove whitespace, lower case + return whitespace_clean(basic_clean(x)).lower() + + +def _clean_whitespace(x): + # basic, remove whitespace + return whitespace_clean(basic_clean(x)) + + +def get_clean_fn(type: str): + if type == 'canonicalize': + return _clean_canonicalize + elif type == 'lower': + return _clean_lower + elif type == 'whitespace': + return _clean_whitespace + else: + assert False, f"Invalid clean function ({type})." + + +def canonicalize_text(text, *, keep_punctuation_exact_string=None): + """Returns canonicalized `text` (lowercase and punctuation removed). + + From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94 + + Args: + text: string to be canonicalized. + keep_punctuation_exact_string: If provided, then this exact string kept. + For example providing '{}' will keep any occurrences of '{}' (but will + still remove '{' and '}' that appear separately). + """ + text = text.replace("_", " ") + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans("", "", string.punctuation)) + for part in text.split(keep_punctuation_exact_string)) + else: + text = text.translate(str.maketrans("", "", string.punctuation)) + text = text.lower() + text = re.sub(r"\s+", " ", text) + return text.strip() + + class SimpleTokenizer(object): - def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + def __init__( + self, + bpe_path: str = default_bpe(), + additional_special_tokens: Optional[List[str]] = None, + context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, + clean: str = 'lower', + reduction_mask: str = '' + ): self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') @@ -89,20 +141,26 @@ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): vocab = vocab + [v+'' for v in vocab] for merge in merges: vocab.append(''.join(merge)) - if not special_tokens: - special_tokens = ['', ''] - else: - special_tokens = ['', ''] + special_tokens + special_tokens = ['', ''] + if additional_special_tokens: + special_tokens += additional_special_tokens vocab.extend(special_tokens) self.encoder = dict(zip(vocab, range(len(vocab)))) self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {t:t for t in special_tokens} special = "|".join(special_tokens) - self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) - + self.pat = re.compile( + special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) self.vocab_size = len(self.encoder) self.all_special_ids = [self.encoder[t] for t in special_tokens] + self.sot_token_id = self.all_special_ids[0] + self.eot_token_id = self.all_special_ids[1] + self.context_length = context_length + self.clean_fn = get_clean_fn(clean) + self.reduction_fn = get_reduction_mask_fn(reduction_mask) if reduction_mask else None def bpe(self, token): if token in self.cache: @@ -147,7 +205,7 @@ def bpe(self, token): def encode(self, text): bpe_tokens = [] - text = whitespace_clean(basic_clean(text)).lower() + text = self.clean_fn(text) for token in re.findall(self.pat, text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) @@ -158,157 +216,128 @@ def decode(self, tokens): text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') return text + def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor: + """ Returns the tokenized representation of given input string(s) -_tokenizer = SimpleTokenizer() + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length -def decode(output_ids: torch.Tensor): - output_ids = output_ids.cpu().numpy() - return _tokenizer.decode(output_ids) + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] -def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - context_length : int - The context length to use; all CLIP models use 77 as the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - if isinstance(texts, str): - texts = [texts] + context_length = context_length or self.context_length + assert context_length, 'Please set a valid context length' - sot_token = _tokenizer.encoder[""] - eot_token = _tokenizer.encoder[""] - all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + if self.reduction_fn is not None: + # use reduction strategy for tokenize if set, otherwise default to truncation below + return self.reduction_fn( + texts, + context_length=context_length, + sot_token_id=self.sot_token_id, + eot_token_id=self.eot_token_id, + encode_fn=self.encode, + ) - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length: - tokens = tokens[:context_length] # Truncate - tokens[-1] = eot_token - result[i, :len(tokens)] = torch.tensor(tokens) + all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - return result + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = self.eot_token_id + result[i, :len(tokens)] = torch.tensor(tokens) + return result -class HFTokenizer: - """HuggingFace tokenizer wrapper""" - def __init__(self, tokenizer_name: str): - from transformers import AutoTokenizer - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) +_tokenizer = SimpleTokenizer() - def save_pretrained(self, dest): - self.tokenizer.save_pretrained(dest) - def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: - # same cleaning as for default tokenizer, except lowercasing - # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance - if isinstance(texts, str): - texts = [texts] - texts = [whitespace_clean(basic_clean(text)) for text in texts] - input_ids = self.tokenizer( - texts, - return_tensors='pt', - max_length=context_length, - padding='max_length', - truncation=True, - ).input_ids - return input_ids +def decode(output_ids: torch.Tensor): + output_ids = output_ids.cpu().numpy() + return _tokenizer.decode(output_ids) -def random_mask_tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - context_length : int - The context length to use; all CLIP models use 77 as the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - if isinstance(texts, str): - texts = [texts] +def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH) -> torch.LongTensor: + return _tokenizer(texts, context_length=context_length) + - sot_token = _tokenizer.encoder[""] - eot_token = _tokenizer.encoder[""] - all_tokens = [_tokenizer.encode(text) for text in texts] +def random_mask_tokenize( + texts: Union[str, List[str]], + context_length: int, + sot_token_id: int, + eot_token_id: int, + encode_fn: Callable, + shuffle: bool = False, +): + all_tokens = [encode_fn(text) for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length - 2: # 2 for sot and eot token - indices = np.random.permutation(len(tokens)).tolist() - indices = indices[:context_length - 2] + tokens = torch.tensor(tokens) + num_tokens = len(tokens) + if num_tokens > context_length - 2: # 2 for sot and eot token + num_keep = context_length - 2 + indices = torch.randperm(len(tokens)) + indices = indices[:num_keep] + if not shuffle: + indices = indices.msort() tokens = tokens[indices] - tokens = [sot_token,] + tokens + [eot_token,] - result[i, :len(tokens)] = torch.tensor(tokens) + num_tokens = num_keep + result[i, 0] = sot_token_id + result[i, 1:num_tokens + 1] = tokens + result[i, num_tokens + 1] = eot_token_id return result -def block_mask_tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - context_length : int - The context length to use; all CLIP models use 77 as the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - if isinstance(texts, str): - texts = [texts] - - sot_token = _tokenizer.encoder[""] - eot_token = _tokenizer.encoder[""] - all_tokens = [_tokenizer.encode(text) for text in texts] +def simple_mask_tokenize( + texts: Union[str, List[str]], + context_length: int, + sot_token_id: int, + eot_token_id: int, + encode_fn: Callable, +): + all_tokens = [encode_fn(text) for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length - 2: # 2 for sot and eot token - start_index = np.random.randint(len(tokens) - context_length + 3) - tokens = tokens[start_index : start_index + context_length - 2] - tokens = [sot_token,] + tokens + [eot_token,] + num_tokens = len(tokens) + if num_tokens > context_length - 2: # 2 for sot and eot token + num_keep = context_length - 2 + start_index = random.randint(0, num_tokens - num_keep) # high is incl + tokens = tokens[start_index: start_index + num_keep] + tokens = [sot_token_id] + tokens + [eot_token_id] result[i, :len(tokens)] = torch.tensor(tokens) return result -def syntax_mask_tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s). +def syntax_mask_tokenize( + texts: Union[str, List[str]], + context_length: int, + sot_token_id: int, + eot_token_id: int, + encode_fn: Callable, +) -> torch.LongTensor: + """ Returns the tokenized representation of given input string(s). Apply syntax masking before tokenize. - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - context_length : int - The context length to use; all CLIP models use 77 as the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] """ - assert nltk is not None - if isinstance(texts, str): - texts = [texts] + import nltk + global _nltk_init + if not _nltk_init: + # run them for the first time + nltk.download('punkt') + nltk.download('averaged_perceptron_tagger') + _nltk_init = True def get_order(x): if x.startswith('NN'): @@ -319,6 +348,7 @@ def get_order(x): return 3 else: return 4 + # syntax masking new_texts = [] for text in texts: @@ -328,8 +358,7 @@ def get_order(x): order_list = [get_order(tag) for _, tag in pos_tags] sorted_ids = np.argsort(np.array(order_list)) sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens - # sample the tokens and convert to tf.tensor - sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) + sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) # sample the tokens new_text = '' for token in sampled_tokens: @@ -338,16 +367,130 @@ def get_order(x): new_texts.append(new_text) texts = new_texts - sot_token = _tokenizer.encoder[""] - eot_token = _tokenizer.encoder[""] - all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + all_tokens = [[sot_token_id] + encode_fn(text) + [eot_token_id] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): # still need first truncate because some words produces two tokens if len(tokens) > context_length: tokens = tokens[:context_length] # Truncate - tokens[-1] = eot_token + tokens[-1] = eot_token_id result[i, :len(tokens)] = torch.tensor(tokens) - return result \ No newline at end of file + return result + + +def get_reduction_mask_fn(type: str): + """ Choose strategy for dropping (masking) tokens to achieve target context length""" + assert type in ('simple', 'random', 'shuffle', 'syntax') + if type == 'simple': + return simple_mask_tokenize # randomly select block [start:end] + elif type == 'random': + return random_mask_tokenize # randomly drop tokens (keep order) + elif type == 'shuffle': + return partial(random_mask_tokenize, shuffle=True) # randomly drop tokens (shuffle order) + elif type == 'syntax': + return syntax_mask_tokenize # randomly drop prioritized by syntax + + +class HFTokenizer: + """HuggingFace tokenizer wrapper""" + + def __init__( + self, + tokenizer_name: str, + context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, + clean: str = 'whitespace', + strip_sep_token: bool = False, + ): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + self.context_length = context_length + self.clean_fn = get_clean_fn(clean) + self.strip_sep_token = strip_sep_token + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + + context_length = context_length or self.context_length + assert context_length, 'Please set a valid context length in class init or call.' + + texts = [self.clean_fn(text) for text in texts] + input_ids = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ).input_ids + + if self.strip_sep_token: + input_ids = torch.where( + input_ids == self.tokenizer.sep_token_id, + torch.zeros_like(input_ids), + input_ids, + ) + + return input_ids + + +class SigLipTokenizer: + """HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs + """ + VOCAB_FILES = { + # english, vocab_size=32_000 + "c4-en": "http://storage.googleapis.com/t5-data/vocabs/cc_en.32000/sentencepiece.model", + # used in multilingual models (mT5, PaLI), vocab_size=250_000 + "mc4": "http://storage.googleapis.com/t5-data/vocabs/mc4.250000.100extra/sentencepiece.model", + } + + def __init__( + self, + tokenizer_name: str, + context_length: Optional[int] = 64, + ): + from transformers import T5TokenizerFast + + if tokenizer_name in self.VOCAB_FILES: + # FIXME temporary hack? + import fsspec + import tempfile + vocab_file = self.VOCAB_FILES[tokenizer_name] + with tempfile.NamedTemporaryFile('wb') as dst: + with fsspec.open(vocab_file, 'rb') as src: + dst.write(src.read()) + self.tokenizer = T5TokenizerFast(dst.name, legacy=False) + else: + self.tokenizer = T5TokenizerFast(tokenizer_name, legacy=False) + + self.tokenizer.pad_token_id = 1 + self.tokenizer.eos_token_id = 1 + self.context_length = context_length + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + + context_length = context_length or self.context_length + assert context_length, 'Please set a valid context length in class init or call.' + + texts = [canonicalize_text(basic_clean(text)) for text in texts] + output = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ) + return output.input_ids diff --git a/src/open_clip/transform.py b/src/open_clip/transform.py index 59f13bb59..45a8e5428 100644 --- a/src/open_clip/transform.py +++ b/src/open_clip/transform.py @@ -1,16 +1,61 @@ +import numbers +import random import warnings from dataclasses import dataclass, asdict -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union -import random import torch -import torch.nn as nn import torchvision.transforms.functional as F - from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ CenterCrop, ColorJitter, Grayscale from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .utils import to_2tuple + + +@dataclass +class PreprocessCfg: + size: Union[int, Tuple[int, int]] = 224 + mode: str = 'RGB' + mean: Tuple[float, ...] = OPENAI_DATASET_MEAN + std: Tuple[float, ...] = OPENAI_DATASET_STD + interpolation: str = 'bicubic' + resize_mode: str = 'shortest' + fill_color: int = 0 + + def __post_init__(self): + assert self.mode in ('RGB') + + @property + def num_channels(self): + return 3 + + @property + def input_size(self): + return (self.num_channels(),) + to_2tuple(self.size) + +_PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys()) + + +def merge_preprocess_dict( + base: Union[PreprocessCfg, Dict], + overlay: Dict, +): + """ Merge overlay key-value pairs on top of base preprocess cfg or dict. + Input dicts are filtered based on PreprocessCfg fields. + """ + if isinstance(base, PreprocessCfg): + base_clean = asdict(base) + else: + base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS} + if overlay: + overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None} + base_clean.update(overlay_clean) + return base_clean + + +def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs): + return merge_preprocess_dict(base, kwargs) @dataclass @@ -18,41 +63,177 @@ class AugmentationCfg: scale: Tuple[float, float] = (0.9, 1.0) ratio: Optional[Tuple[float, float]] = None color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None - interpolation: Optional[str] = None re_prob: Optional[float] = None re_count: Optional[int] = None use_timm: bool = False + # params for simclr_jitter_gray color_jitter_prob: float = None gray_scale_prob: float = None -class ResizeMaxSize(nn.Module): +def _setup_size(size, error_msg): + if isinstance(size, numbers.Number): + return int(size), int(size) - def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): - super().__init__() - if not isinstance(max_size, int): - raise TypeError(f"Size should be int. Got {type(max_size)}") - self.max_size = max_size + if isinstance(size, Sequence) and len(size) == 1: + return size[0], size[0] + + if len(size) != 2: + raise ValueError(error_msg) + + return size + + +class ResizeKeepRatio: + """ Resize and Keep Ratio + + Copy & paste from `timm` + """ + + def __init__( + self, + size, + longest=0., + interpolation=InterpolationMode.BICUBIC, + random_scale_prob=0., + random_scale_range=(0.85, 1.05), + random_aspect_prob=0., + random_aspect_range=(0.9, 1.11) + ): + if isinstance(size, (list, tuple)): + self.size = tuple(size) + else: + self.size = (size, size) self.interpolation = interpolation - self.fn = min if fn == 'min' else min - self.fill = fill + self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest + self.random_scale_prob = random_scale_prob + self.random_scale_range = random_scale_range + self.random_aspect_prob = random_aspect_prob + self.random_aspect_range = random_aspect_range - def forward(self, img): - if isinstance(img, torch.Tensor): - height, width = img.shape[:2] + @staticmethod + def get_params( + img, + target_size, + longest, + random_scale_prob=0., + random_scale_range=(0.85, 1.05), + random_aspect_prob=0., + random_aspect_range=(0.9, 1.11) + ): + """Get parameters + """ + source_size = img.size[::-1] # h, w + h, w = source_size + target_h, target_w = target_size + ratio_h = h / target_h + ratio_w = w / target_w + ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest) + if random_scale_prob > 0 and random.random() < random_scale_prob: + ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1]) + ratio_factor = (ratio_factor, ratio_factor) else: - width, height = img.size - scale = self.max_size / float(max(height, width)) - new_size = tuple(round(dim * scale) for dim in (height, width)) - if scale != 1.0: - img = F.resize(img, new_size, self.interpolation) - if not width == height: - pad_h = self.max_size - new_size[0] - pad_w = self.max_size - new_size[1] - img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) + ratio_factor = (1., 1.) + if random_aspect_prob > 0 and random.random() < random_aspect_prob: + aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1]) + ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor) + size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)] + return size + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size + """ + size = self.get_params( + img, self.size, self.longest, + self.random_scale_prob, self.random_scale_range, + self.random_aspect_prob, self.random_aspect_range + ) + img = F.resize(img, size, self.interpolation) return img + def __repr__(self): + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += f', interpolation={self.interpolation})' + format_string += f', longest={self.longest:.3f})' + return format_string + + +def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor: + """Center crops and/or pads the given image. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. + + Args: + img (PIL Image or Tensor): Image to be cropped. + output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int, + it is used for both directions. + fill (int, Tuple[int]): Padding color + + Returns: + PIL Image or Tensor: Cropped image. + """ + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: + output_size = (output_size[0], output_size[0]) + + _, image_height, image_width = F.get_dimensions(img) + crop_height, crop_width = output_size + + if crop_width > image_width or crop_height > image_height: + padding_ltrb = [ + (crop_width - image_width) // 2 if crop_width > image_width else 0, + (crop_height - image_height) // 2 if crop_height > image_height else 0, + (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, + (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, + ] + img = F.pad(img, padding_ltrb, fill=fill) + _, image_height, image_width = F.get_dimensions(img) + if crop_width == image_width and crop_height == image_height: + return img + + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) + return F.crop(img, crop_top, crop_left, crop_height, crop_width) + + +class CenterCropOrPad(torch.nn.Module): + """Crops the given image at the center. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + """ + + def __init__(self, size, fill=0): + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.fill = fill + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ + return center_crop_or_pad(img, self.size, fill=self.fill) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + def _convert_to_rgb(image): return image.convert('RGB') @@ -89,12 +270,14 @@ def __call__(self, img): else: return img + def image_transform( - image_size: int, + image_size: Union[int, Tuple[int, int]], is_train: bool, mean: Optional[Tuple[float, ...]] = None, std: Optional[Tuple[float, ...]] = None, - resize_longest_max: bool = False, + resize_mode: Optional[str] = None, + interpolation: Optional[str] = None, fill_color: int = 0, aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, ): @@ -106,15 +289,21 @@ def image_transform( if not isinstance(std, (list, tuple)): std = (std,) * 3 - if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: - # for square size, pass size as int so that Resize() uses aspect preserving shortest edge - image_size = image_size[0] + interpolation = interpolation or 'bicubic' + assert interpolation in ['bicubic', 'bilinear', 'random'] + # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set + interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC + + resize_mode = resize_mode or 'shortest' + assert resize_mode in ('shortest', 'longest', 'squash') if isinstance(aug_cfg, dict): aug_cfg = AugmentationCfg(**aug_cfg) else: aug_cfg = aug_cfg or AugmentationCfg() + normalize = Normalize(mean=mean, std=std) + if is_train: aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} use_timm = aug_cfg_dict.pop('use_timm', False) @@ -125,13 +314,11 @@ def image_transform( input_size = (3,) + image_size[-2:] else: input_size = (3, image_size, image_size) - # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time - aug_cfg_dict.setdefault('interpolation', 'random') - aug_cfg_dict.setdefault('color_jitter', None) # disable by default - # drop extra item - aug_cfg_dict.pop('color_jitter_prob', False) - aug_cfg_dict.pop('gray_scale_prob', False) + aug_cfg_dict.setdefault('color_jitter', None) # disable by default + # drop extra non-timm items + aug_cfg_dict.pop('color_jitter_prob', None) + aug_cfg_dict.pop('gray_scale_prob', None) train_transform = create_transform( input_size=input_size, @@ -140,6 +327,7 @@ def image_transform( mean=mean, std=std, re_mode='pixel', + interpolation=interpolation, **aug_cfg_dict, ) else: @@ -169,18 +357,50 @@ def image_transform( warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') return train_transform else: - if resize_longest_max: + if resize_mode == 'longest': transforms = [ - ResizeMaxSize(image_size, fill=fill_color) + ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1), + CenterCropOrPad(image_size, fill=fill_color) ] - else: + elif resize_mode == 'squash': + if isinstance(image_size, int): + image_size = (image_size, image_size) transforms = [ - Resize(image_size, interpolation=InterpolationMode.BICUBIC), - CenterCrop(image_size), + Resize(image_size, interpolation=interpolation_mode), ] + else: + assert resize_mode == 'shortest' + if not isinstance(image_size, (tuple, list)): + image_size = (image_size, image_size) + if image_size[0] == image_size[1]: + # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) + transforms = [ + Resize(image_size[0], interpolation=interpolation_mode) + ] + else: + # resize shortest edge to matching target dim for non-square target + transforms = [ResizeKeepRatio(image_size)] + transforms += [CenterCrop(image_size)] + transforms.extend([ _convert_to_rgb, ToTensor(), normalize, ]) return Compose(transforms) + + +def image_transform_v2( + cfg: PreprocessCfg, + is_train: bool, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, +): + return image_transform( + image_size=cfg.size, + is_train=is_train, + mean=cfg.mean, + std=cfg.std, + interpolation=cfg.interpolation, + resize_mode=cfg.resize_mode, + fill_color=cfg.fill_color, + ) \ No newline at end of file diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index ce5e0d3f7..6d4e604d8 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -1,6 +1,7 @@ from collections import OrderedDict import math from typing import Callable, Optional, Sequence, Tuple +from functools import partial import torch from torch import nn @@ -8,6 +9,7 @@ from torch.utils.checkpoint import checkpoint from .utils import to_2tuple +from .pos_embed import get_2d_sincos_pos_embed class LayerNormFp32(nn.LayerNorm): @@ -179,12 +181,9 @@ def forward(self, x: torch.Tensor): x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND N = x.shape[1] q = self.ln_q(self.query) - out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] + out = self.attn(q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False)[0] return out.permute(1, 0, 2) # LND -> NLD - def _repeat(self, query, N: int): - return query.unsqueeze(1).repeat(1, N, 1) - class ResidualAttentionBlock(nn.Module): def __init__( @@ -273,8 +272,8 @@ def __init__( mlp_width = int(d_model * mlp_ratio) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, mlp_width)), - ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), ("gelu", act_layer()), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), ("c_proj", nn.Linear(mlp_width, d_model)) ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() @@ -285,6 +284,10 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): return x +def _expand_token(token, batch_size: int): + return token.view(1, 1, -1).expand(batch_size, -1, -1) + + class Transformer(nn.Module): def __init__( self, @@ -334,44 +337,51 @@ def __init__( heads: int, mlp_ratio: float, ls_init_value: float = None, - global_average_pool: bool = False, attentional_pool: bool = False, - n_queries: int = 256, + attn_pooler_queries: int = 256, attn_pooler_heads: int = 8, output_dim: int = 512, patch_dropout: float = 0., - input_patchnorm: bool = False, + no_ln_pre: bool = False, + pos_embed_type: str = 'learnable', + pool_type: str = 'tok', + final_ln_after_pool: bool = False, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, - output_tokens: bool = False + output_tokens: bool = False, ): super().__init__() + assert pool_type in ('tok', 'avg', 'none') self.output_tokens = output_tokens image_height, image_width = self.image_size = to_2tuple(image_size) patch_height, patch_width = self.patch_size = to_2tuple(patch_size) self.grid_size = (image_height // patch_height, image_width // patch_width) + self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled self.output_dim = output_dim - # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1 - self.input_patchnorm = input_patchnorm - - if input_patchnorm: - patch_input_dim = patch_height * patch_width * 3 - self.patchnorm_pre_ln = LayerNorm(patch_input_dim) - self.conv1 = nn.Linear(patch_input_dim, width) - else: - self.patchnorm_pre_ln = nn.Identity() - self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) # class embeddings and positional embeddings scale = width ** -0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + if pos_embed_type == 'learnable': + self.positional_embedding = nn.Parameter( + scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + elif pos_embed_type == 'sin_cos_2d': + # fixed sin-cos embedding + assert self.grid_size[0] == self.grid_size[1],\ + 'currently sin cos 2d pos embedding only supports square input' + self.positional_embedding = nn.Parameter( + torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False) + pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True) + self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float()) + else: + raise ValueError # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() - self.ln_pre = norm_layer(width) + self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width) self.transformer = Transformer( width, layers, @@ -382,15 +392,43 @@ def __init__( norm_layer=norm_layer, ) - self.global_average_pool = global_average_pool if attentional_pool: - self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) - self.ln_post = norm_layer(output_dim) - self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) + if isinstance(attentional_pool, str): + self.attn_pool_type = attentional_pool + self.pool_type = 'none' + if attentional_pool in ('parallel', 'cascade'): + self.attn_pool = AttentionalPooler( + output_dim, + width, + n_head=attn_pooler_heads, + n_queries=attn_pooler_queries, + ) + self.attn_pool_contrastive = AttentionalPooler( + output_dim, + width, + n_head=attn_pooler_heads, + n_queries=1, + ) + else: + assert False + else: + self.attn_pool_type = '' + self.pool_type = pool_type + self.attn_pool = AttentionalPooler( + output_dim, + width, + n_head=attn_pooler_heads, + n_queries=attn_pooler_queries, + ) + self.attn_pool_contrastive = None + pool_dim = output_dim else: self.attn_pool = None - self.ln_post = norm_layer(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + pool_dim = width + self.pool_type = pool_type + + self.ln_post = norm_layer(pool_dim) + self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim)) self.init_parameters() @@ -452,33 +490,25 @@ def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - if self.global_average_pool: - return x.mean(dim=1), x + if self.pool_type == 'avg': + pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] + elif self.pool_type == 'tok': + pooled, tokens = x[:, 0], x[:, 1:] else: - return x[:, 0], x[:, 1:] + pooled = tokens = x - def forward(self, x: torch.Tensor): + return pooled, tokens - # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 - if self.input_patchnorm: - # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') - x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1]) - x = x.permute(0, 2, 4, 1, 3, 5) - x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) - x = self.patchnorm_pre_ln(x) - x = self.conv1(x) - else: - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] # class embeddings and positional embeddings - x = torch.cat( - [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), - x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) + # shape = [*, grid ** 2 + 1, width] x = x + self.positional_embedding.to(x.dtype) - # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in x = self.patch_dropout(x) x = self.ln_pre(x) @@ -487,12 +517,26 @@ def forward(self, x: torch.Tensor): x = x.permute(1, 0, 2) # LND -> NLD if self.attn_pool is not None: - x = self.attn_pool(x) - x = self.ln_post(x) + if self.attn_pool_contrastive is not None: + # This is untested, WIP pooling that should match paper + x = self.ln_post(x) # TBD LN first or separate one after each pool? + tokens = self.attn_pool(x) + if self.attn_pool_type == 'parallel': + pooled = self.attn_pool_contrastive(x) + else: + assert self.attn_pool_type == 'cascade' + pooled = self.attn_pool_contrastive(tokens) + else: + # this is the original OpenCLIP CoCa setup, does not match paper + x = self.attn_pool(x) + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + elif self.final_ln_after_pool: pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) else: + x = self.ln_post(x) pooled, tokens = self._global_pool(x) - pooled = self.ln_post(pooled) if self.proj is not None: pooled = pooled @ self.proj @@ -503,6 +547,21 @@ def forward(self, x: torch.Tensor): return pooled +def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'): + if pool_type == 'first': + pooled, tokens = x[:, 0], x[:, 1:] + elif pool_type == 'last': + pooled, tokens = x[:, -1], x[:, :-1] + elif pool_type == 'argmax': + # take features from the eot embedding (eot_token is the highest number in each sequence) + assert text is not None + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + else: + pooled = tokens = x + + return pooled, tokens + + class TextTransformer(nn.Module): output_tokens: torch.jit.Final[bool] @@ -513,15 +572,20 @@ def __init__( width: int = 512, heads: int = 8, layers: int = 12, + mlp_ratio: float = 4.0, ls_init_value: float = None, output_dim: int = 512, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, embed_cls: bool = False, + no_causal_mask: bool = False, pad_id: int = 0, + pool_type: str = 'argmax', + proj_bias: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, output_tokens: bool = False, ): super().__init__() + assert pool_type in ('first', 'last', 'argmax', 'none') self.output_tokens = output_tokens self.num_pos = self.context_length = context_length self.vocab_size = vocab_size @@ -529,28 +593,35 @@ def __init__( self.output_dim = output_dim self.heads = heads self.pad_id = pad_id + self.pool_type = pool_type - self.text_projection = nn.Parameter(torch.empty(width, output_dim)) - + self.token_embedding = nn.Embedding(vocab_size, width) if embed_cls: self.cls_emb = nn.Parameter(torch.empty(width)) self.num_pos += 1 else: self.cls_emb = None - - self.token_embedding = nn.Embedding(vocab_size, width) self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) self.transformer = Transformer( width=width, layers=layers, heads=heads, + mlp_ratio=mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, ) self.ln_final = norm_layer(width) - self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + if no_causal_mask: + self.attn_mask = None + else: + self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False) + + if proj_bias: + self.text_projection = nn.Linear(width, output_dim) + else: + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) self.init_parameters() @@ -570,13 +641,18 @@ def init_parameters(self): nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + if isinstance(self.text_projection, nn.Linear): + nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5) + if self.text_projection.bias is not None: + nn.init.zeros_(self.text_projection.bias) + else: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable - def build_attention_mask(self): + def build_causal_mask(self): # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.num_pos, self.num_pos) @@ -593,9 +669,6 @@ def build_cls_mask(self, text, cast_dtype: torch.dtype): additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) return additive_mask - def _repeat(self, t, N: int): - return t.reshape(1, 1, -1).repeat(N, 1, 1) - def forward(self, text): cast_dtype = self.transformer.get_cast_dtype() seq_len = text.shape[1] @@ -604,9 +677,10 @@ def forward(self, text): attn_mask = self.attn_mask if self.cls_emb is not None: seq_len += 1 - x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) + x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1) cls_mask = self.build_cls_mask(text, cast_dtype) - attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + if attn_mask is not None: + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] x = x + self.positional_embedding[:seq_len].to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND @@ -614,16 +688,19 @@ def forward(self, text): x = x.permute(1, 0, 2) # LND -> NLD # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) if self.cls_emb is not None: - pooled, tokens = x[:, -1], x[:, :-1] - pooled = self.ln_final(pooled) + # presence of appended cls embed (CoCa) overrides pool_type, always take last token + pooled, tokens = text_global_pool(x, pool_type='last') + pooled = self.ln_final(pooled) # final LN applied after pooling in this case else: x = self.ln_final(x) - pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type) if self.text_projection is not None: - pooled = pooled @ self.text_projection + if isinstance(self.text_projection, nn.Linear): + pooled = self.text_projection(pooled) + else: + pooled = pooled @ self.text_projection if self.output_tokens: return pooled, tokens diff --git a/src/training/main.py b/src/training/main.py index 08d2412e2..94496999f 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -6,6 +6,7 @@ import sys import random from datetime import datetime +from functools import partial import numpy as np import torch @@ -229,10 +230,12 @@ def main(args): force_custom_text=args.force_custom_text, force_patch_dropout=args.force_patch_dropout, force_image_size=args.force_image_size, - pretrained_image=args.pretrained_image, image_mean=args.image_mean, image_std=args.image_std, + image_interpolation=args.image_interpolation, + image_resize_mode=args.image_resize_mode, # only effective for inference aug_cfg=args.aug_cfg, + pretrained_image=args.pretrained_image, output_dict=True, **model_kwargs, ) @@ -350,7 +353,13 @@ def main(args): logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") # initialize datasets - data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) + tokenizer = get_tokenizer(args.model) + data = get_data( + args, + (preprocess_train, preprocess_val), + epoch=start_epoch, + tokenizer=tokenizer, + ) assert len(data), 'At least one train or eval dataset must be specified.' # create scheduler if train @@ -415,7 +424,7 @@ def main(args): from open_clip.utils import convert_int8_model_to_inference_mode convert_int8_model_to_inference_mode(model) # Evaluate. - evaluate(model, data, start_epoch, args, writer) + evaluate(model, data, start_epoch, args, tb_writer=writer, tokenizer=tokenizer) return loss = create_loss(args) @@ -428,7 +437,7 @@ def main(args): completed_epoch = epoch + 1 if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')): - evaluate(model, data, completed_epoch, args, writer) + evaluate(model, data, completed_epoch, args, tb_writer=writer, tokenizer=tokenizer) # Saving checkpoints. if args.save_logs: diff --git a/src/training/params.py b/src/training/params.py index 345382e57..3ea5a8f3b 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -234,6 +234,16 @@ def parse_args(args): parser.add_argument( '--image-std', type=float, nargs='+', default=None, metavar='STD', help='Override default image std deviation of of dataset') + parser.add_argument( + '--image-interpolation', + default=None, type=str, choices=['bicubic', 'bilinear', 'random'], + help="Override default image resize interpolation" + ) + parser.add_argument( + '--image-resize-mode', + default=None, type=str, choices=['shortest', 'longest', 'squash'], + help="Override default image resize (& crop) mode during inference" + ) parser.add_argument('--aug-cfg', nargs='*', default={}, action=ParseKwargs) parser.add_argument( "--grad-checkpointing", @@ -442,6 +452,7 @@ def parse_args(args): action="store_true", help='Use SigLip (sigmoid) loss.' ) + args = parser.parse_args(args) # If some params are not passed, we use the default values based on model name. diff --git a/src/training/profile.py b/src/training/profiler.py similarity index 100% rename from src/training/profile.py rename to src/training/profiler.py diff --git a/src/training/train.py b/src/training/train.py index 902fbe36a..a48a34593 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -248,14 +248,14 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist # end for -def evaluate(model, data, epoch, args, tb_writer=None): +def evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None): metrics = {} if not is_master(args): return metrics device = torch.device(args.device) model.eval() - zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + zero_shot_metrics = zero_shot_eval(model, data, epoch, args, tokenizer=tokenizer) metrics.update(zero_shot_metrics) autocast = get_autocast(args.precision) diff --git a/src/training/zero_shot.py b/src/training/zero_shot.py index 8265b424b..06ce7ac09 100644 --- a/src/training/zero_shot.py +++ b/src/training/zero_shot.py @@ -1,7 +1,6 @@ import logging import torch -import torch.nn.functional as F from tqdm import tqdm from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ @@ -42,7 +41,7 @@ def run(model, classifier, dataloader, args): return top1, top5 -def zero_shot_eval(model, data, epoch, args): +def zero_shot_eval(model, data, epoch, args, tokenizer=None): if 'imagenet-val' not in data and 'imagenet-v2' not in data: return {} if args.zeroshot_frequency == 0: @@ -53,11 +52,12 @@ def zero_shot_eval(model, data, epoch, args): model = model.module logging.info('Starting zero-shot imagenet.') + if tokenizer is None: + tokenizer = get_tokenizer(args.model) logging.info('Building zero-shot classifier') autocast = get_autocast(args.precision) with autocast(): - tokenizer = get_tokenizer(args.model) classifier = build_zero_shot_classifier( model, tokenizer=tokenizer, diff --git a/tests/test_hf_model.py b/tests/test_hf_model.py index f9191f1f4..1deb00da8 100644 --- a/tests/test_hf_model.py +++ b/tests/test_hf_model.py @@ -21,7 +21,7 @@ def test_poolers(): def test_pretrained_text_encoder(model_id): bs, sl, d = 2, 10, 64 cfg = AutoConfig.from_pretrained(model_id) - model = HFTextEncoder(model_id, d, proj='linear') + model = HFTextEncoder(model_id, d, proj_type='linear') x = torch.randint(0, cfg.vocab_size, (bs, sl)) with torch.no_grad(): emb = model(x)