-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【prim】Matmul double grad composite api #50452
【prim】Matmul double grad composite api #50452
Conversation
…807/Paddle into matmul_double_grad
你的PR提交成功,感谢你对开源项目的贡献! |
for (int64_t i = 0; i < axis_size; i++) { | ||
if (axis[i] < 0) { | ||
axis_[i] = axis[i] + x_dim_size; | ||
std::cout << "axis_[" << i << "] = " << axis[i]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删掉打印相关的代码
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
desired=dx_, | ||
rtol=TOLERANCE[d_type]['rtol'], | ||
atol=TOLERANCE[d_type]['atol'], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请添加三阶测试保证拆解后反向的正确性
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some coments
@@ -78,11 +81,38 @@ static phi::DDim get_reduce_dims(const phi::DDim& x_dims, | |||
return get_reduce_dims_from_out(out_dims, x_dims); | |||
} | |||
|
|||
static std::vector<int> get_reduce_dims(const Tensor& dx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when matmul used broadcast [1, 4, 3] * [2, 3, 4] =[2, 4, 4] ,x will be broadcast to [2, 4, 3] we need to compute the true dims [1, 4, 3] reduce dims will be [0]
@param.parameterized_class( | ||
('primal0', 'primal1', 'primal2', 'trans_0', 'trans_1', 'dtype'), | ||
[ | ||
# ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add to do here to uncomment it when we fix static mode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -93,7 +93,7 @@ class StaticCompositeContext { | |||
StaticCompositeContext() | |||
: current_block_desc_(nullptr), | |||
generator_(new UniqueNameGenerator()), | |||
skip_comp_ops_({"matmul_v2"}) {} | |||
skip_comp_ops_({"matmul_v2", "matmul_v2_grad"}) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leave todo here to remove this when we fix static
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for run_unittests.sh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
8cd6842
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Describe
matmul 二阶组合算子实现