-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
478 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Keras 3 Overview | ||
|
||
[Keras](https://keras.io/about/) is a deep learning API written in Python and capable of running on top of either JAX, TensorFlow, or PyTorch. Both JAX and TensorFlow backend compiles the model by XLA and delivers the best training and prediction performance on GPU. But results vary from model to model, as non XLA TensorFlow is occasionaly faster on GPU. The following image show how ITEX works with XLA, Keras 3 TensorFlow backend and legacy Keras. | ||
|
||
<p align="center"> | ||
<img src="images/keras3.png" alt="keras3" /> | ||
</p> | ||
|
||
|
||
## Use Case with different performance | ||
There are serval use cases that can lead to diffent performance. | ||
|
||
* Default | ||
Users use Keras 3 and the model supports jit, the model will runs into XLA. | ||
If user script does not contains keras related code and does not enables XLA in tensorflow. There will be performance regression. Set environment variable `ITEX_DISABLE_XLA=1` to avoid regression. After ITEX XLA disabled, users can choose wether to use NPD (default) or stream excutor for better performance by environment variable `ITEX_ENABLE_NEXTPLUGGABLE_DEVICE`. | ||
|
||
* Legacy Keras | ||
To continue using Keras 2.0, do the following. | ||
1. Install `tf-keras` via `pip install tf-keras` | ||
2. To switch `tf.keras` to use Keras 2 (`tf-keras`), set the environment variable `TF_USE_LEGACY_KERAS=1` directly or in your python program with `import os;os.environ["TF_USE_LEGACY_KERAS"]="1"`. Please note that this will set it for all packages in your Python runtime program | ||
3. Change the keras import: replace `import keras` with `import tf_keras as keras`. Update any `from keras import ` to `from tf_keras`. | ||
|
||
Users can choose wether to use NPD (default) or stream excutor for better performance by environment variable `ITEX_ENABLE_NEXTPLUGGABLE_DEVICE`. | ||
|
||
* Keras 3 with jit_compile disabled | ||
Users can disable jit_compile by `model.jit_compile=False` or `model.compile(..., jit_compile=False)`. The use of itex ops override can also lead to disabling jit_compile. In this case, `ITEX_DISABLE_XLA=1` must be set. | ||
|
||
* Enable XLA through TensorFlow. | ||
Users can enable XLA through TensorFlow by add environment variable `TF_XLA_FLAGS="--tf_xla_auto_jit=1"`. Use `tf_xla_auto_jit=1` for auto clustering TF ops into XLA, `tf_xla_auto_jit=2` for compiling all into XLA. Users should set `model.jit_compile=False` if keras model is used. If ITEX custom ops is used or `ITEX_OPS_OVERRIDE` is set, users should use `tf_xla_auto_jit=1` to avoid error. | ||
|
||
|
||
|
||
|
||
|
||
## Situations leads to warning or Error | ||
We list all invalid cases here. Keras version equals to 0 means model script does not use Keras. | ||
|
||
Note that in any cases, `import keras` first before `import tensorflow` will cause an error due to circular import in ITEX. | ||
|
||
| OPS_OVERRIDE | TF_AUTO_JIT_FLAG | Keras version | NPD | Jit Compile | Warning | Error | Solution | | ||
|--------------|------------------|---------------|-----|-------------|---------|-------|----------| | ||
| Any | 0 | 0 | 0 | NA | | PluggableDevice cannot work with latest Keras. | `ITEX_DISABLE_XLA=1` | | ||
| Any | 0 | 0 | 1 | NA | Perf Regression | | `ITEX_DISABLE_XLA=1` | | ||
| Any | Any | 2 | Any | 1 | | | Unkown behavior, not supported. Use `TF_AUTO_JIT_FLAG="--tf_xla_auto_jit=1"` or `2` to enable XLA | | ||
| Any | 0 | 3 | 0 | Any | | Cannot close NPD when keras 3 | `ITEX_DISABLE_XLA=1` | | ||
| Any | 0 | 3 | 1 | 0 | | perf regression | `ITEX_DISABLE_XLA=1` | | ||
| Any | 1 | Any | 0 | Any | | Cannot close NPD | `ITEX_ENABLE_NEXTPLUGGABLE_DEVICE=1` | | ||
| Any | 2 | Any | 0 | Any | | Cannot close NPD | `ITEX_ENABLE_NEXTPLUGGABLE_DEVICE=1` | | ||
| 1 | 2 | Any | 1 | Any | custom op not supported by XLA | | `ITEX_OPS_OVERRIDE=0` | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,268 @@ | ||
# Copyright (c) 2023 Intel Corporation | ||
# | ||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""ITEX optimization for some TensorFlow API.""" | ||
import logging | ||
import os | ||
import types | ||
import tensorflow as tf | ||
|
||
|
||
from keras import ops | ||
|
||
|
||
from intel_extension_for_tensorflow.python.ops.layer_norm_k3 import _layer_norm | ||
|
||
format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||
logging.basicConfig(level=logging.INFO, format=format_str) | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def copy_func(f, name=None): | ||
''' | ||
return a function with same code, globals, defaults, closure, and | ||
name (or provide a new name) | ||
''' | ||
fn = types.FunctionType(f.__code__, f.__globals__, name or f.__name__, | ||
f.__defaults__, f.__closure__) | ||
# in case f was given attrs (note this dict is a shallow copy): | ||
fn.__dict__.update(f.__dict__) | ||
return fn | ||
|
||
|
||
def _can_use_onednn_layer_norm(self, ndims): | ||
"""Return false if Itex layernorm implementation cannot be used. | ||
Check if the axis is contiguous and can be collapsed into the last axis. | ||
The self.axis is assumed to have no duplicates. | ||
""" | ||
self._data_format = "NHWC" # pylint: disable=protected-access | ||
self._is_one_axis_len = None # pylint: disable=protected-access | ||
can_use_onednn_layer_norm = True | ||
axis = sorted(self.axis) | ||
if axis[-1] != ndims - 1 or ndims < 2 or ndims > 4 or axis[-1] - axis[0] != len(axis) - 1: # pylint: disable=line-too-long | ||
can_use_onednn_layer_norm = False | ||
|
||
if can_use_onednn_layer_norm and (axis[-1] == 3 or self.axis[-1] == -1): | ||
self.data_format = 'NHWC' | ||
|
||
if len(axis) == 1: | ||
self._is_one_axis_len = True # pylint: disable=protected-access | ||
else: | ||
self._is_one_axis_len = False # pylint: disable=protected-access | ||
|
||
if self.dtype == 'float64': | ||
raise ValueError( | ||
'Itex Layernorm only support float32, bfloat16 and float16.') # pylint: disable=line-too-long | ||
|
||
return can_use_onednn_layer_norm | ||
|
||
|
||
def experimental_ops_override(): | ||
''' | ||
using itex api in some tf and keras functions. | ||
''' | ||
try: | ||
from pkg_resources import packaging # pylint: disable=import-outside-toplevel | ||
version = packaging.version.parse | ||
if version(tf.__version__) < version("2.16.1"): | ||
return | ||
|
||
from keras.src import backend # pylint: disable=import-outside-toplevel | ||
from keras.src.utils import tf_utils # pylint: disable=import-outside-toplevel | ||
|
||
import keras | ||
tf_ln_call = copy_func(keras.layers.LayerNormalization.call) | ||
tf_gn_call = copy_func(keras.layers.GroupNormalization.call) | ||
tf_gn_build = copy_func(keras.layers.GroupNormalization.build) | ||
|
||
except BaseException: # pylint: disable=broad-except | ||
return | ||
|
||
def itex_layer_norm_build(self, input_shape): | ||
self.supports_jit = False | ||
if self.compute_dtype == "float16" or self.compute_dtype == "bfloat16": # pylint: disable=no-else-return | ||
self._param_dtype = "float32" | ||
else: | ||
self._param_dtype = self.dtype or dtypes.float32 | ||
ndims = len(input_shape) | ||
if ndims is None: | ||
raise ValueError( | ||
'Input shape %s has undefined rank.' % input_shape) | ||
if isinstance(self.axis, list): | ||
shape = tuple([input_shape[dim] for dim in self.axis]) | ||
else: | ||
shape = (input_shape[self.axis],) | ||
self.axis = [self.axis] | ||
for idx, x in enumerate(self.axis): | ||
if x < 0: | ||
self.axis[idx] = ndims + x | ||
param_shape = [input_shape[dim] for dim in self.axis] | ||
if self.scale or self.rms_scaling: | ||
self.gamma = self.add_weight( | ||
name="gamma", | ||
shape=shape, | ||
initializer=self.gamma_initializer, | ||
regularizer=self.gamma_regularizer, | ||
constraint=self.gamma_constraint, | ||
trainable=True, | ||
dtype=self._param_dtype, | ||
) | ||
else: | ||
self.gamma = None | ||
self._gamma_const = ops.ones( | ||
dtype=self._param_dtype, shape=param_shape) | ||
|
||
if self.center and not self.rms_scaling: | ||
self.beta = self.add_weight( | ||
name="beta", | ||
shape=shape, | ||
initializer=self.beta_initializer, | ||
regularizer=self.beta_regularizer, | ||
constraint=self.beta_constraint, | ||
trainable=True, | ||
dtype=self._param_dtype, | ||
) | ||
else: | ||
self.beta = None | ||
self._beta_const = ops.zeros( | ||
dtype=self._param_dtype, shape=param_shape) | ||
self._use_layernorm = _can_use_onednn_layer_norm(self, ndims) | ||
self.built = True | ||
|
||
def _layer_norm_inference_or_training(self, inputs, gamma, beta, training): | ||
"""Returns the output of layer norm.""" | ||
def _layer_norm_training(): | ||
return _layer_norm( | ||
inputs, | ||
scale=gamma, | ||
offset=beta, | ||
epsilon=self.epsilon, | ||
is_training=True, | ||
data_format=self._data_format) | ||
|
||
def _layer_norm_inference(): | ||
return _layer_norm( | ||
inputs, | ||
scale=gamma, | ||
offset=beta, | ||
epsilon=self.epsilon, | ||
is_training=False, | ||
data_format=self._data_format) | ||
|
||
output, _, _ = tf.__internal__.smart_cond.smart_cond( | ||
training, _layer_norm_training, _layer_norm_inference) | ||
return output | ||
|
||
def itex_layer_norm_call(self, inputs, training=None): | ||
if not self._use_layernorm: # pylint: disable=protected-access | ||
return tf_ln_call(self, inputs) # pylint: disable=not-callable | ||
if self.rms_scaling: # pylint: disable=protected-access | ||
return tf_ln_call(self, inputs) # pylint: disable=not-callable | ||
if training is None: | ||
is_training = True | ||
if isinstance(training, int): | ||
is_training = bool(training) | ||
if not self.trainable: | ||
# When the layer is not trainable, it overrides the value passed from | ||
# model. | ||
is_training = False | ||
# Compute the axes along which to reduce the mean / variance | ||
inputs = ops.cast(inputs, self.compute_dtype) | ||
# Compute the axes along which to reduce the mean / variance | ||
input_shape = inputs.shape | ||
ndims = len(input_shape) | ||
|
||
# Broadcasting only necessary for norm when the axis is not just | ||
# the last dimension | ||
broadcast_shape = [1] * ndims | ||
for dim in self.axis: | ||
broadcast_shape[dim] = input_shape[dim] | ||
|
||
def _broadcast(v): | ||
if ( | ||
v is not None | ||
and len(v.shape) != ndims | ||
and self.axis != [ndims - 1] | ||
): | ||
return ops.reshape(v, broadcast_shape) | ||
return v | ||
|
||
input_dtype = inputs.dtype | ||
if input_dtype in (tf.float16, tf.bfloat16) and self.dtype == "float32" and not self._use_layernorm: | ||
# If mixed precision is used, cast inputs to float32 so that | ||
# this is at least as numerically stable as the fused version. | ||
inputs = ops.cast(inputs, "float32") | ||
|
||
beta = self.beta if self.beta is not None else self._beta_const | ||
gamma = self.gamma if self.gamma is not None else self._gamma_const | ||
if self._is_one_axis_len: | ||
outputs = _layer_norm_inference_or_training(self, inputs, gamma, beta, | ||
is_training) | ||
return outputs | ||
else: | ||
# Collapse dims before self.axis, and dims in self.axis | ||
pre_dim, in_dim = (1, 1) | ||
axis = sorted(self.axis) | ||
tensor_shape = inputs.shape | ||
for dim in range(0, ndims): | ||
dim_tensor = tensor_shape[dim] | ||
if dim < axis[0]: | ||
pre_dim = pre_dim * dim_tensor | ||
else: | ||
assert dim in axis | ||
in_dim = in_dim * dim_tensor | ||
|
||
squeezed_shape = [1, pre_dim, in_dim] | ||
inputs = ops.reshape(inputs, squeezed_shape) | ||
|
||
# self.gamma and self.beta have the wrong shape for layer_norm, so | ||
# we cannot pass them as the scale and offset parameters. Therefore, we | ||
# create two constant tensors in correct shapes for layer_norm and | ||
# later construct a separate calculation on the scale and offset. | ||
scale = ops.ones([in_dim], dtype="float32") | ||
offset = ops.zeros([in_dim], dtype="float32") | ||
|
||
# Compute layer normalization. | ||
outputs = _layer_norm_inference_or_training(self, inputs, scale, | ||
offset, is_training) | ||
outputs = ops.reshape(outputs, tensor_shape) | ||
scale, offset = _broadcast( | ||
self.gamma), _broadcast(self.beta) | ||
|
||
if scale is not None: | ||
outputs = outputs * ops.cast(scale, outputs.dtype) | ||
if offset is not None: | ||
outputs = outputs + ops.cast(offset, outputs.dtype) | ||
return outputs | ||
|
||
try: | ||
keras.layers.LayerNormalization.call = itex_layer_norm_call | ||
keras.layers.LayerNormalization.build = itex_layer_norm_build | ||
logger.info("itex experimental ops override is enabled.") | ||
except BaseException: # pylint: disable=broad-except | ||
logger.error("Cannot override itex ops.") | ||
try: | ||
import keras # pylint: disable=import-outside-toplevel | ||
keras.src.layers.normalization.layer_normalization.LayerNormalization.call = itex_layer_norm_call | ||
keras.src.layers.normalization.layer_normalization.LayerNormalization.build = itex_layer_norm_build | ||
except BaseException: # pylint: disable=broad-except | ||
logger.warning( | ||
"itex experimental ops override: Keras is not installed.") # pylint: disable=line-too-long |
Oops, something went wrong.