aboutsummaryrefslogtreecommitdiffstats
path: root/apps
diff options
context:
space:
mode:
authorMohamedBassem <me@mbassem.com>2024-03-27 11:46:29 +0000
committerMohamed Bassem <me@mbassem.com>2024-03-27 11:56:13 +0000
commitd24e50950a36df12b7149b66762a231ac1da14d2 (patch)
tree1e0a9c072bce68c7f4f2bd98c8dd306a02da3cfc /apps
parentff00ebca308f445785096611c47beed0c2c46c9c (diff)
downloadkarakeep-d24e50950a36df12b7149b66762a231ac1da14d2.tar.zst
feature: Add support for local models using ollama
Diffstat (limited to 'apps')
-rw-r--r--apps/workers/inference.ts125
-rw-r--r--apps/workers/openaiWorker.ts97
-rw-r--r--apps/workers/package.json1
-rw-r--r--apps/workers/searchWorker.ts21
4 files changed, 168 insertions, 76 deletions
diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts
new file mode 100644
index 00000000..c622dd54
--- /dev/null
+++ b/apps/workers/inference.ts
@@ -0,0 +1,125 @@
+import { Ollama } from "ollama";
+import OpenAI from "openai";
+
+import serverConfig from "@hoarder/shared/config";
+
+export interface InferenceResponse {
+ response: string;
+ totalTokens: number | undefined;
+}
+
+export interface InferenceClient {
+ inferFromText(prompt: string): Promise<InferenceResponse>;
+ inferFromImage(
+ prompt: string,
+ contentType: string,
+ image: string,
+ ): Promise<InferenceResponse>;
+}
+
+export class InferenceClientFactory {
+ static build(): InferenceClient | null {
+ if (serverConfig.inference.openAIApiKey) {
+ return new OpenAIInferenceClient();
+ }
+
+ if (serverConfig.inference.ollamaBaseUrl) {
+ return new OllamaInferenceClient();
+ }
+ return null;
+ }
+}
+
+class OpenAIInferenceClient implements InferenceClient {
+ openAI: OpenAI;
+
+ constructor() {
+ this.openAI = new OpenAI({
+ apiKey: serverConfig.inference.openAIApiKey,
+ baseURL: serverConfig.inference.openAIBaseUrl,
+ });
+ }
+
+ async inferFromText(prompt: string): Promise<InferenceResponse> {
+ const chatCompletion = await this.openAI.chat.completions.create({
+ messages: [{ role: "system", content: prompt }],
+ model: serverConfig.inference.textModel,
+ response_format: { type: "json_object" },
+ });
+
+ const response = chatCompletion.choices[0].message.content;
+ if (!response) {
+ throw new Error(`Got no message content from OpenAI`);
+ }
+ return { response, totalTokens: chatCompletion.usage?.total_tokens };
+ }
+
+ async inferFromImage(
+ prompt: string,
+ contentType: string,
+ image: string,
+ ): Promise<InferenceResponse> {
+ const chatCompletion = await this.openAI.chat.completions.create({
+ model: serverConfig.inference.imageModel,
+ messages: [
+ {
+ role: "user",
+ content: [
+ { type: "text", text: prompt },
+ {
+ type: "image_url",
+ image_url: {
+ url: `data:${contentType};base64,${image}`,
+ detail: "low",
+ },
+ },
+ ],
+ },
+ ],
+ max_tokens: 2000,
+ });
+
+ const response = chatCompletion.choices[0].message.content;
+ if (!response) {
+ throw new Error(`Got no message content from OpenAI`);
+ }
+ return { response, totalTokens: chatCompletion.usage?.total_tokens };
+ }
+}
+
+class OllamaInferenceClient implements InferenceClient {
+ ollama: Ollama;
+
+ constructor() {
+ this.ollama = new Ollama({
+ host: serverConfig.inference.ollamaBaseUrl,
+ });
+ }
+
+ async inferFromText(prompt: string): Promise<InferenceResponse> {
+ const chatCompletion = await this.ollama.chat({
+ model: serverConfig.inference.textModel,
+ format: "json",
+ messages: [{ role: "system", content: prompt }],
+ });
+
+ const response = chatCompletion.message.content;
+
+ return { response, totalTokens: chatCompletion.eval_count };
+ }
+
+ async inferFromImage(
+ prompt: string,
+ _contentType: string,
+ image: string,
+ ): Promise<InferenceResponse> {
+ const chatCompletion = await this.ollama.chat({
+ model: serverConfig.inference.imageModel,
+ format: "json",
+ messages: [{ role: "user", content: prompt, images: [`${image}`] }],
+ });
+
+ const response = chatCompletion.message.content;
+ return { response, totalTokens: chatCompletion.eval_count };
+ }
+}
diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts
index 5f785f2f..b706fb90 100644
--- a/apps/workers/openaiWorker.ts
+++ b/apps/workers/openaiWorker.ts
@@ -1,13 +1,11 @@
import { Job, Worker } from "bullmq";
import { and, eq, inArray } from "drizzle-orm";
-import OpenAI from "openai";
import { z } from "zod";
import { db } from "@hoarder/db";
import { bookmarks, bookmarkTags, tagsOnBookmarks } from "@hoarder/db/schema";
-import serverConfig from "@hoarder/shared/config";
-import logger from "@hoarder/shared/logger";
import { readAsset } from "@hoarder/shared/assetdb";
+import logger from "@hoarder/shared/logger";
import {
OpenAIQueue,
queueConnectionDetails,
@@ -16,6 +14,8 @@ import {
zOpenAIRequestSchema,
} from "@hoarder/shared/queues";
+import { InferenceClientFactory, InferenceClient } from "./inference";
+
const openAIResponseSchema = z.object({
tags: z.array(z.string()),
});
@@ -41,8 +41,8 @@ async function attemptMarkTaggingStatus(
}
export class OpenAiWorker {
- static build() {
- logger.info("Starting openai worker ...");
+ static async build() {
+ logger.info("Starting inference worker ...");
const worker = new Worker<ZOpenAIRequest, void>(
OpenAIQueue.name,
runOpenAI,
@@ -54,13 +54,13 @@ export class OpenAiWorker {
worker.on("completed", async (job): Promise<void> => {
const jobId = job?.id ?? "unknown";
- logger.info(`[openai][${jobId}] Completed successfully`);
+ logger.info(`[inference][${jobId}] Completed successfully`);
await attemptMarkTaggingStatus(job?.data, "success");
});
worker.on("failed", async (job, error): Promise<void> => {
const jobId = job?.id ?? "unknown";
- logger.error(`[openai][${jobId}] openai job failed: ${error}`);
+ logger.error(`[inference][${jobId}] inference job failed: ${error}`);
await attemptMarkTaggingStatus(job?.data, "failure");
});
@@ -138,82 +138,52 @@ async function fetchBookmark(linkId: string) {
async function inferTagsFromImage(
jobId: string,
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
- openai: OpenAI,
+ inferenceClient: InferenceClient,
) {
-
const { asset, metadata } = await readAsset({
userId: bookmark.userId,
assetId: bookmark.asset.assetId,
});
if (!asset) {
- throw new Error(`[openai][${jobId}] AssetId ${bookmark.asset.assetId} for bookmark ${bookmark.id} not found`);
+ throw new Error(
+ `[inference][${jobId}] AssetId ${bookmark.asset.assetId} for bookmark ${bookmark.id} not found`,
+ );
}
- const base64 = asset.toString('base64');
-
- const chatCompletion = await openai.chat.completions.create({
- model: serverConfig.inference.imageModel,
- messages: [
- {
- role: "user",
- content: [
- { type: "text", text: IMAGE_PROMPT_BASE },
- {
- type: "image_url",
- image_url: {
- url: `data:${metadata.contentType};base64,${base64}`,
- detail: "low",
- },
- },
- ],
- },
- ],
- max_tokens: 2000,
- });
+ const base64 = asset.toString("base64");
- const response = chatCompletion.choices[0].message.content;
- if (!response) {
- throw new Error(`[openai][${jobId}] Got no message content from OpenAI`);
- }
- return { response, totalTokens: chatCompletion.usage?.total_tokens };
+ return await inferenceClient.inferFromImage(
+ IMAGE_PROMPT_BASE,
+ metadata.contentType,
+ base64,
+ );
}
async function inferTagsFromText(
- jobId: string,
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
- openai: OpenAI,
+ inferenceClient: InferenceClient,
) {
- const chatCompletion = await openai.chat.completions.create({
- messages: [{ role: "system", content: buildPrompt(bookmark) }],
- model: serverConfig.inference.textModel,
- response_format: { type: "json_object" },
- });
-
- const response = chatCompletion.choices[0].message.content;
- if (!response) {
- throw new Error(`[openai][${jobId}] Got no message content from OpenAI`);
- }
- return { response, totalTokens: chatCompletion.usage?.total_tokens };
+ return await inferenceClient.inferFromText(buildPrompt(bookmark));
}
async function inferTags(
jobId: string,
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
- openai: OpenAI,
+ inferenceClient: InferenceClient,
) {
let response;
if (bookmark.link || bookmark.text) {
- response = await inferTagsFromText(jobId, bookmark, openai);
+ response = await inferTagsFromText(bookmark, inferenceClient);
} else if (bookmark.asset) {
- response = await inferTagsFromImage(jobId, bookmark, openai);
+ response = await inferTagsFromImage(jobId, bookmark, inferenceClient);
} else {
- throw new Error(`[openai][${jobId}] Unsupported bookmark type`);
+ throw new Error(`[inference][${jobId}] Unsupported bookmark type`);
}
try {
let tags = openAIResponseSchema.parse(JSON.parse(response.response)).tags;
logger.info(
- `[openai][${jobId}] Inferring tag for bookmark "${bookmark.id}" used ${response.totalTokens} tokens and inferred: ${tags}`,
+ `[inference][${jobId}] Inferring tag for bookmark "${bookmark.id}" used ${response.totalTokens} tokens and inferred: ${tags}`,
);
// Sometimes the tags contain the hashtag symbol, let's strip them out if they do.
@@ -227,7 +197,7 @@ async function inferTags(
return tags;
} catch (e) {
throw new Error(
- `[openai][${jobId}] Failed to parse JSON response from OpenAI: ${e}`,
+ `[inference][${jobId}] Failed to parse JSON response from inference client: ${e}`,
);
}
}
@@ -292,23 +262,18 @@ async function connectTags(
async function runOpenAI(job: Job<ZOpenAIRequest, void>) {
const jobId = job.id ?? "unknown";
- const { inference } = serverConfig;
-
- if (!inference.openAIApiKey) {
+ const inferenceClient = InferenceClientFactory.build();
+ if (!inferenceClient) {
logger.debug(
- `[openai][${jobId}] OpenAI is not configured, nothing to do now`,
+ `[inference][${jobId}] No inference client configured, nothing to do now`,
);
return;
}
- const openai = new OpenAI({
- apiKey: inference.openAIApiKey,
- });
-
const request = zOpenAIRequestSchema.safeParse(job.data);
if (!request.success) {
throw new Error(
- `[openai][${jobId}] Got malformed job request: ${request.error.toString()}`,
+ `[inference][${jobId}] Got malformed job request: ${request.error.toString()}`,
);
}
@@ -316,11 +281,11 @@ async function runOpenAI(job: Job<ZOpenAIRequest, void>) {
const bookmark = await fetchBookmark(bookmarkId);
if (!bookmark) {
throw new Error(
- `[openai][${jobId}] bookmark with id ${bookmarkId} was not found`,
+ `[inference][${jobId}] bookmark with id ${bookmarkId} was not found`,
);
}
- const tags = await inferTags(jobId, bookmark, openai);
+ const tags = await inferTags(jobId, bookmark, inferenceClient);
await connectTags(bookmarkId, tags, bookmark.userId);
diff --git a/apps/workers/package.json b/apps/workers/package.json
index f6d58eb4..27a02f88 100644
--- a/apps/workers/package.json
+++ b/apps/workers/package.json
@@ -24,6 +24,7 @@
"metascraper-title": "^5.43.4",
"metascraper-twitter": "^5.43.4",
"metascraper-url": "^5.43.4",
+ "ollama": "^0.5.0",
"openai": "^4.29.0",
"puppeteer": "^22.0.0",
"puppeteer-extra": "^3.3.6",
diff --git a/apps/workers/searchWorker.ts b/apps/workers/searchWorker.ts
index 618e7c89..b24777d7 100644
--- a/apps/workers/searchWorker.ts
+++ b/apps/workers/searchWorker.ts
@@ -1,16 +1,17 @@
+import type { Job } from "bullmq";
+import { Worker } from "bullmq";
+import { eq } from "drizzle-orm";
+
+import type { ZSearchIndexingRequest } from "@hoarder/shared/queues";
import { db } from "@hoarder/db";
+import { bookmarks } from "@hoarder/db/schema";
import logger from "@hoarder/shared/logger";
-import { getSearchIdxClient } from "@hoarder/shared/search";
import {
- SearchIndexingQueue,
- ZSearchIndexingRequest,
queueConnectionDetails,
+ SearchIndexingQueue,
zSearchIndexingRequestSchema,
} from "@hoarder/shared/queues";
-import { Job } from "bullmq";
-import { Worker } from "bullmq";
-import { bookmarks } from "@hoarder/db/schema";
-import { eq } from "drizzle-orm";
+import { getSearchIdxClient } from "@hoarder/shared/search";
export class SearchIndexingWorker {
static async build() {
@@ -25,12 +26,12 @@ export class SearchIndexingWorker {
);
worker.on("completed", (job) => {
- const jobId = job?.id || "unknown";
+ const jobId = job?.id ?? "unknown";
logger.info(`[search][${jobId}] Completed successfully`);
});
worker.on("failed", (job, error) => {
- const jobId = job?.id || "unknown";
+ const jobId = job?.id ?? "unknown";
logger.error(`[search][${jobId}] openai job failed: ${error}`);
});
@@ -85,7 +86,7 @@ async function runDelete(
}
async function runSearchIndexing(job: Job<ZSearchIndexingRequest, void>) {
- const jobId = job.id || "unknown";
+ const jobId = job.id ?? "unknown";
const request = zSearchIndexingRequestSchema.safeParse(job.data);
if (!request.success) {