diff --git a/tests/cpp/operator/mkldnn_operator_test.cc b/tests/cpp/operator/mkldnn_operator_test.cc index 8e86100d793b..ca012271f021 100644 --- a/tests/cpp/operator/mkldnn_operator_test.cc +++ b/tests/cpp/operator/mkldnn_operator_test.cc @@ -163,7 +163,7 @@ OpAttrs GetPoolingOp(int kernel, int dim, int stride, int pad) { OpAttrs attrs; attrs.attrs.op = Op::Get("Pooling"); attrs.num_inputs = 1; - attrs.num_outputs = dim == 2 ? 2 : 1; + attrs.num_outputs = (dim == 2 || dim == 3) ? 2 : 1; attrs.attrs.dict.insert({"kernel" , CreateShapeString(kernel, dim)}); attrs.attrs.dict.insert({"stride" , CreateShapeString(stride, dim)}); attrs.attrs.dict.insert({"pad" , CreateShapeString(pad, dim)}); @@ -175,7 +175,7 @@ OpAttrs GetPoolingOp(int kernel, int dim, int stride, int pad) { OpAttrs GetPoolingBackwardsOp(int kernel, int dim, int stride, int pad) { OpAttrs attrs; attrs.attrs.op = Op::Get("_backward_Pooling"); - attrs.num_inputs = dim == 2 ? 5 : 3; + attrs.num_inputs = (dim == 2 || dim == 3) ? 5 : 3; attrs.num_outputs = 1; attrs.attrs.dict.insert({"kernel", CreateShapeString(kernel, dim)}); attrs.attrs.dict.insert({"stride", CreateShapeString(stride, dim)});