From 177746235a347e1468ade07a7ef425d4011a0bc2 Mon Sep 17 00:00:00 2001 From: ItzCrazyKns <95534749+ItzCrazyKns@users.noreply.github.com> Date: Thu, 28 Nov 2024 20:47:18 +0530 Subject: [PATCH] feat(providers): add gemini --- package.json | 1 + sample.config.toml | 1 + src/config.ts | 3 ++ src/lib/providers/gemini.ts | 69 ++++++++++++++++++++++++++++++++ src/lib/providers/index.ts | 3 ++ src/routes/config.ts | 5 ++- ui/components/SettingsDialog.tsx | 17 ++++++++ yarn.lock | 53 ++++++++++++++++++++++++ 8 files changed, 151 insertions(+), 1 deletion(-) create mode 100644 src/lib/providers/gemini.ts diff --git a/package.json b/package.json index 6a677468..0368b21c 100644 --- a/package.json +++ b/package.json @@ -31,6 +31,7 @@ "@langchain/anthropic": "^0.2.3", "@langchain/community": "^0.2.16", "@langchain/openai": "^0.0.25", + "@langchain/google-genai": "^0.0.23", "@xenova/transformers": "^2.17.1", "axios": "^1.6.8", "better-sqlite3": "^11.0.0", diff --git a/sample.config.toml b/sample.config.toml index dddcc039..50ba95de 100644 --- a/sample.config.toml +++ b/sample.config.toml @@ -7,6 +7,7 @@ KEEP_ALIVE = "5m" # How long to keep Ollama models loaded into memory. (Instead OPENAI = "" # OpenAI API key - sk-1234567890abcdef1234567890abcdef GROQ = "" # Groq API key - gsk_1234567890abcdef1234567890abcdef ANTHROPIC = "" # Anthropic API key - sk-ant-1234567890abcdef1234567890abcdef +GEMINI = "" # Gemini API key - sk-1234567890abcdef1234567890abcdef [API_ENDPOINTS] SEARXNG = "http://localhost:32768" # SearxNG API URL diff --git a/src/config.ts b/src/config.ts index 8624e7f8..001c2590 100644 --- a/src/config.ts +++ b/src/config.ts @@ -14,6 +14,7 @@ interface Config { OPENAI: string; GROQ: string; ANTHROPIC: string; + GEMINI: string; }; API_ENDPOINTS: { SEARXNG: string; @@ -43,6 +44,8 @@ export const getGroqApiKey = () => loadConfig().API_KEYS.GROQ; export const getAnthropicApiKey = () => loadConfig().API_KEYS.ANTHROPIC; +export const getGeminiApiKey = () => loadConfig().API_KEYS.GEMINI; + export const getSearxngApiEndpoint = () => process.env.SEARXNG_API_URL || loadConfig().API_ENDPOINTS.SEARXNG; diff --git a/src/lib/providers/gemini.ts b/src/lib/providers/gemini.ts new file mode 100644 index 00000000..95764cfc --- /dev/null +++ b/src/lib/providers/gemini.ts @@ -0,0 +1,69 @@ +import { + ChatGoogleGenerativeAI, + GoogleGenerativeAIEmbeddings, +} from '@langchain/google-genai'; +import { getGeminiApiKey } from '../../config'; +import logger from '../../utils/logger'; + +export const loadGeminiChatModels = async () => { + const geminiApiKey = getGeminiApiKey(); + + if (!geminiApiKey) return {}; + + try { + const chatModels = { + 'gemini-1.5-flash': { + displayName: 'Gemini 1.5 Flash', + model: new ChatGoogleGenerativeAI({ + modelName: 'gemini-1.5-flash', + temperature: 0.7, + apiKey: geminiApiKey, + }), + }, + 'gemini-1.5-flash-8b': { + displayName: 'Gemini 1.5 Flash 8B', + model: new ChatGoogleGenerativeAI({ + modelName: 'gemini-1.5-flash-8b', + temperature: 0.7, + apiKey: geminiApiKey, + }), + }, + 'gemini-1.5-pro': { + displayName: 'Gemini 1.5 Pro', + model: new ChatGoogleGenerativeAI({ + modelName: 'gemini-1.5-pro', + temperature: 0.7, + apiKey: geminiApiKey, + }), + }, + }; + + return chatModels; + } catch (err) { + logger.error(`Error loading Gemini models: ${err}`); + return {}; + } +}; + +export const loadGeminiEmbeddingsModels = async () => { + const geminiApiKey = getGeminiApiKey(); + + if (!geminiApiKey) return {}; + + try { + const embeddingModels = { + 'text-embedding-004': { + displayName: 'Text Embedding', + model: new GoogleGenerativeAIEmbeddings({ + apiKey: geminiApiKey, + modelName: 'text-embedding-004', + }), + }, + }; + + return embeddingModels; + } catch (err) { + logger.error(`Error loading Gemini embeddings model: ${err}`); + return {}; + } +}; diff --git a/src/lib/providers/index.ts b/src/lib/providers/index.ts index d919fd4a..98846e76 100644 --- a/src/lib/providers/index.ts +++ b/src/lib/providers/index.ts @@ -3,18 +3,21 @@ import { loadOllamaChatModels, loadOllamaEmbeddingsModels } from './ollama'; import { loadOpenAIChatModels, loadOpenAIEmbeddingsModels } from './openai'; import { loadAnthropicChatModels } from './anthropic'; import { loadTransformersEmbeddingsModels } from './transformers'; +import { loadGeminiChatModels, loadGeminiEmbeddingsModels } from './gemini'; const chatModelProviders = { openai: loadOpenAIChatModels, groq: loadGroqChatModels, ollama: loadOllamaChatModels, anthropic: loadAnthropicChatModels, + gemini: loadGeminiChatModels, }; const embeddingModelProviders = { openai: loadOpenAIEmbeddingsModels, local: loadTransformersEmbeddingsModels, ollama: loadOllamaEmbeddingsModels, + gemini: loadGeminiEmbeddingsModels, }; export const getAvailableChatModelProviders = async () => { diff --git a/src/routes/config.ts b/src/routes/config.ts index f635e4b8..38192b70 100644 --- a/src/routes/config.ts +++ b/src/routes/config.ts @@ -7,6 +7,7 @@ import { getGroqApiKey, getOllamaApiEndpoint, getAnthropicApiKey, + getGeminiApiKey, getOpenaiApiKey, updateConfig, } from '../config'; @@ -52,7 +53,8 @@ router.get('/', async (_, res) => { config['ollamaApiUrl'] = getOllamaApiEndpoint(); config['anthropicApiKey'] = getAnthropicApiKey(); config['groqApiKey'] = getGroqApiKey(); - + config['geminiApiKey'] = getGeminiApiKey(); + res.status(200).json(config); } catch (err: any) { res.status(500).json({ message: 'An error has occurred.' }); @@ -68,6 +70,7 @@ router.post('/', async (req, res) => { OPENAI: config.openaiApiKey, GROQ: config.groqApiKey, ANTHROPIC: config.anthropicApiKey, + GEMINI: config.geminiApiKey, }, API_ENDPOINTS: { OLLAMA: config.ollamaApiUrl, diff --git a/ui/components/SettingsDialog.tsx b/ui/components/SettingsDialog.tsx index 716dd7d4..163857bf 100644 --- a/ui/components/SettingsDialog.tsx +++ b/ui/components/SettingsDialog.tsx @@ -63,6 +63,7 @@ interface SettingsType { openaiApiKey: string; groqApiKey: string; anthropicApiKey: string; + geminiApiKey: string; ollamaApiUrl: string; } @@ -476,6 +477,22 @@ const SettingsDialog = ({ } /> +
+ Gemini API Key +
+ + setConfig({ + ...config, + geminiApiKey: e.target.value, + }) + } + /> +