Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add preliminary support of OpenVINO as Keras 3 backend #19727

Open
wants to merge 62 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
76d4c1a
[POC][OV] Support OpenVINO as Keras 3 backend
rkazants May 17, 2024
af254e9
Merge remote-tracking branch 'upstream/master' into rkazants/poc_open…
rkazants Sep 23, 2024
3e0772f
Mark all unsupported ops from numpy space
rkazants Sep 23, 2024
c2c1808
Mark unsupported ops in core, image, and linalg spaces
rkazants Sep 23, 2024
d26c22c
Mark unsupported ops in math, nn, random, and rnn spaces
rkazants Sep 23, 2024
9e29a73
Fix sorting imports
rkazants Sep 24, 2024
716a0bb
Format imports
rkazants Sep 24, 2024
c097126
Fix sorting imports
rkazants Sep 24, 2024
65fe256
Fix sorting imports
rkazants Sep 24, 2024
fe29eb9
Fix inference
rkazants Sep 24, 2024
f57847b
Remove openvino specific code in common part
rkazants Sep 24, 2024
8fb2dc5
Fix typo
rkazants Sep 24, 2024
3d8b41a
Clean-up code
rkazants Sep 25, 2024
08b74dc
Recover imports
rkazants Sep 25, 2024
52408b7
Sort imports properly
rkazants Sep 25, 2024
f41b454
Format source code
rkazants Sep 25, 2024
42364c2
Format the rest of source code
rkazants Sep 25, 2024
954ed1f
Continue format adjustment
rkazants Sep 25, 2024
facde41
Add OpenVINO dependency
rkazants Sep 25, 2024
22d5c17
Merge remote-tracking branch 'upstream/master' into rkazants/poc_open…
rkazants Oct 31, 2024
fa6b461
Fix inference using OV backend
rkazants Nov 1, 2024
0b1537d
Support bert_base_en_uncased and mobilenet_v3_small from Keras Hub
rkazants Nov 25, 2024
64681ab
Merge remote-tracking branch 'upstream/master' into rkazants/poc_open…
rkazants Nov 25, 2024
9d31b56
Remove extra openvino specific code from layer.py
rkazants Nov 25, 2024
9785746
Apply code-style formatting
rkazants Nov 25, 2024
43bd76d
Apply code-style formatting
rkazants Nov 25, 2024
4f68322
Fix remained code-style issue
rkazants Nov 25, 2024
34549b8
Run tests for OpenVINO backend in GHA
rkazants Nov 25, 2024
a418ede
Add config file for openvino backend validation
rkazants Nov 25, 2024
4dae81d
Add import test for openvino backend
rkazants Nov 25, 2024
1559ea0
Fix error in import_test.py
rkazants Nov 25, 2024
e72b32e
Add import_test for openvino backend
rkazants Nov 25, 2024
f086752
Add openvino specific integration tests in GHA
rkazants Nov 25, 2024
389c3f5
Exclude coverage for OpenVINO
rkazants Nov 25, 2024
770ba91
remove coverage for openvino backend
rkazants Nov 25, 2024
9162c15
Try layer tests for openvino backend
rkazants Nov 26, 2024
b5d7413
Run layer tests for openvino backend selectively
rkazants Nov 27, 2024
7ae168f
Mark enabled tests for openvino backend in a different way
rkazants Nov 27, 2024
45a5ed8
Update .github/workflows/actions.yml
rkazants Nov 27, 2024
702a719
Merge remote-tracking branch 'upstream/master' into rkazants/poc_open…
rkazants Nov 27, 2024
cf6542a
Fix import for BackendVariable
rkazants Nov 27, 2024
6026728
Fix errors in layer tests for openvino backend
rkazants Nov 27, 2024
e3e2aed
Add test for Elu via openvino backend
rkazants Nov 28, 2024
796f509
Fix sorted imports
rkazants Nov 28, 2024
758e774
Extend testing for attention
rkazants Nov 28, 2024
34dbfe0
Update keras/src/layers/attention/attention_test.py
rkazants Nov 28, 2024
c1f7604
Switch on activation tests for openvino backend
rkazants Nov 29, 2024
a2bfa71
Merge remote-tracking branch 'origin/rkazants/poc_openvino_backend' i…
rkazants Nov 29, 2024
8bf05d7
Switch on attention tests for openvino backend
rkazants Nov 29, 2024
b2024cf
Update keras/src/layers/attention/additive_attention_test.py
rkazants Nov 29, 2024
1ffc4af
Update keras/src/layers/attention/grouped_query_attention_test.py
rkazants Nov 29, 2024
95f4207
Run conv tests for openvino backend
rkazants Nov 29, 2024
cb74641
Merge remote-tracking branch 'origin/rkazants/poc_openvino_backend' i…
rkazants Nov 29, 2024
dce9f33
Fix convolution in openvino backend
rkazants Nov 29, 2024
39cce17
Work around constant creation for tuple
rkazants Nov 29, 2024
ffc4e22
Work around constant creation in reshape
rkazants Nov 29, 2024
437193a
Run depthwise conv tests for openvino backend
rkazants Nov 30, 2024
a316c7a
Fix get_ov_output for other x types
rkazants Dec 1, 2024
779235a
Fix elu translation
rkazants Dec 1, 2024
15fc0d7
Fix softmax and log_softmax for None axis
rkazants Dec 2, 2024
cba959a
Run nn tests for openvino backend
rkazants Dec 2, 2024
9ab6412
Merge remote-tracking branch 'upstream/master' into rkazants/poc_open…
rkazants Dec 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
elif backend() == "numpy":
from keras.src.backend.numpy import * # noqa: F403

distribution_lib = None
elif backend() == "openvino":
from keras.src.backend.openvino import * # noqa: F403

distribution_lib = None
else:
raise ValueError(f"Unable to import backend : {backend()}")
5 changes: 5 additions & 0 deletions keras/src/backend/exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@

BackendVariable = NumpyVariable
backend_name_scope = backend.common.name_scope.name_scope
elif backend.backend() == "openvino":
from keras.src.backend.openvino.core import Variable as OpenVINOVariable

BackendVariable = OpenVINOVariable
backend_name_scope = backend.common.name_scope.name_scope
else:
raise RuntimeError(f"Invalid backend: {backend.backend()}")

Expand Down
21 changes: 21 additions & 0 deletions keras/src/backend/openvino/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from keras.src.backend.openvino import core
from keras.src.backend.openvino import image
from keras.src.backend.openvino import linalg
from keras.src.backend.openvino import math
from keras.src.backend.openvino import nn
from keras.src.backend.openvino import numpy
from keras.src.backend.openvino import random
from keras.src.backend.openvino.core import SUPPORTS_SPARSE_TENSORS
from keras.src.backend.openvino.core import Variable
from keras.src.backend.openvino.core import cast
from keras.src.backend.openvino.core import compute_output_spec
from keras.src.backend.openvino.core import cond
from keras.src.backend.openvino.core import convert_to_numpy
from keras.src.backend.openvino.core import convert_to_tensor
from keras.src.backend.openvino.core import is_tensor
from keras.src.backend.openvino.core import shape
from keras.src.backend.openvino.core import vectorized_map
from keras.src.backend.openvino.rnn import cudnn_ok
from keras.src.backend.openvino.rnn import gru
from keras.src.backend.openvino.rnn import lstm
from keras.src.backend.openvino.rnn import rnn
258 changes: 258 additions & 0 deletions keras/src/backend/openvino/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import contextlib

import numpy as np

from keras.src.backend.common import global_state
from keras.src import tree
from keras.src.backend.common import KerasVariable
from keras.src.backend.common import standardize_dtype
from keras.src.backend.common.dtypes import result_type
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.backend.common.stateless_scope import StatelessScope
import openvino as ov

SUPPORTS_SPARSE_TENSORS = False

OPENVINO_DTYPES = {
"float16": ov.Type.f16,
"float32": ov.Type.f32,
"float64": ov.Type.f64,
"uint8": ov.Type.u8,
"uint16": ov.Type.u16,
"uint32": ov.Type.u32,
"int8": ov.Type.i8,
"int16": ov.Type.i16,
"int32": ov.Type.i32,
"int64": ov.Type.i64,
"bfloat16": ov.Type.bf16,
"bool": ov.Type.boolean,
"float8_e4m3fn": ov.Type.f8e4m3,
"float8_e5m2": ov.Type.f8e5m2,
}


def ov_to_keras_type(ov_type):
for _keras_type, _ov_type in OPENVINO_DTYPES.items():
if ov_type == _ov_type:
return _keras_type
raise ValueError(
f"Requested OpenVINO type has no keras analogue '{ov_type.to_string()}'"
)

@contextlib.contextmanager
def device_scope(device_name):
current_device = _parse_device_input(device_name)
global_state.set_global_attribute("openvino_device", current_device)


def get_device():
device = global_state.get_global_attribute("openvino_device", None)
if device is None:
return "CPU"
return device


def _parse_device_input(device_name):
if isinstance(device_name, str):
# We support string value like "cpu:0", "gpu:1", and need to convert
# "gpu" to "cuda"
device_name = device_name.upper()
device_type, _ = device_name.split(":")
return device_type
else:
raise ValueError(
"Invalid value for argument `device_name`. "
"Expected a string like 'gpu:0' or 'cpu'. "
f"Received: device_name='{device_name}'"
)
return device_name


