aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMohamedBassem <me@mbassem.com>2024-03-27 16:30:27 +0000
committerMohamedBassem <me@mbassem.com>2024-03-27 16:34:29 +0000
commit9986746aa890f2490ff18fd4fc79be4de0e4dbe2 (patch)
tree094054ddebd76e155eac798ca7ca7fc93fe6c2c5
parent5cbce67fdae7ef697dd999b0f1e3cc6ed9c53e3f (diff)
downloadkarakeep-9986746aa890f2490ff18fd4fc79be4de0e4dbe2.tar.zst
fix: Attempt to increase the reliability of the ollama inference
-rw-r--r--apps/workers/inference.ts45
-rw-r--r--apps/workers/openaiWorker.ts11
-rw-r--r--docs/docs/02-installation.md2
-rw-r--r--packages/shared/queues.ts8
4 files changed, 49 insertions, 17 deletions
diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts
index c622dd54..3b0b5943 100644
--- a/apps/workers/inference.ts
+++ b/apps/workers/inference.ts
@@ -2,6 +2,7 @@ import { Ollama } from "ollama";
import OpenAI from "openai";
import serverConfig from "@hoarder/shared/config";
+import logger from "@hoarder/shared/logger";
export interface InferenceResponse {
response: string;
@@ -96,16 +97,41 @@ class OllamaInferenceClient implements InferenceClient {
});
}
- async inferFromText(prompt: string): Promise<InferenceResponse> {
+ async runModel(model: string, prompt: string, image?: string) {
const chatCompletion = await this.ollama.chat({
- model: serverConfig.inference.textModel,
+ model: model,
format: "json",
- messages: [{ role: "system", content: prompt }],
+ stream: true,
+ messages: [
+ { role: "user", content: prompt, images: image ? [image] : undefined },
+ ],
});
- const response = chatCompletion.message.content;
+ let totalTokens = 0;
+ let response = "";
+ try {
+ for await (const part of chatCompletion) {
+ response += part.message.content;
+ if (!isNaN(part.eval_count)) {
+ totalTokens += part.eval_count;
+ }
+ if (!isNaN(part.prompt_eval_count)) {
+ totalTokens += part.prompt_eval_count;
+ }
+ }
+ } catch (e) {
+ // There seem to be some bug in ollama where you can get some successfull response, but still throw an error.
+ // Using stream + accumulating the response so far is a workaround.
+ // https://github.com/ollama/ollama-js/issues/72
+ totalTokens = NaN;
+ logger.warn(`Got an exception from ollama, will still attempt to deserialize the response we got so far: ${e}`)
+ }
+
+ return { response, totalTokens };
+ }
- return { response, totalTokens: chatCompletion.eval_count };
+ async inferFromText(prompt: string): Promise<InferenceResponse> {
+ return await this.runModel(serverConfig.inference.textModel, prompt);
}
async inferFromImage(
@@ -113,13 +139,6 @@ class OllamaInferenceClient implements InferenceClient {
_contentType: string,
image: string,
): Promise<InferenceResponse> {
- const chatCompletion = await this.ollama.chat({
- model: serverConfig.inference.imageModel,
- format: "json",
- messages: [{ role: "user", content: prompt, images: [`${image}`] }],
- });
-
- const response = chatCompletion.message.content;
- return { response, totalTokens: chatCompletion.eval_count };
+ return await this.runModel(serverConfig.inference.imageModel, prompt, image);
}
}
diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts
index b706fb90..9b2934e3 100644
--- a/apps/workers/openaiWorker.ts
+++ b/apps/workers/openaiWorker.ts
@@ -14,7 +14,7 @@ import {
zOpenAIRequestSchema,
} from "@hoarder/shared/queues";
-import { InferenceClientFactory, InferenceClient } from "./inference";
+import { InferenceClient, InferenceClientFactory } from "./inference";
const openAIResponseSchema = z.object({
tags: z.array(z.string()),
@@ -36,7 +36,7 @@ async function attemptMarkTaggingStatus(
})
.where(eq(bookmarks.id, request.bookmarkId));
} catch (e) {
- console.log(`Something went wrong when marking the tagging status: ${e}`);
+ logger.error(`Something went wrong when marking the tagging status: ${e}`);
}
}
@@ -196,8 +196,9 @@ async function inferTags(
return tags;
} catch (e) {
+ const responseSneak = response.response.substr(0, 20);
throw new Error(
- `[inference][${jobId}] Failed to parse JSON response from inference client: ${e}`,
+ `[inference][${jobId}] The model ignored our prompt and didn't respond with the expected JSON: ${JSON.stringify(e)}. Here's a sneak peak from the response: ${responseSneak}`,
);
}
}
@@ -285,6 +286,10 @@ async function runOpenAI(job: Job<ZOpenAIRequest, void>) {
);
}
+ logger.info(
+ `[inference][${jobId}] Starting an inference job for bookmark with id "${bookmark.id}"`,
+ );
+
const tags = await inferTags(jobId, bookmark, inferenceClient);
await connectTags(bookmarkId, tags, bookmark.userId);
diff --git a/docs/docs/02-installation.md b/docs/docs/02-installation.md
index 70fc3bb1..94d44f5d 100644
--- a/docs/docs/02-installation.md
+++ b/docs/docs/02-installation.md
@@ -55,7 +55,7 @@ Learn more about the costs of using openai [here](/openai).
- Make sure ollama is running.
- Set the `OLLAMA_BASE_URL` env variable to the address of the ollama API.
- - Set `INFERENCE_TEXT_MODEL` to the model you want to use for text inference in ollama (for example: `llama2`)
+ - Set `INFERENCE_TEXT_MODEL` to the model you want to use for text inference in ollama (for example: `mistral`)
- Set `INFERENCE_IMAGE_MODEL` to the model you want to use for image inference in ollama (for example: `llava`)
- Make sure that you `ollama pull`-ed the models that you want to use.
diff --git a/packages/shared/queues.ts b/packages/shared/queues.ts
index b264e2c4..146c19c6 100644
--- a/packages/shared/queues.ts
+++ b/packages/shared/queues.ts
@@ -1,5 +1,6 @@
import { Queue } from "bullmq";
import { z } from "zod";
+
import serverConfig from "./config";
export const queueConnectionDetails = {
@@ -27,6 +28,13 @@ export type ZOpenAIRequest = z.infer<typeof zOpenAIRequestSchema>;
export const OpenAIQueue = new Queue<ZOpenAIRequest, void>("openai_queue", {
connection: queueConnectionDetails,
+ defaultJobOptions: {
+ attempts: 3,
+ backoff: {
+ type: "exponential",
+ delay: 500,
+ },
+ },
});
// Search Indexing Worker