-
Notifications
You must be signed in to change notification settings - Fork 346
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Clean-up unnecessary warnings (including update to PyTorch 2.0) (#581)
Summary: This PR is a collection of smaller fixes that will save us some deprecation issues in the future ## 1. Updating to PyTorch 2.0 **Key files: grad_sample/functorch.py, requirements.txt** `functorch` has been a part of core PyTorch since 1.13. Now they're going a step further and changing the API, while deprecating the old one. There's a [guide](https://pytorch.org/docs/master/func.migrating.html) on how to migrate. TL;DR - `make_functional` will no longer be part of the API, with `torch.func.functional_call()` being (non drop-in) replacement. They key difference for us is `make_functional()` creates a fresh copy of the module, while `functional_call()` uses existing module. As a matter of fact, we need the fresh copy (otherwise all the hooks start firing and you enter nested madness), so I've copy-pasted a [gist](https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf) from the official guide on how to get a full replacement for `make_functional`. ## 2. New mechanism for gradient accumulation detection **Key file: privacy_engine.py, grad_sample_module.py** As [reported](https://discuss.pytorch.org/t/gan-raises-userwarning-using-a-non-full-backward-hook-when-the-forward-contains-multiple/175638/2) on the forum, clients are still getting "non-full backward hook" warning even when using `grad_sample_mode="ew"`. Naturally, `functorch` and `hooks` modes rely on backward hooks and can't be migrated to full hooks because [reasons](#328 (comment)). However, `ew` doesn't rely on hooks and it's unclear why the message should appear. The reason, however, is simple. If the client is using poisson sampling we add an extra check to prohibit gradient accumulation (two poisson batches combined is not a poisson batch), and we do that by the means of backward hooks. ~In this case, backward hook serves a simple purpose and there shouldn't be any problems with migrating to the new method, however that involved changing the checking method. That's because `register_backward_hook` is called *after* hooks on submodule, but `register_full_backward_hook` is called before.~ Strikethrough solution didn't work, because hook order execution is weird for complex graphs, e.g. for GANs. For example, if your forward call looks like this: ``` Discriminator(Generator(x)) ``` then top-level module hook will precede submodule's hooks for `Generator`, but not for `Discriminator` As such, I've realised that gradient accumulation is not even supported in `ExpandedWeights`, so we don't have to worry about that. And the other two modes are both hooks-based, so we can just check the accumulation in the existing backward hook, no need for an extra hook. Deleted some code, profit. ## 3. Refactoring `wrap_collate_with_empty` to please pickle Now here're two facts I didn't know before 1) You can't pickle a nested function, e.g. you can't do the following ```python def foo(): def bar(): <...> return bar pickle.dump(foo(), ...) ``` 2) Whether or not `multiprocessing` uses pickle is python- and platform- dependant. This affects our tests when we test `DataLoader` with multiple workers. As such, our data loaders tests: * Pass on CircleCI with python3.9 * Fail on my local machine with python3.9 * Pass on my local machine with python3.7 I'm not sure how cow common the issue is, but it's safer to just refactor `wrap_collate_with_empty` to avoid nested functions. ## 4. Fix benchmark tests We don't really run `benchmarks/tests` on a regular basis, and some of them were broken since we've upgraded to PyTorch 1.13 (`API_CUTOFF_VERSION` doesn't exist anymore) ## 4. Fix flake8 config Flake8 config no [longer support](https://flake8.pycqa.org/en/latest/user/configuration.html) inline comments, fix is due Pull Request resolved: #581 Reviewed By: alexandresablayrolles Differential Revision: D44749760 Pulled By: ffuuugor fbshipit-source-id: cf225f4134c049da4ee2eef53e1af3ef54d090bf
- Loading branch information
1 parent
fc3fd6b
commit e8bc932
Showing
13 changed files
with
277 additions
and
177 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.