diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 7b958eb89ebf4..2f7ea4edfe4c2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -51,7 +51,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m t.join(); } - if (tensor->type == GGML_TYPE_F32) { + if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) { ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float)); } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) { GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0); @@ -233,6 +233,10 @@ static bool ggml_is_view_op(enum ggml_op op) { struct test_case { virtual ~test_case() {} + virtual std::string op_desc(ggml_tensor * t) { + return ggml_op_desc(t); + } + virtual std::string vars() { return ""; } @@ -240,7 +244,7 @@ struct test_case { virtual ggml_tensor * build_graph(ggml_context * ctx) = 0; virtual double max_nmse_err() { - return 1e-6; + return 1e-7; } virtual void initialize_tensors(ggml_context * ctx) { @@ -270,13 +274,13 @@ struct test_case { ggml_tensor * out = build_graph(ctx); - if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) { - //printf(" %s: skipping\n", ggml_op_desc(out)); + if (op_name != nullptr && op_desc(out) != op_name) { + //printf(" %s: skipping\n", op_desc(out).c_str()); ggml_free(ctx); return true; } - printf(" %s(%s): ", ggml_op_desc(out), vars().c_str()); + printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str()); fflush(stdout); // check if backends support op @@ -317,7 +321,7 @@ struct test_case { for (size_t i = 0; i < f1.size(); i++) { // check for nans if (std::isnan(f1[i]) || std::isnan(f2[i])) { - printf("NaN at index %zu ", i); + printf("[%s] NaN at index %zu ", ggml_op_desc(t1), i); ud->ok = false; return true; } @@ -325,21 +329,32 @@ struct test_case { if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) { if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) { if (std::signbit(f1[i]) != std::signbit(f2[i])) { - printf("inf sign mismatch: %f %f ", f1[i], f2[i]); + printf("[%s] inf sign mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]); ud->ok = false; return true; } } else { - printf("inf mismatch: %f %f ", f1[i], f2[i]); + printf("[%s] inf mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]); ud->ok = false; return true; } } } + //if (t1->op == GGML_OP_SOFT_MAX) { + // printf("[%s] ", ggml_op_desc(t1)); + // for (int i = 0; i < f1.size(); i++) { + // printf("(%x, %x) ", *(uint32_t*)&f1[i], *(uint32_t*)&f2[i]); + // } + // printf("\n"); + //} double err = nmse(f1.data(), f2.data(), f1.size()); if (err > ud->max_err) { - printf("NMSE = %f ", err); + printf("[%s] NMSE = %f ", ggml_op_desc(t1), err); + //for (int i = 0; i < f1.size(); i++) { + // printf("(%f, %f) ", f1[i], f2[i]); + //} + //printf("\n"); ud->ok = false; } return true; @@ -374,13 +389,13 @@ struct test_case { ggml_tensor * out = build_graph(ctx); - if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) { - //printf(" %s: skipping\n", ggml_op_desc(out)); + if (op_name != nullptr && op_desc(out) != op_name) { + //printf(" %s: skipping\n", op_desc(out).c_str()); ggml_free(ctx); return true; } - int len = printf(" %s(%s): ", ggml_op_desc(out), vars().c_str()); + int len = printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str()); fflush(stdout); // check if backends support op @@ -1122,6 +1137,91 @@ struct test_sum_rows : public test_case { } }; +struct test_moe : public test_case { + const int n_experts = 8; + const int n_experts_per_tok = 2; + const int n_tokens = 1; + const int n_embd = 4096; + const int n_ff = 14336; + + std::string op_desc(ggml_tensor * t) override { + return "MOE"; + GGML_UNUSED(t); + } + + std::string vars() override { + return VARS_TO_STR5(n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff); + } + + test_moe() { + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_experts); + + std::vector ffn_up_exp(n_experts); + std::vector ffn_gate_exp(n_experts); + std::vector ffn_down_exp(n_experts); + + for (int i = 0; i < n_experts; ++i) { + ffn_up_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); + ffn_gate_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); + ffn_down_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd); + } + + ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens); + + ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur); // [n_tokens, num_experts] + ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_tokens, num_experts] + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok] + + ggml_tensor * weights = ggml_get_rows(ctx, + ggml_reshape_3d(ctx, probs, 1, n_experts, n_tokens), selected_experts); + printf("get rows args %ld %ld %ld %ld, %ld %ld %ld %ld\n", + weights->src[0]->ne[0], weights->src[0]->ne[1], weights->src[0]->ne[2], weights->src[0]->ne[3], + weights->src[1]->ne[0], weights->src[1]->ne[1], weights->src[1]->ne[2], weights->src[1]->ne[3]); + + + weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok] + + ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); + + weights = ggml_div(ctx, weights, weights_sum); // [n_tokens, num_experts_per_tok] + + // compute expert outputs + ggml_tensor * moe_out = nullptr; + + for (int i = 0; i < n_experts_per_tok; ++i) { + ggml_tensor * cur_expert; + + ggml_tensor * cur_up = ggml_mul_mat_id(ctx, ffn_up_exp.data(), n_experts, selected_experts, i, cur); + + ggml_tensor * cur_gate = ggml_mul_mat_id(ctx, ffn_gate_exp.data(), n_experts, selected_experts, i, cur); + + cur_gate = ggml_silu(ctx, cur_gate); + + cur_expert = ggml_mul(ctx, cur_up, cur_gate); // [n_tokens, n_embd] + + cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert); // [n_tokens, n_embd] + + cur_expert = ggml_mul(ctx, cur_expert, + ggml_view_2d(ctx, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ggml_add(ctx, moe_out, cur_expert); + } + } + + cur = moe_out; + + return cur; + } +}; + enum test_mode { MODE_TEST, MODE_PERF, @@ -1140,11 +1240,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op GGML_TYPE_Q6_K }; + test_cases.emplace_back(new test_moe()); + // unary ops for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) { test_cases.emplace_back(new test_unary((ggml_unary_op) op)); } + test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false)); for (ggml_type type : all_types) { for (int b : {1, 7}) { for (bool v : {false, true}) { @@ -1265,6 +1368,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_concat()); for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) { + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); }