Skip to content

Commit

Permalink
Merge pull request #59 from vicentefelipechile/master
Browse files Browse the repository at this point in the history
Added Flux model support & some changes/fixes
  • Loading branch information
Zain-ul-din authored Oct 19, 2024
2 parents 16223f7 + c14adea commit a188206
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 145 deletions.
10 changes: 8 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
DEBUG=True
PROCESSING="Thinking..."

# Model Services
API_KEY_OPENAI=ADD_YOUR_KEY
Expand All @@ -21,11 +22,16 @@ OPENAI_ENABLED=False
OPENAI_ICON_PREFIX=🤖

DALLE_PREFIX=!dalle
DALLE_ENABLED=True
DALLE_ENABLED=False
DALLE_ICON_PREFIX=🎨
DALLE_USE_3=False

# # Google Gemini
GEMINI_PREFIX=!gemini
GEMINI_ENABLED=True
GEMINI_ICON_PREFIX=🔮
GEMINI_ICON_PREFIX=🔮

# # Hugging Face Flux
HF_PREFIX=!flux
HF_ENABLED=False
HF_ICON_PREFIX=🤗
10 changes: 9 additions & 1 deletion src/baileys/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ interface EnvInterface {
GEMINI_PREFIX?: string;
GEMINI_ENABLED: boolean;
GEMINI_ICON_PREFIX?: string;

// // Hugging Face
HF_PREFIX?: string;
HF_ENABLED: boolean;
HF_ICON_PREFIX?: string;
}

export const ENV: EnvInterface = {
Expand All @@ -56,5 +61,8 @@ export const ENV: EnvInterface = {
DALLE_USE_3: process.env.DALLE_USE_3 === 'True',
GEMINI_PREFIX: process.env.GEMINI_PREFIX,
GEMINI_ENABLED: process.env.GEMINI_ENABLED === 'True',
GEMINI_ICON_PREFIX: process.env.GEMINI_ICON_PREFIX
GEMINI_ICON_PREFIX: process.env.GEMINI_ICON_PREFIX,
HF_PREFIX: process.env.HF_PREFIX,
HF_ENABLED: process.env.HF_ENABLED === 'True',
HF_ICON_PREFIX: process.env.HF_ICON_PREFIX
};
17 changes: 11 additions & 6 deletions src/baileys/handlers/message.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
/* Third-party modules */
import { AnyRegularMessageContent, downloadMediaMessage } from '@whiskeysockets/baileys';

/* Local modules */
import MessageHandlerParams from './../aimodeltype';
import { AIModels } from './../../types/AiModels';
Expand All @@ -16,7 +13,7 @@ import { ENV } from '../env';
const modelTable: Record<AIModels, any> = {
ChatGPT: ENV.OPENAI_ENABLED ? new ChatGPTModel() : null,
Gemini: ENV.GEMINI_ENABLED ? new GeminiModel() : null,
FLUX: new FluxModel()
FLUX: ENV.HF_ENABLED ? new FluxModel() : null
};

// handles message
Expand Down Expand Up @@ -51,9 +48,17 @@ export async function handleMessage({ client, msg, metadata }: MessageHandlerPar
console.error(err);
return;
}
res.edit = messageResponse?.key;

client.sendMessage(metadata.remoteJid, res);
if (res.image) {
// delete the old message
if (messageResponse?.key) {
client.sendMessage(metadata.remoteJid, { delete: messageResponse.key });
}
client.sendMessage(metadata.remoteJid, res, { quoted: msg });
} else {
res.edit = messageResponse?.key;
client.sendMessage(metadata.remoteJid, res);
}
}
);
}
Expand Down
2 changes: 0 additions & 2 deletions src/hooks/useSpinner.ts

This file was deleted.

77 changes: 29 additions & 48 deletions src/models/FluxModel.ts
Original file line number Diff line number Diff line change
@@ -1,62 +1,43 @@
import { useSpinner } from '../hooks/useSpinner';
/* Local modules */
import { ENV } from '../baileys/env';
import { MessageTemplates } from '../util/MessageTemplates';
import { AIModel } from './BaseAiModel';
import { AIArguments, AIHandle, AIModel } from './BaseAiModel';

interface FluxAiModelParams {
sender: string;
prompt: string;
}

