forked from borisdayma/dalle-mini
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
1742 lines (1578 loc) · 63.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021-2022 The HuggingFace & DALL·E Mini team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Training DALL·E Mini.
Script adapted from run_summarization_flax.py
"""
import io
import logging
import os
import sys
import tempfile
import time
from dataclasses import asdict, dataclass, field
from functools import partial
from pathlib import Path
from typing import Any, Callable, NamedTuple, Optional
import datasets
import flax
import jax
import jax.numpy as jnp
import jaxlib
import numpy as np
import optax
import transformers
import wandb
from datasets import Dataset
from flax import core, struct, traverse_util
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.serialization import from_bytes, to_bytes
from flax.training.common_utils import onehot
from jax.experimental import PartitionSpec, maps
from jax.experimental.compilation_cache import compilation_cache as cc
from jax.experimental.pjit import pjit, with_sharding_constraint
from scalable_shampoo.distributed_shampoo import GraftingType, distributed_shampoo
from tqdm import tqdm
from transformers import HfArgumentParser
import dalle_mini
from dalle_mini.data import Dataset
from dalle_mini.model import (
DalleBart,
DalleBartConfig,
DalleBartTokenizer,
set_partitions,
)
try:
from google.cloud import storage
except:
storage = None
logger = logging.getLogger(__name__)
cc.initialize_cache("jax_cache")
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The model checkpoint for weights initialization. "
"Don't set if you want to train a model from scratch. "
"W&B artifact references are supported in addition to the sources supported by `PreTrainedModel`."
},
)
config_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained config name or path if not the same as model_name_or_path"
},
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
},
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
},
)
restore_state: Optional[bool] = field(
default=False,
metadata={
"help": "Restore optimizer and training state. Can be True (will retrieve associated wandb artifact), a local directory or a Google bucket path."
},
)
dropout: Optional[float] = field(
default=None,
metadata={"help": "Dropout rate. Overwrites config."},
)
activation_dropout: Optional[float] = field(
default=None,
metadata={"help": "Activation dropout rate. Overwrites config."},
)
attention_dropout: Optional[float] = field(
default=None,
metadata={"help": "Attention dropout rate. Overwrites config."},
)
def __post_init__(self):
if self.tokenizer_name is None:
self.tokenizer_name = self.model_name_or_path
assert (
self.tokenizer_name is not None
), "Tokenizer name or model name/path needs to be specified"
if self.restore_state:
assert self.model_name_or_path is not None and (
"/model-" in self.model_name_or_path
), "Restoring state only available with W&B artifact reference"
def get_metadata(self):
if self.model_name_or_path is not None and ":" in self.model_name_or_path:
if jax.process_index() == 0:
artifact = wandb.run.use_artifact(self.model_name_or_path)
else:
artifact = wandb.Api().artifact(self.model_name_or_path)
return artifact.metadata
else:
return dict()
def get_opt_state(self):
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
if self.restore_state is True:
# wandb artifact
state_artifact = self.model_name_or_path.replace(
"/model-", "/state-", 1
)
if jax.process_index() == 0:
artifact = wandb.run.use_artifact(state_artifact)
else:
artifact = wandb.Api().artifact(state_artifact)
if artifact.metadata.get("bucket_path"):
# we will read directly file contents
self.restore_state = artifact.metadata["bucket_path"]
else:
artifact_dir = artifact.download(tmp_dir)
self.restore_state = str(Path(artifact_dir) / "opt_state.msgpack")
if self.restore_state.startswith("gs://"):
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
bucket, blob_name = str(bucket_path).split("/", 1)
assert (
storage is not None
), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
client = storage.Client()
bucket = client.bucket(bucket)
blob = bucket.blob(blob_name)
return blob.download_as_bytes()
with Path(self.restore_state).open("rb") as f:
return f.read()
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
text_column: Optional[str] = field(
default="caption",
metadata={
"help": "The name of the column in the datasets containing the full texts (for summarization)."
},
)
encoding_column: Optional[str] = field(
default="encoding",
metadata={
"help": "The name of the column in the datasets containing the image encodings."
},
)
dataset_repo_or_path: str = field(
default=None,
metadata={"help": "The dataset repository containing encoded files."},
)
train_file: Optional[str] = field(
default=None,
metadata={
"help": "The input training data file (glob & braceexpand acceptable)."
},
)
validation_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input evaluation data file (glob & braceexpand acceptable)."
},
)
# data loading should not be a bottleneck so we use "streaming" mode by default
streaming: Optional[bool] = field(
default=True,
metadata={"help": "Whether to stream the dataset."},
)
use_auth_token: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to use the authentication token for private datasets."
},
)
shard_by_host: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to shard data files by host in multi-host environments."
},
)
blank_caption_prob: Optional[float] = field(
default=0.0,
metadata={
"help": "Probability of removing some captions for classifier-free guidance."
},
)
clip_score_column: Optional[str] = field(
default="clip_score",
metadata={"help": "Column that containts clip score for filtering."},
)
min_clip_score: Optional[float] = field(
default=None,
metadata={"help": "Minimum clip score required."},
)
max_clip_score: Optional[float] = field(
default=None,
metadata={"help": "Maximum clip score required."},
)
filter_column: Optional[str] = field(
default=None,
metadata={"help": "Column that containts classes to be filtered."},
)
filter_value: Optional[str] = field(
default=None,
metadata={"help": "Class value to be kept during filtering."},
)
multi_eval_ds: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to look for multiple validation datasets (local support only)."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples."
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples."
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={
"help": "The number of processes to use for the preprocessing. Not used in streaming mode."
},
)
overwrite_cache: bool = field(
default=False,
metadata={
"help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
},
)
# default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
seed_dataset: int = field(
default=None,
metadata={
"help": "Random seed for the dataset that will be set at the beginning of training."
},
)
def __post_init__(self):
if self.dataset_repo_or_path is None:
raise ValueError("Need a dataset repository or path.")
@dataclass
class TrainingArguments:
"""
Arguments pertaining to training parameters.
"""
output_dir: str = field(
metadata={
"help": "The output directory where the model predictions and checkpoints will be written."
},
)
overwrite_output_dir: bool = field(
default=False,
metadata={
"help": (
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(
default=False, metadata={"help": "Whether to run eval on the validation set."}
)
per_device_train_batch_size: int = field(
default=8,
metadata={"help": "Batch size per data parallel device for training."},
)
per_device_eval_batch_size: Optional[int] = field(
default=None,
metadata={
"help": "Batch size per data parallel device for evaluation. Same as training batch size if not set."
},
)
gradient_accumulation_steps: int = field(
default=1,
metadata={
"help": "Number of updates steps to accumulate before performing an update pass."
},
)
gradient_checkpointing: bool = field(
default=False, metadata={"help": "Use gradient checkpointing."}
)
learning_rate: float = field(
default=5e-5, metadata={"help": "The initial learning rate."}
)
optim: str = field(
default="distributed_shampoo",
metadata={
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
},
)
weight_decay: float = field(
default=0.0, metadata={"help": "Weight decay applied to parameters."}
)
beta1: float = field(
default=0.9,
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
)
beta2: float = field(
default=0.999,
metadata={"help": "Beta2 for for Adam & Distributed Shampoo."},
)
adam_epsilon: float = field(
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
)
max_grad_norm: float = field(
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
)
block_size: int = field(
default=1024,
metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
)
preconditioning_compute_steps: int = field(
default=10, metadata={"help": "Number of steps to update preconditioner."}
)
skip_preconditioning_dim_size_gt: int = field(
default=4096,
metadata={"help": "Max size for preconditioning with Distributed Shampoo."},
)
graft_type: str = field(
default="rmsprop_normalized",
metadata={
"help": "The type of grafting to use. Can be 'rmsprop_normalized' (default), 'rmsprop', 'adagrad', 'adagrad_normalized', 'sgd' or 'sqrt_n'"
},
)
nesterov: bool = field(
default=False,
metadata={"help": "Use Nesterov momentum for Distributed Shampoo."},
)
optim_quantized: bool = field(
default=False,
metadata={
"help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
},
)
shard_shampoo_across: str = field(
default="dp",
metadata={
"help": "Whether to shard the optimizer across data devices (dp), model devices (mp) or both (2d)."
},
)
num_train_epochs: int = field(
default=3, metadata={"help": "Total number of training epochs to perform."}
)
warmup_steps: int = field(
default=0, metadata={"help": "Linear warmup over warmup_steps."}
)
lr_decay: str = field(
default=None,
metadata={
"help": "Decay to be used in the learning rate scheduler. Can be None (default), linear or exponential."
},
)
lr_transition_steps: int = field(
default=None,
metadata={
"help": "Number of transition steps associated with learning rate decay when using exponential decay."
},
)
lr_decay_rate: float = field(
default=None,
metadata={
"help": "Decay rate associated with learning rate when using exponential decay."
},
)
lr_staircase: bool = field(
default=False,
metadata={
"help": "Whether to use staircase or continuous learning rate when using exponential decay."
},
)
lr_offset: int = field(
default=0,
metadata={"help": "Number of steps to offset learning rate and keep it at 0."},
)
logging_steps: int = field(
default=40, metadata={"help": "Log every X updates steps."}
)
eval_steps: int = field(
default=400, metadata={"help": "Run an evaluation every X steps."}
)
save_steps: int = field(
default=4000, metadata={"help": "Save checkpoint every X updates steps."}
)
log_model: bool = field(
default=False,
metadata={"help": "Log model to wandb at `save_steps` frequency."},
)
log_norm_steps: int = field(
default=True,
metadata={"help": "Log parameters and gradients norm at this frequency."},
)
log_histogram_steps: int = field(
default=False,
metadata={
"help": "Log parameters and gradients histograms at this frequency. Slows down training."
},
)
seed_model: int = field(
default=42,
metadata={
"help": "Random seed for the model that will be set at the beginning of training."
},
)
embeddings_only: bool = field(
default=False, metadata={"help": "Train only embedding layers."}
)
init_embeddings: bool = field(
default=False,
metadata={"help": "When training embedding layers, initialize them."},
)
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": "The wandb entity to use (for teams)."},
)
wandb_project: str = field(
default="dalle-mini",
metadata={"help": "The name of the wandb project."},
)
wandb_job_type: str = field(
default="Seq2Seq",
metadata={"help": "The name of the wandb job type."},
)
assert_TPU_available: bool = field(
default=False,
metadata={"help": "Verify that TPU is not in use."},
)
use_vmap_trick: bool = field(
default=True,
metadata={"help": "Verify that TPU is not in use."},
)
mp_devices: Optional[int] = field(
default=1,
metadata={
"help": "Number of devices required for model parallelism. The other dimension of available devices is used for data parallelism."
},
)
dp_devices: int = field(init=False)
def __post_init__(self):
if self.assert_TPU_available:
assert (
jax.local_device_count() == 8
), "TPUs in use, please check running processes"
if self.output_dir.startswith("gs://"):
assert (
storage is not None
), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
assert self.optim in [
"distributed_shampoo",
"adam",
"adafactor",
], f"Selected optimizer not supported: {self.optim}"
if self.optim == "adafactor" and self.weight_decay == 0:
self.weight_decay = None
assert self.graft_type in [
"rmsprop_normalized",
"rmsprop",
"adagrad",
"adagrad_normalized",
"sgd",
"sqrt_n",
], f"Selected graft type not supported: {self.graft_type}"
assert self.lr_decay in [
None,
"linear",
"exponential",
], f"Selected learning rate decay not supported: {self.lr_decay}"
if self.per_device_eval_batch_size is None:
self.per_device_eval_batch_size = self.per_device_train_batch_size
if self.log_norm_steps is True:
self.log_norm_steps = self.logging_steps
if not self.do_train:
self.num_train_epochs = 1
if (
os.path.exists(self.output_dir)
and os.listdir(self.output_dir)
and self.do_train
and not self.overwrite_output_dir
):
raise ValueError(
f"Output directory ({self.output_dir}) already exists and is not empty."
"Use --overwrite_output_dir to overcome."
)
assert self.shard_shampoo_across in [
"dp",
"mp",
"2d",
], f"Shard shampoo across {self.shard_shampoo_across} not supported."
assert (
self.mp_devices > 0
), f"Number of devices for model parallelism must be > 0"
assert (
jax.device_count() % self.mp_devices == 0
), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})."
self.dp_devices = jax.device_count() // self.mp_devices
def split_params(data):
"""Split params between scanned and non-scanned"""
flat = traverse_util.flatten_dict(unfreeze(data))
split = {"standard": {}, "scanned_encoder": {}, "scanned_decoder": {}}
for k, v in flat.items():
if "FlaxBartEncoderLayers" in k:
split["scanned_encoder"][k] = v
elif "FlaxBartDecoderLayers" in k:
split["scanned_decoder"][k] = v
else:
split["standard"][k] = v
# remove empty keys
split = {k: v for k, v in split.items() if v}
for k, v in split.items():
split[k] = freeze(traverse_util.unflatten_dict(v))
return split
def unsplit_params(data):
flat = {}
for k in ["standard", "scanned_encoder", "scanned_decoder"]:
if k in data:
flat.update(traverse_util.flatten_dict(unfreeze(data[k])))
return freeze(traverse_util.unflatten_dict(flat))
def trainable_params(data, embeddings_only):
"""Keep only trainable parameters"""
if not embeddings_only:
return data
data = unfreeze(data)
trainable = {
"lm_head": data["lm_head"],
"model": {
"decoder": {
layer: data["model"]["decoder"][layer]
for layer in [
"embed_positions",
"embed_tokens",
"final_ln",
"layernorm_embedding",
]
}
},
}
return freeze(trainable)
def init_embeddings(model, params):
"""Reinitialize trainable embeddings"""
# Must match params in trainable_params() above
trainable_keypaths = [
"lm_head.kernel",
"model.decoder.embed_positions.embedding",
"model.decoder.embed_tokens.embedding",
"model.decoder.final_ln.bias",
"model.decoder.layernorm_embedding.bias",
"model.decoder.layernorm_embedding.scale",
]
# Note: using private _missing_keys
init_keys = {tuple(k.split(".")) for k in trainable_keypaths}
model._missing_keys = init_keys
return model.init_weights(model.key, model.input_shape, params=params)
def main():
# See all possible arguments by passing the --help flag to this script.
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, TrainingArguments)
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# check arguments
if training_args.mp_devices > jax.local_device_count():
assert (
data_args.seed_dataset is not None
), "Seed dataset must be provided when model is split over multiple hosts"
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Setup logging, we only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}")
# Load dataset
dataset = Dataset(
**asdict(data_args),
do_train=training_args.do_train,
do_eval=training_args.do_eval,
)
logger.info(f"Local TPUs: {jax.local_device_count()}")
logger.info(f"Global TPUs: {jax.device_count()}")
# Set up wandb run
if jax.process_index() == 0:
wandb.init(
entity=training_args.wandb_entity,
project=training_args.wandb_project,
job_type=training_args.wandb_job_type,
config=parser.parse_args(),
)
# Set up our new model config
config_args = {
k: getattr(model_args, k)
for k in ["dropout", "activation_dropout", "attention_dropout"]
if getattr(model_args, k) is not None
}
config_args["gradient_checkpointing"] = training_args.gradient_checkpointing
if model_args.config_name:
config = DalleBartConfig.from_pretrained(model_args.config_name)
else:
config = None
# Load or create new model
if model_args.model_name_or_path:
model, params = DalleBart.from_pretrained(
model_args.model_name_or_path,
config=config,
seed=training_args.seed_model,
dtype=getattr(jnp, model_args.dtype),
_do_init=False,
)
if training_args.embeddings_only and training_args.init_embeddings:
params = init_embeddings(model, params)
else:
model = DalleBart(
config,
seed=training_args.seed_model,
dtype=getattr(jnp, model_args.dtype),
_do_init=False,
)
params = None
for k, v in config_args.items():
setattr(model.config, k, v)
params_shape = model.params_shape_tree
# get model metadata
model_metadata = model_args.get_metadata()
# get PartitionSpec for model params (required to be a dict)
param_spec = set_partitions(params_shape, model.config.use_scan)
params_shape = freeze(params_shape)
if params is not None:
params = freeze(params)
# Load tokenizer
tokenizer = DalleBartTokenizer.from_pretrained(
model_args.tokenizer_name, use_fast=True
)
# Preprocessing the datasets.
# We need to normalize and tokenize inputs and targets.
dataset.preprocess(tokenizer=tokenizer, config=model.config)
# Initialize our training
dropout_rng = jax.random.PRNGKey(training_args.seed_model)
# Store some constant
num_epochs = training_args.num_train_epochs
# batch size
batch_size_per_node_per_grad_step = (
training_args.per_device_train_batch_size
* jax.local_device_count()
// training_args.mp_devices
)
batch_size_per_node = (
batch_size_per_node_per_grad_step * training_args.gradient_accumulation_steps
)
batch_size_per_step = batch_size_per_node * jax.process_count()
eval_batch_size_per_node = (
training_args.per_device_eval_batch_size
* jax.local_device_count()
// training_args.mp_devices
)
eval_batch_size_per_step = eval_batch_size_per_node * jax.process_count()
len_train_dataset, len_eval_dataset = dataset.length
steps_per_epoch = (
len_train_dataset // batch_size_per_node
if len_train_dataset is not None
else None
)
num_train_steps = (
steps_per_epoch * num_epochs if steps_per_epoch is not None else None
)
num_params = model.num_params(params_shape)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len_train_dataset}")
logger.info(f" Num Epochs = {num_epochs}")
logger.info(
f" Batch size per dp device = {training_args.per_device_train_batch_size}"
)
logger.info(f" Number of devices = {jax.device_count()}")
logger.info(
f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
)
logger.info(f" Batch size per update = {batch_size_per_step}")
logger.info(f" Model parameters = {num_params:,}")
# set up wandb run
if jax.process_index() == 0:
# set default x-axis as 'train/step'
wandb.define_metric("*", step_metric="train/step")
# add interesting config parameters
wandb.config.update(
{
"len_train_dataset": len_train_dataset,
"len_eval_dataset": len_eval_dataset,
"batch_size_per_step": batch_size_per_step,
"num_params": num_params,
"model_config": model.config.to_dict(),
"num_devices": jax.device_count(),
"versions": {
"jax": jax.__version__,
"jaxlib": jaxlib.__version__,
"flax": flax.__version__,
"transformers": transformers.__version__,
"datasets": datasets.__version__,
"wandb": wandb.__version__,
"dalle_mini": dalle_mini.__version__,
},
}
)
# Create learning rate schedule
def create_learning_rate_fn() -> Callable[[int], jnp.array]:
"""Create the learning rate function."""
warmup_fn = optax.linear_schedule(
init_value=0.0,
end_value=training_args.learning_rate,
transition_steps=training_args.warmup_steps + 1, # ensure not 0
)
last_boundary = training_args.warmup_steps
# offset step when resuming
if training_args.lr_offset:
warmup_fn = optax.join_schedules(
schedules=[optax.constant_schedule(0.0), warmup_fn],
boundaries=[training_args.lr_offset],
)
last_boundary += training_args.lr_offset
if training_args.lr_decay is None:
return warmup_fn
elif training_args.lr_decay == "linear":
assert (
num_train_steps is not None
), "linear decay requires knowing the dataset length"
decay_fn = optax.linear_schedule(
init_value=training_args.learning_rate,
end_value=0,
transition_steps=num_train_steps - training_args.warmup_steps,
)
elif training_args.lr_decay == "exponential":
decay_fn = optax.exponential_decay(
init_value=training_args.learning_rate,
transition_steps=training_args.lr_transition_steps,
decay_rate=training_args.lr_decay_rate,
staircase=training_args.lr_staircase,
)
schedule_fn = optax.join_schedules(
schedules=[warmup_fn, decay_fn],
boundaries=[last_boundary],
)
return schedule_fn
learning_rate_fn = create_learning_rate_fn()
# create optimizer
trainable_params_shape = trainable_params(
params_shape, training_args.embeddings_only
)
if training_args.optim == "distributed_shampoo":
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
graft_type = {
"sgd": GraftingType.SGD,
"adagrad": GraftingType.ADAGRAD,
"rmsprop": GraftingType.RMSPROP,
"rmsprop_normalized": GraftingType.RMSPROP_NORMALIZED,
"sqrt_n": GraftingType.SQRT_N,
"adagrad_normalized": GraftingType.ADAGRAD_NORMALIZED,
}[training_args.graft_type]
statistics_partition_spec = (
PartitionSpec(None, training_args.shard_shampoo_across, None)
if training_args.shard_shampoo_across != "2d"
else PartitionSpec(None, "dp", "mp")
)
opt = distributed_shampoo(
learning_rate_fn,
block_size=training_args.block_size,
beta1=training_args.beta1,
beta2=training_args.beta2,
diagonal_epsilon=1e-10,
matrix_epsilon=1e-6,
weight_decay=training_args.weight_decay,
start_preconditioning_step=max(
training_args.preconditioning_compute_steps + 1, 101
),
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
statistics_compute_steps=1,
best_effort_shape_interpretation=True,
graft_type=graft_type,
nesterov=training_args.nesterov,
exponent_override=0,
statistics_partition_spec=statistics_partition_spec,
preconditioner_partition_spec=PartitionSpec(
training_args.shard_shampoo_across, None, None
)
if training_args.shard_shampoo_across != "2d"
else PartitionSpec(
"mp" if training_args.mp_devices > training_args.dp_devices else "dp",
None,
None,
),
num_devices_for_pjit=training_args.dp_devices,
shard_optimizer_states=True,
inverse_failure_threshold=0.1,
moving_average_for_momentum=True,
skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
clip_by_scaled_gradient_norm=None,
precision=jax.lax.Precision.HIGHEST,
best_effort_memory_usage_reduction=training_args.optim_quantized,
)
# get the real optimizer and helper functions
update_fn = opt.update
optimizer = {}
opt_fn = {}
for k, p in split_params(trainable_params_shape).items():
if "scanned" in k:
p = jax.eval_shape(
lambda x: jax.tree_util.tree_map(lambda y: y[0], x), p
)
optimizer[k] = opt.init(p)
opt_fn[k] = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
optimizer[k].pspec_fn, optimizer[k].shape_and_dtype_fn
)
optimizer[k] = optax.GradientTransformation(optimizer[k].init_fn, update_fn)
elif training_args.optim == "adam":
optimizer = optax.adamw(
learning_rate=learning_rate_fn,
b1=training_args.beta1,
b2=training_args.beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
)
optimizer = {k: optimizer for k in split_params(trainable_params_shape)}
elif training_args.optim == "adafactor":
# We use the default parameters here to initialize adafactor,
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
optimizer = optax.adafactor(
learning_rate=learning_rate_fn,
clipping_threshold=training_args.max_grad_norm,
weight_decay_rate=training_args.weight_decay,
)
optimizer = {k: optimizer for k in split_params(trainable_params_shape)}
# get PartitionSpec for optimizer state
def get_opt_state_spec_and_shape():
# get opt_state shape without actual init
opt_state_shape = {}
for k, p in split_params(trainable_params_shape).items():
if "scanned" not in k:
opt_state_shape[k] = jax.eval_shape(optimizer[k].init, p)
else:
opt_state_shape[k] = jax.eval_shape(jax.vmap(optimizer[k].init), p)
if training_args.optim == "adafactor":
# factorized state must be replicated (rank different than params)
opt_state_spec = {k: None for k in split_params(trainable_params_shape)}
elif training_args.optim in ["adam", "distributed_shampoo"]:
def _opt_state_spec_per_leaf(x, spec):
if isinstance(x, FrozenDict):
# variables with same structure as params
return spec
else:
# other variables such as count
return None
split_spec = split_params(set_partitions(trainable_params_shape, False))
opt_state_spec = {}
for k, p in split_params(trainable_params_shape).items():
if "scanned" in k:
p = jax.eval_shape(
lambda x: jax.tree_util.tree_map(lambda y: y[0], x), p
)
if training_args.optim == "adam":
opt_state_spec[k] = jax.tree_util.tree_map(
partial(_opt_state_spec_per_leaf, spec=split_spec[k]),
opt_state_shape[k],
# return None spec for empty elements
is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
)
elif training_args.optim == "distributed_shampoo":
opt_state_spec[k] = opt_fn[k].pspec_fn(
p,
split_spec[k],
statistics_partition_spec,
)
# add dimension for scanned params
if "scanned" in k:
opt_state_spec[k] = jax.tree_util.tree_map(
lambda x: PartitionSpec(*(None,) + x)
if x is not None
else None,