Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different values when using nki.simulate_kernel #1051

Closed
iamsalil opened this issue Dec 4, 2024 · 37 comments
Closed

Different values when using nki.simulate_kernel #1051

iamsalil opened this issue Dec 4, 2024 · 37 comments
Labels

Comments

@iamsalil
Copy link

iamsalil commented Dec 4, 2024

I've written a kernel. It has no compile-time or run-time errors. When I wrap it with nki.simulate_kernel(), it gives me a different output.

What can possibly cause this?

I've changed all my affine_range()'s to sequential_range()'s in case this error/discrepancy is being caused by some parallelism issue. However, this did not fix my issue at all.

Also, how do I even possibly debug this? I can't even look at intermediate values because nki.language.device_print() does not work unless I'm doing kernel simulation.

@aws-zhehongb
Copy link

could you paste your kernel & the nki.simulate_kernel code in a github secret gist and share with me?

@iamsalil
Copy link
Author

iamsalil commented Dec 5, 2024

could you paste your kernel & the nki.simulate_kernel code in a github secret gist and share with me?

Here it is:
https://gist.github.com/iamsalil/94eae632141203488ffc63eec14ecec6

@aws-zhehongb
Copy link

could you try to use nl.device_print("value of y:", y) to print some intermediate values? it will work under nki.simulate_kernel but will not work if you are doing baremetal

@iamsalil
Copy link
Author

iamsalil commented Dec 5, 2024

I can do that to print intermediate values but that doesn't actually help me figure out what is going wrong when I don't use simulate. And the simulate outputs are correct while the non-simulate outputs are not correct.

How do you suggest I use nl.device_print()?

@aws-zhehongb
Copy link

you can probe into the intermediate value by store intermediate values to HBM and return it as kernel return:

return X_out, some_other_tensor

@aws-zhehongb
Copy link

high level hints to workaround compiler bugs:

  1. use sequential_range on all loops
  2. try to write nki code in "single assignment" form like functional programming language
  3. conv2d require padding of the input. dont actually pad the tensor, use masking instead: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/api/nki.api.shared.html#nki-api-masking be careful that matmult masking syntax is special. also try to use nisa.nc_matmul instead of nl.matmul

@iamsalil
Copy link
Author

iamsalil commented Dec 5, 2024

I can try probing values output by the kernel.

However, I'm confused by your comment about compiler bugs. Is this a compiler bug? Because no compile-time error is thrown. In fact, no run-time error is thrown, either.

@AWSNB
Copy link
Contributor

AWSNB commented Dec 5, 2024

@iamsalil using res_psum += nl.matmul is fine choice. += will work and perform well when right side is matmul output because of a special hardware circuit in psum that allows accumalation in place

@iamsalil
Copy link
Author

iamsalil commented Dec 5, 2024

@AWSNB Thank you for the suggestion but unfortunately, that did not fix it. My code no longer has any other += in it.

@JonathanHenson
Copy link

JonathanHenson commented Dec 5, 2024

@iamsalil I've asked some questions on the gist you shared. Feel free to dialog there if you'd like and I'll be sure to summarize the sharable parts back here when we're done.

@JonathanHenson
Copy link

JonathanHenson commented Dec 5, 2024

@iamsalil I found at least one thing of interest, and commented on the gist. Give it a try and let us know if that unblocks you or not. I was incorrect in my hint as I was using a different decorator to invoke the kernel. I will rerun in the morning and see what else I can find.

@ggumen ggumen added the NKI label Dec 5, 2024
@iamsalil
Copy link
Author

iamsalil commented Dec 5, 2024

Hi @JonathanHenson. Thanks so much for taking the time to look at this.

I unfortunately don't see any comments on the gist other than a suggestion to share the outputs. What was your suggestion?

In the meantime, here are the outputs. The inputs were a (4, 128, 30, 14) batch of images convolved with 256 filters of shape (256, 128, 3, 3) with a bias of shape (256, ) that is all zeros (no bias) and a pool_size of 1 (so no pooling occurs). What I've output is the top 5x5 square of the second channel of the first image (i.e. X_out[0, 1, :5, :5]). Note that channel index 1 is not chosen for any particular reason.

The output numbers are not just slightly off. For some elements they are wildly off (element [1, 0] is off by ~ 4%) so it does not seem to me like it's just a rounding issue. Additionally, I know that the simulate results are the correct ones (I am doing this for a homework assignment at my university and the simulate results match the expected results) so the baremetal results are incorrect.

----- 128 256 -----
doing simulate kernel...
(4, 256, 30, 14)
[[300.0873 286.62927 291.14615 289.68362 294.90714]
[284.847 296.27835 290.41602 296.70108 292.8576 ]
[300.0873 286.62927 291.14615 289.68362 294.90714]
[284.847 296.27835 290.41602 296.70108 292.8576 ]
[300.0873 286.62927 291.14615 289.68362 294.90714]]
doing nonsimulate kernel...
(4, 256, 30, 14)
[[300.4392, 285.8078, 296.8551, 293.7848, 290.9920],
[296.4594, 291.2758, 299.6873, 288.6553, 288.3279],
[300.4162, 291.1158, 300.2389, 290.1125, 291.5898],
[297.3030, 293.3142, 294.2658, 295.0752, 289.5477],
[294.5771, 293.7183, 291.7239, 292.0972, 288.8704]]

@aws-serina-tan
Copy link

Hello, can you also share the test inputs and their dtypes? If this is BF16, could you give FP32 a shot. Feel free to link it to the exact test case in your assignment as well: https://github.com/stanford-cs149/asst4-test/blob/main/part2/test_harness.py.

@aws-zhehongb
Copy link

i can reproduce the issue by copy https://gist.github.com/iamsalil/94eae632141203488ffc63eec14ecec6 next to test_harness.py then import the kernel and use it in test_harness

@andxalex
Copy link

andxalex commented Dec 5, 2024

Also facing the same discrepancy issue when not simulating.

Already replaced ranges with sequential_range, though this did not fix the error. Unfortunately changing precisions is not an option.

Repo: https://github.com/andxalex/CS149/blob/main/asst4-trainium/part2/conv2d.py

Reproduce:

Run in simulation mode: python3 test_harness.py --test_maxpool --simulate: passes correctness and errors (latter error is expected)
Run normally: python3 test_harness.py --test_maxpool :fails correctness and errors (latter error is expected)

@andxalex
Copy link

andxalex commented Dec 5, 2024

I ended up figuring it out. The issue was a call the only call to reshape(), which was used to decompose a free dimension

Replacing this with a for loop as shown below fixed the problem:

Bad:

    for out_i in nl.sequential_range(n_tiles_c_out):
        weights[out_i] = nl.load(W[out_i * c_out_pmax:(out_i + 1)*c_out_pmax,:,:,:])
    weights = weights.reshape((n_tiles_c_out, nl.par_dim(c_out_pmax), n_tiles_c_in, c_in_pmax, filter_height, filter_width))

Good:

    for out_i in nl.sequential_range(n_tiles_c_out):
        for in_i in nl.sequential_range(n_tiles_c_in):
            weights[out_i,:, in_i] = nl.load(W[out_i * c_out_pmax:(out_i + 1)*c_out_pmax,in_i * c_in_pmax:(in_i + 1)*c_in_pmax,:,:])

@aws-zhehongb
Copy link

is this issue resolved?

@JonathanHenson
Copy link

JonathanHenson commented Dec 6, 2024

iamsalil

I'm glad you got it sorted. My suggestion last night was based on an error in my configuration, so I didn't want to lead you down the wrong path and deleted it. Thanks for sharing the details and I'm glad we were able to help. Feel free to reach out if you encounter any further issues.

@iamsalil
Copy link
Author

iamsalil commented Dec 6, 2024

@aws-zhehongb This issue was not resolved. I don't know who andxalex is (I presume another student in the class). It seems they had a similar issue as me, posted in this thread, and then were able to resolve it. However, my original issue never resolved, unfortunately.

I am not sure how much longer I will pursue solving this issue, though. I may try digging around a little more and will post any updates if I find them. Thank you NKI team for your support.

@aws-zhehongb
Copy link

looks like it is because reusing bias_tile trigger a compiler bug.

Could you try this?

In the code remove:

    bias_tile = nl.ndarray(
        (nl.par_dim(O_tile), 1),
        dtype=datatype, buffer=nl.sbuf
    )

and change the producer to

            bias_tile = nl.load(bias[o_start:o_end])

@iamsalil
Copy link
Author

iamsalil commented Dec 6, 2024

Doing that didn't seem to have fixed the error, unfortunately. The error occurs even when the bias is all 0s.

@aws-zhehongb
Copy link

@allpan3
Copy link

allpan3 commented Dec 7, 2024

Hi @aws-zhehongb , I've also faced the same problem. Specifically, the small image test of this configuration failed in hardware but passed in simulation:

Output mismatch for input_channels: 128,
output_channels: 256, kernel_size: 3, batch_size: 4,
image_dims: (32, 16), use_bias: False, use_maxpool: False

And when OC=128 the same image size passed in both cases.
For large image with bias, the code passed in simulation but failed hardware in the following test:

Output mismatch for input_channels: 256,              
output_channels: 256, kernel_size: 3, batch_size: 4,
image_dims: (224, 224), use_bias: True, use_maxpool: False

This only occurred after I modified the code such that I transpose the weights outside of the loop, which helped the performance a lot. This can be found in the preload-weights branch. Prior to this change, the code runs correctly in both hardware and simulation (part2 branch).

I've added you to the private repo. Would really appreciate if you can help take a look or give us some advice on how to debug this kind of issues (i.e. passing in simulation but failing in hardware).
Thanks very much 🙏

@aws-zhehongb
Copy link

@allpan3, quick note:

                            X_tensor[ih, ic, :, :] = nl.load(X_reshaped[b, ic, :, ih, :])

is inside loop b and fx, but the dst of the load X_tensor[ih, ic, :, :] is not indexed by b and fx, this may trigger compiler bug

@aws-zhehongb
Copy link

@allpan3 i sent you a pull request.

also you initialize psum with

            for oc in nl.affine_range(NUM_TILE_OC):
                psum_tensor[oc] = nl.copy(B_tensor[oc].broadcast_to((TILE_OC, out_width)))

then += on psum. It is not supported currently. += on psum is only support when the rhs is a matmul because the += is done by special transistors in the psum when the input is matmult.

@allpan3
Copy link

allpan3 commented Dec 7, 2024

@aws-zhehongb Thanks for the help!

is inside loop b and fx, but the dst of the load X_tensor[ih, ic, :, :] is not indexed by b and fx, this may trigger compiler bug

I wasn't aware this is a requirement. I was thinking since each X row will get reused filter_height times, I'd allocate space for all rows and hope the recent loaded ones stay in SBUF. It probably needs direct allocation in order to work as intended as well. I will explore other options.

then += on psum. It is not supported currently. += on psum is only support when the rhs is a matmul because the += is done by special transistors in the psum when the input is matmul.

I'm a bit confused. So are you saying I cannot do = nl.copy here then do += later? How else should I preload the bias?
I was doing something similar before and it worked.

@AWSNB
Copy link
Contributor

AWSNB commented Dec 7, 2024 via email

@allpan3
Copy link

allpan3 commented Dec 7, 2024

@AWSNB Thanks for the prompt response. I tried simply changing the code based on what you said, and now the small test passed without using @aws-zhehongb's changes, which was not performant due to duplicated loading of input rows.

I'm still facing issue with the larger image size.

Running correctness test for conv2d kernel with larger images + bias...Output mismatch for 
input_channels: 256,                         
output_channels: 256, kernel_size: 3, batch_size: 4,                        
image_dims: (224, 224), use_bias: True, use_maxpool: False
Failed 😢

Still not entirely sure how allocating an array inside loops without using those indices on the LHS affects the correctness.
I will try a few things and see if it solves the problem.

@allpan3
Copy link

allpan3 commented Dec 7, 2024

I think the main error still comes from how bias is loaded (even after the suggested modification). I tried the smaller image size tests with use_bias=True, then it starts to fail on hardware.

@AWSNB
Copy link
Contributor

AWSNB commented Dec 7, 2024

@allpan3 also note that there may be compiler errors that are being truncated/hidden by the test harness. we suggest you try to run the conv2d kernel directly, as @aws-serina-tan mentioned in this comment. can you try that to see if you getting more useful error ?

@allpan3
Copy link

allpan3 commented Dec 7, 2024

@allpan3 also note that there may be compiler errors that are being truncated/hidden by the test harness. we suggest you try to run the conv2d kernel directly, as @aws-serina-tan mentioned in this comment. can you try that to see if you getting more useful error ?

I just tried that but don't see anything being printed out, even with bias turned on.

@AWSNB
Copy link
Contributor

AWSNB commented Dec 7, 2024

will likely need to wait for others to chime in in the morning

just to point out one common issue we saw with indices , common usually during load/store/copy that hongbin put a comment about in case that helps

@allpan3
Copy link

allpan3 commented Dec 7, 2024

will likely need to wait for others to chime in in the morning

just to point out one common issue we saw with indices , common usually during load/store/copy that hongbin put a comment about in case that helps

@AWSNB Yea, although I don't exactly understand this, I've both tried hongbin's fix and implemented a more performant version myself based on the suggestion. Still failing the same tests which seem to be related to bias loading/adding.

I've added you to the repo if you want to see code, but huge thanks for staying online for us Friday eveningmidnight!

@aws-zhehongb
Copy link

the bias fail because you cannot initialize the psum and += on it. Currently you must initialize psum to 0

@aws-zhehongb
Copy link

ok, more context for the psum behavior.

when the matmul instruction write to psum, it have two mode: overwrite mode and accumulation mode

  • overwrite mode mean: psum[idx] = matmul
  • accumulate mode mean psum[idx] += matmul

There is actually an internal flag in the matmul instruction to control overwrite mode vs accumulate mode.

When we write

for i in range(N):
  psum += matmul

it is actually like:

for i in range(N):
  psum += matmul, overwrite="i == 0"

Because currently we cannot explicitly control this overwrite/accumulate flag, the compiler will always set overwrite=True for the first matmult in the accumulation loop.

As a result, any per-exisiting value in the psum will be ignore when we enter psum accumulation loop

@aws-zhehongb
Copy link

in your latest code you need to avoid the extra indices on X_tensor to workaround compiler bug such that it can pass Running correctness test for conv2d kernel with smaller images:

                        X_tensor = nl.load(X_reshaped[b, ic, :, ih, :])
                        for fx in nl.affine_range(filter_width):
                            # for each filter weight, compute (OW x TILE_IC) x (TILE_IC x TILE_OC) => (OW x TILE_OC).
                            # then complete all NUM_TILE_IC, they contribute to the same (OW x TILE_OC) psum
                            # then complete NUM_TILE_OC, they contribute to another TILE_OC of the same OW
                            # slice the input pixels that contribute to ow output width tile. shape = TILE_IC x OW

                            for oc in nl.affine_range(NUM_TILE_OC):
                                # matmul, need to transpose the LHS
                                # result shape = (TILE_OC x OW)
                                psum_tensor[oc, :, oh1, :] += nl.matmul(W_tensor[fy, fx, oc, ic, :, :], X_tensor[:, fx:fx+out_width], transpose_x=True)

For the bias test, you need to do the bias addition without trying to use psum accumulation

@allpan3
Copy link

allpan3 commented Dec 8, 2024

As a result, any per-exisiting value in the psum will be ignore when we enter psum accumulation loop

I see, this explains why tests fail with bias. I think I understand what needs to be done now, but I have to move on to other tasks. I may revisit this if I get a chance in the future.

Thanks for your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

9 participants