diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..c7808e534 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,124 @@ +# This CMake config hopefully makes it easier to compile. +# Ensure the CUDA Toolkit is available on your path. Then run: +# For GCC: `cmake -B build . && cmake --build build` +# For MSVC: `cmake -B build . && cmake --build build --config Release` +# You can also use the following options +# - BUILD_CUDA: Default ON, will build with CUDA +# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support +# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version +# is whatever CMake finds on your path. +# - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC. +# Separate by semicolons, i.e. `-DCOMPUTE_CAPABILITY=89;90` +# Check your compute capability here: https://developer.nvidia.com/cuda-gpus +# - PTXAS_VERBOSE: Pass the `-v` option to the PTX Assembler +cmake_minimum_required(VERSION 3.18) + +project(bitsandbytes LANGUAGES C CXX) + +option(BUILD_CUDA "Build bitsandbytes with CUDA support" ON) +option(NO_CUBLASLT "Disable CUBLAS" OFF) +option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) + +list(APPEND SRC_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp) +list(APPEND CUDA_FILES csrc/ops.cu csrc/kernels.cu) + +message(STATUS "BUILD_CUDA := ${BUILD_CUDA}") +message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}") + +set(BNB_OUTPUT_NAME "libbitsandbytes") + +if(BUILD_CUDA) + enable_language(CUDA) # This will fail if CUDA is not found + + # Convert the CUDA version from X.Y.z to XY. There's probably a shorter way of doing this + string(REGEX MATCH "^[0-9]+.[0-9]+" _CUDA_VERSION_FIRST_TWO "${CMAKE_CUDA_COMPILER_VERSION}") + string(REPLACE "." "" CUDA_VERSION_SHORT "${_CUDA_VERSION_FIRST_TWO}") + + # Expose a cache variable that the user can set to ensure the correct version of CUDA is found + set(CUDA_VERSION "${CUDA_VERSION_SHORT}" CACHE STRING "Expected CUDA Version Shortcode") + + message(STATUS "CUDA Version: ${CUDA_VERSION_SHORT} (${CMAKE_CUDA_COMPILER_VERSION})") + message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}") + + # It should match the discovered version + if(NOT CUDA_VERSION STREQUAL "${CUDA_VERSION_SHORT}") + message(FATAL_ERROR "You've specified CUDA version ${CUDA_VERSION} however the CUDA compiler found is ${CUDA_VERSION_SHORT}." + " Ensure the desired CUDA compiler is the first one available on your PATH." + ) + endif() + + if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS "11.0") + message(FATAL_ERROR "CUDA Version < 11 is not supported") + elseif(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0") + message(FATAL_ERROR "CUDA Version > 12 is not supported") + endif() + + string(APPEND CMAKE_CUDA_FLAGS " --use_fast_math") + if(PTXAS_VERBOSE) + # Verbose? Outputs register usage information, and other things... + string(APPEND CMAKE_CUDA_FLAGS " -Xptxas=-v") + endif() + + foreach(capability ${CMAKE_CUDA_ARCHITECTURES_ALL}) + # Most of the items here are like: `xx-real`, so we just extract the `xx` portion + string(REGEX MATCH "[0-9]+" capability_id "${capability}") + if(capability_id GREATER 0) + list(APPEND POSSIBLE_CAPABILITIES ${capability_id}) + endif() + endforeach() + + # This can be changed via -D argument to CMake + # By default all possible capabilities are compiled + set(COMPUTE_CAPABILITY "${POSSIBLE_CAPABILITIES}" CACHE STRING "Compute Capabilities Targeted") + + message(STATUS "CUDA Capabilities Available: ${POSSIBLE_CAPABILITIES}") + message(STATUS "CUDA Capabilities Selected: ${COMPUTE_CAPABILITY}") + + foreach(capability ${COMPUTE_CAPABILITY}) + string(APPEND CMAKE_CUDA_FLAGS " -gencode arch=compute_${capability},code=sm_${capability}") + endforeach() + + message(STATUS "CUDA NVCC Flags: ${CMAKE_CUDA_FLAGS}") + + list(APPEND SRC_FILES ${CUDA_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") + if(NO_CUBLASLT) + string(APPEND BNB_OUTPUT_NAME "_nocublaslt") + endif() +else() + message(STATUS "Building CPU Only") + string(APPEND BNB_OUTPUT_NAME "_cpu") + if(NO_CUBLASLT) + message(WARNING "We're building in CPU only mode but NO_CUBLASLT is enabled. It will have no effect.") + endif() +endif() + +add_library(libbitsandbytes SHARED ${SRC_FILES}) +target_include_directories(libbitsandbytes PUBLIC csrc include) +target_compile_features(libbitsandbytes PUBLIC cxx_std_14) + + +if(BUILD_CUDA) + target_compile_definitions(libbitsandbytes PUBLIC BUILD_CUDA) + target_link_libraries(libbitsandbytes PUBLIC cudart cublas cusparse) + if(NO_CUBLASLT) + target_compile_definitions(libbitsandbytes PUBLIC NO_CUBLASLT) + else() + target_link_libraries(libbitsandbytes PUBLIC cublasLt) + endif() + + set_target_properties(libbitsandbytes + PROPERTIES + CUDA_SEPARABLE_COMPILATION ON + ) +endif() + +set_target_properties(libbitsandbytes + PROPERTIES + OUTPUT_NAME ${BNB_OUTPUT_NAME} + # We have to use a generator expression to prevent MSVC Debug/Release subdirs being made + RUNTIME_OUTPUT_DIRECTORY "$<1:${CMAKE_SOURCE_DIR}/bitsandbytes>" + POSITION_INDEPENDENT_CODE ON # The `-fPIC` commands for non-windows compilers + WINDOWS_EXPORT_ALL_SYMBOLS ON # On Windows, export all c methods as DLL exports +) diff --git a/README.md b/README.md index 727a86cb5..6e445b1f9 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ CUDA_VERSION=117 make cuda11x python setup.py install ``` +On Windows you *must* compile it from source. See [compile_from_source](./compile_from_source.md). + **Using Int8 inference with HuggingFace Transformers** ```python diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index f3edf4c73..04c21796e 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -116,8 +116,16 @@ def manual_override(self): def run_cuda_setup(self): self.initialized = True self.cuda_setup_log = [] - + + package_dir = Path(__file__).parent.parent binary_name, cudart_path, cc, cuda_version_string = evaluate_cuda_setup() + # Find the correct suffix based on what we can see + for suffix in (".so", ".dll"): + binary_path = package_dir / f"{binary_name}{suffix}" + if binary_path.exists(): + binary_name = f"{binary_name}{suffix}" + break + self.cudart_path = cudart_path self.cuda_available = torch.cuda.is_available() self.cc = cc @@ -125,7 +133,6 @@ def run_cuda_setup(self): self.binary_name = binary_name self.manual_override() - package_dir = Path(__file__).parent.parent binary_path = package_dir / self.binary_name try: @@ -150,10 +157,10 @@ def run_cuda_setup(self): self.add_log_entry('') self.generate_instructions() raise Exception('CUDA SETUP: Setup Failed!') - self.lib = ct.cdll.LoadLibrary(binary_path) + self.lib = ct.cdll.LoadLibrary(str(binary_path)) else: - self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...") - self.lib = ct.cdll.LoadLibrary(binary_path) + self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path!s}...") + self.lib = ct.cdll.LoadLibrary(str(binary_path)) except Exception as ex: self.add_log_entry(str(ex)) @@ -332,7 +339,7 @@ def evaluate_cuda_setup(): cuda_setup.add_log_entry(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'), ('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')) cuda_setup.add_log_entry('='*80) - if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None + if not torch.cuda.is_available(): return 'libbitsandbytes_cpu', None, None, None cudart_path = determine_cuda_runtime_lib_path() ccs = get_compute_capabilities() @@ -356,9 +363,9 @@ def evaluate_cuda_setup(): # since most installations will have the libcudart.so installed, but not the compiler if has_cublaslt: - binary_name = f"libbitsandbytes_cuda{cuda_version_string}.so" + binary_name = f"libbitsandbytes_cuda{cuda_version_string}" else: - "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so" - binary_name = f"libbitsandbytes_cuda{cuda_version_string}_nocublaslt.so" + "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt" + binary_name = f"libbitsandbytes_cuda{cuda_version_string}_nocublaslt" return binary_name, cudart_path, cc, cuda_version_string diff --git a/compile_from_source.md b/compile_from_source.md index f5de4db74..56e6d71cc 100644 --- a/compile_from_source.md +++ b/compile_from_source.md @@ -1,6 +1,6 @@ # Compiling from source -Basic steps. +Basic steps for Unix (see Windows steps below): 1. `CUDA_VERSION=XXX make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cuda12x, cpuonly` 2. `python setup.py install` @@ -38,3 +38,21 @@ If you have problems compiling the library with these instructions from source, Since 0.39.1 bitsandbytes installed via pip no longer provides Kepler binaries and these need to be compiled from source. Follow the steps above and instead of `cuda11x_nomatmul` etc use `cuda11x_nomatmul_kepler` +# Compilation on Windows + +We'll use CMake to do all the heavy lifting for us here. CUDA and the MSVC compiler can be finicky. + +- Install [Microsoft Visual Studio](https://visualstudio.microsoft.com/) +- Install the CUDA Toolkit to match your pytorch CUDA version + - This will install `CUDA xx.y.props` to `BuildCustomizations` (see some documentation [here](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html#sample-projects)) + - i.e. for Visual Studio 2022 and CUDA 11.7, there should be some files `CUDA 11.7...` in here: `C:\Program Files\Microsoft Visual Studio\2022\Professional\MSBuild\Microsoft\VC\v170\BuildCustomizations` +- Install CMake, at least 3.18 (the latest version is usually fine) +- [Optional] Lookup your GPU's [CUDA Compute Capability](https://developer.nvidia.com/cuda-gpus) + - If you don't do this, it will compile optimized code for all possible compute capabilities, which takes much longer... + - Insert it into the command below (i.e. `8.6` -> `86`) +- Configure the CMake Project: + - `cmake -B build . "-DCOMPUTE_CAPABILITY=86"` +- Build the project + - `cmake --build build --config Release` +- Install bitsandbytes + - `pip install .` diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index e28e7b2c2..18795f87d 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -1,5 +1,9 @@ #include +#ifdef _WIN32 +#include +#else #include +#endif #include using namespace BinSearch; @@ -31,7 +35,11 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) { long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; - pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks); +#ifdef _WIN32 + std::thread *threads = (std::thread *) malloc(sizeof(std::thread) * valid_chunks); +#else + pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks); +#endif struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *)); @@ -55,14 +63,23 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long arg->threadidx = block_idx / blocksize; arg->blocksize = blocksize; +#ifdef _WIN32 + new (&threads[chunks_processed]) std::thread(quantize_block, arg); +#else pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg); +#endif chunks_processed += 1; if(chunks_processed == valid_chunks){ break; } } for (int i = 0; i < valid_chunks; i++) + { +#ifdef _WIN32 + threads[i].join(); +#else int err = pthread_join(threads[i], NULL); - +#endif + } free(threads); for (int i = 0; i < valid_chunks; i++) free(args[i]); diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 1ab8aa242..2e5ddb12c 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3816,12 +3816,12 @@ template __global__ void kgemm_4bit_inference_naive(int M, int N template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); @@ -3847,6 +3847,9 @@ template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ +MAKE_Optimizer32bit1State(ADAM, half) +MAKE_Optimizer32bit1State(ADAM, float) +MAKE_Optimizer32bit1State(ADAM, __nv_bfloat16) MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, float) MAKE_Optimizer32bit1State(RMSPROP, half) @@ -3880,14 +3886,45 @@ template __global__ void kPreconditionOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +MAKE_PreconditionOptimizer32bit2State(MOMENTUM, float) +MAKE_PreconditionOptimizer32bit2State(MOMENTUM, half) +MAKE_PreconditionOptimizer32bit2State(MOMENTUM, __nv_bfloat16) +MAKE_PreconditionOptimizer32bit2State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit2State(RMSPROP, half) +MAKE_PreconditionOptimizer32bit2State(RMSPROP, __nv_bfloat16) +MAKE_PreconditionOptimizer32bit2State(LARS, float) +MAKE_PreconditionOptimizer32bit2State(LARS, half) +MAKE_PreconditionOptimizer32bit2State(LARS, __nv_bfloat16) +MAKE_PreconditionOptimizer32bit2State(ADAGRAD, float) +MAKE_PreconditionOptimizer32bit2State(ADAGRAD, half) +MAKE_PreconditionOptimizer32bit2State(ADAGRAD, __nv_bfloat16) +MAKE_PreconditionOptimizer32bit2State(LION, float) +MAKE_PreconditionOptimizer32bit2State(LION, half) +MAKE_PreconditionOptimizer32bit2State(LION, __nv_bfloat16) + +#define MAKE_Optimizer32bit2State(oname, gtype) \ +template __global__ void kOptimizer32bit2State(gtype* g, gtype* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, \ const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +MAKE_Optimizer32bit2State(ADAM, float) +MAKE_Optimizer32bit2State(ADAM, half) +MAKE_Optimizer32bit2State(ADAM, __nv_bfloat16) +MAKE_Optimizer32bit2State(MOMENTUM, float) +MAKE_Optimizer32bit2State(MOMENTUM, half) +MAKE_Optimizer32bit2State(MOMENTUM, __nv_bfloat16) +MAKE_Optimizer32bit2State(RMSPROP, float) +MAKE_Optimizer32bit2State(RMSPROP, half) +MAKE_Optimizer32bit2State(RMSPROP, __nv_bfloat16) +MAKE_Optimizer32bit2State(LARS, float) +MAKE_Optimizer32bit2State(LARS, half) +MAKE_Optimizer32bit2State(LARS, __nv_bfloat16) +MAKE_Optimizer32bit2State(ADAGRAD, float) +MAKE_Optimizer32bit2State(ADAGRAD, half) +MAKE_Optimizer32bit2State(ADAGRAD, __nv_bfloat16) +MAKE_Optimizer32bit2State(LION, float) +MAKE_Optimizer32bit2State(LION, half) +MAKE_Optimizer32bit2State(LION, __nv_bfloat16) + #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ float *unorm, \ @@ -3900,10 +3937,16 @@ template __global__ void kPreconditionOptimizerStatic8bit1State(gt const float gnorm_scale, \ const int n); \ +MAKE_PreconditionStatic8bit1State(ADAM, half) +MAKE_PreconditionStatic8bit1State(ADAM, float) MAKE_PreconditionStatic8bit1State(MOMENTUM, half) MAKE_PreconditionStatic8bit1State(MOMENTUM, float) MAKE_PreconditionStatic8bit1State(RMSPROP, half) MAKE_PreconditionStatic8bit1State(RMSPROP, float) +MAKE_PreconditionStatic8bit1State(LARS, half) +MAKE_PreconditionStatic8bit1State(LARS, float) +MAKE_PreconditionStatic8bit1State(ADAGRAD, half) +MAKE_PreconditionStatic8bit1State(ADAGRAD, float) MAKE_PreconditionStatic8bit1State(LION, half) MAKE_PreconditionStatic8bit1State(LION, float) @@ -3919,10 +3962,16 @@ template __global__ void kOptimizerStatic8bit1State(gtype* p, gtyp const float gnorm_scale, \ const int n); \ +MAKE_optimizerStatic8bit1State(ADAM, half) +MAKE_optimizerStatic8bit1State(ADAM, float) MAKE_optimizerStatic8bit1State(MOMENTUM, half) MAKE_optimizerStatic8bit1State(MOMENTUM, float) MAKE_optimizerStatic8bit1State(RMSPROP, half) MAKE_optimizerStatic8bit1State(RMSPROP, float) +MAKE_optimizerStatic8bit1State(LARS, half) +MAKE_optimizerStatic8bit1State(LARS, float) +MAKE_optimizerStatic8bit1State(ADAGRAD, half) +MAKE_optimizerStatic8bit1State(ADAGRAD, float) MAKE_optimizerStatic8bit1State(LION, half) MAKE_optimizerStatic8bit1State(LION, float) @@ -3938,6 +3987,16 @@ template __global__ void kPreconditionOptimizerStatic8bit2State(gt MAKE_PreconditionStatic8bit2State(ADAM, half) MAKE_PreconditionStatic8bit2State(ADAM, float) +MAKE_PreconditionStatic8bit2State(MOMENTUM, half) +MAKE_PreconditionStatic8bit2State(MOMENTUM, float) +MAKE_PreconditionStatic8bit2State(RMSPROP, half) +MAKE_PreconditionStatic8bit2State(RMSPROP, float) +MAKE_PreconditionStatic8bit2State(LARS, half) +MAKE_PreconditionStatic8bit2State(LARS, float) +MAKE_PreconditionStatic8bit2State(ADAGRAD, half) +MAKE_PreconditionStatic8bit2State(ADAGRAD, float) +MAKE_PreconditionStatic8bit2State(LION, half) +MAKE_PreconditionStatic8bit2State(LION, float) #define MAKE_optimizerStatic8bit2State(oname, gtype) \ template __global__ void kOptimizerStatic8bit2State(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ @@ -3952,6 +4011,16 @@ template __global__ void kOptimizerStatic8bit2State(gtype* p, gtyp MAKE_optimizerStatic8bit2State(ADAM, half) MAKE_optimizerStatic8bit2State(ADAM, float) +MAKE_optimizerStatic8bit2State(MOMENTUM, half) +MAKE_optimizerStatic8bit2State(MOMENTUM, float) +MAKE_optimizerStatic8bit2State(RMSPROP, half) +MAKE_optimizerStatic8bit2State(RMSPROP, float) +MAKE_optimizerStatic8bit2State(LARS, half) +MAKE_optimizerStatic8bit2State(LARS, float) +MAKE_optimizerStatic8bit2State(ADAGRAD, half) +MAKE_optimizerStatic8bit2State(ADAGRAD, float) +MAKE_optimizerStatic8bit2State(LION, half) +MAKE_optimizerStatic8bit2State(LION, float) template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); @@ -4049,6 +4118,21 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise #include -#include #include #include diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.cpp similarity index 99% rename from csrc/pythonInterface.c rename to csrc/pythonInterface.cpp index 865e4b6d5..babccc9cc 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.cpp @@ -45,7 +45,6 @@ MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) - #define MAKE_FUNC32(fname, oname, gtype, gbits) \ void fname##32bit_grad_##gbits(gtype *g, gtype *p, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ diff --git a/include/Type.h b/include/Type.h index 720bfb86f..03328f749 100644 --- a/include/Type.h +++ b/include/Type.h @@ -201,13 +201,30 @@ struct CondData FORCE_INLINE operator const T() const { return 0;} }; +#ifdef _WIN32 +// The `IsComplete` buildtime check doesn't work on Windows +// Given the usage of the BinAlgoBase class, `I != Scalar` should be equivalent to the unix +// equivalent below of `Details::IsComplete>::value` +template +struct WouldAlgoVecBaseBeComplete +{ + static constexpr bool value{I != Scalar}; +}; +#else +template +struct WouldAlgoVecBaseBeComplete : public Details::IsComplete> +{ + +}; +#endif + template -struct BinAlgoBase : Details::conditional< Details::IsComplete>::value +struct BinAlgoBase : Details::conditional< WouldAlgoVecBaseBeComplete::value , Details::AlgoVecBase , Details::AlgoScalarToVec >::type { - typedef typename Details::conditional< Details::IsComplete>::value + typedef typename Details::conditional< WouldAlgoVecBaseBeComplete::value , Details::AlgoVecBase , Details::AlgoScalarToVec >::type base_t; diff --git a/setup.py b/setup.py index d6267088e..7b033a910 100644 --- a/setup.py +++ b/setup.py @@ -7,8 +7,9 @@ from setuptools import find_packages, setup -libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.so")) -libs = [os.path.basename(p) for p in libs] +libs_so = list(glob.glob("./bitsandbytes/libbitsandbytes*.so")) +libs_dll = list(glob.glob("./bitsandbytes/libbitsandbytes*.dll")) +libs = [os.path.basename(p) for p in (*libs_so, *libs_dll)] print("libs:", libs)