diff --git a/.env.example b/.env.example index fa9dce8..8539ee3 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,5 @@ DEBUG=True +PROCESSING="Thinking..." # Model Services API_KEY_OPENAI=ADD_YOUR_KEY @@ -21,11 +22,16 @@ OPENAI_ENABLED=False OPENAI_ICON_PREFIX=🤖 DALLE_PREFIX=!dalle -DALLE_ENABLED=True +DALLE_ENABLED=False DALLE_ICON_PREFIX=🎨 DALLE_USE_3=False # # Google Gemini GEMINI_PREFIX=!gemini GEMINI_ENABLED=True -GEMINI_ICON_PREFIX=🔮 \ No newline at end of file +GEMINI_ICON_PREFIX=🔮 + +# # Hugging Face Flux +HF_PREFIX=!flux +HF_ENABLED=False +HF_ICON_PREFIX=🤗 \ No newline at end of file diff --git a/src/baileys/env.ts b/src/baileys/env.ts index d69b6b7..9543232 100644 --- a/src/baileys/env.ts +++ b/src/baileys/env.ts @@ -33,6 +33,11 @@ interface EnvInterface { GEMINI_PREFIX?: string; GEMINI_ENABLED: boolean; GEMINI_ICON_PREFIX?: string; + + // // Hugging Face + HF_PREFIX?: string; + HF_ENABLED: boolean; + HF_ICON_PREFIX?: string; } export const ENV: EnvInterface = { @@ -56,5 +61,8 @@ export const ENV: EnvInterface = { DALLE_USE_3: process.env.DALLE_USE_3 === 'True', GEMINI_PREFIX: process.env.GEMINI_PREFIX, GEMINI_ENABLED: process.env.GEMINI_ENABLED === 'True', - GEMINI_ICON_PREFIX: process.env.GEMINI_ICON_PREFIX + GEMINI_ICON_PREFIX: process.env.GEMINI_ICON_PREFIX, + HF_PREFIX: process.env.HF_PREFIX, + HF_ENABLED: process.env.HF_ENABLED === 'True', + HF_ICON_PREFIX: process.env.HF_ICON_PREFIX }; diff --git a/src/baileys/handlers/message.ts b/src/baileys/handlers/message.ts index db94390..785407f 100644 --- a/src/baileys/handlers/message.ts +++ b/src/baileys/handlers/message.ts @@ -1,6 +1,3 @@ -/* Third-party modules */ -import { AnyRegularMessageContent, downloadMediaMessage } from '@whiskeysockets/baileys'; - /* Local modules */ import MessageHandlerParams from './../aimodeltype'; import { AIModels } from './../../types/AiModels'; @@ -16,7 +13,7 @@ import { ENV } from '../env'; const modelTable: Record = { ChatGPT: ENV.OPENAI_ENABLED ? new ChatGPTModel() : null, Gemini: ENV.GEMINI_ENABLED ? new GeminiModel() : null, - FLUX: new FluxModel() + FLUX: ENV.HF_ENABLED ? new FluxModel() : null }; // handles message @@ -51,9 +48,17 @@ export async function handleMessage({ client, msg, metadata }: MessageHandlerPar console.error(err); return; } - res.edit = messageResponse?.key; - client.sendMessage(metadata.remoteJid, res); + if (res.image) { + // delete the old message + if (messageResponse?.key) { + client.sendMessage(metadata.remoteJid, { delete: messageResponse.key }); + } + client.sendMessage(metadata.remoteJid, res, { quoted: msg }); + } else { + res.edit = messageResponse?.key; + client.sendMessage(metadata.remoteJid, res); + } } ); } diff --git a/src/hooks/useSpinner.ts b/src/hooks/useSpinner.ts deleted file mode 100644 index a003e9e..0000000 --- a/src/hooks/useSpinner.ts +++ /dev/null @@ -1,2 +0,0 @@ -import ora from 'ora'; -export const useSpinner = ora; diff --git a/src/models/FluxModel.ts b/src/models/FluxModel.ts index fb3386b..e902ab7 100644 --- a/src/models/FluxModel.ts +++ b/src/models/FluxModel.ts @@ -1,62 +1,43 @@ -import { useSpinner } from '../hooks/useSpinner'; +/* Local modules */ import { ENV } from '../baileys/env'; -import { MessageTemplates } from '../util/MessageTemplates'; -import { AIModel } from './BaseAiModel'; +import { AIArguments, AIHandle, AIModel } from './BaseAiModel'; -interface FluxAiModelParams { - sender: string; - prompt: string; -} - -type HandleType = (res: Buffer, err?: string) => Promise; +/* Flux Model */ +class FluxModel extends AIModel { + public endPointAPI: string = 'https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev'; + private headers; -class FluxModel extends AIModel { constructor() { - super(ENV.HFKey, 'FLUX'); + super(ENV.API_KEY_HF, 'FLUX'); + + this.headers = { + Authorization: `Bearer ${this.getApiKey()}`, + 'Content-Type': 'application/json' + }; } - async sendMessage({ sender, prompt }: FluxAiModelParams, handle: HandleType): Promise { - const spinner = useSpinner(MessageTemplates.requestStr(this.aiModelName, sender, prompt)); - spinner.start(); - try { - const startTime = Date.now(); + public async generateImage(prompt: string): Promise { + const response = await fetch(this.endPointAPI, { + headers: this.headers, + method: 'POST', + body: JSON.stringify({ + inputs: prompt + }) + }); - const response = await fetch( - 'https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell', - { - headers: { - Authorization: `Bearer ${this.apiKey}`, - 'Content-Type': 'application/json' - }, - method: 'POST', - body: JSON.stringify({ - inputs: prompt - }) - } - ); + const buffer = await (await response.blob()).arrayBuffer(); + const base64Img = Buffer.from(buffer); - const buffer = await (await response.blob()).arrayBuffer(); - const base64Img = Buffer.from(buffer); + return base64Img; + } - await handle(base64Img); + async sendMessage({ prompt }: AIArguments, handle: AIHandle): Promise { + try { + const imageData = await this.generateImage(prompt); - spinner.succeed( - MessageTemplates.reqSucceedStr( - this.aiModelName, - sender, - '', - Date.now() - startTime - ) - ); + await handle({ image: imageData }); } catch (err) { - spinner.fail( - MessageTemplates.reqFailStr( - this.aiModelName, - 'at FluxModel.ts sendMessage(prompt, msg)', - err - ) - ); - await handle(Buffer.from(''), 'An error occur please see console for more information.'); + await handle('', 'An error occur please see console for more information.'); } } } diff --git a/src/models/GeminiModel.ts b/src/models/GeminiModel.ts index 66b274e..734046f 100644 --- a/src/models/GeminiModel.ts +++ b/src/models/GeminiModel.ts @@ -5,11 +5,7 @@ import { GenerativeModel, GoogleGenerativeAI } from '@google/generative-ai'; -import { - AnyMessageContent, - downloadMediaMessage, - generateWAMessage -} from '@whiskeysockets/baileys'; +import { downloadMediaMessage } from '@whiskeysockets/baileys'; /* Local modules */ import { AIModel, AIArguments, AIHandle, AIMetaData } from './BaseAiModel'; @@ -72,8 +68,6 @@ class GeminiModel extends AIModel { async sendMessage({ sender, prompt, metadata }: AIArguments, handle: AIHandle) { try { let message = ''; - console.log(metadata.quoteMetaData); - if (metadata.isQuoted) { if (metadata.quoteMetaData.type === 'image') { message = this.iconPrefix + (await this.generateImageCompletion(prompt, metadata)); diff --git a/src/models/GeminiVisionModel.ts b/src/models/GeminiVisionModel.ts deleted file mode 100644 index cea61de..0000000 --- a/src/models/GeminiVisionModel.ts +++ /dev/null @@ -1,65 +0,0 @@ -import { AIModel } from './BaseAiModel'; -import { ENV } from '../baileys/env'; -import { GoogleGenerativeAI } from '@google/generative-ai'; -import { useSpinner } from '../hooks/useSpinner'; -import { MessageTemplates } from '../util/MessageTemplates'; - -interface GeminiVisionModelParams { - sender: string; - prompt: { prompt: string; buffer: Buffer; mimeType: string }; -} - -type HandleType = (res: string, error?: string) => Promise; - -class GeminiVisionModel extends AIModel { - public constructor() { - super(ENV.geminiKey, 'GeminiVision', 'Image'); - this.genAI = new GoogleGenerativeAI(this.apiKey); - } - - async sendMessage({ sender, prompt }: GeminiVisionModelParams, handle: HandleType): Promise { - const spinner = useSpinner( - MessageTemplates.requestStr(this.aiModelName, sender, '') - ); - spinner.start(); - - try { - const startTime = Date.now(); - - // check out more at: https://ai.google.dev/tutorials/node_quickstart - const model = this.genAI.getGenerativeModel({ model: 'gemini-1.5-flash' }); - - const imageParts = [this.toGenerativePart(prompt.buffer, prompt.mimeType)]; - const result = await model.generateContent([prompt.prompt, ...imageParts]); - const resText = result.response.text(); - - await handle(resText); - - spinner.succeed( - MessageTemplates.reqSucceedStr(this.aiModelName, sender, resText, Date.now() - startTime) - ); - } catch (err) { - spinner.fail( - MessageTemplates.reqFailStr( - this.aiModelName, - 'at GeminiVisionModel.ts sendMessage(prompt, msg)', - err - ) - ); - await handle('', 'An error occur please see console for more information.'); - } - } - - private toGenerativePart(buffer: Buffer, mimeType: string) { - return { - inlineData: { - data: buffer.toString('base64'), - mimeType - } - }; - } - - private genAI: GoogleGenerativeAI; -} - -export { GeminiVisionModel }; diff --git a/src/models/OpenAIModel.ts b/src/models/OpenAIModel.ts index c19c9b2..070cdb2 100644 --- a/src/models/OpenAIModel.ts +++ b/src/models/OpenAIModel.ts @@ -1,25 +1,20 @@ /* Third-party modules */ -import { - ChatCompletionMessageParam, - ChatCompletionMessage -} from 'openai/resources/chat/completions'; +import { ChatCompletionMessage } from 'openai/resources/chat/completions'; import OpenAI from 'openai'; /* Local modules */ -import { AIModel, AIArguments } from './BaseAiModel'; +import { AIModel, AIArguments, AIHandle } from './BaseAiModel'; import { ENV } from '../baileys/env'; import config from '../whatsapp-ai.config'; /* Util */ -type HandleType = (res: string, error?: string) => Promise; - interface BotImageResponse { url: string; caption: string; } /* ChatGPT Model */ -class ChatGPTModel extends AIModel { +class ChatGPTModel extends AIModel { /* Variables */ private Dalle3: boolean; private Dalle: OpenAI; @@ -65,7 +60,7 @@ class ChatGPTModel extends AIModel { return { url: resInfo.url as string, caption: resInfo.revised_prompt as string }; } - public async sendMessage({ sender, prompt }: AIArguments, handle: HandleType): Promise { + public async sendMessage({ sender, prompt }: AIArguments, handle: AIHandle): Promise { try { if (!this.sessionExists(sender)) { this.sessionCreate(sender); @@ -74,7 +69,7 @@ class ChatGPTModel extends AIModel { const completion = await this.generateCompletion(sender); const res = completion.content || ''; - await handle(res); + await handle({ text: res }); } catch (err) { await handle('', 'An error occur please see console for more information.'); } diff --git a/src/whatsapp-ai.config.ts b/src/whatsapp-ai.config.ts index 3cf48d6..e65be69 100644 --- a/src/whatsapp-ai.config.ts +++ b/src/whatsapp-ai.config.ts @@ -16,12 +16,12 @@ const config: Config = { } }, Gemini: { - prefix: '!gemini', - enable: true + prefix: ENV.GEMINI_PREFIX, + enable: ENV.GEMINI_ENABLED }, FLUX: { - prefix: '!flux', - enable: true + prefix: ENV.HF_PREFIX, + enable: ENV.HF_ENABLED } /* Custom: [