Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new: Added jina clip text embedding #408

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
11 changes: 11 additions & 0 deletions fastembed/image/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@
},
"model_file": "model.onnx",
},
{
"model": "jinaai/jina-clip-v1",
"dim": 768,
"description": "Image embeddings, Multimodal (text&image), 2024 year",
"license": "apache-2.0",
"size_in_GB": 0.34,
"sources": {
"hf": "jinaai/jina-clip-v1",
},
"model_file": "onnx/vision_model.onnx",
},
]


Expand Down
27 changes: 21 additions & 6 deletions fastembed/image/transform/functional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sized, Union
from typing import Sized, Union, Optional

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -62,8 +62,8 @@ def center_crop(

def normalize(
image: np.ndarray,
mean=Union[float, np.ndarray],
std=Union[float, np.ndarray],
mean: Union[float, np.ndarray],
std: Union[float, np.ndarray],
) -> np.ndarray:
if not isinstance(image, np.ndarray):
raise ValueError("image must be a numpy array")
Expand Down Expand Up @@ -96,10 +96,10 @@ def normalize(


def resize(
image: Image,
image: Image.Image,
size: Union[int, tuple[int, int]],
resample: Image.Resampling = Image.Resampling.BILINEAR,
) -> Image:
resample: Union[int, Image.Resampling] = Image.Resampling.BILINEAR,
) -> Image.Image:
if isinstance(size, tuple):
return image.resize(size, resample)

Expand All @@ -122,3 +122,18 @@ def pil2ndarray(image: Union[Image.Image, np.ndarray]):
if isinstance(image, Image.Image):
return np.asarray(image).transpose((2, 0, 1))
return image


def resize2square(
image: Image.Image,
size: int,
fill_color: Optional[Union[str, int, tuple[int, ...]]] = None,
resample: Union[Image.Resampling, int] = Image.Resampling.BICUBIC,
) -> Image.Image:
resized_image = resize(image=image, size=size, resample=resample)

new_image = Image.new(mode="RGB", size=(size, size), color=fill_color)
left = (size - resized_image.size[0]) // 2
top = (size - resized_image.size[1]) // 2
new_image.paste(resized_image, (left, top))
return new_image
89 changes: 85 additions & 4 deletions fastembed/image/transform/operators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Union
from typing import Any, Union, Optional

import numpy as np
from PIL import Image
Expand All @@ -10,6 +10,7 @@
pil2ndarray,
rescale,
resize,
resize2square,
)


