Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Q8_0 quantization for intermediate results #951

Merged
merged 7 commits into from
Apr 15, 2023
Merged

Add Q8_0 quantization for intermediate results #951

merged 7 commits into from
Apr 15, 2023

Conversation

ggerganov
Copy link
Owner

@ggerganov ggerganov commented Apr 13, 2023

ref #909

This is an implementation of mode (E) from the referenced issue.

Basically, we quantize the intermediate results to 8-bits, instead of 4-bits to gain accuracy without any performance degradation.

As a positive side-effect, we will also have full 8-bit quantization support, although I don't think it will be significantly better than the proposed 4-bit quantization with 8-bit intermediate results.

Currently:

  • Reference
  • ARM NEON
  • AVX
  • WASM

PRs are welcome into this PR to implement the missing SIMD routines

Perplexity results

Q4_0 M1 Pro (with BLAS) [655]6.2838 (i.e. reference)
$  make clean && make -j perplexity && time ./perplexity -m ./models/7B/ggml-model-q4_0.bin -f ./build/wiki.test.raw -t 8
I llama.cpp build info: 
I UNAME_S:  Darwin
I UNAME_P:  arm
I UNAME_M:  arm64
I CFLAGS:   -I.              -O3 -DNDEBUG -std=c11   -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function -pthread -DGGML_USE_ACCELERATE
I CXXFLAGS: -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread
I LDFLAGS:   -framework Accelerate
I CC:       Apple clang version 14.0.3 (clang-1403.0.22.14.1)
I CXX:      Apple clang version 14.0.3 (clang-1403.0.22.14.1)

rm -vf *.o main quantize quantize-stats perplexity embedding benchmark-q4_0-matmult
common.o
ggml.o
llama.o
main
quantize
perplexity
embedding
I llama.cpp build info: 
I UNAME_S:  Darwin
I UNAME_P:  arm
I UNAME_M:  arm64
I CFLAGS:   -I.              -O3 -DNDEBUG -std=c11   -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function -pthread -DGGML_USE_ACCELERATE
I CXXFLAGS: -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread
I LDFLAGS:   -framework Accelerate
I CC:       Apple clang version 14.0.3 (clang-1403.0.22.14.1)
I CXX:      Apple clang version 14.0.3 (clang-1403.0.22.14.1)

cc  -I.              -O3 -DNDEBUG -std=c11   -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function -pthread -DGGML_USE_ACCELERATE   -c ggml.c -o ggml.o
c++ -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread -c llama.cpp -o llama.o
c++ -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread -c examples/common.cpp -o common.o
c++ -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread examples/perplexity/perplexity.cpp ggml.o llama.o common.o -o perplexity  -framework Accelerate
main: seed = 1681463663
llama.cpp: loading model from ./models/7B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v1 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 4096
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 32
llama_model_load_internal: n_layer    = 32
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 11008
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size =  59.11 KB
llama_model_load_internal: mem required  = 5809.32 MB (+ 1026.00 MB per state)
llama_init_from_file: kv self size  =  256.00 MB

system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
perplexity : calculating perplexity over 655 chunks, batch_size=512
10.60 seconds per pass - ETA 1.93 hours
[1]4.3802,[2]4.9555,[3]5.8269,[4]6.4692,[5]6.5435,[6]6.5411,[7]6.7174,[8]6.8069,[9]7.1756,[10]7.4121,[11]7.6567,[12]7.6957,[13]7.6058,[14]7.6820,[15]7.9366,[16]7.5419,[17]7.4189,[18]7.3798,[19]7.0077,[20]6.9948,[21]6.8969,[22]6.7125,[23]6.6744,[24]6.5868,[25]6.5871,[26]6.4149,[27]6.2349,[28]6.1341,[29]6.0498,[30]5.8938,[31]5.8659,[32]5.8839,[33]5.8189,[34]5.8537,[35]5.8795,[36]5.9232,[37]5.9272,[38]5.9443,[39]5.9825,[40]6.0412,[41]6.0482,[42]6.0826,[43]6.0397,[44]6.0944,[45]6.0989,[46]6.0729,[47]6.0967,[48]6.0674,[49]6.0745,[50]6.0351,[51]6.0309,[52]6.0200,[53]6.0641,[54]6.0476,[55]6.0250,[56]6.0593,[57]6.0824,[58]6.1043,[59]6.1182,[60]6.1647,[61]6.1536,[62]6.2166,[63]6.2502,[64]6.2653,[65]6.3119,[66]6.3220,[67]6.3401,[68]6.3541,[69]6.3790,[70]6.4113,[71]6.4327,[72]6.4625,[73]6.5276,[74]6.5330,[75]6.5474,[76]6.5637,[77]6.5770,[78]6.5618,[79]6.5914,[80]6.5839,[81]6.5967,[82]6.6005,[83]6.5468,[84]6.5322,[85]6.5208,[86]6.4997,[87]6.4344,[88]6.4059,[89]6.3853,[90]6.3687,[91]6.3948,[92]6.3909,[93]6.3935,[94]6.3910,[95]6.4198,[96]6.4177,[97]6.4105,[98]6.4035,[99]6.3895,[100]6.3895,[101]6.4154,[102]6.4091,[103]6.4308,[104]6.4376,[105]6.4361,[106]6.4538,[107]6.4525,[108]6.4648,[109]6.4595,[110]6.4550,[111]6.4779,[112]6.4969,[113]6.4983,[114]6.4949,[115]6.5032,[116]6.4958,[117]6.5014,[118]6.5298,[119]6.5507,[120]6.5872,[121]6.6035,[122]6.6282,[123]6.6672,[124]6.6850,[125]6.6762,[126]6.7153,[127]6.7524,[128]6.7798,[129]6.7629,[130]6.7725,[131]6.7672,[132]6.7584,[133]6.7456,[134]6.7568,[135]6.7534,[136]6.7402,[137]6.7322,[138]6.7151,[139]6.7035,[140]6.7005,[141]6.6707,[142]6.6658,[143]6.6379,[144]6.6178,[145]6.6092,[146]6.5957,[147]6.6031,[148]6.6054,[149]6.5994,[150]6.5953,[151]6.5965,[152]6.5870,[153]6.5703,[154]6.5613,[155]6.5680,[156]6.5630,[157]6.5813,[158]6.5849,[159]6.5890,[160]6.5916,[161]6.6041,[162]6.5739,[163]6.5619,[164]6.5357,[165]6.5039,[166]6.4751,[167]6.4377,[168]6.4051,[169]6.3916,[170]6.3791,[171]6.3502,[172]6.3322,[173]6.3136,[174]6.2829,[175]6.2607,[176]6.2505,[177]6.2295,[178]6.2059,[179]6.1887,[180]6.1798,[181]6.1574,[182]6.1382,[183]6.1239,[184]6.1238,[185]6.1165,[186]6.1182,[187]6.1236,[188]6.1200,[189]6.1384,[190]6.1393,[191]6.1597,[192]6.1760,[193]6.1938,[194]6.2054,[195]6.2263,[196]6.2434,[197]6.2655,[198]6.2810,[199]6.2840,[200]6.2885,[201]6.2844,[202]6.3049,[203]6.3115,[204]6.3114,[205]6.3224,[206]6.3302,[207]6.3262,[208]6.3346,[209]6.3398,[210]6.3449,[211]6.3547,[212]6.3620,[213]6.3727,[214]6.3762,[215]6.3802,[216]6.3951,[217]6.4129,[218]6.4264,[219]6.4267,[220]6.4231,[221]6.4168,[222]6.4133,[223]6.4024,[224]6.3958,[225]6.3910,[226]6.4125,[227]6.4212,[228]6.4271,[229]6.4338,[230]6.4294,[231]6.4462,[232]6.4332,[233]6.4160,[234]6.4004,[235]6.3845,[236]6.3768,[237]6.3664,[238]6.3698,[239]6.3536,[240]6.3433,[241]6.3466,[242]6.3503,[243]6.3488,[244]6.3368,[245]6.3342,[246]6.3221,[247]6.3097,[248]6.3030,[249]6.3010,[250]6.3057,[251]6.2980,[252]6.2946,[253]6.2844,[254]6.2804,[255]6.2688,[256]6.2496,[257]6.2385,[258]6.2299,[259]6.2279,[260]6.2197,[261]6.2154,[262]6.2095,[263]6.2050,[264]6.1858,[265]6.1850,[266]6.1835,[267]6.1766,[268]6.1862,[269]6.1843,[270]6.1850,[271]6.1928,[272]6.1974,[273]6.1969,[274]6.1983,[275]6.2073,[276]6.2128,[277]6.2288,[278]6.2397,[279]6.2483,[280]6.2518,[281]6.2617,[282]6.2678,[283]6.2825,[284]6.2902,[285]6.2997,[286]6.3144,[287]6.3138,[288]6.3198,[289]6.3107,[290]6.2956,[291]6.2802,[292]6.2644,[293]6.2505,[294]6.2530,[295]6.2524,[296]6.2567,[297]6.2553,[298]6.2579,[299]6.2551,[300]6.2439,[301]6.2440,[302]6.2359,[303]6.2282,[304]6.2204,[305]6.2180,[306]6.2047,[307]6.2072,[308]6.2104,[309]6.1941,[310]6.1880,[311]6.1816,[312]6.1838,[313]6.1782,[314]6.1769,[315]6.1604,[316]6.1562,[317]6.1395,[318]6.1179,[319]6.1298,[320]6.1428,[321]6.1466,[322]6.1421,[323]6.1355,[324]6.1331,[325]6.1431,[326]6.1430,[327]6.1451,[328]6.1494,[329]6.1554,[330]6.1579,[331]6.1703,[332]6.1671,[333]6.1741,[334]6.1682,[335]6.1617,[336]6.1655,[337]6.1625,[338]6.1612,[339]6.1555,[340]6.1511,[341]6.1589,[342]6.1613,[343]6.1669,[344]6.1668,[345]6.1667,[346]6.1638,[347]6.1686,[348]6.1727,[349]6.1746,[350]6.1712,[351]6.1717,[352]6.1717,[353]6.1665,[354]6.1664,[355]6.1718,[356]6.1749,[357]6.1712,[358]6.1802,[359]6.1833,[360]6.1795,[361]6.1791,[362]6.1858,[363]6.1970,[364]6.2035,[365]6.2093,[366]6.2100,[367]6.2188,[368]6.2165,[369]6.2175,[370]6.2185,[371]6.2125,[372]6.2178,[373]6.2234,[374]6.2220,[375]6.2217,[376]6.2301,[377]6.2252,[378]6.2277,[379]6.2338,[380]6.2254,[381]6.2211,[382]6.2154,[383]6.2144,[384]6.2137,[385]6.2124,[386]6.2119,[387]6.2111,[388]6.2066,[389]6.2012,[390]6.1943,[391]6.1862,[392]6.1821,[393]6.1802,[394]6.1828,[395]6.1812,[396]6.1738,[397]6.1814,[398]6.1852,[399]6.1935,[400]6.1931,[401]6.1944,[402]6.1950,[403]6.1969,[404]6.2032,[405]6.1937,[406]6.1903,[407]6.1895,[408]6.1905,[409]6.2029,[410]6.2139,[411]6.2264,[412]6.2427,[413]6.2542,[414]6.2618,[415]6.2670,[416]6.2750,[417]6.2881,[418]6.2916,[419]6.2990,[420]6.3076,[421]6.3197,[422]6.3255,[423]6.3326,[424]6.3446,[425]6.3537,[426]6.3602,[427]6.3647,[428]6.3730,[429]6.3775,[430]6.3865,[431]6.4011,[432]6.4054,[433]6.4041,[434]6.3995,[435]6.4002,[436]6.4026,[437]6.4121,[438]6.4200,[439]6.4164,[440]6.4158,[441]6.4108,[442]6.4099,[443]6.4112,[444]6.4115,[445]6.4095,[446]6.4118,[447]6.4147,[448]6.4190,[449]6.4164,[450]6.4167,[451]6.4124,[452]6.4005,[453]6.3922,[454]6.3862,[455]6.3869,[456]6.3917,[457]6.3934,[458]6.3912,[459]6.3922,[460]6.4009,[461]6.3981,[462]6.3965,[463]6.4015,[464]6.4006,[465]6.3976,[466]6.3895,[467]6.3898,[468]6.3897,[469]6.3919,[470]6.3924,[471]6.3876,[472]6.3922,[473]6.3866,[474]6.3880,[475]6.3821,[476]6.3844,[477]6.3773,[478]6.3764,[479]6.3827,[480]6.3879,[481]6.3899,[482]6.3854,[483]6.3813,[484]6.3835,[485]6.3818,[486]6.3763,[487]6.3763,[488]6.3744,[489]6.3694,[490]6.3667,[491]6.3637,[492]6.3579,[493]6.3548,[494]6.3530,[495]6.3528,[496]6.3493,[497]6.3440,[498]6.3422,[499]6.3372,[500]6.3275,[501]6.3206,[502]6.3204,[503]6.3202,[504]6.3109,[505]6.3133,[506]6.3142,[507]6.3081,[508]6.3038,[509]6.3027,[510]6.3066,[511]6.3113,[512]6.3148,[513]6.3166,[514]6.3232,[515]6.3177,[516]6.3169,[517]6.3180,[518]6.3181,[519]6.3211,[520]6.3238,[521]6.3255,[522]6.3283,[523]6.3294,[524]6.3357,[525]6.3393,[526]6.3405,[527]6.3426,[528]6.3372,[529]6.3376,[530]6.3329,[531]6.3319,[532]6.3367,[533]6.3390,[534]6.3372,[535]6.3395,[536]6.3341,[537]6.3318,[538]6.3366,[539]6.3378,[540]6.3417,[541]6.3426,[542]6.3433,[543]6.3447,[544]6.3459,[545]6.3437,[546]6.3443,[547]6.3398,[548]6.3343,[549]6.3345,[550]6.3318,[551]6.3280,[552]6.3260,[553]6.3217,[554]6.3195,[555]6.3166,[556]6.3163,[557]6.3186,[558]6.3146,[559]6.3142,[560]6.3137,[561]6.3139,[562]6.3120,[563]6.3120,[564]6.3163,[565]6.3180,[566]6.3177,[567]6.3155,[568]6.3160,[569]6.3144,[570]6.3170,[571]6.3176,[572]6.3186,[573]6.3188,[574]6.3151,[575]6.3147,[576]6.3145,[577]6.3135,[578]6.3114,[579]6.3122,[580]6.3056,[581]6.3018,[582]6.3008,[583]6.3016,[584]6.3020,[585]6.2943,[586]6.2875,[587]6.2877,[588]6.2927,[589]6.2985,[590]6.3015,[591]6.3037,[592]6.3022,[593]6.2985,[594]6.2996,[595]6.2973,[596]6.3010,[597]6.2987,[598]6.2949,[599]6.2971,[600]6.2969,[601]6.2954,[602]6.2972,[603]6.3001,[604]6.3012,[605]6.3044,[606]6.3065,[607]6.3048,[608]6.3013,[609]6.3019,[610]6.3056,[611]6.3037,[612]6.3062,[613]6.3026,[614]6.2975,[615]6.2898,[616]6.2928,[617]6.2865,[618]6.2814,[619]6.2757,[620]6.2615,[621]6.2542,[622]6.2525,[623]6.2540,[624]6.2545,[625]6.2544,[626]6.2529,[627]6.2550,[628]6.2555,[629]6.2552,[630]6.2586,[631]6.2650,[632]6.2704,[633]6.2687,[634]6.2721,[635]6.2726,[636]6.2694,[637]6.2659,[638]6.2686,[639]6.2657,[640]6.2667,[641]6.2669,[642]6.2738,[643]6.2760,[644]6.2772,[645]6.2751,[646]6.2793,[647]6.2755,[648]6.2762,[649]6.2763,[650]6.2801,[651]6.2858,[652]6.2865,[653]6.2908,[654]6.2844,[655]6.2838,