type HandleType = (res: Buffer, err?: string) => Promise<void>;
/* Flux Model */
class FluxModel extends AIModel<AIArguments, AIHandle> {
public endPointAPI: string = 'https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev';
private headers;

class FluxModel extends AIModel<FluxAiModelParams, HandleType> {
constructor() {
super(ENV.HFKey, 'FLUX');
super(ENV.API_KEY_HF, 'FLUX');

this.headers = {
Authorization: `Bearer ${this.getApiKey()}`,
'Content-Type': 'application/json'
};
}

async sendMessage({ sender, prompt }: FluxAiModelParams, handle: HandleType): Promise<any> {
const spinner = useSpinner(MessageTemplates.requestStr(this.aiModelName, sender, prompt));
spinner.start();
try {
const startTime = Date.now();
public async generateImage(prompt: string): Promise<any> {
const response = await fetch(this.endPointAPI, {
headers: this.headers,
method: 'POST',
body: JSON.stringify({
inputs: prompt
})
});

const response = await fetch(
'https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell',
{
headers: {
Authorization: `Bearer ${this.apiKey}`,
'Content-Type': 'application/json'
},
method: 'POST',
body: JSON.stringify({
inputs: prompt
})
}
);
const buffer = await (await response.blob()).arrayBuffer();
const base64Img = Buffer.from(buffer);

const buffer = await (await response.blob()).arrayBuffer();
const base64Img = Buffer.from(buffer);
return base64Img;
}

await handle(base64Img);
async sendMessage({ prompt }: AIArguments, handle: AIHandle): Promise<any> {
try {
const imageData = await this.generateImage(prompt);

spinner.succeed(
MessageTemplates.reqSucceedStr(
this.aiModelName,
sender,
'<Image Buffer>',
Date.now() - startTime
)
);
await handle({ image: imageData });
} catch (err) {
spinner.fail(
MessageTemplates.reqFailStr(
this.aiModelName,
'at FluxModel.ts sendMessage(prompt, msg)',
err
)
);
await handle(Buffer.from(''), 'An error occur please see console for more information.');
await handle('', 'An error occur please see console for more information.');
}
}
}
Expand Down
8 changes: 1 addition & 7 deletions src/models/GeminiModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@ import {
GenerativeModel,
GoogleGenerativeAI
} from '@google/generative-ai';
import {
AnyMessageContent,
downloadMediaMessage,
generateWAMessage
} from '@whiskeysockets/baileys';
import { downloadMediaMessage } from '@whiskeysockets/baileys';

/* Local modules */
import { AIModel, AIArguments, AIHandle, AIMetaData } from './BaseAiModel';
Expand Down Expand Up @@ -72,8 +68,6 @@ class GeminiModel extends AIModel<AIArguments, AIHandle> {
async sendMessage({ sender, prompt, metadata }: AIArguments, handle: AIHandle) {
try {
let message = '';
console.log(metadata.quoteMetaData);

if (metadata.isQuoted) {
if (metadata.quoteMetaData.type === 'image') {
message = this.iconPrefix + (await this.generateImageCompletion(prompt, metadata));
Expand Down
65 changes: 0 additions & 65 deletions src/models/GeminiVisionModel.ts

This file was deleted.

15 changes: 5 additions & 10 deletions src/models/OpenAIModel.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
/* Third-party modules */
import {
ChatCompletionMessageParam,
ChatCompletionMessage
} from 'openai/resources/chat/completions';
import { ChatCompletionMessage } from 'openai/resources/chat/completions';
import OpenAI from 'openai';

/* Local modules */
import { AIModel, AIArguments } from './BaseAiModel';
import { AIModel, AIArguments, AIHandle } from './BaseAiModel';
import { ENV } from '../baileys/env';
import config from '../whatsapp-ai.config';

/* Util */
type HandleType = (res: string, error?: string) => Promise<void>;

interface BotImageResponse {
url: string;
caption: string;
}

/* ChatGPT Model */
class ChatGPTModel extends AIModel<AIArguments, HandleType> {
class ChatGPTModel extends AIModel<AIArguments, AIHandle> {
/* Variables */
private Dalle3: boolean;
private Dalle: OpenAI;
Expand Down Expand Up @@ -65,7 +60,7 @@ class ChatGPTModel extends AIModel<AIArguments, HandleType> {
return { url: resInfo.url as string, caption: resInfo.revised_prompt as string };
}

public async sendMessage({ sender, prompt }: AIArguments, handle: HandleType): Promise<any> {
public async sendMessage({ sender, prompt }: AIArguments, handle: AIHandle): Promise<any> {
try {
if (!this.sessionExists(sender)) {
this.sessionCreate(sender);
Expand All @@ -74,7 +69,7 @@ class ChatGPTModel extends AIModel<AIArguments, HandleType> {

const completion = await this.generateCompletion(sender);
const res = completion.content || '';
await handle(res);
await handle({ text: res });
} catch (err) {
await handle('', 'An error occur please see console for more information.');
}
Expand Down
8 changes: 4 additions & 4 deletions src/whatsapp-ai.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ const config: Config = {
}
},
Gemini: {
prefix: '!gemini',
enable: true
prefix: ENV.GEMINI_PREFIX,
enable: ENV.GEMINI_ENABLED
},
FLUX: {
prefix: '!flux',
enable: true
prefix: ENV.HF_PREFIX,
enable: ENV.HF_ENABLED
}
/*
Custom: [
Expand Down

0 comments on commit a188206

Please sign in to comment.