class Variable(KerasVariable):
def _initialize(self, value):
self._value = np.array(value, dtype=self._dtype)

def _direct_assign(self, value):
self._value = np.array(value, dtype=self._dtype)

def _convert_to_tensor(self, value, dtype=None):
return convert_to_tensor(value, dtype=dtype)

# Overload native accessor.
def __array__(self):
return self.value


def convert_to_tensor(x, dtype=None, sparse=None):
if sparse:
raise ValueError("`sparse=True` is not supported with numpy backend")
if dtype is not None:
dtype = standardize_dtype(dtype)
if isinstance(x, Variable):
if dtype and dtype != x.dtype:
return x.value.astype(dtype)
return x.value
if not is_tensor(x) and standardize_dtype(dtype) == "bfloat16":
# Can't create bfloat16 arrays on the fly (e.g. from a h5 Dataset).
# Instead we convert "as is" (to stored dtype) and cast.
return np.asarray(x).astype(dtype)
if dtype is None:
dtype = result_type(
*[getattr(item, "dtype", type(item)) for item in tree.flatten(x)]
)
return np.array(x, dtype=dtype)


def convert_to_numpy(x):
return np.array(x)


def is_tensor(x):
if isinstance(x, (np.generic, np.ndarray)):
return True
return False


def shape(x):
return x.shape


def cast(x, dtype):
raise NotImplementedError(
"`cast` is not supported with openvino backend"
)


def cond(pred, true_fn, false_fn):
raise NotImplementedError(
"`cond` is not supported with openvino backend"
)


def vectorized_map(function, elements):
raise NotImplementedError(
"`vectorized_map` is not supported with openvino backend"
)


# Shape / dtype inference util
def compute_output_spec(fn, *args, **kwargs):
with StatelessScope():

def has_none_shape(x):
if isinstance(x, KerasTensor):
return None in x.shape
return False

none_in_shape = any(map(has_none_shape, tree.flatten((args, kwargs))))

def convert_keras_tensor_to_numpy(x, fill_value=None):
if isinstance(x, KerasTensor):
shape = list(x.shape)
if fill_value:
for i, e in enumerate(shape):
if e is None:
shape[i] = fill_value
return np.empty(
shape=shape,
dtype=x.dtype,
)
return x

args_1, kwargs_1 = tree.map_structure(
lambda x: convert_keras_tensor_to_numpy(x, fill_value=83),
(args, kwargs),
)
outputs_1 = fn(*args_1, **kwargs_1)

outputs = outputs_1

if none_in_shape:
args_2, kwargs_2 = tree.map_structure(
lambda x: convert_keras_tensor_to_numpy(x, fill_value=89),
(args, kwargs),
)
outputs_2 = fn(*args_2, **kwargs_2)

flat_out_1 = tree.flatten(outputs_1)
flat_out_2 = tree.flatten(outputs_2)

flat_out = []
for x1, x2 in zip(flat_out_1, flat_out_2):
shape = list(x1.shape)
for i, e in enumerate(x2.shape):
if e != shape[i]:
shape[i] = None
flat_out.append(KerasTensor(shape, standardize_dtype(x1.dtype)))
outputs = tree.pack_sequence_as(outputs_1, flat_out)

def convert_numpy_to_keras_tensor(x):
if is_tensor(x):
return KerasTensor(x.shape, standardize_dtype(x.dtype))
return x

output_spec = tree.map_structure(convert_numpy_to_keras_tensor, outputs)
return output_spec


def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
raise NotImplementedError(
"`scan` is not supported with openvino backend"
)


def scatter(indices, values, shape):
raise NotImplementedError(
"`scatter` is not supported with openvino backend"
)


def scatter_update(inputs, indices, updates):
raise NotImplementedError(
"`scatter_update` is not supported with openvino backend"
)


def slice(inputs, start_indices, lengths):
raise NotImplementedError(
"`slice` is not supported with openvino backend"
)


def slice_update(inputs, start_indices, updates):
raise NotImplementedError(
"`slice_update` is not supported with openvino backend"
)


def while_loop(
cond,
body,
loop_vars,
maximum_iterations=None,
):
raise NotImplementedError(
"`while_loop` is not supported with openvino backend"
)


def fori_loop(lower, upper, body_fun, init_val):
raise NotImplementedError(
"`fori_loop` is not supported with openvino backend"
)


def stop_gradient(x):
return x


def unstack(x, num=None, axis=0):
raise NotImplementedError(
"`unstack` is not supported with openvino backend"
)


def custom_gradient(fun):
raise NotImplementedError(
"`custom_gradient` is not supported with numpy backend"
)
Loading
Loading