diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 7b821b7777f3..dbc8ce842b4d 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -224,9 +224,9 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; // Use CUSPARSE_SPMV_COO_ALG2 and CUSPARSE_SPMV_CSR_ALG2 for SPMV and // use CUSPARSE_SPMM_COO_ALG2 and CUSPARSE_SPMM_CSR_ALG3 for SPMM, which // provide deterministic (bit-wise) results for each run. These indexing modes -// are available in CUSPARSE 11.4 and newer (which was released as part of -// CUDA 11.2.1) -#if CUSPARSE_VERSION >= 11400 +// are fully supported (both row- and column-major inputs) in CUSPARSE 11.7.1 +// and newer (which was released as part of CUDA 11.8) +#if CUSPARSE_VERSION > 11700 #define GPUSPARSE_SPMV_COO_ALG CUSPARSE_SPMV_COO_ALG2 #define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_SPMV_CSR_ALG2 #define GPUSPARSE_SPMM_COO_ALG CUSPARSE_SPMM_COO_ALG2