aboutsummaryrefslogtreecommitdiffstats
path: root/apps/workers/inference.ts
diff options
context:
space:
mode:
authorMohamedBassem <me@mbassem.com>2024-03-27 16:30:27 +0000
committerMohamedBassem <me@mbassem.com>2024-03-27 16:34:29 +0000
commit9986746aa890f2490ff18fd4fc79be4de0e4dbe2 (patch)
tree094054ddebd76e155eac798ca7ca7fc93fe6c2c5 /apps/workers/inference.ts
parent5cbce67fdae7ef697dd999b0f1e3cc6ed9c53e3f (diff)
downloadkarakeep-9986746aa890f2490ff18fd4fc79be4de0e4dbe2.tar.zst
fix: Attempt to increase the reliability of the ollama inference
Diffstat (limited to 'apps/workers/inference.ts')
-rw-r--r--apps/workers/inference.ts45
1 files changed, 32 insertions, 13 deletions
diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts
index c622dd54..3b0b5943 100644
--- a/apps/workers/inference.ts
+++ b/apps/workers/inference.ts
@@ -2,6 +2,7 @@ import { Ollama } from "ollama";
import OpenAI from "openai";
import serverConfig from "@hoarder/shared/config";
+import logger from "@hoarder/shared/logger";
export interface InferenceResponse {
response: string;
@@ -96,16 +97,41 @@ class OllamaInferenceClient implements InferenceClient {
});
}
- async inferFromText(prompt: string): Promise<InferenceResponse> {
+ async runModel(model: string, prompt: string, image?: string) {
const chatCompletion = await this.ollama.chat({
- model: serverConfig.inference.textModel,
+ model: model,
format: "json",
- messages: [{ role: "system", content: prompt }],
+ stream: true,
+ messages: [
+ { role: "user", content: prompt, images: image ? [image] : undefined },
+ ],
});
- const response = chatCompletion.message.content;
+ let totalTokens = 0;
+ let response = "";
+ try {
+ for await (const part of chatCompletion) {
+ response += part.message.content;
+ if (!isNaN(part.eval_count)) {
+ totalTokens += part.eval_count;
+ }
+ if (!isNaN(part.prompt_eval_count)) {
+ totalTokens += part.prompt_eval_count;
+ }
+ }
+ } catch (e) {
+ // There seem to be some bug in ollama where you can get some successfull response, but still throw an error.
+ // Using stream + accumulating the response so far is a workaround.
+ // https://github.com/ollama/ollama-js/issues/72
+ totalTokens = NaN;
+ logger.warn(`Got an exception from ollama, will still attempt to deserialize the response we got so far: ${e}`)
+ }
+
+ return { response, totalTokens };
+ }
- return { response, totalTokens: chatCompletion.eval_count };
+ async inferFromText(prompt: string): Promise<InferenceResponse> {
+ return await this.runModel(serverConfig.inference.textModel, prompt);
}
async inferFromImage(
@@ -113,13 +139,6 @@ class OllamaInferenceClient implements InferenceClient {
_contentType: string,
image: string,
): Promise<InferenceResponse> {
- const chatCompletion = await this.ollama.chat({
- model: serverConfig.inference.imageModel,
- format: "json",
- messages: [{ role: "user", content: prompt, images: [`${image}`] }],
- });
-
- const response = chatCompletion.message.content;
- return { response, totalTokens: chatCompletion.eval_count };
+ return await this.runModel(serverConfig.inference.imageModel, prompt, image);
}
}