Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add matmul int4 for CUDA #17526

Closed
wants to merge 28 commits into from
Closed

Add matmul int4 for CUDA #17526

wants to merge 28 commits into from

Conversation

yufenglee
Copy link
Member

Description

Motivation and Context

int32_t k = k_block_idx * block_size;
const BlockwiseQuantBlock<T, block_size, bits>* blob_ptr = src_blob + task_idx;
if (nullptr != zero_points) {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], zero_points[task_idx], k, K);

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).
if (nullptr != zero_points) {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], zero_points[task_idx], k, K);
} else {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], k, K);

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).
int32_t k = k_block_idx * block_size;
const BlockwiseQuantBlock<T, block_size, bits>* blob_ptr = src_blob + task_idx;
if (nullptr != zero_points) {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], zero_points[task_idx], k, K);

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).
if (nullptr != zero_points) {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], zero_points[task_idx], k, K);
} else {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], k, K);

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).
int32_t k = k_block_idx * block_size;
const BlockwiseQuantBlock<T, block_size, bits>* blob_ptr = src_blob + task_idx;
if (nullptr != zero_points) {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], zero_points[task_idx], k, K);

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).
if (nullptr != zero_points) {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], zero_points[task_idx], k, K);
} else {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], k, K);

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).
int32_t k = k_block_idx * block_size;
const BlockwiseQuantBlock<T, block_size, bits>* blob_ptr = src_blob + task_idx;
if (nullptr != zero_points) {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], zero_points[task_idx], k, K);

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).
if (nullptr != zero_points) {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], zero_points[task_idx], k, K);
} else {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], k, K);

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).
int32_t k = k_block_idx * block_size;
const BlockwiseQuantBlock<T, block_size, bits>* blob_ptr = src_blob + task_idx;
if (nullptr != zero_points) {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], zero_points[task_idx], k, K);

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).
if (nullptr != zero_points) {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], zero_points[task_idx], k, K);
} else {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], k, K);

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).

Arithmetic overflow: Using operator '*' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '*' to avoid overflow (io.2).
@yufenglee
Copy link
Member Author

A clean one: #17890

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant