Skip to content

Commit

Permalink
fix miss input param for _fuse_prepare_qkv (PaddlePaddle#655)
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoxiaWang authored Aug 23, 2022
1 parent 0b73295 commit 865abf4
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion fleetx/models/gpt_model/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self,
self.out_proj = nn.Linear(
embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)

def _fuse_prepare_qkv(self, query):
def _fuse_prepare_qkv(self, query, use_cache=False, cache=None):
mix_layer = self.qkv_proj(query)
mix_layer = paddle.reshape_(mix_layer,
[0, 0, self.num_heads, 3 * self.head_dim])
Expand Down
2 changes: 1 addition & 1 deletion fleetx/models/gpt_model/modeling_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(self,
input_is_parallel=True,
fuse_matmul_bias=fused_linear)

def _fuse_prepare_qkv(self, query):
def _fuse_prepare_qkv(self, query, use_cache=False, cache=None):
mix_layer = self.qkv_proj(query)
mix_layer = paddle.reshape_(mix_layer,
[0, 0, self.num_heads, 3 * self.head_dim])
Expand Down

0 comments on commit 865abf4

Please sign in to comment.