llama_print_timings:        load time = 11216.03 ms
llama_print_timings:      sample time =     0.00 ms /     1 runs   (    0.00 ms per run)
llama_print_timings: prompt eval time = 4989892.61 ms / 335360 tokens (   14.88 ms per token)
llama_print_timings:        eval time =     0.00 ms /     1 runs   (    0.00 ms per run)
llama_print_timings:       total time = 5024616.43 ms

real	83m45.024s
user	126m54.284s
sys	4m10.884s
Q4_0 M1 Pro (without BLAS) [655]6.2897 (impl 1))
make clean && LLAMA_NO_ACCELERATE=1 make -j perplexity && time ./perplexity -m ./models/7B/ggml-model-q4_0.bin -f ./build/wiki.test.raw -t 8
I llama.cpp build info: 
I UNAME_S:  Darwin
I UNAME_P:  arm
I UNAME_M:  arm64
I CFLAGS:   -I.              -O3 -DNDEBUG -std=c11   -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function -pthread -DGGML_USE_ACCELERATE
I CXXFLAGS: -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread
I LDFLAGS:   -framework Accelerate
I CC:       Apple clang version 14.0.3 (clang-1403.0.22.14.1)
I CXX:      Apple clang version 14.0.3 (clang-1403.0.22.14.1)

rm -vf *.o main quantize quantize-stats perplexity embedding benchmark-q4_0-matmult
common.o
ggml.o
llama.o
perplexity
I llama.cpp build info: 
I UNAME_S:  Darwin
I UNAME_P:  arm
I UNAME_M:  arm64
I CFLAGS:   -I.              -O3 -DNDEBUG -std=c11   -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function -pthread
I CXXFLAGS: -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread
I LDFLAGS:  
I CC:       Apple clang version 14.0.3 (clang-1403.0.22.14.1)
I CXX:      Apple clang version 14.0.3 (clang-1403.0.22.14.1)

cc  -I.              -O3 -DNDEBUG -std=c11   -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function -pthread   -c ggml.c -o ggml.o
c++ -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread -c llama.cpp -o llama.o
c++ -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread -c examples/common.cpp -o common.o
c++ -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread examples/perplexity/perplexity.cpp ggml.o llama.o common.o -o perplexity 
main: seed = 1681424328
llama.cpp: loading model from ./models/7B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v1 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 4096
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 32
llama_model_load_internal: n_layer    = 32
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 11008
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size =  59.11 KB
llama_model_load_internal: mem required  = 5809.32 MB (+ 1026.00 MB per state)
llama_init_from_file: kv self size  =  256.00 MB

system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 0 | VSX = 0 | 
perplexity : calculating perplexity over 655 chunks
21.53 seconds per pass - ETA 3.92 hours
[1]4.3858,[2]4.9662,[3]5.8382,[4]6.4833,[5]6.5555,[6]6.5501,[7]6.7244,[8]6.8139,[9]7.1831,[10]7.4180,[11]7.6620,[12]7.7028,[13]7.6127,[14]7.6873,[15]7.9419,[16]7.5479,[17]7.4241,[18]7.3855,[19]7.0131,[20]7.0003,[21]6.9022,[22]6.7175,[23]6.6793,[24]6.5912,[25]6.5917,[26]6.4198,[27]6.2395,[28]6.1386,[29]6.0541,[30]5.8977,[31]5.8703,[32]5.8883,[33]5.8232,[34]5.8586,[35]5.8843,[36]5.9276,[37]5.9310,[38]5.9481,[39]5.9864,[40]6.0454,[41]6.0524,[42]6.0870,[43]6.0440,[44]6.0991,[45]6.1037,[46]6.0781,[47]6.1018,[48]6.0726,[49]6.0798,[50]6.0408,[51]6.0364,[52]6.0252,[53]6.0690,[54]6.0526,[55]6.0301,[56]6.0646,[57]6.0877,[58]6.1096,[59]6.1235,[60]6.1705,[61]6.1590,[62]6.2224,[63]6.2561,[64]6.2711,[65]6.3178,[66]6.3278,[67]6.3458,[68]6.3598,[69]6.3848,[70]6.4170,[71]6.4384,[72]6.4684,[73]6.5335,[74]6.5388,[75]6.5536,[76]6.5697,[77]6.5828,[78]6.5677,[79]6.5973,[80]6.5898,[81]6.6026,[82]6.6063,[83]6.5526,[84]6.5381,[85]6.5266,[86]6.5055,[87]6.4398,[88]6.4112,[89]6.3906,[90]6.3741,[91]6.4001,[92]6.3963,[93]6.3990,[94]6.3964,[95]6.4252,[96]6.4229,[97]6.4157,[98]6.4089,[99]6.3949,[100]6.3949,[101]6.4209,[102]6.4145,[103]6.4366,[104]6.4436,[105]6.4421,[106]6.4597,[107]6.4582,[108]6.4706,[109]6.4653,[110]6.4607,[111]6.4837,[112]6.5027,[113]6.5043,[114]6.5012,[115]6.5094,[116]6.5021,[117]6.5076,[118]6.5361,[119]6.5569,[120]6.5934,[121]6.6098,[122]6.6346,[123]6.6736,[124]6.6913,[125]6.6824,[126]6.7216,[127]6.7587,[128]6.7862,[129]6.7693,[130]6.7789,[131]6.7735,[132]6.7647,[133]6.7519,[134]6.7631,[135]6.7596,[136]6.7465,[137]6.7384,[138]6.7213,[139]6.7096,[140]6.7065,[141]6.6768,[142]6.6719,[143]6.6440,[144]6.6239,[145]6.6154,[146]6.6019,[147]6.6096,[148]6.6118,[149]6.6056,[150]6.6015,[151]6.6026,[152]6.5932,[153]6.5765,[154]6.5675,[155]6.5742,[156]6.5692,[157]6.5876,[158]6.5912,[159]6.5953,[160]6.5979,[161]6.6103,[162]6.5801,[163]6.5680,[164]6.5418,[165]6.5099,[166]6.4812,[167]6.4437,[168]6.4110,[169]6.3975,[170]6.3851,[171]6.3561,[172]6.3380,[173]6.3193,[174]6.2885,[175]6.2663,[176]6.2561,[177]6.2350,[178]6.2114,[179]6.1942,[180]6.1853,[181]6.1628,[182]6.1436,[183]6.1293,[184]6.1291,[185]6.1219,[186]6.1236,[187]6.1290,[188]6.1253,[189]6.1438,[190]6.1447,[191]6.1652,[192]6.1816,[193]6.1993,[194]6.2109,[195]6.2318,[196]6.2488,[197]6.2708,[198]6.2864,[199]6.2893,[200]6.2938,[201]6.2896,[202]6.3103,[203]6.3169,[204]6.3168,[205]6.3278,[206]6.3357,[207]6.3317,[208]6.3401,[209]6.3453,[210]6.3504,[211]6.3604,[212]6.3678,[213]6.3785,[214]6.3821,[215]6.3863,[216]6.4012,[217]6.4190,[218]6.4325,[219]6.4328,[220]6.4292,[221]6.4228,[222]6.4192,[223]6.4083,[224]6.4016,[225]6.3968,[226]6.4183,[227]6.4270,[228]6.4329,[229]6.4397,[230]6.4352,[231]6.4521,[232]6.4390,[233]6.4218,[234]6.4061,[235]6.3904,[236]6.3826,[237]6.3721,[238]6.3755,[239]6.3593,[240]6.3489,[241]6.3522,[242]6.3559,[243]6.3544,[244]6.3425,[245]6.3397,[246]6.3276,[247]6.3152,[248]6.3085,[249]6.3065,[250]6.3112,[251]6.3035,[252]6.3001,[253]6.2898,[254]6.2857,[255]6.2741,[256]6.2550,[257]6.2438,[258]6.2352,[259]6.2332,[260]6.2250,[261]6.2207,[262]6.2147,[263]6.2103,[264]6.1913,[265]6.1904,[266]6.1890,[267]6.1820,[268]6.1917,[269]6.1897,[270]6.1904,[271]6.1982,[272]6.2029,[273]6.2023,[274]6.2038,[275]6.2128,[276]6.2183,[277]6.2343,[278]6.2452,[279]6.2538,[280]6.2574,[281]6.2672,[282]6.2733,[283]6.2880,[284]6.2958,[285]6.3053,[286]6.3200,[287]6.3194,[288]6.3255,[289]6.3163,[290]6.3012,[291]6.2857,[292]6.2699,[293]6.2561,[294]6.2586,[295]6.2580,[296]6.2624,[297]6.2610,[298]6.2636,[299]6.2608,[300]6.2495,[301]6.2497,[302]6.2416,[303]6.2339,[304]6.2261,[305]6.2236,[306]6.2104,[307]6.2128,[308]6.2161,[309]6.1997,[310]6.1936,[311]6.1873,[312]6.1895,[313]6.1839,[314]6.1826,[315]6.1661,[316]6.1618,[317]6.1451,[318]6.1234,[319]6.1353,[320]6.1483,[321]6.1521,[322]6.1476,[323]6.1410,[324]6.1386,[325]6.1485,[326]6.1485,[327]6.1506,[328]6.1549,[329]6.1609,[330]6.1634,[331]6.1758,[332]6.1726,[333]6.1796,[334]6.1737,[335]6.1672,[336]6.1710,[337]6.1680,[338]6.1667,[339]6.1609,[340]6.1565,[341]6.1643,[342]6.1668,[343]6.1723,[344]6.1722,[345]6.1721,[346]6.1692,[347]6.1740,[348]6.1782,[349]6.1800,[350]6.1766,[351]6.1771,[352]6.1771,[353]6.1720,[354]6.1718,[355]6.1773,[356]6.1803,[357]6.1767,[358]6.1857,[359]6.1888,[360]6.1850,[361]6.1846,[362]6.1912,[363]6.2024,[364]6.2089,[365]6.2148,[366]6.2155,[367]6.2242,[368]6.2220,[369]6.2229,[370]6.2239,[371]6.2179,[372]6.2231,[373]6.2287,[374]6.2274,[375]6.2270,[376]6.2354,[377]6.2305,[378]6.2330,[379]6.2391,[380]6.2307,[381]6.2264,[382]6.2206,[383]6.2196,[384]6.2189,[385]6.2176,[386]6.2172,[387]6.2163,[388]6.2118,[389]6.2064,[390]6.1995,[391]6.1914,[392]6.1873,[393]6.1854,[394]6.1880,[395]6.1864,[396]6.1790,[397]6.1866,[398]6.1904,[399]6.1987,[400]6.1983,[401]6.1996,[402]6.2002,[403]6.2021,[404]6.2084,[405]6.1989,[406]6.1955,[407]6.1947,[408]6.1957,[409]6.2081,[410]6.2191,[411]6.2316,[412]6.2479,[413]6.2594,[414]6.2670,[415]6.2722,[416]6.2803,[417]6.2934,[418]6.2969,[419]6.3043,[420]6.3130,[421]6.3251,[422]6.3308,[423]6.3380,[424]6.3500,[425]6.3592,[426]6.3657,[427]6.3701,[428]6.3785,[429]6.3829,[430]6.3920,[431]6.4066,[432]6.4109,[433]6.4097,[434]6.4050,[435]6.4057,[436]6.4081,[437]6.4176,[438]6.4255,[439]6.4220,[440]6.4214,[441]6.4163,[442]6.4154,[443]6.4167,[444]6.4170,[445]6.4150,[446]6.4174,[447]6.4203,[448]6.4247,[449]6.4220,[450]6.4223,[451]6.4180,[452]6.4062,[453]6.3979,[454]6.3918,[455]6.3925,[456]6.3972,[457]6.3990,[458]6.3968,[459]6.3978,[460]6.4064,[461]6.4036,[462]6.4020,[463]6.4071,[464]6.4062,[465]6.4032,[466]6.3951,[467]6.3954,[468]6.3952,[469]6.3974,[470]6.3980,[471]6.3932,[472]6.3978,[473]6.3921,[474]6.3935,[475]6.3876,[476]6.3900,[477]6.3828,[478]6.3818,[479]6.3882,[480]6.3934,[481]6.3955,[482]6.3909,[483]6.3868,[484]6.3890,[485]6.3873,[486]6.3817,[487]6.3818,[488]6.3799,[489]6.3749,[490]6.3722,[491]6.3692,[492]6.3634,[493]6.3603,[494]6.3585,[495]6.3583,[496]6.3548,[497]6.3495,[498]6.3477,[499]6.3427,[500]6.3330,[501]6.3260,[502]6.3259,[503]6.3257,[504]6.3163,[505]6.3188,[506]6.3198,[507]6.3137,[508]6.3094,[509]6.3083,[510]6.3122,[511]6.3168,[512]6.3204,[513]6.3222,[514]6.3288,[515]6.3232,[516]6.3225,[517]6.3235,[518]6.3236,[519]6.3267,[520]6.3294,[521]6.3311,[522]6.3339,[523]6.3350,[524]6.3412,[525]6.3449,[526]6.3461,[527]6.3481,[528]6.3428,[529]6.3432,[530]6.3385,[531]6.3375,[532]6.3423,[533]6.3446,[534]6.3428,[535]6.3452,[536]6.3398,[537]6.3374,[538]6.3422,[539]6.3434,[540]6.3474,[541]6.3483,[542]6.3490,[543]6.3504,[544]6.3516,[545]6.3494,[546]6.3501,[547]6.3456,[548]6.3401,[549]6.3403,[550]6.3375,[551]6.3338,[552]6.3317,[553]6.3275,[554]6.3253,[555]6.3224,[556]6.3221,[557]6.3244,[558]6.3204,[559]6.3200,[560]6.3195,[561]6.3196,[562]6.3178,[563]6.3178,[564]6.3221,[565]6.3239,[566]6.3235,[567]6.3213,[568]6.3218,[569]6.3201,[570]6.3228,[571]6.3234,[572]6.3244,[573]6.3245,[574]6.3209,[575]6.3204,[576]6.3203,[577]6.3193,[578]6.3172,[579]6.3180,[580]6.3114,[581]6.3076,[582]6.3066,[583]6.3074,[584]6.3077,[585]6.3001,[586]6.2933,[587]6.2935,[588]6.2985,[589]6.3043,[590]6.3074,[591]6.3095,[592]6.3081,[593]6.3044,[594]6.3054,[595]6.3031,[596]6.3069,[597]6.3045,[598]6.3007,[599]6.3029,[600]6.3027,[601]6.3013,[602]6.3031,[603]6.3061,[604]6.3071,[605]6.3104,[606]6.3125,[607]6.3108,[608]6.3073,[609]6.3079,[610]6.3115,[611]6.3097,[612]6.3122,[613]6.3086,[614]6.3035,[615]6.2957,[616]6.2988,[617]6.2925,[618]6.2873,[619]6.2817,[620]6.2674,[621]6.2602,[622]6.2585,[623]6.2600,[624]6.2604,[625]6.2603,[626]6.2588,[627]6.2610,[628]6.2615,[629]6.2612,[630]6.2646,[631]6.2710,[632]6.2764,[633]6.2747,[634]6.2780,[635]6.2786,[636]6.2754,[637]6.2719,[638]6.2746,[639]6.2717,[640]6.2727,[641]6.2730,[642]6.2798,[643]6.2820,[644]6.2832,[645]6.2810,[646]6.2853,[647]6.2815,[648]6.2821,[649]6.2822,[650]6.2861,[651]6.2917,[652]6.2924,[653]6.2967,[654]6.2903,[655]6.2897,
Q4_0 M1 Pro (without BLAS) [655]6.2895 (impl 2)
$  make clean && LLAMA_NO_ACCELERATE=1 make -j perplexity && time ./perplexity -m ./models/7B/ggml-model-q4_0.bin -f ./build/wiki.test.raw -t 8
I llama.cpp build info: 
I UNAME_S:  Darwin
I UNAME_P:  arm
I UNAME_M:  arm64
I CFLAGS:   -I.              -O3 -DNDEBUG -std=c11   -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function -pthread -DGGML_USE_ACCELERATE
I CXXFLAGS: -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread
I LDFLAGS:   -framework Accelerate
I CC:       Apple clang version 14.0.3 (clang-1403.0.22.14.1)
I CXX:      Apple clang version 14.0.3 (clang-1403.0.22.14.1)

rm -vf *.o main quantize quantize-stats perplexity embedding benchmark-q4_0-matmult
I llama.cpp build info: 
I UNAME_S:  Darwin
I UNAME_P:  arm
I UNAME_M:  arm64
I CFLAGS:   -I.              -O3 -DNDEBUG -std=c11   -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function -pthread
I CXXFLAGS: -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread
I LDFLAGS:  
I CC:       Apple clang version 14.0.3 (clang-1403.0.22.14.1)
I CXX:      Apple clang version 14.0.3 (clang-1403.0.22.14.1)

cc  -I.              -O3 -DNDEBUG -std=c11   -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function -pthread   -c ggml.c -o ggml.o
c++ -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread -c llama.cpp -o llama.o
c++ -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread -c examples/common.cpp -o common.o
c++ -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread examples/perplexity/perplexity.cpp ggml.o llama.o common.o -o perplexity 
main: seed = 1681502601
llama.cpp: loading model from ./models/7B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v1 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 4096
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 32
llama_model_load_internal: n_layer    = 32
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 11008
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size =  59.11 KB
llama_model_load_internal: mem required  = 5809.32 MB (+ 1026.00 MB per state)
llama_init_from_file: kv self size  =  256.00 MB

system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 0 | VSX = 0 | 
perplexity : calculating perplexity over 655 chunks, batch_size=512
16.61 seconds per pass - ETA 3.02 hours
[1]4.3830,[2]4.9584,[3]5.8322,[4]6.4747,[5]6.5500,[6]6.5456,[7]6.7229,[8]6.8116,[9]7.1823,[10]7.4177,[11]7.6627,[12]7.7040,[13]7.6137,[14]7.6886,[15]7.9441,[16]7.5493,[17]7.4257,[18]7.3873,[19]7.0143,[20]7.0007,[21]6.9029,[22]6.7178,[23]6.6799,[24]6.5919,[25]6.5925,[26]6.4205,[27]6.2399,[28]6.1389,[29]6.0543,[30]5.8981,[31]5.8706,[32]5.8888,[33]5.8238,[34]5.8593,[35]5.8851,[36]5.9284,[37]5.9321,[38]5.9490,[39]5.9871,[40]6.0456,[41]6.0527,[42]6.0872,[43]6.0444,[44]6.0994,[45]6.1042,[46]6.0786,[47]6.1023,[48]6.0731,[49]6.0802,[50]6.0411,[51]6.0367,[52]6.0256,[53]6.0694,[54]6.0531,[55]6.0305,[56]6.0652,[57]6.0880,[58]6.1098,[59]6.1237,[60]6.1703,[61]6.1591,[62]6.2223,[63]6.2562,[64]6.2711,[65]6.3176,[66]6.3278,[67]6.3461,[68]6.3601,[69]6.3854,[70]6.4176,[71]6.4391,[72]6.4690,[73]6.5340,[74]6.5395,[75]6.5540,[76]6.5703,[77]6.5834,[78]6.5683,[79]6.5980,[80]6.5905,[81]6.6033,[82]6.6072,[83]6.5535,[84]6.5390,[85]6.5275,[86]6.5063,[87]6.4411,[88]6.4125,[89]6.3919,[90]6.3753,[91]6.4013,[92]6.3973,[93]6.3999,[94]6.3974,[95]6.4262,[96]6.4240,[97]6.4168,[98]6.4097,[99]6.3956,[100]6.3954,[101]6.4214,[102]6.4150,[103]6.4371,[104]6.4438,[105]6.4423,[106]6.4599,[107]6.4585,[108]6.4709,[109]6.4656,[110]6.4609,[111]6.4838,[112]6.5028,[113]6.5042,[114]6.5011,[115]6.5094,[116]6.5019,[117]6.5074,[118]6.5360,[119]6.5570,[120]6.5934,[121]6.6097,[122]6.6346,[123]6.6735,[124]6.6914,[125]6.6824,[126]6.7216,[127]6.7587,[128]6.7862,[129]6.7693,[130]6.7787,[131]6.7733,[132]6.7644,[133]6.7517,[134]6.7628,[135]6.7593,[136]6.7462,[137]6.7382,[138]6.7209,[139]6.7092,[140]6.7061,[141]6.6763,[142]6.6715,[143]6.6435,[144]6.6235,[145]6.6149,[146]6.6013,[147]6.6090,[148]6.6112,[149]6.6051,[150]6.6009,[151]6.6020,[152]6.5926,[153]6.5758,[154]6.5668,[155]6.5735,[156]6.5686,[157]6.5870,[158]6.5906,[159]6.5948,[160]6.5974,[161]6.6099,[162]6.5797,[163]6.5675,[164]6.5413,[165]6.5094,[166]6.4807,[167]6.4432,[168]6.4104,[169]6.3970,[170]6.3845,[171]6.3555,[172]6.3374,[173]6.3188,[174]6.2880,[175]6.2658,[176]6.2555,[177]6.2344,[178]6.2109,[179]6.1937,[180]6.1849,[181]6.1624,[182]6.1431,[183]6.1289,[184]6.1287,[185]6.1214,[186]6.1232,[187]6.1286,[188]6.1249,[189]6.1433,[190]6.1442,[191]6.1647,[192]6.1811,[193]6.1988,[194]6.2104,[195]6.2314,[196]6.2484,[197]6.2705,[198]6.2861,[199]6.2891,[200]6.2937,[201]6.2895,[202]6.3099,[203]6.3165,[204]6.3164,[205]6.3274,[206]6.3352,[207]6.3311,[208]6.3396,[209]6.3448,[210]6.3499,[211]6.3597,[212]6.3671,[213]6.3778,[214]6.3814,[215]6.3856,[216]6.4005,[217]6.4184,[218]6.4320,[219]6.4322,[220]6.4287,[221]6.4225,[222]6.4188,[223]6.4079,[224]6.4012,[225]6.3965,[226]6.4181,[227]6.4268,[228]6.4327,[229]6.4393,[230]6.4349,[231]6.4518,[232]6.4387,[233]6.4215,[234]6.4059,[235]6.3901,[236]6.3823,[237]6.3719,[238]6.3752,[239]6.3590,[240]6.3487,[241]6.3520,[242]6.3557,[243]6.3542,[244]6.3422,[245]6.3395,[246]6.3274,[247]6.3150,[248]6.3083,[249]6.3062,[250]6.3109,[251]6.3032,[252]6.2998,[253]6.2895,[254]6.2855,[255]6.2739,[256]6.2548,[257]6.2436,[258]6.2349,[259]6.2328,[260]6.2247,[261]6.2204,[262]6.2145,[263]6.2100,[264]6.1910,[265]6.1902,[266]6.1887,[267]6.1817,[268]6.1914,[269]6.1895,[270]6.1901,[271]6.1979,[272]6.2025,[273]6.2020,[274]6.2034,[275]6.2124,[276]6.2179,[277]6.2340,[278]6.2449,[279]6.2535,[280]6.2571,[281]6.2670,[282]6.2731,[283]6.2878,[284]6.2956,[285]6.3051,[286]6.3197,[287]6.3191,[288]6.3252,[289]6.3161,[290]6.3009,[291]6.2855,[292]6.2697,[293]6.2558,[294]6.2583,[295]6.2577,[296]6.2621,[297]6.2607,[298]6.2634,[299]6.2606,[300]6.2494,[301]6.2495,[302]6.2414,[303]6.2337,[304]6.2259,[305]6.2234,[306]6.2102,[307]6.2126,[308]6.2158,[309]6.1994,[310]6.1934,[311]6.1871,[312]6.1893,[313]6.1837,[314]6.1824,[315]6.1659,[316]6.1616,[317]6.1450,[318]6.1233,[319]6.1351,[320]6.1482,[321]6.1520,[322]6.1475,[323]6.1409,[324]6.1384,[325]6.1484,[326]6.1484,[327]6.1505,[328]6.1547,[329]6.1607,[330]6.1632,[331]6.1756,[332]6.1725,[333]6.1795,[334]6.1736,[335]6.1671,[336]6.1709,[337]6.1679,[338]6.1666,[339]6.1608,[340]6.1564,[341]6.1642,[342]6.1666,[343]6.1722,[344]6.1721,[345]6.1721,[346]6.1692,[347]6.1739,[348]6.1781,[349]6.1799,[350]6.1765,[351]6.1770,[352]6.1771,[353]6.1719,[354]6.1718,[355]6.1773,[356]6.1803,[357]6.1767,[358]6.1857,[359]6.1888,[360]6.1851,[361]6.1847,[362]6.1913,[363]6.2025,[364]6.2090,[365]6.2149,[366]6.2156,[367]6.2243,[368]6.2220,[369]6.2229,[370]6.2239,[371]6.2180,[372]6.2232,[373]6.2289,[374]6.2276,[375]6.2272,[376]6.2356,[377]6.2307,[378]6.2332,[379]6.2393,[380]6.2308,[381]6.2266,[382]6.2208,[383]6.2198,[384]6.2190,[385]6.2178,[386]6.2173,[387]6.2165,[388]6.2119,[389]6.2065,[390]6.1995,[391]6.1914,[392]6.1874,[393]6.1855,[394]6.1881,[395]6.1865,[396]6.1791,[397]6.1866,[398]6.1904,[399]6.1987,[400]6.1983,[401]6.1997,[402]6.2003,[403]6.2021,[404]6.2085,[405]6.1990,[406]6.1956,[407]6.1948,[408]6.1958,[409]6.2082,[410]6.2192,[411]6.2317,[412]6.2480,[413]6.2595,[414]6.2671,[415]6.2723,[416]6.2804,[417]6.2935,[418]6.2970,[419]6.3043,[420]6.3130,[421]6.3251,[422]6.3308,[423]6.3380,[424]6.3500,[425]6.3591,[426]6.3656,[427]6.3701,[428]6.3784,[429]6.3829,[430]6.3919,[431]6.4065,[432]6.4108,[433]6.4095,[434]6.4049,[435]6.4056,[436]6.4080,[437]6.4175,[438]6.4254,[439]6.4218,[440]6.4213,[441]6.4162,[442]6.4153,[443]6.4166,[444]6.4169,[445]6.4149,[446]6.4173,[447]6.4202,[448]6.4246,[449]6.4218,[450]6.4221,[451]6.4179,[452]6.4060,[453]6.3977,[454]6.3917,[455]6.3924,[456]6.3972,[457]6.3989,[458]6.3967,[459]6.3977,[460]6.4063,[461]6.4035,[462]6.4019,[463]6.4070,[464]6.4061,[465]6.4030,[466]6.3950,[467]6.3952,[468]6.3951,[469]6.3973,[470]6.3978,[471]6.3930,[472]6.3976,[473]6.3920,[474]6.3933,[475]6.3875,[476]6.3898,[477]6.3826,[478]6.3817,[479]6.3881,[480]6.3933,[481]6.3953,[482]6.3907,[483]6.3866,[484]6.3888,[485]6.3871,[486]6.3816,[487]6.3816,[488]6.3797,[489]6.3748,[490]6.3721,[491]6.3690,[492]6.3632,[493]6.3602,[494]6.3584,[495]6.3582,[496]6.3547,[497]6.3494,[498]6.3476,[499]6.3426,[500]6.3329,[501]6.3259,[502]6.3258,[503]6.3255,[504]6.3162,[505]6.3187,[506]6.3196,[507]6.3135,[508]6.3092,[509]6.3081,[510]6.3121,[511]6.3167,[512]6.3202,[513]6.3220,[514]6.3287,[515]6.3231,[516]6.3224,[517]6.3234,[518]6.3235,[519]6.3266,[520]6.3293,[521]6.3310,[522]6.3339,[523]6.3349,[524]6.3412,[525]6.3449,[526]6.3461,[527]6.3482,[528]6.3428,[529]6.3432,[530]6.3385,[531]6.3375,[532]6.3423,[533]6.3446,[534]6.3428,[535]6.3452,[536]6.3398,[537]6.3374,[538]6.3422,[539]6.3434,[540]6.3474,[541]6.3483,[542]6.3490,[543]6.3504,[544]6.3516,[545]6.3494,[546]6.3501,[547]6.3455,[548]6.3400,[549]6.3402,[550]6.3375,[551]6.3337,[552]6.3316,[553]6.3274,[554]6.3252,[555]6.3223,[556]6.3220,[557]6.3243,[558]6.3204,[559]6.3199,[560]6.3194,[561]6.3195,[562]6.3177,[563]6.3177,[564]6.3221,[565]6.3238,[566]6.3234,[567]6.3212,[568]6.3217,[569]6.3201,[570]6.3227,[571]6.3234,[572]6.3244,[573]6.3245,[574]6.3209,[575]6.3204,[576]6.3203,[577]6.3192,[578]6.3172,[579]6.3180,[580]6.3113,[581]6.3075,[582]6.3066,[583]6.3073,[584]6.3077,[585]6.3000,[586]6.2932,[587]6.2934,[588]6.2984,[589]6.3042,[590]6.3073,[591]6.3094,[592]6.3080,[593]6.3043,[594]6.3054,[595]6.3030,[596]6.3068,[597]6.3044,[598]6.3007,[599]6.3028,[600]6.3026,[601]6.3011,[602]6.3029,[603]6.3059,[604]6.3070,[605]6.3103,[606]6.3123,[607]6.3106,[608]6.3071,[609]6.3077,[610]6.3114,[611]6.3095,[612]6.3120,[613]6.3084,[614]6.3033,[615]6.2955,[616]6.2985,[617]6.2923,[618]6.2871,[619]6.2815,[620]6.2672,[621]6.2599,[622]6.2582,[623]6.2597,[624]6.2602,[625]6.2601,[626]6.2585,[627]6.2606,[628]6.2612,[629]6.2609,[630]6.2643,[631]6.2707,[632]6.2761,[633]6.2744,[634]6.2777,[635]6.2783,[636]6.2751,[637]6.2717,[638]6.2743,[639]6.2714,[640]6.2724,[641]6.2727,[642]6.2796,[643]6.2817,[644]6.2829,[645]6.2808,[646]6.2850,[647]6.2812,[648]6.2819,[649]6.2820,[650]6.2859,[651]6.2915,[652]6.2922,[653]6.2965,[654]6.2902,[655]6.2895,

