Skip to content

Commit

Permalink
minimal fix to support Windows
Browse files Browse the repository at this point in the history
based on @Jamezo97 and @acpopescu work

manually cherry-picked from PR bitsandbytes-foundation#788 and PR bitsandbytes-foundation#229 and cleanup by wkpark

Signed-off-by: Won-Kyu Park <[email protected]>
  • Loading branch information
Jamezo97 authored and wkpark committed Jan 31, 2024
1 parent b90db7e commit 6955b5e
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 8 deletions.
19 changes: 18 additions & 1 deletion csrc/cpu_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#include <BinSearch.h>
#ifdef _WIN32
#include <thread>
#else
#include <pthread.h>
#endif
#include <common.h>

using namespace BinSearch;
Expand Down Expand Up @@ -31,7 +35,11 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
{
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
#ifdef _WIN32
std::thread *threads = (std::thread *) malloc(sizeof(std::thread) * valid_chunks);
#else
pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks);
#endif

struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *));

Expand All @@ -55,14 +63,23 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
arg->threadidx = block_idx / blocksize;
arg->blocksize = blocksize;

#ifdef _WIN32
new (&threads[chunks_processed]) std::thread(quantize_block, arg);
#else
pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg);
#endif
chunks_processed += 1;
if(chunks_processed == valid_chunks){ break; }
}

for (int i = 0; i < valid_chunks; i++)
{
#ifdef _WIN32
threads[i].join();
#else
int err = pthread_join(threads[i], NULL);

#endif
}
free(threads);
for (int i = 0; i < valid_chunks; i++)
free(args[i]);
Expand Down
12 changes: 6 additions & 6 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3821,12 +3821,12 @@ template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(int M, int N
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);

template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);

template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
Expand Down
1 change: 0 additions & 1 deletion csrc/ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

#include <stdio.h>
#include <iostream>
#include <unistd.h>
#include <assert.h>

#include <cuda_runtime_api.h>
Expand Down
10 changes: 10 additions & 0 deletions include/SIMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ template <> struct InstrFloatTraits<SSE, double>
typedef __m128d vec_t;
};

template <> struct InstrFloatTraits<Scalar, float>
{
typedef float vec_t;
};

template <> struct InstrFloatTraits<Scalar, double>
{
typedef double vec_t;
};

template <InstrSet I, typename T>
struct FTOITraits
{
Expand Down

0 comments on commit 6955b5e

Please sign in to comment.