From 57951e4e386ebda59802b654728e9445d3f43827 Mon Sep 17 00:00:00 2001 From: Cao Dong <87467313+woaixiaoxiao@users.noreply.github.com> Date: Tue, 21 May 2024 06:28:36 +0800 Subject: [PATCH] add_npu_support for goup_norm (#10496) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 原来的实现中只考虑了cuda作为底层实现,而没考虑npu,mlu等其他硬件 --- python/oneflow/nn/modules/normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/oneflow/nn/modules/normalization.py b/python/oneflow/nn/modules/normalization.py index 592cd5502dd..ab6fa1158d1 100644 --- a/python/oneflow/nn/modules/normalization.py +++ b/python/oneflow/nn/modules/normalization.py @@ -44,7 +44,7 @@ def group_norm( ), "The channels of input tensor must equal num_channels" affine = weight is not None and bias is not None - if input.is_cuda: + if not input.is_cpu: return flow._C.group_norm(input, weight, bias, affine, num_groups, eps) else: origin_shape = input.shape