diff --git a/.circleci/regenerate.py b/.circleci/regenerate.py index f1720979fb..61120a1f85 100755 --- a/.circleci/regenerate.py +++ b/.circleci/regenerate.py @@ -21,7 +21,7 @@ PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] -CU_VERSIONS_DICT = {"linux": ["cpu", "cu102", "cu111"], +CU_VERSIONS_DICT = {"linux": ["cpu", "cu102", "cu111","rocm4.1"], "windows": ["cpu", "cu102", "cu111"], "macos": ["cpu"]} @@ -124,8 +124,10 @@ def generate_base_workflow(base_workflow_name, python_version, cu_version, filte d['requires'] = ['download_third_parties_nix'] if btype == 'conda': d['conda_docker_image'] = f'pytorch/conda-builder:{cu_version.replace("cu1","cuda1")}' - elif cu_version != 'cpu': + elif cu_version.startswith('cu'): d['wheel_docker_image'] = f'pytorch/manylinux-{cu_version.replace("cu1","cuda1")}' + elif cu_version.startswith('rocm'): + d["wheel_docker_image"] = f"pytorch/manylinux-rocm:{cu_version[len('rocm'):]}" if filter_branch: d["filters"] = gen_filter_branch_tree(filter_branch) diff --git a/packaging/pkg_helpers.bash b/packaging/pkg_helpers.bash index 173aa4ca07..1626938863 100644 --- a/packaging/pkg_helpers.bash +++ b/packaging/pkg_helpers.bash @@ -107,6 +107,10 @@ setup_cuda() { export FORCE_CUDA=1 export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0" ;; + rocm*) + export FORCE_CUDA=1 + export USE_ROCM=1 + ;; cpu) ;; *)