Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invalid HLSL generated for bwd_diff function #5776

Closed
tunabrain opened this issue Dec 5, 2024 · 0 comments · Fixed by #6000
Closed

Invalid HLSL generated for bwd_diff function #5776

tunabrain opened this issue Dec 5, 2024 · 0 comments · Fixed by #6000
Assignees
Labels
goal:client support Feature or fix needed for a current slang user. kind:bug something doesn't work like it should

Comments

@tunabrain
Copy link

I'm running into an issue where slang will generate invalid HLSL for the following snippet:

struct Foo
{
    float foo(int idx) { return 5.0f; }
}
struct WrappedBuffer<T, let D : int>
{
    StructuredBuffer<T> buffer;
    int[D] shape;

    T get(int idx) { return buffer[idx]; }
}
struct GradInBuffer<T : IDifferentiable, let D : int>
{
    RWStructuredBuffer<T> primal;
    WrappedBuffer<T.Differential, D> grad_in;
}

[Differentiable, BackwardDerivative(set_bwd)]
void set(GradInBuffer<float[1], 1> t, int idx, float[1] value)
{
    t.primal[idx] = detach(value);
}
void set_bwd(GradInBuffer<float[1], 1> t, int idx, inout DifferentialPair<float[1]> grad)
{
    // Generates invalid HLSL, "candidate function not viable: no known conversion from 'WrappedBuffer_0' to 'WrappedBuffer_1' for 1st argument"
    grad = diffPair(grad.p, t.grad_in.get(idx));

    // This works:
    //grad = diffPair(grad.p, t.grad_in.buffer[idx]);
}

struct CallData
{
    Foo weights;
    RWStructuredBuffer<float> dOut;
    GradInBuffer<float[1], 1> result;
}
ParameterBlock<CallData> call_data;

[shader("compute")]
[numthreads(32, 1, 1)]
void main(uint3 dispatchThreadID: SV_DispatchThreadID)
{
    int idx = dispatchThreadID.x;

    float[1] primal = { call_data.weights.foo(idx) };
    DifferentialPair<float[1]> dResult = diffPair(primal);
    bwd_diff(set)(call_data.result, idx, dResult);

    call_data.dOut[idx] = dResult.d[0];
}

It seems to generate two identical specializations for WrappedBuffer (WrappedBuffer_0 and WrappedBuffer_1), and calls the get method of WrappedBuffer_1 with an instance of WrappedBuffer_0. This only happens for the backwards pass and is very finicky; e.g. removing int[D] shape; or changing T from float[1] to float makes it go away.

@saipraveenb25 saipraveenb25 added goal:client support Feature or fix needed for a current slang user. kind:bug something doesn't work like it should labels Dec 6, 2024
@saipraveenb25 saipraveenb25 added this to the Q4 2024 (Fall) milestone Dec 6, 2024
@saipraveenb25 saipraveenb25 self-assigned this Dec 6, 2024
@bmillsNV bmillsNV assigned kaizhangNV and unassigned saipraveenb25 Dec 11, 2024
kaizhangNV pushed a commit to kaizhangNV/slang that referenced this issue Jan 3, 2025
close shader-slang#5776

When we start specialize a "specialize" IR, we should
make sure all the elements are fully specialized, but
we miss checking the elements of an array. This change
will check the it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
goal:client support Feature or fix needed for a current slang user. kind:bug something doesn't work like it should
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants