Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[common] fix bug for dtype convertion function #79

Closed
wants to merge 1 commit into from

Conversation

abenmao
Copy link
Contributor

@abenmao abenmao commented Nov 24, 2023

No description provided.

@changqi1
Copy link
Contributor

changqi1 commented Nov 24, 2023

@abenmao I have some concerns about the kernel perf because there is a conditional judgment in every loop. Would use the following logic to fix this bug.

    constexpr int kStep = 16;
    int blockSize = size / kStep;
    int remainder = size % kStep;

    for (int i = 0; i < blockSize; ++i) {
        __m512 input_vector = _mm512_loadu_ps(src + i * 16);
        __m256i output_vector = cvt_fp32_to_bf16(input_vector);
        _mm256_mask_storeu_epi16(dst + i * 16, 0xffff, output_vector);
    }

    if (remainder != 0) {
        __mmask16 mask = 0xFFFF >> (16 - remainder);
        __m512 input_vector = _mm512_maskz_loadu_ps(mask, src + size - remainder);
        __m256i output_vector = cvt_fp32_to_bf16(input_vector);
        _mm256_mask_storeu_epi16(dst + size - remainder, mask, output_vector);
    }

@changqi1 changqi1 closed this Nov 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants