Skip to content

Commit

Permalink
keras 3 op override (#2690)
Browse files Browse the repository at this point in the history
  • Loading branch information
jianyizh authored May 15, 2024
1 parent ad3c198 commit 55a4394
Show file tree
Hide file tree
Showing 12 changed files with 478 additions and 13 deletions.
8 changes: 6 additions & 2 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@
<tr>
<td colspan="3" align="center"><a href="guide/practice_guide.md#cpu-practice-guide">CPU practice guide</a></td>
<td colspan="3" align="center"><a href="guide/practice_guide.md#gpu-practice-guide">GPU practice guide</a></td>
<td colspan="3" align="center"><a href="install/install_for_cpp.md">C++ API support</a></td>
<td colspan="3" align="center"><a href="guide/OpenXLA.md">OpenXLA</a></td>
<td colspan="2" align="center"><a href="install/install_for_cpp.md">C++ API support</a></td>
<td colspan="2" align="center"><a href="guide/OpenXLA.md">OpenXLA</a></td>
<td colspan="2" align="center"><a href="guide/Keras3_support.md">Keras 3</a></td>
</tr>
</tbody>
<thead>
Expand Down Expand Up @@ -132,3 +133,6 @@
* OpenXLA

Intel® Extension for TensorFlow\* adopts a uniform Device API PJRT as the supported device plugin mechanism to implement Intel GPU backend for OpenXLA support on TensorFlow frontend.

* Keras 3
Keras 3 with TensorFlow comes with a significant enhancement - the Just-In-Time (JIT) compilation is enabled by default. This feature leverages the XLA (Accelerated Linear Algebra) compiler to optimize TensorFlow computations. See <a href="guide/Keras3_support.md">Keras 3</a> to avoid possible performance issues and error.
Binary file added docs/guide/images/keras3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
49 changes: 49 additions & 0 deletions docs/guide/keras3_support.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Keras 3 Overview

[Keras](https://keras.io/about/) is a deep learning API written in Python and capable of running on top of either JAX, TensorFlow, or PyTorch. Both JAX and TensorFlow backend compiles the model by XLA and delivers the best training and prediction performance on GPU. But results vary from model to model, as non XLA TensorFlow is occasionaly faster on GPU. The following image show how ITEX works with XLA, Keras 3 TensorFlow backend and legacy Keras.

<p align="center">
<img src="images/keras3.png" alt="keras3" />
</p>


## Use Case with different performance
There are serval use cases that can lead to diffent performance.

* Default
Users use Keras 3 and the model supports jit, the model will runs into XLA.
If user script does not contains keras related code and does not enables XLA in tensorflow. There will be performance regression. Set environment variable `ITEX_DISABLE_XLA=1` to avoid regression. After ITEX XLA disabled, users can choose wether to use NPD (default) or stream excutor for better performance by environment variable `ITEX_ENABLE_NEXTPLUGGABLE_DEVICE`.

* Legacy Keras
To continue using Keras 2.0, do the following.
1. Install `tf-keras` via `pip install tf-keras`
2. To switch `tf.keras` to use Keras 2 (`tf-keras`), set the environment variable `TF_USE_LEGACY_KERAS=1` directly or in your python program with `import os;os.environ["TF_USE_LEGACY_KERAS"]="1"`. Please note that this will set it for all packages in your Python runtime program
3. Change the keras import: replace `import keras` with `import tf_keras as keras`. Update any `from keras import ` to `from tf_keras`.

Users can choose wether to use NPD (default) or stream excutor for better performance by environment variable `ITEX_ENABLE_NEXTPLUGGABLE_DEVICE`.

* Keras 3 with jit_compile disabled
Users can disable jit_compile by `model.jit_compile=False` or `model.compile(..., jit_compile=False)`. The use of itex ops override can also lead to disabling jit_compile. In this case, `ITEX_DISABLE_XLA=1` must be set.

* Enable XLA through TensorFlow.
Users can enable XLA through TensorFlow by add environment variable `TF_XLA_FLAGS="--tf_xla_auto_jit=1"`. Use `tf_xla_auto_jit=1` for auto clustering TF ops into XLA, `tf_xla_auto_jit=2` for compiling all into XLA. Users should set `model.jit_compile=False` if keras model is used. If ITEX custom ops is used or `ITEX_OPS_OVERRIDE` is set, users should use `tf_xla_auto_jit=1` to avoid error.





## Situations leads to warning or Error
We list all invalid cases here. Keras version equals to 0 means model script does not use Keras.

Note that in any cases, `import keras` first before `import tensorflow` will cause an error due to circular import in ITEX.

| OPS_OVERRIDE | TF_AUTO_JIT_FLAG | Keras version | NPD | Jit Compile | Warning | Error | Solution |
|--------------|------------------|---------------|-----|-------------|---------|-------|----------|
| Any | 0 | 0 | 0 | NA | | PluggableDevice cannot work with latest Keras. | `ITEX_DISABLE_XLA=1` |
| Any | 0 | 0 | 1 | NA | Perf Regression | | `ITEX_DISABLE_XLA=1` |
| Any | Any | 2 | Any | 1 | | | Unkown behavior, not supported. Use `TF_AUTO_JIT_FLAG="--tf_xla_auto_jit=1"` or `2` to enable XLA |
| Any | 0 | 3 | 0 | Any | | Cannot close NPD when keras 3 | `ITEX_DISABLE_XLA=1` |
| Any | 0 | 3 | 1 | 0 | | perf regression | `ITEX_DISABLE_XLA=1` |
| Any | 1 | Any | 0 | Any | | Cannot close NPD | `ITEX_ENABLE_NEXTPLUGGABLE_DEVICE=1` |
| Any | 2 | Any | 0 | Any | | Cannot close NPD | `ITEX_ENABLE_NEXTPLUGGABLE_DEVICE=1` |
| 1 | 2 | Any | 1 | Any | custom op not supported by XLA | | `ITEX_OPS_OVERRIDE=0` |
25 changes: 23 additions & 2 deletions itex/core/kernels/xpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,30 @@ void TF_InitKernel() {
bool ops_override = false;
ITEX_CHECK_OK(
itex::ReadBoolFromEnvVar("ITEX_OPS_OVERRIDE", false, &ops_override));
// clang-format off
if (ops_override) {
PyRun_SimpleString("import intel_extension_for_tensorflow as itex;\n");
PyRun_SimpleString("itex.experimental_ops_override();\n");
PyRun_SimpleString(
"try:\n"
" import os;\n"
" if os.environ.get('TF_USE_LEGACY_KERAS', None) in ('true', 'True', '1'):\n" // NOLINT(whitespace/line_length)
" from intel_extension_for_tensorflow.python.experimental_ops_override import experimental_ops_override;\n" // NOLINT(whitespace/line_length)
" else:\n"
" from intel_extension_for_tensorflow.python.experimental_ops_override_k3 import experimental_ops_override;\n" // NOLINT(whitespace/line_length)
" from intel_extension_for_tensorflow.python.override_keras3 import override_keras3;\n" // NOLINT(whitespace/line_length)
" experimental_ops_override();\n"
" override_keras3();\n"
"except BaseException:\n"
" print('please import ITEX or tensorflow berfore keras')\n"
" quit()\n");
} else {
PyRun_SimpleString(
"try:\n"
" from intel_extension_for_tensorflow.python.override_keras3 import override_keras3;\n" // NOLINT(whitespace/line_length)
" override_keras3();\n"
"except BaseException:\n"
" print('please import ITEX or tensorflow berfore keras')\n"
" quit()\n");
}
// clang-format on
#endif // CC_BUILD
}
4 changes: 4 additions & 0 deletions itex/python/base_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@

if os.environ.get("TF_USE_LEGACY_KERAS", None) in ("true", "True", "1"):
from intel_extension_for_tensorflow.python.experimental_ops_override import experimental_ops_override
else:
from intel_extension_for_tensorflow.python.experimental_ops_override_k3 import experimental_ops_override

from intel_extension_for_tensorflow.python.override_keras3 import override_keras3
268 changes: 268 additions & 0 deletions itex/python/experimental_ops_override_k3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
# Copyright (c) 2023 Intel Corporation
#
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ITEX optimization for some TensorFlow API."""
import logging
import os
import types
import tensorflow as tf


from keras import ops


from intel_extension_for_tensorflow.python.ops.layer_norm_k3 import _layer_norm

format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
logging.basicConfig(level=logging.INFO, format=format_str)
logger = logging.getLogger(__name__)


def copy_func(f, name=None):
'''
return a function with same code, globals, defaults, closure, and
name (or provide a new name)
'''
fn = types.FunctionType(f.__code__, f.__globals__, name or f.__name__,
f.__defaults__, f.__closure__)
# in case f was given attrs (note this dict is a shallow copy):
fn.__dict__.update(f.__dict__)
return fn


def _can_use_onednn_layer_norm(self, ndims):
"""Return false if Itex layernorm implementation cannot be used.
Check if the axis is contiguous and can be collapsed into the last axis.
The self.axis is assumed to have no duplicates.
"""
self._data_format = "NHWC" # pylint: disable=protected-access
self._is_one_axis_len = None # pylint: disable=protected-access
can_use_onednn_layer_norm = True
axis = sorted(self.axis)
if axis[-1] != ndims - 1 or ndims < 2 or ndims > 4 or axis[-1] - axis[0] != len(axis) - 1: # pylint: disable=line-too-long
can_use_onednn_layer_norm = False

if can_use_onednn_layer_norm and (axis[-1] == 3 or self.axis[-1] == -1):
self.data_format = 'NHWC'

if len(axis) == 1:
self._is_one_axis_len = True # pylint: disable=protected-access
else:
self._is_one_axis_len = False # pylint: disable=protected-access

if self.dtype == 'float64':
raise ValueError(
'Itex Layernorm only support float32, bfloat16 and float16.') # pylint: disable=line-too-long

return can_use_onednn_layer_norm


def experimental_ops_override():
'''
using itex api in some tf and keras functions.
'''
try:
from pkg_resources import packaging # pylint: disable=import-outside-toplevel
version = packaging.version.parse
if version(tf.__version__) < version("2.16.1"):
return

from keras.src import backend # pylint: disable=import-outside-toplevel
from keras.src.utils import tf_utils # pylint: disable=import-outside-toplevel

import keras
tf_ln_call = copy_func(keras.layers.LayerNormalization.call)
tf_gn_call = copy_func(keras.layers.GroupNormalization.call)
tf_gn_build = copy_func(keras.layers.GroupNormalization.build)

except BaseException: # pylint: disable=broad-except
return

def itex_layer_norm_build(self, input_shape):
self.supports_jit = False
if self.compute_dtype == "float16" or self.compute_dtype == "bfloat16": # pylint: disable=no-else-return
self._param_dtype = "float32"
else:
self._param_dtype = self.dtype or dtypes.float32
ndims = len(input_shape)
if ndims is None:
raise ValueError(
'Input shape %s has undefined rank.' % input_shape)
if isinstance(self.axis, list):
shape = tuple([input_shape[dim] for dim in self.axis])
else:
shape = (input_shape[self.axis],)
self.axis = [self.axis]
for idx, x in enumerate(self.axis):
if x < 0:
self.axis[idx] = ndims + x
param_shape = [input_shape[dim] for dim in self.axis]
if self.scale or self.rms_scaling:
self.gamma = self.add_weight(
name="gamma",
shape=shape,
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint,
trainable=True,
dtype=self._param_dtype,
)
else:
self.gamma = None
self._gamma_const = ops.ones(
dtype=self._param_dtype, shape=param_shape)

if self.center and not self.rms_scaling:
self.beta = self.add_weight(
name="beta",
shape=shape,
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint,
trainable=True,
dtype=self._param_dtype,
)
else:
self.beta = None
self._beta_const = ops.zeros(
dtype=self._param_dtype, shape=param_shape)
self._use_layernorm = _can_use_onednn_layer_norm(self, ndims)
self.built = True

def _layer_norm_inference_or_training(self, inputs, gamma, beta, training):
"""Returns the output of layer norm."""
def _layer_norm_training():
return _layer_norm(
inputs,
scale=gamma,
offset=beta,
epsilon=self.epsilon,
is_training=True,
data_format=self._data_format)

def _layer_norm_inference():
return _layer_norm(
inputs,
scale=gamma,
offset=beta,
epsilon=self.epsilon,
is_training=False,
data_format=self._data_format)

output, _, _ = tf.__internal__.smart_cond.smart_cond(
training, _layer_norm_training, _layer_norm_inference)
return output

def itex_layer_norm_call(self, inputs, training=None):
if not self._use_layernorm: # pylint: disable=protected-access
return tf_ln_call(self, inputs) # pylint: disable=not-callable
if self.rms_scaling: # pylint: disable=protected-access
return tf_ln_call(self, inputs) # pylint: disable=not-callable
if training is None:
is_training = True
if isinstance(training, int):
is_training = bool(training)
if not self.trainable:
# When the layer is not trainable, it overrides the value passed from
# model.
is_training = False
# Compute the axes along which to reduce the mean / variance
inputs = ops.cast(inputs, self.compute_dtype)
# Compute the axes along which to reduce the mean / variance
input_shape = inputs.shape
ndims = len(input_shape)

# Broadcasting only necessary for norm when the axis is not just
# the last dimension
broadcast_shape = [1] * ndims
for dim in self.axis:
broadcast_shape[dim] = input_shape[dim]

def _broadcast(v):
if (
v is not None
and len(v.shape) != ndims
and self.axis != [ndims - 1]
):
return ops.reshape(v, broadcast_shape)
return v

input_dtype = inputs.dtype
if input_dtype in (tf.float16, tf.bfloat16) and self.dtype == "float32" and not self._use_layernorm:
# If mixed precision is used, cast inputs to float32 so that
# this is at least as numerically stable as the fused version.
inputs = ops.cast(inputs, "float32")

beta = self.beta if self.beta is not None else self._beta_const
gamma = self.gamma if self.gamma is not None else self._gamma_const
if self._is_one_axis_len:
outputs = _layer_norm_inference_or_training(self, inputs, gamma, beta,
is_training)
return outputs
else:
# Collapse dims before self.axis, and dims in self.axis
pre_dim, in_dim = (1, 1)
axis = sorted(self.axis)
tensor_shape = inputs.shape
for dim in range(0, ndims):
dim_tensor = tensor_shape[dim]
if dim < axis[0]:
pre_dim = pre_dim * dim_tensor
else:
assert dim in axis
in_dim = in_dim * dim_tensor

squeezed_shape = [1, pre_dim, in_dim]
inputs = ops.reshape(inputs, squeezed_shape)

# self.gamma and self.beta have the wrong shape for layer_norm, so
# we cannot pass them as the scale and offset parameters. Therefore, we
# create two constant tensors in correct shapes for layer_norm and
# later construct a separate calculation on the scale and offset.
scale = ops.ones([in_dim], dtype="float32")
offset = ops.zeros([in_dim], dtype="float32")

# Compute layer normalization.
outputs = _layer_norm_inference_or_training(self, inputs, scale,
offset, is_training)
outputs = ops.reshape(outputs, tensor_shape)
scale, offset = _broadcast(
self.gamma), _broadcast(self.beta)

if scale is not None:
outputs = outputs * ops.cast(scale, outputs.dtype)
if offset is not None:
outputs = outputs + ops.cast(offset, outputs.dtype)
return outputs

try:
keras.layers.LayerNormalization.call = itex_layer_norm_call
keras.layers.LayerNormalization.build = itex_layer_norm_build
logger.info("itex experimental ops override is enabled.")
except BaseException: # pylint: disable=broad-except
logger.error("Cannot override itex ops.")
try:
import keras # pylint: disable=import-outside-toplevel
keras.src.layers.normalization.layer_normalization.LayerNormalization.call = itex_layer_norm_call
keras.src.layers.normalization.layer_normalization.LayerNormalization.build = itex_layer_norm_build
except BaseException: # pylint: disable=broad-except
logger.warning(
"itex experimental ops override: Keras is not installed.") # pylint: disable=line-too-long
Loading

0 comments on commit 55a4394

Please sign in to comment.