diff options
| author | MohamedBassem <me@mbassem.com> | 2024-10-12 17:25:01 +0000 |
|---|---|---|
| committer | MohamedBassem <me@mbassem.com> | 2024-10-12 17:37:42 +0000 |
| commit | 1b09682685f54f29957163be9b9f9fc2de3b49cc (patch) | |
| tree | 7f10a7635cf984acd45147c24ec3e1d35798e8ba /apps | |
| parent | c16173ea0fdbf6cc47b13756c0a77e8399669055 (diff) | |
| download | karakeep-1b09682685f54f29957163be9b9f9fc2de3b49cc.tar.zst | |
feature: Allow customizing the inference's context length
Diffstat (limited to 'apps')
| -rw-r--r-- | apps/web/components/dashboard/settings/AISettings.tsx | 1 | ||||
| -rw-r--r-- | apps/workers/inference.ts | 3 | ||||
| -rw-r--r-- | apps/workers/openaiWorker.ts | 16 | ||||
| -rw-r--r-- | apps/workers/package.json | 2 | ||||
| -rw-r--r-- | apps/workers/utils.ts | 9 |
5 files changed, 12 insertions, 19 deletions
diff --git a/apps/web/components/dashboard/settings/AISettings.tsx b/apps/web/components/dashboard/settings/AISettings.tsx index 12f656ba..0a8db147 100644 --- a/apps/web/components/dashboard/settings/AISettings.tsx +++ b/apps/web/components/dashboard/settings/AISettings.tsx @@ -291,6 +291,7 @@ export function PromptDemo() { .filter((p) => p.appliesTo == "text" || p.appliesTo == "all") .map((p) => p.text), "\n<CONTENT_HERE>\n", + /* context length */ 1024 /* The value here doesn't matter */, ).trim()} </code> <p>Image Prompt</p> diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts index 071f4742..41ceffd6 100644 --- a/apps/workers/inference.ts +++ b/apps/workers/inference.ts @@ -104,6 +104,9 @@ class OllamaInferenceClient implements InferenceClient { format: "json", stream: true, keep_alive: serverConfig.inference.ollamaKeepAlive, + options: { + num_ctx: serverConfig.inference.contextLength, + }, messages: [ { role: "user", content: prompt, images: image ? [image] : undefined }, ], diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts index 6c6104f3..d51771b2 100644 --- a/apps/workers/openaiWorker.ts +++ b/apps/workers/openaiWorker.ts @@ -23,7 +23,7 @@ import { import type { InferenceClient } from "./inference"; import { InferenceClientFactory } from "./inference"; -import { readPDFText, truncateContent } from "./utils"; +import { readPDFText } from "./utils"; const openAIResponseSchema = z.object({ tags: z.array(z.string()), @@ -102,10 +102,7 @@ async function buildPrompt( ); } - let content = bookmark.link.content; - if (content) { - content = truncateContent(content); - } + const content = bookmark.link.content; return buildTextPrompt( serverConfig.inference.inferredTagLang, prompts, @@ -113,16 +110,16 @@ async function buildPrompt( Title: ${bookmark.link.title ?? ""} Description: ${bookmark.link.description ?? ""} Content: ${content ?? ""}`, + serverConfig.inference.contextLength, ); } if (bookmark.text) { - const content = truncateContent(bookmark.text.text ?? ""); - // TODO: Ensure that the content doesn't exceed the context length of openai return buildTextPrompt( serverConfig.inference.inferredTagLang, prompts, - content, + bookmark.text.text ?? "", + serverConfig.inference.contextLength, ); } @@ -215,7 +212,8 @@ async function inferTagsFromPDF( const prompt = buildTextPrompt( serverConfig.inference.inferredTagLang, await fetchCustomPrompts(bookmark.userId, "text"), - `Content: ${truncateContent(pdfParse.text)}`, + `Content: ${pdfParse.text}`, + serverConfig.inference.contextLength, ); return inferenceClient.inferFromText(prompt); } diff --git a/apps/workers/package.json b/apps/workers/package.json index 35217c96..b8077954 100644 --- a/apps/workers/package.json +++ b/apps/workers/package.json @@ -26,7 +26,7 @@ "metascraper-title": "^5.43.4", "metascraper-twitter": "^5.43.4", "metascraper-url": "^5.43.4", - "ollama": "^0.5.0", + "ollama": "^0.5.9", "openai": "^4.67.1", "pdf2json": "^3.0.5", "pdfjs-dist": "^4.0.379", diff --git a/apps/workers/utils.ts b/apps/workers/utils.ts index 2372684e..8d297e05 100644 --- a/apps/workers/utils.ts +++ b/apps/workers/utils.ts @@ -36,12 +36,3 @@ export async function readPDFText(buffer: Buffer): Promise<{ pdfParser.parseBuffer(buffer); }); } - -export function truncateContent(content: string, length = 1500) { - let words = content.split(" "); - if (words.length > length) { - words = words.slice(0, length); - content = words.join(" "); - } - return content; -} |