Expand Down Expand Up @@ -37,7 +38,9 @@ def __init__(self, mean: Union[float, list[float]], std: Union[float, list[float
self.std = std

def __call__(self, images: list[np.ndarray]) -> list[np.ndarray]:
return [normalize(image, mean=self.mean, std=self.std) for image in images]
return [
normalize(image, mean=np.array(self.mean), std=np.array(self.std)) for image in images
]


class Resize(Transform):
Expand Down Expand Up @@ -66,6 +69,26 @@ def __call__(self, images: list[Union[Image.Image, np.ndarray]]) -> list[np.ndar
return [pil2ndarray(image) for image in images]


class ResizetoSquare(Transform):
def __init__(
self,
size: int,
fill_color: Optional[Union[str, int, tuple[int, ...]]] = None,
resample: Union[Image.Resampling, int] = Image.Resampling.BICUBIC,
):
self.size = size
self.fill_color = fill_color
self.resample = resample

def __call__(self, images: list[Image.Image]) -> list[Image.Image]:
return [
resize2square(
image=image, size=self.size, fill_color=self.fill_color, resample=self.resample
)
for image in images
]


class Compose:
def __init__(self, transforms: list[Transform]):
self.transforms = transforms
Expand All @@ -85,14 +108,20 @@ def from_config(cls, config: dict[str, Any]) -> "Compose":

Valid keys:
- do_resize
- resize_mode
- size
- fill_color
- do_center_crop
- crop_size
- do_rescale
- rescale_factor
- do_normalize
- image_mean
- mean
- image_std
- std
- resample
- interpolation
Valid size keys (nested):
- {"height", "width"}
- {"shortest_edge"}
Expand All @@ -103,6 +132,7 @@ def from_config(cls, config: dict[str, Any]) -> "Compose":
transforms = []
cls._get_convert_to_rgb(transforms, config)
cls._get_resize(transforms, config)
cls._get_resize2square(transforms, config)
cls._get_center_crop(transforms, config)
cls._get_pil2ndarray(transforms, config)
cls._get_rescale(transforms, config)
Expand Down Expand Up @@ -157,6 +187,10 @@ def _get_resize(transforms: list[Transform], config: dict[str, Any]):
resample=config.get("resample", Image.Resampling.BICUBIC),
)
)
elif mode == "JinaCLIPImageProcessor":
pass
else:
raise ValueError(f"Preprocessor {mode} is not supported")

@staticmethod
def _get_center_crop(transforms: list[Transform], config: dict[str, Any]):
Expand All @@ -173,6 +207,8 @@ def _get_center_crop(transforms: list[Transform], config: dict[str, Any]):
transforms.append(CenterCrop(size=crop_size))
elif mode == "ConvNextFeatureExtractor":
pass
elif mode == "JinaCLIPImageProcessor":
pass
else:
raise ValueError(f"Preprocessor {mode} is not supported")

Expand All @@ -188,5 +224,50 @@ def _get_rescale(transforms: list[Transform], config: dict[str, Any]):

@staticmethod
def _get_normalize(transforms: list[Transform], config: dict[str, Any]):
if config.get("do_normalize", False):
transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"]))
if config.get("do_normalize", False) or ("mean" in config and "std" in config):
transforms.append(
Normalize(
mean=config.get("image_mean", config.get("mean")),
std=config.get("image_std", config.get("std")),
)
)

Comment on lines +227 to +233
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if config.get("do_normalize", False) or ("mean" in config and "std" in config):
transforms.append(
Normalize(
mean=config.get("image_mean", config.get("mean")),
std=config.get("image_std", config.get("std")),
)
)
if config.get("do_normalize", False):
transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"]))
elif "mean" in config and "std" in config:
transforms.append(Normalize(mean=config["mean"], std=config["std"]))

@staticmethod
def _get_resize2square(transforms: list[Transform], config: dict[str, Any]):
mode = config.get("image_processor_type", "CLIPImageProcessor")
if mode == "CLIPImageProcessor":
pass
elif mode == "ConvNextFeatureExtractor":
pass
elif mode == "JinaCLIPImageProcessor":
resample = (
Compose._interpolation_resolver(config.get("interpolation"))
if isinstance(config.get("interpolation"), str)
else config.get("interpolation") or Image.Resampling.BICUBIC
)
if "size" in config:
resize_mode = config.get("resize_mode", "shortest")
if resize_mode == "shortest":
transforms.append(
ResizetoSquare(
size=config["size"],
fill_color=config.get("fill_color", 0),
resample=resample,
)
)

@staticmethod
def _interpolation_resolver(resample: Optional[str] = None) -> Image.Resampling:
interpolation_map = {
"nearest": Image.Resampling.NEAREST,
"lanczos": Image.Resampling.LANCZOS,
"bilinear": Image.Resampling.BILINEAR,
"bicubic": Image.Resampling.BICUBIC,
"box": Image.Resampling.BOX,
"hamming": Image.Resampling.HAMMING,
}

if resample and (method := interpolation_map.get(resample.lower())):
return method

raise ValueError(f"Unknown interpolation method: {resample}")
4 changes: 4 additions & 0 deletions fastembed/text/pooled_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:

@classmethod
def mean_pooling(cls, model_output: np.ndarray, attention_mask: np.ndarray) -> np.ndarray:
if model_output.ndim == 2: # (batch, embedding_dim)
seq_length = attention_mask.shape[1]
# (batch, seq_length, embedding_dim)
model_output = np.tile(np.expand_dims(model_output, axis=1), (1, seq_length, 1))
token_embeddings = model_output
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
input_mask_expanded = np.tile(input_mask_expanded, (1, 1, token_embeddings.shape[-1]))
Expand Down
11 changes: 11 additions & 0 deletions fastembed/text/pooled_normalized_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@
"sources": {"hf": "jinaai/jina-embeddings-v2-base-es"},
"model_file": "onnx/model.onnx",
},
{
"model": "jinaai/jina-clip-v1",
"dim": 768,
"description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year",
"license": "apache-2.0",
"size_in_GB": 0.55,
"sources": {
"hf": "jinaai/jina-clip-v1",
},
"model_file": "onnx/text_model.onnx",
},
]


Expand Down
3 changes: 3 additions & 0 deletions tests/test_image_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
"Qdrant/Unicom-ViT-B-32": np.array(
[0.0418, 0.0550, 0.0003, 0.0253, -0.0185, 0.0016, -0.0368, -0.0402, -0.0891, -0.0186]
),
"jinaai/jina-clip-v1": np.array(
[-0.029, 0.0216, 0.0396, 0.0283, -0.0023, 0.0151, 0.011, -0.0235, 0.0251, -0.0343]
),
}


Expand Down
1 change: 1 addition & 0 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
),
"snowflake/snowflake-arctic-embed-l": np.array([0.0189, -0.0673, 0.0183, 0.0124, 0.0146]),
"Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]),
"jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]),
}


Expand Down
Loading