diff options
| author | MohamedBassem <me@mbassem.com> | 2024-03-27 16:30:27 +0000 |
|---|---|---|
| committer | MohamedBassem <me@mbassem.com> | 2024-03-27 16:34:29 +0000 |
| commit | 9986746aa890f2490ff18fd4fc79be4de0e4dbe2 (patch) | |
| tree | 094054ddebd76e155eac798ca7ca7fc93fe6c2c5 /apps/workers/inference.ts | |
| parent | 5cbce67fdae7ef697dd999b0f1e3cc6ed9c53e3f (diff) | |
| download | karakeep-9986746aa890f2490ff18fd4fc79be4de0e4dbe2.tar.zst | |
fix: Attempt to increase the reliability of the ollama inference
Diffstat (limited to 'apps/workers/inference.ts')
| -rw-r--r-- | apps/workers/inference.ts | 45 |
1 files changed, 32 insertions, 13 deletions
diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts index c622dd54..3b0b5943 100644 --- a/apps/workers/inference.ts +++ b/apps/workers/inference.ts @@ -2,6 +2,7 @@ import { Ollama } from "ollama"; import OpenAI from "openai"; import serverConfig from "@hoarder/shared/config"; +import logger from "@hoarder/shared/logger"; export interface InferenceResponse { response: string; @@ -96,16 +97,41 @@ class OllamaInferenceClient implements InferenceClient { }); } - async inferFromText(prompt: string): Promise<InferenceResponse> { + async runModel(model: string, prompt: string, image?: string) { const chatCompletion = await this.ollama.chat({ - model: serverConfig.inference.textModel, + model: model, format: "json", - messages: [{ role: "system", content: prompt }], + stream: true, + messages: [ + { role: "user", content: prompt, images: image ? [image] : undefined }, + ], }); - const response = chatCompletion.message.content; + let totalTokens = 0; + let response = ""; + try { + for await (const part of chatCompletion) { + response += part.message.content; + if (!isNaN(part.eval_count)) { + totalTokens += part.eval_count; + } + if (!isNaN(part.prompt_eval_count)) { + totalTokens += part.prompt_eval_count; + } + } + } catch (e) { + // There seem to be some bug in ollama where you can get some successfull response, but still throw an error. + // Using stream + accumulating the response so far is a workaround. + // https://github.com/ollama/ollama-js/issues/72 + totalTokens = NaN; + logger.warn(`Got an exception from ollama, will still attempt to deserialize the response we got so far: ${e}`) + } + + return { response, totalTokens }; + } - return { response, totalTokens: chatCompletion.eval_count }; + async inferFromText(prompt: string): Promise<InferenceResponse> { + return await this.runModel(serverConfig.inference.textModel, prompt); } async inferFromImage( @@ -113,13 +139,6 @@ class OllamaInferenceClient implements InferenceClient { _contentType: string, image: string, ): Promise<InferenceResponse> { - const chatCompletion = await this.ollama.chat({ - model: serverConfig.inference.imageModel, - format: "json", - messages: [{ role: "user", content: prompt, images: [`${image}`] }], - }); - - const response = chatCompletion.message.content; - return { response, totalTokens: chatCompletion.eval_count }; + return await this.runModel(serverConfig.inference.imageModel, prompt, image); } } |
