aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMohamedBassem <me@mbassem.com>2024-03-27 11:46:29 +0000
committerMohamed Bassem <me@mbassem.com>2024-03-27 11:56:13 +0000
commitd24e50950a36df12b7149b66762a231ac1da14d2 (patch)
tree1e0a9c072bce68c7f4f2bd98c8dd306a02da3cfc
parentff00ebca308f445785096611c47beed0c2c46c9c (diff)
downloadkarakeep-d24e50950a36df12b7149b66762a231ac1da14d2.tar.zst
feature: Add support for local models using ollama
-rw-r--r--README.md2
-rw-r--r--apps/workers/inference.ts125
-rw-r--r--apps/workers/openaiWorker.ts97
-rw-r--r--apps/workers/package.json1
-rw-r--r--apps/workers/searchWorker.ts21
-rw-r--r--docs/docs/01-intro.md2
-rw-r--r--docs/docs/02-installation.md11
-rw-r--r--docs/docs/03-configuration.md15
-rw-r--r--packages/shared/config.ts2
-rw-r--r--pnpm-lock.yaml11
10 files changed, 206 insertions, 81 deletions
diff --git a/README.md b/README.md
index fd1b8e26..cd8dea70 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,7 @@ A self-hostable bookmark-everything app with a touch of AI for the data hoarders
- ⬇️ Automatic fetching for link titles, descriptions and images.
- 📋 Sort your bookmarks into lists.
- 🔎 Full text search of all the content stored.
-- ✨ AI-based (aka chatgpt) automatic tagging.
+- ✨ AI-based (aka chatgpt) automatic tagging. With supports for local models using ollama!
- 🔖 [Chrome plugin](https://chromewebstore.google.com/detail/hoarder/kgcjekpmcjjogibpjebkhaanilehneje) for quick bookmarking.
- 📱 An iOS app that's pending apple's review.
- 🌙 Dark mode support (web only so far).
diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts
new file mode 100644
index 00000000..c622dd54
--- /dev/null
+++ b/apps/workers/inference.ts
@@ -0,0 +1,125 @@
+import { Ollama } from "ollama";
+import OpenAI from "openai";
+
+import serverConfig from "@hoarder/shared/config";
+
+export interface InferenceResponse {
+ response: string;
+ totalTokens: number | undefined;
+}
+
+export interface InferenceClient {
+ inferFromText(prompt: string): Promise<InferenceResponse>;
+ inferFromImage(
+ prompt: string,
+ contentType: string,
+ image: string,
+ ): Promise<InferenceResponse>;
+}
+
+export class InferenceClientFactory {
+ static build(): InferenceClient | null {
+ if (serverConfig.inference.openAIApiKey) {
+ return new OpenAIInferenceClient();
+ }
+
+ if (serverConfig.inference.ollamaBaseUrl) {
+ return new OllamaInferenceClient();
+ }
+ return null;
+ }
+}
+
+class OpenAIInferenceClient implements InferenceClient {
+ openAI: OpenAI;
+
+ constructor() {
+ this.openAI = new OpenAI({
+ apiKey: serverConfig.inference.openAIApiKey,
+ baseURL: serverConfig.inference.openAIBaseUrl,
+ });
+ }
+
+ async inferFromText(prompt: string): Promise<InferenceResponse> {
+ const chatCompletion = await this.openAI.chat.completions.create({
+ messages: [{ role: "system", content: prompt }],
+ model: serverConfig.inference.textModel,
+ response_format: { type: "json_object" },
+ });
+
+ const response = chatCompletion.choices[0].message.content;
+ if (!response) {
+ throw new Error(`Got no message content from OpenAI`);
+ }
+ return { response, totalTokens: chatCompletion.usage?.total_tokens };
+ }
+
+ async inferFromImage(
+ prompt: string,
+ contentType: string,
+ image: string,
+ ): Promise<InferenceResponse> {
+ const chatCompletion = await this.openAI.chat.completions.create({
+ model: serverConfig.inference.imageModel,
+ messages: [
+ {
+ role: "user",
+ content: [
+ { type: "text", text: prompt },
+ {
+ type: "image_url",
+ image_url: {
+ url: `data:${contentType};base64,${image}`,
+ detail: "low",
+ },
+ },
+ ],
+ },
+ ],
+ max_tokens: 2000,
+ });
+
+ const response = chatCompletion.choices[0].message.content;
+ if (!response) {
+ throw new Error(`Got no message content from OpenAI`);
+ }
+ return { response, totalTokens: chatCompletion.usage?.total_tokens };
+ }
+}
+
+class OllamaInferenceClient implements InferenceClient {
+ ollama: Ollama;
+
+ constructor() {
+ this.ollama = new Ollama({
+ host: serverConfig.inference.ollamaBaseUrl,
+ });
+ }
+
+ async inferFromText(prompt: string): Promise<InferenceResponse> {
+ const chatCompletion = await this.ollama.chat({
+ model: serverConfig.inference.textModel,
+ format: "json",
+ messages: [{ role: "system", content: prompt }],
+ });
+
+ const response = chatCompletion.message.content;
+
+ return { response, totalTokens: chatCompletion.eval_count };
+ }
+
+ async inferFromImage(
+ prompt: string,
+ _contentType: string,
+ image: string,
+ ): Promise<InferenceResponse> {
+ const chatCompletion = await this.ollama.chat({
+ model: serverConfig.inference.imageModel,
+ format: "json",
+ messages: [{ role: "user", content: prompt, images: [`${image}`] }],
+ });
+
+ const response = chatCompletion.message.content;
+ return { response, totalTokens: chatCompletion.eval_count };
+ }
+}
diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts
index 5f785f2f..b706fb90 100644
--- a/apps/workers/openaiWorker.ts
+++ b/apps/workers/openaiWorker.ts
@@ -1,13 +1,11 @@
import { Job, Worker } from "bullmq";
import { and, eq, inArray } from "drizzle-orm";
-import OpenAI from "openai";
import { z } from "zod";
import { db } from "@hoarder/db";
import { bookmarks, bookmarkTags, tagsOnBookmarks } from "@hoarder/db/schema";
-import serverConfig from "@hoarder/shared/config";
-import logger from "@hoarder/shared/logger";
import { readAsset } from "@hoarder/shared/assetdb";
+import logger from "@hoarder/shared/logger";
import {
OpenAIQueue,
queueConnectionDetails,
@@ -16,6 +14,8 @@ import {
zOpenAIRequestSchema,
} from "@hoarder/shared/queues";
+import { InferenceClientFactory, InferenceClient } from "./inference";
+
const openAIResponseSchema = z.object({
tags: z.array(z.string()),
});
@@ -41,8 +41,8 @@ async function attemptMarkTaggingStatus(
}
export class OpenAiWorker {
- static build() {
- logger.info("Starting openai worker ...");
+ static async build() {
+ logger.info("Starting inference worker ...");
const worker = new Worker<ZOpenAIRequest, void>(
OpenAIQueue.name,
runOpenAI,
@@ -54,13 +54,13 @@ export class OpenAiWorker {
worker.on("completed", async (job): Promise<void> => {
const jobId = job?.id ?? "unknown";
- logger.info(`[openai][${jobId}] Completed successfully`);
+ logger.info(`[inference][${jobId}] Completed successfully`);
await attemptMarkTaggingStatus(job?.data, "success");
});
worker.on("failed", async (job, error): Promise<void> => {
const jobId = job?.id ?? "unknown";
- logger.error(`[openai][${jobId}] openai job failed: ${error}`);
+ logger.error(`[inference][${jobId}] inference job failed: ${error}`);
await attemptMarkTaggingStatus(job?.data, "failure");
});
@@ -138,82 +138,52 @@ async function fetchBookmark(linkId: string) {
async function inferTagsFromImage(
jobId: string,
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
- openai: OpenAI,
+ inferenceClient: InferenceClient,
) {
-
const { asset, metadata } = await readAsset({
userId: bookmark.userId,
assetId: bookmark.asset.assetId,
});
if (!asset) {
- throw new Error(`[openai][${jobId}] AssetId ${bookmark.asset.assetId} for bookmark ${bookmark.id} not found`);
+ throw new Error(
+ `[inference][${jobId}] AssetId ${bookmark.asset.assetId} for bookmark ${bookmark.id} not found`,
+ );
}
- const base64 = asset.toString('base64');
-
- const chatCompletion = await openai.chat.completions.create({
- model: serverConfig.inference.imageModel,
- messages: [
- {
- role: "user",
- content: [
- { type: "text", text: IMAGE_PROMPT_BASE },
- {
- type: "image_url",
- image_url: {
- url: `data:${metadata.contentType};base64,${base64}`,
- detail: "low",
- },
- },
- ],
- },
- ],
- max_tokens: 2000,
- });
+ const base64 = asset.toString("base64");
- const response = chatCompletion.choices[0].message.content;
- if (!response) {
- throw new Error(`[openai][${jobId}] Got no message content from OpenAI`);
- }
- return { response, totalTokens: chatCompletion.usage?.total_tokens };
+ return await inferenceClient.inferFromImage(
+ IMAGE_PROMPT_BASE,
+ metadata.contentType,
+ base64,
+ );
}
async function inferTagsFromText(
- jobId: string,
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
- openai: OpenAI,
+ inferenceClient: InferenceClient,
) {
- const chatCompletion = await openai.chat.completions.create({
- messages: [{ role: "system", content: buildPrompt(bookmark) }],
- model: serverConfig.inference.textModel,
- response_format: { type: "json_object" },
- });
-
- const response = chatCompletion.choices[0].message.content;
- if (!response) {
- throw new Error(`[openai][${jobId}] Got no message content from OpenAI`);
- }
- return { response, totalTokens: chatCompletion.usage?.total_tokens };
+ return await inferenceClient.inferFromText(buildPrompt(bookmark));
}
async function inferTags(
jobId: string,
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
- openai: OpenAI,
+ inferenceClient: InferenceClient,
) {
let response;
if (bookmark.link || bookmark.text) {
- response = await inferTagsFromText(jobId, bookmark, openai);
+ response = await inferTagsFromText(bookmark, inferenceClient);
} else if (bookmark.asset) {
- response = await inferTagsFromImage(jobId, bookmark, openai);
+ response = await inferTagsFromImage(jobId, bookmark, inferenceClient);
} else {
- throw new Error(`[openai][${jobId}] Unsupported bookmark type`);
+ throw new Error(`[inference][${jobId}] Unsupported bookmark type`);
}
try {
let tags = openAIResponseSchema.parse(JSON.parse(response.response)).tags;
logger.info(
- `[openai][${jobId}] Inferring tag for bookmark "${bookmark.id}" used ${response.totalTokens} tokens and inferred: ${tags}`,
+ `[inference][${jobId}] Inferring tag for bookmark "${bookmark.id}" used ${response.totalTokens} tokens and inferred: ${tags}`,
);
// Sometimes the tags contain the hashtag symbol, let's strip them out if they do.
@@ -227,7 +197,7 @@ async function inferTags(
return tags;
} catch (e) {
throw new Error(
- `[openai][${jobId}] Failed to parse JSON response from OpenAI: ${e}`,
+ `[inference][${jobId}] Failed to parse JSON response from inference client: ${e}`,
);
}
}
@@ -292,23 +262,18 @@ async function connectTags(
async function runOpenAI(job: Job<ZOpenAIRequest, void>) {
const jobId = job.id ?? "unknown";
- const { inference } = serverConfig;
-
- if (!inference.openAIApiKey) {
+ const inferenceClient = InferenceClientFactory.build();
+ if (!inferenceClient) {
logger.debug(
- `[openai][${jobId}] OpenAI is not configured, nothing to do now`,
+ `[inference][${jobId}] No inference client configured, nothing to do now`,
);
return;
}
- const openai = new OpenAI({
- apiKey: inference.openAIApiKey,
- });
-
const request = zOpenAIRequestSchema.safeParse(job.data);
if (!request.success) {
throw new Error(
- `[openai][${jobId}] Got malformed job request: ${request.error.toString()}`,
+ `[inference][${jobId}] Got malformed job request: ${request.error.toString()}`,
);
}
@@ -316,11 +281,11 @@ async function runOpenAI(job: Job<ZOpenAIRequest, void>) {
const bookmark = await fetchBookmark(bookmarkId);
if (!bookmark) {
throw new Error(
- `[openai][${jobId}] bookmark with id ${bookmarkId} was not found`,
+ `[inference][${jobId}] bookmark with id ${bookmarkId} was not found`,
);
}
- const tags = await inferTags(jobId, bookmark, openai);
+ const tags = await inferTags(jobId, bookmark, inferenceClient);
await connectTags(bookmarkId, tags, bookmark.userId);
diff --git a/apps/workers/package.json b/apps/workers/package.json
index f6d58eb4..27a02f88 100644
--- a/apps/workers/package.json
+++ b/apps/workers/package.json
@@ -24,6 +24,7 @@
"metascraper-title": "^5.43.4",
"metascraper-twitter": "^5.43.4",
"metascraper-url": "^5.43.4",
+ "ollama": "^0.5.0",
"openai": "^4.29.0",
"puppeteer": "^22.0.0",
"puppeteer-extra": "^3.3.6",
diff --git a/apps/workers/searchWorker.ts b/apps/workers/searchWorker.ts
index 618e7c89..b24777d7 100644
--- a/apps/workers/searchWorker.ts
+++ b/apps/workers/searchWorker.ts
@@ -1,16 +1,17 @@
+import type { Job } from "bullmq";
+import { Worker } from "bullmq";
+import { eq } from "drizzle-orm";
+
+import type { ZSearchIndexingRequest } from "@hoarder/shared/queues";
import { db } from "@hoarder/db";
+import { bookmarks } from "@hoarder/db/schema";
import logger from "@hoarder/shared/logger";
-import { getSearchIdxClient } from "@hoarder/shared/search";
import {
- SearchIndexingQueue,
- ZSearchIndexingRequest,
queueConnectionDetails,
+ SearchIndexingQueue,
zSearchIndexingRequestSchema,
} from "@hoarder/shared/queues";
-import { Job } from "bullmq";
-import { Worker } from "bullmq";
-import { bookmarks } from "@hoarder/db/schema";
-import { eq } from "drizzle-orm";
+import { getSearchIdxClient } from "@hoarder/shared/search";
export class SearchIndexingWorker {
static async build() {
@@ -25,12 +26,12 @@ export class SearchIndexingWorker {
);
worker.on("completed", (job) => {
- const jobId = job?.id || "unknown";
+ const jobId = job?.id ?? "unknown";
logger.info(`[search][${jobId}] Completed successfully`);
});
worker.on("failed", (job, error) => {
- const jobId = job?.id || "unknown";
+ const jobId = job?.id ?? "unknown";
logger.error(`[search][${jobId}] openai job failed: ${error}`);
});
@@ -85,7 +86,7 @@ async function runDelete(
}
async function runSearchIndexing(job: Job<ZSearchIndexingRequest, void>) {
- const jobId = job.id || "unknown";
+ const jobId = job.id ?? "unknown";
const request = zSearchIndexingRequestSchema.safeParse(job.data);
if (!request.success) {
diff --git a/docs/docs/01-intro.md b/docs/docs/01-intro.md
index e5eac1dc..7413e163 100644
--- a/docs/docs/01-intro.md
+++ b/docs/docs/01-intro.md
@@ -15,7 +15,7 @@ Hoarder is an open source "Bookmark Everything" app that uses AI for automatical
- ⬇️ Automatic fetching for link titles, descriptions and images.
- 📋 Sort your bookmarks into lists.
- 🔎 Full text search of all the content stored.
-- ✨ AI-based (aka chatgpt) automatic tagging.
+- ✨ AI-based (aka chatgpt) automatic tagging. With supports for local models using ollama!
- 🔖 [Chrome plugin](https://chromewebstore.google.com/detail/hoarder/kgcjekpmcjjogibpjebkhaanilehneje) for quick bookmarking.
- 📱 An iOS app that's pending apple's review.
- 🌙 Dark mode support.
diff --git a/docs/docs/02-installation.md b/docs/docs/02-installation.md
index 0a25c7bf..50069e31 100644
--- a/docs/docs/02-installation.md
+++ b/docs/docs/02-installation.md
@@ -46,9 +46,18 @@ To enable automatic tagging, you'll need to configure OpenAI. This is optional t
Learn more about the costs of using openai [here](/openai).
+<details>
+ <summary>If you want to use Ollama (https://ollama.com/) instead for local inference.</summary>
-### 5. Start the service
+ - Make sure ollama is running.
+ - Set the `OLLAMA_BASE_URL` env variable to the address of the ollama API.
+ - Set `INFERENCE_TEXT_MODEL` to the model you want to use for text inference in ollama (for example: `llama2`)
+ - Set `INFERENCE_IMAGE_MODEL` to the model you want to use for image inference in ollama (for example: `llava`)
+ - Make sure that you `ollama pull`-ed the models that you want to use.
+
+</details>
+### 5. Start the service
Start the service by running:
diff --git a/docs/docs/03-configuration.md b/docs/docs/03-configuration.md
index 585d25b5..bba81b70 100644
--- a/docs/docs/03-configuration.md
+++ b/docs/docs/03-configuration.md
@@ -8,6 +8,17 @@ The app is mainly configured by environment variables. All the used environment
| NEXTAUTH_SECRET | Yes | Not set | Random string used to sign the JWT tokens. Generate one with `openssl rand -base64 36`. |
| REDIS_HOST | Yes | localhost | The address of redis used by background jobs |
| REDIS_POST | Yes | 6379 | The port of redis used by background jobs |
-| OPENAI_API_KEY | No | Not set | The OpenAI key used for automatic tagging. If not set, automatic tagging won't be enabled. More on that in [here](/openai). |
-| MEILI_ADDR | No | Not set | The address of meilisearch. If not set, Search will be disabled. E.g. (`http://meilisearch:7700`) |
+| MEILI_ADDR | No | Not set | The address of meilisearch. If not set, Search will be disabled. E.g. (`http://meilisearch:7700`) |
| MEILI_MASTER_KEY | Only in Prod and if search is enabled | Not set | The master key configured for meilisearch. Not needed in development environment. Generate one with `openssl rand -base64 36` |
+
+## Inference Configs (For automatic tagging)
+
+Either `OPENAI_API_KEY` or `OLLAMA_BASE_URL` need to be set for automatic tagging to be enabled. Otherwise, automatic tagging will be skipped.
+
+| Name | Required | Default | Description |
+| --------------------- | -------- | -------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| OPENAI_API_KEY | No | Not set | The OpenAI key used for automatic tagging. More on that in [here](/openai). |
+| OPENAI_BASE_URL | No | Not set | If you just want to use OpenAI you don't need to pass this variable. If, however, you want to use some other openai compatible API (e.g. azure openai service), set this to the url of the API. |
+| OLLAMA_BASE_URL | No | Not set | If you want to use ollama for local inference, set the address of ollama API here. |
+| INFERENCE_TEXT_MODEL | No | gpt-3.5-turbo-0125 | The model to use for text inference. You'll need to change this to some other model if you're using ollama. |
+| INFERENCE_IMAGE_MODEL | No | gpt-4-vision-preview | The model to use for image inference. You'll need to change this to some other model if you're using ollama and that model needs to support vision APIs (e.g. llava). |
diff --git a/packages/shared/config.ts b/packages/shared/config.ts
index 5d83b4f0..ff077147 100644
--- a/packages/shared/config.ts
+++ b/packages/shared/config.ts
@@ -12,6 +12,7 @@ const allEnv = z.object({
DISABLE_SIGNUPS: stringBool("false"),
OPENAI_API_KEY: z.string().optional(),
OPENAI_BASE_URL: z.string().url().optional(),
+ OLLAMA_BASE_URL: z.string().url().optional(),
INFERENCE_TEXT_MODEL: z.string().default("gpt-3.5-turbo-0125"),
INFERENCE_IMAGE_MODEL: z.string().default("gpt-4-vision-preview"),
REDIS_HOST: z.string().default("localhost"),
@@ -38,6 +39,7 @@ const serverConfigSchema = allEnv.transform((val) => {
inference: {
openAIApiKey: val.OPENAI_API_KEY,
openAIBaseUrl: val.OPENAI_BASE_URL,
+ ollamaBaseUrl: val.OLLAMA_BASE_URL,
textModel: val.INFERENCE_TEXT_MODEL,
imageModel: val.INFERENCE_IMAGE_MODEL,
},
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index d5dd4f75..155c1e32 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -578,6 +578,9 @@ importers:
metascraper-url:
specifier: ^5.43.4
version: 5.45.0
+ ollama:
+ specifier: ^0.5.0
+ version: 0.5.0
openai:
specifier: ^4.29.0
version: 4.29.0
@@ -8765,6 +8768,9 @@ packages:
resolution: {integrity: sha512-IF4PcGgzAr6XXSff26Sk/+P4KZFJVuHAJZj3wgO3vX2bMdNVp/QXTP3P7CEm9V1IdG8lDLY3HhiqpsE/nOwpPw==}
engines: {node: ^10.13.0 || >=12.0.0}
+ ollama@0.5.0:
+ resolution: {integrity: sha512-CRtRzsho210EGdK52GrUMohA2pU+7NbgEaBG3DcYeRmvQthDO7E2LHOkLlUUeaYUlNmEd8icbjC02ug9meSYnw==}
+
on-finished@2.3.0:
resolution: {integrity: sha512-ikqdkGAAyf/X/gPhXGvfgAytDZtDbr+bkNUJ0N9h5MI/dmdgCs3l6hoHrcUv41sRKew3jIwrp4qQDXiK99Utww==}
engines: {node: '>= 0.8'}
@@ -24367,6 +24373,11 @@ snapshots:
oidc-token-hash@5.0.3:
dev: false
+ ollama@0.5.0:
+ dependencies:
+ whatwg-fetch: 3.6.20
+ dev: false
+
on-finished@2.3.0:
dependencies:
ee-first: 1.1.1