From 84f3e99965864bd96af6167036104863e78e0a7b Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Tue, 19 Mar 2024 18:04:35 +0800 Subject: [PATCH] [LLM] Add `TransformersBgeEmbeddings` class in `bigdl.llm.langchain.embeddings` (#10459) * Add TransformersBgeEmbeddings class in bigdl.llm.langchain.embeddings * Small fixes --- .../src/bigdl/llm/langchain/embeddings/__init__.py | 5 +++-- .../langchain/embeddings/transformersembeddings.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py b/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py index d001919c976..e6ec52acf8d 100644 --- a/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py +++ b/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py @@ -20,7 +20,7 @@ # only search the first bigdl package and end up finding only one sub-package. from .bigdlllm import * -from .transformersembeddings import TransformersEmbeddings +from .transformersembeddings import TransformersEmbeddings, TransformersBgeEmbeddings __all__ = [ "BigdlNativeEmbeddings", @@ -28,5 +28,6 @@ "BloomEmbeddings", "GptneoxEmbeddings", "StarcoderEmbeddings", - "TransformersEmbeddings" + "TransformersEmbeddings", + "TransformersBgeEmbeddings" ] diff --git a/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py b/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py index c52a8adf285..9c69f4744c3 100644 --- a/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py +++ b/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py @@ -45,6 +45,7 @@ # THE SOFTWARE. """Wrapper around BigdlLLM embedding models.""" +import torch from typing import Any, Dict, List, Optional import numpy as np @@ -181,3 +182,14 @@ def embed_query(self, text: str) -> List[float]: text = text.replace("\n", " ") embedding = self.embed(text, **self.encode_kwargs) return embedding.tolist() + +# fit specific encode method for langchain.embeddings.HuggingFaceBgeEmbeddings +# TODO: directly support HuggingFaceBgeEmbeddings +class TransformersBgeEmbeddings(TransformersEmbeddings): + + def embed(self, text: str, **kwargs): + input_ids = self.tokenizer.encode(text, return_tensors="pt", **kwargs) + input_ids = input_ids.to(self.model.device) + embeddings = self.model(input_ids, return_dict=False)[0].cpu() + embeddings = torch.nn.functional.normalize(embeddings[:, 0], p=2, dim=1) + return embeddings[0]