llama_print_timings:        load time = 17160.99 ms
llama_print_timings:      sample time =     0.00 ms /     1 runs   (    0.00 ms per run)
llama_print_timings: prompt eval time = 12080168.91 ms / 335360 tokens (   36.02 ms per token)
llama_print_timings:        eval time =     0.00 ms /     1 runs   (    0.00 ms per run)
llama_print_timings:       total time = 12113760.51 ms

real	201m54.092s
user	1602m45.695s
sys	2m57.719s

@ggerganov ggerganov added help wanted Extra attention is needed high priority Very important issue generation quality Quality of model output labels Apr 13, 2023
sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These casts are to avoid the edge case of 128 + 128 overflowing uint8_t

@ggerganov ggerganov requested a review from sw April 13, 2023 20:14
ggml.c Show resolved Hide resolved
ggml.h Show resolved Hide resolved
ggml.c Show resolved Hide resolved
ggml.c Outdated Show resolved Hide resolved
@ggerganov
Copy link
Owner Author

I'm still waiting on the final perplexity results for this approach on my machine, but if we confirm that this is the right way to go, I am wondering if we should just remove all SIMD implementations of quantize_row_q4_0 and quantize_row_q4_1 as they will no longer be needed at run-time. Moreover, @ikawrakow reported in #896 that the vectorized implementation of these calls are only 10% faster than the reference. Considering that they take a relatively small part of the total time and that the code will be significantly simplified, maybe it's a good idea.

