Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] decoder 结构中的q,k,v 三个matmul会自动concatenate成一个没有名字的matmul算子 #343

Open
1826133674 opened this issue Aug 28, 2024 · 0 comments

Comments

@1826133674
Copy link

1826133674 commented Aug 28, 2024

Describe the bug
在使用onnxsim优化qwen1.5-1.8B的onnx模型的时候,我们遇到了一个问题,即decoder结构中的q,k,v三个matmul算子会被concatenate成一个算子,然后再使用split分开成三个。此时得到的新的融合版的matmul和split算子都没有名字.

Model

依赖库版本
onnxsim-0.4.36
torch 2.3.1

稳定复现代码


import torch
import torch.nn as nn

class MultiMatMulAddModel(nn.Module):
    def __init__(self):
        super(MultiMatMulAddModel, self).__init__()

        # 初始化三个2048x2048的权重矩阵和三个2048维的偏置向量
        self.weights = nn.ParameterList([nn.Parameter(torch.randn(2048, 2048)) for _ in range(3)])
        self.biases = nn.ParameterList([nn.Parameter(torch.randn(2048)) for _ in range(3)])

    def forward(self, x):
        outputs = []
        for i in range(3):
            # 对输入进行矩阵乘法
            matmul_result = torch.matmul(x.squeeze(0), self.weights[i])

            # 执行加法操作
            add_result = matmul_result + self.biases[i]

            # 将结果添加到输出列表中
            outputs.append(add_result)

        # 返回所有三个输出
        return tuple(outputs)

# 创建模型实例
model = MultiMatMulAddModel()
import torch.onnx

# 创建输入张量
input_tensor = torch.randn(1, 1024, 2048)

# 设置输出文件名
output_file = "multi_matmul_add_model.onnx"

# 导出模型到ONNX格式
torch.onnx.export(model,
                  input_tensor,
                  output_file,
                  export_params=True,        # 存储训练好的参数权重
                  opset_version=11,         # ONNX版本
                  do_constant_folding=True,  # 是否执行常量折叠优化
                  input_names=['input'],    # 输入节点名称
                  output_names=['output1', 'output2', 'output3'],  # 输出节点名称
                  dynamic_axes={'input': {0: 'batch_size'},  # 可变轴信息
                                'output1': {0: 'batch_size'},
                                'output2': {0: 'batch_size'},
                                'output3': {0: 'batch_size'}})

print(f"Model has been exported to {output_file}")
@github-staff github-staff deleted a comment Aug 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants
@1826133674 and others