Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convergence problem #9

Open
yangpc615 opened this issue Dec 18, 2019 · 6 comments
Open

convergence problem #9

yangpc615 opened this issue Dec 18, 2019 · 6 comments
Assignees
Labels
question Further information is requested

Comments

@yangpc615
Copy link

yangpc615 commented Dec 18, 2019

        checkpoint = (j < checkpoint_stop)
        if checkpoint:
            chk = Checkpointing(partition, batch)
            task = Task(streams[i], compute=chk.checkpoint, finalize=chk.recompute)
            del chk

        else:
            def compute(batch: Batch = batch, partition: nn.Sequential = partition) -> Batch:
                return batch.call(partition)
            task = Task(streams[i], compute=compute, finalize=None)
            del compute

When I used the second method of compute not checkpoint, I found my the effect of my network become worse and it is proportional to the number of divisions.

@sublee sublee self-assigned this Dec 18, 2019
@sublee sublee added the question Further information is requested label Dec 18, 2019
@sublee
Copy link
Contributor

sublee commented Dec 18, 2019

Hi @yangpc615, thanks for the report.

I have a few questions to understand the case.

  • Does your network highly depend on BatchNorm or any algorithm regarding the batch dimension?
  • Is there no convergence problem if you enable checkpointing?
  • Can you explain more information for your network anything else?

@yangpc615
Copy link
Author

yangpc615 commented Dec 18, 2019

thanks for your reply,do you know mmdetection ?

  • I tried to apply the torchgpipe to HybridTaskCascade class of mmdetection,
  • After each convolution layer there will be a BatchNorm.
  • Now my network doesn't work by the checkpoint way.

@yangpc615
Copy link
Author

And I want to know how to update network by compute method not checkpoint in torchgpipe.

@sublee
Copy link
Contributor

sublee commented Dec 19, 2019

@yangpc615 Did you mean that your network doesn't converge both with or without checkpointing?

Anyways, if the network highly relies on BatchNorm, a large number of micro-batches may affect training just like DataParallel. See the trade-off of a number of micro-batches. There's an option for this case in GPipe. See "Deferred Batch Normalization" to get more details.

@yangpc615
Copy link
Author

Thank you. In addition I don't understand the following code:

def depend(fork_from: Batch, join_to: Batch) -> None:
    fork_from[0], phony = fork(fork_from[0])
    join_to[0] = join(join_to[0], phony)

What are functions of them and what relation is their functions with the following code:

    def recompute(self, batch: Batch) -> None:
        """Applies :class:`Recompute` to the batch in place."""
        input_atomic = self.batch.atomic
        input = tuple(self.batch)

        # batch[0] is always requiring grad, because it has been passed
        # checkpoint with a phony requiring grad.
        batch[0], phony = fork(batch[0])
        phony = Recompute.apply(phony, self.recomputed, self.rng_states,
                                self.function, input_atomic, *input)
        batch[0] = join(batch[0], phony) 

@sublee
Copy link
Contributor

sublee commented Dec 25, 2019

@yangpc615 That is a good question. However, I recommend making a separate issue for a new question not related to the convergence problem.

fork and join makes an arbitrary dependency on an autograd graph by an empty tensor called phony. It forces the autograd engine to follow our desired execution order. Recompute should be executed at the specific moment in backward pass, but it is not related to the actual gradient flow. Here comes phony which is an empty tensor with size 0. We use it to avoid unnecessary gradient accumulation.

           +-----------------------------+
           |                             |
... ---> fork - - - > Recompute - - - > join ---> ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants