Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MaxIters(N)] causes variables to not be saved for backwards pass #6048

Open
tunabrain opened this issue Jan 9, 2025 · 0 comments · May be fixed by #6054
Open

[MaxIters(N)] causes variables to not be saved for backwards pass #6048

tunabrain opened this issue Jan 9, 2025 · 0 comments · May be fixed by #6054
Assignees
Labels
goal:quality & productivity Quality issues and issues that impact our productivity coding day to day inside slang

Comments

@tunabrain
Copy link

This is possibly related to #6039, with similar behavior but a different trigger. Using [MaxIters] in the following example causes the parameter context of store to not be saved during primal execution. write_bwd in the backwards pass only sees a value of 0 for idx[0] and reads from the wrong memory location. Replacing [MaxIters] with [ForceUnroll] fixes the issue.

struct GradientBuffer<let D : int>
{
    RWStructuredBuffer<float> primal;
    StructuredBuffer<float> grad;
    int strides[D];

    int toIndex(int idx[D]) {
        int result = 0;
        for (int i = 0; i < D; ++i)
            result += strides[i] * idx[i];
        return result;
    }

    [Differentiable]
    void write(int[D] idx, float v) { primal[toIndex(idx)] = detach(v); }

    [BackwardDerivativeOf(write)]
    void write_bwd(int[D] idx, inout DifferentialPair<float> d) { d = diffPair(d.p, grad[toIndex(idx)]); }

    [Differentiable]
    void store<let N : int>(int context[D - 1], in float value[N])
    {
        int idx[D];
        //[ForceUnroll] /* Using ForceUnroll instead of MaxIters makes it work */
        [MaxIters(2)]
        for (int i = 0; i < D - 1; ++i)
            idx[i] = context[i];
        [ForceUnroll]
        for (int i = 0; i < N; i++) {
            idx[D - 1] = i;
            write(idx, value[i]);
        }
    }
}

[Differentiable]
void test(GradientBuffer<2> buf, int[1] base, float[3] value)
{
    buf.store(base, value);
}

[shader("compute")]
float3 repro(RWStructuredBuffer<float> primal, StructuredBuffer<float> grad)
{
    float input[3];
    input[0] = input[1] = input[2] = 1.0f;
    var result = diffPair(input);
    GradientBuffer<2> buf = { primal, grad, {3, 1} };
    bwd_diff(test)(buf, { 1 }, result);
    return float3(result.d[0], result.d[1], result.d[2]);
}

@bmillsNV bmillsNV added this to the Q1 2025 (Winter) milestone Jan 9, 2025
@bmillsNV bmillsNV added the goal:quality & productivity Quality issues and issues that impact our productivity coding day to day inside slang label Jan 9, 2025
@csyonghe csyonghe linked a pull request Jan 10, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
goal:quality & productivity Quality issues and issues that impact our productivity coding day to day inside slang
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants