Skip to content

Commit

Permalink
Simplify and optimize the sgemm code
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Apr 20, 2024
1 parent dd49f26 commit b52304e
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 27 deletions.
3 changes: 1 addition & 2 deletions llamafile/sgemm_q0q0s_dotprod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ class GEMMERQ0ARM {
mp = m0 + (m - m0) / mc * mc;
np = n0 + (n - n0) / nc * nc;
mnpack(mp, m, n0, np);
mnpack(m0, mp, np, n);
mnpack(mp, m, np, n);
mnpack(m0, m, np, n);
}

dontinline void gemm3x3(int m0, int m, int n0, int n) {
Expand Down
149 changes: 135 additions & 14 deletions llamafile/sgemmer.inc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>

#include "llama.cpp/ggml.h"

#include "hsum.h"
Expand All @@ -34,36 +36,155 @@ class SGEMMER {
}

private:
dontinline void mnpack(int m0, int m, int n0, int n) {
void mnpack(int m0, int m, int n0, int n) {
int mc, nc, mp, np;
if (m - m0 <= 0 || n - n0 <= 0)
return;
if (VECTOR_REGISTERS >= 32 && m - m0 >= 8 && n - n0 >= 3) {
mc = 8;
switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) {
#if VECTOR_REGISTERS == 32
case 0x55:
mc = 5;
nc = 5;
gemm<5, 5>(m0, m, n0, n);
break;
case 0x45:
mc = 4;
nc = 5;
gemm<4, 5>(m0, m, n0, n);
break;
case 0x54:
mc = 5;
nc = 4;
gemm<5, 4>(m0, m, n0, n);
break;
case 0x44:
mc = 4;
nc = 4;
gemm<4, 4>(m0, m, n0, n);
break;
case 0x53:
mc = 5;
nc = 3;
gemm<8, 3>(m0, m, n0, n);
} else if (m - m0 >= 4 && n - n0 >= 3) {
gemm<5, 3>(m0, m, n0, n);
break;
case 0x35:
mc = 3;
nc = 5;
gemm<3, 5>(m0, m, n0, n);
break;
case 0x43:
mc = 4;
nc = 3;
gemm<4, 3>(m0, m, n0, n);
} else if (n - n0 >= 4) {
mc = 1;
break;
#else
case 0x55:
case 0x54:
case 0x53:
case 0x45:
case 0x44:
case 0x43:
mc = 4;
nc = 3;
gemm<4, 3>(m0, m, n0, n);
break;
case 0x35:
#endif
case 0x34:
mc = 3;
nc = 4;
gemm<1, 4>(m0, m, n0, n);
} else if (m - m0 >= 4) {
gemm<3, 4>(m0, m, n0, n);
break;
case 0x52:
mc = 5;
nc = 2;
gemm<5, 2>(m0, m, n0, n);
break;
case 0x33:
mc = 3;
nc = 3;
gemm<3, 3>(m0, m, n0, n);
break;
case 0x25:
mc = 2;
nc = 5;
gemm<2, 5>(m0, m, n0, n);
break;
case 0x42:
mc = 4;
nc = 2;
gemm<4, 2>(m0, m, n0, n);
break;
case 0x24:
mc = 2;
nc = 4;
gemm<2, 4>(m0, m, n0, n);
break;
case 0x32:
mc = 3;
nc = 2;
gemm<3, 2>(m0, m, n0, n);
break;
case 0x23:
mc = 2;
nc = 3;
gemm<2, 3>(m0, m, n0, n);
break;
case 0x51:
mc = 5;
nc = 1;
gemm<5, 1>(m0, m, n0, n);
break;
case 0x41:
mc = 4;
nc = 1;
gemm<4, 1>(m0, m, n0, n);
} else {
break;
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2>(m0, m, n0, n);
break;
case 0x15:
mc = 1;
nc = 5;
gemm<1, 5>(m0, m, n0, n);
break;
case 0x14:
mc = 1;
nc = 4;
gemm<1, 4>(m0, m, n0, n);
break;
case 0x31:
mc = 3;
nc = 1;
gemm<3, 1>(m0, m, n0, n);
break;
case 0x13:
mc = 1;
nc = 3;
gemm<1, 3>(m0, m, n0, n);
break;
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1>(m0, m, n0, n);
break;
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1>(m0, m, n0, n);
break;
default:
return;
}
mp = m0 + (m - m0) / mc * mc;
np = n0 + (n - n0) / nc * nc;
mnpack(mp, m, n0, np);
mnpack(m0, mp, np, n);
mnpack(mp, m, np, n);
mnpack(m0, m, np, n);
}

template <int RM, int RN> dontinline void gemm(int m0, int m, int n0, int n) {
Expand Down
101 changes: 92 additions & 9 deletions llamafile/sgemmer0.inc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>

#include "llama.cpp/ggml-impl.h"
#include "llama.cpp/ggml.h"

Expand All @@ -35,32 +37,113 @@ class SGEMMER0 {
}

private:
dontinline void mnpack(int m0, int m, int n0, int n) {
if (m - m0 <= 0 || n - n0 <= 0)
return;
void mnpack(int m0, int m, int n0, int n) {
int mc, nc, mp, np;
if (m - m0 >= 4 && n - n0 >= 3) {
switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) {
#if VECTOR_REGISTERS == 32
case 0x44:
mc = 4;
nc = 4;
gemm<4, 4>(m0, m, n0, n);
break;
case 0x43:
mc = 4;
nc = 3;
gemm<4, 3>(m0, m, n0, n);
} else if (m - m0 >= 4 && n - n0 >= 1) {
break;
case 0x34:
mc = 3;
nc = 4;
gemm<3, 4>(m0, m, n0, n);
break;
case 0x33:
mc = 3;
nc = 3;
gemm<3, 3>(m0, m, n0, n);
break;
case 0x42:
mc = 4;
nc = 2;
gemm<4, 2>(m0, m, n0, n);
break;
case 0x24:
mc = 2;
nc = 4;
gemm<2, 4>(m0, m, n0, n);
break;
#else
case 0x44:
case 0x43:
case 0x42:
mc = 4;
nc = 2;
gemm<4, 2>(m0, m, n0, n);
break;
case 0x34:
case 0x24:
mc = 2;
nc = 4;
gemm<2, 4>(m0, m, n0, n);
break;
case 0x33:
#endif
case 0x32:
mc = 3;
nc = 2;
gemm<3, 2>(m0, m, n0, n);
break;
case 0x23:
mc = 2;
nc = 3;
gemm<2, 3>(m0, m, n0, n);
break;
case 0x41:
mc = 4;
nc = 1;
gemm<4, 1>(m0, m, n0, n);
} else if (m - m0 >= 1 && n - n0 >= 4) {
break;
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2>(m0, m, n0, n);
break;
case 0x14:
mc = 1;
nc = 4;
gemm<1, 4>(m0, m, n0, n);
} else {
break;
case 0x31:
mc = 3;
nc = 1;
gemm<3, 1>(m0, m, n0, n);
break;
case 0x13:
mc = 1;
nc = 3;
gemm<1, 3>(m0, m, n0, n);
break;
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1>(m0, m, n0, n);
break;
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1>(m0, m, n0, n);
break;
default:
return;
}
mp = m0 + (m - m0) / mc * mc;
np = n0 + (n - n0) / nc * nc;
mnpack(mp, m, n0, np);
mnpack(m0, mp, np, n);
mnpack(mp, m, np, n);
mnpack(m0, m, np, n);
}

template <int RM, int RN> dontinline void gemm(int m0, int m, int n0, int n) {
Expand Down
3 changes: 1 addition & 2 deletions llamafile/sgemmer1.inc
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ class SGEMMER1 {
mp = m0 + (m - m0) / mc * mc;
np = n0 + (n - n0) / nc * nc;
mnpack(mp, m, n0, np);
mnpack(m0, mp, np, n);
mnpack(mp, m, np, n);
mnpack(m0, m, np, n);
}

dontinline void gemm4x2(int m0, int m, int n0, int n) {
Expand Down

0 comments on commit b52304e

Please sign in to comment.