Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Applying the jetson fixes #847

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ivansmith7795
Copy link

No description provided.

@ivansmith7795
Copy link
Author

Merging changes from jetson branch

@TimDettmers
Copy link
Collaborator

Did this work out for you? It seems a straightforward fix and a good contribution if this would make the library jetson compatible.

@TimDettmers TimDettmers reopened this Jan 1, 2024
@TimDettmers TimDettmers added high priority (first issues that will be worked on) Low Risk Risk of bugs in transformers and other libraries labels Jan 1, 2024
@Titus-von-Koeller
Copy link
Collaborator

@rickardp @younesbelkada

Do you have opinions on this PR? Could one of you two do the review?

@@ -41,14 +41,15 @@ CC_KEPLER += -gencode arch=compute_37,code=sm_37 # Kepler
CC_CUDA11x := -gencode arch=compute_75,code=sm_75
CC_CUDA11x += -gencode arch=compute_80,code=sm_80
CC_CUDA11x += -gencode arch=compute_86,code=sm_86

CC_CUDA11x += -gencode arch=compute_87,code=sm_87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we confirm that the cmake file works with the Jetson devices? It compiles, but I do not have a device to test with.

Wheels can be taken from the latest build from here
https://github.com/TimDettmers/bitsandbytes/actions/workflows/python-package.yml

@@ -2409,7 +2409,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
}


template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols)
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, int8_t *out_col_normed, int8_t *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this is needed, but as long as it compiles on all platforms (looking at you, MSVC :) ), I don't see a problem with the change either .IIRC, int8_t is exactly 8 bits, while char is at least 8 bits

@@ -28,6 +28,9 @@ FORCE_INLINE int popcnt32(int x32)

#if defined(USE_AVX) || defined(USE_AVX2)
#include <immintrin.h>
#elif defined __aarch64__
#warning "--- THIS IS AARCH64"
#include <sse2neon.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are going to need to support Neon one way or the other. I am pondering if this is the right approach though, or if we should implement the Neon intrinsics directly? If it saves us time in the short run, maybe a viable option?

@@ -1,3 +1,4 @@
#!/usr/bin/python3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want this, /usr/bin/env python3 is more portable.

Also, the file is not executable. Need to chmod 755 and commit if this is to make sense

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority (first issues that will be worked on) Low Risk Risk of bugs in transformers and other libraries
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants