diff --git a/examples/train_horovod/resnet50/README.md b/examples/train_horovod/resnet50/README.md
deleted file mode 100644
index 5fd7f1f88..000000000
--- a/examples/train_horovod/resnet50/README.md
+++ /dev/null
@@ -1,92 +0,0 @@
-# Distributed Training Example with Intel® Optimization for Horovod*
-
-## Model Information
-
-| Use Case |Framework | Model Repo | Branch Commit Tag | Optional Patch
-| :---: | :---: | :---: | :---: | :---: |
-| Training | TensorFlow | [Tensorflow-Models](https://github.com/tensorflow/models) | v2.8.0 | itex.yaml
itex_dummy.yaml
hvd_support_light.patch
or hvd_support.patch |
-
-
-
-## Dependency
-- [Tensorflow](https://pypi.org/project/tensorflow/)
-- [Intel® Extension for TensorFlow*](https://pypi.org/project/intel-extension-for-tensorflow/)
-- [Intel® Optimization for Horovod*](https://pypi.org/project/intel-optimization-for-Horovod/)
-- others show as below
-```
-pip install gin gin-config tensorflow-addons tensorflow-model-optimization tensorflow-datasets
-```
-
-## Model examples preparation
-
-### Model Repo
-```
-WORKSPACE=xxxx # set your workspace folder
-cd $WORKSPACE
-git clone -b v2.8.0 https://github.com/tensorflow/models.git tensorflow-models
-cd tensorflow-models
-git apply path/to/hvd_support_light.patch # or path/to/hvd_support.patch
-```
-**hvd_support_light.patch** is the minimum change.
-- hvd.init() is Horovod initialization, including resource allocation.
-- tf.config.experimental.set_memory_growth(): If memory growth is enabled, the runtime initialization will not allocate all memory on the device.
-- tf.config.experimental.set_visible_devices(): Set the list of visible devices.
-- strategy_scope: Remove native distributed.
-- hvd.DistributedOptimizer(): use Horovod distributed optimizer.
-- dataset.shard(): Multiple workers run the same code but with different data. Dataset is split equally between different index workers.
-
-**hvd_support.patch** adds LARS optimizer [paper](https://arxiv.org/abs/1708.03888)
-
-### Download Dataset
-Download imagenet dataset from https://image-net.org/download-images.php
-
-
-**Note** Only for non-commercial research and/or educational purposes
-
-
-
-## Execution
-### Set Model Parameters
-Export those parameters to script or environment.
-```
-export PYTHONPATH=${WORKSPACE}/tensorflow-models
-MODEL_DIR=${WORKSPACE}/output
-DATA_DIR=${WORKSPACE}/imagenet_data/imagenet
-
-CONFIG_FILE=path/to/itex.yaml
-NUMBER_OF_PROCESS=2
-PROCESS_PER_NODE=2
-```
-- Download `itex.yaml` or `itex_dummy.yaml` and set one of them as CONFIG_FILE, then model would correspondingly run with `real data` or `dummy data`. Default value is itex.yaml.
-- Set `NUMBER_OF_PROCESS` and `PROCESS_PER_NODE` according to hvd rank number you need. Default value is a 2 rank task.
-### HVD command
-
-```
-if [ ! -d "$MODEL_DIR" ]; then
- mkdir -p $MODEL_DIR
-else
- rm -rf $MODEL_DIR && mkdir -p $MODEL_DIR
-fi
-
-mpirun -np $NUMBER_OF_PROCESS -ppn $PROCESS_PER_NODE --prepend-rank \
-python ${PYTHONPATH}/official/vision/image_classification/classifier_trainer.py \
---mode=train_and_eval \
---model_type=resnet \
---dataset=imagenet \
---model_dir=$MODEL_DIR \
---data_dir=$DATA_DIR \
---config_file=$CONFIG_FILE
-```
-
-
-
-## OUTPUT
-### Performance Data
-```
-[1] I0909 03:33:23.323099 140645511436096 keras_utils.py:145] TimeHistory: xxxx seconds, xxxx examples/second between steps 0 and 100
-[0] I0909 03:33:23.324534 140611700504384 keras_utils.py:145] TimeHistory: xxxx seconds, xxxx examples/second between steps 0 and 100
-[0] I0909 03:33:43.037004 140611700504384 keras_utils.py:145] TimeHistory: xxxx seconds, xxxx examples/second between steps 100 and 200
-[1] I0909 03:33:43.037142 140645511436096 keras_utils.py:145] TimeHistory: xxxx seconds, xxxx examples/second between steps 100 and 200
-[1] I0909 03:34:03.213994 140645511436096 keras_utils.py:145] TimeHistory: xxxx seconds, xxxx examples/second between steps 200 and 300
-[0] I0909 03:34:03.214127 140611700504384 keras_utils.py:145] TimeHistory: xxxx seconds, xxxx examples/second between steps 200 and 300
-```
diff --git a/examples/train_horovod/resnet50/hvd_support_light.patch b/examples/train_horovod/resnet50/hvd_support_light.patch
deleted file mode 100644
index 6cd993952..000000000
--- a/examples/train_horovod/resnet50/hvd_support_light.patch
+++ /dev/null
@@ -1,233 +0,0 @@
-diff --git a/official/vision/image_classification/classifier_trainer.py b/official/vision/image_classification/classifier_trainer.py
-index ab6fbaea9..f0f2cb2c5 100644
---- a/official/vision/image_classification/classifier_trainer.py
-+++ b/official/vision/image_classification/classifier_trainer.py
-@@ -37,6 +37,14 @@ from official.vision.image_classification.efficientnet import efficientnet_model
- from official.vision.image_classification.resnet import common
- from official.vision.image_classification.resnet import resnet_model
-
-+global is_mpi
-+try:
-+ import horovod.tensorflow.keras as hvd
-+ hvd.init()
-+ is_mpi = hvd.size()
-+except ImportError:
-+ is_mpi = 0
-+ print("No MPI horovod support, this is running in no-MPI mode!")
-
- def get_models() -> Mapping[str, tf.keras.Model]:
- """Returns the mapping from model type name to Keras model."""
-@@ -289,6 +297,12 @@ def train_and_eval(
- """Runs the train and eval path using compile/fit."""
- logging.info('Running train and eval.')
-
-+ if is_mpi:
-+ gpus = tf.config.experimental.list_physical_devices('XPU')
-+ for gpu in gpus:
-+ tf.config.experimental.set_memory_growth(gpu, True)
-+ tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'XPU')
-+
- distribute_utils.configure_cluster(params.runtime.worker_hosts,
- params.runtime.task_index)
-
-@@ -299,7 +313,7 @@ def train_and_eval(
- num_gpus=params.runtime.num_gpus,
- tpu_address=params.runtime.tpu)
-
-- strategy_scope = distribute_utils.get_strategy_scope(strategy)
-+ #strategy_scope = distribute_utils.get_strategy_scope(strategy)
-
- logging.info('Detected %d devices.',
- strategy.num_replicas_in_sync if strategy else 1)
-@@ -324,56 +338,74 @@ def train_and_eval(
-
- logging.info('Global batch size: %d', train_builder.global_batch_size)
-
-- with strategy_scope:
-- model_params = params.model.model_params.as_dict()
-- model = get_models()[params.model.name](**model_params)
-- learning_rate = optimizer_factory.build_learning_rate(
-- params=params.model.learning_rate,
-- batch_size=train_builder.global_batch_size,
-- train_epochs=train_epochs,
-- train_steps=train_steps)
-- optimizer = optimizer_factory.build_optimizer(
-- optimizer_name=params.model.optimizer.name,
-- base_learning_rate=learning_rate,
-- params=params.model.optimizer.as_dict(),
-- model=model)
-- optimizer = performance.configure_optimizer(
-- optimizer,
-- use_float16=train_builder.dtype == 'float16',
-- loss_scale=get_loss_scale(params))
--
-- metrics_map = _get_metrics(one_hot)
-- metrics = [metrics_map[metric] for metric in params.train.metrics]
-- steps_per_loop = train_steps if params.train.set_epoch_loop else 1
--
-- if one_hot:
-- loss_obj = tf.keras.losses.CategoricalCrossentropy(
-- label_smoothing=params.model.loss.label_smoothing)
-- else:
-- loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
-- model.compile(
-- optimizer=optimizer,
-- loss=loss_obj,
-- metrics=metrics,
-- steps_per_execution=steps_per_loop)
--
-- initial_epoch = 0
-- if params.train.resume_checkpoint:
-- initial_epoch = resume_from_checkpoint(
-- model=model, model_dir=params.model_dir, train_steps=train_steps)
-+ model_params = params.model.model_params.as_dict()
-+ model = get_models()[params.model.name](**model_params)
-+ learning_rate = optimizer_factory.build_learning_rate(
-+ params=params.model.learning_rate,
-+ batch_size=train_builder.global_batch_size * hvd.size(),
-+ train_epochs=train_epochs,
-+ train_steps=train_steps)
-+ optimizer = optimizer_factory.build_optimizer(
-+ optimizer_name=params.model.optimizer.name,
-+ base_learning_rate=learning_rate,
-+ params=params.model.optimizer.as_dict(),
-+ model=model)
-+ optimizer = performance.configure_optimizer(
-+ optimizer,
-+ use_float16=train_builder.dtype == 'float16',
-+ loss_scale=get_loss_scale(params))
-+
-+ metrics_map = _get_metrics(one_hot)
-+ metrics = [metrics_map[metric] for metric in params.train.metrics]
-+ steps_per_loop = train_steps if params.train.set_epoch_loop else 1
-
-- callbacks = custom_callbacks.get_callbacks(
-- model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
-- include_tensorboard=params.train.callbacks.enable_tensorboard,
-- time_history=params.train.callbacks.enable_time_history,
-- track_lr=params.train.tensorboard.track_lr,
-- write_model_weights=params.train.tensorboard.write_model_weights,
-- initial_step=initial_epoch * train_steps,
-- batch_size=train_builder.global_batch_size,
-- log_steps=params.train.time_history.log_steps,
-- model_dir=params.model_dir,
-- backup_and_restore=params.train.callbacks.enable_backup_and_restore)
-+ if one_hot:
-+ loss_obj = tf.keras.losses.CategoricalCrossentropy(
-+ label_smoothing=params.model.loss.label_smoothing)
-+ else:
-+ loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
-
-+ hvd_optimizer = hvd.DistributedOptimizer(optimizer, num_groups=1)
-+ model.compile(
-+ optimizer=hvd_optimizer,
-+ loss=loss_obj,
-+ metrics=metrics,
-+ steps_per_execution=steps_per_loop)
-+
-+ initial_epoch = 0
-+ if params.train.resume_checkpoint:
-+ initial_epoch = resume_from_checkpoint(
-+ model=model, model_dir=params.model_dir, train_steps=train_steps)
-+
-+ # Add broadcast callback for rank0
-+ callbacks = []
-+
-+ if hvd.local_rank() == 0:
-+ callbacks = custom_callbacks.get_callbacks(
-+ model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
-+ include_tensorboard=params.train.callbacks.enable_tensorboard,
-+ time_history=params.train.callbacks.enable_time_history,
-+ track_lr=params.train.tensorboard.track_lr,
-+ write_model_weights=params.train.tensorboard.write_model_weights,
-+ initial_step=initial_epoch * train_steps,
-+ batch_size=train_builder.global_batch_size,
-+ log_steps=params.train.time_history.log_steps,
-+ model_dir=params.model_dir,
-+ backup_and_restore=params.train.callbacks.enable_backup_and_restore)
-+ else:
-+ callbacks = custom_callbacks.get_callbacks(
-+ model_checkpoint=False,
-+ include_tensorboard=params.train.callbacks.enable_tensorboard,
-+ time_history=params.train.callbacks.enable_time_history,
-+ track_lr=params.train.tensorboard.track_lr,
-+ write_model_weights=params.train.tensorboard.write_model_weights,
-+ initial_step=initial_epoch * train_steps,
-+ batch_size=train_builder.global_batch_size,
-+ log_steps=params.train.time_history.log_steps,
-+ model_dir=params.model_dir,
-+ backup_and_restore=params.train.callbacks.enable_backup_and_restore)
-+
-+ callbacks.append(hvd.callbacks.BroadcastGlobalVariablesCallback(0))
- serialize_config(params=params, model_dir=params.model_dir)
-
- if params.evaluation.skip_eval:
-diff --git a/official/vision/image_classification/dataset_factory.py b/official/vision/image_classification/dataset_factory.py
-index a0458eccc..e6dfb39f6 100644
---- a/official/vision/image_classification/dataset_factory.py
-+++ b/official/vision/image_classification/dataset_factory.py
-@@ -29,6 +29,7 @@ import tensorflow_datasets as tfds
- from official.modeling.hyperparams import base_config
- from official.vision.image_classification import augment
- from official.vision.image_classification import preprocessing
-+import horovod.tensorflow as hvd
-
- AUGMENTERS = {
- 'autoaugment': augment.AutoAugment,
-@@ -207,7 +208,7 @@ class DatasetBuilder:
- def num_steps(self) -> int:
- """The number of steps (batches) to exhaust this dataset."""
- # Always divide by the global batch size to get the correct # of steps
-- return self.num_examples // self.global_batch_size
-+ return self.num_examples // (self.global_batch_size * hvd.size())
-
- @property
- def dtype(self) -> tf.dtypes.DType:
-@@ -403,14 +404,10 @@ class DatasetBuilder:
- Returns:
- A TensorFlow dataset outputting batched images and labels.
- """
-- if (self.config.builder != 'tfds' and self.input_context and
-- self.input_context.num_input_pipelines > 1):
-- dataset = dataset.shard(self.input_context.num_input_pipelines,
-- self.input_context.input_pipeline_id)
-+ if self.is_training:
-+ dataset = dataset.shard(hvd.size(), hvd.rank())
- logging.info(
-- 'Sharding the dataset: input_pipeline_id=%d '
-- 'num_input_pipelines=%d', self.input_context.num_input_pipelines,
-- self.input_context.input_pipeline_id)
-+ 'Sharding the dataset: total size: %d ', hvd.size(), " local rank: %d ", hvd.rank())
-
- if self.is_training and self.config.builder == 'records':
- # Shuffle the input files.
-diff --git a/official/vision/image_classification/learning_rate.py b/official/vision/image_classification/learning_rate.py
-index 72f7e9518..e7edd90a2 100644
---- a/official/vision/image_classification/learning_rate.py
-+++ b/official/vision/image_classification/learning_rate.py
-@@ -22,10 +22,12 @@ from typing import Any, Mapping, Optional
-
- import numpy as np
- import tensorflow as tf
-+from tensorflow.python.util.tf_export import keras_export
-
- BASE_LEARNING_RATE = 0.1
-
-
-+@tf.keras.utils.register_keras_serializable(package='Custom', name='WarmupDeacySchedule')
- class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
- """A wrapper for LearningRateSchedule that includes warmup steps."""
-
-@@ -66,10 +68,11 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
- return lr
-
- def get_config(self) -> Mapping[str, Any]:
-- config = self._lr_schedule.get_config()
-+ config = {}
- config.update({
- "warmup_steps": self._warmup_steps,
- "warmup_lr": self._warmup_lr,
-+ "lr_schedule": self._lr_schedule,
- })
- return config
-
diff --git a/examples/train_resnet50/README.md b/examples/train_resnet50/README.md
new file mode 100644
index 000000000..42efac776
--- /dev/null
+++ b/examples/train_resnet50/README.md
@@ -0,0 +1,130 @@
+# Resnet50 train on Intel GPU
+
+## Introduction
+
+Intel® Extension for TensorFlow* is compatible with stock Tensorflow*.
+This example shows resnet50 training.
+
+## Hardware Requirements
+
+Verified Hardware Platforms:
+ - Intel® Data Center GPU Max Series
+
+## Prerequisites
+
+### Model Code change
+We optimized bf16 in resnet50.patch, and enable horovod and LARS in hvd_support.patch, please apply patch
+```
+git clone -b v2.8.0 https://github.com/tensorflow/models.git tensorflow-models
+```
+
+### Prepare for GPU (Skip this step for CPU)
+
+Refer to [Prepare](../common_guide_running.md##Prepare)
+
+### Setup Running Environment
+
+* Setup for GPU
+```bash
+./pip_set_env.sh
+```
+
+### Enable Running Environment
+
+Enable oneAPI running environment (only for GPU) and virtual running environment.
+
+ * For GPU, refer to [Running](../common_guide_running.md##Running)
+
+### Apply Patch
+
+#### If not use Horovod
+```
+git apply path/to/configure/resnet50.patch
+```
+
+#### If use Horovod
+```
+git apply path/to/hvd_configure/hvd_support.patch
+```
+#### Prepare ImageNet dataset
+Using TFDS
+classifier_trainer.py supports ImageNet with [TensorFlow Datasets(TFDS)](https://www.tensorflow.org/datasets/overview) .
+
+Please see the following [example snippet](https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/scripts/download_and_prepare.py) for more information on how to use TFDS to download and prepare datasets, and specifically the [TFDS ImageNet readme](https://github.com/tensorflow/datasets/blob/master/docs/catalog/imagenet2012.md) for manual download instructions.
+
+Legacy TFRecords
+Download the ImageNet dataset and convert it to TFRecord format. The following [script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py) and [README](https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy) provide a few options.
+
+Note that the legacy ResNet runners, e.g. [resnet/resnet_ctl_imagenet_main.py](https://github.com/tensorflow/models/blob/v2.8.0/official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py) require TFRecords whereas `classifier_trainer.py` can use both by setting the builder to 'records' or 'tfds' in the configurations.
+
+## Execution
+### Set Model Parameters
+There are several config yaml files in configure and hvd_configure folder. Set one of them as CONFIG_FILE, then model would correspondly run with `real data` or `dummy data`. Single-tile please use yaml file in configure folder. Distribute training please use yaml file in hvd_configure folder, `itex_bf16_lars.yaml`/`itex_fp32_lars.yaml` for HVD real data and `itex_dummy_bf16_lars.yaml`/`itex_dummy_fp32_lars.yaml` for HVD dummy data.
+Export those parameters to script or environment.
+```
+export PYTHONPATH=/the/path/to/tensorflow-models
+MODEL_DIR=/the/path/to/output
+DATA_DIR=/the/path/to/imagenet
+CONFIG_FILE=path/to/itex_xx.yaml # itex_bf16.yaml/itex_fp32.yaml for accuracy, itex_dummy_bf16.yaml/itex_dummy_fp32.yaml for benchmark
+
+```
+
+### Command
+
+```
+if [ ! -d "$MODEL_DIR" ]; then
+ mkdir -p $MODEL_DIR
+else
+ rm -rf $MODEL_DIR && mkdir -p $MODEL_DIR
+fi
+
+python ${PYTHONPATH}/official/vision/image_classification/classifier_trainer.py \
+--mode=train_and_eval \
+--model_type=resnet \
+--dataset=imagenet \
+--model_dir=$MODEL_DIR \
+--data_dir=$DATA_DIR \
+--config_file=$CONFIG_FILE
+```
+
+### Command with Horovod
+Set `NUMBER_OF_PROCESS` and `PROCESS_PER_NODE` according to hvd rank number you need. Default value is 2 rank task.
+
+```
+if [ ! -d "$MODEL_DIR" ]; then
+ mkdir -p $MODEL_DIR
+else
+ rm -rf $MODEL_DIR && mkdir -p $MODEL_DIR
+fi
+
+NUMBER_OF_PROCESS=2
+PROCESS_PER_NODE=2
+
+mpirun -np $NUMBER_OF_PROCESS -ppn $PROCESS_PER_NODE --prepend-rank \
+python ${PYTHONPATH}/official/vision/image_classification/classifier_trainer.py \
+--mode=train_and_eval \
+--model_type=resnet \
+--dataset=imagenet \
+--model_dir=$MODEL_DIR \
+--data_dir=$DATA_DIR \
+--config_file=$CONFIG_FILE
+```
+
+## Example Output without hvd
+```
+I0203 02:48:01.006297 139660941027136 keras_utils.py:145] TimeHistory: xx seconds, xxxx examples/second between steps 1900 and 2000
+I0203 02:48:16.590331 139660941027136 keras_utils.py:145] TimeHistory: xx seconds, xxxx examples/second between steps 2000 and 2100
+I0203 02:48:32.178206 139660941027136 keras_utils.py:145] TimeHistory: xx seconds, xxxx examples/second between steps 2100 and 2200
+I0203 02:48:47.790128 139660941027136 keras_utils.py:145] TimeHistory: xx seconds, xxxx examples/second between steps 2200 and 2300
+I0203 02:49:03.408512 139660941027136 keras_utils.py:145] TimeHistory: xx seconds, xxxx examples/second between steps 2300 and 2400
+```
+## Example Output with hvd
+```
+[0] I0817 00:09:07.602742 139898862851904 keras_utils.py:145] TimeHistory: xx seconds, xxxx examples/second between steps 400 and 600
+[1] I0817 00:09:07.603262 140612319840064 keras_utils.py:145] TimeHistory: xx seconds, xxxx examples/second between steps 400 and 600
+[0] I0817 00:10:07.917546 139898862851904 keras_utils.py:145] TimeHistory: xx seconds, xxxx examples/second between steps 600 and 800
+[1] I0817 00:10:07.917738 140612319840064 keras_utils.py:145] TimeHistory: xx seconds, xxxx examples/second between steps 600 and 800
+[0] I0817 00:11:08.277716 139898862851904 keras_utils.py:145] TimeHistory: xx seconds, xxxx examples/second between steps 800 and 1000
+[1] I0817 00:11:08.277811 140612319840064 keras_utils.py:145] TimeHistory: xx seconds, xxxx examples/second between steps 800 and 1000
+[0] I0817 00:12:08.555174 139898862851904 keras_utils.py:145] TimeHistory: xx seconds, xxxx examples/second between steps 1000 and 1200
+[1] I0817 00:12:08.555221 140612319840064 keras_utils.py:145] TimeHistory: xx seconds, xxxx examples/second between steps 1000 and 1200
\ No newline at end of file
diff --git a/examples/train_horovod/resnet50/itex.yaml b/examples/train_resnet50/configure/itex_bf16.yaml
similarity index 100%
rename from examples/train_horovod/resnet50/itex.yaml
rename to examples/train_resnet50/configure/itex_bf16.yaml
diff --git a/examples/train_horovod/resnet50/itex_dummy.yaml b/examples/train_resnet50/configure/itex_dummy_bf16.yaml
similarity index 98%
rename from examples/train_horovod/resnet50/itex_dummy.yaml
rename to examples/train_resnet50/configure/itex_dummy_bf16.yaml
index d826f5a88..f8c678026 100644
--- a/examples/train_horovod/resnet50/itex_dummy.yaml
+++ b/examples/train_resnet50/configure/itex_dummy_bf16.yaml
@@ -50,6 +50,6 @@ train:
callbacks:
enable_checkpoint_and_export: True
resume_checkpoint: True
- epochs: 90
+ epochs: 1
evaluation:
epochs_between_evals: 1
diff --git a/examples/train_resnet50/configure/itex_dummy_fp32.yaml b/examples/train_resnet50/configure/itex_dummy_fp32.yaml
new file mode 100644
index 000000000..d360c41c9
--- /dev/null
+++ b/examples/train_resnet50/configure/itex_dummy_fp32.yaml
@@ -0,0 +1,55 @@
+# Training configuration for ResNet trained on ImageNet on TPUs.
+# Takes ~4 minutes, 30 seconds seconds per epoch for a v3-32.
+# Reaches > 76.1% within 90 epochs.
+# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
+runtime:
+ distribution_strategy: 'mirrored'
+ num_gpus: 1
+train_dataset:
+ name: 'imagenet2012'
+ data_dir: None
+ builder: 'synthetic'
+ split: 'train'
+ one_hot: False
+ image_size: 224
+ num_classes: 1000
+ num_examples: 1281167
+ batch_size: 256
+ use_per_replica_batch_size: True
+ mean_subtract: False
+ standardize: False
+ dtype: 'float32'
+validation_dataset:
+ name: 'imagenet2012'
+ data_dir: None
+ builder: 'synthetic'
+ split: 'validation'
+ one_hot: False
+ image_size: 224
+ num_classes: 1000
+ num_examples: 50000
+ batch_size: 256
+ use_per_replica_batch_size: True
+ mean_subtract: False
+ standardize: False
+ dtype: 'float32'
+model:
+ name: 'resnet'
+ model_params:
+ rescale_inputs: True
+ optimizer:
+ name: 'momentum'
+ momentum: 0.9
+ decay: 0.9
+ epsilon: 0.001
+ moving_average_decay: 0.
+ lookahead: False
+ loss:
+ label_smoothing: 0.1
+train:
+ callbacks:
+ enable_checkpoint_and_export: True
+ resume_checkpoint: True
+ epochs: 1
+evaluation:
+ epochs_between_evals: 1
diff --git a/examples/train_resnet50/configure/itex_fp32.yaml b/examples/train_resnet50/configure/itex_fp32.yaml
new file mode 100644
index 000000000..27c0b707d
--- /dev/null
+++ b/examples/train_resnet50/configure/itex_fp32.yaml
@@ -0,0 +1,55 @@
+# Training configuration for ResNet trained on ImageNet on TPUs.
+# Takes ~4 minutes, 30 seconds seconds per epoch for a v3-32.
+# Reaches > 76.1% within 90 epochs.
+# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
+runtime:
+ distribution_strategy: 'mirrored'
+ num_gpus: 1
+train_dataset:
+ name: 'imagenet2012'
+ data_dir: None
+ builder: 'records'
+ split: 'train'
+ one_hot: False
+ image_size: 224
+ num_classes: 1000
+ num_examples: 1281167
+ batch_size: 256
+ use_per_replica_batch_size: True
+ mean_subtract: False
+ standardize: False
+ dtype: 'float32'
+validation_dataset:
+ name: 'imagenet2012'
+ data_dir: None
+ builder: 'records'
+ split: 'validation'
+ one_hot: False
+ image_size: 224
+ num_classes: 1000
+ num_examples: 50000
+ batch_size: 256
+ use_per_replica_batch_size: True
+ mean_subtract: False
+ standardize: False
+ dtype: 'float32'
+model:
+ name: 'resnet'
+ model_params:
+ rescale_inputs: True
+ optimizer:
+ name: 'momentum'
+ momentum: 0.9
+ decay: 0.9
+ epsilon: 0.001
+ moving_average_decay: 0.
+ lookahead: False
+ loss:
+ label_smoothing: 0.1
+train:
+ callbacks:
+ enable_checkpoint_and_export: True
+ resume_checkpoint: True
+ epochs: 90
+evaluation:
+ epochs_between_evals: 1
diff --git a/examples/train_resnet50/configure/resnet50.patch b/examples/train_resnet50/configure/resnet50.patch
new file mode 100644
index 000000000..d5f946aad
--- /dev/null
+++ b/examples/train_resnet50/configure/resnet50.patch
@@ -0,0 +1,43 @@
+diff --git a/official/vision/image_classification/augment.py b/official/vision/image_classification/augment.py
+index f322d31da..294fcd543 100644
+--- a/official/vision/image_classification/augment.py
++++ b/official/vision/image_classification/augment.py
+@@ -25,7 +25,10 @@ from __future__ import print_function
+ import math
+ from typing import Any, Dict, List, Optional, Text, Tuple
+
+-from keras.layers.preprocessing import image_preprocessing as image_ops
++try:
++ from keras.src.layers.preprocessing import image_preprocessing as image_ops
++except ImportError:
++ from keras.layers.preprocessing import image_preprocessing as image_ops
+ import tensorflow as tf
+
+
+diff --git a/official/vision/image_classification/classifier_trainer.py b/official/vision/image_classification/classifier_trainer.py
+index ab6fbaea9..0347bc8c9 100644
+--- a/official/vision/image_classification/classifier_trainer.py
++++ b/official/vision/image_classification/classifier_trainer.py
+@@ -283,6 +283,14 @@ def serialize_config(params: base_configs.ExperimentConfig, model_dir: str):
+ hyperparams.save_params_dict_to_yaml(params, params_save_path)
+
+
++class dummy_context:
++ def __init__(self):
++ pass
++ def __enter__(self):
++ pass
++ def __exit__(self, exc_type, exc_value, traceback):
++ pass
++
+ def train_and_eval(
+ params: base_configs.ExperimentConfig,
+ strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]:
+@@ -323,6 +331,7 @@ def train_and_eval(
+ initialize(params, train_builder)
+
+ logging.info('Global batch size: %d', train_builder.global_batch_size)
++ strategy_scope = dummy_context()
+
+ with strategy_scope:
+ model_params = params.model.model_params.as_dict()
diff --git a/examples/train_horovod/resnet50/hvd_support.patch b/examples/train_resnet50/hvd_configure/hvd_support.patch
similarity index 75%
rename from examples/train_horovod/resnet50/hvd_support.patch
rename to examples/train_resnet50/hvd_configure/hvd_support.patch
index a79c71129..860d2adfc 100644
--- a/examples/train_horovod/resnet50/hvd_support.patch
+++ b/examples/train_resnet50/hvd_configure/hvd_support.patch
@@ -1,17 +1,47 @@
+diff --git a/official/legacy/image_classification/callbacks.py b/official/legacy/image_classification/callbacks.py
+index a4934ed88..a7eaad9fb 100644
+--- a/official/legacy/image_classification/callbacks.py
++++ b/official/legacy/image_classification/callbacks.py
+@@ -149,8 +149,8 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
+ def _calculate_metrics(self) -> MutableMapping[str, Any]:
+ logs = {}
+ # TODO(b/149030439): disable LR reporting.
+- # if self._track_lr:
+- # logs['learning_rate'] = self._calculate_lr()
++ if self._track_lr:
++ logs['learning_rate'] = self._calculate_lr()
+ return logs
+
+ def _calculate_lr(self) -> int:
+diff --git a/official/vision/image_classification/augment.py b/official/vision/image_classification/augment.py
+index f322d31da..e668ff731 100644
+--- a/official/vision/image_classification/augment.py
++++ b/official/vision/image_classification/augment.py
+@@ -25,7 +25,10 @@ from __future__ import print_function
+ import math
+ from typing import Any, Dict, List, Optional, Text, Tuple
+
+-from keras.layers.preprocessing import image_preprocessing as image_ops
++try:
++ from keras.src.layers.preprocessing import image_preprocessing as image_ops
++except ImportError:
++ from keras.layers.preprocessing import image_preprocessing as image_ops
+ import tensorflow as tf
+
+
diff --git a/official/vision/image_classification/callbacks.py b/official/vision/image_classification/callbacks.py
-index a4934ed88..2befe5e18 100644
+index a4934ed88..7eaaf0bea 100644
--- a/official/vision/image_classification/callbacks.py
+++ b/official/vision/image_classification/callbacks.py
-@@ -78,6 +78,8 @@ def get_callbacks(
+@@ -78,6 +78,7 @@ def get_callbacks(
save_weights_only=True,
verbose=1))
callbacks.append(MovingAverageCallback())
-+ callbacks.append(ThresholdStopping(
-+ monitor = "val_accuracy", threshold = 0.759))
++ #callbacks.append(ThresholdStopping(monitor = "val_accuracy", threshold = 0.759))
return callbacks
-@@ -254,3 +256,38 @@ class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
+@@ -254,3 +255,38 @@ class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
result = super()._save_model(epoch, logs) # pytype: disable=attribute-error # typed-keras
self.model.set_weights(non_avg_weights)
return result
@@ -51,12 +81,14 @@ index a4934ed88..2befe5e18 100644
+ return monitor_value
+
diff --git a/official/vision/image_classification/classifier_trainer.py b/official/vision/image_classification/classifier_trainer.py
-index ab6fbaea9..f0f2cb2c5 100644
+index ab6fbaea9..4655da449 100644
--- a/official/vision/image_classification/classifier_trainer.py
+++ b/official/vision/image_classification/classifier_trainer.py
-@@ -37,6 +37,14 @@ from official.vision.image_classification.efficientnet import efficientnet_model
+@@ -36,7 +36,16 @@ from official.vision.image_classification.configs import configs
+ from official.vision.image_classification.efficientnet import efficientnet_model
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import resnet_model
++import time
+global is_mpi
+try:
@@ -69,7 +101,16 @@ index ab6fbaea9..f0f2cb2c5 100644
def get_models() -> Mapping[str, tf.keras.Model]:
"""Returns the mapping from model type name to Keras model."""
-@@ -289,6 +297,12 @@ def train_and_eval(
+@@ -271,7 +280,7 @@ def define_classifier_flags():
+ help='The name of the dataset, e.g. ImageNet, etc.')
+ flags.DEFINE_integer(
+ 'log_steps',
+- default=100,
++ default=200,
+ help='The interval of steps between logging of batch level stats.')
+
+
+@@ -289,6 +298,12 @@ def train_and_eval(
"""Runs the train and eval path using compile/fit."""
logging.info('Running train and eval.')
@@ -82,7 +123,7 @@ index ab6fbaea9..f0f2cb2c5 100644
distribute_utils.configure_cluster(params.runtime.worker_hosts,
params.runtime.task_index)
-@@ -299,7 +313,7 @@ def train_and_eval(
+@@ -299,7 +314,7 @@ def train_and_eval(
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
@@ -91,7 +132,7 @@ index ab6fbaea9..f0f2cb2c5 100644
logging.info('Detected %d devices.',
strategy.num_replicas_in_sync if strategy else 1)
-@@ -324,56 +338,74 @@ def train_and_eval(
+@@ -324,56 +339,74 @@ def train_and_eval(
logging.info('Global batch size: %d', train_builder.global_batch_size)
@@ -169,14 +210,14 @@ index ab6fbaea9..f0f2cb2c5 100644
+ label_smoothing=params.model.loss.label_smoothing)
+ else:
+ loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
-
++
+ hvd_optimizer = hvd.DistributedOptimizer(optimizer, num_groups=1)
+ model.compile(
+ optimizer=hvd_optimizer,
+ loss=loss_obj,
+ metrics=metrics,
+ steps_per_execution=steps_per_loop)
-+
+
+ initial_epoch = 0
+ if params.train.resume_checkpoint:
+ initial_epoch = resume_from_checkpoint(
@@ -214,6 +255,29 @@ index ab6fbaea9..f0f2cb2c5 100644
serialize_config(params=params, model_dir=params.model_dir)
if params.evaluation.skip_eval:
+@@ -384,7 +417,9 @@ def train_and_eval(
+ 'validation_steps': validation_steps,
+ 'validation_freq': params.evaluation.epochs_between_evals,
+ }
+-
++ print('[info] Training steps = ', train_steps)
++ print('[info] Validation steps = ', validation_steps)
++ global_start_time = time.time()
+ history = model.fit(
+ train_dataset,
+ epochs=train_epochs,
+@@ -394,6 +429,11 @@ def train_and_eval(
+ verbose=2,
+ **validation_kwargs)
+
++ global_end_time = time.time()
++ print('[info] Global start time = ', time.asctime(time.localtime(global_start_time)))
++ print('[info] Global end time = ', time.asctime(time.localtime(global_end_time)))
++ print('[info] Global consume time = ', ((global_end_time - global_start_time) / (60.0)), ' mins')
++
+ validation_output = None
+ if not params.evaluation.skip_eval:
+ validation_output = model.evaluate(
diff --git a/official/vision/image_classification/configs/base_configs.py b/official/vision/image_classification/configs/base_configs.py
index 760b3dce0..3939249b6 100644
--- a/official/vision/image_classification/configs/base_configs.py
@@ -227,7 +291,7 @@ index 760b3dce0..3939249b6 100644
@dataclasses.dataclass
diff --git a/official/vision/image_classification/dataset_factory.py b/official/vision/image_classification/dataset_factory.py
-index a0458eccc..9ff76333b 100644
+index a0458eccc..275c6d5fb 100644
--- a/official/vision/image_classification/dataset_factory.py
+++ b/official/vision/image_classification/dataset_factory.py
@@ -29,6 +29,7 @@ import tensorflow_datasets as tfds
@@ -238,16 +302,25 @@ index a0458eccc..9ff76333b 100644
AUGMENTERS = {
'autoaugment': augment.AutoAugment,
-@@ -207,7 +208,7 @@ class DatasetBuilder:
+@@ -207,7 +208,16 @@ class DatasetBuilder:
def num_steps(self) -> int:
"""The number of steps (batches) to exhaust this dataset."""
# Always divide by the global batch size to get the correct # of steps
- return self.num_examples // self.global_batch_size
-+ return self.num_examples // (self.global_batch_size * hvd.size())
++ distributed_size = 1
++ if self.config.split == 'train':
++ distributed_size = hvd.size()
++ divide_steps = self.num_examples // (self.global_batch_size * distributed_size)
++ remain_steps = self.num_examples % (self.global_batch_size * distributed_size)
++ if remain_steps == 0:
++ return divide_steps
++ else:
++ return divide_steps + 1
++ #return self.num_examples // (self.global_batch_size * hvd.size())
@property
def dtype(self) -> tf.dtypes.DType:
-@@ -403,14 +404,9 @@ class DatasetBuilder:
+@@ -403,14 +413,10 @@ class DatasetBuilder:
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
@@ -255,22 +328,35 @@ index a0458eccc..9ff76333b 100644
- self.input_context.num_input_pipelines > 1):
- dataset = dataset.shard(self.input_context.num_input_pipelines,
- self.input_context.input_pipeline_id)
-- logging.info(
++ if self.is_training:
++ dataset = dataset.shard(hvd.size(), hvd.rank())
+ logging.info(
- 'Sharding the dataset: input_pipeline_id=%d '
- 'num_input_pipelines=%d', self.input_context.num_input_pipelines,
- self.input_context.input_pipeline_id)
-+ if self.is_training:
-+ dataset = dataset.shard(hvd.size(), hvd.rank())
-+ print("Sharding the dataset: total size: ", hvd.size(), " local rank: ", hvd.rank())
++ 'Sharding the dataset: total size: %d ', hvd.size(), " local rank: %d ", hvd.rank())
if self.is_training and self.config.builder == 'records':
# Shuffle the input files.
+@@ -455,10 +461,10 @@ class DatasetBuilder:
+ # replicas automatically when strategy.distribute_datasets_from_function
+ # is called, so we use local batch size here.
+ dataset = dataset.batch(
+- self.local_batch_size, drop_remainder=self.is_training)
++ self.local_batch_size, drop_remainder=False)
+ else:
+ dataset = dataset.batch(
+- self.global_batch_size, drop_remainder=self.is_training)
++ self.global_batch_size, drop_remainder=False)
+
+ # Prefetch overlaps in-feed with training
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
diff --git a/official/vision/image_classification/lars_optimizer.py b/official/vision/image_classification/lars_optimizer.py
new file mode 100644
-index 000000000..e54a9738c
+index 000000000..029ae654e
--- /dev/null
+++ b/official/vision/image_classification/lars_optimizer.py
-@@ -0,0 +1,239 @@
+@@ -0,0 +1,248 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
@@ -304,13 +390,21 @@ index 000000000..e54a9738c
+# from tensorflow.python.keras import backend_config
+# from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+tf_minor = int(tf.__version__.split('.')[1])
-+if tf_minor >= 7:
-+ if tf_minor >= 12:
-+ from keras.optimizers.legacy import optimizer_v2
-+ elif tf_minor >= 9:
-+ from keras.optimizers.optimizer_v2 import optimizer_v2
-+ else:
-+ from keras.optimizer_v2 import optimizer_v2
++if tf_minor >=13:
++ from keras.src.optimizers.legacy import optimizer_v2
++ from keras.src import backend_config
++elif tf_minor >= 12:
++ from keras.optimizers.legacy import optimizer_v2
++ from keras import backend_config
++elif tf_minor >= 9:
++ from keras.optimizers.optimizer_v2 import optimizer_v2
++ from keras import backend_config
++elif tf_minor >= 6:
++ from keras.optimizer_v2 import optimizer_v2
++ from keras import backend_config
++else:
++ from tensorflow.python.keras.optimizer_v2 import optimizer_v2
++ from tensorflow.python.keras import backend_config
+
+# class LARSOptimizer(optimizer_v2modified.OptimizerV2Modified):
+class LARSOptimizer(optimizer_v2.OptimizerV2):
@@ -508,10 +602,11 @@ index 000000000..e54a9738c
+ "eeta": self.eeta,
+ "epsilon": self.epsilon,
+ "use_nesterov": self.use_nesterov,
++ "skip_list": self._skip_list,
+ })
+ return config
diff --git a/official/vision/image_classification/learning_rate.py b/official/vision/image_classification/learning_rate.py
-index 72f7e9518..bb5621f04 100644
+index 72f7e9518..19207291d 100644
--- a/official/vision/image_classification/learning_rate.py
+++ b/official/vision/image_classification/learning_rate.py
@@ -22,10 +22,14 @@ from typing import Any, Mapping, Optional
@@ -621,9 +716,9 @@ index 72f7e9518..bb5621f04 100644
+
+ def _get_learning_rate(self, step):
+ with ops.name_scope_v2(self.name or 'PolynomialDecayWithWarmup') as name:
-+ initial_learning_rate = ops.convert_to_tensor_v2(
++ initial_learning_rate = ops.convert_to_tensor(
+ self.initial_learning_rate, name='initial_learning_rate')
-+ warmup_steps = ops.convert_to_tensor_v2(
++ warmup_steps = ops.convert_to_tensor(
+ self.warmup_steps, name='warmup_steps')
+ step = tf.cast(step, tf.float32)
+ warmup_rate = (
@@ -648,7 +743,7 @@ index 72f7e9518..bb5621f04 100644
+ }
+
diff --git a/official/vision/image_classification/optimizer_factory.py b/official/vision/image_classification/optimizer_factory.py
-index 48a4512ee..5f98744aa 100644
+index 48a4512ee..0ed4cb914 100644
--- a/official/vision/image_classification/optimizer_factory.py
+++ b/official/vision/image_classification/optimizer_factory.py
@@ -26,6 +26,7 @@ import tensorflow_addons as tfa
@@ -659,6 +754,39 @@ index 48a4512ee..5f98744aa 100644
# pylint: disable=protected-access
+@@ -61,12 +62,12 @@ def build_optimizer(
+ if optimizer_name == 'sgd':
+ logging.info('Using SGD optimizer')
+ nesterov = params.get('nesterov', False)
+- optimizer = tf.keras.optimizers.SGD(
++ optimizer = tf.keras.optimizers.legacy.SGD(
+ learning_rate=base_learning_rate, nesterov=nesterov)
+ elif optimizer_name == 'momentum':
+ logging.info('Using momentum optimizer')
+ nesterov = params.get('nesterov', False)
+- optimizer = tf.keras.optimizers.SGD(
++ optimizer = tf.keras.optimizers.legacy.SGD(
+ learning_rate=base_learning_rate,
+ momentum=params['momentum'],
+ nesterov=nesterov)
+@@ -75,7 +76,7 @@ def build_optimizer(
+ rho = params.get('decay', None) or params.get('rho', 0.9)
+ momentum = params.get('momentum', 0.9)
+ epsilon = params.get('epsilon', 1e-07)
+- optimizer = tf.keras.optimizers.RMSprop(
++ optimizer = tf.keras.optimizers.legacy.RMSprop(
+ learning_rate=base_learning_rate,
+ rho=rho,
+ momentum=momentum,
+@@ -85,7 +86,7 @@ def build_optimizer(
+ beta_1 = params.get('beta_1', 0.9)
+ beta_2 = params.get('beta_2', 0.999)
+ epsilon = params.get('epsilon', 1e-07)
+- optimizer = tf.keras.optimizers.Adam(
++ optimizer = tf.keras.optimizers.legacy.Adam(
+ learning_rate=base_learning_rate,
+ beta_1=beta_1,
+ beta_2=beta_2,
@@ -102,6 +103,16 @@ def build_optimizer(
beta_1=beta_1,
beta_2=beta_2,
@@ -676,6 +804,15 @@ index 48a4512ee..5f98744aa 100644
else:
raise ValueError('Unknown optimizer %s' % optimizer_name)
+@@ -139,7 +150,7 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
+ else:
+ warmup_steps = 0
+
+- lr_multiplier = params.scale_by_batch_size
++ lr_multiplier = 0 #params.scale_by_batch_size
+
+ if lr_multiplier and lr_multiplier > 0:
+ # Scale the learning rate based on the batch size and a multiplier
@@ -172,6 +183,14 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
batch_size=batch_size,
total_steps=train_epochs * train_steps,
@@ -691,3 +828,38 @@ index 48a4512ee..5f98744aa 100644
if warmup_steps > 0:
if decay_type not in ['cosine_with_warmup']:
logging.info('Applying %d warmup steps to the learning rate',
+diff --git a/official/vision/image_classification/resnet/imagenet_preprocessing.py b/official/vision/image_classification/resnet/imagenet_preprocessing.py
+index 86ba3ed98..298a4a412 100644
+--- a/official/vision/image_classification/resnet/imagenet_preprocessing.py
++++ b/official/vision/image_classification/resnet/imagenet_preprocessing.py
+@@ -113,7 +113,7 @@ def process_record_dataset(dataset,
+ dataset = dataset.map(
+ lambda value: parse_record_fn(value, is_training, dtype),
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+- dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
++ dataset = dataset.batch(batch_size, drop_remainder=False)
+
+ # Operations between the final prefetch and the get_next call to the iterator
+ # will happen synchronously during run time. We prefetch here again to
+@@ -350,7 +350,7 @@ def input_fn(is_training,
+ parse_record_fn=parse_record_fn,
+ dtype=dtype,
+ datasets_num_private_threads=datasets_num_private_threads,
+- drop_remainder=drop_remainder,
++ drop_remainder=False,
+ tf_data_experimental_slack=tf_data_experimental_slack,
+ )
+
+diff --git a/official/vision/image_classification/resnet/resnet_runnable.py b/official/vision/image_classification/resnet/resnet_runnable.py
+index fe3059f77..c521e5992 100644
+--- a/official/vision/image_classification/resnet/resnet_runnable.py
++++ b/official/vision/image_classification/resnet/resnet_runnable.py
+@@ -100,7 +100,7 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
+ datasets_num_private_threads=self.flags_obj
+ .datasets_num_private_threads,
+ dtype=self.dtype,
+- drop_remainder=True)
++ drop_remainder=False)
+ orbit.StandardTrainer.__init__(
+ self,
+ train_dataset,
diff --git a/examples/train_resnet50/hvd_configure/itex_bf16_lars.yaml b/examples/train_resnet50/hvd_configure/itex_bf16_lars.yaml
new file mode 100644
index 000000000..4930b7d42
--- /dev/null
+++ b/examples/train_resnet50/hvd_configure/itex_bf16_lars.yaml
@@ -0,0 +1,55 @@
+evaluation:
+ epochs_between_evals: 1
+model:
+ learning_rate:
+ initial_lr: 10.8
+ name: polynomial
+ warmup_epochs: 5
+ loss:
+ label_smoothing: 0.1
+ model_params:
+ rescale_inputs: true
+ name: resnet
+ optimizer:
+ epsilon: 0
+ lookahead: false
+ momentum: 0.9
+ name: lars
+ weight_decay: 1.25e-05
+model_dir: models
+runtime:
+ distribution_strategy: mirrored
+ num_gpus: 1
+train:
+ callbacks:
+ enable_checkpoint_and_export: true
+ epochs: 42
+ resume_checkpoint: true
+train_dataset:
+ batch_size: 512
+ builder: records
+ data_dir: None
+ dtype: bfloat16
+ image_size: 224
+ mean_subtract: false
+ name: imagenet2012
+ num_classes: 1000
+ num_examples: 1281167
+ one_hot: false
+ split: train
+ standardize: false
+ use_per_replica_batch_size: true
+validation_dataset:
+ batch_size: 512
+ builder: records
+ data_dir: None
+ dtype: bfloat16
+ image_size: 224
+ mean_subtract: false
+ name: imagenet2012
+ num_classes: 1000
+ num_examples: 50000
+ one_hot: false
+ split: validation
+ standardize: false
+ use_per_replica_batch_size: true
diff --git a/examples/train_resnet50/hvd_configure/itex_dummy_bf16_lars.yaml b/examples/train_resnet50/hvd_configure/itex_dummy_bf16_lars.yaml
new file mode 100644
index 000000000..95fd678d9
--- /dev/null
+++ b/examples/train_resnet50/hvd_configure/itex_dummy_bf16_lars.yaml
@@ -0,0 +1,55 @@
+evaluation:
+ epochs_between_evals: 1
+model:
+ learning_rate:
+ initial_lr: 10.8
+ name: polynomial
+ warmup_epochs: 5
+ loss:
+ label_smoothing: 0.1
+ model_params:
+ rescale_inputs: true
+ name: resnet
+ optimizer:
+ epsilon: 0
+ lookahead: false
+ momentum: 0.9
+ name: lars
+ weight_decay: 1.25e-05
+model_dir: models
+runtime:
+ distribution_strategy: mirrored
+ num_gpus: 1
+train:
+ callbacks:
+ enable_checkpoint_and_export: true
+ epochs: 1
+ resume_checkpoint: true
+train_dataset:
+ batch_size: 512
+ builder: synthetic
+ data_dir: None
+ dtype: bfloat16
+ image_size: 224
+ mean_subtract: false
+ name: imagenet2012
+ num_classes: 1000
+ num_examples: 1281167
+ one_hot: false
+ split: train
+ standardize: false
+ use_per_replica_batch_size: true
+validation_dataset:
+ batch_size: 512
+ builder: synthetic
+ data_dir: None
+ dtype: bfloat16
+ image_size: 224
+ mean_subtract: false
+ name: imagenet2012
+ num_classes: 1000
+ num_examples: 50000
+ one_hot: false
+ split: validation
+ standardize: false
+ use_per_replica_batch_size: true
diff --git a/examples/train_resnet50/hvd_configure/itex_dummy_fp32_lars.yaml b/examples/train_resnet50/hvd_configure/itex_dummy_fp32_lars.yaml
new file mode 100644
index 000000000..7577d12c6
--- /dev/null
+++ b/examples/train_resnet50/hvd_configure/itex_dummy_fp32_lars.yaml
@@ -0,0 +1,55 @@
+evaluation:
+ epochs_between_evals: 2
+model:
+ learning_rate:
+ initial_lr: 9.0
+ name: polynomial
+ warmup_epochs: 3
+ loss:
+ label_smoothing: 0.1
+ model_params:
+ rescale_inputs: true
+ name: resnet
+ optimizer:
+ epsilon: 0
+ lookahead: false
+ momentum: 0.9
+ name: lars
+ weight_decay: 1.25e-05
+model_dir: models
+runtime:
+ distribution_strategy: mirrored
+ num_gpus: 1
+train:
+ callbacks:
+ enable_checkpoint_and_export: true
+ epochs: 1
+ resume_checkpoint: true
+train_dataset:
+ batch_size: 256
+ builder: synthetic
+ data_dir: None
+ dtype: float32
+ image_size: 224
+ mean_subtract: false
+ name: imagenet2012
+ num_classes: 1000
+ num_examples: 1281167
+ one_hot: false
+ split: train
+ standardize: false
+ use_per_replica_batch_size: true
+validation_dataset:
+ batch_size: 256
+ builder: synthetic
+ data_dir: None
+ dtype: float32
+ image_size: 224
+ mean_subtract: false
+ name: imagenet2012
+ num_classes: 1000
+ num_examples: 50000
+ one_hot: false
+ split: validation
+ standardize: false
+ use_per_replica_batch_size: true
diff --git a/examples/train_resnet50/hvd_configure/itex_fp32_lars.yaml b/examples/train_resnet50/hvd_configure/itex_fp32_lars.yaml
new file mode 100644
index 000000000..18ed07b16
--- /dev/null
+++ b/examples/train_resnet50/hvd_configure/itex_fp32_lars.yaml
@@ -0,0 +1,55 @@
+evaluation:
+ epochs_between_evals: 2
+model:
+ learning_rate:
+ initial_lr: 9.0
+ name: polynomial
+ warmup_epochs: 3
+ loss:
+ label_smoothing: 0.1
+ model_params:
+ rescale_inputs: true
+ name: resnet
+ optimizer:
+ epsilon: 0
+ lookahead: false
+ momentum: 0.9
+ name: lars
+ weight_decay: 1.25e-05
+model_dir: models
+runtime:
+ distribution_strategy: mirrored
+ num_gpus: 1
+train:
+ callbacks:
+ enable_checkpoint_and_export: true
+ epochs: 39
+ resume_checkpoint: true
+train_dataset:
+ batch_size: 256
+ builder: records
+ data_dir: None
+ dtype: float32
+ image_size: 224
+ mean_subtract: false
+ name: imagenet2012
+ num_classes: 1000
+ num_examples: 1281167
+ one_hot: false
+ split: train
+ standardize: false
+ use_per_replica_batch_size: true
+validation_dataset:
+ batch_size: 512
+ builder: records
+ data_dir: None
+ dtype: float32
+ image_size: 224
+ mean_subtract: false
+ name: imagenet2012
+ num_classes: 1000
+ num_examples: 50000
+ one_hot: false
+ split: validation
+ standardize: false
+ use_per_replica_batch_size: true
diff --git a/examples/train_resnet50/pip_set_env.sh b/examples/train_resnet50/pip_set_env.sh
new file mode 100755
index 000000000..221307782
--- /dev/null
+++ b/examples/train_resnet50/pip_set_env.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+#
+# Copyright (c) 2021-2022 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+ENV_NAME=env_itex
+deactivate
+rm -rf $ENV_NAME
+python -m venv $ENV_NAME
+source $ENV_NAME/bin/activate
+pip install --upgrade pip
+pip install scikit-image
+pip install --upgrade intel-extension-for-tensorflow[xpu]