-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Not for land] Added changes for GPT-2 perf #533
base: gh/awgu/15/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
ghstack-source-id: 82808b1e55456ddc3df041231d965a5666b5b465 Pull Request resolved: #533
Credit: felipemello1 for most of the work here (especially around chunked cross entropy) Running on 4xH100s: Without these changes (`torch.compile`), the max local batch size is 5: ``` [rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:10:33,811 - root - INFO - step: 1 loss: 12.2365 memory: 81.67GiB(85.93%) wps: 5,380 mfu: 1.09% [rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10 loss: 12.1951 memory: 81.67GiB(85.93%) wps: 111,770 mfu: 22.68% [rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20 loss: 11.9455 memory: 81.67GiB(85.93%) wps: 111,714 mfu: 22.67% [rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30 loss: 11.0407 memory: 81.67GiB(85.93%) wps: 112,194 mfu: 22.76% [rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40 loss: 9.9520 memory: 81.67GiB(85.93%) wps: 112,109 mfu: 22.75% [rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50 loss: 9.3392 memory: 81.67GiB(85.93%) wps: 112,218 mfu: 22.77% [rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60 loss: 8.7255 memory: 81.67GiB(85.93%) wps: 112,198 mfu: 22.77% [rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70 loss: 8.1659 memory: 81.67GiB(85.93%) wps: 112,234 mfu: 22.77% [rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80 loss: 7.8037 memory: 81.67GiB(85.93%) wps: 111,802 mfu: 22.68% [rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90 loss: 7.5327 memory: 81.67GiB(85.93%) wps: 111,937 mfu: 22.71% [rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100 loss: 7.3730 memory: 81.67GiB(85.93%) wps: 111,803 mfu: 22.69% ``` Without these changes (no `torch.compile`), local batch size 5: ``` [rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:2024-08-19 14:24:38,558 - root - INFO - step: 1 loss: 12.2581 memory: 86.47GiB(90.99%) wps: 6,393 mfu: 1.30% [rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10 loss: 12.2099 memory: 86.48GiB(90.99%) wps: 98,305 mfu: 19.95% [rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20 loss: 11.9421 memory: 86.48GiB(90.99%) wps: 98,230 mfu: 19.93% [rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30 loss: 11.0090 memory: 86.48GiB(90.99%) wps: 98,435 mfu: 19.97% [rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40 loss: 9.9780 memory: 86.48GiB(90.99%) wps: 99,064 mfu: 20.10% [rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50 loss: 9.3572 memory: 86.48GiB(90.99%) wps: 98,813 mfu: 20.05% [rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60 loss: 8.7479 memory: 86.48GiB(90.99%) wps: 96,567 mfu: 19.59% [rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70 loss: 8.1769 memory: 86.48GiB(90.99%) wps: 98,604 mfu: 20.01% [rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80 loss: 7.8070 memory: 86.48GiB(90.99%) wps: 98,579 mfu: 20.00% [rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90 loss: 7.5329 memory: 86.48GiB(90.99%) wps: 98,743 mfu: 20.04% [rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100 loss: 7.3700 memory: 86.48GiB(90.99%) wps: 98,818 mfu: 20.05% ``` With these changes, we can use local batch size 16: ``` [rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:16:15,523 - root - INFO - step: 1 loss: 12.2386 memory: 72.29GiB(76.06%) wps: 21,887 mfu: 4.44% [rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10 loss: 12.1966 memory: 72.30GiB(76.07%) wps: 168,174 mfu: 34.12% [rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20 loss: 11.9229 memory: 72.30GiB(76.07%) wps: 168,196 mfu: 34.13% [rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30 loss: 10.9399 memory: 72.30GiB(76.07%) wps: 168,144 mfu: 34.12% [rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40 loss: 9.8742 memory: 72.30GiB(76.07%) wps: 167,898 mfu: 34.07% [rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50 loss: 9.2517 memory: 72.30GiB(76.07%) wps: 168,130 mfu: 34.11% [rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60 loss: 8.6441 memory: 72.30GiB(76.07%) wps: 168,435 mfu: 34.18% [rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70 loss: 8.0827 memory: 72.30GiB(76.07%) wps: 168,927 mfu: 34.28% [rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80 loss: 7.7330 memory: 72.30GiB(76.07%) wps: 168,772 mfu: 34.24% [rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90 loss: 7.4835 memory: 72.30GiB(76.07%) wps: 162,008 mfu: 32.87% [rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100 loss: 7.3274 memory: 72.30GiB(76.07%) wps: 167,963 mfu: 34.08% ``` 22.7% MFU -> 34.1% MFU [ghstack-poisoned]
ghstack-source-id: af1b1c31ed203910bb6a431296097b2c8fe0534e Pull Request resolved: #533
FYI, compiling loss + model together should yield much better results than compiling the model alone, if this is whats happening. instead of doing: do something like:
What we found is that using torch.compile on the cross entropy loss alone has great memory benefits (but not better than chunked): https://fb.workplace.com/groups/257735836456307/permalink/708422718054281/ But the best results for us is compiling only the model + using the chunked cross entropy. If we compile everything, then the results of chunked cross entropy are lost. |
If I try to compile both the output linear and cross entropy loss together instead of just compiling the cross entropy loss, I get OOMs at the same batch size. |
My uneducated guess is that the optimizations they made for CrossEntropyLoss accounts only for the loss being compiled on its own. Details of their implementation here: https://fb.workplace.com/groups/257735836456307/permalink/708422718054281/ |
Llama3-8B With these changes:
Baseline:
Moving
|
Credit: felipemello1 for most of the work here (especially around chunked cross entropy) Running on 4xH100s: Without these changes (`torch.compile`), the max local batch size is 5: ``` [rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:10:33,811 - root - INFO - step: 1 loss: 12.2365 memory: 81.67GiB(85.93%) wps: 5,380 mfu: 1.09% [rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10 loss: 12.1951 memory: 81.67GiB(85.93%) wps: 111,770 mfu: 22.68% [rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20 loss: 11.9455 memory: 81.67GiB(85.93%) wps: 111,714 mfu: 22.67% [rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30 loss: 11.0407 memory: 81.67GiB(85.93%) wps: 112,194 mfu: 22.76% [rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40 loss: 9.9520 memory: 81.67GiB(85.93%) wps: 112,109 mfu: 22.75% [rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50 loss: 9.3392 memory: 81.67GiB(85.93%) wps: 112,218 mfu: 22.77% [rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60 loss: 8.7255 memory: 81.67GiB(85.93%) wps: 112,198 mfu: 22.77% [rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70 loss: 8.1659 memory: 81.67GiB(85.93%) wps: 112,234 mfu: 22.77% [rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80 loss: 7.8037 memory: 81.67GiB(85.93%) wps: 111,802 mfu: 22.68% [rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90 loss: 7.5327 memory: 81.67GiB(85.93%) wps: 111,937 mfu: 22.71% [rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100 loss: 7.3730 memory: 81.67GiB(85.93%) wps: 111,803 mfu: 22.69% ``` Without these changes (no `torch.compile`), local batch size 5: ``` [rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:2024-08-19 14:24:38,558 - root - INFO - step: 1 loss: 12.2581 memory: 86.47GiB(90.99%) wps: 6,393 mfu: 1.30% [rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10 loss: 12.2099 memory: 86.48GiB(90.99%) wps: 98,305 mfu: 19.95% [rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20 loss: 11.9421 memory: 86.48GiB(90.99%) wps: 98,230 mfu: 19.93% [rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30 loss: 11.0090 memory: 86.48GiB(90.99%) wps: 98,435 mfu: 19.97% [rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40 loss: 9.9780 memory: 86.48GiB(90.99%) wps: 99,064 mfu: 20.10% [rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50 loss: 9.3572 memory: 86.48GiB(90.99%) wps: 98,813 mfu: 20.05% [rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60 loss: 8.7479 memory: 86.48GiB(90.99%) wps: 96,567 mfu: 19.59% [rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70 loss: 8.1769 memory: 86.48GiB(90.99%) wps: 98,604 mfu: 20.01% [rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80 loss: 7.8070 memory: 86.48GiB(90.99%) wps: 98,579 mfu: 20.00% [rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90 loss: 7.5329 memory: 86.48GiB(90.99%) wps: 98,743 mfu: 20.04% [rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100 loss: 7.3700 memory: 86.48GiB(90.99%) wps: 98,818 mfu: 20.05% ``` With these changes, we can use local batch size 16: ``` [rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:16:15,523 - root - INFO - step: 1 loss: 12.2386 memory: 72.29GiB(76.06%) wps: 21,887 mfu: 4.44% [rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10 loss: 12.1966 memory: 72.30GiB(76.07%) wps: 168,174 mfu: 34.12% [rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20 loss: 11.9229 memory: 72.30GiB(76.07%) wps: 168,196 mfu: 34.13% [rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30 loss: 10.9399 memory: 72.30GiB(76.07%) wps: 168,144 mfu: 34.12% [rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40 loss: 9.8742 memory: 72.30GiB(76.07%) wps: 167,898 mfu: 34.07% [rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50 loss: 9.2517 memory: 72.30GiB(76.07%) wps: 168,130 mfu: 34.11% [rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60 loss: 8.6441 memory: 72.30GiB(76.07%) wps: 168,435 mfu: 34.18% [rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70 loss: 8.0827 memory: 72.30GiB(76.07%) wps: 168,927 mfu: 34.28% [rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80 loss: 7.7330 memory: 72.30GiB(76.07%) wps: 168,772 mfu: 34.24% [rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90 loss: 7.4835 memory: 72.30GiB(76.07%) wps: 162,008 mfu: 32.87% [rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100 loss: 7.3274 memory: 72.30GiB(76.07%) wps: 167,963 mfu: 34.08% ``` 22.7% MFU -> 34.1% MFU [ghstack-poisoned]
ghstack-source-id: 39b91d06c8c1c6398e58a7d8841c4432ba4532c7 Pull Request resolved: #533
Credit: felipemello1 for most of the work here (especially around chunked cross entropy) Running on 4xH100s: Without these changes (`torch.compile`), the max local batch size is 5: ``` [rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:10:33,811 - root - INFO - step: 1 loss: 12.2365 memory: 81.67GiB(85.93%) wps: 5,380 mfu: 1.09% [rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10 loss: 12.1951 memory: 81.67GiB(85.93%) wps: 111,770 mfu: 22.68% [rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20 loss: 11.9455 memory: 81.67GiB(85.93%) wps: 111,714 mfu: 22.67% [rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30 loss: 11.0407 memory: 81.67GiB(85.93%) wps: 112,194 mfu: 22.76% [rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40 loss: 9.9520 memory: 81.67GiB(85.93%) wps: 112,109 mfu: 22.75% [rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50 loss: 9.3392 memory: 81.67GiB(85.93%) wps: 112,218 mfu: 22.77% [rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60 loss: 8.7255 memory: 81.67GiB(85.93%) wps: 112,198 mfu: 22.77% [rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70 loss: 8.1659 memory: 81.67GiB(85.93%) wps: 112,234 mfu: 22.77% [rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80 loss: 7.8037 memory: 81.67GiB(85.93%) wps: 111,802 mfu: 22.68% [rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90 loss: 7.5327 memory: 81.67GiB(85.93%) wps: 111,937 mfu: 22.71% [rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100 loss: 7.3730 memory: 81.67GiB(85.93%) wps: 111,803 mfu: 22.69% ``` <details> <summary> Without these changes, no compile </summary> Without these changes (no `torch.compile`), local batch size 5: ``` [rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:2024-08-19 14:24:38,558 - root - INFO - step: 1 loss: 12.2581 memory: 86.47GiB(90.99%) wps: 6,393 mfu: 1.30% [rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10 loss: 12.2099 memory: 86.48GiB(90.99%) wps: 98,305 mfu: 19.95% [rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20 loss: 11.9421 memory: 86.48GiB(90.99%) wps: 98,230 mfu: 19.93% [rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30 loss: 11.0090 memory: 86.48GiB(90.99%) wps: 98,435 mfu: 19.97% [rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40 loss: 9.9780 memory: 86.48GiB(90.99%) wps: 99,064 mfu: 20.10% [rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50 loss: 9.3572 memory: 86.48GiB(90.99%) wps: 98,813 mfu: 20.05% [rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60 loss: 8.7479 memory: 86.48GiB(90.99%) wps: 96,567 mfu: 19.59% [rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70 loss: 8.1769 memory: 86.48GiB(90.99%) wps: 98,604 mfu: 20.01% [rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80 loss: 7.8070 memory: 86.48GiB(90.99%) wps: 98,579 mfu: 20.00% [rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90 loss: 7.5329 memory: 86.48GiB(90.99%) wps: 98,743 mfu: 20.04% [rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100 loss: 7.3700 memory: 86.48GiB(90.99%) wps: 98,818 mfu: 20.05% ``` </details> With these changes (`torch.compile`), local batch size 32: ``` [rank0]:2024-09-06 19:48:58,342 - root - INFO - Training starts at step 1, with local batch size 32, global batch size 128, sequence length 8192, total steps 50 (warmup 200) [rank0]:2024-09-06 19:49:08,904 - root - INFO - step: 1 loss: 12.2442 memory: 79.40GiB(83.54%) wps: 24,819 mfu: 5.04% [rank0]:2024-09-06 19:49:08,904 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-09-06 19:49:23,127 - root - INFO - step: 10 loss: 12.1998 memory: 80.81GiB(85.03%) wps: 165,880 mfu: 33.66% [rank0]:2024-09-06 19:49:38,946 - root - INFO - step: 20 loss: 11.9284 memory: 80.81GiB(85.03%) wps: 165,732 mfu: 33.63% [rank0]:2024-09-06 19:49:54,764 - root - INFO - step: 30 loss: 10.9587 memory: 80.81GiB(85.03%) wps: 165,733 mfu: 33.63% [rank0]:2024-09-06 19:50:10,566 - root - INFO - step: 40 loss: 9.8493 memory: 80.81GiB(85.03%) wps: 165,904 mfu: 33.66% [rank0]:2024-09-06 19:50:26,973 - root - INFO - step: 50 loss: 9.2317 memory: 80.81GiB(85.03%) wps: 159,786 mfu: 32.42% ``` <details> <summary> Old Results </summary> With these changes, we can use local batch size 16: ``` [rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:16:15,523 - root - INFO - step: 1 loss: 12.2386 memory: 72.29GiB(76.06%) wps: 21,887 mfu: 4.44% [rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10 loss: 12.1966 memory: 72.30GiB(76.07%) wps: 168,174 mfu: 34.12% [rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20 loss: 11.9229 memory: 72.30GiB(76.07%) wps: 168,196 mfu: 34.13% [rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30 loss: 10.9399 memory: 72.30GiB(76.07%) wps: 168,144 mfu: 34.12% [rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40 loss: 9.8742 memory: 72.30GiB(76.07%) wps: 167,898 mfu: 34.07% [rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50 loss: 9.2517 memory: 72.30GiB(76.07%) wps: 168,130 mfu: 34.11% [rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60 loss: 8.6441 memory: 72.30GiB(76.07%) wps: 168,435 mfu: 34.18% [rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70 loss: 8.0827 memory: 72.30GiB(76.07%) wps: 168,927 mfu: 34.28% [rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80 loss: 7.7330 memory: 72.30GiB(76.07%) wps: 168,772 mfu: 34.24% [rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90 loss: 7.4835 memory: 72.30GiB(76.07%) wps: 162,008 mfu: 32.87% [rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100 loss: 7.3274 memory: 72.30GiB(76.07%) wps: 167,963 mfu: 34.08% ``` 22.7% MFU -> 34.1% MFU </details> [ghstack-poisoned]
ghstack-source-id: 0cdcc964f2012f1b0c00e3eeba7eaca14e768629 Pull Request resolved: #533
Credit: felipemello1 for the previous token chunked cross entropy Credit: Chillee for the new token chunked cross entropy Running on 4xH100s: Without these changes (`torch.compile`), the max local batch size is 5: ``` [rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:2024-08-19 11:10:33,811 - root - INFO - step: 1 loss: 12.2365 memory: 81.67GiB(85.93%) wps: 5,380 mfu: 1.09% [rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10 loss: 12.1951 memory: 81.67GiB(85.93%) wps: 111,770 mfu: 22.68% [rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20 loss: 11.9455 memory: 81.67GiB(85.93%) wps: 111,714 mfu: 22.67% [rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30 loss: 11.0407 memory: 81.67GiB(85.93%) wps: 112,194 mfu: 22.76% [rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40 loss: 9.9520 memory: 81.67GiB(85.93%) wps: 112,109 mfu: 22.75% [rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50 loss: 9.3392 memory: 81.67GiB(85.93%) wps: 112,218 mfu: 22.77% [rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60 loss: 8.7255 memory: 81.67GiB(85.93%) wps: 112,198 mfu: 22.77% [rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70 loss: 8.1659 memory: 81.67GiB(85.93%) wps: 112,234 mfu: 22.77% [rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80 loss: 7.8037 memory: 81.67GiB(85.93%) wps: 111,802 mfu: 22.68% [rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90 loss: 7.5327 memory: 81.67GiB(85.93%) wps: 111,937 mfu: 22.71% [rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100 loss: 7.3730 memory: 81.67GiB(85.93%) wps: 111,803 mfu: 22.69% ``` <details> <summary> Without these changes, no compile </summary> Without these changes (no `torch.compile`), local batch size 5: ``` [rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:2024-08-19 14:24:38,558 - root - INFO - step: 1 loss: 12.2581 memory: 86.47GiB(90.99%) wps: 6,393 mfu: 1.30% [rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10 loss: 12.2099 memory: 86.48GiB(90.99%) wps: 98,305 mfu: 19.95% [rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20 loss: 11.9421 memory: 86.48GiB(90.99%) wps: 98,230 mfu: 19.93% [rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30 loss: 11.0090 memory: 86.48GiB(90.99%) wps: 98,435 mfu: 19.97% [rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40 loss: 9.9780 memory: 86.48GiB(90.99%) wps: 99,064 mfu: 20.10% [rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50 loss: 9.3572 memory: 86.48GiB(90.99%) wps: 98,813 mfu: 20.05% [rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60 loss: 8.7479 memory: 86.48GiB(90.99%) wps: 96,567 mfu: 19.59% [rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70 loss: 8.1769 memory: 86.48GiB(90.99%) wps: 98,604 mfu: 20.01% [rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80 loss: 7.8070 memory: 86.48GiB(90.99%) wps: 98,579 mfu: 20.00% [rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90 loss: 7.5329 memory: 86.48GiB(90.99%) wps: 98,743 mfu: 20.04% [rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100 loss: 7.3700 memory: 86.48GiB(90.99%) wps: 98,818 mfu: 20.05% ``` </details> With these changes (`torch.compile`), local batch size 32: ``` [rank0]:2024-09-06 19:48:58,342 - root - INFO - Training starts at step 1, with local batch size 32, global batch size 128, sequence length 8192, total steps 50 (warmup 200) [rank0]:2024-09-06 19:49:08,904 - root - INFO - step: 1 loss: 12.2442 memory: 79.40GiB(83.54%) wps: 24,819 mfu: 5.04% [rank0]:2024-09-06 19:49:08,904 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-09-06 19:49:23,127 - root - INFO - step: 10 loss: 12.1998 memory: 80.81GiB(85.03%) wps: 165,880 mfu: 33.66% [rank0]:2024-09-06 19:49:38,946 - root - INFO - step: 20 loss: 11.9284 memory: 80.81GiB(85.03%) wps: 165,732 mfu: 33.63% [rank0]:2024-09-06 19:49:54,764 - root - INFO - step: 30 loss: 10.9587 memory: 80.81GiB(85.03%) wps: 165,733 mfu: 33.63% [rank0]:2024-09-06 19:50:10,566 - root - INFO - step: 40 loss: 9.8493 memory: 80.81GiB(85.03%) wps: 165,904 mfu: 33.66% [rank0]:2024-09-06 19:50:26,973 - root - INFO - step: 50 loss: 9.2317 memory: 80.81GiB(85.03%) wps: 159,786 mfu: 32.42% ``` <details> <summary> Old Results </summary> With these changes, we can use local batch size 16: ``` [rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:16:15,523 - root - INFO - step: 1 loss: 12.2386 memory: 72.29GiB(76.06%) wps: 21,887 mfu: 4.44% [rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10 loss: 12.1966 memory: 72.30GiB(76.07%) wps: 168,174 mfu: 34.12% [rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20 loss: 11.9229 memory: 72.30GiB(76.07%) wps: 168,196 mfu: 34.13% [rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30 loss: 10.9399 memory: 72.30GiB(76.07%) wps: 168,144 mfu: 34.12% [rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40 loss: 9.8742 memory: 72.30GiB(76.07%) wps: 167,898 mfu: 34.07% [rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50 loss: 9.2517 memory: 72.30GiB(76.07%) wps: 168,130 mfu: 34.11% [rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60 loss: 8.6441 memory: 72.30GiB(76.07%) wps: 168,435 mfu: 34.18% [rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70 loss: 8.0827 memory: 72.30GiB(76.07%) wps: 168,927 mfu: 34.28% [rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80 loss: 7.7330 memory: 72.30GiB(76.07%) wps: 168,772 mfu: 34.24% [rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90 loss: 7.4835 memory: 72.30GiB(76.07%) wps: 162,008 mfu: 32.87% [rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100 loss: 7.3274 memory: 72.30GiB(76.07%) wps: 167,963 mfu: 34.08% ``` 22.7% MFU -> 34.1% MFU </details> [ghstack-poisoned]
ghstack-source-id: ddfb8a972f0332ca9c7bd7ca6072b02df4e1792c Pull Request resolved: #533
Stack from ghstack (oldest at bottom):
Credit: @felipemello1 for the previous token chunked cross entropy
Credit: @Chillee for the new token chunked cross entropy
Running on 4xH100s:
Without these changes (
torch.compile
), the max local batch size is 5:Without these changes, no compile
Without these changes (no
torch.compile
), local batch size 5:With these changes (
torch.compile
), local batch size 32:Old Results
With these changes, we can use local batch size 16:
22.7% MFU -> 34.1% MFU