diff --git a/manywheel/build_rocm.sh b/manywheel/build_rocm.sh index 77942898d..4ee9449ef 100755 --- a/manywheel/build_rocm.sh +++ b/manywheel/build_rocm.sh @@ -71,6 +71,7 @@ else echo "Unhandled ROCM_VERSION ${ROCM_VERSION}" exit 1 fi +ROCM_VERSION_WITH_PATCH=rocm${ROCM_VERSION_MAJOR}.${ROCM_VERSION_MINOR}.${ROCM_VERSION_PATCH} ROCM_INT=$(($ROCM_VERSION_MAJOR * 10000 + $ROCM_VERSION_MINOR * 100 + $ROCM_VERSION_PATCH)) # Required ROCm libraries @@ -277,6 +278,25 @@ if [[ $ROCM_INT -ge 50600 ]]; then DEPS_AUX_DSTLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_DST/}) fi +# Add triton install dependency +# No triton dependency for now on 3.12 since we don't have binaries for it +# and torch.compile doesn't work. +PYTORCH_VERSION=$(cat $PYTORCH_ROOT/version.txt | grep -oP "[0-9]+\.[0-9]+\.[0-9]+") +# Assuming PYTORCH_VERSION=x.y.z, if x >= 2 +if [ ${PYTORCH_VERSION%%\.*} -ge 2 ]; then + if [[ $(uname) == "Linux" && "$DESIRED_PYTHON" != "3.12" ]]; then + TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-rocm.txt) + TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt) + + if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then + export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="pytorch-triton-rocm==${TRITON_VERSION}+${ROCM_VERSION_WITH_PATCH}.${TRITON_SHORTHASH}" + else + export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | pytorch-triton-rocm==${TRITON_VERSION}+${ROCM_VERSION_WITH_PATCH}.${TRITON_SHORTHASH}" + fi + fi +fi + + echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH}" SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"