Skip to content

Commit

Permalink
Merge pull request #53 from justyns/fix-undefined-error
Browse files Browse the repository at this point in the history
Fix bug causing generateEmbeddings to run when it shouldn't
  • Loading branch information
justyns authored Jul 31, 2024
2 parents d8a845a + e0d6ec0 commit d9d6027
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 11 deletions.
141 changes: 141 additions & 0 deletions src/embeddings.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import { assertEquals } from "https://deno.land/std@0.224.0/assert/mod.ts";
import "./mocks/syscalls.ts";
import { aiSettings, initializeOpenAI } from "./init.ts";
import {
canIndexPage,
indexEmbeddings,
shouldIndexEmbeddings,
shouldIndexSummaries,
} from "./embeddings.ts";

const settingsPageSample = `
\`\`\`yaml
ai:
indexEmbeddings: true
indexEmbeddingsExcludePages:
- passwords
indexEmbeddingsExcludeStrings:
- foo
chat:
bakeMessages: false
customEnrichFunctions:
- enrichWithURL
textModels:
- name: mock-t1
provider: mock
modelName: mock-t1
imageModels:
- name: mock-i1
provider: mock
modelName: mock-i1
embeddingModels:
- name: mock-e1
modelName: mock-e1
provider: mock
baseUrl: http://localhost:11434
requireAuth: false
\`\`\`
`;

const settingsPageSampleNoEmbeddings = `
\`\`\`yaml
ai:
indexEmbeddings: true
indexSummaries: true
\`\`\`
`;

const secretsPageSample = `
\`\`\`yaml
OPENAI_API_KEY: bar
\`\`\`
`;

Deno.test("canIndexPage respects aiSettings.indexEmbeddingsExcludePages", async () => {
await syscall("mock.setPage", "SETTINGS", settingsPageSample);
await syscall("mock.setPage", "SECRETS", secretsPageSample);
await initializeOpenAI();
aiSettings.indexEmbeddingsExcludePages = ["ExcludedPage"];
assertEquals(canIndexPage("RegularPage"), true);
assertEquals(canIndexPage("ExcludedPage"), false);
assertEquals(canIndexPage("_HiddenPage"), false);
assertEquals(canIndexPage("Library/SomePage"), false);
});

Deno.test("shouldIndexEmbeddings returns true when conditions are met", async () => {
await syscall("mock.setPage", "SETTINGS", settingsPageSample);
await syscall("mock.setPage", "SECRETS", secretsPageSample);
await initializeOpenAI();
await syscall("mock.setEnv", "server");

const result = await shouldIndexEmbeddings();
assertEquals(result, true);
});

Deno.test("shouldIndexEmbeddings returns false when not on server", async () => {
await syscall("mock.setPage", "SETTINGS", settingsPageSample);
await syscall("mock.setPage", "SECRETS", secretsPageSample);
await initializeOpenAI();
await syscall("mock.setEnv", "client");

const result = await shouldIndexEmbeddings();
assertEquals(result, false);
});

Deno.test("shouldIndexEmbeddings returns false when indexEmbeddings is disabled", async () => {
const modifiedSettings = settingsPageSample.replace(
"indexEmbeddings: true",
"indexEmbeddings: false",
);
await syscall("mock.setPage", "SETTINGS", modifiedSettings);
await syscall("mock.setPage", "SECRETS", secretsPageSample);
await initializeOpenAI();
await syscall("mock.setEnv", "server");

const result = await shouldIndexEmbeddings();
assertEquals(result, false);
});

Deno.test("shouldIndexSummaries returns true when conditions are met", async () => {
const modifiedSettings = settingsPageSample.replace(
"indexEmbeddings: true",
"indexEmbeddings: true\n indexSummary: true",
);
await syscall("mock.setPage", "SETTINGS", modifiedSettings);
await syscall("mock.setPage", "SECRETS", secretsPageSample);
await initializeOpenAI();
await syscall("mock.setEnv", "server");

const result = await shouldIndexSummaries();
assertEquals(result, true);
});

Deno.test("shouldIndexSummaries returns false when indexSummary is disabled", async () => {
await syscall("mock.setPage", "SETTINGS", settingsPageSample);
await syscall("mock.setPage", "SECRETS", secretsPageSample);
await initializeOpenAI();
await syscall("mock.setEnv", "server");

const result = await shouldIndexSummaries();
assertEquals(result, false);
});

