From c46c9e5082deeb71668d45b7683cc4ac2dc40b06 Mon Sep 17 00:00:00 2001 From: Kobby Panford-Quainoo Date: Fri, 3 Nov 2023 12:10:30 +0000 Subject: [PATCH] Do training on Vertex AI TPU (#7) * add tpu training * add docker container for TPU * add docker container for TPU * add docker container for TPU * add job args and fix linting * fix linting issues * fix bug in encoding function * modify conflicting accelerator type name * Add train on tpu parts * fix tpu issues * modify xm parallel trial runs * separate functions into seperate modules * update unit test * update unit test * update unit test and xm runs * apply reversible encoding * fix type annotation * add unittesting for dataencoder * add description to functions * Restructure code and cleanup * Add method description to encoding methods --- src/skai/model/data.py | 200 ++++++++++++++++++ src/skai/model/data_test.py | 137 +++++++++++- src/skai/model/docker_instructions.py | 34 +-- src/skai/model/train.py | 23 +- src/skai/model/train_lib.py | 33 +-- src/skai/model/train_lib_test.py | 3 + src/skai/model/train_strategy.py | 36 ++++ .../model/xm_launch_single_model_vertex.py | 53 ++--- 8 files changed, 454 insertions(+), 65 deletions(-) create mode 100644 src/skai/model/train_strategy.py diff --git a/src/skai/model/data.py b/src/skai/model/data.py index 9930b0f3..96a00f99 100644 --- a/src/skai/model/data.py +++ b/src/skai/model/data.py @@ -177,6 +177,206 @@ def apply_batch(dataloader, batch_size): return dataloader +class DataEncoder: + def __init__(self): + string_label_categories: dict[str, int] = {b'bad_example' :0, + b'destroyed' :1, + b'major_damage':2, + b'minor_damage':3, + b'no_damage' :4} + self.label_to_int_table = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer(list(string_label_categories.keys()), + list(string_label_categories.values())), + default_value=-1 + ) + self.int_to_label_table = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer(list(string_label_categories.values()), + list(string_label_categories.keys())), + default_value='unknown' + ) + def encode_example_ids(self, dataloader: Dataloader)-> Dataloader: + """ + Encode example IDs from hexadecimal strings to integers in a + TensorFlow DataLoader. + + Description: + example_id are hexadecimal strings, eg. b0b947f423a1c77ac948c76f63fa8209. + This is encode by taking the int to base 16. This gives a long integer + representation, ie 125613911306676688688906949689977127817181202292590253, + which cannot be stored by a tensorflow tensor. This long integer can be + broken into smaller segments like [2323, 9023, 3403] using a combination of + integer division and modulo operations which can be reversed. The segments + are (pre-)padded to same size for all examples in a batch and initial size + before padding appended to segments. ie [0, 0, 2323, 9023, 3403, 3] + + Args: + - dataloader: The TensorFlow DataLoader containing example IDs to be encoded. + + Returns: + - dataloader: The modified TensorFlow DataLoader with encoded example IDs. + """ + return self._apply_map_to_features(dataloader, + self._convert_hex_strings_to_int, + 'example_id') + + def encode_string_labels(self, dataloader: Dataloader)-> Dataloader: + """ + Encode string data components to numerical values. + HashTable used stores the unique labels and an integer value for lookup + Args: + dataloader: The dataloader. + + Returns: + dataloader with string label encoded. + """ + return self._apply_map_to_features(dataloader, + self._convert_label_to_int, + 'string_label') + + def decode_example_ids(self, inputs: tf.Tensor | Dataloader): + """ + Decode example IDs from integers to hexadecimal strings in a batch. + + Args: + - inputs: A batch of data or dataloader containing encoded example IDs. + + Returns: + - The modified batch or dataloader with decoded example IDs. + """ + if isinstance(inputs, Dataloader): + return self._apply_map_to_features( + inputs, + self._convert_int_to_hex_strings, + 'example_id') + else: + return self._convert_int_to_hex_strings(inputs) + + def decode_string_labels(self, inputs: tf.Tensor | Dataloader): + """ + Decodes string labels by looking up strings from integers in a batch. + + Args: + - inputs: A batch of data or dataloader containing encoded string labels. + + Returns: + - The modified batch or dataloader with decoded string labels. + """ + + if isinstance(inputs, Dataloader): + return self._apply_map_to_features( + inputs, + self._convert_int_to_label, + 'string_label') + else: + return self._convert_int_to_label(inputs) + + def _convert_hex_strings_to_int(self, hex_strings): + """Converts hex strings to integer values, typically a very long one. + This long integer values do not fit into a tensorflow tensor int datatype. + So the long integer is broken into segments using modulo technique and padding + to same size + """ + segment_size=4 + def split_long_integer(number): + segments = [] + while number > 0: + segment = number % (10 ** segment_size) # Extract the last `segment_size` digits + segments.append(segment) + number //= 10 ** segment_size # Removes the last `segment_size` digits + return segments + + output = [] + for hex_string in hex_strings: + integer = int(hex_string.numpy(), 16) + short_integers = split_long_integer(integer) + short_integers += [len(short_integers)] + output.append(short_integers) + padded_sequences = tf.keras.preprocessing.sequence.pad_sequences( + output, padding='pre') + return padded_sequences + + def _convert_int_to_hex_strings(self, segments): + """Converts integer segments to a long integer value + that can be decoded to retrieve its hex string representation + """ + def combine_segments(segments, segment_size=4): + list_size = segments[-1] + segments_to_decode = segments[-(list_size+1):-1] + + number = 0 + for i, segment in enumerate(segments_to_decode): + number += segment * (10 ** (i * segment_size)) + return number + + def long_integer_to_string(integer): + strings = f'{integer:032x}' + return tf.compat.as_bytes( + strings, encoding='utf-8' + ) + + output = [] + segment_size = 4 + for segment in segments: + long_integer = combine_segments( + segment.numpy().tolist(), segment_size) + output.append(long_integer_to_string(long_integer)) + return tf.convert_to_tensor(output) + + def _convert_label_to_int(self, string_labels): + """Lookup integer values from string labels""" + return self.label_to_int_table.lookup(string_labels) + + def _convert_int_to_label(self, int_labels): + """Lookup string labels from integer keys""" + return self.int_to_label_table.lookup(int_labels) + + def _process_per_batch(self, batch, map_fn, feature): + """Apply a map function to a batch of data.""" + for idx, examples in enumerate(batch): + processed = map_fn(examples[feature]) + examples[feature] = processed + + if idx==0: + transformed_batch=tf.data.Dataset.from_tensor_slices(examples) + continue + transformed_batch.concatenate( + tf.data.Dataset.from_tensor_slices(examples)) + return transformed_batch + + def _apply_map_to_features(self, dataloader: Dataloader, + map_fn: collections.abc.Callable[[tf.Tensor], tf.Tensor], + feature: str): + """ + Apply a map function to a TensorFlow DataLoader and return the modified DataLoader. + + Args: + - dataloader: The TensorFlow DataLoader to apply the map function to. + - map_fn: The mapping function to apply. + + Returns: + - dataloader: The modified TensorFlow DataLoader. + """ + batch_size = dataloader.train_splits[0]._batch_size.numpy() + + dataloader.train_splits = [ + self._process_per_batch(data, map_fn, feature) for data in dataloader.train_splits + ] + dataloader.val_splits = [ + self._process_per_batch(data, map_fn, feature) for data in dataloader.val_splits + ] + num_splits = len(dataloader.train_splits) + train_ds = gather_data_splits( + list(range(num_splits)), dataloader.train_splits) + val_ds = gather_data_splits(list(range(num_splits)), dataloader.val_splits) + dataloader.train_ds = train_ds + dataloader.eval_ds['val'] = val_ds + for (k, v) in dataloader.eval_ds.items(): + if k != 'val': + dataloader.eval_ds[k] = self._process_per_batch(v, map_fn, feature) + dataloader = apply_batch(dataloader, batch_size) + return dataloader + + def gather_data_splits( slice_idx: list[int], dataset: tf.data.Dataset | list[tf.data.Dataset]) -> tf.data.Dataset: diff --git a/src/skai/model/data_test.py b/src/skai/model/data_test.py index db326e85..a4ff58b1 100644 --- a/src/skai/model/data_test.py +++ b/src/skai/model/data_test.py @@ -17,7 +17,7 @@ import os import tempfile from typing import List - +from absl.testing import absltest import numpy as np from skai.model import data import tensorflow as tf @@ -135,5 +135,140 @@ def setUpClass(cls): ] +def _create_test_data_with_hex_strings(): + examples_dir = _make_temp_dir() + labeled_train_path = os.path.join( + examples_dir, 'train_labeled_examples.tfrecord') + labeled_test_path = os.path.join( + examples_dir, 'test_labeled_examples.tfrecord') + unlabeled_path = os.path.join( + examples_dir, 'unlabeled_examples.tfrecord') + + _write_tfrecord([ + _make_example('b0b947f423a1c77ac948c76f63fa8209', 0, 0, 'A0', 0, 'no_damage', 64, 256), + _make_example('5fb3fc48db76805c169e8dc667c3f266', 0, 1, 'A1', 0, 'no_damage', 64, 256), + _make_example('21bdfdb3f65974473d4a19f05871449d', 0, 2, 'A2', 1, 'major_damage', 64, 256), + ], labeled_train_path) + + _write_tfrecord([ + _make_example('a564b943bdebd4936ce0fd135cc19fbf', 1, 0, 'B0', 0, 'no_damage', 64, 256), + ], labeled_test_path) + + _write_tfrecord([ + _make_example('3a8e68680d3ec6d1013d11f492a2d7d5', 2, 0, 'C0', -1, 'bad_example', 64, 256), + _make_example('1004dc994ff1888052aa3ff4be5e55cf', 2, 1, 'C1', -1, 'bad_example', 64, 256), + _make_example('4b49276f4f10856b9e8a57fad78ee593', 2, 2, 'C2', -1, 'bad_example', 64, 256), + _make_example('97a9600f1e418132af93ea03f4264ad2', 2, 3, 'C3', -1, 'bad_example', 64, 256), + ], unlabeled_path) + + return labeled_train_path, labeled_test_path, unlabeled_path +class TestDataEncoder(absltest.TestCase): + def setUp(self): + self.data_encoder = data.DataEncoder() + labeled_train_path, labeled_test_path, unlabeled_path = _create_test_data_with_hex_strings() + self.labeled_train_path = labeled_train_path + self.labeled_test_path = labeled_test_path + self.unlabeled_path = unlabeled_path + + dataset_builder = data.get_dataset('skai') + kwargs = { + 'labeled_train_pattern': self.labeled_train_path, + 'unlabeled_train_pattern': self.unlabeled_path, + 'validation_pattern': self.labeled_test_path, + 'use_post_disaster_only': False, + 'load_small_images': True, + 'data_dir': _make_temp_dir(), + } + + dataloader = dataset_builder( + 1, + initial_sample_proportion=1, + subgroup_ids=(), + subgroup_proportions=(), + **kwargs + ) + self.dataloader = data.apply_batch(dataloader, 2) + + def test_encode_example_ids_returns_dataloader(self): + # Check if encode_example_id method correctly returns a dataloader + encoded_dataloader = self.data_encoder.encode_example_ids(self.dataloader) + self.assertIsInstance(encoded_dataloader, data.Dataloader) + + def test_encode_example_ids_encodes_strings_to_int(self): + # Check if the example IDs are correctly encoded to ints + encoded_dataloader = self.data_encoder.encode_example_ids(self.dataloader) + dataset = encoded_dataloader.train_splits[0] + encoded_example_ids = list(dataset.map(lambda x: x['example_id'] + ).as_numpy_iterator())[0] + self.assertIsInstance(encoded_example_ids, np.ndarray) + self.assertTrue(np.issubdtype(encoded_example_ids.dtype, np.integer)) + + def test_encode_string_labels_returns_dataloader(self): + # Check if encode_string_label method correctly returns a dataloader + encoded_dataloader = self.data_encoder.encode_string_labels(self.dataloader) + self.assertIsInstance(encoded_dataloader, data.Dataloader) + + def test_encode_string_labels_encodes_strings_to_int(self): + # Check if encode_string_label method correctly returns a dataloader + encoded_dataloader = self.data_encoder.encode_string_labels(self.dataloader) + dataset = encoded_dataloader.train_splits[0] #pick one example and evaluate + encoded_string_label = list(dataset.map(lambda x: x['string_label'] + ).as_numpy_iterator())[0] + self.assertIsInstance(encoded_string_label, np.ndarray) + self.assertTrue(np.issubdtype(encoded_string_label.dtype, np.integer)) + + def test_decode_example_ids_returns_dataloader(self): + encoded_dataloader = self.data_encoder.encode_example_ids(self.dataloader) + decoded_data = self.data_encoder.decode_example_ids(encoded_dataloader) + self.assertIsInstance(decoded_data, data.Dataloader) + + def test_decode_int_label_decodes_int_to_string(self): + # Check if the example IDs are correctly encoded + encoded_dataloader = self.data_encoder.encode_string_labels(self.dataloader) + decoded_dataloader = self.data_encoder.decode_string_labels(encoded_dataloader) + dataset = decoded_dataloader.train_splits[0] + decoded_int_label = list(dataset.map(lambda x: x['string_label'] + ).as_numpy_iterator())[0] + self.assertIsInstance(decoded_int_label, np.ndarray) + self.assertTrue(np.issubdtype(decoded_int_label.dtype, np.str_) or + np.issubdtype(decoded_int_label.dtype, object)) + + def test_decode_example_id_outputs_matches_inputs(self): + all_example_ids = [] + dataset_true = self.dataloader.train_splits[0] + true_id_list = list(dataset_true.map(lambda x: x['example_id']).as_numpy_iterator()) + for string_label in true_id_list: + all_example_ids += string_label.tolist() + + encoded_dataloader = self.data_encoder.encode_example_ids(self.dataloader) + decoded_dataloader = self.data_encoder.decode_example_ids(encoded_dataloader) + + all_decoded_ids = [] + dataset_decoded = decoded_dataloader.train_splits[0] + decoded_id_list = list(dataset_decoded.map(lambda x: x['example_id']).as_numpy_iterator()) + for string_label in decoded_id_list: + all_decoded_ids += string_label.tolist() + self.assertItemsEqual(all_example_ids[:len(all_decoded_ids)], + all_decoded_ids) + + def test_decode_string_label_outputs_matches_inputs(self): + all_string_labels = [] + dataset_true = self.dataloader.train_splits[0] + true_labels_list = list(dataset_true.map(lambda x: x['string_label']).as_numpy_iterator()) + for string_label in true_labels_list: + all_string_labels += string_label.tolist() + + encoded_dataloader = self.data_encoder.encode_string_labels(self.dataloader) + decoded_dataloader = self.data_encoder.decode_string_labels(encoded_dataloader) + + all_decoded_labels = [] + dataset_decoded = decoded_dataloader.train_splits[0] + decoded_labels_list = list(dataset_decoded.map(lambda x: x['string_label']).as_numpy_iterator()) + for string_label in decoded_labels_list: + all_decoded_labels += string_label.tolist() + self.assertItemsEqual(all_string_labels[:len(all_decoded_labels)], + all_decoded_labels) + + if __name__ == '__main__': tfds.testing.test_main() diff --git a/src/skai/model/docker_instructions.py b/src/skai/model/docker_instructions.py index fd94a26d..e6e5c06d 100644 --- a/src/skai/model/docker_instructions.py +++ b/src/skai/model/docker_instructions.py @@ -17,21 +17,22 @@ GPU_ACCELERATORS = ['P100', 'V100', 'P4', 'T4', 'A100'] TPU_ACCELERATORS = ['TPU_V2', 'TPU_V3'] -CPU_BASE_IMAGE = 'tensorflow/tensorflow:2.13.0' -GPU_BASE_IMAGE = 'nvcr.io/nvidia/tensorflow:23.08-tf2-py3' -TPU_BASE_IMAGE = 'ubuntu:20.04' +CPU_BASE_IMAGE = 'tensorflow/tensorflow:2.14.0' +GPU_BASE_IMAGE = 'tensorflow/tensorflow:2.14.0-gpu' +TPU_BASE_IMAGE = 'ubuntu:22.04' def tpuvm_docker_instructions() -> list[str]: - """Returns a list of docker instructions necessary to use TF 2.9.1 on TPUs.""" - tpu_shared_object_url = 'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.3.0/libtpu.so' - tf_wheel_url = 'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.9.1/tensorflow-2.9.1-cp38-cp38-linux_x86_64.whl' + """Returns a list of docker instructions necessary to use TensorFlow on TPUs.""" + tf_wheel_name = 'tensorflow-2.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl' + tf_wheel_url= 'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.14.0/' + tf_wheel_name + tpu_shared_object_url = 'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.8.0/libtpu.so' return [ f'RUN wget {tpu_shared_object_url} -O /lib/libtpu.so', 'RUN chmod 700 /lib/libtpu.so', f'RUN wget {tf_wheel_url}', - 'RUN pip3 install tensorflow-2.9.1-cp38-cp38-linux_x86_64.whl', - 'RUN rm tensorflow-2.9.1-cp38-cp38-linux_x86_64.whl', + f'RUN pip3 install {tf_wheel_name}', + f'RUN rm {tf_wheel_name}', ] @@ -55,22 +56,7 @@ def get_docker_instructions(accelerator: str) -> tuple[str, list[str]]: # https://cloud.google.com/deep-learning-containers/docs/choosing-container base_image = GPU_BASE_IMAGE docker_instructions = [ - # Add deadsnakes repo - 'RUN apt update', - 'RUN apt-get install software-properties-common -y', - 'RUN add-apt-repository ppa:deadsnakes/ppa -y', - - # Install Python 3.10', - 'RUN apt update && apt install -y python3.10 python3.10-distutils', - 'RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10', - - # Replace python shell with python3.10', - 'RUN unlink /usr/bin/python', - 'RUN ln -s /usr/bin/python3.10 /usr/bin/python', - - 'RUN python -m pip install --pre --extra-index-url ' + - 'https://developer.download.nvidia.com/compute/redist/jp/v50 ' + - 'tensorflow==2.13' + 'RUN apt-get update && apt-get install -y libcairo2-dev libjpeg-dev libgif-dev' ] else: diff --git a/src/skai/model/train.py b/src/skai/model/train.py index 91f22ac0..992858c1 100644 --- a/src/skai/model/train.py +++ b/src/skai/model/train.py @@ -18,7 +18,6 @@ # pylint: enable=line-too-long """ -import datetime import logging as native_logging import os @@ -33,6 +32,7 @@ from skai.model import sampling_policies from skai.model import train_lib from skai.model.configs import base_config +from skai.model.train_strategy import get_strategy import tensorflow as tf @@ -43,6 +43,12 @@ flags.DEFINE_bool( 'is_vertex', False, 'True if the training job will be executed on VertexAI.' ) +flags.DEFINE_enum( + 'accelerator_type', + default='cpu', + help='Accelerator to use for computations', + enum_values=['cpu', 'gpu', 'tpu'] + ) flags.DEFINE_string('ensemble_dir', '', 'If specified, loads the models at ' 'this directory to consider the ensemble.') flags.DEFINE_string( @@ -111,7 +117,7 @@ def main(_) -> None: dataloader.train_ds.filter( generate_bias_table_lib.filter_ids_fn(ids_tab)) for ids_tab in sampling_policies.convert_ids_to_table(config.ids_dir)] - + print("Ids dir: ", config.ids_dir) model_params = models.ModelTrainingParameters( model_name=config.model.name, train_bias=config.train_bias, @@ -132,7 +138,7 @@ def main(_) -> None: reweighting_signal=config.reweighting.signal ) model_params.train_bias = config.train_bias - output_dir = config.output_dir + if FLAGS.is_vertex: # TODO: go/skai-instadeep - Create output_dir specific to job. start_time = datetime.datetime.now() @@ -203,6 +209,14 @@ def main(_) -> None: # Apply batching (must apply batching only after filtering) dataloader = data.apply_batch(dataloader, config.data.batch_size) + if FLAGS.accelerator_type == 'tpu': + # Encode string data components as numerical + # This is useful when using TPU which does not accept string datatype + dataloader = data.DataEncoder().encode_string_labels(dataloader) + dataloader = data.DataEncoder().encode_example_ids(dataloader) + + strategy = get_strategy(accelerator_type=FLAGS.accelerator_type) + _ = train_lib.train_and_evaluate( train_as_ensemble=config.train_stage_2_as_ensemble, dataloader=dataloader, @@ -218,8 +232,9 @@ def main(_) -> None: example_id_to_bias_table=example_id_to_bias_table, vizier_trial_name=FLAGS.trial_name, is_vertex=FLAGS.is_vertex, + strategy=strategy ) if __name__ == '__main__': - app.run(main) + app.run(main) \ No newline at end of file diff --git a/src/skai/model/train_lib.py b/src/skai/model/train_lib.py index ebc0aaf8..ac09d92d 100644 --- a/src/skai/model/train_lib.py +++ b/src/skai/model/train_lib.py @@ -509,8 +509,9 @@ def run_train( val_ds: tf.data.Dataset, model_params: models.ModelTrainingParameters, experiment_name: str, + strategy: tf.distribute.Strategy, callbacks: Optional[List[tf.keras.callbacks.Callback]] = None, - example_id_to_bias_table: Optional[tf.lookup.StaticHashTable] = None + example_id_to_bias_table: Optional[tf.lookup.StaticHashTable] = None, ) -> tf.keras.Model: """Initializes and trains model on given training and validation data. @@ -519,17 +520,20 @@ def run_train( val_ds: Evaluation dataset. model_params: Dataclass object containing model and training parameters. experiment_name: String to describe model being trained. + strategy: Strategy for distributed training. callbacks: Keras Callbacks, like saving checkpoints or early stopping. example_id_to_bias_table: Hash table mapping example ID to bias label. Returns: Trained model. """ - two_head_model = init_model( - model_params=model_params, - experiment_name=experiment_name, - example_id_to_bias_table=example_id_to_bias_table - ) + with strategy.scope(): + two_head_model = init_model( + model_params=model_params, + experiment_name=experiment_name, + example_id_to_bias_table=example_id_to_bias_table + ) + two_head_model.fit( train_ds, @@ -545,10 +549,11 @@ def train_ensemble( num_splits: int, ood_ratio: float, output_dir: str, + strategy: tf.distribute.Strategy, save_model_checkpoints: bool = True, early_stopping: bool = True, example_id_to_bias_table: Optional[tf.lookup.StaticHashTable] = None, - is_vertex: bool = False, + is_vertex: bool = False ) -> List[tf.keras.Model]: """Trains an ensemble of models, locally. See xm_launch.py for parallelized. @@ -559,11 +564,12 @@ def train_ensemble( ood_ratio: Float for the ratio of slices that will be considered out-of-distribution. output_dir: String for directory path where checkpoints will be saved. + strategy: Strategy for distributed training. save_model_checkpoints: Boolean for saving checkpoints during training. early_stopping: Boolean for early stopping during training. example_id_to_bias_table: Hash table mapping example ID to bias label. is_vertex: Set to true if training on VertexAI. - + Returns: List of trained models and, optionally, predictions. """ @@ -587,7 +593,8 @@ def train_ensemble( model_params=model_params, experiment_name=combo_name, callbacks=combo_callbacks, - example_id_to_bias_table=example_id_to_bias_table) + example_id_to_bias_table=example_id_to_bias_table, + strategy=strategy) ensemble.append(combo_model) return ensemble @@ -932,10 +939,11 @@ def train_and_evaluate( save_model_checkpoints: bool, save_best_model: bool, early_stopping: bool, + strategy: tf.distribute.Strategy, ensemble_dir: Optional[str] = '', example_id_to_bias_table: Optional[tf.lookup.StaticHashTable] = None, vizier_trial_name: str | None = None, - is_vertex: bool = False + is_vertex: bool = False, ): """Performs the operations of training, optionally ensembling, and evaluation. @@ -988,7 +996,8 @@ def train_and_evaluate( model_params=model_params, experiment_name=experiment_name, callbacks=callbacks, - example_id_to_bias_table=example_id_to_bias_table) + example_id_to_bias_table=example_id_to_bias_table, + strategy=strategy) evaluate_model(two_head_model, output_dir, dataloader.eval_ds, save_model_checkpoints, save_best_model) - return two_head_model + return two_head_model \ No newline at end of file diff --git a/src/skai/model/train_lib_test.py b/src/skai/model/train_lib_test.py index ac5ebd8a..5dfc6cff 100644 --- a/src/skai/model/train_lib_test.py +++ b/src/skai/model/train_lib_test.py @@ -187,6 +187,7 @@ def setUp(self): log_metrics_callback.LogMetricsCallback = mock.Mock( return_value=tf.keras.callbacks.Callback() ) + self.strategy = tf.distribute.get_strategy() @parameterized.named_parameters( (model_name, model_name) @@ -264,6 +265,7 @@ def test_train_and_load_model_from_checkpoint(self, model_name): self.model_params_one_head, 'test_model_eval', callbacks=callbacks, + strategy=self.strategy, ) checkpoint_dir = os.path.join(self.output_dir, 'checkpoints') self.assertNotEmpty(tf.io.gfile.listdir(checkpoint_dir)) @@ -308,6 +310,7 @@ def test_train_and_load_entire_model(self, model_name): self.model_params_one_head, 'test_model_eval', callbacks=callbacks, + strategy=self.strategy, ) model_dir = self.output_dir diff --git a/src/skai/model/train_strategy.py b/src/skai/model/train_strategy.py new file mode 100644 index 00000000..b847c49d --- /dev/null +++ b/src/skai/model/train_strategy.py @@ -0,0 +1,36 @@ +from typing import Union +import tensorflow as tf + + +_Strategy = Union[ + tf.distribute.Strategy, + tf.distribute.MirroredStrategy, + tf.distribute.TPUStrategy + ] + + +def get_tpu_resolver(): + """Create cluster resolver for Cloud TPUs + """ + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local') + tf.config.experimental_connect_to_cluster(resolver) + tf.tpu.experimental.initialize_tpu_system(resolver) + return resolver + + +def get_strategy(accelerator_type: str)->_Strategy: + """Gets distributed training strategy for accelerator type + Args: + accelerator_type: The accelerator type which is one of cpu, gpu or tpu + + Returns: + MirrorStrategy if accelerator_type is gpu, + TPUStrategy if accelerator_type is tpu, + else default Strategy + """ + if accelerator_type == 'gpu': + return tf.distribute.MirroredStrategy() + elif accelerator_type == 'tpu': + resolver = get_tpu_resolver() + return tf.distribute.TPUStrategy(resolver) + return tf.distribute.get_strategy() \ No newline at end of file diff --git a/src/skai/model/xm_launch_single_model_vertex.py b/src/skai/model/xm_launch_single_model_vertex.py index 5d3159f8..3892c81b 100644 --- a/src/skai/model/xm_launch_single_model_vertex.py +++ b/src/skai/model/xm_launch_single_model_vertex.py @@ -47,26 +47,25 @@ FLAGS = flags.FLAGS flags.DEFINE_string('project_path', '.', 'Path to project') flags.DEFINE_string( - 'experiment_name', - '', - 'Label for XManager experiment to make it easier to find.', + 'experiment_name', + '', + 'Label for XManager experiment to make it easier to find.', ) flags.DEFINE_bool( - 'use_vizier', False, 'Finds the best hyperparameters using Vizier.' + 'use_vizier', False, 'Finds the best hyperparameters using Vizier.' ) flags.DEFINE_bool( - 'train_as_ensemble', - False, - 'Trains an ensemble of single ' - 'models, as we would for Stage 1 in Introspective Self-Play.', + 'train_as_ensemble', + False, + 'Trains an ensemble of single ' + 'models, as we would for Stage 1 in Introspective Self-Play.', ) flags.DEFINE_bool('eval_only', False, 'Only runs evaluation, no training.') flags.DEFINE_integer( - 'ram', - 32, - 'Fixed amount of RAM for the work unit in GB', + 'ram', + 32, + 'Fixed amount of RAM for the work unit in GB', ) - flags.DEFINE_integer( 'cpu', 4, @@ -85,7 +84,6 @@ help='Accelerator to use for faster computations.', enum_values=['P100', 'V100', 'P4', 'T4', 'A100', 'TPU_V2', 'TPU_V3'] ) - flags.DEFINE_integer( 'accelerator_count', 1, @@ -95,6 +93,7 @@ 'https://github.com/deepmind/xmanager/blob/main/docs/executors.md' ), ) + config_flags.DEFINE_config_file('config') @@ -164,21 +163,26 @@ def main(_) -> None: ]), use_deep_module=True, ) + if FLAGS.accelerator is not None: - if ( - FLAGS.accelerator in ['TPU_V3', 'TPU_V2'] - and FLAGS.accelerator_count != 8 - ): - raise ValueError( - f'The accelerator {FLAGS.accelerator} only support 8 devices.' - ) + if FLAGS.accelerator in ['TPU_V3', 'TPU_V2']: + if FLAGS.accelerator_count != 8: + raise ValueError( + f'The accelerator {FLAGS.accelerator} only support 8 devices.' + ) + accelerator_type = 'tpu' + else: + accelerator_type = 'gpu' + resources_args = { - FLAGS.accelerator: FLAGS.accelerator_count, - 'RAM': FLAGS.ram * xm.GiB, - 'CPU': FLAGS.cpu * xm.vCPU, + FLAGS.accelerator: FLAGS.accelerator_count, + 'RAM': FLAGS.ram * xm.GiB, + 'CPU': FLAGS.cpu * xm.vCPU, } else: resources_args = {'RAM': FLAGS.ram * xm.GiB, 'CPU': FLAGS.cpu * xm.vCPU} + accelerator_type = 'cpu' + executor = xm_local.Vertex( requirements=xm.JobRequirements( service_tier=xm.ServiceTier.PROD, **resources_args @@ -191,7 +195,8 @@ def main(_) -> None: executor_spec=xm_local.Vertex.Spec(), args={ 'config': config_path, - 'is_vertex': 'vertex' in str(executor.Spec()).lower() + 'is_vertex': True, + 'accelerator_type': accelerator_type }, ), ])