aboutsummaryrefslogtreecommitdiffstats
path: root/apps
diff options
context:
space:
mode:
authorMohamedBassem <me@mbassem.com>2024-10-12 17:25:01 +0000
committerMohamedBassem <me@mbassem.com>2024-10-12 17:37:42 +0000
commit1b09682685f54f29957163be9b9f9fc2de3b49cc (patch)
tree7f10a7635cf984acd45147c24ec3e1d35798e8ba /apps
parentc16173ea0fdbf6cc47b13756c0a77e8399669055 (diff)
downloadkarakeep-1b09682685f54f29957163be9b9f9fc2de3b49cc.tar.zst
feature: Allow customizing the inference's context length
Diffstat (limited to 'apps')
-rw-r--r--apps/web/components/dashboard/settings/AISettings.tsx1
-rw-r--r--apps/workers/inference.ts3
-rw-r--r--apps/workers/openaiWorker.ts16
-rw-r--r--apps/workers/package.json2
-rw-r--r--apps/workers/utils.ts9
5 files changed, 12 insertions, 19 deletions
diff --git a/apps/web/components/dashboard/settings/AISettings.tsx b/apps/web/components/dashboard/settings/AISettings.tsx
index 12f656ba..0a8db147 100644
--- a/apps/web/components/dashboard/settings/AISettings.tsx
+++ b/apps/web/components/dashboard/settings/AISettings.tsx
@@ -291,6 +291,7 @@ export function PromptDemo() {
.filter((p) => p.appliesTo == "text" || p.appliesTo == "all")
.map((p) => p.text),
"\n<CONTENT_HERE>\n",
+ /* context length */ 1024 /* The value here doesn't matter */,
).trim()}
</code>
<p>Image Prompt</p>
diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts
index 071f4742..41ceffd6 100644
--- a/apps/workers/inference.ts
+++ b/apps/workers/inference.ts
@@ -104,6 +104,9 @@ class OllamaInferenceClient implements InferenceClient {
format: "json",
stream: true,
keep_alive: serverConfig.inference.ollamaKeepAlive,
+ options: {
+ num_ctx: serverConfig.inference.contextLength,
+ },
messages: [
{ role: "user", content: prompt, images: image ? [image] : undefined },
],
diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts
index 6c6104f3..d51771b2 100644
--- a/apps/workers/openaiWorker.ts
+++ b/apps/workers/openaiWorker.ts
@@ -23,7 +23,7 @@ import {
import type { InferenceClient } from "./inference";
import { InferenceClientFactory } from "./inference";
-import { readPDFText, truncateContent } from "./utils";
+import { readPDFText } from "./utils";
const openAIResponseSchema = z.object({
tags: z.array(z.string()),
@@ -102,10 +102,7 @@ async function buildPrompt(
);
}
- let content = bookmark.link.content;
- if (content) {
- content = truncateContent(content);
- }
+ const content = bookmark.link.content;
return buildTextPrompt(
serverConfig.inference.inferredTagLang,
prompts,
@@ -113,16 +110,16 @@ async function buildPrompt(
Title: ${bookmark.link.title ?? ""}
Description: ${bookmark.link.description ?? ""}
Content: ${content ?? ""}`,
+ serverConfig.inference.contextLength,
);
}
if (bookmark.text) {
- const content = truncateContent(bookmark.text.text ?? "");
- // TODO: Ensure that the content doesn't exceed the context length of openai
return buildTextPrompt(
serverConfig.inference.inferredTagLang,
prompts,
- content,
+ bookmark.text.text ?? "",
+ serverConfig.inference.contextLength,
);
}
@@ -215,7 +212,8 @@ async function inferTagsFromPDF(
const prompt = buildTextPrompt(
serverConfig.inference.inferredTagLang,
await fetchCustomPrompts(bookmark.userId, "text"),
- `Content: ${truncateContent(pdfParse.text)}`,
+ `Content: ${pdfParse.text}`,
+ serverConfig.inference.contextLength,
);
return inferenceClient.inferFromText(prompt);
}
diff --git a/apps/workers/package.json b/apps/workers/package.json
index 35217c96..b8077954 100644
--- a/apps/workers/package.json
+++ b/apps/workers/package.json
@@ -26,7 +26,7 @@
"metascraper-title": "^5.43.4",
"metascraper-twitter": "^5.43.4",
"metascraper-url": "^5.43.4",
- "ollama": "^0.5.0",
+ "ollama": "^0.5.9",
"openai": "^4.67.1",
"pdf2json": "^3.0.5",
"pdfjs-dist": "^4.0.379",
diff --git a/apps/workers/utils.ts b/apps/workers/utils.ts
index 2372684e..8d297e05 100644
--- a/apps/workers/utils.ts
+++ b/apps/workers/utils.ts
@@ -36,12 +36,3 @@ export async function readPDFText(buffer: Buffer): Promise<{
pdfParser.parseBuffer(buffer);
});
}
-
-export function truncateContent(content: string, length = 1500) {
- let words = content.split(" ");
- if (words.length > length) {
- words = words.slice(0, length);
- content = words.join(" ");
- }
- return content;
-}