diff --git a/src/embeddings.test.ts b/src/embeddings.test.ts new file mode 100644 index 0000000..e84185f --- /dev/null +++ b/src/embeddings.test.ts @@ -0,0 +1,141 @@ +import { assertEquals } from "https://deno.land/std@0.224.0/assert/mod.ts"; +import "./mocks/syscalls.ts"; +import { aiSettings, initializeOpenAI } from "./init.ts"; +import { + canIndexPage, + indexEmbeddings, + shouldIndexEmbeddings, + shouldIndexSummaries, +} from "./embeddings.ts"; + +const settingsPageSample = ` +\`\`\`yaml +ai: + indexEmbeddings: true + indexEmbeddingsExcludePages: + - passwords + indexEmbeddingsExcludeStrings: + - foo + chat: + bakeMessages: false + customEnrichFunctions: + - enrichWithURL + textModels: + - name: mock-t1 + provider: mock + modelName: mock-t1 + imageModels: + - name: mock-i1 + provider: mock + modelName: mock-i1 + embeddingModels: + - name: mock-e1 + modelName: mock-e1 + provider: mock + baseUrl: http://localhost:11434 + requireAuth: false +\`\`\` + `; + +const settingsPageSampleNoEmbeddings = ` +\`\`\`yaml +ai: + indexEmbeddings: true + indexSummaries: true +\`\`\` + `; + +const secretsPageSample = ` +\`\`\`yaml +OPENAI_API_KEY: bar +\`\`\` + `; + +Deno.test("canIndexPage respects aiSettings.indexEmbeddingsExcludePages", async () => { + await syscall("mock.setPage", "SETTINGS", settingsPageSample); + await syscall("mock.setPage", "SECRETS", secretsPageSample); + await initializeOpenAI(); + aiSettings.indexEmbeddingsExcludePages = ["ExcludedPage"]; + assertEquals(canIndexPage("RegularPage"), true); + assertEquals(canIndexPage("ExcludedPage"), false); + assertEquals(canIndexPage("_HiddenPage"), false); + assertEquals(canIndexPage("Library/SomePage"), false); +}); + +Deno.test("shouldIndexEmbeddings returns true when conditions are met", async () => { + await syscall("mock.setPage", "SETTINGS", settingsPageSample); + await syscall("mock.setPage", "SECRETS", secretsPageSample); + await initializeOpenAI(); + await syscall("mock.setEnv", "server"); + + const result = await shouldIndexEmbeddings(); + assertEquals(result, true); +}); + +Deno.test("shouldIndexEmbeddings returns false when not on server", async () => { + await syscall("mock.setPage", "SETTINGS", settingsPageSample); + await syscall("mock.setPage", "SECRETS", secretsPageSample); + await initializeOpenAI(); + await syscall("mock.setEnv", "client"); + + const result = await shouldIndexEmbeddings(); + assertEquals(result, false); +}); + +Deno.test("shouldIndexEmbeddings returns false when indexEmbeddings is disabled", async () => { + const modifiedSettings = settingsPageSample.replace( + "indexEmbeddings: true", + "indexEmbeddings: false", + ); + await syscall("mock.setPage", "SETTINGS", modifiedSettings); + await syscall("mock.setPage", "SECRETS", secretsPageSample); + await initializeOpenAI(); + await syscall("mock.setEnv", "server"); + + const result = await shouldIndexEmbeddings(); + assertEquals(result, false); +}); + +Deno.test("shouldIndexSummaries returns true when conditions are met", async () => { + const modifiedSettings = settingsPageSample.replace( + "indexEmbeddings: true", + "indexEmbeddings: true\n indexSummary: true", + ); + await syscall("mock.setPage", "SETTINGS", modifiedSettings); + await syscall("mock.setPage", "SECRETS", secretsPageSample); + await initializeOpenAI(); + await syscall("mock.setEnv", "server"); + + const result = await shouldIndexSummaries(); + assertEquals(result, true); +}); + +Deno.test("shouldIndexSummaries returns false when indexSummary is disabled", async () => { + await syscall("mock.setPage", "SETTINGS", settingsPageSample); + await syscall("mock.setPage", "SECRETS", secretsPageSample); + await initializeOpenAI(); + await syscall("mock.setEnv", "server"); + + const result = await shouldIndexSummaries(); + assertEquals(result, false); +}); + +Deno.test("shouldIndexEmbeddings returns false when no embedding models are configured", async () => { + await syscall("mock.setPage", "SETTINGS", settingsPageSampleNoEmbeddings); + await syscall("mock.setPage", "SECRETS", secretsPageSample); + await initializeOpenAI(); + await syscall("mock.setEnv", "server"); + + const result = await shouldIndexEmbeddings(); + assertEquals(result, false); +}); + +Deno.test("shouldIndexSummaries returns false when no embedding models are configured", async () => { + await syscall("mock.setPage", "SETTINGS", settingsPageSampleNoEmbeddings); + await syscall("mock.setPage", "SECRETS", secretsPageSample); + await initializeOpenAI(); + await syscall("mock.setEnv", "server"); + + const result = await shouldIndexSummaries(); + assertEquals(result, false); +}); diff --git a/src/embeddings.ts b/src/embeddings.ts index 95216e3..4a62929 100644 --- a/src/embeddings.ts +++ b/src/embeddings.ts @@ -7,7 +7,11 @@ import type { } from "./types.ts"; import { indexObjects, queryObjects } from "$sbplugs/index/plug_api.ts"; import { renderToText } from "$sb/lib/tree.ts"; -import { currentEmbeddingProvider, initIfNeeded } from "../src/init.ts"; +import { + currentEmbeddingModel, + currentEmbeddingProvider, + initIfNeeded, +} from "../src/init.ts"; import { log, supportsServerProxyCall } from "./utils.ts"; import { editor, system } from "$sb/syscalls.ts"; import { aiSettings, configureSelectedModel } from "./init.ts"; @@ -18,7 +22,7 @@ const searchPrefix = "🤖 "; /** * Check whether a page is allowed to be indexed or not. */ -function canIndexPage(pageName: string): boolean { +export function canIndexPage(pageName: string): boolean { // Only index pages if the user enabled it, and skip anything they want to exclude const excludePages = [ "SETTINGS", @@ -36,17 +40,39 @@ function canIndexPage(pageName: string): boolean { return true; } +// Logic for whether or not to index something: +// - On server +// - With embeddings enabled +// - With a valid embedding model and provider + +export async function shouldIndexEmbeddings() { + await initIfNeeded(); + return aiSettings.indexEmbeddings && + currentEmbeddingProvider !== undefined && + currentEmbeddingModel !== undefined && + aiSettings.embeddingModels.length > 0 && + (await system.getEnv()) === "server"; +} + +export async function shouldIndexSummaries() { + await initIfNeeded(); + return aiSettings.indexEmbeddings && + aiSettings.indexSummary && + currentEmbeddingProvider !== undefined && + currentEmbeddingModel !== undefined && + aiSettings.embeddingModels.length > 0 && + (await system.getEnv()) === "server"; +} + /** * Generate embeddings for each paragraph in a page, and then indexes * them. */ export async function indexEmbeddings({ name: page, tree }: IndexTreeEvent) { - if (await system.getEnv() !== "server") { + if (!await shouldIndexEmbeddings()) { return; } - await initIfNeeded(); - if (!canIndexPage(page)) { return; } @@ -117,12 +143,7 @@ export async function indexEmbeddings({ name: page, tree }: IndexTreeEvent) { * Generate a summary for a page, and then indexes it. */ export async function indexSummary({ name: page, tree }: IndexTreeEvent) { - if (await system.getEnv() !== "server") { - return; - } - await initIfNeeded(); - - if (!aiSettings.indexSummary) { + if (!await shouldIndexSummaries()) { return; } @@ -214,6 +235,9 @@ export async function getAllAISummaries(): Promise { export async function generateEmbeddings(text: string): Promise { await initIfNeeded(); + if (!currentEmbeddingProvider || !currentEmbeddingModel) { + throw new Error("No embedding provider found"); + } return await currentEmbeddingProvider.generateEmbeddings({ text }); } diff --git a/src/init.ts b/src/init.ts index 915a7d9..c691641 100644 --- a/src/init.ts +++ b/src/init.ts @@ -19,6 +19,9 @@ import type { PromptInstructions, } from "./types.ts"; import { EmbeddingProvider, ImageProvider, Provider } from "./types.ts"; +import { MockImageProvider } from "./mocks/mockproviders.ts"; +import { MockProvider } from "./mocks/mockproviders.ts"; +import { MockEmbeddingProvider } from "./mocks/mockproviders.ts"; export let apiKey: string; export let aiSettings: AISettings; @@ -146,6 +149,12 @@ function setupImageProvider(model: ImageModelConfig) { model.baseUrl || aiSettings.dallEBaseUrl, ); break; + case ImageProvider.Mock: + currentImageProvider = new MockImageProvider( + apiKey, + model.modelName, + ); + break; default: throw new Error( `Unsupported image provider: ${model.provider}. Please configure a supported provider.`, @@ -167,6 +176,13 @@ function setupAIProvider(model: ModelConfig) { case Provider.Gemini: currentAIProvider = new GeminiProvider(apiKey, model.modelName); break; + case Provider.Mock: + currentAIProvider = new MockProvider( + apiKey, + model.modelName, + model.baseUrl, + ); + break; default: throw new Error( `Unsupported AI provider: ${model.provider}. Please configure a supported provider.`, @@ -200,6 +216,13 @@ function setupEmbeddingProvider(model: EmbeddingModelConfig) { model.requireAuth, ); break; + case EmbeddingProvider.Mock: + currentEmbeddingProvider = new MockEmbeddingProvider( + apiKey, + model.modelName, + model.baseUrl, + ); + break; default: throw new Error( `Unsupported embedding provider: ${model.provider}. Please configure a supported provider.`, diff --git a/src/mocks/mockproviders.ts b/src/mocks/mockproviders.ts new file mode 100644 index 0000000..01c8bca --- /dev/null +++ b/src/mocks/mockproviders.ts @@ -0,0 +1,68 @@ +import { AbstractProvider } from "../interfaces/Provider.ts"; +import { AbstractImageProvider } from "../interfaces/ImageProvider.ts"; +import { AbstractEmbeddingProvider } from "../interfaces/EmbeddingProvider.ts"; +import { + EmbeddingGenerationOptions, + ImageGenerationOptions, + StreamChatOptions, +} from "../types.ts"; + +export class MockProvider extends AbstractProvider { + constructor( + apiKey: string, + modelName: string, + baseUrl: string = "http://localhost", + ) { + super(apiKey, baseUrl, "mock", modelName); + } + + async chatWithAI(options: StreamChatOptions): Promise { + const mockResponse = "This is a mock response from the AI."; + if (options.onDataReceived) { + for (const char of mockResponse) { + await new Promise((resolve) => setTimeout(resolve, 50)); + options.onDataReceived(char); + } + } + return mockResponse; + } +} + +export class MockImageProvider extends AbstractImageProvider { + constructor( + apiKey: string, + modelName: string, + baseUrl: string = "http://localhost", + ) { + super(apiKey, baseUrl, "mock", modelName); + } + + generateImage(options: ImageGenerationOptions): Promise { + return new Promise((resolve) => { + setTimeout(() => { + resolve("https://example.com/mock-image.jpg"); + }, 5); + }); + } +} + +export class MockEmbeddingProvider extends AbstractEmbeddingProvider { + constructor( + apiKey: string, + modelName: string, + baseUrl: string = "http://localhost", + ) { + super(apiKey, baseUrl, "mock", modelName); + } + + _generateEmbeddings( + options: EmbeddingGenerationOptions, + ): Promise> { + return new Promise>((resolve) => { + setTimeout(() => { + const mockEmbedding = Array(1536).fill(0).map(() => Math.random()); + resolve(mockEmbedding); + }, 5); + }); + } +} diff --git a/src/mocks/syscalls.ts b/src/mocks/syscalls.ts index 492e47a..89c84e1 100644 --- a/src/mocks/syscalls.ts +++ b/src/mocks/syscalls.ts @@ -11,6 +11,12 @@ let pages: { [key: string]: string } = {}; let currentEnv: string = "server"; (globalThis as any).currentEnv; +let clientStore: { [key: string]: string } = {}; +(globalThis as any).clientStore; + +// let indexedObjects: { [key: string]: string } = {}; +// (globalThis as any).indexedObjects; + globalThis.syscall = async (name: string, ...args: readonly any[]) => { switch (name) { // I tried a lot of things to get this working differently, but @@ -40,7 +46,29 @@ globalThis.syscall = async (name: string, ...args: readonly any[]) => { return await Promise.resolve(parse(extendedMarkdownLanguage, args[0])); case "yaml.parse": return await Promise.resolve(YAML.parse(args[0])); + + case "system.invokeFunctionOnServer": + return invokeFunctionMock(args); + case "system.invokeFunction": + return invokeFunctionMock(args); + + case "clientStore.set": + clientStore[args[0]] = args[1]; + break; + case "clientStore.get": + return clientStore[args[0]]; + default: throw Error(`Missing mock for: ${name}`); } }; + +function invokeFunctionMock(args: readonly any[]) { + switch (args[0]) { + case "index.indexObjects": + return true; + default: + console.log("system.invokeFunctionOnServer", args); + throw Error(`Missing invokeFunction mock for ${args[0]}`); + } +} diff --git a/src/types.ts b/src/types.ts index b3c6ad2..eef1652 100644 --- a/src/types.ts +++ b/src/types.ts @@ -62,16 +62,22 @@ export type ChatMessage = { export enum Provider { OpenAI = "openai", Gemini = "gemini", + + Mock = "mock", } export enum ImageProvider { DallE = "dalle", + + Mock = "mock", } export enum EmbeddingProvider { OpenAI = "openai", Gemini = "gemini", Ollama = "ollama", + + Mock = "mock", } export type ChatSettings = {