From aece95346d70744f79c91dc4e3103326e0b8c1fe Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Tue, 19 Nov 2024 13:26:52 +0200 Subject: [PATCH 01/15] WIP: Added jina clip text embedding --- fastembed/text/clip_embedding.py | 11 +++ tests/test_text_onnx_embeddings.py | 116 ++++++++++++++++------------- 2 files changed, 75 insertions(+), 52 deletions(-) diff --git a/fastembed/text/clip_embedding.py b/fastembed/text/clip_embedding.py index a757d875..2d71c282 100644 --- a/fastembed/text/clip_embedding.py +++ b/fastembed/text/clip_embedding.py @@ -18,6 +18,17 @@ }, "model_file": "model.onnx", }, + { + "model": "jinaai/jina-clip-v1", + "dim": 768, + "description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year", + "license": "apache-2.0", + "size_in_GB": 0.55, + "sources": { + "hf": "jinaai/jina-clip-v1", + }, + "model_file": "onnx/text_model.onnx", + }, ] diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index f576330c..74fd921d 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -1,7 +1,6 @@ import os import numpy as np -import pytest from fastembed.text.text_embedding import TextEmbedding from tests.utils import delete_model_cache @@ -64,19 +63,32 @@ ), "snowflake/snowflake-arctic-embed-l": np.array([0.0189, -0.0673, 0.0183, 0.0124, 0.0146]), "Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]), + "jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]), } def test_embedding(): is_ci = os.getenv("CI") - for model_desc in TextEmbedding.list_supported_models(): + for model_desc in [ + { + "model": "jinaai/jina-clip-v1", + "dim": 768, + "description": "Text embeddings, Multimodal (text&image), English, 77 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year", + "license": "apache-2.0", + "size_in_GB": 0.55, + "sources": { + "hf": "jinaai/jina-clip-v1", + }, + "model_file": "onnx/text_model.onnx", + } + ]: if not is_ci and model_desc["size_in_GB"] > 1: continue dim = model_desc["dim"] - model = TextEmbedding(model_name=model_desc["model"]) + model = TextEmbedding(model_name=model_desc["model"], cache_dir="models") docs = ["hello world", "flag embedding"] embeddings = list(model.embed(docs)) embeddings = np.stack(embeddings, axis=0) @@ -90,66 +102,66 @@ def test_embedding(): delete_model_cache(model.model._model_dir) -@pytest.mark.parametrize( - "n_dims,model_name", - [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")], -) -def test_batch_embedding(n_dims, model_name): - is_ci = os.getenv("CI") - model = TextEmbedding(model_name=model_name) +# @pytest.mark.parametrize( +# "n_dims,model_name", +# [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")], +# ) +# def test_batch_embedding(n_dims, model_name): +# is_ci = os.getenv("CI") +# model = TextEmbedding(model_name=model_name) - docs = ["hello world", "flag embedding"] * 100 - embeddings = list(model.embed(docs, batch_size=10)) - embeddings = np.stack(embeddings, axis=0) +# docs = ["hello world", "flag embedding"] * 100 +# embeddings = list(model.embed(docs, batch_size=10)) +# embeddings = np.stack(embeddings, axis=0) - assert embeddings.shape == (200, n_dims) - if is_ci: - delete_model_cache(model.model._model_dir) +# assert embeddings.shape == (200, n_dims) +# if is_ci: +# delete_model_cache(model.model._model_dir) -@pytest.mark.parametrize( - "n_dims,model_name", - [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")], -) -def test_parallel_processing(n_dims, model_name): - is_ci = os.getenv("CI") - model = TextEmbedding(model_name=model_name) +# @pytest.mark.parametrize( +# "n_dims,model_name", +# [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")], +# ) +# def test_parallel_processing(n_dims, model_name): +# is_ci = os.getenv("CI") +# model = TextEmbedding(model_name=model_name) - docs = ["hello world", "flag embedding"] * 100 - embeddings = list(model.embed(docs, batch_size=10, parallel=2)) - embeddings = np.stack(embeddings, axis=0) +# docs = ["hello world", "flag embedding"] * 100 +# embeddings = list(model.embed(docs, batch_size=10, parallel=2)) +# embeddings = np.stack(embeddings, axis=0) - embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) - embeddings_2 = np.stack(embeddings_2, axis=0) +# embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) +# embeddings_2 = np.stack(embeddings_2, axis=0) - embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) - embeddings_3 = np.stack(embeddings_3, axis=0) +# embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) +# embeddings_3 = np.stack(embeddings_3, axis=0) - assert embeddings.shape == (200, n_dims) - assert np.allclose(embeddings, embeddings_2, atol=1e-3) - assert np.allclose(embeddings, embeddings_3, atol=1e-3) +# assert embeddings.shape == (200, n_dims) +# assert np.allclose(embeddings, embeddings_2, atol=1e-3) +# assert np.allclose(embeddings, embeddings_3, atol=1e-3) - if is_ci: - delete_model_cache(model.model._model_dir) +# if is_ci: +# delete_model_cache(model.model._model_dir) -@pytest.mark.parametrize( - "model_name", - ["BAAI/bge-small-en-v1.5"], -) -def test_lazy_load(model_name): - is_ci = os.getenv("CI") - model = TextEmbedding(model_name=model_name, lazy_load=True) - assert not hasattr(model.model, "model") - docs = ["hello world", "flag embedding"] - list(model.embed(docs)) - assert hasattr(model.model, "model") +# @pytest.mark.parametrize( +# "model_name", +# ["BAAI/bge-small-en-v1.5"], +# ) +# def test_lazy_load(model_name): +# is_ci = os.getenv("CI") +# model = TextEmbedding(model_name=model_name, lazy_load=True) +# assert not hasattr(model.model, "model") +# docs = ["hello world", "flag embedding"] +# list(model.embed(docs)) +# assert hasattr(model.model, "model") - model = TextEmbedding(model_name=model_name, lazy_load=True) - list(model.query_embed(docs)) +# model = TextEmbedding(model_name=model_name, lazy_load=True) +# list(model.query_embed(docs)) - model = TextEmbedding(model_name=model_name, lazy_load=True) - list(model.passage_embed(docs)) +# model = TextEmbedding(model_name=model_name, lazy_load=True) +# list(model.passage_embed(docs)) - if is_ci: - delete_model_cache(model.model._model_dir) +# if is_ci: +# delete_model_cache(model.model._model_dir) From 5b2cf1ad492d78714ca93035f50d5d141d38d172 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Tue, 19 Nov 2024 13:42:12 +0200 Subject: [PATCH 02/15] WIP: Added preprocess for jina clip --- fastembed/image/transform/functional.py | 17 ++++++- fastembed/image/transform/operators.py | 61 +++++++++++++++++++++++-- 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/fastembed/image/transform/functional.py b/fastembed/image/transform/functional.py index 70da2a22..17b103c7 100644 --- a/fastembed/image/transform/functional.py +++ b/fastembed/image/transform/functional.py @@ -1,7 +1,7 @@ from typing import Sized, Union import numpy as np -from PIL import Image +from PIL import Image, ImageOps def convert_to_rgb(image: Image.Image) -> Image.Image: @@ -122,3 +122,18 @@ def pil2ndarray(image: Union[Image.Image, np.ndarray]): if isinstance(image, Image.Image): return np.asarray(image).transpose((2, 0, 1)) return image + + +def pad2square( + image: Image, + fill_color: str | int | tuple[int, ...] | None = None, + resample: Image.Resampling = Image.Resampling.BILINEAR, +): + width, height = image.size + max_dim = max(width, height) + return ImageOps.pad( + image=image, + size=(max_dim, max_dim), + color=fill_color, + method=resample, + ) diff --git a/fastembed/image/transform/operators.py b/fastembed/image/transform/operators.py index 2b943dbb..56a83888 100644 --- a/fastembed/image/transform/operators.py +++ b/fastembed/image/transform/operators.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any, Union, Optional import numpy as np from PIL import Image @@ -10,6 +10,7 @@ pil2ndarray, rescale, resize, + pad2sqaure, ) @@ -66,6 +67,22 @@ def __call__(self, images: list[Union[Image.Image, np.ndarray]]) -> list[np.ndar return [pil2ndarray(image) for image in images] +class PadtoSquare(Transform): + def __init__( + self, + fill_color: Optional[Union[str, int, tuple[int, ...]]] = None, + resample: Image.Resampling = Image.Resampling.BICUBIC, + ): + self.fill_color = fill_color + self.resample = resample + + def __call__(self, images: list[np.ndarray]) -> list[np.ndarray]: + return [ + pad2sqaure(image=image, fill_color=self.fill_color, resample=self.resample) + for image in images + ] + + class Compose: def __init__(self, transforms: list[Transform]): self.transforms = transforms @@ -85,14 +102,20 @@ def from_config(cls, config: dict[str, Any]) -> "Compose": Valid keys: - do_resize + - resize_mode - size + - fill_color - do_center_crop - crop_size - do_rescale - rescale_factor - do_normalize - image_mean + - mean - image_std + - std + - resample + - interpolation Valid size keys (nested): - {"height", "width"} - {"shortest_edge"} @@ -103,6 +126,7 @@ def from_config(cls, config: dict[str, Any]) -> "Compose": transforms = [] cls._get_convert_to_rgb(transforms, config) cls._get_resize(transforms, config) + cls._get_padtosquare(transforms, config) cls._get_center_crop(transforms, config) cls._get_pil2ndarray(transforms, config) cls._get_rescale(transforms, config) @@ -157,6 +181,18 @@ def _get_resize(transforms: list[Transform], config: dict[str, Any]): resample=config.get("resample", Image.Resampling.BICUBIC), ) ) + elif mode == "JinaCLIPImageProcessor": + if "size" in config: + resize_mode = config.get("resize_mode", "shortest") + if resize_mode == "shortest": + transforms.append( + Resize( + size=config["size"], + resample=config.get("interpolation", Image.Resampling.BICUBIC), + ) + ) + else: + raise ValueError(f"Preprocessor {mode} is not supported") @staticmethod def _get_center_crop(transforms: list[Transform], config: dict[str, Any]): @@ -173,6 +209,8 @@ def _get_center_crop(transforms: list[Transform], config: dict[str, Any]): transforms.append(CenterCrop(size=crop_size)) elif mode == "ConvNextFeatureExtractor": pass + elif mode == "JinaCLIPImageProcessor": + pass else: raise ValueError(f"Preprocessor {mode} is not supported") @@ -188,5 +226,22 @@ def _get_rescale(transforms: list[Transform], config: dict[str, Any]): @staticmethod def _get_normalize(transforms: list[Transform], config: dict[str, Any]): - if config.get("do_normalize", False): - transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"])) + if config.get("do_normalize", False) or ("mean" in config and "std" in config): + transforms.append( + Normalize( + mean=config["image_mean"] or config["mean"], + std=config["image_std"] or config["std"], + ) + ) + + @staticmethod + def _get_padtosquare(transforms: list[Transform], config: dict[str, Any]): + if config.get("do_pad_to_square", False): + transforms.append( + PadtoSquare( + fill_color=config["fill_color"], + resample=config.get("interpolation") + or config.get("resample") + or Image.Resampling.BICUBIC, + ) + ) From 42d46d7b37a435c4cdb65a11ee58f97b282b5a36 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Tue, 19 Nov 2024 13:42:48 +0200 Subject: [PATCH 03/15] WIP: Added jina clip vision (not sure if it works yet) --- fastembed/image/onnx_embedding.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/fastembed/image/onnx_embedding.py b/fastembed/image/onnx_embedding.py index 5647c2ff..47beae22 100644 --- a/fastembed/image/onnx_embedding.py +++ b/fastembed/image/onnx_embedding.py @@ -53,6 +53,17 @@ }, "model_file": "model.onnx", }, + { + "model": "jinaai/jina-clip-v1", + "dim": 768, + "description": "Image embeddings, Multimodal (text&image), 2024 year", + "license": "apache-2.0", + "size_in_GB": 0.34, + "sources": { + "hf": "jinaai/jina-clip-v1", + }, + "model_file": "onnx/vision_model.onnx", + }, ] From 6399682aa74d535db391d6d8a761109e394d6add Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Wed, 20 Nov 2024 13:28:44 +0200 Subject: [PATCH 04/15] improve: Improved mean pooling if the output doesnt have seq length --- fastembed/text/pooled_embedding.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fastembed/text/pooled_embedding.py b/fastembed/text/pooled_embedding.py index 526f2df5..069396e9 100644 --- a/fastembed/text/pooled_embedding.py +++ b/fastembed/text/pooled_embedding.py @@ -50,6 +50,10 @@ def _get_worker_class(cls) -> Type[TextEmbeddingWorker]: @classmethod def mean_pooling(cls, model_output: np.ndarray, attention_mask: np.ndarray) -> np.ndarray: + if model_output.ndim == 2: # (batch, embedding_dim) + seq_length = attention_mask.shape[1] + # (batch, seq_length, embedding_dim) + model_output = np.tile(np.expand_dims(model_output, axis=1), (1, seq_length, 1)) token_embeddings = model_output input_mask_expanded = np.expand_dims(attention_mask, axis=-1) input_mask_expanded = np.tile(input_mask_expanded, (1, 1, token_embeddings.shape[-1])) From e6c3f5138803ad292bdd01bd13b84c8f2096826b Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Wed, 20 Nov 2024 13:29:22 +0200 Subject: [PATCH 05/15] fix: Fixed jina clip text --- fastembed/text/clip_embedding.py | 11 ----------- fastembed/text/pooled_normalized_embedding.py | 11 +++++++++++ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/fastembed/text/clip_embedding.py b/fastembed/text/clip_embedding.py index 2d71c282..a757d875 100644 --- a/fastembed/text/clip_embedding.py +++ b/fastembed/text/clip_embedding.py @@ -18,17 +18,6 @@ }, "model_file": "model.onnx", }, - { - "model": "jinaai/jina-clip-v1", - "dim": 768, - "description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year", - "license": "apache-2.0", - "size_in_GB": 0.55, - "sources": { - "hf": "jinaai/jina-clip-v1", - }, - "model_file": "onnx/text_model.onnx", - }, ] diff --git a/fastembed/text/pooled_normalized_embedding.py b/fastembed/text/pooled_normalized_embedding.py index 70660420..a938d47b 100644 --- a/fastembed/text/pooled_normalized_embedding.py +++ b/fastembed/text/pooled_normalized_embedding.py @@ -75,6 +75,17 @@ "sources": {"hf": "jinaai/jina-embeddings-v2-base-es"}, "model_file": "onnx/model.onnx", }, + { + "model": "jinaai/jina-clip-v1", + "dim": 768, + "description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year", + "license": "apache-2.0", + "size_in_GB": 0.55, + "sources": { + "hf": "jinaai/jina-clip-v1", + }, + "model_file": "onnx/text_model.onnx", + }, ] From 6418b90f0068f11c5fe5c3d1a13474b0b668d10b Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Wed, 20 Nov 2024 13:32:54 +0200 Subject: [PATCH 06/15] nit --- tests/test_text_onnx_embeddings.py | 113 +++++++++++++---------------- 1 file changed, 51 insertions(+), 62 deletions(-) diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 74fd921d..0db3f727 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -1,6 +1,7 @@ import os import numpy as np +import pytest from fastembed.text.text_embedding import TextEmbedding from tests.utils import delete_model_cache @@ -70,19 +71,7 @@ def test_embedding(): is_ci = os.getenv("CI") - for model_desc in [ - { - "model": "jinaai/jina-clip-v1", - "dim": 768, - "description": "Text embeddings, Multimodal (text&image), English, 77 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year", - "license": "apache-2.0", - "size_in_GB": 0.55, - "sources": { - "hf": "jinaai/jina-clip-v1", - }, - "model_file": "onnx/text_model.onnx", - } - ]: + for model_desc in TextEmbedding.list_supported_models(): if not is_ci and model_desc["size_in_GB"] > 1: continue @@ -102,66 +91,66 @@ def test_embedding(): delete_model_cache(model.model._model_dir) -# @pytest.mark.parametrize( -# "n_dims,model_name", -# [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")], -# ) -# def test_batch_embedding(n_dims, model_name): -# is_ci = os.getenv("CI") -# model = TextEmbedding(model_name=model_name) +@pytest.mark.parametrize( + "n_dims,model_name", + [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")], +) +def test_batch_embedding(n_dims, model_name): + is_ci = os.getenv("CI") + model = TextEmbedding(model_name=model_name) -# docs = ["hello world", "flag embedding"] * 100 -# embeddings = list(model.embed(docs, batch_size=10)) -# embeddings = np.stack(embeddings, axis=0) + docs = ["hello world", "flag embedding"] * 100 + embeddings = list(model.embed(docs, batch_size=10)) + embeddings = np.stack(embeddings, axis=0) -# assert embeddings.shape == (200, n_dims) -# if is_ci: -# delete_model_cache(model.model._model_dir) + assert embeddings.shape == (200, n_dims) + if is_ci: + delete_model_cache(model.model._model_dir) -# @pytest.mark.parametrize( -# "n_dims,model_name", -# [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")], -# ) -# def test_parallel_processing(n_dims, model_name): -# is_ci = os.getenv("CI") -# model = TextEmbedding(model_name=model_name) +@pytest.mark.parametrize( + "n_dims,model_name", + [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")], +) +def test_parallel_processing(n_dims, model_name): + is_ci = os.getenv("CI") + model = TextEmbedding(model_name=model_name) -# docs = ["hello world", "flag embedding"] * 100 -# embeddings = list(model.embed(docs, batch_size=10, parallel=2)) -# embeddings = np.stack(embeddings, axis=0) + docs = ["hello world", "flag embedding"] * 100 + embeddings = list(model.embed(docs, batch_size=10, parallel=2)) + embeddings = np.stack(embeddings, axis=0) -# embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) -# embeddings_2 = np.stack(embeddings_2, axis=0) + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) + embeddings_2 = np.stack(embeddings_2, axis=0) -# embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) -# embeddings_3 = np.stack(embeddings_3, axis=0) + embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) + embeddings_3 = np.stack(embeddings_3, axis=0) -# assert embeddings.shape == (200, n_dims) -# assert np.allclose(embeddings, embeddings_2, atol=1e-3) -# assert np.allclose(embeddings, embeddings_3, atol=1e-3) + assert embeddings.shape == (200, n_dims) + assert np.allclose(embeddings, embeddings_2, atol=1e-3) + assert np.allclose(embeddings, embeddings_3, atol=1e-3) -# if is_ci: -# delete_model_cache(model.model._model_dir) + if is_ci: + delete_model_cache(model.model._model_dir) -# @pytest.mark.parametrize( -# "model_name", -# ["BAAI/bge-small-en-v1.5"], -# ) -# def test_lazy_load(model_name): -# is_ci = os.getenv("CI") -# model = TextEmbedding(model_name=model_name, lazy_load=True) -# assert not hasattr(model.model, "model") -# docs = ["hello world", "flag embedding"] -# list(model.embed(docs)) -# assert hasattr(model.model, "model") +@pytest.mark.parametrize( + "model_name", + ["BAAI/bge-small-en-v1.5"], +) +def test_lazy_load(model_name): + is_ci = os.getenv("CI") + model = TextEmbedding(model_name=model_name, lazy_load=True) + assert not hasattr(model.model, "model") + docs = ["hello world", "flag embedding"] + list(model.embed(docs)) + assert hasattr(model.model, "model") -# model = TextEmbedding(model_name=model_name, lazy_load=True) -# list(model.query_embed(docs)) + model = TextEmbedding(model_name=model_name, lazy_load=True) + list(model.query_embed(docs)) -# model = TextEmbedding(model_name=model_name, lazy_load=True) -# list(model.passage_embed(docs)) + model = TextEmbedding(model_name=model_name, lazy_load=True) + list(model.passage_embed(docs)) -# if is_ci: -# delete_model_cache(model.model._model_dir) + if is_ci: + delete_model_cache(model.model._model_dir) From 5635c9cb53e485102e9f1a785e0f3d7d6ded3615 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Wed, 20 Nov 2024 16:09:02 +0200 Subject: [PATCH 07/15] fix: Fixed jina clip image preprocessor --- fastembed/image/transform/functional.py | 6 +-- fastembed/image/transform/operators.py | 55 ++++++++++++++++++------- 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/fastembed/image/transform/functional.py b/fastembed/image/transform/functional.py index 17b103c7..d8e43529 100644 --- a/fastembed/image/transform/functional.py +++ b/fastembed/image/transform/functional.py @@ -1,4 +1,4 @@ -from typing import Sized, Union +from typing import Sized, Union, Optional import numpy as np from PIL import Image, ImageOps @@ -126,8 +126,8 @@ def pil2ndarray(image: Union[Image.Image, np.ndarray]): def pad2square( image: Image, - fill_color: str | int | tuple[int, ...] | None = None, - resample: Image.Resampling = Image.Resampling.BILINEAR, + fill_color: Optional[Union[str, int, tuple[int, ...]]] = None, + resample: Union[Image.Resampling, int] = Image.Resampling.BILINEAR, ): width, height = image.size max_dim = max(width, height) diff --git a/fastembed/image/transform/operators.py b/fastembed/image/transform/operators.py index 56a83888..ad5e0afa 100644 --- a/fastembed/image/transform/operators.py +++ b/fastembed/image/transform/operators.py @@ -10,7 +10,7 @@ pil2ndarray, rescale, resize, - pad2sqaure, + pad2square, ) @@ -71,14 +71,14 @@ class PadtoSquare(Transform): def __init__( self, fill_color: Optional[Union[str, int, tuple[int, ...]]] = None, - resample: Image.Resampling = Image.Resampling.BICUBIC, + resample: Union[Image.Resampling, int] = Image.Resampling.BICUBIC, ): self.fill_color = fill_color self.resample = resample def __call__(self, images: list[np.ndarray]) -> list[np.ndarray]: return [ - pad2sqaure(image=image, fill_color=self.fill_color, resample=self.resample) + pad2square(image=image, fill_color=self.fill_color, resample=self.resample) for image in images ] @@ -125,8 +125,8 @@ def from_config(cls, config: dict[str, Any]) -> "Compose": """ transforms = [] cls._get_convert_to_rgb(transforms, config) + cls._get_pad2square(transforms, config) cls._get_resize(transforms, config) - cls._get_padtosquare(transforms, config) cls._get_center_crop(transforms, config) cls._get_pil2ndarray(transforms, config) cls._get_rescale(transforms, config) @@ -188,7 +188,11 @@ def _get_resize(transforms: list[Transform], config: dict[str, Any]): transforms.append( Resize( size=config["size"], - resample=config.get("interpolation", Image.Resampling.BICUBIC), + resample=( + Compose._interpolation_resolver(config.get("interpolation")) + if isinstance(config.get("interpolation"), str) + else config.get("interpolation") or Image.Resampling.BICUBIC + ), ) ) else: @@ -229,19 +233,38 @@ def _get_normalize(transforms: list[Transform], config: dict[str, Any]): if config.get("do_normalize", False) or ("mean" in config and "std" in config): transforms.append( Normalize( - mean=config["image_mean"] or config["mean"], - std=config["image_std"] or config["std"], + mean=config.get("image_mean", config.get("mean")), + std=config.get("image_std", config.get("std")), ) ) @staticmethod - def _get_padtosquare(transforms: list[Transform], config: dict[str, Any]): - if config.get("do_pad_to_square", False): - transforms.append( - PadtoSquare( - fill_color=config["fill_color"], - resample=config.get("interpolation") - or config.get("resample") - or Image.Resampling.BICUBIC, - ) + def _get_pad2square(transforms: list[Transform], config: dict[str, Any]): + mode = config.get("image_processor_type", "CLIPImageProcessor") + if mode == "CLIPImageProcessor": + pass + elif mode == "ConvNextFeatureExtractor": + pass + elif mode == "JinaCLIPImageProcessor": + resample = ( + Compose._interpolation_resolver(config.get("interpolation")) + if isinstance(config.get("interpolation"), str) + else config.get("interpolation") or Image.Resampling.BICUBIC ) + transforms.append(PadtoSquare(fill_color=config["fill_color"], resample=resample)) + + @staticmethod + def _interpolation_resolver(resample: Optional[str] = None) -> Image.Resampling: + interpolation_map = { + "nearest": Image.Resampling.NEAREST, + "lanczos": Image.Resampling.LANCZOS, + "bilinear": Image.Resampling.BILINEAR, + "bicubic": Image.Resampling.BICUBIC, + "box": Image.Resampling.BOX, + "hamming": Image.Resampling.HAMMING, + } + + if resample and (method := interpolation_map.get(resample.lower())): + return method + + raise ValueError(f"Unknown interpolation method: {resample}") From f402c7fce8deb0dca6f7ee4269d8a5fea04003eb Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Thu, 21 Nov 2024 01:41:07 +0200 Subject: [PATCH 08/15] fix: Fix type hints new: added resize2square --- fastembed/image/transform/functional.py | 36 +++++++++---------- fastembed/image/transform/operators.py | 48 +++++++++++++------------ 2 files changed, 44 insertions(+), 40 deletions(-) diff --git a/fastembed/image/transform/functional.py b/fastembed/image/transform/functional.py index d8e43529..553639f8 100644 --- a/fastembed/image/transform/functional.py +++ b/fastembed/image/transform/functional.py @@ -1,7 +1,7 @@ from typing import Sized, Union, Optional import numpy as np -from PIL import Image, ImageOps +from PIL import Image def convert_to_rgb(image: Image.Image) -> Image.Image: @@ -62,8 +62,8 @@ def center_crop( def normalize( image: np.ndarray, - mean=Union[float, np.ndarray], - std=Union[float, np.ndarray], + mean: Union[float, np.ndarray], + std: Union[float, np.ndarray], ) -> np.ndarray: if not isinstance(image, np.ndarray): raise ValueError("image must be a numpy array") @@ -96,10 +96,10 @@ def normalize( def resize( - image: Image, + image: Image.Image, size: Union[int, tuple[int, int]], - resample: Image.Resampling = Image.Resampling.BILINEAR, -) -> Image: + resample: Union[int, Image.Resampling] = Image.Resampling.BILINEAR, +) -> Image.Image: if isinstance(size, tuple): return image.resize(size, resample) @@ -124,16 +124,16 @@ def pil2ndarray(image: Union[Image.Image, np.ndarray]): return image -def pad2square( - image: Image, +def resize2square( + image: Image.Image, + size: int, fill_color: Optional[Union[str, int, tuple[int, ...]]] = None, - resample: Union[Image.Resampling, int] = Image.Resampling.BILINEAR, -): - width, height = image.size - max_dim = max(width, height) - return ImageOps.pad( - image=image, - size=(max_dim, max_dim), - color=fill_color, - method=resample, - ) + resample: Union[Image.Resampling, int] = Image.Resampling.BICUBIC, +) -> Image.Image: + resized_image = resize(image=image, size=size, resample=resample) + + new_image = Image.new(mode="RGB", size=(size, size), color=fill_color) + left = (size - resized_image.size[0]) // 2 + top = (size - resized_image.size[1]) // 2 + new_image.paste(resized_image, (left, top)) + return new_image diff --git a/fastembed/image/transform/operators.py b/fastembed/image/transform/operators.py index ad5e0afa..87df88c4 100644 --- a/fastembed/image/transform/operators.py +++ b/fastembed/image/transform/operators.py @@ -10,7 +10,7 @@ pil2ndarray, rescale, resize, - pad2square, + resize2square, ) @@ -38,7 +38,9 @@ def __init__(self, mean: Union[float, list[float]], std: Union[float, list[float self.std = std def __call__(self, images: list[np.ndarray]) -> list[np.ndarray]: - return [normalize(image, mean=self.mean, std=self.std) for image in images] + return [ + normalize(image, mean=np.array(self.mean), std=np.array(self.std)) for image in images + ] class Resize(Transform): @@ -67,18 +69,22 @@ def __call__(self, images: list[Union[Image.Image, np.ndarray]]) -> list[np.ndar return [pil2ndarray(image) for image in images] -class PadtoSquare(Transform): +class ResizetoSquare(Transform): def __init__( self, + size: int, fill_color: Optional[Union[str, int, tuple[int, ...]]] = None, resample: Union[Image.Resampling, int] = Image.Resampling.BICUBIC, ): + self.size = size self.fill_color = fill_color self.resample = resample - def __call__(self, images: list[np.ndarray]) -> list[np.ndarray]: + def __call__(self, images: list[Image.Image]) -> list[Image.Image]: return [ - pad2square(image=image, fill_color=self.fill_color, resample=self.resample) + resize2square( + image=image, size=self.size, fill_color=self.fill_color, resample=self.resample + ) for image in images ] @@ -125,8 +131,8 @@ def from_config(cls, config: dict[str, Any]) -> "Compose": """ transforms = [] cls._get_convert_to_rgb(transforms, config) - cls._get_pad2square(transforms, config) cls._get_resize(transforms, config) + cls._get_resize2square(transforms, config) cls._get_center_crop(transforms, config) cls._get_pil2ndarray(transforms, config) cls._get_rescale(transforms, config) @@ -182,19 +188,7 @@ def _get_resize(transforms: list[Transform], config: dict[str, Any]): ) ) elif mode == "JinaCLIPImageProcessor": - if "size" in config: - resize_mode = config.get("resize_mode", "shortest") - if resize_mode == "shortest": - transforms.append( - Resize( - size=config["size"], - resample=( - Compose._interpolation_resolver(config.get("interpolation")) - if isinstance(config.get("interpolation"), str) - else config.get("interpolation") or Image.Resampling.BICUBIC - ), - ) - ) + pass else: raise ValueError(f"Preprocessor {mode} is not supported") @@ -224,7 +218,8 @@ def _get_pil2ndarray(transforms: list[Transform], config: dict[str, Any]): @staticmethod def _get_rescale(transforms: list[Transform], config: dict[str, Any]): - if config.get("do_rescale", True): + # mode = config.get("image_processor_type", "CLIPImageProcessor") + if config.get("do_rescale", True): # or (mode == "JinaCLIPImageProcessor"): rescale_factor = config.get("rescale_factor", 1 / 255) transforms.append(Rescale(scale=rescale_factor)) @@ -239,7 +234,7 @@ def _get_normalize(transforms: list[Transform], config: dict[str, Any]): ) @staticmethod - def _get_pad2square(transforms: list[Transform], config: dict[str, Any]): + def _get_resize2square(transforms: list[Transform], config: dict[str, Any]): mode = config.get("image_processor_type", "CLIPImageProcessor") if mode == "CLIPImageProcessor": pass @@ -251,7 +246,16 @@ def _get_pad2square(transforms: list[Transform], config: dict[str, Any]): if isinstance(config.get("interpolation"), str) else config.get("interpolation") or Image.Resampling.BICUBIC ) - transforms.append(PadtoSquare(fill_color=config["fill_color"], resample=resample)) + if "size" in config: + resize_mode = config.get("resize_mode", "shortest") + if resize_mode == "shortest": + transforms.append( + ResizetoSquare( + size=config["size"], + fill_color=config.get("fill_color", 0), + resample=resample, + ) + ) @staticmethod def _interpolation_resolver(resample: Optional[str] = None) -> Image.Resampling: From d9cbbce7ea937d7fd0ddb5c1dca097d033cf89ce Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Thu, 21 Nov 2024 01:49:39 +0200 Subject: [PATCH 09/15] tests: Add jina clip vision test case --- tests/test_image_onnx_embeddings.py | 3 +++ tests/test_text_onnx_embeddings.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_image_onnx_embeddings.py b/tests/test_image_onnx_embeddings.py index 78194caf..a5fb8e36 100644 --- a/tests/test_image_onnx_embeddings.py +++ b/tests/test_image_onnx_embeddings.py @@ -21,6 +21,9 @@ "Qdrant/Unicom-ViT-B-32": np.array( [0.0418, 0.0550, 0.0003, 0.0253, -0.0185, 0.0016, -0.0368, -0.0402, -0.0891, -0.0186] ), + "jinaai/jina-clip-v1": np.array( + [-0.029, 0.0216, 0.0396, 0.0283, -0.0023, 0.0151, 0.011, -0.0235, 0.0251, -0.0343] + ), } diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 0db3f727..a13d0caa 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -77,7 +77,7 @@ def test_embedding(): dim = model_desc["dim"] - model = TextEmbedding(model_name=model_desc["model"], cache_dir="models") + model = TextEmbedding(model_name=model_desc["model"]) docs = ["hello world", "flag embedding"] embeddings = list(model.embed(docs)) embeddings = np.stack(embeddings, axis=0) From e78c76cc8939b1f03444e734d255762ab3f2fd78 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Thu, 21 Nov 2024 02:21:45 +0200 Subject: [PATCH 10/15] nit --- fastembed/image/transform/operators.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fastembed/image/transform/operators.py b/fastembed/image/transform/operators.py index 87df88c4..59683124 100644 --- a/fastembed/image/transform/operators.py +++ b/fastembed/image/transform/operators.py @@ -218,8 +218,7 @@ def _get_pil2ndarray(transforms: list[Transform], config: dict[str, Any]): @staticmethod def _get_rescale(transforms: list[Transform], config: dict[str, Any]): - # mode = config.get("image_processor_type", "CLIPImageProcessor") - if config.get("do_rescale", True): # or (mode == "JinaCLIPImageProcessor"): + if config.get("do_rescale", True): rescale_factor = config.get("rescale_factor", 1 / 255) transforms.append(Rescale(scale=rescale_factor)) From eeeaa77f4bc7808827f97746151f34fadb4d974e Mon Sep 17 00:00:00 2001 From: Hossam Hagag <90828745+hh-space-invader@users.noreply.github.com> Date: Mon, 25 Nov 2024 00:38:14 +0200 Subject: [PATCH 11/15] refactor: Update fastembed/image/transform/operators.py Co-authored-by: George --- fastembed/image/transform/operators.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/fastembed/image/transform/operators.py b/fastembed/image/transform/operators.py index 59683124..2ed108d4 100644 --- a/fastembed/image/transform/operators.py +++ b/fastembed/image/transform/operators.py @@ -224,13 +224,10 @@ def _get_rescale(transforms: list[Transform], config: dict[str, Any]): @staticmethod def _get_normalize(transforms: list[Transform], config: dict[str, Any]): - if config.get("do_normalize", False) or ("mean" in config and "std" in config): - transforms.append( - Normalize( - mean=config.get("image_mean", config.get("mean")), - std=config.get("image_std", config.get("std")), - ) - ) +if config.get("do_normalize", False): + transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"])) +elif "mean" in config and "std" in config: + transforms.append(Normalize(mean=config["mean"], std=config["std"])) @staticmethod def _get_resize2square(transforms: list[Transform], config: dict[str, Any]): From 0fc67615da08c5944aaf5dd650041e7ee1bd3fb0 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 25 Nov 2024 00:41:35 +0200 Subject: [PATCH 12/15] fix: Fix indentation --- fastembed/image/transform/operators.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fastembed/image/transform/operators.py b/fastembed/image/transform/operators.py index 2ed108d4..7d43ecf9 100644 --- a/fastembed/image/transform/operators.py +++ b/fastembed/image/transform/operators.py @@ -224,10 +224,10 @@ def _get_rescale(transforms: list[Transform], config: dict[str, Any]): @staticmethod def _get_normalize(transforms: list[Transform], config: dict[str, Any]): -if config.get("do_normalize", False): - transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"])) -elif "mean" in config and "std" in config: - transforms.append(Normalize(mean=config["mean"], std=config["std"])) + if config.get("do_normalize", False): + transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"])) + elif "mean" in config and "std" in config: + transforms.append(Normalize(mean=config["mean"], std=config["std"])) @staticmethod def _get_resize2square(transforms: list[Transform], config: dict[str, Any]): From 64829ab8e6d9c3bc85d5924e9d61696d8fc2db00 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 25 Nov 2024 00:50:08 +0200 Subject: [PATCH 13/15] refactor: Refactored how we call padding for image --- fastembed/image/transform/functional.py | 11 +++--- fastembed/image/transform/operators.py | 49 ++++++++++++------------- 2 files changed, 29 insertions(+), 31 deletions(-) diff --git a/fastembed/image/transform/functional.py b/fastembed/image/transform/functional.py index 553639f8..ddad8498 100644 --- a/fastembed/image/transform/functional.py +++ b/fastembed/image/transform/functional.py @@ -124,16 +124,15 @@ def pil2ndarray(image: Union[Image.Image, np.ndarray]): return image -def resize2square( +def pad2square( image: Image.Image, size: int, fill_color: Optional[Union[str, int, tuple[int, ...]]] = None, - resample: Union[Image.Resampling, int] = Image.Resampling.BICUBIC, ) -> Image.Image: - resized_image = resize(image=image, size=size, resample=resample) + height, width = image.height, image.width new_image = Image.new(mode="RGB", size=(size, size), color=fill_color) - left = (size - resized_image.size[0]) // 2 - top = (size - resized_image.size[1]) // 2 - new_image.paste(resized_image, (left, top)) + left = (size - height) // 2 + top = (size - width) // 2 + new_image.paste(image, (left, top)) return new_image diff --git a/fastembed/image/transform/operators.py b/fastembed/image/transform/operators.py index 7d43ecf9..46ff1a73 100644 --- a/fastembed/image/transform/operators.py +++ b/fastembed/image/transform/operators.py @@ -10,7 +10,7 @@ pil2ndarray, rescale, resize, - resize2square, + pad2square, ) @@ -69,23 +69,18 @@ def __call__(self, images: list[Union[Image.Image, np.ndarray]]) -> list[np.ndar return [pil2ndarray(image) for image in images] -class ResizetoSquare(Transform): +class PadtoSquare(Transform): def __init__( self, size: int, fill_color: Optional[Union[str, int, tuple[int, ...]]] = None, - resample: Union[Image.Resampling, int] = Image.Resampling.BICUBIC, ): self.size = size self.fill_color = fill_color - self.resample = resample def __call__(self, images: list[Image.Image]) -> list[Image.Image]: return [ - resize2square( - image=image, size=self.size, fill_color=self.fill_color, resample=self.resample - ) - for image in images + pad2square(image=image, size=self.size, fill_color=self.fill_color) for image in images ] @@ -132,7 +127,7 @@ def from_config(cls, config: dict[str, Any]) -> "Compose": transforms = [] cls._get_convert_to_rgb(transforms, config) cls._get_resize(transforms, config) - cls._get_resize2square(transforms, config) + cls._get_pad2square(transforms, config) cls._get_center_crop(transforms, config) cls._get_pil2ndarray(transforms, config) cls._get_rescale(transforms, config) @@ -188,7 +183,20 @@ def _get_resize(transforms: list[Transform], config: dict[str, Any]): ) ) elif mode == "JinaCLIPImageProcessor": - pass + resample = ( + Compose._interpolation_resolver(config.get("interpolation")) + if isinstance(config.get("interpolation"), str) + else config.get("interpolation") or Image.Resampling.BICUBIC + ) + if "size" in config: + resize_mode = config.get("resize_mode", "shortest") + if resize_mode == "shortest": + transforms.append( + Resize( + size=config["size"], + resample=resample, + ) + ) else: raise ValueError(f"Preprocessor {mode} is not supported") @@ -230,28 +238,19 @@ def _get_normalize(transforms: list[Transform], config: dict[str, Any]): transforms.append(Normalize(mean=config["mean"], std=config["std"])) @staticmethod - def _get_resize2square(transforms: list[Transform], config: dict[str, Any]): + def _get_pad2square(transforms: list[Transform], config: dict[str, Any]): mode = config.get("image_processor_type", "CLIPImageProcessor") if mode == "CLIPImageProcessor": pass elif mode == "ConvNextFeatureExtractor": pass elif mode == "JinaCLIPImageProcessor": - resample = ( - Compose._interpolation_resolver(config.get("interpolation")) - if isinstance(config.get("interpolation"), str) - else config.get("interpolation") or Image.Resampling.BICUBIC + transforms.append( + PadtoSquare( + size=config["size"], + fill_color=config.get("fill_color", 0), + ) ) - if "size" in config: - resize_mode = config.get("resize_mode", "shortest") - if resize_mode == "shortest": - transforms.append( - ResizetoSquare( - size=config["size"], - fill_color=config.get("fill_color", 0), - resample=resample, - ) - ) @staticmethod def _interpolation_resolver(resample: Optional[str] = None) -> Image.Resampling: From 2de91d5daf6a650e8c6669b6b1506351ab55464a Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 25 Nov 2024 02:10:31 +0200 Subject: [PATCH 14/15] fix: Fix pad to image when resized size larger than new square canvas --- fastembed/image/transform/functional.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/fastembed/image/transform/functional.py b/fastembed/image/transform/functional.py index ddad8498..47724515 100644 --- a/fastembed/image/transform/functional.py +++ b/fastembed/image/transform/functional.py @@ -131,8 +131,17 @@ def pad2square( ) -> Image.Image: height, width = image.height, image.width + # if the size is larger than the new canvas + if width > size or height > size: + left = (width - size) // 2 + top = (height - size) // 2 + right = left + size + bottom = top + size + image = image.crop((left, top, right, bottom)) + return image + new_image = Image.new(mode="RGB", size=(size, size), color=fill_color) - left = (size - height) // 2 - top = (size - width) // 2 + left = (size - width) // 2 + top = (size - height) // 2 new_image.paste(image, (left, top)) return new_image From 33de341ebad13a99a7f329cee46d63bb7fd97d45 Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Wed, 27 Nov 2024 00:22:32 +0200 Subject: [PATCH 15/15] refactor: minor refactor --- fastembed/image/transform/operators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastembed/image/transform/operators.py b/fastembed/image/transform/operators.py index 46ff1a73..0d08963d 100644 --- a/fastembed/image/transform/operators.py +++ b/fastembed/image/transform/operators.py @@ -186,7 +186,7 @@ def _get_resize(transforms: list[Transform], config: dict[str, Any]): resample = ( Compose._interpolation_resolver(config.get("interpolation")) if isinstance(config.get("interpolation"), str) - else config.get("interpolation") or Image.Resampling.BICUBIC + else config.get("interpolation", Image.Resampling.BICUBIC) ) if "size" in config: resize_mode = config.get("resize_mode", "shortest")