-
Notifications
You must be signed in to change notification settings - Fork 215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Not for land] Added changes for GPT-2 perf #533
Draft
awgu
wants to merge
5
commits into
gh/awgu/15/base
Choose a base branch
from
gh/awgu/15/head
base: gh/awgu/15/base
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Commits on Aug 19, 2024
-
Configuration menu - View commit details
-
Copy full SHA for 6ad9afa - Browse repository at this point
Copy the full SHA 6ad9afaView commit details -
Update on "[Not for land] Added changes for GPT-2 perf"
Credit: felipemello1 for most of the work here (especially around chunked cross entropy) Running on 4xH100s: Without these changes (`torch.compile`), the max local batch size is 5: ``` [rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:10:33,811 - root - INFO - step: 1 loss: 12.2365 memory: 81.67GiB(85.93%) wps: 5,380 mfu: 1.09% [rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10 loss: 12.1951 memory: 81.67GiB(85.93%) wps: 111,770 mfu: 22.68% [rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20 loss: 11.9455 memory: 81.67GiB(85.93%) wps: 111,714 mfu: 22.67% [rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30 loss: 11.0407 memory: 81.67GiB(85.93%) wps: 112,194 mfu: 22.76% [rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40 loss: 9.9520 memory: 81.67GiB(85.93%) wps: 112,109 mfu: 22.75% [rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50 loss: 9.3392 memory: 81.67GiB(85.93%) wps: 112,218 mfu: 22.77% [rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60 loss: 8.7255 memory: 81.67GiB(85.93%) wps: 112,198 mfu: 22.77% [rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70 loss: 8.1659 memory: 81.67GiB(85.93%) wps: 112,234 mfu: 22.77% [rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80 loss: 7.8037 memory: 81.67GiB(85.93%) wps: 111,802 mfu: 22.68% [rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90 loss: 7.5327 memory: 81.67GiB(85.93%) wps: 111,937 mfu: 22.71% [rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100 loss: 7.3730 memory: 81.67GiB(85.93%) wps: 111,803 mfu: 22.69% ``` Without these changes (no `torch.compile`), local batch size 5: ``` [rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:2024-08-19 14:24:38,558 - root - INFO - step: 1 loss: 12.2581 memory: 86.47GiB(90.99%) wps: 6,393 mfu: 1.30% [rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10 loss: 12.2099 memory: 86.48GiB(90.99%) wps: 98,305 mfu: 19.95% [rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20 loss: 11.9421 memory: 86.48GiB(90.99%) wps: 98,230 mfu: 19.93% [rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30 loss: 11.0090 memory: 86.48GiB(90.99%) wps: 98,435 mfu: 19.97% [rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40 loss: 9.9780 memory: 86.48GiB(90.99%) wps: 99,064 mfu: 20.10% [rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50 loss: 9.3572 memory: 86.48GiB(90.99%) wps: 98,813 mfu: 20.05% [rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60 loss: 8.7479 memory: 86.48GiB(90.99%) wps: 96,567 mfu: 19.59% [rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70 loss: 8.1769 memory: 86.48GiB(90.99%) wps: 98,604 mfu: 20.01% [rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80 loss: 7.8070 memory: 86.48GiB(90.99%) wps: 98,579 mfu: 20.00% [rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90 loss: 7.5329 memory: 86.48GiB(90.99%) wps: 98,743 mfu: 20.04% [rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100 loss: 7.3700 memory: 86.48GiB(90.99%) wps: 98,818 mfu: 20.05% ``` With these changes, we can use local batch size 16: ``` [rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:16:15,523 - root - INFO - step: 1 loss: 12.2386 memory: 72.29GiB(76.06%) wps: 21,887 mfu: 4.44% [rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10 loss: 12.1966 memory: 72.30GiB(76.07%) wps: 168,174 mfu: 34.12% [rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20 loss: 11.9229 memory: 72.30GiB(76.07%) wps: 168,196 mfu: 34.13% [rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30 loss: 10.9399 memory: 72.30GiB(76.07%) wps: 168,144 mfu: 34.12% [rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40 loss: 9.8742 memory: 72.30GiB(76.07%) wps: 167,898 mfu: 34.07% [rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50 loss: 9.2517 memory: 72.30GiB(76.07%) wps: 168,130 mfu: 34.11% [rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60 loss: 8.6441 memory: 72.30GiB(76.07%) wps: 168,435 mfu: 34.18% [rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70 loss: 8.0827 memory: 72.30GiB(76.07%) wps: 168,927 mfu: 34.28% [rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80 loss: 7.7330 memory: 72.30GiB(76.07%) wps: 168,772 mfu: 34.24% [rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90 loss: 7.4835 memory: 72.30GiB(76.07%) wps: 162,008 mfu: 32.87% [rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100 loss: 7.3274 memory: 72.30GiB(76.07%) wps: 167,963 mfu: 34.08% ``` 22.7% MFU -> 34.1% MFU [ghstack-poisoned]
Configuration menu - View commit details
-
Copy full SHA for b4a24d2 - Browse repository at this point
Copy the full SHA b4a24d2View commit details
Commits on Sep 7, 2024
-
Update on "[Not for land] Added changes for GPT-2 perf"
Credit: felipemello1 for most of the work here (especially around chunked cross entropy) Running on 4xH100s: Without these changes (`torch.compile`), the max local batch size is 5: ``` [rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:10:33,811 - root - INFO - step: 1 loss: 12.2365 memory: 81.67GiB(85.93%) wps: 5,380 mfu: 1.09% [rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10 loss: 12.1951 memory: 81.67GiB(85.93%) wps: 111,770 mfu: 22.68% [rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20 loss: 11.9455 memory: 81.67GiB(85.93%) wps: 111,714 mfu: 22.67% [rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30 loss: 11.0407 memory: 81.67GiB(85.93%) wps: 112,194 mfu: 22.76% [rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40 loss: 9.9520 memory: 81.67GiB(85.93%) wps: 112,109 mfu: 22.75% [rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50 loss: 9.3392 memory: 81.67GiB(85.93%) wps: 112,218 mfu: 22.77% [rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60 loss: 8.7255 memory: 81.67GiB(85.93%) wps: 112,198 mfu: 22.77% [rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70 loss: 8.1659 memory: 81.67GiB(85.93%) wps: 112,234 mfu: 22.77% [rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80 loss: 7.8037 memory: 81.67GiB(85.93%) wps: 111,802 mfu: 22.68% [rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90 loss: 7.5327 memory: 81.67GiB(85.93%) wps: 111,937 mfu: 22.71% [rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100 loss: 7.3730 memory: 81.67GiB(85.93%) wps: 111,803 mfu: 22.69% ``` Without these changes (no `torch.compile`), local batch size 5: ``` [rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:2024-08-19 14:24:38,558 - root - INFO - step: 1 loss: 12.2581 memory: 86.47GiB(90.99%) wps: 6,393 mfu: 1.30% [rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10 loss: 12.2099 memory: 86.48GiB(90.99%) wps: 98,305 mfu: 19.95% [rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20 loss: 11.9421 memory: 86.48GiB(90.99%) wps: 98,230 mfu: 19.93% [rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30 loss: 11.0090 memory: 86.48GiB(90.99%) wps: 98,435 mfu: 19.97% [rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40 loss: 9.9780 memory: 86.48GiB(90.99%) wps: 99,064 mfu: 20.10% [rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50 loss: 9.3572 memory: 86.48GiB(90.99%) wps: 98,813 mfu: 20.05% [rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60 loss: 8.7479 memory: 86.48GiB(90.99%) wps: 96,567 mfu: 19.59% [rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70 loss: 8.1769 memory: 86.48GiB(90.99%) wps: 98,604 mfu: 20.01% [rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80 loss: 7.8070 memory: 86.48GiB(90.99%) wps: 98,579 mfu: 20.00% [rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90 loss: 7.5329 memory: 86.48GiB(90.99%) wps: 98,743 mfu: 20.04% [rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100 loss: 7.3700 memory: 86.48GiB(90.99%) wps: 98,818 mfu: 20.05% ``` With these changes, we can use local batch size 16: ``` [rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:16:15,523 - root - INFO - step: 1 loss: 12.2386 memory: 72.29GiB(76.06%) wps: 21,887 mfu: 4.44% [rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10 loss: 12.1966 memory: 72.30GiB(76.07%) wps: 168,174 mfu: 34.12% [rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20 loss: 11.9229 memory: 72.30GiB(76.07%) wps: 168,196 mfu: 34.13% [rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30 loss: 10.9399 memory: 72.30GiB(76.07%) wps: 168,144 mfu: 34.12% [rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40 loss: 9.8742 memory: 72.30GiB(76.07%) wps: 167,898 mfu: 34.07% [rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50 loss: 9.2517 memory: 72.30GiB(76.07%) wps: 168,130 mfu: 34.11% [rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60 loss: 8.6441 memory: 72.30GiB(76.07%) wps: 168,435 mfu: 34.18% [rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70 loss: 8.0827 memory: 72.30GiB(76.07%) wps: 168,927 mfu: 34.28% [rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80 loss: 7.7330 memory: 72.30GiB(76.07%) wps: 168,772 mfu: 34.24% [rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90 loss: 7.4835 memory: 72.30GiB(76.07%) wps: 162,008 mfu: 32.87% [rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100 loss: 7.3274 memory: 72.30GiB(76.07%) wps: 167,963 mfu: 34.08% ``` 22.7% MFU -> 34.1% MFU [ghstack-poisoned]
Configuration menu - View commit details
-
Copy full SHA for b6a84e4 - Browse repository at this point
Copy the full SHA b6a84e4View commit details -
Update on "[Not for land] Added changes for GPT-2 perf"
Credit: felipemello1 for most of the work here (especially around chunked cross entropy) Running on 4xH100s: Without these changes (`torch.compile`), the max local batch size is 5: ``` [rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:10:33,811 - root - INFO - step: 1 loss: 12.2365 memory: 81.67GiB(85.93%) wps: 5,380 mfu: 1.09% [rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10 loss: 12.1951 memory: 81.67GiB(85.93%) wps: 111,770 mfu: 22.68% [rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20 loss: 11.9455 memory: 81.67GiB(85.93%) wps: 111,714 mfu: 22.67% [rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30 loss: 11.0407 memory: 81.67GiB(85.93%) wps: 112,194 mfu: 22.76% [rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40 loss: 9.9520 memory: 81.67GiB(85.93%) wps: 112,109 mfu: 22.75% [rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50 loss: 9.3392 memory: 81.67GiB(85.93%) wps: 112,218 mfu: 22.77% [rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60 loss: 8.7255 memory: 81.67GiB(85.93%) wps: 112,198 mfu: 22.77% [rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70 loss: 8.1659 memory: 81.67GiB(85.93%) wps: 112,234 mfu: 22.77% [rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80 loss: 7.8037 memory: 81.67GiB(85.93%) wps: 111,802 mfu: 22.68% [rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90 loss: 7.5327 memory: 81.67GiB(85.93%) wps: 111,937 mfu: 22.71% [rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100 loss: 7.3730 memory: 81.67GiB(85.93%) wps: 111,803 mfu: 22.69% ``` <details> <summary> Without these changes, no compile </summary> Without these changes (no `torch.compile`), local batch size 5: ``` [rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:2024-08-19 14:24:38,558 - root - INFO - step: 1 loss: 12.2581 memory: 86.47GiB(90.99%) wps: 6,393 mfu: 1.30% [rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10 loss: 12.2099 memory: 86.48GiB(90.99%) wps: 98,305 mfu: 19.95% [rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20 loss: 11.9421 memory: 86.48GiB(90.99%) wps: 98,230 mfu: 19.93% [rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30 loss: 11.0090 memory: 86.48GiB(90.99%) wps: 98,435 mfu: 19.97% [rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40 loss: 9.9780 memory: 86.48GiB(90.99%) wps: 99,064 mfu: 20.10% [rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50 loss: 9.3572 memory: 86.48GiB(90.99%) wps: 98,813 mfu: 20.05% [rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60 loss: 8.7479 memory: 86.48GiB(90.99%) wps: 96,567 mfu: 19.59% [rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70 loss: 8.1769 memory: 86.48GiB(90.99%) wps: 98,604 mfu: 20.01% [rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80 loss: 7.8070 memory: 86.48GiB(90.99%) wps: 98,579 mfu: 20.00% [rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90 loss: 7.5329 memory: 86.48GiB(90.99%) wps: 98,743 mfu: 20.04% [rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100 loss: 7.3700 memory: 86.48GiB(90.99%) wps: 98,818 mfu: 20.05% ``` </details> With these changes (`torch.compile`), local batch size 32: ``` [rank0]:2024-09-06 19:48:58,342 - root - INFO - Training starts at step 1, with local batch size 32, global batch size 128, sequence length 8192, total steps 50 (warmup 200) [rank0]:2024-09-06 19:49:08,904 - root - INFO - step: 1 loss: 12.2442 memory: 79.40GiB(83.54%) wps: 24,819 mfu: 5.04% [rank0]:2024-09-06 19:49:08,904 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-09-06 19:49:23,127 - root - INFO - step: 10 loss: 12.1998 memory: 80.81GiB(85.03%) wps: 165,880 mfu: 33.66% [rank0]:2024-09-06 19:49:38,946 - root - INFO - step: 20 loss: 11.9284 memory: 80.81GiB(85.03%) wps: 165,732 mfu: 33.63% [rank0]:2024-09-06 19:49:54,764 - root - INFO - step: 30 loss: 10.9587 memory: 80.81GiB(85.03%) wps: 165,733 mfu: 33.63% [rank0]:2024-09-06 19:50:10,566 - root - INFO - step: 40 loss: 9.8493 memory: 80.81GiB(85.03%) wps: 165,904 mfu: 33.66% [rank0]:2024-09-06 19:50:26,973 - root - INFO - step: 50 loss: 9.2317 memory: 80.81GiB(85.03%) wps: 159,786 mfu: 32.42% ``` <details> <summary> Old Results </summary> With these changes, we can use local batch size 16: ``` [rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:16:15,523 - root - INFO - step: 1 loss: 12.2386 memory: 72.29GiB(76.06%) wps: 21,887 mfu: 4.44% [rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10 loss: 12.1966 memory: 72.30GiB(76.07%) wps: 168,174 mfu: 34.12% [rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20 loss: 11.9229 memory: 72.30GiB(76.07%) wps: 168,196 mfu: 34.13% [rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30 loss: 10.9399 memory: 72.30GiB(76.07%) wps: 168,144 mfu: 34.12% [rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40 loss: 9.8742 memory: 72.30GiB(76.07%) wps: 167,898 mfu: 34.07% [rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50 loss: 9.2517 memory: 72.30GiB(76.07%) wps: 168,130 mfu: 34.11% [rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60 loss: 8.6441 memory: 72.30GiB(76.07%) wps: 168,435 mfu: 34.18% [rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70 loss: 8.0827 memory: 72.30GiB(76.07%) wps: 168,927 mfu: 34.28% [rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80 loss: 7.7330 memory: 72.30GiB(76.07%) wps: 168,772 mfu: 34.24% [rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90 loss: 7.4835 memory: 72.30GiB(76.07%) wps: 162,008 mfu: 32.87% [rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100 loss: 7.3274 memory: 72.30GiB(76.07%) wps: 167,963 mfu: 34.08% ``` 22.7% MFU -> 34.1% MFU </details> [ghstack-poisoned]
Configuration menu - View commit details
-
Copy full SHA for 8b27669 - Browse repository at this point
Copy the full SHA 8b27669View commit details -
Update on "[Not for land] Added changes for GPT-2 perf"
Credit: felipemello1 for the previous token chunked cross entropy Credit: Chillee for the new token chunked cross entropy Running on 4xH100s: Without these changes (`torch.compile`), the max local batch size is 5: ``` [rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:2024-08-19 11:10:33,811 - root - INFO - step: 1 loss: 12.2365 memory: 81.67GiB(85.93%) wps: 5,380 mfu: 1.09% [rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10 loss: 12.1951 memory: 81.67GiB(85.93%) wps: 111,770 mfu: 22.68% [rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20 loss: 11.9455 memory: 81.67GiB(85.93%) wps: 111,714 mfu: 22.67% [rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30 loss: 11.0407 memory: 81.67GiB(85.93%) wps: 112,194 mfu: 22.76% [rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40 loss: 9.9520 memory: 81.67GiB(85.93%) wps: 112,109 mfu: 22.75% [rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50 loss: 9.3392 memory: 81.67GiB(85.93%) wps: 112,218 mfu: 22.77% [rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60 loss: 8.7255 memory: 81.67GiB(85.93%) wps: 112,198 mfu: 22.77% [rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70 loss: 8.1659 memory: 81.67GiB(85.93%) wps: 112,234 mfu: 22.77% [rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80 loss: 7.8037 memory: 81.67GiB(85.93%) wps: 111,802 mfu: 22.68% [rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90 loss: 7.5327 memory: 81.67GiB(85.93%) wps: 111,937 mfu: 22.71% [rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100 loss: 7.3730 memory: 81.67GiB(85.93%) wps: 111,803 mfu: 22.69% ``` <details> <summary> Without these changes, no compile </summary> Without these changes (no `torch.compile`), local batch size 5: ``` [rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200) [rank0]:2024-08-19 14:24:38,558 - root - INFO - step: 1 loss: 12.2581 memory: 86.47GiB(90.99%) wps: 6,393 mfu: 1.30% [rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10 loss: 12.2099 memory: 86.48GiB(90.99%) wps: 98,305 mfu: 19.95% [rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20 loss: 11.9421 memory: 86.48GiB(90.99%) wps: 98,230 mfu: 19.93% [rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30 loss: 11.0090 memory: 86.48GiB(90.99%) wps: 98,435 mfu: 19.97% [rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40 loss: 9.9780 memory: 86.48GiB(90.99%) wps: 99,064 mfu: 20.10% [rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50 loss: 9.3572 memory: 86.48GiB(90.99%) wps: 98,813 mfu: 20.05% [rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60 loss: 8.7479 memory: 86.48GiB(90.99%) wps: 96,567 mfu: 19.59% [rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70 loss: 8.1769 memory: 86.48GiB(90.99%) wps: 98,604 mfu: 20.01% [rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80 loss: 7.8070 memory: 86.48GiB(90.99%) wps: 98,579 mfu: 20.00% [rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90 loss: 7.5329 memory: 86.48GiB(90.99%) wps: 98,743 mfu: 20.04% [rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100 loss: 7.3700 memory: 86.48GiB(90.99%) wps: 98,818 mfu: 20.05% ``` </details> With these changes (`torch.compile`), local batch size 32: ``` [rank0]:2024-09-06 19:48:58,342 - root - INFO - Training starts at step 1, with local batch size 32, global batch size 128, sequence length 8192, total steps 50 (warmup 200) [rank0]:2024-09-06 19:49:08,904 - root - INFO - step: 1 loss: 12.2442 memory: 79.40GiB(83.54%) wps: 24,819 mfu: 5.04% [rank0]:2024-09-06 19:49:08,904 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-09-06 19:49:23,127 - root - INFO - step: 10 loss: 12.1998 memory: 80.81GiB(85.03%) wps: 165,880 mfu: 33.66% [rank0]:2024-09-06 19:49:38,946 - root - INFO - step: 20 loss: 11.9284 memory: 80.81GiB(85.03%) wps: 165,732 mfu: 33.63% [rank0]:2024-09-06 19:49:54,764 - root - INFO - step: 30 loss: 10.9587 memory: 80.81GiB(85.03%) wps: 165,733 mfu: 33.63% [rank0]:2024-09-06 19:50:10,566 - root - INFO - step: 40 loss: 9.8493 memory: 80.81GiB(85.03%) wps: 165,904 mfu: 33.66% [rank0]:2024-09-06 19:50:26,973 - root - INFO - step: 50 loss: 9.2317 memory: 80.81GiB(85.03%) wps: 159,786 mfu: 32.42% ``` <details> <summary> Old Results </summary> With these changes, we can use local batch size 16: ``` [rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200) [rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. [rank0]: warnings.warn( [rank0]:2024-08-19 11:16:15,523 - root - INFO - step: 1 loss: 12.2386 memory: 72.29GiB(76.06%) wps: 21,887 mfu: 4.44% [rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10 loss: 12.1966 memory: 72.30GiB(76.07%) wps: 168,174 mfu: 34.12% [rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20 loss: 11.9229 memory: 72.30GiB(76.07%) wps: 168,196 mfu: 34.13% [rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30 loss: 10.9399 memory: 72.30GiB(76.07%) wps: 168,144 mfu: 34.12% [rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40 loss: 9.8742 memory: 72.30GiB(76.07%) wps: 167,898 mfu: 34.07% [rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50 loss: 9.2517 memory: 72.30GiB(76.07%) wps: 168,130 mfu: 34.11% [rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60 loss: 8.6441 memory: 72.30GiB(76.07%) wps: 168,435 mfu: 34.18% [rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70 loss: 8.0827 memory: 72.30GiB(76.07%) wps: 168,927 mfu: 34.28% [rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80 loss: 7.7330 memory: 72.30GiB(76.07%) wps: 168,772 mfu: 34.24% [rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90 loss: 7.4835 memory: 72.30GiB(76.07%) wps: 162,008 mfu: 32.87% [rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100 loss: 7.3274 memory: 72.30GiB(76.07%) wps: 167,963 mfu: 34.08% ``` 22.7% MFU -> 34.1% MFU </details> [ghstack-poisoned]
Configuration menu - View commit details
-
Copy full SHA for e0634d9 - Browse repository at this point
Copy the full SHA e0634d9View commit details
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.