Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[BUGFIX]try avoid the error in operator/tensor/amp_cast.h #20188

Merged
merged 3 commits into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion 3rdparty/mshadow/mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -1364,7 +1364,8 @@ struct minimum {
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}

// amp_cast.h is using this MSHADOW_TYPE_SWITCH_WITH_BOOL in order to
// avoid 'Unsupport enum type 12' error.
szha marked this conversation as resolved.
Show resolved Hide resolved
#define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
Expand Down
8 changes: 4 additions & 4 deletions src/operator/tensor/amp_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ void AMPCastCompute(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DstDType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DstDType, {
Tensor<xpu, 1, DstDType> out = outputs[0].FlatTo1D<xpu, DstDType>(s);
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, SrcDType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, SrcDType, {
Tensor<xpu, 1, SrcDType> data = inputs[0].FlatTo1D<xpu, SrcDType>(s);
if (outputs[0].type_flag_ != inputs[0].type_flag_ ||
req[0] != kWriteInplace) {
Expand All @@ -155,9 +155,9 @@ void AMPMultiCastCompute(const nnvm::NodeAttrs& attrs,
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
for (size_t i = 0; i < outputs.size(); ++i) {
MSHADOW_TYPE_SWITCH(outputs[i].type_flag_, DstDType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[i].type_flag_, DstDType, {
Tensor<xpu, 1, DstDType> out = outputs[i].FlatTo1D<xpu, DstDType>(s);
MSHADOW_TYPE_SWITCH(inputs[i].type_flag_, SrcDType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[i].type_flag_, SrcDType, {
Tensor<xpu, 1, SrcDType> data = inputs[i].FlatTo1D<xpu, SrcDType>(s);
if (outputs[i].type_flag_ != inputs[i].type_flag_ ||
req[i] != kWriteInplace) {
Expand Down