It's possible that the same holds true for dequantize_row_q4_0 and dequantize_row_q4_1.
These are used during prompt evaluation with BLAS enabled. A timing analysis is needed to decide if it is worth maintaining the cumbersome SIMD implementations.

@howard0su
Copy link
Collaborator

howard0su commented Apr 14, 2023

Add AVX support

diff --git a/ggml.c b/ggml.c
index db8babb..cb92adb 100644
--- a/ggml.c
+++ b/ggml.c
@@ -2578,6 +2578,53 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     }

     sumf = sum0 + sum1;
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    const __m128i off_8 = _mm_set1_epi8(8);
+    const __m128i off_128 = _mm_set1_epi8(128);
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        // Compute combined scale for the block
+        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+
+        __m128i i32[2];
+        for (int j = 0; j < 2; ++j) {
+            // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
+            __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
+            __m128i by = _mm_loadu_epi8( y[i].qs + 16*j);
+
+            // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+            bx = _mm_sub_epi8( bx, off_8 );
+            by = _mm_sub_epi8( by, off_128 );
+
+            // Get absolute values of x vectors
+            const __m128i ax = _mm_sign_epi8(bx, bx);
+
+            // Sign the values of the y vectors
+            const __m128i sy = _mm_sign_epi8(by, bx);
+
+            // Perform multiplication and create 16-bit values
+            const __m128i dot = _mm_maddubs_epi16(ax, sy);
+
+            const __m128i ones = _mm_set1_epi16(1);
+            i32[j] = _mm_madd_epi16(ones, dot);
+        }
+
+        // Convert int32_t to float
+        __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
+        // Apply the scale, and accumulate
+        acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
+    }
+
+    // Return horizontal sum of the acc vector
+    __m128 res = _mm256_extractf128_ps( acc, 1 );
+    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
+    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
+    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
+
+    sumf = _mm_cvtss_f32( res );
 #else
     // scalar
     for (int i = 0; i < nb; i++) {

@slaren
Copy link
Collaborator

slaren commented Apr 14, 2023

@ggerganov The SIMD quantize functions could still be useful to allow using LoRA with quantized models, either by modifying the quantized model, in which case, ggml_add needs to dequantize, add and re-quantize the matrices, or to be able to support in they fly quantization from f16, in which case we would take the f16 layers from a base model, apply the LoRA adapter, and then quantize it. Without the SIMD implementations this may be too slow to do on the fly.

@ggerganov ggerganov self-assigned this Apr 14, 2023
@sw
Copy link
Contributor

sw commented Apr 14, 2023

@howard0su

_mm_loadu_epi8

Intel's guide tells me this is AVX512? _mm_loadu_si128 should be equivalent, though.

Anyway I've already started doing AVX and AVX2 on a branch mentioned here #909 (comment). I'll try to get a PR against this PR ready.

@ggerganov
Copy link
Owner Author

Rebased on latest master

Will run perplexity "without BLAS" once more to confirm we didn't break something with the rounding.
Will try to implement the WASM path for Q8_0 and merge this to master.

The Q8_0 model file creation will be done later in another PR (low-prio).

Question: should we add Q8_1? It might not be worth, since we are already very close to the optimal result (i.e. "with BLAS")

Thinking that after we merge this, we should focus on implementing the Q4_0 variant that uses 2x F16 factors instead of 1x F32. If we manage to implement efficient SIMD dot-product for that (ideally without performance regression compared to the existing Q4_0 x Q8_0 call), I believe we should be should have a 4-bit quantization that has very similar quality to the full F16 mode

@MillionthOdin16 MillionthOdin16 mentioned this pull request Apr 14, 2023
@howard0su
Copy link
Collaborator

@howard0su

_mm_loadu_epi8

Intel's guide tells me this is AVX512? _mm_loadu_si128 should be equivalent, though.

Anyway I've already started doing AVX and AVX2 on a branch mentioned here #909 (comment). I'll try to get a PR against this PR ready.

You are right. but I noticed that it works on MSVC so I didn't change.

ggml.c Outdated Show resolved Hide resolved
@ggerganov
Copy link
Owner Author

Ok, so I've decided on the following plan:

  • postpone the WASM implementation from this PR for later
  • merge to master
  • put high priority on the idea of 2x F16 Q4_0 (see 679e1cb)
  • measure speed and perplexity and decide how to proceed

If this idea is viable in terms of speed and accuracy, we can think about changing the Q4_0 format and supporting only that in order to reduce the complexity of the implementation

Will merge this in about an hour if no objections.

@dfyz
Copy link
Collaborator

dfyz commented Apr 15, 2023

For an efficient AVX-512 implementation, would it be acceptable for quantize_row_q8_0() to store the quantized data in the format @unbounded proposed here ("all the nibbles followed by all the scales")? I want to load two blocks at once, and two block_q8_0's now consume 72 bytes, which is larger than the 64 bytes an AVX-512 register can handle. With the format from @unbounded, it is possible to load the block data and the scales separately.

This might be potentially confusing, since the AVX-512 implementation will use a different format to store the intermediate 8-bit results (as opposed to block_q8_0 for all other implementations). On the other hand, quantize_row_q8_0() is specifically used only as part of ggml_compute_forward_mul_mat_q_f32() (at least for now), so you could say that the intermediate 8-bit format is an implementation detail which doesn't matter as long as quantize_row_q8_0() and ggml_vec_dot_q4_0_q8_0() agree on what it is.

@ggerganov
Copy link
Owner Author

@dfyz
Yes, that should be possible.
We can do this after merging this PR and if you can demonstrate that the AVX512 implementation benefits from it.

@dfyz
Copy link
Collaborator

dfyz commented Apr 15, 2023

@dfyz Yes, that should be possible. We can do this after merging this PR and if you can demonstrate that the AVX512 implementation benefits from it.

Thanks! I was just thinking out loud, definitely wasn't proposing an AVX-512 implementation to integrate into this PR (#933 with the Q4_0 x Q4_0 multiplication should be reviewed and merged first, anyway).

Continuing the "thinking out loud" part: on second thought, changing the format doesn't really buy us anything. Instead of loading the block data and the scales with two separate load instructions, we can just load as many bytes of a block_q8_0 as possible, permute the bytes with VPERMB, then issue a second load for leftovers (masked AVX-512 loads should handle this just fine).

Still, it is nice that changing format is possible in theory. I will try implementing a Q4_0 x Q8_0 AVX-512 routine and get some measurements when the dust settles (i.e., when this PR and #933 are both merged).

@sw
Copy link
Contributor

sw commented Apr 15, 2023

For an efficient AVX-512 implementation, would it be acceptable for quantize_row_q8_0() to store the quantized data in the format @unbounded proposed here ("all the nibbles followed by all the scales")?

Other SIMD optimizations may benefit as well, because the quant blocks would be aligned to 32 bytes. On AVX2, you could use _mm256_load_si256 instead of _mm256_loadu_si256, but I don't know if that has a large impact. On the other hand, cache locality may be worse and you'll have to juggle two pointers.

Or how about this?

typedef struct {
    float   d0;      // delta
    float   d1;      // delta
    int8_t  qs0[QK]; // quants
    int8_t  qs1[QK]; // quants
} block_q8_0;

Depending on your register width, you could choose to do just one at a time.

@ggerganov ggerganov merged commit e95b655 into master Apr 15, 2023
@ggerganov ggerganov deleted the q8_0 branch April 15, 2023 14:53
@unbounded
Copy link
Contributor

Great!

If we change the layout of q8_0, I would look at also changing q4_0 to "all the nibbles followed by all the scales",
with a block of 32 bytes containing 32 consecutive nibbles in the lower half, and 32 more nibbles in the higher half. So you could and with 0x0f0f0f... to get a register of consecutive bytes ready to be dp'd with q8 bytes, and then shift + and to get another register of consecutive bytes. Should only need permutation for the scales in the main loop then, I think.

Need to verify for other architectures, but I think this would make the SIMD implementations more straight forward in most cases, and shouldn't hurt anything.

The AVX2 implementation of ggml_vec_dot_q4_0 already assumes the number of blocks is divisible 8 in https://github.com/ggerganov/llama.cpp/blob/master/ggml.c#L2169, so evidently that works for the llama model.

@ggerganov
Copy link
Owner Author

A few of thoughts on the data formats:

  • The selected data structures should be optimized for ARM NEON with highest priority. After that, I would probably put WASM SIMD. Maybe AVX equally important as WASM. Unfortunately, WASM toolchain is a bit painful, so it is not very realistic to put WASM as second priority for now. But both ARM NEON and WASM SIMD work with 128-bit registers, so we have to keep this in mind
  • Originally, I used to pack all scales at the start, followed by all nibbles for a row. This didn't work out well because the LLaMA models were split into shards, and some of the rows were being split across shards and it made merging the quantized shards very difficult. Therefore, I changed the data to the current format. Now, we no longer have this problem because the merging of the shards is done via the Python conversion script. Still, going back to the old format would require to demonstrate significant benefit, especially for ARM NEON. For Q8_0 it's easier to make the change, since we are not using it to store data yet, but again - we need to have a good reason
  • The AVX2 assumption of blocks divisible by 8 should be "asserted" in some way. And we should start handling "leftovers" from the division. Currently, it works for LLaMA, but these routines will be used for other models where this assumption might not be true

@sw sw mentioned this pull request Apr 15, 2023
@dfyz
Copy link
Collaborator

dfyz commented Apr 15, 2023

On AVX2, you could use _mm256_load_si256 instead of _mm256_loadu_si256, but I don't know if that has a large impact

This could make a non-negligible difference, but if you change the data ordering row-wise (as @unbounded suggested), you have to somehow ensure that the length of each row in bytes is also divisible by 32? E.g., two Q4_0 blocks are 40 bytes, so if a row has two blocks, you have to allocate space for additional 24 bytes at the end of each row. I think this might get complicated quickly.

I personally will try to avoid changing formats in the AVX-512 implementation unless absolutely necessary (for the reasons @ggerganov provided above).

@SebastianApel
Copy link
Contributor

If we change the layout of q8_0, I would look at also changing q4_0 to "all the nibbles followed by all the scales",
with a block of 32 bytes containing 32 consecutive nibbles in the lower half, and 32 more nibbles in the higher half

I agree with the hypothesis that reordering the q4_0 format would be beneficial. On an AVX2 CPU, the current format (f32, 32xq4) makes it hard for any implementation to load the f32's as a vector. A scheme that would enable vector based loading would provide more flexibility.

Caveat: I experimented with a "reordered" byte sequence based on a "superblock" of 8 q4_0 blocks about 2/3 weeks ago, and there was slight a speedup. It was, however, as far as I remember, not a major step change, especially since you need to divide the single thread improvements by the number of threads you are going to use.

@unbounded
Copy link
Contributor

👍 I will probably do some experiments with continuous layouts - I understand that we would need to see a significant benefit to change the format. It will be hard to speed things up much if we're already close to memory bandwidth, of course.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
generation quality Quality of model output help wanted Extra attention is needed high priority Very important issue Review Complexity : High Generally require indepth knowledge of LLMs or GPUs Tensor Encoding Scheme https://github.com/ggerganov/llama.cpp/wiki/Tensor-Encoding-Schemes
Development

Successfully merging this pull request may close these issues.

Investigate alternative ggml_compute_forward_mul_mat_q_f32() implementation
8 participants