From 306323c61b935ba0e1f8acf6d08652fead8ec280 Mon Sep 17 00:00:00 2001 From: alex Date: Fri, 8 Mar 2024 13:56:10 -0600 Subject: [PATCH] add gemini --- src/main/app/backend/routers/llm.py | 2 +- src/main/app/backend/tools/llm.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/main/app/backend/routers/llm.py b/src/main/app/backend/routers/llm.py index 8c98cd98..e74ae44c 100644 --- a/src/main/app/backend/routers/llm.py +++ b/src/main/app/backend/routers/llm.py @@ -63,7 +63,7 @@ async def get_response(question: Question, background_tasks: BackgroundTasks) -> context = reader.retrieve_context_documents(question_embedding=question_embedding, number_of_context_documents=question.number_of_documents) # print(context) print("context retrieved...") - llm = LLM(llm_type="GPT-4 8k", temperature=question.temperature) + llm = LLM(llm_type=question.llm_type, temperature=question.temperature) print("llm initialized...") llm_response = llm.get_response(question=question.question, context=context) print("response retrieved...") diff --git a/src/main/app/backend/tools/llm.py b/src/main/app/backend/tools/llm.py index 8403b709..59e627ba 100644 --- a/src/main/app/backend/tools/llm.py +++ b/src/main/app/backend/tools/llm.py @@ -2,7 +2,9 @@ from typing import List, Dict, Tuple import openai -from langchain_community.chat_models import ChatVertexAI, AzureChatOpenAI +from langchain_community.chat_models import AzureChatOpenAI +# from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_google_vertexai import ChatVertexAI from langchain_openai import OpenAI # from langchain.chains import ConversationChain import pandas as pd @@ -51,6 +53,8 @@ def _init_llm(self, llm_type: str, temperature: float): top_p=0.95, # default is 0.95 top_k = 40 # default is 40 ) + case "Gemini": + self.llm_instance = ChatVertexAI(model_name="gemini-pro") case "GPT-4 8k": # Tokens per Minute Rate Limit (thousands): 10 # Rate limit (Tokens per minute): 10000