Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed instruction::replace() logic. #3553

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from

Conversation

tcgu-amd
Copy link
Contributor

@tcgu-amd tcgu-amd commented Oct 24, 2024

The previous fix with BFS doesn't fully work in more complex cases (e.g. it will fail in the newly added test case check_replace_dag). This fix implements topological sorting to replace instructions in topological order which should work for all cases.

More details:

In a dummy scenario of add2(reduce(x), add1(abs(reduce(x)), sin(reduce(x)))), we will have a dependency tree looking like

reduce _
        \_abs__
         \_sin__\_add1_
          \_____________\_add2

If we call reduce.replace(), BFS will visit the instructions in the following order:

reduce -> abs -> sin -> add2 -> add1

This will causes an error of shape mismatch at add2 because it is called before its input add1.

Topological sorting the instruction tree will yield:

reduce -> sin -> abs -> add1 -> add2

Which is the correct order to process the instructions.

This should be able to extend to more complex cases.

… fully work in more complex cases (e.g. it will fail in the newly added test case check_replace_dag). This fix implements topological sorting to replace instruction in topological order which should work for all cases.
Copy link

codecov bot commented Oct 24, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.16%. Comparing base (1e1a229) to head (92ebe7f).

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #3553   +/-   ##
========================================
  Coverage    92.16%   92.16%           
========================================
  Files          512      512           
  Lines        21401    21408    +7     
========================================
+ Hits         19724    19731    +7     
  Misses        1677     1677           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@pfultz2
Copy link
Collaborator

pfultz2 commented Oct 24, 2024

This seems like it will be really slow since it needs topologically sort until end of the model instead of just until the shapes no longer change.

@tcgu-amd
Copy link
Contributor Author

tcgu-amd commented Oct 24, 2024

This seems like it will be really slow since it needs topologically sort until end of the model instead of just until the shapes no longer change.

Yes unfortunately I think this is definitely going to be slower than the previous implementations. I am not quite sure if there's potentially a better approach since we don't know the dependencies of instructions beforehand until after the sort.

One way I can think of is to take an optimistic approach and perform BFS assuming everything is going to be fine, and on shape mismatch just push the instruction to the back of the queue. Only return the error if all instructions in the queue are shape mismatches. This is a little bit unconventional so I will need to test it to make sure it is going to generate correct results.

Edit: Actually, upon further consideration, I think this problem can be solved easily by using a modified version of Kahn's algorithm. I will update the code and try it out.

@pfultz2
Copy link
Collaborator

pfultz2 commented Oct 24, 2024

There might be a way to traverse up the inputs to check for dependencies. I would need to think about it more.

…ly based on Khan's algorithm.

This version avoids sorting the entire graph, and will terminate when no more changes are requried like old versions
@tcgu-amd
Copy link
Contributor Author

tcgu-amd commented Oct 25, 2024

Hi @pfultz2, I have created a new version of the algorithm that should have the same performance as the old versions.

This is loosely based on Khan's algorithm in that we only process nodes that has been visited by all its children that needs to be replaced.

To achieve this, we perform a BFS from the base instruction as usual, but keep a map counting the number of arguments for each instruction we encounter. If it an instruction is unary, then we can directly process the current instruction. If there's more than one argument, we subtract one from the number of arguments in the map and check to see if the number reaches zero, in which case all of the arguments must have been replaced and we can replace this instruction; otherwise some arguments may still need to be replaced, and we can just skip replacing this instruction for now and wait for it to be encounter again when one of its arguments ultimately adds it back to the queue.

For instructions that have more than one child, but only one of them needs to be replaced and the other ones are from unrelated sub-graphs, we can add them from the map to the queue when it empties, and try to process them. If this ends up generates a shape mismatch it will error out as normal.

Edit:

For instructions that have more than one child, but only one of them needs to be replaced and the other ones are from unrelated sub-graphs, we can add them from the map to the queue when it empties, and try to process them. If this ends up generates a shape mismatch it will error out as normal.

I just realized that there might still be a dependency between the instructions that needs to be partially replaced, and the current version may not be able to capture that..

@pfultz2
Copy link
Collaborator

pfultz2 commented Oct 25, 2024

I would think instead you would check if the inputs reaches the instruction and then add that to a revisit queue:

void instruction::replace(const shape& r)
{
    if(r != result)
    {
        result = r;
        std::deque<instruction_ref> q(output.begin(), output.end());
        std::deque<instruction_ref> revisit;
        std::unordered_set<instruction_ref> visited;
        while(not q.empty())
        {
            instruction_ref ins = q.front();
            q.pop_front();
            if(not visited.insert(ins).second)
                continue;
            assert(ins->name() == "@return" or ins->name().front() != '@');
            shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args);
            if(new_r != ins->result)
            {
                ins->result = new_r;
                for(auto out:ins->outputs())
                {
                    if(any_of(out->inputs(), [&](instruction_ref x) { return x != ins and reaches(ins, x); }))
                    {
                        revisit.push_back(out);
                    }
                    else
                    {
                        q.push_back(ins);
                    }
                }
            }
            if(q.empty())
            {
                q.insert(q.end(), revisit.begin(), revisit.end());
                revisit.clear();
            }
        }
    }
}

This would fix the simple case you presented but I am not sure it would handle more complicated cases.

@pfultz2
Copy link
Collaborator

pfultz2 commented Oct 25, 2024

Actually, I think it might be much simpler if we just use the order in the instruction list as that should already be in order. So we could just use a priority_queue instead:

struct replace_shape_order
{
    instruction_ref start;

    std::size_t location(instruction_ref x) const
    {
        return std::distance(start, x);
    }

    bool operator()(instruction_ref x, instruction_ref y) const
    {
        return location(x) > location(y);
    }
};

void instruction::replace(const shape& r)
{
    if(r != result)
    {
        result = r;
        auto start = std::find_if(output.front()->inputs().begin(), output.front()->inputs().end(), [&](instruction_ref x) {
            return this == as_address(x);
        });
        assert(as_address(*start) == this);
        std::priority_queue<instruction_ref, std::vector<instruction_ref>, replace_shape_order> q(output, replace_shape_order{*start});
        while(not q.empty())
        {
            instruction_ref ins = q.top();
            q.pop();
            assert(ins->name() == "@return" or ins->name().front() != '@');
            shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args);
            if(new_r != ins->result)
            {
                ins->result = new_r;
                std::copy(ins->output.begin(), ins->output.end(), push_inserter(q));
            }
        }
    }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants