Skip to content

Commit

Permalink
Make updates to reduce number of load instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
Srihari-mcw authored and Srihari-mcw committed Sep 4, 2024
1 parent 364dc96 commit c950fc3
Showing 1 changed file with 24 additions and 32 deletions.
56 changes: 24 additions & 32 deletions ggml/src/ggml-aarch64.c
Original file line number Diff line number Diff line change
Expand Up @@ -2504,22 +2504,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
for (int rp = 0; rp < 4; rp++) {
// Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
__m256i lhs_mat_01_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs)));
lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_01_0, lhs_mat_01_0, 0);
__m256i lhs_mat_23_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 16)));
lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_23_0, lhs_mat_23_0, 0);
__m256i lhs_mat_01_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 32)));
lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_01_1, lhs_mat_01_1, 0);
__m256i lhs_mat_23_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 48)));
lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_23_1, lhs_mat_23_1, 0);
__m256i lhs_mat_01_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 64)));
lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_01_2, lhs_mat_01_2, 0);
__m256i lhs_mat_23_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 80)));
lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_23_2, lhs_mat_23_2, 0);
__m256i lhs_mat_01_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 96)));
lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_01_3, lhs_mat_01_3, 0);
__m256i lhs_mat_23_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 112)));
lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_23_3, lhs_mat_23_3, 0);
__m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
__m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
__m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
__m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
__m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
__m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
__m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
__m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
__m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
__m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
__m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
__m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);

// Shuffle pattern one - left side input
const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
Expand Down Expand Up @@ -2670,22 +2666,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *

// Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
__m256i lhs_mat_01_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs)));
lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_01_0, lhs_mat_01_0, 0);
__m256i lhs_mat_23_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16)));
lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_23_0, lhs_mat_23_0, 0);
__m256i lhs_mat_01_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32)));
lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_01_1, lhs_mat_01_1, 0);
__m256i lhs_mat_23_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48)));
lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_23_1, lhs_mat_23_1, 0);
__m256i lhs_mat_01_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 64)));
lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_01_2, lhs_mat_01_2, 0);
__m256i lhs_mat_23_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 80)));
lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_23_2, lhs_mat_23_2, 0);
__m256i lhs_mat_01_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 96)));
lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_01_3, lhs_mat_01_3, 0);
__m256i lhs_mat_23_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 112)));
lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_23_3, lhs_mat_23_3, 0);
__m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
__m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
__m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
__m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
__m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
__m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
__m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
__m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
__m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
__m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
__m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
__m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);

// Shuffle pattern one - left side input

Expand Down

0 comments on commit c950fc3

Please sign in to comment.