From 46368e4e734938bffe7978cb4077923343095abc Mon Sep 17 00:00:00 2001 From: Tianjian Lu Date: Thu, 3 Nov 2022 21:39:10 -0700 Subject: [PATCH] [sparse] Update the guard of cusparse SpMM and SpMv algorithms to cusparse version 11.7.1 onwards. PiperOrigin-RevId: 486051658 --- jaxlib/gpu/vendor.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 7b821b7777f3..dbc8ce842b4d 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -224,9 +224,9 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; // Use CUSPARSE_SPMV_COO_ALG2 and CUSPARSE_SPMV_CSR_ALG2 for SPMV and // use CUSPARSE_SPMM_COO_ALG2 and CUSPARSE_SPMM_CSR_ALG3 for SPMM, which // provide deterministic (bit-wise) results for each run. These indexing modes -// are available in CUSPARSE 11.4 and newer (which was released as part of -// CUDA 11.2.1) -#if CUSPARSE_VERSION >= 11400 +// are fully supported (both row- and column-major inputs) in CUSPARSE 11.7.1 +// and newer (which was released as part of CUDA 11.8) +#if CUSPARSE_VERSION > 11700 #define GPUSPARSE_SPMV_COO_ALG CUSPARSE_SPMV_COO_ALG2 #define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_SPMV_CSR_ALG2 #define GPUSPARSE_SPMM_COO_ALG CUSPARSE_SPMM_COO_ALG2