diff --git a/csrc/common.h b/csrc/common.h index c99034e78..bcc979968 100644 --- a/csrc/common.h +++ b/csrc/common.h @@ -7,8 +7,16 @@ using namespace BinSearch; #define BLOCK_SIZE 16384 +#if defined(USE_AVX) || defined(USE_AVX2) +#define INSTR_SET AVX +#elif defined(USE_SSE41) || defined(USE_SSE42) +#define INSTR_SET SSE +#else +#define INSTR_SET Scalar +#endif + struct quantize_block_args { - BinAlgo *bin_searcher; + BinAlgo *bin_searcher; float *code; float *A; float *absmax; diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index e28e7b2c2..2090e3dd3 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -1,5 +1,9 @@ #include +#ifdef _WIN32 +#include +#else #include +#endif #include using namespace BinSearch; @@ -23,7 +27,7 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long num_blocks += n % blocksize == 0 ? 0 : 1; const uint32 elements_code = 256; - BinAlgo bin_searcher(code, elements_code); + BinAlgo bin_searcher(code, elements_code); int thread_wave_size = 256; // we chunk the thresds into waves of 256 since the max limit is @@ -31,7 +35,11 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) { long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; +#ifdef _WIN32 + std::thread *threads = (std::thread *) malloc(sizeof(std::thread) * valid_chunks); +#else pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks); +#endif struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *)); @@ -55,14 +63,23 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long arg->threadidx = block_idx / blocksize; arg->blocksize = blocksize; +#ifdef _WIN32 + new (&threads[chunks_processed]) std::thread(quantize_block, arg); +#else pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg); +#endif chunks_processed += 1; if(chunks_processed == valid_chunks){ break; } } for (int i = 0; i < valid_chunks; i++) + { +#ifdef _WIN32 + threads[i].join(); +#else int err = pthread_join(threads[i], NULL); - +#endif + } free(threads); for (int i = 0; i < valid_chunks; i++) free(args[i]); diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 1ab8aa242..c2e2d7da7 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3816,12 +3816,12 @@ template __global__ void kgemm_4bit_inference_naive(int M, int N template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index f37b3b3af..da9df6af0 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -9,7 +9,6 @@ #include #include -#include #include #include diff --git a/include/SIMD.h b/include/SIMD.h index a2ac1a9ae..d559e9f55 100644 --- a/include/SIMD.h +++ b/include/SIMD.h @@ -64,6 +64,16 @@ template <> struct InstrFloatTraits typedef __m128d vec_t; }; +template <> struct InstrFloatTraits +{ + typedef float vec_t; +}; + +template <> struct InstrFloatTraits +{ + typedef double vec_t; +}; + template struct FTOITraits {