Deno.test("shouldIndexEmbeddings returns false when no embedding models are configured", async () => {
await syscall("mock.setPage", "SETTINGS", settingsPageSampleNoEmbeddings);
await syscall("mock.setPage", "SECRETS", secretsPageSample);
await initializeOpenAI();
await syscall("mock.setEnv", "server");

const result = await shouldIndexEmbeddings();
assertEquals(result, false);
});

Deno.test("shouldIndexSummaries returns false when no embedding models are configured", async () => {
await syscall("mock.setPage", "SETTINGS", settingsPageSampleNoEmbeddings);
await syscall("mock.setPage", "SECRETS", secretsPageSample);
await initializeOpenAI();
await syscall("mock.setEnv", "server");

const result = await shouldIndexSummaries();
assertEquals(result, false);
});
46 changes: 35 additions & 11 deletions src/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ import type {
} from "./types.ts";
import { indexObjects, queryObjects } from "$sbplugs/index/plug_api.ts";
import { renderToText } from "$sb/lib/tree.ts";
import { currentEmbeddingProvider, initIfNeeded } from "../src/init.ts";
import {
currentEmbeddingModel,
currentEmbeddingProvider,
initIfNeeded,
} from "../src/init.ts";
import { log, supportsServerProxyCall } from "./utils.ts";
import { editor, system } from "$sb/syscalls.ts";
import { aiSettings, configureSelectedModel } from "./init.ts";
Expand All @@ -18,7 +22,7 @@ const searchPrefix = "🤖 ";
/**
* Check whether a page is allowed to be indexed or not.
*/
function canIndexPage(pageName: string): boolean {
export function canIndexPage(pageName: string): boolean {
// Only index pages if the user enabled it, and skip anything they want to exclude
const excludePages = [
"SETTINGS",
Expand All @@ -36,17 +40,39 @@ function canIndexPage(pageName: string): boolean {
return true;
}

// Logic for whether or not to index something:
// - On server
// - With embeddings enabled
// - With a valid embedding model and provider

export async function shouldIndexEmbeddings() {
await initIfNeeded();
return aiSettings.indexEmbeddings &&
currentEmbeddingProvider !== undefined &&
currentEmbeddingModel !== undefined &&
aiSettings.embeddingModels.length > 0 &&
(await system.getEnv()) === "server";
}

export async function shouldIndexSummaries() {
await initIfNeeded();
return aiSettings.indexEmbeddings &&
aiSettings.indexSummary &&
currentEmbeddingProvider !== undefined &&
currentEmbeddingModel !== undefined &&
aiSettings.embeddingModels.length > 0 &&
(await system.getEnv()) === "server";
}

/**
* Generate embeddings for each paragraph in a page, and then indexes
* them.
*/
export async function indexEmbeddings({ name: page, tree }: IndexTreeEvent) {
if (await system.getEnv() !== "server") {
if (!await shouldIndexEmbeddings()) {
return;
}

await initIfNeeded();

if (!canIndexPage(page)) {
return;
}
Expand Down Expand Up @@ -117,12 +143,7 @@ export async function indexEmbeddings({ name: page, tree }: IndexTreeEvent) {
* Generate a summary for a page, and then indexes it.
*/
export async function indexSummary({ name: page, tree }: IndexTreeEvent) {
if (await system.getEnv() !== "server") {
return;
}
await initIfNeeded();

if (!aiSettings.indexSummary) {
if (!await shouldIndexSummaries()) {
return;
}

Expand Down Expand Up @@ -214,6 +235,9 @@ export async function getAllAISummaries(): Promise<AISummaryObject[]> {

export async function generateEmbeddings(text: string): Promise<number[]> {
await initIfNeeded();
if (!currentEmbeddingProvider || !currentEmbeddingModel) {
throw new Error("No embedding provider found");
}
return await currentEmbeddingProvider.generateEmbeddings({ text });
}

Expand Down
23 changes: 23 additions & 0 deletions src/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ import type {
PromptInstructions,
} from "./types.ts";
import { EmbeddingProvider, ImageProvider, Provider } from "./types.ts";
import { MockImageProvider } from "./mocks/mockproviders.ts";
import { MockProvider } from "./mocks/mockproviders.ts";
import { MockEmbeddingProvider } from "./mocks/mockproviders.ts";

export let apiKey: string;
export let aiSettings: AISettings;
Expand Down Expand Up @@ -146,6 +149,12 @@ function setupImageProvider(model: ImageModelConfig) {
model.baseUrl || aiSettings.dallEBaseUrl,
);
break;
case ImageProvider.Mock:
currentImageProvider = new MockImageProvider(
apiKey,
model.modelName,
);
break;
default:
throw new Error(
`Unsupported image provider: ${model.provider}. Please configure a supported provider.`,
Expand All @@ -167,6 +176,13 @@ function setupAIProvider(model: ModelConfig) {
case Provider.Gemini:
currentAIProvider = new GeminiProvider(apiKey, model.modelName);
break;
case Provider.Mock:
currentAIProvider = new MockProvider(
apiKey,
model.modelName,
model.baseUrl,
);
break;
default:
throw new Error(
`Unsupported AI provider: ${model.provider}. Please configure a supported provider.`,
Expand Down Expand Up @@ -200,6 +216,13 @@ function setupEmbeddingProvider(model: EmbeddingModelConfig) {
model.requireAuth,
);
break;
case EmbeddingProvider.Mock:
currentEmbeddingProvider = new MockEmbeddingProvider(
apiKey,
model.modelName,
model.baseUrl,
);
break;
default:
throw new Error(
`Unsupported embedding provider: ${model.provider}. Please configure a supported provider.`,
Expand Down
68 changes: 68 additions & 0 deletions src/mocks/mockproviders.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import { AbstractProvider } from "../interfaces/Provider.ts";
import { AbstractImageProvider } from "../interfaces/ImageProvider.ts";
import { AbstractEmbeddingProvider } from "../interfaces/EmbeddingProvider.ts";
import {
EmbeddingGenerationOptions,
ImageGenerationOptions,
StreamChatOptions,
} from "../types.ts";

export class MockProvider extends AbstractProvider {
constructor(
apiKey: string,
modelName: string,
baseUrl: string = "http://localhost",
) {
super(apiKey, baseUrl, "mock", modelName);
}

async chatWithAI(options: StreamChatOptions): Promise<any> {
const mockResponse = "This is a mock response from the AI.";
if (options.onDataReceived) {
for (const char of mockResponse) {
await new Promise((resolve) => setTimeout(resolve, 50));
options.onDataReceived(char);
}
}
return mockResponse;
}
}

export class MockImageProvider extends AbstractImageProvider {
constructor(
apiKey: string,
modelName: string,
baseUrl: string = "http://localhost",
) {
super(apiKey, baseUrl, "mock", modelName);
}

generateImage(options: ImageGenerationOptions): Promise<string> {
return new Promise<string>((resolve) => {
setTimeout(() => {
resolve("https://example.com/mock-image.jpg");
}, 5);
});
}
}

export class MockEmbeddingProvider extends AbstractEmbeddingProvider {
constructor(
apiKey: string,
modelName: string,
baseUrl: string = "http://localhost",
) {
super(apiKey, baseUrl, "mock", modelName);
}

_generateEmbeddings(
options: EmbeddingGenerationOptions,
): Promise<Array<number>> {
return new Promise<Array<number>>((resolve) => {
setTimeout(() => {
const mockEmbedding = Array(1536).fill(0).map(() => Math.random());
resolve(mockEmbedding);
}, 5);
});
}
}
28 changes: 28 additions & 0 deletions src/mocks/syscalls.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ let pages: { [key: string]: string } = {};
let currentEnv: string = "server";
(globalThis as any).currentEnv;

let clientStore: { [key: string]: string } = {};
(globalThis as any).clientStore;

// let indexedObjects: { [key: string]: string } = {};
// (globalThis as any).indexedObjects;

globalThis.syscall = async (name: string, ...args: readonly any[]) => {
switch (name) {
// I tried a lot of things to get this working differently, but
Expand Down Expand Up @@ -40,7 +46,29 @@ globalThis.syscall = async (name: string, ...args: readonly any[]) => {
return await Promise.resolve(parse(extendedMarkdownLanguage, args[0]));
case "yaml.parse":
return await Promise.resolve(YAML.parse(args[0]));

case "system.invokeFunctionOnServer":
return invokeFunctionMock(args);
case "system.invokeFunction":
return invokeFunctionMock(args);

case "clientStore.set":
clientStore[args[0]] = args[1];
break;
case "clientStore.get":
return clientStore[args[0]];

default:
throw Error(`Missing mock for: ${name}`);
}
};

function invokeFunctionMock(args: readonly any[]) {
switch (args[0]) {
case "index.indexObjects":
return true;
default:
console.log("system.invokeFunctionOnServer", args);
throw Error(`Missing invokeFunction mock for ${args[0]}`);
}
}
Loading

0 comments on commit d9d6027

Please sign in to comment.