diff --git a/oneflow/api/python/framework/autocast.cpp b/oneflow/api/python/framework/autocast.cpp index 506718e6dd8..f47e5fb91cc 100644 --- a/oneflow/api/python/framework/autocast.cpp +++ b/oneflow/api/python/framework/autocast.cpp @@ -16,7 +16,9 @@ limitations under the License. #include #include "oneflow/api/python/of_api_registry.h" +#include "oneflow/core/common/device_type.pb.h" #include "oneflow/core/common/throw.h" +#include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/framework/autocast.h" namespace py = pybind11; @@ -36,7 +38,7 @@ class AutoCastMode { public: OF_DISALLOW_COPY_AND_MOVE(AutoCastMode); - AutoCastMode(const std::string& device_type, Symbol dtype, bool enabled, + AutoCastMode(const std::string& device_name, Symbol dtype, bool enabled, bool cache_enabled) : prev_enabled_(autocast::is_enabled()), prev_cache_enabled_(autocast::is_autocast_cache_enabled()), @@ -48,16 +50,23 @@ class AutoCastMode { increase_nested_count(); autocast::set_enabled(enabled); autocast::set_autocast_cache_enabled(cache_enabled); - if (device_type == "cpu") { - autocast::set_autocast_device_type(kCPU); - autocast::set_autocast_dtype(dtype); - autocast::set_autocast_cpu_dtype(dtype); - } else if (device_type == "cuda") { - autocast::set_autocast_device_type(kCUDA); - autocast::set_autocast_dtype(dtype); - autocast::set_autocast_gpu_dtype(dtype); - } else { - THROW(RuntimeError) << "User specified autocast device_type must be 'cuda' or 'cpu'"; + auto device_type = ep::DeviceManagerRegistry::GetDeviceTypeByDeviceTypeName(device_name); + switch (device_type) { + case kCPU: + autocast::set_autocast_device_type(device_type); + autocast::set_autocast_dtype(dtype); + autocast::set_autocast_cpu_dtype(dtype); + break; + case kCUDA: + case kMLU: + case kNPU: + autocast::set_autocast_device_type(device_type); + autocast::set_autocast_dtype(dtype); + autocast::set_autocast_gpu_dtype(dtype); + break; + default: + THROW(RuntimeError) + << "User specified autocast device_type must be 'cuda' or 'cpu' or 'mlu' or 'npu'"; } } diff --git a/oneflow/core/framework/autocast.cpp b/oneflow/core/framework/autocast.cpp index f3ce8320183..1817467d843 100644 --- a/oneflow/core/framework/autocast.cpp +++ b/oneflow/core/framework/autocast.cpp @@ -181,7 +181,7 @@ std::shared_ptr MakeAutoCastMeta( // autocast only supports the following device type(s) and low precision type(s): // - device type: CUDA // - low precision type: half, bfloat16 - static std::vector autocast_device_types{kCUDA}; + static std::vector autocast_device_types{kCUDA, kMLU, kNPU}; static std::vector> autocast_dtypes{DType::Float16(), DType::BFloat16()}; if (autocast_meta->autocast_color() != kBlack) { diff --git a/oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp b/oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp index 0a3a64b330b..8fca38bd681 100644 --- a/oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp +++ b/oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp @@ -20,7 +20,9 @@ namespace oneflow { const AMPList& AutoMixedPrecisionLists::WhiteList() { static AMPList white_list = {"matmul", "batch_matmul", + "conv1d", "conv2d", + "conv3d", "conv_data_grad", "conv_filter_grad", "conv_bias_grad", @@ -137,7 +139,8 @@ const AMPList& AutoMixedPrecisionLists::GrayList() { "group_norm_grad", "silu", "silu_grad", - "fused_weighted_sum"}; + "fused_weighted_sum", + "cast"}; return gray_list; } diff --git a/python/oneflow/amp/autocast_mode.py b/python/oneflow/amp/autocast_mode.py index 3fbaf429566..a23be0ba5a0 100644 --- a/python/oneflow/amp/autocast_mode.py +++ b/python/oneflow/amp/autocast_mode.py @@ -153,10 +153,10 @@ def __init__( cache_enabled: Optional[bool] = None, ): self.device = device_type - if self.device == "cuda": - self.fast_dtype = flow.get_autocast_gpu_dtype() - elif self.device == "cpu": + if self.device == "cpu": self.fast_dtype = flow.get_autocast_cpu_dtype() + elif self.device in ["cuda", "mlu", "npu"]: + self.fast_dtype = flow.get_autocast_gpu_dtype() else: raise RuntimeError( "User specified autocast device_type must be 'cuda' or 'cpu'"